1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-07-18 03:08:31 +00:00

[test] add format initialization tests

This commit is contained in:
coletdjnz 2025-06-29 18:32:18 +12:00
parent 1fdd744253
commit 8a0917584c
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
2 changed files with 507 additions and 7 deletions

View File

@ -1,15 +1,21 @@
import dataclasses
import pytest
from unittest.mock import MagicMock
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart
from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError
from yt_dlp.extractor.youtube._streaming.sabr.part import PoTokenStatusSabrPart, FormatInitializedSabrPart
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult
from yt_dlp.extractor.youtube._streaming.sabr.processor import SabrProcessor, ProcessStreamProtectionStatusResult, \
ProcessFormatInitializationMetadataResult
from yt_dlp.extractor.youtube._streaming.sabr.models import (
AudioSelector,
VideoSelector,
CaptionSelector,
InitializedFormat,
)
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId, StreamProtectionStatus, \
FormatInitializationMetadata, LiveMetadata
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
@ -32,23 +38,26 @@ def base_args(logger, client_info):
}
example_video_id = 'example_video_id'
def make_selector(selector_type, *, discard_media=False, format_ids=None):
if selector_type == 'audio':
return AudioSelector(
display_name='audio',
format_ids=format_ids or [FormatId(itag=140)],
format_ids=format_ids if format_ids is not None else [FormatId(itag=140)],
discard_media=discard_media,
)
elif selector_type == 'video':
return VideoSelector(
display_name='video',
format_ids=format_ids or [FormatId(itag=248)],
format_ids=format_ids if format_ids is not None else [FormatId(itag=248)],
discard_media=discard_media,
)
elif selector_type == 'caption':
return CaptionSelector(
display_name='caption',
format_ids=format_ids or [FormatId(itag=386)],
format_ids=format_ids if format_ids is not None else [FormatId(itag=386)],
discard_media=discard_media,
)
raise ValueError(f'Unknown selector_type: {selector_type}')
@ -305,3 +314,494 @@ def test_next_request_policy_part(self, logger, base_args):
# Check logger trace was called
assert logger.trace.call_count == 2
class TestFormatInitialization:
def test_initialize_format(self, logger, base_args):
selector = make_selector('audio')
format_id = selector.format_ids[0]
processor = SabrProcessor(**base_args, audio_selection=selector, video_id='test_video')
format_init_metadata_part = FormatInitializationMetadata(
video_id='test_video',
format_id=format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
result = processor.process_format_initialization_metadata(format_init_metadata_part)
assert isinstance(result, ProcessFormatInitializationMetadataResult)
assert isinstance(result.sabr_part, FormatInitializedSabrPart)
assert result.sabr_part.format_selector is selector
assert result.sabr_part.format_id == format_id
assert len(processor.initialized_formats) == 1
assert str(format_id) in processor.initialized_formats
initialized_format = processor.initialized_formats[str(format_id)]
expected_initialized_format = InitializedFormat(
format_id=format_id,
video_id='test_video',
mime_type='audio/mp4',
duration_ms=10000,
total_segments=5,
end_time_ms=10000,
format_selector=selector,
discard=False,
)
assert initialized_format == expected_initialized_format
logger.debug.assert_called_with(
f'Initialized Format: {expected_initialized_format}',
)
def test_initialize_format_already_initialized(self, logger, base_args):
selector = make_selector('audio')
format_id = selector.format_ids[0]
processor = SabrProcessor(**base_args, audio_selection=selector, video_id='test_video')
format_init_metadata_part = FormatInitializationMetadata(
video_id='test_video',
format_id=format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
assert processor.process_format_initialization_metadata(format_init_metadata_part)
# Now try to initialize it again
result = processor.process_format_initialization_metadata(
dataclasses.replace(format_init_metadata_part, total_segments=10))
assert isinstance(result, ProcessFormatInitializationMetadataResult)
assert result.sabr_part is None
logger.trace.assert_called_with(f'Format {format_id} already initialized')
assert len(processor.initialized_formats) == 1
assert str(format_id) in processor.initialized_formats
initialized_format = processor.initialized_formats[str(format_id)]
assert initialized_format.total_segments == 5
def test_initialize_multiple_formats(self, logger, base_args):
audio_selector = make_selector('audio')
video_selector = make_selector('video')
audio_format_id = audio_selector.format_ids[0]
video_format_id = video_selector.format_ids[0]
processor = SabrProcessor(
**base_args,
audio_selection=audio_selector,
video_selection=video_selector,
video_id=example_video_id,
)
audio_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
video_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
end_time_ms=10000,
total_segments=20,
mime_type='video/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
# Process audio format initialization
audio_result = processor.process_format_initialization_metadata(audio_format_init_metadata)
assert audio_result.sabr_part.format_selector is audio_selector
assert audio_result.sabr_part.format_id == audio_format_id
assert len(processor.initialized_formats) == 1
assert str(audio_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(audio_format_id)].format_id == audio_format_id
# Process video format initialization
video_result = processor.process_format_initialization_metadata(video_format_init_metadata)
assert video_result.sabr_part.format_selector is video_selector
assert video_result.sabr_part.format_id == video_format_id
assert len(processor.initialized_formats) == 2
assert str(video_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(video_format_id)].format_id == video_format_id
def test_initialized_format_not_match_selector(self, logger, base_args):
selector = make_selector('audio', format_ids=[FormatId(140)])
processor = SabrProcessor(
**base_args,
video_id=example_video_id,
audio_selection=selector)
format_id = FormatId(itag=141) # Different format id than the selector
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
with pytest.raises(SabrStreamError, match='does not match any format selector'):
processor.process_format_initialization_metadata(format_init_metadata_part)
assert len(processor.initialized_formats) == 0
def test_initialized_format_match_mimetype(self, logger, base_args):
selector = make_selector('audio', format_ids=[])
assert len(selector.format_ids) == 0
processor = SabrProcessor(
**base_args,
video_id=example_video_id,
audio_selection=selector,
video_selection=make_selector('caption'),
)
format_id = FormatId(itag=251)
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
result = processor.process_format_initialization_metadata(format_init_metadata_part)
assert isinstance(result, ProcessFormatInitializationMetadataResult)
assert isinstance(result.sabr_part, FormatInitializedSabrPart)
assert result.sabr_part.format_selector is selector
assert result.sabr_part.format_id == format_id
assert len(processor.initialized_formats) == 1
# If mimetype does not match any selector, it should raise an error
bad_fmt_init_metadata = dataclasses.replace(
format_init_metadata_part, mime_type='video/mp4', format_id=FormatId(itag=248))
with pytest.raises(SabrStreamError, match='does not match any format selector'):
processor.process_format_initialization_metadata(bad_fmt_init_metadata)
assert len(processor.initialized_formats) == 1
def test_discard_media(self, logger, base_args):
# Discard and only match by format id
audio_selector = make_selector('audio', discard_media=True)
audio_format_id = audio_selector.format_ids[0]
# Discard any video
video_selector = make_selector('video', format_ids=[], discard_media=True)
assert len(video_selector.format_ids) == 0
video_format_id = FormatId(itag=248)
# xxx: Caption selector not specified, should be discarded by default
processor = SabrProcessor(
**base_args,
video_id=example_video_id,
audio_selection=audio_selector,
video_selection=video_selector,
)
audio_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
# Process audio format initialization
audio_result = processor.process_format_initialization_metadata(audio_format_init_metadata)
# When discarding, should not return a sabr_part
assert isinstance(audio_result, ProcessFormatInitializationMetadataResult)
assert audio_result.sabr_part is None
assert len(processor.initialized_formats) == 1
assert str(audio_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(audio_format_id)].format_id == audio_format_id
assert processor.initialized_formats[str(audio_format_id)].discard is True
# The format should be marked as completely buffered
assert len(processor.initialized_formats[str(audio_format_id)].consumed_ranges) == 1
consumed_range = processor.initialized_formats[str(audio_format_id)].consumed_ranges[0]
assert consumed_range.start_sequence_number == 0
assert consumed_range.end_sequence_number >= 5
assert consumed_range.start_time_ms == 0
assert consumed_range.duration_ms >= 10000
# Process video format initialization. This should match the selector but be discarded.
video_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
end_time_ms=10000,
total_segments=20,
mime_type='video/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
video_result = processor.process_format_initialization_metadata(video_format_init_metadata)
assert video_result.sabr_part is None
assert len(processor.initialized_formats) == 2
assert str(video_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(video_format_id)].format_id == video_format_id
assert processor.initialized_formats[str(video_format_id)].discard is True
# The format should be marked as completely buffered
assert len(processor.initialized_formats[str(video_format_id)].consumed_ranges) == 1
consumed_range = processor.initialized_formats[str(video_format_id)].consumed_ranges[0]
assert consumed_range.start_sequence_number == 0
assert consumed_range.end_sequence_number >= 20
assert consumed_range.start_time_ms == 0
assert consumed_range.duration_ms >= 10000
# Process a caption format initialization. This should be discarded by default as no selector was specified.
caption_format_id = FormatId(itag=386)
# Simulate no duration data (for livestreams with no live_metadata)
caption_format_init_metadata = FormatInitializationMetadata(
video_id=example_video_id,
format_id=caption_format_id,
mime_type='text/mp4',
)
caption_result = processor.process_format_initialization_metadata(caption_format_init_metadata)
assert caption_result.sabr_part is None
assert len(processor.initialized_formats) == 3
assert str(caption_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(caption_format_id)].format_id == caption_format_id
assert processor.initialized_formats[str(caption_format_id)].discard is True
# The format should be marked as completely buffered
assert len(processor.initialized_formats[str(caption_format_id)].consumed_ranges) == 1
consumed_range = processor.initialized_formats[str(caption_format_id)].consumed_ranges[0]
assert consumed_range.start_sequence_number == 0
assert consumed_range.end_sequence_number >= 99999999
assert consumed_range.start_time_ms == 0
assert consumed_range.duration_ms >= 99999999
def test_total_duration_ms(self, logger, base_args):
# Test the duration_ms calculation when end_time_ms and duration_ms are different
audio_selector = make_selector('audio')
video_selector = make_selector('video')
caption_selector = make_selector('caption')
audio_format_id = audio_selector.format_ids[0]
video_format_id = video_selector.format_ids[0]
caption_format_id = caption_selector.format_ids[0]
processor = SabrProcessor(
**base_args,
audio_selection=audio_selector,
video_selection=video_selector,
caption_selection=caption_selector,
video_id=example_video_id,
)
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
end_time_ms=15001, # End time is slightly more than 15 seconds
total_segments=5,
mime_type='audio/mp4',
duration_ticks=15000,
duration_timescale=1000,
)
assert processor.total_duration_ms is None
processor.process_format_initialization_metadata(format_init_metadata_part)
assert str(audio_format_id) in processor.initialized_formats
assert processor.total_duration_ms == 15001
# But if duration_ticks is greater, then use that
video_format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
end_time_ms=15003, # end time is slightly less
total_segments=10,
mime_type='video/mp4',
duration_ticks=150050, # Duration ticks is greater than end_time_ms
duration_timescale=10000, # slightly different timescale
)
processor.process_format_initialization_metadata(video_format_init_metadata_part)
assert str(video_format_id) in processor.initialized_formats
assert processor.total_duration_ms == 15005
# And if total_duration_ms is greater, use that
caption_format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=caption_format_id,
end_time_ms=15004,
total_segments=20,
mime_type='text/mp4',
duration_ticks=15004,
duration_timescale=1000,
)
processor.process_format_initialization_metadata(caption_format_init_metadata_part)
assert str(caption_format_id) in processor.initialized_formats
assert processor.total_duration_ms == 15005 # should not change
def test_no_duration(self, logger, base_args):
# Test the case where no duration is provided
selector = make_selector('audio')
format_id = selector.format_ids[0]
processor = SabrProcessor(**base_args, audio_selection=selector, video_id=example_video_id)
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=format_id,
end_time_ms=None, # No end time
total_segments=5,
mime_type='audio/mp4',
duration_ticks=None,
duration_timescale=None,
)
assert processor.total_duration_ms is None
processor.process_format_initialization_metadata(format_init_metadata_part)
assert str(format_id) in processor.initialized_formats
assert processor.total_duration_ms == 0 # TODO: should this be None or 0?
def test_no_duration_total_duration_ms_set(self, logger, base_args):
# Test the case where no duration is provided but total_duration_ms is set (by e.g. live_metadata)
selector = make_selector('audio')
format_id = selector.format_ids[0]
processor = SabrProcessor(
**base_args,
audio_selection=selector,
video_id=example_video_id,
)
processor.total_duration_ms = 10000
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=format_id,
end_time_ms=None, # No end time
total_segments=5,
mime_type='audio/mp4',
duration_ticks=None,
duration_timescale=None,
)
processor.process_format_initialization_metadata(format_init_metadata_part)
assert str(format_id) in processor.initialized_formats
assert processor.total_duration_ms == 10000
def test_video_id_mismatch(self, logger, base_args):
selector = make_selector('audio')
format_id = selector.format_ids[0]
processor = SabrProcessor(**base_args, audio_selection=selector, video_id='video_1')
format_init_metadata_part = FormatInitializationMetadata(
video_id='video_2',
format_id=format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
with pytest.raises(SabrStreamError, match='Received unexpected Format Initialization Metadata for video video_2'):
processor.process_format_initialization_metadata(format_init_metadata_part)
assert len(processor.initialized_formats) == 0
def test_selector_consumed(self, logger, base_args):
# Test that if a format selector is already in use, it raises an error
selector = make_selector('audio', format_ids=[])
audio_format_id = FormatId(itag=140)
processor = SabrProcessor(**base_args, audio_selection=selector, video_id=example_video_id)
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
end_time_ms=10000,
total_segments=5,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
processor.process_format_initialization_metadata(format_init_metadata_part)
assert str(audio_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(audio_format_id)].format_selector is selector
with pytest.raises(SabrStreamError, match='Changing formats is not currently supported'):
processor.process_format_initialization_metadata(
dataclasses.replace(format_init_metadata_part, format_id=FormatId(itag=141)))
def test_no_segment_count(self, logger, base_args):
selector = make_selector('audio')
format_id = selector.format_ids[0]
processor = SabrProcessor(**base_args, audio_selection=selector, video_id=example_video_id)
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=format_id,
end_time_ms=10000,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
processor.process_format_initialization_metadata(format_init_metadata_part)
assert str(format_id) in processor.initialized_formats
assert processor.initialized_formats[str(format_id)].total_segments is None
def test_total_segment_count_live_metadata(self, logger, base_args):
# Test that total_segments is set from live_metadata when not in the format
audio_selector = make_selector('audio')
video_selector = make_selector('video')
audio_format_id = audio_selector.format_ids[0]
video_format_id = video_selector.format_ids[0]
processor = SabrProcessor(
**base_args, audio_selection=audio_selector, video_selection=video_selector,
video_id=example_video_id)
format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=audio_format_id,
end_time_ms=10000,
mime_type='audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
processor.live_metadata = LiveMetadata(head_sequence_number=10)
processor.process_format_initialization_metadata(format_init_metadata_part)
assert str(audio_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(audio_format_id)].total_segments == 10
# But live metadata should not override total_segments if it is present
# XXX: when live metadata is updated, it will update the total_segments.
# However, we can consider the total_segments
# from the format initialization metadata as the most-up-to-date value until then.
video_format_init_metadata_part = FormatInitializationMetadata(
video_id=example_video_id,
format_id=video_format_id,
end_time_ms=10000,
# This should take precedence over live_metadata.
# Generally, this should only ever be greater than the live_metadata value.
# Never seen this be present for livestreams at this time.
# TODO: add a guard to ensure total segments is > live_metadata.head_sequence_number?
total_segments=9,
mime_type='video/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
processor.process_format_initialization_metadata(video_format_init_metadata_part)
assert str(video_format_id) in processor.initialized_formats
assert processor.initialized_formats[str(video_format_id)].total_segments == 9

View File

@ -525,7 +525,7 @@ def process_format_initialization_metadata(self, format_init_metadata: FormatIni
if izf.format_selector is format_selector:
raise SabrStreamError('Server changed format. Changing formats is not currently supported')
duration_ms = ticks_to_ms(format_init_metadata.duration_timescale, format_init_metadata.duration_ticks)
duration_ms = ticks_to_ms(format_init_metadata.duration_ticks, format_init_metadata.duration_timescale)
total_segments = format_init_metadata.total_segments
if not total_segments and self.live_metadata and self.live_metadata.head_sequence_number: