#!/usr/bin/env python
"""
Tool to explore the network via SNMP and other types of polling.
"""
import asyncio
import importlib
import logging
import os, os.path
import subprocess
import sys
import time

from tabulate import tabulate
import yaml

CFG_PATH = '/etc/preseem/netmeta.yaml'
NETPOLL_LIB_PATH = '/usr/lib/preseem/netpoll'
sys.path.append(NETPOLL_LIB_PATH)

from preseem.sentry import SentryClient
from context import Context, SENTRY_KEY

import ap
import ap_data
from check import check
from ne import Element, NetworkElementRegistry
from preseem import parse_ip_value, ping
from preseem import NetworkMetricsClient, NetworkMetricsModel, FlexMetric
from preseem import FakeNetworkMetricsModel, FakeNetworkMetadataModel
from preseem import NetworkMetadataClient, PreseemNetworkMetadataModel, NetworkMetadataReference
from preseem import HttpClient
from preseem import TelemetryApiClient
import version
from preseem.netmeta_config_loader import load_config

_instances = {}


def suppress_stderr():
    """PPA-1411: Redirect stderr to /dev/null to suppress error messages for cleaner --check."""
    sys.stderr.flush()  # Flush any pending output
    devnull = os.open('/dev/null', os.O_WRONLY)
    os.dup2(devnull, sys.stderr.fileno())
    os.close(devnull)


def disable_httpx_in_vcr():
    """This function overrides import as a way to prevent vcrpy from loading
       httpx.  This is needed because httpx is broken in vcrpy (for streaming)
       so we have implemented this ourselves, but we still want to use it
       to record regular aiohttp data."""

    def importer(name, globals=None, locals=None, fromlist=(), level=0):
        if name == 'httpx' and globals and globals.get('__package__') == 'vcr':
            raise ImportError()
        return importlib.__import__(name, globals, locals, fromlist, level)

    __builtins__.__dict__['__import__'] = importer


async def main():
    import argparse
    parser = argparse.ArgumentParser(description='Netpoll Tool')
    parser.add_argument('-a', '--aps')
    parser.add_argument('-c', '--community')
    parser.add_argument('-d', '--daemon', action='store_true')
    parser.add_argument('-f', '--config-file', default=CFG_PATH)
    parser.add_argument('-j', '--job', action='store_true')
    parser.add_argument('-k', '--api-key')
    parser.add_argument('-l', '--libpath')
    parser.add_argument('-m', '--metrics', action='store_true')
    parser.add_argument('-M', '--load-metadata', action='store_true')
    parser.add_argument('-o', '--olts')
    parser.add_argument('-r', '--routers')
    parser.add_argument('-s', '--switches')
    parser.add_argument('-t', '--test-mode', action='store_true')
    parser.add_argument('-U', '--http-username')
    parser.add_argument('-P', '--http-password')
    parser.add_argument('-n', '--noaction', action='store_true')
    parser.add_argument('--check', action='store_true')
    parser.add_argument('--make-test-data', action='store_true')
    parser.add_argument('--make-test-data-partial', action='store_true')
    parser.add_argument('--show-config', nargs='*')
    parser.add_argument('--test-data')
    parser.add_argument('--vardir', default='/var/preseem')
    parser.add_argument('-v', '--verbose', help='verbose', action='count', default=0)
    parser.add_argument('command', nargs='?')
    args = parser.parse_args()

    apdd = ap_data.ApData(args.vardir + '/ap_info.yaml')

    cfg = load_config(args.config_file)
    if args.show_config is not None:
        snmp_config = cfg.get('snmp') or {}
        snmp_config.update(cfg.get('network-poller') or {})
        if not args.show_config:
            print(yaml.dump(snmp_config))
        else:  # specific args: format {name}[,{default}],...
            vals = []
            for arg in args.show_config:
                arg = arg.split(',')
                default = None if len(arg) == 1 else arg[1]
                vals.append(snmp_config.get(arg[0], default))
            print('\n'.join([str(x or '') for x in vals]))
            sys.exit(0 if any(vals) else 1)
        sys.exit(0)
    instances = cfg.get('instances')
    snmp_cfg = cfg.get('snmp') or {}
    source_cfg = cfg.get('source')
    netpoller_cfg = cfg.get('network-poller')

    if args.community:
        snmp_cfg['community'] = args.community.split(',')
    if args.http_username:
        snmp_cfg['http_user'] = args.http_username
    if args.http_password:
        snmp_cfg['http_pass'] = args.http_password

    if args.test_mode or args.make_test_data or args.make_test_data_partial or args.check:
        if args.make_test_data or args.make_test_data_partial:
            disable_httpx_in_vcr()
        args.job = True  # test mode only makes sense for a job

    server_name = ''
    if args.api_key:  # temporary instance defined on the command-line
        server_name = 'cli'
        api_key = args.api_key or (instances[0].api_key if instances else None)
        inst_cfg = {'api_key': api_key}
        inst_cfg['noaction'] = True
        inst_cfg['verbose'] = args.verbose
        inst = _instances['cli'] = Context(
            'cli',
            inst_cfg,
            apdd,
            snmp_cfg,
            args.job,
            args.libpath,
            netpoller_cfg,
            source_cfg,
            test_mode=args.test_mode,
            make_test_data=args.make_test_data,
            make_test_data_partial=args.make_test_data_partial,
            test_data=args.test_data,
            version=version.__version__)
    elif instances:
        if not isinstance(instances, dict):
            raise RuntimeError('Error in config: instances must be a dict')
        for name, inst_cfg in instances.items():
            server_name = name
            inst_cfg[
                'noaction'] = args.noaction or args.aps or args.olts or args.routers or args.switches or args.make_test_data or args.make_test_data_partial
            inst_cfg['verbose'] = max(inst_cfg.get('verbose') or 0, args.verbose or 0)
            inst = _instances[name] = Context(
                name,
                inst_cfg,
                apdd,
                snmp_cfg,
                args.job,
                args.libpath,
                netpoller_cfg,
                source_cfg,
                test_mode=args.test_mode,
                make_test_data=args.make_test_data,
                make_test_data_partial=args.make_test_data_partial,
                test_data=args.test_data,
                version=version.__version__)

    ctx = list(_instances.values())[0]

    if args.aps:  # define APs on command line
        if not args.daemon:
            ctx.job = True
        for apcfg in args.aps.split(',') or []:
            await ctx.ap_registry.set(apcfg, {
                'host': apcfg,
                'noaction': True,
                'verbose': args.verbose
            })

    if args.olts:  # define olts on command line
        if not args.daemon:
            ctx.job = True
        for oltcfg in args.olts.split(',') or []:
            await ctx.olt_registry.set(oltcfg, {
                'host': oltcfg,
                'noaction': True,
                'verbose': args.verbose
            })

    if args.routers:  # define routers on command line
        if not args.daemon:
            ctx.job = True
        for rtrcfg in args.routers.split(',') or []:
            await ctx.router_registry.set(rtrcfg, {
                'host': rtrcfg,
                'noaction': True,
                'verbose': args.verbose
            })

    if args.switches:  # define switches on command line
        if not args.daemon:
            ctx.job = True
        for swcfg in args.switches.split(',') or []:
            await ctx.switch_registry.set(swcfg, {
                'host': swcfg,
                'noaction': True,
                'verbose': args.verbose
            })

    load_elements = False if args.aps or args.olts or args.routers or args.switches or args.check else True
    load_metadata = True if args.load_metadata else load_elements
    await inst.start(load_metadata=load_metadata, load_elements=load_elements)

    if args.check:
        await check(ctx)

    if args.test_mode:
        # Diff cpe_mac references and print out the diffs
        diffs = []
        old_cpe_macs = inst.src_model.refs.get('cpe_mac') or {}
        new_cpe_macs = inst.netmeta_model.refs.get('cpe_mac') or {}
        for mac in sorted(set(list(old_cpe_macs) + list(new_cpe_macs))):
            old_cpe_mac = old_cpe_macs.get(mac)
            new_cpe_mac = new_cpe_macs.get(mac)
            if old_cpe_mac == new_cpe_mac:
                continue
            diffs.append((mac, old_cpe_mac.attributes if old_cpe_mac else '',
                          new_cpe_mac.attributes if new_cpe_mac else ''))
        print('CPE MAC DIFFS')
        if diffs:
            print(tabulate(diffs, headers=('CPE MAC', 'Old', 'New')))
        print('{} diffs'.format(len(diffs)))
        print('NETWORK METADATA DIFFS')
        diffs = inst.src_model.diff(inst.netmeta_model)
        if diffs:
            print(tabulate(diffs, headers=('IP Address', 'Old', 'New')))
        print('{} diffs'.format(len(diffs)))

    await ctx.close()

    # cleaner shutdown
    try:
        tasks = asyncio.all_tasks()
    except AttributeError:  # later python version
        # Older python versions can return finished tasks https://github.com/python/cpython/pull/7174
        tasks = [t for t in asyncio.Task.all_tasks() if not t.done()]
    for t in tasks:
        t.cancel()
        try:
            await t
        except asyncio.CancelledError:
            pass
    if args.check:
        suppress_stderr()


if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=logging.INFO)
    try:
        asyncio.run(main())
    except AttributeError:  # python < 3.7
        loop = asyncio.get_event_loop()
        task = loop.create_task(main())
        try:
            loop.run_until_complete(task)
        except RuntimeError as e:
            logging.error(e)
            sys.exit(1)
        except KeyboardInterrupt as e:
            try:
                tasks = asyncio.all_tasks()
            except AttributeError:  # later python version
                tasks = asyncio.Task.all_tasks()
            for t in tasks:
                t.cancel()
        finally:
            loop.run_until_complete(loop.shutdown_asyncgens())
            loop.close()
