1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-06-28 01:18:30 +00:00

[test] add more tests and minor refactoring

This commit is contained in:
coletdjnz 2025-06-25 08:34:47 +12:00
parent 0fab874507
commit 53a2cd4732
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
3 changed files with 71 additions and 20 deletions

View File

@ -1,14 +1,16 @@
import pytest import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult
from yt_dlp.extractor.youtube._streaming.sabr.models import ( from yt_dlp.extractor.youtube._streaming.sabr.models import (
AudioSelector, AudioSelector,
VideoSelector, VideoSelector,
CaptionSelector, CaptionSelector,
) )
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
@pytest.fixture @pytest.fixture
@ -233,7 +235,6 @@ def test_defaults(self, base_args):
assert processor.live_segment_target_duration_sec == 5 assert processor.live_segment_target_duration_sec == 5
assert processor.live_segment_target_duration_tolerance_ms == 100 assert processor.live_segment_target_duration_tolerance_ms == 100
assert processor.start_time_ms == 0 assert processor.start_time_ms == 0
assert processor.live_end_segment_tolerance == 10
assert processor.post_live is False assert processor.post_live is False
def test_override_defaults(self, base_args): def test_override_defaults(self, base_args):
@ -242,11 +243,65 @@ def test_override_defaults(self, base_args):
live_segment_target_duration_sec=8, live_segment_target_duration_sec=8,
live_segment_target_duration_tolerance_ms=42, live_segment_target_duration_tolerance_ms=42,
start_time_ms=123, start_time_ms=123,
live_end_segment_tolerance=3,
post_live=True, post_live=True,
) )
assert processor.live_segment_target_duration_sec == 8 assert processor.live_segment_target_duration_sec == 8
assert processor.live_segment_target_duration_tolerance_ms == 42 assert processor.live_segment_target_duration_tolerance_ms == 42
assert processor.start_time_ms == 123 assert processor.start_time_ms == 123
assert processor.live_end_segment_tolerance == 3
assert processor.post_live is True assert processor.post_live is True
class TestStreamProtectionStatusPart:
@pytest.mark.parametrize(
'sps,po_token,expected_status',
[
(StreamProtectionStatus.Status.OK, None, PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED),
(StreamProtectionStatus.Status.OK, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.OK),
(StreamProtectionStatus.Status.ATTESTATION_PENDING, None, PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING),
(StreamProtectionStatus.Status.ATTESTATION_PENDING, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.PENDING),
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, None, PoTokenStatusSabrPart.PoTokenStatus.MISSING),
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.INVALID),
],
)
def test_stream_protection_status_part(self, base_args, sps, po_token, expected_status):
processor = SabrProcessor(**base_args, po_token=po_token)
part = StreamProtectionStatus(status=sps)
result = processor.process_stream_protection_status(part)
assert isinstance(result, ProcessStreamProtectionStatusResult)
assert isinstance(result.sabr_part, PoTokenStatusSabrPart)
assert result.sabr_part.status == expected_status
assert processor.stream_protection_status == sps
def test_no_stream_protection_status(self, logger, base_args):
processor = SabrProcessor(**base_args, po_token='valid_token')
part = StreamProtectionStatus(status=None)
result = processor.process_stream_protection_status(part)
assert isinstance(result, ProcessStreamProtectionStatusResult)
assert result.sabr_part is None
assert processor.stream_protection_status is None
assert logger.warning.call_count == 1
logger.warning.assert_called_with(
'Received an unknown StreamProtectionStatus: StreamProtectionStatus(status=None, max_retries=None)',
)
class TestNextRequestPolicyPart:
def test_next_request_policy_part(self, logger, base_args):
processor = SabrProcessor(**base_args)
next_request_policy = NextRequestPolicy(target_audio_readahead_ms=123)
result = processor.process_next_request_policy(next_request_policy)
assert result is None
assert processor.next_request_policy is next_request_policy
# Verify it is overridden in the processor on another call
next_request_policy = NextRequestPolicy(target_video_readahead_ms=456)
result = processor.process_next_request_policy(next_request_policy)
assert result is None
assert processor.next_request_policy is next_request_policy
# Check logger trace was called
assert logger.trace.call_count == 2

View File

@ -97,8 +97,6 @@ def __init__(
live_segment_target_duration_tolerance_ms: int | None = None, live_segment_target_duration_tolerance_ms: int | None = None,
start_time_ms: int | None = None, start_time_ms: int | None = None,
po_token: str | None = None, po_token: str | None = None,
live_end_wait_sec: int | None = None,
live_end_segment_tolerance: int | None = None,
post_live: bool = False, post_live: bool = False,
video_id: str | None = None, video_id: str | None = None,
): ):
@ -119,9 +117,6 @@ def __init__(
if self.start_time_ms < 0: if self.start_time_ms < 0:
raise ValueError('start_time_ms must be greater than or equal to 0') raise ValueError('start_time_ms must be greater than or equal to 0')
# TODO: move to SabrStream
self.live_end_wait_sec = live_end_wait_sec or max(10, 3 * self.live_segment_target_duration_sec)
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
self.post_live = post_live self.post_live = post_live
self._is_live = False self._is_live = False
self.video_id = video_id self.video_id = video_id
@ -498,6 +493,7 @@ def process_stream_protection_status(self, stream_protection_status: StreamProte
else PoTokenStatusSabrPart.PoTokenStatus.MISSING else PoTokenStatusSabrPart.PoTokenStatus.MISSING
) )
else: else:
self.logger.warning(f'Received an unknown StreamProtectionStatus: {stream_protection_status}')
result_status = None result_status = None
sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None

View File

@ -141,8 +141,6 @@ def __init__(
live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms, live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms,
start_time_ms=start_time_ms, start_time_ms=start_time_ms,
po_token=po_token, po_token=po_token,
live_end_wait_sec=live_end_wait_sec,
live_end_segment_tolerance=live_end_segment_tolerance,
post_live=post_live, post_live=post_live,
video_id=video_id, video_id=video_id,
) )
@ -151,6 +149,8 @@ def __init__(
self.pot_retries = pot_retries or 5 self.pot_retries = pot_retries or 5
self.host_fallback_threshold = host_fallback_threshold or 8 self.host_fallback_threshold = host_fallback_threshold or 8
self.max_empty_requests = max_empty_requests or 3 self.max_empty_requests = max_empty_requests or 3
self.live_end_wait_sec = live_end_wait_sec or max(10, self.max_empty_requests * self.processor.live_segment_target_duration_sec)
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds
if self.expiry_threshold_sec <= 0: if self.expiry_threshold_sec <= 0:
raise ValueError('expiry_threshold_sec must be greater than 0') raise ValueError('expiry_threshold_sec must be greater than 0')
@ -346,7 +346,7 @@ def _validate_response_integrity(self):
self.processor.live_metadata self.processor.live_metadata
# TODO: generalize # TODO: generalize
and self.processor.client_abr_state.player_time_ms >= ( and self.processor.client_abr_state.player_time_ms >= (
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance)) self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance))
): ):
# Only log a warning if we are not near the head of a stream # Only log a warning if we are not near the head of a stream
self.logger.debug(msg) self.logger.debug(msg)
@ -469,9 +469,9 @@ def _prepare_next_playback_time(self):
if ( if (
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
and not self._is_retry and not self._is_retry
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
): ):
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds, assuming end of live stream') self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds, assuming end of live stream')
self._consumed = True self._consumed = True
else: else:
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec) wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
@ -499,7 +499,7 @@ def _prepare_next_playback_time(self):
initialized_format.consumed_ranges, initialized_format.consumed_ranges,
key=lambda cr: cr.end_sequence_number, key=lambda cr: cr.end_sequence_number,
)[-1].end_sequence_number )[-1].end_sequence_number
>= initialized_format.total_segments - self.processor.live_end_segment_tolerance >= initialized_format.total_segments - self.live_end_segment_tolerance
) )
for initialized_format in enabled_initialized_formats for initialized_format in enabled_initialized_formats
) )
@ -515,16 +515,16 @@ def _prepare_next_playback_time(self):
# Because of this, we should also check the player time against # Because of this, we should also check the player time against
# the head segment time using the estimated segment duration. # the head segment time using the estimated segment duration.
# xxx: consider also taking into account the max seekable timestamp # xxx: consider also taking into account the max seekable timestamp
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance) request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance)
) )
) )
): ):
if ( if (
not self._is_retry # allow us to sleep on a retry not self._is_retry # allow us to sleep on a retry
and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
): ):
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds; assuming end of live stream') self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds; assuming end of live stream')
self._consumed = True self._consumed = True
elif self._no_new_segments_tracker.consecutive_requests >= 1: elif self._no_new_segments_tracker.consecutive_requests >= 1:
# Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream # Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream