Skip to content
Snippets Groups Projects
Commit 9fefac12 authored by Jakub Maloštík's avatar Jakub Maloštík
Browse files

Generalized test_warden_server.py database access

parent 93b7cd42
No related branches found
No related tags found
No related merge requests found
...@@ -7,8 +7,6 @@ import getpass ...@@ -7,8 +7,6 @@ import getpass
import sys import sys
import warnings import warnings
from os import path from os import path
from copy import deepcopy
import MySQLdb as my
from warden_server import build_server from warden_server import build_server
import warden_server import warden_server
...@@ -26,48 +24,19 @@ USER = 'warden3test' ...@@ -26,48 +24,19 @@ USER = 'warden3test'
PASSWORD = 'h7w*D>4B)3omcvLM$oJp' PASSWORD = 'h7w*D>4B)3omcvLM$oJp'
DB = 'w3test' DB = 'w3test'
def setUpModule(): # pylint: disable = locally-disabled, invalid-name def setUpModule(): # pylint: disable = locally-disabled, invalid-name
"""Initialize the test database""" """Initialize the test database"""
print(__doc__) print(__doc__)
conn = None DBMS.set_up()
try:
conn = my.connect(user=USER, passwd=PASSWORD)
cur = conn.cursor()
with warnings.catch_warnings(): # The database is not supposed to exist
warnings.simplefilter("ignore")
cur.execute("DROP DATABASE IF EXISTS %s" % (DB,)) # NOT SECURE
cur.execute("CREATE DATABASE %s" % (DB,)) # NOT SECURE
cur.execute("USE %s" % (DB,)) # NOT SECURE
with open(path.join(path.dirname(__file__), 'warden_3.0_mysql.sql')) as script:
statements = ''.join([line.replace('\n', '') for line in script if line[0:2] != '--']).split(';')[:-1]
for statement in statements:
cur.execute(statement)
cur.execute("INSERT INTO clients VALUES(NULL, NOW(), 'warden-info@cesnet.cz', 'test.server.warden.cesnet.cz', NULL, 1, 'cz.cesnet.warden3test', 'abc', 1, 1, 1, 0)")
conn.commit()
except my.OperationalError as ex:
if conn:
conn.rollback()
conn.close()
conn = None
print('Setup failed, have you tried --init ? Original exception: %s' % (str(ex),))
exit()
finally:
if conn:
conn.close()
NO_PURGE = False NO_PURGE = False
DBMS = None
def tearDownModule(): # pylint: disable = locally-disabled, invalid-name def tearDownModule(): # pylint: disable = locally-disabled, invalid-name
"""Clean up by purging the test database""" """Clean up by purging the test database"""
if not NO_PURGE: if not NO_PURGE:
conn = my.connect(user=USER, passwd=PASSWORD) DBMS.tear_down()
cur = conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s" % (DB,)) # NOT SECURE
conn.commit()
conn.close()
class ReadableSTR(str): class ReadableSTR(str):
...@@ -123,8 +92,6 @@ class Request(object): ...@@ -123,8 +92,6 @@ class Request(object):
class Warden3ServerTest(unittest.TestCase): class Warden3ServerTest(unittest.TestCase):
"""High level Warden3 Server tests""" """High level Warden3 Server tests"""
config = {'log': {'level': 'debug'}, 'validator': {'type': 'NoValidator'}, 'auth': {'type': 'PlainAuthenticator'},
'db': {'user': USER, 'password': PASSWORD, 'dbname': DB}, 'handler': {'description': 'Warden Test Server'}}
getInfo_interface_tests_specific = [ getInfo_interface_tests_specific = [
("/getInfo", "403 I'm watching. Authenticate."), ("/getInfo", "403 I'm watching. Authenticate."),
...@@ -138,22 +105,26 @@ class Warden3ServerTest(unittest.TestCase): ...@@ -138,22 +105,26 @@ class Warden3ServerTest(unittest.TestCase):
("/getEvents?secret=123", "403 I'm watching. Authenticate.", None), ("/getEvents?secret=123", "403 I'm watching. Authenticate.", None),
] ]
@staticmethod
def get_config():
return {
'log': {'level': 'debug'},
'validator': {'type': 'NoValidator'},
'auth': {'type': 'PlainAuthenticator'},
'db': {'type': DBMS.name, 'user': USER, 'password': PASSWORD, 'dbname': DB},
'handler': {'description': 'Warden Test Server'}
}
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Pre-test cleanup""" """Pre-test cleanup"""
cls.clean_lastid() cls.clean_lastid()
cls.app = build_server(cls.config) cls.app = build_server(cls.get_config())
@classmethod @classmethod
def clean_lastid(cls): def clean_lastid(cls):
"""Cleans the lastid information for all clients""" """Cleans the lastid information for all clients"""
conn = my.connect(user=USER, passwd=PASSWORD, db=DB) DBMS.clean_lastid()
cur = conn.cursor()
cur.execute("DELETE FROM events")
cur.execute("DELETE FROM last_events")
cur.close()
conn.commit()
conn.close()
def test_getInfo_interface(self): # pylint: disable = locally-disabled, invalid-name def test_getInfo_interface(self): # pylint: disable = locally-disabled, invalid-name
"""Tests the getInfo method invocation""" """Tests the getInfo method invocation"""
...@@ -245,14 +216,22 @@ class Warden3ServerTest(unittest.TestCase): ...@@ -245,14 +216,22 @@ class Warden3ServerTest(unittest.TestCase):
class X509AuthenticatorTest(Warden3ServerTest): class X509AuthenticatorTest(Warden3ServerTest):
"""Performs the basic test suite using the X509Authenticator""" """Performs the basic test suite using the X509Authenticator"""
config = deepcopy(Warden3ServerTest.config)
@staticmethod
def get_config():
config = Warden3ServerTest.get_config()
config['auth']['type'] = 'X509Authenticator' config['auth']['type'] = 'X509Authenticator'
return config
class X509NameAuthenticatorTest(Warden3ServerTest): class X509NameAuthenticatorTest(Warden3ServerTest):
"""Performs the basic test suite using the X509NameAuthenticator""" """Performs the basic test suite using the X509NameAuthenticator"""
config = deepcopy(Warden3ServerTest.config)
@staticmethod
def get_config():
config = Warden3ServerTest.get_config()
config['auth']['type'] = 'X509NameAuthenticator' config['auth']['type'] = 'X509NameAuthenticator'
return config
getInfo_interface_tests_specific = [ getInfo_interface_tests_specific = [
("/getInfo", "200 OK"), ("/getInfo", "200 OK"),
...@@ -271,8 +250,13 @@ class WScliTest(unittest.TestCase): ...@@ -271,8 +250,13 @@ class WScliTest(unittest.TestCase):
"""Tester of the Warden Server command line interface""" """Tester of the Warden Server command line interface"""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.config = {'log': {'level': 'debug'}, 'validator': {'type': 'NoValidator'}, 'auth': {'type': 'PlainAuthenticator'}, cls.config = {
'db': {'user': USER, 'password': PASSWORD, 'dbname': DB}, 'handler': {'description': 'Warden Test Server'}} 'log': {'level': 'debug'},
'validator': {'type': 'NoValidator'},
'auth': {'type': 'PlainAuthenticator'},
'db': {'type': DBMS.name, 'user': USER, 'password': PASSWORD, 'dbname': DB},
'handler': {'description': 'Warden Test Server'}
}
warden_server.server = build_server(cls.config) warden_server.server = build_server(cls.config)
@staticmethod @staticmethod
...@@ -298,17 +282,6 @@ class WScliTest(unittest.TestCase): ...@@ -298,17 +282,6 @@ class WScliTest(unittest.TestCase):
sys.argv = argv_backup sys.argv = argv_backup
return ret, out.getvalue(), err.getvalue() return ret, out.getvalue(), err.getvalue()
@staticmethod
def do_sql_select(query, params):
"""Reads data from database"""
conn = my.connect(user=USER, passwd=PASSWORD, db=DB)
cur = conn.cursor()
cur.execute(query, params)
result = cur.fetchall()
cur.close()
conn.close()
return result
def test_list(self): def test_list(self):
"""Tests the list command line option""" """Tests the list command line option"""
tests = [ tests = [
...@@ -392,7 +365,6 @@ class WScliTest(unittest.TestCase): ...@@ -392,7 +365,6 @@ class WScliTest(unittest.TestCase):
(['modify', '-i', 'CLIENT_ID', '--note', 'Valid until:', '20.1.2038'], 2, (['modify', '-i', 'CLIENT_ID', '--note', 'Valid until:', '20.1.2038'], 2,
(('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, 'Valid until: 18.01.2038'),)), (('a@b', 'test3.warden.cesnet.cz', 'cz.cesnet.warden.test3', 'top_secret', 1, 1, 0, 1, 0, 'Valid until: 18.01.2038'),)),
] ]
test_sql = "SELECT requestor, hostname, name, secret, valid, clients.read, debug, clients.write, test, note FROM clients WHERE id = %s"
client_id = None client_id = None
for supplied_arguments, expected_return, expected_sql_result in tests: for supplied_arguments, expected_return, expected_sql_result in tests:
with self.subTest(supplied_arguments=supplied_arguments, expected_return=expected_return, expected_sql_result=expected_sql_result): with self.subTest(supplied_arguments=supplied_arguments, expected_return=expected_return, expected_sql_result=expected_sql_result):
...@@ -403,21 +375,41 @@ class WScliTest(unittest.TestCase): ...@@ -403,21 +375,41 @@ class WScliTest(unittest.TestCase):
client_id = int(out.split('\n')[-2].split(' ')[0]) client_id = int(out.split('\n')[-2].split(' ')[0])
except IndexError: # No modification was performed, keep the previous client_id except IndexError: # No modification was performed, keep the previous client_id
pass pass
result = self.do_sql_select(test_sql, (client_id,)) result = DBMS.do_sql_select(DBMS.reg_mod_test_query, (client_id,))
self.assertEqual(result, expected_sql_result) self.assertEqual(result, expected_sql_result)
def init_user(): class MySQL:
name = "MySQL"
reg_mod_test_query = "SELECT requestor, hostname, name, secret, valid, clients.read, " \
"debug, clients.write, test, note FROM clients WHERE id = %s"
def __init__(self, user=USER, password=PASSWORD, dbname=DB):
import MySQLdb as my
self.my = my
self.user = user
self.password = password
self.dbname = dbname
def init_user(self):
"""DB user rights setup""" """DB user rights setup"""
conn = None conn = None
try: try:
conn = my.connect(user='root', passwd=getpass.getpass('Enter MySQL Root password:')) conn = self.my.connect(user='root', passwd=getpass.getpass(
'Enter MySQL Root password:'))
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("CREATE USER IF NOT EXISTS %s@'localhost' IDENTIFIED BY %s", (USER, PASSWORD)) cur.execute(
cur.execute("GRANT SELECT, INSERT, UPDATE, CREATE, DELETE, DROP ON *.* TO %s@'localhost'", (USER,)) "CREATE USER IF NOT EXISTS %s@'localhost' IDENTIFIED BY %s",
(self.user, self.password)
)
cur.execute(
"GRANT SELECT, INSERT, UPDATE, CREATE, DELETE, DROP ON *.* TO %s@'localhost'",
(self.user,)
)
conn.commit() conn.commit()
print("DB User set up successfuly") print("DB User set up successfuly")
except my.OperationalError as ex: except self.my.OperationalError as ex:
if conn: if conn:
conn.rollback() conn.rollback()
conn.close() conn.close()
...@@ -431,16 +423,88 @@ def init_user(): ...@@ -431,16 +423,88 @@ def init_user():
if conn: if conn:
conn.close() conn.close()
def set_up(self):
conn = None
try:
conn = self.my.connect(user=self.user, passwd=self.password)
cur = conn.cursor()
with warnings.catch_warnings(): # The database is not supposed to exist
warnings.simplefilter("ignore")
cur.execute("DROP DATABASE IF EXISTS %s" % (self.dbname,)) # NOT SECURE
cur.execute("CREATE DATABASE %s" % (self.dbname,)) # NOT SECURE
cur.execute("USE %s" % (self.dbname,)) # NOT SECURE
with open(path.join(path.dirname(__file__), 'warden_3.0_mysql.sql')) as script:
statements = ''.join(
[line.replace('\n', '') for line in script if line[0:2] != '--']
).split(';')[:-1]
for statement in statements:
cur.execute(statement)
cur.execute(
"INSERT INTO clients VALUES("
"NULL, NOW(), 'warden-info@cesnet.cz', 'test.server.warden.cesnet.cz',"
"NULL, 1, 'cz.cesnet.warden3test', 'abc', 1, 1, 1, 0"
")"
)
conn.commit()
except self.my.OperationalError as ex:
if conn:
conn.rollback()
conn.close()
conn = None
print('Setup failed, have you tried --init ? Original exception: %s' % (str(ex),))
exit()
finally:
if conn:
conn.close()
def do_sql_select(self, query, params):
"""Reads data from database"""
conn = self.my.connect(user=self.user, passwd=self.password, db=self.dbname)
cur = conn.cursor()
cur.execute(query, params)
result = cur.fetchall()
cur.close()
conn.close()
return result
def tear_down(self):
"""Clean up by purging the test database"""
conn = self.my.connect(user=self.user, passwd=self.password)
cur = conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s" % (self.dbname,)) # NOT SECURE
conn.commit()
conn.close()
def clean_lastid(self):
"""Cleans the lastid information for all clients"""
conn = self.my.connect(user=self.user, passwd=self.password, db=self.dbname)
cur = conn.cursor()
cur.execute("DELETE FROM last_events")
cur.execute("DELETE FROM events")
cur.close()
conn.commit()
conn.close()
database_types = {
'MySQL': MySQL
}
def main(): def main():
"""Parses arguments and acts accordingly""" """Parses arguments and acts accordingly"""
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('-d', '--dbms', default='MySQL', choices=database_types, help='Database management system to use for testing')
parser.add_argument('-i', '--init', action='store_true', help='Set up an user with rights to CREATE/DROP the test database') parser.add_argument('-i', '--init', action='store_true', help='Set up an user with rights to CREATE/DROP the test database')
parser.add_argument('-n', '--nopurge', action='store_true', help='Skip the database purge after running the tests') parser.add_argument('-n', '--nopurge', action='store_true', help='Skip the database purge after running the tests')
args = parser.parse_args() args = parser.parse_args()
global DBMS # pylint: disable = locally-disabled, global-statement
DBMS = database_types[args.dbms](USER, PASSWORD, DB)
if args.init: if args.init:
init_user() DBMS.init_user()
else: else:
if args.nopurge: if args.nopurge:
global NO_PURGE # pylint: disable = locally-disabled, global-statement global NO_PURGE # pylint: disable = locally-disabled, global-statement
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment