"""
This module provides a client class to connect to a Preseem GRPC service and
execute operations on it.  GRPC specific aspects are abstracted here, and
Python methods are provided.  By design, this module does not look at any
of the model objects, just the request/response objects and the rpc calls.

A large part of the responsibility of this module is to convert the third-party
GRPC library to the asyncio concurrency model.  This includes support for
streaming subscriptions, where the caller is blocked in a Subscribe call and
then has streaming messages asynchronously delivered into the event loop via
a callback.

This is designed to be easily replaceable by a stub client class for unit
testing.

All Preseem services are modeled here; I'm not super happy with that aspect
of the design, but for now it's left that way.
"""
import asyncio
import concurrent
import functools
from importlib import import_module
import logging
import time
import traceback

import grpc

from preseem_grpc_model import metrics_pb2, metrics_pb2_grpc
from preseem_grpc_model import network_metadata_pb2, network_metadata_pb2_grpc
from preseem_protobuf.config import company_configs_pb2, company_configs_pb2_grpc


class ConfigType():

    def __init__(self, name, prefix, value, pb, grpc):
        self.name = name  # the lowercase name of the enum
        self.value = value  # the int value of the enum
        self.prefix = prefix  # op prefix
        self.pb = pb  # pb2 import
        self.grpc = grpc  # pb2_grpc import
        self.svc = None  # service stub
        self.pb_type = None  # grpc type
        self.del_req = None  # delete request
        self.get_req = None  # get request
        self.set_req = None  # set request
        self.sub_req = None  # subscribe request


class Subscription(object):
    """Helper class to manage a GRPC subscription."""

    def __init__(self, client, op, req, cbk):
        """
        Manage a GRPC subscription in asyncio.  Pass in the client object
        and the operation to call that will return the subscription iterator.
        New results are handled by running a callback in the eventloop.
        The "syn" parameter to the callback is set to True if this is the
        start of a new synchronize operation (after a GRPC reconnection).
        """
        self._client = client
        self._op = op
        self._req = req
        self._cbk = cbk
        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self._loop = asyncio.get_event_loop()
        self._iter = None

    def close(self):
        """Close the subscription.  This wakes up the blocked thread."""
        if self._iter:
            self._iter.cancel()

    def _run(self):
        """Read from the subscription, in a separate thread."""
        retry = False
        while True:
            # Setup the subscription
            syn = False
            if self._iter is None:
                try:
                    if retry:
                        logging.info('retry subscription to: %s', self._client.host)
                    else:
                        retry = True
                    self._iter = self._client._grpc_op_thread(self._op, self._req)
                except Exception as e:
                    logging.warning('Error creating subscription: %s', e)
                    time.sleep(10)
                    continue
                syn = True

            try:
                msg = self._iter.next()
            except grpc.RpcError as e:
                # https://grpc.io/grpc/python/grpc.html?highlight=method#grpc.StatusCode
                # Disconnect, need to restart subscription.  Restart
                # the initializing state after this.
                if e.code() == grpc.StatusCode.CANCELLED and e.details(
                ) == 'Locally cancelled by application!':
                    return  # cancelled by application
                elif e.code() == grpc.StatusCode.NOT_FOUND:
                    # STM-4023 the thing we are subscribed to has gone away.
                    logging.warning('GRPC NotFound error: %s, exit subscription',
                                    e.details())
                    break
                logging.warning('GRPC ERROR: %s: %s for host: %s', e.code(),
                                e.details(), self._client.host)
                logging.info('Reconnect to %s after 10s', self._client.host)
                time.sleep(10)  # TODO Backoff
                self._iter = None
                continue

            except Exception as e:
                print('Error reading from subscription iterator: {}'.format(e))
                print(type(e))
                time.sleep(10)  # TODO Backoff
                continue
            try:
                fut = asyncio.run_coroutine_threadsafe(self._cbk(msg, syn), self._loop)
                result = fut.result()
            except concurrent.futures.CancelledError:  # eventloop shutdown
                return
            except Exception as e:
                print('Error calling callback: {}'.format(e))
                traceback.print_tb(e.__traceback__)
                time.sleep(10)

    async def run(self):
        """Read from the subscription and block the caller on the eventloop."""
        await self._loop.run_in_executor(self._executor, self._run)


class PreseemGrpcClient(object):
    """Implement the GRPC client to the Preseem API, adapted to asyncio."""

    def __init__(self, api_key, host=None, port=None):
        self.api_key = api_key
        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
        self.host = host or 'nmapi.preseem.com'
        self.port = port or 443
        self._channel = None
        self._metrics_svc = None
        self._nm_svc = None
        self._nm_ref_svc = None
        self._config2_svc = None

        self._nm_ent_subscription = None
        self._nm_ref_subscription = None
        self._cfg_subscriptions = {}
        self._cfg_lock = None

        self.config_types = {}
        for ctname, ctid in company_configs_pb2.Type.items():
            if ctid < 2:
                continue  # ignore NODE configs for now
            try:
                cn = ctname.lower()
                pmod = import_module(f'preseem_protobuf.config.{cn}_pb2')
                gmod = import_module(f'preseem_protobuf.config.{cn}_pb2_grpc')
                ct = self.config_types[ctid] = ConfigType(cn, cn.replace('_', ''), ctid,
                                                          pmod, gmod)
                # assume naming conventions for the operations
                capname = ''.join([x[:1].upper() + x[1:] for x in cn.split('_')])
                ct.pb_type = getattr(pmod, f'{capname}')
                ct.get_req = getattr(pmod, f'{capname}GetRequest')
                ct.set_req = getattr(pmod, f'{capname}SetRequest')
                ct.del_req = getattr(pmod, f'{capname}DeleteRequest')
                ct.sub_req = getattr(pmod, f'{capname}SubscribeRequest')
            except ModuleNotFoundError:
                # we don't support this, that's ok
                pass

    def close_subscriptions(self):
        if self._nm_ent_subscription:
            self._nm_ent_subscription.close()
        if self._nm_ref_subscription:
            self._nm_ref_subscription.close()
        for sub in self._cfg_subscriptions.values():
            sub.close()
        self._cfg_subscriptions.clear()

    def _grpc_op_thread(self, op, *args, num_retries=5):
        for retry_num in range(num_retries + 1):
            if retry_num > 0:
                time.sleep(2**retry_num)
                logging.debug('Retrying after error result')
            try:
                return op(*args)  #, timeout=600) -- careful, not good for subscribe!
            except Exception as err:
                logging.exception('Error calling grpc service')

    def _grpc_op_collect(self, op, req, metadata):
        """Run a GRPC operation and collect the results in this thread."""
        r = op(req, metadata=metadata)
        try:
            iterator = iter(r)  # this is to detect whether its a sequence
            return [x for x in iterator]  # this does network i/o
        except TypeError:  # not an iterable result, ok
            pass
        return r

    async def _grpc_op(self,
                       op,
                       req,
                       num_retries=5,
                       timeout=None,
                       collect_results=None,
                       metadata=None):
        loop = asyncio.get_event_loop()
        for retry_num in range(num_retries + 1):
            if retry_num > 0:
                await asyncio.sleep(2**retry_num)
                logging.debug('Retrying after error result')
            try:
                if timeout:
                    try:
                        if collect_results:
                            return await asyncio.wait_for(loop.run_in_executor(
                                self._executor, self._grpc_op_collect, op, req,
                                metadata),
                                                          timeout=timeout)
                        else:
                            return await asyncio.wait_for(loop.run_in_executor(
                                self._executor,
                                functools.partial(op, req, metadata=metadata)),
                                                          timeout=timeout)
                    except asyncio.TimeoutError:
                        logging.warning('GRPC operation timed out: %s', op)
                        raise
                else:
                    if collect_results:
                        x = await loop.run_in_executor(self._executor,
                                                       self._grpc_op_collect, op, req,
                                                       metadata)
                    else:
                        x = await loop.run_in_executor(
                            self._executor, functools.partial(op,
                                                              req,
                                                              metadata=metadata))
                    return x
            except grpc.RpcError as err:
                logging.debug('Error calling service: %s (%s)', err.code(),
                              err.details())
                if err.code() in (grpc.StatusCode.INTERNAL, grpc.StatusCode.NOT_FOUND,
                                  grpc.StatusCode.UNKNOWN,
                                  grpc.StatusCode.UNIMPLEMENTED,
                                  grpc.StatusCode.INVALID_ARGUMENT,
                                  grpc.StatusCode.PERMISSION_DENIED,
                                  grpc.StatusCode.RESOURCE_EXHAUSTED):
                    raise
                if retry_num == num_retries:
                    logging.warning('Error calling service: %s (%s)', err.code(),
                                    err.details())
                    raise
            except Exception as e:
                logging.warning('%s error calling service: %s', type(e).__name__, e)
                if retry_num == num_retries:
                    logging.exception('Error calling grpc service')
                    raise

    def _connect(self):
        """Connect to the server and authenticate. (blocking)"""

        class ApiKeyCallCredentials(grpc.AuthMetadataPlugin):

            def __init__(self, api_key):
                self.api_key = api_key

            def __call__(self, context, callback):
                callback((('api_key', self.api_key), ), None)

        dst = '{}:{}'.format(self.host, self.port)
        while self._nm_svc is None:
            try:
                # STM-3013, STM-3115 make grpc connection robust
                grpc_opts = [
                    ('grpc.keepalive_time_ms', 30000),  # send keepalive every 30s
                    ('grpc.keepalive_timeout_ms',
                     15000),  # keepalive ping timeout after 15s
                    ('grpc.keepalive_permit_without_calls',
                     True),  # send keepalives when no grpc calls
                    ('grpc.http2.max_pings_without_data', 0)
                ]  # allow unlimited pings without data
                if self.api_key:
                    call_creds = grpc.metadata_call_credentials(
                        ApiKeyCallCredentials(self.api_key))
                    try:
                        with open('/etc/pki/tls/certs/ca-bundle.crt', 'rb') as f:
                            ca_certs = f.read()
                    except:  # not fedora?  use default built-in CA bundle
                        ca_certs = None
                    ssl_creds = grpc.ssl_channel_credentials(root_certificates=ca_certs)
                    channel_creds = grpc.composite_channel_credentials(
                        ssl_creds, call_creds)
                    channel = grpc.secure_channel(dst, channel_creds, options=grpc_opts)
                else:
                    channel = grpc.insecure_channel(dst, options=grpc_opts)
                self._channel = channel
                self._metrics_svc = metrics_pb2_grpc.NetworkMetricsStub(channel)
                self._nm_svc = network_metadata_pb2_grpc.NetworkMetadataServiceStub(
                    channel)
                self._nm_ref_svc = network_metadata_pb2_grpc.NetworkMetadataReferenceServiceStub(
                    channel)
                self._config2_svc = company_configs_pb2_grpc.CompanyConfigsServiceStub(
                    channel)
                for ct in self.config_types.values():
                    cls = [x for x in dir(ct.grpc) if x[-4:] == 'Stub']
                    if cls:
                        ct.svc = getattr(ct.grpc, cls[0])(channel)
            except:
                logging.exception('Error connecting to grpc service')
                time.sleep(30)

    async def connect(self):
        """Connect to the server and authenticate."""
        if self._cfg_lock is None:
            self._cfg_lock = asyncio.Lock()
        async with self._cfg_lock:
            await asyncio.get_event_loop().run_in_executor(self._executor,
                                                           self._connect)

    async def nm_ent_subscribe(self, cbk, company_id=None, resolve=False):
        """Block in this coroutine and call cbk as notifications arrive."""
        if not self._nm_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataSubscribeRequest(
            allow_ipv6_and_shorter_ipv4_prefixes=True)
        req.do_not_resolve_references = not resolve
        if company_id:
            req.company_id = company_id
        self._nm_ent_subscription = Subscription(self, self._nm_svc.Subscribe, req, cbk)
        try:
            await self._nm_ent_subscription.run()
        except asyncio.CancelledError:
            return

    async def nm_ref_subscribe(self, cbk, company_id=None):
        """Subscribe to Network Metadata Reference subscriptions."""
        if not self._nm_ref_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataReferenceSubscribeRequest()
        if company_id:
            req.company_id = company_id
        self._nm_ref_subscription = Subscription(self, self._nm_ref_svc.Subscribe, req,
                                                 cbk)
        try:
            await self._nm_ref_subscription.run()
        except asyncio.CancelledError:
            return

    async def nm_list(self, company_id=None):
        if not self._nm_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataListRequest()
        if company_id:
            req.company_id = company_id
        return await self._grpc_op(self._nm_svc.List, req)

    async def nm_set(self, entity, company_id=None):
        if not self._nm_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataSetRequest(entity=entity)
        if company_id:
            req.company_id = company_id
        return await self._grpc_op(self._nm_svc.Set, req)

    async def nm_del(self, entity, company_id=None):
        if not self._nm_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataDeleteRequest(entity=entity)
        if company_id:
            req.company_id = company_id
        return await self._grpc_op(self._nm_svc.Delete, req)

    async def nm_ref_list(self, company_id=None):
        if not self._nm_ref_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataReferenceListRequest()
        if company_id:
            req.company_id = company_id
        return await self._grpc_op(self._nm_ref_svc.List, req)

    async def nm_ref_set(self, ref, company_id=None):
        if not self._nm_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataReferenceSetRequest(entity=ref)
        if company_id:
            req.company_id = company_id
        return await self._grpc_op(self._nm_ref_svc.Set, req)

    async def nm_ref_del(self, ref, company_id=None):
        if not self._nm_svc:
            await self.connect()
        req = network_metadata_pb2.NetworkMetadataReferenceDeleteRequest(entity=ref)
        if company_id:
            req.company_id = company_id
        return await self._grpc_op(self._nm_ref_svc.Delete, req)

    async def push_flexes(self, req):
        if not self._metrics_svc:
            await self.connect()
        return await self._grpc_op(self._metrics_svc.PushFlexes, req, timeout=300)

    async def get_config(self, config_type, config_id, company_id=None):
        """Get a config instance by type and id"""
        ct = self.config_types.get(config_type)
        if ct is None:
            raise ValueError('Unsupported config type')
        if not ct.svc:
            await self.connect()
        req = ct.get_req()
        req.company_id = company_id
        setattr(req, f'{ct.prefix}_id', config_id)
        r = await self._grpc_op(ct.svc.Get, req)
        if bool(r.success):
            return getattr(r, ct.prefix)
        else:  # probably a race where it was deleted before we could get it
            logging.warning('Failed to get %s config "%s"', ct.name, config_id)
            return None

    async def set_config(self, config_type, obj, company_id=None):
        ct = self.config_types.get(config_type)
        if ct is None:
            raise ValueError('Unsupported config type')
        if not ct.svc:
            await self.connect()
        req = ct.set_req()
        req.company_id = company_id
        getattr(req, ct.prefix).CopyFrom(obj)
        r = await self._grpc_op(ct.svc.Set, req)
        return bool(r.success)

    async def del_config(self, config_type, config_id, company_id=None):
        ct = self.config_types.get(config_type)
        if ct is None:
            raise ValueError('Unsupported config type')
        if not ct.svc:
            await self.connect()
        req = ct.del_req()
        req.company_id = company_id
        setattr(req, f'{ct.prefix}_id', config_id)
        r = await self._grpc_op(ct.svc.Delete, req)
        return bool(r.success)

    async def sub_config(self, config_type, config_id, cbk, company_id=None):
        """
        Subscribe to notifications for updates to a config instance.
        The callback takes args (company_id, config_type, message).
        """

        async def _cbk(msg, _):
            nonlocal company_id, config_type, cbk
            await cbk(company_id, config_type, msg)

        ct = self.config_types.get(config_type)
        if not ct.svc:
            await self.connect()
        req = ct.sub_req()
        req.company_id = company_id
        setattr(req, f'{ct.prefix}_id', config_id)
        sub = self._cfg_subscriptions[config_id] = Subscription(
            self, ct.svc.Subscribe, req, _cbk)
        try:
            await sub.run()
        except asyncio.CancelledError:
            sub.close()
            try:
                del self._cfg_subscriptions[config_id]
            except KeyError:
                pass  # already deleted

    async def sub_company_configs(self, cbk, company_id=None, subscribe_all=None):
        """Subscribe to company config updates."""
        if not self._config2_svc:
            await self.connect()
        req = company_configs_pb2.CompanyConfigSubscribeRequest()
        if company_id:
            req.company_id = company_id
        elif subscribe_all:
            req.subscribe_all = True
        sub = Subscription(self, self._config2_svc.Subscribe, req, cbk)
        try:
            await sub.run()
        except asyncio.CancelledError:
            sub.close()
