diff --git a/test/test_sabr/test_processor.py b/test/test_sabr/test_processor.py index 351952f239..e43a15dc6a 100644 --- a/test/test_sabr/test_processor.py +++ b/test/test_sabr/test_processor.py @@ -1,15 +1,21 @@ +import dataclasses + import pytest from unittest.mock import MagicMock -from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart +from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError +from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart, FormatInitializedSabrPart -from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult +from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult, \ + ProcessFormatInitializationMetadataResult from yt_dlp.extractor.youtube._streaming.sabr.models import ( AudioSelector, VideoSelector, CaptionSelector, + InitializedFormat, ) -from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus, \ + FormatInitializationMetadata, LiveMetadata from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy @@ -32,23 +38,26 @@ def base_args(logger, client_info): } +example_video_id = 'example_video_id' + + def make_selector(selector_type, *, discard_media=False, format_ids=None): if selector_type == 'audio': return AudioSelector( display_name='audio', - format_ids=format_ids or [FormatId(itag=140)], + format_ids=format_ids if format_ids is not None else [FormatId(itag=140)], discard_media=discard_media, ) elif selector_type == 'video': return VideoSelector( display_name='video', - format_ids=format_ids or [FormatId(itag=248)], + format_ids=format_ids if format_ids is not None else [FormatId(itag=248)], discard_media=discard_media, ) elif selector_type == 'caption': return CaptionSelector( display_name='caption', - format_ids=format_ids or [FormatId(itag=386)], + format_ids=format_ids if format_ids is not None else [FormatId(itag=386)], discard_media=discard_media, ) raise ValueError(f'Unknown selector_type: {selector_type}') @@ -305,3 +314,494 @@ def test_next_request_policy_part(self, logger, base_args): # Check logger trace was called assert logger.trace.call_count == 2 + + +class TestFormatInitialization: + + def test_initialize_format(self, logger, base_args): + selector = make_selector('audio') + format_id = selector.format_ids[0] + processor = SabrProcessor(**base_args, audio_selection=selector, video_id='test_video') + format_init_metadata_part = FormatInitializationMetadata( + video_id='test_video', + format_id=format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + result = processor.process_format_initialization_metadata(format_init_metadata_part) + + assert isinstance(result, ProcessFormatInitializationMetadataResult) + assert isinstance(result.sabr_part, FormatInitializedSabrPart) + assert result.sabr_part.format_selector is selector + assert result.sabr_part.format_id == format_id + assert len(processor.initialized_formats) == 1 + assert str(format_id) in processor.initialized_formats + initialized_format = processor.initialized_formats[str(format_id)] + expected_initialized_format = InitializedFormat( + format_id=format_id, + video_id='test_video', + mime_type='audio/mp4', + duration_ms=10000, + total_segments=5, + end_time_ms=10000, + format_selector=selector, + discard=False, + ) + assert initialized_format == expected_initialized_format + logger.debug.assert_called_with( + f'Initialized Format: {expected_initialized_format}', + ) + + def test_initialize_format_already_initialized(self, logger, base_args): + selector = make_selector('audio') + format_id = selector.format_ids[0] + processor = SabrProcessor(**base_args, audio_selection=selector, video_id='test_video') + format_init_metadata_part = FormatInitializationMetadata( + video_id='test_video', + format_id=format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + assert processor.process_format_initialization_metadata(format_init_metadata_part) + + # Now try to initialize it again + result = processor.process_format_initialization_metadata( + dataclasses.replace(format_init_metadata_part, total_segments=10)) + + assert isinstance(result, ProcessFormatInitializationMetadataResult) + assert result.sabr_part is None + logger.trace.assert_called_with(f'Format {format_id} already initialized') + assert len(processor.initialized_formats) == 1 + assert str(format_id) in processor.initialized_formats + initialized_format = processor.initialized_formats[str(format_id)] + assert initialized_format.total_segments == 5 + + def test_initialize_multiple_formats(self, logger, base_args): + audio_selector = make_selector('audio') + video_selector = make_selector('video') + + audio_format_id = audio_selector.format_ids[0] + video_format_id = video_selector.format_ids[0] + + processor = SabrProcessor( + **base_args, + audio_selection=audio_selector, + video_selection=video_selector, + video_id=example_video_id, + ) + + audio_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + video_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + end_time_ms=10000, + total_segments=20, + mime_type='video/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + # Process audio format initialization + audio_result = processor.process_format_initialization_metadata(audio_format_init_metadata) + assert audio_result.sabr_part.format_selector is audio_selector + assert audio_result.sabr_part.format_id == audio_format_id + assert len(processor.initialized_formats) == 1 + assert str(audio_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(audio_format_id)].format_id == audio_format_id + + # Process video format initialization + video_result = processor.process_format_initialization_metadata(video_format_init_metadata) + assert video_result.sabr_part.format_selector is video_selector + assert video_result.sabr_part.format_id == video_format_id + assert len(processor.initialized_formats) == 2 + assert str(video_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(video_format_id)].format_id == video_format_id + + def test_initialized_format_not_match_selector(self, logger, base_args): + selector = make_selector('audio', format_ids=[FormatId(140)]) + processor = SabrProcessor( + **base_args, + video_id=example_video_id, + audio_selection=selector) + + format_id = FormatId(itag=141) # Different format id than the selector + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + with pytest.raises(SabrStreamError, match='does not match any format selector'): + processor.process_format_initialization_metadata(format_init_metadata_part) + + assert len(processor.initialized_formats) == 0 + + def test_initialized_format_match_mimetype(self, logger, base_args): + selector = make_selector('audio', format_ids=[]) + assert len(selector.format_ids) == 0 + processor = SabrProcessor( + **base_args, + video_id=example_video_id, + audio_selection=selector, + video_selection=make_selector('caption'), + ) + + format_id = FormatId(itag=251) + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + result = processor.process_format_initialization_metadata(format_init_metadata_part) + assert isinstance(result, ProcessFormatInitializationMetadataResult) + assert isinstance(result.sabr_part, FormatInitializedSabrPart) + assert result.sabr_part.format_selector is selector + assert result.sabr_part.format_id == format_id + assert len(processor.initialized_formats) == 1 + + # If mimetype does not match any selector, it should raise an error + bad_fmt_init_metadata = dataclasses.replace( + format_init_metadata_part, mime_type='video/mp4', format_id=FormatId(itag=248)) + with pytest.raises(SabrStreamError, match='does not match any format selector'): + processor.process_format_initialization_metadata(bad_fmt_init_metadata) + + assert len(processor.initialized_formats) == 1 + + def test_discard_media(self, logger, base_args): + # Discard and only match by format id + audio_selector = make_selector('audio', discard_media=True) + audio_format_id = audio_selector.format_ids[0] + + # Discard any video + video_selector = make_selector('video', format_ids=[], discard_media=True) + assert len(video_selector.format_ids) == 0 + video_format_id = FormatId(itag=248) + + # xxx: Caption selector not specified, should be discarded by default + processor = SabrProcessor( + **base_args, + video_id=example_video_id, + audio_selection=audio_selector, + video_selection=video_selector, + ) + + audio_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + # Process audio format initialization + audio_result = processor.process_format_initialization_metadata(audio_format_init_metadata) + # When discarding, should not return a sabr_part + assert isinstance(audio_result, ProcessFormatInitializationMetadataResult) + assert audio_result.sabr_part is None + assert len(processor.initialized_formats) == 1 + assert str(audio_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(audio_format_id)].format_id == audio_format_id + assert processor.initialized_formats[str(audio_format_id)].discard is True + + # The format should be marked as completely buffered + assert len(processor.initialized_formats[str(audio_format_id)].consumed_ranges) == 1 + consumed_range = processor.initialized_formats[str(audio_format_id)].consumed_ranges[0] + assert consumed_range.start_sequence_number == 0 + assert consumed_range.end_sequence_number >= 5 + assert consumed_range.start_time_ms == 0 + assert consumed_range.duration_ms >= 10000 + + # Process video format initialization. This should match the selector but be discarded. + video_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + end_time_ms=10000, + total_segments=20, + mime_type='video/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + video_result = processor.process_format_initialization_metadata(video_format_init_metadata) + assert video_result.sabr_part is None + assert len(processor.initialized_formats) == 2 + assert str(video_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(video_format_id)].format_id == video_format_id + assert processor.initialized_formats[str(video_format_id)].discard is True + + # The format should be marked as completely buffered + assert len(processor.initialized_formats[str(video_format_id)].consumed_ranges) == 1 + consumed_range = processor.initialized_formats[str(video_format_id)].consumed_ranges[0] + assert consumed_range.start_sequence_number == 0 + assert consumed_range.end_sequence_number >= 20 + assert consumed_range.start_time_ms == 0 + assert consumed_range.duration_ms >= 10000 + + # Process a caption format initialization. This should be discarded by default as no selector was specified. + caption_format_id = FormatId(itag=386) + # Simulate no duration data (for livestreams with no live_metadata) + caption_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=caption_format_id, + mime_type='text/mp4', + ) + + caption_result = processor.process_format_initialization_metadata(caption_format_init_metadata) + assert caption_result.sabr_part is None + assert len(processor.initialized_formats) == 3 + assert str(caption_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(caption_format_id)].format_id == caption_format_id + assert processor.initialized_formats[str(caption_format_id)].discard is True + + # The format should be marked as completely buffered + assert len(processor.initialized_formats[str(caption_format_id)].consumed_ranges) == 1 + consumed_range = processor.initialized_formats[str(caption_format_id)].consumed_ranges[0] + assert consumed_range.start_sequence_number == 0 + assert consumed_range.end_sequence_number >= 99999999 + assert consumed_range.start_time_ms == 0 + assert consumed_range.duration_ms >= 99999999 + + def test_total_duration_ms(self, logger, base_args): + # Test the duration_ms calculation when end_time_ms and duration_ms are different + audio_selector = make_selector('audio') + video_selector = make_selector('video') + caption_selector = make_selector('caption') + audio_format_id = audio_selector.format_ids[0] + video_format_id = video_selector.format_ids[0] + caption_format_id = caption_selector.format_ids[0] + processor = SabrProcessor( + **base_args, + audio_selection=audio_selector, + video_selection=video_selector, + caption_selection=caption_selector, + video_id=example_video_id, + ) + + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + end_time_ms=15001, # End time is slightly more than 15 seconds + total_segments=5, + mime_type='audio/mp4', + duration_ticks=15000, + duration_timescale=1000, + ) + + assert processor.total_duration_ms is None + processor.process_format_initialization_metadata(format_init_metadata_part) + assert str(audio_format_id) in processor.initialized_formats + assert processor.total_duration_ms == 15001 + + # But if duration_ticks is greater, then use that + video_format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + end_time_ms=15003, # end time is slightly less + total_segments=10, + mime_type='video/mp4', + duration_ticks=150050, # Duration ticks is greater than end_time_ms + duration_timescale=10000, # slightly different timescale + ) + + processor.process_format_initialization_metadata(video_format_init_metadata_part) + assert str(video_format_id) in processor.initialized_formats + assert processor.total_duration_ms == 15005 + + # And if total_duration_ms is greater, use that + caption_format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=caption_format_id, + end_time_ms=15004, + total_segments=20, + mime_type='text/mp4', + duration_ticks=15004, + duration_timescale=1000, + ) + + processor.process_format_initialization_metadata(caption_format_init_metadata_part) + assert str(caption_format_id) in processor.initialized_formats + assert processor.total_duration_ms == 15005 # should not change + + def test_no_duration(self, logger, base_args): + # Test the case where no duration is provided + selector = make_selector('audio') + format_id = selector.format_ids[0] + processor = SabrProcessor(**base_args, audio_selection=selector, video_id=example_video_id) + + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=format_id, + end_time_ms=None, # No end time + total_segments=5, + mime_type='audio/mp4', + duration_ticks=None, + duration_timescale=None, + ) + + assert processor.total_duration_ms is None + processor.process_format_initialization_metadata(format_init_metadata_part) + assert str(format_id) in processor.initialized_formats + assert processor.total_duration_ms == 0 # TODO: should this be None or 0? + + def test_no_duration_total_duration_ms_set(self, logger, base_args): + # Test the case where no duration is provided but total_duration_ms is set (by e.g. live_metadata) + selector = make_selector('audio') + format_id = selector.format_ids[0] + processor = SabrProcessor( + **base_args, + audio_selection=selector, + video_id=example_video_id, + ) + + processor.total_duration_ms = 10000 + + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=format_id, + end_time_ms=None, # No end time + total_segments=5, + mime_type='audio/mp4', + duration_ticks=None, + duration_timescale=None, + ) + + processor.process_format_initialization_metadata(format_init_metadata_part) + assert str(format_id) in processor.initialized_formats + assert processor.total_duration_ms == 10000 + + def test_video_id_mismatch(self, logger, base_args): + selector = make_selector('audio') + format_id = selector.format_ids[0] + processor = SabrProcessor(**base_args, audio_selection=selector, video_id='video_1') + + format_init_metadata_part = FormatInitializationMetadata( + video_id='video_2', + format_id=format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + with pytest.raises(SabrStreamError, match='Received unexpected Format Initialization Metadata for video video_2'): + processor.process_format_initialization_metadata(format_init_metadata_part) + + assert len(processor.initialized_formats) == 0 + + def test_selector_consumed(self, logger, base_args): + # Test that if a format selector is already in use, it raises an error + selector = make_selector('audio', format_ids=[]) + audio_format_id = FormatId(itag=140) + processor = SabrProcessor(**base_args, audio_selection=selector, video_id=example_video_id) + + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + end_time_ms=10000, + total_segments=5, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + processor.process_format_initialization_metadata(format_init_metadata_part) + assert str(audio_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(audio_format_id)].format_selector is selector + + with pytest.raises(SabrStreamError, match='Changing formats is not currently supported'): + processor.process_format_initialization_metadata( + dataclasses.replace(format_init_metadata_part, format_id=FormatId(itag=141))) + + def test_no_segment_count(self, logger, base_args): + selector = make_selector('audio') + format_id = selector.format_ids[0] + processor = SabrProcessor(**base_args, audio_selection=selector, video_id=example_video_id) + + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=format_id, + end_time_ms=10000, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + processor.process_format_initialization_metadata(format_init_metadata_part) + assert str(format_id) in processor.initialized_formats + assert processor.initialized_formats[str(format_id)].total_segments is None + + def test_total_segment_count_live_metadata(self, logger, base_args): + # Test that total_segments is set from live_metadata when not in the format + audio_selector = make_selector('audio') + video_selector = make_selector('video') + audio_format_id = audio_selector.format_ids[0] + video_format_id = video_selector.format_ids[0] + processor = SabrProcessor( + **base_args, audio_selection=audio_selector, video_selection=video_selector, + video_id=example_video_id) + + format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + end_time_ms=10000, + mime_type='audio/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + processor.live_metadata = LiveMetadata(head_sequence_number=10) + processor.process_format_initialization_metadata(format_init_metadata_part) + assert str(audio_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(audio_format_id)].total_segments == 10 + + # But live metadata should not override total_segments if it is present + # XXX: when live metadata is updated, it will update the total_segments. + # However, we can consider the total_segments + # from the format initialization metadata as the most-up-to-date value until then. + + video_format_init_metadata_part = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + end_time_ms=10000, + # This should take precedence over live_metadata. + # Generally, this should only ever be greater than the live_metadata value. + # Never seen this be present for livestreams at this time. + # TODO: add a guard to ensure total segments is > live_metadata.head_sequence_number? + total_segments=9, + mime_type='video/mp4', + duration_ticks=10000, + duration_timescale=1000, + ) + + processor.process_format_initialization_metadata(video_format_init_metadata_part) + assert str(video_format_id) in processor.initialized_formats + assert processor.initialized_formats[str(video_format_id)].total_segments == 9 diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py index fa462b02bc..d4373c9681 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py @@ -525,7 +525,7 @@ def process_format_initialization_metadata(self, format_init_metadata: FormatIni if izf.format_selector is format_selector: raise SabrStreamError('Server changed format. Changing formats is not currently supported') - duration_ms = ticks_to_ms(format_init_metadata.duration_timescale, format_init_metadata.duration_ticks) + duration_ms = ticks_to_ms(format_init_metadata.duration_ticks, format_init_metadata.duration_timescale) total_segments = format_init_metadata.total_segments if not total_segments and self.live_metadata and self.live_metadata.head_sequence_number: