mirror of
				https://github.com/yt-dlp/yt-dlp.git
				synced 2025-11-04 00:25:15 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			498 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			498 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#!/usr/bin/env python3
 | 
						|
 | 
						|
# Allow direct execution
 | 
						|
import os
 | 
						|
import sys
 | 
						|
import time
 | 
						|
 | 
						|
import pytest
 | 
						|
 | 
						|
from test.helper import verify_address_availability
 | 
						|
from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
 | 
						|
 | 
						|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 | 
						|
 | 
						|
import http.client
 | 
						|
import http.cookiejar
 | 
						|
import http.server
 | 
						|
import json
 | 
						|
import random
 | 
						|
import ssl
 | 
						|
import threading
 | 
						|
 | 
						|
from yt_dlp import socks, traverse_obj
 | 
						|
from yt_dlp.cookies import YoutubeDLCookieJar
 | 
						|
from yt_dlp.dependencies import websockets
 | 
						|
from yt_dlp.networking import Request
 | 
						|
from yt_dlp.networking.exceptions import (
 | 
						|
    CertificateVerifyError,
 | 
						|
    HTTPError,
 | 
						|
    ProxyError,
 | 
						|
    RequestError,
 | 
						|
    SSLError,
 | 
						|
    TransportError,
 | 
						|
)
 | 
						|
from yt_dlp.utils.networking import HTTPHeaderDict
 | 
						|
 | 
						|
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
 | 
						|
 | 
						|
 | 
						|
def websocket_handler(websocket):
 | 
						|
    for message in websocket:
 | 
						|
        if isinstance(message, bytes):
 | 
						|
            if message == b'bytes':
 | 
						|
                return websocket.send('2')
 | 
						|
        elif isinstance(message, str):
 | 
						|
            if message == 'headers':
 | 
						|
                return websocket.send(json.dumps(dict(websocket.request.headers)))
 | 
						|
            elif message == 'path':
 | 
						|
                return websocket.send(websocket.request.path)
 | 
						|
            elif message == 'source_address':
 | 
						|
                return websocket.send(websocket.remote_address[0])
 | 
						|
            elif message == 'str':
 | 
						|
                return websocket.send('1')
 | 
						|
        return websocket.send(message)
 | 
						|
 | 
						|
 | 
						|
def process_request(self, request):
 | 
						|
    if request.path.startswith('/gen_'):
 | 
						|
        status = http.HTTPStatus(int(request.path[5:]))
 | 
						|
        if 300 <= status.value <= 300:
 | 
						|
            return websockets.http11.Response(
 | 
						|
                status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
 | 
						|
        return self.protocol.reject(status.value, status.phrase)
 | 
						|
    elif request.path.startswith('/get_cookie'):
 | 
						|
        response = self.protocol.accept(request)
 | 
						|
        response.headers['Set-Cookie'] = 'test=ytdlp'
 | 
						|
        return response
 | 
						|
    return self.protocol.accept(request)
 | 
						|
 | 
						|
 | 
						|
def create_websocket_server(**ws_kwargs):
 | 
						|
    import websockets.sync.server
 | 
						|
    wsd = websockets.sync.server.serve(
 | 
						|
        websocket_handler, '127.0.0.1', 0,
 | 
						|
        process_request=process_request, open_timeout=2, **ws_kwargs)
 | 
						|
    ws_port = wsd.socket.getsockname()[1]
 | 
						|
    ws_server_thread = threading.Thread(target=wsd.serve_forever)
 | 
						|
    ws_server_thread.daemon = True
 | 
						|
    ws_server_thread.start()
 | 
						|
    return ws_server_thread, ws_port
 | 
						|
 | 
						|
 | 
						|
def create_ws_websocket_server():
 | 
						|
    return create_websocket_server()
 | 
						|
 | 
						|
 | 
						|
def create_wss_websocket_server():
 | 
						|
    certfn = os.path.join(TEST_DIR, 'testcert.pem')
 | 
						|
    sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 | 
						|
    sslctx.load_cert_chain(certfn, None)
 | 
						|
    return create_websocket_server(ssl=sslctx)
 | 
						|
 | 
						|
 | 
						|
MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
 | 
						|
 | 
						|
 | 
						|
def create_mtls_wss_websocket_server():
 | 
						|
    certfn = os.path.join(TEST_DIR, 'testcert.pem')
 | 
						|
    cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
 | 
						|
 | 
						|
    sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 | 
						|
    sslctx.verify_mode = ssl.CERT_REQUIRED
 | 
						|
    sslctx.load_verify_locations(cafile=cacertfn)
 | 
						|
    sslctx.load_cert_chain(certfn, None)
 | 
						|
 | 
						|
    return create_websocket_server(ssl=sslctx)
 | 
						|
 | 
						|
 | 
						|
def create_legacy_wss_websocket_server():
 | 
						|
    certfn = os.path.join(TEST_DIR, 'testcert.pem')
 | 
						|
    sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 | 
						|
    sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
 | 
						|
    sslctx.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
 | 
						|
    sslctx.load_cert_chain(certfn, None)
 | 
						|
    return create_websocket_server(ssl=sslctx)
 | 
						|
 | 
						|
 | 
						|
def ws_validate_and_send(rh, req):
 | 
						|
    rh.validate(req)
 | 
						|
    max_tries = 3
 | 
						|
    for i in range(max_tries):
 | 
						|
        try:
 | 
						|
            return rh.send(req)
 | 
						|
        except TransportError as e:
 | 
						|
            if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
 | 
						|
                # websockets server sometimes hangs on new connections
 | 
						|
                continue
 | 
						|
            raise
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
 | 
						|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
 | 
						|
class TestWebsSocketRequestHandlerConformance:
 | 
						|
    @classmethod
 | 
						|
    def setup_class(cls):
 | 
						|
        cls.ws_thread, cls.ws_port = create_ws_websocket_server()
 | 
						|
        cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
 | 
						|
 | 
						|
        cls.wss_thread, cls.wss_port = create_wss_websocket_server()
 | 
						|
        cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
 | 
						|
 | 
						|
        cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
 | 
						|
        cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
 | 
						|
 | 
						|
        cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
 | 
						|
        cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
 | 
						|
 | 
						|
        cls.legacy_wss_thread, cls.legacy_wss_port = create_legacy_wss_websocket_server()
 | 
						|
        cls.legacy_wss_host = f'wss://127.0.0.1:{cls.legacy_wss_port}'
 | 
						|
 | 
						|
    def test_basic_websockets(self, handler):
 | 
						|
        with handler() as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            assert 'upgrade' in ws.headers
 | 
						|
            assert ws.status == 101
 | 
						|
            ws.send('foo')
 | 
						|
            assert ws.recv() == 'foo'
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
 | 
						|
    @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
 | 
						|
    def test_send_types(self, handler, msg, opcode):
 | 
						|
        with handler() as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            ws.send(msg)
 | 
						|
            assert int(ws.recv()) == opcode
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    def test_verify_cert(self, handler):
 | 
						|
        with handler() as rh:
 | 
						|
            with pytest.raises(CertificateVerifyError):
 | 
						|
                ws_validate_and_send(rh, Request(self.wss_base_url))
 | 
						|
 | 
						|
        with handler(verify=False) as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.wss_base_url))
 | 
						|
            assert ws.status == 101
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    def test_ssl_error(self, handler):
 | 
						|
        with handler(verify=False) as rh:
 | 
						|
            with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
 | 
						|
                ws_validate_and_send(rh, Request(self.bad_wss_host))
 | 
						|
            assert not issubclass(exc_info.type, CertificateVerifyError)
 | 
						|
 | 
						|
    def test_legacy_ssl_extension(self, handler):
 | 
						|
        with handler(verify=False) as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.legacy_wss_host, extensions={'legacy_ssl': True}))
 | 
						|
            assert ws.status == 101
 | 
						|
            ws.close()
 | 
						|
 | 
						|
            # Ensure only applies to request extension
 | 
						|
            with pytest.raises(SSLError):
 | 
						|
                ws_validate_and_send(rh, Request(self.legacy_wss_host))
 | 
						|
 | 
						|
    def test_legacy_ssl_support(self, handler):
 | 
						|
        with handler(verify=False, legacy_ssl_support=True) as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.legacy_wss_host))
 | 
						|
            assert ws.status == 101
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    @pytest.mark.parametrize('path,expected', [
 | 
						|
        # Unicode characters should be encoded with uppercase percent-encoding
 | 
						|
        ('/中文', '/%E4%B8%AD%E6%96%87'),
 | 
						|
        # don't normalize existing percent encodings
 | 
						|
        ('/%c7%9f', '/%c7%9f'),
 | 
						|
    ])
 | 
						|
    def test_percent_encode(self, handler, path, expected):
 | 
						|
        with handler() as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
 | 
						|
            ws.send('path')
 | 
						|
            assert ws.recv() == expected
 | 
						|
            assert ws.status == 101
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    def test_remove_dot_segments(self, handler):
 | 
						|
        with handler() as rh:
 | 
						|
            # This isn't a comprehensive test,
 | 
						|
            # but it should be enough to check whether the handler is removing dot segments
 | 
						|
            ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
 | 
						|
            assert ws.status == 101
 | 
						|
            ws.send('path')
 | 
						|
            assert ws.recv() == '/test'
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    # We are restricted to known HTTP status codes in http.HTTPStatus
 | 
						|
    # Redirects are not supported for websockets
 | 
						|
    @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
 | 
						|
    def test_raise_http_error(self, handler, status):
 | 
						|
        with handler() as rh:
 | 
						|
            with pytest.raises(HTTPError) as exc_info:
 | 
						|
                ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
 | 
						|
            assert exc_info.value.status == status
 | 
						|
 | 
						|
    @pytest.mark.parametrize('params,extensions', [
 | 
						|
        ({'timeout': sys.float_info.min}, {}),
 | 
						|
        ({}, {'timeout': sys.float_info.min}),
 | 
						|
    ])
 | 
						|
    def test_read_timeout(self, handler, params, extensions):
 | 
						|
        with handler(**params) as rh:
 | 
						|
            with pytest.raises(TransportError):
 | 
						|
                ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
 | 
						|
 | 
						|
    def test_connect_timeout(self, handler):
 | 
						|
        # nothing should be listening on this port
 | 
						|
        connect_timeout_url = 'ws://10.255.255.255'
 | 
						|
        with handler(timeout=0.01) as rh, pytest.raises(TransportError):
 | 
						|
            now = time.time()
 | 
						|
            ws_validate_and_send(rh, Request(connect_timeout_url))
 | 
						|
        assert time.time() - now < DEFAULT_TIMEOUT
 | 
						|
 | 
						|
        # Per request timeout, should override handler timeout
 | 
						|
        request = Request(connect_timeout_url, extensions={'timeout': 0.01})
 | 
						|
        with handler() as rh, pytest.raises(TransportError):
 | 
						|
            now = time.time()
 | 
						|
            ws_validate_and_send(rh, request)
 | 
						|
        assert time.time() - now < DEFAULT_TIMEOUT
 | 
						|
 | 
						|
    def test_cookies(self, handler):
 | 
						|
        cookiejar = YoutubeDLCookieJar()
 | 
						|
        cookiejar.set_cookie(http.cookiejar.Cookie(
 | 
						|
            version=0, name='test', value='ytdlp', port=None, port_specified=False,
 | 
						|
            domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
 | 
						|
            path_specified=True, secure=False, expires=None, discard=False, comment=None,
 | 
						|
            comment_url=None, rest={}))
 | 
						|
 | 
						|
        with handler(cookiejar=cookiejar) as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            ws.send('headers')
 | 
						|
            assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
 | 
						|
            ws.close()
 | 
						|
 | 
						|
        with handler() as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            ws.send('headers')
 | 
						|
            assert 'cookie' not in json.loads(ws.recv())
 | 
						|
            ws.close()
 | 
						|
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
 | 
						|
            ws.send('headers')
 | 
						|
            assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
 | 
						|
    def test_cookie_sync_only_cookiejar(self, handler):
 | 
						|
        # Ensure that cookies are ONLY being handled by the cookiejar
 | 
						|
        with handler() as rh:
 | 
						|
            ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
 | 
						|
            ws.send('headers')
 | 
						|
            assert 'cookie' not in json.loads(ws.recv())
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
 | 
						|
    def test_cookie_sync_delete_cookie(self, handler):
 | 
						|
        # Ensure that cookies are ONLY being handled by the cookiejar
 | 
						|
        cookiejar = YoutubeDLCookieJar()
 | 
						|
        with handler(verbose=True, cookiejar=cookiejar) as rh:
 | 
						|
            ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            ws.send('headers')
 | 
						|
            assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
 | 
						|
            ws.close()
 | 
						|
            cookiejar.clear_session_cookies()
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            ws.send('headers')
 | 
						|
            assert 'cookie' not in json.loads(ws.recv())
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    def test_source_address(self, handler):
 | 
						|
        source_address = f'127.0.0.{random.randint(5, 255)}'
 | 
						|
        verify_address_availability(source_address)
 | 
						|
        with handler(source_address=source_address) as rh:
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            ws.send('source_address')
 | 
						|
            assert source_address == ws.recv()
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    def test_response_url(self, handler):
 | 
						|
        with handler() as rh:
 | 
						|
            url = f'{self.ws_base_url}/something'
 | 
						|
            ws = ws_validate_and_send(rh, Request(url))
 | 
						|
            assert ws.url == url
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    def test_request_headers(self, handler):
 | 
						|
        with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
 | 
						|
            # Global Headers
 | 
						|
            ws = ws_validate_and_send(rh, Request(self.ws_base_url))
 | 
						|
            ws.send('headers')
 | 
						|
            headers = HTTPHeaderDict(json.loads(ws.recv()))
 | 
						|
            assert headers['test1'] == 'test'
 | 
						|
            ws.close()
 | 
						|
 | 
						|
            # Per request headers, merged with global
 | 
						|
            ws = ws_validate_and_send(rh, Request(
 | 
						|
                self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
 | 
						|
            ws.send('headers')
 | 
						|
            headers = HTTPHeaderDict(json.loads(ws.recv()))
 | 
						|
            assert headers['test1'] == 'test'
 | 
						|
            assert headers['test2'] == 'changed'
 | 
						|
            assert headers['test3'] == 'test3'
 | 
						|
            ws.close()
 | 
						|
 | 
						|
    @pytest.mark.parametrize('client_cert', (
 | 
						|
        {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
 | 
						|
        {
 | 
						|
            'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
 | 
						|
            'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
 | 
						|
        },
 | 
						|
        {
 | 
						|
            'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
 | 
						|
            'client_certificate_password': 'foobar',
 | 
						|
        },
 | 
						|
        {
 | 
						|
            'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
 | 
						|
            'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
 | 
						|
            'client_certificate_password': 'foobar',
 | 
						|
        },
 | 
						|
    ))
 | 
						|
    def test_mtls(self, handler, client_cert):
 | 
						|
        with handler(
 | 
						|
            # Disable client-side validation of unacceptable self-signed testcert.pem
 | 
						|
            # The test is of a check on the server side, so unaffected
 | 
						|
            verify=False,
 | 
						|
            client_cert=client_cert,
 | 
						|
        ) as rh:
 | 
						|
            ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
 | 
						|
 | 
						|
    def test_request_disable_proxy(self, handler):
 | 
						|
        for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
 | 
						|
            # Given handler is configured with a proxy
 | 
						|
            with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
 | 
						|
                # When a proxy is explicitly set to None for the request
 | 
						|
                ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
 | 
						|
                # Then no proxy should be used
 | 
						|
                assert ws.status == 101
 | 
						|
                ws.close()
 | 
						|
 | 
						|
    @pytest.mark.skip_handlers_if(
 | 
						|
        lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
 | 
						|
    def test_noproxy(self, handler):
 | 
						|
        for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
 | 
						|
            # Given the handler is configured with a proxy
 | 
						|
            with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
 | 
						|
                for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
 | 
						|
                    # When request no proxy includes the request url host
 | 
						|
                    ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
 | 
						|
                    # Then the proxy should not be used
 | 
						|
                    assert ws.status == 101
 | 
						|
                    ws.close()
 | 
						|
 | 
						|
    @pytest.mark.skip_handlers_if(
 | 
						|
        lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
 | 
						|
    def test_allproxy(self, handler):
 | 
						|
        supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
 | 
						|
        # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
 | 
						|
        # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
 | 
						|
        with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
 | 
						|
            with pytest.raises(TransportError):
 | 
						|
                ws_validate_and_send(rh, Request(self.ws_base_url)).close()
 | 
						|
 | 
						|
        with handler(timeout=0.1) as rh:
 | 
						|
            with pytest.raises(TransportError):
 | 
						|
                ws_validate_and_send(
 | 
						|
                    rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
 | 
						|
 | 
						|
 | 
						|
def create_fake_ws_connection(raised):
 | 
						|
    import websockets.sync.client
 | 
						|
 | 
						|
    class FakeWsConnection(websockets.sync.client.ClientConnection):
 | 
						|
        def __init__(self, *args, **kwargs):
 | 
						|
            class FakeResponse:
 | 
						|
                body = b''
 | 
						|
                headers = {}
 | 
						|
                status_code = 101
 | 
						|
                reason_phrase = 'test'
 | 
						|
 | 
						|
            self.response = FakeResponse()
 | 
						|
 | 
						|
        def send(self, *args, **kwargs):
 | 
						|
            raise raised()
 | 
						|
 | 
						|
        def recv(self, *args, **kwargs):
 | 
						|
            raise raised()
 | 
						|
 | 
						|
        def close(self, *args, **kwargs):
 | 
						|
            return
 | 
						|
 | 
						|
    return FakeWsConnection()
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
 | 
						|
class TestWebsocketsRequestHandler:
 | 
						|
    @pytest.mark.parametrize('raised,expected', [
 | 
						|
        # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
 | 
						|
        (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
 | 
						|
        # Requires a response object. Should be covered by HTTP error tests.
 | 
						|
        # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
 | 
						|
        (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
 | 
						|
        # These are subclasses of InvalidHandshake
 | 
						|
        (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
 | 
						|
        (lambda: websockets.exceptions.NegotiationError(), TransportError),
 | 
						|
        # Catch-all
 | 
						|
        (lambda: websockets.exceptions.WebSocketException(), TransportError),
 | 
						|
        (lambda: TimeoutError(), TransportError),
 | 
						|
        # These may be raised by our create_connection implementation, which should also be caught
 | 
						|
        (lambda: OSError(), TransportError),
 | 
						|
        (lambda: ssl.SSLError(), SSLError),
 | 
						|
        (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
 | 
						|
        (lambda: socks.ProxyError(), ProxyError),
 | 
						|
    ])
 | 
						|
    def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
 | 
						|
        import websockets.sync.client
 | 
						|
 | 
						|
        import yt_dlp.networking._websockets
 | 
						|
        with handler() as rh:
 | 
						|
            def fake_connect(*args, **kwargs):
 | 
						|
                raise raised()
 | 
						|
            monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
 | 
						|
            monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
 | 
						|
            with pytest.raises(expected) as exc_info:
 | 
						|
                rh.send(Request('ws://fake-url'))
 | 
						|
            assert exc_info.type is expected
 | 
						|
 | 
						|
    @pytest.mark.parametrize('raised,expected,match', [
 | 
						|
        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
 | 
						|
        (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
 | 
						|
        (lambda: RuntimeError(), TransportError, None),
 | 
						|
        (lambda: TimeoutError(), TransportError, None),
 | 
						|
        (lambda: TypeError(), RequestError, None),
 | 
						|
        (lambda: socks.ProxyError(), ProxyError, None),
 | 
						|
        # Catch-all
 | 
						|
        (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
 | 
						|
    ])
 | 
						|
    def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
 | 
						|
        from yt_dlp.networking._websockets import WebsocketsResponseAdapter
 | 
						|
        ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
 | 
						|
        with pytest.raises(expected, match=match) as exc_info:
 | 
						|
            ws.send('test')
 | 
						|
        assert exc_info.type is expected
 | 
						|
 | 
						|
    @pytest.mark.parametrize('raised,expected,match', [
 | 
						|
        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
 | 
						|
        (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
 | 
						|
        (lambda: RuntimeError(), TransportError, None),
 | 
						|
        (lambda: TimeoutError(), TransportError, None),
 | 
						|
        (lambda: socks.ProxyError(), ProxyError, None),
 | 
						|
        # Catch-all
 | 
						|
        (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
 | 
						|
    ])
 | 
						|
    def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
 | 
						|
        from yt_dlp.networking._websockets import WebsocketsResponseAdapter
 | 
						|
        ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
 | 
						|
        with pytest.raises(expected, match=match) as exc_info:
 | 
						|
            ws.recv()
 | 
						|
        assert exc_info.type is expected
 |