"""Test airtime calculations scenarios"""
import asyncio
from collections import namedtuple, Counter
import os, os.path
import sys
import time
import unittest
from uuid import UUID

from preseem import FakeNetworkMetadataModel, FakeNetworkMetricsModel, NetworkMetadataReference, NetworkMetadataEntity, Reference

sys.path.append(os.path.dirname(__file__))  # let code under test load stubs
import ap
import ap_data
from fake_snmp import FakeSnmpClient
from ne import NetworkElementRegistry
from preseem import network_element_update
from devices.ubnt import snmpBadMibError, snmpBadUpdateError, snmpUnsupportedFwError
from device_test import get_datafiles
from preseem_protobuf.network_poller import network_poller_pb2
from preseem.network_element_update import ErrorContext, NetworkElementUpdate, NetworkElement
from fake_context import FakeHttpClient, fake_ping, FakeContext

import pprint
def pp(name, d, indent=8):
    print(' ' * indent + '{} = '.format(name) + '{\n ', end='')
    for line in pprint.pformat(d, indent=indent)[1:-1].split('\n'):
        print(' ' * indent + line)
    print(' ' * indent + '}')

class Ap(ap.Ap):
    """We subclass the Ap object to add some test synchronization helpers"""
    async def poll(self):
        FakeContext.ap_event.clear()
        await super().poll()
        FakeContext.ap_event.set()

# What I want here is a Module I can add to ap.
# Really I would rather not even do any SNMP here, we're not testing that logic
# and in theory SNMP would be optional anyway.
# I could monkey-patch ap.poll() - sort of already did inAp above..could do
# similar and just skip all the polling logic, which I think may do away
# with any snmp requirements?
# YES - then just call self.load() to do the AP algorithms, after I set
# self.module to my fake module.
#
# May not even need anything fake: just construct a couple AP objects and
# call load on them?

# Ok so try this:
# 1. Create Ap object manually instead of going through registry.
# 2. Set module on it to a FakeModule that I can control to return whatever
# stations list I want, move those around.
class FakeModule:
    def __init__(self):
        self.stations = []
        self.model = None
        self.mode = "ap"
        self.station_poller = None

    async def poll(self):
        pass

    async def poll_station(self, sta, snmp_client, http_client):
        if self.station_poller:
            return self.station_poller(sta, snmp_client, http_client)
        return sta
    

class FakeClient:
    def __init__(self, host=None):
        self.dft = {}
        self.host = host


class TestAP(unittest.TestCase):
    def setUp(self):
        apdd = ap_data.ApData('ap_info.yaml')
        self.datafiles = get_datafiles(apdd)
        self.ctx = FakeContext()
        self.loop = asyncio.new_event_loop() # needed to initialize asyncio
        self.reg = NetworkElementRegistry(self.ctx, {}, Ap)
        self.reg.ne_type = 'ap'
        self.ctx.company_uuid = UUID('4a24ad99-d502-3846-a8de-6c202c665a37')
        asyncio.set_event_loop(self.loop)
        self._await(self.ctx.start())

    def tearDown(self):
        self._await(self.reg.close())
        self._await(self.ctx.close())
        self.loop.close()

    def _await(self, co):
        return self.loop.run_until_complete(co)

    def wait_for_ap_poll(self):
        """Wait for an AP poll to complete."""
        self._await(self.ctx.ap_event.wait())


    def test_airtime_calc(self):
        """Test the creation of network_element_update_msg and correctness of airtime calculations """
        df = self.datafiles.get(('cambium-epmp.epmp.ePMP 2000.4.5.6.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        total_tx_airtime = 0
        total_rx_airtime = 0

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())
        ne_u= ap.module.network_element_update_msg
        ne_u.data.uptime += 60


        for n in range(len(ne_u.data.interfaces[1].links)):
            ne_u.data.interfaces[1].links[n].in_octets += 6000000
            ne_u.data.interfaces[1].links[n].out_octets += 3000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = 100000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = 200000000

        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[1].radios[0].streams[0].links)):
            link_rx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_airtime
            link_tx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_airtime
            rx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate)
            tx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate)

            self.assertEqual(round(link_rx_airtime,5), 6000000 * 8 * 100 / (rx_link_rate * 60))
            self.assertEqual(round(link_tx_airtime,5), 3000000 * 8 * 100 / (tx_link_rate * 60))
            total_rx_airtime += link_rx_airtime
            total_tx_airtime += link_tx_airtime

        # port rf level tx_airtime equals reported tx_frame utilization and rx_airtime equals sum of stations' airtimes for ePMP models
        self.assertEqual(ne_u.data.interfaces[1].radios[0].rx_airtime, round(total_rx_airtime,5))
        self.assertEqual(ne_u.data.interfaces[1].radios[0].tx_airtime, ne_u.data.interfaces[1].radios[0].tx_frutl)

    def test_airtime_calc2(self):
        """Test the creation of network_element_update_msg and correctness of airtime calculations in case of pre-existing SM level airtimes """
        df = self.datafiles.get(('ubnt.airmax-ac.Rocket 5AC PTMP.v8.5.0.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        total_tx_airtime = 0
        total_rx_airtime = 0

        self._await(ap.poll())
        self.wait_for_ap_poll()
        ne_u= ap.module.network_element_update_msg

        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[0].radios[0].streams[0].links)):
            link_rx_airtime = ne_u.data.interfaces[0].radios[0].streams[0].links[n].rx_airtime
            link_tx_airtime = ne_u.data.interfaces[0].radios[0].streams[0].links[n].tx_airtime
            self.assertIsNotNone(link_rx_airtime)
            self.assertIsNotNone(link_tx_airtime)
            total_rx_airtime += link_rx_airtime
            total_tx_airtime += link_tx_airtime

        self.assertEqual(round(ne_u.data.interfaces[0].radios[0].rx_airtime, 5), round(total_rx_airtime,5))
        self.assertEqual(round(ne_u.data.interfaces[0].radios[0].tx_airtime, 5),  round(total_tx_airtime,5))

    def test_airtime_calc3(self):
        """Test the airtime calculations in case no conn_time available from Stations. We use AP's uptime instead"""
        ap = Ap(self.ctx, 'TestNE', 'Test Element', 'site', None, None, 'ap', None)
        self._await(ap.log_ne_update()) # Resolve race condition with NEU message
        ap.module = FakeModule()
        ap.errctx = getattr(ap.module, 'errctx', ErrorContext())
        ap.network_element_update_msg = getattr(ap.module, 'network_element_update_msg', NetworkElementUpdate(errctx=ap.errctx))
        Sta = namedtuple('Sta', ('mac_address'))
        ap.module.stations = [Sta('08:55:31:0a:d4:38')]
        msg = ap.network_element_update_msg
        msg.data.uptime = 1000
        wlan = msg.data.interfaces.add()
        link, radio = wlan.links.add(), wlan.radios.add()
        link.in_octets = 402579957
        link.out_octets = 505139803
        link.poller_hash = b'\x01'
        stream = radio.streams.add()
        wlink = stream.links.add()
        wlink.tx_link_rate = 2310000000
        wlink.rx_link_rate = 2810000000
        wlink.poller_hash = b'\x01'

        total_tx_airtime = 0
        total_rx_airtime = 0
        ap.stations = ap.module.stations
        self._await(ap.log_ne_update())
        ne_u = ap.network_element_update_msg

        ne_u.data.uptime += 60
        for n in range(len(ne_u.data.interfaces[0].links)):
            ne_u.data.interfaces[0].links[n].in_octets += 6000000
            ne_u.data.interfaces[0].links[n].out_octets += 3000000
            ne_u.data.interfaces[0].radios[0].streams[0].links[n].rx_link_rate = 10000000
            ne_u.data.interfaces[0].radios[0].streams[0].links[n].tx_link_rate = 20000000
        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[0].radios[0].streams[0].links)):
            link_rx_airtime = ne_u.data.interfaces[0].radios[0].streams[0].links[n].rx_airtime
            link_tx_airtime = ne_u.data.interfaces[0].radios[0].streams[0].links[n].tx_airtime
            rx_link_rate = float(ne_u.data.interfaces[0].radios[0].streams[0].links[n].rx_link_rate)
            tx_link_rate = float(ne_u.data.interfaces[0].radios[0].streams[0].links[n].tx_link_rate)
    
            self.assertEqual(link_rx_airtime, 6000000 * 8 * 100 / (rx_link_rate * 60))
            self.assertEqual(link_tx_airtime, 3000000 * 8 * 100 / (tx_link_rate * 60))
            total_rx_airtime += link_rx_airtime
            total_tx_airtime += link_tx_airtime

        self.assertEqual(ne_u.data.interfaces[0].radios[0].rx_airtime,  total_rx_airtime)
        self.assertEqual(ne_u.data.interfaces[0].radios[0].tx_airtime,  total_tx_airtime)

    def test_airtime_calc4(self):
        """Test airtime for wireless links in case of multiple streams """
        df = self.datafiles.get(('cambium-canopy.450.PMP 450.15.1.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)

        self._await(ap.poll())
        self.wait_for_ap_poll()
        ne_u = ap.module.network_element_update_msg
        link_len = len(ne_u.data.interfaces[1].links)

        # To populate missing rx_link_rate in .snmp for airtime_calc
        for n in range(link_len):
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = 10000000
            ne_u.data.interfaces[1].radios[0].streams[1].links[n].rx_link_rate = 10000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = 20000000
            ne_u.data.interfaces[1].radios[0].streams[1].links[n].tx_link_rate = 20000000
        self._await(ap.log_ne_update())

        for n in range(link_len):
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = 10000000
            ne_u.data.interfaces[1].radios[0].streams[1].links[n].rx_link_rate = 10000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = 20000000
            ne_u.data.interfaces[1].radios[0].streams[1].links[n].tx_link_rate = 20000000
            ne_u.data.interfaces[1].links[n].in_octets += 6000000
            ne_u.data.interfaces[1].links[n].out_octets += 3000000
        ne_u.data.uptime += 60
        self._await(ap.log_ne_update())

        for n in range(link_len):
            streamv_rx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_airtime
            streamh_rx_airtime = ne_u.data.interfaces[1].radios[0].streams[1].links[n].rx_airtime
            streamv_tx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_airtime
            streamh_tx_airtime = ne_u.data.interfaces[1].radios[0].streams[1].links[n].tx_airtime
            rx_link_rate = ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate
            tx_link_rate = ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate
            exp_rx_airtime = 6000000 * 8 * 100 / (rx_link_rate * 60)
            exp_tx_airtime = 3000000 * 8 * 100 / (tx_link_rate * 60)
            self.assertEqual(streamv_rx_airtime, exp_rx_airtime)
            self.assertEqual(streamh_rx_airtime, exp_rx_airtime)
            self.assertEqual(streamv_tx_airtime, exp_tx_airtime)
            self.assertEqual(streamh_tx_airtime, exp_tx_airtime)
            self.assertEqual(streamv_rx_airtime,streamh_rx_airtime)
            self.assertEqual(streamv_tx_airtime,streamh_tx_airtime)
        
        # port rf level airtime equals reported frame utilization for Canopy 450 (both directions) nop multiplexing gain is available
        self.assertEqual(ne_u.data.interfaces[1].radios[0].tx_airtime,  ne_u.data.interfaces[1].radios[0].tx_frutl)
        self.assertEqual(ne_u.data.interfaces[1].radios[0].rx_airtime,  ne_u.data.interfaces[1].radios[0].rx_frutl)

    def test_airtime_calc5(self):
        """Test airtime for directly reported Frutl multiplied by Gain """
        df = self.datafiles.get(('cambium-canopy.450m.PMP 450m.CANOPY 20.3.1 AP.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())
        ne_u = ap.module.network_element_update_msg
   
        # port rf level airtime equals reported frame utilization for Canopy 450m with addition of mutiplexing gain (both directions)
        self.assertEqual(round(ne_u.data.interfaces[1].radios[0].tx_airtime,5),  round(ne_u.data.interfaces[1].radios[0].tx_frutl*ne_u.data.interfaces[1].radios[0].tx_multiplexing_gain, 5))
        self.assertEqual(round(ne_u.data.interfaces[1].radios[0].rx_airtime, 5),  round(ne_u.data.interfaces[1].radios[0].rx_frutl*ne_u.data.interfaces[1].radios[0].rx_multiplexing_gain, 5))

    def test_airtime_calc6(self):
        """Test the creation of network_element_update_msg and correctness of airtime calculations if one of station was reseted/rebooted """
        df = self.datafiles.get(('cambium-epmp.epmp.ePMP 2000.4.5.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        total_tx_airtime = 0
        total_rx_airtime = 0

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())
        ne_u= ap.module.network_element_update_msg


        for n in range(len(ne_u.data.interfaces[1].links)):
            ne_u.data.interfaces[1].links[n].in_octets += 6000000
            ne_u.data.interfaces[1].links[n].out_octets += 3000000
            if n==0:
                ne_u.data.interfaces[1].links[0].in_octets = 1 # set some low value
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = 100000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = 200000000
        ne_u.data.uptime += 60
        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[1].radios[0].streams[0].links)):
            link_rx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_airtime
            link_tx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_airtime
            rx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate)
            tx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate)

            self.assertEqual(round(link_tx_airtime,5), 3000000 * 8 * 100 / (tx_link_rate * 60))
            if n==0:
                self.assertIsNone(link_rx_airtime)
            else:
                self.assertEqual(round(link_rx_airtime,5), 6000000 * 8 * 100 / (rx_link_rate * 60))
            total_tx_airtime += link_tx_airtime
            if link_rx_airtime is not None:
                total_rx_airtime += link_rx_airtime

        # port rf level tx_airtime equals reported tx_frame utilization and rx_airtime equals sum of stations' airtimes for ePMP models
        self.assertEqual(ne_u.data.interfaces[1].radios[0].rx_airtime,  round(total_rx_airtime,5))
        self.assertEqual(ne_u.data.interfaces[1].radios[0].tx_airtime,  ne_u.data.interfaces[1].radios[0].tx_frutl)


    def test_airtime_calc7(self):
        """Test the creation of network_element_update_msg and correctness of airtime calculations == None , No Mikrotik API """
        df = self.datafiles.get(('mikrotik.mikrotik-ap.RB411.6.25.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        total_tx_airtime = 0
        total_rx_airtime = 0

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())
        ne_u= ap.module.network_element_update_msg


        for n in range(len(ne_u.data.interfaces[1].links)):
            ne_u.data.interfaces[1].links[n].in_octets += 6000000
            ne_u.data.interfaces[1].links[n].out_octets += 3000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = 20000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = 10000000
        ne_u.data.uptime += 60
        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[1].radios[0].streams[0].links)):
            link_tx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_airtime
            tx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate)
            link_rx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_airtime
            rx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate)
            self.assertIsNone(link_rx_airtime) 
            self.assertIsNone(link_tx_airtime)

        # port rf level airtime equals sum of stations' airtime for Mikrotik models (both directions) is None 
        self.assertIsNone(ne_u.data.interfaces[1].radios[0].rx_airtime)
        self.assertIsNone(ne_u.data.interfaces[1].radios[0].tx_airtime)


    def test_airtime_calc8(self):
        """Test the creation of network_element_update_msg and correctness of airtime calculations == None , No Mikrotik API """
        df = self.datafiles.get(('mikrotik.mikrotik-ap.RB411.6.25.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        total_tx_airtime = 0
        total_rx_airtime = 0
        ap.in_out_bytes = {}
        

        self._await(ap.poll())
        ne_u= ap.module.network_element_update_msg
        for n in range(len(ne_u.data.interfaces[1].links)):
            ne_u.data.interfaces[1].links[n].in_octets = 0
            ne_u.data.interfaces[1].links[n].out_octets = 0

        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[1].links)):
            ne_u.data.interfaces[1].links[n].in_octets += 0
            ne_u.data.interfaces[1].links[n].out_octets += 0
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = 20000000
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = 10000000
        ne_u.data.uptime += 60
        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[1].radios[0].streams[0].links)):
            link_tx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_airtime
            tx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate)
            link_rx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_airtime
            rx_link_rate = float(ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate)
            self.assertIsNone(link_rx_airtime)
            self.assertIsNone(link_tx_airtime)


        # port rf level airtime equals sum of stations' airtime for Mikrotik models (both directions)
        self.assertIsNone(ne_u.data.interfaces[1].radios[0].rx_airtime)
        self.assertIsNone(ne_u.data.interfaces[1].radios[0].tx_airtime)


    def test_airtime_calc9(self):
        """Test the creation of network_element_update_msg and correctness of airtime calculations None if no link_rate"""
        df = self.datafiles.get(('mikrotik.mikrotik-ap.RB411.6.25.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        ap.in_out_bytes = {}
        

        self._await(ap.poll())
        ne_u= ap.module.network_element_update_msg
        for n in range(len(ne_u.data.interfaces[1].links)):
            ne_u.data.interfaces[1].links[n].in_octets = 0
            ne_u.data.interfaces[1].links[n].out_octets = 0
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = None
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = None

        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[1].links)):
            ne_u.data.interfaces[1].links[n].in_octets += 0
            ne_u.data.interfaces[1].links[n].out_octets += 0
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_link_rate = None
            ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_link_rate = None
        ne_u.data.uptime += 60
        self._await(ap.log_ne_update())

        for n in range(len(ne_u.data.interfaces[1].radios[0].streams[0].links)):
            link_tx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].tx_airtime
            link_rx_airtime = ne_u.data.interfaces[1].radios[0].streams[0].links[n].rx_airtime
            self.assertEqual(link_rx_airtime, None)
            self.assertEqual(link_tx_airtime, None)

        # port rf level airtime equals sum of stations' airtime for Mikrotik models (both directions)
        self.assertEqual(ne_u.data.interfaces[1].radios[0].rx_airtime, None)
        self.assertEqual(ne_u.data.interfaces[1].radios[0].tx_airtime, None)


    def test_airtime_calc10(self):
        """Test the airtime calculations when the peers have the same AP counter_source"""
        df = self.datafiles.get(('cambium-epmp.epmp3000.ePMP 3000.4.3.2.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        self._await(ap.log_ne_update())
        
        # Poll 1: Set to AP counter_source
        self._await(ap.poll())
        msg = ap.module.network_element_update_msg
        msg.peers[0].counter_source = "AP"

        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets = 1000
            msg.data.interfaces[1].links[n].out_octets = 1000

        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        msg.data.interfaces[1].radios[0].streams[0].links[0].connected_time += 60 
        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets += 6000000
            msg.data.interfaces[1].links[n].out_octets += 3000000

        # Poll 2: Keep AP as counter_source
        msg.data.uptime += 60
        msg.peers[0].counter_source = "AP"
        self._await(ap.log_ne_update())

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        self.assertIsNotNone(msg.data.interfaces[1].radios[0].rx_airtime)
        self.assertIsNotNone(msg.data.interfaces[1].radios[0].tx_airtime)


    def test_airtime_calc11(self):
        """Test the airtime calculations when the peers have the same CPE counter_source"""
        df = self.datafiles.get(('cambium-epmp.epmp3000.ePMP 3000.4.3.2.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        self._await(ap.log_ne_update())
        
        # Poll 1: Set to CPE counter_source
        self._await(ap.poll())
        msg = ap.module.network_element_update_msg
        msg.peers[0].counter_source = "CPE"

        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets = 1000
            msg.data.interfaces[1].links[n].out_octets = 1000

        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        msg.data.interfaces[1].radios[0].streams[0].links[0].connected_time += 60 
        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets += 6000000
            msg.data.interfaces[1].links[n].out_octets += 3000000

        # Poll 2: Keep as CPE counter_source
        msg.data.uptime += 60
        msg.peers[0].counter_source = "CPE"
        self._await(ap.log_ne_update())

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        self.assertIsNotNone(msg.data.interfaces[1].radios[0].rx_airtime)
        self.assertIsNotNone(msg.data.interfaces[1].radios[0].tx_airtime)


    def test_airtime_calc12(self):
        """Test the airtime calculations when the peers AP and CPE have different counter_source"""
        df = self.datafiles.get(('cambium-epmp.epmp3000.ePMP 3000.4.3.2.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        self._await(ap.log_ne_update())
        
        # Poll 1: Set to AP counter_source
        self._await(ap.poll())
        msg = ap.module.network_element_update_msg
        msg.peers[0].counter_source = "AP"

        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets = 1000
            msg.data.interfaces[1].links[n].out_octets = 1000

        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        msg.data.interfaces[1].radios[0].streams[0].links[0].connected_time += 60 
        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets += 6000000
            msg.data.interfaces[1].links[n].out_octets += 3000000

        # Poll 2: Change to CPE counter_source
        msg.data.uptime += 60
        msg.peers[0].counter_source = "CPE"
        self._await(ap.log_ne_update())

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        self.assertIsNone(msg.data.interfaces[1].radios[0].rx_airtime)
        self.assertIsNone(msg.data.interfaces[1].radios[0].tx_airtime)


    def test_airtime_calc13(self):
        """Test the airtime calculations when the peers CPE and AP have different counter_source"""
        df = self.datafiles.get(('cambium-epmp.epmp3000.ePMP 3000.4.3.2.01', 'snmp'))
        ne = self._await(self.reg.set('TestNE', {'name': 'Test Element', 'site': 'My Site', 'path': df.path}))
        ap = Ap(self.ctx, ne.id, ne.name, ne.site, ne.snmp_ne, ne.snmp_ne_v3, ne.ne_type, ne.holdoff)
        self._await(ap.log_ne_update())
        
        # Poll 1: Set to CPE counter_source
        self._await(ap.poll())
        msg = ap.module.network_element_update_msg
        msg.peers[0].counter_source = "CPE"

        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets = 1000
            msg.data.interfaces[1].links[n].out_octets = 1000

        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        msg.data.interfaces[1].radios[0].streams[0].links[0].connected_time += 60 
        for n in range(len(msg.data.interfaces[1].links)):
            msg.data.interfaces[1].links[n].in_octets += 6000000
            msg.data.interfaces[1].links[n].out_octets += 3000000

        # Poll 2: Change to AP counter_source
        msg.data.uptime += 60
        msg.peers[0].counter_source = "AP"
        self._await(ap.log_ne_update())

        self._await(ap.poll())
        self.wait_for_ap_poll()
        self._await(ap.log_ne_update())

        self.assertIsNone(msg.data.interfaces[1].radios[0].rx_airtime)
        self.assertIsNone(msg.data.interfaces[1].radios[0].tx_airtime)
