mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-08-14 16:38:29 +00:00
[test] media part tests
This commit is contained in:
parent
eaec0c84bf
commit
a9b1c25ddc
@ -1,4 +1,5 @@
|
||||
import dataclasses
|
||||
import io
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
@ -9,6 +10,7 @@
|
||||
FormatInitializedSabrPart,
|
||||
MediaSeekSabrPart,
|
||||
MediaSegmentInitSabrPart,
|
||||
MediaSegmentDataSabrPart,
|
||||
)
|
||||
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.processor import (
|
||||
@ -18,6 +20,7 @@
|
||||
ProcessLiveMetadataResult,
|
||||
ProcessSabrSeekResult,
|
||||
ProcessMediaHeaderResult,
|
||||
ProcessMediaResult,
|
||||
)
|
||||
from yt_dlp.extractor.youtube._streaming.sabr.models import (
|
||||
AudioSelector,
|
||||
@ -319,7 +322,7 @@ def test_override_defaults(self, base_args):
|
||||
assert processor.post_live is True
|
||||
|
||||
|
||||
class TestStreamProtectionStatusPart:
|
||||
class TestStreamProtectionStatus:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'sps,po_token,expected_status',
|
||||
@ -2010,3 +2013,127 @@ def test_segment_mismatch(self, base_args):
|
||||
processor.process_media_header(media_header)
|
||||
assert exc_info.value.expected_sequence_number == 11
|
||||
assert exc_info.value.received_sequence_number == 1
|
||||
|
||||
|
||||
class TestMedia:
|
||||
def test_valid_media_parts(self, base_args):
|
||||
selector = make_selector('audio')
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=selector,
|
||||
)
|
||||
|
||||
example_payload = b'example-data'
|
||||
fim = make_format_im(selector)
|
||||
processor.process_format_initialization_metadata(fim)
|
||||
media_header = make_media_header(selector, sequence_no=1)
|
||||
media_header.content_length = len(example_payload)
|
||||
processor.process_media_header(media_header)
|
||||
|
||||
result = processor.process_media(
|
||||
header_id=media_header.header_id,
|
||||
content_length=media_header.content_length,
|
||||
data=io.BytesIO(example_payload))
|
||||
|
||||
assert isinstance(result, ProcessMediaResult)
|
||||
assert isinstance(result.sabr_part, MediaSegmentDataSabrPart)
|
||||
assert result.sabr_part == MediaSegmentDataSabrPart(
|
||||
format_selector=selector,
|
||||
format_id=selector.format_ids[0],
|
||||
sequence_number=1,
|
||||
is_init_segment=False,
|
||||
total_segments=fim.total_segments,
|
||||
data=example_payload,
|
||||
content_length=len(example_payload),
|
||||
segment_start_bytes=0,
|
||||
)
|
||||
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload)
|
||||
|
||||
# Subsequent call should increment received data length
|
||||
result = processor.process_media(
|
||||
header_id=media_header.header_id,
|
||||
content_length=media_header.content_length,
|
||||
data=io.BytesIO(example_payload))
|
||||
|
||||
assert isinstance(result, ProcessMediaResult)
|
||||
assert isinstance(result.sabr_part, MediaSegmentDataSabrPart)
|
||||
assert result.sabr_part == MediaSegmentDataSabrPart(
|
||||
format_selector=selector,
|
||||
format_id=selector.format_ids[0],
|
||||
sequence_number=1,
|
||||
is_init_segment=False,
|
||||
total_segments=fim.total_segments,
|
||||
data=example_payload,
|
||||
content_length=len(example_payload),
|
||||
segment_start_bytes=len(example_payload),
|
||||
)
|
||||
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload) * 2
|
||||
|
||||
def test_no_matching_partial_segment(self, base_args):
|
||||
# Should raise an error if no matching partial segment found
|
||||
selector = make_selector('audio')
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=selector,
|
||||
)
|
||||
with pytest.raises(SabrStreamError, match='Header ID 12345 not found in partial segments'):
|
||||
processor.process_media(
|
||||
header_id=12345, # Non-existent header ID
|
||||
content_length=100,
|
||||
data=io.BytesIO(b'example-data'),
|
||||
)
|
||||
|
||||
def test_discarded_partial_segment(self, base_args):
|
||||
# Should ignore the media part if the segment is marked as discard
|
||||
selector = make_selector('audio', discard_media=True)
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=selector,
|
||||
)
|
||||
example_payload = b'example-data'
|
||||
fim = make_format_im(selector)
|
||||
processor.process_format_initialization_metadata(fim)
|
||||
media_header = make_media_header(selector, sequence_no=1)
|
||||
media_header.content_length = len(example_payload)
|
||||
processor.process_media_header(media_header)
|
||||
|
||||
result = processor.process_media(
|
||||
header_id=media_header.header_id,
|
||||
content_length=media_header.content_length,
|
||||
data=io.BytesIO(example_payload))
|
||||
|
||||
assert isinstance(result, ProcessMediaResult)
|
||||
assert result.sabr_part is None
|
||||
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload)
|
||||
|
||||
def test_valid_init_segment(self, base_args):
|
||||
# Should process init segment correctly and report as such
|
||||
selector = make_selector('audio')
|
||||
processor = SabrProcessor(
|
||||
**base_args,
|
||||
audio_selection=selector,
|
||||
)
|
||||
example_payload = b'example-init-data'
|
||||
fim = make_format_im(selector)
|
||||
processor.process_format_initialization_metadata(fim)
|
||||
media_header = make_init_header(selector)
|
||||
processor.process_media_header(media_header)
|
||||
|
||||
result = processor.process_media(
|
||||
header_id=media_header.header_id,
|
||||
content_length=len(example_payload),
|
||||
data=io.BytesIO(example_payload))
|
||||
|
||||
assert isinstance(result, ProcessMediaResult)
|
||||
assert isinstance(result.sabr_part, MediaSegmentDataSabrPart)
|
||||
assert result.sabr_part == MediaSegmentDataSabrPart(
|
||||
format_selector=selector,
|
||||
format_id=selector.format_ids[0],
|
||||
sequence_number=None,
|
||||
is_init_segment=True, # Init segment should be True
|
||||
total_segments=fim.total_segments,
|
||||
data=example_payload,
|
||||
content_length=len(example_payload),
|
||||
segment_start_bytes=0,
|
||||
)
|
||||
assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload)
|
||||
|
@ -35,10 +35,10 @@ class MediaSegmentDataSabrPart(SabrPart):
|
||||
format_id: FormatId
|
||||
sequence_number: int | None = None
|
||||
is_init_segment: bool = False
|
||||
total_segments: int = None
|
||||
total_segments: int | None = None
|
||||
data: bytes = b''
|
||||
content_length: int = None
|
||||
segment_start_bytes: int = None
|
||||
content_length: int | None = None
|
||||
segment_start_bytes: int | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -330,7 +330,7 @@ def process_media(self, header_id: int, content_length: int, data: io.BufferedIO
|
||||
segment = self.partial_segments.get(header_id)
|
||||
if not segment:
|
||||
self.logger.debug(f'Header ID {header_id} not found')
|
||||
return result
|
||||
raise SabrStreamError(f'Header ID {header_id} not found in partial segments')
|
||||
|
||||
segment_start_bytes = segment.received_data_length
|
||||
segment.received_data_length += content_length
|
||||
|
Loading…
Reference in New Issue
Block a user