From 9e83ea75b8852fdb6200f4eef6a9d3b0addc2fd0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Radoslav=20Bod=C3=B3?= <bodik@cesnet.cz>
Date: Tue, 16 Apr 2024 14:33:02 +0200
Subject: [PATCH] rwm: add backup exclusive lock to support cron use-cases

---
 rwm.py            | 48 +++++++++++++++++++++++++++++++++++++++++++++--
 tests/test_rwm.py | 23 +++++++++++++++++++++++
 2 files changed, 69 insertions(+), 2 deletions(-)

diff --git a/rwm.py b/rwm.py
index d90dced..b759b5b 100755
--- a/rwm.py
+++ b/rwm.py
@@ -11,6 +11,7 @@ import sys
 from argparse import ArgumentParser
 from dataclasses import dataclass
 from datetime import datetime
+from fcntl import flock, LOCK_EX, LOCK_NB, LOCK_UN
 from io import BytesIO
 from pathlib import Path
 from typing import List, Dict, Optional
@@ -122,6 +123,9 @@ class RWMConfig(BaseModel):
         retention:
             Dictionary containing retention policies for Restic repository.
             Keys and values corresponds to a `restic forget` command `--keep*` options without leading dashes.
+
+        lock_path:
+            Path for parallel execution exclusion lock. Defaults to `/var/run/rwm.lock`.
     """
 
     model_config = ConfigDict(extra='forbid')
@@ -133,6 +137,7 @@ class RWMConfig(BaseModel):
     restic_password: Optional[str] = None
     backups: Dict[str, BackupConfig] = {}
     retention: Dict[str, str] = {}
+    lock_path: str = "/var/run/rwm.lock"
 
 
 class RwmJSONEncoder(json.JSONEncoder):
@@ -453,6 +458,37 @@ class StorageManager:
         return 0
 
 
+class LockManager:
+    """parallel execution locker"""
+
+    def __init__(self, lock_path):
+        self.lock_path = lock_path
+        self.lock_instance = None
+
+    def lock(self):
+        """acquire lock"""
+
+        self.lock_instance = Path(  # pylint: disable=consider-using-with
+            self.lock_path
+        ).open(mode="w+", encoding="utf-8")
+        try:
+            flock(self.lock_instance, LOCK_EX | LOCK_NB)
+        except BlockingIOError:
+            logger.warning("failed to acquired lock")
+            self.lock_instance.close()
+            self.lock_instance = None
+            return 1
+
+        return 0
+
+    def unlock(self):
+        """release lock"""
+
+        flock(self.lock_instance, LOCK_UN)
+        self.lock_instance.close()
+        self.lock_instance = None
+
+
 class RWM:
     """rwm impl"""
 
@@ -463,6 +499,7 @@ class RWM:
             self.config.s3_access_key,
             self.config.s3_secret_key
         )
+        self.cronlock = LockManager(self.config.lock_path)
 
     def aws_cmd(self, args) -> subprocess.CompletedProcess:
         """aws cli wrapper"""
@@ -551,6 +588,9 @@ class RWM:
     def backup(self, backup_selector: str | list) -> int:
         """backup command. perform selected backup or all configured backups"""
 
+        if self.cronlock.lock():
+            return 1
+
         stats = []
         ret = 0
         selected_backups = backup_selector if isinstance(backup_selector, list) else [backup_selector]
@@ -586,6 +626,8 @@ class RWM:
 
         logger.info("backup results")
         print(tabulate([item.to_dict() for item in stats], headers="keys", numalign="left"))
+
+        self.cronlock.unlock()
         return ret
 
     def backup_all(self) -> int:
@@ -749,11 +791,13 @@ def main(argv=None):  # pylint: disable=too-many-branches
 
     if args.command == "backup":
         ret = rwmi.backup(args.name)
-        logger.info("backup finished with %s (ret %d)", "success" if ret == 0 else "errors", ret)
+        severity, result = (logging.INFO, "success") if ret == 0 else (logging.ERROR, "errors")
+        logger.log(severity, f"backup finished with {result} (ret {ret})")
 
     if args.command == "backup-all":
         ret = rwmi.backup_all()
-        logger.info("backup_all finished with %s (ret %d)", "success" if ret == 0 else "errors", ret)
+        severity, result = (logging.INFO, "success") if ret == 0 else (logging.ERROR, "errors")
+        logger.log(severity, f"backup-all finished with {result} (ret {ret})")
 
     if args.command == "storage-create":
         ret = rwmi.storage_create(args.bucket_name, args.target_username)
diff --git a/tests/test_rwm.py b/tests/test_rwm.py
index ee8168a..70dd160 100644
--- a/tests/test_rwm.py
+++ b/tests/test_rwm.py
@@ -210,8 +210,16 @@ def test_backup_error_handling(tmpworkdir: str):  # pylint: disable=unused-argum
     mock_ok = Mock(return_value=0)
     mock_fail = Mock(return_value=11)
 
+    # when lock fails
+    with (
+        patch.object(rwm.LockManager, "lock", mock_fail),
+    ):
+        assert rwm.RWM(rwm_conf).backup("dummycfg") == 1
+
+    # when invalid selector
     assert rwm.RWM(rwm_conf).backup("invalidselector") == 1
 
+    # when backup fails (also triggers warnings)
     with (
         patch.object(rwm.StorageManager, "storage_check_policy", mock_false),
         patch.object(rwm.StorageManager, "storage_check_selfowned", mock_true),
@@ -220,6 +228,7 @@ def test_backup_error_handling(tmpworkdir: str):  # pylint: disable=unused-argum
     ):
         assert rwm.RWM(rwm_conf).backup("dummycfg") == 11
 
+    # when forget fails
     with (
         patch.object(rwm.StorageManager, "storage_check_policy", mock_true),
         patch.object(rwm.StorageManager, "storage_check_selfowned", mock_false),
@@ -229,6 +238,7 @@ def test_backup_error_handling(tmpworkdir: str):  # pylint: disable=unused-argum
     ):
         assert rwm.RWM(rwm_conf).backup("dummycfg") == 11
 
+    # when save state fails
     with (
         patch.object(rwm.StorageManager, "storage_check_policy", mock_true),
         patch.object(rwm.StorageManager, "storage_check_selfowned", mock_false),
@@ -379,3 +389,16 @@ def test_storage_restore_state_restic(tmpworkdir: str, radosuser_admin: rwm.Stor
     assert len(snapshot_files) == 1
     assert "/testdatadir/testdata1.txt" == snapshot_files[0]
     assert trwm_restore.restic_cmd(["check"]).returncode == 0
+
+
+def test_locks(tmpworkdir: str):  # pylint: disable=unused-argument
+    """test LockManager"""
+
+    lock_path = "./test.lock"
+    locker1 = rwm.LockManager(lock_path)
+    locker1.lock()
+
+    locker2 = rwm.LockManager(lock_path)
+    assert locker2.lock() == 1
+
+    locker1.unlock()
-- 
GitLab