"""
Network Poller GRPC integration for sending NetworkElementUpdate messages
"""
import asyncio
import logging

import grpc

from preseem import PreseemGrpcClient
from preseem_protobuf.network_poller.network_poller_pb2 import (
    NetworkElementUpdate, SendUpdateBatchRequest, GetUpdatesRequest,
    NetworkElementDiscovery, SendDiscoveryBatchRequest, GetDiscoveryRequest)
from preseem_protobuf.network_poller.network_poller_pb2_grpc import NetworkPollerServiceStub


class NetworkPollerGrpcClient:
    """Client for posting NetworkElementUpdate messages."""
    BATCH_HOLDOFF_TIME = 10  # max seconds to wait to send a batch
    RETRY_HOLDOFF_TIME = 10  # time to wait to retry a failed post
    MAX_MESSAGE_SIZE = 1024 * 1024 * 3  # GRPC max message size is 4M
    MAX_MESSAGE_QLEN = 4000  # max number of messages we will queue
    POST_TIMEOUT = 30  # timeout post request after waiting 30s

    def __init__(self, api_key=None, host=None, port=None):
        if host:
            self._client = PreseemGrpcClient(api_key=api_key, host=host, port=port)
        else:
            self._client = None  # for testing
        self.host = host
        self._svc = None
        self._msg_q_neu = None  #asyncio.Queue(5000)
        self._task_neu = None
        self._filters = []
        self.drop_count_neu = 0
        self._msg_q_ned = None
        self._task_ned = None
        self.drop_count_ned = 0

    async def close(self):
        """Close this object and cancel the client task."""
        if self._task_neu:
            self._task_neu.cancel()
            try:
                await self._task_neu
            except asyncio.CancelledError:
                pass
            self._task_neu = None
        if self._task_ned:
            self._task_ned.cancel()
            try:
                await self._task_ned
            except asyncio.CancelledError:
                pass
            self._task_ned = None
        if self._client:
            self._client = None

    async def get_updates(self, company_uuid=None):
        if not self._client:
            return
        if not self._svc:
            await self._client.connect()
            self._svc = NetworkPollerServiceStub(self._client._channel)
        req = GetUpdatesRequest()
        if company_uuid:
            req.company_uuid = company_uuid
        return await self._client._grpc_op(self._svc.GetUpdates,
                                           req,
                                           collect_results=True)

    async def get_discovery(self, company_uuid=None):
        if not self._client:
            return
        if not self._svc:
            await self._client.connect()
            self._svc = NetworkPollerServiceStub(self._client._channel)
        req = GetDiscoveryRequest()
        if company_uuid:
            req.company_uuid = company_uuid
        return await self._client._grpc_op(self._svc.GetDiscovery,
                                           req,
                                           collect_results=True)

    async def _post_neu(self, messages):
        """Post a list of NetworkElementUpdate messages"""
        if not self._client:
            return
        if not self._svc:
            await self._client.connect()
            self._svc = NetworkPollerServiceStub(self._client._channel)
        req = SendUpdateBatchRequest()
        req.data.extend(messages)
        return await self._client._grpc_op(self._svc.SendUpdateBatch,
                                           req,
                                           timeout=self.POST_TIMEOUT)

    async def _post_ned(self, messages):
        """Post a list of NetworkElementDiscovery messages"""
        if not self._client:
            return
        if not self._svc:
            await self._client.connect()
            self._svc = NetworkPollerServiceStub(self._client._channel)
        req = SendDiscoveryBatchRequest()
        req.data.extend(messages)
        return await self._client._grpc_op(self._svc.SendDiscoveryBatch,
                                           req,
                                           timeout=self.POST_TIMEOUT)

    async def _service_queue(self, msg_q, post_fn, type):
        """Service the message queue"""
        n = 0
        try:
            loop = asyncio.get_event_loop()
            messages = []
            next_message = None
            size = 0
            while True:
                if messages:
                    try:
                        await post_fn(messages)
                        messages.clear()
                        size = 0
                        if next_message:
                            messages.append(next_message)
                            size = next_message.ByteSize()
                            next_message = None
                    except asyncio.CancelledError:
                        raise
                    except grpc.RpcError as err:
                        if type is NetworkElementUpdate:
                            logging.warning(
                                'Error [%s] posting NetworkElementUpdate to %s: %s.  Queue size: %s.',
                                err.code(), self.host, err.details(),
                                self._msg_q_neu.qsize())
                        elif type is NetworkElementDiscovery:
                            logging.warning(
                                'Error [%s] posting NetworkElementDiscovery to %s: %s.  Queue size: %s.',
                                err.code(), self.host, err.details(),
                                self._msg_q_ned.qsize())
                        if err.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
                            # Shouldn't happen, but likely too large message
                            messages.clear()
                        await asyncio.sleep(self.RETRY_HOLDOFF_TIME)
                        continue
                    except Exception as err:
                        # if it fails, we'll retry until it succeeds.  So we're
                        # assuming the message is ok or we will never proceed.
                        logging.error('Error posting %s: %s', type.__name__, err)
                        await asyncio.sleep(self.RETRY_HOLDOFF_TIME)
                        continue
                try:
                    # Deque messages to build a batch to post
                    t = loop.time() + self.BATCH_HOLDOFF_TIME
                    while True:
                        if messages:
                            # build the batch until the holdoff time
                            msg = await asyncio.wait_for(msg_q.get(), t - loop.time())
                        else:
                            msg = await msg_q.get()
                        msg_q.task_done()
                        msg_size = msg.ByteSize()
                        if size + msg_size > self.MAX_MESSAGE_SIZE:
                            next_message = msg
                            break
                        else:
                            size += msg_size
                            messages.append(msg)
                except asyncio.TimeoutError:
                    pass  # we will send whatever we built on next loop iter
                except asyncio.CancelledError:
                    # try to send any in-progress message
                    try:
                        if messages:
                            await post_fn(messages)
                            if next_message:
                                await post_fn([next_message])
                    finally:
                        raise
        except asyncio.CancelledError:
            pass

    def add_filter(self, f):
        """Add a filter to handle messages as they are posted.  Take (msg). return msg or None."""
        self._filters.append(f)

    async def post(self, msg, id_key=None):
        """Push a NetworkElementUpdate / NetworkElementDiscovery message.  Optionally pass in an id_key
           which will be passed along to any message filters."""
        assert isinstance(msg, (NetworkElementUpdate, NetworkElementDiscovery))
        type = None
        if isinstance(msg, NetworkElementUpdate):
            type = NetworkElementUpdate
            msg_copy = NetworkElementUpdate()
        elif isinstance(msg, NetworkElementDiscovery):
            type = NetworkElementDiscovery
            msg_copy = NetworkElementDiscovery()
        if not self._task_neu:
            self._msg_q_neu = asyncio.Queue(self.MAX_MESSAGE_QLEN)
            self._task_neu = asyncio.get_event_loop().create_task(
                self._service_queue(self._msg_q_neu, self._post_neu,
                                    NetworkElementUpdate))
        if not self._task_ned:
            self._msg_q_ned = asyncio.Queue(self.MAX_MESSAGE_QLEN)
            self._task_ned = asyncio.get_event_loop().create_task(
                self._service_queue(self._msg_q_ned, self._post_ned,
                                    NetworkElementDiscovery))
        # Make a copy to make sure the caller can't change it after this
        msg_copy.CopyFrom(msg)
        try:
            if type is NetworkElementUpdate:
                for f in self._filters:  # Apply filter only on NetworkElementUpdate
                    msg_copy = f(msg_copy, id_key)
                    if msg_copy is None:  # filter out this message
                        return
                self._msg_q_neu.put_nowait(msg_copy)
            elif type is NetworkElementDiscovery:
                self._msg_q_ned.put_nowait(msg_copy)
        except asyncio.QueueFull:  # send queue is full, drop the request
            if type is NetworkElementUpdate:
                self.drop_count_neu += 1
                if str(self.drop_count_neu)[0] != str(self.drop_count_neu - 1)[0]:
                    # log when we're dropping, but don't spam the log
                    logging.warning('NetworkElementUpdate queue full, %s dropped',
                                    self.drop_count_neu)
            elif type is NetworkElementDiscovery:
                self.drop_count_ned += 1
                if str(self.drop_count_ned)[0] != str(self.drop_count_ned - 1)[0]:
                    # log when we're dropping, but don't spam the log
                    logging.warning('NetworkElementDiscovery queue full, %s dropped',
                                    self.drop_count_ned)


class FakeNetworkPollerGrpcClient(NetworkPollerGrpcClient):
    """Fake network poller client.  Will print the message if verbose.
       Post returns a future that will be set when the message is sent."""
    BATCH_HOLDOFF_TIME = 0.1  # short delay for testing

    def __init__(self, verbose=0):
        super().__init__(host=None)
        self.verbose = verbose
        self.fut = None

    async def _post_neu(self, messages):
        """Post a list of NetworkElementUpdate messages"""
        # TODO pretty-print messages if verbose > 0
        if self.fut:
            self.fut.set_result(messages.copy())
            self.fut = None

    async def _post_ned(self, messages):
        """Post a list of NetworkElementDiscovery messages"""
        # TODO pretty-print messages if verbose > 0
        if self.fut:
            self.fut.set_result(messages.copy())
            self.fut = None

    async def post(self, msg, id_key=None):
        if msg:
            await super().post(msg, id_key)
        else:  # we can pass a null message to just wait for a second result
            pass
        fut = self.fut or asyncio.get_event_loop().create_future()
        self.fut = fut
        return fut
