#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Cesnet z.s.p.o
# Use of this source is governed by a 3-clause BSD-style license, see LICENSE file.

"""
Wardenfiler output connector. Writes audit logs to Wardenfiler spool directory in IDEA format
"""

import os
import errno
import socket
import json
import string
from urllib.parse import urlparse
from time import time, gmtime, strftime
from datetime import datetime
from uuid import uuid4
from hashlib import sha1
from base64 import b64encode
from ipaddress import ip_address
from ipaddress import IPv4Network
from ipaddress import IPv6Network
from cowrie.core.config import CowrieConfig
import cowrie.core.output

class Filer(object):
    """
    IDEA files creator
    """

    def __init__(self, directory):
        self.basedir = self._ensure_path(directory)
        self.tmp = self._ensure_path(os.path.join(self.basedir, "tmp"))
        self.incoming = self._ensure_path(os.path.join(self.basedir, "incoming"))
        self.hostname = socket.gethostname()
        self.pid = os.getpid()

    def _ensure_path(self, p):
        try:
            os.mkdir(p)
        except OSError:
            if not os.path.isdir(p):
                raise
        return p

    def _get_new_name(self, fd=None):
        (inode, device) = os.fstat(fd)[1:3] if fd else (0, 0)
        return "%s.%d.%f.%d.%d" % (
            self.hostname, self.pid, time(), device, inode)

    def create_unique_file(self):
        tmpname = None
        while not tmpname:
            tmpname = self._get_new_name()
            try:
                fd = os.open(os.path.join(self.tmp, tmpname), os.O_CREAT | os.O_RDWR | os.O_EXCL)
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise
                tmpname = None
        newname = self._get_new_name(fd)
        os.rename(os.path.join(self.tmp, tmpname), os.path.join(self.tmp, newname))
        nf = os.fdopen(fd, "w")
        return nf, newname

    def publish_file(self, short_name):
        os.rename(os.path.join(self.tmp, short_name), os.path.join(self.incoming, short_name))

class Output(cowrie.core.output.Output):
    """
    Wardenfiler Output
    """
    detector_name = None
    resolve_nat = False
    reported_public_ipv4 = None
    reported_public_ipv6 = None
    reported_ssh_port = None
    nat_host = "gateway"
    nat_port = 1456
    anon_mask_4 = 32
    anon_mask_6 = 128
    aggr_win = 5 * 60
    test_mode = True
    output_dir = "var/spool/warden"
    drop_malware = True
    win_start = None
    attackers = {}
    # aggregated credentials from failed attempts per IP
    attackers_creds = {}
    sessions = {}
    port_xlat = {}


    def save_event(self, event):
        f, name = self.filer.create_unique_file()
        with f:
            f.write(json.dumps(event, ensure_ascii=True))
        self.filer.publish_file(name)


    def start(self):
        if CowrieConfig.has_option('output_wardenfiler', 'detector_name'):
            self.detector_name = CowrieConfig.get('output_wardenfiler', 'detector_name')
        if CowrieConfig.has_option('output_wardenfiler', 'resolve_nat'):
            self.resolve_nat = CowrieConfig.getboolean('output_wardenfiler', 'resolve_nat')
        if CowrieConfig.has_option('output_wardenfiler', 'reported_public_ipv4'):
            self.reported_public_ipv4 = CowrieConfig.get('output_wardenfiler', 'reported_public_ipv4')
        if CowrieConfig.has_option('output_wardenfiler', 'reported_public_ipv6'):
            self.reported_public_ipv6 = CowrieConfig.get('output_wardenfiler', 'reported_public_ipv6')
        if CowrieConfig.has_option('output_wardenfiler', 'reported_ssh_port'):
            self.reported_ssh_port = CowrieConfig.getint('output_wardenfiler', 'reported_ssh_port')
        if CowrieConfig.has_option('output_wardenfiler', 'nat_host'):
            self.nat_host = CowrieConfig.get('output_wardenfiler', 'nat_host')
        if CowrieConfig.has_option('output_wardenfiler', 'nat_port'):
            self.nat_port = CowrieConfig.getint('output_wardenfiler', 'nat_port')
        if CowrieConfig.has_option('output_wardenfiler', 'anon_mask_4'):
            self.anon_mask_4 = CowrieConfig.getint('output_wardenfiler', 'anon_mask_4')
        if CowrieConfig.has_option('output_wardenfiler', 'anon_mask_6'):
            self.anon_mask_6 = CowrieConfig.getint('output_wardenfiler', 'anon_mask_6')
        if CowrieConfig.has_option('output_wardenfiler', 'aggr_win'):
            self.aggr_win = CowrieConfig.getint('output_wardenfiler', 'aggr_win')
        if CowrieConfig.has_option('output_wardenfiler', 'test_mode'):
            self.test_mode = CowrieConfig.getboolean('output_wardenfiler', 'test_mode')
        if CowrieConfig.has_option('output_wardenfiler', 'output_dir'):
            self.output_dir = CowrieConfig.get('output_wardenfiler', 'output_dir')
        if CowrieConfig.has_option('output_wardenfiler', 'port_xlat'):
            self.port_xlat = dict((int(x), int(y)) for x, y in (e.split(':') for e in CowrieConfig.get('output_wardenfiler', 'port_xlat').split()))
        if CowrieConfig.has_option('output_wardenfiler', 'drop_malware'):
            self.drop_malware = CowrieConfig.getboolean('output_wardenfiler', 'drop_malware')

        self.filer = Filer(self.output_dir)


    def stop(self):
        """
        No actions needed on honeypot shutdown
        """


    def write(self, entry):
        event = {
            "Format": "IDEA0",
            "ID": str(uuid4()),
            "DetectTime": entry['timestamp'],
            "Category": [],
            "Source": [{"Proto": ["tcp", "ssh"]}],
            "Target": [{ "Proto": ["tcp", "ssh"]}],
            "Node": [
                {
                    "Name": self.detector_name,
                    "Type": ["Connection", "Auth", "Honeypot"],
                    "SW": ["Cowrie with Warden Filer output module"],
                }
            ]
        }

        if self.test_mode:
            event["Category"].append("Test")

        if entry["src_ip"].startswith("::ffff:"):
            entry["src_ip"] = entry["src_ip"][7:]
        if entry.get("dst_ip") and entry["dst_ip"].startswith("::ffff:"):
            entry["dst_ip"] = entry["dst_ip"][7:]

        # detect IPv4 or IPv6
        src_af = "IP4" if not ':' in entry["src_ip"] else "IP6"

        # If configured, override destination IP and port
        if entry.get("dst_ip"):
            if src_af == "IP4" and self.reported_public_ipv4:
                entry["dst_ip"] = self.reported_public_ipv4
            elif src_af == "IP6" and self.reported_public_ipv6:
                entry["dst_ip"] = self.reported_public_ipv6

        if entry.get("dst_port") and self.reported_ssh_port:
            entry["dst_port"] = self.reported_ssh_port

        if entry["eventid"] == 'cowrie.session.connect':
            # Do not track a session for a source
            # which is not globally routable
            if not ip_address(entry["src_ip"]).is_global:
                return()
            
            if self.resolve_nat:
                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                s.connect((self.nat_host, self.nat_port))
                s.sendall((','.join((entry["src_ip"], str(entry["src_port"]))).encode("utf-8")))
                dst = s.recv(50).decode("utf-8")
                s.close()
                if dst != "NE":
                    entry["dst_ip"] = dst
                else:
                    return()

            entry["dst_ip"] = (
                str(IPv4Network("/".join((entry["dst_ip"], str(self.anon_mask_4))), False).network_address) if not ':' in  entry["dst_ip"] else
                str(IPv6Network("/".join((entry["dst_ip"], str(self.anon_mask_6))), False).network_address)
            )

            entry["loggedin"] = False
            entry["credentials"] = []
            # AID - aggregation ID
            entry["aid"] = aid = ','.join((entry["src_ip"], entry["dst_ip"]))
            self.sessions[entry["session"]] = entry
            ws = self.win_start or time()
            cnt = self.attackers.get(aid, 0)
            # aggregated credentials from attempts
            if not self.attackers_creds.get(aid):
                self.attackers_creds[aid] = []

            if (time() - ws < self.aggr_win):
                self.attackers[aid] = cnt + 1
            else:
                # This flushes out ALL the aggregated events!
                # NOTE: The AID, and its values, are no longer relevant for this part of code!
                event["Node"][0]["AggrWin"] = strftime("%H:%M:%S", gmtime(float(self.aggr_win)))
                event["WinStartTime"] = datetime.utcfromtimestamp(ws).isoformat() + 'Z'
                event["WinEndTime"] = datetime.utcfromtimestamp(ws + self.aggr_win).isoformat() + 'Z'
                event["Category"].append("Attempt.Login")
                event["Description"] = "SSH login attempt"
                for i, c in self.attackers.items():
                    a_src_ip, a_dst_ip = i.split(',')
                    a_af = "IP4" if not ':' in a_src_ip else "IP6"
                    a_creds = self.attackers_creds.get(i, [])
                    event["ID"] = str(uuid4())
                    event["DetectTime"] = event["WinEndTime"]
                    event["ConnCount"] = c
                    event["Source"] = [{"Proto": ["tcp", "ssh"], a_af: [a_src_ip]}]
                    event["Target"] = [{"Proto": ["tcp", "ssh"], a_af: [a_dst_ip]}]
                    if (self.anon_mask_4 < 32 and a_af == "IP4") or (self.anon_mask_6 < 128):
                        event["Target"][0]["Anonymised"] = True
                    if a_creds:
                        event["Credentials"] = a_creds
                    else:
                        event.pop("Credentials", None)
                    self.save_event(event)
                self.attackers = {}
                self.attackers_creds = {}
                ws = time()
                self.attackers[aid] = 1
                self.attackers_creds[aid] = []
            self.win_start = ws

        elif entry["session"] not in self.sessions:
            # We do not save sessions
            # that were created during previous Cowrie runs
            # and we should not care about them.
            return()

        elif entry["eventid"] == 'cowrie.login.success':
            u, p = entry["username"], entry["password"]
            s = entry["session"]
            if s in self.sessions:
                self.sessions[s]["input"] = []
                self.sessions[s]["loggedin"] = True
                self.sessions[s]["credentials"].append({"Username": u, "Password": p, "Type": ["AcceptedByServer"]})

        elif entry["eventid"] == "cowrie.login.failed":
            u, p = entry["username"], entry["password"]
            s = entry["session"]
            if s in self.sessions:
                self.sessions[s]["credentials"].append({"Username": u, "Password": p})

        elif entry["eventid"] == 'cowrie.command.input':
            s = entry["session"]
            if s in self.sessions:
                self.sessions[s]["input"].append(entry["input"])

        elif entry["eventid"] == 'cowrie.session.file_download':
            s = entry["session"]
            if s in self.sessions:
                sch = { "http": 80, "https": 443, "ftp": 21 }
                
                # deal with the file first (drop even if not reported)
                mware = None
                fname = None
                if "outfile" in entry and os.path.exists(entry["outfile"]):
                    fp = open(entry["outfile"], "rb")
                    mware = fp.read()
                    fp.close()
                    if self.drop_malware:
                        os.remove(entry["outfile"])
                
                if mware:
                    # TODO: Classify everything as Malware?
                    event["Category"].append("Malware")
                    event["Description"] = "Malware download during honeypot session"
                    
                    if "url" in entry and entry["url"].startswith(tuple(sch.keys())):
                        url = urlparse(entry["url"])
                        url_host = url.hostname
                        url_ai = socket.getaddrinfo(url_host, None)[0]
                        url_af = "IP6" if url_ai[0] == socket.AddressFamily.AF_INET6 else "IP4"
                        url_ip = url_ai[4][0]
                        proto = [ "tcp", url.scheme ]
                        port = url.port or sch[url.scheme]
                        
                        fname = os.path.basename(entry["url"])
                        if not fname and 'destfile' in entry:
                            fname = os.path.basename(entry['destfile'])
                            
                    elif not "url" in entry:
                        if "destfile" in entry:
                            event["Description"] = "Redirected content during honeypot session"
                            fname = os.path.basename(entry["destfile"])
                        else:
                            event["Description"] = "Stdin contents during honeypot session"

                    else:
                        # TODO: Some exotic protocol? Let's not worry with that now
                        return()
                        
                    event["DetectTime"] = entry["timestamp"]
                    if "url" in entry:
                        del event["Target"]
                        event["Source"][0] = { "Type": ["Malware"] }
                        event["Source"][0]["URL"] = [entry["url"]]
                        event["Source"][0][url_af] = [url_ip]
                        event["Source"][0]["Proto"] = proto
                        event["Source"][0]["Port"] = [port]
                        if url_ip != url_host:
                            event["Source"][0]["Hostname"] = [url_host]  
                    else:
                        event["Source"][0] = { "Type": ["Botnet"] }
                        # the source of the malicious activity is the host, we don't have further details to that
                        event["Source"][0][src_af] = [entry["src_ip"]]
                        event["Source"][0]["Port"] = [self.sessions[s]["src_port"]]

                    event["Attach"] = [{
                        "Type": ["ShellCode"],
                        "Hash": ["sha256:" + entry["shasum"]],
                        "Size": len(mware),
                        "Description": "Some probably malicious code downloaded during honeypot SSH session",
                        "ContentEncoding": "base64",
                        "Content": b64encode(mware).decode(),
                    }]
                    if fname:
                        event["Attach"][0]["FileName"] = [fname]
                    if "url" in entry:
                        event["Attach"][0]["ExternalURI"] = [entry["url"]]
                    self.save_event(event)

        elif entry["eventid"] == 'cowrie.session.file_upload':
            # Upload through SCP or SFTP to the honeypot
            s = entry["session"]
            if s in self.sessions:                
                # deal with the file first (drop even if not reported)
                mware = None
                fname = None
                if "outfile" in entry and os.path.exists(entry["outfile"]):
                    fp = open(entry["outfile"], "rb")
                    mware = fp.read()
                    fp.close()
                    if self.drop_malware:
                        os.remove(entry["outfile"])
                    fname = entry["filename"]
                    
                if mware:
                    event["Category"].append("Malware")
                    event["Description"] = "Malware download during honeypot session"
                    event["DetectTime"] = entry["timestamp"]
                    event["Source"][0] = { "Type": ["Botnet"] }
                    # the source of the malicious activity is the host, we don't have further details to that
                    event["Source"][0][src_af] = [entry["src_ip"]]
                    event["Source"][0]["Port"] = [self.sessions[s]["src_port"]]

                    event["Attach"] = [{
                        "Type": ["ShellCode"],
                        "FileName": [fname],
                        "Hash": ["sha256:" + entry["shasum"]],
                        "Size": len(mware),
                        "Description": "Some probably malicious code downloaded during honeypot SSH session",
                        "ContentEncoding": "base64",
                        "Content": b64encode(mware).decode(),
                    }]                    
                    self.save_event(event)
            

        elif entry["eventid"] == 'cowrie.session.closed':
            s = entry["session"]
            if s in self.sessions and self.sessions[s]["loggedin"]:
                idata = '\n'.join(self.sessions[s]["input"])
                plain = all(c in string.printable for c in idata)
                event["Category"].append("Intrusion.UserCompromise")
                event["Description"] = "SSH successful login" + (" with unauthorized command input" if len(idata) else "")
                event["Source"][0][src_af] = [entry["src_ip"]]
                event["Target"][0][src_af] = [self.sessions[s]["dst_ip"]]
                event["Source"][0]["Port"] = [self.sessions[s]["src_port"]]
                dst_port = self.sessions[s]["dst_port"]
                if dst_port in self.port_xlat:
                    dst_port = self.port_xlat[dst_port]
                event["Target"][0]["Port"] = [dst_port]
                if len(idata):
                    eidata = idata if plain else b64encode(idata.encode()).decode()
                    attach = {
                        "Type": ["Exploit"],
                        "Hash": ["sha1:" + sha1(idata.encode("utf-8")).hexdigest()],
                        "Size": len(idata),
                        "Description": "Commands entered by attacker during honeypot SSH session",
                        "Content": eidata
                    }
                    if not plain:
                        attach["ContentEncoding"] = "base64"
                    event["Attach"] = [attach]
                if self.sessions[s]["credentials"]:
                    accepted_creds = [ c for c in self.sessions[s]["credentials"] if "Type" in c and "AcceptedByServer" in c["Type"] ]
                    event["Credentials"] = list(accepted_creds)
                self.save_event(event)
            
            if s in self.sessions:
                # Store attempted credentials (all) to the aggregation cache
                aid = self.sessions[s]["aid"]
                self.attackers_creds[aid].extend(self.sessions[s]["credentials"])
            # Discard the session
            self.sessions.pop(s, None)