import asyncio
import os

# because py.test tries to collect it as a test-case
from unittest.mock import Mock

import pytest
import txaio

from autobahn.asyncio.websocket import WebSocketServerFactory


async def echo_async(what, when):
    await asyncio.sleep(when)
    return what


@pytest.mark.skipif(
    not os.environ.get("USE_ASYNCIO", False), reason="test runs on asyncio only"
)
@pytest.mark.asyncio
async def test_echo_async():
    assert "Hello!" == await echo_async("Hello!", 0)


@pytest.mark.skipif(
    not os.environ.get("USE_ASYNCIO", False), reason="test runs on asyncio only"
)
def test_websocket_custom_loop():
    factory = WebSocketServerFactory(loop=asyncio.new_event_loop())
    server = factory()
    transport = Mock()
    server.connection_made(transport)


@pytest.mark.skipif(
    not os.environ.get("USE_ASYNCIO", False), reason="test runs on asyncio only"
)
def test_async_on_connect_server():
    num = 42
    done = txaio.create_future()
    values = []

    async def foo(x):
        await asyncio.sleep(1)
        return x * x

    async def on_connect(req):
        v = await foo(num)
        values.append(v)
        txaio.resolve(done, req)

    factory = WebSocketServerFactory()
    server = factory()
    server.onConnect = on_connect
    transport = Mock()

    server.connection_made(transport)
    server.data = b"\r\n".join(
        [
            b"GET /ws HTTP/1.1",
            b"Host: www.example.com",
            b"Sec-WebSocket-Version: 13",
            b"Origin: http://www.example.com.malicious.com",
            b"Sec-WebSocket-Extensions: permessage-deflate",
            b"Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==",
            b"Connection: keep-alive, Upgrade",
            b"Upgrade: websocket",
            b"\r\n",  # last string doesn't get a \r\n from join()
        ]
    )
    server.processHandshake()

    asyncio.get_event_loop().run_until_complete(done)

    assert len(values) == 1
    assert values[0] == num * num
