"""Implement counter-to-interval conversion and interpolation as specified in STM-6055."""
from collections import namedtuple
import time

from preseem_protobuf.network_poller import network_poller_pb2

interface_counters = [
    "in_octets",
    "in_ucast_pkts",
    "in_nucast_pkts",
    "in_discards",
    "in_errors",
    "in_unknown_protos",
    "out_octets",
    "out_ucast_pkts",
    "out_nucast_pkts",
    "out_discards",
    "out_errors",
]

link_counters = [
    "in_octets",
    "in_ucast_pkts",
    "in_nucast_pkts",
    "in_discards",
    "in_errors",
    "in_unknown_protos",
    "out_octets",
    "out_ucast_pkts",
    "out_nucast_pkts",
    "out_discards",
    "out_errors",
    "out_retransmits",
]


def process_msg_to_send(msg):
    # HS-85 Sometimes we don't have aggregate "port-level" metrics on the parent side
    # of a set of PTMP links.  In that case, we can populate this with the sum of the
    # link-level metrics.
    for intf in list(
            msg.data.interfaces) + [i for m in msg.data.modules for i in m.interfaces]:
        if len(intf.links) == 0:
            continue
        for field in interface_counters:
            if not intf.HasField(field):
                link_fields = [
                    getattr(l, field) for l in intf.links if l.HasField(field)
                ]
                if link_fields:
                    setattr(intf, field, sum(link_fields))
    return msg


def process_element(cur_inc, cur_int, new_inc, new_int, int_factor):
    """Update interval counts given previous state."""
    cur_inc_interfaces = {
        x.poller_hash: x
        for x in list(cur_inc.interfaces) +
        [i for m in cur_inc.modules for i in m.interfaces]
    } if cur_inc else {}
    cur_int_interfaces = {
        x.poller_hash: x
        for x in list(cur_int.interfaces) +
        [i for m in cur_int.modules for i in m.interfaces]
    } if cur_int else {}
    new_int_interfaces = {
        x.poller_hash: x
        for x in list(new_int.interfaces) +
        [i for m in new_int.modules for i in m.interfaces]
    }
    for new_inc_intf in list(
            new_inc.interfaces) + [i for m in new_inc.modules for i in m.interfaces]:
        cur_inc_intf = cur_inc_interfaces.get(new_inc_intf.poller_hash)
        cur_int_intf = cur_int_interfaces.get(new_inc_intf.poller_hash)
        new_int_intf = new_int_interfaces.get(new_inc_intf.poller_hash)
        for field in interface_counters:
            # If either message doesn't have the field, both intervals get nulled out.
            # The next one after should have data.
            counter1 = counter2 = None
            if cur_inc_intf and cur_inc_intf.HasField(field):
                counter1 = getattr(cur_inc_intf, field)
            if new_inc_intf.HasField(field):
                counter2 = getattr(new_inc_intf, field)
            if counter1 is not None and counter2 is not None and counter1 <= counter2:
                delta = counter2 - counter1
                val0 = int(delta * int_factor)
                val1 = delta - val0
                if cur_int_intf.HasField(field):
                    setattr(cur_int_intf, field, getattr(cur_int_intf, field) + val0)
                if new_int_intf.HasField(field):
                    setattr(new_int_intf, field, val1)
            else:
                # We don't have both counter values, or the counter reset.
                # Treat this as a counter reset either way.
                if counter1 and counter2 == 0:
                    # special case where we got a zero counter: this _could_
                    # be a polling error and not a reset on the device side;
                    # we will still set the current message to null but we
                    # copy the previous value across in case the counters
                    # continue where they left off.  If it was in fact a
                    # reset, we'll handle the reset when we get a non-zero
                    # value.
                    setattr(new_inc_intf, field, counter1)
                if cur_int_intf:
                    cur_int_intf.ClearField(field)
                new_int_intf.ClearField(field)

        cur_inc_links = {
            x.poller_hash: x
            for x in cur_inc_intf.links
        } if cur_inc_intf else {}
        cur_int_links = {
            x.poller_hash: x
            for x in cur_int_intf.links
        } if cur_int_intf else {}
        new_int_links = {x.poller_hash: x for x in new_int_intf.links}
        for new_inc_link in new_inc_intf.links:
            cur_inc_link = cur_inc_links.get(new_inc_link.poller_hash)
            cur_int_link = cur_int_links.get(new_inc_link.poller_hash)
            new_int_link = new_int_links.get(new_inc_link.poller_hash)
            for field in link_counters:
                counter1 = counter2 = None
                if cur_inc_link and cur_inc_link.HasField(field):
                    counter1 = getattr(cur_inc_link, field)
                if new_inc_link.HasField(field):
                    counter2 = getattr(new_inc_link, field)
                if counter1 is not None and counter2 is not None and counter1 <= counter2:
                    delta = counter2 - counter1
                    val0 = int(delta * int_factor)
                    val1 = delta - val0
                    if cur_int_link and cur_int_link.HasField(field):
                        setattr(cur_int_link, field,
                                getattr(cur_int_link, field) + val0)
                    if new_int_link.HasField(field):
                        setattr(new_int_link, field, val1)
                else:
                    # We don't have both counter values, or the counter reset.
                    # Treat this as a counter reset either way.
                    if counter1 and counter2 == 0:
                        setattr(new_inc_link, field, counter1)
                    if cur_int_link:
                        cur_int_link.ClearField(field)
                    new_int_link.ClearField(field)


class UpdateIntervalManager:
    """Class responsible for counter->interval conversion"""
    State = namedtuple('State', ('orig_msg', 'new_msg', 'time'))

    def __init__(self):
        self.elem_state = {}

    def handle_message(self, msg, id_key=None):
        """Handle a message from the network_poller_grpc post method."""
        if not msg.time or not msg.data.poller_hash:
            return msg  # only apply this logic to messages with data.

        # we use a unique key from the message as well as an (optional) id_key
        # passed in.  The id_key lets us disambiguate cases where the poller
        # sends two messages for an element (e.g. it's entered as an AP but
        # is actually a CPE radio of another AP).
        elem_id = (msg.data.poller_hash, id_key)
        cur = self.elem_state.get(elem_id) or self.State(None, None, None)

        interval_end = msg.time.seconds - msg.time.seconds % 60 + 60
        interval_start = cur.new_msg.interval_end.seconds if cur.new_msg else interval_end - 60
        new = self.State(msg, network_poller_pb2.NetworkElementUpdate(),
                         time.monotonic())
        new.new_msg.CopyFrom(msg)
        new.new_msg.interval_start.FromSeconds(interval_start)
        new.new_msg.interval_end.FromSeconds(interval_end)

        delta_t = msg.time.seconds - cur.new_msg.time.seconds if cur.new_msg else None
        if delta_t is not None and delta_t <= 0.0:
            raise RuntimeError('Timestamp did not increment')
        int_factor = None if cur.new_msg is None else (
            cur.new_msg.interval_end.seconds -
            cur.new_msg.time.seconds) / (msg.time.seconds - cur.new_msg.time.seconds)
        if cur.new_msg and cur.new_msg.interval_end == new.new_msg.interval_end:
            # The new message is within the interval of the current message.
            # Update the current message in progress and keep the new message.
            process_element(cur.orig_msg.data, cur.new_msg.data, msg.data,
                            new.new_msg.data, 1.0)

            cur.new_msg.time.FromSeconds(msg.time.seconds)
            new = self.State(msg, cur.new_msg, new.time)
            cur = self.State(None, None, None)
        else:
            # Regular message that spans the interval.  Process it.
            process_element(cur.orig_msg.data if cur.orig_msg else None,
                            cur.new_msg.data if cur.new_msg else None, msg.data,
                            new.new_msg.data, int_factor)
            cur_inc_peers = {
                x.poller_hash: x
                for x in cur.orig_msg.peers
            } if cur.orig_msg else {}
            cur_int_peers = {
                x.poller_hash: x
                for x in cur.new_msg.peers
            } if cur.new_msg else {}
            new_int_peers = {x.poller_hash: x for x in new.new_msg.peers}
            for peer in msg.peers:
                cur_inc_peer = cur_inc_peers.get(peer.poller_hash)
                cur_int_peer = cur_int_peers.get(peer.poller_hash)
                new_int_peer = new_int_peers.get(peer.poller_hash)
                process_element(cur_inc_peer, cur_int_peer, peer, new_int_peer,
                                int_factor)

        self.elem_state.pop(elem_id, None)  # keep the dict in LRU order
        self.elem_state[elem_id] = new
        return process_msg_to_send(cur.new_msg) if cur.new_msg else None

    def timeout_messages(self, expire_time):
        """Remove any messages we are holding that are older than expire_time seconds."""
        now = time.monotonic()
        dels = []
        for elem_id, state in self.elem_state.items():
            # see how old time is.  If expired, delete it.
            if state.time + expire_time < now:
                dels.append(elem_id)
            else:
                break  # we're past any expired records in the dict
        for elem_id in dels:
            del self.elem_state[elem_id]
