Skip to content
Snippets Groups Projects
test_warden_server.py 24 KiB
Newer Older
#!/usr/bin/python
"""Warden3 Server Test Suite"""

from __future__ import print_function
import argparse
import getpass
import sys
import warnings
from os import path
from copy import deepcopy
import MySQLdb as my
from warden_server import build_server
import warden_server

if sys.version_info >= (3, 10):
    import unittest
else:
    import unittest2

if sys.version_info[0] >= 3:
    from io import StringIO
else:
    from StringIO import StringIO

USER = 'warden3test'
PASSWORD = 'h7w*D>4B)3omcvLM$oJp'
DB = 'w3test'


def setUpModule():      # pylint: disable = locally-disabled, invalid-name
    """Initialize the test database"""
    print(__doc__)
    conn = None
    try:
        conn = my.connect(user=USER, passwd=PASSWORD)
        cur = conn.cursor()
        with warnings.catch_warnings():                         # The database is not supposed to exist
            warnings.simplefilter("ignore")
            cur.execute("DROP DATABASE IF EXISTS %s" % (DB,))   # NOT SECURE
            cur.execute("CREATE DATABASE %s" % (DB,))           # NOT SECURE
            cur.execute("USE %s" % (DB,))                       # NOT SECURE
            with open(path.join(path.dirname(__file__), 'warden_3.0.sql')) as script:
                statements = ''.join([line.replace('\n', '') for line in script if line[0:2] != '--']).split(';')[:-1]
                for statement in statements:
                    cur.execute(statement)
            cur.execute("INSERT INTO clients VALUES(NULL, NOW(), 'warden-info@cesnet.cz', 'test.server.warden.cesnet.cz', NULL, 1, 'cz.cesnet.warden3test', 'abc', 1, 1, 1, 0)")
            conn.commit()
    except my.OperationalError as ex:
        if conn:
            conn.rollback()
            conn.close()
            conn = None
        print('Setup failed, have you tried --init ? Original exception: %s' % (str(ex),))
        exit()
    finally:
        if conn:
            conn.close()


NO_PURGE = False


def tearDownModule():  # pylint: disable = locally-disabled, invalid-name
    """Clean up by purging the test database"""
    if not NO_PURGE:
        conn = my.connect(user=USER, passwd=PASSWORD)
        cur = conn.cursor()
        cur.execute("DROP DATABASE IF EXISTS %s" % (DB,))       # NOT SECURE
        conn.commit()
        conn.close()


class ReadableSTR(str):
    """Mission: To boldly quack like a buffer, like no str has quacked before"""
    def read(self, content_length=None):
        """Return own content"""

    if getattr(str, 'decode', None) is None:
        def decode(self, encoding="UTF-8", errors="strict"):    # pylint: disable = locally-disabled, unused-argument
            """For Py3 return own content, no decoding necessary"""
            return self


class Request(object):
    """Abstraction layer to perform an WSGI request"""
    def __init__(self, app, uri, payload=""):
        env = self.get_environ(uri, payload)
        self.status = None
        self.headers = None
        raw_out = app(env, self.start_response)
        self.out = [item.decode('ascii') for item in raw_out]

    def __call__(self):
        return self.status, self.headers, self.out

    @staticmethod
    def get_environ(uri, payload):
        """Prepares an (partial) environ for WSGI app, almost like an WSGI server would"""
        try:
            full_path, query_string = uri.split('?')
        except ValueError:
            full_path = uri
            query_string = ''
        path_info = '/' + full_path.split('/')[-1]
        env = {
            "REQUEST_URI": uri,
            "PATH_INFO": path_info,
            "QUERY_STRING": query_string,
            "SSL_CLIENT_VERIFY": "SUCCESS",
            "SSL_CLIENT_S_DN_CN": "cz.cesnet.warden3test",
            "SSL_CLIENT_CERT": "-----BEGIN CERTIFICATE-----\nMIIDgDCCAmgCCQDEG431XDXZjDANBgkqhkiG9w0BAQsFADCBgTELMAkGA1UEBhMCQ1oxFzAVBgNVBAoMDkNFU05FVCwgYS5sLmUuMQwwCgYDVQQLDAM3MDkxJTAjBgNVBAMMHHRlc3Quc2VydmVyLndhcmRlbi5jZXNuZXQuY3oxJDAiBgkqhkiG9w0BCQEWFXdhcmRlbi1pbmZvQGNlc25ldC5jejAeFw0xODA3MjMxMzMyMjFaFw0xODA4MjIxMzMyMjFaMIGBMQswCQYDVQQGEwJDWjEXMBUGA1UECgwOQ0VTTkVULCBhLmwuZS4xDDAKBgNVBAsMAzcwOTElMCMGA1UEAwwcdGVzdC5zZXJ2ZXIud2FyZGVuLmNlc25ldC5jejEkMCIGCSqGSIb3DQEJARYVd2FyZGVuLWluZm9AY2VzbmV0LmN6MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvgwOv1bv44hyWF7UDAPGdm+PqcbITi/6SVEfCENbMx6DAT+M3ZJlg7aOZyiZ16CRNxrjWizXYYY1H+NhOvlPZwsBcHFvnaBrcBciURMW6AQ+OiIHUONDUV7zqTcyiZ6NDMoNy472UpfNBMYXMtaUjPO33aRYwtl+QjoivU8bhzcSxyr/4P6WnZ7rW2nuHWfUNcGWGVxsRw7E2r4OY3Yr6M4SjKEDTEalByApoOYj2s3oEmeiNPjxKhN0wgD4h38+HcnpmKGZLNFbOEdT/7luA6IwzJ7l0p4ktjgCl/x3/Y6ZBrIZuFCNxjYrdfciD27LmcA5A6nEJ083fa4d+O/H8QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBc6EtV6FYnFBd735h4zwe2SIaFs2bu1d6COsOsaWe5loInI+oEATThaBlA9QiVamikkug3t2wgro8YcYhp0CMPN1gMxR6GstrBrKafprWp/Dv3+IP8RY+Z2lJ0ivw1MTMipqsCMiB+Lvs2wRVV3xBIXslgI3dbceZXos2bj6CPf3Frho7Z7oRaHetI+1a0T9QqZSug7dUSmYNCd9ZXQ8kFzU3eCFP0JKMqOy75KHIE00xowarDDFjTyyPoHmZviIOsY8ByKGNRDQz/WnZWzghAQjb+7tTFm2deOQua0XIyO7GSIU2xdGbTje4wA3/YiWhkpF8HWpCEAN8G6sMTDEXF\n-----END CERTIFICATE-----",      # pylint: disable = locally-disabled, line-too-long
            "wsgi.input": ReadableSTR(payload),
            "CONTENT_LENGTH": len(payload)
            }
        return env

    def start_response(self, status, headers):
        """Mocked start_response to record returned status and headers"""
        self.status = status
        self.headers = headers


class Warden3ServerTest(unittest.TestCase):
    """High level Warden3 Server tests"""
    config = {'log': {'level': 'debug'}, 'validator': {'type': 'NoValidator'}, 'auth': {'type': 'PlainAuthenticator'},
              'db': {'user': USER, 'password': PASSWORD, 'dbname': DB}, 'handler': {'description': 'Warden Test Server'}}

    getInfo_interface_tests_specific = [
        ("/getInfo", "403 I'm watching. Authenticate."),
        ("/getInfo?client=", "403 I'm watching. Authenticate."),
        ("/getInfo?client=cz.cesnet.warden3test", "403 I'm watching. Authenticate."),
        ("/getInfo?client=cz.cesnet.warden3test&secret=123", "403 I'm watching. Authenticate."),
        ("/getInfo?secret=123", "403 I'm watching. Authenticate."),
    ]
    getEvents_interface_tests_specific = [
        ("/getEvents", "403 I'm watching. Authenticate.", None),
        ("/getEvents?secret=123", "403 I'm watching. Authenticate.", None),
    ]

    @classmethod
    def setUpClass(cls):
        """Pre-test cleanup"""
        cls.clean_lastid()
        cls.app = build_server(cls.config)

    @classmethod
    def clean_lastid(cls):
        """Cleans the lastid information for all clients"""
        conn = my.connect(user=USER, passwd=PASSWORD, db=DB)
        cur = conn.cursor()
        cur.execute("DELETE FROM events")
        cur.execute("DELETE FROM last_events")
        cur.close()
        conn.commit()
        conn.close()

    def test_getInfo_interface(self):       # pylint: disable = locally-disabled, invalid-name
        """Tests the getInfo method invocation"""
        tests_common = [
            ("/getInfo?secret=abc", "200 OK"),
            ("/getInfo?secret=abc&evil=false", "200 OK"),       # RFC3514
            ("/getInfo?client=cz.cesnet.warden3test&secret=abc", "200 OK"),
            ("/getInfo?client=asdf.blefub", "403 I'm watching. Authenticate."),
            ("/getInfo?client=asdf.blefub&secret=abc", "403 I'm watching. Authenticate."),
            ("/getInfo?secret=abc&self=test", "200 OK"),        # Internal parameter
        ]
        for query, expected_status in tests_common + self.getInfo_interface_tests_specific:
            with self.subTest(query=query, expected_status=expected_status):
                status, _, _ = Request(self.app, query)()
                self.assertEqual(status, expected_status)

    def test_getEvents_interface(self):     # pylint: disable = locally-disabled, invalid-name
        """Tests the getEvents method invocation"""
        tests_common = [
            ("/getEvents?secret=abc", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?client=foo", "403 I'm watching. Authenticate.", None),
            ("/getEvents?secret=abc&foo=bar", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&lastid=1", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&lastid=0", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&lastid=9", "200 OK", ['{"lastid": 1, "events": []}']),
            ("/getEvents?secret=abc&cat=bflm", "422 Wrong tag or category used in query.", None),
            ("/getEvents?secret=abc&cat=Other", "200 OK", None),
            ("/getEvents?secret=abc&tag=Other", "200 OK", None),
            ("/getEvents?secret=abc&group=Other", "200 OK", None),
            ("/getEvents?secret=abc&cat=Other&nocat=Test", "422 Unrealizable conditions. Choose cat or nocat option.", None),
            ("/getEvents?secret=abc&tag=Other&notag=Test", "422 Unrealizable conditions. Choose tag or notag option.", None),
            ("/getEvents?secret=abc&group=Other&nogroup=Test", "422 Unrealizable conditions. Choose group or nogroup option.", None),
            ("/getEvents?client=cz.cesnet.warden3test&secret=abc&count=3&id=10", "200 OK", None)
        ]
        for query, expected_status, expected_response in tests_common + self.getEvents_interface_tests_specific:
            with self.subTest(query=query, expected_status=expected_status, expected_response=expected_response):
                status, _, out = Request(self.app, query)()
                self.assertEqual(status, expected_status)
                if expected_response is not None:
                    self.assertEqual(out, expected_response)

    def test_getDebug_interface(self):      # pylint: disable = locally-disabled, invalid-name
        """Tests the getDebug method invocation"""
        tests = [
            ("/getDebug?secret=abc", "200 OK"),
            ("/getDebug?client=cz.cesnet.warden3test&secret=abc", "200 OK"),
            ("/getDebug?secret=abc&self=test", "200 OK"),
        ]
        for query, expected_status in tests:
            with self.subTest(query=query, expected_status=expected_status):
                status, _, _ = Request(self.app, query)()
                self.assertEqual(status, expected_status)

    def test_methods(self):
        """Tests application behaviour in method parsing"""
        tests = [
            ("", "404 You've fallen off the cliff."),
            ("/blefub?client=client&secret=secret", "404 You've fallen off the cliff."),
            ("/?client=client&secret=secret", "404 You've fallen off the cliff."),
        ]
        for query, expected_status in tests:
            with self.subTest(query=query, expected_status=expected_status):
                status, _, _ = Request(self.app, query)()
                self.assertEqual(status, expected_status)

    def test_payload(self):
        """Tests parsing of transported data"""
        tests = [
            ("/getInfo?secret=abc", "", "200 OK", None),
            ("/getInfo?secret=abc", "[1]", "200 OK", None),
            ("/getInfo?secret=abc", "{#$%^", "200 OK", None),
            ("/sendEvents?secret=abc", "", "200 OK", ['{"saved": 0}']),
            ("/sendEvents?secret=abc", "{'test': 'true'}", "400 Deserialization error.", None),
            ("/sendEvents?secret=abc", '{"test": "true"}', "400 List of events expected.", None),
            ("/sendEvents?secret=abc", '[{"test": "true"}]', "422 Event does not bear valid Node attribute", None),
            ("/sendEvents?secret=abc", '[{"Node": ["test", "test2"]}]', "422 Event does not bear valid Node attribute", None),
            ("/sendEvents?secret=abc", '[{"Node": ["Name", "test"]}]', "422 Event does not bear valid Node attribute", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name"}]}]', "400 Deserialization error.", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "test"}]}]', "422 Node does not correspond with saving client", None),
            ("/sendEvents?secret=abc", '[{"Node": [{"Name": "cz.cesnet.warden3test"}]}]', "200 OK", ['{"saved": 1}']),
        ]
        for query, payload, expected_status, expected_response in tests:
            with self.subTest(query=query, payload=payload, expected_status=expected_status, expected_response=expected_response):
                status, _, out = Request(self.app, query, payload)()
                self.assertEqual(status, expected_status)
                if expected_response is not None:
                    self.assertEqual(out, expected_response)


class X509AuthenticatorTest(Warden3ServerTest):
    """Performs the basic test suite using the X509Authenticator"""
    config = deepcopy(Warden3ServerTest.config)
    config['auth']['type'] = 'X509Authenticator'


class X509NameAuthenticatorTest(Warden3ServerTest):
    """Performs the basic test suite using the X509NameAuthenticator"""
    config = deepcopy(Warden3ServerTest.config)
    config['auth']['type'] = 'X509NameAuthenticator'

    getInfo_interface_tests_specific = [
        ("/getInfo", "200 OK"),
        ("/getInfo?client=", "200 OK"),
        ("/getInfo?client=cz.cesnet.warden3test", "200 OK"),
        ("/getInfo?client=cz.cesnet.warden3test&secret=123", "200 OK"),
        ("/getInfo?secret=123", "200 OK"),
    ]
    getEvents_interface_tests_specific = [
        ("/getEvents", "200 OK", None),
        ("/getEvents?secret=123", "200 OK", None),
    ]


class WScliTest(unittest.TestCase):
    """Tester of the Warden Server command line interface"""
    @classmethod
    def setUpClass(cls):
        cls.config = {'log': {'level': 'debug'}, 'validator': {'type': 'NoValidator'}, 'auth': {'type': 'PlainAuthenticator'},
                      'db': {'user': USER, 'password': PASSWORD, 'dbname': DB}, 'handler': {'description': 'Warden Test Server'}}
        warden_server.server = build_server(cls.config)

    @staticmethod
    def do_cli(command_line):
        """Performs the command line action requested by argv and presents the results"""
        argv_backup = sys.argv
        sys.argv = command_line
        out = StringIO()
        err = StringIO()
        sys.stdout = out
        sys.stderr = err
        try:
            args = warden_server.get_args()
            command = args.command
            subargs = vars(args)
            del subargs["command"]
            del subargs["config"]
            ret = command(**subargs)
        except SystemExit as sys_exit:
            ret = sys_exit.code
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
        sys.argv = argv_backup
        return ret, out.getvalue(), err.getvalue()

    @staticmethod
    def do_sql_select(query, params):
        """Reads data from database"""
        conn = my.connect(user=USER, passwd=PASSWORD, db=DB)
        cur = conn.cursor()
        cur.execute(query, params)
        result = cur.fetchall()
        cur.close()
        conn.close()
        return result

    def test_list(self):
        """Tests the list command line option"""
        tests = [
            (['list'], 0, 4),
            (['list', '--id=1'], 0, 4),
            (['list', '--id=1000'], 0, 3),
            (['list', '--id', '1'], 0, 4),
            (['list', '--id', '1000'], 0, 3),
        ]
        for supplied_arguments, expected_return, output_lines in tests:
            with self.subTest(supplied_arguments=supplied_arguments, expected_return=expected_return, output_lines=output_lines):
                ret, out, _ = self.do_cli(['./warden_server.py'] + supplied_arguments)
                self.assertEqual(ret, expected_return)
                self.assertEqual(len(out.split('\n')), output_lines)

    def test_register_modify(self):
        """Tests the client registration and its modification"""
        tests = [
            (['register', '-n', 'cz.cesnet.warden.test2', '-h', 'test2.warden.cesnet.cz', '-r', 'warden-info@cesnet.cz'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--novalid'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 0, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--valid'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--valid', '--novalid'], 2,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--read'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 0, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--noread', '--write'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 0, 0, 1, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--debug', '--read'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 1, 1, 1, None),)),
            (['modify', '-i', 'CLIENT_ID', '--notest', '--nodebug'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '--notest', '--nodebug'], 2,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', '1000', '--notest', '--nodebug'], 251,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test2', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-n', 'cz.cesnet.warden.test3'], 0,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-n', '..'], 254,
             (('warden-info@cesnet.cz', 'test2.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-h', 'test3.warden.cesnet.cz'], 0,
             (('warden-info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-h', ''.zfill(256)], 253,
             (('warden-info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-h', '..'], 253,
             (('warden-info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'warden-info@cesnet.cz, info@cesnet.cz'], 0,
             (('warden-info@cesnet.cz, info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'warden-info@cesnet.cz ,info@cesnet.cz'], 0,
             (('warden-info@cesnet.cz ,info@cesnet.cz', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'Warden Info <warden-info@cesnet.cz>'], 0,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'Other Info <other-info@x.cz'], 252,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'Other other@x.cz'], 252,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'a@b, '], 252,
             (('Warden Info <warden-info@cesnet.cz>', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'a@b'], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', '@'], 252,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'abc'], 252,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-r', 'a@b@c'], 252,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-n', 'cz.cesnet.warden.test3'], 250,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-s', 'abc'], 249,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', None, 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-s', 'top_secret'], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '-s', 'top_secret'], 249,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, None),)),
            (['modify', '-i', 'CLIENT_ID', '--note', ''.zfill(1024)], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, ''.zfill(1024)),)),
            (['modify', '-i', 'CLIENT_ID', '--note', 'Valid until: 18.01.2038'], 0,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, 'Valid until: 18.01.2038'),)),
            (['modify', '-i', 'CLIENT_ID', '--note', 'Valid until:', '20.1.2038'], 2,
             (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, 'Valid until: 18.01.2038'),)),
        ]
        test_sql = "SELECT requestor, hostname, name, secret, valid, clients.read, debug, clients.write, test, note FROM clients WHERE id = %s"
        client_id = None
        for supplied_arguments, expected_return, expected_sql_result in tests:
            with self.subTest(supplied_arguments=supplied_arguments, expected_return=expected_return, expected_sql_result=expected_sql_result):
                supplied_arguments = [entry.replace('CLIENT_ID', str(client_id)) for entry in supplied_arguments]
                ret, out, _ = self.do_cli(['./warden_server.py'] + supplied_arguments)
                self.assertEqual(ret, expected_return)
                try:
                    client_id = int(out.split('\n')[-2].split(' ')[0])
                except IndexError:      # No modification was performed, keep the previous client_id
                    pass
                result = self.do_sql_select(test_sql, (client_id,))
                self.assertEqual(result, expected_sql_result)


def init_user():
    """DB user rights setup"""
    conn = None
    try:
        conn = my.connect(user='root', passwd=getpass.getpass('Enter MySQL Root password:'))
        with conn as cur:   # Not a canonical connector implementation, for sure
            cur.execute("GRANT SELECT, INSERT, UPDATE, CREATE, DELETE, DROP ON *.* TO %s@'localhost' IDENTIFIED BY %s", (USER, PASSWORD))
        conn.commit()
        print("DB User set up successfuly")
    except my.OperationalError as ex:
        if conn:
            conn.rollback()
            conn.close()
            conn = None
        print('Connection unsuccessful, bad password? Original exception: %s' % (str(ex)))
        exit()
    except KeyboardInterrupt:
        print("\nCancelled!")
        exit()
    finally:
        if conn:
            conn.close()


def main():
    """Parses arguments and acts accordingly"""
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('-i', '--init', action='store_true', help='Set up an user with rights to CREATE/DROP the test database')
    parser.add_argument('-n', '--nopurge', action='store_true', help='Skip the database purge after running the tests')

    args = parser.parse_args()
    if args.init:
        init_user()
    else:
        if args.nopurge:
            global NO_PURGE    # pylint: disable = locally-disabled, global-statement
            NO_PURGE = True
        sys.argv = [sys.argv[0]]
        unittest.main()


if __name__ == "__main__":
    main()