Skip to content
Snippets Groups Projects
warden_server.py 61 KiB
Newer Older
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011-2015 Cesnet z.s.p.o
# Use of this source is governed by a 3-clause BSD-style license, see LICENSE file.

from __future__ import print_function

import os
import logging
import logging.handlers
import json
import re
from traceback import format_tb
import M2Crypto.X509
import MySQLdb as my
import MySQLdb.cursors as mycursors

if sys.version_info[0] >= 3:
    import configparser as ConfigParser
    from urllib.parse import parse_qs
    unicode = str

    def get_method_params(method):
        return method.__code__.co_varnames[:method.__code__.co_argcount]

else:
    import ConfigParser
    from urlparse import parse_qs
    def get_method_params(method):
        return method.func_code.co_varnames[:method.func_code.co_argcount]


# for local version of up to date jsonschema
sys.path.append(path.join(path.dirname(__file__), "..", "lib"))

Pavel Kácha's avatar
Pavel Kácha committed
VERSION = "3.0-beta3"
Pavel Kácha's avatar
Pavel Kácha committed

class Error(Exception):

    def __init__(self, method=None, req_id=None, errors=None, **kwargs):
        self.method = method
        self.errors = [kwargs] if kwargs else []
        if errors:
            self.errors.extend(errors)

    def append(self, _events=None, **kwargs):
        self.errors.append(kwargs)

    def get_http_err_msg(self):
        try:
            err = self.errors[0]["error"]
            msg = self.errors[0]["message"].replace("\n", " ")
        except (IndexError, KeyError):
            err = 500
            msg = "There's NO self-destruction button! Ah, you've just found it..."
        for e in self.errors:
            next_err = e.get("error", 500)
            if err != next_err:
                # errors not same, round to basic err code (400, 500)
                # and use the highest one
                err = max(err//100, next_err//100)*100
            next_msg = e.get("message", "Unknown error").replace("\n", " ")
            if msg != next_msg:
                msg = "Multiple errors"
        return err, msg

    def __str__(self):
        return "\n".join(self.str_err(e) for e in self.errors)

    def log(self, logger, prio=logging.ERROR):
        for e in self.errors:
            logger.log(prio, self.str_err(e))
            info = self.str_info(e)
            if info:
                logger.info(info)
            debug = self.str_debug(e)
            if debug:
                logger.debug(debug)

    def str_err(self, e):
        out.append("Error(%s) %s " % (e.get("error", 0), e.get("message", "Unknown error")))
        if "exc" in e and e["exc"]:
            out.append("(cause was %s: %s)" % (e["exc"][0].__name__, str(e["exc"][1])))
        return "".join(out)

    def str_info(self, e):
        ecopy = dict(e)    # shallow copy
        ecopy.pop("req_id", None)
        ecopy.pop("method", None)
        ecopy.pop("error", None)
        ecopy.pop("message", None)
        ecopy.pop("exc", None)
        if ecopy:
            out = "Detail: %s" % (json.dumps(ecopy, default=lambda v: str(v)))
        else:
            out = ""
        return out
        if not e.get("exc"):
            return ""
        exc_tb = e["exc"][2]
        if exc_tb:
            out.append("Traceback:\n")
            out.extend(format_tb(exc_tb))
        return "".join(out)

    def to_dict(self):
        errlist = []
        for e in self.errors:
            ecopy = dict(e)
            ecopy.pop("exc", None)
            errlist.append(ecopy)
        d = {
            "method": self.method,
            "req_id": self.req_id,
            "errors": errlist
        }
        return d


def get_clean_root_logger(level=logging.INFO):
    """ Attempts to get logging module into clean slate state """

    # We want to be able to set up at least stderr logger before any
    # configuration is read, and then later get rid of it and set up
    # whatever administrator requires.
    # However, there can exist only one logger, but we want to get a clean
    # slate everytime we initialize StreamLogger or FileLogger... which
    # is not exactly supported by logging module.
    # So, we look directly inside logger class and clean up handlers/filters
    # manually.
    logger.setLevel(level)
    while logger.handlers:
        logger.removeHandler(logger.handlers[0])
    while logger.filters:
        logger.removeFilter(logger.filters[0])
def StreamLogger(stream=sys.stderr, level=logging.DEBUG):
    """ Fallback handler just for setup, not meant to be used from
        configuration file because during wsgi query stdout/stderr
        is forbidden.
    """

    fhand = logging.StreamHandler(stream)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
class LogRequestFilter(logging.Filter):
    """ Filter class, instance of which is added to logger class to add
        info about request automatically into every logline, no matter
        how it came into existence.
    """

    def __init__(self, req):
        logging.Filter.__init__(self)
        self.req = req

    def filter(self, record):
        if self.req.env:
Pavel Kácha's avatar
Pavel Kácha committed
            record.req_preamble = "%08x/%s: " % (self.req.req_id or 0, self.req.path)
def FileLogger(req, filename, level=logging.INFO):

    fhand = logging.FileHandler(filename)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(req_preamble)s%(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
Pavel Kácha's avatar
Pavel Kácha committed
    logger.info("Initialized FileLogger(req=%r, filename=\"%s\", level=%s)" % (req, filename, level))
def SysLogger(req, socket="/dev/log", facility=logging.handlers.SysLogHandler.LOG_DAEMON, level=logging.INFO):

    fhand = logging.handlers.SysLogHandler(address=socket, facility=facility)
    fform = logging.Formatter('%(filename)s[%(process)d]: (%(levelname)s) %(req_preamble)s%(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
Pavel Kácha's avatar
Pavel Kácha committed
    logger.info("Initialized SysLogger(req=%r, socket=\"%s\", facility=\"%d\", level=%s)" % (req, socket, facility, level))
Pavel Kácha's avatar
Pavel Kácha committed
    return logger
Pavel Kácha's avatar
Pavel Kácha committed
Client = namedtuple("Client", [
    "id", "registered", "requestor", "hostname", "name",
    "secret", "valid", "read", "debug", "write", "test", "note"])
        attrs = get_method_params(self.__init__)[1:]
Pavel Kácha's avatar
Pavel Kácha committed
        eq_str = ["%s=%r" % (attr, getattr(self, attr, None)) for attr in attrs]
        return "%s(%s)" % (type(self).__name__, ", ".join(eq_str))
class Request(Object):
    """ Simple container for info about ongoing request.
        One instance gets created before server startup, and all other
        configured objects get it as parameter during instantiation.

        Server then takes care of populating this instance on the start
        of wsgi request (and resetting at the end). All other objects
        then can find this actual request info in their own self.req.

        However, only Server.wsgi_app, handler (WardenHandler) exposed
        methods and logging related objects should use self.req directly.
        All other objects should use self.req only as source of data for
        error/exception handling/logging, and should take/return
        necessary data as arguments/return values for clarity on
        which data their main codepaths work with.
    """
    def reset(self, env=None, client=None, path=None, req_id=None):
        self.env = env
        self.client = client
        self.path = path or ""
        if req_id is not None:
            self.req_id = req_id
        else:
            self.req_id = 0 if env is None else randint(0x00000000, 0xFFFFFFFF)
Pavel Kácha's avatar
Pavel Kácha committed
    __init__ = reset

    def error(self, **kwargs):
        return Error(self.path, self.req_id, **kwargs)
class PlainAuthenticator(ObjectBase):
    def __init__(self, req, log, db):
        ObjectBase.__init__(self, req, log)
Pavel Kácha's avatar
Pavel Kácha committed
    def authenticate(self, env, args, hostnames=None, check_secret=True):
        secret = args.get("secret", [None])[0] if check_secret else None

        client = self.db.get_client_by_name(hostnames, name, secret)

        if not client:
            self.log.info("authenticate: client not found by name: \"%s\", secret: %s, hostnames: %s" % (
                name, secret, str(hostnames)))
        # Clients with 'secret' set must get authenticated by it.
        # No secret turns secret auth off for this particular client.
        if client.secret is not None and secret is None and check_secret:
            self.log.info("authenticate: missing secret argument")
        self.log.info("authenticate: %s" % str(client))
        # These args are not for handler
        args.pop("client", None)
        args.pop("secret", None)

    def authorize(self, env, client, path, method):
        if method.debug:
            if not client.debug:
                self.log.info("authorize: failed, client does not have debug enabled")
        if method.read:
            if not client.read:
                self.log.info("authorize: failed, client does not have read enabled")
        if method.write:
            if not (client.write or client.test):
                self.log.info("authorize: failed, client is not allowed to write or test")
class X509Authenticator(PlainAuthenticator):

    def get_cert_dns_names(self, pem):

        cert = M2Crypto.X509.load_cert_string(pem)

        subj = cert.get_subject()
        commons = [n.get_data().as_text() for n in subj.get_entries_by_nid(subj.nid["CN"])]

Pavel Kácha's avatar
Pavel Kácha committed
            extstrs = cert.get_ext("subjectAltName").get_value().split(",")
Pavel Kácha's avatar
Pavel Kácha committed
            extstrs = []
        extstrs = [val.strip() for val in extstrs]
        altnames = [val[4:] for val in extstrs if val.startswith("DNS:")]

        # bit of mangling to get rid of duplicates and leave commonname first
        firstcommon = commons[0]
        return [firstcommon] + list(set(altnames+commons) - set([firstcommon]))

    def is_verified_by_apache(self, env, args):
        # Allows correct work while SSLVerifyClient both "optional" and "required"
        verify = env.get("SSL_CLIENT_VERIFY")
        if verify == "SUCCESS":
            return True
Pavel Kácha's avatar
Pavel Kácha committed
        exception = self.req.error(
            message="authenticate: certificate verification failed",
            error=403, args=args, ssl_client_verify=verify, cert=env.get("SSL_CLIENT_CERT"))
        exception.log(self.log)
        return False
    def authenticate(self, env, args):
        if not self.is_verified_by_apache(env, args):
            return None

            cert_names = self.get_cert_dns_names(env["SSL_CLIENT_CERT"])
        except:
Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: cannot get or parse certificate from env",
                error=403, exc=sys.exc_info(), env=env)
            exception.log(self.log)
Michal Kostenec's avatar
Michal Kostenec committed
            return None
Pavel Kácha's avatar
Pavel Kácha committed
        return PlainAuthenticator.authenticate(self, env, args, hostnames=cert_names)

class X509NameAuthenticator(X509Authenticator):
    def authenticate(self, env, args):
        if not self.is_verified_by_apache(env, args):
            return None

        try:
            cert_name = env["SSL_CLIENT_S_DN_CN"]
        except:
Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: cannot get or parse certificate from env",
                error=403, exc=sys.exc_info(), env=env)
            exception.log(self.log)
        if cert_name != args.setdefault("client", [cert_name])[0]:
Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: client name does not correspond with certificate",
                error=403, cn=cert_name, args=args)
            exception.log(self.log)
Pavel Kácha's avatar
Pavel Kácha committed
        return PlainAuthenticator.authenticate(self, env, args, check_secret=False)
class X509MixMatchAuthenticator(X509Authenticator):
    def __init__(self, req, log, db):
        PlainAuthenticator.__init__(self, req, log, db)
        self.hostname_auth = X509Authenticator(req, log, db)
        self.name_auth = X509NameAuthenticator(req, log, db)
        if not self.is_verified_by_apache(env, args):
            return None

            cert_name = env["SSL_CLIENT_S_DN_CN"]
Pavel Kácha's avatar
Pavel Kácha committed
            exception = self.req.error(
                message="authenticate: cannot get or parse certificate from env",
                error=403, exc=sys.exc_info(), env=env)
            exception.log(self.log)
Pavel Kácha's avatar
Pavel Kácha committed
        secret = args.get("secret", [None])[0]
        # Client names are in reverse notation than DNS, client name should
        # thus never be the same as machine hostname (if it is, client
        # admin does something very amiss).
        # So, if client sends the same name in query as in the certificate,
        # or sends no name or secret (which is necessary for hostname auth),
        # use X509NameAuthenticator. Otherwise (names are different and there
        # is name and/or secret in query) use (hostname) X509Authenticator.
        if name == cert_name or (name is None and secret is None):
            auth = self.name_auth
        else:
            auth = self.hostname_auth
        self.log.info("MixMatch is choosing %s (name: %s, cert_name: %s)" % (type(auth).__name__, name, cert_name))
    def __init__(self, req, log):
        ObjectBase.__init__(self, req, log)
    def check(self, event):
        return []


class JSONSchemaValidator(NoValidator):

    def __init__(self, req, log, filename=None):
        NoValidator.__init__(self, req, log)
        self.path = filename or path.join(path.dirname(__file__), "idea.schema")
        with io.open(self.path, "r", encoding="utf-8") as f:
            self.schema = json.load(f)
        self.validator = Draft4Validator(self.schema)

    def check(self, event):

        def sortkey(k):
            """ Treat keys as lowercase, prefer keys with less path segments """
            return (len(k.path), "/".join(str(k.path)).lower())

        res = []
        for error in sorted(self.validator.iter_errors(event), key=sortkey):
Loading
Loading full blame...