diff --git a/test/test_sabr/test_processor.py b/test/test_sabr/test_processor.py index 9a2e63a8ed..1592a198d9 100644 --- a/test/test_sabr/test_processor.py +++ b/test/test_sabr/test_processor.py @@ -8,13 +8,16 @@ PoTokenStatusSabrPart, FormatInitializedSabrPart, MediaSeekSabrPart, + MediaSegmentInitSabrPart, ) from yt_dlp.extractor.youtube._streaming.sabr.processor import ( SabrProcessor, ProcessStreamProtectionStatusResult, ProcessFormatInitializationMetadataResult, - ProcessLiveMetadataResult, ProcessSabrSeekResult, + ProcessLiveMetadataResult, + ProcessSabrSeekResult, + ProcessMediaHeaderResult, ) from yt_dlp.extractor.youtube._streaming.sabr.models import ( AudioSelector, @@ -29,7 +32,9 @@ FormatInitializationMetadata, LiveMetadata, SabrContextUpdate, - SabrContextSendingPolicy, SabrSeek, + SabrContextSendingPolicy, + SabrSeek, + MediaHeader, ) from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy @@ -84,6 +89,43 @@ def factory(): return factory +def make_format_im(selector=None, video_id=None): + return FormatInitializationMetadata( + video_id=video_id or example_video_id, + format_id=selector.format_ids[0] if selector else FormatId(itag=140), + end_time_ms=10000, + total_segments=5, + mime_type=(selector.mime_prefix + '/mp4') if selector else 'audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + +def make_init_header(selector=None, video_id=None): + return MediaHeader( + video_id=video_id or example_video_id, + format_id=selector.format_ids[0] if selector else FormatId(itag=140), + header_id=0, + is_init_segment=True, + start_data_range=0, + content_length=501, + ) + + +def make_media_header(selector=None, video_id=None, sequence_no=None, header_id=0): + return MediaHeader( + video_id=video_id or example_video_id, + format_id=selector.format_ids[0] if selector else FormatId(itag=140), + header_id=header_id, + start_data_range=502, + content_length=10000, + sequence_number=sequence_no, + is_init_segment=False, + duration_ms=2300, + start_ms=0, + ) + + class TestSabrProcessorInitialization: @pytest.mark.parametrize( 'audio_sel,video_sel,caption_sel,expected_bitfield', @@ -1120,8 +1162,6 @@ def test_write_policy_keep_existing(self, logger, base_args): 'Received a SABR Context Update with write_policy=KEEP_EXISTING' 'matching an existing SABR Context Update. Ignoring update') - -class TestSabrContextUpdateSendingPolicy: def test_set_sabr_context_update_sending_policy(self, base_args, logger): processor = SabrProcessor(**base_args) @@ -1232,3 +1272,107 @@ def test_sabr_seek(self, logger, base_args): assert izf.current_segment is None logger.debug.assert_called_with('Seeking to 5600ms') + + +class TestMediaHeader: + def test_media_header_init_segment(self, base_args): + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + + media_header = make_init_header(selector) + + result = processor.process_media_header(media_header) + + assert isinstance(result, ProcessMediaHeaderResult) + assert isinstance(result.sabr_part, MediaSegmentInitSabrPart) + + part = result.sabr_part + + # TODO: confirm expected duration/start_ms settings for init segments + assert part == MediaSegmentInitSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + player_time_ms=0, + sequence_number=None, + total_segments=5, + is_init_segment=True, + content_length=501, + start_time_ms=0, + duration_ms=0, + start_bytes=0, + ) + assert media_header.header_id in processor.partial_segments + + segment = processor.partial_segments[media_header.header_id] + + # TODO: confirm expected duration/start_ms settings for init segments + assert segment == Segment( + format_id=selector.format_ids[0], + is_init_segment=True, + duration_ms=0, + start_ms=0, + start_data_range=0, + sequence_number=None, + content_length=501, + content_length_estimated=False, + initialized_format=processor.initialized_formats[str(selector.format_ids[0])], + duration_estimated=True, + discard=False, + consumed=False, + received_data_length=0, + ) + + def test_media_header_segment(self, base_args): + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + + media_header = make_media_header(selector, sequence_no=1) + + result = processor.process_media_header(media_header) + + assert isinstance(result, ProcessMediaHeaderResult) + assert isinstance(result.sabr_part, MediaSegmentInitSabrPart) + + part = result.sabr_part + + assert part == MediaSegmentInitSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + player_time_ms=0, + sequence_number=1, + total_segments=5, + is_init_segment=False, + content_length=10000, + start_time_ms=0, + duration_ms=2300, + start_bytes=502, + ) + assert media_header.header_id in processor.partial_segments + + segment = processor.partial_segments[media_header.header_id] + + assert segment == Segment( + format_id=selector.format_ids[0], + is_init_segment=False, + duration_ms=2300, + start_ms=0, + start_data_range=502, + sequence_number=1, + content_length=10000, + content_length_estimated=False, + initialized_format=processor.initialized_formats[str(selector.format_ids[0])], + duration_estimated=False, + discard=False, + consumed=False, + received_data_length=0, + ) diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/models.py b/yt_dlp/extractor/youtube/_streaming/sabr/models.py index 52edcc5ce4..abd600cae6 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/models.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/models.py @@ -13,7 +13,7 @@ class Segment: duration_ms: int = 0 start_ms: int = 0 start_data_range: int = 0 - sequence_number: int = 0 + sequence_number: int | None = 0 content_length: int | None = None content_length_estimated: bool = False initialized_format: InitializedFormat = None @@ -62,36 +62,29 @@ class FormatSelector: display_name: str format_ids: list[FormatId] = dataclasses.field(default_factory=list) discard_media: bool = False + mime_prefix: str | None = None - def match(self, format_id: FormatId = None, **kwargs) -> bool: - return format_id in self.format_ids + def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool: + return ( + format_id in self.format_ids + or ( + not self.format_ids + and self.mime_prefix + and mime_type and mime_type.lower().startswith(self.mime_prefix) + ) + ) @dataclasses.dataclass class AudioSelector(FormatSelector): - - def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool: - return ( - super().match(format_id, mime_type=mime_type, **kwargs) - or (not self.format_ids and mime_type and mime_type.lower().startswith('audio')) - ) + mime_prefix: str = dataclasses.field(default='audio') @dataclasses.dataclass class VideoSelector(FormatSelector): - - def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool: - return ( - super().match(format_id, mime_type=mime_type, **kwargs) - or (not self.format_ids and mime_type and mime_type.lower().startswith('video')) - ) + mime_prefix: str = dataclasses.field(default='video') @dataclasses.dataclass class CaptionSelector(FormatSelector): - - def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool: - return ( - super().match(format_id, mime_type=mime_type, **kwargs) - or (not self.format_ids and mime_type and mime_type.lower().startswith('text')) - ) + mime_prefix: str = dataclasses.field(default='text') diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/part.py b/yt_dlp/extractor/youtube/_streaming/sabr/part.py index 8c7d8dc0c3..2dff2efe35 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/part.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/part.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import enum @@ -15,7 +17,7 @@ class SabrPart: class MediaSegmentInitSabrPart(SabrPart): format_selector: FormatSelector format_id: FormatId - sequence_number: int = None + sequence_number: int | None = None is_init_segment: bool = False total_segments: int = None start_time_ms: int = None @@ -31,7 +33,7 @@ class MediaSegmentInitSabrPart(SabrPart): class MediaSegmentDataSabrPart(SabrPart): format_selector: FormatSelector format_id: FormatId - sequence_number: int = None + sequence_number: int | None = None is_init_segment: bool = False total_segments: int = None data: bytes = b'' @@ -43,7 +45,7 @@ class MediaSegmentDataSabrPart(SabrPart): class MediaSegmentEndSabrPart(SabrPart): format_selector: FormatSelector format_id: FormatId - sequence_number: int = None + sequence_number: int | None = None is_init_segment: bool = False total_segments: int = None