Source code for asgi_tools.tests

"""Testing tools."""

from __future__ import annotations

import asyncio
import binascii
import io
import mimetypes
import os
import random
from collections import deque
from contextlib import asynccontextmanager, suppress
from functools import partial
from http.cookies import SimpleCookie
from json import loads
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Awaitable,
    Callable,
    Coroutine,
    Deque,
    Optional,
    Union,
    cast,
)
from urllib.parse import urlencode

from multidict import MultiDict
from yarl import URL

from ._compat import aio_cancel, aio_sleep, aio_spawn, aio_timeout, aio_wait
from .constants import BASE_ENCODING, DEFAULT_CHARSET
from .errors import ASGIConnectionClosedError, ASGIInvalidMessageError
from .response import Response, ResponseJSON, ResponseWebSocket, parse_websocket_msg
from .utils import CIMultiDict, parse_headers

if TYPE_CHECKING:
    from .types import TJSON, TASGIApp, TASGIMessage, TASGIReceive, TASGIScope, TASGISend


class TestResponse(Response):
    """Response for test client."""

    def __init__(self):
        super().__init__(b"")
        self.content = None

    async def __call__(self, _: TASGIScope, receive: TASGIReceive, send: TASGISend):  # noqa: ARG002
        self._receive = receive
        msg = await self._receive()
        assert msg.get("type") == "http.response.start", "Invalid Response"
        self.status_code = int(msg.get("status", 502))
        self.headers = cast(MultiDict, parse_headers(msg.get("headers", [])))
        self.content_type = self.headers.get("content-type")
        for cookie in self.headers.getall("set-cookie", []):
            self.cookies.load(cookie)

    async def stream(self) -> AsyncGenerator[bytes, None]:
        """Stream the response."""
        more_body = True
        while more_body:
            msg = await self._receive()
            if msg.get("type") == "http.response.body":
                chunk = msg.get("body")
                if chunk:
                    yield chunk
                more_body = msg.get("more_body", False)

    async def body(self) -> bytes:
        """Load response body."""
        if self.content is None:
            body = b""
            async for chunk in self.stream():
                body += chunk
            self.content = body

        return self.content

    async def text(self) -> str:
        body = await self.body()
        return body.decode(DEFAULT_CHARSET)

    async def json(self) -> TJSON:
        text = await self.text()
        return loads(text)


class TestWebSocketResponse(ResponseWebSocket):
    """Support websockets in tests."""

    def connect(self) -> Coroutine[TASGIMessage, Any, Any]:
        return self.send({"type": "websocket.connect"})

    async def disconnect(self):
        await self.send({"type": "websocket.disconnect", "code": 1005})
        self.state = self.STATES.DISCONNECTED

    def send(self, msg, msg_type="websocket.receive"):
        """Send a message to a client."""
        return super().send(msg, msg_type=msg_type)

    async def receive(self, *, raw=False):
        """Receive messages from a client."""
        if self.partner_state == self.STATES.DISCONNECTED:
            raise ASGIConnectionClosedError

        msg = await self._receive()
        if not msg["type"].startswith("websocket."):
            raise ASGIInvalidMessageError(msg)

        if msg["type"] == "websocket.accept":
            self.partner_state = self.STATES.CONNECTED
            return await self.receive(raw=raw)

        if msg["type"] == "websocket.close":
            self.partner_state = self.STATES.DISCONNECTED
            raise ASGIConnectionClosedError

        return msg if raw else parse_websocket_msg(msg, charset=DEFAULT_CHARSET)


[docs] class ASGITestClient: """The test client allows you to make requests against an ASGI application. Features: * cookies * multipart/form-data * follow redirects * request streams * response streams * websocket support * lifespan management """ def __init__(self, app: TASGIApp, base_url: str = "http://localhost"): self.app = app self.base_url = URL(base_url) self.cookies: SimpleCookie = SimpleCookie() self.headers: dict[str, str] = {} def __getattr__(self, name: str) -> Callable[..., Awaitable]: return partial(self.request, method=name.upper())
[docs] async def request( self, path: str, method: str = "GET", *, query: Union[str, dict] = "", headers: Optional[dict[str, str]] = None, cookies: Optional[dict[str, str]] = None, data: Union[bytes, str, dict, AsyncGenerator[Any, bytes]] = b"", json: TJSON = None, follow_redirect: bool = True, timeout: float = 10.0, ) -> TestResponse: """Make a HTTP requests.""" headers = headers or dict(self.headers) if isinstance(data, str): data = Response.process_content(data) elif isinstance(data, dict): is_multipart = any(isinstance(value, io.IOBase) for value in data.values()) if is_multipart: data, headers["Content-Type"] = encode_multipart(data) else: headers["Content-Type"] = "application/x-www-form-urlencoded" data = urlencode(data).encode(DEFAULT_CHARSET) elif json is not None: headers["Content-Type"] = "application/json" data = ResponseJSON.process_content(json) pipe = Pipe() if isinstance(data, bytes): headers.setdefault("Content-Length", str(len(data))) scope = self.build_scope( path, type="http", query=query, method=method, headers=headers, cookies=cookies, ) async with aio_timeout(timeout): await aio_wait( pipe.stream(data), self.app(scope, pipe.receive_from_app, pipe.send_to_client), ) res = TestResponse() await res(scope, pipe.receive_from_client, pipe.send_to_app) for n, v in res.cookies.items(): self.cookies[n] = v if follow_redirect and res.status_code in {301, 302, 303, 307, 308}: return await self.get(res.headers["location"]) return res
# TODO: Timeouts for websockets
[docs] @asynccontextmanager async def websocket( self, path: str, query: Union[str, dict, None] = None, headers: Optional[dict] = None, cookies: Optional[dict] = None, ): """Connect to a websocket.""" pipe = Pipe() ci_headers = CIMultiDict(headers or {}) scope = self.build_scope( path, headers=ci_headers, query=query, cookies=cookies, type="websocket", subprotocols=str(ci_headers.get("Sec-WebSocket-Protocol", "")).split(","), ) ws = TestWebSocketResponse(scope, pipe.receive_from_client, pipe.send_to_app) async with aio_spawn( self.app, scope, pipe.receive_from_app, pipe.send_to_client, ): await ws.connect() yield ws await ws.disconnect()
[docs] def lifespan(self, timeout: float = 3e-2): """Manage `Lifespan <https://asgi.readthedocs.io/en/latest/specs/lifespan.html>`_ protocol.""" return manage_lifespan(self.app, timeout=timeout)
def build_scope( self, path: str, headers: Union[dict, CIMultiDict, None] = None, query: Union[str, dict, None] = None, cookies: Optional[dict] = None, **scope, ) -> TASGIScope: """Prepare a request scope.""" headers = headers or {} headers.setdefault("User-Agent", "ASGI-Tools-Test-Client") headers.setdefault("Host", self.base_url.host) if cookies: for c, v in cookies.items(): self.cookies[c] = v if len(self.cookies): headers.setdefault("Cookie", self.cookies.output(header="", sep=";")) url = URL(path) if query: url = url.with_query(query) # Setup client scope.setdefault("client", ("127.0.0.1", random.randint(1024, 65535))) # noqa: S311 return dict( { "asgi": {"version": "3.0"}, "http_version": "1.1", "path": url.path, "query_string": url.raw_query_string.encode(), "raw_path": url.raw_path.encode(), "root_path": "", "scheme": scope.get("type") == "http" and self.base_url.scheme or "ws", "headers": [ (key.lower().encode(BASE_ENCODING), str(val).encode(BASE_ENCODING)) for key, val in (headers or {}).items() ], "server": ("127.0.0.1", self.base_url.port), }, **scope, )
def encode_multipart(data: dict) -> tuple[bytes, str]: body = io.BytesIO() boundary = binascii.hexlify(os.urandom(16)) for name, data_value in data.items(): value = data_value headers = f'Content-Disposition: form-data; name="{ name }"' if hasattr(value, "read"): filename = getattr(value, "name", None) if filename: headers = f'{ headers }; filename="{ Path(filename).name }"' content_type = mimetypes.guess_type(filename)[0] or "application/octet-stream" headers = f"{ headers }\r\nContent-Type: { content_type }" value = value.read() body.write(b"--%b\r\n" % boundary) body.write(headers.encode("utf-8")) body.write(b"\r\n\r\n") if isinstance(value, str): value = value.encode("utf-8") body.write(value) body.write(b"\r\n") body.write(b"--%b--\r\n" % boundary) return body.getvalue(), (b"multipart/form-data; boundary=%s" % boundary).decode() class Pipe: __slots__ = ( "delay", "app_is_closed", "client_is_closed", "app_queue", "client_queue", ) def __init__(self, delay: float = 1e-3): self.delay = delay self.app_is_closed = False self.client_is_closed = False self.app_queue: Deque[TASGIMessage] = deque() self.client_queue: Deque[TASGIMessage] = deque() async def send_to_client(self, msg: TASGIMessage): if self.client_is_closed: raise ASGIInvalidMessageError(msg.get("type")) if msg.get("type") == "websocket.close": self.client_is_closed = True elif msg.get("type") == "http.response.body": self.client_is_closed = not msg.get("more_body", False) self.client_queue.append(msg) async def send_to_app(self, msg: TASGIMessage): if self.app_is_closed: raise ASGIInvalidMessageError(msg.get("type")) if msg.get("type") == "http.disconnect": self.app_is_closed = True self.app_queue.append(msg) async def receive_from_client(self): while not self.client_queue: await aio_sleep(self.delay) return self.client_queue.popleft() async def receive_from_app(self): while not self.app_queue: await aio_sleep(self.delay) return self.app_queue.popleft() async def stream(self, data: Union[bytes, AsyncGenerator[Any, bytes]]): if isinstance(data, bytes): return await self.send_to_app( {"type": "http.request", "body": data, "more_body": False}, ) async for chunk in data: await self.send_to_app({"type": "http.request", "body": chunk, "more_body": True}) await self.send_to_app({"type": "http.request", "body": b"", "more_body": False}) return None @asynccontextmanager async def manage_lifespan(app, timeout: float = 3e-2): """Manage `Lifespan <https://asgi.readthedocs.io/en/latest/specs/lifespan.html>`_ protocol.""" pipe = Pipe() scope = {"type": "lifespan"} async def safe_spawn(): with suppress(BaseException): await app(scope, pipe.receive_from_app, pipe.send_to_client) async with aio_spawn(safe_spawn) as task: await pipe.send_to_app({"type": "lifespan.startup"}) with suppress(TimeoutError, asyncio.TimeoutError): # python 39, 310 async with aio_timeout(timeout): msg = await pipe.receive_from_client() if msg["type"] == "lifespan.startup.failed": await aio_cancel(task) yield await pipe.send_to_app({"type": "lifespan.shutdown"}) with suppress(TimeoutError, asyncio.TimeoutError): # python 39, 310 async with aio_timeout(timeout): await pipe.receive_from_client()