1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-08-14 16:38:29 +00:00

[test] media part tests

This commit is contained in:
coletdjnz 2025-08-09 15:18:14 +12:00
parent eaec0c84bf
commit a9b1c25ddc
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
3 changed files with 132 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import dataclasses import dataclasses
import io
import pytest import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -9,6 +10,7 @@
FormatInitializedSabrPart, FormatInitializedSabrPart,
MediaSeekSabrPart, MediaSeekSabrPart,
MediaSegmentInitSabrPart, MediaSegmentInitSabrPart,
MediaSegmentDataSabrPart,
) )
from yt_dlp.extractor.youtube._streaming.sabr.processor import ( from yt_dlp.extractor.youtube._streaming.sabr.processor import (
@ -18,6 +20,7 @@
ProcessLiveMetadataResult, ProcessLiveMetadataResult,
ProcessSabrSeekResult, ProcessSabrSeekResult,
ProcessMediaHeaderResult, ProcessMediaHeaderResult,
ProcessMediaResult,
) )
from yt_dlp.extractor.youtube._streaming.sabr.models import ( from yt_dlp.extractor.youtube._streaming.sabr.models import (
AudioSelector, AudioSelector,
@ -319,7 +322,7 @@ def test_override_defaults(self, base_args):
assert processor.post_live is True assert processor.post_live is True
class TestStreamProtectionStatusPart: class TestStreamProtectionStatus:
@pytest.mark.parametrize( @pytest.mark.parametrize(
'sps,po_token,expected_status', 'sps,po_token,expected_status',
@ -2010,3 +2013,127 @@ def test_segment_mismatch(self, base_args):
processor.process_media_header(media_header) processor.process_media_header(media_header)
assert exc_info.value.expected_sequence_number == 11 assert exc_info.value.expected_sequence_number == 11
assert exc_info.value.received_sequence_number == 1 assert exc_info.value.received_sequence_number == 1
class TestMedia:
def test_valid_media_parts(self, base_args):
selector = make_selector('audio')
processor = SabrProcessor(
**base_args,
audio_selection=selector,
)
example_payload = b'example-data'
fim = make_format_im(selector)
processor.process_format_initialization_metadata(fim)
media_header = make_media_header(selector, sequence_no=1)
media_header.content_length = len(example_payload)
processor.process_media_header(media_header)
result = processor.process_media(
header_id=media_header.header_id,
content_length=media_header.content_length,
data=io.BytesIO(example_payload))
assert isinstance(result, ProcessMediaResult)
assert isinstance(result.sabr_part, MediaSegmentDataSabrPart)
assert result.sabr_part == MediaSegmentDataSabrPart(
format_selector=selector,
format_id=selector.format_ids[0],
sequence_number=1,
is_init_segment=False,
total_segments=fim.total_segments,
data=example_payload,
content_length=len(example_payload),
segment_start_bytes=0,
)
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload)
# Subsequent call should increment received data length
result = processor.process_media(
header_id=media_header.header_id,
content_length=media_header.content_length,
data=io.BytesIO(example_payload))
assert isinstance(result, ProcessMediaResult)
assert isinstance(result.sabr_part, MediaSegmentDataSabrPart)
assert result.sabr_part == MediaSegmentDataSabrPart(
format_selector=selector,
format_id=selector.format_ids[0],
sequence_number=1,
is_init_segment=False,
total_segments=fim.total_segments,
data=example_payload,
content_length=len(example_payload),
segment_start_bytes=len(example_payload),
)
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload) * 2
def test_no_matching_partial_segment(self, base_args):
# Should raise an error if no matching partial segment found
selector = make_selector('audio')
processor = SabrProcessor(
**base_args,
audio_selection=selector,
)
with pytest.raises(SabrStreamError, match='Header ID 12345 not found in partial segments'):
processor.process_media(
header_id=12345, # Non-existent header ID
content_length=100,
data=io.BytesIO(b'example-data'),
)
def test_discarded_partial_segment(self, base_args):
# Should ignore the media part if the segment is marked as discard
selector = make_selector('audio', discard_media=True)
processor = SabrProcessor(
**base_args,
audio_selection=selector,
)
example_payload = b'example-data'
fim = make_format_im(selector)
processor.process_format_initialization_metadata(fim)
media_header = make_media_header(selector, sequence_no=1)
media_header.content_length = len(example_payload)
processor.process_media_header(media_header)
result = processor.process_media(
header_id=media_header.header_id,
content_length=media_header.content_length,
data=io.BytesIO(example_payload))
assert isinstance(result, ProcessMediaResult)
assert result.sabr_part is None
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload)
def test_valid_init_segment(self, base_args):
# Should process init segment correctly and report as such
selector = make_selector('audio')
processor = SabrProcessor(
**base_args,
audio_selection=selector,
)
example_payload = b'example-init-data'
fim = make_format_im(selector)
processor.process_format_initialization_metadata(fim)
media_header = make_init_header(selector)
processor.process_media_header(media_header)
result = processor.process_media(
header_id=media_header.header_id,
content_length=len(example_payload),
data=io.BytesIO(example_payload))
assert isinstance(result, ProcessMediaResult)
assert isinstance(result.sabr_part, MediaSegmentDataSabrPart)
assert result.sabr_part == MediaSegmentDataSabrPart(
format_selector=selector,
format_id=selector.format_ids[0],
sequence_number=None,
is_init_segment=True, # Init segment should be True
total_segments=fim.total_segments,
data=example_payload,
content_length=len(example_payload),
segment_start_bytes=0,
)
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload)

View File

@ -35,10 +35,10 @@ class MediaSegmentDataSabrPart(SabrPart):
format_id: FormatId format_id: FormatId
sequence_number: int | None = None sequence_number: int | None = None
is_init_segment: bool = False is_init_segment: bool = False
total_segments: int = None total_segments: int | None = None
data: bytes = b'' data: bytes = b''
content_length: int = None content_length: int | None = None
segment_start_bytes: int = None segment_start_bytes: int | None = None
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -330,7 +330,7 @@ def process_media(self, header_id: int, content_length: int, data: io.BufferedIO
segment = self.partial_segments.get(header_id) segment = self.partial_segments.get(header_id)
if not segment: if not segment:
self.logger.debug(f'Header ID {header_id} not found') self.logger.debug(f'Header ID {header_id} not found')
return result raise SabrStreamError(f'Header ID {header_id} not found in partial segments')
segment_start_bytes = segment.received_data_length segment_start_bytes = segment.received_data_length
segment.received_data_length += content_length segment.received_data_length += content_length