From 4cbb5a46e120b2fc07eb62224c81c9a4dc264e44 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavel=20K=C3=A1cha?= <ph@cesnet.cz>
Date: Thu, 29 Jun 2017 16:02:29 +0200
Subject: [PATCH] Reworked database access (again) in attempt to handle
commits/rollbacks/retries correctly
---
warden3/warden_server/warden_server.py | 292 ++++++++++++++-----------
1 file changed, 164 insertions(+), 128 deletions(-)
diff --git a/warden3/warden_server/warden_server.py b/warden3/warden_server/warden_server.py
index 90f6211..3b51c50 100755
--- a/warden3/warden_server/warden_server.py
+++ b/warden3/warden_server/warden_server.py
@@ -21,6 +21,7 @@ from time import sleep
from urlparse import parse_qs
from os import path
from random import randint
+from contextlib import closing
# for local version of up to date jsonschema
sys.path.append(path.join(path.dirname(__file__), "..", "lib"))
@@ -482,6 +483,7 @@ class MySQL(ObjectBase):
self.port = port
self.retry_count = retry_count
self.retry_pause = retry_pause
+ self.retry_attempt = 0
self.event_size_limit = event_size_limit
self.catmap_filename = catmap_filename
self.tagmap_filename = tagmap_filename
@@ -496,8 +498,6 @@ class MySQL(ObjectBase):
self.con = None
- self.connect()
-
def __str__(self):
return "%s(req=%s, host='%s', user='%s', dbname='%s', port=%d, retry_count=%d, retry_pause=%d, catmap_filename=\"%s\", tagmap_filename=\"%s\")" % (
@@ -515,44 +515,84 @@ class MySQL(ObjectBase):
self.con.close()
except Exception:
pass
+ self.con = None
__del__ = close
- def query(self, *args, **kwargs):
- """ Execute query on self.con, reconnecting if necessary """
- countdown = self.retry_count
- commit = kwargs.pop("commit", False)
- crs = kwargs.pop("crs", None)
- while True:
+ def repeat(self):
+ """ Allows for graceful repeating of transactions self.retry_count
+ times. Unsuccessful attempts wait for self.retry_pause until
+ next attempt.
+
+ Meant for usage with context manager:
+
+ for attempt in self.repeat():
+ with attempt as db:
+ crs = db.query(...)
+ # do something with crs
+
+ Note that it's not reentrant (as is not underlying MySQL
+ connection), so avoid nesting on the same MySQL object.
+ """
+ self.retry_attempt = self.retry_count
+ while self.retry_attempt:
+ if self.retry_attempt != self.retry_count:
+ sleep(self.retry_pause)
+ self.retry_attempt -= 1
+ yield self
+
+
+ def __enter__(self):
+ """ Context manager protocol. Guarantees that transaction will
+ get either commited or rolled back in case of database
+ exception. Can be used with self.repeat(), or alone as:
+
+ with self as db:
+ crs = db.query(...)
+ # do something with crs
+
+ Note that it's not reentrant (as is not underlying MySQL
+ connection), so avoid nesting on the same MySQL object.
+ """
+ if not self.retry_attempt:
+ self.retry_attempt = 0
+ return self
+
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """ Context manager protocol. If db exception is fired and
+ self.retry_attempt is not zero, it is only logged and
+ does not propagate, otherwise it propagates up. Also
+ open transaction is rolled back.
+ In case of no exception, transaction gets commited.
+ """
+ if not exc_type:
+ self.con.commit()
+ self.retry_attempt = 0
+ else:
try:
- if self.con is None:
- self.connect()
- if crs is None:
- crs = self.con.cursor()
- self.log.debug("execute: %s %s" % (args, kwargs))
- crs.execute(*args, **kwargs)
- if commit:
- self.con.commit()
- return crs
- except my.OperationalError:
- if not countdown:
- raise
- self.log.info("execute: Database down, trying to reconnect (%d attempts left)..." % countdown)
- if countdown<self.retry_count:
- sleep(self.retry_pause) # no need to melt down server on longer outage
- try:
- crs.close()
- except Exception:
- pass
- try:
- self.close()
- except Exception:
- pass
- crs = None
- self.con = None
- countdown -= 1
+ if self.con:
+ self.con.rollback()
+ except my.Error:
+ pass
+ try:
+ self.close()
+ except my.Error:
+ pass
+ if self.retry_attempt:
+ self.log.info("Database error (%d attempts left): %s %s" % (self.retry_attempt, exc_type.__name__, exc_val))
+ return True
+
+
+ def query(self, *args, **kwargs):
+ if not self.con:
+ self.connect()
+ crs = self.con.cursor()
+ self.log.debug("execute: %s %s" % (args, kwargs))
+ crs.execute(*args, **kwargs)
+ return crs
def _get_comma_perc(self, l):
@@ -575,13 +615,17 @@ class MySQL(ObjectBase):
if cert_names:
query.append(" AND hostname IN (%s)" % self._get_comma_perc(cert_names))
params.extend(n.lower() for n in cert_names)
- rows = self.query("".join(query), params, commit=True).fetchall()
- if len(rows)>1:
- self.log.warn("get_client_by_name: query returned more than one result (cert_names = %s, name = %s, secret = %s): %s" % (cert_names, name, secret, ", ".join([str(Client(**row)) for row in rows])))
- return None
+ for attempt in self.repeat():
+ with attempt as db:
+ rows = db.query("".join(query), params).fetchall()
+
+ if len(rows)>1:
+ self.log.warn("get_client_by_name: query returned more than one result (cert_names = %s, name = %s, secret = %s): %s" % (cert_names, name, secret, ", ".join([str(Client(**row)) for row in rows])))
+ return None
+
+ return Client(**rows[0]) if rows else None
- return Client(**rows[0]) if rows else None
def get_clients(self, id=None):
@@ -591,8 +635,10 @@ class MySQL(ObjectBase):
query.append("WHERE id = %s")
params.append(id)
query.append("ORDER BY id")
- rows = self.query(" ".join(query), params, commit=True).fetchall()
- return [Client(**row) for row in rows]
+ for attempt in self.repeat():
+ with attempt as db:
+ rows = db.query(" ".join(query), params).fetchall()
+ return [Client(**row) for row in rows]
def add_modify_client(self, id=None, **kwargs):
@@ -617,19 +663,23 @@ class MySQL(ObjectBase):
if id is not None:
query.append("WHERE id = %s")
params.append(id)
- crs = self.query(" ".join(query), params, commit=True)
- newid = crs.lastrowid if id is None else id
- return newid
+ for attempt in self.repeat():
+ with attempt as db:
+ crs = db.query(" ".join(query), params)
+ newid = crs.lastrowid if id is None else id
+ return newid
def get_debug(self):
- rows = self.query("SELECT VERSION() AS VER", commit=True).fetchall()
- tablestat = self.query("SHOW TABLE STATUS", commit=True).fetchall()
- return {
- "db": "MySQL",
- "version": rows[0]["VER"],
- "tables": tablestat
- }
+ for attempt in self.repeat():
+ with attempt as db:
+ rows = db.query("SELECT VERSION() AS VER").fetchall()
+ tablestat = db.query("SHOW TABLE STATUS").fetchall()
+ return {
+ "db": "MySQL",
+ "version": rows[0]["VER"],
+ "tables": tablestat
+ }
def getMaps(self, section, variables):
@@ -693,7 +743,10 @@ class MySQL(ObjectBase):
query_string = "".join(query)
- row = self.query(query_string, params, commit=True).fetchall()
+ row = None
+ for attempt in self.repeat():
+ with attempt as db:
+ row = db.query(query_string, params).fetchall()
if row:
maxid = max(r['id'] for r in row)
@@ -720,30 +773,28 @@ class MySQL(ObjectBase):
def store_events(self, client, events, events_raw):
- crs = self.con.cursor()
try:
- for event, raw_event in zip(events, events_raw):
- self.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)",
- (client.id, raw_event), crs=crs)
- lastid = crs.lastrowid
-
- catlist = event.get('Category', ["Other"])
- cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist)
- for cat in cats:
- cat_id = self.catmap.get(cat, self.catmap_other)
- self.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id), crs=crs)
-
- nodes = event.get('Node', [])
- tags = []
- for node in nodes:
- tags.extend(node.get('Type', []))
- for tag in set(tags):
- tag_id = self.tagmap.get(tag, self.tagmap_other)
- self.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id), crs=crs)
- self.con.commit()
- return []
+ for attempt in self.repeat():
+ with attempt as db:
+ for event, raw_event in zip(events, events_raw):
+ lastid = db.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)",
+ (client.id, raw_event)).lastrowid
+
+ catlist = event.get('Category', ["Other"])
+ cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist)
+ for cat in cats:
+ cat_id = self.catmap.get(cat, self.catmap_other)
+ db.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id))
+
+ nodes = event.get('Node', [])
+ tags = []
+ for node in nodes:
+ tags.extend(node.get('Type', []))
+ for tag in set(tags):
+ tag_id = self.tagmap.get(tag, self.tagmap_other)
+ db.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id))
+ return []
except Exception as e:
- self.con.rollback()
exception = self.req.error(message="DB error", error=500, exc=sys.exc_info(), env=env)
exception.log(self.log)
return [{"error": 500, "message": "DB error %s" % type(e).__name__}]
@@ -751,90 +802,75 @@ class MySQL(ObjectBase):
def insertLastReceivedId(self, client, id):
self.log.debug("insertLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname))
- try:
- self.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id))
- self.con.commit()
- except Exception as e:
- self.con.rollback()
- raise
+ for attempt in self.repeat():
+ with attempt as db:
+ db.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id))
def getLastEventId(self):
- row = self.query("SELECT MAX(id) as id FROM events", commit=True).fetchall()[0]
-
- return row['id'] or 0
+ for attempt in self.repeat():
+ with attempt as db:
+ row = db.query("SELECT MAX(id) as id FROM events").fetchall()[0]
+ return row['id'] or 0
def getLastReceivedId(self, client):
- res = self.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", (client.id,), commit=True).fetchall()
- try:
- row = res[0]
- except IndexError:
- id = None
- self.log.debug("getLastReceivedId: probably first access, unable to get id for client %i(%s)" % (client.id, client.hostname))
- else:
- id = row["id"]
- self.log.debug("getLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname))
+ for attempt in self.repeat():
+ with attempt as db:
+ res = db.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", (client.id,)).fetchall()
+ try:
+ row = res[0]
+ except IndexError:
+ id = None
+ self.log.debug("getLastReceivedId: probably first access, unable to get id for client %i(%s)" % (client.id, client.hostname))
+ else:
+ id = row["id"]
+ self.log.debug("getLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname))
- return id
+ return id
def load_maps(self):
- crs = self.con.crs()
- try:
- self.query("DELETE FROM tags", crs=crs)
+ with self as db:
+ db.query("DELETE FROM tags")
for tag, num in self.tagmap.iteritems():
- self.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag), crs=crs)
- self.query("DELETE FROM categories", crs=crs)
+ db.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag))
+ db.query("DELETE FROM categories")
for cat_subcat, num in self.catmap.iteritems():
catsplit = cat_subcat.split(".", 1)
category = catsplit[0]
subcategory = catsplit[1] if len(catsplit)>1 else None
- self.query("INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)",
- (num, category, subcategory, cat_subcat), crs=crs)
- self.con.commit()
- except Exception as e:
- self.con.rollback()
- raise
+ db.query("INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)",
+ (num, category, subcategory, cat_subcat))
def purge_lastlog(self, days):
- try:
- crs = self.query(
+ with self as db:
+ return db.query(
"DELETE FROM last_events "
" USING last_events LEFT JOIN ("
" SELECT MAX(id) AS last FROM last_events"
" GROUP BY client_id"
" ) AS maxids ON last=id"
" WHERE timestamp < DATE_SUB(CURDATE(), INTERVAL %s DAY) AND last IS NULL",
- (days,))
- affected = crs.rowcount
- self.con.commit()
- except Exception as e:
- self.con.rollback()
- raise
- return affected
+ (days,)).rowcount
def purge_events(self, days):
- affected = 0
- try:
- id_ = self.query(
+ with self as db:
+ affected = 0
+ id_ = db.query(
"SELECT MAX(id) as id"
" FROM events"
" WHERE received < DATE_SUB(CURDATE(), INTERVAL %s DAY)",
- (days,),
- commit=True
+ (days,)
).fetchall()[0]["id"]
- crs = self.query("DELETE FROM events WHERE id <= %s", (id_,))
- affected = crs.rowcount
- self.query("DELETE FROM event_category_mapping WHERE event_id <= %s", (id_,))
- self.query("DELETE FROM event_tag_mapping WHERE event_id <= %s", (id_,))
- self.con.commit()
- except Exception as e:
- self.con.rollback()
- raise
- return affected
+ if id_ is None:
+ return 0
+ affected = db.query("DELETE FROM events WHERE id <= %s", (id_,)).rowcount
+ db.query("DELETE FROM event_category_mapping WHERE event_id <= %s", (id_,))
+ db.query("DELETE FROM event_tag_mapping WHERE event_id <= %s", (id_,))
+ return affected
@@ -1281,7 +1317,7 @@ param_def = {
"password": {"type": "str", "default": ""},
"dbname": {"type": "str", "default": "warden3"},
"port": {"type": "natural", "default": 3306},
- "retry_pause": {"type": "natural", "default": 5},
+ "retry_pause": {"type": "natural", "default": 3},
"retry_count": {"type": "natural", "default": 3},
"event_size_limit": {"type": "natural", "default": 5*1024*1024},
"catmap_filename": {"type": "filepath", "default": path.join(path.dirname(__file__), "catmap_mysql.json")},
--
GitLab