From 53a2cd47323cdf2acc26c5af8116972f2236acb0 Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Wed, 25 Jun 2025 08:34:47 +1200 Subject: [PATCH] [test] add more tests and minor refactoring --- test/test_sabr/test_processor.py | 67 +++++++++++++++++-- .../youtube/_streaming/sabr/processor.py | 6 +- .../youtube/_streaming/sabr/stream.py | 18 ++--- 3 files changed, 71 insertions(+), 20 deletions(-) diff --git a/test/test_sabr/test_processor.py b/test/test_sabr/test_processor.py index 064396b58c..351952f239 100644 --- a/test/test_sabr/test_processor.py +++ b/test/test_sabr/test_processor.py @@ -1,14 +1,16 @@ import pytest from unittest.mock import MagicMock -from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor +from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart + +from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult from yt_dlp.extractor.youtube._streaming.sabr.models import ( AudioSelector, VideoSelector, CaptionSelector, ) -from yt_dlp.extractor.youtube._proto.videostreaming import FormatId -from yt_dlp.extractor.youtube._proto.innertube import ClientInfo +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus +from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy @pytest.fixture @@ -233,7 +235,6 @@ def test_defaults(self, base_args): assert processor.live_segment_target_duration_sec == 5 assert processor.live_segment_target_duration_tolerance_ms == 100 assert processor.start_time_ms == 0 - assert processor.live_end_segment_tolerance == 10 assert processor.post_live is False def test_override_defaults(self, base_args): @@ -242,11 +243,65 @@ def test_override_defaults(self, base_args): live_segment_target_duration_sec=8, live_segment_target_duration_tolerance_ms=42, start_time_ms=123, - live_end_segment_tolerance=3, post_live=True, ) assert processor.live_segment_target_duration_sec == 8 assert processor.live_segment_target_duration_tolerance_ms == 42 assert processor.start_time_ms == 123 - assert processor.live_end_segment_tolerance == 3 assert processor.post_live is True + + +class TestStreamProtectionStatusPart: + + @pytest.mark.parametrize( + 'sps,po_token,expected_status', + [ + (StreamProtectionStatus.Status.OK, None, PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED), + (StreamProtectionStatus.Status.OK, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.OK), + (StreamProtectionStatus.Status.ATTESTATION_PENDING, None, PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING), + (StreamProtectionStatus.Status.ATTESTATION_PENDING, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.PENDING), + (StreamProtectionStatus.Status.ATTESTATION_REQUIRED, None, PoTokenStatusSabrPart.PoTokenStatus.MISSING), + (StreamProtectionStatus.Status.ATTESTATION_REQUIRED, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.INVALID), + ], + ) + def test_stream_protection_status_part(self, base_args, sps, po_token, expected_status): + processor = SabrProcessor(**base_args, po_token=po_token) + part = StreamProtectionStatus(status=sps) + + result = processor.process_stream_protection_status(part) + assert isinstance(result, ProcessStreamProtectionStatusResult) + assert isinstance(result.sabr_part, PoTokenStatusSabrPart) + assert result.sabr_part.status == expected_status + assert processor.stream_protection_status == sps + + def test_no_stream_protection_status(self, logger, base_args): + processor = SabrProcessor(**base_args, po_token='valid_token') + part = StreamProtectionStatus(status=None) + + result = processor.process_stream_protection_status(part) + assert isinstance(result, ProcessStreamProtectionStatusResult) + assert result.sabr_part is None + assert processor.stream_protection_status is None + assert logger.warning.call_count == 1 + logger.warning.assert_called_with( + 'Received an unknown StreamProtectionStatus: StreamProtectionStatus(status=None, max_retries=None)', + ) + + +class TestNextRequestPolicyPart: + def test_next_request_policy_part(self, logger, base_args): + processor = SabrProcessor(**base_args) + next_request_policy = NextRequestPolicy(target_audio_readahead_ms=123) + + result = processor.process_next_request_policy(next_request_policy) + assert result is None + assert processor.next_request_policy is next_request_policy + + # Verify it is overridden in the processor on another call + next_request_policy = NextRequestPolicy(target_video_readahead_ms=456) + result = processor.process_next_request_policy(next_request_policy) + assert result is None + assert processor.next_request_policy is next_request_policy + + # Check logger trace was called + assert logger.trace.call_count == 2 diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py index c97efd7170..fa462b02bc 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py @@ -97,8 +97,6 @@ def __init__( live_segment_target_duration_tolerance_ms: int | None = None, start_time_ms: int | None = None, po_token: str | None = None, - live_end_wait_sec: int | None = None, - live_end_segment_tolerance: int | None = None, post_live: bool = False, video_id: str | None = None, ): @@ -119,9 +117,6 @@ def __init__( if self.start_time_ms < 0: raise ValueError('start_time_ms must be greater than or equal to 0') - # TODO: move to SabrStream - self.live_end_wait_sec = live_end_wait_sec or max(10, 3 * self.live_segment_target_duration_sec) - self.live_end_segment_tolerance = live_end_segment_tolerance or 10 self.post_live = post_live self._is_live = False self.video_id = video_id @@ -498,6 +493,7 @@ def process_stream_protection_status(self, stream_protection_status: StreamProte else PoTokenStatusSabrPart.PoTokenStatus.MISSING ) else: + self.logger.warning(f'Received an unknown StreamProtectionStatus: {stream_protection_status}') result_status = None sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/stream.py b/yt_dlp/extractor/youtube/_streaming/sabr/stream.py index 5e1096a9a8..d2ccd0fc4b 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/stream.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/stream.py @@ -141,8 +141,6 @@ def __init__( live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms, start_time_ms=start_time_ms, po_token=po_token, - live_end_wait_sec=live_end_wait_sec, - live_end_segment_tolerance=live_end_segment_tolerance, post_live=post_live, video_id=video_id, ) @@ -151,6 +149,8 @@ def __init__( self.pot_retries = pot_retries or 5 self.host_fallback_threshold = host_fallback_threshold or 8 self.max_empty_requests = max_empty_requests or 3 + self.live_end_wait_sec = live_end_wait_sec or max(10, self.max_empty_requests * self.processor.live_segment_target_duration_sec) + self.live_end_segment_tolerance = live_end_segment_tolerance or 10 self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds if self.expiry_threshold_sec <= 0: raise ValueError('expiry_threshold_sec must be greater than 0') @@ -346,7 +346,7 @@ def _validate_response_integrity(self): self.processor.live_metadata # TODO: generalize and self.processor.client_abr_state.player_time_ms >= ( - self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance)) + self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance)) ): # Only log a warning if we are not near the head of a stream self.logger.debug(msg) @@ -469,9 +469,9 @@ def _prepare_next_playback_time(self): if ( self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests and not self._is_retry - and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec + and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec ): - self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds, assuming end of live stream') + self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds, assuming end of live stream') self._consumed = True else: wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec) @@ -499,7 +499,7 @@ def _prepare_next_playback_time(self): initialized_format.consumed_ranges, key=lambda cr: cr.end_sequence_number, )[-1].end_sequence_number - >= initialized_format.total_segments - self.processor.live_end_segment_tolerance + >= initialized_format.total_segments - self.live_end_segment_tolerance ) for initialized_format in enabled_initialized_formats ) @@ -515,16 +515,16 @@ def _prepare_next_playback_time(self): # Because of this, we should also check the player time against # the head segment time using the estimated segment duration. # xxx: consider also taking into account the max seekable timestamp - request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance) + request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance) ) ) ): if ( not self._is_retry # allow us to sleep on a retry and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests - and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec + and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec ): - self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds; assuming end of live stream') + self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds; assuming end of live stream') self._consumed = True elif self._no_new_segments_tracker.consecutive_requests >= 1: # Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream