"""Test multiprocessing of network-poller"""
import asyncio
import multiprocessing
import os
import unittest
import time

from context import Context
from ne import NetworkElementRegistry
from preseem import NetworkMetadataReference

#import logging
#logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
#                        level=logging.INFO)

class FakeAp:
    ne_type = 'ap'
    def __init__(self, ctx, neid, name, site, snmp_ne, snmp_ne_v3, ne_type, holdoff, extra=None):
        self.ctx = ctx
        self.neid = neid
        self.name = name
        self.site = site
        self.snmp_ne = snmp_ne
        self.snmp_ne_v3 = None
        self.ne_type = ne_type
        self.holdoff = holdoff

    async def close(self):
        if self.snmp_ne:
            self.snmp_ne.close()
            self.snmp_ne = None

    async def remove_ref_offline(self):
        pass

    def start(self):
        pass

    async def poll(self):
        pass

class FakeRouter(FakeAp):
    ne_type = 'router'

class FakeSwitch(FakeAp):
    ne_type = 'switch'


class TestMp(unittest.TestCase):
    def setUp(self):
        self.ctx = self._task = None
        self.subctx = {}
        self.subtask = {}
        self.loop = asyncio.new_event_loop() # needed to initialize asyncio
        asyncio.set_event_loop(self.loop)

    def tearDown(self):
        self.stop()
        self.loop.close()
        if len(multiprocessing.active_children()) > 0:
            # STM-8031: Wait for all threads to stop
            time.sleep(0.1)

    def _await(self, co):
        return self.loop.run_until_complete(co)

    def stop(self):
        for id, ctx in self.subctx.items():
            self.subtask[id].cancel()
            self._await(ctx.close())
        if self._task:
            self._task.cancel()
        if self.ctx:
            self._await(self.ctx.close())
    
    def start(self, processes=None, in_proc=False, extra_config=None):
        cfg = {'processes': processes}
        if extra_config:
            cfg.update(extra_config)
        self.ctx = Context('test', {}, None, {}, netpoller_cfg=cfg, fake=True)
        self.ctx.ap_registry = NetworkElementRegistry(self.ctx, self.ctx.defaults, FakeAp)
        self.ctx.router_registry = NetworkElementRegistry(self.ctx, self.ctx.defaults, FakeRouter)
        self.ctx.switch_registry = NetworkElementRegistry(self.ctx, self.ctx.defaults, FakeSwitch)
        initialized = self.loop.create_future()
        self._task = self.loop.create_task(self.ctx.start(fut=initialized))
        if in_proc:
            # create "fake" subprocesses in-process by creating Context objects
            for i in range(1, processes):
                ctx = self.subctx[i] = self.ctx.clone(i)
                ctx.ap_registry = NetworkElementRegistry(self.ctx, self.ctx.defaults, FakeAp)
                ctx.router_registry = NetworkElementRegistry(self.ctx, self.ctx.defaults, FakeRouter)
                ctx.switch_registry = NetworkElementRegistry(self.ctx, self.ctx.defaults, FakeSwitch)
                fut = self.loop.create_future()
                self.subtask[i] = self.loop.create_task(ctx.start(fut=fut))
        self._await(initialized)

    def del_ref(self, ref):
        """Delete a reference on all contexts, to fake the central metadata server"""
        for ctx in [self.ctx] + list(self.subctx.values()):
            self._await(ctx.netmeta_model.del_ref(ref))

    def set_ref(self, ref):
        """Set a reference on all contexts, to fake the central metadata server"""
        for ctx in [self.ctx] + list(self.subctx.values()):
            self._await(ctx.netmeta_model.set_ref(ref))

    def test_mp_config_01(self):
        """Multiprocessing configuration test."""
        cfg = {}
        self.ctx = Context('test', {}, None, {}, netpoller_cfg={'processes': 'a'})
        self.assertEqual(self.ctx.num_instances, None)
        self.ctx = Context('test', {}, None, {}, netpoller_cfg={'processes': 0})
        self.assertEqual(self.ctx.num_instances, None)
        self.ctx = Context('test', {}, None, {}, netpoller_cfg={'processes': -1})
        self.assertEqual(self.ctx.num_instances, None)
        self.ctx = Context('test', {}, None, {}, netpoller_cfg={'processes': 1})
        self.assertEqual(self.ctx.num_instances, 1)
        self.ctx = Context('test', {}, None, {}, netpoller_cfg={'processes': 2})
        self.assertEqual(self.ctx.num_instances, 2)

    def test_mp_config_02(self):
        """Processes are limited to the available number of CPU cores."""
        cores = len(os.sched_getaffinity(os.getpid()))
        self.ctx = Context('test', {}, None, {}, netpoller_cfg={'processes': cores + 1})
        self.assertEqual(self.ctx.num_instances, cores)

    def test_mp_01(self):
        """Multiple processes can be started by configuration."""
        self.start(2)
        children = multiprocessing.active_children()
        self.assertEqual(len(children), 1)

    # Using our fakes across processes is hard, so for these tests, we simply
    # create all of the contexts in the same process.  They still operate in
    # the same way and it simplifies unit testing.

    def test_lb_01(self):
        """Test load balancing across processes for APs."""
        self.start(3, in_proc=True)
        self.assertEqual(len(multiprocessing.active_children()), 2)

        # two APs are split across different processes
        attrs = {'topology2': 'T2', 'topology1': 'T1', 'ip_address': '192.0.2.1'}
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP1', attrs))
        self.assertIsNone(self.ctx.ap_registry.nes.get('TESTAP1'))
        self.assertIsNone(self.subctx[1].ap_registry.nes.get('TESTAP1'))
        self.assertIsNotNone(self.subctx[2].ap_registry.nes.get('TESTAP1'))
        attrs['topology2'] = 'T3'
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP2', attrs))
        self.assertIsNone(self.ctx.ap_registry.nes.get('TESTAP2'))
        self.assertIsNotNone(self.subctx[1].ap_registry.nes.get('TESTAP2'))
        self.assertIsNone(self.subctx[2].ap_registry.nes.get('TESTAP2'))

        # change the hash key (topology2) to move an AP
        attrs['topology2'] = 'T5'
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP1', attrs))
        self.assertIsNotNone(self.ctx.ap_registry.nes.get('TESTAP1'))
        self.assertIsNone(self.subctx[1].ap_registry.nes.get('TESTAP1'))
        self.assertIsNone(self.subctx[2].ap_registry.nes.get('TESTAP1'))

        # delete an element
        self.del_ref(NetworkMetadataReference('ap', 'TESTAP1', attrs))
        self.assertIsNone(self.ctx.ap_registry.nes.get('TESTAP1'))
        self.assertIsNone(self.subctx[1].ap_registry.nes.get('TESTAP1'))
        self.assertIsNone(self.subctx[2].ap_registry.nes.get('TESTAP1'))

        # load balance an element with no hash key attribute
        del attrs['topology2']
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP3', attrs))
        self.assertIsNone(self.ctx.ap_registry.nes.get('TESTAP3'))
        self.assertIsNone(self.subctx[1].ap_registry.nes.get('TESTAP3'))
        self.assertIsNotNone(self.subctx[2].ap_registry.nes.get('TESTAP3'))

    def test_lb_02(self):
        """Test load balancing across processes for routers."""
        self.start(3, in_proc=True)
        self.assertEqual(len(multiprocessing.active_children()), 2)

        # two routers are split across different processes
        attrs = {'site': 'T2', 'name': 'T1', 'host': '192.0.2.1'}
        self.set_ref(NetworkMetadataReference('router', 'TESTRTR1', attrs))
        self.assertIsNone(self.ctx.router_registry.nes.get('TESTRTR1'))
        self.assertIsNone(self.subctx[1].router_registry.nes.get('TESTRTR1'))
        self.assertIsNotNone(self.subctx[2].router_registry.nes.get('TESTRTR1'))
        attrs['site'] = 'T3'
        self.set_ref(NetworkMetadataReference('router', 'TESTRTR2', attrs))
        self.assertIsNone(self.ctx.router_registry.nes.get('TESTRTR2'))
        self.assertIsNotNone(self.subctx[1].router_registry.nes.get('TESTRTR2'))
        self.assertIsNone(self.subctx[2].router_registry.nes.get('TESTRTR2'))

        # change the hash key (topology2) to move a router
        attrs['site'] = 'T5'
        self.set_ref(NetworkMetadataReference('router', 'TESTRTR1', attrs))
        self.assertIsNotNone(self.ctx.router_registry.nes.get('TESTRTR1'))
        self.assertIsNone(self.subctx[1].router_registry.nes.get('TESTRTR1'))
        self.assertIsNone(self.subctx[2].router_registry.nes.get('TESTRTR1'))

        # delete an element
        self.del_ref(NetworkMetadataReference('router', 'TESTRTR1', attrs))
        self.assertIsNone(self.ctx.router_registry.nes.get('TESTRTR1'))
        self.assertIsNone(self.subctx[1].router_registry.nes.get('TESTRTR1'))
        self.assertIsNone(self.subctx[2].router_registry.nes.get('TESTRTR1'))

        # load balance an element with no hash key attribute
        del attrs['site']
        self.set_ref(NetworkMetadataReference('router', 'TESTRTR3', attrs))
        self.assertIsNone(self.ctx.router_registry.nes.get('TESTRTR3'))
        self.assertIsNone(self.subctx[1].router_registry.nes.get('TESTRTR3'))
        self.assertIsNotNone(self.subctx[2].router_registry.nes.get('TESTRTR3'))

    def test_lb_03(self):
        """Test load balancing across processes for switches."""
        self.start(3, in_proc=True)
        self.assertEqual(len(multiprocessing.active_children()), 2)

        # two switches are split across different processes
        attrs = {'site': 'T2', 'name': 'T1', 'host': '192.0.2.1'}
        self.set_ref(NetworkMetadataReference('switch', 'TESTSW1', attrs))
        self.assertIsNone(self.ctx.switch_registry.nes.get('TESTSW1'))
        self.assertIsNone(self.subctx[1].switch_registry.nes.get('TESTSW1'))
        self.assertIsNotNone(self.subctx[2].switch_registry.nes.get('TESTSW1'))
        attrs['site'] = 'T3'
        self.set_ref(NetworkMetadataReference('switch', 'TESTSW2', attrs))
        self.assertIsNone(self.ctx.switch_registry.nes.get('TESTSW2'))
        self.assertIsNotNone(self.subctx[1].switch_registry.nes.get('TESTSW2'))
        self.assertIsNone(self.subctx[2].switch_registry.nes.get('TESTSW2'))

        # change the hash key (topology2) to move a switch
        attrs['site'] = 'T5'
        self.set_ref(NetworkMetadataReference('switch', 'TESTSW1', attrs))
        self.assertIsNotNone(self.ctx.switch_registry.nes.get('TESTSW1'))
        self.assertIsNone(self.subctx[1].switch_registry.nes.get('TESTSW1'))
        self.assertIsNone(self.subctx[2].switch_registry.nes.get('TESTSW1'))

        # delete an element
        self.del_ref(NetworkMetadataReference('switch', 'TESTSW1', attrs))
        self.assertIsNone(self.ctx.switch_registry.nes.get('TESTSW1'))
        self.assertIsNone(self.subctx[1].switch_registry.nes.get('TESTSW1'))
        self.assertIsNone(self.subctx[2].switch_registry.nes.get('TESTSW1'))

        # load balance an element with no hash key attribute
        del attrs['site']
        self.set_ref(NetworkMetadataReference('switch', 'TESTSW3', attrs))
        self.assertIsNone(self.ctx.switch_registry.nes.get('TESTSW3'))
        self.assertIsNone(self.subctx[1].switch_registry.nes.get('TESTSW3'))
        self.assertIsNotNone(self.subctx[2].switch_registry.nes.get('TESTSW3'))

    def test_ap_config_lb(self):
        """STM-9640 test that ap_configs are handled properly."""
        self.start(3, in_proc=True)
        self.assertEqual(len(multiprocessing.active_children()), 2)

        # two APs are split across different processes
        attrs = {'topology2': 'T2', 'topology1': 'T1', 'ip_address': '192.0.2.1'}
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP1', attrs))
        self.assertIsNone(self.ctx.ap_registry.nes.get('TESTAP1'))
        self.assertIsNone(self.subctx[1].ap_registry.nes.get('TESTAP1'))
        self.assertIsNotNone(self.subctx[2].ap_registry.nes.get('TESTAP1'))
        attrs['topology2'] = 'T3'
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP2', attrs))
        self.assertIsNone(self.ctx.ap_registry.nes.get('TESTAP2'))
        self.assertIsNotNone(self.subctx[1].ap_registry.nes.get('TESTAP2'))
        self.assertIsNone(self.subctx[2].ap_registry.nes.get('TESTAP2'))

        # now add in an ap_config for the one on process 2 and make sure
        # it stays on process 2.
        self.set_ref(NetworkMetadataReference('ap_config', 'TESTAP1', {'ap_snmp_community': 'public'}))
        self.assertIsNone(self.ctx.ap_registry.nes.get('TESTAP1'))
        self.assertIsNone(self.subctx[1].ap_registry.nes.get('TESTAP1'))
        self.assertIsNotNone(self.subctx[2].ap_registry.nes.get('TESTAP1'))

    def test_lb_by_type(self):
        """Test load balancing across processes by the element type."""
        self.start(3, in_proc=True, extra_config={'map_process_by_type': 'true'})
        self.assertEqual(len(multiprocessing.active_children()), 2)

        # Create two of each type, to make sure they get mapped properly.
        attrs = {'topology2': 'T2', 'topology1': 'AP 1', 'ip_address': '192.0.2.1'}
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP1', attrs))
        attrs = {'topology2': 'T2', 'topology1': 'AP 2', 'ip_address': '192.0.2.2'}
        self.set_ref(NetworkMetadataReference('ap', 'TESTAP2', attrs))
        attrs = {'site': 'T2', 'name': 'Router 1', 'host': '192.0.2.11'}
        self.set_ref(NetworkMetadataReference('router', 'TESTRTR1', attrs))
        attrs = {'site': 'T2', 'name': 'Router 2', 'host': '192.0.2.12'}
        self.set_ref(NetworkMetadataReference('router', 'TESTRTR2', attrs))
        attrs = {'site': 'T2', 'name': 'Switch 1', 'host': '192.0.2.21'}
        self.set_ref(NetworkMetadataReference('switch', 'TESTSW1', attrs))
        attrs = {'site': 'T2', 'name': 'Switch 2', 'host': '192.0.2.22'}
        self.set_ref(NetworkMetadataReference('switch', 'TESTSW2', attrs))

        # Each type is assigned to a different process
        elems1 = self.ctx.ap_registry.nes
        elems1.update(self.ctx.router_registry.nes)
        elems1.update(self.ctx.switch_registry.nes)
        elems2 = self.subctx[1].ap_registry.nes
        elems2.update(self.subctx[1].router_registry.nes)
        elems2.update(self.subctx[1].switch_registry.nes)
        elems3 = self.subctx[2].ap_registry.nes
        elems3.update(self.subctx[2].router_registry.nes)
        elems3.update(self.subctx[2].switch_registry.nes)
        self.assertEqual(len(elems1), 2)
        self.assertEqual(len(elems2), 2)
        self.assertEqual(len(elems3), 2)
        types = set()
        types.add(tuple(x.ne_type for x in elems1.values()))
        types.add(tuple(x.ne_type for x in elems2.values()))
        types.add(tuple(x.ne_type for x in elems3.values()))
        self.assertEqual(types, {('ap', 'ap'), ('router', 'router'), ('switch', 'switch')})
