mirror of
				https://github.com/yt-dlp/yt-dlp.git
				synced 2025-10-31 14:45:14 +00:00 
			
		
		
		
	[test] Workaround websocket server hanging (#9467)
Authored by: coletdjnz
This commit is contained in:
		| @@ -32,8 +32,6 @@ from yt_dlp.networking.exceptions import ( | ||||
| ) | ||||
| from yt_dlp.utils.networking import HTTPHeaderDict | ||||
| 
 | ||||
| from test.conftest import validate_and_send | ||||
| 
 | ||||
| TEST_DIR = os.path.dirname(os.path.abspath(__file__)) | ||||
| 
 | ||||
| 
 | ||||
| @@ -66,7 +64,9 @@ def process_request(self, request): | ||||
| 
 | ||||
| def create_websocket_server(**ws_kwargs): | ||||
|     import websockets.sync.server | ||||
|     wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) | ||||
|     wsd = websockets.sync.server.serve( | ||||
|         websocket_handler, '127.0.0.1', 0, | ||||
|         process_request=process_request, open_timeout=2, **ws_kwargs) | ||||
|     ws_port = wsd.socket.getsockname()[1] | ||||
|     ws_server_thread = threading.Thread(target=wsd.serve_forever) | ||||
|     ws_server_thread.daemon = True | ||||
| @@ -100,6 +100,19 @@ def create_mtls_wss_websocket_server(): | ||||
|     return create_websocket_server(ssl_context=sslctx) | ||||
| 
 | ||||
| 
 | ||||
| def ws_validate_and_send(rh, req): | ||||
|     rh.validate(req) | ||||
|     max_tries = 3 | ||||
|     for i in range(max_tries): | ||||
|         try: | ||||
|             return rh.send(req) | ||||
|         except TransportError as e: | ||||
|             if i < (max_tries - 1) and 'connection closed during handshake' in str(e): | ||||
|                 # websockets server sometimes hangs on new connections | ||||
|                 continue | ||||
|             raise | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') | ||||
| class TestWebsSocketRequestHandlerConformance: | ||||
|     @classmethod | ||||
| @@ -119,7 +132,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | ||||
|     def test_basic_websockets(self, handler): | ||||
|         with handler() as rh: | ||||
|             ws = validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             assert 'upgrade' in ws.headers | ||||
|             assert ws.status == 101 | ||||
|             ws.send('foo') | ||||
| @@ -131,7 +144,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | ||||
|     def test_send_types(self, handler, msg, opcode): | ||||
|         with handler() as rh: | ||||
|             ws = validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws.send(msg) | ||||
|             assert int(ws.recv()) == opcode | ||||
|             ws.close() | ||||
| @@ -140,10 +153,10 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     def test_verify_cert(self, handler): | ||||
|         with handler() as rh: | ||||
|             with pytest.raises(CertificateVerifyError): | ||||
|                 validate_and_send(rh, Request(self.wss_base_url)) | ||||
|                 ws_validate_and_send(rh, Request(self.wss_base_url)) | ||||
| 
 | ||||
|         with handler(verify=False) as rh: | ||||
|             ws = validate_and_send(rh, Request(self.wss_base_url)) | ||||
|             ws = ws_validate_and_send(rh, Request(self.wss_base_url)) | ||||
|             assert ws.status == 101 | ||||
|             ws.close() | ||||
| 
 | ||||
| @@ -151,7 +164,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     def test_ssl_error(self, handler): | ||||
|         with handler(verify=False) as rh: | ||||
|             with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: | ||||
|                 validate_and_send(rh, Request(self.bad_wss_host)) | ||||
|                 ws_validate_and_send(rh, Request(self.bad_wss_host)) | ||||
|             assert not issubclass(exc_info.type, CertificateVerifyError) | ||||
| 
 | ||||
|     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | ||||
| @@ -163,7 +176,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     ]) | ||||
|     def test_percent_encode(self, handler, path, expected): | ||||
|         with handler() as rh: | ||||
|             ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) | ||||
|             ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) | ||||
|             ws.send('path') | ||||
|             assert ws.recv() == expected | ||||
|             assert ws.status == 101 | ||||
| @@ -174,7 +187,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|         with handler() as rh: | ||||
|             # This isn't a comprehensive test, | ||||
|             # but it should be enough to check whether the handler is removing dot segments | ||||
|             ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) | ||||
|             ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) | ||||
|             assert ws.status == 101 | ||||
|             ws.send('path') | ||||
|             assert ws.recv() == '/test' | ||||
| @@ -187,7 +200,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     def test_raise_http_error(self, handler, status): | ||||
|         with handler() as rh: | ||||
|             with pytest.raises(HTTPError) as exc_info: | ||||
|                 validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) | ||||
|                 ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) | ||||
|             assert exc_info.value.status == status | ||||
| 
 | ||||
|     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | ||||
| @@ -198,7 +211,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     def test_timeout(self, handler, params, extensions): | ||||
|         with handler(**params) as rh: | ||||
|             with pytest.raises(TransportError): | ||||
|                 validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) | ||||
|                 ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) | ||||
| 
 | ||||
|     @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) | ||||
|     def test_cookies(self, handler): | ||||
| @@ -210,18 +223,18 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|             comment_url=None, rest={})) | ||||
| 
 | ||||
|         with handler(cookiejar=cookiejar) as rh: | ||||
|             ws = validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws.send('headers') | ||||
|             assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | ||||
|             ws.close() | ||||
| 
 | ||||
|         with handler() as rh: | ||||
|             ws = validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws.send('headers') | ||||
|             assert 'cookie' not in json.loads(ws.recv()) | ||||
|             ws.close() | ||||
| 
 | ||||
|             ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) | ||||
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) | ||||
|             ws.send('headers') | ||||
|             assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' | ||||
|             ws.close() | ||||
| @@ -231,7 +244,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|         source_address = f'127.0.0.{random.randint(5, 255)}' | ||||
|         verify_address_availability(source_address) | ||||
|         with handler(source_address=source_address) as rh: | ||||
|             ws = validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws.send('source_address') | ||||
|             assert source_address == ws.recv() | ||||
|             ws.close() | ||||
| @@ -240,7 +253,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     def test_response_url(self, handler): | ||||
|         with handler() as rh: | ||||
|             url = f'{self.ws_base_url}/something' | ||||
|             ws = validate_and_send(rh, Request(url)) | ||||
|             ws = ws_validate_and_send(rh, Request(url)) | ||||
|             assert ws.url == url | ||||
|             ws.close() | ||||
| 
 | ||||
| @@ -248,14 +261,14 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|     def test_request_headers(self, handler): | ||||
|         with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: | ||||
|             # Global Headers | ||||
|             ws = validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws = ws_validate_and_send(rh, Request(self.ws_base_url)) | ||||
|             ws.send('headers') | ||||
|             headers = HTTPHeaderDict(json.loads(ws.recv())) | ||||
|             assert headers['test1'] == 'test' | ||||
|             ws.close() | ||||
| 
 | ||||
|             # Per request headers, merged with global | ||||
|             ws = validate_and_send(rh, Request( | ||||
|             ws = ws_validate_and_send(rh, Request( | ||||
|                 self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) | ||||
|             ws.send('headers') | ||||
|             headers = HTTPHeaderDict(json.loads(ws.recv())) | ||||
| @@ -288,7 +301,7 @@ class TestWebsSocketRequestHandlerConformance: | ||||
|             verify=False, | ||||
|             client_cert=client_cert | ||||
|         ) as rh: | ||||
|             validate_and_send(rh, Request(self.mtls_wss_base_url)).close() | ||||
|             ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close() | ||||
| 
 | ||||
| 
 | ||||
| def create_fake_ws_connection(raised): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 coletdjnz
					coletdjnz