import asyncio
import time
import unittest

from preseem_grpc_model import metrics_pb2
from preseem import NetworkMetricsModel, FlexMetric

class FakeMetricsClient(object):
    """Fake metrics client"""
    def __init__(self):
        self.flexes = []  # queue of requests pushed to the client
        self.error = None # error to throw on push

    async def push_flexes(self, req):
        if self.error:
            err = self.error
            self.error = None
            raise err
        self.flexes.append(req)


class TestMetrics(unittest.TestCase):
    def setUp(self):
        self.loop = asyncio.new_event_loop() # needed to initialize asyncio
        asyncio.set_event_loop(self.loop)
        self.client = FakeMetricsClient()
        self.model = NetworkMetricsModel(self.client)
        self.model.HOLDOFF_TIME = 1  # lower holdoff for faster tests

    def tearDown(self):
        self.loop.run_until_complete(self.model.close())
        self.loop.run_until_complete(self.loop.shutdown_asyncgens())
        self.loop.close()

    def _push_flexes(self, flexes):
        return self.loop.run_until_complete(self.model.push_flexes(flexes))

    def _wait_for_queue(self):
        """Wait for the flex model queue to be emptied."""
        self.loop.run_until_complete(self.model._flex_q.join())

    def _wait(self, sec):
        self.loop.run_until_complete(asyncio.sleep(sec))

    def check_req(self, pb, flexes):
        """Verify that a list of flexes is the same as a flex request."""
        req = metrics_pb2.FlexPushRequest()
        req.flexes.extend([self.model._flex_pb(x) for x in flexes])
        self.assertEqual(pb, req)

    def test_holdoff_1(self):
        """A small flex messages is sent after the 10s holdoff timer"""
        fm = FlexMetric('test', time.time(), {'a': 'b'}, {'f1': 100.0})
        self._push_flexes([fm])
        self._wait_for_queue()
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes[0], [fm])

    def test_holdoff_2(self):
        """Multiple messages are put in one request by the holdoff timer"""
        fm1 = FlexMetric('test1', time.time(), {'a': 'b'}, {'f1': 100.0})
        self._push_flexes([fm1])
        self._wait_for_queue()
        self._wait(0.1)
        fm2 = FlexMetric('test2', time.time(), {'a': 'c'}, {'f1': 101.0})
        self._push_flexes([fm2])
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes[0], [fm1, fm2])

    def test_large_push_1(self):
        """A push request with > 5000 messages is sent in two parts"""
        flexes = [FlexMetric('m', time.time(), {'a': 'b'}, {'f': float(i + 1)}) for i in range(5001)]
        self._push_flexes(flexes)
        self._wait_for_queue()
        self._wait(0.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes.pop(), flexes[:5000])
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes.pop(), flexes[5000:])

    def test_large_push_2(self):
        """A push request with 5000 messages is sent immediately"""
        flexes = [FlexMetric('m', time.time(), {'a': 'b'}, {'f': float(i + 1)}) for i in range(5000)]
        self._push_flexes(flexes)
        self._wait_for_queue()
        self._wait(0.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes.pop(), flexes)
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 0)

    def test_large_push_3(self):
        """A push request with > 2M size is sent in two parts"""
        flexes = [FlexMetric('m', time.time(), {'a': 'b', 'big': 'c' * 10000}, {'f': float(i + 1)}) for i in range(209)]
        self._push_flexes(flexes)
        self._wait_for_queue()
        self._wait(0.1)
        self.assertEqual(len(self.client.flexes), 1)
        req = self.client.flexes.pop()
        self.check_req(req, flexes[:208])
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes.pop(), flexes[208:])

    def test_large_push_4(self):
        """A second push request that goes over the limits is handled properly."""
        flexes = [FlexMetric('m', time.time(), {'a': 'b'}, {'f': float(i + 1)}) for i in range(5001)]
        self._push_flexes(flexes[:1])
        self._push_flexes(flexes[1:])
        self._wait_for_queue()
        self._wait(0.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes.pop(), flexes[:5000])
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes.pop(), flexes[5000:])

    def test_multi_batch_1(self):
        """A large request that has to be split into multiple batches works."""
        flexes = [FlexMetric('m', time.time(), {'a': 'b'}, {'f': float(i + 1)}) for i in range(15001)]
        self._push_flexes(flexes)
        self._wait_for_queue()
        self._wait(0.1)
        self.assertEqual(len(self.client.flexes), 3)
        self.check_req(self.client.flexes.pop(0), flexes[:5000])
        self.check_req(self.client.flexes.pop(0), flexes[5000:10000])
        self.check_req(self.client.flexes.pop(0), flexes[10000:15000])
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes.pop(), flexes[15000:])

    def test_multi_batch_2(self):
        """Try a very large set of metrics to exercise STM-2666"""
        flexes = [FlexMetric('m', time.time(), {'a': 'b'}, {'f': float(i + 1)}) for i in range(50000)]
        self._push_flexes(flexes)
        self._wait_for_queue()
        self._wait(0.1)
        self.assertEqual(len(self.client.flexes), 10)

    def test_push_failure_1(self):
        """Make sure the push is retried if a failure occurs."""
        self.client.error = RuntimeError('TEST')
        fm = FlexMetric('test', time.time(), {'a': 'b'}, {'f1': 100.0})
        self._push_flexes([fm])
        self._wait_for_queue()
        self._wait(1.1)
        self.assertEqual(len(self.client.flexes), 0)
        self._wait(1)
        self.assertEqual(len(self.client.flexes), 1)
        self.check_req(self.client.flexes[0], [fm])
