From 69204d1a71046cf646ec5f966c9abb12ea5034bf Mon Sep 17 00:00:00 2001
From: Jan Zerdik <zerdik@cesnet.cz>
Date: Fri, 14 Jun 2019 10:03:09 +0200
Subject: [PATCH] Fix for comparision of ips in python3, identical behavior in
python 2 and 3
---
ipranges.py | 29 +++++++++++++++++++++++++++++
test_ipranges.py | 37 +++++++++++++++++++++++++++++++++++++
2 files changed, 66 insertions(+)
diff --git a/ipranges.py b/ipranges.py
index bda70af..326ccc6 100644
--- a/ipranges.py
+++ b/ipranges.py
@@ -11,6 +11,7 @@ import socket
import struct
import numbers
import sys
+import operator
try:
basestring
@@ -147,6 +148,15 @@ class IPNetBase(IPBase):
def __hash__(self):
return hash((self.base, self.mask))
+def check_types(f):
+ def wrapper(a, b):
+ raise RuntimeError("AAAa")
+ if not isinstance(b, type(a)):
+ return NotImplemented
+ else:
+ return f(a, b)
+
+_opperations = {'__lt__': '<', '__le__': '<=', '__gt__': '>', '__ge__': '>='}
class IPAddrBase(IPBase):
__slots__ = ("ip")
@@ -172,6 +182,25 @@ class IPAddrBase(IPBase):
def __hash__(self): return hash(self.ip)
+ def _check_types(f):
+ def wrapper(self, other):
+ if not isinstance(other, type(self)) and not isinstance(self, type(other)):
+ raise TypeError("'{}' not supported between instances of '{}' and '{}'".format(_opperations.get(f.__name__), type(self).__name__, type(other).__name__))
+ return f(self, other)
+ return wrapper
+
+ @_check_types
+ def __lt__(self, other): return self.ip < other.ip
+
+ @_check_types
+ def __le__(self, other): return self.ip <= other.ip
+
+ @_check_types
+ def __gt__(self, other): return self.ip > other.ip
+
+ @_check_types
+ def __ge__(self, other): return self.ip >= other.ip
+
def low(self): return self.ip
def high(self): return self.ip
diff --git a/test_ipranges.py b/test_ipranges.py
index 8ad0955..abffa15 100755
--- a/test_ipranges.py
+++ b/test_ipranges.py
@@ -275,6 +275,43 @@ class TestIPRange(unittest.TestCase):
with self.assertRaises(ValueError):
from_str("/")
+ def testComparing(self):
+ ip4_low1 = from_str("0.0.0.1")
+ ip4_low2 = from_str("0.0.0.1")
+ ip4_high = from_str("0.0.0.2")
+ self.assertGreater(ip4_high, ip4_low1)
+ self.assertGreaterEqual(ip4_high, ip4_low1)
+ self.assertGreaterEqual(ip4_low1, ip4_low2)
+ self.assertLess(ip4_low1, ip4_high)
+ self.assertLessEqual(ip4_low1, ip4_high)
+ self.assertLessEqual(ip4_low1, ip4_low2)
+
+ ip6_low1 = from_str("::1")
+ ip6_low2 = from_str("::1")
+ ip6_high = from_str("::2")
+ self.assertGreater(ip6_high, ip6_low1)
+ self.assertGreaterEqual(ip6_high, ip6_low1)
+ self.assertGreaterEqual(ip6_low1, ip6_low2)
+ self.assertLess(ip6_low1, ip6_high)
+ self.assertLessEqual(ip6_low1, ip6_high)
+ self.assertLessEqual(ip6_low1, ip6_low2)
+
+ with self.assertRaises(TypeError):
+ ip4_low1 < ip6_high
+ with self.assertRaises(TypeError):
+ ip6_high > ip4_low1
+
+ with self.assertRaises(TypeError):
+ ip4_low1 < "0.0.0.2"
+ with self.assertRaises(TypeError):
+ ip6_high > "::1"
+
+ class IP4Sub(IP4): pass
+
+ ip4_sub = IP4Sub(1)
+ self.assertGreater(ip4_high, ip4_sub)
+ with self.assertRaises(TypeError):
+ ip6_high > ip4_sub
if __name__ == "__main__":
unittest.main()
--
GitLab