From 5c84e9adc0f5ce1ab9a3463efeea8df69d453d07 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavel=20K=C3=A1cha?= <ph@cesnet.cz>
Date: Tue, 19 Apr 2016 17:05:38 +0200
Subject: [PATCH] Reworked db layer, allowed for comitting after the whole
 bunch (req by Bodik)

---
 warden3/warden_server/warden_server.py | 137 +++++++++++++------------
 1 file changed, 70 insertions(+), 67 deletions(-)

diff --git a/warden3/warden_server/warden_server.py b/warden3/warden_server/warden_server.py
index bbf9f5c..13e4a6f 100755
--- a/warden3/warden_server/warden_server.py
+++ b/warden3/warden_server/warden_server.py
@@ -450,7 +450,7 @@ class MySQL(ObjectReq):
             self.tagmap = json.load(tagmap_fd)
             self.tagmap_other = self.catmap["Other"]    # Catch error soon, avoid lookup later
 
-        self.con = self.crs = None
+        self.con = None
 
         self.connect()
 
@@ -463,13 +463,10 @@ class MySQL(ObjectReq):
     def connect(self):
         self.con = my.connect(host=self.host, user=self.user, passwd=self.password,
             db=self.dbname, port=self.port, cursorclass=mycursors.DictCursor)
-        self.crs = self.con.cursor()
 
 
     def close(self):
         try:
-            if self.crs:
-                self.crs.close()
             if self.con:
                 self.con.close()
         except Exception:
@@ -479,28 +476,20 @@ class MySQL(ObjectReq):
     __del__ = close
     
     
-    def log_transactions(self):
-        self.crs.execute("SHOW ENGINE INNODB STATUS")
-        res = self.crs.fetchall()
-        self.con.commit()
-        tolog = [l for l in res[0]["Status"].split("\n") if "thread id" in l]
-        for l in tolog:
-            logging.debug(l)
-
-
     def query(self, *args, **kwargs):
         """ Execute query on self.con, reconnecting if necessary """
-        success = False
         countdown = self.retry_count
-        res = None
-        dml = kwargs.pop("dml", False)
-        while not success:
+        commit = kwargs.pop("commit", False)
+        crs = kwargs.pop("crs", None)
+        while True:
             try:
-                self.crs.execute(*args, **kwargs)
-                if not dml:
-                    res = self.crs.fetchall()
+                if crs is None:
+                    crs = self.con.cursor()
+                logging.debug("execute: %s %s" % (args, kwargs))
+                crs.execute(*args, **kwargs)
+                if commit:
                     self.con.commit()
-                success = True
+                return crs
             except my.OperationalError:
                 if not countdown:
                     raise
@@ -509,8 +498,9 @@ class MySQL(ObjectReq):
                     sleep(self.retry_pause)    # no need to melt down server on longer outage
                 self.close()
                 self.connect()
+                crs = None
                 countdown -= 1
-        return res
+
 
     def _get_comma_perc(self, l):
         return ','.join(['%s'] * len(l))
@@ -531,7 +521,7 @@ class MySQL(ObjectReq):
             params.append(secret)
         query.append(" AND hostname IN (%s)" % self._get_comma_perc(cert_names))
         params.extend(n.lower() for n in cert_names)
-        rows = self.query("".join(query), params)
+        rows = self.query("".join(query), params, commit=True).fetchall()
 
         if len(rows)>1:
             logging.warn("get_client_by_name: query returned more than one result: %s" % ", ".join(
@@ -548,7 +538,7 @@ class MySQL(ObjectReq):
             query.append("WHERE id = %s")
             params.append(id)
         query.append("ORDER BY id")
-        rows = self.query(" ".join(query), params)
+        rows = self.query(" ".join(query), params, commit=True).fetchall()
         return [Client(**row) for row in rows]
 
 
@@ -572,13 +562,15 @@ class MySQL(ObjectReq):
         if id is not None:
             query.append("WHERE id = %s")
             params.append(id)
-        self.query(" ".join(query), params)
-        return self.crs.lastrowid if id is None else id
+        crs = self.query(" ".join(query), params).fetchall()
+        newid = crs.lastrowid if id is None else id
+        self.con.commit()
+        return newid
 
 
     def get_debug(self):
-        rows = self.query("SELECT VERSION() AS VER")
-        tablestat = self.query("SHOW TABLE STATUS")
+        rows = self.query("SELECT VERSION() AS VER", commit=True).fetchall()
+        tablestat = self.query("SHOW TABLE STATUS", commit=True).fetchall()
         return {
             "db": "MySQL",
             "version": rows[0]["VER"],
@@ -646,10 +638,8 @@ class MySQL(ObjectReq):
         params.append(count)
 
         query_string = "".join(query)
-        logging.debug("fetch_events: query - %s" % query_string)
-        logging.debug("fetch_events: params - %s", str(params))
 
-        row = self.query(query_string, params)
+        row = self.query(query_string, params, commit=True).fetchall()
 
         if row:
             maxid = max(r['id'] for r in row)
@@ -675,40 +665,40 @@ class MySQL(ObjectReq):
         }
 
 
-    def store_event(self, client, event):
-        json_event = json.dumps(event)
-        if len(json_event) >= self.event_size_limit:
-            return [{"error": 413, "message": "Event too long (>%i B)" % self.event_size_limit}]
+    def store_events(self, client, events, events_raw):
+        crs = self.con.cursor()
         try:
-            self.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)",
-                (client.id, json_event), dml=True)
-            lastid = self.crs.lastrowid
-
-            catlist = event.get('Category', ["Other"])
-            cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist)
-            for cat in cats:
-                cat_id = self.catmap.get(cat, self.catmap_other)
-                self.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id), dml=True)
-                
-            nodes = event.get('Node', [])
-            tags = []
-            for node in nodes:
-                tags.extend(node.get('Type', []))
-            for tag in set(tags):
-                tag_id = self.tagmap.get(tag, self.tagmap_other)
-                self.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id), dml=True)
-
+            for event, raw_event in zip(events, events_raw):
+                self.query("INSERT INTO events (received,client_id,data) VALUES (NOW(), %s, %s)",
+                    (client.id, raw_event), crs=crs)
+                lastid = crs.lastrowid
+
+                catlist = event.get('Category', ["Other"])
+                cats = set(catlist) | set(cat.split(".", 1)[0] for cat in catlist)
+                for cat in cats:
+                    cat_id = self.catmap.get(cat, self.catmap_other)
+                    self.query("INSERT INTO event_category_mapping (event_id,category_id) VALUES (%s, %s)", (lastid, cat_id), crs=crs)
+
+                nodes = event.get('Node', [])
+                tags = []
+                for node in nodes:
+                    tags.extend(node.get('Type', []))
+                for tag in set(tags):
+                    tag_id = self.tagmap.get(tag, self.tagmap_other)
+                    self.query("INSERT INTO event_tag_mapping (event_id,tag_id) VALUES (%s, %s)", (lastid, tag_id), crs=crs)
             self.con.commit()
             return []
         except Exception as e:
             self.con.rollback()
-            return [{"error": 500, "message": type(e).__name__}]
+            exception = self.req.error(message="DB error", error=500, exc=sys.exc_info(), env=env)
+            exception.log(logging.getLogger())
+            return [{"error": 500, "message": "DB error %s" % type(e).__name__}]
 
 
     def insertLastReceivedId(self, client, id):
         logging.debug("insertLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname))
         try:
-            self.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id), dml=True)
+            self.query("INSERT INTO last_events(client_id, event_id, timestamp) VALUES(%s, %s, NOW())", (client.id, id))
             self.con.commit()
         except Exception as e:
             self.con.rollback()
@@ -716,13 +706,13 @@ class MySQL(ObjectReq):
 
 
     def getLastEventId(self):
-        row = self.query("SELECT MAX(id) as id FROM events")[0]
+        row = self.query("SELECT MAX(id) as id FROM events", commit=True).fetchall()[0]
 
         return row['id'] or 0
 
 
     def getLastReceivedId(self, client):
-        res = self.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", client.id)
+        res = self.query("SELECT event_id as id FROM last_events WHERE client_id = %s ORDER BY last_events.id DESC LIMIT 1", client.id, commit=True).fetchall()
         try:
             row = res[0]
         except IndexError:
@@ -736,17 +726,18 @@ class MySQL(ObjectReq):
 
 
     def load_maps(self):
+        crs = self.con.crs()
         try:
-            self.query("DELETE FROM tags", dml=True)
+            self.query("DELETE FROM tags", crs=crs)
             for tag, num in self.tagmap.iteritems():
-                self.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag), dml=True)
-            self.query("DELETE FROM categories", dml=True)
+                self.query("INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag), crs=crs)
+            self.query("DELETE FROM categories", crs=crs)
             for cat_subcat, num in self.catmap.iteritems():
                 catsplit = cat_subcat.split(".", 1)
                 category = catsplit[0]
                 subcategory = catsplit[1] if len(catsplit)>1 else None
                 self.query("INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)",
-                    (num, category, subcategory, cat_subcat), dml=True)
+                    (num, category, subcategory, cat_subcat), crs=crs)
             self.con.commit()
         except Exception as e:
             self.con.rollback()
@@ -762,7 +753,7 @@ class MySQL(ObjectReq):
                 "    GROUP BY client_id"
                 " ) AS maxids ON last=id"
                 " WHERE timestamp < DATE_SUB(CURDATE(), INTERVAL %s DAY) AND last IS NULL",
-                days, dml=True)
+                days)
             affected = self.con.affected_rows()
             self.con.commit()
         except Exception as e:
@@ -775,7 +766,7 @@ class MySQL(ObjectReq):
         try:
             self.query(
                 "DELETE FROM events WHERE received < DATE_SUB(CURDATE(), INTERVAL %s DAY)",
-                days, dml=True)
+                days)
             affected = self.con.affected_rows()
             self.con.commit()
         except Exception as e:
@@ -1051,6 +1042,9 @@ class WardenHandler(ObjectReq):
                       "send_events_limit": self.send_events_limit}]))
 
         saved = 0
+        events_tosend = []
+        events_raw = []
+        events_nums = []
         for i, event in enumerate(events[0:self.send_events_limit]):
             v_errs = self.validator.check(event)
             if v_errs:
@@ -1068,12 +1062,21 @@ class WardenHandler(ObjectReq):
                     "categories": event.get('Category', [])}]))
                 continue
 
-            db_errs = self.db.store_event(self.req.client, event)
-            if db_errs:
-                errs.extend(self.add_event_nums([i], events, db_errs))
+            raw_event = json.dumps(event)
+            if len(raw_event) >= self.db.event_size_limit:
+                errs.extend(self.add_event_nums([i], events, [{"error": 413, "message": "Event too long (>%i B)" % self.event_size_limit}]))
                 continue
 
-            saved += 1
+            events_tosend.append(event)
+            events_raw.append(raw_event)
+            events_nums.append(i)
+
+        db_errs = self.db.store_events(self.req.client, events_tosend, events_raw)
+        if db_errs:
+            errs.extend(self.add_event_nums(events_nums, events_tosend, db_errs))
+            saved = 0
+        else:
+            saved = len(events_tosend)
 
         logging.info("Saved %i events" % saved)
         if errs:
-- 
GitLab