diff --git a/test/test_networking.py b/test/test_networking.py
index afdd0c7aa7..e972f597b5 100644
--- a/test/test_networking.py
+++ b/test/test_networking.py
@@ -3,6 +3,7 @@
# Allow direct execution
import os
import sys
+from unittest.mock import MagicMock
import pytest
@@ -614,8 +615,11 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
def test_gzip_trailing_garbage(self, handler):
with handler() as rh:
- data = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode()
+ res = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage'))
+ data = res.read().decode()
assert data == ''
+ # Should auto-close and mark the response adaptor as closed
+ assert res.closed
@pytest.mark.skip_handler('CurlCFFI', 'not applicable to curl-cffi')
@pytest.mark.skipif(not brotli, reason='brotli support is not installed')
@@ -627,6 +631,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
headers={'ytdl-encoding': 'br'}))
assert res.headers.get('Content-Encoding') == 'br'
assert res.read() == b''
+ # Should auto-close and mark the response adaptor as closed
+ assert res.closed
def test_deflate(self, handler):
with handler() as rh:
@@ -636,6 +642,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
headers={'ytdl-encoding': 'deflate'}))
assert res.headers.get('Content-Encoding') == 'deflate'
assert res.read() == b''
+ # Should auto-close and mark the response adaptor as closed
+ assert res.closed
def test_gzip(self, handler):
with handler() as rh:
@@ -645,6 +653,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
headers={'ytdl-encoding': 'gzip'}))
assert res.headers.get('Content-Encoding') == 'gzip'
assert res.read() == b''
+ # Should auto-close and mark the response adaptor as closed
+ assert res.closed
def test_multiple_encodings(self, handler):
with handler() as rh:
@@ -655,6 +665,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
headers={'ytdl-encoding': pair}))
assert res.headers.get('Content-Encoding') == pair
assert res.read() == b''
+ # Should auto-close and mark the response adaptor as closed
+ assert res.closed
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
def test_unsupported_encoding(self, handler):
@@ -665,6 +677,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
headers={'ytdl-encoding': 'unsupported', 'Accept-Encoding': '*'}))
assert res.headers.get('Content-Encoding') == 'unsupported'
assert res.read() == b'raw'
+ # Should auto-close and mark the response adaptor as closed
+ assert res.closed
def test_read(self, handler):
with handler() as rh:
@@ -672,9 +686,13 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request(f'http://127.0.0.1:{self.http_port}/headers'))
assert res.readable()
assert res.read(1) == b'H'
+ # Ensure we don't close the adaptor yet
+ assert not res.closed
assert res.read(3) == b'ost'
assert res.read().decode().endswith('\n\n')
assert res.read() == b''
+ # Should auto-close and mark the response adaptor as closed
+ assert res.closed
def test_request_disable_proxy(self, handler):
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
@@ -875,11 +893,31 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
with handler(enable_file_urls=True) as rh:
res = validate_and_send(rh, req)
- assert res.read() == b'foobar'
- res.close()
+ assert res.read(1) == b'f'
+ assert not res.fp.closed
+ assert res.read() == b'oobar'
+ # Should automatically close the underlying file object
+ assert res.fp.closed
os.unlink(tf.name)
+ def test_data_uri_auto_close(self, handler):
+ with handler() as rh:
+ res = validate_and_send(rh, Request('data:text/plain,hello%20world'))
+ assert res.read() == b'hello world'
+ # Should automatically close the underlying file object
+ assert res.fp.closed
+ assert res.closed
+
+ def test_http_response_auto_close(self, handler):
+ with handler() as rh:
+ res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
+ assert res.read() == b''
+ # Should automatically close the underlying file object in the HTTP Response
+ assert isinstance(res.fp, http.client.HTTPResponse)
+ assert res.fp.fp is None
+ assert res.closed
+
def test_http_error_returns_content(self, handler):
# urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost
def get_response():
@@ -1012,6 +1050,14 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
rh.close()
assert called
+ def test_http_response_auto_close(self, handler):
+ with handler() as rh:
+ res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
+ assert res.read() == b''
+ # Should automatically close the underlying file object in the HTTP Response
+ assert res.fp.closed
+ assert res.closed
+
@pytest.mark.parametrize('handler', ['CurlCFFI'], indirect=True)
class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
@@ -1177,6 +1223,14 @@ class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
assert res4.closed
assert res4._buffer == b''
+ def test_http_response_auto_close(self, handler):
+ with handler() as rh:
+ res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
+ assert res.read() == b''
+ # Should automatically close the underlying file object in the HTTP Response
+ assert res.fp.closed
+ assert res.closed
+
def run_validation(handler, error, req, **handler_kwargs):
with handler(**handler_kwargs) as rh:
@@ -2032,6 +2086,30 @@ class TestResponse:
assert res.info() is res.headers
assert res.getheader('test') == res.get_header('test')
+ def test_auto_close(self):
+ # Should mark the response as closed if the underlying file is closed
+ class AutoCloseBytesIO(io.BytesIO):
+ def read(self, size=-1, /):
+ data = super().read(size)
+ self.close()
+ return data
+
+ fp = AutoCloseBytesIO(b'test')
+ res = Response(fp, url='test://', headers={}, status=200)
+ assert not res.closed
+ res.read()
+ assert res.closed
+
+ def test_close(self):
+ # Should not call close() on the underlying file when already closed
+ fp = MagicMock()
+ fp.closed = False
+ res = Response(fp, url='test://', headers={}, status=200)
+ res.close()
+ fp.closed = True
+ res.close()
+ assert fp.close.call_count == 1
+
class TestImpersonateTarget:
@pytest.mark.parametrize('target_str,expected', [
diff --git a/yt_dlp/networking/_curlcffi.py b/yt_dlp/networking/_curlcffi.py
index 90570417bd..e6baf48780 100644
--- a/yt_dlp/networking/_curlcffi.py
+++ b/yt_dlp/networking/_curlcffi.py
@@ -96,7 +96,10 @@ class CurlCFFIResponseAdapter(Response):
def read(self, amt=None):
try:
- return self.fp.read(amt)
+ res = self.fp.read(amt)
+ if self.fp.closed:
+ self.close()
+ return res
except curl_cffi.requests.errors.RequestsError as e:
if e.code == CurlECode.PARTIAL_FILE:
content_length = e.response and int_or_none(e.response.headers.get('Content-Length'))
diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py
index 06d3220f4a..c7629c7c63 100644
--- a/yt_dlp/networking/_requests.py
+++ b/yt_dlp/networking/_requests.py
@@ -119,17 +119,22 @@ class RequestsResponseAdapter(Response):
self._requests_response = res
+ def _real_read(self, amt: int | None = None) -> bytes:
+ # Work around issue with `.read(amt)` then `.read()`
+ # See: https://github.com/urllib3/urllib3/issues/3636
+ if amt is None:
+ # compat: py3.9: Python 3.9 preallocates the whole read buffer, read in chunks
+ read_chunk = functools.partial(self.fp.read, 1 << 20, decode_content=True)
+ return b''.join(iter(read_chunk, b''))
+ # Interact with urllib3 response directly.
+ return self.fp.read(amt, decode_content=True)
+
def read(self, amt: int | None = None):
try:
- # Work around issue with `.read(amt)` then `.read()`
- # See: https://github.com/urllib3/urllib3/issues/3636
- if amt is None:
- # compat: py3.9: Python 3.9 preallocates the whole read buffer, read in chunks
- read_chunk = functools.partial(self.fp.read, 1 << 20, decode_content=True)
- return b''.join(iter(read_chunk, b''))
- # Interact with urllib3 response directly.
- return self.fp.read(amt, decode_content=True)
-
+ data = self._real_read(amt)
+ if self.fp.closed:
+ self.close()
+ return data
# See urllib3.response.HTTPResponse.read() for exceptions raised on read
except urllib3.exceptions.SSLError as e:
raise SSLError(cause=e) from e
diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py
index cb7a430bb3..34f2d45d3c 100644
--- a/yt_dlp/networking/_urllib.py
+++ b/yt_dlp/networking/_urllib.py
@@ -306,7 +306,25 @@ class UrllibResponseAdapter(Response):
def read(self, amt=None):
try:
- return self.fp.read(amt)
+ data = self.fp.read(amt)
+ underlying = getattr(self.fp, 'fp', None)
+ if isinstance(self.fp, http.client.HTTPResponse) and underlying is None:
+ # http.client.HTTPResponse automatically closes itself when fully read
+ self.close()
+ elif isinstance(self.fp, urllib.response.addinfourl) and underlying is not None:
+ # urllib's addinfourl does not close the underlying fp automatically when fully read
+ if isinstance(underlying, io.BytesIO):
+ # data URLs or in-memory responses (e.g. gzip/deflate/brotli decoded)
+ if underlying.tell() >= len(underlying.getbuffer()):
+ self.close()
+ elif isinstance(underlying, io.BufferedReader) and amt is None:
+ # file URLs.
+ # XXX: this will not mark the response as closed if it was fully read with amt.
+ self.close()
+ elif underlying is not None and underlying.closed:
+ # Catch-all for any cases where underlying file is closed
+ self.close()
+ return data
except Exception as e:
handle_response_read_exceptions(e)
raise e
diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py
index 6680d1c7c8..22c002cc44 100644
--- a/yt_dlp/networking/common.py
+++ b/yt_dlp/networking/common.py
@@ -554,12 +554,16 @@ class Response(io.IOBase):
# Expected errors raised here should be of type RequestError or subclasses.
# Subclasses should redefine this method with more precise error handling.
try:
- return self.fp.read(amt)
+ res = self.fp.read(amt)
+ if self.fp.closed:
+ self.close()
+ return res
except Exception as e:
raise TransportError(cause=e) from e
def close(self):
- self.fp.close()
+ if not self.fp.closed:
+ self.fp.close()
return super().close()
def get_header(self, name, default=None):