"""Function to sync a Preseem network metadata store to the new Preseem model."""
import asyncio
from collections import namedtuple
from ipaddress import ip_address, ip_network
import logging
from uuid import UUID

from grpc import RpcError

from preseem import NetworkMetadataReference
from preseem.source_model import AccessPoint, Olt, Router, Service
from preseem_protobuf.model.common_pb2 import ElementStatus
from preseem_protobuf.model.service_pb2 import Attachment, SetAttachmentReq

# STM-6039 Only allow the sync to run once at a time, to prevent duplicate
# entities from being created.
sync_lock = None


# This object is used to pass element items to the sync_preseem_model
# algorithm for a given source and intent role.
class ElementSyncItem:

    def __init__(self, id, name, site, host):
        self.id = id
        self.name = name
        self.site = site
        self.host = host
        self._site_uuid = None

    def __repr__(self):
        clsname = self.__class__.__name__
        attrs = ', '.join(f"{k}={v!r}" for k, v in self.__dict__.items() if k[0] != '_')
        return f"{clsname}({attrs})"


_account_field_map = {
    'type': lambda x: x.type,
    'status': lambda x: x.status,
    'name': lambda x: x.name,
    'first_name': lambda x: x.first_name,
    'last_name': lambda x: x.last_name,
}

_package_field_map = {
    'name': lambda x: x.name,
    'download_rate': lambda x: x.download_rate or None,
    'upload_rate': lambda x: x.upload_rate or None,
}

_elem_disc_field_map = {
    'name': lambda x: x.name,
    'management_ip': lambda x: x.host,
    'site_uuid': lambda x: x._site_uuid,
    'intent_role_uuid': lambda x: x.role_uuid.bytes,
}

_service_field_map = {
    'account_uuid': lambda x: x._account_uuid,
    'package_uuid': lambda x: x._package_uuid,
    'download_rate': lambda x: x._download_rate or None,
    'upload_rate': lambda x: x._upload_rate or None,
    'attachment': lambda x: x._attachment.pb if x._attachment else None,
    'subscriber_identifier': lambda x: x._subscriber_identifier,
    'source_url': lambda x: x._url,
}

_elem_field_map = {
    'name': lambda x: x.name,
    'management_ip': lambda x: x.host,
    'site_uuid': lambda x: x._site_uuid,
    'intent_role_uuid': lambda x: x.role_uuid.bytes,
    'poller_hash': lambda x: x.poller_hash,
}


def has_any_fields_set(message):
    for field_desc, value in message.ListFields():
        if value:
            return True
    return False


async def call_and_log_error(co):
    """Call a GRPC coroutine and log any errors from it."""
    try:
        return await co
    except asyncio.CancelledError:
        raise
    except RpcError as err:
        logging.warning('sync_preseem_model: error calling service: %s (%s)',
                        err.code(), err.details())
    except Exception as err:
        logging.warning('sync_preseem_model: %s error calling service: %s',
                        type(err).__name__,
                        err,
                        exc_info=1)
    return None


async def run_limited(coros, limit: int):
    """Run a list of coroutines with a maximum concurrency limit."""
    sem = asyncio.Semaphore(limit)

    async def guarded(coro):
        async with sem:
            return await coro

    # schedule everything, but only `limit` run at a time
    return await asyncio.gather(*(guarded(c) for c in coros), )


async def _sync_objects(model,
                        company_uuid,
                        source,
                        bss_objs,
                        model_objs,
                        createrpc,
                        updaterpc,
                        field_map,
                        deactivate,
                        noact,
                        cfg=None,
                        init_fields=None,
                        concurrency=10):
    """
    Function to sync python bss-sourced objects to model objects.
    Pass in the model and the company_uuid and source we are working on.
    - bss_objs is a dict {source_id -> obj) where obj is a python bss object.
    - model_objs is a list of objects from the model.
    - createrpc and updaterpc are the grpc update methods for the object.
    - field_map is a dict of {model field name -> mapping func} where the mapping func
      returns the value to sync from the bss object.
    """
    new_ids = set(bss_objs)
    new_objs = []
    modcos = []
    delcos = []
    for mobj in model_objs:
        if mobj.source != source:
            continue  # ignore objects owned by other sources
        bobj = bss_objs.get(mobj.source_id)
        if bobj:
            changes = []
            mods = {}
            new_ids.discard(mobj.source_id)
            if mobj.inactive:
                changes.append('reactivated')
                mods['inactive'] = False
            for f, m in field_map.items():
                mval = getattr(mobj, f)
                bval = m(bobj)
                if mval != bval:
                    if f == 'attachment':
                        # find what attachment fields changed for logging purposes
                        batt = Attachment()
                        batt.CopyFrom(bval)  # make a copy so I don't modify the input
                        bval = batt
                        achgs = []
                        attachment_cfg = cfg.get('attachment') if cfg else {}
                        if mval and bval:
                            for sf in bval.DESCRIPTOR.fields:
                                smval = getattr(mval, sf.name)
                                sbval = getattr(bval, sf.name)
                                if sf.name == 'mac_addresses':
                                    smval = set(x for x in smval)
                                    sbval = set(x for x in sbval)
                                if sf.name == 'network_prefixes':
                                    smval = set([ip_network(x) for x in smval])
                                    sbval = set([ip_network(x) for x in sbval])
                                field_cfg = attachment_cfg.get(sf.name)
                                if field_cfg and field_cfg.get('ignore'):
                                    # To leave the field as it was, we have to set it
                                    # to what is in model-server now,
                                    # because the attachment is set as one full object.
                                    setattr(bval, sf.name, smval)
                                    sbval = smval
                                if smval != sbval:
                                    achgs.append(f"{sf.name}[{smval}->{sbval}]")
                        elif bval and not mval:
                            # PPA-318 the model sets an empty attachment to no
                            # attachment, so handle the comparison properly here.
                            if has_any_fields_set(bval):
                                achgs = [f"null->{bval}"]
                        elif mval and not bval:
                            achgs = [f"{bval}->null"]
                        if not achgs:
                            continue  # no further inspection, not actually a diff
                        changes.append(f'{f}: {",".join(achgs)}')
                    else:
                        changes.append(f'{f}: {mval} -> {bval}')
                    mods[f] = bval
            if mods:
                objdesc = mobj
                if isinstance(bobj, Service):
                    # more concise error message for Service
                    objdesc = f"Service {UUID(bytes=mobj.uuid)}"
                logging.info('sync_preseem_model(%s): modify %s (%s)', source, objdesc,
                             ', '.join(changes))
                if not noact:
                    if isinstance(bobj, Service) and 'attachment' in mods:
                        # special case for setting attachment
                        attachment = mods.pop('attachment')
                        set_att_req = SetAttachmentReq()
                        if attachment:
                            set_att_req.attachment.CopyFrom(attachment)
                        mods['attachment_req'] = set_att_req
                    modcos.append(
                        call_and_log_error(
                            updaterpc(company_uuid=company_uuid, uuid=mobj.uuid,
                                      **mods)))
        elif deactivate:  # sync deletes unless explicitly told not to
            if not mobj.inactive:
                logging.info("sync_preseem_model(%s): deactivate %s %s", source,
                             type(mobj).__name__.lower(), UUID(bytes=mobj.uuid))
                if not noact:
                    delcos.append(
                        call_and_log_error(
                            updaterpc(company_uuid=company_uuid,
                                      uuid=mobj.uuid,
                                      inactive=True)))

    # call modify and delete ops concurrently
    if modcos:
        logging.info(f"Modify {len(modcos)} objects")
        await run_limited(modcos, concurrency)
        logging.info(f"Modified {len(modcos)} objects")
    if delcos:
        logging.info(f"Deactivate {len(delcos)} objects")
        await run_limited(delcos, concurrency)
        logging.info(f"Deactivated {len(delcos)} objects")

    # Need to do these concurrent too, easy enough, will need a new helper fn.
    addcos = []
    for obj_id in new_ids:
        bobj = bss_objs.get(obj_id)
        if bobj:
            logging.info('sync_preseem_model(%s): create %s', source, bobj)
            if not noact:
                addcos.append(
                    call_and_log_error(
                        createrpc(company_uuid=company_uuid,
                                  source=source,
                                  source_id=bobj.id,
                                  **{
                                      f: m(bobj)
                                      for f, m in field_map.items()
                                  },
                                  **{
                                      f: v
                                      for f, v in (init_fields or {}).items()
                                  })))

    if addcos:
        logging.info(f"Create {len(addcos)} objects")
        mobjs = await run_limited(addcos, concurrency)
        for mobj in mobjs:
            if mobj:
                uuid = UUID(bytes=mobj.uuid)
                logging.info('sync_preseem_model(%s): created %s "%s"', source,
                             type(mobj).__name__, uuid)
                new_objs.append(mobj)
        logging.info(f"Created {len(addcos)} objects")
    return new_objs


async def _sync_sites(model, company_uuid, source, objs, noact=False):
    """Make sure model has all sites needed by source objs, and set the site_uuid on each obj."""
    sites = await model.Site.List(company_uuid=company_uuid)
    sites = {x.uuid: x for x in sites}

    # Create a map of name -> site.  Handle multiple sites with the same
    # name by preferring active sites, then taking the first sorted uuid.
    site_names = {}
    for site in sorted(sites.values(), key=lambda x: x.uuid):
        cur_site = site_names.get(site.name)
        if not cur_site or (cur_site.inactive and not site.inactive):
            site_names[site.name] = site

    new_sites = set()
    for sobj in objs:
        if sobj.site:
            site = site_names.get(sobj.site)
            if not site:
                new_sites.add(sobj.site)
            elif site.inactive:
                logging.info('sync_sites(%s): reactivate site "%s"', source, sobj.site)
                if not noact:
                    await call_and_log_error(
                        model.Site.Update(company_uuid=company_uuid,
                                          uuid=site.uuid,
                                          inactive=False))
    for site_name in new_sites:
        logging.info('sync_sites(%s): create site "%s"', source, site_name)
        if not noact:
            site = site_names[site_name] = await call_and_log_error(
                model.Site.Create(company_uuid=company_uuid, name=site_name))
            if site:
                sites[site.uuid] = site_names[site_name] = site

    for sobj in objs:
        if sobj.site:
            site = site_names.get(sobj.site)
            sobj._site_uuid = site.uuid
        else:
            sobj._site_uuid = None


async def sync_preseem_model(preseem_model,
                             company_uuid,
                             source,
                             role,
                             items,
                             packages=None,
                             accounts=None,
                             services=None,
                             routers=None,
                             olts=None,
                             aps=None,
                             nm_model=None,
                             deactivate=None,
                             noact=False,
                             cfg=None,
                             verbose=0):
    """
    Update the Preseem model to keep its ElementDiscovery objects in sync with
    another system.  This method takes the full set of items for a specific
    company, source, and role.  It diffs this against the current model and
    applies any create, modify or delete (set inactive) operations.
    It returns a map of {id -> uuid} for the caller to use.
    """
    global sync_lock
    if sync_lock is None:
        sync_lock = asyncio.Lock()

    assert isinstance(company_uuid, bytes)
    assert isinstance(source, str)
    if role:
        # PPA-1512 pass role=None to disable the old AP sync method.
        assert isinstance(role, UUID)

    if source == '':
        raise ValueError('source cannot be the empty string')

    async with sync_lock:
        # Collect all discovery objects so we can sync the sites.
        ne_objs = [
            obj for d in (aps, routers, olts) if d is not None for obj in d.values()
        ]
        for item in items or []:
            if item.site is not None and (not isinstance(item.site, str)
                                          or item.site == ''):
                raise ValueError('invalid string value for site')
            ne_objs.append(item)

        model_accounts, model_discs, model_services, model_packages, _ = await asyncio.gather(
            preseem_model.Account.List(company_uuid=company_uuid),
            preseem_model.Element.ListElementDiscovery(company_uuid=company_uuid),
            preseem_model.Service.List(company_uuid=company_uuid,
                                       include_attachment=True),
            preseem_model.Package.List(company_uuid=company_uuid),
            _sync_sites(preseem_model, company_uuid, source, ne_objs, noact))

        role_discs = {}
        for disc in model_discs:
            discs = role_discs.get(disc.intent_role_uuid)
            if discs is None:
                discs = role_discs[disc.intent_role_uuid] = []
            discs.append(disc)

        if role:
            model_discs = {x.uuid: x for x in role_discs.get(role.bytes) or []}
            disc_ids = {(x.source, x.source_id): x for x in model_discs.values()}

        # Make sure all sites we reference exist and do some pre-validation.
        new_sites = set()
        item_ids = set()
        dup_ids = set()
        for item in items:
            if not isinstance(item.id, str) or item.id == '':
                raise ValueError('invalid string value for source_id')
            if not isinstance(item.name, str) or item.name == '':
                raise ValueError('invalid string value for name')
            if item.id in item_ids:
                dup_ids.add(item.id)
            else:
                item_ids.add(item.id)
        if dup_ids:
            logging.warning('sync_preseem_model(%s): duplicate IDs found: %s', source,
                            dup_ids)
            raise ValueError('duplicate IDs found')

        uuid_map = {}
        if role:  # old algorithm
            unref_discs = set(
                [x[0] for x in model_discs.items() if x[1].source == source])
            for item in items:
                host = str(ip_address(item.host))  # validate ip address
                key = (source, item.id)
                disc = disc_ids.get(key)
                if disc:
                    uuid_map[item.id] = UUID(bytes=disc.uuid)
                    unref_discs.remove(disc.uuid)
                    changes = []
                    mods = {}
                    if disc.inactive:
                        changes.append('reactivated')
                        mods['inactive'] = False
                        mods['element_uuid'] = None  # STM-9823
                    if disc.name != item.name:
                        changes.append(f'name: {disc.name} -> {item.name}')
                        mods['name'] = item.name
                    if disc.management_ip != host:
                        changes.append(f'management_ip: {disc.management_ip} -> {host}')
                        mods['management_ip'] = host
                    if disc.site_uuid != item._site_uuid:
                        changes.append(f'site: {disc.site_uuid} -> {item._site_uuid}')
                        mods['site_uuid'] = item._site_uuid
                    if mods:
                        logging.info('sync_preseem_model(%s): modify %s (%s)', source,
                                     item, ','.join(changes))
                        if not noact:
                            if nm_model and 'reactivated' in changes and disc.element_uuid:
                                # STM-9823 as part of clearing the element_uuid,
                                # we also need to delete the elem_discovery ref.
                                ref = NetworkMetadataReference(
                                    'elem_discovery', str(UUID(bytes=disc.uuid)), {})
                                logging.info(' delete reference: %s', ref)
                                await nm_model.del_ref(ref)
                            await call_and_log_error(
                                preseem_model.Element.UpdateElementDiscovery(
                                    company_uuid=company_uuid, uuid=disc.uuid, **mods))

                else:
                    logging.info('sync_preseem_model(%s): create %s', source, item)
                    if not noact:
                        disc = await call_and_log_error(
                            preseem_model.Element.CreateElementDiscovery(
                                company_uuid=company_uuid,
                                name=item.name,
                                management_ip=item.host,
                                site_uuid=item._site_uuid,
                                source=source,
                                source_id=item.id,
                                intent_role_uuid=role.bytes))
                        if disc:
                            uuid = UUID(bytes=disc.uuid)
                            logging.info(
                                'sync_preseem_model(%s): created ElementDiscovery "%s"',
                                source, uuid)
                            uuid_map[item.id] = uuid
            for disc_uuid in unref_discs:
                disc = model_discs.get(disc_uuid)
                if not disc.inactive:
                    logging.info("sync_preseem_model(%s): deactivate %s", source,
                                 UUID(bytes=disc_uuid))
                    if not noact:
                        await call_and_log_error(
                            preseem_model.Element.UpdateElementDiscovery(
                                company_uuid=company_uuid,
                                uuid=disc_uuid,
                                inactive=True))
        new_packages = []
        if packages is not None:
            assert isinstance(packages, dict)
            for pkg in packages.values():
                if not isinstance(pkg.id, str) or pkg.id == '':
                    raise ValueError('invalid string value for package source_id')
                if not isinstance(pkg.name, str) or pkg.name == '':
                    raise ValueError('invalid string value for package name')
                if pkg.download_rate and (not isinstance(pkg.download_rate, int)
                                          or pkg.download_rate < 0):
                    raise ValueError('invalid value for package down rate')
                if pkg.upload_rate and (not isinstance(pkg.upload_rate, int)
                                        or pkg.upload_rate < 0):
                    raise ValueError('invalid value for package up rate')
            new_packages = await _sync_objects(preseem_model,
                                               company_uuid,
                                               source,
                                               packages,
                                               model_packages,
                                               preseem_model.Package.Create,
                                               preseem_model.Package.Update,
                                               _package_field_map,
                                               deactivate=deactivate is not False,
                                               noact=noact)

        new_accounts = []
        if accounts is not None:
            assert isinstance(accounts, dict)
            for account in accounts.values():
                if not isinstance(account.id, str) or account.id == '':
                    raise ValueError('invalid string value for account source_id')
                if account.type is not None and (not isinstance(account.type, int)
                                                 or account.type <= 0):
                    raise ValueError('invalid value for account type')
                if account.status is None or not isinstance(account.status,
                                                            int) or account.status <= 0:
                    raise ValueError('invalid value for account status')
                if account.name is not None and (not isinstance(account.name, str)
                                                 or account.name == ''):
                    raise ValueError('invalid string value for account name')
                if account.first_name is not None and (not isinstance(
                        account.first_name, str) or account.first_name == ''):
                    raise ValueError('invalid string value for account first_name')
                if account.last_name is not None and (not isinstance(
                        account.last_name, str) or account.last_name == ''):
                    raise ValueError('invalid string value for account last_name')
            new_accounts = await _sync_objects(preseem_model,
                                               company_uuid,
                                               source,
                                               accounts,
                                               model_accounts,
                                               preseem_model.Account.Create,
                                               preseem_model.Account.Update,
                                               _account_field_map,
                                               deactivate=deactivate is not False,
                                               noact=noact)

        if services is not None:
            assert isinstance(services, dict)
            # Setup maps to lookup account and package UUIDs from the source IDs.
            # We have to handle cases where there are multiple matching objects.
            accounts_by_id = {}
            for account in model_accounts + new_accounts:
                al = accounts_by_id.get(account.source_id)
                if al is None:
                    accounts_by_id[account.source_id] = account
                elif isinstance(al, list):
                    al.append(account)
                else:
                    accounts_by_id[account.source_id] = [al, account]
            packages_by_id = {}
            for package in model_packages + new_packages:
                pl = packages_by_id.get(package.source_id)
                if pl is None:
                    packages_by_id[package.source_id] = package
                elif isinstance(pl, list):
                    pl.append(package)
                else:
                    packages_by_id[package.source_id] = [pl, package]

            # type/value checks are done in the source_model class
            for service in services.values():
                assert isinstance(service, Service)
                account = accounts_by_id.get(
                    service.account_id) if service.account_id else None
                if account:
                    if isinstance(account, list):  # multiple matching accounts
                        acc = next((x for x in account if x.source == source), None)
                        if acc:
                            account = acc  # use account with matching source
                        else:
                            account = None  # multiple non-source matches; raise error
                            service.err(
                                'account_id', service.account_id,
                                ValueError("Multiple accounts match the account ID"))
                    if account:
                        service._account_uuid = account.uuid
                package = packages_by_id.get(
                    service.package_id) if service.package_id else None
                if package:
                    if isinstance(package, list):  # multiple matching packages
                        pkg = next((x for x in package if x.source == source), None)
                        if pkg:
                            package = pkg  # use package with matching source
                        else:
                            package = None  # multiple non-source matches; raise error
                            service.err(
                                'package_id', service.package_id,
                                ValueError("Multiple packages match the package ID"))
                    if package:
                        service._package_uuid = package.uuid
            await _sync_objects(preseem_model,
                                company_uuid,
                                source,
                                services,
                                model_services,
                                preseem_model.Service.Create,
                                preseem_model.Service.Update,
                                _service_field_map,
                                deactivate=deactivate is not False,
                                cfg=cfg.get('service') if cfg else None,
                                noact=noact)

        for bss_objs, source_type in (aps, AccessPoint), (routers, Router), (olts, Olt):
            if bss_objs is not None:
                assert isinstance(bss_objs, dict)
                for sobj in bss_objs.values():
                    assert isinstance(sobj, source_type)
                    if not sobj.name:  # name is required for ElementDiscovery
                        raise ValueError(
                            f'invalid value for {source_type.typename} name')
                    if not sobj.host:  # host is required for ElementDiscovery
                        raise ValueError(
                            f'invalid value for {source_type.typename} host')

                discs = [
                    x for x in role_discs.get(source_type.role_uuid.bytes) or []
                    if x.source == source
                ]
                new_objs = await _sync_objects(
                    preseem_model,
                    company_uuid,
                    source,
                    bss_objs,
                    discs,
                    preseem_model.Element.CreateElementDiscovery,
                    preseem_model.Element.UpdateElementDiscovery,
                    _elem_disc_field_map,
                    deactivate=deactivate is not False,
                    noact=noact)
                for mobj in discs + new_objs:
                    sobj = bss_objs.get(mobj.source_id)
                    if sobj:
                        sobj._uuid = mobj.uuid

        return uuid_map


async def sync_preseem_elements(model, company_uuid, source, objs, noact=False):
    """
    Update Elements directly in the Preseem model, given source objects such as APs,
    OLTs, Routers.  This is to support cases when another system such as an NMS is the
    source of truth for an element's data, and we will not be polling the element
    directory.  In this case, the element may not even have (or we may not know) a 
    management IP address, and there is no point in having the poller discover it.
    """
    assert isinstance(company_uuid, bytes)
    assert isinstance(source, str)
    if source == '':
        raise ValueError('source cannot be the empty string')

    global sync_lock
    if sync_lock is None:
        sync_lock = asyncio.Lock()

    async with sync_lock:
        # It would be more efficient if we had a way to get elements by source,
        # but there is no such method, so we get all elements and filter what we need.
        all_elems, _ = await asyncio.gather(
            model.Element.List(company_uuid=company_uuid),
            _sync_sites(model, company_uuid, source, objs.values(), noact=noact))

        # Only keep elements that match a source and poller hash in objs.
        # This works because we don't currently disable elements anyway, and it
        # makes sure we're syncing the correct element in the case of a hardware swap.
        phs = {x.poller_hash: x for x in objs.values()}

        elems = []
        for elem in all_elems:
            phelem = phs.get(elem.poller_hash)
            if elem.source != source or not phelem:
                continue
            elems.append(elem)

            # Handle ID changes with the same poller hash.  This needs to be the same
            # element object.
            if phelem and phelem.id != elem.source_id:
                logging.info(
                    "sync_preseem_element: change element source_id %s: %s -> %s",
                    UUID(bytes=elem.uuid), elem.source_id, phelem.id)
                if not noact:
                    await call_and_log_error(
                        model.Element.Update(company_uuid=company_uuid,
                                             uuid=elem.uuid,
                                             source_id=phelem.id))
                elem.source_id = phelem.id

        # Sync the objects to the model.  New elements need their status set.
        init_fields = {'status': ElementStatus.ELEMENT_STATUS_UNKNOWN}
        new_objs = await _sync_objects(model,
                                       company_uuid,
                                       source,
                                       objs,
                                       elems,
                                       model.Element.Create,
                                       model.Element.Update,
                                       _elem_field_map,
                                       deactivate=False,
                                       init_fields=init_fields,
                                       noact=noact)
        for mobj in elems + new_objs:
            sobj = objs.get(mobj.source_id)
            if sobj:
                sobj._uuid = mobj.uuid
