1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-06-28 01:18:30 +00:00
yt-dlp/test/test_sabr/test_processor.py
2025-06-25 08:34:47 +12:00

308 lines
12 KiB
Python

import pytest
from unittest.mock import MagicMock
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult
from yt_dlp.extractor.youtube._streaming.sabr.models import (
AudioSelector,
VideoSelector,
CaptionSelector,
)
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
@pytest.fixture
def logger():
return MagicMock()
@pytest.fixture
def client_info():
return ClientInfo()
@pytest.fixture
def base_args(logger, client_info):
return {
'logger': logger,
'client_info': client_info,
'video_playback_ustreamer_config': 'dGVzdA==',
}
def make_selector(selector_type, *, discard_media=False, format_ids=None):
if selector_type == 'audio':
return AudioSelector(
display_name='audio',
format_ids=format_ids or [FormatId(itag=140)],
discard_media=discard_media,
)
elif selector_type == 'video':
return VideoSelector(
display_name='video',
format_ids=format_ids or [FormatId(itag=248)],
discard_media=discard_media,
)
elif selector_type == 'caption':
return CaptionSelector(
display_name='caption',
format_ids=format_ids or [FormatId(itag=386)],
discard_media=discard_media,
)
raise ValueError(f'Unknown selector_type: {selector_type}')
def selector_factory(selector_type, *, discard_media=False, format_ids=None):
def factory():
return make_selector(selector_type, discard_media=discard_media, format_ids=format_ids)
return factory
class TestSabrProcessorInitialization:
@pytest.mark.parametrize(
'audio_sel,video_sel,caption_sel,expected_bitfield',
[
# audio+video
(selector_factory('audio'), selector_factory('video'), None, 0),
# audio+video+caption(discard)
(
selector_factory('audio'),
selector_factory('video'),
selector_factory('caption', discard_media=True),
0,
),
# audio only
(selector_factory('audio'), None, None, 1),
# audio only (video+caption manual discard)
(
selector_factory('audio'),
selector_factory('video', discard_media=True),
selector_factory('caption', discard_media=True),
1,
),
# audio+video+caption
(
selector_factory('audio'),
selector_factory('video'),
selector_factory('caption'),
7,
),
# video only
(None, selector_factory('video'), None, 0),
# video only (audio+caption manual discard)
(
selector_factory('audio', discard_media=True),
selector_factory('video'),
selector_factory('caption', discard_media=True),
0,
),
# caption only
(None, None, selector_factory('caption'), 7),
# caption only (audio+video manual discard)
(
selector_factory('audio', discard_media=True),
selector_factory('video', discard_media=True),
selector_factory('caption'),
7,
),
],
)
def test_client_abr_state_bitfield(
self, base_args, audio_sel, video_sel, caption_sel, expected_bitfield,
):
processor = SabrProcessor(
**base_args,
audio_selection=audio_sel() if audio_sel else None,
video_selection=video_sel() if video_sel else None,
caption_selection=caption_sel() if caption_sel else None,
)
assert processor.client_abr_state.enabled_track_types_bitfield == expected_bitfield
@pytest.mark.parametrize(
'audio_sel,video_sel,caption_sel,expected_audio_ids,expected_video_ids,expected_caption_ids',
[
# audio+video
(
selector_factory('audio'), selector_factory('video'), None,
[FormatId(itag=140)], [FormatId(itag=248)], [],
),
# audio only
(
selector_factory('audio'), None, None,
[FormatId(itag=140)], [], [],
),
# video only
(
None, selector_factory('video'), None,
[], [FormatId(itag=248)], [],
),
# caption only
(
None, None, selector_factory('caption'),
[], [], [FormatId(itag=386)],
),
# audio+video+caption
(
selector_factory('audio'), selector_factory('video'), selector_factory('caption'),
[FormatId(itag=140)], [FormatId(itag=248)], [FormatId(itag=386)],
),
# multiple ids
(
selector_factory('audio', format_ids=[FormatId(itag=140), FormatId(itag=141)]),
selector_factory('video', format_ids=[FormatId(itag=248), FormatId(itag=249)]),
selector_factory('caption', format_ids=[FormatId(itag=386), FormatId(itag=387)]),
[FormatId(itag=140), FormatId(itag=141)],
[FormatId(itag=248), FormatId(itag=249)],
[FormatId(itag=386), FormatId(itag=387)],
),
],
)
def test_selected_format_ids(
self, base_args, audio_sel, video_sel, caption_sel,
expected_audio_ids, expected_video_ids, expected_caption_ids,
):
processor = SabrProcessor(
**base_args,
audio_selection=audio_sel() if audio_sel else None,
video_selection=video_sel() if video_sel else None,
caption_selection=caption_sel() if caption_sel else None,
)
assert processor.selected_audio_format_ids == expected_audio_ids
assert processor.selected_video_format_ids == expected_video_ids
assert processor.selected_caption_format_ids == expected_caption_ids
@pytest.mark.parametrize(
'start_time_ms,expected',
[
(None, 0),
(0, 0),
(12345, 12345),
],
)
def test_start_time_ms_initialization(self, base_args, start_time_ms, expected):
processor = SabrProcessor(
**base_args,
start_time_ms=start_time_ms,
)
assert processor.start_time_ms == expected
assert processor.client_abr_state.player_time_ms == expected
@pytest.mark.parametrize('invalid_start_time_ms', [-1, -100])
def test_start_time_ms_invalid(self, base_args, invalid_start_time_ms):
with pytest.raises(ValueError, match='start_time_ms must be greater than or equal to 0'):
SabrProcessor(
**base_args,
audio_selection=selector_factory('audio')(),
video_selection=selector_factory('video')(),
caption_selection=None,
start_time_ms=invalid_start_time_ms,
)
@pytest.mark.parametrize(
'duration_sec,tolerance_ms',
[
(10, 4999),
(10, 0),
],
)
def test_live_segment_target_duration_tolerance_ms_valid(self, base_args, duration_sec, tolerance_ms):
# Should not raise
SabrProcessor(
**base_args,
live_segment_target_duration_sec=duration_sec,
live_segment_target_duration_tolerance_ms=tolerance_ms,
)
@pytest.mark.parametrize(
'duration_sec,tolerance_ms',
[
(10, 5000), # exactly half
(10, 6000), # more than half
],
)
def test_live_segment_target_duration_tolerance_ms_validation(self, base_args, duration_sec, tolerance_ms):
with pytest.raises(ValueError, match='live_segment_target_duration_tolerance_ms must be less than'):
SabrProcessor(
**base_args,
live_segment_target_duration_sec=duration_sec,
live_segment_target_duration_tolerance_ms=tolerance_ms,
)
def test_defaults(self, base_args):
processor = SabrProcessor(**base_args)
assert processor.live_segment_target_duration_sec == 5
assert processor.live_segment_target_duration_tolerance_ms == 100
assert processor.start_time_ms == 0
assert processor.post_live is False
def test_override_defaults(self, base_args):
processor = SabrProcessor(
**base_args,
live_segment_target_duration_sec=8,
live_segment_target_duration_tolerance_ms=42,
start_time_ms=123,
post_live=True,
)
assert processor.live_segment_target_duration_sec == 8
assert processor.live_segment_target_duration_tolerance_ms == 42
assert processor.start_time_ms == 123
assert processor.post_live is True
class TestStreamProtectionStatusPart:
@pytest.mark.parametrize(
'sps,po_token,expected_status',
[
(StreamProtectionStatus.Status.OK, None, PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED),
(StreamProtectionStatus.Status.OK, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.OK),
(StreamProtectionStatus.Status.ATTESTATION_PENDING, None, PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING),
(StreamProtectionStatus.Status.ATTESTATION_PENDING, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.PENDING),
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, None, PoTokenStatusSabrPart.PoTokenStatus.MISSING),
(StreamProtectionStatus.Status.ATTESTATION_REQUIRED, 'valid_token', PoTokenStatusSabrPart.PoTokenStatus.INVALID),
],
)
def test_stream_protection_status_part(self, base_args, sps, po_token, expected_status):
processor = SabrProcessor(**base_args, po_token=po_token)
part = StreamProtectionStatus(status=sps)
result = processor.process_stream_protection_status(part)
assert isinstance(result, ProcessStreamProtectionStatusResult)
assert isinstance(result.sabr_part, PoTokenStatusSabrPart)
assert result.sabr_part.status == expected_status
assert processor.stream_protection_status == sps
def test_no_stream_protection_status(self, logger, base_args):
processor = SabrProcessor(**base_args, po_token='valid_token')
part = StreamProtectionStatus(status=None)
result = processor.process_stream_protection_status(part)
assert isinstance(result, ProcessStreamProtectionStatusResult)
assert result.sabr_part is None
assert processor.stream_protection_status is None
assert logger.warning.call_count == 1
logger.warning.assert_called_with(
'Received an unknown StreamProtectionStatus: StreamProtectionStatus(status=None, max_retries=None)',
)
class TestNextRequestPolicyPart:
def test_next_request_policy_part(self, logger, base_args):
processor = SabrProcessor(**base_args)
next_request_policy = NextRequestPolicy(target_audio_readahead_ms=123)
result = processor.process_next_request_policy(next_request_policy)
assert result is None
assert processor.next_request_policy is next_request_policy
# Verify it is overridden in the processor on another call
next_request_policy = NextRequestPolicy(target_video_readahead_ms=456)
result = processor.process_next_request_policy(next_request_policy)
assert result is None
assert processor.next_request_policy is next_request_policy
# Check logger trace was called
assert logger.trace.call_count == 2