Skip to content
Snippets Groups Projects
Verified Commit 33655c28 authored by Rajmund Hruška's avatar Rajmund Hruška
Browse files

Fix: Allow subtypes of Mapping and Sequence in the searched data

parent 1c78479c
No related branches found
No related tags found
No related merge requests found
import re import re
from collections.abc import MutableSequence
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
from numbers import Number from numbers import Number
...@@ -49,7 +50,7 @@ def _op_range_scalar(op: str, _range: tuple, scalar) -> tuple: ...@@ -49,7 +50,7 @@ def _op_range_scalar(op: str, _range: tuple, scalar) -> tuple:
return (binary_operation(op, start, scalar), binary_operation(op, end, scalar)) return (binary_operation(op, start, scalar), binary_operation(op, end, scalar))
def _op_scalar_list(op: str, scalar, t: list) -> list: def _op_scalar_list(op: str, scalar, t: MutableSequence) -> list:
""" """
Perform a binary operation between a scalar and each element in a list. Perform a binary operation between a scalar and each element in a list.
...@@ -64,7 +65,7 @@ def _op_scalar_list(op: str, scalar, t: list) -> list: ...@@ -64,7 +65,7 @@ def _op_scalar_list(op: str, scalar, t: list) -> list:
return [binary_operation(op, scalar, elem) for elem in t] return [binary_operation(op, scalar, elem) for elem in t]
def _op_list_scalar(op: str, t: list, scalar: Any) -> list: def _op_list_scalar(op: str, t: MutableSequence, scalar: Any) -> list:
""" """
Perform a binary operation between each element in a list and a scalar. Perform a binary operation between each element in a list and a scalar.
...@@ -115,7 +116,7 @@ def _comp_range_scalar(op: str, _range: tuple, scalar) -> bool: ...@@ -115,7 +116,7 @@ def _comp_range_scalar(op: str, _range: tuple, scalar) -> bool:
return binary_operation(op, start, scalar) or binary_operation(op, end, scalar) return binary_operation(op, start, scalar) or binary_operation(op, end, scalar)
def _comp_scalar_list(op: str, scalar: Any, t: list) -> bool: def _comp_scalar_list(op: str, scalar: Any, t: MutableSequence) -> bool:
""" """
Compare a scalar with a list using the specified operator. Compare a scalar with a list using the specified operator.
...@@ -131,7 +132,7 @@ def _comp_scalar_list(op: str, scalar: Any, t: list) -> bool: ...@@ -131,7 +132,7 @@ def _comp_scalar_list(op: str, scalar: Any, t: list) -> bool:
return any(map(partial(binary_operation, op, scalar), t)) return any(map(partial(binary_operation, op, scalar), t))
def _comp_list_scalar(op: str, t: list, scalar: Any) -> bool: def _comp_list_scalar(op: str, t: MutableSequence, scalar: Any) -> bool:
""" """
Compare each element of a list with a scalar using the specified operator. Compare each element of a list with a scalar using the specified operator.
...@@ -147,7 +148,7 @@ def _comp_list_scalar(op: str, t: list, scalar: Any) -> bool: ...@@ -147,7 +148,7 @@ def _comp_list_scalar(op: str, t: list, scalar: Any) -> bool:
return any(binary_operation(op, x, scalar) for x in t) return any(binary_operation(op, x, scalar) for x in t)
def _comp_list_list(op: str, t1: list, t2: list) -> bool: def _comp_list_list(op: str, t1: MutableSequence, t2: MutableSequence) -> bool:
""" """
Compare two lists using the specified operator. Compare two lists using the specified operator.
...@@ -193,7 +194,7 @@ def _comp_ip_ip(op: str, ip1: IP, ip2: IP) -> bool: ...@@ -193,7 +194,7 @@ def _comp_ip_ip(op: str, ip1: IP, ip2: IP) -> bool:
raise OperatorNotFoundError(op, ("ip", "ip"), (ip1, ip2)) raise OperatorNotFoundError(op, ("ip", "ip"), (ip1, ip2))
def _concat(a, b) -> str | list: def _concat(a, b) -> str | MutableSequence:
""" """
Concatenate two objects, either strings or lists. Concatenate two objects, either strings or lists.
...@@ -207,7 +208,7 @@ def _concat(a, b) -> str | list: ...@@ -207,7 +208,7 @@ def _concat(a, b) -> str | list:
return a + b return a + b
def _in_scalar_list(left, right: list) -> bool: def _in_scalar_list(left, right: MutableSequence) -> bool:
""" """
Check if a scalar value is contained within a list or iterable. Check if a scalar value is contained within a list or iterable.
...@@ -225,7 +226,7 @@ def _in_scalar_list(left, right: list) -> bool: ...@@ -225,7 +226,7 @@ def _in_scalar_list(left, right: list) -> bool:
the membership check is performed using the `binary_operation("in")`. the membership check is performed using the `binary_operation("in")`.
- For other types, the comparison is done using equality (`==`). - For other types, the comparison is done using equality (`==`).
""" """
iterable = (list, tuple, IP4Net, IP4Range, IP6Range, IP6Net) iterable = (MutableSequence, tuple, IP4Net, IP4Range, IP6Range, IP6Net)
return any( return any(
binary_operation("in", left, x) if isinstance(x, iterable) else left == x binary_operation("in", left, x) if isinstance(x, iterable) else left == x
for x in right for x in right
...@@ -248,7 +249,7 @@ def _in_scalar_range(scalar, _range: tuple) -> bool: ...@@ -248,7 +249,7 @@ def _in_scalar_range(scalar, _range: tuple) -> bool:
return binary_operation(">=", scalar, start) and binary_operation("<=", scalar, end) return binary_operation(">=", scalar, start) and binary_operation("<=", scalar, end)
def _in_list_tuple(t: list, _range: tuple) -> bool: def _in_list_tuple(t: MutableSequence, _range: tuple) -> bool:
""" """
Check if any element from the list is within a range. Check if any element from the list is within a range.
...@@ -262,7 +263,7 @@ def _in_list_tuple(t: list, _range: tuple) -> bool: ...@@ -262,7 +263,7 @@ def _in_list_tuple(t: list, _range: tuple) -> bool:
return any(_in_scalar_range(elem, _range) for elem in t) return any(_in_scalar_range(elem, _range) for elem in t)
def _in_list_list(left: list, right: list) -> bool: def _in_list_list(left: MutableSequence, right: MutableSequence) -> bool:
""" """
Check if any element of one list is in another list. Check if any element of one list is in another list.
...@@ -296,15 +297,15 @@ def _get_comp_dict(op: str, comp: Callable) -> dict[tuple[Operand, Operand], Cal ...@@ -296,15 +297,15 @@ def _get_comp_dict(op: str, comp: Callable) -> dict[tuple[Operand, Operand], Cal
(datetime, tuple): partial(_comp_scalar_range, op), (datetime, tuple): partial(_comp_scalar_range, op),
(tuple, Number): partial(_comp_range_scalar, op), (tuple, Number): partial(_comp_range_scalar, op),
(tuple, datetime): partial(_comp_range_scalar, op), (tuple, datetime): partial(_comp_range_scalar, op),
("ip", list): partial(_comp_scalar_list, op), ("ip", MutableSequence): partial(_comp_scalar_list, op),
(Number, list): partial(_comp_scalar_list, op), (Number, MutableSequence): partial(_comp_scalar_list, op),
(datetime, list): partial(_comp_scalar_list, op), (datetime, MutableSequence): partial(_comp_scalar_list, op),
(timedelta, list): partial(_comp_scalar_list, op), (timedelta, MutableSequence): partial(_comp_scalar_list, op),
(list, "ip"): partial(_comp_list_scalar, op), (MutableSequence, "ip"): partial(_comp_list_scalar, op),
(list, Number): partial(_comp_list_scalar, op), (MutableSequence, Number): partial(_comp_list_scalar, op),
(list, datetime): partial(_comp_list_scalar, op), (MutableSequence, datetime): partial(_comp_list_scalar, op),
(list, timedelta): partial(_comp_list_scalar, op), (MutableSequence, timedelta): partial(_comp_list_scalar, op),
(list, list): partial(_comp_list_list, op), (MutableSequence, MutableSequence): partial(_comp_list_list, op),
} }
...@@ -316,9 +317,9 @@ _operator_map: dict[str, dict[tuple[Operand, Operand], Callable]] = { ...@@ -316,9 +317,9 @@ _operator_map: dict[str, dict[tuple[Operand, Operand], Callable]] = {
(timedelta, timedelta): add, (timedelta, timedelta): add,
(Number, tuple): partial(_op_scalar_range, "+"), (Number, tuple): partial(_op_scalar_range, "+"),
(timedelta, tuple): partial(_op_scalar_range, "+"), (timedelta, tuple): partial(_op_scalar_range, "+"),
(Number, list): partial(_op_scalar_list, "+"), (Number, MutableSequence): partial(_op_scalar_list, "+"),
(datetime, list): partial(_op_scalar_list, "+"), (datetime, MutableSequence): partial(_op_scalar_list, "+"),
(timedelta, list): partial(_op_scalar_list, "+"), (timedelta, MutableSequence): partial(_op_scalar_list, "+"),
}, },
"-": { "-": {
(Number, Number): sub, (Number, Number): sub,
...@@ -334,35 +335,35 @@ _operator_map: dict[str, dict[tuple[Operand, Operand], Callable]] = { ...@@ -334,35 +335,35 @@ _operator_map: dict[str, dict[tuple[Operand, Operand], Callable]] = {
(tuple, timedelta): partial(_op_range_scalar, "-"), (tuple, timedelta): partial(_op_range_scalar, "-"),
(tuple, datetime): partial(_op_range_scalar, "-"), (tuple, datetime): partial(_op_range_scalar, "-"),
# x - list[x] # x - list[x]
(Number, list): partial(_op_scalar_list, "-"), (Number, MutableSequence): partial(_op_scalar_list, "-"),
(datetime, list): partial(_op_scalar_list, "-"), (datetime, MutableSequence): partial(_op_scalar_list, "-"),
(timedelta, list): partial(_op_scalar_list, "-"), (timedelta, MutableSequence): partial(_op_scalar_list, "-"),
# list[x] - x # list[x] - x
(list, Number): partial(_op_list_scalar, "-"), (MutableSequence, Number): partial(_op_list_scalar, "-"),
(list, datetime): partial(_op_list_scalar, "-"), (MutableSequence, datetime): partial(_op_list_scalar, "-"),
(list, timedelta): partial(_op_list_scalar, "-"), (MutableSequence, timedelta): partial(_op_list_scalar, "-"),
}, },
"*": { "*": {
(timedelta, Number): mul, (timedelta, Number): mul,
(Number, Number): mul, (Number, Number): mul,
(timedelta, list): partial(_op_scalar_list, "*"), (timedelta, MutableSequence): partial(_op_scalar_list, "*"),
(Number, list): partial(_op_scalar_list, "*"), (Number, MutableSequence): partial(_op_scalar_list, "*"),
}, },
"/": { "/": {
(timedelta, timedelta): truediv, (timedelta, timedelta): truediv,
(Number, Number): truediv, (Number, Number): truediv,
(timedelta, list): partial(_op_scalar_list, "/"), (timedelta, MutableSequence): partial(_op_scalar_list, "/"),
(Number, list): partial(_op_scalar_list, "/"), (Number, MutableSequence): partial(_op_scalar_list, "/"),
(list, timedelta): partial(_op_list_scalar, "/"), (MutableSequence, timedelta): partial(_op_list_scalar, "/"),
(list, Number): partial(_op_list_scalar, "/"), (MutableSequence, Number): partial(_op_list_scalar, "/"),
}, },
"%": { "%": {
(timedelta, timedelta): mod, (timedelta, timedelta): mod,
(Number, Number): mod, (Number, Number): mod,
(timedelta, list): partial(_op_scalar_list, "%"), (timedelta, MutableSequence): partial(_op_scalar_list, "%"),
(Number, list): partial(_op_scalar_list, "%"), (Number, MutableSequence): partial(_op_scalar_list, "%"),
(list, timedelta): partial(_op_list_scalar, "%"), (MutableSequence, timedelta): partial(_op_list_scalar, "%"),
(list, Number): partial(_op_list_scalar, "%"), (MutableSequence, Number): partial(_op_list_scalar, "%"),
}, },
">": _get_comp_dict(">", gt), ">": _get_comp_dict(">", gt),
">=": _get_comp_dict(">=", ge), ">=": _get_comp_dict(">=", ge),
...@@ -371,25 +372,25 @@ _operator_map: dict[str, dict[tuple[Operand, Operand], Callable]] = { ...@@ -371,25 +372,25 @@ _operator_map: dict[str, dict[tuple[Operand, Operand], Callable]] = {
"=": _get_comp_dict("=", eq), "=": _get_comp_dict("=", eq),
".": { ".": {
(str, str): _concat, (str, str): _concat,
(list, list): _concat, (MutableSequence, MutableSequence): _concat,
}, },
"contains": { "contains": {
(str, str): lambda value, pattern: pattern in value, (str, str): lambda value, pattern: pattern in value,
(list, str): lambda t, x: any(x in elem for elem in t), (MutableSequence, str): lambda t, x: any(x in elem for elem in t),
}, },
"like": {(str, str): lambda data, pattern: re.match(pattern, data) is not None}, "like": {(str, str): lambda data, pattern: re.match(pattern, data) is not None},
"in": { "in": {
("ip", "ip"): lambda left, right: left in right, ("ip", "ip"): lambda left, right: left in right,
(str, list): _in_scalar_list, (str, MutableSequence): _in_scalar_list,
(Number, list): _in_scalar_list, (Number, MutableSequence): _in_scalar_list,
(timedelta, list): _in_scalar_list, (timedelta, MutableSequence): _in_scalar_list,
(datetime, list): _in_scalar_list, (datetime, MutableSequence): _in_scalar_list,
("ip", list): _in_scalar_list, ("ip", MutableSequence): _in_scalar_list,
("ip", tuple): _in_scalar_range, ("ip", tuple): _in_scalar_range,
(Number, tuple): _in_scalar_range, (Number, tuple): _in_scalar_range,
(datetime, tuple): _in_scalar_range, (datetime, tuple): _in_scalar_range,
(list, tuple): _in_list_tuple, (MutableSequence, tuple): _in_list_tuple,
(list, list): _in_list_list, (MutableSequence, MutableSequence): _in_list_list,
}, },
} }
...@@ -408,6 +409,8 @@ def _resolve_type(value: Any) -> type | str: ...@@ -408,6 +409,8 @@ def _resolve_type(value: Any) -> type | str:
return Number return Number
if isinstance(value, (IP4, IP4Range, IP4Net, IP6, IP6Net, IP6Range)): if isinstance(value, (IP4, IP4Range, IP4Net, IP6, IP6Net, IP6Range)):
return "ip" return "ip"
if isinstance(value, MutableSequence):
return MutableSequence
return type(value) return type(value)
......
...@@ -23,6 +23,7 @@ Classes: ...@@ -23,6 +23,7 @@ Classes:
""" """
import re import re
from collections.abc import Mapping, MutableSequence
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
...@@ -648,7 +649,7 @@ class Filter(Interpreter): ...@@ -648,7 +649,7 @@ class Filter(Interpreter):
return var return var
def _get_data_value( def _get_data_value(
self, path: str, data: Optional[dict | list] self, path: str, data: Optional[Mapping | MutableSequence]
) -> Tuple[Any, bool]: ) -> Tuple[Any, bool]:
""" """
Retrieves a value from the data structure (dictionary or list) based on Retrieves a value from the data structure (dictionary or list) based on
...@@ -686,20 +687,20 @@ class Filter(Interpreter): ...@@ -686,20 +687,20 @@ class Filter(Interpreter):
key, *remaining_list = path.split(".", 1) key, *remaining_list = path.split(".", 1)
remaining_path = remaining_list[0] if remaining_list else "" remaining_path = remaining_list[0] if remaining_list else ""
if isinstance(data, dict): if isinstance(data, Mapping):
# Navigate dictionary # Navigate dictionary
if key not in data: if key not in data:
return None, False return None, False
return self._get_data_value(remaining_path, data[key]) return self._get_data_value(remaining_path, data[key])
elif isinstance(data, list): elif isinstance(data, MutableSequence):
# Aggregate results from all list elements # Aggregate results from all list elements
aggregated = [] aggregated = []
for item in data: for item in data:
if isinstance(item, (dict, list)): if isinstance(item, (Mapping, MutableSequence)):
result, _ = self._get_data_value(path, item) result, _ = self._get_data_value(path, item)
if result is not None: if result is not None:
if isinstance(result, list): if isinstance(result, MutableSequence):
aggregated.extend(result) aggregated.extend(result)
else: else:
aggregated.append(result) aggregated.append(result)
......
from collections.abc import MutableSequence
import pytest import pytest
from ipranges import IP4Range from ipranges import IP4Range
from ransack.operator import OperatorNotFoundError, _comp_ip_ip, binary_operation from ransack.exceptions import OperatorNotFoundError
from ransack.operator import _comp_ip_ip, binary_operation
def test_comp_ip_ip(): def test_comp_ip_ip():
...@@ -65,6 +68,10 @@ def test_binary_operation_operator_not_found(): ...@@ -65,6 +68,10 @@ def test_binary_operation_operator_not_found():
left_type = "Number" if isinstance(left, (int, float)) else str(type(left)) left_type = "Number" if isinstance(left, (int, float)) else str(type(left))
right_type = "Number" if isinstance(right, (int, float)) else str(type(right)) right_type = "Number" if isinstance(right, (int, float)) else str(type(right))
# Special handling for lists (list is resolved to MutableSequence
if isinstance(right, MutableSequence):
right_type = str(MutableSequence)
# Check if the exception message contains the correct operator and types # Check if the exception message contains the correct operator and types
assert f"Operator '{operator}' not found" in exc_message assert f"Operator '{operator}' not found" in exc_message
assert left_type in exc_message assert left_type in exc_message
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment