"""
This is a helper class to provide some common functionality around HTTP
client operations.
"""
import asyncio
from collections import namedtuple
import logging
import ssl

import aiohttp
import httpx
import vcr
import yaml

# This library used by aiohttp 3.8 needs its logging turned down
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)

class HttpClient(object):
    """Helper class to simplify HTTP operations on an open session."""
    def __init__(self, url, verify_ssl=True, username=None, password=None,
                 auth=None, cert=None, key=None, user_agent=None,
                 connection_limit=None, conn_timeout=None, read_timeout=None,
                 record=None):
        """Initialize the session pool and connector objects.
           We found enable_cleanup_closed reduced memory leaks (STM-3012)."""
        work_around_aiohttp_3535()
        timeout = aiohttp.ClientTimeout(connect=conn_timeout or 30,
                                        sock_read=read_timeout or 300)
        if verify_ssl:
            conn = aiohttp.TCPConnector(limit_per_host=connection_limit,
                                        enable_cleanup_closed=True)
            self._session = aiohttp.ClientSession(connector=conn, timeout=timeout)
        else:
            ssl_ctx = ssl._create_unverified_context()
            try:
                try:
                    ssl_ctx.minimum_version = ssl.TLSVersion.TLSv1  # STM-6823
                except AttributeError:
                    pass  # Python 3.6
                ssl_ctx.set_ciphers('DEFAULT@SECLEVEL=1')  # STM-4531
            except ssl.SSLError:
                pass  # some platforms don't allow this, ok
            if cert and key:
                ssl_ctx.load_cert_chain(cert, key)
                conn = aiohttp.TCPConnector(ssl_context=ssl_ctx,
                                            limit_per_host=connection_limit,
                                            enable_cleanup_closed=True)
            else:
                conn = aiohttp.TCPConnector(ssl_context=ssl_ctx,
                                            limit_per_host=connection_limit,
                                            enable_cleanup_closed=True)
            self._session = aiohttp.ClientSession(connector=conn, timeout=timeout,
                                                  cookie_jar=aiohttp.CookieJar(unsafe=True))
        self._url = url
        self.auth = auth
        if username:
            if password:
                self.auth = aiohttp.BasicAuth(username, password)
            else:
                self.auth = aiohttp.BasicAuth(username)
        self.user_agent = user_agent or 'preseem-netmeta-agent'

        self.vcr_cassette = None
        if record:  # record all http transactions to the provided name
            self.record(record)
        self.http2_client = None
        self.http2_record = None

    def record(self, name):
        if not self.vcr_cassette:
            vcr_log = logging.getLogger("vcr")
            vcr_log.setLevel(logging.WARNING)
            self.vcr = vcr.VCR(record_mode='all', cassette_library_dir='/tmp')
            self.vcr_cm = self.vcr.use_cassette(name)
            self.vcr_cassette = self.vcr_cm.__enter__()

    async def http2_request(self, path, num_retries=5, **kwargs):
        """Make a http2 request.  Returns the httpx response object."""
        if not self.http2_client:
            self.http2_client = httpx.AsyncClient(http2=True, verify=False)
        url = '{}{}'.format(self._url, path)
        err = None
        for retry_num in range(num_retries + 1):
            if retry_num > 0:
                await asyncio.sleep(2**retry_num)
                if retry_num > 1:
                    # don't log on the first retry, let it fail once silently
                    logging.warning('Retry request to {}, last error {}'.format(url, err))
            method = 'POST' if kwargs.get('data') or kwargs.get('json') or kwargs.get('content') else 'GET'
            try:
                # We want to just return the request to the caller, for now.
                response = await self.http2_client.request(method, url, **kwargs)
                if self.vcr_cassette is not None:
                    if self.http2_record is None:
                        self.http2_record = []
                    self.http2_record.append({
                        'request':
                            {
                                'url': url,
                                **kwargs
                            },
                        'response':
                            {
                                'status_code': response.status_code,
                                'reason_phrase': response.reason_phrase,
                                'http_version': response.http_version,
                                'url': str(response.url),
                                'headers': dict(response.headers),
                                'content': response.content,
                            }
                    })
                return response
            except asyncio.CancelledError:
                raise
            except asyncio.TimeoutError:
                err = 'TimeoutError'
            except httpx.RequestError as e:
                err = e
        raise RuntimeError(f'Error fetching {url}: {err}')

    async def http2_stream(self, path, **kwargs):
        """Make a http2 stream request.  Returns the httpx response object."""
        url = '{}{}'.format(self._url, path)
        method = 'POST' if kwargs.get('data') or kwargs.get('json') or kwargs.get('content') else 'GET'
        if not self.http2_client:
            self.http2_client = httpx.AsyncClient(http2=True, verify=False)
        req = self.http2_client.build_request(method, url, **kwargs)
        r = await self.http2_client.send(req, stream=True)
        r.preseem_client = self
        if self.vcr_cassette is not None:  # record this session
            if self.http2_record is None:
                self.http2_record = []
            self.http2_record.append({
                'request':
                    {
                        'url': url,
                        **kwargs
                    },
                'response':
                    {
                        'status_code': r.status_code,
                        'reason_phrase': r.reason_phrase,
                        'http_version': r.http_version,
                        'url': str(r.url),
                        'headers': dict(r.headers),
                    },
                'stream': []
            })
        return r

    def record_http2_stream_data(self, r, data):
        """Helper to facilitate recording data for http streams.  This is
           needed because I haven't found a good way to do it transparently."""
        if self.vcr_cassette is not None:
            cur = self.http2_record[-1]
            if 'stream' in cur and cur['request']['url'] == r.url:
                cur['stream'].append(data)

    async def close(self):
        if self.http2_client:
            await self.http2_client.aclose()
        if self.vcr_cassette is not None:
            self.vcr_cm.__exit__()
            self.vcr_cassette = None
        if self._session:
            await self._session.close()

    async def request(self, path, num_retries=5, handler=None, response_encoding=None, **kwargs):
        """Post a request to the session."""
        hdrs = {
            'user-agent': self.user_agent
        }
        hdrs.update(kwargs.get('headers') or {})
        if 'headers' in kwargs:
            del kwargs['headers']
        if self.auth and 'auth' not in kwargs:
            kwargs['auth'] = self.auth
        url = '{}{}'.format(self._url, path)
        err = None
        for retry_num in range(num_retries + 1):
            if retry_num > 0:
                await asyncio.sleep(2**retry_num)
                if retry_num > 1:
                    # don't log on the first retry, let it fail once silently
                    logging.warning('Retry request to {}, last error {}'.format(url, err))
            try:
                method = 'POST' if kwargs.get('data') or kwargs.get('json') else 'GET'
                if 'method' in kwargs:
                    method = kwargs['method']
                    del kwargs['method']
                async with self._session.request(method, url, headers=hdrs, **kwargs) as r:
                    if handler:
                        # Allow a custom response handler to be passed in
                        return await handler(r)
                    if 200 <= r.status < 300:
                        if r.content_type == 'application/json':
                            if response_encoding:
                                return await r.json(encoding=response_encoding)
                            else:
                                return await r.json()
                        else:
                            if response_encoding:
                                return await r.text(encoding=response_encoding)
                            else:
                                return await r.text()
                    msg = await r.text()
                    err = 'Status code={} text={}'.format(r.status, msg)
                    if r.status >= 500:
                        err = r.text
                        continue
                    r.raise_for_status()
            except asyncio.TimeoutError:
                err = 'TimeoutError'
            except aiohttp.ServerDisconnectedError:
                # This is just a regular occurrence.
                # We only log an error on the second retry to avoid this.
                err = 'ServerDisconnectedError'
            except aiohttp.ClientConnectionError as e:
                err = 'ClientConnectionError'
            except aiohttp.ServerDisconnectedError:
                err = 'ServerDisconnectedError'
        raise RuntimeError('Error fetching {}: {}'.format(url, err))


def ignore_aiohttp_ssl_eror(loop, aiohttpversion='3.8.1'):
    """Ignore aiohttp #3535 issue with SSL data after close
    There appears to be an issue on Python 3.7 and aiohttp SSL that throws a
    ssl.SSLError fatal error (ssl.SSLError: [SSL: KRB5_S_INIT] application data
    after close notify (_ssl.c:2609)) after we are already done with the
    connection. See GitHub issue aio-libs/aiohttp#3535

    Given a loop, this sets up a exception handler that ignores this specific
    exception, but passes everything else on to the previous exception handler
    this one replaces.

    If the current aiohttp version is not exactly equal to aiohttpversion
    nothing is done, assuming that the next version will have this bug fixed.
    This can be disabled by setting this parameter to None
    """
    if aiohttpversion is not None and aiohttp.__version__ != aiohttpversion:
        return

    orig_handler = loop.get_exception_handler() or loop.default_exception_handler

    def ignore_ssl_error(loop, context):
        if context.get('message') == 'SSL error in data received':
            # validate we have the right exception, transport and protocol
            exception = context.get('exception')
            protocol = context.get('protocol')
            if (
                isinstance(exception, ssl.SSLError) and exception.reason == 'KRB5_S_INIT' and
                isinstance(protocol, asyncio.sslproto.SSLProtocol) and
                isinstance(protocol._app_protocol, aiohttp.client_proto.ResponseHandler)
            ):  
                if loop.get_debug():
                    asyncio.log.logger.debug('Ignoring aiohttp SSL KRB5_S_INIT error')
                return
        try:
            orig_handler(loop, context)
        except TypeError:
            orig_handler(context)

    loop.set_exception_handler(ignore_ssl_error)

_worked_around_aiohttp_3535 = False
def work_around_aiohttp_3535():
    global _worked_around_aiohttp_3535
    if _worked_around_aiohttp_3535:
        return
    try:
        ignore_aiohttp_ssl_eror(asyncio.get_running_loop())
    except AttributeError:  # predates python 3.7
        ignore_aiohttp_ssl_eror(asyncio.get_event_loop())
    finally:
        _worked_around_aiohttp_3535 = True

dict2fs = lambda d: frozenset((k, frozenset(v.items()) if isinstance(v, dict) else v) for k, v in d.items())

class ReplayHttpClient(HttpClient):
    """Special HTTP client that replays a capture file.  Currently http2 only,
       we use vcrpy to replay other http."""
    file = None
    http2_client = None
    _session = None
    def __init__(self, url, file=None):
        super().__init__(url)
        self._url = url
        file = file or self.file
        if file:
            with open(file) as f:  # we assume caller set the file class var
                data = yaml.safe_load(f.read())
            self.data = {}
            for rec in data:
                self.data[dict2fs(rec['request'])] = rec
        self.vcr_cassette = None

    async def http2_request(self, path, num_retries=None, **kwargs):
        """Make a http2 request.  Returns a httpx-compatible response object."""
        url = f'{self._url}{path}'
        key = dict2fs({'url': url, **kwargs})
        rec = self.data.get(key)
        if not rec:
            raise RuntimeError(f'No matching record for {key}')
        # return an object that will behave like a httpx.Response
        nt = namedtuple('Response', rec['response'].keys())
        return nt(*rec['response'].values())

    class FakeStreamResponse:
        """async iterator that works like the httpx stream response."""
        def __init__(self, chunks):
            self.chunks = chunks
        async def aclose(self):
            pass
        async def aiter_bytes(self):
            for chunk in self.chunks:
                yield chunk

    async def http2_stream(self, path, **kwargs):
        """Make a http2 stream request.  Returns a fake response object."""
        url = '{}{}'.format(self._url, path)
        key = dict2fs({'url': url, **kwargs})
        rec = self.data.get(key)
        if not rec or 'stream' not in rec:
            raise RuntimeError(f'No matching stream record for {key}')
        r = self.FakeStreamResponse(rec['stream'])
        for k, v in rec['response'].items():
            setattr(r, k, v)
        return r

    def record_http2_stream_data(self, r, data):
        pass
