From a9b1c25ddc90fbae5f1e9409967855130d8386d2 Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Sat, 9 Aug 2025 15:18:14 +1200 Subject: [PATCH] [test] media part tests --- test/test_sabr/test_processor.py | 129 +++++++++++++++++- .../extractor/youtube/_streaming/sabr/part.py | 6 +- .../youtube/_streaming/sabr/processor.py | 2 +- 3 files changed, 132 insertions(+), 5 deletions(-) diff --git a/test/test_sabr/test_processor.py b/test/test_sabr/test_processor.py index 70752cc8ef..5f91ebcf2d 100644 --- a/test/test_sabr/test_processor.py +++ b/test/test_sabr/test_processor.py @@ -1,4 +1,5 @@ import dataclasses +import io import pytest from unittest.mock import MagicMock @@ -9,6 +10,7 @@ FormatInitializedSabrPart, MediaSeekSabrPart, MediaSegmentInitSabrPart, + MediaSegmentDataSabrPart, ) from yt_dlp.extractor.youtube._streaming.sabr.processor import ( @@ -18,6 +20,7 @@ ProcessLiveMetadataResult, ProcessSabrSeekResult, ProcessMediaHeaderResult, + ProcessMediaResult, ) from yt_dlp.extractor.youtube._streaming.sabr.models import ( AudioSelector, @@ -319,7 +322,7 @@ def test_override_defaults(self, base_args): assert processor.post_live is True -class TestStreamProtectionStatusPart: +class TestStreamProtectionStatus: @pytest.mark.parametrize( 'sps,po_token,expected_status', @@ -2010,3 +2013,127 @@ def test_segment_mismatch(self, base_args): processor.process_media_header(media_header) assert exc_info.value.expected_sequence_number == 11 assert exc_info.value.received_sequence_number == 1 + + +class TestMedia: + def test_valid_media_parts(self, base_args): + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + + example_payload = b'example-data' + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_media_header(selector, sequence_no=1) + media_header.content_length = len(example_payload) + processor.process_media_header(media_header) + + result = processor.process_media( + header_id=media_header.header_id, + content_length=media_header.content_length, + data=io.BytesIO(example_payload)) + + assert isinstance(result, ProcessMediaResult) + assert isinstance(result.sabr_part, MediaSegmentDataSabrPart) + assert result.sabr_part == MediaSegmentDataSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + sequence_number=1, + is_init_segment=False, + total_segments=fim.total_segments, + data=example_payload, + content_length=len(example_payload), + segment_start_bytes=0, + ) + assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload) + + # Subsequent call should increment received data length + result = processor.process_media( + header_id=media_header.header_id, + content_length=media_header.content_length, + data=io.BytesIO(example_payload)) + + assert isinstance(result, ProcessMediaResult) + assert isinstance(result.sabr_part, MediaSegmentDataSabrPart) + assert result.sabr_part == MediaSegmentDataSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + sequence_number=1, + is_init_segment=False, + total_segments=fim.total_segments, + data=example_payload, + content_length=len(example_payload), + segment_start_bytes=len(example_payload), + ) + assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload) * 2 + + def test_no_matching_partial_segment(self, base_args): + # Should raise an error if no matching partial segment found + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + with pytest.raises(SabrStreamError, match='Header ID 12345 not found in partial segments'): + processor.process_media( + header_id=12345, # Non-existent header ID + content_length=100, + data=io.BytesIO(b'example-data'), + ) + + def test_discarded_partial_segment(self, base_args): + # Should ignore the media part if the segment is marked as discard + selector = make_selector('audio', discard_media=True) + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + example_payload = b'example-data' + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_media_header(selector, sequence_no=1) + media_header.content_length = len(example_payload) + processor.process_media_header(media_header) + + result = processor.process_media( + header_id=media_header.header_id, + content_length=media_header.content_length, + data=io.BytesIO(example_payload)) + + assert isinstance(result, ProcessMediaResult) + assert result.sabr_part is None + assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload) + + def test_valid_init_segment(self, base_args): + # Should process init segment correctly and report as such + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + example_payload = b'example-init-data' + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_init_header(selector) + processor.process_media_header(media_header) + + result = processor.process_media( + header_id=media_header.header_id, + content_length=len(example_payload), + data=io.BytesIO(example_payload)) + + assert isinstance(result, ProcessMediaResult) + assert isinstance(result.sabr_part, MediaSegmentDataSabrPart) + assert result.sabr_part == MediaSegmentDataSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + sequence_number=None, + is_init_segment=True, # Init segment should be True + total_segments=fim.total_segments, + data=example_payload, + content_length=len(example_payload), + segment_start_bytes=0, + ) + assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload) diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/part.py b/yt_dlp/extractor/youtube/_streaming/sabr/part.py index 2dff2efe35..cb235aaea6 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/part.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/part.py @@ -35,10 +35,10 @@ class MediaSegmentDataSabrPart(SabrPart): format_id: FormatId sequence_number: int | None = None is_init_segment: bool = False - total_segments: int = None + total_segments: int | None = None data: bytes = b'' - content_length: int = None - segment_start_bytes: int = None + content_length: int | None = None + segment_start_bytes: int | None = None @dataclasses.dataclass diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py index 44a399a2bf..69aae6de2f 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py @@ -330,7 +330,7 @@ def process_media(self, header_id: int, content_length: int, data: io.BufferedIO segment = self.partial_segments.get(header_id) if not segment: self.logger.debug(f'Header ID {header_id} not found') - return result + raise SabrStreamError(f'Header ID {header_id} not found in partial segments') segment_start_bytes = segment.received_data_length segment.received_data_length += content_length