Skip to content
Snippets Groups Projects
warden_server.py 59.4 KiB
Newer Older
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011-2015 Cesnet z.s.p.o
# Use of this source is governed by a 3-clause BSD-style license, see LICENSE file.

import sys
import os
import logging
import logging.handlers
import ConfigParser
from traceback import format_tb
import M2Crypto.X509
import json
Pavel Kácha's avatar
Pavel Kácha committed
import MySQLdb as my
import MySQLdb.cursors as mycursors
import re
import email.utils
from urlparse import parse_qs
from os import path

# for local version of up to date jsonschema
sys.path.append(path.join(path.dirname(__file__), "..", "lib"))

Pavel Kácha's avatar
Pavel Kácha committed
VERSION = "3.0-beta2"

class Error(Exception):

    def __init__(self, method=None, req_id=None, errors=None, **kwargs):
        self.method = method
        self.errors = [kwargs] if kwargs else []
        if errors:
            self.errors.extend(errors)


    def append(self, _events=None, **kwargs):
        self.errors.append(kwargs)


    def get_http_err_msg(self):
        try:
            err = self.errors[0]["error"]
            msg = self.errors[0]["message"].replace("\n", " ")
        except (IndexError, KeyError):
            err = 500
            msg = "There's NO self-destruction button! Ah, you've just found it..."
        for e in self.errors:
            next_err = e.get("error", 500)
            if err != next_err:
                # errors not same, round to basic err code (400, 500)
                # and use the highest one
                err = max(err//100, next_err//100)*100
            next_msg = e.get("message", "Unknown error").replace("\n", " ")
            if msg != next_msg:
                msg = "Multiple errors"
        return err, msg
        return "\n".join(self.str_err(e) for e in self.errors)


    def log(self, logger, prio=logging.ERROR):
        for e in self.errors:
            logger.log(prio, self.str_err(e))
            info = self.str_info(e)
            if info:
                logger.info(info)
            debug = self.str_debug(e)
            if debug:
                logger.debug(debug)


    def str_err(self, e):
        out.append("Error(%s) %s " % (e.get("error", 0), e.get("message", "Unknown error")))
        if "exc" in e and e["exc"]:
            out.append("(cause was %s: %s)" % (e["exc"][0].__name__, str(e["exc"][1])))
    def str_info(self, e):
        ecopy = dict(e)    # shallow copy
        ecopy.pop("req_id", None)
        ecopy.pop("method", None)
        ecopy.pop("error", None)
        ecopy.pop("message", None)
        ecopy.pop("exc", None)
        if ecopy:
            out = "Detail: %s" % (json.dumps(ecopy, default=lambda v: str(v)))
        else:
            out = ""
        return out
        if not "exc" in e or not e["exc"]:
            return ""
        exc_tb = e["exc"][2]
        if exc_tb:
            out.append("Traceback:\n")
            out.extend(format_tb(exc_tb))
        return "".join(out)


    def to_dict(self):
        errlist = []
        for e in self.errors:
            ecopy = dict(e)
            ecopy.pop("exc", None)
            errlist.append(ecopy)
        d = {
            "method": self.method,
            "req_id": self.req_id,
            "errors": errlist
        }
        return d



def get_clean_root_logger(level=logging.INFO):
    """ Attempts to get logging module into clean slate state """

    # We want to be able to set up at least stderr logger before any
    # configuration is read, and then later get rid of it and set up
    # whatever administrator requires.
    # However, there can exist only one logger, but we want to get a clean
    # slate everytime we initialize StreamLogger or FileLogger... which
    # is not exactly supported by logging module.
    # So, we look directly inside logger class and clean up handlers/filters
    # manually.
    logger.setLevel(level)
    while logger.handlers:
        logger.removeHandler(logger.handlers[0])
    while logger.filters:
        logger.removeFilter(logger.filters[0])
def StreamLogger(stream=sys.stderr, level=logging.DEBUG):
    """ Fallback handler just for setup, not meant to be used from
        configuration file because during wsgi query stdout/stderr
        is forbidden.
    """

    fhand = logging.StreamHandler(stream)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
class LogRequestFilter(logging.Filter):
    """ Filter class, instance of which is added to logger class to add
        info about request automatically into every logline, no matter
        how it came into existence.
    """

    def __init__(self, req):
        logging.Filter.__init__(self)
        self.req = req


    def filter(self, record):
        if self.req.env:
Pavel Kácha's avatar
Pavel Kácha committed
            record.req_preamble = "%08x/%s: " % (self.req.req_id or 0, self.req.path)
def FileLogger(req, filename, level=logging.INFO):

    fhand = logging.FileHandler(filename)
    fform = logging.Formatter('%(asctime)s %(filename)s[%(process)d]: (%(levelname)s) %(req_preamble)s%(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
    logger.info("Initialized FileLogger(req=%s, filename=\"%s\", level=\"%d\")" % (type(req).__name__, filename, level))
def SysLogger(req, socket="/dev/log", facility=logging.handlers.SysLogHandler.LOG_DAEMON, level=logging.INFO):

    fhand = logging.handlers.SysLogHandler(address=socket, facility=facility)
    fform = logging.Formatter('%(filename)s[%(process)d]: (%(levelname)s) %(req_preamble)s%(message)s')
    fhand.setFormatter(fform)
    logger = get_clean_root_logger(level)
    logger.addHandler(fhand)
    logger.info("Initialized SysLogger(req=%s, socket=\"%s\", facility=\"%d\", level=\"%d\")" % (type(req).__name__, socket, facility, level))
Pavel Kácha's avatar
Pavel Kácha committed
    return logger
    ["id", "registered", "requestor", "hostname", "name",
    "secret", "valid", "read", "debug", "write", "test", "note"])
class Object(object):

    def __str__(self):
        return "%s()" % type(self).__name__
class Request(Object):
    """ Simple container for info about ongoing request.
        One instance gets created before server startup, and all other
        configured objects get it as parameter during instantiation.

        Server then takes care of populating this instance on the start
        of wsgi request (and resetting at the end). All other objects
        then can find this actual request info in their own self.req.

        However, only Server.wsgi_app, handler (WardenHandler) exposed
        methods and logging related objects should use self.req directly.
        All other objects should use self.req only as source of data for
        error/exception handling/logging, and should take/return
        necessary data as arguments/return values for clarity on
        which data their main codepaths work with.
    """

    def __init__(self):
        Object.__init__(self)
        self.reset()


    def __str__(self):
        return "%s()" % (type(self).__name__, str(self.env), str(self.client))


    def reset(self, env=None, client=None, path=None, req_id=None):
        self.env = env
        self.client = client
        self.path = path or ""
        if req_id is not None:
            self.req_id = req_id
        else:
            self.req_id = 0 if env is None else randint(0x00000000, 0xFFFFFFFF)
    def error(self, **kwargs):
        return Error(self.path, self.req_id, **kwargs)


    def __str__(self):
        return "%s(req=%s)" % (type(self).__name__, type(self.req).__name__)


class PlainAuthenticator(ObjectBase):
    def __init__(self, req, log, db):
        ObjectBase.__init__(self, req, log)
    def __str__(self):
        return "%s(req=%s, db=%s)" % (type(self).__name__, type(self.req).__name__, type(self.db).__name__)
    def authenticate(self, env, args, hostnames = None, check_secret = True):
        secret = args.get("secret", [None])[0] if check_secret else None

        client = self.db.get_client_by_name(hostnames, name, secret)

        if not client:
            self.log.info("authenticate: client not found by name: \"%s\", secret: %s, hostnames: %s" % (
                name, secret, str(hostnames)))
        # Clients with 'secret' set must get authenticated by it.
        # No secret turns secret auth off for this particular client.
        if client.secret is not None and secret is None and check_secret:
            self.log.info("authenticate: missing secret argument")
        self.log.info("authenticate: %s" % str(client))
        # These args are not for handler
        args.pop("client", None)
        args.pop("secret", None)

    def authorize(self, env, client, path, method):
        if method.debug:
            if not client.debug:
                self.log.info("authorize: failed, client does not have debug enabled")
        if method.read:
            if not client.read:
                self.log.info("authorize: failed, client does not have read enabled")
        if method.write:
            if not (client.write or client.test):
                self.log.info("authorize: failed, client is not allowed to write or test")
class X509Authenticator(PlainAuthenticator):

    def get_cert_dns_names(self, pem):

        cert = M2Crypto.X509.load_cert_string(pem)

        subj = cert.get_subject()
        commons = [n.get_data().as_text() for n in subj.get_entries_by_nid(subj.nid["CN"])]

        try:
           extstrs = cert.get_ext("subjectAltName").get_value().split(",")
        except LookupError:
           extstrs = []
        extstrs = [val.strip() for val in extstrs]
        altnames = [val[4:] for val in extstrs if val.startswith("DNS:")]

        # bit of mangling to get rid of duplicates and leave commonname first
        firstcommon = commons[0]
        return [firstcommon] + list(set(altnames+commons) - set([firstcommon]))


    def authenticate(self, env, args):
            cert_names = self.get_cert_dns_names(env["SSL_CLIENT_CERT"])
        except:
            exception = self.req.error(message="authenticate: cannot get or parse certificate from env", error=403, exc=sys.exc_info(), env=env)
            exception.log(self.log)
Michal Kostenec's avatar
Michal Kostenec committed
            return None
        return PlainAuthenticator.authenticate(self, env, args, hostnames = cert_names)
class X509NameAuthenticator(PlainAuthenticator):

    def authenticate(self, env, args):
        try:
            cert_name = env["SSL_CLIENT_S_DN_CN"]
        except:
            exception = self.req.error(message="authenticate: cannot get or parse certificate from env", error=403, exc=sys.exc_info(), env=env)
Pavel Kácha's avatar
Loading
Loading full blame...