From 75fb53bccc84583f70238f6b6f32d9b0bf5d132c Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Tue, 1 Jul 2025 07:45:58 +1200 Subject: [PATCH] [test] add more sabr processor tests --- test/test_sabr/test_processor.py | 437 +++++++++++++++++- .../youtube/_streaming/sabr/processor.py | 2 +- 2 files changed, 433 insertions(+), 6 deletions(-) diff --git a/test/test_sabr/test_processor.py b/test/test_sabr/test_processor.py index e43a15dc6a..9a2e63a8ed 100644 --- a/test/test_sabr/test_processor.py +++ b/test/test_sabr/test_processor.py @@ -4,18 +4,33 @@ from unittest.mock import MagicMock from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError -from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart, FormatInitializedSabrPart +from yt_dlp.extractor.youtube._streaming.sabr.part import ( + PoTokenStatusSabrPart, + FormatInitializedSabrPart, + MediaSeekSabrPart, +) -from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult, \ - ProcessFormatInitializationMetadataResult +from yt_dlp.extractor.youtube._streaming.sabr.processor import ( + SabrProcessor, + ProcessStreamProtectionStatusResult, + ProcessFormatInitializationMetadataResult, + ProcessLiveMetadataResult, ProcessSabrSeekResult, +) from yt_dlp.extractor.youtube._streaming.sabr.models import ( AudioSelector, VideoSelector, CaptionSelector, InitializedFormat, + Segment, +) +from yt_dlp.extractor.youtube._proto.videostreaming import ( + FormatId, + StreamProtectionStatus, + FormatInitializationMetadata, + LiveMetadata, + SabrContextUpdate, + SabrContextSendingPolicy, SabrSeek, ) -from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus, \ - FormatInitializationMetadata, LiveMetadata from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy @@ -805,3 +820,415 @@ def test_total_segment_count_live_metadata(self, logger, base_args): processor.process_format_initialization_metadata(video_format_init_metadata_part) assert str(video_format_id) in processor.initialized_formats assert processor.initialized_formats[str(video_format_id)].total_segments == 9 + + +class TestLiveMetadata: + + def test_live_metadata_initialization(self, base_args): + processor = SabrProcessor(**base_args) + assert processor.live_metadata is None + + def test_live_metadata_update(self, base_args): + processor = SabrProcessor(**base_args) + live_metadata = LiveMetadata(head_sequence_number=10) + + result = processor.process_live_metadata(live_metadata) + assert isinstance(result, ProcessLiveMetadataResult) + assert len(result.seek_sabr_parts) == 0 + assert processor.live_metadata is live_metadata + + # Ensure new live_metadata replaces the old one + live_metadata = dataclasses.replace(live_metadata, head_sequence_number=20) + result = processor.process_live_metadata(live_metadata) + + assert isinstance(result, ProcessLiveMetadataResult) + assert len(result.seek_sabr_parts) == 0 + assert processor.live_metadata is live_metadata + + def test_live_metadata_no_head_sequence_time_ms(self, base_args): + processor = SabrProcessor(**base_args) + live_metadata = LiveMetadata(head_sequence_number=10, head_sequence_time_ms=None) + + processor.process_live_metadata(live_metadata) + assert processor.live_metadata is live_metadata + assert processor.total_duration_ms is None + + def test_live_metadata_with_head_sequence_time_ms(self, base_args): + processor = SabrProcessor(**base_args) + live_metadata = LiveMetadata(head_sequence_number=10, head_sequence_time_ms=5000) + + processor.process_live_metadata(live_metadata) + assert processor.live_metadata is live_metadata + assert processor.total_duration_ms == 5000 + + def test_update_izf_total_segments(self, base_args): + + live_metadata = LiveMetadata(head_sequence_number=10) + audio_selector = make_selector('audio') + video_selector = make_selector('video') + processor = SabrProcessor( + **base_args, + audio_selection=audio_selector, + video_selection=video_selector, + video_id=example_video_id) + + # Initialize both audio and video formats + audio_format_id = audio_selector.format_ids[0] + video_format_id = video_selector.format_ids[0] + + audio_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + mime_type='audio/mp4', + ) + + video_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + mime_type='video/mp4', + ) + + processor.process_format_initialization_metadata(audio_format_init_metadata) + processor.process_format_initialization_metadata(video_format_init_metadata) + assert len(processor.initialized_formats) == 2 + + # Process live metadata + processor.process_live_metadata(live_metadata) + assert processor.live_metadata is live_metadata + + # Check that total_segments is updated in both formats + assert processor.initialized_formats[str(audio_format_id)].total_segments == 10 + assert processor.initialized_formats[str(video_format_id)].total_segments == 10 + + def test_min_seekable_time_ms_less_than_player_time_ms(self, base_args): + # If min_seekable_time_ms is greater or equal to player time, there should not be a seek + live_metadata = LiveMetadata( + head_sequence_number=10, + head_sequence_time_ms=10000, + min_seekable_time_ticks=50000, + min_seekable_timescale=10000) + + audio_selector = make_selector('audio') + video_selector = make_selector('video') + processor = SabrProcessor( + **base_args, + audio_selection=audio_selector, + video_selection=video_selector, + video_id=example_video_id, + start_time_ms=5001) + + assert processor.client_abr_state.player_time_ms == 5001 + + # Initialize both audio and video formats + audio_format_id = audio_selector.format_ids[0] + video_format_id = video_selector.format_ids[0] + + audio_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + mime_type='audio/mp4', + ) + + video_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + mime_type='video/mp4', + ) + + processor.process_format_initialization_metadata(audio_format_init_metadata) + processor.process_format_initialization_metadata(video_format_init_metadata) + assert len(processor.initialized_formats) == 2 + + # Process live metadata + result = processor.process_live_metadata(live_metadata) + assert isinstance(result, ProcessLiveMetadataResult) + assert len(result.seek_sabr_parts) == 0 + assert processor.live_metadata is live_metadata + assert processor.client_abr_state.player_time_ms == 5001 + + def test_min_seekable_time_ms_greater_than_player_time_ms(self, base_args, logger): + # If min_seekable_time_ms is less than player time, there should be a seek + live_metadata = LiveMetadata( + head_sequence_number=10, + head_sequence_time_ms=10000, + min_seekable_time_ticks=50000, + min_seekable_timescale=10000, + ) + + audio_selector = make_selector('audio') + video_selector = make_selector('video') + processor = SabrProcessor( + **base_args, + audio_selection=audio_selector, + video_selection=video_selector, + video_id=example_video_id, + start_time_ms=4999) + + assert processor.client_abr_state.player_time_ms == 4999 + + # Initialize both audio and video formats + audio_format_id = audio_selector.format_ids[0] + video_format_id = video_selector.format_ids[0] + + audio_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + mime_type='audio/mp4', + ) + + video_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + mime_type='video/mp4', + ) + + processor.process_format_initialization_metadata(audio_format_init_metadata) + processor.process_format_initialization_metadata(video_format_init_metadata) + assert len(processor.initialized_formats) == 2 + + # Add a dummy previous segment to each format - this should be cleared on seek + for izf in processor.initialized_formats.values(): + izf.current_segment = Segment( + format_id=izf.format_id, + ) + + # Process live metadata + result = processor.process_live_metadata(live_metadata) + assert isinstance(result, ProcessLiveMetadataResult) + assert len(result.seek_sabr_parts) == 2 + assert processor.live_metadata is live_metadata + assert processor.client_abr_state.player_time_ms == 5000 + + for seek_part in result.seek_sabr_parts: + assert isinstance(seek_part, MediaSeekSabrPart) + assert seek_part.format_id in (audio_format_id, video_format_id) + assert seek_part.format_selector in (audio_selector, video_selector) + + # Current segment should be cleared to indicate a seek + for izf in processor.initialized_formats.values(): + assert izf.current_segment is None + + logger.debug.assert_called_with('Player time 4999 is less than min seekable time 5000, simulating server seek') + + +class TestSabrContextUpdate: + def test_initialization(self, base_args): + processor = SabrProcessor(**base_args) + assert len(processor.sabr_context_updates) == 0 + assert len(processor.sabr_contexts_to_send) == 0 + + def test_invalid_sabr_context_update(self, logger, base_args): + processor = SabrProcessor(**base_args) + invalid_update = SabrContextUpdate() + processor.process_sabr_context_update(invalid_update) + + assert len(processor.sabr_context_updates) == 0 + assert len(processor.sabr_contexts_to_send) == 0 + logger.warning.assert_called_with('Received an invalid SabrContextUpdate, ignoring') + + def test_valid_sabr_context_update(self, logger, base_args): + processor = SabrProcessor(**base_args) + valid_update = SabrContextUpdate( + type=5, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS, + value=b'{"key": "value"}', + send_by_default=True, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE, + ) + processor.process_sabr_context_update(valid_update) + + assert len(processor.sabr_context_updates) == 1 + assert processor.sabr_context_updates[valid_update.type] == valid_update + assert len(processor.sabr_contexts_to_send) == 1 + assert valid_update.type in processor.sabr_contexts_to_send + logger.debug.assert_called_with(f'Registered SabrContextUpdate {valid_update}') + + def test_not_send_by_default(self, logger, base_args): + processor = SabrProcessor(**base_args) + valid_update = SabrContextUpdate( + type=5, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS, + value=b'{"key": "value"}', + send_by_default=False, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE, + ) + processor.process_sabr_context_update(valid_update) + assert len(processor.sabr_context_updates) == 1 + assert processor.sabr_context_updates[valid_update.type] == valid_update + assert len(processor.sabr_contexts_to_send) == 0 + + def test_write_policy_overwrite(self, logger, base_args): + processor = SabrProcessor(**base_args) + first_ctx_update = SabrContextUpdate( + type=5, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS, + value=b'{"key": "value"}', + send_by_default=True, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE, + ) + processor.process_sabr_context_update(first_ctx_update) + + assert len(processor.sabr_context_updates) == 1 + assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update + assert len(processor.sabr_contexts_to_send) == 1 + assert first_ctx_update.type in processor.sabr_contexts_to_send + + second_ctx_update = SabrContextUpdate( + type=5, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS, + value=b'{"key": "new_value"}', + send_by_default=True, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE, + ) + + processor.process_sabr_context_update(second_ctx_update) + assert len(processor.sabr_context_updates) == 1 + assert processor.sabr_context_updates[second_ctx_update.type] == second_ctx_update + assert len(processor.sabr_contexts_to_send) == 1 + assert second_ctx_update.type in processor.sabr_contexts_to_send + + def test_write_policy_keep_existing(self, logger, base_args): + processor = SabrProcessor(**base_args) + first_ctx_update = SabrContextUpdate( + type=5, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS, + value=b'{"key": "value"}', + send_by_default=True, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING, + ) + processor.process_sabr_context_update(first_ctx_update) + + assert len(processor.sabr_context_updates) == 1 + assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update + assert len(processor.sabr_contexts_to_send) == 1 + assert first_ctx_update.type in processor.sabr_contexts_to_send + + second_ctx_update = SabrContextUpdate( + type=5, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS, + value=b'{"key": "new_value"}', + send_by_default=True, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING, + ) + + processor.process_sabr_context_update(second_ctx_update) + assert len(processor.sabr_context_updates) == 1 + assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update + assert len(processor.sabr_contexts_to_send) == 1 + assert first_ctx_update.type in processor.sabr_contexts_to_send + logger.debug.assert_called_with( + 'Received a SABR Context Update with write_policy=KEEP_EXISTING' + 'matching an existing SABR Context Update. Ignoring update') + + +class TestSabrContextUpdateSendingPolicy: + def test_set_sabr_context_update_sending_policy(self, base_args, logger): + processor = SabrProcessor(**base_args) + + processor.process_sabr_context_update(SabrContextUpdate( + type=3, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_PLAYBACK, + value=b'{"key": "value"}', + send_by_default=True, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE, + )) + + processor.process_sabr_context_update(SabrContextUpdate( + type=4, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_REQUEST, + value=b'{"key": "value"}', + send_by_default=True, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING, + )) + + processor.process_sabr_context_update(SabrContextUpdate( + type=5, + scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS, + value=b'{"key": "value"}', + send_by_default=False, + write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE, + )) + + assert len(processor.sabr_context_updates) == 3 + assert len(processor.sabr_contexts_to_send) == 2 + assert 3 in processor.sabr_contexts_to_send + assert 4 in processor.sabr_contexts_to_send + assert 5 not in processor.sabr_contexts_to_send + + # Sending policy should update what contexts are sent + processor.process_sabr_context_sending_policy( + SabrContextSendingPolicy( + start_policy=[5, 6], + stop_policy=[3, 0], + discard_policy=[4, 7])) + + assert len(processor.sabr_context_updates) == 2 + assert len(processor.sabr_contexts_to_send) == 3 + assert 5 in processor.sabr_contexts_to_send + assert 4 in processor.sabr_contexts_to_send # discarding does not remove from contexts to send + assert 6 in processor.sabr_contexts_to_send + assert all(n not in processor.sabr_contexts_to_send for n in [3, 0, 7]) + assert 4 not in processor.sabr_context_updates + + +class TestSabrSeek: + def test_invalid_sabr_seek(self, logger, base_args): + processor = SabrProcessor(**base_args) + invalid_seek = SabrSeek(seek_time_ticks=100, timescale=None) + with pytest.raises(SabrStreamError, match='Server sent a SabrSeek part that is missing required seek data'): + processor.process_sabr_seek(invalid_seek) + assert processor.client_abr_state.player_time_ms == 0 + + def test_sabr_seek(self, logger, base_args): + audio_selector = make_selector('audio') + video_selector = make_selector('video') + audio_format_id = audio_selector.format_ids[0] + video_format_id = video_selector.format_ids[0] + processor = SabrProcessor( + **base_args, + audio_selection=audio_selector, + video_selection=video_selector, + video_id=example_video_id) + + assert processor.client_abr_state.player_time_ms == 0 + + audio_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=audio_format_id, + mime_type='audio/mp4', + ) + video_format_init_metadata = FormatInitializationMetadata( + video_id=example_video_id, + format_id=video_format_id, + mime_type='video/mp4', + ) + + processor.process_format_initialization_metadata(audio_format_init_metadata) + processor.process_format_initialization_metadata(video_format_init_metadata) + assert len(processor.initialized_formats) == 2 + + # Add a dummy previous segment to each format - this should be cleared on seek + for izf in processor.initialized_formats.values(): + izf.current_segment = Segment( + format_id=izf.format_id, + ) + + sabr_seek = SabrSeek( + seek_time_ticks=56000, + timescale=10000, + ) + + result = processor.process_sabr_seek(sabr_seek) + assert isinstance(result, ProcessSabrSeekResult) + assert len(result.seek_sabr_parts) == 2 + assert processor.client_abr_state.player_time_ms == 5600 + for seek_part in result.seek_sabr_parts: + assert isinstance(seek_part, MediaSeekSabrPart) + assert seek_part.format_id in (audio_format_id, video_format_id) + assert seek_part.format_selector in (audio_selector, video_selector) + + # Current segment should be cleared to indicate a seek + for izf in processor.initialized_formats.values(): + assert izf.current_segment is None + + logger.debug.assert_called_with('Seeking to 5600ms') diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py index d4373c9681..6acea99750 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py @@ -614,7 +614,7 @@ def process_sabr_context_update(self, sabr_ctx_update: SabrContextUpdate): def process_sabr_context_sending_policy(self, sabr_ctx_sending_policy: SabrContextSendingPolicy): for start_type in sabr_ctx_sending_policy.start_policy: - if start_type not in self.sabr_context_updates: + if start_type not in self.sabr_contexts_to_send: self.logger.debug(f'Server requested to enable SABR Context Update for type {start_type}') self.sabr_contexts_to_send.add(start_type)