Skip to content
Snippets Groups Projects
test_warden_server.py 35.4 KiB
Newer Older
#!/usr/bin/python
"""Warden3 Server Test Suite"""

from __future__ import print_function
import argparse
import getpass
import sys
import warnings
from os import path
from warden_server import build_server
import warden_server

if sys.version_info >= (3, 10):
    import unittest
else:
    import unittest2 as unittest
if sys.version_info[0] >= 3:
    from io import StringIO
else:
    from StringIO import StringIO

USER = 'warden3test'
PASSWORD = 'h7w*D>4B)3omcvLM$oJp'
DB = 'w3test'

def setUpModule():      # pylint: disable = locally-disabled, invalid-name
    """Initialize the test database"""
    print(__doc__)


NO_PURGE = False

def tearDownModule():  # pylint: disable = locally-disabled, invalid-name
    """Clean up by purging the test database"""
    if not NO_PURGE:


class ReadableSTR(str):
    """Mission: To boldly quack like a buffer, like no str has quacked before"""
    def read(self, content_length=None):
        """Return own content"""

    if getattr(str, 'decode', None) is None:
        def decode(self, encoding="UTF-8", errors="strict"):    # pylint: disable = locally-disabled, unused-argument
            """For Py3 return own content, no decoding necessary"""
            return self


class Request(object):
    """Abstraction layer to perform an WSGI request"""
    def __init__(self, app, uri, payload=""):
        env = self.get_environ(uri, payload)
        self.status = None
        self.headers = None
        raw_out = app(env, self.start_response)
        self.out = [item.decode('ascii') for item in raw_out]

    def __call__(self):
        return self.status, self.headers, self.out

    @staticmethod
    def get_environ(uri, payload):
        """Prepares an (partial) environ for WSGI app, almost like an WSGI server would"""
        try:
            full_path, query_string = uri.split('?')
        except ValueError:
            full_path = uri
            query_string = ''
        path_info = '/' + full_path.split('/')[-1]
        env = {
            "REQUEST_URI": uri,
            "PATH_INFO": path_info,
            "QUERY_STRING": query_string,
            "SSL_CLIENT_VERIFY": "SUCCESS",
            "SSL_CLIENT_S_DN_CN": "cz.cesnet.warden3test",
            "SSL_CLIENT_CERT": "-----BEGIN CERTIFICATE-----\nMIIDgDCCAmgCCQDEG431XDXZjDANBgkqhkiG9w0BAQsFADCBgTELMAkGA1UEBhMCQ1oxFzAVBgNVBAoMDkNFU05FVCwgYS5sLmUuMQwwCgYDVQQLDAM3MDkxJTAjBgNVBAMMHHRlc3Quc2VydmVyLndhcmRlbi5jZXNuZXQuY3oxJDAiBgkqhkiG9w0BCQEWFXdhcmRlbi1pbmZvQGNlc25ldC5jejAeFw0xODA3MjMxMzMyMjFaFw0xODA4MjIxMzMyMjFaMIGBMQswCQYDVQQGEwJDWjEXMBUGA1UECgwOQ0VTTkVULCBhLmwuZS4xDDAKBgNVBAsMAzcwOTElMCMGA1UEAwwcdGVzdC5zZXJ2ZXIud2FyZGVuLmNlc25ldC5jejEkMCIGCSqGSIb3DQEJARYVd2FyZGVuLWluZm9AY2VzbmV0LmN6MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvgwOv1bv44hyWF7UDAPGdm+PqcbITi/6SVEfCENbMx6DAT+M3ZJlg7aOZyiZ16CRNxrjWizXYYY1H+NhOvlPZwsBcHFvnaBrcBciURMW6AQ+OiIHUONDUV7zqTcyiZ6NDMoNy472UpfNBMYXMtaUjPO33aRYwtl+QjoivU8bhzcSxyr/4P6WnZ7rW2nuHWfUNcGWGVxsRw7E2r4OY3Yr6M4SjKEDTEalByApoOYj2s3oEmeiNPjxKhN0wgD4h38+HcnpmKGZLNFbOEdT/7luA6IwzJ7l0p4ktjgCl/x3/Y6ZBrIZuFCNxjYrdfciD27LmcA5A6nEJ083fa4d+O/H8QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBc6EtV6FYnFBd735h4zwe2SIaFs2bu1d6COsOsaWe5loInI+oEATThaBlA9QiVamikkug3t2wgro8YcYhp0CMPN1gMxR6GstrBrKafprWp/Dv3+IP8RY+Z2lJ0ivw1MTMipqsCMiB+Lvs2wRVV3xBIXslgI3dbceZXos2bj6CPf3Frho7Z7oRaHetI+1a0T9QqZSug7dUSmYNCd9ZXQ8kFzU3eCFP0JKMqOy75KHIE00xowarDDFjTyyPoHmZviIOsY8ByKGNRDQz/WnZWzghAQjb+7tTFm2deOQua0XIyO7GSIU2xdGbTje4wA3/YiWhkpF8HWpCEAN8G6sMTDEXF\n-----END CERTIFICATE-----",      # pylint: disable = locally-disabled, line-too-long
            "wsgi.input": ReadableSTR(payload),
            "CONTENT_LENGTH": len(payload)
            }
        return env

    def start_response(self, status, headers):
        """Mocked start_response to record returned status and headers"""
        self.status = status
        self.headers = headers


class Warden3ServerTest(unittest.TestCase):
    """High level Warden3 Server tests"""

    getInfo_interface_tests_specific = [
        ("/getInfo", "403 I'm watching. Authenticate."),
        ("/getInfo?client=", "403 I'm watching. Authenticate."),
        ("/getInfo?client=cz.cesnet.warden3test", "403 I'm watching. Authenticate."),
        ("/getInfo?client=cz.cesnet.warden3test&secret=123", "403 I'm watching. Authenticate."),
        ("/getInfo?secret=123", "403 I'm watching. Authenticate."),
    ]
    getEvents_interface_tests_specific = [
        ("/getEvents", "403 I'm watching. Authenticate.", None),
        ("/getEvents?secret=123", "403 I'm watching. Authenticate.", None),
    ]

    @staticmethod
    def get_config():
        return {
            'log': {'level': 'debug'},
            'validator': {'type': 'NoValidator'},
            'auth': {'type': 'PlainAuthenticator'},
            'db': {'type': DBMS.name, 'user': USER, 'password': PASSWORD, 'dbname': DB},
            'handler': {'description': 'Warden Test Server'}
        }

    @classmethod
    def setUpClass(cls):
        """Pre-test cleanup"""
        cls.clean_lastid()
        cls.app = build_server(cls.get_config())

    @classmethod
    def clean_lastid(cls):
        """Cleans the lastid information for all clients"""

    def test_getInfo_interface(self):       # pylint: disable = locally-disabled, invalid-name
        """Tests the getInfo method invocation"""
        tests_common = [
            ("/getInfo?secret=abc", "200 OK"),
            ("/getInfo?secret=abc&evil=false", "200 OK"),       # RFC3514
            ("/getInfo?client=cz.cesnet.warden3test&secret=abc", "200 OK"),
            ("/getInfo?client=asdf.blefub", "403 I'm watching. Authenticate."),
            ("/getInfo?client=asdf.blefub&secret=abc", "403 I'm watching. Authenticate."),
            ("/getInfo?secret=abc&self=test", "200 OK"),        # Internal parameter
        ]
        for query, expected_status in tests_common + self.getInfo_interface_tests_specific:
            with self.subTest(query=query, expected_status=expected_status):
                status, _, _ = Request(self.app, query)()
                self.assertEqual(status, expected_status)

    def test_getEvents_interface(self):     # pylint: disable = locally-disabled, invalid-name
        """Tests the getEvents method invocation"""
        tests_common = [
            ("/getEvents?secret=abc", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?client=foo", "403 I'm watching. Authenticate.", None),
            ("/getEvents?secret=abc&foo=bar", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&lastid=1", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&lastid=0", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&lastid=9", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&cat=bflm", "422 Wrong tag or category used in query.", None),
            ("/getEvents?secret=abc&cat=Other", "200 OK", None),
            ("/getEvents?secret=abc&tag=Other", "200 OK", None),
            ("/getEvents?secret=abc&group=Other", "200 OK", None),
            ("/getEvents?secret=abc&cat=Other&nocat=Test", "422 Unrealizable conditions. Choose cat or nocat option.", None),
            ("/getEvents?secret=abc&tag=Other&notag=Test", "422 Unrealizable conditions. Choose tag or notag option.", None),
            ("/getEvents?secret=abc&group=Other&nogroup=Test", "422 Unrealizable conditions. Choose group or nogroup option.", None),
            ("/getEvents?client=cz.cesnet.warden3test&secret=abc&count=3&id=10", "200 OK", None)
        ]
        for query, expected_status, expected_response in tests_common + self.getEvents_interface_tests_specific:
            with self.subTest(query=query, expected_status=expected_status, expected_response=expected_response):
                status, _, out = Request(self.app, query)()
                self.assertEqual(status, expected_status)
                if expected_response is not None:
                    self.assertEqual(out, expected_response)

    def test_getDebug_interface(self):      # pylint: disable = locally-disabled, invalid-name
        """Tests the getDebug method invocation"""
        tests = [
            ("/getDebug?secret=abc", "200 OK"),
            ("/getDebug?client=cz.cesnet.warden3test&secret=abc", "200 OK"),
            ("/getDebug?secret=abc&self=test", "200 OK"),
        ]
        for query, expected_status in tests:
            with self.subTest(query=query, expected_status=expected_status):
                status, _, _ = Request(self.app, query)()
                self.assertEqual(status, expected_status)

    def test_methods(self):
        """Tests application behaviour in method parsing"""
        tests = [
            ("", "404 You've fallen off the cliff."),
            ("/blefub?client=client&secret=secret", "404 You've fallen off the cliff."),
            ("/?client=client&secret=secret", "404 You've fallen off the cliff."),
        ]
        for query, expected_status in tests:
            with self.subTest(query=query, expected_status=expected_status):
                status, _, _ = Request(self.app, query)()
                self.assertEqual(status, expected_status)

    def test_payload(self):
        """Tests parsing of transported data"""
        tests = [
            ("/getInfo?secret=abc", "", "200 OK", None),
            ("/getInfo?secret=abc", "[1]", "200 OK", None),
            ("/getInfo?secret=abc", "{#$%^", "200 OK", None),
            ("/sendEvents?secret=abc", "", "200 OK", ['{"saved": 0}']),
            ("/sendEvents?secret=abc", "{'test': 'true'}", "400 Deserialization error.", None),
            ("/sendEvents?secret=abc", '{"test": "true"}', "400 List of events expected.", None),
Jakub Maloštík's avatar
Jakub Maloštík committed
            ("/sendEvents?secret=abc", '[{"test": "true"}]', "422 Missing IDEA ID", None),
            ("/sendEvents?secret=abc", '[{"test": "true", "ID": "120820201142"}]', "422 Event does not bear valid Node attribute", None),
            ("/sendEvents?secret=abc", '[{"Node": ["test", "test2"], "ID": "120820201142"}]', "422 Event does not bear valid Node attribute", None),
            ("/sendEvents?secret=abc", '[{"Node": ["Name", "test"], "ID": "120820201142"}]', "422 Event does not bear valid Node attribute", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name"}], "ID": "120820201142"}]', "400 Deserialization error.", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "test"}], "ID": "120820201142"}]', "422 Node does not correspond with saving client", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "test"}], "ID": "verylongideaidverylongideaidverylongideaidverylongideaidverylongideaid"}]', "422 The provided event ID is too long", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "verylongideaidverylongideaidverylongideaidverylongideaidverylongideaid"}]', "422 The provided event ID is too long", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "ideaidcontaininga\\u0000byte"}]', "422 IDEA ID cannot contain null bytes", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "verylongideaidverylongideaid\\u0000verylongideaidverylongideaidverylongideaid"}]', "422 Multiple errors", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201142"}]', "200 OK", ['{"saved": 1}']),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201142"}]', "409 IDEA event with this ID already exists", None),
            (
                "/sendEvents?secret=abc",
                '['
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201142"}, '
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201143"}'
                ']',
                "409 IDEA event with this ID already exists",
                None
            ),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201143"}]', "409 IDEA event with this ID already exists", None),
            (
                "/sendEvents?secret=abc",
                '['
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": ""}, '
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201144"}'
                ']',
                "422 The provided IDEA ID is invalid",
                None
            ),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201144"}]', "409 IDEA event with this ID already exists", None),
            (
                "/sendEvents?secret=abc",
                '['
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201145"}, '
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201145"}'
                ']',
                "409 IDEA event with this ID already exists",
                None
            ),
            (
                "/sendEvents?secret=abc",
                '['
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}]}, '
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "verylongideaidverylongideaidverylongideaidverylongideaidverylongideaid"}'
                ']',
                '422 Multiple errors',
                None
            ),
            (
                "/sendEvents?secret=abc",
                '['
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201146"}, '
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201147"}, '
                    '{"Node": [{"Name": "cz.cesnet.warden3test"}], "ID": "120820201148"}'
                ']',
                "200 OK",
                ['{"saved": 3}']
            )
        ]
        for query, payload, expected_status, expected_response in tests:
            with self.subTest(query=query, payload=payload, expected_status=expected_status, expected_response=expected_response):
                status, _, out = Request(self.app, query, payload)()
                self.assertEqual(status, expected_status)
                if expected_response is not None:
                    self.assertEqual(out, expected_response)


class X509AuthenticatorTest(Warden3ServerTest):
    """Performs the basic test suite using the X509Authenticator"""

    @staticmethod
    def get_config():
        config = Warden3ServerTest.get_config()
        config['auth']['type'] = 'X509Authenticator'
        return config


class X509NameAuthenticatorTest(Warden3ServerTest):
    """Performs the basic test suite using the X509NameAuthenticator"""

    @staticmethod
    def get_config():
        config = Warden3ServerTest.get_config()
        config['auth']['type'] = 'X509NameAuthenticator'
        return config

    getInfo_interface_tests_specific = [
        ("/getInfo", "200 OK"),
        ("/getInfo?client=", "200 OK"),
        ("/getInfo?client=cz.cesnet.warden3test", "200 OK"),
        ("/getInfo?client=cz.cesnet.warden3test&secret=123", "200 OK"),
        ("/getInfo?secret=123", "200 OK"),
    ]
    getEvents_interface_tests_specific = [
        ("/getEvents", "200 OK", None),
        ("/getEvents?secret=123", "200 OK", None),
    ]


class WScliTest(unittest.TestCase):
    """Tester of the Warden Server command line interface"""
    @classmethod
    def setUpClass(cls):
        cls.config = {
            'log': {'level': 'debug'},
            'validator': {'type': 'NoValidator'},
            'auth': {'type': 'PlainAuthenticator'},
            'db': {'type': DBMS.name, 'user': USER, 'password': PASSWORD, 'dbname': DB},
            'handler': {'description': 'Warden Test Server'}
        }
        warden_server.server = build_server(cls.config)

    @staticmethod
    def do_cli(command_line):
        """Performs the command line action requested by argv and presents the results"""
        argv_backup = sys.argv
        sys.argv = command_line
        out = StringIO()
        err = StringIO()
        sys.stdout = out
        sys.stderr = err
        try:
            args = warden_server.get_args()
            command = args.command
            subargs = vars(args)
            del subargs["command"]
            del subargs["config"]
            ret = command(**subargs)
        except SystemExit as sys_exit:
            ret = sys_exit.code
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
        sys.argv = argv_backup
        return ret, out.getvalue(), err.getvalue()

    def test_list(self):
        """Tests the list command line option"""
        tests = [
            (['list'], 0, 4),
            (['list', '--id=1'], 0, 4),
            (['list', '--id=1000'], 0, 3),
            (['list', '--id', '1'], 0, 4),
            (['list', '--id', '1000'], 0, 3),
        ]
        for supplied_arguments, expected_return, output_lines in tests:
            with self.subTest(supplied_arguments=supplied_arguments, expected_return=expected_return, output_lines=output_lines):
                ret, out, _ = self.do_cli(['./warden_server.py'] + supplied_arguments)
                self.assertEqual(ret, expected_return)
                self.assertEqual(len(out.split('\n')), output_lines)

    def test_register_modify(self):
        """Tests the client registration and its modification"""
        tests = [
            (['register', '-n', 'cz.cesnet.warden.test2', '-h', 'test2.warden.cesnet.cz', '-r', 'warden-info@cesnet.cz'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--novalid'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 0, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--valid'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--valid', '--novalid'], 2,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--read'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--noread', '--write'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 0, 0, 1, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--debug', '--read'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 1, 1, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--notest', '--nodebug'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '--notest', '--nodebug'], 2,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', '1000', '--notest', '--nodebug'], 251,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-n', 'cz.cesnet.warden.test3'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-n', '..'], 254,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-h', 'test3.warden.cesnet.cz'], 0,
             (('warden-info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-h', ''.zfill(256)], 253,
             (('warden-info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-h', '..'], 253,
             (('warden-info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'warden-info@cesnet.cz, info@cesnet.cz'], 0,
             (('warden-info@cesnet.cz, info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'warden-info@cesnet.cz ,info@cesnet.cz'], 0,
             (('warden-info@cesnet.cz ,info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'Warden Info <warden-info@cesnet.cz>'], 0,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'Other Info <other-info@x.cz'], 252,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'Other other@x.cz'], 252,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'a@b, '], 252,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'a@b'], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', '@'], 252,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'abc'], 252,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'a@b@c'], 252,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-n', 'cz.cesnet.warden.test3'], 250,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-s', 'abc'], 249,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-s', 'top_secret'], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-s', 'top_secret'], 249,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '--note', ''.zfill(1024)], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, ''.zfill(1024)),)),
            (['modify', '-i', 'CLIENT_ID', '--note', 'Valid until: 18.01.2038'], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, 'Valid until: 18.01.2038'),)),
            (['modify', '-i', 'CLIENT_ID', '--note', 'Valid until:', '20.1.2038'], 2,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, 'Valid until: 18.01.2038'),)),
        ]
        client_id = None
        for supplied_arguments, expected_return, expected_sql_result in tests:
            with self.subTest(supplied_arguments=supplied_arguments, expected_return=expected_return, expected_sql_result=expected_sql_result):
                supplied_arguments = [entry.replace('CLIENT_ID', str(client_id)) for entry in supplied_arguments]
                ret, out, _ = self.do_cli(['./warden_server.py'] + supplied_arguments)
                self.assertEqual(ret, expected_return)
                try:
                    client_id = int(out.split('\n')[-2].split(' ')[0])
                except IndexError:      # No modification was performed, keep the previous client_id
                    pass
                result = DBMS.do_sql_select(DBMS.reg_mod_test_query, (client_id,))
                self.assertEqual(result, expected_sql_result)


class MySQL:
    name = "MySQL"
    reg_mod_test_query = "SELECT requestor, hostname, name, secret, valid, clients.read, " \
                         "debug, clients.write, test, note FROM clients WHERE id = %s"

    def __init__(self, user=USER, password=PASSWORD, dbname=DB):
        import MySQLdb as my
        self.my = my

        self.user = user
        self.password = password
        self.dbname = dbname

    def init_user(self):
        """DB user rights setup"""
        conn = None
        try:
            conn = self.my.connect(user='root', passwd=getpass.getpass(
                'Enter MySQL Root password:'))
            with conn.cursor() as cur:
                cur.execute(
                    "CREATE USER IF NOT EXISTS %s@'localhost' IDENTIFIED BY %s",
                    (self.user, self.password)
                )
                cur.execute(
                    "GRANT SELECT, INSERT, UPDATE, CREATE, DELETE, DROP ON *.* TO %s@'localhost'",
                    (self.user,)
                )
            conn.commit()
            print("DB User set up successfuly")
        except self.my.OperationalError as ex:
            if conn:
                conn.rollback()
                conn.close()
                conn = None
            print('Connection unsuccessful, bad password? Original exception: %s' % (str(ex)))
            exit()
        except KeyboardInterrupt:
            print("\nCancelled!")
            exit()
        finally:
            if conn:
                conn.close()

    def set_up(self):
        conn = None
        try:
            conn = self.my.connect(user=self.user, passwd=self.password)
            cur = conn.cursor()
            with warnings.catch_warnings(): # The database is not supposed to exist
                warnings.simplefilter("ignore")
                cur.execute("DROP DATABASE IF EXISTS %s" % (self.dbname,))   # NOT SECURE
                cur.execute("CREATE DATABASE %s" % (self.dbname,))           # NOT SECURE
                cur.execute("USE %s" % (self.dbname,))                       # NOT SECURE
                with open(path.join(path.dirname(__file__), 'warden_3.0_mysql.sql')) as script:
                    statements = ''.join(
                        [line.replace('\n', '') for line in script if line[0:2] != '--']
                    ).split(';')[:-1]
                    for statement in statements:
                        cur.execute(statement)
                cur.execute(
                    "INSERT INTO clients VALUES("
                    "NULL, NOW(), 'warden-info@cesnet.cz', 'test.server.warden.cesnet.cz',"
                    "NULL, 1, 'cz.cesnet.warden3test', 'abc', 1, 1, 1, 0"
                    ")"
                )
                conn.commit()
        except self.my.OperationalError as ex:
            if conn:
                conn.rollback()
                conn.close()
                conn = None
            print('Setup failed, have you tried --init ? Original exception: %s' % (str(ex),))
            exit()
        finally:
            if conn:
                conn.close()

    def do_sql_select(self, query, params):
        """Reads data from database"""
        conn = self.my.connect(user=self.user, passwd=self.password, db=self.dbname)
        cur = conn.cursor()
        cur.execute(query, params)
        result = cur.fetchall()
        cur.close()
        conn.close()
        return result

    def tear_down(self):
        """Clean up by purging the test database"""
        conn = self.my.connect(user=self.user, passwd=self.password)
        cur = conn.cursor()
        cur.execute("DROP DATABASE IF EXISTS %s" % (self.dbname,))       # NOT SECURE
        conn.commit()
        conn.close()

    def clean_lastid(self):
        """Cleans the lastid information for all clients"""
        conn = self.my.connect(user=self.user, passwd=self.password, db=self.dbname)
        cur = conn.cursor()
        cur.execute("DELETE FROM last_events")
        cur.execute("DELETE FROM events")
        cur.close()
        conn.commit()
class PostgreSQL:
    name = "PostgreSQL"
    reg_mod_test_query = "SELECT requestor, hostname, name, secret, valid, clients.read, " \
                         "debug, clients.write, test, note FROM clients WHERE id = %s"

    def __init__(self, user=USER, password=PASSWORD, dbname=DB):
        import psycopg2 as ppg
        from psycopg2 import sql as ppgsql
        self.ppg = ppg
        self.ppgsql = ppgsql

        self.user = user
        self.password = password
        self.dbname = dbname

    def init_user(self):
        """DB user rights setup"""
        running_as_postgres = getpass.getuser() == "postgres"
        conn = None
        try:
            password = None if running_as_postgres else getpass.getpass("Enter PostgreSQL password for the user 'postgres':")
            conn = self.ppg.connect(user="postgres", password=password)
            with conn.cursor() as cur:
                cur.execute(
                    self.ppgsql.SQL("CREATE ROLE {} PASSWORD {} CREATEDB LOGIN").format(
                        self.ppgsql.Identifier(self.user),
                        self.ppgsql.Placeholder()
                    ),
                    (self.password,)
                )
            conn.commit()
            print("DB User set up successfuly")
        except self.ppg.OperationalError as ex:
            if conn:
                conn.rollback()
                conn.close()
                conn = None
            if running_as_postgres:
                print("Connection unsuccessful. Original exception: %s" % (str(ex)))
            else:
                print("Connection unsuccessful, bad password or meant to run as the user 'postgres'"
                      " (su postgres -c '%s --dbms PostgreSQL --init')? Original exception: %s" %
                      (path.join('.', path.normpath(sys.argv[0])), str(ex)))
            exit()
        except KeyboardInterrupt:
            print("\nCancelled!")
            exit()
        finally:
            if conn:
                conn.close()

    def _load_tags(self, cur):
        with open(path.join(path.dirname(__file__), "tagmap_db.json")) as tagmapf:
            tagmap = json.load(tagmapf)
        for tag, num in tagmap.items():
            cur.execute(
                "INSERT INTO tags(id, tag) VALUES (%s, %s)", (num, tag))

    def _load_cats(self, cur):
        with open(path.join(path.dirname(__file__), "catmap_db.json")) as catmapf:
            catmap = json.load(catmapf)
        for cat_subcat, num in catmap.items():
            catsplit = cat_subcat.split(".", 1)
            category = catsplit[0]
            subcategory = catsplit[1] if len(catsplit) > 1 else None
            cur.execute(
                "INSERT INTO categories(id, category, subcategory, cat_subcat) VALUES (%s, %s, %s, %s)",
                (num, category, subcategory, cat_subcat)
            )

    def set_up(self):
        conn = None
        try:
            conn = self.ppg.connect(user=self.user, password=self.password,
                                    host='localhost', dbname='postgres')
            conn.autocommit = True
            cur = conn.cursor()
            cur.execute(
                self.ppgsql.SQL("DROP DATABASE IF EXISTS {}").format(
                    self.ppgsql.Identifier(self.dbname)
                )
            )
            cur.execute(
                self.ppgsql.SQL("CREATE DATABASE {}").format(
                    self.ppgsql.Identifier(self.dbname)
                )
            )
        except self.ppg.OperationalError as ex:
            if conn:
                conn.rollback()
                conn.close()
                conn = None
            print(
                'Setup failed, have you tried --init ? Original exception: %s' % (str(ex),))
            exit()
        finally:
            if conn:
                conn.close()
        conn = None
        try:
            conn = self.ppg.connect(user=self.user, password=self.password,
                                    dbname=self.dbname, host='localhost')
            cur = conn.cursor()
            with open(path.join(path.dirname(__file__), 'warden_3.0_postgres.sql')) as script:
                statements = script.read()
            cur.execute(statements)

            self._load_tags(cur)
            self._load_cats(cur)

            cur.execute(
                "INSERT INTO clients "
                "(registered, requestor, hostname, note, valid,"
                " name, secret, read, debug, write, test) "
                "VALUES(NOW(), 'warden-info@cesnet.cz', 'test.server.warden.cesnet.cz', "
                       "NULL, true, 'cz.cesnet.warden3test', 'abc', true, true, true, false)"
            )
            conn.commit()
        except self.ppg.OperationalError as ex:
            if conn:
                conn.rollback()
                conn.close()
                conn = None
            print(
                'Something went wrong during database setup. Original exception: %s' % (str(ex),))
            exit()
        finally:
            if conn:
                conn.close()

    def do_sql_select(self, query, params):
        """Reads data from database"""
        conn = self.ppg.connect(user=self.user, password=self.password,
                                dbname=self.dbname, host='localhost')
        cur = conn.cursor()
        cur.execute(query, params)
        result = cur.fetchall()
        cur.close()
        conn.close()
        return tuple(result)

    def tear_down(self):
        """Clean up by purging the test database"""
        conn = self.ppg.connect(user=self.user, password=self.password,
                                dbname='postgres', host='localhost')
        conn.autocommit = True
        cur = conn.cursor()
        cur.execute(
            self.ppgsql.SQL("DROP DATABASE IF EXISTS {} WITH(FORCE)").format(
                self.ppgsql.Identifier(self.dbname)
            )
        )
        conn.close()

    def clean_lastid(self):
        """Cleans the lastid information for all clients"""
        conn = self.ppg.connect(
            user=self.user, password=self.password, dbname=self.dbname, host='localhost')
        cur = conn.cursor()
        cur.execute("DELETE FROM last_events")
        cur.execute("DELETE FROM events")
        cur.close()
        conn.commit()
        conn.close()


    'MySQL': MySQL,
    'PostgreSQL': PostgreSQL


def main():
    """Parses arguments and acts accordingly"""
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('-d', '--dbms', default='MySQL', choices=database_types, help='Database management system to use for testing')
    parser.add_argument('-i', '--init', action='store_true', help='Set up an user with rights to CREATE/DROP the test database')
    parser.add_argument('-n', '--nopurge', action='store_true', help='Skip the database purge after running the tests')

    args = parser.parse_args()

    global DBMS    # pylint: disable = locally-disabled, global-statement
    DBMS = database_types[args.dbms](USER, PASSWORD, DB)

    if args.init:
    else:
        if args.nopurge:
            global NO_PURGE    # pylint: disable = locally-disabled, global-statement
            NO_PURGE = True
        sys.argv = [sys.argv[0]]
        unittest.main()


if __name__ == "__main__":
    main()