diff --git a/warden_server/warden_server.py b/warden_server/warden_server.py index 4a7d95e02da83fc808d596e68851dee10e1b38f7..ec4bf8928cb5c1979b57b09cf22dac7a48b5999a 100755 --- a/warden_server/warden_server.py +++ b/warden_server/warden_server.py @@ -506,15 +506,20 @@ class DataBase(ObjectBase): self.tagmap = json.load(tagmap_fd) self.tagmap_other = self.tagmap["Other"] # Catch error soon, avoid lookup later + self.db = None self.con = None @override_required def connect(self): pass - @override_required def close(self): - pass + try: + if self.con: + self.con.close() + except Exception: + pass + self.con = None def __del__(self): self.close() @@ -555,7 +560,6 @@ class DataBase(ObjectBase): self.retry_attempt = 0 return self - @override_required def __exit__(self, exc_type, exc_val, exc_tb): """ Context manager protocol. If db exception is fired and self.retry_attempt is not zero, it is only logged and @@ -563,22 +567,66 @@ class DataBase(ObjectBase): open transaction is rolled back. In case of no exception, transaction gets commited. """ + if exc_type is None: + self.con.commit() + self.retry_attempt = 0 + else: + try: + if self.con is not None: + self.con.rollback() + except self.db.Error: + pass + try: + self.close() + except self.db.Error: + pass + if self.retry_attempt > 0: + self.log.info("Database error (%d attempts left): %s %s" % + (self.retry_attempt, exc_type.__name__, exc_val)) + return True + + def _query(self, *args, **kwargs): + if not self.con: + self.connect() + crs = self.con.cursor() + self.log.debug("execute: %s %s" % (args, kwargs)) + crs.execute(*args, **kwargs) + return crs + + def _query_multiple(self, query, params, ret, fetch): + res = None + for n, (q, p) in enumerate(zip(query, params)): + cur = self._query(q, p) + if n == ret: + res = fetch(cur) + if ret == -1: # fetch the result of the last query + res = fetch(cur) + return res - @override_required def execute(self, query, params, ret=None): """Execute the provided queries; discard the result""" + self._query_multiple(query, params, None, None) - @override_required def query_all(self, query, params, ret=-1): """Execute the provided queries; return list of all rows as dicts of the ret-th query (0 based)""" + return self._query_multiple(query, params, ret, lambda cur: cur.fetchall()) - @override_required - def query_one(self, query, prams, ret=-1): + def query_one(self, query, params, ret=-1): """Execute the provided queries; return the first result of the ret-th query (0 based)""" + return self._query_multiple(query, params, ret, lambda cur: cur.fetchone()) - @override_required def query_rowcount(self, query, params, ret=-1): """Execute provided query; return the number of affected rows or the number of returned rows of the ret-th query (0 based)""" + return self._query_multiple(query, params, ret, lambda cur: cur.rowcount) + + def _get_comma_perc(self, l): + return ",".join(repeat("%s", l if isinstance(l, int) else len(l))) + + def _get_comma_perc_n(self, n, l): + return ", ".join(repeat("(%s)" % self._get_comma_perc(n), len(l))) + + def _get_not(self, b): + return "" if b else "NOT" @override_required def _build_get_client_by_name(self, cert_names, name, secret): @@ -847,95 +895,7 @@ class DataBase(ObjectBase): return affected -class DataBaseAPIv2(DataBase): - - def __init__(self, req, log, host, user, password, dbname, port, retry_count, - retry_pause, event_size_limit, catmap_filename, tagmap_filename): - - super().__init__(req, log, host, user, password, dbname, port, retry_count, - retry_pause, event_size_limit, catmap_filename, tagmap_filename) - - self.db = None - self.con = None - - def close(self): - try: - if self.con: - self.con.close() - except Exception: - pass - self.con = None - - def __exit__(self, exc_type, exc_val, exc_tb): - """ Context manager protocol. If db exception is fired and - self.retry_attempt is not zero, it is only logged and - does not propagate, otherwise it propagates up. Also - open transaction is rolled back. - In case of no exception, transaction gets commited. - """ - if exc_type is None: - self.con.commit() - self.retry_attempt = 0 - else: - try: - if self.con is not None: - self.con.rollback() - except self.db.Error: - pass - try: - self.close() - except self.db.Error: - pass - if self.retry_attempt > 0: - self.log.info("Database error (%d attempts left): %s %s" % - (self.retry_attempt, exc_type.__name__, exc_val)) - return True - - def _query(self, *args, **kwargs): - if not self.con: - self.connect() - crs = self.con.cursor() - self.log.debug("execute: %s %s" % (args, kwargs)) - crs.execute(*args, **kwargs) - return crs - - def _query_multiple(self, query, params, ret, fetch): - res = None - for n, (q, p) in enumerate(zip(query, params)): - cur = self._query(q, p) - if n == ret: - res = fetch(cur) - if ret == -1: # fetch the result of the last query - res = fetch(cur) - return res - - def execute(self, query, params, ret=None): - """Execute the provided queries; discard the result""" - self._query_multiple(query, params, None, None) - - def query_all(self, query, params, ret=-1): - """Execute the provided queries; return list of all rows as dicts of the ret-th query (0 based)""" - return self._query_multiple(query, params, ret, lambda cur: cur.fetchall()) - - def query_one(self, query, params, ret=-1): - """Execute the provided queries; return the first result of the ret-th query (0 based)""" - return self._query_multiple(query, params, ret, lambda cur: cur.fetchone()) - - def query_rowcount(self, query, params, ret=-1): - """Execute provided query; return the number of affected rows or the number of returned rows of the ret-th query (0 based)""" - return self._query_multiple(query, params, ret, lambda cur: cur.rowcount) - - def _get_comma_perc(self, l): - return ",".join(repeat("%s", l if isinstance(l, int) else len(l))) - - def _get_comma_perc_n(self, n, l): - return ", ".join(repeat("(%s)" % self._get_comma_perc(n), len(l))) - - def _get_not(self, b): - return "" if b else "NOT" - - -class MySQL(DataBaseAPIv2): +class MySQL(DataBase): def __init__( self, req, log, host, user, password, dbname, port, retry_count, @@ -1187,7 +1147,7 @@ class MySQL(DataBaseAPIv2): ) -class PostgreSQL(DataBaseAPIv2): +class PostgreSQL(DataBase): def __init__( self, req, log, host, user, password, dbname, port, retry_count,