from datetime import datetime
from functools import partial
from uuid import UUID

from google.protobuf.timestamp_pb2 import Timestamp


def _py_result(model, r, del_uuid):
    """Convert the result of a model operation to a python value,
       and update the object cache in the model."""

    def pb2py(pb, cls):
        """convert a protobuf to a python object."""
        fields = {}
        for x in pb.DESCRIPTOR.oneofs_by_name.values():
            if not x.name.startswith('_'):  # used for optional attributes, ignore
                fields[x.name] = pb.WhichOneof(x.name)  # handle "oneof" attributes
        for f in pb.DESCRIPTOR.fields_by_name:
            try:
                v = getattr(pb, f, None) if pb.HasField(f) else None
            except ValueError:  # HasField only works for optional and submessage
                v = getattr(pb, f, None)
            if isinstance(v, Timestamp):
                v = v.ToNanoseconds() / 1000000000
            fields[f] = v
        uuid = fields.get('uuid')
        if uuid and model.cache_enabled:
            # handle this as a managed object
            cur = model._objs.get(uuid)
            if cur:
                curpb, curpy = cur
                for f, v in fields.items():
                    try:
                        pbv = getattr(curpb, f, None) if curpb.HasField(f) else None
                    except ValueError:
                        pbv = getattr(curpb, f, None)
                    if isinstance(pbv, Timestamp):
                        pbv = pbv.ToNanoseconds() / 1000000000
                    if v != pbv:
                        pyv = getattr(curpy, f, None)
                        if pyv and pbv != pyv:
                            # This field is dirty.  We may want different
                            # behavior here someday?  For now overwrite it.
                            setattr(curpy, f, v)
                        else:
                            setattr(curpy, f, v)
                curpb.CopyFrom(pb)
            else:
                curpy = cls(**fields)
                model._objs[uuid] = (pb, curpy)
            return curpy
        # no uuid, just return the object unmanagd.
        return cls(**fields)

    if isinstance(r, list):
        if r:
            cls = model._types.get(r[0].DESCRIPTOR.name)
            return [pb2py(x, cls) for x in r]
        return []
    else:
        cls = model._types.get(r.DESCRIPTOR.name)
        if cls:
            return pb2py(r, cls)
        if del_uuid:
            # this was a delete operation; delete our cached object data
            cur = model._objs.pop(del_uuid, None)
            return cur[1] if cur else None
        if len(r.DESCRIPTOR.fields) == 0:
            # This is just an empty response, we simply return None.
            return
        return r


def _pb_request(reqtype, **kwargs):
    """Setup a GRPC request protobuf for the indicated request and arguments."""
    req = reqtype()
    for k, v in kwargs.items():
        f = reqtype.DESCRIPTOR.fields_by_name[k]
        if v is None and not (f.message_type
                              and f.message_type.name.startswith('Nullable')):
            continue  # allow args to be passed a None to avoid setting them.
        if f.message_type:  # this is a nested message, not a simple type
            if isinstance(v, datetime
                          ) and f.message_type.full_name == 'google.protobuf.Timestamp':
                getattr(req, k).FromDatetime(v)  # accept a python datetime object
            elif f.message_type.name.startswith('Nullable'):
                # special case where we have a OneOf for optional values
                if v is None:
                    setattr(getattr(req, k), 'null', 0)  # NullValue.NULL_VALUE
                else:
                    setattr(getattr(req, k), 'data', v)
            else:
                if isinstance(v, list):
                    # A list of a nested type
                    for x in v:
                        lv = getattr(req, k).add()
                        for nk, nt in f.message_type.fields_by_name.items():
                            nv = getattr(x, nk)
                            try:
                                if not x.HasField(nk):
                                    continue  # skip Optional fields that are not set
                            except ValueError:
                                pass  # non-optional field
                            if nt.message_type:
                                getattr(lv, nk).MergeFrom(nv)
                            elif nv is not None:
                                setattr(lv, nk, nv)
                else:
                    # A plain nested type
                    try:
                        getattr(req, k).CopyFrom(v)
                    except AttributeError:
                        getattr(req, k).MergeFrom(v)

        else:
            setattr(req, k, v)
    return req


async def op(model, op, reqtype, delete=False, _pythonize=True, **kwargs):
    """Wrapper function to perform a GRPC operation on the model. (asyncio)"""
    req = _pb_request(reqtype, **kwargs)
    company_uuid = kwargs.get('company_uuid') or (
        kwargs.get('uuid') if reqtype.__module__ == 'company_pb2' else None)
    md = (('company-uuid', str(UUID(bytes=company_uuid))), ) if company_uuid else None
    r = await model._client._grpc_op(op, req, collect_results=True, metadata=md)
    del_uuid = kwargs.get('uuid') if delete else None
    return _py_result(model, r, del_uuid) if _pythonize else r


def op_sync(model, op, reqtype, delete=False, **kwargs):
    """Wrapper function to perform a GRPC operation on the model. (non-asyncio)"""
    company_uuid = kwargs.get('company_uuid') or (
        kwargs.get('uuid') if reqtype.__module__ == 'company_pb2' else None)
    md = (('company-uuid', str(UUID(bytes=company_uuid))), ) if company_uuid else None
    req = _pb_request(reqtype, **kwargs)
    r = op(req, metadata=md)
    try:
        iterator = iter(r)  # this is to detect whether its a sequence
        r = [x for x in iterator]  # this does network i/o
    except TypeError:  # not an iterable result, ok
        pass
    del_uuid = kwargs.get('uuid') if delete else None
    return _py_result(model, r, del_uuid)


def _extract_service_name(stubcls):
    prefixes = ['ServiceStub', 'ApiStub']
    return next((stubcls[:-len(prefix)]
                 for prefix in prefixes if stubcls.endswith(prefix)), stubcls)


def load_api(model, wrapper, pmod, gmod):
    stubs = [x for x in dir(gmod) if x.endswith('Stub')]
    if not stubs:
        return
    stubcls = stubs[0]
    svcstub = getattr(gmod, stubcls)(model._client._channel)
    svcname = _extract_service_name(stubcls)
    # Create a new python class for this model object type
    cls = type(svcname, (BaseModelObject, ), {})
    setattr(model, svcname, cls)
    model._types[svcname] = cls
    # Create python methods for the services
    for _, s in pmod.DESCRIPTOR.services_by_name.items():
        for method in s.methods:
            reqtype = getattr(pmod, method.input_type.name)
            reqns = type(reqtype.DESCRIPTOR.name, (BaseModelObject, ), {})
            setattr(cls, reqtype.DESCRIPTOR.name, reqns)
            op = getattr(svcstub, method.name)
            setattr(
                cls, method.name,
                partial(wrapper,
                        model=model,
                        op=op,
                        delete=method.name == 'Delete',
                        reqtype=reqtype))
            if method.output_type and not method.output_type.name.endswith('Response'):
                if method.output_type.name not in model._types:
                    model._types[method.output_type.name] = type(
                        method.output_type.name, (BaseModelObject, ), {})
            for nt in reqtype.DESCRIPTOR.nested_types:
                # Create a nested object class
                nc = type(nt.name, (BaseModelObject, ),
                          {f: None
                           for f in nt.fields_by_name})
                setattr(reqns, nt.name, nc)


# XXX once we require Python 3.7, this should be a dataclass.
class BaseModelObject:
    """Base class for all pythonized model objects."""

    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def __repr__(self):
        attrs = {
            x: str(UUID(bytes=y)) if y and x.endswith('uuid') else y
            for x, y in self.__dict__.items()
        }
        return f'{type(self).__name__}({attrs})'

    def __eq__(self, other):
        if other.__class__ is not self.__class__:
            return NotImplemented
        return self.__dict__ == other.__dict__

    def __hash__(self):
        return hash((x[1] for x in sorted(self.__dict__.items())))

    def replace(self, **kwargs):
        """Return a copy of this object with some fields updated."""
        args = self.__dict__.copy()
        args.update(kwargs)
        return self.__class__(**args)
