###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################

import asyncio
from asyncio import Future, iscoroutine
from collections import deque
from typing import Optional

import txaio

txaio.use_asyncio()  # noqa

from autobahn.asyncio.util import create_transport_details, transport_channel_id
from autobahn.util import hltype, public
from autobahn.wamp import websocket
from autobahn.websocket import protocol

__all__ = (
    "WampWebSocketClientFactory",
    "WampWebSocketClientProtocol",
    "WampWebSocketServerFactory",
    "WampWebSocketServerProtocol",
    "WebSocketClientFactory",
    "WebSocketClientProtocol",
    "WebSocketServerFactory",
    "WebSocketServerProtocol",
)


def yields(value):
    """
    Returns ``True`` iff the value yields.

    .. seealso:: http://stackoverflow.com/questions/20730248/maybedeferred-analog-with-asyncio
    """
    return isinstance(value, Future) or iscoroutine(value)


class WebSocketAdapterProtocol(asyncio.Protocol):
    """
    Adapter class for asyncio-based WebSocket client and server protocols.
    """

    log = txaio.make_logger()

    peer: Optional[str] = None
    is_server: Optional[bool] = None

    def connection_made(self, transport):
        # asyncio networking framework entry point, called by asyncio
        # when the connection is established (either a client or a server)
        self.log.debug(
            "{func}(transport={transport})",
            func=hltype(self.connection_made),
            transport=transport,
        )

        self.transport = transport

        # determine preliminary transport details (what is know at this point)
        self._transport_details = create_transport_details(
            self.transport, self.is_server
        )

        # backward compatibility
        self.peer = self._transport_details.peer

        self.receive_queue = deque()
        self._consume()

        self._connectionMade()

    def connection_lost(self, exc):
        self._connectionLost(exc)
        # according to asyncio docs, connection_lost(None) is called
        # if something else called transport.close()
        if exc is not None:
            self.transport.close()
        self.transport = None

    def _consume(self):
        self.waiter = Future(loop=self.factory.loop or txaio.config.loop)

        def process(_):
            while self.receive_queue:
                data = self.receive_queue.popleft()
                if self.transport:
                    self._dataReceived(data)
            self._consume()

        self.waiter.add_done_callback(process)

    def data_received(self, data):
        self.receive_queue.append(data)
        if not self.waiter.done():
            self.waiter.set_result(None)

    def _closeConnection(self, abort=False):
        if abort and hasattr(self.transport, "abort"):
            self.transport.abort()
        else:
            self.transport.close()

    def _onOpen(self):
        if self._transport_details.is_secure:
            # now that the TLS opening handshake is complete, the actual TLS channel ID
            # will be available. make sure to set it!
            channel_id = {
                "tls-unique": transport_channel_id(
                    self.transport, self._transport_details.is_server, "tls-unique"
                ),
            }
            self._transport_details.channel_id = channel_id

        res = self.onOpen()
        if yields(res):
            asyncio.ensure_future(res)

    def _onMessageBegin(self, isBinary):
        res = self.onMessageBegin(isBinary)
        if yields(res):
            asyncio.ensure_future(res)

    def _onMessageFrameBegin(self, length):
        res = self.onMessageFrameBegin(length)
        if yields(res):
            asyncio.ensure_future(res)

    def _onMessageFrameData(self, payload):
        res = self.onMessageFrameData(payload)
        if yields(res):
            asyncio.ensure_future(res)

    def _onMessageFrameEnd(self):
        res = self.onMessageFrameEnd()
        if yields(res):
            asyncio.ensure_future(res)

    def _onMessageFrame(self, payload):
        res = self.onMessageFrame(payload)
        if yields(res):
            asyncio.ensure_future(res)

    def _onMessageEnd(self):
        res = self.onMessageEnd()
        if yields(res):
            asyncio.ensure_future(res)

    def _onMessage(self, payload, isBinary):
        res = self.onMessage(payload, isBinary)
        if yields(res):
            asyncio.ensure_future(res)

    def _onPing(self, payload):
        res = self.onPing(payload)
        if yields(res):
            asyncio.ensure_future(res)

    def _onPong(self, payload):
        res = self.onPong(payload)
        if yields(res):
            asyncio.ensure_future(res)

    def _onClose(self, wasClean, code, reason):
        res = self.onClose(wasClean, code, reason)
        if yields(res):
            asyncio.ensure_future(res)

    def registerProducer(self, producer, streaming):
        raise Exception("not implemented")

    def unregisterProducer(self):
        # note that generic websocket/protocol.py code calls
        # .unregisterProducer whenever we dropConnection -- that's
        # correct behavior on Twisted so either we'd have to
        # try/except there, or special-case Twisted, ..or just make
        # this "not an error"
        pass


@public
class WebSocketServerProtocol(
    WebSocketAdapterProtocol, protocol.WebSocketServerProtocol
):
    """
    Base class for asyncio-based WebSocket server protocols.

    Implements:

    * :class:`autobahn.websocket.interfaces.IWebSocketChannel`
    """

    log = txaio.make_logger()


@public
class WebSocketClientProtocol(
    WebSocketAdapterProtocol, protocol.WebSocketClientProtocol
):
    """
    Base class for asyncio-based WebSocket client protocols.

    Implements:

    * :class:`autobahn.websocket.interfaces.IWebSocketChannel`
    """

    log = txaio.make_logger()

    def _onConnect(self, response):
        res = self.onConnect(response)
        self.log.debug("{func}: {res}", func=hltype(self._onConnect), res=res)
        if yields(res):
            asyncio.ensure_future(res)

    def startTLS(self):
        raise Exception("WSS over explicit proxies not implemented")


class WebSocketAdapterFactory(object):
    """
    Adapter class for asyncio-based WebSocket client and server factories.
    """

    log = txaio.make_logger()

    def __call__(self):
        proto = self.protocol()
        proto.factory = self
        return proto


@public
class WebSocketServerFactory(WebSocketAdapterFactory, protocol.WebSocketServerFactory):
    """
    Base class for asyncio-based WebSocket server factories.

    Implements:

    * :class:`autobahn.websocket.interfaces.IWebSocketServerChannelFactory`
    """

    log = txaio.make_logger()

    protocol = WebSocketServerProtocol

    def __init__(self, *args, **kwargs):
        """
        .. note::
            In addition to all arguments to the constructor of
            :meth:`autobahn.websocket.interfaces.IWebSocketServerChannelFactory`,
            you can supply a ``loop`` keyword argument to specify the
            asyncio event loop to be used.
        """
        loop = kwargs.pop("loop", None)
        self.loop = loop or asyncio.get_event_loop()

        protocol.WebSocketServerFactory.__init__(self, *args, **kwargs)


@public
class WebSocketClientFactory(WebSocketAdapterFactory, protocol.WebSocketClientFactory):
    """
    Base class for asyncio-based WebSocket client factories.

    Implements:

    * :class:`autobahn.websocket.interfaces.IWebSocketClientChannelFactory`
    """

    log = txaio.make_logger()

    def __init__(self, *args, **kwargs):
        """

        .. note::
            In addition to all arguments to the constructor of
            :meth:`autobahn.websocket.interfaces.IWebSocketClientChannelFactory`,
            you can supply a ``loop`` keyword argument to specify the
            asyncio event loop to be used.
        """
        loop = kwargs.pop("loop", None)
        self.loop = loop or asyncio.get_event_loop()

        protocol.WebSocketClientFactory.__init__(self, *args, **kwargs)


@public
class WampWebSocketServerProtocol(
    websocket.WampWebSocketServerProtocol, WebSocketServerProtocol
):
    """
    asyncio-based WAMP-over-WebSocket server protocol.

    Implements:

    * :class:`autobahn.wamp.interfaces.ITransport`
    """

    log = txaio.make_logger()


@public
class WampWebSocketServerFactory(
    websocket.WampWebSocketServerFactory, WebSocketServerFactory
):
    """
    asyncio-based WAMP-over-WebSocket server factory.
    """

    log = txaio.make_logger()

    protocol = WampWebSocketServerProtocol

    def __init__(self, factory, *args, **kwargs):
        """

        :param factory: A callable that produces instances that implement
            :class:`autobahn.wamp.interfaces.ITransportHandler`
        :type factory: callable

        :param serializers: A list of WAMP serializers to use (or ``None``
            for all available serializers).
        :type serializers: list of objects implementing
            :class:`autobahn.wamp.interfaces.ISerializer`
        """

        serializers = kwargs.pop("serializers", None)

        websocket.WampWebSocketServerFactory.__init__(self, factory, serializers)

        kwargs["protocols"] = self._protocols

        # noinspection PyCallByClass
        WebSocketServerFactory.__init__(self, *args, **kwargs)


@public
class WampWebSocketClientProtocol(
    websocket.WampWebSocketClientProtocol, WebSocketClientProtocol
):
    """
    asyncio-based WAMP-over-WebSocket client protocols.

    Implements:

    * :class:`autobahn.wamp.interfaces.ITransport`
    """

    log = txaio.make_logger()


@public
class WampWebSocketClientFactory(
    websocket.WampWebSocketClientFactory, WebSocketClientFactory
):
    """
    asyncio-based WAMP-over-WebSocket client factory.
    """

    log = txaio.make_logger()

    protocol = WampWebSocketClientProtocol

    def __init__(self, factory, *args, **kwargs):
        """

        :param factory: A callable that produces instances that implement
            :class:`autobahn.wamp.interfaces.ITransportHandler`
        :type factory: callable

        :param serializer: The WAMP serializer to use (or ``None`` for
           "best" serializer, chosen as the first serializer available from
           this list: CBOR, MessagePack, UBJSON, JSON).
        :type serializer: object implementing :class:`autobahn.wamp.interfaces.ISerializer`
        """

        serializers = kwargs.pop("serializers", None)

        websocket.WampWebSocketClientFactory.__init__(self, factory, serializers)

        kwargs["protocols"] = self._protocols

        WebSocketClientFactory.__init__(self, *args, **kwargs)
