#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2020 Cesnet z.s.p.o
# Use of this source is governed by a 3-clause BSD-style license, see LICENSE file.

"""
Wardenfiler output connector. Writes audit logs to Wardenfiler spool directory in IDEA format
"""

import os
import errno
import socket
import json
import hashlib
import logging
import string
from urllib.parse import urlparse
from time import time, gmtime, strftime
from datetime import datetime
from uuid import uuid4
from hashlib import sha1
from base64 import b64encode
from ipaddress import ip_address
from ipaddress import IPv4Network
from ipaddress import IPv6Network

from dionaea import IHandlerLoader
from dionaea.core import ihandler, connection
from dionaea.exception import LoaderError

logger = logging.getLogger("log_wardenfiler")
logger.setLevel(logging.DEBUG)

class Filer(object):
    """
    IDEA files creator
    """

    def __init__(self, directory):
        self.basedir = self._ensure_path(directory)
        self.tmp = self._ensure_path(os.path.join(self.basedir, "tmp"))
        self.incoming = self._ensure_path(os.path.join(self.basedir, "incoming"))
        self.hostname = socket.gethostname()
        self.pid = os.getpid()

    def _ensure_path(self, p):
        try:
            os.mkdir(p)
        except OSError:
            if not os.path.isdir(p):
                raise
        return p

    def _get_new_name(self, fd=None):
        (inode, device) = os.fstat(fd)[1:3] if fd else (0, 0)
        return "%s.%d.%f.%d.%d" % (
            self.hostname, self.pid, time(), device, inode)

    def create_unique_file(self):
        tmpname = None
        while not tmpname:
            tmpname = self._get_new_name()
            try:
                fd = os.open(os.path.join(self.tmp, tmpname), os.O_CREAT | os.O_RDWR | os.O_EXCL)
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise
                tmpname = None
        newname = self._get_new_name(fd)
        os.rename(os.path.join(self.tmp, tmpname), os.path.join(self.tmp, newname))
        nf = os.fdopen(fd, "w")
        return nf, newname

    def publish_file(self, short_name):
        os.rename(os.path.join(self.tmp, short_name), os.path.join(self.incoming, short_name))


class LogWardenfilerHandlerLoader(IHandlerLoader):
    name = "log_wardenfiler"

    @classmethod
    def start(cls, config=None):
        try:
            return LogWardenfilerHandler("*", config=config)
        except LoaderError as e:
            logger.error(e.msg, *e.args)
            return None

class LogWardenfilerHandler(ihandler):
    detector_name = None
    resolve_nat = False
    nat_host = "gateway"
    nat_port = 1456
    anon_mask_4 = 32
    anon_mask_6 = 128
    aggr_win = 5 * 60
    test_mode = True
    output_dir = "var/spool/warden"
    drop_malware = True
    win_start = None
    attackers = {}
    sessions = {}

    def __init__(self, path, config = None):
        logger.debug("%s ready!", self.__class__.__name__)
        ihandler.__init__(self, path)
        self.path = path
        self._config = config

    def _bytes_to_str(self, s):
        if isinstance(s, str):
            return s
        return str(s, "utf-8", "backslashreplace")

    def _fixup_event(self, event):
        if 'database' in event and isinstance(event['database'], bytes):
            event['database'] = self._bytes_to_str(event['database'])
        return event

    def _save_event(self, event):
        event = self._fixup_event(event)
        f, name = self.filer.create_unique_file()
        with f:
            f.write(json.dumps(event, ensure_ascii = True))
        self.filer.publish_file(name)

    def _format_credentials(self, creds=[]):
        return list(
            map(lambda c:
                {
                "Username": c.get("User"),
                "Password": c.get("Password")
                }, creds
            )
        )

    def start(self):
        if 'detector_name' in self._config:
            self.detector_name = self._config.get('detector_name')
        if 'resolve_nat' in self._config:
            self.resolve_nat = self._config.get('resolve_nat')
        if 'nat_host' in self._config:
            self.nat_host = self._config.get('nat_host')
        if 'nat_port' in self._config:
            self.nat_port = self._config.get('nat_port')
        if 'reported_ipv4' in self._config:
            self.reported_ipv4 = self._config.get('reported_ipv4')
        if 'reported_ipv6' in self._config:
            self.reported_ipv6 = self._config.get('reported_ipv6')
        if 'anon_mask_4' in self._config:
            self.anon_mask_4 = self._config.get('anon_mask_4')
        if 'anon_mask_6' in self._config:
            self.anon_mask_6 = self._config.get('anon_mask_6')
        if 'aggr_win' in self._config:
            self.aggr_win = self._config.get('aggr_win')
        if 'test_mode' in self._config:
            self.test_mode = self._config.get('test_mode')
        if 'output_dir' in self._config:
            self.output_dir = self._config.get('output_dir')
        if 'drop_malware' in self._config:
            self.drop_malware = self._config.get('drop_malware')

        self.filer = Filer(self.output_dir)

    def _aggregate(self):
        ws = self.win_start or time()
        if (time() - ws >= self.aggr_win):
            logger.info("Counting attacks: %s" % json.dumps(self.attackers, ensure_ascii = True))
            we = datetime.utcfromtimestamp(ws + self.aggr_win).isoformat() + 'Z'
            sevent = {
                "Format": "IDEA0",
                "WinStartTime": datetime.utcfromtimestamp(ws).isoformat() + 'Z',
                "WinEndTime": we,
                "DetectTime": we,
                "Category": [],
                "Node": [
                    {
                        "Name": self.detector_name,
                        "Type": ["Connection", "Auth", "Honeypot"],
                        "SW": ["Dionaea with Warden Filer output module"],
                        "AggrWin": strftime("%H:%M:%S", gmtime(float(self.aggr_win)))
                    }
                ]
            }
            if self.test_mode:
                sevent["Category"].append("Test")

            for i, a in self.attackers.items():
                c = a["count"]
                if c > 1:
                    src_ip, dst_ip, dst_port, proto = i.split(',')
                    sevent["ID"] = str(uuid4())
                    if len(a["creds"]):
                        sevent["Category"] = ["Recon.Scanning"]
                        sevent["Description"] = "Successful logins to honeypoted service."
                    else:
                        sevent["Category"] = ["Attempt.Login"]
                        sevent["Description"] = "Connection attempts to IPs assigned to honeypot."
                    sevent["ConnCount"] = c
                    af = "IP4" if not ':' in src_ip else "IP6"
                    proto = [proto]
                    if a["proto"]:
                        proto.append(a["proto"])
                    sevent["Source"] = [{"Proto": proto, af: [src_ip], "Port": a["sports"]}]
                    sevent["Target"] = [{"Proto": proto, af: [dst_ip], "Port": [int(dst_port)]}]
                    if (self.anon_mask_4 < 32) and (not ':' in  dst_ip) or (self.anon_mask_6 < 128):
                        sevent["Target"][0]["Anonymised"] = "true"
                    if len(a["creds"]):
                        sevent["Credentials"] = self._format_credentials(a["creds"])
                    self._save_event(sevent)
                    logger.info("sending scanning event for %s probing %s (%i times)" % (src_ip, dst_ip, c))
            self.attackers = {}
            self.win_start = time()

    def _make_idea(self, con):
        s = self.sessions[con]
        proto = [s["trans"]]
        if s["proto"]:
            proto.append(s["proto"])
        event = {
            "Format": "IDEA0",
            "ID": s["id"],
            "DetectTime": s["dt"],
            "Category": s["cat"],
            "Source": [{"Proto": proto, s["af"]: [s["src_ip"]], "Port": [s["src_port"]]}],
            "Target": [{ "Proto": proto, s["af"]: [s["dst_ip"]], "Port": [s["dst_port"]]}],
            "Node": [
                {
                    "Name": self.detector_name,
                    "Type": ["Connection", "Auth", "Honeypot"],
                    "SW": ["Dionaea with Warden Filer output module"],
                }
            ]
        }

        if s["anon"]:
            event["Target"][0]["Anonymised"] = "true"

        if len(s["creds"]):
            p = {
                "ftp": "FTP",
                "mysql": "MySQL",
                "ms-sql-s": "MSSQL",
            }
            event["Category"].append("Intrusion.UserCompromise")
            if s["proto"]:
                event["Description"] = p[s["proto"]] + " successful login"
            else:
                event["Description"] = "Successful login attempt"
            creds = self._format_credentials(s["creds"])
            
            if "Credentials" not in event:
                event["Credentials"] = []
            event["Credentials"].extend(creds)
        else:
            # login without password or similar thing
            event["Category"].append("Recon.Scanning")
            event["Description"] = "Connection"

        if len(s["cmds"]):
            # consider this an exploit only if there was a login attempt
            if len(s["creds"]):
                event["Category"].append("Attempt.Exploit")
            event["Description"] += " with command input"
            idata = "\n".join(str(c) for c in s["cmds"])
            plain = all(c in string.printable for c in idata)
            eidata = idata if plain else b64encode(idata.encode()).decode()
            attach = {
                "Type": ["Exploit"],
                "Hash": ["sha1:" + sha1(idata.encode("utf-8")).hexdigest()],
                "Size": len(idata),
                "Description": "Commands entered by attacker during honeypot session",
                "Content": eidata
            }
            if not plain:
                attach["ContentEncoding"] = "base64"
            if "Attach" not in event:
                event["Attach"] = []
            event["Attach"].append(attach)

        return(event)

    def _register_connection(self, con, proto = None, cred = None, cmd = None):
        if not con in self.sessions:
            self.sessions[con] = {}

            src_ip = con.remote.host
            dst_ip = con.local.host
            if src_ip.startswith("::ffff:"):
                src_ip = src_ip[7:]
            if dst_ip.startswith("::ffff:"):
                dst_ip = dst_ip[7:]

            af = "IP4" if not ':' in src_ip else "IP6"

            # Test for static IP to report as attack target
            if af == "IP4" and self.reported_ipv4:
                dst_ip = self.reported_ipv4
            # Resolve NAT if instructed
            elif af == "IP4" and self.resolve_nat:
                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                s.connect((self.nat_host, self.nat_port))
                s.sendall((','.join((src_ip, str(con.remote.port))).encode("utf-8")))
                dst = s.recv(50).decode("utf-8")
                s.close()
                if dst != "NE":
                    dst_ip = dst
                else:
                    logger.warn("no translation for %s:%s" % (src_ip, con.remote.port))
                    return()
            elif af == "IP6" and self.reported_ipv6:
                dst_ip = self.reported_ipv6

            
            anon = (self.anon_mask_4 < 32) and (not ':' in  dst_ip) or (self.anon_mask_6 < 128)
            if anon:
                dst_ip = [(
                    str(IPv4Network("/".join((dst_ip, str(self.anon_mask_4))), False).network_address) if not ':' in dst_ip else
                    str(IPv6Network("/".join((dst_ip, str(self.anon_mask_6))), False).network_address)
                )]

            self.sessions[con]["id"] = str(uuid4())
            self.sessions[con]["dt"] = datetime.utcnow().isoformat() + "Z"
            self.sessions[con]["cat"] = ["Test"] if self.test_mode else []
            self.sessions[con]["af"] = af
            self.sessions[con]["anon"] = anon
            self.sessions[con]["src_ip"] = src_ip
            self.sessions[con]["dst_ip"] = dst_ip
            self.sessions[con]["src_port"] = con.remote.port
            self.sessions[con]["dst_port"] = con.local.port
            self.sessions[con]["trans"] = con.transport
            self.sessions[con]["proto"] = None
            self.sessions[con]["creds"] = []
            self.sessions[con]["cmds"] = []

        aid = ','.join((self.sessions[con]["src_ip"], self.sessions[con]["dst_ip"], str(con.local.port), con.transport))

        if not aid in self.attackers:
            self.attackers[aid] = {
                "count": 0,
                "sports": [],
                "creds": [],
                "proto": None
            }

        self.attackers[aid]["count"] += 1
        if not con.remote.port in self.attackers[aid]["sports"]:
            self.attackers[aid]["sports"].append(con.remote.port)
        if proto:
            self.sessions[con]["proto"] = proto
            self.attackers[aid]["proto"] = proto
        if cred:
            self.sessions[con]["creds"].append(cred)
            self.attackers[aid]["creds"].append(cred)
        if cmd:
            self.sessions[con]["cmds"].append(cmd)

    def handle_incident(self, icd):
        pass

    def handle_incident_dionaea_connection_tcp_listen(self, icd):
        pass;

    def handle_incident_dionaea_connection_tls_listen(self, icd):
        pass

    def handle_incident_dionaea_connection_tcp_connect(self, icd):
        con = icd.con
        self._register_connection(con)
        logger.info("connect connection to %s/%s:%i from %s:%i" % (con.remote.host, con.remote.hostname, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_connection_tls_connect(self, icd):
        con = icd.con
        self._register_connection(con, "ssl-tls")
        logger.info("connect connection to %s/%s:%i from %s:%i" % (con.remote.host, con.remote.hostname, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_connection_udp_connect(self, icd):
        con = icd.con
        self._register_connection(con)
        logger.info("connect connection to %s/%s:%i from %s:%i" % (con.remote.host, con.remote.hostname, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_connection_tcp_accept(self, icd):
        con = icd.con
        self._register_connection(con)
        logger.info("accepted connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_connection_tls_accept(self, icd):
        con = icd.con
        self._register_connection(con, "ssl-tls")
        logger.info("accepted connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_connection_tcp_reject(self, icd):
        con = icd.con
        self._register_connection(con)
        logger.info("reject connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_modules_python_ftp_command(self, icd):
        con = icd.con
        cmd = icd.command.decode()
        if hasattr(icd, 'arguments'):
            cmd += " " + " ".join(icd.arguments)
        self._register_connection(con, "ftp", cmd = cmd)
        logger.info("new FTP command within connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_modules_python_mssql_cmd(self, icd):
        con = icd.con
        self._register_connection(con, "ms-sql-s", cmd = icd.cmd)
        logger.info("new MSSQL command within connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_modules_python_mysql_command(self, icd):
        con = icd.con
        cmd = str(icd.command)
        if hasattr(icd, 'args'):
            cmd += "\n" + "\n".join(icd.args)
        self._register_connection(con, "mysql", cmd = cmd)
        logger.info("new MYSQL command within connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_modules_python_ftp_login(self, icd):
        con = icd.con
        self._register_connection(con, "ftp",  cred = {"User": self._bytes_to_str(icd.username), "Password": self._bytes_to_str(icd.password)})
        logger.info("new FTP login within connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_modules_python_mssql_login(self, icd):
        con = icd.con
        self._register_connection(con, "ms-sql-s",  cred = {"User": self._bytes_to_str(icd.username), "Password": self._bytes_to_str(icd.password)})
        logger.info("new MSSQL login within connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_modules_python_mysql_login(self, icd):
        con = icd.con
        self._register_connection(con, "mysql",  cred = {"User": self._bytes_to_str(icd.username), "Password": self._bytes_to_str(icd.password)})
        logger.info("new MySQL login within connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))

    def handle_incident_dionaea_modules_python_p0f(self, icd):
        pass;

    def handle_incident_dionaea_connection_free(self, icd):
        con = icd.con

        self._aggregate()

        if con in self.sessions:
            s = self.sessions[con]

            # Do not generate IDEA event for a source
            # which is not globally routable
            if not ip_address(s["src_ip"]).is_global:
                logger.info("not generating an event for connection from non-global IP %s:%s" % (con.remote.host, con.remote.port))

            elif s.get("cmds"):
                event = self._make_idea(con)
                self._save_event(event)
                logger.info("sending connection event from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))
            self.sessions.pop(con, None)
            logger.info("closing connection from %s:%i to %s:%i" % (con.remote.host, con.remote.port, con.local.host, con.local.port))
        else:
            logger.warn("no attack data for %s:%s" % (con.remote.host, con.remote.port))