"""Testing tools."""from__future__importannotationsimportasyncioimportbinasciiimportioimportmimetypesimportosimportrandomfromcollectionsimportdequefromcontextlibimportasynccontextmanager,suppressfromfunctoolsimportpartialfromhttp.cookiesimportSimpleCookiefromjsonimportloadsfrompathlibimportPathfromtypingimport(TYPE_CHECKING,Any,AsyncGenerator,Awaitable,Callable,Coroutine,Deque,Optional,Union,cast,)fromurllib.parseimporturlencodefrommultidictimportMultiDictfromyarlimportURLfrom._compatimportaio_cancel,aio_sleep,aio_spawn,aio_timeout,aio_waitfrom.constantsimportBASE_ENCODING,DEFAULT_CHARSETfrom.errorsimportASGIConnectionClosedError,ASGIInvalidMessageErrorfrom.responseimportResponse,ResponseJSON,ResponseWebSocket,parse_websocket_msgfrom.utilsimportCIMultiDict,parse_headersifTYPE_CHECKING:from.typesimportTJSON,TASGIApp,TASGIMessage,TASGIReceive,TASGIScope,TASGISendclassTestResponse(Response):"""Response for test client."""def__init__(self):super().__init__(b"")self.content=Noneasyncdef__call__(self,_:TASGIScope,receive:TASGIReceive,send:TASGISend):# noqa: ARG002self._receive=receivemsg=awaitself._receive()assertmsg.get("type")=="http.response.start","Invalid Response"self.status_code=int(msg.get("status",502))self.headers=cast(MultiDict,parse_headers(msg.get("headers",[])))self.content_type=self.headers.get("content-type")forcookieinself.headers.getall("set-cookie",[]):self.cookies.load(cookie)asyncdefstream(self)->AsyncGenerator[bytes,None]:"""Stream the response."""more_body=Truewhilemore_body:msg=awaitself._receive()ifmsg.get("type")=="http.response.body":chunk=msg.get("body")ifchunk:yieldchunkmore_body=msg.get("more_body",False)asyncdefbody(self)->bytes:"""Load response body."""ifself.contentisNone:body=b""asyncforchunkinself.stream():body+=chunkself.content=bodyreturnself.contentasyncdeftext(self)->str:body=awaitself.body()returnbody.decode(DEFAULT_CHARSET)asyncdefjson(self)->TJSON:text=awaitself.text()returnloads(text)classTestWebSocketResponse(ResponseWebSocket):"""Support websockets in tests."""defconnect(self)->Coroutine[TASGIMessage,Any,Any]:returnself.send({"type":"websocket.connect"})asyncdefdisconnect(self):awaitself.send({"type":"websocket.disconnect","code":1005})self.state=self.STATES.DISCONNECTEDdefsend(self,msg,msg_type="websocket.receive"):"""Send a message to a client."""returnsuper().send(msg,msg_type=msg_type)asyncdefreceive(self,*,raw=False):"""Receive messages from a client."""ifself.partner_state==self.STATES.DISCONNECTED:raiseASGIConnectionClosedErrormsg=awaitself._receive()ifnotmsg["type"].startswith("websocket."):raiseASGIInvalidMessageError(msg)ifmsg["type"]=="websocket.accept":self.partner_state=self.STATES.CONNECTEDreturnawaitself.receive(raw=raw)ifmsg["type"]=="websocket.close":self.partner_state=self.STATES.DISCONNECTEDraiseASGIConnectionClosedErrorreturnmsgifrawelseparse_websocket_msg(msg,charset=DEFAULT_CHARSET)
[docs]classASGITestClient:"""The test client allows you to make requests against an ASGI application. Features: * cookies * multipart/form-data * follow redirects * request streams * response streams * websocket support * lifespan management """def__init__(self,app:TASGIApp,base_url:str="http://localhost"):self.app=appself.base_url=URL(base_url)self.cookies:SimpleCookie=SimpleCookie()self.headers:dict[str,str]={}def__getattr__(self,name:str)->Callable[...,Awaitable]:returnpartial(self.request,method=name.upper())
[docs]asyncdefrequest(self,path:str,method:str="GET",*,query:Union[str,dict]="",headers:Optional[dict[str,str]]=None,cookies:Optional[dict[str,str]]=None,data:Union[bytes,str,dict,AsyncGenerator[Any,bytes]]=b"",json:TJSON=None,follow_redirect:bool=True,timeout:float=10.0,)->TestResponse:"""Make a HTTP requests."""headers=headersordict(self.headers)ifisinstance(data,str):data=Response.process_content(data)elifisinstance(data,dict):is_multipart=any(isinstance(value,io.IOBase)forvalueindata.values())ifis_multipart:data,headers["Content-Type"]=encode_multipart(data)else:headers["Content-Type"]="application/x-www-form-urlencoded"data=urlencode(data).encode(DEFAULT_CHARSET)elifjsonisnotNone:headers["Content-Type"]="application/json"data=ResponseJSON.process_content(json)pipe=Pipe()ifisinstance(data,bytes):headers.setdefault("Content-Length",str(len(data)))scope=self.build_scope(path,type="http",query=query,method=method,headers=headers,cookies=cookies,)asyncwithaio_timeout(timeout):awaitaio_wait(pipe.stream(data),self.app(scope,pipe.receive_from_app,pipe.send_to_client),)res=TestResponse()awaitres(scope,pipe.receive_from_client,pipe.send_to_app)forn,vinres.cookies.items():self.cookies[n]=viffollow_redirectandres.status_codein{301,302,303,307,308}:returnawaitself.get(res.headers["location"])returnres
# TODO: Timeouts for websockets
[docs]@asynccontextmanagerasyncdefwebsocket(self,path:str,query:Union[str,dict,None]=None,headers:Optional[dict]=None,cookies:Optional[dict]=None,):"""Connect to a websocket."""pipe=Pipe()ci_headers=CIMultiDict(headersor{})scope=self.build_scope(path,headers=ci_headers,query=query,cookies=cookies,type="websocket",subprotocols=str(ci_headers.get("Sec-WebSocket-Protocol","")).split(","),)ws=TestWebSocketResponse(scope,pipe.receive_from_client,pipe.send_to_app)asyncwithaio_spawn(self.app,scope,pipe.receive_from_app,pipe.send_to_client,):awaitws.connect()yieldwsawaitws.disconnect()