diff --git a/warden3/warden_server/warden_server.py b/warden3/warden_server/warden_server.py index 90f62113958e9f6139e3f332789a81b22d6f61ce..3b51c50e7ead5066d1211ab7621b0487f7030970 100755 --- a/warden3/warden_server/warden_server.py +++ b/warden3/warden_server/warden_server.py @@ -21,6 +21,7 @@ from time import sleep from urlparse import parse_qs from os import path from random import randint +from contextlib import closing # for local version of up to date jsonschema sys.path.append(path.join(path.dirname(__file__), "..", "lib")) @@ -482,6 +483,7 @@ class MySQL(ObjectBase): self.port = port self.retry_count = retry_count self.retry_pause = retry_pause + self.retry_attempt = 0 self.event_size_limit = event_size_limit self.catmap_filename = catmap_filename self.tagmap_filename = tagmap_filename @@ -496,8 +498,6 @@ class MySQL(ObjectBase): self.con = None - self.connect() - def __str__(self): return "%s(req=%s, host='%s', user='%s', dbname='%s', port=%d, retry_count=%d, retry_pause=%d, catmap_filename=\"%s\", tagmap_filename=\"%s\")" % ( @@ -515,44 +515,84 @@ class MySQL(ObjectBase): self.con.close() except Exception: pass + self.con = None __del__ = close - def query(self, *args, **kwargs): - """ Execute query on self.con, reconnecting if necessary """ - countdown = self.retry_count - commit = kwargs.pop("commit", False) - crs = kwargs.pop("crs", None) - while True: + def repeat(self): + """ Allows for graceful repeating of transactions self.retry_count + times. Unsuccessful attempts wait for self.retry_pause until + next attempt. + + Meant for usage with context manager: + + for attempt in self.repeat(): + with attempt as db: + crs = db.query(...) + # do something with crs + + Note that it's not reentrant (as is not underlying MySQL + connection), so avoid nesting on the same MySQL object. + """ + self.retry_attempt = self.retry_count + while self.retry_attempt: + if self.retry_attempt != self.retry_count: + sleep(self.retry_pause) + self.retry_attempt -= 1 + yield self + + + def __enter__(self): + """ Context manager protocol. Guarantees that transaction will + get either commited or rolled back in case of database + exception. Can be used with self.repeat(), or alone as: + + with self as db: + crs = db.query(...) + # do something with crs + + Note that it's not reentrant (as is not underlying MySQL + connection), so avoid nesting on the same MySQL object. + """ + if not self.retry_attempt: + self.retry_attempt = 0 + return self + + + def __exit__(self, exc_type, exc_val, exc_tb): + """ Context manager protocol. If db exception is fired and + self.retry_attempt is not zero, it is only logged and + does not propagate, otherwise it propagates up. Also + open transaction is rolled back. + In case of no exception, transaction gets commited. + """ + if not exc_type: + self.con.commit() + self.retry_attempt = 0 + else: try: - if self.con is None: - self.connect() - if crs is None: - crs = self.con.cursor() - self.log.debug("execute: %s %s" % (args, kwargs)) - crs.execute(*args, **kwargs) - if commit: - self.con.commit() - return crs - except my.OperationalError: - if not countdown: - raise - self.log.info("execute: Database down, trying to reconnect (%d attempts left)..." % countdown) - if countdown<self.retry_count: - sleep(self.retry_pause) # no need to melt down server on longer outage - try: - crs.close() - except Exception: - pass - try: - self.close() - except Exception: - pass - crs = None - self.con = None - countdown -= 1 + if self.con: + self.con.rollback() + except my.Error: + pass + try: + self.close() + except my.Error: + pass + if self.retry_attempt: + self.log.info("Database error (%d attempts left): %s %s" % (self.retry_attempt, exc_type.__name__, exc_val)) + return True + + + def query(self, *args, **kwargs): + if not self.con: + self.connect() + crs = self.con.cursor() + self.log.debug("execute: %s %s" % (args, kwargs)) + crs.execute(*args, **kwargs) + return crs def _get_comma_perc(self, l): @@ -575,13 +615,17 @@ class MySQL(ObjectBase): if cert_names: query.append(" AND hostname IN (%s)" % self._get_comma_perc(cert_names)) params.extend(n.lower() for n in cert_names) - rows = self.query("".join(query), params, commit=True).fetchall() - if len(rows)>1: - self.log.warn("get_client_by_name: query returned more than one result (cert_names = %s, name = %s, secret = %s): %s" % (cert_names, name, secret, ", ".join([str(Client(**row)) for row in rows]))) - return None + for attempt in self.repeat(): + with attempt as db: + rows = db.query("".join(query), params).fetchall() + + if len(rows)>1: + self.log.warn("get_client_by_name: query returned more than one result (cert_names = %s, name = %s, secret = %s): %s" % (cert_names, name, secret, ", ".join([str(Client(**row)) for row in rows]))) + return None + + return Client(**rows[0]) if rows else None - return Client(**rows[0]) if rows else None def get_clients(self, id=None): @@ -591,8 +635,10 @@ class MySQL(ObjectBase): query.append("WHERE id = %s") params.append(id) query.append("ORDER BY id") - rows = self.query(" ".join(query), params, commit=True).fetchall() - return [Client(**row) for row in rows] + for attempt in self.repeat(): + with attempt as db: + rows = db.query(" ".join(query), params).fetchall() + return [Client(**row) for row in rows] def add_modify_client(self, id=None, **kwargs): @@ -617,19 +663,23 @@ class MySQL(ObjectBase): if id is not None: query.append("WHERE id = %s") params.append(id) - crs = self.query(" ".join(query), params, commit=True) - newid = crs.lastrowid if id is None else id - return newid + for attempt in self.repeat(): + with attempt as db: + crs = db.query(" ".join(query), params) + newid = crs.lastrowid if id is None else id + return newid def get_debug(self): - rows = self.query("SELECT VERSION() AS VER", commit=True).fetchall() - tablestat = self.query("SHOW TABLE STATUS", commit=True).fetchall() - return { - "db": "MySQL", - "version": rows[0]["VER"], - "tables": tablestat - } + for attempt in self.repeat(): + with attempt as db: + rows = db.query("SELECT VERSION() AS VER").fetchall() + tablestat = db.query("SHOW TABLE STATUS").fetchall() + return { + "db": "MySQL", + "version": rows[0]["VER"], + "tables": tablestat + } def getMaps(self, section, variables): @@ -693,7 +743,10 @@ class MySQL(ObjectBase): query_string = "".join(query) - row = self.query(query_string, params, commit=True).fetchall() + row = None + for attempt in self.repeat(): + with attempt as db: + row = db.query(query_string, params).fetchall() if row: maxid = max(r['id'] for r in row) @@ -720,30 +773,28 @@ class MySQL(ObjectBase): def store_events(self, client, events, events_raw): - crs = self.con.cursor() try: - for event, raw_event in zip(events, events_raw): - self.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)", - (client.id, raw_event), crs=crs) - lastid = crs.lastrowid - - catlist = event.get('Category', ["Other"]) - cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist) - for cat in cats: - cat_id = self.catmap.get(cat, self.catmap_other) - self.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id), crs=crs) - - nodes = event.get('Node', []) - tags = [] - for node in nodes: - tags.extend(node.get('Type', [])) - for tag in set(tags): - tag_id = self.tagmap.get(tag, self.tagmap_other) - self.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id), crs=crs) - self.con.commit() - return [] + for attempt in self.repeat(): + with attempt as db: + for event, raw_event in zip(events, events_raw): + lastid = db.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)", + (client.id, raw_event)).lastrowid + + catlist = event.get('Category', ["Other"]) + cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist) + for cat in cats: + cat_id = self.catmap.get(cat, self.catmap_other) + db.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id)) + + nodes = event.get('Node', []) + tags = [] + for node in nodes: + tags.extend(node.get('Type', [])) + for tag in set(tags): + tag_id = self.tagmap.get(tag, self.tagmap_other) + db.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id)) + return [] except Exception as e: - self.con.rollback() exception = self.req.error(message="DB error", error=500, exc=sys.exc_info(), env=env) exception.log(self.log) return [{"error": 500, "message": "DB error %s" % type(e).__name__}] @@ -751,90 +802,75 @@ class MySQL(ObjectBase): def insertLastReceivedId(self, client, id): self.log.debug("insertLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname)) - try: - self.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id)) - self.con.commit() - except Exception as e: - self.con.rollback() - raise + for attempt in self.repeat(): + with attempt as db: + db.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id)) def getLastEventId(self): - row = self.query("SELECT MAX(id) as id FROM events", commit=True).fetchall()[0] - - return row['id'] or 0 + for attempt in self.repeat(): + with attempt as db: + row = db.query("SELECT MAX(id) as id FROM events").fetchall()[0] + return row['id'] or 0 def getLastReceivedId(self, client): - res = self.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", (client.id,), commit=True).fetchall() - try: - row = res[0] - except IndexError: - id = None - self.log.debug("getLastReceivedId: probably first access, unable to get id for client %i(%s)" % (client.id, client.hostname)) - else: - id = row["id"] - self.log.debug("getLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname)) + for attempt in self.repeat(): + with attempt as db: + res = db.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", (client.id,)).fetchall() + try: + row = res[0] + except IndexError: + id = None + self.log.debug("getLastReceivedId: probably first access, unable to get id for client %i(%s)" % (client.id, client.hostname)) + else: + id = row["id"] + self.log.debug("getLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname)) - return id + return id def load_maps(self): - crs = self.con.crs() - try: - self.query("DELETE FROM tags", crs=crs) + with self as db: + db.query("DELETE FROM tags") for tag, num in self.tagmap.iteritems(): - self.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag), crs=crs) - self.query("DELETE FROM categories", crs=crs) + db.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag)) + db.query("DELETE FROM categories") for cat_subcat, num in self.catmap.iteritems(): catsplit = cat_subcat.split(".", 1) category = catsplit[0] subcategory = catsplit[1] if len(catsplit)>1 else None - self.query("INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)", - (num, category, subcategory, cat_subcat), crs=crs) - self.con.commit() - except Exception as e: - self.con.rollback() - raise + db.query("INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)", + (num, category, subcategory, cat_subcat)) def purge_lastlog(self, days): - try: - crs = self.query( + with self as db: + return db.query( "DELETE FROM last_events " " USING last_events LEFT JOIN (" " SELECT MAX(id) AS last FROM last_events" " GROUP BY client_id" " ) AS maxids ON last=id" " WHERE timestamp < DATE_SUB(CURDATE(), INTERVAL %s DAY) AND last IS NULL", - (days,)) - affected = crs.rowcount - self.con.commit() - except Exception as e: - self.con.rollback() - raise - return affected + (days,)).rowcount def purge_events(self, days): - affected = 0 - try: - id_ = self.query( + with self as db: + affected = 0 + id_ = db.query( "SELECT MAX(id) as id" " FROM events" " WHERE received < DATE_SUB(CURDATE(), INTERVAL %s DAY)", - (days,), - commit=True + (days,) ).fetchall()[0]["id"] - crs = self.query("DELETE FROM events WHERE id <= %s", (id_,)) - affected = crs.rowcount - self.query("DELETE FROM event_category_mapping WHERE event_id <= %s", (id_,)) - self.query("DELETE FROM event_tag_mapping WHERE event_id <= %s", (id_,)) - self.con.commit() - except Exception as e: - self.con.rollback() - raise - return affected + if id_ is None: + return 0 + affected = db.query("DELETE FROM events WHERE id <= %s", (id_,)).rowcount + db.query("DELETE FROM event_category_mapping WHERE event_id <= %s", (id_,)) + db.query("DELETE FROM event_tag_mapping WHERE event_id <= %s", (id_,)) + return affected @@ -1281,7 +1317,7 @@ param_def = { "password": {"type": "str", "default": ""}, "dbname": {"type": "str", "default": "warden3"}, "port": {"type": "natural", "default": 3306}, - "retry_pause": {"type": "natural", "default": 5}, + "retry_pause": {"type": "natural", "default": 3}, "retry_count": {"type": "natural", "default": 3}, "event_size_limit": {"type": "natural", "default": 5*1024*1024}, "catmap_filename": {"type": "filepath", "default": path.join(path.dirname(__file__), "catmap_mysql.json")},