1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-07-18 03:08:31 +00:00

[test] initial media header tests

This commit is contained in:
coletdjnz 2025-07-08 07:55:36 +12:00
parent 75fb53bccc
commit 81a99ff14d
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
3 changed files with 167 additions and 28 deletions

View File

@ -8,13 +8,16 @@
PoTokenStatusSabrPart, PoTokenStatusSabrPart,
FormatInitializedSabrPart, FormatInitializedSabrPart,
MediaSeekSabrPart, MediaSeekSabrPart,
MediaSegmentInitSabrPart,
) )
from yt_dlp.extractor.youtube._streaming.sabr.processor import ( from yt_dlp.extractor.youtube._streaming.sabr.processor import (
SabrProcessor, SabrProcessor,
ProcessStreamProtectionStatusResult, ProcessStreamProtectionStatusResult,
ProcessFormatInitializationMetadataResult, ProcessFormatInitializationMetadataResult,
ProcessLiveMetadataResult, ProcessSabrSeekResult, ProcessLiveMetadataResult,
ProcessSabrSeekResult,
ProcessMediaHeaderResult,
) )
from yt_dlp.extractor.youtube._streaming.sabr.models import ( from yt_dlp.extractor.youtube._streaming.sabr.models import (
AudioSelector, AudioSelector,
@ -29,7 +32,9 @@
FormatInitializationMetadata, FormatInitializationMetadata,
LiveMetadata, LiveMetadata,
SabrContextUpdate, SabrContextUpdate,
SabrContextSendingPolicy, SabrSeek, SabrContextSendingPolicy,
SabrSeek,
MediaHeader,
) )
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
@ -84,6 +89,43 @@ def factory():
return factory return factory
def make_format_im(selector=None, video_id=None):
return FormatInitializationMetadata(
video_id=video_id or example_video_id,
format_id=selector.format_ids[0] if selector else FormatId(itag=140),
end_time_ms=10000,
total_segments=5,
mime_type=(selector.mime_prefix + '/mp4') if selector else 'audio/mp4',
duration_ticks=10000,
duration_timescale=1000,
)
def make_init_header(selector=None, video_id=None):
return MediaHeader(
video_id=video_id or example_video_id,
format_id=selector.format_ids[0] if selector else FormatId(itag=140),
header_id=0,
is_init_segment=True,
start_data_range=0,
content_length=501,
)
def make_media_header(selector=None, video_id=None, sequence_no=None, header_id=0):
return MediaHeader(
video_id=video_id or example_video_id,
format_id=selector.format_ids[0] if selector else FormatId(itag=140),
header_id=header_id,
start_data_range=502,
content_length=10000,
sequence_number=sequence_no,
is_init_segment=False,
duration_ms=2300,
start_ms=0,
)
class TestSabrProcessorInitialization: class TestSabrProcessorInitialization:
@pytest.mark.parametrize( @pytest.mark.parametrize(
'audio_sel,video_sel,caption_sel,expected_bitfield', 'audio_sel,video_sel,caption_sel,expected_bitfield',
@ -1120,8 +1162,6 @@ def test_write_policy_keep_existing(self, logger, base_args):
'Received a SABR Context Update with write_policy=KEEP_EXISTING' 'Received a SABR Context Update with write_policy=KEEP_EXISTING'
'matching an existing SABR Context Update. Ignoring update') 'matching an existing SABR Context Update. Ignoring update')
class TestSabrContextUpdateSendingPolicy:
def test_set_sabr_context_update_sending_policy(self, base_args, logger): def test_set_sabr_context_update_sending_policy(self, base_args, logger):
processor = SabrProcessor(**base_args) processor = SabrProcessor(**base_args)
@ -1232,3 +1272,107 @@ def test_sabr_seek(self, logger, base_args):
assert izf.current_segment is None assert izf.current_segment is None
logger.debug.assert_called_with('Seeking to 5600ms') logger.debug.assert_called_with('Seeking to 5600ms')
class TestMediaHeader:
def test_media_header_init_segment(self, base_args):
selector = make_selector('audio')
processor = SabrProcessor(
**base_args,
audio_selection=selector,
)
fim = make_format_im(selector)
processor.process_format_initialization_metadata(fim)
media_header = make_init_header(selector)
result = processor.process_media_header(media_header)
assert isinstance(result, ProcessMediaHeaderResult)
assert isinstance(result.sabr_part, MediaSegmentInitSabrPart)
part = result.sabr_part
# TODO: confirm expected duration/start_ms settings for init segments
assert part == MediaSegmentInitSabrPart(
format_selector=selector,
format_id=selector.format_ids[0],
player_time_ms=0,
sequence_number=None,
total_segments=5,
is_init_segment=True,
content_length=501,
start_time_ms=0,
duration_ms=0,
start_bytes=0,
)
assert media_header.header_id in processor.partial_segments
segment = processor.partial_segments[media_header.header_id]
# TODO: confirm expected duration/start_ms settings for init segments
assert segment == Segment(
format_id=selector.format_ids[0],
is_init_segment=True,
duration_ms=0,
start_ms=0,
start_data_range=0,
sequence_number=None,
content_length=501,
content_length_estimated=False,
initialized_format=processor.initialized_formats[str(selector.format_ids[0])],
duration_estimated=True,
discard=False,
consumed=False,
received_data_length=0,
)
def test_media_header_segment(self, base_args):
selector = make_selector('audio')
processor = SabrProcessor(
**base_args,
audio_selection=selector,
)
fim = make_format_im(selector)
processor.process_format_initialization_metadata(fim)
media_header = make_media_header(selector, sequence_no=1)
result = processor.process_media_header(media_header)
assert isinstance(result, ProcessMediaHeaderResult)
assert isinstance(result.sabr_part, MediaSegmentInitSabrPart)
part = result.sabr_part
assert part == MediaSegmentInitSabrPart(
format_selector=selector,
format_id=selector.format_ids[0],
player_time_ms=0,
sequence_number=1,
total_segments=5,
is_init_segment=False,
content_length=10000,
start_time_ms=0,
duration_ms=2300,
start_bytes=502,
)
assert media_header.header_id in processor.partial_segments
segment = processor.partial_segments[media_header.header_id]
assert segment == Segment(
format_id=selector.format_ids[0],
is_init_segment=False,
duration_ms=2300,
start_ms=0,
start_data_range=502,
sequence_number=1,
content_length=10000,
content_length_estimated=False,
initialized_format=processor.initialized_formats[str(selector.format_ids[0])],
duration_estimated=False,
discard=False,
consumed=False,
received_data_length=0,
)

View File

@ -13,7 +13,7 @@ class Segment:
duration_ms: int = 0 duration_ms: int = 0
start_ms: int = 0 start_ms: int = 0
start_data_range: int = 0 start_data_range: int = 0
sequence_number: int = 0 sequence_number: int | None = 0
content_length: int | None = None content_length: int | None = None
content_length_estimated: bool = False content_length_estimated: bool = False
initialized_format: InitializedFormat = None initialized_format: InitializedFormat = None
@ -62,36 +62,29 @@ class FormatSelector:
display_name: str display_name: str
format_ids: list[FormatId] = dataclasses.field(default_factory=list) format_ids: list[FormatId] = dataclasses.field(default_factory=list)
discard_media: bool = False discard_media: bool = False
mime_prefix: str | None = None
def match(self, format_id: FormatId = None, **kwargs) -> bool: def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
return format_id in self.format_ids return (
format_id in self.format_ids
or (
not self.format_ids
and self.mime_prefix
and mime_type and mime_type.lower().startswith(self.mime_prefix)
)
)
@dataclasses.dataclass @dataclasses.dataclass
class AudioSelector(FormatSelector): class AudioSelector(FormatSelector):
mime_prefix: str = dataclasses.field(default='audio')
def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
return (
super().match(format_id, mime_type=mime_type, **kwargs)
or (not self.format_ids and mime_type and mime_type.lower().startswith('audio'))
)
@dataclasses.dataclass @dataclasses.dataclass
class VideoSelector(FormatSelector): class VideoSelector(FormatSelector):
mime_prefix: str = dataclasses.field(default='video')
def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
return (
super().match(format_id, mime_type=mime_type, **kwargs)
or (not self.format_ids and mime_type and mime_type.lower().startswith('video'))
)
@dataclasses.dataclass @dataclasses.dataclass
class CaptionSelector(FormatSelector): class CaptionSelector(FormatSelector):
mime_prefix: str = dataclasses.field(default='text')
def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
return (
super().match(format_id, mime_type=mime_type, **kwargs)
or (not self.format_ids and mime_type and mime_type.lower().startswith('text'))
)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses import dataclasses
import enum import enum
@ -15,7 +17,7 @@ class SabrPart:
class MediaSegmentInitSabrPart(SabrPart): class MediaSegmentInitSabrPart(SabrPart):
format_selector: FormatSelector format_selector: FormatSelector
format_id: FormatId format_id: FormatId
sequence_number: int = None sequence_number: int | None = None
is_init_segment: bool = False is_init_segment: bool = False
total_segments: int = None total_segments: int = None
start_time_ms: int = None start_time_ms: int = None
@ -31,7 +33,7 @@ class MediaSegmentInitSabrPart(SabrPart):
class MediaSegmentDataSabrPart(SabrPart): class MediaSegmentDataSabrPart(SabrPart):
format_selector: FormatSelector format_selector: FormatSelector
format_id: FormatId format_id: FormatId
sequence_number: int = None sequence_number: int | None = None
is_init_segment: bool = False is_init_segment: bool = False
total_segments: int = None total_segments: int = None
data: bytes = b'' data: bytes = b''
@ -43,7 +45,7 @@ class MediaSegmentDataSabrPart(SabrPart):
class MediaSegmentEndSabrPart(SabrPart): class MediaSegmentEndSabrPart(SabrPart):
format_selector: FormatSelector format_selector: FormatSelector
format_id: FormatId format_id: FormatId
sequence_number: int = None sequence_number: int | None = None
is_init_segment: bool = False is_init_segment: bool = False
total_segments: int = None total_segments: int = None