mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-06-28 01:18:30 +00:00
[test] add more tests and minor refactoring
This commit is contained in:
parent
0fab874507
commit
53a2cd4732
@ -1,14 +1,16 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor
|
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart
|
||||||
|
|
||||||
|
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult
|
||||||
from yt_dlp.extractor.youtube._streaming.sabr.models import (
|
from yt_dlp.extractor.youtube._streaming.sabr.models import (
|
||||||
AudioSelector,
|
AudioSelector,
|
||||||
VideoSelector,
|
VideoSelector,
|
||||||
CaptionSelector,
|
CaptionSelector,
|
||||||
)
|
)
|
||||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus
|
||||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo
|
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -233,7 +235,6 @@ def test_defaults(self, base_args):
|
|||||||
assert processor.live_segment_target_duration_sec == 5
|
assert processor.live_segment_target_duration_sec == 5
|
||||||
assert processor.live_segment_target_duration_tolerance_ms == 100
|
assert processor.live_segment_target_duration_tolerance_ms == 100
|
||||||
assert processor.start_time_ms == 0
|
assert processor.start_time_ms == 0
|
||||||
assert processor.live_end_segment_tolerance == 10
|
|
||||||
assert processor.post_live is False
|
assert processor.post_live is False
|
||||||
|
|
||||||
def test_override_defaults(self, base_args):
|
def test_override_defaults(self, base_args):
|
||||||
@ -242,11 +243,65 @@ def test_override_defaults(self, base_args):
|
|||||||
live_segment_target_duration_sec=8,
|
live_segment_target_duration_sec=8,
|
||||||
live_segment_target_duration_tolerance_ms=42,
|
live_segment_target_duration_tolerance_ms=42,
|
||||||
start_time_ms=123,
|
start_time_ms=123,
|
||||||
live_end_segment_tolerance=3,
|
|
||||||
post_live=True,
|
post_live=True,
|
||||||
)
|
)
|
||||||
assert processor.live_segment_target_duration_sec == 8
|
assert processor.live_segment_target_duration_sec == 8
|
||||||
assert processor.live_segment_target_duration_tolerance_ms == 42
|
assert processor.live_segment_target_duration_tolerance_ms == 42
|
||||||
assert processor.start_time_ms == 123
|
assert processor.start_time_ms == 123
|
||||||
assert processor.live_end_segment_tolerance == 3
|
|
||||||
assert processor.post_live is True
|
assert processor.post_live is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamProtectionStatusPart:
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'sps,po_token,expected_status',
|
||||||
|
[
|
||||||
|
(StreamProtectionStatus.Status.OK, None, PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED),
|
||||||
|
(StreamProtectionStatus.Status.OK, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.OK),
|
||||||
|
(StreamProtectionStatus.Status.ATTESTATION_PENDING, None, PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING),
|
||||||
|
(StreamProtectionStatus.Status.ATTESTATION_PENDING, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.PENDING),
|
||||||
|
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, None, PoTokenStatusSabrPart.PoTokenStatus.MISSING),
|
||||||
|
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.INVALID),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_stream_protection_status_part(self, base_args, sps, po_token, expected_status):
|
||||||
|
processor = SabrProcessor(**base_args, po_token=po_token)
|
||||||
|
part = StreamProtectionStatus(status=sps)
|
||||||
|
|
||||||
|
result = processor.process_stream_protection_status(part)
|
||||||
|
assert isinstance(result, ProcessStreamProtectionStatusResult)
|
||||||
|
assert isinstance(result.sabr_part, PoTokenStatusSabrPart)
|
||||||
|
assert result.sabr_part.status == expected_status
|
||||||
|
assert processor.stream_protection_status == sps
|
||||||
|
|
||||||
|
def test_no_stream_protection_status(self, logger, base_args):
|
||||||
|
processor = SabrProcessor(**base_args, po_token='valid_token')
|
||||||
|
part = StreamProtectionStatus(status=None)
|
||||||
|
|
||||||
|
result = processor.process_stream_protection_status(part)
|
||||||
|
assert isinstance(result, ProcessStreamProtectionStatusResult)
|
||||||
|
assert result.sabr_part is None
|
||||||
|
assert processor.stream_protection_status is None
|
||||||
|
assert logger.warning.call_count == 1
|
||||||
|
logger.warning.assert_called_with(
|
||||||
|
'Received an unknown StreamProtectionStatus: StreamProtectionStatus(status=None, max_retries=None)',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNextRequestPolicyPart:
|
||||||
|
def test_next_request_policy_part(self, logger, base_args):
|
||||||
|
processor = SabrProcessor(**base_args)
|
||||||
|
next_request_policy = NextRequestPolicy(target_audio_readahead_ms=123)
|
||||||
|
|
||||||
|
result = processor.process_next_request_policy(next_request_policy)
|
||||||
|
assert result is None
|
||||||
|
assert processor.next_request_policy is next_request_policy
|
||||||
|
|
||||||
|
# Verify it is overridden in the processor on another call
|
||||||
|
next_request_policy = NextRequestPolicy(target_video_readahead_ms=456)
|
||||||
|
result = processor.process_next_request_policy(next_request_policy)
|
||||||
|
assert result is None
|
||||||
|
assert processor.next_request_policy is next_request_policy
|
||||||
|
|
||||||
|
# Check logger trace was called
|
||||||
|
assert logger.trace.call_count == 2
|
||||||
|
@ -97,8 +97,6 @@ def __init__(
|
|||||||
live_segment_target_duration_tolerance_ms: int | None = None,
|
live_segment_target_duration_tolerance_ms: int | None = None,
|
||||||
start_time_ms: int | None = None,
|
start_time_ms: int | None = None,
|
||||||
po_token: str | None = None,
|
po_token: str | None = None,
|
||||||
live_end_wait_sec: int | None = None,
|
|
||||||
live_end_segment_tolerance: int | None = None,
|
|
||||||
post_live: bool = False,
|
post_live: bool = False,
|
||||||
video_id: str | None = None,
|
video_id: str | None = None,
|
||||||
):
|
):
|
||||||
@ -119,9 +117,6 @@ def __init__(
|
|||||||
if self.start_time_ms < 0:
|
if self.start_time_ms < 0:
|
||||||
raise ValueError('start_time_ms must be greater than or equal to 0')
|
raise ValueError('start_time_ms must be greater than or equal to 0')
|
||||||
|
|
||||||
# TODO: move to SabrStream
|
|
||||||
self.live_end_wait_sec = live_end_wait_sec or max(10, 3 * self.live_segment_target_duration_sec)
|
|
||||||
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
|
|
||||||
self.post_live = post_live
|
self.post_live = post_live
|
||||||
self._is_live = False
|
self._is_live = False
|
||||||
self.video_id = video_id
|
self.video_id = video_id
|
||||||
@ -498,6 +493,7 @@ def process_stream_protection_status(self, stream_protection_status: StreamProte
|
|||||||
else PoTokenStatusSabrPart.PoTokenStatus.MISSING
|
else PoTokenStatusSabrPart.PoTokenStatus.MISSING
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.logger.warning(f'Received an unknown StreamProtectionStatus: {stream_protection_status}')
|
||||||
result_status = None
|
result_status = None
|
||||||
|
|
||||||
sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None
|
sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None
|
||||||
|
@ -141,8 +141,6 @@ def __init__(
|
|||||||
live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms,
|
live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms,
|
||||||
start_time_ms=start_time_ms,
|
start_time_ms=start_time_ms,
|
||||||
po_token=po_token,
|
po_token=po_token,
|
||||||
live_end_wait_sec=live_end_wait_sec,
|
|
||||||
live_end_segment_tolerance=live_end_segment_tolerance,
|
|
||||||
post_live=post_live,
|
post_live=post_live,
|
||||||
video_id=video_id,
|
video_id=video_id,
|
||||||
)
|
)
|
||||||
@ -151,6 +149,8 @@ def __init__(
|
|||||||
self.pot_retries = pot_retries or 5
|
self.pot_retries = pot_retries or 5
|
||||||
self.host_fallback_threshold = host_fallback_threshold or 8
|
self.host_fallback_threshold = host_fallback_threshold or 8
|
||||||
self.max_empty_requests = max_empty_requests or 3
|
self.max_empty_requests = max_empty_requests or 3
|
||||||
|
self.live_end_wait_sec = live_end_wait_sec or max(10, self.max_empty_requests * self.processor.live_segment_target_duration_sec)
|
||||||
|
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
|
||||||
self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds
|
self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds
|
||||||
if self.expiry_threshold_sec <= 0:
|
if self.expiry_threshold_sec <= 0:
|
||||||
raise ValueError('expiry_threshold_sec must be greater than 0')
|
raise ValueError('expiry_threshold_sec must be greater than 0')
|
||||||
@ -346,7 +346,7 @@ def _validate_response_integrity(self):
|
|||||||
self.processor.live_metadata
|
self.processor.live_metadata
|
||||||
# TODO: generalize
|
# TODO: generalize
|
||||||
and self.processor.client_abr_state.player_time_ms >= (
|
and self.processor.client_abr_state.player_time_ms >= (
|
||||||
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance))
|
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance))
|
||||||
):
|
):
|
||||||
# Only log a warning if we are not near the head of a stream
|
# Only log a warning if we are not near the head of a stream
|
||||||
self.logger.debug(msg)
|
self.logger.debug(msg)
|
||||||
@ -469,9 +469,9 @@ def _prepare_next_playback_time(self):
|
|||||||
if (
|
if (
|
||||||
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
||||||
and not self._is_retry
|
and not self._is_retry
|
||||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
|
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
|
||||||
):
|
):
|
||||||
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds, assuming end of live stream')
|
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds, assuming end of live stream')
|
||||||
self._consumed = True
|
self._consumed = True
|
||||||
else:
|
else:
|
||||||
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
|
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
|
||||||
@ -499,7 +499,7 @@ def _prepare_next_playback_time(self):
|
|||||||
initialized_format.consumed_ranges,
|
initialized_format.consumed_ranges,
|
||||||
key=lambda cr: cr.end_sequence_number,
|
key=lambda cr: cr.end_sequence_number,
|
||||||
)[-1].end_sequence_number
|
)[-1].end_sequence_number
|
||||||
>= initialized_format.total_segments - self.processor.live_end_segment_tolerance
|
>= initialized_format.total_segments - self.live_end_segment_tolerance
|
||||||
)
|
)
|
||||||
for initialized_format in enabled_initialized_formats
|
for initialized_format in enabled_initialized_formats
|
||||||
)
|
)
|
||||||
@ -515,16 +515,16 @@ def _prepare_next_playback_time(self):
|
|||||||
# Because of this, we should also check the player time against
|
# Because of this, we should also check the player time against
|
||||||
# the head segment time using the estimated segment duration.
|
# the head segment time using the estimated segment duration.
|
||||||
# xxx: consider also taking into account the max seekable timestamp
|
# xxx: consider also taking into account the max seekable timestamp
|
||||||
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance)
|
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.live_end_segment_tolerance)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
not self._is_retry # allow us to sleep on a retry
|
not self._is_retry # allow us to sleep on a retry
|
||||||
and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
|
||||||
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
|
and self._no_new_segments_tracker.timestamp_started < time.time() + self.live_end_wait_sec
|
||||||
):
|
):
|
||||||
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds; assuming end of live stream')
|
self.logger.debug(f'No new segments received for at least {self.live_end_wait_sec} seconds; assuming end of live stream')
|
||||||
self._consumed = True
|
self._consumed = True
|
||||||
elif self._no_new_segments_tracker.consecutive_requests >= 1:
|
elif self._no_new_segments_tracker.consecutive_requests >= 1:
|
||||||
# Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream
|
# Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream
|
||||||
|
Loading…
Reference in New Issue
Block a user