mirror of
				https://github.com/yt-dlp/yt-dlp.git
				synced 2025-10-31 14:45:14 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			506 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			506 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| 
 | |
| # Allow direct execution
 | |
| import os
 | |
| import sys
 | |
| import time
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from test.helper import verify_address_availability
 | |
| from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
 | |
| 
 | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 | |
| 
 | |
| import http.client
 | |
| import http.cookiejar
 | |
| import http.server
 | |
| import json
 | |
| import random
 | |
| import ssl
 | |
| import threading
 | |
| 
 | |
| from yt_dlp import socks, traverse_obj
 | |
| from yt_dlp.cookies import YoutubeDLCookieJar
 | |
| from yt_dlp.dependencies import websockets
 | |
| from yt_dlp.networking import Request
 | |
| from yt_dlp.networking.exceptions import (
 | |
|     CertificateVerifyError,
 | |
|     HTTPError,
 | |
|     ProxyError,
 | |
|     RequestError,
 | |
|     SSLError,
 | |
|     TransportError,
 | |
| )
 | |
| from yt_dlp.utils.networking import HTTPHeaderDict
 | |
| 
 | |
| TEST_DIR = os.path.dirname(os.path.abspath(__file__))
 | |
| 
 | |
| 
 | |
| def websocket_handler(websocket):
 | |
|     for message in websocket:
 | |
|         if isinstance(message, bytes):
 | |
|             if message == b'bytes':
 | |
|                 return websocket.send('2')
 | |
|         elif isinstance(message, str):
 | |
|             if message == 'headers':
 | |
|                 return websocket.send(json.dumps(dict(websocket.request.headers.raw_items())))
 | |
|             elif message == 'path':
 | |
|                 return websocket.send(websocket.request.path)
 | |
|             elif message == 'source_address':
 | |
|                 return websocket.send(websocket.remote_address[0])
 | |
|             elif message == 'str':
 | |
|                 return websocket.send('1')
 | |
|         return websocket.send(message)
 | |
| 
 | |
| 
 | |
| def process_request(self, request):
 | |
|     if request.path.startswith('/gen_'):
 | |
|         status = http.HTTPStatus(int(request.path[5:]))
 | |
|         if 300 <= status.value <= 300:
 | |
|             return websockets.http11.Response(
 | |
|                 status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
 | |
|         return self.protocol.reject(status.value, status.phrase)
 | |
|     elif request.path.startswith('/get_cookie'):
 | |
|         response = self.protocol.accept(request)
 | |
|         response.headers['Set-Cookie'] = 'test=ytdlp'
 | |
|         return response
 | |
|     return self.protocol.accept(request)
 | |
| 
 | |
| 
 | |
| def create_websocket_server(**ws_kwargs):
 | |
|     import websockets.sync.server
 | |
|     wsd = websockets.sync.server.serve(
 | |
|         websocket_handler, '127.0.0.1', 0,
 | |
|         process_request=process_request, open_timeout=2, **ws_kwargs)
 | |
|     ws_port = wsd.socket.getsockname()[1]
 | |
|     ws_server_thread = threading.Thread(target=wsd.serve_forever)
 | |
|     ws_server_thread.daemon = True
 | |
|     ws_server_thread.start()
 | |
|     return ws_server_thread, ws_port
 | |
| 
 | |
| 
 | |
| def create_ws_websocket_server():
 | |
|     return create_websocket_server()
 | |
| 
 | |
| 
 | |
| def create_wss_websocket_server():
 | |
|     certfn = os.path.join(TEST_DIR, 'testcert.pem')
 | |
|     sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 | |
|     sslctx.load_cert_chain(certfn, None)
 | |
|     return create_websocket_server(ssl=sslctx)
 | |
| 
 | |
| 
 | |
| MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
 | |
| 
 | |
| 
 | |
| def create_mtls_wss_websocket_server():
 | |
|     certfn = os.path.join(TEST_DIR, 'testcert.pem')
 | |
|     cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
 | |
| 
 | |
|     sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 | |
|     sslctx.verify_mode = ssl.CERT_REQUIRED
 | |
|     sslctx.load_verify_locations(cafile=cacertfn)
 | |
|     sslctx.load_cert_chain(certfn, None)
 | |
| 
 | |
|     return create_websocket_server(ssl=sslctx)
 | |
| 
 | |
| 
 | |
| def create_legacy_wss_websocket_server():
 | |
|     certfn = os.path.join(TEST_DIR, 'testcert.pem')
 | |
|     sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 | |
|     sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
 | |
|     sslctx.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
 | |
|     sslctx.load_cert_chain(certfn, None)
 | |
|     return create_websocket_server(ssl=sslctx)
 | |
| 
 | |
| 
 | |
| def ws_validate_and_send(rh, req):
 | |
|     rh.validate(req)
 | |
|     max_tries = 3
 | |
|     for i in range(max_tries):
 | |
|         try:
 | |
|             return rh.send(req)
 | |
|         except TransportError as e:
 | |
|             if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
 | |
|                 # websockets server sometimes hangs on new connections
 | |
|                 continue
 | |
|             raise
 | |
| 
 | |
| 
 | |
| @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
 | |
| @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
 | |
| class TestWebsSocketRequestHandlerConformance:
 | |
|     @classmethod
 | |
|     def setup_class(cls):
 | |
|         cls.ws_thread, cls.ws_port = create_ws_websocket_server()
 | |
|         cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
 | |
| 
 | |
|         cls.wss_thread, cls.wss_port = create_wss_websocket_server()
 | |
|         cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
 | |
| 
 | |
|         cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
 | |
|         cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
 | |
| 
 | |
|         cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
 | |
|         cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
 | |
| 
 | |
|         cls.legacy_wss_thread, cls.legacy_wss_port = create_legacy_wss_websocket_server()
 | |
|         cls.legacy_wss_host = f'wss://127.0.0.1:{cls.legacy_wss_port}'
 | |
| 
 | |
|     def test_basic_websockets(self, handler):
 | |
|         with handler() as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             assert 'upgrade' in ws.headers
 | |
|             assert ws.status == 101
 | |
|             ws.send('foo')
 | |
|             assert ws.recv() == 'foo'
 | |
|             ws.close()
 | |
| 
 | |
|     # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
 | |
|     @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
 | |
|     def test_send_types(self, handler, msg, opcode):
 | |
|         with handler() as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             ws.send(msg)
 | |
|             assert int(ws.recv()) == opcode
 | |
|             ws.close()
 | |
| 
 | |
|     def test_verify_cert(self, handler):
 | |
|         with handler() as rh:
 | |
|             with pytest.raises(CertificateVerifyError):
 | |
|                 ws_validate_and_send(rh, Request(self.wss_base_url))
 | |
| 
 | |
|         with handler(verify=False) as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.wss_base_url))
 | |
|             assert ws.status == 101
 | |
|             ws.close()
 | |
| 
 | |
|     def test_ssl_error(self, handler):
 | |
|         with handler(verify=False) as rh:
 | |
|             with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
 | |
|                 ws_validate_and_send(rh, Request(self.bad_wss_host))
 | |
|             assert not issubclass(exc_info.type, CertificateVerifyError)
 | |
| 
 | |
|     def test_legacy_ssl_extension(self, handler):
 | |
|         with handler(verify=False) as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.legacy_wss_host, extensions={'legacy_ssl': True}))
 | |
|             assert ws.status == 101
 | |
|             ws.close()
 | |
| 
 | |
|             # Ensure only applies to request extension
 | |
|             with pytest.raises(SSLError):
 | |
|                 ws_validate_and_send(rh, Request(self.legacy_wss_host))
 | |
| 
 | |
|     def test_legacy_ssl_support(self, handler):
 | |
|         with handler(verify=False, legacy_ssl_support=True) as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.legacy_wss_host))
 | |
|             assert ws.status == 101
 | |
|             ws.close()
 | |
| 
 | |
|     @pytest.mark.parametrize('path,expected', [
 | |
|         # Unicode characters should be encoded with uppercase percent-encoding
 | |
|         ('/中文', '/%E4%B8%AD%E6%96%87'),
 | |
|         # don't normalize existing percent encodings
 | |
|         ('/%c7%9f', '/%c7%9f'),
 | |
|     ])
 | |
|     def test_percent_encode(self, handler, path, expected):
 | |
|         with handler() as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
 | |
|             ws.send('path')
 | |
|             assert ws.recv() == expected
 | |
|             assert ws.status == 101
 | |
|             ws.close()
 | |
| 
 | |
|     def test_remove_dot_segments(self, handler):
 | |
|         with handler() as rh:
 | |
|             # This isn't a comprehensive test,
 | |
|             # but it should be enough to check whether the handler is removing dot segments
 | |
|             ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
 | |
|             assert ws.status == 101
 | |
|             ws.send('path')
 | |
|             assert ws.recv() == '/test'
 | |
|             ws.close()
 | |
| 
 | |
|     # We are restricted to known HTTP status codes in http.HTTPStatus
 | |
|     # Redirects are not supported for websockets
 | |
|     @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
 | |
|     def test_raise_http_error(self, handler, status):
 | |
|         with handler() as rh:
 | |
|             with pytest.raises(HTTPError) as exc_info:
 | |
|                 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
 | |
|             assert exc_info.value.status == status
 | |
| 
 | |
|     @pytest.mark.parametrize('params,extensions', [
 | |
|         ({'timeout': sys.float_info.min}, {}),
 | |
|         ({}, {'timeout': sys.float_info.min}),
 | |
|     ])
 | |
|     def test_read_timeout(self, handler, params, extensions):
 | |
|         with handler(**params) as rh:
 | |
|             with pytest.raises(TransportError):
 | |
|                 ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
 | |
| 
 | |
|     def test_connect_timeout(self, handler):
 | |
|         # nothing should be listening on this port
 | |
|         connect_timeout_url = 'ws://10.255.255.255'
 | |
|         with handler(timeout=0.01) as rh, pytest.raises(TransportError):
 | |
|             now = time.time()
 | |
|             ws_validate_and_send(rh, Request(connect_timeout_url))
 | |
|         assert time.time() - now < DEFAULT_TIMEOUT
 | |
| 
 | |
|         # Per request timeout, should override handler timeout
 | |
|         request = Request(connect_timeout_url, extensions={'timeout': 0.01})
 | |
|         with handler() as rh, pytest.raises(TransportError):
 | |
|             now = time.time()
 | |
|             ws_validate_and_send(rh, request)
 | |
|         assert time.time() - now < DEFAULT_TIMEOUT
 | |
| 
 | |
|     def test_cookies(self, handler):
 | |
|         cookiejar = YoutubeDLCookieJar()
 | |
|         cookiejar.set_cookie(http.cookiejar.Cookie(
 | |
|             version=0, name='test', value='ytdlp', port=None, port_specified=False,
 | |
|             domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
 | |
|             path_specified=True, secure=False, expires=None, discard=False, comment=None,
 | |
|             comment_url=None, rest={}))
 | |
| 
 | |
|         with handler(cookiejar=cookiejar) as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             ws.send('headers')
 | |
|             assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
 | |
|             ws.close()
 | |
| 
 | |
|         with handler() as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             ws.send('headers')
 | |
|             assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
 | |
|             ws.close()
 | |
| 
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
 | |
|             ws.send('headers')
 | |
|             assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
 | |
|             ws.close()
 | |
| 
 | |
|     @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
 | |
|     def test_cookie_sync_only_cookiejar(self, handler):
 | |
|         # Ensure that cookies are ONLY being handled by the cookiejar
 | |
|         with handler() as rh:
 | |
|             ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
 | |
|             ws.send('headers')
 | |
|             assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
 | |
|             ws.close()
 | |
| 
 | |
|     @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
 | |
|     def test_cookie_sync_delete_cookie(self, handler):
 | |
|         # Ensure that cookies are ONLY being handled by the cookiejar
 | |
|         cookiejar = YoutubeDLCookieJar()
 | |
|         with handler(verbose=True, cookiejar=cookiejar) as rh:
 | |
|             ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             ws.send('headers')
 | |
|             assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
 | |
|             ws.close()
 | |
|             cookiejar.clear_session_cookies()
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             ws.send('headers')
 | |
|             assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
 | |
|             ws.close()
 | |
| 
 | |
|     def test_source_address(self, handler):
 | |
|         source_address = f'127.0.0.{random.randint(5, 255)}'
 | |
|         verify_address_availability(source_address)
 | |
|         with handler(source_address=source_address) as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             ws.send('source_address')
 | |
|             assert source_address == ws.recv()
 | |
|             ws.close()
 | |
| 
 | |
|     def test_response_url(self, handler):
 | |
|         with handler() as rh:
 | |
|             url = f'{self.ws_base_url}/something'
 | |
|             ws = ws_validate_and_send(rh, Request(url))
 | |
|             assert ws.url == url
 | |
|             ws.close()
 | |
| 
 | |
|     def test_request_headers(self, handler):
 | |
|         with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
 | |
|             # Global Headers
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | |
|             ws.send('headers')
 | |
|             headers = HTTPHeaderDict(json.loads(ws.recv()))
 | |
|             assert headers['test1'] == 'test'
 | |
|             ws.close()
 | |
| 
 | |
|             # Per request headers, merged with global
 | |
|             ws = ws_validate_and_send(rh, Request(
 | |
|                 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
 | |
|             ws.send('headers')
 | |
|             headers = HTTPHeaderDict(json.loads(ws.recv()))
 | |
|             assert headers['test1'] == 'test'
 | |
|             assert headers['test2'] == 'changed'
 | |
|             assert headers['test3'] == 'test3'
 | |
|             ws.close()
 | |
| 
 | |
|     def test_keep_header_casing(self, handler):
 | |
|         with handler(headers=HTTPHeaderDict({'x-TeSt1': 'test'})) as rh:
 | |
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url, headers={'x-TeSt2': 'test'}, extensions={'keep_header_casing': True}))
 | |
|             ws.send('headers')
 | |
|             headers = json.loads(ws.recv())
 | |
|             assert 'x-TeSt1' in headers
 | |
|             assert 'x-TeSt2' in headers
 | |
| 
 | |
|     @pytest.mark.parametrize('client_cert', (
 | |
|         {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
 | |
|         {
 | |
|             'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
 | |
|             'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
 | |
|         },
 | |
|         {
 | |
|             'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
 | |
|             'client_certificate_password': 'foobar',
 | |
|         },
 | |
|         {
 | |
|             'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
 | |
|             'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
 | |
|             'client_certificate_password': 'foobar',
 | |
|         },
 | |
|     ))
 | |
|     def test_mtls(self, handler, client_cert):
 | |
|         with handler(
 | |
|             # Disable client-side validation of unacceptable self-signed testcert.pem
 | |
|             # The test is of a check on the server side, so unaffected
 | |
|             verify=False,
 | |
|             client_cert=client_cert,
 | |
|         ) as rh:
 | |
|             ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
 | |
| 
 | |
|     def test_request_disable_proxy(self, handler):
 | |
|         for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
 | |
|             # Given handler is configured with a proxy
 | |
|             with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
 | |
|                 # When a proxy is explicitly set to None for the request
 | |
|                 ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
 | |
|                 # Then no proxy should be used
 | |
|                 assert ws.status == 101
 | |
|                 ws.close()
 | |
| 
 | |
|     @pytest.mark.skip_handlers_if(
 | |
|         lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
 | |
|     def test_noproxy(self, handler):
 | |
|         for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
 | |
|             # Given the handler is configured with a proxy
 | |
|             with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
 | |
|                 for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
 | |
|                     # When request no proxy includes the request url host
 | |
|                     ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
 | |
|                     # Then the proxy should not be used
 | |
|                     assert ws.status == 101
 | |
|                     ws.close()
 | |
| 
 | |
|     @pytest.mark.skip_handlers_if(
 | |
|         lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
 | |
|     def test_allproxy(self, handler):
 | |
|         supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
 | |
|         # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
 | |
|         # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
 | |
|         with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
 | |
|             with pytest.raises(TransportError):
 | |
|                 ws_validate_and_send(rh, Request(self.ws_base_url)).close()
 | |
| 
 | |
|         with handler(timeout=0.1) as rh:
 | |
|             with pytest.raises(TransportError):
 | |
|                 ws_validate_and_send(
 | |
|                     rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
 | |
| 
 | |
| 
 | |
| def create_fake_ws_connection(raised):
 | |
|     import websockets.sync.client
 | |
| 
 | |
|     class FakeWsConnection(websockets.sync.client.ClientConnection):
 | |
|         def __init__(self, *args, **kwargs):
 | |
|             class FakeResponse:
 | |
|                 body = b''
 | |
|                 headers = {}
 | |
|                 status_code = 101
 | |
|                 reason_phrase = 'test'
 | |
| 
 | |
|             self.response = FakeResponse()
 | |
| 
 | |
|         def send(self, *args, **kwargs):
 | |
|             raise raised()
 | |
| 
 | |
|         def recv(self, *args, **kwargs):
 | |
|             raise raised()
 | |
| 
 | |
|         def close(self, *args, **kwargs):
 | |
|             return
 | |
| 
 | |
|     return FakeWsConnection()
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
 | |
| class TestWebsocketsRequestHandler:
 | |
|     @pytest.mark.parametrize('raised,expected', [
 | |
|         # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
 | |
|         (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
 | |
|         # Requires a response object. Should be covered by HTTP error tests.
 | |
|         # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
 | |
|         (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
 | |
|         # These are subclasses of InvalidHandshake
 | |
|         (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
 | |
|         (lambda: websockets.exceptions.NegotiationError(), TransportError),
 | |
|         # Catch-all
 | |
|         (lambda: websockets.exceptions.WebSocketException(), TransportError),
 | |
|         (lambda: TimeoutError(), TransportError),
 | |
|         # These may be raised by our create_connection implementation, which should also be caught
 | |
|         (lambda: OSError(), TransportError),
 | |
|         (lambda: ssl.SSLError(), SSLError),
 | |
|         (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
 | |
|         (lambda: socks.ProxyError(), ProxyError),
 | |
|     ])
 | |
|     def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
 | |
|         import websockets.sync.client
 | |
| 
 | |
|         import yt_dlp.networking._websockets
 | |
|         with handler() as rh:
 | |
|             def fake_connect(*args, **kwargs):
 | |
|                 raise raised()
 | |
|             monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
 | |
|             monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
 | |
|             with pytest.raises(expected) as exc_info:
 | |
|                 rh.send(Request('ws://fake-url'))
 | |
|             assert exc_info.type is expected
 | |
| 
 | |
|     @pytest.mark.parametrize('raised,expected,match', [
 | |
|         # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
 | |
|         (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
 | |
|         (lambda: RuntimeError(), TransportError, None),
 | |
|         (lambda: TimeoutError(), TransportError, None),
 | |
|         (lambda: TypeError(), RequestError, None),
 | |
|         (lambda: socks.ProxyError(), ProxyError, None),
 | |
|         # Catch-all
 | |
|         (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
 | |
|     ])
 | |
|     def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
 | |
|         from yt_dlp.networking._websockets import WebsocketsResponseAdapter
 | |
|         ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
 | |
|         with pytest.raises(expected, match=match) as exc_info:
 | |
|             ws.send('test')
 | |
|         assert exc_info.type is expected
 | |
| 
 | |
|     @pytest.mark.parametrize('raised,expected,match', [
 | |
|         # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
 | |
|         (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
 | |
|         (lambda: RuntimeError(), TransportError, None),
 | |
|         (lambda: TimeoutError(), TransportError, None),
 | |
|         (lambda: socks.ProxyError(), ProxyError, None),
 | |
|         # Catch-all
 | |
|         (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
 | |
|     ])
 | |
|     def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
 | |
|         from yt_dlp.networking._websockets import WebsocketsResponseAdapter
 | |
|         ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
 | |
|         with pytest.raises(expected, match=match) as exc_info:
 | |
|             ws.recv()
 | |
|         assert exc_info.type is expected
 | 
