Skip to content
Snippets Groups Projects
Commit 69204d1a authored by Jan Zerdik's avatar Jan Zerdik
Browse files

Fix for comparision of ips in python3, identical behavior in python 2 and 3

parent e9f1fa04
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ import socket ...@@ -11,6 +11,7 @@ import socket
import struct import struct
import numbers import numbers
import sys import sys
import operator
try: try:
basestring basestring
...@@ -147,6 +148,15 @@ class IPNetBase(IPBase): ...@@ -147,6 +148,15 @@ class IPNetBase(IPBase):
def __hash__(self): def __hash__(self):
return hash((self.base, self.mask)) return hash((self.base, self.mask))
def check_types(f):
def wrapper(a, b):
raise RuntimeError("AAAa")
if not isinstance(b, type(a)):
return NotImplemented
else:
return f(a, b)
_opperations = {'__lt__': '<', '__le__': '<=', '__gt__': '>', '__ge__': '>='}
class IPAddrBase(IPBase): class IPAddrBase(IPBase):
__slots__ = ("ip") __slots__ = ("ip")
...@@ -172,6 +182,25 @@ class IPAddrBase(IPBase): ...@@ -172,6 +182,25 @@ class IPAddrBase(IPBase):
def __hash__(self): return hash(self.ip) def __hash__(self): return hash(self.ip)
def _check_types(f):
def wrapper(self, other):
if not isinstance(other, type(self)) and not isinstance(self, type(other)):
raise TypeError("'{}' not supported between instances of '{}' and '{}'".format(_opperations.get(f.__name__), type(self).__name__, type(other).__name__))
return f(self, other)
return wrapper
@_check_types
def __lt__(self, other): return self.ip < other.ip
@_check_types
def __le__(self, other): return self.ip <= other.ip
@_check_types
def __gt__(self, other): return self.ip > other.ip
@_check_types
def __ge__(self, other): return self.ip >= other.ip
def low(self): return self.ip def low(self): return self.ip
def high(self): return self.ip def high(self): return self.ip
......
...@@ -275,6 +275,43 @@ class TestIPRange(unittest.TestCase): ...@@ -275,6 +275,43 @@ class TestIPRange(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
from_str("/") from_str("/")
def testComparing(self):
ip4_low1 = from_str("0.0.0.1")
ip4_low2 = from_str("0.0.0.1")
ip4_high = from_str("0.0.0.2")
self.assertGreater(ip4_high, ip4_low1)
self.assertGreaterEqual(ip4_high, ip4_low1)
self.assertGreaterEqual(ip4_low1, ip4_low2)
self.assertLess(ip4_low1, ip4_high)
self.assertLessEqual(ip4_low1, ip4_high)
self.assertLessEqual(ip4_low1, ip4_low2)
ip6_low1 = from_str("::1")
ip6_low2 = from_str("::1")
ip6_high = from_str("::2")
self.assertGreater(ip6_high, ip6_low1)
self.assertGreaterEqual(ip6_high, ip6_low1)
self.assertGreaterEqual(ip6_low1, ip6_low2)
self.assertLess(ip6_low1, ip6_high)
self.assertLessEqual(ip6_low1, ip6_high)
self.assertLessEqual(ip6_low1, ip6_low2)
with self.assertRaises(TypeError):
ip4_low1 < ip6_high
with self.assertRaises(TypeError):
ip6_high > ip4_low1
with self.assertRaises(TypeError):
ip4_low1 < "0.0.0.2"
with self.assertRaises(TypeError):
ip6_high > "::1"
class IP4Sub(IP4): pass
ip4_sub = IP4Sub(1)
self.assertGreater(ip4_high, ip4_sub)
with self.assertRaises(TypeError):
ip6_high > ip4_sub
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment