#!/usr/bin/python
# -*- encoding: utf-8 -*-

import lmdb
import random
import functools
import operator
import itertools
import struct

from dumb_db import Index
from keyvalue_db import SubqueryTuple

def preread_func(x):
    res = set(x)
    print "len %i" % len(res)
    return res

preread_all = lambda x: x
#preread_all = preread_func

class LMDBIndex(Index):


    def __init__(self, name, env, dup=False):
        Index.__init__(self, name, env)
        self.dup = dup
        self.handle = self.env.open_db(self.name, dupsort=self.dup)


    @staticmethod
    def int32_to_bin(i):
        return "%08x" % i


    @staticmethod
    def bin_to_int32(b):
        return int(b, 16)


    def insert_orig(self, txn, key=None, value=None):
        if key is None:
            key = "%016x" % random.randint(0, 2**64-1)  # may not be safe enough
        txn.put(key, value, db=self.handle, dupdata=self.dup, overwrite=True)
        return key, value


    def insert(self, txn, key=None, value=None):
        if key is None:
            max_key = self.query_max_key(txn)
            int_key = self.bin_to_int32(max_key) + 1
            key = self.int32_to_bin(int_key)
        # Lmdb allows for only one writable transaction, so there should be no race.
        txn.put(key, value, db=self.handle, dupdata=self.dup, overwrite=self.dup)
        return key, value


    def query_max_key(self, txn):
        with txn.cursor(db=self.handle) as crs:
            if crs.last():
                return crs.key()
            else:
                return self.int32_to_bin(0)


    # Note that we return iterator in hope that up in the stack the
    # fast C implementation of lmdb cursor will connect directly
    # to the fast C implementation of the Python set(), without any
    # interpreted code intervention.

    def query_eq_all(self, txn, key):
        crs = txn.cursor(db=self.handle)
        print "eq_all"
        crs.set_key(key)
        return SubqueryTuple(
            set([crs]),
            preread_all(crs.iternext_dup(keys=False, values=True)),
        )


    def query_eq(self, txn, key):
        return txn.get(key, db=self.handle)


    def get_key_func(self, txn):
        return functools.partial(txn.get, db=self.handle)


    def query_ge(self, txn, key):
        crs = txn.cursor(db=self.handle)
        print "ge"
        crs.set_range(key)
        return SubqueryTuple(
            set([crs]),
            preread_all(crs.iternext(keys=False, values=True)),
        )


    def query_le(self, txn, key):
        # Reverse reading from underlying media may have limits.
        crs = txn.cursor(db=self.handle)
        print "le"
        crs.set_range(key)
        it = crs.iterprev(keys=False, values=True)
        try:
            next(it)
        except StopIteration:
            it = iter(())
        return SubqueryTuple(
            set([crs]),
            preread_all(it),
        )


    def query_range_orig(self, txn, key1, key2):
        crs = txn.cursor(db=self.handle)
        print "range"
        # Set lmdb cursor to lower range
        crs.set_range(key1)
        # Get iterator for _both_ keys and values
        # Keys are what is filtered on, values are event id's
        lmdb_iter = crs.iternext(keys=True, values=True)
        # Iterator reader, which stops when upper bound is reached
        # Note that here is Python lambda, which I don't know how
        # to replace with faster implementation. :(
        stop_iter = itertools.takewhile(lambda x: x[0]<=key2, lmdb_iter)
        # Now return only id iterator
        res_iter = itertools.imap(operator.itemgetter(1), stop_iter)
        return SubqueryTuple(
            set([crs]),
            preread_all(res_iter),
        )


    def query_range(self, txn, key1, key2):
        # Create possibly C wrapped C implementation of comparison
        comparator = functools.partial(operator.ge, key2)
        crs = txn.cursor(db=self.handle)
        print "range"
        # Set lmdb cursor to lower range
        crs.set_range(key1)
        # Get iterator for values
        # Keys are what is filtered on, values are event id's
        lmdb_iter = crs.iternext(keys=True, values=False)
        # Iterate while comparator returns True
        keys = itertools.takewhile(comparator, lmdb_iter)
        # Count length
        length = len(list(keys))
        # Set cursor back to lower range
        crs.set_range(key1)
        # Get iterator for keys
        lmdb_iter = crs.iternext(keys=False, values=True)
        # Just fetch 'length' values
        res_iter = itertools.islice(lmdb_iter, 0, length)
        return SubqueryTuple(
            set([crs]),
            preread_all(res_iter),
        )


    def itervalues(self, txn):
        crs = txn.cursor(db=self.handle)
        print "all vals"
        lmdb_iter = crs.iternext(keys=False, values=True)
        return SubqueryTuple(
            set([crs]),
            preread_all(lmdb_iter),
        )


    def iteritems(self, txn):
        crs = txn.cursor(db=self.handle)
        print "all items"
        lmdb_iter = crs.iternext(keys=True, values=True)
        return SubqueryTuple(
            set([crs]),
            preread_all(lmdb_iter),
        )


    def clear(self, txn):
        txn.drop(db=self.handle, delete=False)


    def dump(self, txn):
        res = {}
        with txn.cursor(db=self.handle) as crs:
            crs.first()
            for key, value in crs.iternext(keys=True, values=True):
                if self.dup:
                    res.setdefault(key, []).append(value)
                else:
                    res[key] = value
        return {self.name: res}