46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar
|
|
|
|
from aiohttp.web_request import Request
|
|
from aiohttp.web_response import StreamResponse
|
|
|
|
if TYPE_CHECKING:
|
|
F = TypeVar("F", bound=Callable[..., Any])
|
|
middleware: Callable[[F], F]
|
|
else:
|
|
try:
|
|
from aiohttp.web_middlewares import middleware
|
|
except ImportError:
|
|
# @middleware is deprecated and its behaviour is the default since aiohttp 4.0
|
|
# so if it doesn't exist anymore, define a no-op for forward compatibility.
|
|
middleware = lambda x: x # noqa: E731
|
|
|
|
Handler = Callable[[Request], Awaitable[StreamResponse]]
|
|
Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]]
|
|
|
|
|
|
def cors(allow_headers: Iterable[str]) -> Middleware:
|
|
@middleware
|
|
async def impl(request: Request, handler: Handler) -> StreamResponse:
|
|
is_options = request.method == "OPTIONS"
|
|
is_preflight = is_options and "Access-Control-Request-Method" in request.headers
|
|
if is_preflight:
|
|
resp = StreamResponse()
|
|
else:
|
|
resp = await handler(request)
|
|
|
|
origin = request.headers.get("Origin")
|
|
if not origin:
|
|
return resp
|
|
|
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
resp.headers["Access-Control-Expose-Headers"] = "*"
|
|
if is_options:
|
|
resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
|
|
resp.headers["Access-Control-Allow-Methods"] = ", ".join(
|
|
("OPTIONS", "POST")
|
|
)
|
|
|
|
return resp
|
|
|
|
return impl
|