mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-06-27 17:08:32 +00:00
[test] add more tests and minor refactoring
This commit is contained in:
parent
0fab874507
commit
53a2cd4732
@ -1,14 +1,16 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import (
|
||||
AudioSelector,
|
||||
VideoSelector,
|
||||
CaptionSelector,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus
|
||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -233,7 +235,6 @@ def test_defaults(self, base_args):
|
||||
assert processor.live_segment_target_duration_sec == 5
|
||||
assert processor.live_segment_target_duration_tolerance_ms == 100
|
||||
assert processor.start_time_ms == 0
|
||||
assert processor.live_end_segment_tolerance == 10
|
||||
assert processor.post_live is False
|
||||
|
||||
def test_override_defaults(self, base_args):
|
||||
@ -242,11 +243,65 @@ def test_override_defaults(self, base_args):
|
||||
live_segment_target_duration_sec=8,
|
||||
live_segment_target_duration_tolerance_ms=42,
|
||||
start_time_ms=123,
|
||||
live_end_segment_tolerance=3,
|
||||
post_live=True,
|
||||
)
|
||||
assert processor.live_segment_target_duration_sec == 8
|
||||
assert processor.live_segment_target_duration_tolerance_ms == 42
|
||||
assert processor.start_time_ms == 123
|
||||
assert processor.live_end_segment_tolerance == 3
|
||||
assert processor.post_live is True
|
||||
|
||||
|
||||
class TestStreamProtectionStatusPart:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'sps,po_token,expected_status',
|
||||
[
|
||||
(StreamProtectionStatus.Status.OK, None, PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED),
|
||||
(StreamProtectionStatus.Status.OK, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.OK),
|
||||
(StreamProtectionStatus.Status.ATTESTATION_PENDING, None, PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING),
|
||||
(StreamProtectionStatus.Status.ATTESTATION_PENDING, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.PENDING),
|
||||
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, None, PoTokenStatusSabrPart.PoTokenStatus.MISSING),
|
||||
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.INVALID),
|
||||
],
|
||||
)
|
||||
def test_stream_protection_status_part(self, base_args, sps, po_token, expected_status):
|
||||
processor = SabrProcessor(**base_args, po_token=po_token)
|
||||
part = StreamProtectionStatus(status=sps)
|
||||
|
||||
result = processor.process_stream_protection_status(part)
|
||||
assert isinstance(result, ProcessStreamProtectionStatusResult)
|
||||
assert isinstance(result.sabr_part, PoTokenStatusSabrPart)
|
||||
assert result.sabr_part.status == expected_status
|
||||
assert processor.stream_protection_status == sps
|
||||
|
||||
def test_no_stream_protection_status(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args, po_token='valid_token')
|
||||
part = StreamProtectionStatus(status=None)
|
||||
|
||||
result = processor.process_stream_protection_status(part)
|
||||
assert isinstance(result, ProcessStreamProtectionStatusResult)
|
||||
assert result.sabr_part is None
|
||||
assert processor.stream_protection_status is None
|
||||
assert logger.warning.call_count == 1
|
||||
logger.warning.assert_called_with(
|
||||
'Received an unknown StreamProtectionStatus: StreamProtectionStatus(status=None, max_retries=None)',
|
||||
)
|
||||
|
||||
|
||||
class TestNextRequestPolicyPart:
|
||||
def test_next_request_policy_part(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
next_request_policy = NextRequestPolicy(target_audio_readahead_ms=123)
|
||||
|
||||
result = processor.process_next_request_policy(next_request_policy)
|
||||
assert result is None
|
||||
assert processor.next_request_policy is next_request_policy
|
||||
|
||||
# Verify it is overridden in the processor on another call
|
||||
next_request_policy = NextRequestPolicy(target_video_readahead_ms=456)
|
||||
result = processor.process_next_request_policy(next_request_policy)
|
||||
assert result is None
|
||||
assert processor.next_request_policy is next_request_policy
|
||||
|
||||
# Check logger trace was called
|
||||
assert logger.trace.call_count == 2
|
||||
|
@ -97,8 +97,6 @@ def __init__(
|
||||
live_segment_target_duration_tolerance_ms: int | None = None,
|
||||
start_time_ms: int | None = None,
|
||||
po_token: str | None = None,
|
||||
live_end_wait_sec: int | None = None,
|
||||
live_end_segment_tolerance: int | None = None,
|
||||
post_live: bool = False,
|
||||
video_id: str | None = None,
|
||||
):
|
||||
@ -119,9 +117,6 @@ def __init__(
|
||||
if self.start_time_ms < 0:
|
||||
raise ValueError('start_time_ms must be greater than or equal to 0')
|
||||
|
||||
# TODO: move to SabrStream
|
||||
self.live_end_wait_sec = live_end_wait_sec or max(10, 3 * self.live_segment_target_duration_sec)
|
||||
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
|
||||
self.post_live = post_live
|
||||
self._is_live = False
|
||||
self.video_id = video_id
|
||||
@ -498,6 +493,7 @@ def process_stream_protection_status(self, stream_protection_status: StreamProte
|
||||
else PoTokenStatusSabrPart.PoTokenStatus.MISSING
|
||||
)
|
||||
else:
|
||||
self.logger.warning(f'Received an unknown StreamProtectionStatus: {stream_protection_status}')
|
||||
result_status = None
|
||||
|
||||
sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None
|
||||
|
@ -141,8 +141,6 @@ def __init__(
|
||||
live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms,
|
||||
start_time_ms=start_time_ms,
|
||||
po_token=po_token,
|
||||
live_end_wait_sec=live_end_wait_sec,
|
||||
live_end_segment_tolerance=live_end_segment_tolerance,
|
||||
post_live=post_live,
|
||||
video_id=video_id,
|
||||
)
|
||||
@ -151,6 +149,8 @@ def __init__(
|
||||
self.pot_retries = pot_retries or 5
|
||||
self.host_fallback_threshold = host_fallback_threshold or 8
|
||||
self.max_empty_requests = max_empty_requests or 3
|
||||
self.live_end_wait_sec = live_end_wait_sec or max(10, self.max_empty_requests * self.processor.live_segment_target_duration_sec)
|
||||
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
|
||||
self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds
|
||||
if self.expiry_threshold_sec <= 0:
|
||||
raise ValueError('expiry_threshold_sec must be greater than 0')
|
||||
@ -346,7 +346,7 @@ def _validate_response_integrity(self):
|
||||
self.processor.live_metadata
|
||||
# TODO: generalize
|
||||
and self.processor.client_abr_state.player_time_ms >= (
|
||||
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance))
|
||||
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance))
|
||||
):
|
||||
# Only log a warning if we are not near the head of a stream
|
||||
self.logger.debug(msg)
|
||||
@ -469,9 +469,9 @@ def _prepare_next_playback_time(self):
|
||||
if (
|
||||
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
||||
and not self._is_retry
|
||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
|
||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
|
||||
):
|
||||
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds, assuming end of live stream')
|
||||
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds, assuming end of live stream')
|
||||
self._consumed = True
|
||||
else:
|
||||
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
|
||||
@ -499,7 +499,7 @@ def _prepare_next_playback_time(self):
|
||||
initialized_format.consumed_ranges,
|
||||
key=lambda cr: cr.end_sequence_number,
|
||||
)[-1].end_sequence_number
|
||||
>= initialized_format.total_segments - self.processor.live_end_segment_tolerance
|
||||
>= initialized_format.total_segments - self.live_end_segment_tolerance
|
||||
)
|
||||
for initialized_format in enabled_initialized_formats
|
||||
)
|
||||
@ -515,16 +515,16 @@ def _prepare_next_playback_time(self):
|
||||
# Because of this, we should also check the player time against
|
||||
# the head segment time using the estimated segment duration.
|
||||
# xxx: consider also taking into account the max seekable timestamp
|
||||
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance)
|
||||
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance)
|
||||
)
|
||||
)
|
||||
):
|
||||
if (
|
||||
not self._is_retry # allow us to sleep on a retry
|
||||
and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
|
||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
|
||||
):
|
||||
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds; assuming end of live stream')
|
||||
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds; assuming end of live stream')
|
||||
self._consumed = True
|
||||
elif self._no_new_segments_tracker.consecutive_requests >= 1:
|
||||
# Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream
|
||||
|
Loading…
Reference in New Issue
Block a user