Skip to content
Snippets Groups Projects
warden_server.py 21.6 KiB
Newer Older
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011-2013 Cesnet z.s.p.o
# Use of this source is governed by a 3-clause BSD-style license, see LICENSE file.

import sys
import logging
import logging.handlers
import ConfigParser
from traceback import format_tb
import M2Crypto.X509
import json
from uuid import uuid4
from time import time, gmtime
from math import trunc
from io import BytesIO
from urlparse import parse_qs
from os import path

# for local version of up to date jsonschema
sys.path.append(path.join(path.dirname(__file__), "..", "lib"))

from jsonschema import Draft4Validator, FormatChecker


VERSION = "3.0-not-even-alpha"


class Error(Exception):

    def __init__(self, message, error=500, method=None,
            detail=None, exc=(None, None, None)):
        self.error = int(error)
        self.method = method
        self.message = message
        self.detail = detail
        (self.exctype, self.excval, self.exctb) = exc or sys.exc_info()
        self.cause = self.excval # compatibility with other exceptions


    def __str__(self):
        out = []
        out.append("Error(%s)" % (self.error))
        if self.method is not None:
            out.append(" in \"%s\"" % self.method)
        if self.message is not None:
            out.append(": %s" % self.message)
        if self.excval is not None:
            out.append(" - cause was %s: %s" % (type(self.excval).__name__, str(self.excval)))
        return "".join(out)


    def info_str(self):
        return ("Detail: %s" % self.detail) or ""


    def debug_str(self):
        out = []
        if self.excval is not None:
            out.append("Exception %s: %s\n" % (type(self.excval).__name__, str(self.excval)))
        if self.exctb is not None:
            out.append("Traceback:\n%s" % "".join(format_tb(self.exctb)))
        return "".join(out)


    def to_dict(self):
        d = {}
        if self.error is not None:
            d["error"] = self.error
        if self.method is not None:
            d["method"] = self.method
        if self.message is not None:
            d["message"] = self.message
        if self.detail is not None:
            d["detail"] = self.detail
        if self.excval is not None:
            d["message"] = d["message"] + ", cause was %s: %s" % (type(self.excval).__name__, str(self.excval))
        return d



def get_clean_root_logger(level=logging.INFO):
    """ Attempts to get logging module into clean slate state """

    # We want to be able to set up at least stderr logger before any
    # configuration is read, and then later get rid of it and set up
    # whatever administrator requires.
    # However, there can exist only one logger, but we want to get a clean
    # slate everytime we initialize StreamLogger or FileLogger... which
    # is not exactly supported by logging module.
    # So, we look directly inside logger class and clean up handlers/filters
    # manually.
    logger = logging.getLogger()  # no need to create new
    logger.setLevel(level)
    while logger.handlers:
        logger.removeHandler(logger.handlers[0])
    while logger.filters:
        logger.removeFilter(logger.filters[0])
    return logger



def StreamLogger(stream=sys.stderr, level=logging.INFO):
    """ Fallback handler just for setup, not meant to be used from
        configuration file because during wsgi query stdout/stderr
        is forbidden.
    """

    fhand = logging.StreamHandler(stream)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)



def FileLogger(filename, level=logging.INFO):

    fhand = logging.FileHandler(filename)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
    logging.info("Initialized FileLogger(filename=\"%s\", \"%s\")" % (filename, level))



def SysLogger(socket="/dev/log", facility=logging.handlers.SysLogHandler.LOG_DAEMON, level=logging.INFO):

    fhand = logging.handlers.SysLogHandler(address=socket, facility=facility)
    fform = logging.Formatter('%(filename)s[%(process)d]: (%(levelname)s) %(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
    logging.info("Initialized SysLogger(socket=\"%s\", facility=\"%s\", level=\"%s\")" % (socket, facility, level))



class Object(object):

    def __str__(self):
        return "%s()" % type(self).__name__



class NoAuthenticator(Object):

    def __init__(self):
        Object.__init__(self)

    def authenticate (self, env):
        return "anybody"    # or None


    def authorize(self, env, client, method, args):
        return (client is not None)



class X509Authenticator(NoAuthenticator):

    def __init__(self, db):
        self.db = db
        NoAuthenticator.__init__(self)


    def __str__(self):
        return "%s(db=%s)" % (type(self).__name__, type(self.db).__name__)


    def get_cert_dns_names(self, pem):

        cert = M2Crypto.X509.load_cert_string(pem)

        subj = cert.get_subject()
        commons = [n.get_data().as_text() for n in subj.get_entries_by_nid(subj.nid["CN"])]

        ext = cert.get_ext("subjectAltName")
        extstrs = [val.strip() for val in ext.get_value().split(",")]
        altnames = [val[4:] for val in extstrs if val.startswith("DNS:")]

        # bit of mangling to get rid of duplicates and leave commonname first
        firstcommon = commons[0]
        return [firstcommon] + list(set(altnames+commons) - set([firstcommon]))


    def authenticate (self, env):
        names = self.get_cert_dns_names(env["SSL_CLIENT_CERT"])
        # FIXME: should probably fetch and return id from db, not textual username
        env["warden.x509_dns_names"] = names
        return names[0] if names else None


    def authorize(self, env, client, method, args):
        # Here we might choose with methods or args to (dis)allow for which
        # client.
        # FIXME: fetch reader/writer or better list of allowed methods from db
        return (client is not None)



class NoValidator(Object):

    def check(self, event):
        return []



class JSONSchemaValidator(NoValidator):

    def __init__(self, filename=None):
        self.path = filename or path.join(path.dirname(__file__), "idea.schema")
        with open(self.path) as f:
            self.schema = json.load(f)
        self.validator = Draft4Validator(self.schema, format_checker=FormatChecker())


    def __str__(self):
        return "%s(filename=\"%s\")" % (type(self).__name__, self.path)


    def check(self, event):

        def sortkey(k):
            """ Treat keys as lowercase, prefer keys with less path segments """
            return (len(k.path), "/".join(str(k.path)).lower())

        res = []
        for error in sorted(self.validator.iter_errors(event), key=sortkey):
            res.append(
                "Validation error: key \"%s\", value \"%s\", expected - %s, error message - %s\n" % (
                    u"/".join(str(v) for v in error.path),
                    error.instance,
                    error.schema.get('description', 'no additional info'),
                    error.message))

        return res



class Database(Object):
    #FIXME: here database model will dictate methods, which other
    #       objects will use. This is only dull example.

    def __init__(self):
        # Will accept db configuration parameters, initialize connection, etc.
        pass

    def gen_random_idea(self):

        def get_precise_timestamp():
            t = time()
            us = trunc((t-trunc(t))*1000000)
            g = gmtime(t)
            iso = '%04d-%02d-%02dT%02d:%02d:%02d.%0dZ' % (g[0:6]+(us,))
            return iso

        return {
           "Format": "IDEA0",
           "ID": str(uuid4()),
           "DetectTime": get_precise_timestamp(),
           "Category": ["Test"],
        }


    def fetch_events(self, client, id, count,
            cat=None, nocat=None,
            tag=None, notag=None,
            group=None, nogroup=None):
        return {
            "lastid": (id or 0)+count,
            "events": [self.gen_random_idea() for i in range(count)]
        }


    def store_events(self, client, events):
        errs = []   # See sendEvents and validation, should return something similar
        return errs



def expose(meth):
    meth.exposed = True
    return meth


class Server(Object):

    def __init__(self, auth, handler):
        self.auth = auth
        self.handler = handler


    def __str__(self):
        return "%s(auth=%s, handler=%s)" % (type(self).__name__, type(self.auth).__name__, type(self.handler).__name__)


    def sanitize_args(self, path, func, args, exclude=["self", "_env", "_client"]):
        # silently remove internal args, these should never be used
        # but if somebody does, we do not expose them by error message
        intargs = set(args).intersection(exclude)
        for a in intargs:
            del args[a]
        if intargs:
            logging.info("%s called with internal args: %s" % (path, ", ".join(intargs)))

        # silently remove surplus arguments - potential forward
        # compatibility (unknown args will get ignored)
        badargs = set(args)-set(func.func_code.co_varnames[0:func.func_code.co_argcount])
        for a in badargs:
            del args[a]
        if badargs:
            logging.info("%s called with superfluous args: %s" % (path, ", ".join(badargs)))

        return args


    def wsgi_app(self, environ, start_response, exc_info=None):
        path = environ.get("PATH_INFO", "").lstrip("/")
        output = ""
        status = "200 OK"
        headers = [('Content-type', 'application/json')]
        exception = None

        try:
            try:
                injson = environ['wsgi.input'].read()
            except:
                raise Error("Data read error", 400, method=path, exc=sys.exc_info())

            try:
                method = getattr(self.handler, path)
                method.exposed    # dummy access to trigger AttributeError
            except Exception:
                raise Error("You've fallen of the cliff.", 404, method=path)

            client = self.auth.authenticate(environ)
            if not client:
                raise Error("I'm watching YOU.", 403, method=path)

            try:
                events = json.loads(injson) if injson else None
            except Exception:
                raise Error("Deserialization error", 400, method=path,
                    exc=sys.exc_info(), detail={"args": injson})

            args = parse_qs(environ.get('QUERY_STRING', ""))
            for k, v in args.iteritems():
                args[k] = v[0]
            logging.debug("%s called with %s" % (path, str(args)))
            if events:
                args["events"] = events

            if not self.auth.authorize(environ, client, path, args):
                raise Error("I'm watching YOU.", 403, method=path, detail={"client": client})

            args = self.sanitize_args(path, method, args)
            result = method(_env=environ, _client=client, **args)   # call requested method

            try:
            # 'default': takes care of non JSON serializable objects,
            # which could (although shouldn't) appear in handler code
                output = json.dumps(result, default=lambda v: str(v))
            except Exception as e:
                raise Error("Serialization error", 500, method=path,
                    exc=sys.exc_info(), detail={"args": str(result)})

        except Error as e:
            exception = e
        except Exception as e:
            exception = Error("Server exception", 500, method=path, exc=sys.exc_info())

        if exception:
            status = "%d %s" % (exception.error, exception.message)
            result = exception.to_dict()
            try:
                output = json.dumps(result, default=lambda v: str(v))
            except Exception as e:
                # Here all bets are off, generate at least sane output
                output = '{"error": %d, "message": "%s"}' % (
                    exception.error, exception.message)

            logging.error(str(exception))
            i = exception.info_str()
            if i:
                logging.info(i)
            d = exception.debug_str()
            if d:
                logging.debug(d)

        headers.append(('Content-Length', str(len(output))))
        start_response(status, headers)
        return [output]


    __call__ = wsgi_app



class WardenHandler(Object):

    def __init__(self, validator, db,
            send_events_limit=100000, get_events_limit=100000,
            description=None):

        self.db = db
        self.validator = validator
        self.send_events_limit = send_events_limit
        self.get_events_limit = get_events_limit
        self.description = description


    def __str__(self):
        return "%s(validator=%s, db=%s, send_events_limit=%s, get_events_limit=%s, description=\"%s\")" % (
            type(self).__name__, type(self.validator).__name__, type(self.db).__name__,
            self.get_events_limit, self.send_events_limit, self.description)


    @expose
    def getDebug(self, _env, _client):
        return _env


    @expose
    def getInfo(self, _env, _client):
        info = {
            "version": VERSION,
            "send_events_limit": self.send_events_limit,
            "get_events_limit": self.get_events_limit
        }
        if self.description:
            info["description"] = self.description
        return info


    @expose
    def getEvents(self, _env, _client, id=None, count=None,
            cat=None, nocat=None,
            tag=None, notag=None,
            group=None, nogroup=None):

        try:
            id = int(id)
        except (ValueError, TypeError):
            id=0

        try:
            count = int(count)
        except (ValueError, TypeError):
            count = 1

        if self.get_events_limit:
            count = min(count, self.get_events_limit)

        logging.debug("getEvents - count: %s" % count)
        res = self.db.fetch_events(_client, id, count, cat, nocat, tag, notag, group, nogroup)
        logging.info("getEvents(%d, %d, %s, %s, %s, %s, %s, %s): sending %d events" % (
            id, count, cat, nocat, tag, notag, group, nogroup, len(res["events"])))

        return res


    @expose
    def sendEvents(self, _env, _client, events=[]):
        if not isinstance(events, list):
            raise Error("List of events expected", 400, method="sendEvents")

        if len(events)>self.send_events_limit:
            raise Error("Too much events in one batch", 400, method="sendEvents",
                detail={"limit": self.send_events_limit})

        # FIXME: Maybe just croak on first bad event, save good ones so far
        # and make client deal with the rest? Would simplify server error
        # handling greatly.
        okevents = []
        valerrs = []
        for event in events:
            verrs = self.validator.check(event)
            if verrs:
                valerrs.append({"errors": verrs, "event": event})
            else:
                okevents.append(event)

        dberrs = self.db.store_events(_client, okevents)

        if valerrs or dberrs:
            raise Error("Event storage error", 500, method="sendEvents",
                detail=valerrs+dberrs)

        logging.info("sendEvents(...): Saved %i events" % len(okevents))

        return {"saved": len(okevents)}



def read_ini(path):
    c = ConfigParser.RawConfigParser()
    res = c.read(path)
    if not res or not path in res:
        # We don't have loggin yet, hopefully this will go into webserver log
        raise Error("Unable to read config: %s" % path)
    data = {}
    for sect in c.sections():
        for opts in c.options(sect):
            lsect = sect.lower()
            if not lsect in data:
                data[lsect] = {}
            data[lsect][opts] = c.get(sect, opts)
    return data


def read_cfg(path):
    with open(path, "r") as f:
        stripcomments = "\n".join((l for l in f if not l.lstrip().startswith("#")))
        conf = json.loads(stripcomments)

    # Lowercase keys
    conf = dict((sect.lower(), dict(
        (subkey.lower(), val) for subkey, val in subsect.iteritems())
    ) for sect, subsect in conf.iteritems())

    return conf


def fallback_wsgi(environ, start_response, exc_info=None):

    # If server does not start, set up simple server, returning
    # Warden JSON compliant error message
    error=503
    message="Server not running due to initialization error"
    headers = [('Content-type', 'application/json')]

    logline = "Error(%d): %s" % (error, message)
    status = "%d %s" % (error, message)
    output = '{"error": %d, "message": "%s"}' % (
        error, message)

    logging.critical(logline)
    start_response(status, headers)
    return [output]


def build_server(conf):

    # Functions for validation and conversion of config values
    def facility(name):
        return int(getattr(logging.handlers.SysLogHandler, "LOG_" + name.upper()))

    def loglevel(name):
        return int(getattr(logging, name.upper()))

    def natural(name):
        num = int(name)
        if num<1:
            raise ValueError("Not a natural number")
        return num

    def filepath(name):
        # Make paths relative to dir of this script
        return path.join(path.dirname(__file__), name)

    def objdef(name):
        return objects[name.lower()]

    obj = objdef    # Draw into local namespace for init_obj

    objects = {}    # Already initialized objects

    # List of sections and objects, configured by them
    # First object in each object list is the default one, otherwise
    # "type" keyword in section may be used to choose other
    section_def = {
        "log": ["FileLogger", "SysLogger"],
        "db": ["Database"],
        "auth": ["X509Authenticator", "NoAuthenticator"],
        "validator": ["JSONSchemaValidator", "NoValidator"],
        "handler": ["WardenHandler"],
        "server": ["Server"]
    }

    # Object parameter conversions and defaults
    param_def = {
        "FileLogger": {
            "filename": {"type": filepath, "default": path.join(path.dirname(__file__), path.splitext(path.split(__file__)[1])[0] + ".log")},
            "level": {"type": loglevel, "default": "info"},
        },
        "SysLogger": {
            "socket": {"type": filepath, "default": "/dev/log"},
            "facility": {"type": facility, "default": "daemon"},
            "level": {"type": loglevel, "default": "info"}
        },
        "NoAuthenticator": {},
        "X509Authenticator": {
            "db": {"type": obj, "default": "db"}
        },
        "NoValidator": {},
        "JSONSchemaValidator": {
            "filename": {"type": filepath, "default": path.join(path.dirname(__file__), "idea.schema")}
        },
        "Database": {},
        "WardenHandler": {
            "validator": {"type": obj, "default": "validator"},
            "db": {"type": obj, "default": "DB"},
            "send_events_limit": {"type": natural, "default": 10000},
            "get_events_limit": {"type": natural, "default": 10000},
            "description": {"type": str, "default": ""}
        },
        "Server": {
            "auth": {"type": obj, "default": "auth"},
            "handler": {"type": obj, "default": "handler"}
        }
    }

    def init_obj(sect_name):
        config = conf.get(sect_name, {})
        sect_name = sect_name.lower()
        sect_def = section_def[sect_name]

        try:    # Object type defined?
            objtype = config["type"]
            del config["type"]
        except KeyError:    # No, fetch default object type for this section
            objtype = sect_def[0]
        else:
            if not objtype in sect_def:
                raise KeyError("Unknown type %s in section %s" % (objtype, sect_name))

        params = param_def[objtype]

        # No surplus parameters? Disallow also 'obj' attributes, these are only
        # to provide default referenced section
        for name in config:
            if name not in params or (name in params and params[name]["type"] is objdef):
                raise KeyError("Unknown key %s in section %s" % (name, sect_name))

        # Process parameters
        kwargs = {}
        for name, definition in params.iteritems():
            raw_val = config.get(name, definition["default"])
            try:
                val = definition["type"](raw_val)
            except Exception:
                raise KeyError("Bad value \"%s\" for %s in section %s" % (raw_val, name, sect_name))
            kwargs[name] = val

        cls = globals()[objtype]   # get class/function type
        try:
            obj = cls(**kwargs)         # run it
        except Exception as e:
            raise KeyError("Cannot initialize %s from section %s: %s" % (
                objtype, sect_name, str(e)))

        if isinstance(obj, Object):
            # Log only objects here, functions must take care of themselves
            logging.info("Initialized %s" % str(obj))

        objects[sect_name] = obj
        return obj

    # Init logging with at least simple stderr StreamLogger
    # Dunno if it's ok within wsgi, but we have no other choice, let's
    # hope it at least ends up in webserver error log
    StreamLogger()

    try:
        # Now try to init required objects
        for o in ("log", "db", "auth", "validator", "handler", "server"):
            init_obj(o)
    except Exception as e:
        logging.critical(str(e))
        logging.debug("", exc_info=sys.exc_info())
        return fallback_wsgi

    logging.info("Ready to serve")

    return objects["server"]


if __name__=="__main__":
    # FIXME: just development stuff
    srv = build_server(read_ini("warden3.cfg.wheezy-warden3"))