Skip to content
Snippets Groups Projects
warden_server.py 83.2 KiB
Newer Older
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011-2015 Cesnet z.s.p.o
# Use of this source is governed by a 3-clause BSD-style license, see LICENSE file.

from __future__ import print_function

import os
import logging
import logging.handlers
import json
import re
from traceback import format_tb
from itertools import repeat
if sys.version_info[0] >= 3:
    import configparser as ConfigParser
    from urllib.parse import parse_qs
    unicode = str

    def get_method_params(method):
        return method.__code__.co_varnames[:method.__code__.co_argcount]

else:
    import ConfigParser
    from urlparse import parse_qs
    def get_method_params(method):
        return method.func_code.co_varnames[:method.func_code.co_argcount]


# for local version of up to date jsonschema
sys.path.append(path.join(path.dirname(__file__), "..", "lib"))

Pavel Kácha's avatar
Pavel Kácha committed
VERSION = "3.0-beta3"
Pavel Kácha's avatar
Pavel Kácha committed

class Error(Exception):

    def __init__(self, method=None, req_id=None, errors=None, **kwargs):
        self.method = method
        self.errors = [kwargs] if kwargs else []
        if errors:
            self.errors.extend(errors)

    def append(self, _events=None, **kwargs):
        self.errors.append(kwargs)

    def get_http_err_msg(self):
        try:
            err = self.errors[0]["error"]
            msg = self.errors[0]["message"].replace("\n", " ")
        except (IndexError, KeyError):
            err = 500
            msg = "There's NO self-destruction button! Ah, you've just found it..."
        for e in self.errors:
            next_err = e.get("error", 500)
            if err != next_err:
                # errors not same, round to basic err code (400, 500)
                # and use the highest one
                err = max(err//100, next_err//100)*100
            next_msg = e.get("message", "Unknown error").replace("\n", " ")
            if msg != next_msg:
                msg = "Multiple errors"
        return err, msg

    def __str__(self):
        return "\n".join(self.str_err(e) for e in self.errors)

    def log(self, logger, prio=logging.ERROR):
        for e in self.errors:
            logger.log(prio, self.str_err(e))
            info = self.str_info(e)
            if info:
                logger.info(info)
            debug = self.str_debug(e)
            if debug:
                logger.debug(debug)

    def str_err(self, e):
        out.append("Error(%s) %s " % (e.get("error", 0), e.get("message", "Unknown error")))
        if "exc" in e and e["exc"]:
            out.append("(cause was %s: %s)" % (e["exc"][0].__name__, str(e["exc"][1])))
        return "".join(out)

    def str_info(self, e):
        ecopy = dict(e)    # shallow copy
        ecopy.pop("req_id", None)
        ecopy.pop("method", None)
        ecopy.pop("error", None)
        ecopy.pop("message", None)
        ecopy.pop("exc", None)
        if ecopy:
            out = "Detail: %s" % (json.dumps(ecopy, default=lambda v: str(v)))
        else:
            out = ""
        return out
        if not e.get("exc"):
            return ""
        exc_tb = e["exc"][2]
        if exc_tb:
            out.append("Traceback:\n")
            out.extend(format_tb(exc_tb))
        return "".join(out)

    def to_dict(self):
        errlist = []
        for e in self.errors:
            ecopy = dict(e)
            ecopy.pop("exc", None)
            errlist.append(ecopy)
        d = {
            "method": self.method,
            "req_id": self.req_id,
            "errors": errlist
        }
        return d


def get_clean_root_logger(level=logging.INFO):
    """ Attempts to get logging module into clean slate state """

    # We want to be able to set up at least stderr logger before any
    # configuration is read, and then later get rid of it and set up
    # whatever administrator requires.
    # However, there can exist only one logger, but we want to get a clean
    # slate everytime we initialize StreamLogger or FileLogger... which
    # is not exactly supported by logging module.
    # So, we look directly inside logger class and clean up handlers/filters
    # manually.
    logger.setLevel(level)
    while logger.handlers:
        logger.removeHandler(logger.handlers[0])
    while logger.filters:
        logger.removeFilter(logger.filters[0])
def StreamLogger(stream=sys.stderr, level=logging.DEBUG):
    """ Fallback handler just for setup, not meant to be used from
        configuration file because during wsgi query stdout/stderr
        is forbidden.
    """

    fhand = logging.StreamHandler(stream)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
class LogRequestFilter(logging.Filter):
    """ Filter class, instance of which is added to logger class to add
        info about request automatically into every logline, no matter
        how it came into existence.
    """

    def __init__(self, req):
        logging.Filter.__init__(self)
        self.req = req

    def filter(self, record):
        if self.req.env:
Pavel Kácha's avatar
Pavel Kácha committed
            record.req_preamble = "%08x/%s: " % (self.req.req_id or 0, self.req.path)
def FileLogger(req, filename, level=logging.INFO):

    fhand = logging.FileHandler(filename)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(req_preamble)s%(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
Pavel Kácha's avatar
Pavel Kácha committed
    logger.info("Initialized FileLogger(req=%r, filename=\"%s\", level=%s)" % (req, filename, level))
def SysLogger(req, socket="/dev/log", facility=logging.handlers.SysLogHandler.LOG_DAEMON, level=logging.INFO):

    fhand = logging.handlers.SysLogHandler(address=socket, facility=facility)
    fform = logging.Formatter('%(filename)s[%(process)d]: (%(levelname)s) %(req_preamble)s%(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
Pavel Kácha's avatar
Pavel Kácha committed
    logger.info("Initialized SysLogger(req=%r, socket=\"%s\", facility=\"%d\", level=%s)" % (req, socket, facility, level))
Pavel Kácha's avatar
Pavel Kácha committed
    return logger
Pavel Kácha's avatar
Pavel Kácha committed
Client = namedtuple("Client", [
    "id", "registered", "requestor", "hostname", "name",
    "secret", "valid", "read", "debug", "write", "test", "note"])
        attrs = get_method_params(self.__init__)[1:]
Pavel Kácha's avatar
Pavel Kácha committed
        eq_str = ["%s=%r" % (attr, getattr(self, attr, None)) for attr in attrs]
        return "%s(%s)" % (type(self).__name__, ", ".join(eq_str))
class Request(Object):
    """ Simple container for info about ongoing request.
        One instance gets created before server startup, and all other
        configured objects get it as parameter during instantiation.

        Server then takes care of populating this instance on the start
        of wsgi request (and resetting at the end). All other objects
        then can find this actual request info in their own self.req.

        However, only Server.wsgi_app, handler (WardenHandler) exposed
        methods and logging related objects should use self.req directly.
        All other objects should use self.req only as source of data for
        error/exception handling/logging, and should take/return
        necessary data as arguments/return values for clarity on
        which data their main codepaths work with.
    """
    def reset(self, env=None, client=None, path=None, req_id=None):
        self.env = env
        self.client = client
        self.path = path or ""
        if req_id is not None:
            self.req_id = req_id
        else:
            self.req_id = 0 if env is None else randint(0x00000000, 0xFFFFFFFF)
Pavel Kácha's avatar
Pavel Kácha committed
    __init__ = reset

    def error(self, **kwargs):
        return Error(self.path, self.req_id, **kwargs)
class PlainAuthenticator(ObjectBase):
    def __init__(self, req, log, db):
        ObjectBase.__init__(self, req, log)
Pavel Kácha's avatar
Pavel Kácha committed
    def authenticate(self, env, args, hostnames=None, check_secret=True):
        secret = args.get("secret", [None])[0] if check_secret else None

        client = self.db.get_client_by_name(hostnames, name, secret)

        if not client:
            self.log.info("authenticate: client not found by name: \"%s\", secret: %s, hostnames: %s" % (
                name, secret, str(hostnames)))
        # Clients with 'secret' set must get authenticated by it.
        # No secret turns secret auth off for this particular client.
        if client.secret is not None and secret is None and check_secret:
            self.log.info("authenticate: missing secret argument")
        self.log.info("authenticate: %s" % str(client))
        # These args are not for handler
        args.pop("client", None)
        args.pop("secret", None)

    def authorize(self, env, client, path, method):
        if method.debug:
            if not client.debug:
                self.log.info("authorize: failed, client does not have debug enabled")
        if method.read:
            if not client.read:
                self.log.info("authorize: failed, client does not have read enabled")
        if method.write:
            if not (client.write or client.test):
                self.log.info("authorize: failed, client is not allowed to write or test")
class X509Authenticator(PlainAuthenticator):

    def get_cert_dns_names(self, pem):

        cert = M2Crypto.X509.load_cert_string(pem)

        subj = cert.get_subject()
        commons = [n.get_data().as_text() for n in subj.get_entries_by_nid(subj.nid["CN"])]

Pavel Kácha's avatar
Pavel Kácha committed
            extstrs = cert.get_ext("subjectAltName").get_value().split(",")
Pavel Kácha's avatar
Pavel Kácha committed
            extstrs = []
        extstrs = [val.strip() for val in extstrs]
        altnames = [val[4:] for val in extstrs if val.startswith("DNS:")]

        # bit of mangling to get rid of duplicates and leave commonname first
        firstcommon = commons[0]
        return [firstcommon] + list(set(altnames+commons) - set([firstcommon]))

    def is_verified_by_apache(self, env, args):
        # Allows correct work while SSLVerifyClient both "optional" and "required"
        verify = env.get("SSL_CLIENT_VERIFY")
        if verify == "SUCCESS":
            return True
Pavel Kácha's avatar
Pavel Kácha committed
        exception = self.req.error(
            message="authenticate: certificate verification failed",
            error=403, args=args, ssl_client_verify=verify, cert=env.get("SSL_CLIENT_CERT"))
        exception.log(self.log)
        return False
    def authenticate(self, env, args):
        if not self.is_verified_by_apache(env, args):
            return None

            cert_names = self.get_cert_dns_names(env["SSL_CLIENT_CERT"])
        except:
Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: cannot get or parse certificate from env",
                error=403, exc=sys.exc_info(), env=env)
            exception.log(self.log)
Michal Kostenec's avatar
Michal Kostenec committed
            return None
Pavel Kácha's avatar
Pavel Kácha committed
        return PlainAuthenticator.authenticate(self, env, args, hostnames=cert_names)

class X509NameAuthenticator(X509Authenticator):
    def authenticate(self, env, args):
        if not self.is_verified_by_apache(env, args):
            return None

        try:
            cert_name = env["SSL_CLIENT_S_DN_CN"]
        except:
Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: cannot get or parse certificate from env",
                error=403, exc=sys.exc_info(), env=env)
            exception.log(self.log)
        if cert_name != args.setdefault("client", [cert_name])[0]:
Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: client name does not correspond with certificate",
                error=403, cn=cert_name, args=args)
            exception.log(self.log)
Pavel Kácha's avatar
Pavel Kácha committed
        return PlainAuthenticator.authenticate(self, env, args, check_secret=False)
class X509MixMatchAuthenticator(X509Authenticator):
    def __init__(self, req, log, db):
        PlainAuthenticator.__init__(self, req, log, db)
        self.hostname_auth = X509Authenticator(req, log, db)
        self.name_auth = X509NameAuthenticator(req, log, db)
        if not self.is_verified_by_apache(env, args):
            return None

Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: cannot get or parse certificate from env",
                error=403, exc=sys.exc_info(), env=env)
            exception.log(self.log)
Pavel Kácha's avatar
Pavel Kácha committed
        secret = args.get("secret", [None])[0]
        # Client names are in reverse notation than DNS, client name should
        # thus never be the same as machine hostname (if it is, client
        # admin does something very amiss).
        # So, if client sends the same name in query as in the certificate,
        # or sends no name or secret (which is necessary for hostname auth),
        # use X509NameAuthenticator. Otherwise (names are different and there
        # is name and/or secret in query) use (hostname) X509Authenticator.
        if name == cert_name or (name is None and secret is None):
            auth = self.name_auth
        else:
            auth = self.hostname_auth
        self.log.info("MixMatch is choosing %s (name: %s, cert_name: %s)" % (type(auth).__name__, name, cert_name))
    def __init__(self, req, log):
        ObjectBase.__init__(self, req, log)
    def check(self, event):
        return []


class JSONSchemaValidator(NoValidator):

    def __init__(self, req, log, filename=None):
        NoValidator.__init__(self, req, log)
        self.path = filename or path.join(path.dirname(__file__), "idea.schema")
        with io.open(self.path, "r", encoding="utf-8") as f:
            self.schema = json.load(f)
        self.validator = Draft4Validator(self.schema)

    def check(self, event):

        def sortkey(k):
            """ Treat keys as lowercase, prefer keys with less path segments """
            return (len(k.path), "/".join(str(k.path)).lower())

        res = []
        for error in sorted(self.validator.iter_errors(event), key=sortkey):
            res.append({
                "error": 460,
                "message": "Validation error: key \"%s\", value \"%s\"" % (
                    error.instance
                ),
                "expected": error.schema.get('description', 'no additional info')
            })
class DataBase(ObjectBase):
Pavel Kácha's avatar
Pavel Kácha committed

Pavel Kácha's avatar
Pavel Kácha committed
    def __init__(
            self, req, log, host, user, password, dbname, port, retry_count,
            retry_pause, event_size_limit, catmap_filename, tagmap_filename):
        ObjectBase.__init__(self, req, log)
Pavel Kácha's avatar
Pavel Kácha committed
        self.host = host
        self.user = user
        self.password = password
        self.dbname = dbname
        self.port = port
        self.retry_count = retry_count
        self.retry_pause = retry_pause
        self.event_size_limit = event_size_limit
        self.catmap_filename = catmap_filename
        self.tagmap_filename = tagmap_filename

        with io.open(catmap_filename, "r", encoding="utf-8") as catmap_fd:
            self.catmap = json.load(catmap_fd)
            self.catmap_other = self.catmap["Other"]    # Catch error soon, avoid lookup later

        with io.open(tagmap_filename, "r", encoding="utf-8") as tagmap_fd:
            self.tagmap_other = self.tagmap["Other"]    # Catch error soon, avoid lookup later
Pavel Kácha's avatar
Pavel Kácha committed

Pavel Kácha's avatar
Pavel Kácha committed

        try:
            if self.con:
                self.con.close()
        except Exception:
            pass
        self.con = None
    def __del__(self):
        self.close()
Pavel Kácha's avatar
Pavel Kácha committed

    def repeat(self):
        """ Allows for graceful repeating of transactions self.retry_count
            times. Unsuccessful attempts wait for self.retry_pause until
            next attempt.

            Meant for usage with context manager:

            for attempt in self.repeat():
                with attempt as db:
                    res = db.query_all(...)

            Note that it's not reentrant (as is not underlying MySQL
            connection), so avoid nesting on the same MySQL object.
        """
        self.retry_attempt = self.retry_count
        while self.retry_attempt:
            if self.retry_attempt != self.retry_count:
                sleep(self.retry_pause)
            self.retry_attempt -= 1
            yield self

    def __enter__(self):
        """ Context manager protocol. Guarantees that transaction will
            get either commited or rolled back in case of database
            exception. Can be used with self.repeat(), or alone as:

            with self as db:
                res = db.query_all(...)

            Note that it's not reentrant (as is not underlying MySQL
            connection), so avoid nesting on the same MySQL object.
        """
        if not self.retry_attempt:
            self.retry_attempt = 0
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """ Context manager protocol. If db exception is fired and
            self.retry_attempt is not zero, it is only logged and
            does not propagate, otherwise it propagates up. Also
            open transaction is rolled back.
            In case of no exception, transaction gets commited.
        """
        if exc_type is None:
            self.con.commit()
            self.retry_attempt = 0
        else:
            try:
                if self.con is not None:
                    self.con.rollback()
            except self.db.Error:
                pass
            try:
                self.close()
            except self.db.Error:
                pass
            if self.retry_attempt > 0:
                self.log.info("Database error (%d attempts left): %s %s" %
                              (self.retry_attempt, exc_type.__name__, exc_val))
                return True

    def _query(self, *args, **kwargs):
        if not self.con:
            self.connect()
        crs = self.con.cursor()
        self.log.debug("execute: %s %s" % (args, kwargs))
        crs.execute(*args, **kwargs)
        return crs

    def _query_multiple(self, query, params, ret, fetch):
        res = None
        for n, (q, p) in enumerate(zip(query, params)):
            cur = self._query(q, p)
            if n == ret:
                res = fetch(cur)
        if ret == -1:  # fetch the result of the last query
            res = fetch(cur)
        return res
    def execute(self, query, params, ret=None):
        """Execute the provided queries; discard the result"""
        self._query_multiple(query, params, None, None)
    def query_all(self, query, params, ret=-1):
        """Execute the provided queries; return list of all rows as dicts of the ret-th query (0 based)"""
        return self._query_multiple(query, params, ret, lambda cur: cur.fetchall())
    def query_one(self, query, params, ret=-1):
        """Execute the provided queries; return the first result of the ret-th query (0 based)"""
        return self._query_multiple(query, params, ret, lambda cur: cur.fetchone())

    def query_rowcount(self, query, params, ret=-1):
        """Execute provided query; return the number of affected rows or the number of returned rows of the ret-th query (0 based)"""
        return self._query_multiple(query, params, ret, lambda cur: cur.rowcount)

    def _get_comma_perc(self, l):
        return ",".join(repeat("%s", l if isinstance(l, int) else len(l)))

    def _get_comma_perc_n(self, n, l):
        return ", ".join(repeat("(%s)" % self._get_comma_perc(n), len(l)))

    def _get_not(self, b):
        return "" if b else "NOT"
    def _build_get_client_by_name(self, cert_names, name, secret):
        """Build query and params for client lookup"""
    def get_client_by_name(self, cert_names=None, name=None, secret=None):
        query, params, ret = self._build_get_client_by_name(cert_names, name, secret)
                rows = db.query_all(query, params, ret)
Pavel Kácha's avatar
Pavel Kácha committed
                if len(rows) > 1:
                        "get_client_by_name: query returned more than one result (cert_names = %s, name = %s, secret = %s): %s" %
                        (cert_names, name, secret, ", ".join([str(Client(**row)) for row in rows]))
                    )
    def _build_get_clients(self, id):
        """Build query and params for client lookup by id"""

    def get_clients(self, id=None):
        query, params, ret = self._build_get_clients(id)

                rows = db.query_all(query, params, ret=ret)
    def _build_add_modify_client(self, id, **kwargs):
        """Build query and params for adding/modifying client"""

    def add_modify_client(self, id=None, **kwargs):
        if id is not None and all(kwargs.get(attr, None) is None for attr in set(Client._fields) - {"id", "registered"}):

        query, params, ret = self._build_add_modify_client(id, **kwargs)

                res_id = db.query_one(query, params, ret=ret)["id"]
                newid = res_id if id is None else id
    def _build_get_debug_version(self):
        pass

    def _build_get_debug_tablestat(self):
        pass

Pavel Kácha's avatar
Pavel Kácha committed
    def get_debug(self):
        vquery, vparams, vret = self._build_get_debug_version()
        tquery, tparams, tret = self._build_get_debug_tablestat()
        for attempt in self.repeat():
            with attempt as db:
                return {
                    "db": type(self).__name__,
                    "version": db.query_one(vquery, vparams, vret)["version"],
                    "tables": db.query_all(tquery, tparams, tret)
Pavel Kácha's avatar
Pavel Kácha committed

    def getMaps(self, section, variables):
        maps = []
Pavel Kácha's avatar
Pavel Kácha committed
                raise self.req.error(
                    message="Wrong tag or category used in query.",
                    error=422, exc=sys.exc_info(), key=v
                )
            maps.append(mapped)
        return set(maps)    # unique
    def _build_fetch_events(
            self, client, id, count,
            cat, nocat, tag, notag, group, nogroup):
        """Build query and params for fetching events based on id, count and category, tag and group filters"""

    def _load_event_json(self, data):
        """Return decoded json from data loaded from database, if unable to decode, return None"""

Pavel Kácha's avatar
Pavel Kácha committed
    def fetch_events(
            self, client, id, count,
            cat=None, nocat=None,
            tag=None, notag=None,
            group=None, nogroup=None):
Pavel Kácha's avatar
Pavel Kácha committed

Pavel Kácha's avatar
Pavel Kácha committed
            raise self.req.error(
                message="Unrealizable conditions. Choose cat or nocat option.",
                error=422, cat=cat, nocat=nocat)
Pavel Kácha's avatar
Pavel Kácha committed
            raise self.req.error(
                message="Unrealizable conditions. Choose tag or notag option.",
                error=422, tag=tag, notag=notag)
Pavel Kácha's avatar
Pavel Kácha committed
            raise self.req.error(
                message="Unrealizable conditions. Choose group or nogroup option.",
                error=422, group=group, nogroup=nogroup)
Michal Kostenec's avatar
Michal Kostenec committed

        query, params, ret = self._build_fetch_events(
            client, id, count,
            cat, nocat,
            tag, notag,
            group, nogroup
        )
        row = None
        for attempt in self.repeat():
            with attempt as db:
                row = db.query_all(query, params, ret=ret)
Michal Kostenec's avatar
Michal Kostenec committed

Michal Kostenec's avatar
Michal Kostenec committed
        if row:
            maxid = max(r['id'] for r in row)
        else:
            maxid = self.getLastEventId()

            e = self._load_event_json(r["data"])
            if e is None:  # null cannot be valid event JSON
                # Note that we use Error object just for proper formatting,
                # but do not raise it; from client perspective invalid
                # events get skipped silently.
Pavel Kácha's avatar
Pavel Kácha committed
                err = self.req.error(
                    message="Unable to deserialize JSON event from db, id=%s" % r["id"],
                    error=500, exc=sys.exc_info(), id=r["id"])
                err.log(self.log, prio=logging.WARNING)
            else:
                events.append(e)
Michal Kostenec's avatar
Michal Kostenec committed
            "lastid": maxid,
    def _build_store_events_event(self, client, event, raw_event):
        """Build query and params for event insertion"""

    def _build_store_events_categories(self, event_id, cat_ids):
        """Build query and params for insertion of event-categories mapping"""

    def _build_store_events_tags(self, event_id, tag_ids):
        """Build query and params for insertion of event-tags mapping"""

    def store_events(self, client, events, events_raw):
            for attempt in self.repeat():
                with attempt as db:
                    for event, raw_event in zip(events, events_raw):
                        equery, eparams, eret = self._build_store_events_event(client, event, raw_event)
                        lastid = db.query_one(equery, eparams, ret=eret)["id"]
                        cats = set(catlist) | {cat.split(".", 1)[0] for cat in catlist}
                        cat_ids = [self.catmap.get(cat, self.catmap_other) for cat in cats]
                        cquery, cparams, _ = self._build_store_events_categories(lastid, cat_ids)
                        db.execute(cquery, cparams)
                        tags = {tag for node in nodes for tag in node.get('Type', [])}
                        if tags:
                            tag_ids = [self.tagmap.get(tag, self.tagmap_other) for tag in tags]
                            tquery, tparams, _ = self._build_store_events_tags(lastid, tag_ids)
                            db.execute(tquery, tparams)

            exception = self.req.error(message="DB error", error=500, exc=sys.exc_info(), env=self.req.env)
            exception.log(self.log)
            return [{"error": 500, "message": "DB error %s" % type(e).__name__}]
    def _build_insert_last_received_id(self, client, id):
        """Build query and params for insertion of the last event id received by client"""

Michal Kostenec's avatar
Michal Kostenec committed
    def insertLastReceivedId(self, client, id):
        self.log.debug("insertLastReceivedId: id %i for client %i(%s)" % (id, client.id, client.hostname))
        query, params, _ = self._build_insert_last_received_id(client, id)
                db.execute(query, params)

    def _build_get_last_event_id(self):
        """Build query and params for querying the id of the last inserted event"""
Michal Kostenec's avatar
Michal Kostenec committed

    def getLastEventId(self):
        query, params, ret = self._build_get_last_event_id()
                id_ = db.query_one(query, params, ret=ret)["id"]
                return id_ or 1

    def _build_get_last_received_id(self, client):
        """Build query and params for querying the last event id received by client"""
Michal Kostenec's avatar
Michal Kostenec committed

    def getLastReceivedId(self, client):
        query, params, ret = self._build_get_last_received_id(client)
                res = db.query_one(query, params, ret=ret)

                if res is None:
                    self.log.debug("getLastReceivedId: probably first access, unable to get id for client %i(%s)" %
                        (client.id, client.hostname))
                    id = res["id"] or 1
                    self.log.debug("getLastReceivedId: id %i for client %i(%s)" %
                        (id, client.id, client.hostname))
    def _build_load_maps_tags(self):
        """Build query and params for updating the tag map"""

    def _build_load_maps_cats(self):
        """Build query and params for updating the catetgory map"""

        tquery, tparams, _ = self._build_load_maps_tags()
        cquery, cparams, _ = self._build_load_maps_cats()
            db.execute(tquery, tparams)
            db.execute(cquery, cparams)

    def _build_purge_lastlog(self, days):
        """Build query and params for purging stored client last event mapping older than days"""
    def purge_lastlog(self, days):
        query, params, ret = self._build_purge_lastlog(days)
            return db.query_rowcount(query, params, ret=ret)

    def _build_purge_events_get_id(self, days):
        """Build query and params to get largest event id of events older than days"""

    def _build_purge_events_events(self, id_):
        """Build query and params to remove events older then days and their mappings"""

    def purge_events(self, days):
        iquery, iparams, iret = self._build_purge_events_get_id(days)
        with self as db:
            id_ = db.query_one(iquery, iparams, ret=iret)["id"]
            if id_ is None:
                return 0
            equery, eparams, eret = self._build_purge_events_events(id_)
            affected = db.query_rowcount(equery, eparams, ret=eret)
            return affected


DataBase = abc.ABCMeta("DataBase", (DataBase,), {})


class MySQL(DataBase):

    def __init__(
            self, req, log, host, user, password, dbname, port, retry_count,
            retry_pause, event_size_limit, catmap_filename, tagmap_filename):

        super(DataBase, self).__init__(req, log, host, user, password, dbname, port, retry_count,
            retry_pause, event_size_limit, catmap_filename, tagmap_filename)

        import MySQLdb as db
        import MySQLdb.cursors as mycursors
        self.db = db
        self.mycursors = mycursors

    def connect(self):
        self.con = self.db.connect(
            host=self.host, user=self.user, passwd=self.password,
            db=self.dbname, port=self.port, cursorclass=self.mycursors.DictCursor)

    def _build_get_client_by_name(self, cert_names=None, name=None, secret=None):
        """Build query and params for client lookup"""
        query = ["SELECT * FROM clients WHERE valid = 1"]
        params = []
        if name:
            query.append(" AND name = %s")
            params.append(name.lower())
        if secret:
            query.append(" AND secret = %s")
            params.append(secret)
        if cert_names:
            query.append(" AND hostname IN (%s)" % self._get_comma_perc(cert_names))
            params.extend(n.lower() for n in cert_names)

        return ["".join(query)], [params], 0

    def _build_get_clients(self, id):
        """Build query and params for client lookup by id"""
        query = ["SELECT * FROM clients"]
        params = []
        if id:
            query.append("WHERE id = %s")
            params.append(id)
        query.append("ORDER BY id")

        return [" ".join(query)], [params], 0

    def _build_add_modify_client(self, id, **kwargs):
        """Build query and params for adding/modifying client"""
        query = []
        params = []
        uquery = []
        if id is None:
            query.append("INSERT INTO clients SET")
            uquery.append("registered = now()")
        else:
            query.append("UPDATE clients SET")
        for attr in set(Client._fields) - set(["id", "registered"]):
            val = kwargs.get(attr, None)
            if val is not None:  # guaranteed at least one is not None
                if attr == "secret" and val == "":  # disable secret
                    val = None
                uquery.append("`%s` = %%s" % attr)
                params.append(val)

        query.append(", ".join(uquery))
        if id is not None:
            query.append("WHERE id = %s")
            params.append(id)
        return (
            [" ".join(query), 'SELECT LAST_INSERT_ID() AS id'],
            [params, []],
            1
        )

    def _build_get_debug_version(self):
        return ["SELECT VERSION() AS version"], [()], 0

    def _build_get_debug_tablestat(self):
        return ["SHOW TABLE STATUS"], [()], 0

    def _load_event_json(self, data):
        """Return decoded json from data loaded from database, if unable to decode, return None"""
        try:
            return json.loads(data)
        except Exception:
            return None

    def _build_fetch_events(
            self, client, id, count,
            cat, nocat, tag, notag, group, nogroup):
        query = ["SELECT e.id, e.data FROM clients c RIGHT JOIN events e ON c.id = e.client_id WHERE e.id > %s"]
        params = [id or 0]

        if cat or nocat:
            cats = self.getMaps(self.catmap, (cat or nocat))
            query.append(
                " AND e.id %s IN (SELECT event_id FROM event_category_mapping WHERE category_id IN (%s))" %
                (self._get_not(cat), self._get_comma_perc(cats))
            )
            params.extend(cats)

        if tag or notag:
            tags = self.getMaps(self.tagmap, (tag or notag))
            query.append(
                " AND e.id %s IN (SELECT event_id FROM event_tag_mapping WHERE tag_id IN (%s))" %
                (self._get_not(tag), self._get_comma_perc(tags))
            )
            params.extend(tags)

        if group or nogroup: