From b50a3fb0c12a72d115741f48d16076b8b4a90a0a Mon Sep 17 00:00:00 2001
From: Jan Zerdik <zerdik@cesnet.cz>
Date: Tue, 4 Feb 2020 12:42:45 +0100
Subject: [PATCH] Changes in filtering events.

(Redmine issue: #4489)
---
 lib/mentat/reports/event.py         | 164 +++++++++++++++++-----------
 lib/mentat/reports/utils.py         |  72 ++++++------
 lib/mentat/services/eventstorage.py |  13 +--
 3 files changed, 142 insertions(+), 107 deletions(-)

diff --git a/lib/mentat/reports/event.py b/lib/mentat/reports/event.py
index 838b79269..6596faa1e 100644
--- a/lib/mentat/reports/event.py
+++ b/lib/mentat/reports/event.py
@@ -24,6 +24,7 @@ import json
 import datetime
 import zipfile
 import csv
+from copy import deepcopy
 
 #
 # Custom libraries
@@ -44,6 +45,7 @@ from mentat.reports.utils import StorageThresholdingCache, NoThresholdingCache
 from mentat.datatype.sqldb import EventReportModel
 from mentat.emails.event import ReportEmail
 from mentat.reports.base import BaseReporter
+from mentat.services.eventstorage import record_to_idea
 
 
 REPORT_SUBJECT_SUMMARY = tr_("[{:s}] {:s} - Notice about possible problems in your network")
@@ -249,8 +251,8 @@ class EventReporter(BaseReporter):
             if not events_fetched:
                 break
 
-            # B: Perform event filtering according to custom group filters.
-            events_flt, fltlog = self.filter_events(events_fetched, abuse_group, settings)
+            # B: Perform event filtering according to custom group filters and aggregate by source.
+            events_flt, events_aggr, fltlog = self.filter_events(events_fetched, abuse_group, settings)
             result['evcount_flt'] = len(events_flt)
             result['evcount_flt_blk'] = len(events_fetched) - len(events_flt)
             result['filtering'] = fltlog
@@ -258,14 +260,14 @@ class EventReporter(BaseReporter):
                 break
 
             # C: Perform event thresholding.
-            events_thr = self.threshold_events(events_flt, abuse_group, severity, time_h)
+            events_thr, events_aggr = self.threshold_events(events_aggr, abuse_group, severity, time_h)
+
             result['evcount_thr'] = len(events_thr)
             result['evcount_thr_blk'] = len(events_flt) - len(events_thr)
             if not events_thr:
                 break
 
-            # D: Aggregate events by sources for further processing.
-            events_aggr = self.aggregate_events_by_source(events_thr, settings)
+            # D: Save aggregated events for further processing.
             events['regular'] = events_thr
             events['regular_aggr'] = events_aggr
 
@@ -280,7 +282,7 @@ class EventReporter(BaseReporter):
                 break
 
             # B: Aggregate events by sources for further processing.
-            events_aggr = self.aggregate_events_by_source(events_rel, settings)
+            events_rel, events_aggr = self.aggregate_relapsed_events(events_rel)
             events['relapsed'] = events_rel
             events['relapsed_aggr'] = events_aggr
 
@@ -521,6 +523,18 @@ class EventReporter(BaseReporter):
             )
         return events
 
+    @staticmethod
+    def _whois_filter(sources, src, whoismodule, whoismodule_cache):
+        """
+        Help method for filtering sources by abuse group's networks
+        """
+        if src not in whoismodule_cache:
+            # Source IP must belong to network range of given abuse group.
+            whoismodule_cache[src] = bool(whoismodule.lookup(src))
+        if whoismodule_cache[src]:
+            sources.add(src)
+        return sources
+
     def filter_events(self, events, abuse_group, settings):
         """
         Filter given list of IDEA events according to given abuse group settings.
@@ -528,19 +542,52 @@ class EventReporter(BaseReporter):
         :param list events: List of IDEA events as :py:class:`mentat.idea.internal.Idea` objects.
         :param mentat.datatype.sqldb.GroupModel: Abuse group.
         :param mentat.reports.event.ReportingSettings settings: Reporting settings.
-        :return: Tuple with list of events that passed filtering and filtering log as a dictionary.
+        :return: Tuple with list of events that passed filtering, aggregation of them and filtering log as a dictionary.
         :rtype: tuple
         """
+        whoismodule = mentat.services.whois.WhoisModule()
+        networks = settings.setup_networks()
+        whoismodule.setup(networks)
+        whoismodule_cache = {}
+
         filter_list = settings.setup_filters(self.filter_parser, self.filter_compiler)
         result = []
         fltlog = {}
+        aggregated_result = {}
         for event in events:
             match = self.filter_event(filter_list, event)
+            sources = set()
             if match:
-                fltlog[match] = fltlog.get(match, 0) + 1
-                self.logger.debug("Event matched filtering rule '%s'", match)
-                continue
-            result.append(event)
+                if len(jpath_values(event, 'Source.IP4') + jpath_values(event, 'Source.IP6')) > 1:
+                    event_copy = deepcopy(event)
+                    for s in event_copy["Source"]:
+                        s["IP4"] = []
+                        s["IP6"] = []
+                    for src in set(jpath_values(event, 'Source.IP4')):
+                        event_copy["Source"][0]["IP4"] = [src]
+                        if not self.filter_event(filter_list, event_copy, False):
+                            sources = self._whois_filter(sources, src, whoismodule, whoismodule_cache)
+                    event_copy["Source"][0]["IP4"] = []
+                    for src in set(jpath_values(event, 'Source.IP6')):
+                        event_copy["Source"][0]["IP6"] = [src]
+                        if not self.filter_event(filter_list, event_copy, False):
+                            sources = self._whois_filter(sources, src, whoismodule, whoismodule_cache)
+
+                if sources:
+                    self.logger.debug("Event matched filtering rule '%s', some sources allowed through", match)
+                else:
+                    self.logger.debug("Event matched filtering rule '%s', all sources filtered", match)
+                    fltlog[match] = fltlog.get(match, 0) + 1
+            else:
+                for src in set(jpath_values(event, 'Source.IP4') + jpath_values(event, 'Source.IP6')):
+                    sources = self._whois_filter(sources, src, whoismodule, whoismodule_cache)
+
+            if sources:
+                result.append(event)
+                for src in sources:
+                    if str(src) not in aggregated_result:
+                        aggregated_result[str(src)] = []
+                    aggregated_result[str(src)].append(event)
 
         if result:
             self.logger.info(
@@ -555,40 +602,48 @@ class EventReporter(BaseReporter):
                 abuse_group.name,
                 len(events)
             )
-        return result, fltlog
+        return result, aggregated_result, fltlog
 
-    def threshold_events(self, events, abuse_group, severity, time_h):
+    def threshold_events(self, events_aggr, abuse_group, severity, time_h):
         """
         Threshold given list of IDEA events according to given abuse group settings.
 
-        :param list events: List of IDEA events as :py:class:`mentat.idea.internal.Idea` objects.
+        :param dict events_aggr: Aggregation of IDEA events as :py:class:`mentat.idea.internal.Idea` objects by source.
         :param mentat.datatype.sqldb.GroupModel: Abuse group.
         :param str severity: Severity for which to perform reporting.
         :param datetime.datetime time_h: Upper reporting time threshold.
         :return: List of events that passed thresholding.
         :rtype: list
         """
-        result = []
-        for event in events:
-            if not self.tcache.event_is_thresholded(event, time_h):
-                result.append(event)
-            else:
-                self.tcache.threshold_event(event, abuse_group.name, severity, time_h)
-
+        result = {}
+        aggregated_result = {}
+        filtered = set()
+        for source, events in events_aggr.items():
+            for event in events:
+                if not self.tcache.event_is_thresholded(event, source, time_h):
+                    if source not in aggregated_result:
+                        aggregated_result[source] = []
+                    aggregated_result[source].append(event)
+                    result[event["ID"]] = event
+                else:
+                    filtered.add(event["ID"])
+                    self.tcache.threshold_event(event, source, abuse_group.name, severity, time_h)
+
+        filtered -= set(result.keys())
         if result:
             self.logger.info(
                 "%s: Thresholds let %d events through, %d blocked.",
                 abuse_group.name,
                 len(result),
-                (len(events) - len(result))
+                len(filtered)
             )
         else:
             self.logger.info(
                 "%s: Thresholds blocked all %d events, nothing to report.",
                 abuse_group.name,
-                len(events)
+                len(filtered)
             )
-        return result
+        return list(result.values()), aggregated_result
 
     def relapse_events(self, abuse_group, severity, time_h):
         """
@@ -622,6 +677,23 @@ class EventReporter(BaseReporter):
             )
         return events
 
+    def aggregate_relapsed_events(self, relapsed):
+        """
+        :param dict events: Dicetionary of events aggregated by threshold key.
+        :return: Events aggregated by source.
+        :rtype: dict
+        """
+        result = []
+        aggregated_result = {}
+        for event in relapsed:
+            result.append(record_to_idea(event))
+            for key in event.keyids:
+                source = self.tcache.get_source_from_cache_key(key)
+                if source not in aggregated_result:
+                    aggregated_result[source] = []
+                aggregated_result[source].append(result[-1])
+        return result, aggregated_result
+
     def update_thresholding_cache(self, events, settings, severity, time_h):
         """
         :param dict events: Dictionary structure with IDEA events that were reported.
@@ -642,57 +714,23 @@ class EventReporter(BaseReporter):
     #---------------------------------------------------------------------------
 
 
-    def filter_event(self, filter_rules, event):
+    def filter_event(self, filter_rules, event, to_db=True):
         """
         Filter given event according to given list of filtering rules.
 
         :param list filter_rules: Filters to be used.
         :param mentat.idea.internal.Idea: Event to be filtered.
+        :param bool to_db: Save hit to db.
         :return: ``True`` in case any filter matched, ``False`` otherwise.
         :rtype: bool
         """
         for flt in filter_rules:
             if self.filter_worker.filter(flt[1], event):
-                flt[0].hits += 1
+                if to_db:
+                    flt[0].hits += 1
                 return flt[0].name
         return False
 
-    @staticmethod
-    def aggregate_events_by_source(events, settings):
-        """
-        Aggregate given list of events to dictionary structure according to the
-        IPv4 and IPv6 sources. The resulting structure contains event source
-        addresses (value of ``Source.IP4``  and ``Source.IP6`` attributes) as keys
-        and list of events with given source in no particular order. For example::
-
-            {
-                '192.168.1.1': [...],
-                '::1': [...],
-                ...
-            }
-
-        :param list events: List of events as :py:class:`mentat.idea.internal.Idea` objects.
-        :param mentat.reports.event.ReportingSettings settings: Reporting settings.
-        :return: Dictionary structure of aggregated events.
-        :rtype: dict
-        """
-        whoismodule = mentat.services.whois.WhoisModule()
-        networks = settings.setup_networks()
-        whoismodule.setup(networks)
-
-        result = {}
-        for event in events:
-            sources_ip4 = jpath_values(event, 'Source.IP4')
-            sources_ip6 = jpath_values(event, 'Source.IP6')
-            for src in sources_ip4 + sources_ip6:
-                # Source IP must belong to network range of given abuse group.
-                if not whoismodule.lookup(src):
-                    continue
-                if str(src) not in result:
-                    result[str(src)] = []
-                result[str(src)].append(event)
-        return result
-
     @staticmethod
     def aggregate_events(events):
         """
@@ -738,7 +776,7 @@ class EventReporter(BaseReporter):
                 ip_result["count"] += 1
                 # Name of last node for identify unique detector names
                 ip_result["detectors_count"][event.get("Node", [{}])[-1].get("Name")] = 1
-                ip_result["approx_conn_count"] += event["ConnCount"] if event.get("ConnCount") else int(event.get("FlowCount", 0) / 2) 
+                ip_result["approx_conn_count"] += event["ConnCount"] if event.get("ConnCount") else int(event.get("FlowCount", 0) / 2)
 
                 for data_key, idea_key in (("conn_count", "ConnCount"), ("flow_count", "FlowCount"), ("packet_count", "PacketCount"), ("byte_count", "ByteCount")):
                     ip_result[data_key] += event.get(idea_key, 0)
diff --git a/lib/mentat/reports/utils.py b/lib/mentat/reports/utils.py
index f691a5094..52e409677 100644
--- a/lib/mentat/reports/utils.py
+++ b/lib/mentat/reports/utils.py
@@ -306,22 +306,18 @@ class ThresholdingCache:
     reporting.
     """
 
-    def event_is_thresholded(self, event, ttl):
+    def event_is_thresholded(self, event, source, ttl):
         """
-        Check, that given event is thresholded within given TTL.
+        Check, that given combination of event and source is thresholded within given TTL.
 
         :param mentat.idea.internal.Idea event: IDEA event to check.
+        :param str source: Source to check.
         :param datetime.datetime ttl: TTL for the thresholding record.
         :return: ``True`` in case the event is thresholded, ``False`` otherwise.
         :rtype: bool
         """
-        cachekeys = self._generate_cache_keys(event)
-        result = True
-        for chk in cachekeys:
-            res = self.check(chk, ttl)
-            if not res:
-                result = False
-        return result
+        cachekey = self._generate_cache_key(event, source)
+        return self.check(cachekey, ttl)
 
     def set_threshold(self, event, source, thresholdtime, relapsetime, ttl):
         """
@@ -333,24 +329,21 @@ class ThresholdingCache:
         :param datetime.datetime relapsetime: Relapse window start time.
         :param datetime.datetime ttl: Record TTL.
         """
-        if source:
-            source = [source]
-        cachekeys = self._generate_cache_keys(event, source)
-        for chk in cachekeys:
-            self.set(chk, thresholdtime, relapsetime, ttl)
+        cachekey = self._generate_cache_key(event, source)
+        self.set(cachekey, thresholdtime, relapsetime, ttl)
 
-    def threshold_event(self, event, group_name, severity, createtime):
+    def threshold_event(self, event, source, group_name, severity, createtime):
         """
         Threshold given event with given TTL.
 
         :param mentat.idea.internal.Idea event: IDEA event to threshold.
+        :param str source: Source address because of which to threshold the event.
         :param str group_name: Name of the group for which to threshold.
         :param str severity: Event severity.
         :param datetime.datetime createtime: Thresholding timestamp.
         """
-        cachekeys = self._generate_cache_keys(event)
-        for chk in cachekeys:
-            self.save(event.get_id(), chk, group_name, severity, createtime)
+        cachekey = self._generate_cache_key(event, source)
+        self.save(event.get_id(), cachekey, group_name, severity, createtime)
 
     #---------------------------------------------------------------------------
 
@@ -394,8 +387,8 @@ class ThresholdingCache:
         :param str group_name: Name of the abuse group.
         :param str severity: Event severity.
         :param datetime.datetime ttl: Record TTL time.
-        :return: List of relapsed events as :py:class:`mentat.idea.internal.Idea` objects.
-        :rtype: list
+        :return: Touple with list of relapsed events as :py:class:`mentat.idea.internal.Idea` objects and their aggregation by keyid.
+        :rtype: touple
         """
         raise NotImplementedError()
 
@@ -409,25 +402,29 @@ class ThresholdingCache:
 
     #---------------------------------------------------------------------------
 
-    def _generate_cache_keys(self, event, sources = None):
+    def _generate_cache_key(self, event, source):
         """
-        Generate list of thresholding/relapse cache keys for given event.
+        Generate cache key for given event and source.
 
         :param mentat.idea.internal.Idea event: Event to process.
-        :return: List of cache keys as strings.
-        :rtype: list
+        :param str source: Source to process.
+        :return: Cache key as strings.
+        :rtype: str
         """
-        keys = []
         event_class = jpath_value(event, '_CESNET.EventClass')
         if not event_class:
             event_class = '/'.join(jpath_values(event, 'Category'))
-        if not sources:
-            sources_ip4 = jpath_values(event, 'Source.IP4')
-            sources_ip6 = jpath_values(event, 'Source.IP6')
-            sources = sources_ip4 + sources_ip6
-        for src in sources:
-            keys.append('+++'.join((event_class, str(src))))
-        return keys
+        return '+++'.join((event_class, str(source)))
+
+    def get_source_from_cache_key(self, key):
+        """
+        Return source from which was key generated.
+
+        :param str key: Cache key.
+        :return: Cached source.
+        :rtype: str
+        """
+        return key.split('+++')[1] if key and len(key.split('+++')) > 1 else key
 
 
 class NoThresholdingCache(ThresholdingCache):
@@ -517,15 +514,16 @@ class SingleSourceThresholdingCache(SimpleMemoryThresholdingCache):
         super().__init__()
         self.source = source
 
-    def _generate_cache_keys(self, event, sources = None):
+    def _generate_cache_key(self, event, source):
         """
-        Generate list of thresholding/relapse cache keys for given event.
+        Generate cache key for given event and source.
 
         :param mentat.idea.internal.Idea event: Event to process.
-        :return: List of cache keys as strings.
-        :rtype: list
+        :param str source: Source to process.
+        :return: Cache key as strings.
+        :rtype: str
         """
-        return super()._generate_cache_keys(event, [self.source])
+        return super()._generate_cache_key(event, self.source)
 
 
 class StorageThresholdingCache(ThresholdingCache):
diff --git a/lib/mentat/services/eventstorage.py b/lib/mentat/services/eventstorage.py
index 4e884327d..b8e987ee5 100644
--- a/lib/mentat/services/eventstorage.py
+++ b/lib/mentat/services/eventstorage.py
@@ -519,7 +519,7 @@ class EventStorageCursor:
         query, params  = build_query(parameters, qtype = 'count', qname = qname)
         self.lastquery = self.cursor.mogrify(query, params)
         self.cursor.execute(query, params)
-        
+
         record = self.cursor.fetchone()
         if record:
             return record[0]
@@ -551,7 +551,7 @@ class EventStorageCursor:
 
     def delete_events(self, parameters = None, qname = None):
         """
-        Delete IDEA events in database according to given parameters. There is 
+        Delete IDEA events in database according to given parameters. There is
         an option to assign given unique name to the query, so that it can be
         identified within the ``pg_stat_activity`` table.
 
@@ -717,15 +717,14 @@ class EventStorageCursor:
         :param str group_name: Name of the abuse group.
         :param str severity: Event severity.
         :param datetime.datetime ttl: Record TTL time.
-        :return: List of relapsed events as :py:class:`mentat.idea.internal.Idea` objects.
+        :return: List of relapsed events as touple of id, json of event data and list of threshold keys.
         :rtype: list
         """
         self.cursor.execute(
-            "SELECT events.*, events_json.event FROM events INNER JOIN events_json ON events.id = events_json.id WHERE events.id IN (SELECT DISTINCT eventid FROM events_thresholded WHERE keyid IN (SELECT keyid FROM events_thresholded INNER JOIN thresholds ON (events_thresholded.keyid = thresholds.id) WHERE events_thresholded.groupname = %s AND events_thresholded.eventseverity = %s AND events_thresholded.createtime >= thresholds.relapsetime AND thresholds.ttltime <= %s))",
+            "SELECT events_json.id, events_json.event, ARRAY_AGG(events_thresholded.keyid) AS keyids FROM events_json INNER JOIN events_thresholded ON events_json.id = events_thresholded.eventid INNER JOIN thresholds ON events_thresholded.keyid = thresholds.id WHERE events_thresholded.groupname = %s AND events_thresholded.eventseverity = %s AND events_thresholded.createtime >= thresholds.relapsetime AND thresholds.ttltime <= %s GROUP BY events_json.id",
             (group_name, severity, ttl)
         )
-        events_raw  = self.cursor.fetchall()
-        return [record_to_idea(event) for event in events_raw]
+        return self.cursor.fetchall()
 
     def thresholded_events_count(self):
         """
@@ -1102,7 +1101,7 @@ class EventStorageService:
             try:
                 self.cursor.insert_event(idea_event)
                 self.savepoint_create()
-                return 
+                return
 
             except psycopg2.DataError as err:
                 self.savepoint_rollback()
-- 
GitLab