diff --git a/test/test_sabr/test_processor.py b/test/test_sabr/test_processor.py index ff8eda1e24..5849b07abb 100644 --- a/test/test_sabr/test_processor.py +++ b/test/test_sabr/test_processor.py @@ -11,6 +11,7 @@ MediaSeekSabrPart, MediaSegmentInitSabrPart, MediaSegmentDataSabrPart, + MediaSegmentEndSabrPart, ) from yt_dlp.extractor.youtube._streaming.sabr.processor import ( @@ -21,6 +22,7 @@ ProcessSabrSeekResult, ProcessMediaHeaderResult, ProcessMediaResult, + ProcessMediaEndResult, ) from yt_dlp.extractor.youtube._streaming.sabr.models import ( AudioSelector, @@ -2165,3 +2167,357 @@ def test_valid_init_segment(self, base_args): segment_start_bytes=0, ) assert processor.partial_segments[media_header.header_id].received_data_length == len(example_payload) + + +class TestMediaEnd: + + def test_init_segment_media_end(self, base_args): + # Should process media end for init segment + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_init_header(selector) + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, media_header.content_length, io.BytesIO(b'example-init-data')) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert isinstance(result.sabr_part, MediaSegmentEndSabrPart) + assert result.sabr_part == MediaSegmentEndSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + sequence_number=None, + is_init_segment=True, + total_segments=fim.total_segments, + ) + assert result.is_new_segment is True + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is segment + assert init_format.current_segment is None + assert not init_format.consumed_ranges + + def test_media_segment_media_end(self, base_args): + # Should process media end for a regular media segment + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_media_header(selector, sequence_no=1) + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, media_header.content_length, io.BytesIO(b'example-data')) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert isinstance(result.sabr_part, MediaSegmentEndSabrPart) + assert result.sabr_part == MediaSegmentEndSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + sequence_number=1, + is_init_segment=False, + total_segments=fim.total_segments, + ) + assert result.is_new_segment is True + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is None + assert init_format.current_segment is segment + assert len(init_format.consumed_ranges) == 1 + assert init_format.consumed_ranges[0] == ConsumedRange( + start_sequence_number=1, + end_sequence_number=1, + start_time_ms=media_header.start_ms, + duration_ms=media_header.duration_ms, + ) + + def test_media_segment_update_consumed_range(self, base_args): + # Should update an existing consumed range the segment belongs to (at the end of) + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + init_format = processor.initialized_formats[str(selector.format_ids[0])] + init_format.consumed_ranges.append( + ConsumedRange( + start_sequence_number=1, + end_sequence_number=3, + start_time_ms=20, + duration_ms=3000)) + # Unrelated consumed range + init_format.consumed_ranges.append( + ConsumedRange( + start_sequence_number=6, + end_sequence_number=10, + start_time_ms=6000, + duration_ms=3000)) + + media_header = make_media_header(selector, sequence_no=4) + media_header.start_ms = 3021 + media_header.duration_ms = 1050 + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, media_header.content_length, io.BytesIO(b'example-data')) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert isinstance(result.sabr_part, MediaSegmentEndSabrPart) + assert result.sabr_part == MediaSegmentEndSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + sequence_number=4, + is_init_segment=False, + total_segments=fim.total_segments, + ) + assert result.is_new_segment is True + assert len(init_format.consumed_ranges) == 2 + assert init_format.consumed_ranges[0] == ConsumedRange( + start_sequence_number=1, + end_sequence_number=4, + start_time_ms=20, + duration_ms=4051, + ) + assert init_format.consumed_ranges[1] == ConsumedRange( + start_sequence_number=6, + end_sequence_number=10, + start_time_ms=6000, + duration_ms=3000, + ) + assert init_format.current_segment is segment + + def test_media_segment_discard(self, base_args): + # Should discard the segment if it is marked as discard. Consumed ranges should be updated. + selector = make_selector('audio', discard_media=True) + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + # Clear consumed ranges. We want to also handle the case when we cannot mark the format as entirely consumed. + processor.initialized_formats[str(selector.format_ids[0])].consumed_ranges.clear() + media_header = make_media_header(selector, sequence_no=1) + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, media_header.content_length, io.BytesIO(b'example-data')) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert result.sabr_part is None + # New segment created, but discarded. Not previously consumed. + assert result.is_new_segment is True + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is None + assert init_format.current_segment is segment + assert len(init_format.consumed_ranges) == 1 + + def test_init_segment_discard(self, base_args): + # Should discard the init segment if it is marked as discard. + selector = make_selector('audio', discard_media=True) + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + # Clear consumed ranges. We want to also handle the case when we cannot mark the format as entirely consumed. + processor.initialized_formats[str(selector.format_ids[0])].consumed_ranges.clear() + media_header = make_init_header(selector) + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, media_header.content_length, io.BytesIO(b'example-init-data')) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert result.sabr_part is None + # New segment created, but discarded. Not previously consumed. + assert result.is_new_segment is True + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is segment + assert init_format.current_segment is None + assert len(init_format.consumed_ranges) == 0 + + def test_media_segment_consumed(self, base_args): + # Should mark the segment as consumed (and discard) if it is already in consumed ranges + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_media_header(selector, sequence_no=1) + # Simulate that the segment is already consumed + processor.initialized_formats[str(selector.format_ids[0])].consumed_ranges.append( + ConsumedRange( + start_sequence_number=1, + end_sequence_number=2, + start_time_ms=media_header.start_ms, + duration_ms=media_header.duration_ms + 500, + )) + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, media_header.content_length, io.BytesIO(b'example-data')) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert result.sabr_part is None + assert result.is_new_segment is False + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is None + assert init_format.current_segment is segment + assert len(init_format.consumed_ranges) == 1 + assert init_format.consumed_ranges[0] == ConsumedRange( + start_sequence_number=1, + end_sequence_number=2, + start_time_ms=media_header.start_ms, + duration_ms=media_header.duration_ms + 500, + ) + + def test_init_segment_consumed(self, base_args): + # Should mark the init segment as consumed (and discard) if already seen init segment + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_init_header(selector) + # Simulate that the init segment is already consumed + processor.initialized_formats[str(selector.format_ids[0])].init_segment = Segment( + is_init_segment=True, + format_id=selector.format_ids[0], + sequence_number=None, + ) + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, media_header.content_length, io.BytesIO(b'example-init-data')) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert result.sabr_part is None + assert result.is_new_segment is False + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is segment # xxx: Should we be taking the new init segment we discarded? + assert init_format.current_segment is None + assert len(init_format.consumed_ranges) == 0 + + def test_media_end_no_content_length(self, base_args): + # Should not raise an error if segment does not have a content length + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_media_header(selector, sequence_no=1) + media_header.content_length = None + processor.process_media_header(media_header) + processor.process_media(media_header.header_id, 500, io.BytesIO(b'example-data')) + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert isinstance(result.sabr_part, MediaSegmentEndSabrPart) + + def test_media_end_no_partial_segment(self, base_args): + # Should raise an error if no partial segment found for the header_id + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + + with pytest.raises(SabrStreamError, match='Header ID 12345 not found in partial segments'): + processor.process_media_end(12345) + + def test_content_length_mismatch(self, base_args): + # Should raise an error if content length does not match the expected length + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_media_header(selector, sequence_no=1) + processor.process_media_header(media_header) + + with pytest.raises(SabrStreamError, match='Content length mismatch'): + processor.process_media_end(media_header.header_id) + + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is None + assert init_format.current_segment is None + + def test_estimated_content_length_mismatch(self, base_args, logger): + # Should not raise an error if estimated content length does not match, rather log in trace + selector = make_selector('audio') + processor = SabrProcessor( + **base_args, + audio_selection=selector, + ) + processor.is_live = True + fim = make_format_im(selector) + processor.process_format_initialization_metadata(fim) + media_header = make_media_header(selector, sequence_no=1) + media_header.content_length = None + media_header.bitrate_bps = 1000000 + media_header.duration_ms = 4000 + processor.process_media_header(media_header) + processor.process_media( + media_header.header_id, + content_length=500, # Mismatch between what is estimated + data=io.BytesIO(b'example-data'), + ) + segment = processor.partial_segments[media_header.header_id] + + result = processor.process_media_end(media_header.header_id) + + assert isinstance(result, ProcessMediaEndResult) + assert isinstance(result.sabr_part, MediaSegmentEndSabrPart) + assert result.sabr_part == MediaSegmentEndSabrPart( + format_selector=selector, + format_id=selector.format_ids[0], + sequence_number=1, + is_init_segment=False, + total_segments=fim.total_segments, + ) + assert result.is_new_segment is True + assert media_header.header_id not in processor.partial_segments + init_format = processor.initialized_formats[str(selector.format_ids[0])] + assert init_format.init_segment is None + assert init_format.current_segment is segment + assert init_format.current_segment.content_length_estimated is True + assert init_format.current_segment.content_length == 4000000 + assert init_format.current_segment.received_data_length == 500 + logger.trace.assert_called_with( + f'Content length for {segment.format_id} (sequence 1) was estimated, ' + f'estimated {segment.content_length} bytes, got {segment.received_data_length} bytes') diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py index 28b04a8879..f8cd0a5b06 100644 --- a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py +++ b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py @@ -44,7 +44,7 @@ class ProcessMediaEndResult: def __init__(self, sabr_part: MediaSegmentEndSabrPart = None, is_new_segment: bool = False): - self.is_new_segment = is_new_segment + self.is_new_segment = is_new_segment # TODO: better name self.sabr_part = sabr_part @@ -355,10 +355,8 @@ def process_media_end(self, header_id: int) -> ProcessMediaEndResult: result = ProcessMediaEndResult() segment = self.partial_segments.pop(header_id, None) if not segment: - # Should only happen due to server issue, - # or we have an uninitialized format (which itself should not happen) - self.logger.warning(f'Received a MediaEnd for an unknown or already finished header ID {header_id}') - return result + self.logger.debug(f'Header ID {header_id} not found') + raise SabrStreamError(f'Header ID {header_id} not found in partial segments') self.logger.trace( f'MediaEnd for {segment.format_id} (sequence {segment.sequence_number}, data length = {segment.received_data_length})') @@ -374,10 +372,9 @@ def process_media_end(self, header_id: int) -> ProcessMediaEndResult: f'expected {segment.content_length} bytes, got {segment.received_data_length} bytes', ) - # Only count received segments as new segments if they are not discarded (consumed) - # or it was part of a format that was discarded (but not consumed). - # The latter can happen if the format is to be discarded but was not marked as fully consumed. - if not segment.discard or (segment.initialized_format.discard and not segment.consumed): + # Only count received segments as new segments if they are not consumed. + # Discarded segments that are not consumed are considered new segments. + if not segment.consumed: result.is_new_segment = True # Return the segment here instead of during MEDIA part(s) because: @@ -409,22 +406,19 @@ def process_media_end(self, header_id: int) -> ProcessMediaEndResult: segment.initialized_format.current_segment = segment - # Try to find a consumed range for this segment in sequence - consumed_range = next( - (cr for cr in segment.initialized_format.consumed_ranges if cr.end_sequence_number == segment.sequence_number - 1), - None, - ) - - if not consumed_range and any( - cr.start_sequence_number <= segment.sequence_number <= cr.end_sequence_number - for cr in segment.initialized_format.consumed_ranges - ): + if segment.consumed: # Segment is already consumed, do not create a new consumed range. It was probably discarded. # This can be expected to happen in the case of video-only, where we discard the audio track (and mark it as entirely buffered) # We still want to create/update consumed range for discarded media IF it is not already consumed self.logger.debug(f'{segment.format_id} segment {segment.sequence_number} already consumed, not creating or updating consumed range (discard={segment.discard})') return result + # Try to find a consumed range for this segment in sequence + consumed_range = next( + (cr for cr in segment.initialized_format.consumed_ranges if cr.end_sequence_number == segment.sequence_number - 1), + None, + ) + if not consumed_range: # Create a new consumed range starting from this segment segment.initialized_format.consumed_ranges.append(ConsumedRange( @@ -611,7 +605,7 @@ def process_sabr_context_update(self, sabr_ctx_update: SabrContextUpdate): 'This may cause issues with playback.') self.sabr_context_updates[sabr_ctx_update.type] = sabr_ctx_update - if sabr_ctx_update.send_by_default is True: + if sabr_ctx_update.send_by_default: self.sabr_contexts_to_send.add(sabr_ctx_update.type) self.logger.debug(f'Registered SabrContextUpdate {sabr_ctx_update}')