1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-06-27 17:08:32 +00:00

[test] add more tests and minor refactoring

This commit is contained in:
coletdjnz 2025-06-25 08:34:47 +12:00
parent 0fab874507
commit 53a2cd4732
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
3 changed files with 71 additions and 20 deletions

View File

@ -1,14 +1,16 @@
import pytest
from unittest.mock import MagicMock
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult
from yt_dlp.extractor.youtube._streaming.sabr.models import (
AudioSelector,
VideoSelector,
CaptionSelector,
)
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
@pytest.fixture
@ -233,7 +235,6 @@ def test_defaults(self, base_args):
assert processor.live_segment_target_duration_sec == 5
assert processor.live_segment_target_duration_tolerance_ms == 100
assert processor.start_time_ms == 0
assert processor.live_end_segment_tolerance == 10
assert processor.post_live is False
def test_override_defaults(self, base_args):
@ -242,11 +243,65 @@ def test_override_defaults(self, base_args):
live_segment_target_duration_sec=8,
live_segment_target_duration_tolerance_ms=42,
start_time_ms=123,
live_end_segment_tolerance=3,
post_live=True,
)
assert processor.live_segment_target_duration_sec == 8
assert processor.live_segment_target_duration_tolerance_ms == 42
assert processor.start_time_ms == 123
assert processor.live_end_segment_tolerance == 3
assert processor.post_live is True
class TestStreamProtectionStatusPart:
@pytest.mark.parametrize(
'sps,po_token,expected_status',
[
(StreamProtectionStatus.Status.OK, None, PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED),
(StreamProtectionStatus.Status.OK, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.OK),
(StreamProtectionStatus.Status.ATTESTATION_PENDING, None, PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING),
(StreamProtectionStatus.Status.ATTESTATION_PENDING, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.PENDING),
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, None, PoTokenStatusSabrPart.PoTokenStatus.MISSING),
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.INVALID),
],
)
def test_stream_protection_status_part(self, base_args, sps, po_token, expected_status):
processor = SabrProcessor(**base_args, po_token=po_token)
part = StreamProtectionStatus(status=sps)
result = processor.process_stream_protection_status(part)
assert isinstance(result, ProcessStreamProtectionStatusResult)
assert isinstance(result.sabr_part, PoTokenStatusSabrPart)
assert result.sabr_part.status == expected_status
assert processor.stream_protection_status == sps
def test_no_stream_protection_status(self, logger, base_args):
processor = SabrProcessor(**base_args, po_token='valid_token')
part = StreamProtectionStatus(status=None)
result = processor.process_stream_protection_status(part)
assert isinstance(result, ProcessStreamProtectionStatusResult)
assert result.sabr_part is None
assert processor.stream_protection_status is None
assert logger.warning.call_count == 1
logger.warning.assert_called_with(
'Received an unknown StreamProtectionStatus: StreamProtectionStatus(status=None, max_retries=None)',
)
class TestNextRequestPolicyPart:
def test_next_request_policy_part(self, logger, base_args):
processor = SabrProcessor(**base_args)
next_request_policy = NextRequestPolicy(target_audio_readahead_ms=123)
result = processor.process_next_request_policy(next_request_policy)
assert result is None
assert processor.next_request_policy is next_request_policy
# Verify it is overridden in the processor on another call
next_request_policy = NextRequestPolicy(target_video_readahead_ms=456)
result = processor.process_next_request_policy(next_request_policy)
assert result is None
assert processor.next_request_policy is next_request_policy
# Check logger trace was called
assert logger.trace.call_count == 2

View File

@ -97,8 +97,6 @@ def __init__(
live_segment_target_duration_tolerance_ms: int | None = None,
start_time_ms: int | None = None,
po_token: str | None = None,
live_end_wait_sec: int | None = None,
live_end_segment_tolerance: int | None = None,
post_live: bool = False,
video_id: str | None = None,
):
@ -119,9 +117,6 @@ def __init__(
if self.start_time_ms < 0:
raise ValueError('start_time_ms must be greater than or equal to 0')
# TODO: move to SabrStream
self.live_end_wait_sec = live_end_wait_sec or max(10, 3 * self.live_segment_target_duration_sec)
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
self.post_live = post_live
self._is_live = False
self.video_id = video_id
@ -498,6 +493,7 @@ def process_stream_protection_status(self, stream_protection_status: StreamProte
else PoTokenStatusSabrPart.PoTokenStatus.MISSING
)
else:
self.logger.warning(f'Received an unknown StreamProtectionStatus: {stream_protection_status}')
result_status = None
sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None

View File

@ -141,8 +141,6 @@ def __init__(
live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms,
start_time_ms=start_time_ms,
po_token=po_token,
live_end_wait_sec=live_end_wait_sec,
live_end_segment_tolerance=live_end_segment_tolerance,
post_live=post_live,
video_id=video_id,
)
@ -151,6 +149,8 @@ def __init__(
self.pot_retries = pot_retries or 5
self.host_fallback_threshold = host_fallback_threshold or 8
self.max_empty_requests = max_empty_requests or 3
self.live_end_wait_sec = live_end_wait_sec or max(10, self.max_empty_requests * self.processor.live_segment_target_duration_sec)
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds
if self.expiry_threshold_sec <= 0:
raise ValueError('expiry_threshold_sec must be greater than 0')
@ -346,7 +346,7 @@ def _validate_response_integrity(self):
self.processor.live_metadata
# TODO: generalize
and self.processor.client_abr_state.player_time_ms >= (
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance))
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance))
):
# Only log a warning if we are not near the head of a stream
self.logger.debug(msg)
@ -469,9 +469,9 @@ def _prepare_next_playback_time(self):
if (
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
and not self._is_retry
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
):
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds, assuming end of live stream')
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds, assuming end of live stream')
self._consumed = True
else:
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
@ -499,7 +499,7 @@ def _prepare_next_playback_time(self):
initialized_format.consumed_ranges,
key=lambda cr: cr.end_sequence_number,
)[-1].end_sequence_number
>= initialized_format.total_segments - self.processor.live_end_segment_tolerance
>= initialized_format.total_segments - self.live_end_segment_tolerance
)
for initialized_format in enabled_initialized_formats
)
@ -515,16 +515,16 @@ def _prepare_next_playback_time(self):
# Because of this, we should also check the player time against
# the head segment time using the estimated segment duration.
# xxx: consider also taking into account the max seekable timestamp
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance)
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance)
)
)
):
if (
not self._is_retry # allow us to sleep on a retry
and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
):
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds; assuming end of live stream')
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds; assuming end of live stream')
self._consumed = True
elif self._no_new_segments_tracker.consecutive_requests >= 1:
# Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream