From bec9a59e8ec82c18e3bf9268eaa436793dd52e35 Mon Sep 17 00:00:00 2001 From: bashonly <88596187+bashonly@users.noreply.github.com> Date: Sat, 4 May 2024 17:19:42 -0500 Subject: [PATCH] [networking] Add `extensions` attribute to `Response` (#9756) CurlCFFIRH now provides an `impersonate` field in its responses' extensions Authored by: bashonly --- test/test_networking.py | 19 +++++++++++++++++++ yt_dlp/networking/_curlcffi.py | 10 ++++++++++ yt_dlp/networking/common.py | 6 +++++- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/test/test_networking.py b/test/test_networking.py index b50f70d086..d613cb5681 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -785,6 +785,25 @@ def test_supported_impersonate_targets(self, handler): assert res.status == 200 assert std_headers['user-agent'].lower() not in res.read().decode().lower() + def test_response_extensions(self, handler): + with handler() as rh: + for target in rh.supported_targets: + request = Request( + f'http://127.0.0.1:{self.http_port}/gen_200', extensions={'impersonate': target}) + res = validate_and_send(rh, request) + assert res.extensions['impersonate'] == rh._get_request_target(request) + + def test_http_error_response_extensions(self, handler): + with handler() as rh: + for target in rh.supported_targets: + request = Request( + f'http://127.0.0.1:{self.http_port}/gen_404', extensions={'impersonate': target}) + try: + validate_and_send(rh, request) + except HTTPError as e: + res = e.response + assert res.extensions['impersonate'] == rh._get_request_target(request) + class TestRequestHandlerMisc: """Misc generic tests for request handlers, not related to request or validation testing""" diff --git a/yt_dlp/networking/_curlcffi.py b/yt_dlp/networking/_curlcffi.py index 39d1f70fb0..10751a1050 100644 --- a/yt_dlp/networking/_curlcffi.py +++ b/yt_dlp/networking/_curlcffi.py @@ -132,6 +132,16 @@ def _check_extensions(self, extensions): extensions.pop('cookiejar', None) extensions.pop('timeout', None) + def send(self, request: Request) -> Response: + target = self._get_request_target(request) + try: + response = super().send(request) + except HTTPError as e: + e.response.extensions['impersonate'] = target + raise + response.extensions['impersonate'] = target + return response + def _send(self, request: Request): max_redirects_exceeded = False session: curl_cffi.requests.Session = self._get_instance( diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py index 4c66ba66aa..a2217034c9 100644 --- a/yt_dlp/networking/common.py +++ b/yt_dlp/networking/common.py @@ -497,6 +497,7 @@ class Response(io.IOBase): @param headers: response headers. @param status: Response HTTP status code. Default is 200 OK. @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided. + @param extensions: Dictionary of handler-specific response extensions. """ def __init__( @@ -505,7 +506,9 @@ def __init__( url: str, headers: Mapping[str, str], status: int = 200, - reason: str = None): + reason: str = None, + extensions: dict = None + ): self.fp = fp self.headers = Message() @@ -517,6 +520,7 @@ def __init__( self.reason = reason or HTTPStatus(status).phrase except ValueError: self.reason = None + self.extensions = extensions or {} def readable(self): return self.fp.readable()