mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-11-13 13:05:13 +00:00
[networking] Ensure underlying file object is closed when fully read (#14935)
Fixes https://github.com/yt-dlp/yt-dlp/issues/14891 Authored by: coletdjnz
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
# Allow direct execution
|
# Allow direct execution
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -614,8 +615,11 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||||||
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
|
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
|
||||||
def test_gzip_trailing_garbage(self, handler):
|
def test_gzip_trailing_garbage(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
data = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode()
|
res = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage'))
|
||||||
|
data = res.read().decode()
|
||||||
assert data == '<html><video src="/vid.mp4" /></html>'
|
assert data == '<html><video src="/vid.mp4" /></html>'
|
||||||
|
# Should auto-close and mark the response adaptor as closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
@pytest.mark.skip_handler('CurlCFFI', 'not applicable to curl-cffi')
|
@pytest.mark.skip_handler('CurlCFFI', 'not applicable to curl-cffi')
|
||||||
@pytest.mark.skipif(not brotli, reason='brotli support is not installed')
|
@pytest.mark.skipif(not brotli, reason='brotli support is not installed')
|
||||||
@@ -627,6 +631,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||||||
headers={'ytdl-encoding': 'br'}))
|
headers={'ytdl-encoding': 'br'}))
|
||||||
assert res.headers.get('Content-Encoding') == 'br'
|
assert res.headers.get('Content-Encoding') == 'br'
|
||||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||||
|
# Should auto-close and mark the response adaptor as closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
def test_deflate(self, handler):
|
def test_deflate(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
@@ -636,6 +642,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||||||
headers={'ytdl-encoding': 'deflate'}))
|
headers={'ytdl-encoding': 'deflate'}))
|
||||||
assert res.headers.get('Content-Encoding') == 'deflate'
|
assert res.headers.get('Content-Encoding') == 'deflate'
|
||||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||||
|
# Should auto-close and mark the response adaptor as closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
def test_gzip(self, handler):
|
def test_gzip(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
@@ -645,6 +653,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||||||
headers={'ytdl-encoding': 'gzip'}))
|
headers={'ytdl-encoding': 'gzip'}))
|
||||||
assert res.headers.get('Content-Encoding') == 'gzip'
|
assert res.headers.get('Content-Encoding') == 'gzip'
|
||||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||||
|
# Should auto-close and mark the response adaptor as closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
def test_multiple_encodings(self, handler):
|
def test_multiple_encodings(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
@@ -655,6 +665,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||||||
headers={'ytdl-encoding': pair}))
|
headers={'ytdl-encoding': pair}))
|
||||||
assert res.headers.get('Content-Encoding') == pair
|
assert res.headers.get('Content-Encoding') == pair
|
||||||
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
assert res.read() == b'<html><video src="/vid.mp4" /></html>'
|
||||||
|
# Should auto-close and mark the response adaptor as closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
|
@pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
|
||||||
def test_unsupported_encoding(self, handler):
|
def test_unsupported_encoding(self, handler):
|
||||||
@@ -665,6 +677,8 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||||||
headers={'ytdl-encoding': 'unsupported', 'Accept-Encoding': '*'}))
|
headers={'ytdl-encoding': 'unsupported', 'Accept-Encoding': '*'}))
|
||||||
assert res.headers.get('Content-Encoding') == 'unsupported'
|
assert res.headers.get('Content-Encoding') == 'unsupported'
|
||||||
assert res.read() == b'raw'
|
assert res.read() == b'raw'
|
||||||
|
# Should auto-close and mark the response adaptor as closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
def test_read(self, handler):
|
def test_read(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
@@ -672,9 +686,13 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||||||
rh, Request(f'http://127.0.0.1:{self.http_port}/headers'))
|
rh, Request(f'http://127.0.0.1:{self.http_port}/headers'))
|
||||||
assert res.readable()
|
assert res.readable()
|
||||||
assert res.read(1) == b'H'
|
assert res.read(1) == b'H'
|
||||||
|
# Ensure we don't close the adaptor yet
|
||||||
|
assert not res.closed
|
||||||
assert res.read(3) == b'ost'
|
assert res.read(3) == b'ost'
|
||||||
assert res.read().decode().endswith('\n\n')
|
assert res.read().decode().endswith('\n\n')
|
||||||
assert res.read() == b''
|
assert res.read() == b''
|
||||||
|
# Should auto-close and mark the response adaptor as closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
def test_request_disable_proxy(self, handler):
|
def test_request_disable_proxy(self, handler):
|
||||||
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
|
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
|
||||||
@@ -875,11 +893,31 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
|
|||||||
|
|
||||||
with handler(enable_file_urls=True) as rh:
|
with handler(enable_file_urls=True) as rh:
|
||||||
res = validate_and_send(rh, req)
|
res = validate_and_send(rh, req)
|
||||||
assert res.read() == b'foobar'
|
assert res.read(1) == b'f'
|
||||||
res.close()
|
assert not res.fp.closed
|
||||||
|
assert res.read() == b'oobar'
|
||||||
|
# Should automatically close the underlying file object
|
||||||
|
assert res.fp.closed
|
||||||
|
|
||||||
os.unlink(tf.name)
|
os.unlink(tf.name)
|
||||||
|
|
||||||
|
def test_data_uri_auto_close(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
res = validate_and_send(rh, Request('data:text/plain,hello%20world'))
|
||||||
|
assert res.read() == b'hello world'
|
||||||
|
# Should automatically close the underlying file object
|
||||||
|
assert res.fp.closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
|
def test_http_response_auto_close(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
|
||||||
|
assert res.read() == b'<html></html>'
|
||||||
|
# Should automatically close the underlying file object in the HTTP Response
|
||||||
|
assert isinstance(res.fp, http.client.HTTPResponse)
|
||||||
|
assert res.fp.fp is None
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
def test_http_error_returns_content(self, handler):
|
def test_http_error_returns_content(self, handler):
|
||||||
# urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost
|
# urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost
|
||||||
def get_response():
|
def get_response():
|
||||||
@@ -1012,6 +1050,14 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
|
|||||||
rh.close()
|
rh.close()
|
||||||
assert called
|
assert called
|
||||||
|
|
||||||
|
def test_http_response_auto_close(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
|
||||||
|
assert res.read() == b'<html></html>'
|
||||||
|
# Should automatically close the underlying file object in the HTTP Response
|
||||||
|
assert res.fp.closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['CurlCFFI'], indirect=True)
|
@pytest.mark.parametrize('handler', ['CurlCFFI'], indirect=True)
|
||||||
class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
|
class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
|
||||||
@@ -1177,6 +1223,14 @@ class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
|
|||||||
assert res4.closed
|
assert res4.closed
|
||||||
assert res4._buffer == b''
|
assert res4._buffer == b''
|
||||||
|
|
||||||
|
def test_http_response_auto_close(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/gen_200'))
|
||||||
|
assert res.read() == b'<html></html>'
|
||||||
|
# Should automatically close the underlying file object in the HTTP Response
|
||||||
|
assert res.fp.closed
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
|
|
||||||
def run_validation(handler, error, req, **handler_kwargs):
|
def run_validation(handler, error, req, **handler_kwargs):
|
||||||
with handler(**handler_kwargs) as rh:
|
with handler(**handler_kwargs) as rh:
|
||||||
@@ -2032,6 +2086,30 @@ class TestResponse:
|
|||||||
assert res.info() is res.headers
|
assert res.info() is res.headers
|
||||||
assert res.getheader('test') == res.get_header('test')
|
assert res.getheader('test') == res.get_header('test')
|
||||||
|
|
||||||
|
def test_auto_close(self):
|
||||||
|
# Should mark the response as closed if the underlying file is closed
|
||||||
|
class AutoCloseBytesIO(io.BytesIO):
|
||||||
|
def read(self, size=-1, /):
|
||||||
|
data = super().read(size)
|
||||||
|
self.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
fp = AutoCloseBytesIO(b'test')
|
||||||
|
res = Response(fp, url='test://', headers={}, status=200)
|
||||||
|
assert not res.closed
|
||||||
|
res.read()
|
||||||
|
assert res.closed
|
||||||
|
|
||||||
|
def test_close(self):
|
||||||
|
# Should not call close() on the underlying file when already closed
|
||||||
|
fp = MagicMock()
|
||||||
|
fp.closed = False
|
||||||
|
res = Response(fp, url='test://', headers={}, status=200)
|
||||||
|
res.close()
|
||||||
|
fp.closed = True
|
||||||
|
res.close()
|
||||||
|
assert fp.close.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
class TestImpersonateTarget:
|
class TestImpersonateTarget:
|
||||||
@pytest.mark.parametrize('target_str,expected', [
|
@pytest.mark.parametrize('target_str,expected', [
|
||||||
|
|||||||
@@ -96,7 +96,10 @@ class CurlCFFIResponseAdapter(Response):
|
|||||||
|
|
||||||
def read(self, amt=None):
|
def read(self, amt=None):
|
||||||
try:
|
try:
|
||||||
return self.fp.read(amt)
|
res = self.fp.read(amt)
|
||||||
|
if self.fp.closed:
|
||||||
|
self.close()
|
||||||
|
return res
|
||||||
except curl_cffi.requests.errors.RequestsError as e:
|
except curl_cffi.requests.errors.RequestsError as e:
|
||||||
if e.code == CurlECode.PARTIAL_FILE:
|
if e.code == CurlECode.PARTIAL_FILE:
|
||||||
content_length = e.response and int_or_none(e.response.headers.get('Content-Length'))
|
content_length = e.response and int_or_none(e.response.headers.get('Content-Length'))
|
||||||
|
|||||||
@@ -119,17 +119,22 @@ class RequestsResponseAdapter(Response):
|
|||||||
|
|
||||||
self._requests_response = res
|
self._requests_response = res
|
||||||
|
|
||||||
|
def _real_read(self, amt: int | None = None) -> bytes:
|
||||||
|
# Work around issue with `.read(amt)` then `.read()`
|
||||||
|
# See: https://github.com/urllib3/urllib3/issues/3636
|
||||||
|
if amt is None:
|
||||||
|
# compat: py3.9: Python 3.9 preallocates the whole read buffer, read in chunks
|
||||||
|
read_chunk = functools.partial(self.fp.read, 1 << 20, decode_content=True)
|
||||||
|
return b''.join(iter(read_chunk, b''))
|
||||||
|
# Interact with urllib3 response directly.
|
||||||
|
return self.fp.read(amt, decode_content=True)
|
||||||
|
|
||||||
def read(self, amt: int | None = None):
|
def read(self, amt: int | None = None):
|
||||||
try:
|
try:
|
||||||
# Work around issue with `.read(amt)` then `.read()`
|
data = self._real_read(amt)
|
||||||
# See: https://github.com/urllib3/urllib3/issues/3636
|
if self.fp.closed:
|
||||||
if amt is None:
|
self.close()
|
||||||
# compat: py3.9: Python 3.9 preallocates the whole read buffer, read in chunks
|
return data
|
||||||
read_chunk = functools.partial(self.fp.read, 1 << 20, decode_content=True)
|
|
||||||
return b''.join(iter(read_chunk, b''))
|
|
||||||
# Interact with urllib3 response directly.
|
|
||||||
return self.fp.read(amt, decode_content=True)
|
|
||||||
|
|
||||||
# See urllib3.response.HTTPResponse.read() for exceptions raised on read
|
# See urllib3.response.HTTPResponse.read() for exceptions raised on read
|
||||||
except urllib3.exceptions.SSLError as e:
|
except urllib3.exceptions.SSLError as e:
|
||||||
raise SSLError(cause=e) from e
|
raise SSLError(cause=e) from e
|
||||||
|
|||||||
@@ -306,7 +306,25 @@ class UrllibResponseAdapter(Response):
|
|||||||
|
|
||||||
def read(self, amt=None):
|
def read(self, amt=None):
|
||||||
try:
|
try:
|
||||||
return self.fp.read(amt)
|
data = self.fp.read(amt)
|
||||||
|
underlying = getattr(self.fp, 'fp', None)
|
||||||
|
if isinstance(self.fp, http.client.HTTPResponse) and underlying is None:
|
||||||
|
# http.client.HTTPResponse automatically closes itself when fully read
|
||||||
|
self.close()
|
||||||
|
elif isinstance(self.fp, urllib.response.addinfourl) and underlying is not None:
|
||||||
|
# urllib's addinfourl does not close the underlying fp automatically when fully read
|
||||||
|
if isinstance(underlying, io.BytesIO):
|
||||||
|
# data URLs or in-memory responses (e.g. gzip/deflate/brotli decoded)
|
||||||
|
if underlying.tell() >= len(underlying.getbuffer()):
|
||||||
|
self.close()
|
||||||
|
elif isinstance(underlying, io.BufferedReader) and amt is None:
|
||||||
|
# file URLs.
|
||||||
|
# XXX: this will not mark the response as closed if it was fully read with amt.
|
||||||
|
self.close()
|
||||||
|
elif underlying is not None and underlying.closed:
|
||||||
|
# Catch-all for any cases where underlying file is closed
|
||||||
|
self.close()
|
||||||
|
return data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
handle_response_read_exceptions(e)
|
handle_response_read_exceptions(e)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -554,12 +554,16 @@ class Response(io.IOBase):
|
|||||||
# Expected errors raised here should be of type RequestError or subclasses.
|
# Expected errors raised here should be of type RequestError or subclasses.
|
||||||
# Subclasses should redefine this method with more precise error handling.
|
# Subclasses should redefine this method with more precise error handling.
|
||||||
try:
|
try:
|
||||||
return self.fp.read(amt)
|
res = self.fp.read(amt)
|
||||||
|
if self.fp.closed:
|
||||||
|
self.close()
|
||||||
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise TransportError(cause=e) from e
|
raise TransportError(cause=e) from e
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.fp.close()
|
if not self.fp.closed:
|
||||||
|
self.fp.close()
|
||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
def get_header(self, name, default=None):
|
def get_header(self, name, default=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user