#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2016, CESNET, z. s. p. o.
# Use of this source is governed by an ISC license, see LICENSE file.

import unittest
from ipranges import IP4, IP6, IP4Range, IP6Range, IP4Net, IP6Net, from_str

class TestIPRange(unittest.TestCase):

    def testIP4(self):
        for ip in ["0.0.0.0", "192.0.2.100", "255.255.255.255"]:
            self.assertEqual(str(IP4(ip)), ip)

    def testIP4Fail(self):
        for ip in ["", "-", "/", "0", "123", "1.2.3.4.5"]:
            with self.assertRaises(ValueError):
                IP4(ip)

    def testIP6(self):
        for ip in ["::", "2001:db8:220:1:248:1893:25c8:1946", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"]:
            self.assertEqual(str(IP6(ip)), ip)

    def testIP6Fail(self):
        for ip in ["", "-", "/", "0", "123", "1:2::3::"]:
            with self.assertRaises(ValueError):
                IP6(ip)

    def testIP4Range(self):
        for r in ["0.0.0.0-255.255.255.255", "192.0.2.64-192.0.2.127", "192.0.2.5-192.0.2.5"]:
            self.assertEqual(str(IP4Range(r)), r)

    def testIP4RangeFail(self):
        for r in ["", "0.0.0.0",  "asdf"]:
            with self.assertRaises(ValueError):
                IP4Range(r)

    def testIP4RangeFail2(self):
        for r in ["192.0.2.64-", "-192.0.2.5", "-"]:
            with self.assertRaises(ValueError):
                IP4Range(r)

    def testIP6Range(self):
        for r in ["::-ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "2001:db8:220:1:248:1893:25c8:1946-2001:db8:230:1:248:1893:25c8:1946", "2001:db8:230::25c8:1946-2001:db8:230::25c8:1946"]:
            self.assertEqual(str(IP6Range(r)), r)

    def testIP6RangeFail(self):
        for r in ["", "::",  "asdf"]:
            with self.assertRaises(ValueError):
                IP6Range(r)

    def testIP6RangeFail2(self):
        for r in ["2001:db8:220:1:248:1893:25c8:1946-", "-2001:db8:220:1:248:1893:25c8:1946", "-"]:
            with self.assertRaises(ValueError):
                IP6Range(r)

    def testIP4Net(self):
        for n in ["0.0.0.0/0", "192.0.2.64/26", "192.0.2.5/32"]:
            self.assertEqual(str(IP4Net(n)), n)

    def testIP4NetFail(self):
        for r in ["", "0.0.0.0",  "asdf", "192.0.2.64/", "192.0.2.64/?"]:
            with self.assertRaises(ValueError):
                IP4Net(r)

    def testIP4NetFail2(self):
        for r in ["/26",  "/"]:
            with self.assertRaises(ValueError):
                IP4Net(r)

    def testIP6Net(self):
        for n in ["::/0", "2001:db8:220:1::/64", "2001:db8:230::25c8:1946/32"]:
            self.assertEqual(str(IP6Net(n)), n)

    def testIP6NetFail(self):
        for r in ["", "0.0.0.0",  "asdf", "2001:db8:220:1::/", "2001:db8:220:1::/?"]:
            with self.assertRaises(ValueError):
                IP6Net(r)

    def testIP6NetFail2(self):
        for r in ["/26",  "/"]:
            with self.assertRaises(ValueError):
                IP6Net(r)

    def test4SameNetRange(self):
        net1 = IP4Net("192.0.2.64/26")
        net2 = IP4Range("192.0.2.64-192.0.2.127")
        self.assertTrue(net1 == net2)
        self.assertFalse(net1 != net2)

    def test4SameOne(self):
        ip1 = IP4Net("192.0.2.65/32")
        ip2 = IP4Range("192.0.2.65-192.0.2.65")
        ip3 = IP4("192.0.2.65")
        self.assertTrue(ip1 == ip2)
        self.assertTrue(ip2 == ip3)
        self.assertTrue(ip1 == ip3)
        self.assertFalse(ip1 != ip2)
        self.assertFalse(ip2 != ip3)
        self.assertFalse(ip1 != ip3)

    def test6SameNetRange(self):
        net1 = IP6Net("2001:db8:220:1::/64")
        net2 = IP6Range("2001:db8:220:1::-2001:db8:220:1:ffff:ffff:ffff:ffff")
        self.assertTrue(net1 == net2)
        self.assertFalse(net1 != net2)

    def test6SameOne(self):
        ip1 = IP6Net("2001:db8:220:1:248:1893:25c8:1946/128")
        ip2 = IP6Range("2001:db8:220:1:248:1893:25c8:1946-2001:db8:220:1:248:1893:25c8:1946")
        ip3 = IP6("2001:db8:220:1:248:1893:25c8:1946")
        self.assertTrue(ip1 == ip2)
        self.assertTrue(ip2 == ip3)
        self.assertTrue(ip1 == ip3)
        self.assertFalse(ip1 != ip2)
        self.assertFalse(ip2 != ip3)
        self.assertFalse(ip1 != ip3)

    def test4Contains(self):
        self.assertTrue(IP4Net("192.0.2.64/28") in IP4Net("192.0.2.64/26"))
        self.assertTrue(IP4Net("192.0.2.64/28") in IP4Range("192.0.2.64-192.0.2.127"))
        self.assertTrue(IP4Net("192.0.2.65/32") in IP4("192.0.2.65"))

        self.assertTrue(IP4Range("192.0.2.65-192.0.2.126") in IP4Range("192.0.2.64-192.0.2.127"))
        self.assertTrue(IP4Range("192.0.2.65-192.0.2.126") in IP4Net("192.0.2.64/26"))
        self.assertTrue(IP4Range("192.0.2.65-192.0.2.65") in IP4("192.0.2.65"))

        self.assertTrue(IP4("192.0.2.65") in IP4Range("192.0.2.64-192.0.2.127"))
        self.assertTrue(IP4("192.0.2.65") in IP4Net("192.0.2.64/26"))
        self.assertTrue(IP4("192.0.2.65") in IP4("192.0.2.65"))

    def test6Contains(self):
        self.assertTrue(IP6Net("2001:db8:220:1::/64") in IP6Net("2001:db8:220:1::/64"))
        self.assertTrue(IP6Net("2001:db8:220:1::/64") in IP6Range("2001:db8:220:1::-2001:db8:220:1:ffff:ffff:ffff:ffff"))
        self.assertTrue(IP6Net("2001:db8:220:1:248:1893:25c8:1946/128") in IP6("2001:db8:220:1:248:1893:25c8:1946"))

        self.assertTrue(IP6Range("2001:db8:220:1::-2001:db8:220:1:ffff:ffff:ffff:ffff") in IP6Range("2001:db8:220:1::-2001:db8:220:1:ffff:ffff:ffff:ffff"))
        self.assertTrue(IP6Range("2001:db8:220:1::-2001:db8:220:1:ffff:ffff:ffff:ffff") in IP6Net("2001:db8:220:1::/64"))
        self.assertTrue(IP6Range("2001:db8:220:1:248:1893:25c8:1946-2001:db8:220:1:248:1893:25c8:1946") in IP6("2001:db8:220:1:248:1893:25c8:1946"))

        self.assertTrue(IP6("2001:db8:220:1:248:1893:25c8:1946") in IP6Range("2001:db8:220:1::-2001:db8:220:1:ffff:ffff:ffff:ffff"))
        self.assertTrue(IP6("2001:db8:220:1:248:1893:25c8:1946") in IP6Net("2001:db8:220:1::/64"))
        self.assertTrue(IP6("2001:db8:220:1:248:1893:25c8:1946") in IP6("2001:db8:220:1:248:1893:25c8:1946"))

    def test4Iter(self):
        self.assertEqual(
            tuple(str(ip) for ip in IP4Net("192.0.2.64/30")),
            ("192.0.2.64", "192.0.2.65", "192.0.2.66", "192.0.2.67"))
        self.assertEqual(
            tuple(str(ip) for ip in IP4Range("192.0.2.64-192.0.2.67")),
            ("192.0.2.64", "192.0.2.65", "192.0.2.66", "192.0.2.67"))
        self.assertEqual(
            tuple(str(ip) for ip in IP4("192.0.2.65")),
            ("192.0.2.65",))

    def test6Iter(self):
        self.assertEqual(
            tuple(str(ip) for ip in IP6Net("2001:db8:220:1:248:1893:25c8:1944/126")),
            ("2001:db8:220:1:248:1893:25c8:1944",
             "2001:db8:220:1:248:1893:25c8:1945",
             "2001:db8:220:1:248:1893:25c8:1946",
             "2001:db8:220:1:248:1893:25c8:1947"))
        self.assertEqual(
            tuple(str(ip) for ip in IP6Range("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1947")),
            ("2001:db8:220:1:248:1893:25c8:1944",
             "2001:db8:220:1:248:1893:25c8:1945",
             "2001:db8:220:1:248:1893:25c8:1946",
             "2001:db8:220:1:248:1893:25c8:1947"))
        self.assertEqual(
            tuple(str(ip) for ip in IP6("2001:db8:220:1:248:1893:25c8:1947")),
            ("2001:db8:220:1:248:1893:25c8:1947",))

    def testGetItem(self):
        for rng in (
                IP4Range("192.0.2.64-192.0.2.67"),
                IP4Net("192.0.2.64/30"),
                IP4("192.0.2.65"),
                IP6Range("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1947"),
                IP6Net("2001:db8:220:1:248:1893:25c8:1944/126"),
                IP6("2001:db8:220:1:248:1893:25c8:1947")):
            for idx in (0, -1):
                res = [str(rng[i]) for i in range(len(rng))][idx]
                res2 = str(rng[idx])
                self.assertEqual(res, res2)

    def testGetSlice(self):
        for rng in (
                IP4Range("192.0.2.64-192.0.2.67"),
                IP4Net("192.0.2.64/30"),
                IP4("192.0.2.65"),
                IP6Range("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1947"),
                IP6Net("2001:db8:220:1:248:1893:25c8:1944/126"),
                IP6("2001:db8:220:1:248:1893:25c8:1947")):
            for idx in (slice(None, None, None), slice(-3, -1), slice(0, -1, 2)):
                res = [str(rng[i]) for i in range(len(rng))][idx]
                res2 = [str(ip) for ip in rng[idx]]
                self.assertEqual(res, res2)

    def testConvToIP(self):
        self.assertEqual(IP4(IP4Range("192.0.2.64-192.0.2.64")), IP4("192.0.2.64"))
        self.assertEqual(IP4(IP4Net("192.0.2.64/32")), IP4("192.0.2.64"))
        with self.assertRaises(ValueError):
            IP4(IP4Net("192.0.2.64/30"))

        self.assertEqual(IP6(IP6Range("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1944")), IP6("2001:db8:220:1:248:1893:25c8:1944"))
        self.assertEqual(IP6(IP6Net("2001:db8:220:1:248:1893:25c8:1944/128")), IP6("2001:db8:220:1:248:1893:25c8:1944"))
        with self.assertRaises(ValueError):
            IP6(IP6Net("2001:db8:220:1:248:1893:25c8:1944/126"))

    def testConvToNet(self):
        self.assertEqual(IP4Net(IP4Range("192.0.2.64-192.0.2.127")), IP4Net("192.0.2.64/26"))
        self.assertEqual(IP4Net(IP4("192.0.2.64")), IP4Net("192.0.2.64/32"))
        with self.assertRaises(ValueError):
            IP4Net(IP4Range("192.0.2.64-192.0.2.120"))

        self.assertEqual(IP4Net(IP4Range("192.0.2.64-192.0.2.127")), IP4Net("192.0.2.64/26"))
        self.assertEqual(IP4Net(IP4("192.0.2.64")), IP4Net("192.0.2.64/32"))
        with self.assertRaises(ValueError):
            IP4Net(IP4Range("192.0.2.64-192.0.2.120"))

        self.assertEqual(IP6Net(IP6Range("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1947")), IP6Net("2001:db8:220:1:248:1893:25c8:1944/126"))
        self.assertEqual(IP6Net(IP6("2001:db8:220:1:248:1893:25c8:1947")), IP6Net("2001:db8:220:1:248:1893:25c8:1947/128"))
        with self.assertRaises(ValueError):
            IP6Net(IP6Range("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1948"))

    def testFromStr(self):
        fs = from_str("192.0.2.64")
        obj = IP4("192.0.2.64")
        self.assertEqual(fs, obj)
        self.assertTrue(isinstance(fs, IP4))

        fs = from_str("192.0.2.64-192.0.2.127")
        obj = IP4Range("192.0.2.64-192.0.2.127")
        self.assertEqual(fs, obj)
        self.assertTrue(isinstance(fs, IP4Range))

        fs = from_str("192.0.2.64/26")
        obj = IP4Net("192.0.2.64/26")
        self.assertEqual(fs, obj)
        self.assertTrue(isinstance(fs, IP4Net))

        fs = from_str("2001:db8:220:1:248:1893:25c8:1947")
        obj = IP6("2001:db8:220:1:248:1893:25c8:1947")
        self.assertEqual(fs, obj)
        self.assertTrue(isinstance(fs, IP6))

        fs = from_str("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1947")
        obj = IP6Range("2001:db8:220:1:248:1893:25c8:1944-2001:db8:220:1:248:1893:25c8:1947")
        self.assertEqual(fs, obj)
        self.assertTrue(isinstance(fs, IP6Range))

        fs = from_str("2001:db8:220:1:248:1893:25c8:1947/128")
        obj = IP6Net("2001:db8:220:1:248:1893:25c8:1947/128")
        self.assertEqual(fs, obj)
        self.assertTrue(isinstance(fs, IP6Net))

    def testFromStrInvalid(self):
        with self.assertRaises(ValueError):
            from_str("192.0.2.500")
        with self.assertRaises(ValueError):
            from_str(":::")
        with self.assertRaises(ValueError):
            from_str("asdf")
        with self.assertRaises(ValueError):
            from_str("-")
        with self.assertRaises(ValueError):
            from_str("/")


if __name__ == "__main__":
    unittest.main()