mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-11-13 04:55:13 +00:00
[networking] Ensure underlying file object is closed when fully read (#14935)
Fixes https://github.com/yt-dlp/yt-dlp/issues/14891 Authored by: coletdjnz
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
# Allow direct execution
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -614,8 +615,11 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
|
||||
def test_gzip_trailing_garbage(self, handler):
|
||||
with handler() as rh:
|
||||
data = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode()
|
||||
res = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage'))
|
||||
data = res.read().decode()
|
||||
assert data == '<html><video src="/vid.mp4" /></html>'
|
||||
# Should auto-close and mark the response adaptor as closed
|
||||
assert res.closed
|
||||
|
||||
@pytest.mark.skip_handler('CurlCFFI', 'not applicable to curl-cffi')
|
||||
@pytest.mark.skipif(not brotli, reason='brotli support is not installed')
|
||||
@@ -627,6 +631,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||
headers={'ytdl-encoding': 'br'}))
|
||||
assert res.headers.get('Content-Encoding') == 'br'
|
||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||
# Should auto-close and mark the response adaptor as closed
|
||||
assert res.closed
|
||||
|
||||
def test_deflate(self, handler):
|
||||
with handler() as rh:
|
||||
@@ -636,6 +642,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||
headers={'ytdl-encoding': 'deflate'}))
|
||||
assert res.headers.get('Content-Encoding') == 'deflate'
|
||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||
# Should auto-close and mark the response adaptor as closed
|
||||
assert res.closed
|
||||
|
||||
def test_gzip(self, handler):
|
||||
with handler() as rh:
|
||||
@@ -645,6 +653,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||
headers={'ytdl-encoding': 'gzip'}))
|
||||
assert res.headers.get('Content-Encoding') == 'gzip'
|
||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||
# Should auto-close and mark the response adaptor as closed
|
||||
assert res.closed
|
||||
|
||||
def test_multiple_encodings(self, handler):
|
||||
with handler() as rh:
|
||||
@@ -655,6 +665,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||
headers={'ytdl-encoding': pair}))
|
||||
assert res.headers.get('Content-Encoding') == pair
|
||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||
# Should auto-close and mark the response adaptor as closed
|
||||
assert res.closed
|
||||
|
||||
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
|
||||
def test_unsupported_encoding(self, handler):
|
||||
@@ -665,6 +677,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||
headers={'ytdl-encoding': 'unsupported', 'Accept-Encoding': '*'}))
|
||||
assert res.headers.get('Content-Encoding') == 'unsupported'
|
||||
assert res.read() == b'raw'
|
||||
# Should auto-close and mark the response adaptor as closed
|
||||
assert res.closed
|
||||
|
||||
def test_read(self, handler):
|
||||
with handler() as rh:
|
||||
@@ -672,9 +686,13 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
||||
rh, Request(f'http://127.0.0.1:{self.http_port}/headers'))
|
||||
assert res.readable()
|
||||
assert res.read(1) == b'H'
|
||||
# Ensure we don't close the adaptor yet
|
||||
assert not res.closed
|
||||
assert res.read(3) == b'ost'
|
||||
assert res.read().decode().endswith('\n\n')
|
||||
assert res.read() == b''
|
||||
# Should auto-close and mark the response adaptor as closed
|
||||
assert res.closed
|
||||
|
||||
def test_request_disable_proxy(self, handler):
|
||||
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
|
||||
@@ -875,11 +893,31 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
|
||||
|
||||
with handler(enable_file_urls=True) as rh:
|
||||
res = validate_and_send(rh, req)
|
||||
assert res.read() == b'foobar'
|
||||
res.close()
|
||||
assert res.read(1) == b'f'
|
||||
assert not res.fp.closed
|
||||
assert res.read() == b'oobar'
|
||||
# Should automatically close the underlying file object
|
||||
assert res.fp.closed
|
||||
|
||||
os.unlink(tf.name)
|
||||
|
||||
def test_data_uri_auto_close(self, handler):
|
||||
with handler() as rh:
|
||||
res = validate_and_send(rh, Request('data:text/plain,hello%20world'))
|
||||
assert res.read() == b'hello world'
|
||||
# Should automatically close the underlying file object
|
||||
assert res.fp.closed
|
||||
assert res.closed
|
||||
|
||||
def test_http_response_auto_close(self, handler):
|
||||
with handler() as rh:
|
||||
res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
|
||||
assert res.read() == b'<html></html>'
|
||||
# Should automatically close the underlying file object in the HTTP Response
|
||||
assert isinstance(res.fp, http.client.HTTPResponse)
|
||||
assert res.fp.fp is None
|
||||
assert res.closed
|
||||
|
||||
def test_http_error_returns_content(self, handler):
|
||||
# urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost
|
||||
def get_response():
|
||||
@@ -1012,6 +1050,14 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
|
||||
rh.close()
|
||||
assert called
|
||||
|
||||
def test_http_response_auto_close(self, handler):
|
||||
with handler() as rh:
|
||||
res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
|
||||
assert res.read() == b'<html></html>'
|
||||
# Should automatically close the underlying file object in the HTTP Response
|
||||
assert res.fp.closed
|
||||
assert res.closed
|
||||
|
||||
|
||||
@pytest.mark.parametrize('handler', ['CurlCFFI'], indirect=True)
|
||||
class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
|
||||
@@ -1177,6 +1223,14 @@ class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
|
||||
assert res4.closed
|
||||
assert res4._buffer == b''
|
||||
|
||||
def test_http_response_auto_close(self, handler):
|
||||
with handler() as rh:
|
||||
res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
|
||||
assert res.read() == b'<html></html>'
|
||||
# Should automatically close the underlying file object in the HTTP Response
|
||||
assert res.fp.closed
|
||||
assert res.closed
|
||||
|
||||
|
||||
def run_validation(handler, error, req, **handler_kwargs):
|
||||
with handler(**handler_kwargs) as rh:
|
||||
@@ -2032,6 +2086,30 @@ class TestResponse:
|
||||
assert res.info() is res.headers
|
||||
assert res.getheader('test') == res.get_header('test')
|
||||
|
||||
def test_auto_close(self):
|
||||
# Should mark the response as closed if the underlying file is closed
|
||||
class AutoCloseBytesIO(io.BytesIO):
|
||||
def read(self, size=-1, /):
|
||||
data = super().read(size)
|
||||
self.close()
|
||||
return data
|
||||
|
||||
fp = AutoCloseBytesIO(b'test')
|
||||
res = Response(fp, url='test://', headers={}, status=200)
|
||||
assert not res.closed
|
||||
res.read()
|
||||
assert res.closed
|
||||
|
||||
def test_close(self):
|
||||
# Should not call close() on the underlying file when already closed
|
||||
fp = MagicMock()
|
||||
fp.closed = False
|
||||
res = Response(fp, url='test://', headers={}, status=200)
|
||||
res.close()
|
||||
fp.closed = True
|
||||
res.close()
|
||||
assert fp.close.call_count == 1
|
||||
|
||||
|
||||
class TestImpersonateTarget:
|
||||
@pytest.mark.parametrize('target_str,expected', [
|
||||
|
||||
@@ -96,7 +96,10 @@ class CurlCFFIResponseAdapter(Response):
|
||||
|
||||
def read(self, amt=None):
|
||||
try:
|
||||
return self.fp.read(amt)
|
||||
res = self.fp.read(amt)
|
||||
if self.fp.closed:
|
||||
self.close()
|
||||
return res
|
||||
except curl_cffi.requests.errors.RequestsError as e:
|
||||
if e.code == CurlECode.PARTIAL_FILE:
|
||||
content_length = e.response and int_or_none(e.response.headers.get('Content-Length'))
|
||||
|
||||
@@ -119,17 +119,22 @@ class RequestsResponseAdapter(Response):
|
||||
|
||||
self._requests_response = res
|
||||
|
||||
def _real_read(self, amt: int | None = None) -> bytes:
|
||||
# Work around issue with `.read(amt)` then `.read()`
|
||||
# See: https://github.com/urllib3/urllib3/issues/3636
|
||||
if amt is None:
|
||||
# compat: py3.9: Python 3.9 preallocates the whole read buffer, read in chunks
|
||||
read_chunk = functools.partial(self.fp.read, 1 << 20, decode_content=True)
|
||||
return b''.join(iter(read_chunk, b''))
|
||||
# Interact with urllib3 response directly.
|
||||
return self.fp.read(amt, decode_content=True)
|
||||
|
||||
def read(self, amt: int | None = None):
|
||||
try:
|
||||
# Work around issue with `.read(amt)` then `.read()`
|
||||
# See: https://github.com/urllib3/urllib3/issues/3636
|
||||
if amt is None:
|
||||
# compat: py3.9: Python 3.9 preallocates the whole read buffer, read in chunks
|
||||
read_chunk = functools.partial(self.fp.read, 1 << 20, decode_content=True)
|
||||
return b''.join(iter(read_chunk, b''))
|
||||
# Interact with urllib3 response directly.
|
||||
return self.fp.read(amt, decode_content=True)
|
||||
|
||||
data = self._real_read(amt)
|
||||
if self.fp.closed:
|
||||
self.close()
|
||||
return data
|
||||
# See urllib3.response.HTTPResponse.read() for exceptions raised on read
|
||||
except urllib3.exceptions.SSLError as e:
|
||||
raise SSLError(cause=e) from e
|
||||
|
||||
@@ -306,7 +306,25 @@ class UrllibResponseAdapter(Response):
|
||||
|
||||
def read(self, amt=None):
|
||||
try:
|
||||
return self.fp.read(amt)
|
||||
data = self.fp.read(amt)
|
||||
underlying = getattr(self.fp, 'fp', None)
|
||||
if isinstance(self.fp, http.client.HTTPResponse) and underlying is None:
|
||||
# http.client.HTTPResponse automatically closes itself when fully read
|
||||
self.close()
|
||||
elif isinstance(self.fp, urllib.response.addinfourl) and underlying is not None:
|
||||
# urllib's addinfourl does not close the underlying fp automatically when fully read
|
||||
if isinstance(underlying, io.BytesIO):
|
||||
# data URLs or in-memory responses (e.g. gzip/deflate/brotli decoded)
|
||||
if underlying.tell() >= len(underlying.getbuffer()):
|
||||
self.close()
|
||||
elif isinstance(underlying, io.BufferedReader) and amt is None:
|
||||
# file URLs.
|
||||
# XXX: this will not mark the response as closed if it was fully read with amt.
|
||||
self.close()
|
||||
elif underlying is not None and underlying.closed:
|
||||
# Catch-all for any cases where underlying file is closed
|
||||
self.close()
|
||||
return data
|
||||
except Exception as e:
|
||||
handle_response_read_exceptions(e)
|
||||
raise e
|
||||
|
||||
@@ -554,12 +554,16 @@ class Response(io.IOBase):
|
||||
# Expected errors raised here should be of type RequestError or subclasses.
|
||||
# Subclasses should redefine this method with more precise error handling.
|
||||
try:
|
||||
return self.fp.read(amt)
|
||||
res = self.fp.read(amt)
|
||||
if self.fp.closed:
|
||||
self.close()
|
||||
return res
|
||||
except Exception as e:
|
||||
raise TransportError(cause=e) from e
|
||||
|
||||
def close(self):
|
||||
self.fp.close()
|
||||
if not self.fp.closed:
|
||||
self.fp.close()
|
||||
return super().close()
|
||||
|
||||
def get_header(self, name, default=None):
|
||||
|
||||
Reference in New Issue
Block a user