diff --git a/ipranges.py b/ipranges.py index bda70af326380ffdfd094223de49d580f75be864..326ccc672a0438fad3443dca85d549e050a4534f 100644 --- a/ipranges.py +++ b/ipranges.py @@ -11,6 +11,7 @@ import socket import struct import numbers import sys +import operator try: basestring @@ -147,6 +148,15 @@ class IPNetBase(IPBase): def __hash__(self): return hash((self.base, self.mask)) +def check_types(f): + def wrapper(a, b): + raise RuntimeError("AAAa") + if not isinstance(b, type(a)): + return NotImplemented + else: + return f(a, b) + +_opperations = {'__lt__': '<', '__le__': '<=', '__gt__': '>', '__ge__': '>='} class IPAddrBase(IPBase): __slots__ = ("ip") @@ -172,6 +182,25 @@ class IPAddrBase(IPBase): def __hash__(self): return hash(self.ip) + def _check_types(f): + def wrapper(self, other): + if not isinstance(other, type(self)) and not isinstance(self, type(other)): + raise TypeError("'{}' not supported between instances of '{}' and '{}'".format(_opperations.get(f.__name__), type(self).__name__, type(other).__name__)) + return f(self, other) + return wrapper + + @_check_types + def __lt__(self, other): return self.ip < other.ip + + @_check_types + def __le__(self, other): return self.ip <= other.ip + + @_check_types + def __gt__(self, other): return self.ip > other.ip + + @_check_types + def __ge__(self, other): return self.ip >= other.ip + def low(self): return self.ip def high(self): return self.ip diff --git a/test_ipranges.py b/test_ipranges.py index 8ad09556becd8b2a830ed9506c741a26300b0865..abffa15a0f29981f6ddb4852f0f6bb0591132eaf 100755 --- a/test_ipranges.py +++ b/test_ipranges.py @@ -275,6 +275,43 @@ class TestIPRange(unittest.TestCase): with self.assertRaises(ValueError): from_str("/") + def testComparing(self): + ip4_low1 = from_str("0.0.0.1") + ip4_low2 = from_str("0.0.0.1") + ip4_high = from_str("0.0.0.2") + self.assertGreater(ip4_high, ip4_low1) + self.assertGreaterEqual(ip4_high, ip4_low1) + self.assertGreaterEqual(ip4_low1, ip4_low2) + self.assertLess(ip4_low1, ip4_high) + self.assertLessEqual(ip4_low1, ip4_high) + self.assertLessEqual(ip4_low1, ip4_low2) + + ip6_low1 = from_str("::1") + ip6_low2 = from_str("::1") + ip6_high = from_str("::2") + self.assertGreater(ip6_high, ip6_low1) + self.assertGreaterEqual(ip6_high, ip6_low1) + self.assertGreaterEqual(ip6_low1, ip6_low2) + self.assertLess(ip6_low1, ip6_high) + self.assertLessEqual(ip6_low1, ip6_high) + self.assertLessEqual(ip6_low1, ip6_low2) + + with self.assertRaises(TypeError): + ip4_low1 < ip6_high + with self.assertRaises(TypeError): + ip6_high > ip4_low1 + + with self.assertRaises(TypeError): + ip4_low1 < "0.0.0.2" + with self.assertRaises(TypeError): + ip6_high > "::1" + + class IP4Sub(IP4): pass + + ip4_sub = IP4Sub(1) + self.assertGreater(ip4_high, ip4_sub) + with self.assertRaises(TypeError): + ip6_high > ip4_sub if __name__ == "__main__": unittest.main()