mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-07-18 03:08:31 +00:00
[test] add more sabr processor tests
This commit is contained in:
parent
8a0917584c
commit
75fb53bccc
@ -4,18 +4,33 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart, FormatInitializedSabrPart
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.part import (
|
||||
PoTokenStatusSabrPart,
|
||||
FormatInitializedSabrPart,
|
||||
MediaSeekSabrPart,
|
||||
)
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult, \
|
||||
ProcessFormatInitializationMetadataResult
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.processor import (
|
||||
SabrProcessor,
|
||||
ProcessStreamProtectionStatusResult,
|
||||
ProcessFormatInitializationMetadataResult,
|
||||
ProcessLiveMetadataResult, ProcessSabrSeekResult,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import (
|
||||
AudioSelector,
|
||||
VideoSelector,
|
||||
CaptionSelector,
|
||||
InitializedFormat,
|
||||
Segment,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import (
|
||||
FormatId,
|
||||
StreamProtectionStatus,
|
||||
FormatInitializationMetadata,
|
||||
LiveMetadata,
|
||||
SabrContextUpdate,
|
||||
SabrContextSendingPolicy, SabrSeek,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus, \
|
||||
FormatInitializationMetadata, LiveMetadata
|
||||
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
|
||||
|
||||
|
||||
@ -805,3 +820,415 @@ def test_total_segment_count_live_metadata(self, logger, base_args):
|
||||
processor.process_format_initialization_metadata(video_format_init_metadata_part)
|
||||
assert str(video_format_id) in processor.initialized_formats
|
||||
assert processor.initialized_formats[str(video_format_id)].total_segments == 9
|
||||
|
||||
|
||||
class TestLiveMetadata:
|
||||
|
||||
def test_live_metadata_initialization(self, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
assert processor.live_metadata is None
|
||||
|
||||
def test_live_metadata_update(self, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
live_metadata = LiveMetadata(head_sequence_number=10)
|
||||
|
||||
result = processor.process_live_metadata(live_metadata)
|
||||
assert isinstance(result, ProcessLiveMetadataResult)
|
||||
assert len(result.seek_sabr_parts) == 0
|
||||
assert processor.live_metadata is live_metadata
|
||||
|
||||
# Ensure new live_metadata replaces the old one
|
||||
live_metadata = dataclasses.replace(live_metadata, head_sequence_number=20)
|
||||
result = processor.process_live_metadata(live_metadata)
|
||||
|
||||
assert isinstance(result, ProcessLiveMetadataResult)
|
||||
assert len(result.seek_sabr_parts) == 0
|
||||
assert processor.live_metadata is live_metadata
|
||||
|
||||
def test_live_metadata_no_head_sequence_time_ms(self, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
live_metadata = LiveMetadata(head_sequence_number=10, head_sequence_time_ms=None)
|
||||
|
||||
processor.process_live_metadata(live_metadata)
|
||||
assert processor.live_metadata is live_metadata
|
||||
assert processor.total_duration_ms is None
|
||||
|
||||
def test_live_metadata_with_head_sequence_time_ms(self, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
live_metadata = LiveMetadata(head_sequence_number=10, head_sequence_time_ms=5000)
|
||||
|
||||
processor.process_live_metadata(live_metadata)
|
||||
assert processor.live_metadata is live_metadata
|
||||
assert processor.total_duration_ms == 5000
|
||||
|
||||
def test_update_izf_total_segments(self, base_args):
|
||||
|
||||
live_metadata = LiveMetadata(head_sequence_number=10)
|
||||
audio_selector = make_selector('audio')
|
||||
video_selector = make_selector('video')
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=audio_selector,
|
||||
video_selection=video_selector,
|
||||
video_id=example_video_id)
|
||||
|
||||
# Initialize both audio and video formats
|
||||
audio_format_id = audio_selector.format_ids[0]
|
||||
video_format_id = video_selector.format_ids[0]
|
||||
|
||||
audio_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=audio_format_id,
|
||||
mime_type='audio/mp4',
|
||||
)
|
||||
|
||||
video_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=video_format_id,
|
||||
mime_type='video/mp4',
|
||||
)
|
||||
|
||||
processor.process_format_initialization_metadata(audio_format_init_metadata)
|
||||
processor.process_format_initialization_metadata(video_format_init_metadata)
|
||||
assert len(processor.initialized_formats) == 2
|
||||
|
||||
# Process live metadata
|
||||
processor.process_live_metadata(live_metadata)
|
||||
assert processor.live_metadata is live_metadata
|
||||
|
||||
# Check that total_segments is updated in both formats
|
||||
assert processor.initialized_formats[str(audio_format_id)].total_segments == 10
|
||||
assert processor.initialized_formats[str(video_format_id)].total_segments == 10
|
||||
|
||||
def test_min_seekable_time_ms_less_than_player_time_ms(self, base_args):
|
||||
# If min_seekable_time_ms is greater or equal to player time, there should not be a seek
|
||||
live_metadata = LiveMetadata(
|
||||
head_sequence_number=10,
|
||||
head_sequence_time_ms=10000,
|
||||
min_seekable_time_ticks=50000,
|
||||
min_seekable_timescale=10000)
|
||||
|
||||
audio_selector = make_selector('audio')
|
||||
video_selector = make_selector('video')
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=audio_selector,
|
||||
video_selection=video_selector,
|
||||
video_id=example_video_id,
|
||||
start_time_ms=5001)
|
||||
|
||||
assert processor.client_abr_state.player_time_ms == 5001
|
||||
|
||||
# Initialize both audio and video formats
|
||||
audio_format_id = audio_selector.format_ids[0]
|
||||
video_format_id = video_selector.format_ids[0]
|
||||
|
||||
audio_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=audio_format_id,
|
||||
mime_type='audio/mp4',
|
||||
)
|
||||
|
||||
video_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=video_format_id,
|
||||
mime_type='video/mp4',
|
||||
)
|
||||
|
||||
processor.process_format_initialization_metadata(audio_format_init_metadata)
|
||||
processor.process_format_initialization_metadata(video_format_init_metadata)
|
||||
assert len(processor.initialized_formats) == 2
|
||||
|
||||
# Process live metadata
|
||||
result = processor.process_live_metadata(live_metadata)
|
||||
assert isinstance(result, ProcessLiveMetadataResult)
|
||||
assert len(result.seek_sabr_parts) == 0
|
||||
assert processor.live_metadata is live_metadata
|
||||
assert processor.client_abr_state.player_time_ms == 5001
|
||||
|
||||
def test_min_seekable_time_ms_greater_than_player_time_ms(self, base_args, logger):
|
||||
# If min_seekable_time_ms is less than player time, there should be a seek
|
||||
live_metadata = LiveMetadata(
|
||||
head_sequence_number=10,
|
||||
head_sequence_time_ms=10000,
|
||||
min_seekable_time_ticks=50000,
|
||||
min_seekable_timescale=10000,
|
||||
)
|
||||
|
||||
audio_selector = make_selector('audio')
|
||||
video_selector = make_selector('video')
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=audio_selector,
|
||||
video_selection=video_selector,
|
||||
video_id=example_video_id,
|
||||
start_time_ms=4999)
|
||||
|
||||
assert processor.client_abr_state.player_time_ms == 4999
|
||||
|
||||
# Initialize both audio and video formats
|
||||
audio_format_id = audio_selector.format_ids[0]
|
||||
video_format_id = video_selector.format_ids[0]
|
||||
|
||||
audio_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=audio_format_id,
|
||||
mime_type='audio/mp4',
|
||||
)
|
||||
|
||||
video_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=video_format_id,
|
||||
mime_type='video/mp4',
|
||||
)
|
||||
|
||||
processor.process_format_initialization_metadata(audio_format_init_metadata)
|
||||
processor.process_format_initialization_metadata(video_format_init_metadata)
|
||||
assert len(processor.initialized_formats) == 2
|
||||
|
||||
# Add a dummy previous segment to each format - this should be cleared on seek
|
||||
for izf in processor.initialized_formats.values():
|
||||
izf.current_segment = Segment(
|
||||
format_id=izf.format_id,
|
||||
)
|
||||
|
||||
# Process live metadata
|
||||
result = processor.process_live_metadata(live_metadata)
|
||||
assert isinstance(result, ProcessLiveMetadataResult)
|
||||
assert len(result.seek_sabr_parts) == 2
|
||||
assert processor.live_metadata is live_metadata
|
||||
assert processor.client_abr_state.player_time_ms == 5000
|
||||
|
||||
for seek_part in result.seek_sabr_parts:
|
||||
assert isinstance(seek_part, MediaSeekSabrPart)
|
||||
assert seek_part.format_id in (audio_format_id, video_format_id)
|
||||
assert seek_part.format_selector in (audio_selector, video_selector)
|
||||
|
||||
# Current segment should be cleared to indicate a seek
|
||||
for izf in processor.initialized_formats.values():
|
||||
assert izf.current_segment is None
|
||||
|
||||
logger.debug.assert_called_with('Player time 4999 is less than min seekable time 5000, simulating server seek')
|
||||
|
||||
|
||||
class TestSabrContextUpdate:
|
||||
def test_initialization(self, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
assert len(processor.sabr_context_updates) == 0
|
||||
assert len(processor.sabr_contexts_to_send) == 0
|
||||
|
||||
def test_invalid_sabr_context_update(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
invalid_update = SabrContextUpdate()
|
||||
processor.process_sabr_context_update(invalid_update)
|
||||
|
||||
assert len(processor.sabr_context_updates) == 0
|
||||
assert len(processor.sabr_contexts_to_send) == 0
|
||||
logger.warning.assert_called_with('Received an invalid SabrContextUpdate, ignoring')
|
||||
|
||||
def test_valid_sabr_context_update(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
valid_update = SabrContextUpdate(
|
||||
type=5,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
|
||||
value=b'{"key": "value"}',
|
||||
send_by_default=True,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
|
||||
)
|
||||
processor.process_sabr_context_update(valid_update)
|
||||
|
||||
assert len(processor.sabr_context_updates) == 1
|
||||
assert processor.sabr_context_updates[valid_update.type] == valid_update
|
||||
assert len(processor.sabr_contexts_to_send) == 1
|
||||
assert valid_update.type in processor.sabr_contexts_to_send
|
||||
logger.debug.assert_called_with(f'Registered SabrContextUpdate {valid_update}')
|
||||
|
||||
def test_not_send_by_default(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
valid_update = SabrContextUpdate(
|
||||
type=5,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
|
||||
value=b'{"key": "value"}',
|
||||
send_by_default=False,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
|
||||
)
|
||||
processor.process_sabr_context_update(valid_update)
|
||||
assert len(processor.sabr_context_updates) == 1
|
||||
assert processor.sabr_context_updates[valid_update.type] == valid_update
|
||||
assert len(processor.sabr_contexts_to_send) == 0
|
||||
|
||||
def test_write_policy_overwrite(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
first_ctx_update = SabrContextUpdate(
|
||||
type=5,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
|
||||
value=b'{"key": "value"}',
|
||||
send_by_default=True,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
|
||||
)
|
||||
processor.process_sabr_context_update(first_ctx_update)
|
||||
|
||||
assert len(processor.sabr_context_updates) == 1
|
||||
assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update
|
||||
assert len(processor.sabr_contexts_to_send) == 1
|
||||
assert first_ctx_update.type in processor.sabr_contexts_to_send
|
||||
|
||||
second_ctx_update = SabrContextUpdate(
|
||||
type=5,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
|
||||
value=b'{"key": "new_value"}',
|
||||
send_by_default=True,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
|
||||
)
|
||||
|
||||
processor.process_sabr_context_update(second_ctx_update)
|
||||
assert len(processor.sabr_context_updates) == 1
|
||||
assert processor.sabr_context_updates[second_ctx_update.type] == second_ctx_update
|
||||
assert len(processor.sabr_contexts_to_send) == 1
|
||||
assert second_ctx_update.type in processor.sabr_contexts_to_send
|
||||
|
||||
def test_write_policy_keep_existing(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
first_ctx_update = SabrContextUpdate(
|
||||
type=5,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
|
||||
value=b'{"key": "value"}',
|
||||
send_by_default=True,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING,
|
||||
)
|
||||
processor.process_sabr_context_update(first_ctx_update)
|
||||
|
||||
assert len(processor.sabr_context_updates) == 1
|
||||
assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update
|
||||
assert len(processor.sabr_contexts_to_send) == 1
|
||||
assert first_ctx_update.type in processor.sabr_contexts_to_send
|
||||
|
||||
second_ctx_update = SabrContextUpdate(
|
||||
type=5,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
|
||||
value=b'{"key": "new_value"}',
|
||||
send_by_default=True,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING,
|
||||
)
|
||||
|
||||
processor.process_sabr_context_update(second_ctx_update)
|
||||
assert len(processor.sabr_context_updates) == 1
|
||||
assert processor.sabr_context_updates[first_ctx_update.type] == first_ctx_update
|
||||
assert len(processor.sabr_contexts_to_send) == 1
|
||||
assert first_ctx_update.type in processor.sabr_contexts_to_send
|
||||
logger.debug.assert_called_with(
|
||||
'Received a SABR Context Update with write_policy=KEEP_EXISTING'
|
||||
'matching an existing SABR Context Update. Ignoring update')
|
||||
|
||||
|
||||
class TestSabrContextUpdateSendingPolicy:
|
||||
def test_set_sabr_context_update_sending_policy(self, base_args, logger):
|
||||
processor = SabrProcessor(**base_args)
|
||||
|
||||
processor.process_sabr_context_update(SabrContextUpdate(
|
||||
type=3,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_PLAYBACK,
|
||||
value=b'{"key": "value"}',
|
||||
send_by_default=True,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
|
||||
))
|
||||
|
||||
processor.process_sabr_context_update(SabrContextUpdate(
|
||||
type=4,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_REQUEST,
|
||||
value=b'{"key": "value"}',
|
||||
send_by_default=True,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING,
|
||||
))
|
||||
|
||||
processor.process_sabr_context_update(SabrContextUpdate(
|
||||
type=5,
|
||||
scope=SabrContextUpdate.SabrContextScope.SABR_CONTEXT_SCOPE_CONTENT_ADS,
|
||||
value=b'{"key": "value"}',
|
||||
send_by_default=False,
|
||||
write_policy=SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_OVERWRITE,
|
||||
))
|
||||
|
||||
assert len(processor.sabr_context_updates) == 3
|
||||
assert len(processor.sabr_contexts_to_send) == 2
|
||||
assert 3 in processor.sabr_contexts_to_send
|
||||
assert 4 in processor.sabr_contexts_to_send
|
||||
assert 5 not in processor.sabr_contexts_to_send
|
||||
|
||||
# Sending policy should update what contexts are sent
|
||||
processor.process_sabr_context_sending_policy(
|
||||
SabrContextSendingPolicy(
|
||||
start_policy=[5, 6],
|
||||
stop_policy=[3, 0],
|
||||
discard_policy=[4, 7]))
|
||||
|
||||
assert len(processor.sabr_context_updates) == 2
|
||||
assert len(processor.sabr_contexts_to_send) == 3
|
||||
assert 5 in processor.sabr_contexts_to_send
|
||||
assert 4 in processor.sabr_contexts_to_send # discarding does not remove from contexts to send
|
||||
assert 6 in processor.sabr_contexts_to_send
|
||||
assert all(n not in processor.sabr_contexts_to_send for n in [3, 0, 7])
|
||||
assert 4 not in processor.sabr_context_updates
|
||||
|
||||
|
||||
class TestSabrSeek:
|
||||
def test_invalid_sabr_seek(self, logger, base_args):
|
||||
processor = SabrProcessor(**base_args)
|
||||
invalid_seek = SabrSeek(seek_time_ticks=100, timescale=None)
|
||||
with pytest.raises(SabrStreamError, match='Server sent a SabrSeek part that is missing required seek data'):
|
||||
processor.process_sabr_seek(invalid_seek)
|
||||
assert processor.client_abr_state.player_time_ms == 0
|
||||
|
||||
def test_sabr_seek(self, logger, base_args):
|
||||
audio_selector = make_selector('audio')
|
||||
video_selector = make_selector('video')
|
||||
audio_format_id = audio_selector.format_ids[0]
|
||||
video_format_id = video_selector.format_ids[0]
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=audio_selector,
|
||||
video_selection=video_selector,
|
||||
video_id=example_video_id)
|
||||
|
||||
assert processor.client_abr_state.player_time_ms == 0
|
||||
|
||||
audio_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=audio_format_id,
|
||||
mime_type='audio/mp4',
|
||||
)
|
||||
video_format_init_metadata = FormatInitializationMetadata(
|
||||
video_id=example_video_id,
|
||||
format_id=video_format_id,
|
||||
mime_type='video/mp4',
|
||||
)
|
||||
|
||||
processor.process_format_initialization_metadata(audio_format_init_metadata)
|
||||
processor.process_format_initialization_metadata(video_format_init_metadata)
|
||||
assert len(processor.initialized_formats) == 2
|
||||
|
||||
# Add a dummy previous segment to each format - this should be cleared on seek
|
||||
for izf in processor.initialized_formats.values():
|
||||
izf.current_segment = Segment(
|
||||
format_id=izf.format_id,
|
||||
)
|
||||
|
||||
sabr_seek = SabrSeek(
|
||||
seek_time_ticks=56000,
|
||||
timescale=10000,
|
||||
)
|
||||
|
||||
result = processor.process_sabr_seek(sabr_seek)
|
||||
assert isinstance(result, ProcessSabrSeekResult)
|
||||
assert len(result.seek_sabr_parts) == 2
|
||||
assert processor.client_abr_state.player_time_ms == 5600
|
||||
for seek_part in result.seek_sabr_parts:
|
||||
assert isinstance(seek_part, MediaSeekSabrPart)
|
||||
assert seek_part.format_id in (audio_format_id, video_format_id)
|
||||
assert seek_part.format_selector in (audio_selector, video_selector)
|
||||
|
||||
# Current segment should be cleared to indicate a seek
|
||||
for izf in processor.initialized_formats.values():
|
||||
assert izf.current_segment is None
|
||||
|
||||
logger.debug.assert_called_with('Seeking to 5600ms')
|
||||
|
@ -614,7 +614,7 @@ def process_sabr_context_update(self, sabr_ctx_update: SabrContextUpdate):
|
||||
|
||||
def process_sabr_context_sending_policy(self, sabr_ctx_sending_policy: SabrContextSendingPolicy):
|
||||
for start_type in sabr_ctx_sending_policy.start_policy:
|
||||
if start_type not in self.sabr_context_updates:
|
||||
if start_type not in self.sabr_contexts_to_send:
|
||||
self.logger.debug(f'Server requested to enable SABR Context Update for type {start_type}')
|
||||
self.sabr_contexts_to_send.add(start_type)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user