"""Test NetworkPoller GRPC infrastructure"""
import asyncio
import os
import sys
import unittest

from preseem_protobuf.network_poller.network_poller_pb2 import NetworkElementUpdate
from preseem import FakeNetworkPollerGrpcClient

class TestNetworkPollerGrpc(unittest.TestCase):
    def setUp(self):
        self.loop = asyncio.new_event_loop() # needed to initialize asyncio
        asyncio.set_event_loop(self.loop)
        self.client = None

    def tearDown(self):
        if self.client:
            self._await(self.client.close())
        self.loop.close()

    def _await(self, co):
        return self.loop.run_until_complete(co)

    async def _post(self, msg, asyn=False, filt=None):
        if self.client is None:
            self.client = FakeNetworkPollerGrpcClient()
        if filt:
            self.client.add_filter(filt)
        fut = await self.client.post(msg)
        if asyn:
            return None
        return await asyncio.wait_for(fut, 1)

    def post(self, msg, asyn=False, filt=None):
        return self._await(self._post(msg, asyn, filt))

    def test_network_poller_grpc_01(self):
        """Post a message.  It is delivered."""
        msg = NetworkElementUpdate()
        msg.time.FromSeconds(1612927766)
        r = self.post(msg)
        self.assertIsNotNone(r)
        self.assertEqual(len(r), 1)
        self.assertEqual(r[0], msg)

    def test_network_poller_grpc_02(self):
        """Post two messages.  They are delivered as a batch."""
        msg1 = NetworkElementUpdate()
        msg1.time.FromSeconds(1612927766)
        r = self.post(msg1, True)
        msg2 = NetworkElementUpdate()
        msg2.time.FromSeconds(1612928035)
        r = self.post(msg2)
        self.assertIsNotNone(r)
        self.assertEqual(len(r), 2)
        self.assertEqual(r[0], msg1)
        self.assertEqual(r[1], msg2)

    def test_network_poller_grpc_03(self):
        """Test maximum batch message size."""
        start_time = 1612927766
        msg = NetworkElementUpdate()
        msg.data.name = '0123456789' * 1000  # 10KB string
        n = 0
        t = start_time
        nm = 0
        while True:
            msg.time.FromSeconds(t)
            t += 1
            n += msg.ByteSize()
            nm += 1
            if n > FakeNetworkPollerGrpcClient.MAX_MESSAGE_SIZE:
                r = self.post(msg)
                break
            else:
                self.post(msg, True)
        self.assertIsNotNone(r)
        self.assertEqual(len(r), nm - 1)  # the next one will follow later
        for i, rmsg in enumerate(r):
            msg.time.FromSeconds(start_time + i)
            self.assertEqual(rmsg, msg)
        # get the last message that overflowed in a separate batch
        r = self.post(None)
        self.assertEqual(len(r), 1)
        self.assertEqual(r[0].data, msg.data)

    def test_network_poller_grpc_04(self):
        """Test oversized message is not part of the current or next batch."""
        start_time = 1612927766
        msg = NetworkElementUpdate()
        msg.data.name = '0123456789' * 1000  # 10 KB string

        n = 0
        t = start_time
        nm = 0
        r = None
        regular_msg_size = msg.ByteSize()
        while True:
            msg.time.FromSeconds(t)
            if nm == 2:
                # 3rd message is oversized
                msg.data.name = 'OOVERSIZED' * 1000 * 500

            t += 1
            n += msg.ByteSize()
            nm += 1
            if msg.ByteSize() > FakeNetworkPollerGrpcClient.MAX_MESSAGE_SIZE:
                r = self.post(msg)
                break
            else:
                self.post(msg, True)

        self.assertIsNotNone(r)
        self.assertEqual(len(r), 2)  # the batch size should stop on 2
        # back to original reg size message reference
        msg.data.name = '0123456789' * 1000  # 10 KB string
        for i, rmsg in enumerate(r):
            msg.time.FromSeconds(start_time + i)
            #compare all message in batch are regular size with proper time
            self.assertEqual(rmsg, msg)

        # post one regular size message again to see that oversized message is not part of the next batch
        msg.time.FromSeconds(t)
        r = self.post(msg, False)
        self.assertEqual(len(r), 1)
        self.assertEqual(r[0].data, msg.data)
        msg.time.FromSeconds(start_time + 4)

    def test_filter_01(self):
        """Test that a filter can be added that modifies the message."""
        def filt(msg, _):
            msg.instance = 'test'
            return msg
        msg = NetworkElementUpdate()
        msg.time.FromSeconds(1612927766)
        r = self.post(msg, filt=filt)
        self.assertIsNotNone(r)
        self.assertEqual(len(r), 1)
        msg.instance = 'test'
        self.assertEqual(r[0], msg)


    def test_filter_02(self):
        """Test that a filter can be added that suppresses the message."""
        def filt(msg, _):
            return None
        msg = NetworkElementUpdate()
        msg.time.FromSeconds(1612927766)
        self.assertRaises(asyncio.TimeoutError, self.post, msg, filt=filt)
