Source code for asgi_tools.middleware

"""ASGI-Tools Middlewares."""

from __future__ import annotations

import abc
from contextlib import suppress
from contextvars import ContextVar
from functools import partial
from inspect import isawaitable
from pathlib import Path
from typing import TYPE_CHECKING, Awaitable, Callable, Final, Mapping, Optional, Union

from http_router import Router

from .errors import ASGIError
from .logs import logger
from .request import Request
from .response import Response, ResponseError, ResponseFile, ResponseRedirect, parse_response
from .utils import to_awaitable

if TYPE_CHECKING:
    from .types import TASGIApp, TASGIMessage, TASGIReceive, TASGIScope, TASGISend


class BaseMiddeware(metaclass=abc.ABCMeta):
    """Base class for ASGI-Tools middlewares."""

    scopes: tuple[str, ...] = ("http", "websocket")

    def __init__(self, app: Optional[TASGIApp] = None) -> None:
        """Save ASGI App."""
        self.bind(app)

    def __call__(self, scope: TASGIScope, receive: TASGIReceive, send: TASGISend) -> Awaitable:
        """Handle ASGI call."""
        if scope["type"] in self.scopes:
            return self.__process__(scope, receive, send)

        return self.app(scope, receive, send)

    @abc.abstractmethod
    async def __process__(self, scope: TASGIScope, receive: TASGIReceive, send: TASGISend):
        """Do the middleware's logic."""
        raise NotImplementedError()

    @classmethod
    def setup(cls, **params) -> Callable:
        """Setup the middleware without an initialization."""
        return partial(cls, **params)  # type: ignore[abstract]

    def bind(self, app: Optional[TASGIApp] = None):
        """Rebind the middleware to an ASGI application if it has been inited already."""
        self.app = app or ResponseError.NOT_FOUND()
        return self


[docs] class ResponseMiddleware(BaseMiddeware): """Automatically convert ASGI_ apps results into responses :class:`~asgi_tools.Response` and send them to server as ASGI_ messages. .. code-block:: python from asgi_tools import ResponseMiddleware, ResponseText, ResponseRedirect async def app(scope, receive, send): # ResponseMiddleware catches ResponseError, ResponseRedirect and convert the exceptions # into HTTP response if scope['path'] == '/user': raise ResponseRedirect('/login') # Return ResponseHTML if scope['method'] == 'GET': return '<b>HTML is here</b>' # Return ResponseJSON if scope['method'] == 'POST': return {'json': 'here'} # Return any response explicitly if scope['method'] == 'PUT': return ResponseText('response is here') # Short form to responses: (status_code, body) or (status_code, body, headers) return 405, 'Unknown method' app = ResponseMiddleware(app) The conversion rules: * :class:`Response` objects will be directly returned from the view * ``dict``, ``list``, ``int``, ``bool``, ``None`` results will be converted into :class:`ResponseJSON` * ``str``, ``bytes`` results will be converted into :class:`ResponseHTML` * ``tuple[int, Any, dict]`` will be converted into a :class:`Response` with ``int`` status code, ``dict`` will be used as headers, ``Any`` will be used to define the response's type .. code-block:: python from asgi_tools import ResponseMiddleware # The result will be converted into HTML 404 response with the 'Not Found' body async def app(request, receive, send): return 404, 'Not Found' app = ResponseMiddleware(app) You are able to raise :class:`ResponseError` from yours ASGI_ apps and it will be catched and returned as a response """ async def __process__(self, scope: TASGIScope, receive: TASGIReceive, send: TASGISend): """Parse responses from callbacks.""" try: result = await self.app(scope, receive, self.send) response = parse_response(result) await response(scope, receive, send) except (ResponseError, ResponseRedirect) as exc: await exc(scope, receive, send) def send(self, _: TASGIMessage): raise RuntimeError("You can't use send() method in ResponseMiddleware") # noqa: TRY003 def bind(self, app: Optional[TASGIApp] = None): """Rebind the middleware to an ASGI application if it has been inited already.""" self.app = app or to_awaitable(lambda *_: ResponseError.NOT_FOUND()) return self
[docs] class RequestMiddleware(BaseMiddeware): """Automatically create :class:`asgi_tools.Request` from the scope and pass it to ASGI_ apps. .. code-block:: python from asgi_tools import RequestMiddleware, Response async def app(request, receive, send): content = f"{ request.method } { request.url.path }" response = Response(content) await response(scope, receive, send) app = RequestMiddleware(app) """ async def __process__(self, scope: TASGIScope, receive: TASGIReceive, send: TASGISend): """Replace scope with request object.""" return await self.app(Request(scope, receive, send), receive, send)
[docs] class LifespanMiddleware(BaseMiddeware): """Manage ASGI_ Lifespan events. :param ignore_errors: Ignore errors from startup/shutdown handlers :param on_startup: the list of callables to run when the app is starting :param on_shutdown: the list of callables to run when the app is finishing .. code-block:: python from asgi_tools import LifespanMiddleware, Response async def app(scope, receive, send): response = Response('OK') await response(scope, receive, send) app = lifespan = LifespanMiddleware(app) @lifespan.on_startup async def start(): print('The app is starting') @lifespan.on_shutdown async def start(): print('The app is finishing') Lifespan middleware may be used as an async context manager for testing purposes .. code-block: python async def test_my_app(): # ... # Registered startup/shutdown handlers will be called async with lifespan: # ... do something """ scopes = ("lifespan",) def __init__( self, app: Optional[TASGIApp] = None, *, logger=logger, ignore_errors: bool = False, on_startup: Union[Callable, list[Callable], None] = None, on_shutdown: Union[Callable, list[Callable], None] = None, ) -> None: """Prepare the middleware.""" super(LifespanMiddleware, self).__init__(app) self.ignore_errors = ignore_errors self.logger = logger self.__startup__: list[Callable] = [] self.__shutdown__: list[Callable] = [] self.__register__(on_startup, self.__startup__) self.__register__(on_shutdown, self.__shutdown__) async def __process__(self, _: TASGIScope, receive: TASGIReceive, send: TASGISend): """Manage lifespan cycle.""" while True: message = await receive() if message["type"] == "lifespan.startup": msg = await self.run("startup", send) await send(msg) elif message["type"] == "lifespan.shutdown": msg = await self.run("shutdown", send) await send(msg) break def __register__( self, handlers: Union[Callable, list[Callable], None], container: list[Callable] ) -> None: """Register lifespan handlers.""" if not handlers: return if not isinstance(handlers, list): handlers = [handlers] container += handlers async def __aenter__(self): """Use the lifespan middleware as a context manager.""" await self.run("startup") return self async def __aexit__(self, *_): """Use the lifespan middleware as a context manager.""" await self.run("shutdown") async def run(self, event: str, _: Optional[TASGISend] = None): """Run startup/shutdown handlers.""" assert event in {"startup", "shutdown"} handlers = getattr(self, f"__{event}__") for handler in handlers: try: res = handler() if isawaitable(res): await res except Exception as exc: # noqa: PERF203 self.logger.exception("%s method '%s' raises an exception.", event.title(), handler) if self.ignore_errors: continue self.logger.exception("Lifespans process failed") return {"type": f"lifespan.{event}.failed", "message": str(exc)} return {"type": f"lifespan.{event}.complete"}
[docs] def on_startup(self, fn: Callable) -> None: """Add a function to startup.""" self.__register__(fn, self.__startup__)
[docs] def on_shutdown(self, fn: Callable) -> None: """Add a function to shutdown.""" self.__register__(fn, self.__shutdown__)
[docs] class RouterMiddleware(BaseMiddeware): r"""Manage routing. .. code-block:: python from asgi_tools import RouterMiddleware, ResponseHTML, ResponseError async def default_app(scope, receive, send): response = ResponseError.NOT_FOUND() await response(scope, receive, send) app = router = RouterMiddleware(default_app) @router.route('/status', '/stat') async def status(scope, receive, send): response = ResponseHTML('STATUS OK') await response(scope, receive, send) # Bind methods # ------------ @router.route('/only-post', methods=['POST']) async def only_post(scope, receive, send): response = ResponseHTML('POST OK') await response(scope, receive, send) # Regexp paths # ------------ import re @router.route(re.compile(r'/\d+/?')) async def num(scope, receive, send): num = int(scope['path'].strip('/')) response = ResponseHTML(f'Number { num }') await response(scope, receive, send) # Dynamic paths # ------------- @router.route('/hello/{name}') async def hello(scope, receive, send): name = scope['path_params']['name'] response = ResponseHTML(f'Hello { name.title() }') await response(scope, receive, send) # Set regexp for params @router.route(r'/multiply/{first:\d+}/{second:\d+}') async def multiply(scope, receive, send): first, second = map(int, scope['path_params'].values()) response = ResponseHTML(str(first * second)) await response(scope, receive, send) Path parameters are made available in the request/scope, as the ``path_params`` dictionary. """ def __init__(self, app: Optional[TASGIApp] = None, router: Optional[Router] = None) -> None: """Initialize HTTP router.""" super().__init__(app) self.router = router or Router(validator=callable) async def __process__(self, scope: TASGIScope, *args): """Get an app and process.""" app, scope["path_params"] = self.__dispatch__(scope) return await app(scope, *args) def __dispatch__(self, scope: TASGIScope) -> tuple[Callable, Optional[Mapping]]: """Lookup for a callback.""" path = f"{scope.get('root_path', '')}{scope['path']}" try: match = self.router(path, scope["method"]) except self.router.RouterError: return self.app, {} else: return match.target, match.params # type: ignore[] def route(self, *args, **kwargs): """Register a route.""" return self.router.route(*args, **kwargs)
[docs] class StaticFilesMiddleware(BaseMiddeware): """Serve static files. :param url_prefix: an URL prefix for static files :type url_prefix: str, "/static" :param folders: Paths to folders with static files :type folders: list[str] .. code-block:: python from asgi_tools import StaticFilesMiddleware, ResponseHTML async def app(scope, receive, send): response = ResponseHTML('OK) await response(scope, receive, send) # Files from static folder will be served from /static app = StaticFilesMiddleware(app, folders=['static']) """ scopes = ("http",) def __init__( self, app: Optional[TASGIApp] = None, url_prefix: str = "/static", folders: Optional[list[Union[str, Path]]] = None, ) -> None: """Initialize the middleware.""" super().__init__(app) self.url_prefix = url_prefix folders = folders or [] self.folders: list[Path] = [Path(folder) for folder in folders] async def __process__(self, scope: TASGIScope, receive: TASGIReceive, send: TASGISend) -> None: """Serve static files for self url prefix.""" path = scope["path"] url_prefix = self.url_prefix if path.startswith(url_prefix): response: Optional[Response] = None filename = path[len(url_prefix) :].strip("/") for folder in self.folders: filepath = folder.joinpath(filename).resolve() with suppress(ASGIError): response = ResponseFile(filepath, headers_only=scope["method"] == "HEAD") break response = response or ResponseError(status_code=404) await response(scope, receive, send) else: await self.app(scope, receive, send)
BACKGROUND_TASK: Final = ContextVar[Optional[Awaitable]]("background_task", default=None)
[docs] class BackgroundMiddleware(BaseMiddeware): """Run background tasks. .. code-block:: python from asgi_tools import BackgroundMiddleware, ResponseText async def app(scope, receive, send): response = ResponseText('OK) # Schedule any awaitable for later execution BackgroundMiddleware.set_task(asyncio.sleep(1)) # Return response immediately await response(scope, receive, send) # The task will be executed after the response is sent app = BackgroundMiddleware(app) """ async def __process__(self, scope: TASGIScope, receive: TASGIReceive, send: TASGISend): """Run background tasks.""" await self.app(scope, receive, send) bgtask = BACKGROUND_TASK.get() if bgtask is not None and isawaitable(bgtask): await bgtask @staticmethod def set_task(task: Awaitable): """Set a task for background execution.""" BACKGROUND_TASK.set(task)