Skip to content
Snippets Groups Projects
Commit 9468e8f1 authored by pharook's avatar pharook
Browse files

Initial commit, working LMDB and SQL backend

parents
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import string
import uuid
import datetime
import random
import pprint
import cProfile
import lmdb
import MySQLdb
from keyvalue_db import KeyValueDB
from lmdb_index import LMDBIndex
from sql_db import SQLDB, SQLIndex
from sql_db import DummyEnv
# Notes:
# * gen_random_idea is not complete, only single ip addresses
# * gen_random_idea generates lexically comparable ip adresses (zero padded)
# it's simpler for testing purposes
# (in reality the proper conversion machinery would be necessary)
def __test__():
lmdb_env = lmdb.Environment(
"lmdb",
max_dbs=128,
sync=False,
writemap=False)
db = KeyValueDB(
("name", "surname", "group.work", "group.home"),
idxfactory=LMDBIndex, env=lmdb_env)
db.clear()
db.insert(dict(name="Beda", surname="Travnicek", group=dict(work="ceo", home="slave")))
db.insert(dict(name=["Llamar", "Puno"], surname="Popovic", group=dict(work="hr", home="free")))
db.insert(dict(name="Zuno", surname="Popovic", group=[dict(work="hr", home="complicated"), dict(work="pr")]))
db.insert(dict(name="Zuno", surname="Cimbal", group=dict(work="nobody", home="complicated")))
pprint.pprint(db.dump())
pprint.pprint(db.query(
db.and_(
db.eq("group.work", "hr"),
db.eq("group.home", "complicated")
),
"surname"
))
pprint.pprint(db.query(
db.or_(
db.eq("group.work", "hr"),
db.eq("group.home", "complicated")
),
"surname"
))
pprint.pprint(db.query(
db.range("surname", "O", "S"),
"name"
))
def insert_idea(db, num=10):
for i in range(num):
idea = gen_random_idea()
db.insert(idea)
def ip4tolex(ipstr):
return ".".join(("%03i" % int(i) for i in ipstr.split(".")))
def ip6tolex(ipstr):
return ":".join(("%04x" % int(i, 16) for i in ipstr.split(":")))
def gen_random_idea(client_name="cz.example.warden.test"):
def format_timestamp():
return datetime.datetime.now().isoformat() + "+02:00"
def rand4ip():
return "192.000.002.%03i" % random.randint(1, 254)
def rand6ip():
return "2001:0db8:%s" % ":".join("%04x" % random.randint(0, 65535) for i in range(6))
event = {
"Format": "IDEA0",
"ID": str(uuid.uuid4()),
"CreateTime": format_timestamp(),
"DetectTime": format_timestamp(),
"Category": [random.choice(["Abusive.Spam","Abusive.Harassment","Malware","Fraud.Copyright","Test","Fraud.Phishing","Fraud.Scam"]) for dummy in range(random.randint(1, 3))],
"Note": "Random event",
"ConnCount": random.randint(0, 65535),
"Source": [
{
"Type": ["Phishing"],
"IP4": [rand4ip() for i in range(random.randrange(1, 5))],
"IP6": [rand6ip() for i in range(random.randrange(1, 5))],
"Hostname": ["example.com"],
"Port": [random.randint(1, 65535) for i in range(random.randrange(1, 3))],
}
],
"Target": [
{
"IP4": [rand4ip() for i in range(random.randrange(1, 5))],
"IP6": [rand6ip() for i in range(random.randrange(1, 5))],
"Proto": ["tcp", "http"],
}
],
"Node": [
{
"Name": client_name,
"Type": [random.choice(["Data", "Protocol", "Honeypot", "Heuristic", "Log"]) for dummy in range(random.randint(1, 3))],
"SW": ["Kippo"],
}
]
}
return event
def __main__():
# lmdb
#~ env = lmdb.Environment(
#~ "lmdb",
#~ map_size=1024*1024*1024*1024, # 1TiB
#~ max_dbs=128,
#~ sync=False,
#~ writemap=False)
#~ idxfactory = LMDBIndex
#~ dbfactory = KeyValueDB
#~ # dummy sql
#~ env = DummyEnv()
#~ idxfactory = SQLIndex
#~ dbfactory = SQLDB
# sql
env = MySQLdb.connect(
host='localhost',
user='root',
passwd='l3nivec',
db='dumbdb')
idxfactory = SQLIndex
dbfactory = SQLDB
db = dbfactory(
("ID", "DetectTime", "Category", "Node.Name",
"Source.IP4", "Source.IP6", "Target.IP4", "Target.IP6"),
idxfactory=idxfactory, env=env)
#~ db.clear()
#~ return
#~ insert_idea(db, 80000)
#~ env.commit()
#~ pprint.pprint(db.dump())
#~ return
#~ res = db.query(
#~ db.and_(
#~ db.range("Source.IP4", '192.000.002.000', '192.000.002.255'),
#~ db.range("Target.IP4", "192.000.002.000", "192.000.002.255"),
#~ db.eq("Node.Name", "cz.example.warden.test")
#~ ),
#~ "DetectTime"
#~ )
res = db.query(
db.and_(
db.range("Source.IP4", '192.000.002.128', '192.000.002.255'),
db.range("Target.IP4", "192.000.002.120", "192.000.002.255"),
db.eq("Node.Name", "cz.example.warden.test")
),
order="DetectTime",
skip=0,
limit=30
#~ order=None
)
print len(res)
#~ pprint.pprint(res)
#~ cProfile.run("__main__()", sort="cumulative")
__main__()
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import collections
import cPickle
class Index(object):
def __init__(self, name, env):
self.name = name
self.env = env
def insert(self, key=None, value=None):
raise NotImplemented
def query_eq_all(self, key):
raise NotImplemented
def query_eq(self, key):
raise NotImplemented
def query_ge(self, key):
raise NotImplemented
def query_le(self, key):
raise NotImplemented
def query_range(self, key1, key2):
raise NotImplemented
def clear(self):
raise NotImplemented
def dump(self):
raise NotImplemented
class DB(object):
rev = True
fwd = True
def __init__(self, indices, idxfactory, rev=True, fwd=True, *args, **kwargs):
self.indices = indices
self.env = kwargs.get("env")
if self.rev:
self.revkeys = {}
if self.fwd:
self.fwdkeys = {}
for key in self.indices:
if self.rev:
self.revkeys[key] = idxfactory("rev." + key, dup=True, *args, **kwargs)
if self.fwd:
self.fwdkeys[key] = idxfactory("fwd." + key, dup=False, *args, **kwargs)
self.data = idxfactory("__data__", dup=False, *args, **kwargs)
def insert(self, data):
uniq = self.data.insert(None, cPickle.dumps(data))[0]
for key in self.indices:
values = self.get_value(data, key.split("."))
for value in values:
if self.rev:
self.revkeys[key].insert(value, uniq)
if self.fwd:
self.fwdkeys[key].insert(uniq, value)
def and_(self, *q):
raise NotImplemented
def or_(self, *q):
raise NotImplemented
def query(self, q, order=None, reverse=False, skip=0, limit=1):
raise NotImplemented
def get_value(self, data, path):
if not path:
if isinstance(data, set):
return data
else:
return set([data])
key = path[0]
try:
subdata = data[key]
except KeyError:
return set()
if isinstance(subdata, collections.Sequence) and not isinstance(subdata, basestring):
res = set()
for v in subdata:
res |= self.get_value(v, path[1:])
else:
res = self.get_value(set([subdata]), path[1:])
return res
def dump(self):
res = {}
res.update(self.data.dump())
if self.rev:
for keyobj in self.revkeys.itervalues():
res.update(keyobj.dump())
if self.fwd:
for keyobj in self.fwdkeys.itervalues():
res.update(keyobj.dump())
return res
def clear(self):
if self.rev:
for keyobj in self.revkeys.itervalues():
keyobj.clear()
if self.fwd:
for keyobj in self.fwdkeys.itervalues():
keyobj.clear()
self.data.clear()
def _op(self, op, key, *args, **kwargs):
idxset = self.revkeys if self.rev else self.fwdkeys
return getattr(idxset[key], op)(*args, **kwargs)
def eq(self, key, data):
return self._op("query_eq_all", key, data)
def le(self, key, data):
return self._op("query_le", key, data)
def ge(self, key, data):
return self._op("query_ge", key, data)
def range(self, key, data1, data2):
return self._op("query_range", key, data1, data2)
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import cPickle
from dumb_db import DB
class KeyValueDB(DB):
rev = True
fwd = True
def and_(self, *q): # FIXME - intersection knows more
res = q[0]
for q in q[1:]:
res &= q
return res
def or_(self, *q):
res = q[0]
for q in q[1:]:
res |= q
return res
def query(self, q, order=None, reverse=False, skip=0, limit=1):
if order is not None:
res = sorted(q, key=self.fwdkeys[order].query_eq, reverse=reverse)
else:
res = q
if skip or limit:
# Note that slicing makes copy, so this will lose efficiency
# with big limits.
# We could use itertools.islice, but it considers input as
# a dumb iterator and will lose efficiency for big skips.
# Tough call, but let's assume big skip and small limit
# is more common.
res = res[skip:skip+limit]
return [cPickle.loads(self.data.query_eq(v)) for v in res]
#return [self.data.query_eq(v) for v in res]
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import lmdb
import random
from dumb_db import Index
class LMDBIndex(Index):
def __init__(self, name, env, dup=False):
Index.__init__(self, name, env)
self.dup = dup
self.handle = self.env.open_db(self.name, dupsort=self.dup)
def insert(self, key=None, value=None):
if key is None:
key = "%016x" % random.randint(0, 2**64-1) # may not be safe enough
with self.env.begin(buffers=False, write=True) as txn:
txn.put(key, value, db=self.handle, dupdata=self.dup, overwrite=True)
return key, value
# Note here, that we use set(iterator) construct in hope that
# fast C implementation of lmdb cursor iterator connects directly
# to the fast C implementation of the Python set(), without any
# interpreted code intervention. And, as it seems, it does.
def query_eq_all(self, key):
with self.env.begin(buffers=False) as txn:
with txn.cursor(db=self.handle) as crs:
crs.set_key(key)
return set(crs.iternext_dup(keys=False, values=True))
def query_eq(self, key):
with self.env.begin(buffers=False) as txn:
return txn.get(key, db=self.handle)
def query_ge(self, key):
with self.env.begin(buffers=False) as txn:
with txn.cursor(db=self.handle) as crs:
crs.set_range(key)
return set(crs.iternext(keys=False, values=True))
def query_le(self, key):
with self.env.begin(buffers=False) as txn:
with txn.cursor(db=self.handle) as crs:
crs.set_range(key)
it = crs.iterprev(keys=False, values=True)
it.next()
return set(it)
def query_range(self, key1, key2):
# Not quite correct
return self.query_ge(key1) & self.query_le(key2)
def clear(self):
with self.env.begin(buffers=True, write=True) as txn:
txn.drop(db=self.handle, delete=False)
def dump(self):
res = {}
with self.env.begin(buffers=False) as txn:
with txn.cursor(db=self.handle) as crs:
crs.first()
for key, value in crs.iternext(keys=True, values=True):
if self.dup:
res.setdefault(key, []).append(value)
else:
res[key] = value
return {self.name: res}
sql_db.py 0 → 100644
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from collections import namedtuple
from dumb_db import DB, Index
class DummyCursor(object):
def execute(self, q):
print q
def fetchall(self):
return []
def close(self):
pass
class DummyEnv(object):
def __init__(self):
self.crs = DummyCursor()
def commit(self):
pass
def escape(self, s):
return '"%s"' % s
def cursor(self):
return DummyCursor()
SubqueryTuple = namedtuple("SubqueryTuple", ["tables", "query"])
class SQLIndex(Index):
def __init__(self, name, env, dup=False):
Index.__init__(self, name, env)
crs = self.env.cursor()
crs.execute(
"""CREATE TABLE IF NOT EXISTS `%s` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`val` longtext NOT NULL,
KEY (`id`),
KEY (`val`(255)),
KEY (`id`, `val`(255))
)""" % self.name)
#~ crs.execute(
#~ """ALTER TABLE `%s` ADD KEY (`val`(255), `id`)
#~ """ % self.name)
crs.close()
self.env.commit()
def insert(self, key=None, value=None):
crs = self.env.cursor()
crs.execute(
"INSERT INTO `%s` (`id`, `val`) VALUES (%%s, %%s)" % self.name,
(key, value))
if key is None:
key = crs.lastrowid
crs.close()
#~ self.env.commit()
return key, value
def query_eq_all(self, key):
return SubqueryTuple(
tables = set([self.name]),
query = "`%s`.val = %s" % (self.name, self.env.escape(key))
)
query_eq = query_eq_all
def query_ge(self, key):
return SubqueryTuple(
tables = set([self.name]),
query = "`%s`.val >= %s" % (self.name, self.env.escape(key))
)
def query_le(self, key):
return SubqueryTuple(
tables = set([self.name]),
query = "`%s`.val >= %s" % (self.name, self.env.escape(key))
)
def query_range(self, key1, key2):
return SubqueryTuple(
tables = set([self.name]),
query = "`%s`.val BETWEEN %s AND %s" % (self.name, self.env.escape(key1), self.env.escape(key2))
)
def clear(self):
crs = self.env.cursor()
crs.execute("DROP TABLE `%s`" % self.name)
crs.close()
self.env.commit()
def dump(self):
res = {}
crs = self.env.cursor()
crs.execute("SELECT `id`, `val` FROM `%s`" % self.name)
for key, value in crs.fetchall():
res[key] = value
crs.close()
self.env.commit()
return {self.name: res}
class SQLDB(DB):
rev = False
fwd = True
def _nary_op(self, op, *q):
return SubqueryTuple(
tables = set.union(*(single[0] for single in q)),
query = (" %s " % op).join(("(%s)\n " % single[1] for single in q))
)
def and_(self, *q):
return self._nary_op("AND", *q)
def or_(self, *q):
return self._nary_op("OR", *q)
def _gen_join(self, name):
return "INNER JOIN `%s` USING (id)" % name
def query(self, q, order=None, reverse=False, skip=0, limit=1):
base = "SELECT DISTINCT `__data__`.val\n FROM `__data__`\n %s\n WHERE\n %s %s %s"
if order is not None:
order_name = self.fwdkeys[order].name
order_by = "ORDER BY `%s`.val %s" % (
order_name,
'DESC' if reverse else 'ASC'
)
q.tables.add(order_name)
else:
order_by = ""
if skip or limit:
limit_offset = "LIMIT %i OFFSET %i" % (limit, skip)
else:
limit_offset = ""
join = "\n ".join((self._gen_join(idx) for idx in q.tables))
query = base % (join, q.query, order_by, limit_offset)
crs = self.env.cursor()
print query
crs.execute(query)
res = crs.fetchall()
crs.close()
self.env.commit()
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment