"""
Mikrotik API Client.
https://wiki.mikrotik.com/wiki/Manual:API
"""
import asyncio
from binascii import hexlify, unhexlify
from collections import namedtuple
from hashlib import md5
import itertools
from keyword import iskeyword
import logging
import ssl
import struct


class MikrotikApiError(Exception):
    pass


# TODO unit testable.  Make test code able to replace with test data,
# and some method to write test data.
# Ideally test as much as possible, just stub out read/write from socket?
# Maybe do a little stub server on localhost.


class MikrotikApiClient:
    """Query the Mikrotik API"""
    _ssl_context = None

    def __init__(self,
                 host,
                 username,
                 password,
                 enable_ssl=False,
                 port=None,
                 timeout=30,
                 ignore_permissions=False,
                 ignore_internal_errors=False,
                 record=False):
        self.host = host
        self.port = port or (8729 if enable_ssl else 8728)
        self._user = username
        self._pass = password
        self._timeout = timeout
        self._ignore_internal_errors = ignore_internal_errors
        self._ignore_permissions = ignore_permissions
        self._logged_in = False
        self._reader = None
        self._writer = None
        self._ssl = None
        self.record = [] if record else None

        if enable_ssl:
            if not MikrotikApiClient._ssl_context:
                # PPA-1362 Change the way sslctx is configured
                sslctx = MikrotikApiClient._ssl_context = ssl.create_default_context()
                sslctx.check_hostname = False
                sslctx.verify_mode = ssl.CERT_NONE
                sslctx.set_ciphers('ALL:@SECLEVEL=0')
            self._ssl = MikrotikApiClient._ssl_context

        # map to filter special chars in api keys
        self.trans_map = str.maketrans('-', '_', './')
        self.error = None  # this will be set if we hit an error we can continue from

    async def _get_results_data(self,
                                cmd,
                                include_disabled=False,
                                props=None,
                                query=None):
        """Run a command and return a list of namedtuple results."""
        if not self._logged_in:
            await self.login()
        results = []
        try:
            sentence = [cmd]
            if props:  # comma-separated list of property names to return
                sentence.append(f'=.proplist={props}')
            if query:  # query words
                sentence.extend(query)
            res = await self.talk(sentence)
        except MikrotikApiError as err:
            if str(err) == 'no such command prefix':
                return []
            raise
        for r in res:
            if r[0] == '!done':
                break
            elif not r or r[0] != '!re':
                print('unexpected result for {}: {}'.format(self.host, r))
                raise RuntimeError('unexpected result')
            row = {x[0][1:]: x[1] for x in r[1].items()}
            if include_disabled:
                results.append(row)
            else:
                if row.get('disabled') != 'true':
                    results.append(row)
        if results:
            keys = list(sorted(set(itertools.chain.from_iterable(results))))
            # keys cannot include special chars '/' , '.' or started with digit
            fields = [
                f"x_{str(x).translate(self.trans_map)}"
                if str(x[0]).isdigit() else str(x).translate(self.trans_map)
                for x in keys
            ]
            fields = [x.upper() if iskeyword(x) else x for x in fields]
            nt = namedtuple('Row', fields)
        return [nt(*[x.get(k) for k in keys]) for x in results]

    async def _get_results(self,
                           cmd,
                           include_disabled=False,
                           props=None,
                           query=None,
                           num_retries=1):
        err = None
        for retry_num in range(num_retries + 1):
            if retry_num > 0:
                await asyncio.sleep(2**retry_num)
            try:
                return await asyncio.wait_for(self._get_results_data(
                    cmd, include_disabled, props, query),
                                              timeout=self._timeout)
            except asyncio.CancelledError:
                logging.info('Cancelled "%s" operation for %s', cmd, self.host)
                await self.close()
                raise
            except asyncio.TimeoutError:
                logging.info('Operation "%s" to host %s timed out', cmd, self.host)
                await self.close()
                raise
            except Exception as e:
                if isinstance(e, MikrotikApiError):
                    if 'not enough permissions' in str(e):
                        self.error = e
                        if self._ignore_permissions:
                            logging.info(
                                'Not enough permissions for "%s" on %s, returning empty result.',
                                cmd, self.host)
                            return []  # just return empty results
                        raise  # no need to retry this
                    elif 'contact MikroTik support' in str(e):
                        self.error = e
                        if self._ignore_internal_errors:
                            logging.info(
                                '%s running command "%s" on %s, returning empty result.',
                                e, cmd, self.host)
                            return []  # just return empty results
                        raise  # no need to retry this
                logging.info('Error calling "%s" on host %s: %s (%s)', cmd, self.host,
                             type(e).__name__, e)
                err = e
                await self.close()
        logging.info('Mikrotik get "%s" to %s retried out', cmd, self.host)
        raise err

    async def connect(self):
        self._reader, self._writer = await asyncio.open_connection(self.host,
                                                                   self.port,
                                                                   ssl=self._ssl)

    async def close(self):
        if self._writer:
            try:
                self._writer.close()
            except Exception:
                pass
            self._writer = self._reader = None
            self._logged_in = False

    def encode_word(self, w):
        """Encode a word to write to the server."""
        if isinstance(w, str):
            w = w.encode('iso-8859-1')
        l = len(w)
        if l < 0x80:
            buf = struct.pack('>B', l)
        elif l < 0x4000:
            buf = struct.pack('>I', l | 0x8000)[2:]
        elif l < 0x200000:
            buf = struct.pack('>I', l | 0xc00000)[1:]
        elif l < 0x10000000:
            buf = struct.pack('>I', l | 0xe0000000)
        else:
            buf = struct.pack('>BI', 0xff, l)
        return buf + w

    async def write(self, words):
        """Write a list of "words" to the server."""
        buf = b''
        for w in words:
            buf += self.encode_word(w)
        self._writer.write(buf + b'\0')
        return len(words)

    async def login(self):
        """Login to the Mikrotik API."""
        if self._logged_in:
            return
        r = await self.talk(
            ["/login", f'=name={self._user}', f'=password={self._pass}'])
        if r:
            repl, attrs = r[0]
            if attrs:  # old style challenge:
                chal = unhexlify(attrs['=ret'])
                md = md5()
                md.update(b'\x00')
                md.update(self._pass.encode())
                md.update(chal)
                r = await self.talk([
                    "/login", f'=name={self._user}',
                    "=response=00" + hexlify(md.digest()).decode('ascii')
                ])
        self._logged_in = True

    async def talk(self, words):
        """Send words to the server and get a result back."""
        if not self._reader:
            await self.connect()
        if await self.write(words) == 0:
            return
        r = []
        err = None
        record = None if self.record is None else [words, []]
        if record:
            self.record.append(record)
        while True:
            i = await self.read()
            if record:
                record[1].append(i)
            if len(i) == 0:
                continue
            reply = i[0]
            attrs = {}
            for w in i[1:]:
                j = w.find('=', 1)
                if (j == -1):
                    attrs[w] = ''
                else:
                    attrs[w[:j]] = w[j + 1:]
            r.append((reply, attrs))
            if reply == '!done':
                if err:
                    raise err
                return r
            elif reply == '!trap':  # wait for done to return
                err = MikrotikApiError(attrs.get('=message'))
            elif reply == '!fatal':
                reason = list(attrs.keys())[0]
                if reason == "too many commands before login":
                    # PPA-923 we want to let the caller continue in this case.
                    await self.close()
                    return await self.talk(words)
                raise MikrotikApiError(reason)

    async def read(self):
        """Read a result from the API."""
        r = []
        while 1:
            w = await self.read_word()
            if w == '':
                return r
            r.append(w)

    async def read_word(self):
        """Read a string from the server."""
        # length
        buf = await self._reader.read(1)
        if not buf:  # disconnected
            await self.close()
            raise MikrotikApiError('Connection closed by remote end')
        byte_value = buf[0]
        if (byte_value & 0x80) != 0x80:
            buf = bytes([byte_value & ~0x80])
            i = 0
        elif (byte_value & 0xC0) != 0xC0:
            buf = bytes([byte_value & ~0xC0])
            i = 1
        elif (byte_value & 0xE0) != 0xE0:
            buf = bytes([byte_value & ~0xE0])
            i = 2
        elif (byte_value & 0xF0) != 0xF0:
            buf = bytes([byte_value & ~0xF0])
            i = 3
        else:
            buf = b''
            i = 4
        if i:
            buf += await self._reader.read(i)
        buf = b'\x00' * (4 - len(buf)) + buf
        length = struct.unpack('>I', buf)[0]
        # value
        ret = ''
        while len(ret) < length:
            s = await self._reader.read(length - len(ret))
            if s == b'':
                raise RuntimeError("connection closed by remote end")
            ret += s.decode('iso-8859-1')
        return ret

    async def get(self, cmd, include_disabled=False, props=None, query=None):
        """Run a command and get the results as a list of Row namedtuples."""
        return await self._get_results(cmd, include_disabled, props, query)
