import asyncio
from importlib import import_module
import json
import unittest

import preseem.grpc_config_model
from preseem_protobuf.config import company_configs_pb2, powercode_pb2
from preseem import ConfigModel, PreseemGrpcClient
from google.protobuf.json_format import MessageToJson

class ConfigType():
    def __init__(self, name, value, pb):
        self.name = name    # the lowercase name of the enum
        self.value = value  # the int value of the enum
        self.pb = pb        # pb2 import
        self.svc = None     # service stub
        self.pb_type = None # grpc type

def _make(nt, **kwargs):
    """Helper to make a new namedtuple object"""
    return nt(**{x: kwargs.get(x) for x in nt._fields})

class FakeConfigClient(PreseemGrpcClient):
    """Fake config client"""
    def __init__(self):
        super().__init__('TEST')
        self.pbs = {}
        self.company_id = 'TestCompany'
        self._cbk = None
        self._new_cbk = None
        self._obj_cbk = {}
        self._pc_cbk = None

    def close_subscriptions(self):
        self._cbk = None

    async def get_config(self, config_type, config_id, company_id=None):
        return self.pbs[(company_id, config_type, config_id)]

    async def del_config(self, config_type, config_id, company_id=None):
        company_id = company_id or self.company_id  # have to have a company id
        try:
            del self.pbs[(company_id, config_type, config_id)]
            if self._cbk:
                ev = company_configs_pb2.CompanyConfigEvent(company_id=company_id, config_id=config_id, event_type=company_configs_pb2.CompanyConfigEvent.DELETE, config_type=config_type)
                try:
                    await self._cbk(ev, False)
                except Exception as err:
                    print("BAD", err)
            return True
        except KeyError:
            pass

    async def sub_config(self, config_type, config_id, cbk, company_id=None):
        company_id = company_id or self.company_id
        self._obj_cbk[(company_id, config_type, config_id)] = cbk
        pb = self.pbs[(company_id, config_type, config_id)]
        if pb:
            await cbk(company_id, config_type, pb)

    async def set_config(self, config_type, obj, company_id=None):
        company_id = company_id or self.company_id  # have to have a company id
        ct = self.config_types[config_type]
        cur = self.pbs.get((company_id, config_type, obj.id))
        self.pbs[(company_id, config_type, obj.id)] = obj
        cbk = self._obj_cbk.get((company_id, config_type, obj.id))
        if cbk:
            await cbk(company_id, config_type, obj)
        if self._cbk:
            op = company_configs_pb2.CompanyConfigEvent.UPDATE if cur else company_configs_pb2.CompanyConfigEvent.SET
            ev = company_configs_pb2.CompanyConfigEvent(company_id=company_id, config_id=obj.id, event_type=op, config_type=config_type)
            await self._cbk(ev, False)
        return True

    async def list_config_instances(self, company_id=None):
        r = instances_pb2.InstancesListResponse(success=True)
        r.items.extend([instances_pb2.TypeID(type=x[0], id=y) for x in self.pbs.items() for y in x[1].keys()])
        return r

    async def sub_config_instances(self, cbk, types, company_id=None):
        self._cbk = cbk

    async def sub_company_configs(self, cbk, company_id=None, subscribe_all=None):
        async def task(self):
            try:
                while True:
                    await asyncio.sleep(1)
            except asyncio.CancelledError:
                print("CANCELME")
                self._cbk = None
                print("CANCEL!", self._cbk)
        self._cbk = cbk
        for company_id, config_type, config_id in self.pbs:
            msg = company_configs_pb2.CompanyConfigEvent(company_id=company_id, config_id=config_id, event_type=company_configs_pb2.CompanyConfigEvent.SET, config_type=config_type)
            await self._cbk(msg)
        # XXX doesn't work yet, this shoudl be done in bg after returning.
        done_msg = company_configs_pb2.CompanyConfigEvent(event_type=company_configs_pb2.CompanyConfigEvent.LOADED)
        await self._cbk(done_msg)
        return asyncio.get_event_loop().run_task()


class TestConfig(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop() # needed to initialize asyncio
        asyncio.set_event_loop(self.loop)
        self.client = FakeConfigClient()
        self.model = ConfigModel(self.client)

    def tearDown(self):
        self.model.close()
        self.loop.close()

    def _delete(self, cfg):
        return self.loop.run_until_complete(self.model.delete(cfg))

    def _list(self):
        return self.loop.run_until_complete(self.model.list())

    def _listen(self, cbk=None):
        return self.loop.run_until_complete(self.model.listen(cbk))

    def _set(self, cfg):
        return self.loop.run_until_complete(self.model.set(cfg))

    def _make(self, ty, **kwargs):
        """Helper to make a new config object"""
        obj = self.model.get_default('powercode')
        return obj._replace(**kwargs)

    def test_list_1(self):
        """Test basic list operations"""
        self.assertEqual(self._list(), [])
        cfg = self._make('powercode', id='test')
        self._set(cfg)
        self.assertEqual(self._list(), [cfg])

    def test_del_1(self):
        """Test deletion."""
        cfg = self._make('powercode', id='test')
        self._set(cfg)
        r = self._list()
        self.assertEqual(len(r), 1)
        r = r[0]
        self.assertEqual(r.id, cfg.id)
        self._delete(cfg)
        r = self._list()
        self.assertEqual(len(r), 0)

    def test_set_1(self):
        """Test a basic Set operation"""
        cfg = self._make('powercode', id='test')
        self._set(cfg)
        r = self._list()
        self.assertEqual(len(r), 1)
        r = r[0]
        self.assertEqual(r.id, cfg.id)

    def test_set_2(self):
        """An ID gets created if not passed in."""
        self._set(self._make('powercode'))
        r = self._list()
        self.assertEqual(len(r), 1)
        self.assertTrue(r[0].id)

    def test_set_powercode(self):
        """Test that Powercode attributes are set and obtained back via get."""
        cfg = self._make('powercode', id='test', enabled=True,
                    url='https://192.0.2.1', api_key='--apikey--',
                    block_non_active_customers=True, report_usage=True,
                    service_rate_multiplier=1.07, 
                    extra={'map_session': True, 'cpe_categories': ['Other']})
        self._set(cfg)
        r = self._list()
        self.assertEqual(len(r), 1)
        r = r[0]
        self.assertEqual(r, cfg)

    def test_subscribe_1(self):
        """The internal objs dict is kept in sync."""
        self._listen()
        cfg = self._make('powercode', id='test')
        self._set(cfg)
        self.assertEqual(len(self.model.objs), 1)
        cfg = self._make('powercode', id='test2', enabled=True,
                    url='https://192.0.2.1', api_key='--apikey--',
                    block_non_active_customers=True, report_usage=True,
                    service_rate_multiplier=1.07,
                    extra={'map_session': True, 'cpe_categories': ['Other']})
        self._set(cfg)
        self.assertEqual(len(self.model.objs), 2)
        self.assertEqual(self.model.objs.get('test2'), cfg)
        self._delete(cfg)
        self.assertEqual(len(self.model.objs), 1)
        self.assertEqual(self.model.objs.get('test2'), None)

    def test_subscribe_2(self):
        """Callbacks are sent properly."""
        async def cbk(old, new):
            nonlocal cbk_old, cbk_new
            cbk_old = old
            cbk_new = new
        self._listen(cbk)
        cfg = self._make('powercode', id='test', enabled=True,
                    url='https://192.0.2.1', api_key='--apikey--',
                    block_non_active_customers=True, report_usage=True,
                    service_rate_multiplier=1.07,
                    extra={'map_session': True, 'cpe_categories': ['Other']})
        cbk_old = cbk_new = None
        self._set(cfg)
        self.assertEqual(cbk_old, None)
        self.assertEqual(cbk_new, cfg)
        cbk_old = cbk_new = None
        newcfg = cfg._replace(enabled=False)
        self._set(newcfg)
        self.assertEqual(cbk_old, cfg)
        self.assertEqual(cbk_new, newcfg)
        cbk_old = cbk_new = None
        self._set(newcfg)
        self.assertEqual(cbk_old, None)
        self.assertEqual(cbk_new, None)
        self._delete(newcfg)
        self.assertEqual(cbk_old, newcfg)
        self.assertEqual(cbk_new, None)

    def test_ap_csvfile_1(self):
        """Basic checks for the ap-csvfile config API."""
        # Get the default value and check it
        data = self.model.get_default('ap-csvfile')
        self.assertIsNotNone(data)
        self.assertIsNotNone(getattr(data, 'url', None))
        self.assertIsNotNone(getattr(data, 'google_id', None))
        self.assertIsNotNone(getattr(data, 'delimiter', None))
        self.assertIsNotNone(getattr(data, 'header', None))
        self.assertIsNotNone(getattr(data, 'poll_interval', None))
        self.assertIsNotNone(getattr(data, 'fields', None))
        fields = data.fields
        self.assertIsInstance(data.fields, dict)
        self.assertIsNotNone(fields.get('id'))
        self.assertIsNotNone(fields.get('sector'))
        self.assertIsNotNone(fields.get('tower'))
        self.assertIsNotNone(fields.get('ip_address'))

        # Set an object
        data = data._replace(url='https://preseem.com')
        data.fields['sector'] = 'ssid'
        self._set(data)

        # Verify it
        cfgs = self._list()
        self.assertIsNotNone(cfgs)
        self.assertEqual(len(cfgs), 1)
        self.assertEqual(cfgs[0], data)

    def test_customer_csvfile_map(self):
        """Test handling of the map in the customer-csvfile protobuf."""
        data = self.model.get_default('customer-csvfile')
        data = data._replace(account_status_config={"OFFLINE": {"kbps_down": "1", "kbps_up": "1", "account_status": "DELINQUENT"}})
        self._set(data)
        cfgs = self._list()
        self.assertEqual(len(cfgs), 1)
        self.maxDiff = 10000
        obj = cfgs[0]._asdict()
        obj.pop('id')  # don't compare this, it changes from run to run 
        self.assertEqual(obj, {'enabled': False, 'extra': None, 'url': '', 'google_id': '', 'delimiter': '', 'fields': {'id': '', 'secondary_id': '', 'account_id': '', 'name': '', 'cpe_mac': '', 'dev_mac': '', 'cpe_ip_address': '', 'ip_address': '', 'kbps_down': '', 'kbps_up': '', 'mbps_down': '', 'mbps_up': '', 'package_name': '', 'sector': '', 'imsi': '', 'ip_address_2': '', 'serial_number': '', 'account_status': ''}, 'header': [], 'poll_interval': 0, 'system': '', 'service_rate_multiplier': 0.0, 'account_status_config': ['OFFLINE']})
