#!/usr/bin/env python
"""
Tool to explore the network via SNMP and other types of polling.
"""
import asyncio
import importlib
import logging
import os, os.path
import subprocess
import sys
import time

from tabulate import tabulate
import yaml

CFG_PATH = '/etc/preseem/netmeta.yaml'
NETPOLL_LIB_PATH = '/usr/lib/preseem/netpoll'
sys.path.append(NETPOLL_LIB_PATH)

from preseem.sentry import SentryClient
from preseem.util import check_required_dependency
from context import Context, SENTRY_KEY

import ap
import ap_data
from check import check
from ne import Element, NetworkElementRegistry
from preseem import parse_ip_value, ping
from preseem import NetworkMetricsClient, NetworkMetricsModel, FlexMetric
from preseem import FakeNetworkMetricsModel, FakeNetworkMetadataModel
from preseem import NetworkMetadataClient, PreseemNetworkMetadataModel, NetworkMetadataReference
from preseem import HttpClient
from preseem import TelemetryApiClient
import version

_instances = {}


def load_config(path):
    if not os.path.isfile(path):
        raise RuntimeError('config file not found: {}'.format(path))
    with open(path) as f:
        cfg = yaml.safe_load(f.read())
    if not cfg:
        raise RuntimeError('invalid config file: {}'.format(path))
    return cfg


def disable_httpx_in_vcr():
    """This function overrides import as a way to prevent vcrpy from loading
       httpx.  This is needed because httpx is broken in vcrpy (for streaming)
       so we have implemented this ourselves, but we still want to use it
       to record regular aiohttp data."""

    def importer(name, globals=None, locals=None, fromlist=(), level=0):
        if name == 'httpx' and globals and globals.get('__package__') == 'vcr':
            raise ImportError()
        return importlib.__import__(name, globals, locals, fromlist, level)

    __builtins__.__dict__['__import__'] = importer


async def main():
    import argparse
    parser = argparse.ArgumentParser(description='Netpoll Tool')
    parser.add_argument('-a', '--aps')
    parser.add_argument('-c', '--community')
    parser.add_argument('-d', '--daemon', action='store_true')
    parser.add_argument('-f', '--config-file', default=CFG_PATH)
    parser.add_argument('-j', '--job', action='store_true')
    parser.add_argument('-k', '--api-key')
    parser.add_argument('-l', '--libpath')
    parser.add_argument('-m', '--metrics', action='store_true')
    parser.add_argument('-M', '--load-metadata', action='store_true')
    parser.add_argument('-r', '--routers')
    parser.add_argument('-s', '--switches')
    parser.add_argument('-t', '--test-mode', action='store_true')
    parser.add_argument('-U', '--http-username')
    parser.add_argument('-P', '--http-password')
    parser.add_argument('-n', '--noaction', action='store_true')
    parser.add_argument('--check', action='store_true')
    parser.add_argument('--make-test-data', action='store_true')
    parser.add_argument('--show-config', nargs='*')
    parser.add_argument('--test-data')
    parser.add_argument('--vardir', default='/var/preseem')
    parser.add_argument('-v', '--verbose', help='verbose', action='count', default=0)
    parser.add_argument('command', nargs='?')
    args = parser.parse_args()

    apdd = ap_data.ApData(args.vardir + '/ap_info.yaml')

    cfg = load_config(args.config_file)
    if args.show_config is not None:
        snmp_config = cfg.get('snmp') or {}
        snmp_config.update(cfg.get('network-poller') or {})
        if not args.show_config:
            print(yaml.dump(snmp_config))
        else:  # specific args: format {name}[,{default}],...
            vals = []
            for arg in args.show_config:
                arg = arg.split(',')
                default = None if len(arg) == 1 else arg[1]
                vals.append(snmp_config.get(arg[0], default))
            print('\n'.join([str(x or '') for x in vals]))
            sys.exit(0 if any(vals) else 1)
        sys.exit(0)
    instances = cfg.get('instances')
    snmp_cfg = cfg.get('snmp') or {}
    source_cfg = cfg.get('source')
    netpoller_cfg = cfg.get('network-poller')

    if args.community:
        snmp_cfg['community'] = args.community.split(',')
    if args.http_username:
        snmp_cfg['http_user'] = args.http_username
    if args.http_password:
        snmp_cfg['http_pass'] = args.http_password

    if args.test_mode or args.make_test_data or args.check:
        if args.make_test_data:
            disable_httpx_in_vcr()
        args.job = True  # test mode only makes sense for a job

    server_name = ''
    if args.api_key:  # temporary instance defined on the command-line
        server_name = 'cli'
        api_key = args.api_key or (instances[0].api_key if instances else None)
        inst_cfg = {'api_key': api_key}
        inst_cfg['noaction'] = True
        inst_cfg['verbose'] = args.verbose
        inst = _instances['cli'] = Context('cli',
                                           inst_cfg,
                                           apdd,
                                           snmp_cfg,
                                           args.job,
                                           args.libpath,
                                           netpoller_cfg,
                                           source_cfg,
                                           test_mode=args.test_mode,
                                           make_test_data=args.make_test_data,
                                           test_data=args.test_data,
                                           version=version.__version__)
    elif instances:
        if not isinstance(instances, dict):
            raise RuntimeError('Error in config: instances must be a dict')
        for name, inst_cfg in instances.items():
            server_name = name
            inst_cfg[
                'noaction'] = args.noaction or args.aps or args.routers or args.switches or args.make_test_data
            inst_cfg['verbose'] = max(inst_cfg.get('verbose') or 0, args.verbose or 0)
            inst = _instances[name] = Context(name,
                                              inst_cfg,
                                              apdd,
                                              snmp_cfg,
                                              args.job,
                                              args.libpath,
                                              netpoller_cfg,
                                              source_cfg,
                                              test_mode=args.test_mode,
                                              make_test_data=args.make_test_data,
                                              test_data=args.test_data,
                                              version=version.__version__)

    if netpoller_cfg and netpoller_cfg.get(
            'sentry_enabled') is not False and server_name not in ('', 'cli'):
        try:
            sentry_client = SentryClient(sentry_key=SENTRY_KEY, server_name=server_name)
            # use the requirements.txt under network-poller directory
            check_required_dependency(os.path.join(NETPOLL_LIB_PATH,
                                                   "requirements.txt"))
        except Exception as e:
            sentry_client.error(description=e, frequency=86400, tags={})

    ctx = list(_instances.values())[0]

    if args.aps:  # define APs on command line
        if not args.daemon:
            ctx.job = True
        for apcfg in args.aps.split(',') or []:
            await ctx.ap_registry.set(apcfg, {
                'host': apcfg,
                'noaction': True,
                'verbose': args.verbose
            })

    if args.routers:  # define routers on command line
        if not args.daemon:
            ctx.job = True
        for rtrcfg in args.routers.split(',') or []:
            await ctx.router_registry.set(rtrcfg, {
                'host': rtrcfg,
                'noaction': True,
                'verbose': args.verbose
            })

    if args.switches:  # define switches on command line
        if not args.daemon:
            ctx.job = True
        for swcfg in args.switches.split(',') or []:
            await ctx.switch_registry.set(swcfg, {
                'host': swcfg,
                'noaction': True,
                'verbose': args.verbose
            })

    load_elements = False if args.aps or args.routers or args.switches or args.check else True
    load_metadata = True if args.load_metadata else load_elements
    await inst.start(load_metadata=load_metadata, load_elements=load_elements)

    if args.check:
        await check(ctx)

    if args.test_mode:
        # Diff cpe_mac references and print out the diffs
        diffs = []
        old_cpe_macs = inst.src_model.refs.get('cpe_mac') or {}
        new_cpe_macs = inst.netmeta_model.refs.get('cpe_mac') or {}
        for mac in sorted(set(list(old_cpe_macs) + list(new_cpe_macs))):
            old_cpe_mac = old_cpe_macs.get(mac)
            new_cpe_mac = new_cpe_macs.get(mac)
            if old_cpe_mac == new_cpe_mac:
                continue
            diffs.append((mac, old_cpe_mac.attributes if old_cpe_mac else '',
                          new_cpe_mac.attributes if new_cpe_mac else ''))
        print('CPE MAC DIFFS')
        if diffs:
            print(tabulate(diffs, headers=('CPE MAC', 'Old', 'New')))
        print('{} diffs'.format(len(diffs)))
        print('NETWORK METADATA DIFFS')
        diffs = inst.src_model.diff(inst.netmeta_model)
        if diffs:
            print(tabulate(diffs, headers=('IP Address', 'Old', 'New')))
        print('{} diffs'.format(len(diffs)))

    await ctx.close()

    # cleaner shutdown
    try:
        tasks = asyncio.all_tasks()
    except AttributeError:  # later python version
        # Older python versions can return finished tasks https://github.com/python/cpython/pull/7174
        tasks = [t for t in asyncio.Task.all_tasks() if not t.done()]
    for t in tasks:
        t.cancel()
        try:
            await t
        except asyncio.CancelledError:
            pass


if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=logging.INFO)
    try:
        asyncio.run(main())
    except AttributeError:  # python < 3.7
        loop = asyncio.get_event_loop()
        task = loop.create_task(main())
        try:
            loop.run_until_complete(task)
        except RuntimeError as e:
            logging.error(e)
            sys.exit(1)
        except KeyboardInterrupt as e:
            try:
                tasks = asyncio.all_tasks()
            except AttributeError:  # later python version
                tasks = asyncio.Task.all_tasks()
            for t in tasks:
                t.cancel()
        finally:
            loop.run_until_complete(loop.shutdown_asyncgens())
            loop.close()
