from uuid import uuid4
import re
from time import gmtime, strftime


class StixGenerator(object):
    def sighting_object(self, identity, conn_count, observed_data, alert_type):
        return {
            'type': "sighting",
            'id': "sighting--" + str(uuid4()),
            'created_by_ref': identity,
            'created': strftime("%y-%m-%dT%H:%M:%S", gmtime()),
            'count': conn_count,
            'sighting_of_ref': alert_type,
            'observed_data_refs': [observed_data],
            'where_sighted_refs': [identity]
        }

    def identity_object(self, node):
        return {
            'type': "identity",
            'id': "identity--" + str(uuid4()),
            'name': node[0]['Name'],
            'labels': node[0]['Type'],
            'description': "".join(node[0]['SW']),
            'identity_class': "technology"
        }

    def ipvx_addr_object(self, addr_objects, object_counter, source=False):
        objects = {}
        ip_references = {}
        for ip_dict in addr_objects:
            af = "IP4" if addr_objects[0].get('IP4') else "IP6"
            for ip_addr in ip_dict[af]:
                objects[str(object_counter[-1])] = {'type': "ipv4-addr" if af == "IP4" else "ipv6-addr",
                                                    'value': ip_addr}
                object_counter.append(object_counter[-1] + 1)
            ip_references[tuple(ip_dict['Port'])] = [object_counter[0:-1], ([ip_dict['Proto']] if source else [])]
            object_counter = [object_counter[-1]]
        return ip_references, object_counter, objects

    def one_network_traffic_object(self, src_ip_references=None, dst_ip_references=None):
        network_traffic = {
            'type': "network-traffic"
        }
        if src_ip_references:
            for port in src_ip_references.keys():
                network_traffic['src_ref'] = [str(ip_key) for ip_key in src_ip_references[port][0]]
                network_traffic['protocols'] = src_ip_references[port][1][0]
                network_traffic['src_port'] = port[0] if len(port) == 1 else port
        if dst_ip_references:
            for port in dst_ip_references.keys():
                network_traffic['dst_ref'] = [str(ip_key) for ip_key in dst_ip_references[port][0]]
                network_traffic['dst_port'] = port[0] if len(port) == 1 else port
        return network_traffic

    def all_network_traffic_objects(self, src_ip_references, dst_ip_references, object_counter):
        objects = {}
        if len(src_ip_references) > 1 and len(dst_ip_references) > 1:
            network_state = "go_through_src_dst"
        elif len(dst_ip_references) == 0 or len(src_ip_references) > len(dst_ip_references):
            network_state = "go_through_src"
        elif len(src_ip_references) == 0 or len(dst_ip_references) > len(src_ip_references):
            network_state = "go_through_dst"
        else:
            network_state = "one_object"
        network_opts = {
            'go_through_src_dst': {'dst_params': {'dst_ip_references': None},
                                   'src_params': {'src_ip_references': None}},
            'go_through_src': {'dst_params': {'dst_ip_references': dst_ip_references}},
            'go_through_dst': {'src_params': {'src_ip_references': src_ip_references}},
        }
        if re.search("src", network_state):
            for port, list_of_src_ip in src_ip_references.items():
                objects[str(object_counter)] = self.one_network_traffic_object({port: list_of_src_ip},
                                                    network_opts[network_state]['dst_params']['dst_ip_references'])
                object_counter += 1
        if re.search("dst", network_state):
            for port, list_of_dst_ip in dst_ip_references.items():
                objects[str(object_counter)] = self.one_network_traffic_object(network_opts[network_state]
                                                    ['src_params']['src_ip_references'], {port: list_of_dst_ip})
                object_counter += 1
        if network_state == "one_object":
            objects[str(object_counter)] = self.one_network_traffic_object(src_ip_references, dst_ip_references)
        return objects

    def external_references(self, refs):
        ext_references = []
        for record in refs:
            if re.search("^url:", record):
                ext_references.append({'url': record[4:]})
            else:
                ext_references.append({'source_name': record.split(":")[0],
                                       'external_id': record.split(":")[1]})
        return ext_references

    def observed_data_object(self, identity, data, labels=False):
        observed_data = {
            'type': "observed-data",
            'id': "observed-data--" + str(uuid4()),
            'created_by_ref': identity,
            'created': data['DetectTime'],
            'first_observed': data['EventTime'] if data.get('EventTime') else data['DetectTime'],
            'last_observed': data['CeaseTime'] if data.get('CeaseTime') else data['DetectTime'],
            'number-observed': data['ConnCount'] if data.get('ConnCount') else 1
        }
        if data['Ref']:
            observed_data['external_references'] = self.external_references(data['Ref'])
        if labels:
            observed_data['labels'] = data['Category']
        object_counter = [0]
        # process source and target data
        if data.get('Source') and data.get('Target'):
            src_ip_references, object_counter, src_objects = self.ipvx_addr_object(data['Source'], 
                                                                  object_counter, True)
            dst_ip_references, object_counter, dst_objects = self.ipvx_addr_object(data['Target'], 
                                                                                            object_counter)
            objects = {**src_objects, **dst_objects}
        elif data.get('Target'):
            dst_ip_references, object_counter, objects = self.ipvx_addr_object(data['Target'], object_counter)
            src_ip_references = {}
        elif data.get('Source'):
            src_ip_references, object_counter, objects = self.ipvx_addr_object(data['Source'], object_counter, 
                                                                                        True)
            dst_ip_references = {}
        else:
            objects = None
        if objects:
            object_counter = object_counter[-1]
            network_objects = self.all_network_traffic_objects(src_ip_references, dst_ip_references, 
                                                                        object_counter)
            observed_data['objects'] = {**objects, **network_objects}
        return observed_data

    def alert_object(self, category, ref):
        if re.search("Vulnerability", category):
            vulnerability = {
                'type': "Vulnerability",
                'id': "vulnerability--" + str(uuid4()),
                'name': "unknown"
            }
            if ref:
                vulnerability['external_references'] = self.external_references(ref)
            return vulnerability
        else:
            return {
                'type': "malware",
                'id': "malware--" + str(uuid4()),
                'name': "unknown",
                'labels': [category]
            }