"""
Common classes used by this package.
"""
import asyncio
from collections import namedtuple
from icmplib import async_ping
import ipaddress
import logging
import os
import time
import pkg_resources

import aiohttp
import aiohttp.web
from packaging import version

from .http_client import HttpClient
from .mikrotik import MikrotikApiError
from .ssh_client import SshClient
import asyncssh  # has to come after ssh_client import to suppress warnings
from prometheus_client.exposition import generate_latest

# Set an IPv6 enabled interface for utility code to use.
# Functions (i.e. parse_ip_value) will check its "ipv6_enabled" var.
_ipv6_settings = None


def set_ipv6_settings(obj):
    global _ipv6_settings
    _ipv6_settings = obj


_ping_sem = None  # used to prevent too many open files error


async def ping(host):
    """Ping a host.  Returns rtt if we can ping it, 0.0 otherwise."""
    global _ping_sem
    if not _ping_sem:
        _ping_sem = asyncio.Semaphore(200)
    try:
        async with _ping_sem:
            result = await async_ping(host, count=1, privileged=False)
        return result.avg_rtt
    except asyncio.CancelledError:
        raise
    except Exception as err:
        logging.warning('%s error pinging host %s: %s', type(err).__name__, host, err)


async def http_probe(host,
                     http_client_cls,
                     methods=('https', 'http'),
                     ports=None,
                     **kwargs):
    """
    Probe a host to see if it's listening for http(s).
    Returns a http_client_cls object if it is.
    This requires the caller to pass in the HttpClient class (for unit testing)
    """

    async def handler(response):
        if 200 <= response.status <= 400:  # accept a redirect, its listening.
            if response.status == 302 and response.request_info.url.scheme == 'http':
                loc = response.headers.get('location')
                if loc and loc.startswith('https'):
                    return False
            return True
        return False

    urls = [
        f'{m}://{host}{":" if port else ""}{port}' for m in methods
        for port in ports or ['']
    ]
    clients = [
        http_client_cls(url,
                        verify_ssl=False,
                        conn_timeout=5,
                        read_timeout=5,
                        connection_limit=len(urls),
                        **kwargs) for url in urls
    ]
    result = None
    try:
        for i, r in enumerate(await asyncio.gather(*[
                x.request('', num_retries=0, allow_redirects=False, handler=handler)
                for x in clients
        ],
                                                   return_exceptions=True)):
            if r and not result and not isinstance(r, Exception):
                result = clients[i]
            else:
                await clients[i].close()
    except asyncio.CancelledError:
        # If the gather operation is cancelled, we have to clean up here.
        [await x.close() for x in clients]
        raise
    return result


def normalize_mac(mac):
    """Try to return a normalized MAC address from input."""
    mac_address = None
    mac = mac.strip()
    if mac:
        try:
            s = '%012x' % int(mac, 16)
            mac_address = ':'.join(a + b for a, b in zip(s[::2], s[1::2]))
        except (TypeError, ValueError):  # t probably has some delimiters
            try:
                m = mac.replace('-', ':').replace('.',
                                                  ':').replace(' ',
                                                               ':').replace("'", '')
                mac_address = ':'.join(['%02x' % int(x, 16) for x in m.split(':')])
            except (TypeError, ValueError):
                raise ValueError('Invalid MAC address "{}"'.format(mac))
        if mac_address is None or len(mac_address) != 17:
            raise ValueError('Invalid MAC address "{}"'.format(mac))
    return mac_address


def parse_ip_value(val, max_width=None, strict=False):
    """Parse an IP address or CIDR and return a list of IP address strings."""
    ipv6_enabled = _ipv6_settings.ipv6_enabled if _ipv6_settings else None
    if max_width is None:
        max_width = 20 if ipv6_enabled else 24  # STM-5079
    ip_addrs = []
    if val:
        for ip in val.split(','):
            if ip.strip() == '':
                continue
            if '-' in ip:  # range
                ip_range = ip.split('-')
                if len(ip_range) != 2:
                    raise ValueError('Bad IP range: {}'.format(ip))
                try:
                    ip1 = ipaddress.ip_address(ip_range[0])
                    try:
                        ip2 = ipaddress.ip_address(ip_range[1])
                    except ValueError:  # handle range for just last octet
                        ipd = int(ip_range[1][1:]) if ip_range[1][0] == '.' else int(
                            ip_range[1])
                        ipl = ip_range[0].split('.')
                        ipl[3] = str(ipd)
                        ip2 = ipaddress.ip_address('.'.join(ipl))
                except ValueError:
                    raise ValueError('Bad IP range addresses: {}'.format(ip))
                try:
                    for net in ipaddress.summarize_address_range(ip1, ip2):
                        if ipv6_enabled:
                            ip_addrs.append(str(net))
                        else:
                            ip_addrs.extend([str(x) for x in net])
                except ValueError as e:
                    raise ValueError('Bad IP range {}: {}'.format(ip, e))
                continue
            ip_subnet = ip.lstrip().split(' ')[0].strip().split('/')
            if len(ip_subnet) < 2 or ip_subnet[1] in ('32', '255.255.255.255'):
                try:
                    ip = ipaddress.ip_address(ip_subnet[0])
                    if ip.compressed == '0.0.0.0':
                        raise ValueError('Bad IP address: {}'.format(ip_subnet[0]))
                except ValueError:
                    raise ValueError('Invalid IP address: {}'.format(val))
                ip_addrs.append(ip_subnet[0])
            else:
                try:
                    net = ipaddress.ip_network('/'.join(ip_subnet), strict=strict)
                except ValueError as e:
                    raise ValueError('Error parsing IP network {} ({})'.format(ip, e))
                if net.prefixlen >= max_width:
                    if ipv6_enabled or net.version == 6:  # STM-4192 map IPv6 subnets
                        ip_addrs.append(str(net))
                    else:
                        ip_addrs.extend([str(x) for x in net.hosts()])
                        if str(net.network_address) not in ip_addrs:
                            ip_addrs.append(str(net.network_address))  # STM-3988
                else:
                    raise ValueError('Error parsing subnet {}, too big'.format(ip))
    return ip_addrs


class TelemetryApiClient():
    PRESEEM_DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ'

    def __init__(self, api_key, http_client_cls=HttpClient, company_id=None):
        self._client = http_client_cls('https://api.preseem.com/v1/telemetry',
                                       username=api_key)
        self.company_id = company_id

    async def close(self):
        await self._client.close()

    async def query(self, query, company=None):
        """Execute a raw query."""
        query = {'q': query}
        company = company or self.company_id
        if company:
            query['company'] = company
        try:
            data = await self._client.request('/query', params=query)
        except aiohttp.ClientResponseError as err:
            if err.code == 404:  # no stats present
                return []
            raise
        return data or []

    async def get_events(self, measurement, query, t0=None, t1=None):
        if t0:
            query['t0'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t0)
        if t1:
            query['t1'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t1)
        if self.company_id:
            query['company'] = self.company_id
        try:
            data = await self._client.request('/events/{}'.format(measurement),
                                              params=query)
        except aiohttp.ClientResponseError as err:
            if err.code == 404:  # no stats present
                return []
            raise
        if data:
            # use first and last because they could have different cols (asc+desc)
            name_map = {x: x.replace('-', '_') for x in set(data[0]) | set(data[-1])}
            nt = namedtuple(measurement, sorted(name_map.values()))
            return [nt(**{p: x.get(f) for f, p in name_map.items()}) for x in data]
        return []

    async def get_series(self, measurement, query, t0=None, t1=None):
        if t0:
            query['t0'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t0)
        if t1:
            query['t1'] = time.strftime(self.PRESEEM_DATETIME_FORMAT, t1)
        if self.company_id:
            query['company'] = self.company_id
        try:
            data = await self._client.request('/series/{}'.format(measurement),
                                              params=query)
        except aiohttp.ClientResponseError as err:
            if err.code == 404:  # no stats present
                return []
            raise
        for result in data.get('results') or []:
            # only support a single result and assume all series same cols
            r = {}
            series = result.get('series')
            if not series:
                return None
            Row = namedtuple(series[0]['name'], series[0]['columns'])
            tags = series[0].get('tags')
            if tags:
                Dataset = namedtuple('Dataset_{}'.format(series[0]['name']),
                                     sorted(tags))
            for s in series:
                rows = [Row(*row) for row in s['values']]
                if tags:
                    dataset = Dataset(*[x[1] for x in sorted(s['tags'].items())])
                    r[dataset] = rows
                else:
                    return rows
            return r
        return None


class PreseemApiKey(object):
    """Wrapper class for API key information from preseem-web."""

    def __init__(self, api_key, token):
        self.api_key = api_key
        self._hdr = {'Authorization': f'token {token}'}
        self._loaded = None
        self._data = {}
        self._api = None

    async def close(self):
        await self._api.close()
        self._api = None

    @property
    def company_id(self):
        return self._data.get('company_id')

    @property
    def company_name(self):
        return self._data.get('company_name')

    @property
    def company_uuid(self):
        return self._data.get('company_uuid')

    async def load(self):
        if self._api is None:
            self._api = HttpClient('https://app.preseem.com', verify_ssl=False)

        now = asyncio.get_event_loop().time()
        if self._loaded and self._loaded + 3600 > now:
            return
        self._data = await self._api.request('/api/v2/api-key-lookup/',
                                             headers=self._hdr,
                                             params={'api_key': self.api_key})
        self._loaded = now


class uint64(int):
    """This is used to hint to the GRPC code that it should create a uint"""


class counter32(int):
    """Helper class to manage 32-bit counter math"""

    def __init__(self):
        self.reset()

    def __int__(self):
        return self.value

    def reset(self):
        """Reset this counter to 0."""
        self.value = 0
        self.prev = None

    def update(self, val):
        """update the current value with a new value."""
        if self.prev is None:
            pass  # leave value at 0 for the initial set
        elif val < self.prev:  # handle rollover
            self.value += (0xffffffff - self.prev) + int(val)
        else:  # typical case, counter incremented
            self.value += int(val) - self.prev
        self.prev = int(val)


async def get_node_versions(telemetry_api):
    """Get a dict of {node_id -> version} from the telemetry API."""
    data = {}
    rows = await telemetry_api.get_events('platform', {
        'limit': '1000,desc',
        'fields': 'node,netdev-manager'
    })
    for row in rows:
        data[row.node] = row
    return {x[0]: version.parse(x[1].netdev_manager.split()[0]) for x in data.items()}


def running_in_container() -> bool:
    """Check different hints to see if we're running in a container."""
    return os.getpid() == 1 or os.path.exists('/.dockerenv')


class MetricsServer(object):
    """Run a basic server for metrics and keepalives."""

    def __init__(self, port=9150):
        self.port = port
        self.app_metrics = aiohttp.web.Application()
        self.app_metrics.router.add_route('GET', '/metrics', self.get)

    async def start(self):
        self.server = await asyncio.get_event_loop().create_server(
            self.app_metrics.make_handler(access_log=None), '0.0.0.0', self.port)

    async def get(self, request):
        return aiohttp.web.Response(body=generate_latest(), content_type="text/plain")


def run_tool_script(coroutine):
    """Run a Presem tool script - setup the environment properly etc."""

    async def run(coroutine):
        """Async wrapper to run the tool script."""
        try:
            await coroutine
        except asyncio.CancelledError:
            pass
        except Exception as err:
            print(f'Error: {err}')

    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=logging.INFO)
    loop = asyncio.get_event_loop()
    task = loop.create_task(coroutine)
    try:
        loop.run_until_complete(task)
        task.result()
    except SystemExit:
        pass
    except KeyboardInterrupt:
        for t in asyncio.Task.all_tasks():
            t.cancel()
    except Exception as err:
        print("ERR", err)
    finally:
        try:
            task.result()
        except BaseException:
            pass
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()


def check_required_dependency(req_file: str):
    """Check the current environment have all dependencies required from a given file.

    Args:
        req_file (str): The file path of the requirement.txt.
    """
    pkg_resources.require(open(req_file).read())


async def _mikrotik_probe(host,
                          mikrotik_api_cls,
                          creds,
                          raise_errors,
                          num_retries=1,
                          **kwargs):
    """Probe the mikrotik API of a host using (username, password) tuples."""
    client = None
    last_err = None
    if creds:
        mikrotik_api = mikrotik_api_cls(host, None, None, **kwargs)
        try:
            for retry_num in range(num_retries + 1):
                if retry_num > 0:
                    await asyncio.sleep(2**retry_num)
                try:
                    await asyncio.wait_for(mikrotik_api.connect(), 5.0)
                    for cred in creds:
                        last_err = None
                        mikrotik_api._user = cred[0]
                        mikrotik_api._pass = cred[1]
                        try:
                            await asyncio.wait_for(mikrotik_api.login(), 15.0)
                        except MikrotikApiError as err:
                            last_err = err
                            continue  # no match to username/password
                        except asyncio.CancelledError:
                            raise
                        except Exception as err:
                            last_err = err
                            return None
                        client = mikrotik_api
                        return client
                except (MikrotikApiError, OSError, asyncio.TimeoutError) as err:
                    pass
        finally:
            if not client:
                await mikrotik_api.close()
            if last_err and raise_errors:
                raise last_err
    return client


async def mikrotik_api_probe(host,
                             mikrotik_api_cls,
                             creds,
                             raise_errors=False,
                             **kwargs):
    """Try to connect to a Mikrotik host using a list of (username, password) credentials."""
    pend = [
        asyncio.create_task(x) for x in [
            _mikrotik_probe(
                host, mikrotik_api_cls, creds, raise_errors, enable_ssl=x, **kwargs)
            for x in (True, False)
        ]
    ]
    client = None
    while pend:
        done, pend = await asyncio.wait(pend, return_when=asyncio.FIRST_COMPLETED)
        while done:
            task = done.pop()
            try:
                result = task.result()
            except asyncio.CancelledError:
                [x.cancel() for x in done]
                [x.cancel() for x in pend]
                raise
            except:
                if raise_errors and not client and not done and not pend:
                    raise  # this is the last result, all results were errors.
                result = None
            if result:
                if client:
                    await result.close()
                else:
                    client = result
        if client:
            while pend:
                task = pend.pop()
                task.cancel()
                try:
                    await task  # let the task cleanup
                except asyncio.CancelledError:
                    pass
    return client


async def _ssh_probe(host, port, ssh_client_cls, creds, raise_errors, **kwargs):
    """Probe the SSL server of a host using (username, password) tuples."""
    client = None
    last_err = None
    for cred in creds:
        last_err = None
        ssh_client = ssh_client_cls(host,
                                    port=port,
                                    username=cred[0],
                                    password=cred[1],
                                    **kwargs)
        try:
            await asyncio.wait_for(ssh_client._open(), 5.0)
            client = ssh_client
            break
        except asyncssh.PermissionDenied as err:
            last_err = err
            await ssh_client.close()
        except (asyncssh.Error, OSError, asyncio.TimeoutError) as err:
            last_err = err
            await ssh_client.close()
            break  # connectivity error, no point in trying other creds
        except Exception:
            await ssh_client.close()
            raise
    if last_err and raise_errors:
        raise last_err
    return client


async def ssh_probe(host, ssh_client_cls, creds, raise_errors=False, **kwargs):
    """Try to connect to a SSH server using list of (username, password) credentials."""
    # it's done this way to support multiple ports in the future
    pend = [
        asyncio.create_task(x) for x in
        [_ssh_probe(host, None, ssh_client_cls, creds, raise_errors, **kwargs)]
    ]
    client = None
    while pend:
        done, pend = await asyncio.wait(pend, return_when=asyncio.FIRST_COMPLETED)
        while done:
            task = done.pop()
            result = task.result()
            if result:
                if client:
                    await result.close()
                else:
                    client = result
        if client:
            while pend:
                task = pend.pop()
                task.cancel()
                try:
                    await task  # let the task cleanup
                except asyncio.CancelledError:
                    pass
    return client
