1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-07-18 03:08:31 +00:00

[test] add more sabr processor tests

This commit is contained in:
coletdjnz 2025-07-01 07:45:58 +12:00
parent 8a0917584c
commit 75fb53bccc
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
2 changed files with 433 additions and 6 deletions

View File

@ -4,18 +4,33 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart, FormatInitializedSabrPart from yt_dlp.extractor.youtube._streaming.sabr.part import (
PoTokenStatusSabrPart,
FormatInitializedSabrPart,
MediaSeekSabrPart,
)
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult, \ from yt_dlp.extractor.youtube._streaming.sabr.processor import (
ProcessFormatInitializationMetadataResult SabrProcessor,
ProcessStreamProtectionStatusResult,
ProcessFormatInitializationMetadataResult,
ProcessLiveMetadataResult, ProcessSabrSeekResult,
)
from yt_dlp.extractor.youtube._streaming.sabr.models import ( from yt_dlp.extractor.youtube._streaming.sabr.models import (
AudioSelector, AudioSelector,
VideoSelector, VideoSelector,
CaptionSelector, CaptionSelector,
InitializedFormat, InitializedFormat,
Segment,
)
from yt_dlp.extractor.youtube._proto.videostreaming import (
FormatId,
StreamProtectionStatus,
FormatInitializationMetadata,
LiveMetadata,
SabrContextUpdate,
SabrContextSendingPolicy, SabrSeek,
) )
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus, \
FormatInitializationMetadata, LiveMetadata
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
@ -805,3 +820,415 @@ def test_total_segment_count_live_metadata(self, logger, base_args):
processor.process_format_initialization_metadata(video_format_init_metadata_part) processor.process_format_initialization_metadata(video_format_init_metadata_part)
assert str(video_format_id) in processor.initialized_formats assert str(video_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(video_format_id)].total_segments == 9 assert processor.initialized_formats[str(video_format_id)].total_segments == 9
class TestLiveMetadata:
def test_live_metadata_initialization(self, base_args):
processor = SabrProcessor(**base_args)
assert processor.live_metadata is None
def test_live_metadata_update(self, base_args):
processor = SabrProcessor(**base_args)
live_metadata = LiveMetadata(head_sequence_number=10)
result = processor.process_live_metadata(live_metadata)
assert isinstance(result, ProcessLiveMetadataResult)
assert len(result.seek_sabr_parts) == 0
assert processor.live_metadata is live_metadata
# Ensure new live_metadata replaces the old one
live_metadata = dataclasses.replace(live_metadata, head_sequence_number=20)
result = processor.process_live_metadata(live_metadata)
assert isinstance(result, ProcessLiveMetadataResult)
assert len(result.seek_sabr_parts) == 0
assert processor.live_metadata is live_metadata
def test_live_metadata_no_head_sequence_time_ms(self, base_args):
processor = SabrProcessor(**base_args)
live_metadata = LiveMetadata(head_sequence_number=10, head_sequence_time_ms=None)
processor.process_live_metadata(live_metadata)
assert processor.live_metadata is live_metadata
assert processor.total_duration_ms is None
def test_live_metadata_with_head_sequence_time_ms(self, base_args):
processor = SabrProcessor(**base_args)
live_metadata = LiveMetadata(head_sequence_number=10, head_sequence_time_ms=5000)
processor.process_live_metadata(live_metadata)
assert processor.live_metadata is live_metadata
assert processor.total_duration_ms == 5000
def test_update_izf_total_segments(self, base_args):
live_metadata = LiveMetadata(head_sequence_number=10)
audio_selector = make_selector('audio')
video_selector = make_selector('video')
processor = SabrProcessor(
**base_args,
audio_selection=audio_selector,
video_selection=video_selector,
video_id=example_video_id)
# Initialize both audio and video formats
audio_format_id = audio_selector.format_ids[0]
video_format_id = video_selector.format_ids[0]
audio_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
mime_type='audio/mp4',
)
video_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
mime_type='video/mp4',
)
processor.process_format_initialization_metadata(audio_format_init_metadata)
processor.process_format_initialization_metadata(video_format_init_metadata)
assert len(processor.initialized_formats) == 2
# Process live metadata
processor.process_live_metadata(live_metadata)
assert processor.live_metadata is live_metadata
# Check that total_segments is updated in both formats
assert processor.initialized_formats[str(audio_format_id)].total_segments == 10
assert processor.initialized_formats[str(video_format_id)].total_segments == 10
def test_min_seekable_time_ms_less_than_player_time_ms(self, base_args):
# If min_seekable_time_ms is greater or equal to player time, there should not be a seek
live_metadata = LiveMetadata(
head_sequence_number=10,
head_sequence_time_ms=10000,
min_seekable_time_ticks=50000,
min_seekable_timescale=10000)
audio_selector = make_selector('audio')
video_selector = make_selector('video')
processor = SabrProcessor(
**base_args,
audio_selection=audio_selector,
video_selection=video_selector,
video_id=example_video_id,
start_time_ms=5001)
assert processor.client_abr_state.player_time_ms == 5001
# Initialize both audio and video formats
audio_format_id = audio_selector.format_ids[0]
video_format_id = video_selector.format_ids[0]
audio_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
mime_type='audio/mp4',
)
video_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
mime_type='video/mp4',
)
processor.process_format_initialization_metadata(audio_format_init_metadata)
processor.process_format_initialization_metadata(video_format_init_metadata)
assert len(processor.initialized_formats) == 2
# Process live metadata
result = processor.process_live_metadata(live_metadata)
assert isinstance(result, ProcessLiveMetadataResult)
assert len(result.seek_sabr_parts) == 0
assert processor.live_metadata is live_metadata
assert processor.client_abr_state.player_time_ms == 5001
def test_min_seekable_time_ms_greater_than_player_time_ms(self, base_args, logger):
# If min_seekable_time_ms is less than player time, there should be a seek
live_metadata = LiveMetadata(
head_sequence_number=10,
head_sequence_time_ms=10000,
min_seekable_time_ticks=50000,
min_seekable_timescale=10000,
)
audio_selector = make_selector('audio')
video_selector = make_selector('video')
processor = SabrProcessor(
**base_args,
audio_selection=audio_selector,
video_selection=video_selector,
video_id=example_video_id,
start_time_ms=4999)
assert processor.client_abr_state.player_time_ms == 4999
# Initialize both audio and video formats
audio_format_id = audio_selector.format_ids[0]
video_format_id = video_selector.format_ids[0]
audio_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
mime_type='audio/mp4',
)
video_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
mime_type='video/mp4',
)
processor.process_format_initialization_metadata(audio_format_init_metadata)
processor.process_format_initialization_metadata(video_format_init_metadata)
assert len(processor.initialized_formats) == 2
# Add a dummy previous segment to each format - this should be cleared on seek
for izf in processor.initialized_formats.values():
izf.current_segment = Segment(
format_id=izf.format_id,
)
# Process live metadata
result = processor.process_live_metadata(live_metadata)
assert isinstance(result, ProcessLiveMetadataResult)
assert len(result.seek_sabr_parts) == 2
assert processor.live_metadata is live_metadata
assert processor.client_abr_state.player_time_ms == 5000
for seek_part in result.seek_sabr_parts:
assert isinstance(seek_part, MediaSeekSabrPart)
assert seek_part.format_id in (audio_format_id, video_format_id)
assert seek_part.format_selector in (audio_selector, video_selector)
# Current segment should be cleared to indicate a seek
for izf in processor.initialized_formats.values():
assert izf.current_segment is None
logger.debug.assert_called_with('Player time 4999 is less than min seekable time 5000, simulating server seek')
class TestSabrContextUpdate:
def test_initialization(self, base_args):
processor = SabrProcessor(**base_args)
assert len(processor.sabr_context_updates) == 0
assert len(processor.sabr_contexts_to_send) == 0
def test_invalid_sabr_context_update(self, logger, base_args):
processor = SabrProcessor(**base_args)
invalid_update = SabrContextUpdate()
processor.process_sabr_context_update(invalid_update)
assert len(processor.sabr_context_updates) == 0
assert len(processor.sabr_contexts_to_send) == 0
logger.warning.assert_called_with('Received an invalid SabrContextUpdate, ignoring')
def test_valid_sabr_context_update(self, logger, base_args):
processor = SabrProcessor(**base_args)
valid_update = SabrContextUpdate(
type=5,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
value=b'{"key": "value"}',
send_by_default=True,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
)
processor.process_sabr_context_update(valid_update)
assert len(processor.sabr_context_updates) == 1
assert processor.sabr_context_updates[valid_update.type] == valid_update
assert len(processor.sabr_contexts_to_send) == 1
assert valid_update.type in processor.sabr_contexts_to_send
logger.debug.assert_called_with(f'Registered SabrContextUpdate {valid_update}')
def test_not_send_by_default(self, logger, base_args):
processor = SabrProcessor(**base_args)
valid_update = SabrContextUpdate(
type=5,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
value=b'{"key": "value"}',
send_by_default=False,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
)
processor.process_sabr_context_update(valid_update)
assert len(processor.sabr_context_updates) == 1
assert processor.sabr_context_updates[valid_update.type] == valid_update
assert len(processor.sabr_contexts_to_send) == 0
def test_write_policy_overwrite(self, logger, base_args):
processor = SabrProcessor(**base_args)
first_ctx_update = SabrContextUpdate(
type=5,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
value=b'{"key": "value"}',
send_by_default=True,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
)
processor.process_sabr_context_update(first_ctx_update)
assert len(processor.sabr_context_updates) == 1
assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update
assert len(processor.sabr_contexts_to_send) == 1
assert first_ctx_update.type in processor.sabr_contexts_to_send
second_ctx_update = SabrContextUpdate(
type=5,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
value=b'{"key": "new_value"}',
send_by_default=True,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
)
processor.process_sabr_context_update(second_ctx_update)
assert len(processor.sabr_context_updates) == 1
assert processor.sabr_context_updates[second_ctx_update.type] == second_ctx_update
assert len(processor.sabr_contexts_to_send) == 1
assert second_ctx_update.type in processor.sabr_contexts_to_send
def test_write_policy_keep_existing(self, logger, base_args):
processor = SabrProcessor(**base_args)
first_ctx_update = SabrContextUpdate(
type=5,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
value=b'{"key": "value"}',
send_by_default=True,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING,
)
processor.process_sabr_context_update(first_ctx_update)
assert len(processor.sabr_context_updates) == 1
assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update
assert len(processor.sabr_contexts_to_send) == 1
assert first_ctx_update.type in processor.sabr_contexts_to_send
second_ctx_update = SabrContextUpdate(
type=5,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
value=b'{"key": "new_value"}',
send_by_default=True,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING,
)
processor.process_sabr_context_update(second_ctx_update)
assert len(processor.sabr_context_updates) == 1
assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update
assert len(processor.sabr_contexts_to_send) == 1
assert first_ctx_update.type in processor.sabr_contexts_to_send
logger.debug.assert_called_with(
'Received a SABR Context Update with write_policy=KEEP_EXISTING'
'matching an existing SABR Context Update. Ignoring update')
class TestSabrContextUpdateSendingPolicy:
def test_set_sabr_context_update_sending_policy(self, base_args, logger):
processor = SabrProcessor(**base_args)
processor.process_sabr_context_update(SabrContextUpdate(
type=3,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_PLAYBACK,
value=b'{"key": "value"}',
send_by_default=True,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
))
processor.process_sabr_context_update(SabrContextUpdate(
type=4,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_REQUEST,
value=b'{"key": "value"}',
send_by_default=True,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING,
))
processor.process_sabr_context_update(SabrContextUpdate(
type=5,
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
value=b'{"key": "value"}',
send_by_default=False,
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
))
assert len(processor.sabr_context_updates) == 3
assert len(processor.sabr_contexts_to_send) == 2
assert 3 in processor.sabr_contexts_to_send
assert 4 in processor.sabr_contexts_to_send
assert 5 not in processor.sabr_contexts_to_send
# Sending policy should update what contexts are sent
processor.process_sabr_context_sending_policy(
SabrContextSendingPolicy(
start_policy=[5, 6],
stop_policy=[3, 0],
discard_policy=[4, 7]))
assert len(processor.sabr_context_updates) == 2
assert len(processor.sabr_contexts_to_send) == 3
assert 5 in processor.sabr_contexts_to_send
assert 4 in processor.sabr_contexts_to_send # discarding does not remove from contexts to send
assert 6 in processor.sabr_contexts_to_send
assert all(n not in processor.sabr_contexts_to_send for n in [3, 0, 7])
assert 4 not in processor.sabr_context_updates
class TestSabrSeek:
def test_invalid_sabr_seek(self, logger, base_args):
processor = SabrProcessor(**base_args)
invalid_seek = SabrSeek(seek_time_ticks=100, timescale=None)
with pytest.raises(SabrStreamError, match='Server sent a SabrSeek part that is missing required seek data'):
processor.process_sabr_seek(invalid_seek)
assert processor.client_abr_state.player_time_ms == 0
def test_sabr_seek(self, logger, base_args):
audio_selector = make_selector('audio')
video_selector = make_selector('video')
audio_format_id = audio_selector.format_ids[0]
video_format_id = video_selector.format_ids[0]
processor = SabrProcessor(
**base_args,
audio_selection=audio_selector,
video_selection=video_selector,
video_id=example_video_id)
assert processor.client_abr_state.player_time_ms == 0
audio_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
mime_type='audio/mp4',
)
video_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
mime_type='video/mp4',
)
processor.process_format_initialization_metadata(audio_format_init_metadata)
processor.process_format_initialization_metadata(video_format_init_metadata)
assert len(processor.initialized_formats) == 2
# Add a dummy previous segment to each format - this should be cleared on seek
for izf in processor.initialized_formats.values():
izf.current_segment = Segment(
format_id=izf.format_id,
)
sabr_seek = SabrSeek(
seek_time_ticks=56000,
timescale=10000,
)
result = processor.process_sabr_seek(sabr_seek)
assert isinstance(result, ProcessSabrSeekResult)
assert len(result.seek_sabr_parts) == 2
assert processor.client_abr_state.player_time_ms == 5600
for seek_part in result.seek_sabr_parts:
assert isinstance(seek_part, MediaSeekSabrPart)
assert seek_part.format_id in (audio_format_id, video_format_id)
assert seek_part.format_selector in (audio_selector, video_selector)
# Current segment should be cleared to indicate a seek
for izf in processor.initialized_formats.values():
assert izf.current_segment is None
logger.debug.assert_called_with('Seeking to 5600ms')

View File

@ -614,7 +614,7 @@ def process_sabr_context_update(self, sabr_ctx_update: SabrContextUpdate):
def process_sabr_context_sending_policy(self, sabr_ctx_sending_policy: SabrContextSendingPolicy): def process_sabr_context_sending_policy(self, sabr_ctx_sending_policy: SabrContextSendingPolicy):
for start_type in sabr_ctx_sending_policy.start_policy: for start_type in sabr_ctx_sending_policy.start_policy:
if start_type not in self.sabr_context_updates: if start_type not in self.sabr_contexts_to_send:
self.logger.debug(f'Server requested to enable SABR Context Update for type {start_type}') self.logger.debug(f'Server requested to enable SABR Context Update for type {start_type}')
self.sabr_contexts_to_send.add(start_type) self.sabr_contexts_to_send.add(start_type)