"""
Mikrotik API Client.
"""
import asyncio
from binascii import hexlify, unhexlify
from collections import namedtuple
from hashlib import md5
import itertools
import logging
import ssl
import struct

class MikrotikApiError(Exception):
    pass

# TODO unit testable.  Make test code able to replace with test data,
# and some method to write test data.
# Ideally test as much as possible, just stub out read/write from socket?
# Maybe do a little stub server on localhost.

class MikrotikApiClient:
    """Query the Mikrotik API"""
    def __init__(self, host, username, password, enable_ssl=False, timeout=30, record=False):
        self.host = host
        self._user = username
        self._pass = password
        self._timeout = timeout
        self._logged_in = False
        self._reader = None
        self._writer = None
        self._ssl = None
        self.record = [] if record else None
        if enable_ssl:
            self._ssl = ssl.create_default_context()
            self._ssl.check_hostname = False
            self._ssl.verify_mode = ssl.CERT_NONE

    async def _get_results_data(self, cmd, include_disabled=False):
        """Run a command and return a list of namedtuple results."""
        if not self._logged_in:
            await self.login()
        results = []
        try:
            res = await self.talk([cmd])
        except MikrotikApiError as err:
            if str(err) == 'no such command prefix':
                return []
            raise
        for r in res:
            if r[0] == '!done':
                break
            elif not r or r[0] != '!re':
                print('unexpected result for {}: {}'.format(self.host, r))
                raise RuntimeError('unexpected result')
                continue
            row = {x[0][1:]: x[1] for x in r[1].items()}
            if include_disabled:
                results.append(row)
            else:
                if row.get('disabled') != 'true':
                    results.append(row)
        if results:
            keys = list(sorted(set(itertools.chain.from_iterable(results))))
            nt = namedtuple('Row', [x.replace('.', '').replace('-', '_') for x in keys])
        return [nt(*[x.get(k) for k in keys]) for x in results]

    async def _get_results(self, cmd, include_disabled=False):
        try:
            return await asyncio.wait_for(self._get_results_data(cmd, include_disabled), timeout=self._timeout)
        except asyncio.CancelledError:
            logging.info('Cancelled "%s" operation for %s', cmd, self.host)
            await self.close()
            raise
        except asyncio.TimeoutError:
            logging.info('Operation "%s" to host %s timed out', cmd, self.host)
            await self.close()
            raise
        except Exception as err:
            logging.info('Error calling "%s" on host %s: %s (%s)', cmd, self.host, type(err).__name__, err)
            await self.close()
            raise

    async def connect(self):
        try:
            port = 8729 if self._ssl else 8728
            self._reader, self._writer = await asyncio.open_connection(self.host, port, ssl=self._ssl)
        except Exception as err:  # eg. TimeoutError
            logging.warning('Failed to connect to Mikrotik API at %s:%s (%s)', self.host, port, type(err).__name__)
            raise

    async def close(self):
        if self._writer:
            try:
                self._writer.close()
            except Exception as err:
                pass
            self._writer = self._reader = None
            self._logged_in = False

    def encode_word(self, w):
        """Encode a word to write to the server."""
        if isinstance(w, str):
            w = w.encode('iso-8859-1')
        l = len(w)
        if l < 0x80:
            buf = struct.pack('>B', l)
        elif l < 0x4000:
            buf = struct.pack('>I', l | 0x8000)[2:]
        elif l < 0x200000:
            buf = struct.pack('>I', l | 0xc00000)[1:]
        elif l < 0x10000000:        
            buf = struct.pack('>I', l | 0xe0000000)
        else:                       
            buf = struct.pack('>BI', 0xff, l)
        return buf + w

    async def write(self, words):
        """Write a list of "words" to the server."""
        buf = b''
        for w in words:
            buf += self.encode_word(w)
        self._writer.write(buf + b'\0')
        return len(words)

    async def login(self):
        """Login to the Mikrotik API."""
        if self._logged_in:
            return
        r = await self.talk(["/login", f'=name={self._user}', f'=password={self._pass}'])
        if r:
            repl, attrs = r[0]
            if attrs:  # old style challenge:
                chal = unhexlify(attrs['=ret'])
                md = md5()
                md.update(b'\x00')
                md.update(self._pass.encode())
                md.update(chal)
                r = await self.talk(["/login", f'=name={self._user}',
                    "=response=00" + hexlify(md.digest()).decode('ascii')])
        self._logged_in = True

    async def talk(self, words):
        """Send words to the server and get a result back."""
        if not self._reader:
            await self.connect()
        if await self.write(words) == 0:
            return
        r = []
        err = None
        record = None if self.record is None else [words, []]
        if record:
            self.record.append(record)
        while True:
            i = await self.read();
            if record:
                record[1].append(i)
            if len(i) == 0: continue
            reply = i[0]
            attrs = {}
            for w in i[1:]:
                j = w.find('=', 1)
                if (j == -1):
                    attrs[w] = ''
                else:
                    attrs[w[:j]] = w[j+1:]
            r.append((reply, attrs))
            if reply == '!done':
                if err:
                    raise err
                return r
            elif reply == '!trap':  # wait for done to return
                err = MikrotikApiError(attrs.get('=message'))
            elif reply == '!fatal':
                raise MikrotikApiError(list(attrs.keys())[0])

    async def read(self):
        """Read a result from the API."""
        r = []
        while 1:
            w = await self.read_word()
            if w == '':
                return r
            r.append(w)
            
    _masks = [0x80, 0xc0, 0xe0, 0xf0]
    async def read_word(self):      
        """Read a string from the server."""
        # length
        buf = await self._reader.read(1)
        if not buf:  # disconnect?
            raise MikrotikApiError('Connection closed by remote end')
        for i, mask in enumerate(self._masks):
            if (ord(buf) & mask) != mask:
                buf = struct.pack('B', ord(buf) & ~mask)
                break
        else:
            buf = b''
            i = 4
        if i:
            buf += await self._reader.read(i)
        buf = b'\x00' * (4 - len(buf)) + buf
        length = struct.unpack('>I', buf)[0]
        # value
        ret = ''                    
        while len(ret) < length:    
            s = await self._reader.read(length - len(ret))
            if s == '':
                raise RuntimeError("connection closed by remote end")
            ret += s.decode('iso-8859-1')
        return ret

    async def get(self, cmd, include_disabled=False):
        """Run a command and get the results as a list of Row namedtuples."""
        return await self._get_results(cmd)
