Skip to content
Snippets Groups Projects
Commit 4cbb5a46 authored by Pavel Kácha's avatar Pavel Kácha
Browse files

Reworked database access (again) in attempt to handle commits/rollbacks/retries correctly

parent 3b7eaf04
No related branches found
No related tags found
No related merge requests found
...@@ -21,6 +21,7 @@ from time import sleep ...@@ -21,6 +21,7 @@ from time import sleep
from urlparse import parse_qs from urlparse import parse_qs
from os import path from os import path
from random import randint from random import randint
from contextlib import closing
# for local version of up to date jsonschema # for local version of up to date jsonschema
sys.path.append(path.join(path.dirname(__file__), "..", "lib")) sys.path.append(path.join(path.dirname(__file__), "..", "lib"))
...@@ -482,6 +483,7 @@ class MySQL(ObjectBase): ...@@ -482,6 +483,7 @@ class MySQL(ObjectBase):
self.port = port self.port = port
self.retry_count = retry_count self.retry_count = retry_count
self.retry_pause = retry_pause self.retry_pause = retry_pause
self.retry_attempt = 0
self.event_size_limit = event_size_limit self.event_size_limit = event_size_limit
self.catmap_filename = catmap_filename self.catmap_filename = catmap_filename
self.tagmap_filename = tagmap_filename self.tagmap_filename = tagmap_filename
...@@ -496,8 +498,6 @@ class MySQL(ObjectBase): ...@@ -496,8 +498,6 @@ class MySQL(ObjectBase):
self.con = None self.con = None
self.connect()
def __str__(self): def __str__(self):
return "%s(req=%s, host='%s', user='%s', dbname='%s', port=%d, retry_count=%d, retry_pause=%d, catmap_filename=\"%s\", tagmap_filename=\"%s\")" % ( return "%s(req=%s, host='%s', user='%s', dbname='%s', port=%d, retry_count=%d, retry_pause=%d, catmap_filename=\"%s\", tagmap_filename=\"%s\")" % (
...@@ -515,44 +515,84 @@ class MySQL(ObjectBase): ...@@ -515,44 +515,84 @@ class MySQL(ObjectBase):
self.con.close() self.con.close()
except Exception: except Exception:
pass pass
self.con = None
__del__ = close __del__ = close
def query(self, *args, **kwargs): def repeat(self):
""" Execute query on self.con, reconnecting if necessary """ """ Allows for graceful repeating of transactions self.retry_count
countdown = self.retry_count times. Unsuccessful attempts wait for self.retry_pause until
commit = kwargs.pop("commit", False) next attempt.
crs = kwargs.pop("crs", None)
while True: Meant for usage with context manager:
try:
if self.con is None: for attempt in self.repeat():
self.connect() with attempt as db:
if crs is None: crs = db.query(...)
crs = self.con.cursor() # do something with crs
self.log.debug("execute: %s %s" % (args, kwargs))
crs.execute(*args, **kwargs) Note that it's not reentrant (as is not underlying MySQL
if commit: connection), so avoid nesting on the same MySQL object.
"""
self.retry_attempt = self.retry_count
while self.retry_attempt:
if self.retry_attempt != self.retry_count:
sleep(self.retry_pause)
self.retry_attempt -= 1
yield self
def __enter__(self):
""" Context manager protocol. Guarantees that transaction will
get either commited or rolled back in case of database
exception. Can be used with self.repeat(), or alone as:
with self as db:
crs = db.query(...)
# do something with crs
Note that it's not reentrant (as is not underlying MySQL
connection), so avoid nesting on the same MySQL object.
"""
if not self.retry_attempt:
self.retry_attempt = 0
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" Context manager protocol. If db exception is fired and
self.retry_attempt is not zero, it is only logged and
does not propagate, otherwise it propagates up. Also
open transaction is rolled back.
In case of no exception, transaction gets commited.
"""
if not exc_type:
self.con.commit() self.con.commit()
return crs self.retry_attempt = 0
except my.OperationalError: else:
if not countdown:
raise
self.log.info("execute: Database down, trying to reconnect (%d attempts left)..." % countdown)
if countdown<self.retry_count:
sleep(self.retry_pause) # no need to melt down server on longer outage
try: try:
crs.close() if self.con:
except Exception: self.con.rollback()
except my.Error:
pass pass
try: try:
self.close() self.close()
except Exception: except my.Error:
pass pass
crs = None if self.retry_attempt:
self.con = None self.log.info("Database error (%d attempts left): %s %s" % (self.retry_attempt, exc_type.__name__, exc_val))
countdown -= 1 return True
def query(self, *args, **kwargs):
if not self.con:
self.connect()
crs = self.con.cursor()
self.log.debug("execute: %s %s" % (args, kwargs))
crs.execute(*args, **kwargs)
return crs
def _get_comma_perc(self, l): def _get_comma_perc(self, l):
...@@ -575,7 +615,10 @@ class MySQL(ObjectBase): ...@@ -575,7 +615,10 @@ class MySQL(ObjectBase):
if cert_names: if cert_names:
query.append(" AND hostname IN (%s)" % self._get_comma_perc(cert_names)) query.append(" AND hostname IN (%s)" % self._get_comma_perc(cert_names))
params.extend(n.lower() for n in cert_names) params.extend(n.lower() for n in cert_names)
rows = self.query("".join(query), params, commit=True).fetchall()
for attempt in self.repeat():
with attempt as db:
rows = db.query("".join(query), params).fetchall()
if len(rows)>1: if len(rows)>1:
self.log.warn("get_client_by_name: query returned more than one result (cert_names = %s, name = %s, secret = %s): %s" % (cert_names, name, secret, ", ".join([str(Client(**row)) for row in rows]))) self.log.warn("get_client_by_name: query returned more than one result (cert_names = %s, name = %s, secret = %s): %s" % (cert_names, name, secret, ", ".join([str(Client(**row)) for row in rows])))
...@@ -584,6 +627,7 @@ class MySQL(ObjectBase): ...@@ -584,6 +627,7 @@ class MySQL(ObjectBase):
return Client(**rows[0]) if rows else None return Client(**rows[0]) if rows else None
def get_clients(self, id=None): def get_clients(self, id=None):
query = ["SELECT * FROM clients"] query = ["SELECT * FROM clients"]
params = [] params = []
...@@ -591,7 +635,9 @@ class MySQL(ObjectBase): ...@@ -591,7 +635,9 @@ class MySQL(ObjectBase):
query.append("WHERE id = %s") query.append("WHERE id = %s")
params.append(id) params.append(id)
query.append("ORDER BY id") query.append("ORDER BY id")
rows = self.query(" ".join(query), params, commit=True).fetchall() for attempt in self.repeat():
with attempt as db:
rows = db.query(" ".join(query), params).fetchall()
return [Client(**row) for row in rows] return [Client(**row) for row in rows]
...@@ -617,14 +663,18 @@ class MySQL(ObjectBase): ...@@ -617,14 +663,18 @@ class MySQL(ObjectBase):
if id is not None: if id is not None:
query.append("WHERE id = %s") query.append("WHERE id = %s")
params.append(id) params.append(id)
crs = self.query(" ".join(query), params, commit=True) for attempt in self.repeat():
with attempt as db:
crs = db.query(" ".join(query), params)
newid = crs.lastrowid if id is None else id newid = crs.lastrowid if id is None else id
return newid return newid
def get_debug(self): def get_debug(self):
rows = self.query("SELECT VERSION() AS VER", commit=True).fetchall() for attempt in self.repeat():
tablestat = self.query("SHOW TABLE STATUS", commit=True).fetchall() with attempt as db:
rows = db.query("SELECT VERSION() AS VER").fetchall()
tablestat = db.query("SHOW TABLE STATUS").fetchall()
return { return {
"db": "MySQL", "db": "MySQL",
"version": rows[0]["VER"], "version": rows[0]["VER"],
...@@ -693,7 +743,10 @@ class MySQL(ObjectBase): ...@@ -693,7 +743,10 @@ class MySQL(ObjectBase):
query_string = "".join(query) query_string = "".join(query)
row = self.query(query_string, params, commit=True).fetchall() row = None
for attempt in self.repeat():
with attempt as db:
row = db.query(query_string, params).fetchall()
if row: if row:
maxid = max(r['id'] for r in row) maxid = max(r['id'] for r in row)
...@@ -720,18 +773,18 @@ class MySQL(ObjectBase): ...@@ -720,18 +773,18 @@ class MySQL(ObjectBase):
def store_events(self, client, events, events_raw): def store_events(self, client, events, events_raw):
crs = self.con.cursor()
try: try:
for attempt in self.repeat():
with attempt as db:
for event, raw_event in zip(events, events_raw): for event, raw_event in zip(events, events_raw):
self.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)", lastid = db.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)",
(client.id, raw_event), crs=crs) (client.id, raw_event)).lastrowid
lastid = crs.lastrowid
catlist = event.get('Category', ["Other"]) catlist = event.get('Category', ["Other"])
cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist) cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist)
for cat in cats: for cat in cats:
cat_id = self.catmap.get(cat, self.catmap_other) cat_id = self.catmap.get(cat, self.catmap_other)
self.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id), crs=crs) db.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id))
nodes = event.get('Node', []) nodes = event.get('Node', [])
tags = [] tags = []
...@@ -739,11 +792,9 @@ class MySQL(ObjectBase): ...@@ -739,11 +792,9 @@ class MySQL(ObjectBase):
tags.extend(node.get('Type', [])) tags.extend(node.get('Type', []))
for tag in set(tags): for tag in set(tags):
tag_id = self.tagmap.get(tag, self.tagmap_other) tag_id = self.tagmap.get(tag, self.tagmap_other)
self.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id), crs=crs) db.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id))
self.con.commit()
return [] return []
except Exception as e: except Exception as e:
self.con.rollback()
exception = self.req.error(message="DB error", error=500, exc=sys.exc_info(), env=env) exception = self.req.error(message="DB error", error=500, exc=sys.exc_info(), env=env)
exception.log(self.log) exception.log(self.log)
return [{"error": 500, "message": "DB error %s" % type(e).__name__}] return [{"error": 500, "message": "DB error %s" % type(e).__name__}]
...@@ -751,22 +802,22 @@ class MySQL(ObjectBase): ...@@ -751,22 +802,22 @@ class MySQL(ObjectBase):
def insertLastReceivedId(self, client, id): def insertLastReceivedId(self, client, id):
self.log.debug("insertLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname)) self.log.debug("insertLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname))
try: for attempt in self.repeat():
self.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id)) with attempt as db:
self.con.commit() db.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id))
except Exception as e:
self.con.rollback()
raise
def getLastEventId(self): def getLastEventId(self):
row = self.query("SELECT MAX(id) as id FROM events", commit=True).fetchall()[0] for attempt in self.repeat():
with attempt as db:
row = db.query("SELECT MAX(id) as id FROM events").fetchall()[0]
return row['id'] or 0 return row['id'] or 0
def getLastReceivedId(self, client): def getLastReceivedId(self, client):
res = self.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", (client.id,), commit=True).fetchall() for attempt in self.repeat():
with attempt as db:
res = db.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", (client.id,)).fetchall()
try: try:
row = res[0] row = res[0]
except IndexError: except IndexError:
...@@ -780,60 +831,45 @@ class MySQL(ObjectBase): ...@@ -780,60 +831,45 @@ class MySQL(ObjectBase):
def load_maps(self): def load_maps(self):
crs = self.con.crs() with self as db:
try: db.query("DELETE FROM tags")
self.query("DELETE FROM tags", crs=crs)
for tag, num in self.tagmap.iteritems(): for tag, num in self.tagmap.iteritems():
self.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag), crs=crs) db.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag))
self.query("DELETE FROM categories", crs=crs) db.query("DELETE FROM categories")
for cat_subcat, num in self.catmap.iteritems(): for cat_subcat, num in self.catmap.iteritems():
catsplit = cat_subcat.split(".", 1) catsplit = cat_subcat.split(".", 1)
category = catsplit[0] category = catsplit[0]
subcategory = catsplit[1] if len(catsplit)>1 else None subcategory = catsplit[1] if len(catsplit)>1 else None
self.query("INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)", db.query("INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)",
(num, category, subcategory, cat_subcat), crs=crs) (num, category, subcategory, cat_subcat))
self.con.commit()
except Exception as e:
self.con.rollback()
raise
def purge_lastlog(self, days): def purge_lastlog(self, days):
try: with self as db:
crs = self.query( return db.query(
"DELETE FROM last_events " "DELETE FROM last_events "
" USING last_events LEFT JOIN (" " USING last_events LEFT JOIN ("
" SELECT MAX(id) AS last FROM last_events" " SELECT MAX(id) AS last FROM last_events"
" GROUP BY client_id" " GROUP BY client_id"
" ) AS maxids ON last=id" " ) AS maxids ON last=id"
" WHERE timestamp < DATE_SUB(CURDATE(), INTERVAL %s DAY) AND last IS NULL", " WHERE timestamp < DATE_SUB(CURDATE(), INTERVAL %s DAY) AND last IS NULL",
(days,)) (days,)).rowcount
affected = crs.rowcount
self.con.commit()
except Exception as e:
self.con.rollback()
raise
return affected
def purge_events(self, days): def purge_events(self, days):
with self as db:
affected = 0 affected = 0
try: id_ = db.query(
id_ = self.query(
"SELECT MAX(id) as id" "SELECT MAX(id) as id"
" FROM events" " FROM events"
" WHERE received < DATE_SUB(CURDATE(), INTERVAL %s DAY)", " WHERE received < DATE_SUB(CURDATE(), INTERVAL %s DAY)",
(days,), (days,)
commit=True
).fetchall()[0]["id"] ).fetchall()[0]["id"]
crs = self.query("DELETE FROM events WHERE id <= %s", (id_,)) if id_ is None:
affected = crs.rowcount return 0
self.query("DELETE FROM event_category_mapping WHERE event_id <= %s", (id_,)) affected = db.query("DELETE FROM events WHERE id <= %s", (id_,)).rowcount
self.query("DELETE FROM event_tag_mapping WHERE event_id <= %s", (id_,)) db.query("DELETE FROM event_category_mapping WHERE event_id <= %s", (id_,))
self.con.commit() db.query("DELETE FROM event_tag_mapping WHERE event_id <= %s", (id_,))
except Exception as e:
self.con.rollback()
raise
return affected return affected
...@@ -1281,7 +1317,7 @@ param_def = { ...@@ -1281,7 +1317,7 @@ param_def = {
"password": {"type": "str", "default": ""}, "password": {"type": "str", "default": ""},
"dbname": {"type": "str", "default": "warden3"}, "dbname": {"type": "str", "default": "warden3"},
"port": {"type": "natural", "default": 3306}, "port": {"type": "natural", "default": 3306},
"retry_pause": {"type": "natural", "default": 5}, "retry_pause": {"type": "natural", "default": 3},
"retry_count": {"type": "natural", "default": 3}, "retry_count": {"type": "natural", "default": 3},
"event_size_limit": {"type": "natural", "default": 5*1024*1024}, "event_size_limit": {"type": "natural", "default": 5*1024*1024},
"catmap_filename": {"type": "filepath", "default": path.join(path.dirname(__file__), "catmap_mysql.json")}, "catmap_filename": {"type": "filepath", "default": path.join(path.dirname(__file__), "catmap_mysql.json")},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment