import asyncio
from collections import namedtuple
from itertools import chain
import logging
import socket
from sys import intern

import pytricia

from preseem_grpc_model import network_metadata_pb2
from preseem.grpc_client import PreseemGrpcClient

NetworkMetadataEntity = namedtuple('NetworkMetadataEntity',
                                   ('network_prefix', 'start', 'attributes'))
NetworkMetadataReference = namedtuple('NetworkMetadataReference',
                                      ('type', 'value', 'attributes'))
NetworkMetadataIndex = namedtuple('NetworkMetadataIndex', ('type', 'attribute'))
NetworkMetadataLowercaseIndex = namedtuple('NetworkMetadataLowercaseIndex',
                                           ('type', 'attribute'))


class NetworkMetadataClient(PreseemGrpcClient):

    def __init__(self, api_key, host=None, port=None, beta=None):
        super().__init__(api_key,
                         host or 'nmapi{}.preseem.com'.format('-beta' if beta else ''),
                         port,
                         compression=True)


def ip_sort_key(x):
    """Sort key for listing IP address strings."""
    try:
        return socket.inet_aton(x)
    except OSError:
        return bytes(0)


def _resolve_refs(refs, attrs):
    "Helper function to resolve reference attributes." ""
    # first resolve non-refs, then recurse alphabetically on refs.  first one wins.
    # - first resolve on current ref
    # - then resolve refs, alphabetically, and overwrite only if not set in map.
    ref_attrs = []
    for t, v in list(attrs.items()):
        if isinstance(v, Reference):
            ref_attrs.append((t, v))

    for t, v in sorted(ref_attrs):
        refmap = refs.get(t)
        if refmap:
            ref = refmap.get(v)
            if ref:
                child_attrs = _resolve_refs(refs, ref.attributes.copy())
                for at, av in child_attrs.items():
                    if at not in attrs:
                        attrs[at] = av
    return attrs


def _diff(src, dst, refs_by_type=False):
    """Helper function to return operations to make dst like src"""
    adds = []
    dels = []
    mods = []
    if refs_by_type:
        keys = sorted(set([(x[0], y) for x in src.items() for y in x[1].keys()] + \
                          [(x[0], y) for x in dst.items() for y in x[1].keys()]))
    else:
        keys = sorted(set(list(src.keys()) + list(dst.keys())))
    for key in keys:
        if refs_by_type:
            dobj = (dst.get(key[0]) or {}).get(key[1])
            sobj = (src.get(key[0]) or {}).get(key[1])
        else:
            dobj = dst.get(key)
            sobj = src.get(key)
        if sobj is None:
            dels.append(dobj)
        elif dobj is None:
            adds.append(sobj)
        elif dobj != sobj:
            mods.append(sobj)
    return adds, dels, mods


class Reference(str):
    """This is used to hint to the GRPC code that it should create a ref"""
    pass


class uint64(int):
    """This is used to hint to the GRPC code that it should create a uint"""
    pass


def _attrs_to_dict(attrs):
    result = {}
    for attr, data in attrs.items():
        if data.type == network_metadata_pb2.NetworkMetadataAttribute.REFERENCE:
            val = Reference(data.refVal.value)
        elif data.type == network_metadata_pb2.NetworkMetadataAttribute.BOOL:
            val = data.boolVal
        elif data.type == network_metadata_pb2.NetworkMetadataAttribute.DOUBLE:
            val = data.doubleVal
        elif data.type == network_metadata_pb2.NetworkMetadataAttribute.STRING:
            val = data.stringVal
        elif data.type == network_metadata_pb2.NetworkMetadataAttribute.INT64:
            val = data.int64Val
        elif data.type == network_metadata_pb2.NetworkMetadataAttribute.UINT64:
            val = data.uint64Val
        result[intern(attr)] = val
    return result


def _nm_set_attrs(nm_attrs, py_attrs):
    for t, v in py_attrs.items():
        if v is None:
            continue  # ignore null values
        apb = nm_attrs[t]
        apb.name = t
        ty = type(v)
        if ty is Reference:
            apb.type = network_metadata_pb2.NetworkMetadataAttribute.REFERENCE
            apb.refVal.type = t
            apb.refVal.value = str(v)
        elif ty is bool:
            apb.type = network_metadata_pb2.NetworkMetadataAttribute.BOOL
            apb.boolVal = v
        elif ty is float:
            apb.type = network_metadata_pb2.NetworkMetadataAttribute.DOUBLE
            apb.doubleVal = v
        elif ty is str:
            apb.type = network_metadata_pb2.NetworkMetadataAttribute.STRING
            apb.stringVal = v
        elif ty is int:
            apb.type = network_metadata_pb2.NetworkMetadataAttribute.INT64
            apb.int64Val = v
        elif ty is uint64:
            apb.type = network_metadata_pb2.NetworkMetadataAttribute.UINT64
            apb.uint64Val = v
        else:
            raise RuntimeError('unknown attr type {}'.format(ty.__name__))


class PreseemNetworkMetadataModel(object):

    def __init__(self,
                 client,
                 company_id=None,
                 cbk=None,
                 refs_by_type=None,
                 indexes=None):
        """
        Pass in refs_by_type=True to use the newer API where the "refs"
        attribute is {type: {id: obj}}.  Default is {(type, id): obj}.
        Indexes are specified here as a list of NetworkMetadataIndex.
        """
        self._client = client
        self.company_id = company_id  # used for multi-company keys
        self._cbk = cbk  # optional callback for updates
        self.refs_by_type = refs_by_type
        self.refs = {}  # public: current references
        self.nms = {}  # public: current entities
        self.refs_last_updated = {}  # ref_type -> eventloop.time() of last change
        self.pt4 = pytricia.PyTricia(32)
        self.syn = None  # dict of GRPC objects we receive in synchronize state
        self.nm_syn = None
        self.start_load = None
        self._ent_sub_task = None
        self._ref_sub_task = None
        self._ents_loaded = asyncio.Condition()
        self._refs_loaded = asyncio.Condition()
        self._ents_lock = asyncio.Lock()
        self._refs_lock = asyncio.Lock()
        self._index = {x.type: {} for x in indexes or []}
        self._index_cfg = {x.type: {} for x in indexes or []}
        for index in indexes or []:
            self._index[index.type][index.attribute] = {}
            self._index_cfg[index.type][index.attribute] = index

    def close(self):
        self._client.close_subscriptions()
        if self._ent_sub_task:
            self._ent_sub_task.cancel()
            self._ent_sub_task = None
        if self._ref_sub_task:
            self._ref_sub_task.cancel()
            self._ref_sub_task = None

    async def handle_ent_update(self, msg, syn=False):
        if syn:
            self.nm_syn = {}
        if msg.action == network_metadata_pb2.NetworkMetadataEntityAction.LOADED:
            self.nms = self.nm_syn
            self.pt4 = pytricia.PyTricia(32)
            for ip, nm in self.nms.items():
                if ':' not in ip:
                    try:
                        self.pt4[ip] = nm
                    except Exception as err:
                        logging.info('Error setting nm entity %s: %s', ip, err)
                        continue
            self.nm_syn = None
            if self._cbk:
                for nm in self.nms.values():
                    try:
                        await self._cbk(nm, deleted=False)
                    except Exception:
                        pass
            async with self._ents_loaded:
                self._ents_loaded.notify()
        elif msg.action == network_metadata_pb2.NetworkMetadataEntityAction.SET:
            nm = self._make_nm(msg.entity)
            if self.nm_syn is None:
                self.nms[nm.network_prefix] = nm
                if ':' not in nm.network_prefix:
                    try:
                        self.pt4[nm.network_prefix] = nm
                    except Exception as err:
                        logging.info('Error setting nm entity %s: %s',
                                     nm.network_prefix, err)
                if self._cbk:
                    try:
                        await self._cbk(nm, deleted=False)
                    except Exception:
                        pass
            else:
                self.nm_syn[nm.network_prefix] = nm
                if ':' not in msg.entity.network_prefix:
                    try:
                        self.pt4[nm.network_prefix] = nm
                    except Exception as err:
                        logging.info('Error setting nm entity %s: %s',
                                     nm.network_prefix, err)
                        del self.nm_syn[nm.network_prefix]
        elif msg.action == network_metadata_pb2.NetworkMetadataEntityAction.DELETE:
            nm = self._make_nm(msg.entity)
            try:
                del self.nms[nm.network_prefix]
            except KeyError:
                pass
            try:
                del self.pt4[nm.network_prefix]
            except KeyError:
                pass
            if self._cbk:
                try:
                    await self._cbk(nm, deleted=True)
                except Exception:
                    pass

    def is_loading(self):
        """Returns true if we are currently loading references."""
        return self.start_load is not None

    async def handle_ref_update(self, msg, syn=False):
        """
        Convert an updated entity to a Python type and call its handler.
        syn=True means this is the start of a synchronize, which will be
        terminated by a LOAD_COMPLETE message.
        This method is only called once at a time.
        Note, if cbk is defined we just pass all messages to the callback,
        no care is taken to suppress duplicates/resyncs etc.
        """
        now = asyncio.get_event_loop().time()
        if syn:
            self.syn = {}
            self.start_load = now
        if msg.action == network_metadata_pb2.NetworkMetadataReferenceEntityAction.LOADED:
            self.refs = self.syn
            self.refs_last_updated = {ty: now for ty in self.refs}
            load_time = round(asyncio.get_event_loop().time() - self.start_load or 0)
            if load_time >= 5:  # log a long load time
                logging.info(f"Loaded nm refs in %s seconds", load_time)
            self.start_load = None
            if self.refs_by_type:
                for it, index in self._index.items():
                    for ia, ir in index.items():
                        ir.clear()
                        if ia == '_LOWER':  # lowercase value index
                            ir[None] = {
                                n.lower(): r
                                for n, r in (self.refs.get(it) or {}).items()
                            }
                            continue
                        is_lower = isinstance(self._index_cfg[it][ia],
                                              NetworkMetadataLowercaseIndex)
                        for ref in (self.refs.get(it) or {}).values():
                            attr = ref.attributes.get(ia)
                            if attr:
                                if is_lower:
                                    attr = attr.lower()
                                refs = ir.get(attr)
                                if refs is None:
                                    refs = ir[attr] = {}
                                refs[ref.value] = ref
            if self._cbk:
                if self.refs_by_type:
                    refs = chain.from_iterable([x.values() for x in self.refs.values()])
                else:
                    refs = self.refs.values()
                for ref in refs:
                    try:
                        await self._cbk(ref, deleted=False)
                    except Exception:
                        pass
            self.syn = None
            async with self._refs_loaded:
                self._refs_loaded.notify()
        elif msg.action == network_metadata_pb2.NetworkMetadataReferenceEntityAction.SET:
            ref = NetworkMetadataReference(intern(msg.entity.type), msg.entity.value,
                                           _attrs_to_dict(msg.entity.attributes))
            if self.syn is None:
                if self.refs_by_type:
                    refs = self.refs.get(ref.type)
                    if refs is None:
                        refs = self.refs[ref.type] = {}
                    cur = refs.get(ref.value)
                    refs[ref.value] = ref
                    index = self._index.get(ref.type)
                    if index is not None:
                        for ia, ir in index.items():
                            if ia == '_LOWER':
                                refs = ir.get(None)
                                if refs is None:
                                    refs = ir[None] = {}
                                refs[ref.value.lower()] = ref
                                continue
                            is_lower = isinstance(self._index_cfg[ref.type][ia],
                                                  NetworkMetadataLowercaseIndex)
                            if cur:
                                # Remove the old ref from any indexes
                                attr = cur.attributes.get(ia)
                                if attr:
                                    if is_lower:
                                        attr = attr.lower()
                                    refs = ir.get(attr)
                                    if refs:
                                        refs.pop(ref.value, None)
                            attr = ref.attributes.get(ia)
                            if attr:
                                if is_lower:
                                    attr = attr.lower()
                                refs = ir.get(attr)
                                if refs is None:
                                    refs = ir[attr] = {}
                                refs[ref.value] = ref
                else:
                    self.refs[(ref.type, ref.value)] = ref
                if self._cbk:
                    try:
                        await self._cbk(ref, deleted=False)
                    except Exception:
                        pass
                self.refs_last_updated[ref.type] = now
            else:
                if self.refs_by_type:
                    refs = self.syn.get(ref.type)
                    if refs is None:
                        refs = self.syn[ref.type] = {}
                    refs[ref.value] = ref
                else:
                    self.syn[(ref.type, ref.value)] = ref
        elif msg.action == network_metadata_pb2.NetworkMetadataReferenceEntityAction.DELETE:
            ref = NetworkMetadataReference(intern(msg.entity.type), msg.entity.value,
                                           _attrs_to_dict(msg.entity.attributes))
            try:
                # use the saved ref so we have its attributes
                if self.refs_by_type:
                    ref = self.refs[ref.type].pop(ref.value)
                    if not self.refs[ref.type]:
                        del self.refs[ref.type]
                    index = self._index.get(ref.type)
                    if index is not None:
                        for ia, ir in index.items():
                            is_lower = isinstance(self._index_cfg[ref.type][ia],
                                                  NetworkMetadataLowercaseIndex)
                            if ia == '_LOWER':
                                refs = ir.get(None)
                                if refs:
                                    refs.pop(ref.value.lower(), None)
                                continue
                            attr = ref.attributes.get(ia)
                            if attr:
                                if is_lower:
                                    attr = attr.lower()
                                refs = ir.get(attr)
                                if refs:
                                    refs.pop(ref.value, None)
                else:
                    ref = self.refs.pop((ref.type, ref.value))
                self.refs_last_updated[ref.type] = now
            except KeyError:
                pass
            if self._cbk:
                try:
                    await self._cbk(ref, deleted=True)
                except Exception:
                    pass
        else:
            print("INVALID ACTION", msg.action)

    def _make_nm(self, msg, resolve_refs=False):
        cidr = msg.network_prefix.split('/')
        net = cidr[0] if len(cidr) == 1 or cidr[1] == '32' else msg.network_prefix
        obj = NetworkMetadataEntity(net, msg.start.seconds,
                                    _attrs_to_dict(msg.attributes))
        if resolve_refs:
            _resolve_refs(refs, obj.attributes)
        return obj

    def get_ref_index(self, reftype, attr, val=None):
        """Get a map of reference types of a type with a given attr value."""
        return self._index[reftype][attr].get(val)

    async def list_nm(self, resolve_refs=False):
        result = []

        refs = {}
        if resolve_refs:
            for ref in await self.list_refs():
                refmap = refs.get(ref.type)
                if refmap is None:
                    refmap = refs[ref.type] = {}
                refmap[ref.value] = ref

        for msg in await self._client.nm_list(company_id=self.company_id):
            result.append(self._make_nm(msg, resolve_refs))
        return result

    async def set_nm(self, entity):
        assert isinstance(entity, NetworkMetadataEntity)
        ts = network_metadata_pb2.google_dot_protobuf_dot_timestamp__pb2.Timestamp()
        ts.FromNanoseconds(int(entity.start * 1000000000))
        obj = network_metadata_pb2.NetworkMetadataEntity(
            address_space='default', network_prefix=entity.network_prefix, start=ts)
        _nm_set_attrs(obj.attributes, entity.attributes)
        await self._client.nm_set(obj, company_id=self.company_id)

    async def del_nm(self, entity):
        assert isinstance(entity, NetworkMetadataEntity)
        ts = network_metadata_pb2.google_dot_protobuf_dot_timestamp__pb2.Timestamp()
        ts.FromNanoseconds(int(entity.start * 1000000000))
        obj = network_metadata_pb2.NetworkMetadataEntity(
            address_space='default', network_prefix=entity.network_prefix, start=ts)
        _nm_set_attrs(obj.attributes, entity.attributes)
        await self._client.nm_del(obj, company_id=self.company_id)

    async def list_refs(self):
        result = []
        for msg in await self._client.nm_ref_list(company_id=self.company_id):
            ref = NetworkMetadataReference(intern(msg.type), msg.value,
                                           _attrs_to_dict(msg.attributes))
            result.append(ref)
        return result

    async def set_ref(self, ref):
        assert isinstance(ref, NetworkMetadataReference)
        obj = network_metadata_pb2.NetworkMetadataReferenceEntity(type=ref.type,
                                                                  value=ref.value)
        _nm_set_attrs(obj.attributes, ref.attributes)
        await self._client.nm_ref_set(obj, company_id=self.company_id)

    async def del_ref(self, ref):
        assert isinstance(ref, NetworkMetadataReference)
        obj = network_metadata_pb2.NetworkMetadataReferenceEntity(type=ref.type,
                                                                  value=ref.value)
        await self._client.nm_ref_del(obj, company_id=self.company_id)

    async def _load_ents(self):
        async with self._ents_lock:
            if self._ent_sub_task is None:
                cor = self._client.nm_ent_subscribe(self.handle_ent_update,
                                                    company_id=self.company_id)
                task = asyncio.get_event_loop().create_task(cor)
                async with self._ents_loaded:
                    await self._ents_loaded.wait()
                self._ent_sub_task = task

    async def _load_refs(self):
        async with self._refs_lock:
            if self._ref_sub_task is None:
                cor = self._client.nm_ref_subscribe(self.handle_ref_update,
                                                    company_id=self.company_id)
                task = asyncio.get_event_loop().create_task(cor)
                async with self._refs_loaded:
                    await self._refs_loaded.wait()
                self._ref_sub_task = task

    async def update(self):
        """Start the model; sync with the central database.  Return when we are
           synchronized with the starting state."""
        tasks = []
        if self._ref_sub_task is None:
            tasks.append(self._load_refs())
        if self._ent_sub_task is None:
            tasks.append(self._load_ents())
        if tasks:
            await asyncio.gather(*tasks)

    async def copy(self, dst_model):
        """Perform operations to make a destination model the same as this."""

        async def copy_refs():
            adds, dels, mods = _diff(self.refs, dst_model.refs, self.refs_by_type)
            logging.info('refs: {} adds, {} dels, {} mods'.format(
                len(adds), len(dels), len(mods)))
            for obj in adds:
                await dst_model.set_ref(obj)
            for obj in dels:
                await dst_model.del_ref(obj)
            for obj in mods:
                await dst_model.set_ref(obj)

        async def copy_nms():
            adds, dels, mods = _diff(self.nms, dst_model.nms)
            logging.info('ipas: {} adds, {} dels, {} mods'.format(
                len(adds), len(dels), len(mods)))
            for obj in adds:
                await dst_model.set_nm(obj)
            for obj in dels:
                await dst_model.del_nm(obj)
            for obj in mods:
                await dst_model.set_nm(obj)

        await asyncio.gather(self.update(), dst_model.update())
        await asyncio.gather(copy_refs(), copy_nms())

    def diff(self, rhs_model):
        """Compare this model to another one."""

        def resolve_refs(refs, attrs):
            """Resolve references.  return the combined attribute map."""
            result = {}
            bfs = [attrs]
            while bfs:
                for t, v in sorted(bfs.pop(0).items()):
                    if v is None:
                        continue  # just ignore attrs set to None
                    if isinstance(v, Reference):
                        if self.refs_by_type:
                            ref = (refs.get(t) or {}).get(v)
                        else:
                            ref = refs.get((t, v))
                        if ref:
                            bfs.append(ref.attributes)
                    else:
                        result[t] = result.get(t, v)
            return result

        diffs = []
        for key in sorted(set(list(self.nms.keys()) + list(rhs_model.nms.keys())),
                          key=ip_sort_key):
            lobj = self.nms.get(key)
            robj = rhs_model.nms.get(key)
            if lobj:
                lattrs = resolve_refs(self.refs, lobj.attributes)
            if robj:
                rattrs = resolve_refs(rhs_model.refs, robj.attributes)
            if robj is None and lattrs:
                diffs.append((key, lattrs, None))
            elif lobj is None and rattrs:
                diffs.append((key, None, rattrs))
            elif robj and lobj and lattrs != rattrs:
                diffs.append((key, {
                    x: y
                    for x, y in lattrs.items() if rattrs.get(x) != y
                }, {
                    x: y
                    for x, y in rattrs.items() if lattrs.get(x) != y
                }))
        return diffs
