"""
Common classes used by this package.
"""
import asyncio
from collections import namedtuple
import ipaddress
import logging
import time

import aiohttp
from packaging import version
from .http_client import HttpClient

# Set an IPv6 enabled interface for utility code to use.
# Functions (i.e. parse_ip_value) will check its "ipv6_enabled" var.
_ipv6_settings = None
def set_ipv6_settings(obj):
    global _ipv6_settings
    _ipv6_settings = obj

# Ping code, return ping resposne time if pingable, otherwise return 0
# The # of concurrent “ping” processes we will have running at one time
_run_sem = asyncio.Semaphore(32)
ping_time = 0
async def _run(cmd):
    async with _run_sem:
        proc = await asyncio.create_subprocess_shell(
            cmd,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE)
        try:
            stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=10)
            if stdout and proc.returncode == 0:
                # take avg value
                ping_time = str(stdout).rsplit('min/avg/max/mdev = ')[1].split('/')[1].split('/')[0]
            else:
                ping_time = 0

        except asyncio.TimeoutError:
            logging.warning("Asyncio timeout for %s command", cmd)
        return ping_time

# invoke 1 ping
async def ping(host):
    return await _run("ping -c 1 -w2 " + host)

async def http_probe(host, http_client_cls, methods=('https', 'http'), ports=None, **kwargs):
    """
    Probe a host to see if it's listening for http(s).
    Returns a http_client_cls object if it is.
    This requires the caller to pass in the HttpClient class (for unit testing)
    """

    async def handler(response):
        if 200 <= response.status <= 400:  # accept a redirect, its listening.
            return True
        return False

    urls = [f'{m}://{host}{":" if port else ""}{port}' for m in methods for port in ports or ['']]
    clients = [http_client_cls(url,
                               verify_ssl=False,
                               conn_timeout=5,
                               read_timeout=5,
                               connection_limit=len(urls),
                               **kwargs)
               for url in urls]
    result = None
    try:
        for i, r in enumerate(await asyncio.gather(*[x.request('', num_retries=0,
                                                               allow_redirects=False,
                                                               handler=handler) for x in clients],
                return_exceptions=True)):
            if not result and not isinstance(r, Exception):
                result = clients[i]
            else:
                await clients[i].close()
    except asyncio.CancelledError:
        # If the gather operation is cancelled, we have to clean up here.
        [await x.close() for x in clients]
        raise
    return result

def normalize_mac(mac):
    """Try to return a normalized MAC address from input."""
    mac_address = None
    mac = mac.strip()
    if mac:
        try:
            s = '%012x' % int(mac, 16)
            mac_address = ':'.join(a+b for a,b in zip(s[::2], s[1::2]))
        except ValueError: # t probably has some delimiters
            try:
                m = mac.replace('-', ':').replace('.', ':').replace(' ', ':').replace("'", '')
                mac_address = ':'.join(['%02x' % int(x, 16) for x in m.split(':')])
            except ValueError:
                raise ValueError('Invalid MAC address "{}"'.format(mac))
        if len(mac_address) != 17:
            raise ValueError('Invalid MAC address "{}"'.format(mac))
    return mac_address

def parse_ip_value(val, max_width=None, strict=False):
    """Parse an IP address or CIDR and return a list of IP address strings."""
    ipv6_enabled = _ipv6_settings.ipv6_enabled if _ipv6_settings else None
    if max_width is None:
        max_width = 20 if ipv6_enabled else 24  # STM-5079
    ip_addrs = []
    if val:
        for ip in val.split(','):
            if ip.strip() == '':
                continue
            if '-' in ip:  # range
                ip_range = ip.split('-')
                if len(ip_range) != 2:
                    raise ValueError('Bad IP range: {}'.format(ip))
                try:
                    ip1 = ipaddress.ip_address(ip_range[0])
                    try:
                        ip2 = ipaddress.ip_address(ip_range[1])
                    except ValueError:  # handle range for just last octet
                        ipd = int(ip_range[1][1:]) if ip_range[1][0] == '.' else int(ip_range[1])
                        ipl = ip_range[0].split('.')
                        ipl[3] = str(ipd)
                        ip2 = ipaddress.ip_address('.'.join(ipl))
                except ValueError:
                    raise ValueError('Bad IP range addresses: {}'.format(ip))
                try:
                    for net in ipaddress.summarize_address_range(ip1, ip2):
                        if ipv6_enabled:
                            ip_addrs.append(str(net))
                        else:
                            ip_addrs.extend([str(x) for x in net])
                except ValueError as e:
                    raise ValueError('Bad IP range {}: {}'.format(ip, e))
                continue
            ip_subnet = ip.lstrip().split(' ')[0].strip().split('/')
            if len(ip_subnet) < 2 or ip_subnet[1] in ('32', '255.255.255.255'):
                try:
                    ip = ipaddress.ip_address(ip_subnet[0])
                    if ip.compressed == '0.0.0.0':
                        raise ValueError('Bad IP address: {}'.format(ip_subnet[0]))
                except ValueError:
                    raise ValueError('Invalid IP address: {}'.format(val))
                ip_addrs.append(ip_subnet[0])
            else:
                try:
                    net = ipaddress.ip_network('/'.join(ip_subnet), strict=strict)
                except ValueError as e:
                    raise ValueError('Error parsing IP network {} ({})'.format(ip, e))
                if net.prefixlen >= max_width:
                    if ipv6_enabled or net.version == 6:  # STM-4192 always map IPv6 subnets
                        ip_addrs.append(str(net))
                    else:
                        ip_addrs.extend([str(x) for x in net.hosts()])
                        if str(net.network_address) not in ip_addrs:
                            ip_addrs.append(str(net.network_address))  # STM-3988
                else:
                    raise ValueError('Error parsing subnet {}, too big'.format(ip))
    return ip_addrs


class TelemetryApiClient():
    PRESEEM_DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ'
    def __init__(self, api_key, http_client_cls=HttpClient, company_id=None):
       self._client = http_client_cls('https://api.preseem.com/v1/telemetry',
                                      username=api_key)
       self.company_id = company_id

    async def close(self):
        await self._client.close()

    async def query(self, query, company=None):
        """Execute a raw query."""
        query = {'q': query}
        company = company or self.company_id
        if company:
            query['company'] = company
        try:
            data = await self._client.request('/query', params=query)
        except aiohttp.ClientResponseError as err:
            if err.code == 404:  # no stats present
                return []
            raise
        return data or []

    async def get_events(self, measurement, query, t0=None, t1=None):
        if t0:
            query['t0'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t0)
        if t1:
            query['t1'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t1)
        if self.company_id:
            query['company'] = self.company_id
        try:
            data = await self._client.request('/events/{}'.format(measurement),
                                              params=query)
        except aiohttp.ClientResponseError as err:
            if err.code == 404:  # no stats present
                return []
            raise
        if data:
            # use first and last because they could have different cols (asc+desc)
            name_map = {x: x.replace('-', '_') for x in set(data[0]) | set(data[-1])}
            nt = namedtuple(measurement, sorted(name_map.values()))
            return [nt(**{p: x.get(f) for f, p in name_map.items()}) for x in data]
        return []

    async def get_series(self, measurement, query, t0=None, t1=None):
        if t0:
            query['t0'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t0)
        if t1:
            query['t1'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t1)
        if self.company_id:
            query['company'] = self.company_id
        try:
            data = await self._client.request('/series/{}'.format(measurement),
                                              params=query)
        except aiohttp.ClientResponseError as err:
            if err.code == 404:  # no stats present
                return []
            raise
        for result in data.get('results') or []:
            # only support a single result and assume all series same cols
            r = {}
            series = result.get('series')
            if not series:
                return None
            Row = namedtuple(series[0]['name'], series[0]['columns'])
            tags = series[0].get('tags')
            if tags:
                Dataset = namedtuple('Dataset_{}'.format(series[0]['name']),
                                     sorted(tags))
            for s in series:
                rows = [Row(*row) for row in s['values']]
                if tags:
                    dataset = Dataset(*[x[1] for x in sorted(s['tags'].items())])
                    r[dataset] = rows
                else:
                    return rows
            return r
        return None

class uint64(int):
    """This is used to hint to the GRPC code that it should create a uint"""

class counter32(int):
    """Helper class to manage 32-bit counter math"""
    def __init__(self):
        self.reset()

    def __int__(self):
        return self.value
    
    def reset(self):
        """Reset this counter to 0."""
        self.value = 0
        self.prev = None

    def update(self, val):
        """update the current value with a new value."""
        if self.prev is None:
            pass  # leave value at 0 for the initial set
        elif val < self.prev:  # handle rollover
            self.value += (0xffffffff - self.prev) + int(val)
        else:  # typical case, counter incremented
            self.value += int(val) - self.prev
        self.prev = int(val)

async def get_node_versions(telemetry_api):
    """Get a dict of {node_id -> version} from the telemetry API."""
    data = {}
    rows = await telemetry_api.get_events('platform', {'limit': '1000,desc', 'fields': 'node,netdev-manager'})
    for row in rows:
        data[row.node] = row
    return {x[0]: version.parse(x[1].netdev_manager.split()[0]) for x in data.items()}

def run_tool_script(coroutine):
    """Run a Presem tool script - setup the environment properly etc."""
    async def run(coroutine):
        """Async wrapper to run the tool script."""
        try:
            await coroutine
        except asyncio.CancelledError:
            pass
        except Exception as err:
            print(f'Error: {err}')

    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=logging.INFO)
    loop = asyncio.get_event_loop()
    task = loop.create_task(coroutine)
    try:
        loop.run_until_complete(task)
        task.result()
    except SystemExit:
        pass
    except KeyboardInterrupt as e:
        for t in asyncio.Task.all_tasks():
            t.cancel()
    except Exception as err:
        print("ERR", err)
    finally:
        try:
            task.result()
        except BaseException as err:
            pass
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()
