diff --git a/warden_server/test_warden_server.py b/warden_server/test_warden_server.py index a51ad0a7924e0052d4a3b357e255d8673ab010f6..09376e1e4fc44b0505c6edee1e56d7c0be75a753 100755 --- a/warden_server/test_warden_server.py +++ b/warden_server/test_warden_server.py @@ -7,8 +7,6 @@ import getpass import sys import warnings from os import path -from copy import deepcopy -import MySQLdb as my from warden_server import build_server import warden_server @@ -26,48 +24,19 @@ USER = 'warden3test' PASSWORD = 'h7w*D>4B)3omcvLM$oJp' DB = 'w3test' - def setUpModule(): # pylint: disable = locally-disabled, invalid-name """Initialize the test database""" print(__doc__) - conn = None - try: - conn = my.connect(user=USER, passwd=PASSWORD) - cur = conn.cursor() - with warnings.catch_warnings(): # The database is not supposed to exist - warnings.simplefilter("ignore") - cur.execute("DROP DATABASE IF EXISTS %s" % (DB,)) # NOT SECURE - cur.execute("CREATE DATABASE %s" % (DB,)) # NOT SECURE - cur.execute("USE %s" % (DB,)) # NOT SECURE - with open(path.join(path.dirname(__file__), 'warden_3.0_mysql.sql')) as script: - statements = ''.join([line.replace('\n', '') for line in script if line[0:2] != '--']).split(';')[:-1] - for statement in statements: - cur.execute(statement) - cur.execute("INSERT INTO clients VALUES(NULL, NOW(), 'warden-info@cesnet.cz', 'test.server.warden.cesnet.cz', NULL, 1, 'cz.cesnet.warden3test', 'abc', 1, 1, 1, 0)") - conn.commit() - except my.OperationalError as ex: - if conn: - conn.rollback() - conn.close() - conn = None - print('Setup failed, have you tried --init ? Original exception: %s' % (str(ex),)) - exit() - finally: - if conn: - conn.close() + DBMS.set_up() NO_PURGE = False - +DBMS = None def tearDownModule(): # pylint: disable = locally-disabled, invalid-name """Clean up by purging the test database""" if not NO_PURGE: - conn = my.connect(user=USER, passwd=PASSWORD) - cur = conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s" % (DB,)) # NOT SECURE - conn.commit() - conn.close() + DBMS.tear_down() class ReadableSTR(str): @@ -123,8 +92,6 @@ class Request(object): class Warden3ServerTest(unittest.TestCase): """High level Warden3 Server tests""" - config = {'log': {'level': 'debug'}, 'validator': {'type': 'NoValidator'}, 'auth': {'type': 'PlainAuthenticator'}, - 'db': {'user': USER, 'password': PASSWORD, 'dbname': DB}, 'handler': {'description': 'Warden Test Server'}} getInfo_interface_tests_specific = [ ("/getInfo", "403 I'm watching. Authenticate."), @@ -138,22 +105,26 @@ class Warden3ServerTest(unittest.TestCase): ("/getEvents?secret=123", "403 I'm watching. Authenticate.", None), ] + @staticmethod + def get_config(): + return { + 'log': {'level': 'debug'}, + 'validator': {'type': 'NoValidator'}, + 'auth': {'type': 'PlainAuthenticator'}, + 'db': {'type': DBMS.name, 'user': USER, 'password': PASSWORD, 'dbname': DB}, + 'handler': {'description': 'Warden Test Server'} + } + @classmethod def setUpClass(cls): """Pre-test cleanup""" cls.clean_lastid() - cls.app = build_server(cls.config) + cls.app = build_server(cls.get_config()) @classmethod def clean_lastid(cls): """Cleans the lastid information for all clients""" - conn = my.connect(user=USER, passwd=PASSWORD, db=DB) - cur = conn.cursor() - cur.execute("DELETE FROM events") - cur.execute("DELETE FROM last_events") - cur.close() - conn.commit() - conn.close() + DBMS.clean_lastid() def test_getInfo_interface(self): # pylint: disable = locally-disabled, invalid-name """Tests the getInfo method invocation""" @@ -245,14 +216,22 @@ class Warden3ServerTest(unittest.TestCase): class X509AuthenticatorTest(Warden3ServerTest): """Performs the basic test suite using the X509Authenticator""" - config = deepcopy(Warden3ServerTest.config) - config['auth']['type'] = 'X509Authenticator' + + @staticmethod + def get_config(): + config = Warden3ServerTest.get_config() + config['auth']['type'] = 'X509Authenticator' + return config class X509NameAuthenticatorTest(Warden3ServerTest): """Performs the basic test suite using the X509NameAuthenticator""" - config = deepcopy(Warden3ServerTest.config) - config['auth']['type'] = 'X509NameAuthenticator' + + @staticmethod + def get_config(): + config = Warden3ServerTest.get_config() + config['auth']['type'] = 'X509NameAuthenticator' + return config getInfo_interface_tests_specific = [ ("/getInfo", "200 OK"), @@ -271,8 +250,13 @@ class WScliTest(unittest.TestCase): """Tester of the Warden Server command line interface""" @classmethod def setUpClass(cls): - cls.config = {'log': {'level': 'debug'}, 'validator': {'type': 'NoValidator'}, 'auth': {'type': 'PlainAuthenticator'}, - 'db': {'user': USER, 'password': PASSWORD, 'dbname': DB}, 'handler': {'description': 'Warden Test Server'}} + cls.config = { + 'log': {'level': 'debug'}, + 'validator': {'type': 'NoValidator'}, + 'auth': {'type': 'PlainAuthenticator'}, + 'db': {'type': DBMS.name, 'user': USER, 'password': PASSWORD, 'dbname': DB}, + 'handler': {'description': 'Warden Test Server'} + } warden_server.server = build_server(cls.config) @staticmethod @@ -298,17 +282,6 @@ class WScliTest(unittest.TestCase): sys.argv = argv_backup return ret, out.getvalue(), err.getvalue() - @staticmethod - def do_sql_select(query, params): - """Reads data from database""" - conn = my.connect(user=USER, passwd=PASSWORD, db=DB) - cur = conn.cursor() - cur.execute(query, params) - result = cur.fetchall() - cur.close() - conn.close() - return result - def test_list(self): """Tests the list command line option""" tests = [ @@ -392,7 +365,6 @@ class WScliTest(unittest.TestCase): (['modify', '-i', 'CLIENT_ID', '--note', 'Valid until:', '20.1.2038'], 2, (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, 'Valid until: 18.01.2038'),)), ] - test_sql = "SELECT requestor, hostname, name, secret, valid, clients.read, debug, clients.write, test, note FROM clients WHERE id = %s" client_id = None for supplied_arguments, expected_return, expected_sql_result in tests: with self.subTest(supplied_arguments=supplied_arguments, expected_return=expected_return, expected_sql_result=expected_sql_result): @@ -403,44 +375,136 @@ class WScliTest(unittest.TestCase): client_id = int(out.split('\n')[-2].split(' ')[0]) except IndexError: # No modification was performed, keep the previous client_id pass - result = self.do_sql_select(test_sql, (client_id,)) + result = DBMS.do_sql_select(DBMS.reg_mod_test_query, (client_id,)) self.assertEqual(result, expected_sql_result) -def init_user(): - """DB user rights setup""" - conn = None - try: - conn = my.connect(user='root', passwd=getpass.getpass('Enter MySQL Root password:')) - with conn.cursor() as cur: - cur.execute("CREATE USER IF NOT EXISTS %s@'localhost' IDENTIFIED BY %s", (USER, PASSWORD)) - cur.execute("GRANT SELECT, INSERT, UPDATE, CREATE, DELETE, DROP ON *.* TO %s@'localhost'", (USER,)) +class MySQL: + name = "MySQL" + reg_mod_test_query = "SELECT requestor, hostname, name, secret, valid, clients.read, " \ + "debug, clients.write, test, note FROM clients WHERE id = %s" + + def __init__(self, user=USER, password=PASSWORD, dbname=DB): + import MySQLdb as my + self.my = my + + self.user = user + self.password = password + self.dbname = dbname + + def init_user(self): + """DB user rights setup""" + conn = None + try: + conn = self.my.connect(user='root', passwd=getpass.getpass( + 'Enter MySQL Root password:')) + with conn.cursor() as cur: + cur.execute( + "CREATE USER IF NOT EXISTS %s@'localhost' IDENTIFIED BY %s", + (self.user, self.password) + ) + cur.execute( + "GRANT SELECT, INSERT, UPDATE, CREATE, DELETE, DROP ON *.* TO %s@'localhost'", + (self.user,) + ) + conn.commit() + print("DB User set up successfuly") + except self.my.OperationalError as ex: + if conn: + conn.rollback() + conn.close() + conn = None + print('Connection unsuccessful, bad password? Original exception: %s' % (str(ex))) + exit() + except KeyboardInterrupt: + print("\nCancelled!") + exit() + finally: + if conn: + conn.close() + + def set_up(self): + conn = None + try: + conn = self.my.connect(user=self.user, passwd=self.password) + cur = conn.cursor() + with warnings.catch_warnings(): # The database is not supposed to exist + warnings.simplefilter("ignore") + cur.execute("DROP DATABASE IF EXISTS %s" % (self.dbname,)) # NOT SECURE + cur.execute("CREATE DATABASE %s" % (self.dbname,)) # NOT SECURE + cur.execute("USE %s" % (self.dbname,)) # NOT SECURE + with open(path.join(path.dirname(__file__), 'warden_3.0_mysql.sql')) as script: + statements = ''.join( + [line.replace('\n', '') for line in script if line[0:2] != '--'] + ).split(';')[:-1] + for statement in statements: + cur.execute(statement) + cur.execute( + "INSERT INTO clients VALUES(" + "NULL, NOW(), 'warden-info@cesnet.cz', 'test.server.warden.cesnet.cz'," + "NULL, 1, 'cz.cesnet.warden3test', 'abc', 1, 1, 1, 0" + ")" + ) + conn.commit() + except self.my.OperationalError as ex: + if conn: + conn.rollback() + conn.close() + conn = None + print('Setup failed, have you tried --init ? Original exception: %s' % (str(ex),)) + exit() + finally: + if conn: + conn.close() + + def do_sql_select(self, query, params): + """Reads data from database""" + conn = self.my.connect(user=self.user, passwd=self.password, db=self.dbname) + cur = conn.cursor() + cur.execute(query, params) + result = cur.fetchall() + cur.close() + conn.close() + return result + + def tear_down(self): + """Clean up by purging the test database""" + conn = self.my.connect(user=self.user, passwd=self.password) + cur = conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s" % (self.dbname,)) # NOT SECURE + conn.commit() + conn.close() + + def clean_lastid(self): + """Cleans the lastid information for all clients""" + conn = self.my.connect(user=self.user, passwd=self.password, db=self.dbname) + cur = conn.cursor() + cur.execute("DELETE FROM last_events") + cur.execute("DELETE FROM events") + cur.close() conn.commit() - print("DB User set up successfuly") - except my.OperationalError as ex: - if conn: - conn.rollback() - conn.close() - conn = None - print('Connection unsuccessful, bad password? Original exception: %s' % (str(ex))) - exit() - except KeyboardInterrupt: - print("\nCancelled!") - exit() - finally: - if conn: - conn.close() + conn.close() + + +database_types = { + 'MySQL': MySQL +} def main(): """Parses arguments and acts accordingly""" parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('-d', '--dbms', default='MySQL', choices=database_types, help='Database management system to use for testing') parser.add_argument('-i', '--init', action='store_true', help='Set up an user with rights to CREATE/DROP the test database') parser.add_argument('-n', '--nopurge', action='store_true', help='Skip the database purge after running the tests') args = parser.parse_args() + + global DBMS # pylint: disable = locally-disabled, global-statement + DBMS = database_types[args.dbms](USER, PASSWORD, DB) + if args.init: - init_user() + DBMS.init_user() else: if args.nopurge: global NO_PURGE # pylint: disable = locally-disabled, global-statement