"""Code shared between the API classes."""
import functools
import sys
from abc import ABC, abstractmethod
from typing import (Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar,
                    Union, overload)

from msgpack import unpackb
if sys.version_info < (3, 8):
    from typing_extensions import Literal, Protocol
else:
    from typing import Literal, Protocol

from pynvim.compat import unicode_errors_default

__all__ = ()


T = TypeVar('T')
TDecodeMode = Union[Literal[True], str]


class NvimError(Exception):
    pass


class IRemote(Protocol):
    def request(self, name: str, *args: Any, **kwargs: Any) -> Any:
        raise NotImplementedError


class Remote(ABC):

    """Base class for Nvim objects(buffer/window/tabpage).

    Each type of object has it's own specialized class with API wrappers around
    the msgpack-rpc session. This implements equality which takes the remote
    object handle into consideration.
    """

    def __init__(self, session: IRemote, code_data: Tuple[int, Any]):
        """Initialize from session and code_data immutable object.

        The `code_data` contains serialization information required for
        msgpack-rpc calls. It must be immutable for Buffer equality to work.
        """
        self._session = session
        self.code_data = code_data
        self.handle = unpackb(code_data[1])
        self.api = RemoteApi(self, self._api_prefix)
        self.vars = RemoteMap(self, self._api_prefix + 'get_var',
                              self._api_prefix + 'set_var',
                              self._api_prefix + 'del_var')
        self.options = RemoteMap(self, self._api_prefix + 'get_option',
                                 self._api_prefix + 'set_option')

    @property
    @abstractmethod
    def _api_prefix(self) -> str:
        raise NotImplementedError()

    def __repr__(self) -> str:
        """Get text representation of the object."""
        return '<%s(handle=%r)>' % (
            self.__class__.__name__,
            self.handle,
        )

    def __eq__(self, other: Any) -> bool:
        """Return True if `self` and `other` are the same object."""
        return (hasattr(other, 'code_data')
                and other.code_data == self.code_data)

    def __hash__(self) -> int:
        """Return hash based on remote object id."""
        return self.code_data.__hash__()

    def request(self, name: str, *args: Any, **kwargs: Any) -> Any:
        """Wrapper for nvim.request."""
        return self._session.request(name, self, *args, **kwargs)


class RemoteApi:
    """Wrapper to allow api methods to be called like python methods."""

    def __init__(self, obj: IRemote, api_prefix: str):
        """Initialize a RemoteApi with object and api prefix."""
        self._obj = obj
        self._api_prefix = api_prefix

    def __getattr__(self, name: str) -> Callable[..., Any]:
        """Return wrapper to named api method."""
        return functools.partial(self._obj.request, self._api_prefix + name)


E = TypeVar('E', bound=Exception)


def transform_keyerror(exc: E) -> Union[E, KeyError]:
    if isinstance(exc, NvimError):
        if exc.args[0].startswith('Key not found:'):
            return KeyError(exc.args[0])
        if exc.args[0].startswith('Invalid option name:'):
            return KeyError(exc.args[0])
    return exc


class RemoteMap:
    """Represents a string->object map stored in Nvim.

    This is the dict counterpart to the `RemoteSequence` class, but it is used
    as a generic way of retrieving values from the various map-like data
    structures present in Nvim.

    It is used to provide a dict-like API to vim variables and options.
    """

    _set = None
    _del = None

    def __init__(
        self,
        obj: IRemote,
        get_method: str,
        set_method: Optional[str] = None,
        del_method: Optional[str] = None
    ):
        """Initialize a RemoteMap with session, getter/setter."""
        self._get = functools.partial(obj.request, get_method)
        if set_method:
            self._set = functools.partial(obj.request, set_method)
        if del_method:
            self._del = functools.partial(obj.request, del_method)

    def __getitem__(self, key: str) -> Any:
        """Return a map value by key."""
        try:
            return self._get(key)
        except NvimError as exc:
            raise transform_keyerror(exc)

    def __setitem__(self, key: str, value: Any) -> None:
        """Set a map value by key(if the setter was provided)."""
        if not self._set:
            raise TypeError('This dict is read-only')
        self._set(key, value)

    def __delitem__(self, key: str) -> None:
        """Delete a map value by associating None with the key."""
        if not self._del:
            raise TypeError('This dict is read-only')
        try:
            return self._del(key)
        except NvimError as exc:
            raise transform_keyerror(exc)

    def __contains__(self, key: str) -> bool:
        """Check if key is present in the map."""
        try:
            self._get(key)
            return True
        except Exception:
            return False

    @overload
    def get(self, key: str, default: T) -> T: ...

    @overload
    def get(self, key: str, default: Optional[T] = None) -> Optional[T]: ...

    def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
        """Return value for key if present, else a default value."""
        try:
            return self.__getitem__(key)
        except KeyError:
            return default


class RemoteSequence(Generic[T]):

    """Represents a sequence of objects stored in Nvim.

    This class is used to wrap msgpack-rpc functions that work on Nvim
    sequences(of lines, buffers, windows and tabpages) with an API that
    is similar to the one provided by the python-vim interface.

    For example, the 'windows' property of the `Nvim` class is a RemoteSequence
    sequence instance, and the expression `nvim.windows[0]` is translated to
    session.request('nvim_list_wins')[0].

    One important detail about this class is that all methods will fetch the
    sequence into a list and perform the necessary manipulation
    locally(iteration, indexing, counting, etc).
    """

    def __init__(self, session: IRemote, method: str):
        """Initialize a RemoteSequence with session, method."""
        self._fetch = functools.partial(session.request, method)

    def __len__(self) -> int:
        """Return the length of the remote sequence."""
        return len(self._fetch())

    @overload
    def __getitem__(self, idx: int) -> T: ...

    @overload
    def __getitem__(self, idx: slice) -> List[T]: ...

    def __getitem__(self, idx: Union[slice, int]) -> Union[T, List[T]]:
        """Return a sequence item by index."""
        if not isinstance(idx, slice):
            return self._fetch()[idx]
        return self._fetch()[idx.start:idx.stop]

    def __iter__(self) -> Iterator[T]:
        """Return an iterator for the sequence."""
        items = self._fetch()
        for item in items:
            yield item

    def __contains__(self, item: T) -> bool:
        """Check if an item is present in the sequence."""
        return item in self._fetch()


@overload
def decode_if_bytes(obj: bytes, mode: TDecodeMode = True) -> str: ...


@overload
def decode_if_bytes(obj: T, mode: TDecodeMode = True) -> Union[T, str]: ...


def decode_if_bytes(obj: T, mode: TDecodeMode = True) -> Union[T, str]:
    """Decode obj if it is bytes."""
    if mode is True:
        mode = unicode_errors_default
    if isinstance(obj, bytes):
        return obj.decode("utf-8", errors=mode)
    return obj


def walk(fn: Callable[..., Any], obj: Any, *args: Any, **kwargs: Any) -> Any:
    """Recursively walk an object graph applying `fn`/`args` to objects."""
    if type(obj) in [list, tuple]:
        return list(walk(fn, o, *args) for o in obj)
    if type(obj) is dict:
        return dict((walk(fn, k, *args), walk(fn, v, *args)) for k, v in
                    obj.items())
    return fn(obj, *args, **kwargs)
