diff --git a/warden3/warden_server/warden_server.py b/warden3/warden_server/warden_server.py index 1cd1ba499bd2127a6b6bbf90853a1c6ea79b3a0f..2949a14be33724d8030287eadf4733a4a07663a0 100755 --- a/warden3/warden_server/warden_server.py +++ b/warden3/warden_server/warden_server.py @@ -470,15 +470,29 @@ class MySQL(ObjectReq): __del__ = close + + + def log_transactions(self): + self.crs.execute("SHOW ENGINE INNODB STATUS") + res = self.crs.fetchall() + self.con.commit() + tolog = [l for l in res[0]["Status"].split("\n") if "thread id" in l] + for l in tolog: + logging.debug(l) - def execute(self, *args, **kwargs): + def query(self, *args, **kwargs): """ Execute query on self.con, reconnecting if necessary """ success = False countdown = self.retry_count + res = None + dml = kwargs.pop("dml", False) while not success: try: self.crs.execute(*args, **kwargs) + if not dml: + res = self.crs.fetchall() + self.con.commit() success = True except my.OperationalError: if not countdown: @@ -489,7 +503,7 @@ class MySQL(ObjectReq): self.close() self.connect() countdown -= 1 - + return res def _get_comma_perc(self, l): return ','.join(['%s'] * len(l)) @@ -510,8 +524,7 @@ class MySQL(ObjectReq): params.append(secret) query.append(" AND hostname IN (%s)" % self._get_comma_perc(cert_names)) params.extend(cert_names) - self.execute("".join(query), params) - rows = self.crs.fetchall() + rows = self.query("".join(query), params) if len(rows)>1: logging.warn("get_client_by_name: query returned more than one result: %s" % ", ".join( @@ -522,13 +535,11 @@ class MySQL(ObjectReq): def get_debug(self): - self.execute("SELECT VERSION() AS VER") - row = self.crs.fetchone() - self.execute("SHOW TABLE STATUS") - tablestat = self.crs.fetchall() + rows = self.query("SELECT VERSION() AS VER") + tablestat = self.query("SHOW TABLE STATUS") return { "db": "MySQL", - "version": row["VER"], + "version": rows[0]["VER"], "tables": tablestat } @@ -596,8 +607,7 @@ class MySQL(ObjectReq): logging.debug("fetch_events: query - %s" % query_string) logging.debug("fetch_events: params - %s", str(params)) - self.execute(query_string, params) - row = self.crs.fetchall() + row = self.query(query_string, params) if row: maxid = max(r['id'] for r in row) @@ -614,14 +624,14 @@ class MySQL(ObjectReq): def store_event(self, client, event): try: - self.execute("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)", (client.id, json.dumps(event))) + self.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)", (client.id, json.dumps(event)), dml=True) lastid = self.crs.lastrowid catlist = event.get('Category', ["Other"]) cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist) for cat in cats: cat_id = self.catmap.get(cat, self.catmap_other) - self.execute("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id)) + self.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id), dml=True) try: tags = event['Node'][0]['Tags'] @@ -630,7 +640,7 @@ class MySQL(ObjectReq): for tag in tags: tag_id = self.tagmap.get(tag, self.tagmap_other) - self.execute("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id)) + self.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id), dml=True) self.con.commit() return [] @@ -641,18 +651,16 @@ class MySQL(ObjectReq): def insertLastReceivedId(self, client, id): logging.debug("insertLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname)) - self.execute("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id)) + self.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id), dml=True) self.con.commit() def getLastEventId(self): - self.execute("SELECT MAX(id) as id FROM events") - row = self.crs.fetchone() + row = self.query("SELECT MAX(id) as id FROM events")[0] return row['id'] or 0 def getLastReceivedId(self, client): - self.execute("SELECT MAX(event_id) as id FROM last_events WHERE client_id = %s", client.id) - row = self.crs.fetchone() + row = self.query("SELECT MAX(event_id) as id FROM last_events WHERE client_id = %s", client.id)[0] id = row['id'] if row is not None else 0 logging.debug("getLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname))