diff --git a/lib/mentat/reports/event.py b/lib/mentat/reports/event.py index 838b7926925210baff2039ae875513ba537aa8c7..6596faa1e3d54ce29590133a5f0dc17666fa7ecb 100644 --- a/lib/mentat/reports/event.py +++ b/lib/mentat/reports/event.py @@ -24,6 +24,7 @@ import json import datetime import zipfile import csv +from copy import deepcopy # # Custom libraries @@ -44,6 +45,7 @@ from mentat.reports.utils import StorageThresholdingCache, NoThresholdingCache from mentat.datatype.sqldb import EventReportModel from mentat.emails.event import ReportEmail from mentat.reports.base import BaseReporter +from mentat.services.eventstorage import record_to_idea REPORT_SUBJECT_SUMMARY = tr_("[{:s}] {:s} - Notice about possible problems in your network") @@ -249,8 +251,8 @@ class EventReporter(BaseReporter): if not events_fetched: break - # B: Perform event filtering according to custom group filters. - events_flt, fltlog = self.filter_events(events_fetched, abuse_group, settings) + # B: Perform event filtering according to custom group filters and aggregate by source. + events_flt, events_aggr, fltlog = self.filter_events(events_fetched, abuse_group, settings) result['evcount_flt'] = len(events_flt) result['evcount_flt_blk'] = len(events_fetched) - len(events_flt) result['filtering'] = fltlog @@ -258,14 +260,14 @@ class EventReporter(BaseReporter): break # C: Perform event thresholding. - events_thr = self.threshold_events(events_flt, abuse_group, severity, time_h) + events_thr, events_aggr = self.threshold_events(events_aggr, abuse_group, severity, time_h) + result['evcount_thr'] = len(events_thr) result['evcount_thr_blk'] = len(events_flt) - len(events_thr) if not events_thr: break - # D: Aggregate events by sources for further processing. - events_aggr = self.aggregate_events_by_source(events_thr, settings) + # D: Save aggregated events for further processing. events['regular'] = events_thr events['regular_aggr'] = events_aggr @@ -280,7 +282,7 @@ class EventReporter(BaseReporter): break # B: Aggregate events by sources for further processing. - events_aggr = self.aggregate_events_by_source(events_rel, settings) + events_rel, events_aggr = self.aggregate_relapsed_events(events_rel) events['relapsed'] = events_rel events['relapsed_aggr'] = events_aggr @@ -521,6 +523,18 @@ class EventReporter(BaseReporter): ) return events + @staticmethod + def _whois_filter(sources, src, whoismodule, whoismodule_cache): + """ + Help method for filtering sources by abuse group's networks + """ + if src not in whoismodule_cache: + # Source IP must belong to network range of given abuse group. + whoismodule_cache[src] = bool(whoismodule.lookup(src)) + if whoismodule_cache[src]: + sources.add(src) + return sources + def filter_events(self, events, abuse_group, settings): """ Filter given list of IDEA events according to given abuse group settings. @@ -528,19 +542,52 @@ class EventReporter(BaseReporter): :param list events: List of IDEA events as :py:class:`mentat.idea.internal.Idea` objects. :param mentat.datatype.sqldb.GroupModel: Abuse group. :param mentat.reports.event.ReportingSettings settings: Reporting settings. - :return: Tuple with list of events that passed filtering and filtering log as a dictionary. + :return: Tuple with list of events that passed filtering, aggregation of them and filtering log as a dictionary. :rtype: tuple """ + whoismodule = mentat.services.whois.WhoisModule() + networks = settings.setup_networks() + whoismodule.setup(networks) + whoismodule_cache = {} + filter_list = settings.setup_filters(self.filter_parser, self.filter_compiler) result = [] fltlog = {} + aggregated_result = {} for event in events: match = self.filter_event(filter_list, event) + sources = set() if match: - fltlog[match] = fltlog.get(match, 0) + 1 - self.logger.debug("Event matched filtering rule '%s'", match) - continue - result.append(event) + if len(jpath_values(event, 'Source.IP4') + jpath_values(event, 'Source.IP6')) > 1: + event_copy = deepcopy(event) + for s in event_copy["Source"]: + s["IP4"] = [] + s["IP6"] = [] + for src in set(jpath_values(event, 'Source.IP4')): + event_copy["Source"][0]["IP4"] = [src] + if not self.filter_event(filter_list, event_copy, False): + sources = self._whois_filter(sources, src, whoismodule, whoismodule_cache) + event_copy["Source"][0]["IP4"] = [] + for src in set(jpath_values(event, 'Source.IP6')): + event_copy["Source"][0]["IP6"] = [src] + if not self.filter_event(filter_list, event_copy, False): + sources = self._whois_filter(sources, src, whoismodule, whoismodule_cache) + + if sources: + self.logger.debug("Event matched filtering rule '%s', some sources allowed through", match) + else: + self.logger.debug("Event matched filtering rule '%s', all sources filtered", match) + fltlog[match] = fltlog.get(match, 0) + 1 + else: + for src in set(jpath_values(event, 'Source.IP4') + jpath_values(event, 'Source.IP6')): + sources = self._whois_filter(sources, src, whoismodule, whoismodule_cache) + + if sources: + result.append(event) + for src in sources: + if str(src) not in aggregated_result: + aggregated_result[str(src)] = [] + aggregated_result[str(src)].append(event) if result: self.logger.info( @@ -555,40 +602,48 @@ class EventReporter(BaseReporter): abuse_group.name, len(events) ) - return result, fltlog + return result, aggregated_result, fltlog - def threshold_events(self, events, abuse_group, severity, time_h): + def threshold_events(self, events_aggr, abuse_group, severity, time_h): """ Threshold given list of IDEA events according to given abuse group settings. - :param list events: List of IDEA events as :py:class:`mentat.idea.internal.Idea` objects. + :param dict events_aggr: Aggregation of IDEA events as :py:class:`mentat.idea.internal.Idea` objects by source. :param mentat.datatype.sqldb.GroupModel: Abuse group. :param str severity: Severity for which to perform reporting. :param datetime.datetime time_h: Upper reporting time threshold. :return: List of events that passed thresholding. :rtype: list """ - result = [] - for event in events: - if not self.tcache.event_is_thresholded(event, time_h): - result.append(event) - else: - self.tcache.threshold_event(event, abuse_group.name, severity, time_h) - + result = {} + aggregated_result = {} + filtered = set() + for source, events in events_aggr.items(): + for event in events: + if not self.tcache.event_is_thresholded(event, source, time_h): + if source not in aggregated_result: + aggregated_result[source] = [] + aggregated_result[source].append(event) + result[event["ID"]] = event + else: + filtered.add(event["ID"]) + self.tcache.threshold_event(event, source, abuse_group.name, severity, time_h) + + filtered -= set(result.keys()) if result: self.logger.info( "%s: Thresholds let %d events through, %d blocked.", abuse_group.name, len(result), - (len(events) - len(result)) + len(filtered) ) else: self.logger.info( "%s: Thresholds blocked all %d events, nothing to report.", abuse_group.name, - len(events) + len(filtered) ) - return result + return list(result.values()), aggregated_result def relapse_events(self, abuse_group, severity, time_h): """ @@ -622,6 +677,23 @@ class EventReporter(BaseReporter): ) return events + def aggregate_relapsed_events(self, relapsed): + """ + :param dict events: Dicetionary of events aggregated by threshold key. + :return: Events aggregated by source. + :rtype: dict + """ + result = [] + aggregated_result = {} + for event in relapsed: + result.append(record_to_idea(event)) + for key in event.keyids: + source = self.tcache.get_source_from_cache_key(key) + if source not in aggregated_result: + aggregated_result[source] = [] + aggregated_result[source].append(result[-1]) + return result, aggregated_result + def update_thresholding_cache(self, events, settings, severity, time_h): """ :param dict events: Dictionary structure with IDEA events that were reported. @@ -642,57 +714,23 @@ class EventReporter(BaseReporter): #--------------------------------------------------------------------------- - def filter_event(self, filter_rules, event): + def filter_event(self, filter_rules, event, to_db=True): """ Filter given event according to given list of filtering rules. :param list filter_rules: Filters to be used. :param mentat.idea.internal.Idea: Event to be filtered. + :param bool to_db: Save hit to db. :return: ``True`` in case any filter matched, ``False`` otherwise. :rtype: bool """ for flt in filter_rules: if self.filter_worker.filter(flt[1], event): - flt[0].hits += 1 + if to_db: + flt[0].hits += 1 return flt[0].name return False - @staticmethod - def aggregate_events_by_source(events, settings): - """ - Aggregate given list of events to dictionary structure according to the - IPv4 and IPv6 sources. The resulting structure contains event source - addresses (value of ``Source.IP4`` and ``Source.IP6`` attributes) as keys - and list of events with given source in no particular order. For example:: - - { - '192.168.1.1': [...], - '::1': [...], - ... - } - - :param list events: List of events as :py:class:`mentat.idea.internal.Idea` objects. - :param mentat.reports.event.ReportingSettings settings: Reporting settings. - :return: Dictionary structure of aggregated events. - :rtype: dict - """ - whoismodule = mentat.services.whois.WhoisModule() - networks = settings.setup_networks() - whoismodule.setup(networks) - - result = {} - for event in events: - sources_ip4 = jpath_values(event, 'Source.IP4') - sources_ip6 = jpath_values(event, 'Source.IP6') - for src in sources_ip4 + sources_ip6: - # Source IP must belong to network range of given abuse group. - if not whoismodule.lookup(src): - continue - if str(src) not in result: - result[str(src)] = [] - result[str(src)].append(event) - return result - @staticmethod def aggregate_events(events): """ @@ -738,7 +776,7 @@ class EventReporter(BaseReporter): ip_result["count"] += 1 # Name of last node for identify unique detector names ip_result["detectors_count"][event.get("Node", [{}])[-1].get("Name")] = 1 - ip_result["approx_conn_count"] += event["ConnCount"] if event.get("ConnCount") else int(event.get("FlowCount", 0) / 2) + ip_result["approx_conn_count"] += event["ConnCount"] if event.get("ConnCount") else int(event.get("FlowCount", 0) / 2) for data_key, idea_key in (("conn_count", "ConnCount"), ("flow_count", "FlowCount"), ("packet_count", "PacketCount"), ("byte_count", "ByteCount")): ip_result[data_key] += event.get(idea_key, 0) diff --git a/lib/mentat/reports/test_event.py b/lib/mentat/reports/test_event.py index 177b7fc00a54bcf50476663070e0ad18baa0e4b8..9fdc7510935a904422443d7e3763db326a48c81d 100644 --- a/lib/mentat/reports/test_event.py +++ b/lib/mentat/reports/test_event.py @@ -32,6 +32,7 @@ import mentat.reports.utils import mentat.reports.event from mentat.datatype.sqldb import GroupModel, FilterModel, NetworkModel, \ SettingsReportingModel, EventReportModel +from pynspect.jpath import jpath_values #------------------------------------------------------------------------------- # NOTE: Sorry for the long lines in this file. They are deliberate, because the @@ -95,7 +96,7 @@ class TestMentatReportsEvent(unittest.TestCase): 'Description': 'Synthetic example 02', 'Source': [ { - 'IP4': ['192.168.1.2-192.168.1.5', '192.169.0.0/25', '10.0.0.1'], + 'IP4': ['10.0.1.2-10.0.1.5', '10.0.0.0/25', '10.0.0.0/22', '10.0.2.1'], 'IP6': ['2002:db8::ff00:42:0/112'] } ], @@ -158,7 +159,8 @@ class TestMentatReportsEvent(unittest.TestCase): group = GroupModel(name = 'abuse@cesnet.cz', source = 'manual', description = 'CESNET, z.s.p.o.') FilterModel(group = group, name = 'FLT1', type = 'basic', filter = 'Node.Name == "org.example.kippo_honey"', description = 'DESC1', enabled = True) - FilterModel(group = group, name = 'FLT2', type = 'basic', filter = 'Category == "Recon.Scanning"', description = 'DESC2', enabled = True) + FilterModel(group = group, name = 'FLT2', type = 'basic', filter = 'Source.IP4 IN [10.0.0.0/24]', description = 'DESC2', enabled = True) + FilterModel(group = group, name = 'FLT3', type = 'basic', filter = 'Source.IP4 IN [10.0.1.0/28]', description = 'DESC3', enabled = True) NetworkModel(group = group, netname = 'UNET1', source = 'manual', network = '10.0.0.0/8') SettingsReportingModel(group = group) @@ -329,24 +331,7 @@ class TestMentatReportsEvent(unittest.TestCase): os.unlink(report_path) os.unlink("{}.zip".format(report_path)) - def test_05_aggr_events_by_source(self): - """ - Test :py:func:`mentat.reports.event.EventReporter.aggregate_events_by_source` function. - """ - self.maxDiff = None - - abuse_group = self.sqlstorage.session.query(GroupModel).filter(GroupModel.name == 'abuse@cesnet.cz').one() - self.sqlstorage.session.commit() - - reporting_settings = mentat.reports.utils.ReportingSettings( - abuse_group - ) - - events_aggr = self.reporter.aggregate_events_by_source(self.ideas_obj, reporting_settings) - self.assertEqual(list(sorted(events_aggr.keys())), ['10.0.0.1']) - self.assertEqual(list(sorted(map(lambda x: x['ID'], events_aggr['10.0.0.1']))), ['msg01','msg02']) - - def test_06_filter_events(self): + def test_05_filter_events(self): """ Test :py:class:`mentat.reports.event.EventReporter.filter_events` function. """ @@ -358,28 +343,32 @@ class TestMentatReportsEvent(unittest.TestCase): reporting_settings = mentat.reports.utils.ReportingSettings( abuse_group ) - events, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) - self.assertEqual(events, []) - self.assertEqual(fltlog, {'FLT1': 1, 'FLT2': 1}) + events, aggr, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) + self.assertEqual(fltlog, {'FLT1': 1}) + self.assertEqual(len(aggr), 2) self.reporter.logger.assert_has_calls([ - call.debug("Event matched filtering rule '%s'", 'FLT1'), - call.debug("Event matched filtering rule '%s'", 'FLT2'), - call.info('%s: Filters blocked all %d events, nothing to report.', 'abuse@cesnet.cz', 2) + call.debug("Event matched filtering rule '%s', all sources filtered", 'FLT1'), + call.debug("Event matched filtering rule '%s', some sources allowed through", 'FLT2'), + call.info('%s: Filters let %d events through, %d blocked.', 'abuse@cesnet.cz', 1, 1) ]) self.sqlstorage.session.commit() - events, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) + events, aggr, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) self.sqlstorage.session.commit() flt1 = self.sqlstorage.session.query(FilterModel).filter(FilterModel.name == 'FLT1').one() self.assertEqual(flt1.hits, 2) - events, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) - events, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) + events, aggr, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) + events, aggr, fltlog = self.reporter.filter_events(self.ideas_obj, abuse_group, reporting_settings) self.sqlstorage.session.commit() flt1 = self.sqlstorage.session.query(FilterModel).filter(FilterModel.name == 'FLT1').one() self.assertEqual(flt1.hits, 4) - def test_07_fetch_severity_events(self): + aggr = self.reporter.aggregate_events(aggr) + self.assertEqual(list(sorted(aggr.keys())), ['anomaly-traffic']) + self.assertEqual(list(aggr['anomaly-traffic'].keys()), ['10.0.2.1', '10.0.0.0/22']) + + def test_06_fetch_severity_events(self): """ Test :py:class:`mentat.reports.event.EventReporter.fetch_severity_events` function. """ @@ -412,7 +401,7 @@ class TestMentatReportsEvent(unittest.TestCase): ) self.assertEqual(list(map(lambda x: x['ID'], events)), []) - def test_08_j2t_idea_path_valueset(self): + def test_07_j2t_idea_path_valueset(self): """ Test :py:class:`mentat.reports.event.EventReporter.j2t_idea_path_valueset` function. """ @@ -445,25 +434,7 @@ class TestMentatReportsEvent(unittest.TestCase): ['https', 'ssh'] ) - def test_09_aggregate_events(self): - """ - Test :py:class:`mentat.reports.event.EventReporter.aggregate_events` function. - """ - self.maxDiff = None - - abuse_group = self.sqlstorage.session.query(GroupModel).filter(GroupModel.name == 'abuse@cesnet.cz').one() - self.sqlstorage.session.commit() - - reporting_settings = mentat.reports.utils.ReportingSettings( - abuse_group - ) - - events_aggr = self.reporter.aggregate_events_by_source(self.ideas_obj, reporting_settings) - aggr = self.reporter.aggregate_events(events_aggr) - self.assertEqual(list(sorted(aggr.keys())), ['anomaly-traffic', 'class01']) - self.assertEqual(list(sorted([list(v.keys()) for v in aggr.values()])), [['10.0.0.1'], ['10.0.0.1']]) - - def test_10_render_report_summary(self): + def test_08_render_report_summary(self): """ Test :py:class:`mentat.reports.event.EventReporter.render_report_summary` function. """ @@ -504,7 +475,7 @@ class TestMentatReportsEvent(unittest.TestCase): self.assertTrue(report_txt) self.assertEqual(report_txt.split('\n')[0], 'Váženà kolegové.') - def test_11_render_report_extra(self): + def test_09_render_report_extra(self): """ Test :py:class:`mentat.reports.event.EventReporter.render_report_extra` function. """ @@ -596,7 +567,11 @@ class TestMentatReportsEvent(unittest.TestCase): report.statistics = mentat.stats.idea.truncate_evaluations( mentat.stats.idea.evaluate_events(self.ideas_obj) ) - events_aggr = self.reporter.aggregate_events_by_source(self.ideas_obj, self.reporting_settings) + + events_aggr = {} + for obj in self.ideas_obj: + for src in (jpath_values(obj, 'Source.IP4') + jpath_values(obj, 'Source.IP6')): + events_aggr[src] = [obj] report.structured_data = self.reporter.prepare_structured_data(events_aggr, events_aggr, self.reporting_settings) return report diff --git a/lib/mentat/reports/test_utils.py b/lib/mentat/reports/test_utils.py index d21e0b53d07b870b458888cbae5c5a8fc663bd3b..a749ad0be15b3053abc08cd9a1b3dc4e1f965ec4 100644 --- a/lib/mentat/reports/test_utils.py +++ b/lib/mentat/reports/test_utils.py @@ -25,6 +25,7 @@ import pprint # Custom libraries # from pynspect.gparser import PynspectFilterParser +from pynspect.jpath import jpath_values import mentat.const import mentat.services.sqlstorage @@ -175,23 +176,15 @@ class TestMentatReportsUtils(unittest.TestCase): """ self.maxDiff = None + for ip in (jpath_values(self.ideas_raw[0], "Source.IP4") + jpath_values(self.ideas_raw[0], "Source.IP6")): + key = self.stcache._generate_cache_key(self.ideas_raw[0], ip) + self.assertEqual(self.stcache.get_source_from_cache_key(key), ip) self.assertEqual( - self.stcache._generate_cache_keys(self.ideas_raw[0]), # pylint: disable=locally-disabled,protected-access - [ - 'class01+++192.168.0.2-192.168.0.5', - 'class01+++192.168.0.0/25', - 'class01+++10.0.0.1', - 'class01+++2001:db8::ff00:42:0/112' - ] - ) - self.assertEqual( - self.stcache._generate_cache_keys({ # pylint: disable=locally-disabled,protected-access + self.stcache._generate_cache_key({ # pylint: disable=locally-disabled,protected-access 'Category': ['Test', 'Value'], 'Source': [{'IP4': ['195.113.144.194']}] - }), - [ - 'Test/Value+++195.113.144.194' - ] + }, '195.113.144.194'), + 'Test/Value+++195.113.144.194' ) def test_02_no_thr_cache(self): @@ -205,13 +198,16 @@ class TestMentatReportsUtils(unittest.TestCase): thresholdtime = relapsetime - datetime.timedelta(seconds = 600) self.assertFalse( - self.ntcache.event_is_thresholded(self.ideas_obj[0], ttltime) + self.ntcache.event_is_thresholded(self.ideas_obj[0], '192.168.1.1', ttltime) ) self.ntcache.set_threshold(self.ideas_obj[0], '192.168.1.1', thresholdtime, relapsetime, ttltime) self.assertFalse( - self.ntcache.event_is_thresholded(self.ideas_obj[0], ttltime) + self.ntcache.event_is_thresholded(self.ideas_obj[0], '192.168.1.1', ttltime) + ) + self.ntcache.threshold_event(self.ideas_obj[0], '192.168.1.1', 'Test', 'low', datetime.datetime.utcnow()) + self.assertFalse( + self.ntcache.event_is_thresholded(self.ideas_obj[0], '192.168.1.1', ttltime) ) - self.ntcache.threshold_event(self.ideas_obj[0], 'Test', 'low', datetime.datetime.utcnow()) def test_03_storage_thr_cache(self): """ @@ -224,36 +220,33 @@ class TestMentatReportsUtils(unittest.TestCase): thrtime = reltime - datetime.timedelta(seconds = 300) self.assertFalse( - self.stcache.event_is_thresholded(self.ideas_obj[0], ttltime) + self.stcache.event_is_thresholded(self.ideas_obj[0], '192.168.0.2-192.168.0.5', ttltime) ) self.stcache.set_threshold(self.ideas_obj[0], '192.168.0.2-192.168.0.5', thrtime, reltime, ttltime) - self.assertFalse( - self.stcache.event_is_thresholded(self.ideas_obj[0], ttltime - datetime.timedelta(seconds = 50)) - ) - self.stcache.set_threshold(self.ideas_obj[0], '192.168.0.0/25', thrtime, reltime, ttltime) - self.assertFalse( - self.stcache.event_is_thresholded(self.ideas_obj[0], ttltime - datetime.timedelta(seconds = 50)) + self.assertTrue( + self.stcache.event_is_thresholded(self.ideas_obj[0], '192.168.0.2-192.168.0.5', ttltime - datetime.timedelta(seconds = 50)) ) - self.stcache.set_threshold(self.ideas_obj[0], '10.0.0.1', thrtime, reltime, ttltime) self.assertFalse( - self.stcache.event_is_thresholded(self.ideas_obj[0], ttltime - datetime.timedelta(seconds = 50)) + self.stcache.event_is_thresholded(self.ideas_obj[0], '192.168.0.0/25', ttltime - datetime.timedelta(seconds = 50)) ) - self.stcache.set_threshold(self.ideas_obj[0], '2001:db8::ff00:42:0/112', thrtime, reltime, ttltime) + self.stcache.set_threshold(self.ideas_obj[0], '192.168.0.0/25', thrtime, reltime, ttltime) self.assertTrue( - self.stcache.event_is_thresholded(self.ideas_obj[0], ttltime - datetime.timedelta(seconds = 50)) + self.stcache.event_is_thresholded(self.ideas_obj[0], '192.168.0.0/25', ttltime - datetime.timedelta(seconds = 50)) ) - self.stcache.threshold_event(self.ideas_obj[0], 'test@domain.org', 'low', ttltime - datetime.timedelta(seconds = 50)) + self.stcache.threshold_event(self.ideas_obj[0], '192.168.0.2-192.168.0.5', 'test@domain.org', 'low', ttltime - datetime.timedelta(seconds = 50)) + self.stcache.threshold_event(self.ideas_obj[0], '192.168.0.0/25', 'test@domain.org', 'low', ttltime - datetime.timedelta(seconds = 50)) relapses = self.stcache.relapses('test@domain.org', 'low', ttltime + datetime.timedelta(seconds = 50)) self.assertEqual(len(relapses), 1) - self.assertEqual(relapses[0]['ID'], 'msg01') + self.assertEqual(relapses[0][0], 'msg01') + self.assertEqual(len(relapses[0][2]), 2) - self.assertEqual(self.stcache.eventservice.thresholds_count(), 4) - self.assertEqual(self.stcache.eventservice.thresholded_events_count(), 4) + self.assertEqual(self.stcache.eventservice.thresholds_count(), 2) + self.assertEqual(self.stcache.eventservice.thresholded_events_count(), 2) self.assertEqual( self.stcache.cleanup(ttltime + datetime.timedelta(seconds = 50)), - {'thresholds': 4, 'events': 4} + {'thresholds': 2, 'events': 2} ) def test_04_reporting_settings(self): diff --git a/lib/mentat/reports/utils.py b/lib/mentat/reports/utils.py index f691a50947178ffb92df737176543c24ea72a284..52e409677aa5093452c20e7a377da85f0711fff5 100644 --- a/lib/mentat/reports/utils.py +++ b/lib/mentat/reports/utils.py @@ -306,22 +306,18 @@ class ThresholdingCache: reporting. """ - def event_is_thresholded(self, event, ttl): + def event_is_thresholded(self, event, source, ttl): """ - Check, that given event is thresholded within given TTL. + Check, that given combination of event and source is thresholded within given TTL. :param mentat.idea.internal.Idea event: IDEA event to check. + :param str source: Source to check. :param datetime.datetime ttl: TTL for the thresholding record. :return: ``True`` in case the event is thresholded, ``False`` otherwise. :rtype: bool """ - cachekeys = self._generate_cache_keys(event) - result = True - for chk in cachekeys: - res = self.check(chk, ttl) - if not res: - result = False - return result + cachekey = self._generate_cache_key(event, source) + return self.check(cachekey, ttl) def set_threshold(self, event, source, thresholdtime, relapsetime, ttl): """ @@ -333,24 +329,21 @@ class ThresholdingCache: :param datetime.datetime relapsetime: Relapse window start time. :param datetime.datetime ttl: Record TTL. """ - if source: - source = [source] - cachekeys = self._generate_cache_keys(event, source) - for chk in cachekeys: - self.set(chk, thresholdtime, relapsetime, ttl) + cachekey = self._generate_cache_key(event, source) + self.set(cachekey, thresholdtime, relapsetime, ttl) - def threshold_event(self, event, group_name, severity, createtime): + def threshold_event(self, event, source, group_name, severity, createtime): """ Threshold given event with given TTL. :param mentat.idea.internal.Idea event: IDEA event to threshold. + :param str source: Source address because of which to threshold the event. :param str group_name: Name of the group for which to threshold. :param str severity: Event severity. :param datetime.datetime createtime: Thresholding timestamp. """ - cachekeys = self._generate_cache_keys(event) - for chk in cachekeys: - self.save(event.get_id(), chk, group_name, severity, createtime) + cachekey = self._generate_cache_key(event, source) + self.save(event.get_id(), cachekey, group_name, severity, createtime) #--------------------------------------------------------------------------- @@ -394,8 +387,8 @@ class ThresholdingCache: :param str group_name: Name of the abuse group. :param str severity: Event severity. :param datetime.datetime ttl: Record TTL time. - :return: List of relapsed events as :py:class:`mentat.idea.internal.Idea` objects. - :rtype: list + :return: Touple with list of relapsed events as :py:class:`mentat.idea.internal.Idea` objects and their aggregation by keyid. + :rtype: touple """ raise NotImplementedError() @@ -409,25 +402,29 @@ class ThresholdingCache: #--------------------------------------------------------------------------- - def _generate_cache_keys(self, event, sources = None): + def _generate_cache_key(self, event, source): """ - Generate list of thresholding/relapse cache keys for given event. + Generate cache key for given event and source. :param mentat.idea.internal.Idea event: Event to process. - :return: List of cache keys as strings. - :rtype: list + :param str source: Source to process. + :return: Cache key as strings. + :rtype: str """ - keys = [] event_class = jpath_value(event, '_CESNET.EventClass') if not event_class: event_class = '/'.join(jpath_values(event, 'Category')) - if not sources: - sources_ip4 = jpath_values(event, 'Source.IP4') - sources_ip6 = jpath_values(event, 'Source.IP6') - sources = sources_ip4 + sources_ip6 - for src in sources: - keys.append('+++'.join((event_class, str(src)))) - return keys + return '+++'.join((event_class, str(source))) + + def get_source_from_cache_key(self, key): + """ + Return source from which was key generated. + + :param str key: Cache key. + :return: Cached source. + :rtype: str + """ + return key.split('+++')[1] if key and len(key.split('+++')) > 1 else key class NoThresholdingCache(ThresholdingCache): @@ -517,15 +514,16 @@ class SingleSourceThresholdingCache(SimpleMemoryThresholdingCache): super().__init__() self.source = source - def _generate_cache_keys(self, event, sources = None): + def _generate_cache_key(self, event, source): """ - Generate list of thresholding/relapse cache keys for given event. + Generate cache key for given event and source. :param mentat.idea.internal.Idea event: Event to process. - :return: List of cache keys as strings. - :rtype: list + :param str source: Source to process. + :return: Cache key as strings. + :rtype: str """ - return super()._generate_cache_keys(event, [self.source]) + return super()._generate_cache_key(event, self.source) class StorageThresholdingCache(ThresholdingCache): diff --git a/lib/mentat/services/eventstorage.py b/lib/mentat/services/eventstorage.py index 4e884327dcc638e2d6aaefc03dffbe56daeb450a..b8e987ee5f359f8253fdfd14d605b9df79bf34a2 100644 --- a/lib/mentat/services/eventstorage.py +++ b/lib/mentat/services/eventstorage.py @@ -519,7 +519,7 @@ class EventStorageCursor: query, params = build_query(parameters, qtype = 'count', qname = qname) self.lastquery = self.cursor.mogrify(query, params) self.cursor.execute(query, params) - + record = self.cursor.fetchone() if record: return record[0] @@ -551,7 +551,7 @@ class EventStorageCursor: def delete_events(self, parameters = None, qname = None): """ - Delete IDEA events in database according to given parameters. There is + Delete IDEA events in database according to given parameters. There is an option to assign given unique name to the query, so that it can be identified within the ``pg_stat_activity`` table. @@ -717,15 +717,14 @@ class EventStorageCursor: :param str group_name: Name of the abuse group. :param str severity: Event severity. :param datetime.datetime ttl: Record TTL time. - :return: List of relapsed events as :py:class:`mentat.idea.internal.Idea` objects. + :return: List of relapsed events as touple of id, json of event data and list of threshold keys. :rtype: list """ self.cursor.execute( - "SELECT events.*, events_json.event FROM events INNER JOIN events_json ON events.id = events_json.id WHERE events.id IN (SELECT DISTINCT eventid FROM events_thresholded WHERE keyid IN (SELECT keyid FROM events_thresholded INNER JOIN thresholds ON (events_thresholded.keyid = thresholds.id) WHERE events_thresholded.groupname = %s AND events_thresholded.eventseverity = %s AND events_thresholded.createtime >= thresholds.relapsetime AND thresholds.ttltime <= %s))", + "SELECT events_json.id, events_json.event, ARRAY_AGG(events_thresholded.keyid) AS keyids FROM events_json INNER JOIN events_thresholded ON events_json.id = events_thresholded.eventid INNER JOIN thresholds ON events_thresholded.keyid = thresholds.id WHERE events_thresholded.groupname = %s AND events_thresholded.eventseverity = %s AND events_thresholded.createtime >= thresholds.relapsetime AND thresholds.ttltime <= %s GROUP BY events_json.id", (group_name, severity, ttl) ) - events_raw = self.cursor.fetchall() - return [record_to_idea(event) for event in events_raw] + return self.cursor.fetchall() def thresholded_events_count(self): """ @@ -1102,7 +1101,7 @@ class EventStorageService: try: self.cursor.insert_event(idea_event) self.savepoint_create() - return + return except psycopg2.DataError as err: self.savepoint_rollback()