diff --git a/datatest.py b/datatest.py index 87b7ce27c7c9759a40e755dbf0f942b5bd2f79d2..ee5d41783d68b125020665f3607968b0dac1e50b 100755 --- a/datatest.py +++ b/datatest.py @@ -60,47 +60,110 @@ def __main__(): "Source.Port", "Target.Port"), idxfactory=idxfactory, env=env) - #~ db.clear() - #~ return - #~ pprint.pprint(db.dump()) - #~ return - - # Import - for i, l in enumerate(mongo_idea_import.get_events(sys.stdin)): - if not i%1000: - print i - sys.stdout.flush() - db.insert(l, mongo_idea_import.json_default) - - res = db.query( - db.and_( - db.range("Target.IP4.ip", '195.113.000.000', '195.113.255.255'), - db.range("Source.IP4.ip", "071.006.165.000", "071.006.165.255"), - db.eq("Node.Name", "cz.cesnet.mentat.warden_filer"), - db.eq("Node.Type", "Relay"), - db.eq("Target.Port", " 22") - ), - order="DetectTime", - skip=0, - limit=30 - ) - - #~ res = db.query( - #~ db.eq("Source.IP4.ip", "071.006.165.200"), - #~ order="DetectTime", - #~ skip=0, - #~ limit=None - #~ ) - - #~ res = db.query( - #~ db.range("Source.IP4.ip", "000.000.000.000", "255.255.255.255"), - #~ order="DetectTime", - #~ skip=0, - #~ limit=None - #~ ) - - #~ pprint.pprint(res) - print len(res) - -#~ cProfile.run("__main__()", sort="cumulative") + if 0: + with db.transaction(write=True) as txn: + #~ db.clear() + #~ return + #~ pprint.pprint(db.dump()) + #~ return + + # Import + for i, l in enumerate(mongo_idea_import.get_events(sys.stdin)): + if not i%1000 and not i==0: + print i + sys.stdout.flush() + db.insert(txn, l, mongo_idea_import.json_default) + + + with db.transaction() as txn: + res = [] + + if 1: + res = db.query(txn, + db.range(txn, "Source.IP4.ip", "000.000.000.000", "255.255.255.255"), + order="DetectTime", + #order=None, + skip=0, + limit=30 + ) + + if 0: + res = db.query(txn, + db.and_(txn, + db.range(txn, "DetectTime", '2016-06-05T22:00:00Z', '2016-06-05T22:00:04Z'), + db.range(txn, "Source.IP4.ip", "000.000.000.000", "255.255.255.255") + ), + order="DetectTime", + #order=None, + skip=0, + limit=30 + ) + + if 0: + res = db.query(txn, + db.and_(txn, + db.range(txn, "Target.IP4.ip", '195.113.000.000', '195.113.255.255'), + db.range(txn, "Source.IP4.ip", "071.006.165.000", "071.006.165.255"), + db.eq(txn, "Node.Name", "cz.cesnet.mentat.warden_filer"), + db.eq(txn, "Node.Type", "Relay"), + db.eq(txn, "Target.Port", " 22") + ), + order="DetectTime", + skip=0, + limit=30 + ) + + if 0: + res = db.query(txn, + db.and_(txn, + db.range(txn, "Target.IP4.ip", '147.230.000.000', '147.230.255.255'), + db.eq(txn, "Category", "Recon.Scanning"), + db.range(txn, "DetectTime", "2016-06-04T22:00:19Z", "2016-06-05T23:00:19Z"), + db.eq(txn, "Node.Name", "cz.tul.ward.dionaea"), + db.or_(txn, + db.eq(txn, "Target.Port", " 23"), + db.eq(txn, "Target.Port", " 6666") + ) + ), + order="DetectTime", + skip=0, + limit=30 + ) + + if 0: + res = db.query(txn, + db.eq(txn, "Source.IP4.ip", "071.006.165.200"), + order="DetectTime", + skip=0, + limit=30 + ) + + if 0: + res = db.query(txn, + db.range(txn, "Source.IP4.ip", "195.113.000.000", "195.113.255.255"), + order="DetectTime", + skip=0, + limit=30 + ) + + if 0: + res = db.query(txn, + db.range(txn, "Target.IP4.ip", "195.113.000.000", "195.113.255.255"), + order="DetectTime", + #order=None, + skip=0, + limit=30 + ) + + if res: + #~ pprint.pprint(res) + pprint.pprint(res[0]) + pprint.pprint(res[-1]) + print len(res) + + if 0: + pprint.pprint(db.dump(txn)) + + +#cProfile.run("__main__()", sort="cumulative") __main__() diff --git a/dumb_db.py b/dumb_db.py index 2e1db6b87c797e442aecc95c269dcf2b2c682fde..b709010d54e0d0d7291a87630232ba0c8de511c7 100644 --- a/dumb_db.py +++ b/dumb_db.py @@ -3,6 +3,7 @@ import collections import json +import zlib class Index(object): @@ -11,36 +12,36 @@ class Index(object): self.env = env - def insert(self, key=None, value=None): - raise NotImplemented + def insert(self, txn, key=None, value=None): + raise NotImplementedError - def query_eq_all(self, key): - raise NotImplemented + def query_eq_all(self, txn, key): + raise NotImplementedError - def query_eq(self, key): - raise NotImplemented + def query_eq(self, txn, key): + raise NotImplementedError - def query_ge(self, key): - raise NotImplemented + def query_ge(self, txn, key): + raise NotImplementedError - def query_le(self, key): - raise NotImplemented + def query_le(self, txn, key): + raise NotImplementedError - def query_range(self, key1, key2): - raise NotImplemented + def query_range(self, txn, key1, key2): + raise NotImplementedError - def clear(self): - raise NotImplemented + def clear(self, txn): + raise NotImplementedError - def dump(self): - raise NotImplemented + def dump(self, txn): + raise NotImplementedError @@ -73,28 +74,28 @@ class DB(object): self.data = idxfactory("__data__", dup=False, *args, **kwargs) - def insert(self, data, json_default=None): - uniq = self.data.insert(None, json.dumps(data, ensure_ascii = True, default = json_default))[0] + def insert(self, txn, data, json_default=None): + uniq = self.data.insert(txn, None, zlib.compress(json.dumps(data, ensure_ascii = True, default = json_default)))[0] for key in self.indices: - values = self.get_value(data, key.split(".")) + values = sorted(self.get_value(data, key.split("."))) for value in values: bin_value = self.binarize_str(value) if self.rev: - self.revkeys[key].insert(bin_value, uniq) + self.revkeys[key].insert(txn, bin_value, uniq) if self.fwd: - self.fwdkeys[key].insert(uniq, bin_value) + self.fwdkeys[key].insert(txn, uniq, bin_value) - def and_(self, *q): - raise NotImplemented + def and_(self, txn, *q): + raise NotImplementedError - def or_(self, *q): - raise NotImplemented + def or_(self, txn, *q): + raise NotImplementedError - def query(self, q, order=None, reverse=False, skip=0, limit=1): - raise NotImplemented + def query(self, txn, q, order=None, reverse=False, skip=0, limit=1): + raise NotImplementedError def get_value(self, data, path): @@ -119,19 +120,19 @@ class DB(object): return res - def dump(self): + def dump(self, txn): res = {} - res.update(self.data.dump()) + res.update(self.data.dump(txn)) if self.rev: for keyobj in self.revkeys.itervalues(): - res.update(keyobj.dump()) + res.update(keyobj.dump(txn)) if self.fwd: for keyobj in self.fwdkeys.itervalues(): - res.update(keyobj.dump()) + res.update(keyobj.dump(txn)) return res - def clear(self): + def clear(self, txn): if self.rev: for keyobj in self.revkeys.itervalues(): keyobj.clear() @@ -141,22 +142,22 @@ class DB(object): self.data.clear() - def _op(self, op, key, *args, **kwargs): + def _op(self, txn, op, key, *args, **kwargs): idxset = self.revkeys if self.rev else self.fwdkeys - return getattr(idxset[key], op)(*args, **kwargs) + return getattr(idxset[key], op)(txn, *args, **kwargs) - def eq(self, key, data): - return self._op("query_eq_all", key, data) + def eq(self, txn, key, data): + return self._op(txn, "query_eq_all", key, data) - def le(self, key, data): - return self._op("query_le", key, data) + def le(self, txn, key, data): + return self._op(txn, "query_le", key, data) - def ge(self, key, data): - return self._op("query_ge", key, data) + def ge(self, txn, key, data): + return self._op(txn, "query_ge", key, data) - def range(self, key, data1, data2): - return self._op("query_range", key, data1, data2) + def range(self, txn, key, data1, data2): + return self._op(txn, "query_range", key, data1, data2) diff --git a/keyvalue_db.py b/keyvalue_db.py index 5be22c07b0aad163022d49c12c5d4884ceb85fae..e7e1b4f06375e2ec4ae6de5c7fa9bd0451b3c547 100644 --- a/keyvalue_db.py +++ b/keyvalue_db.py @@ -2,34 +2,139 @@ # -*- encoding: utf-8 -*- import json +import zlib +import collections +import heapq +import pandas +import itertools +import operator from dumb_db import DB +set_factory = set +set_types = (set, pandas.Index) + +# We are returning iterators from subqueries, wherever we can, and +# iterators may be read event at the very end of processing. So we +# have to also cache the corresponding opened cursors, so they are +# not closed by lmdb sooner than we need them. +SubqueryTuple = collections.namedtuple("SubqueryTuple", ["cursors", "result"]) class KeyValueDB(DB): - # This is very naïve implementation. Users would benefit - # from exposing of underlying transaction api in some form. + # This is very naïve implementation. rev = True fwd = True - def and_(self, *q): - return set.intersection(*q) + def and_(self, txn, *q): + result = q[0].result + cursors = q[0].cursors + #if not isinstance(result, set): + if not isinstance(result, set_types): + print "creating set" + result = set_factory(result) + print "set len %i" % len(result) + for next in q[1:]: + print "and" + cursors = cursors.union(next.cursors) + result = result.intersection(next.result) + print len(result) + return SubqueryTuple( + cursors = cursors, + result = result + ) + + + def or_(self, txn, *q): + print "or" + result = q[0].result + cursors = q[0].cursors + if not isinstance(result, set_types): + print "creating set" + result = set_factory(result) + print "set len %i" % len(result) + for next in q[1:]: + print "and", + cursors = cursors.union(next.cursors) + result = result.union(next.result) + print len(result) + return SubqueryTuple( + cursors = cursors, + result = result + ) + + + def transaction(self, write=False): + return self.env.begin(buffers=False, write=write) + + + def query_heap(self, txn, q, order=None, reverse=False, skip=0, limit=1): + if order is not None: + key_func = self.fwdkeys[order].get_key_func(txn) + print "decorating" + res = [(key_func(val), val) for val in q.result] + print "heapify" + heapq.heapify(res) + print "pop, undecorate" + res = [heapq.heappop(res)[1] for i in range(skip+limit)] + else: + res = list(q.result) + print "res len %d" % len(res) + if skip or limit: + res = res[skip:skip+limit] + print "res limited len %d" % len(res) + + print "loading data" + return [json.loads(self.data.query_eq(txn, v)) for v in res] + #return [self.data.query_eq(v) for v in res] + + + def query_basesort(self, txn, q, order=None, reverse=False, skip=0, limit=1): + if order is not None: + print "sorting" + key_func = self.fwdkeys[order].get_key_func(txn) + #res = sorted(q.result, key=lambda k: self.fwdkeys[order].query_eq(txn, k), reverse=reverse) + res = sorted(q.result, key=key_func, reverse=reverse) + else: + res = list(q.result) + print "res len %d" % len(res) + if skip or limit: + # Note that slicing makes copy, so this will lose efficiency + # with big limits. + # We could use itertools.islice, but it considers input as + # a dumb iterator and will lose efficiency for big skips. + # Tough call, but let's assume big skip and small limit + # is more common. + res = res[skip:skip+limit] + print "res limited len %d" % len(res) + + # Here some form of cursor api would be appropriate - cursor would + # contain list of resulting IDs for free skipping and limiting + # while fetching only actualy read data. + print "loading data" + return [json.loads(zlib.decompress(self.data.query_eq(txn, v))) for v in res] + #return [self.data.query_eq(v) for v in res] - def or_(self, *q): - return set.union(*q) + def query_dictsort(self, txn, q, order=None, reverse=False, skip=0, limit=1): + def populate_cache(order): + print "caching" + if not hasattr(self, "memcache"): + self.memcache = {} + if not order in self.memcache: + self.memcache[order] = dict(self.fwdkeys[order].iteritems(txn).result) + return self.memcache[order] - def query(self, q, order=None, reverse=False, skip=0, limit=1): - # There is a bottleneck in sorting = query_eq is a python method, - # not C optimization. Maybe we could somehow draw out txn.get, maybe - # through functools.partial (implemented in C) if order is not None: - res = sorted(q, key=self.fwdkeys[order].query_eq, reverse=reverse) + order_memcache = populate_cache(order) + print "sorting" + #res = sorted(q.result, key=lambda k: self.fwdkeys[order].query_eq(txn, k), reverse=reverse) + res = sorted(q.result, key=order_memcache.get, reverse=reverse) else: - res = q + res = list(q.result) + print "res len %d" % len(res) if skip or limit: # Note that slicing makes copy, so this will lose efficiency # with big limits. @@ -38,9 +143,75 @@ class KeyValueDB(DB): # Tough call, but let's assume big skip and small limit # is more common. res = res[skip:skip+limit] + print "res limited len %d" % len(res) # Here some form of cursor api would be appropriate - cursor would # contain list of resulting IDs for free skipping and limiting # while fetching only actualy read data. - return [json.loads(self.data.query_eq(v)) for v in res] + print "loading data" + return [json.loads(zlib.decompress(self.data.query_eq(txn, v))) for v in res] + #return [self.data.query_eq(v) for v in res] + + + def query_walksort(self, txn, q, order=None, reverse=False, skip=0, limit=1): + if not isinstance(q.result, set_types): + res_set = set_factory(q.result) + else: + res_set = q.result + print "res len %d" % len(res_set) + #if order is not None: + # print "sorting" + # res = [] + # enough = (skip or 0) + (limit or len(res_set)-skip) + # for k in self.revkeys[order].iteritems(txn).result: + # if k in res_set: + # res.append(k) + # if len(res) > enough: + # break + if order is not None: + print "sorting" + res = [] + enough = (skip or 0) + (limit or len(res_set)-skip) + res_iter = itertools.ifilter(res_set.__contains__, self.revkeys[order].itervalues(txn).result) + if skip or limit: + slice_iter = itertools.islice(res_iter, skip, skip+limit) + else: + slice_iter = res_iter + res = list(slice_iter) + else: + res = list(res_set) + if skip or limit: + res = res[skip:skip+limit] + + print "res limited len %d" % len(res) + + print "loading data" + return [json.loads(zlib.decompress(self.data.query_eq(txn, v))) for v in res] #return [self.data.query_eq(v) for v in res] + + + def query_pandas(self, txn, q, order=None, reverse=False, skip=0, limit=1): + + def gen_pandas_index(order): + return pandas.Index(self.revkeys[order].iteritems(txn).result) + + if not isinstance(q.result, set_types): + res_set = set_factory(q.result) + else: + res_set = q.result + if order is not None: + print "sorting" + res = pandas.Index(q.result) & gen_pandas_index(order) + else: + res = list(res_set) + print "res len %d" % len(res) + if skip or limit: + res = res[skip:skip+limit] + print "res limited len %d" % len(res) + + print "loading data" + return [json.loads(self.data.query_eq(txn, v)) for v in res] + #return [self.data.query_eq(v) for v in res] + + #query = query_basesort + query = query_walksort diff --git a/lmdb_index.py b/lmdb_index.py index ccb0d48519455971a71d34db9d670f1910945e7d..3b40537922e79bfc4abacb5e7d1dc035fe202a84 100644 --- a/lmdb_index.py +++ b/lmdb_index.py @@ -3,95 +3,192 @@ import lmdb import random +import functools +import operator +import itertools +import struct from dumb_db import Index +from keyvalue_db import SubqueryTuple + +def preread_func(x): + res = set(x) + print "len %i" % len(res) + return res + +preread_all = lambda x: x +#preread_all = preread_func class LMDBIndex(Index): - + def __init__(self, name, env, dup=False): Index.__init__(self, name, env) self.dup = dup self.handle = self.env.open_db(self.name, dupsort=self.dup) - - - def insert(self, key=None, value=None): + + + @staticmethod + def int32_to_bin(i): + return "%08x" % i + + + @staticmethod + def bin_to_int32(b): + return int(b, 16) + + + def insert_orig(self, txn, key=None, value=None): if key is None: key = "%016x" % random.randint(0, 2**64-1) # may not be safe enough - with self.env.begin(buffers=False, write=True) as txn: - txn.put(key, value, db=self.handle, dupdata=self.dup, overwrite=True) - return key, value + txn.put(key, value, db=self.handle, dupdata=self.dup, overwrite=True) + return key, value - # Note here, that we use set(iterator) construct in hope that - # fast C implementation of lmdb cursor iterator connects directly + def insert(self, txn, key=None, value=None): + if key is None: + max_key = self.query_max_key(txn) + int_key = self.bin_to_int32(max_key) + 1 + key = self.int32_to_bin(int_key) + # Lmdb allows for only one writable transaction, so there should be no race. + txn.put(key, value, db=self.handle, dupdata=self.dup, overwrite=self.dup) + return key, value + + + def query_max_key(self, txn): + with txn.cursor(db=self.handle) as crs: + if crs.last(): + return crs.key() + else: + return self.int32_to_bin(0) + + + # Note that we return iterator in hope that up in the stack the + # fast C implementation of lmdb cursor will connect directly # to the fast C implementation of the Python set(), without any - # interpreted code intervention. And, as it seems, it does. - - def query_eq_all(self, key): - with self.env.begin(buffers=False) as txn: - with txn.cursor(db=self.handle) as crs: - crs.set_key(key) - return set(crs.iternext_dup(keys=False, values=True)) - - - def query_eq(self, key): - with self.env.begin(buffers=False) as txn: - return txn.get(key, db=self.handle) - - - def query_ge(self, key): - with self.env.begin(buffers=False) as txn: - with txn.cursor(db=self.handle) as crs: - crs.set_range(key) - return set(crs.iternext(keys=False, values=True)) - - - def query_le(self, key): - # Reverse reading from underlaying media may have limits - # Another implementation could be to start from the very first - # item and iterate until key is reached. However for problems - # with this approach see comments in query_range. - with self.env.begin(buffers=False) as txn: - with txn.cursor(db=self.handle) as crs: - crs.set_range(key) - it = crs.iterprev(keys=False, values=True) - try: - next(it) - except StopIteration: - return set() - return set(it) - - - def query_range(self, key1, key2): - # Not quite correct, may return events which contain - # one IP address greater than both keys and second IP - # address lower than both keys - - # Possible correct implementations: - # * fetch and intersect keys, not values, then get ids for resulting keys - # * get query_ge iterator for key1, then fetch keys until key2 is - # reached. - # Problem is how to implement fast comparison with key2 without - # sacrificing iter->set C speed. - # Maybe operator.lt/gt (C) and itertools.takewhile or - # itertools.ifilter (C)? - return self.query_ge(key1) & self.query_le(key2) - - - def clear(self): - with self.env.begin(buffers=True, write=True) as txn: - txn.drop(db=self.handle, delete=False) - - - def dump(self): + # interpreted code intervention. + + def query_eq_all(self, txn, key): + crs = txn.cursor(db=self.handle) + print "eq_all" + crs.set_key(key) + return SubqueryTuple( + set([crs]), + preread_all(crs.iternext_dup(keys=False, values=True)), + ) + + + def query_eq(self, txn, key): + return txn.get(key, db=self.handle) + + + def get_key_func(self, txn): + return functools.partial(txn.get, db=self.handle) + + + def query_ge(self, txn, key): + crs = txn.cursor(db=self.handle) + print "ge" + crs.set_range(key) + return SubqueryTuple( + set([crs]), + preread_all(crs.iternext(keys=False, values=True)), + ) + + + def query_le(self, txn, key): + # Reverse reading from underlying media may have limits. + crs = txn.cursor(db=self.handle) + print "le" + crs.set_range(key) + it = crs.iterprev(keys=False, values=True) + try: + next(it) + except StopIteration: + it = iter(()) + return SubqueryTuple( + set([crs]), + preread_all(it), + ) + + + def query_range_orig(self, txn, key1, key2): + crs = txn.cursor(db=self.handle) + print "range" + # Set lmdb cursor to lower range + crs.set_range(key1) + # Get iterator for _both_ keys and values + # Keys are what is filtered on, values are event id's + lmdb_iter = crs.iternext(keys=True, values=True) + # Iterator reader, which stops when upper bound is reached + # Note that here is Python lambda, which I don't know how + # to replace with faster implementation. :( + stop_iter = itertools.takewhile(lambda x: x[0]<=key2, lmdb_iter) + # Now return only id iterator + res_iter = itertools.imap(operator.itemgetter(1), stop_iter) + return SubqueryTuple( + set([crs]), + preread_all(res_iter), + ) + + + def query_range(self, txn, key1, key2): + # Create possibly C wrapped C implementation of comparison + comparator = functools.partial(operator.ge, key2) + crs = txn.cursor(db=self.handle) + print "range" + # Set lmdb cursor to lower range + crs.set_range(key1) + # Get iterator for values + # Keys are what is filtered on, values are event id's + lmdb_iter = crs.iternext(keys=True, values=False) + # Iterate while comparator returns True + keys = itertools.takewhile(comparator, lmdb_iter) + # Count length + length = len(list(keys)) + # Set cursor back to lower range + crs.set_range(key1) + # Get iterator for keys + lmdb_iter = crs.iternext(keys=False, values=True) + # Just fetch 'length' values + res_iter = itertools.islice(lmdb_iter, 0, length) + return SubqueryTuple( + set([crs]), + preread_all(res_iter), + ) + + + def itervalues(self, txn): + crs = txn.cursor(db=self.handle) + print "all vals" + lmdb_iter = crs.iternext(keys=False, values=True) + return SubqueryTuple( + set([crs]), + preread_all(lmdb_iter), + ) + + + def iteritems(self, txn): + crs = txn.cursor(db=self.handle) + print "all items" + lmdb_iter = crs.iternext(keys=True, values=True) + return SubqueryTuple( + set([crs]), + preread_all(lmdb_iter), + ) + + + def clear(self, txn): + txn.drop(db=self.handle, delete=False) + + + def dump(self, txn): res = {} - with self.env.begin(buffers=False) as txn: - with txn.cursor(db=self.handle) as crs: - crs.first() - for key, value in crs.iternext(keys=True, values=True): - if self.dup: - res.setdefault(key, []).append(value) - else: - res[key] = value + with txn.cursor(db=self.handle) as crs: + crs.first() + for key, value in crs.iternext(keys=True, values=True): + if self.dup: + res.setdefault(key, []).append(value) + else: + res[key] = value return {self.name: res}