diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e2411ecfa..8012ebb8c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -208,7 +208,7 @@ jobs: python3.9 -m pip install -U pip wheel 'setuptools>=71.0.2' # XXX: Keep this in sync with pyproject.toml (it can't be accessed at this stage) and exclude secretstorage python3.9 -m pip install -U Pyinstaller mutagen pycryptodomex brotli certifi cffi \ - 'requests>=2.32.2,<3' 'urllib3>=1.26.17,<3' 'websockets>=13.0' + 'requests>=2.32.2,<3' 'urllib3>=1.26.17,<3' 'websockets>=13.0' 'protobug==0.3.0' run: | cd repo diff --git a/README.md b/README.md index 0f9a7d556..9a1057db4 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,7 @@ ### Metadata ### Misc +* [**protobug**](https://github.com/yt-dlp/protobug)\* - Protobuf library, for serializing and deserializing protobuf data. Licensed under [Unlicense](https://github.com/yt-dlp/protobug/blob/main/LICENSE) * [**pycryptodomex**](https://github.com/Legrandin/pycryptodome)\* - For decrypting AES-128 HLS streams and various other data. Licensed under [BSD-2-Clause](https://github.com/Legrandin/pycryptodome/blob/master/LICENSE.rst) * [**phantomjs**](https://github.com/ariya/phantomjs) - Used in extractors where javascript needs to be run. Licensed under [BSD-3-Clause](https://github.com/ariya/phantomjs/blob/master/LICENSE.BSD) * [**secretstorage**](https://github.com/mitya57/secretstorage)\* - For `--cookies-from-browser` to access the **Gnome** keyring while decrypting cookies of **Chromium**-based browsers on **Linux**. Licensed under [BSD-3-Clause](https://github.com/mitya57/secretstorage/blob/master/LICENSE) @@ -1808,6 +1809,7 @@ #### youtube * `innertube_host`: Innertube API host to use for all API requests; e.g. `studio.youtube.com`, `youtubei.googleapis.com`. Note that cookies exported from one subdomain will not work on others * `innertube_key`: Innertube API key to use for all API requests. By default, no API key is used * `raise_incomplete_data`: `Incomplete Data Received` raises an error instead of reporting a warning +* `sabr_log_level`: Set the log level of the SABR downloader. One of `TRACE`, `DEBUG` or `INFO` (default) * `data_sync_id`: Overrides the account Data Sync ID used in Innertube API requests. This may be needed if you are using an account with `youtube:player_skip=webpage,configs` or `youtubetab:skip=webpage` * `visitor_data`: Overrides the Visitor Data used in Innertube API requests. This should be used with `player_skip=webpage,configs` and without cookies. Note: this may have adverse effects if used improperly. If a session from a browser is wanted, you should pass cookies instead (which contain the Visitor ID) * `po_token`: Proof of Origin (PO) Token(s) to use. Comma seperated list of PO Tokens in the format `CLIENT.CONTEXT+PO_TOKEN`, e.g. `youtube:po_token=web.gvs+XXX,web.player=XXX,web_safari.gvs+YYY`. Context can be any of `gvs` (Google Video Server URLs), `player` (Innertube player request) or `subs` (Subtitles) diff --git a/pyproject.toml b/pyproject.toml index 3775251e1..af7543fe8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ default = [ "requests>=2.32.2,<3", "urllib3>=1.26.17,<3", "websockets>=13.0", + "protobug==0.3.0", ] curl-cffi = [ "curl-cffi>=0.5.10,!=0.6.*,!=0.7.*,!=0.8.*,!=0.9.*,<0.11; implementation_name=='cpython'", diff --git a/test/test_sabr/test_ump.py b/test/test_sabr/test_ump.py new file mode 100644 index 000000000..38514c216 --- /dev/null +++ b/test/test_sabr/test_ump.py @@ -0,0 +1,102 @@ +import io +import pytest + +from yt_dlp.extractor.youtube._streaming.ump import varint_size, read_varint, UMPDecoder, UMPPartId + + +@pytest.mark.parametrize('data, expected', [ + (0x01, 1), + (0x4F, 1), + (0x80, 2), + (0xBF, 2), + (0xC0, 3), + (0xDF, 3), + (0xE0, 4), + (0xEF, 4), + (0xF0, 5), + (0xFF, 5), +]) +def test_varint_size(data, expected): + assert varint_size(data) == expected + + +@pytest.mark.parametrize('data, expected', [ + # 1 byte long varint + (b'\x01', 1), + (b'\x4F', 79), + # 2 byte long varint + (b'\x80\x01', 64), + (b'\x8A\x7F', 8138), + (b'\xBF\x7F', 8191), + # 3 byte long varint + (b'\xC0\x80\x01', 12288), + (b'\xDF\x7F\xFF', 2093055), + # 4 byte long varint + (b'\xE0\x80\x80\x01', 1574912), + (b'\xEF\x7F\xFF\xFF', 268433407), + # 5 byte long varint + (b'\xF0\x80\x80\x80\x01', 25198720), + (b'\xFF\x7F\xFF\xFF\xFF', 4294967167), +], +) +def test_readvarint(data, expected): + assert read_varint(io.BytesIO(data)) == expected + + +class TestUMPDecoder: + EXAMPLE_PART_DATA = [ + { + # Part 1: Part type of 20, part size of 127 + 'part_type_bytes': b'\x14', + 'part_size_bytes': b'\x7F', + 'part_data_bytes': b'\x01' * 127, + 'part_id': UMPPartId.MEDIA_HEADER, + 'part_size': 127, + }, + # Part 2, Part type of 4294967295, part size of 0 + { + 'part_type_bytes': b'\xFF\xFF\xFF\xFF\xFF', + 'part_size_bytes': b'\x00', + 'part_data_bytes': b'', + 'part_id': UMPPartId.UNKNOWN, + 'part_size': 0, + }, + # Part 3: Part type of 21, part size of 1574912 + { + 'part_type_bytes': b'\x15', + 'part_size_bytes': b'\xE0\x80\x80\x01', + 'part_data_bytes': b'\x01' * 1574912, + 'part_id': UMPPartId.MEDIA, + 'part_size': 1574912, + }, + ] + + COMBINED_PART_DATA = b''.join(part['part_type_bytes'] + part['part_size_bytes'] + part['part_data_bytes'] for part in EXAMPLE_PART_DATA) + + def test_iter_parts(self): + # Create a mock file-like object + mock_file = io.BytesIO(self.COMBINED_PART_DATA) + + # Create an instance of UMPDecoder with the mock file + decoder = UMPDecoder(mock_file) + + # Iterate over the parts and check the values + for idx, part in enumerate(decoder.iter_parts()): + assert part.part_id == self.EXAMPLE_PART_DATA[idx]['part_id'] + assert part.size == self.EXAMPLE_PART_DATA[idx]['part_size'] + assert part.data.read() == self.EXAMPLE_PART_DATA[idx]['part_data_bytes'] + + assert mock_file.closed + + def test_unexpected_eof(self): + # Unexpected bytes at the end of the file + mock_file = io.BytesIO(self.COMBINED_PART_DATA + b'\x00') + decoder = UMPDecoder(mock_file) + + # Iterate over the parts and check the values + with pytest.raises(EOFError, match='Unexpected EOF while reading part size'): + for idx, part in enumerate(decoder.iter_parts()): + assert part.part_id == self.EXAMPLE_PART_DATA[idx]['part_id'] + part.data.read() + + assert mock_file.closed diff --git a/test/test_sabr/test_utils.py b/test/test_sabr/test_utils.py new file mode 100644 index 000000000..4224356c0 --- /dev/null +++ b/test/test_sabr/test_utils.py @@ -0,0 +1,23 @@ +import pytest +from yt_dlp.extractor.youtube._streaming.sabr.utils import ticks_to_ms, broadcast_id_from_url + + +@pytest.mark.parametrize( + 'ticks, timescale, expected_ms', + [ + (1000, 1000, 1000), + (5000, 10000, 500), + (234234, 44100, 5312), + (1, 1, 1000), + (None, 1000, None), + (1000, None, None), + (None, None, None), + ], +) +def test_ticks_to_ms(ticks, timescale, expected_ms): + assert ticks_to_ms(ticks, timescale) == expected_ms + + +def test_broadcast_id_from_url(): + assert broadcast_id_from_url('https://example.com/path?other=param&id=example.1~243&other2=param2') == 'example.1~243' + assert broadcast_id_from_url('https://example.com/path?other=param&other2=param2') is None diff --git a/yt_dlp/__pyinstaller/hook-yt_dlp.py b/yt_dlp/__pyinstaller/hook-yt_dlp.py index 8e7f42f59..f416fdb55 100644 --- a/yt_dlp/__pyinstaller/hook-yt_dlp.py +++ b/yt_dlp/__pyinstaller/hook-yt_dlp.py @@ -25,7 +25,7 @@ def get_hidden_imports(): for module in ('websockets', 'requests', 'urllib3'): yield from collect_submodules(module) # These are auto-detected, but explicitly add them just in case - yield from ('mutagen', 'brotli', 'certifi', 'secretstorage', 'curl_cffi') + yield from ('mutagen', 'brotli', 'certifi', 'secretstorage', 'curl_cffi', 'protobug') hiddenimports = list(get_hidden_imports()) diff --git a/yt_dlp/dependencies/__init__.py b/yt_dlp/dependencies/__init__.py index 0d58da2bd..a40d6c3c2 100644 --- a/yt_dlp/dependencies/__init__.py +++ b/yt_dlp/dependencies/__init__.py @@ -79,6 +79,11 @@ except ImportError: curl_cffi = None +try: + import protobug +except ImportError: + protobug = None + from . import Cryptodome all_dependencies = {k: v for k, v in globals().items() if not k.startswith('_')} diff --git a/yt_dlp/downloader/__init__.py b/yt_dlp/downloader/__init__.py index 9c34bd289..8f575ece4 100644 --- a/yt_dlp/downloader/__init__.py +++ b/yt_dlp/downloader/__init__.py @@ -15,6 +15,10 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N and not (to_stdout and len(protocols) > 1) and set(protocols) == {'http_dash_segments_generator'}): return DashSegmentsFD + elif SabrFD is not None and set(downloaders) == {SabrFD} and SabrFD.can_download(info_copy): + # NOTE: there may be one or more SABR downloaders for this info_dict, + # as SABR can download multiple formats at once. + return SabrFD elif len(downloaders) == 1: return downloaders[0] return None @@ -36,6 +40,7 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N from .websocket import WebSocketFragmentFD from .youtube_live_chat import YoutubeLiveChatFD from .bunnycdn import BunnyCdnFD +from .sabr import SabrFD PROTOCOL_MAP = { 'rtmp': RtmpFD, @@ -56,6 +61,7 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N 'youtube_live_chat': YoutubeLiveChatFD, 'youtube_live_chat_replay': YoutubeLiveChatFD, 'bunnycdn': BunnyCdnFD, + 'sabr': SabrFD, } diff --git a/yt_dlp/downloader/sabr/__init__.py b/yt_dlp/downloader/sabr/__init__.py new file mode 100644 index 000000000..57655ee0f --- /dev/null +++ b/yt_dlp/downloader/sabr/__init__.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug +from yt_dlp.utils import DownloadError +from yt_dlp.downloader import FileDownloader + +if not protobug: + class SabrFD(FileDownloader): + + @classmethod + def can_download(cls, info_dict): + is_sabr = ( + info_dict.get('requested_formats') + and all( + format_info.get('protocol') == 'sabr' + for format_info in info_dict['requested_formats'])) + + if is_sabr: + raise DownloadError('SABRFD requires protobug to be installed') + + return is_sabr + +else: + from ._fd import SabrFD # noqa: F401 diff --git a/yt_dlp/downloader/sabr/_fd.py b/yt_dlp/downloader/sabr/_fd.py new file mode 100644 index 000000000..ff70d0d6a --- /dev/null +++ b/yt_dlp/downloader/sabr/_fd.py @@ -0,0 +1,330 @@ +from __future__ import annotations +import collections +import itertools + +from yt_dlp.networking.exceptions import TransportError, HTTPError +from yt_dlp.utils import traverse_obj, int_or_none, DownloadError, join_nonempty +from yt_dlp.downloader import FileDownloader + +from ._writer import SabrFDFormatWriter +from ._logger import create_sabrfd_logger + +from yt_dlp.extractor.youtube._streaming.sabr.part import ( + MediaSegmentEndSabrPart, + MediaSegmentDataSabrPart, + MediaSegmentInitSabrPart, + PoTokenStatusSabrPart, + RefreshPlayerResponseSabrPart, + MediaSeekSabrPart, + FormatInitializedSabrPart, +) +from yt_dlp.extractor.youtube._streaming.sabr.stream import SabrStream +from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange, AudioSelector, VideoSelector, CaptionSelector +from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError +from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, ClientName +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId + + +class SabrFD(FileDownloader): + + @classmethod + def can_download(cls, info_dict): + return ( + info_dict.get('requested_formats') + and all( + format_info.get('protocol') == 'sabr' + for format_info in info_dict['requested_formats'])) + + def _group_formats_by_client(self, filename, info_dict): + format_groups = collections.defaultdict(dict, {}) + requested_formats = info_dict.get('requested_formats') or [info_dict] + + for _idx, f in enumerate(requested_formats): + sabr_config = f.get('_sabr_config') + client_name = sabr_config.get('client_name') + client_info = sabr_config.get('client_info') + server_abr_streaming_url = f.get('url') + video_playback_ustreamer_config = sabr_config.get('video_playback_ustreamer_config') + + if not video_playback_ustreamer_config: + raise DownloadError('Video playback ustreamer config not found') + + sabr_format_group_config = format_groups.get(client_name) + + if not sabr_format_group_config: + sabr_format_group_config = format_groups[client_name] = { + 'server_abr_streaming_url': server_abr_streaming_url, + 'video_playback_ustreamer_config': video_playback_ustreamer_config, + 'formats': [], + 'initial_po_token': sabr_config.get('po_token'), + 'fetch_po_token_fn': fn if callable(fn := sabr_config.get('fetch_po_token_fn')) else None, + 'reload_config_fn': fn if callable(fn := sabr_config.get('reload_config_fn')) else None, + 'live_status': sabr_config.get('live_status'), + 'video_id': sabr_config.get('video_id'), + 'client_info': ClientInfo( + client_name=traverse_obj(client_info, ('clientName', {lambda x: ClientName[x]})), + client_version=traverse_obj(client_info, 'clientVersion'), + os_version=traverse_obj(client_info, 'osVersion'), + os_name=traverse_obj(client_info, 'osName'), + device_model=traverse_obj(client_info, 'deviceModel'), + device_make=traverse_obj(client_info, 'deviceMake'), + ), + 'target_duration_sec': sabr_config.get('target_duration_sec'), + # Number.MAX_SAFE_INTEGER + 'start_time_ms': ((2**53) - 1) if info_dict.get('live_status') == 'is_live' and not f.get('is_from_start') else 0, + } + + else: + if sabr_format_group_config['server_abr_streaming_url'] != server_abr_streaming_url: + raise DownloadError('Server ABR streaming URL mismatch') + + if sabr_format_group_config['video_playback_ustreamer_config'] != video_playback_ustreamer_config: + raise DownloadError('Video playback ustreamer config mismatch') + + itag = int_or_none(sabr_config.get('itag')) + sabr_format_group_config['formats'].append({ + 'display_name': f.get('format_id'), + 'format_id': itag and FormatId( + itag=itag, lmt=int_or_none(sabr_config.get('last_modified')), xtags=sabr_config.get('xtags')), + 'format_type': format_type(f), + 'quality': sabr_config.get('quality'), + 'height': sabr_config.get('height'), + 'filename': f.get('filepath', filename), + 'info_dict': f, + }) + + return format_groups + + def real_download(self, filename, info_dict): + format_groups = self._group_formats_by_client(filename, info_dict) + + is_test = self.params.get('test', False) + resume = self.params.get('continuedl', True) + + for client_name, format_group in format_groups.items(): + formats = format_group['formats'] + audio_formats = (f for f in formats if f['format_type'] == 'audio') + video_formats = (f for f in formats if f['format_type'] == 'video') + caption_formats = (f for f in formats if f['format_type'] == 'caption') + for audio_format, video_format, caption_format in itertools.zip_longest(audio_formats, video_formats, caption_formats): + format_str = join_nonempty(*[ + traverse_obj(audio_format, 'display_name'), + traverse_obj(video_format, 'display_name'), + traverse_obj(caption_format, 'display_name')], delim='+') + self.write_debug(f'Downloading formats: {format_str} ({client_name} client)') + self._download_sabr_stream( + info_dict=info_dict, + video_format=video_format, + audio_format=audio_format, + caption_format=caption_format, + resume=resume, + is_test=is_test, + server_abr_streaming_url=format_group['server_abr_streaming_url'], + video_playback_ustreamer_config=format_group['video_playback_ustreamer_config'], + initial_po_token=format_group['initial_po_token'], + fetch_po_token_fn=format_group['fetch_po_token_fn'], + reload_config_fn=format_group['reload_config_fn'], + client_info=format_group['client_info'], + start_time_ms=format_group['start_time_ms'], + target_duration_sec=format_group.get('target_duration_sec', None), + live_status=format_group.get('live_status'), + video_id=format_group.get('video_id'), + ) + return True + + def _download_sabr_stream( + self, + video_id: str, + info_dict: dict, + video_format: dict, + audio_format: dict, + caption_format: dict, + resume: bool, + is_test: bool, + server_abr_streaming_url: str, + video_playback_ustreamer_config: str, + initial_po_token: str, + fetch_po_token_fn: callable | None = None, + reload_config_fn: callable | None = None, + client_info: ClientInfo | None = None, + start_time_ms: int = 0, + target_duration_sec: int | None = None, + live_status: str | None = None, + ): + + writers = {} + audio_selector = None + video_selector = None + caption_selector = None + + if audio_format: + audio_selector = AudioSelector( + display_name=audio_format['display_name'], format_ids=[audio_format['format_id']]) + writers[audio_selector.display_name] = SabrFDFormatWriter( + self, audio_format.get('filename'), + audio_format['info_dict'], len(writers), resume=resume) + + if video_format: + video_selector = VideoSelector( + display_name=video_format['display_name'], format_ids=[video_format['format_id']]) + writers[video_selector.display_name] = SabrFDFormatWriter( + self, video_format.get('filename'), + video_format['info_dict'], len(writers), resume=resume) + + if caption_format: + caption_selector = CaptionSelector( + display_name=caption_format['display_name'], format_ids=[caption_format['format_id']]) + writers[caption_selector.display_name] = SabrFDFormatWriter( + self, caption_format.get('filename'), + caption_format['info_dict'], len(writers), resume=resume) + + stream = SabrStream( + urlopen=self.ydl.urlopen, + logger=create_sabrfd_logger(self.ydl, prefix='sabr:stream'), + server_abr_streaming_url=server_abr_streaming_url, + video_playback_ustreamer_config=video_playback_ustreamer_config, + po_token=initial_po_token, + video_selection=video_selector, + audio_selection=audio_selector, + caption_selection=caption_selector, + start_time_ms=start_time_ms, + client_info=client_info, + live_segment_target_duration_sec=target_duration_sec, + post_live=live_status == 'post_live', + video_id=video_id, + retry_sleep_func=self.params.get('retry_sleep_functions', {}).get('http'), + ) + + self._prepare_multiline_status(len(writers) + 1) + + try: + total_bytes = 0 + for part in stream: + if is_test and total_bytes >= self._TEST_FILE_SIZE: + stream.close() + break + if isinstance(part, PoTokenStatusSabrPart): + if not fetch_po_token_fn: + self.report_warning( + 'No fetch PO token function found - this can happen if you use --load-info-json.' + ' The download will fail if a valid PO token is required.', only_once=True) + if part.status in ( + part.PoTokenStatus.INVALID, + part.PoTokenStatus.PENDING, + ): + # Fetch a PO token with bypass_cache=True + # (ensure we create a new one) + po_token = fetch_po_token_fn(bypass_cache=True) + if po_token: + stream.processor.po_token = po_token + elif part.status in ( + part.PoTokenStatus.MISSING, + part.PoTokenStatus.PENDING_MISSING, + ): + # Fetch a PO Token, bypass_cache=False + po_token = fetch_po_token_fn() + if po_token: + stream.processor.po_token = po_token + + elif isinstance(part, FormatInitializedSabrPart): + writer = writers.get(part.format_selector.display_name) + if not writer: + self.report_warning(f'Unknown format selector: {part.format_selector}') + continue + + writer.initialize_format(part.format_id) + initialized_format = stream.processor.initialized_formats[str(part.format_id)] + if writer.state.init_sequence: + initialized_format.init_segment = True + initialized_format.current_segment = None # allow a seek + + # Build consumed ranges from the sequences + consumed_ranges = [] + for sequence in writer.state.sequences: + consumed_ranges.append(ConsumedRange( + start_time_ms=sequence.first_segment.start_time_ms, + duration_ms=(sequence.last_segment.start_time_ms + sequence.last_segment.duration_ms) - sequence.first_segment.start_time_ms, + start_sequence_number=sequence.first_segment.sequence_number, + end_sequence_number=sequence.last_segment.sequence_number, + )) + if consumed_ranges: + initialized_format.consumed_ranges = consumed_ranges + initialized_format.current_segment = None # allow a seek + self.to_screen(f'[download] Resuming download for format {part.format_selector.display_name}') + + elif isinstance(part, MediaSegmentInitSabrPart): + writer = writers.get(part.format_selector.display_name) + if not writer: + self.report_warning(f'Unknown init format selector: {part.format_selector}') + continue + writer.initialize_segment(part) + + elif isinstance(part, MediaSegmentDataSabrPart): + total_bytes += len(part.data) # TODO: not reliable + writer = writers.get(part.format_selector.display_name) + if not writer: + self.report_warning(f'Unknown data format selector: {part.format_selector}') + continue + writer.write_segment_data(part) + + elif isinstance(part, MediaSegmentEndSabrPart): + writer = writers.get(part.format_selector.display_name) + if not writer: + self.report_warning(f'Unknown end format selector: {part.format_selector}') + continue + writer.end_segment(part) + + elif isinstance(part, RefreshPlayerResponseSabrPart): + self.to_screen(f'Refreshing player response; Reason: {part.reason}') + # In-place refresh - not ideal but should work in most cases + # TODO: handle case where live stream changes to non-livestream on refresh? + # TODO: if live, allow a seek as for non-DVR streams the reload may be longer than the buffer duration + # TODO: handle po token function change + if not reload_config_fn: + raise self.report_warning( + 'No reload config function found - cannot refresh SABR streaming URL.' + ' The url will expire soon and the download will fail.') + try: + stream.url, stream.processor.video_playback_ustreamer_config = reload_config_fn(part.reload_playback_token) + except (TransportError, HTTPError) as e: + self.report_warning(f'Failed to refresh SABR streaming URL: {e}') + + elif isinstance(part, MediaSeekSabrPart): + if ( + not info_dict.get('is_live') + and live_status not in ('post_live', 'is_live') + and not stream.processor.is_live + and part.reason == MediaSeekSabrPart.Reason.SERVER_SEEK + ): + raise DownloadError('Server tried to seek a video') + else: + self.to_screen(f'Unhandled part type: {part.__class__.__name__}') + + for writer in writers.values(): + writer.finish() + except SabrStreamError as e: + raise DownloadError(str(e)) from e + except KeyboardInterrupt: + if ( + not info_dict.get('is_live') + and not live_status == 'is_live' + and not stream.processor.is_live + ): + raise + self.to_screen('Interrupted by user') + for writer in writers.values(): + writer.finish() + finally: + # TODO: for livestreams, since we cannot resume them, should we finish the writers? + for writer in writers.values(): + writer.close() + + +def format_type(f): + if f.get('acodec') == 'none': + return 'video' + elif f.get('vcodec') == 'none': + return 'audio' + elif f.get('vcodec') is None and f.get('acodec') is None: + return 'caption' + return None diff --git a/yt_dlp/downloader/sabr/_file.py b/yt_dlp/downloader/sabr/_file.py new file mode 100644 index 000000000..340999bf7 --- /dev/null +++ b/yt_dlp/downloader/sabr/_file.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import dataclasses +from yt_dlp.utils import DownloadError +from ._io import DiskFormatIOBackend, MemoryFormatIOBackend + + +@dataclasses.dataclass +class Segment: + segment_id: str + content_length: int | None = None + content_length_estimated: bool = False + sequence_number: int | None = None + start_time_ms: int | None = None + duration_ms: int | None = None + duration_estimated: bool = False + is_init_segment: bool = False + + +@dataclasses.dataclass +class Sequence: + sequence_id: str + # The segments may not have a start byte range, so to keep it simple we will track + # length of the sequence. We can infer from this and the segment's content_length where they should end and begin. + sequence_content_length: int = 0 + first_segment: Segment | None = None + last_segment: Segment | None = None + + +class SequenceFile: + + def __init__(self, fd, format_filename, sequence: Sequence, resume=False): + self.fd = fd + self.format_filename = format_filename + self.sequence = sequence + self.file = DiskFormatIOBackend( + fd=self.fd, + filename=self.format_filename + f'.sq{self.sequence_id}.sabr.part', + ) + self.current_segment: SegmentFile | None = None + self.resume = resume + + sequence_file_exists = self.file.exists() + + if not resume and sequence_file_exists: + self.file.remove() + + elif not self.sequence.last_segment and sequence_file_exists: + self.file.remove() + + if self.sequence.last_segment and not sequence_file_exists: + raise DownloadError(f'Cannot find existing sequence {self.sequence_id} file') + + if self.sequence.last_segment and not self.file.validate_length(self.sequence.sequence_content_length): + self.file.remove() + raise DownloadError(f'Existing sequence {self.sequence_id} file is not valid; removing') + + @property + def sequence_id(self): + return self.sequence.sequence_id + + @property + def current_length(self): + total = self.sequence.sequence_content_length + if self.current_segment: + total += self.current_segment.current_length + return total + + def is_next_segment(self, segment: Segment): + if self.current_segment: + return False + latest_segment = self.sequence.last_segment or self.sequence.first_segment + if not latest_segment: + return True + if segment.is_init_segment and latest_segment.is_init_segment: + # Only one segment allowed for init segments + return False + return segment.sequence_number == latest_segment.sequence_number + 1 + + def is_current_segment(self, segment_id: str): + if not self.current_segment: + return False + return self.current_segment.segment_id == segment_id + + def initialize_segment(self, segment: Segment): + if self.current_segment and not self.is_current_segment(segment.segment_id): + raise ValueError('Cannot reinitialize a segment that does not match the current segment') + + if not self.current_segment and not self.is_next_segment(segment): + raise ValueError('Cannot initialize a segment that does not match the next segment') + + self.current_segment = SegmentFile( + fd=self.fd, + format_filename=self.format_filename, + segment=segment, + ) + + def write_segment_data(self, data, segment_id: str): + if not self.is_current_segment(segment_id): + raise ValueError('Cannot write to a segment that does not match the current segment') + + self.current_segment.write(data) + + def end_segment(self, segment_id): + if not self.is_current_segment(segment_id): + raise ValueError('Cannot end a segment that does not exist') + + self.current_segment.finish_write() + + if ( + self.current_segment.segment.content_length + and not self.current_segment.segment.content_length_estimated + and self.current_segment.current_length != self.current_segment.segment.content_length + ): + raise DownloadError( + f'Filesize mismatch for segment {self.current_segment.segment_id}: ' + f'Expected {self.current_segment.segment.content_length} bytes, got {self.current_segment.current_length} bytes') + + self.current_segment.segment.content_length = self.current_segment.current_length + self.current_segment.segment.content_length_estimated = False + + if not self.sequence.first_segment: + self.sequence.first_segment = self.current_segment.segment + + self.sequence.last_segment = self.current_segment.segment + self.sequence.sequence_content_length += self.current_segment.current_length + + if not self.file.mode: + self.file.initialize_writer(self.resume) + + self.current_segment.read_into(self.file) + self.current_segment.remove() + self.current_segment = None + + def read_into(self, backend): + self.file.initialize_reader() + self.file.read_into(backend) + self.file.close() + + def remove(self): + self.close() + self.file.remove() + + def close(self): + self.file.close() + + +class SegmentFile: + + def __init__(self, fd, format_filename, segment: Segment, memory_file_limit=2 * 1024 * 1024): + self.fd = fd + self.format_filename = format_filename + self.segment: Segment = segment + self.current_length = 0 + + filename = format_filename + f'.sg{segment.sequence_number}.sabr.part' + # Store the segment in memory if it is small enough + if segment.content_length and segment.content_length <= memory_file_limit: + self.file = MemoryFormatIOBackend( + fd=self.fd, + filename=filename, + ) + else: + self.file = DiskFormatIOBackend( + fd=self.fd, + filename=filename, + ) + + # Never resume a segment + exists = self.file.exists() + if exists: + self.file.remove() + + @property + def segment_id(self): + return self.segment.segment_id + + def write(self, data): + if not self.file.mode: + self.file.initialize_writer(resume=False) + self.current_length += self.file.write(data) + + def read_into(self, file): + self.file.initialize_reader() + self.file.read_into(file) + self.file.close() + + def remove(self): + self.close() + self.file.remove() + + def finish_write(self): + self.close() + + def close(self): + self.file.close() diff --git a/yt_dlp/downloader/sabr/_io.py b/yt_dlp/downloader/sabr/_io.py new file mode 100644 index 000000000..b875a9435 --- /dev/null +++ b/yt_dlp/downloader/sabr/_io.py @@ -0,0 +1,162 @@ +from __future__ import annotations +import abc +import io +import os +import shutil +import typing + + +class FormatIOBackend(abc.ABC): + def __init__(self, fd, filename, buffer=1024 * 1024): + self.fd = fd + self.filename = filename + self.write_buffer = buffer + self._fp = None + self._fp_mode = None + + @property + def writer(self): + if self._fp is None or self._fp_mode != 'write': + return None + return self._fp + + @property + def reader(self): + if self._fp is None or self._fp_mode != 'read': + return None + return self._fp + + def initialize_writer(self, resume=False): + if self._fp is not None: + raise ValueError('Backend already initialized') + + self._fp = self._create_writer(resume) + self._fp_mode = 'write' + + @abc.abstractmethod + def _create_writer(self, resume=False) -> typing.IO: + pass + + def initialize_reader(self): + if self._fp is not None: + raise ValueError('Backend already initialized') + self._fp = self._create_reader() + self._fp_mode = 'read' + + @abc.abstractmethod + def _create_reader(self) -> typing.IO: + pass + + def close(self): + if self._fp and not self._fp.closed: + self._fp.flush() + self._fp.close() + self._fp = None + self._fp_mode = None + + @abc.abstractmethod + def validate_length(self, expected_length): + pass + + def remove(self): + self.close() + self._remove() + + @abc.abstractmethod + def _remove(self): + pass + + @abc.abstractmethod + def exists(self): + pass + + @property + def mode(self): + if self._fp is None: + return None + return self._fp_mode + + def write(self, data: io.BufferedIOBase | bytes): + if not self.writer: + raise ValueError('Backend writer not initialized') + + if isinstance(data, bytes): + bytes_written = self.writer.write(data) + elif isinstance(data, io.BufferedIOBase): + bytes_written = self.writer.tell() + shutil.copyfileobj(data, self.writer, length=self.write_buffer) + bytes_written = self.writer.tell() - bytes_written + else: + raise TypeError('Data must be bytes or a BufferedIOBase object') + + self.writer.flush() + + return bytes_written + + def read_into(self, backend): + if not backend.writer: + raise ValueError('Backend writer not initialized') + if not self.reader: + raise ValueError('Backend reader not initialized') + shutil.copyfileobj(self.reader, backend.writer, length=self.write_buffer) + backend.writer.flush() + + +class DiskFormatIOBackend(FormatIOBackend): + def _create_writer(self, resume=False) -> typing.IO: + if resume and self.exists(): + write_fp, self.filename = self.fd.sanitize_open(self.filename, 'ab') + else: + write_fp, self.filename = self.fd.sanitize_open(self.filename, 'wb') + return write_fp + + def _create_reader(self) -> typing.IO: + read_fp, self.filename = self.fd.sanitize_open(self.filename, 'rb') + return read_fp + + def validate_length(self, expected_length): + return os.path.getsize(self.filename) == expected_length + + def _remove(self): + self.fd.try_remove(self.filename) + + def exists(self): + return os.path.isfile(self.filename) + + +class MemoryFormatIOBackend(FormatIOBackend): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._memory_store = io.BytesIO() + + def _create_writer(self, resume=False) -> typing.IO: + class NonClosingBufferedWriter(io.BufferedWriter): + def close(self): + self.flush() + # Do not close the underlying buffer + + if resume and self.exists(): + self._memory_store.seek(0, io.SEEK_END) + else: + self._memory_store.seek(0) + self._memory_store.truncate(0) + + return NonClosingBufferedWriter(self._memory_store) + + def _create_reader(self) -> typing.IO: + class NonClosingBufferedReader(io.BufferedReader): + def close(self): + self.flush() + + # Seek to the beginning of the buffer + self._memory_store.seek(0) + return NonClosingBufferedReader(self._memory_store) + + def validate_length(self, expected_length): + return self._memory_store.getbuffer().nbytes != expected_length + + def _remove(self): + self._memory_store = io.BytesIO() + + def exists(self): + return self._memory_store.getbuffer().nbytes > 0 diff --git a/yt_dlp/downloader/sabr/_logger.py b/yt_dlp/downloader/sabr/_logger.py new file mode 100644 index 000000000..407f0bb27 --- /dev/null +++ b/yt_dlp/downloader/sabr/_logger.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from yt_dlp.utils import format_field, traverse_obj +from yt_dlp.extractor.youtube._streaming.sabr.models import SabrLogger +from yt_dlp.utils._utils import _YDLLogger + +# TODO: create a logger that logs to a file rather than the console. +# Might be useful for debugging SABR issues from users. + + +class SabrFDLogger(SabrLogger): + def __init__(self, ydl, prefix, log_level: SabrLogger.LogLevel | None = None): + self._ydl_logger = _YDLLogger(ydl) + self.prefix = prefix + self.log_level = log_level if log_level is not None else self.LogLevel.INFO + + def _format_msg(self, message: str): + prefixstr = format_field(self.prefix, None, '[%s] ') + return f'{prefixstr}{message}' + + def trace(self, message: str): + if self.log_level <= self.LogLevel.TRACE: + self._ydl_logger.debug(self._format_msg('TRACE: ' + message)) + + def debug(self, message: str): + if self.log_level <= self.LogLevel.DEBUG: + self._ydl_logger.debug(self._format_msg(message)) + + def info(self, message: str): + if self.log_level <= self.LogLevel.INFO: + self._ydl_logger.info(self._format_msg(message)) + + def warning(self, message: str, *, once=False): + if self.log_level <= self.LogLevel.WARNING: + self._ydl_logger.warning(self._format_msg(message), once=once) + + def error(self, message: str): + if self.log_level <= self.LogLevel.ERROR: + self._ydl_logger.error(self._format_msg(message), is_error=False) + + +def create_sabrfd_logger(ydl, prefix): + return SabrFDLogger( + ydl, prefix=prefix, + log_level=SabrFDLogger.LogLevel(traverse_obj( + ydl.params, ('extractor_args', 'youtube', 'sabr_log_level', 0, {str}), get_all=False))) diff --git a/yt_dlp/downloader/sabr/_state.py b/yt_dlp/downloader/sabr/_state.py new file mode 100644 index 000000000..5663c58bf --- /dev/null +++ b/yt_dlp/downloader/sabr/_state.py @@ -0,0 +1,77 @@ +from __future__ import annotations +import contextlib +import os +import tempfile + +from yt_dlp.dependencies import protobug +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId + + +@protobug.message +class SabrStateSegment: + sequence_number: protobug.Int32 = protobug.field(1) + start_time_ms: protobug.Int64 = protobug.field(2) + duration_ms: protobug.Int64 = protobug.field(3) + duration_estimated: protobug.Bool = protobug.field(4) + content_length: protobug.Int64 = protobug.field(5) + + +@protobug.message +class SabrStateSequence: + sequence_start_number: protobug.Int32 = protobug.field(1) + sequence_content_length: protobug.Int64 = protobug.field(2) + first_segment: SabrStateSegment = protobug.field(3) + last_segment: SabrStateSegment = protobug.field(4) + + +@protobug.message +class SabrStateInitSegment: + content_length: protobug.Int64 = protobug.field(2) + + +@protobug.message +class SabrState: + format_id: FormatId = protobug.field(1) + init_segment: SabrStateInitSegment | None = protobug.field(2, default=None) + sequences: list[SabrStateSequence] = protobug.field(3, default_factory=list) + + +class SabrStateFile: + + def __init__(self, format_filename, fd): + self.filename = format_filename + '.sabr.state' + self.fd = fd + + @property + def exists(self): + return os.path.isfile(self.filename) + + def retrieve(self): + stream, self.filename = self.fd.sanitize_open(self.filename, 'rb') + try: + return self.deserialize(stream.read()) + finally: + stream.close() + + def update(self, sabr_document): + # Attempt to write progress document somewhat atomically to avoid corruption + tf = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(self.filename)) + try: + with open(tf.name, 'wb') as f: + f.write(self.serialize(sabr_document)) + f.flush() + os.fsync(f.fileno()) + os.replace(tf.name, self.filename) + finally: + if os.path.exists(tf.name): + with contextlib.suppress(FileNotFoundError, OSError): + os.unlink(tf.name) + + def serialize(self, sabr_document): + return protobug.dumps(sabr_document) + + def deserialize(self, data): + return protobug.loads(data, SabrState) + + def remove(self): + self.fd.try_remove(self.filename) diff --git a/yt_dlp/downloader/sabr/_writer.py b/yt_dlp/downloader/sabr/_writer.py new file mode 100644 index 000000000..3bc5f3373 --- /dev/null +++ b/yt_dlp/downloader/sabr/_writer.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +import dataclasses + +from ._io import DiskFormatIOBackend +from ._file import SequenceFile, Sequence, Segment +from ._state import ( + SabrStateSegment, + SabrStateSequence, + SabrStateInitSegment, + SabrState, + SabrStateFile, +) +from yt_dlp.extractor.youtube._streaming.sabr.part import ( + MediaSegmentInitSabrPart, + MediaSegmentDataSabrPart, + MediaSegmentEndSabrPart, +) + +from yt_dlp.utils import DownloadError +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId +from yt_dlp.utils.progress import ProgressCalculator + +INIT_SEGMENT_ID = 'i' + + +@dataclasses.dataclass +class SabrFormatState: + format_id: FormatId + init_sequence: Sequence | None = None + sequences: list[Sequence] = dataclasses.field(default_factory=list) + + +class SabrFDFormatWriter: + def __init__(self, fd, filename, infodict, progress_idx=0, resume=False): + self.fd = fd + self.info_dict = infodict + self.filename = filename + self.progress_idx = progress_idx + self.resume = resume + + self._progress = None + self._downloaded_bytes = 0 + self._state = {} + self._format_id = None + + self.file = DiskFormatIOBackend( + fd=self.fd, + filename=self.fd.temp_name(filename), + ) + self._sabr_state_file = SabrStateFile(format_filename=self.filename, fd=fd) + self._sequence_files: list[SequenceFile] = [] + self._init_sequence: SequenceFile | None = None + + @property + def state(self): + return SabrFormatState( + format_id=self._format_id, + init_sequence=self._init_sequence.sequence if self._init_sequence else None, + sequences=[sf.sequence for sf in self._sequence_files], + ) + + @property + def downloaded_bytes(self): + return (sum( + sequence.current_length for sequence in self._sequence_files) + + (self._init_sequence.current_length if self._init_sequence else 0)) + + def initialize_format(self, format_id): + if self._format_id: + raise ValueError('Already initialized') + self._format_id = format_id + + if not self.resume: + if self._sabr_state_file.exists: + self._sabr_state_file.remove() + return + + document = self._load_sabr_state() + + if document.init_segment: + init_segment = Segment( + segment_id=INIT_SEGMENT_ID, + content_length=document.init_segment.content_length, + is_init_segment=True, + ) + + try: + self._init_sequence = SequenceFile( + fd=self.fd, + format_filename=self.filename, + resume=True, + sequence=Sequence( + sequence_id=INIT_SEGMENT_ID, + sequence_content_length=init_segment.content_length, + first_segment=init_segment, + last_segment=init_segment, + )) + except DownloadError as e: + self.fd.report_warning(f'Failed to resume init segment for format {self.info_dict.get("format_id")}: {e}') + + for sabr_sequence in list(document.sequences): + try: + self._sequence_files.append(SequenceFile( + fd=self.fd, + format_filename=self.filename, + resume=True, + sequence=Sequence( + sequence_id=str(sabr_sequence.sequence_start_number), + sequence_content_length=sabr_sequence.sequence_content_length, + first_segment=Segment( + segment_id=str(sabr_sequence.first_segment.sequence_number), + sequence_number=sabr_sequence.first_segment.sequence_number, + content_length=sabr_sequence.first_segment.content_length, + start_time_ms=sabr_sequence.first_segment.start_time_ms, + duration_ms=sabr_sequence.first_segment.duration_ms, + is_init_segment=False, + ), + last_segment=Segment( + segment_id=str(sabr_sequence.last_segment.sequence_number), + sequence_number=sabr_sequence.last_segment.sequence_number, + content_length=sabr_sequence.last_segment.content_length, + start_time_ms=sabr_sequence.last_segment.start_time_ms, + duration_ms=sabr_sequence.last_segment.duration_ms, + is_init_segment=False, + ), + ), + )) + except DownloadError as e: + self.fd.report_warning( + f'Failed to resume sequence {sabr_sequence.sequence_start_number} ' + f'for format {self.info_dict.get("format_id")}: {e}') + + @property + def initialized(self): + return self._format_id is not None + + def close(self): + if not self.file: + raise ValueError('Already closed') + for sequence in self._sequence_files: + sequence.close() + self._sequence_files.clear() + if self._init_sequence: + self._init_sequence.close() + self._init_sequence = None + self.file.close() + + def _find_sequence_file(self, predicate): + match = None + for sequence in self._sequence_files: + if predicate(sequence): + if match is not None: + raise DownloadError('Multiple sequence files found for segment') + match = sequence + return match + + def find_next_sequence_file(self, next_segment: Segment): + return self._find_sequence_file(lambda sequence: sequence.is_next_segment(next_segment)) + + def find_current_sequence_file(self, segment_id: str): + return self._find_sequence_file(lambda sequence: sequence.is_current_segment(segment_id)) + + def initialize_segment(self, part: MediaSegmentInitSabrPart): + if not self._progress: + self._progress = ProgressCalculator(part.start_bytes) + + if not self._format_id: + raise ValueError('not initialized') + + if part.is_init_segment: + if not self._init_sequence: + self._init_sequence = SequenceFile( + fd=self.fd, + format_filename=self.filename, + resume=False, + sequence=Sequence( + sequence_id=INIT_SEGMENT_ID, + )) + + self._init_sequence.initialize_segment(Segment( + segment_id=INIT_SEGMENT_ID, + content_length=part.content_length, + content_length_estimated=part.content_length_estimated, + is_init_segment=True, + )) + return True + + segment = Segment( + segment_id=str(part.sequence_number), + sequence_number=part.sequence_number, + start_time_ms=part.start_time_ms, + duration_ms=part.duration_ms, + duration_estimated=part.duration_estimated, + content_length=part.content_length, + content_length_estimated=part.content_length_estimated, + ) + + sequence_file = self.find_current_sequence_file(segment.segment_id) or self.find_next_sequence_file(segment) + + if not sequence_file: + sequence_file = SequenceFile( + fd=self.fd, + format_filename=self.filename, + resume=False, + sequence=Sequence(sequence_id=str(part.sequence_number)), + ) + self._sequence_files.append(sequence_file) + + sequence_file.initialize_segment(segment) + return True + + def write_segment_data(self, part: MediaSegmentDataSabrPart): + if part.is_init_segment: + sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID + else: + segment_id = str(part.sequence_number) + sequence_file = self.find_current_sequence_file(segment_id) + + if not sequence_file: + raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?') + + sequence_file.write_segment_data(part.data, segment_id) + + # TODO: Handling of disjointed segments (e.g. when downloading segments out of order / concurrently) + self._progress.total = self.info_dict.get('filesize') + self._state = { + 'status': 'downloading', + 'downloaded_bytes': self.downloaded_bytes, + 'total_bytes': self.info_dict.get('filesize'), + 'filename': self.filename, + 'eta': self._progress.eta.smooth, + 'speed': self._progress.speed.smooth, + 'elapsed': self._progress.elapsed, + 'progress_idx': self.progress_idx, + 'fragment_count': part.total_segments, + 'fragment_index': part.sequence_number, + } + + self._progress.update(self._state['downloaded_bytes']) + self.fd._hook_progress(self._state, self.info_dict) + + def end_segment(self, part: MediaSegmentEndSabrPart): + if part.is_init_segment: + sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID + else: + segment_id = str(part.sequence_number) + sequence_file = self.find_current_sequence_file(segment_id) + + if not sequence_file: + raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?') + + sequence_file.end_segment(segment_id) + self._write_sabr_state() + + def _load_sabr_state(self): + sabr_state = None + if self._sabr_state_file.exists: + try: + sabr_state = self._sabr_state_file.retrieve() + except Exception: + self.fd.report_warning( + f'Corrupted state file for format {self.info_dict.get("format_id")}, restarting download') + + if sabr_state and sabr_state.format_id != self._format_id: + self.fd.report_warning( + f'Format ID mismatch in state file for {self.info_dict.get("format_id")}, restarting download') + sabr_state = None + + if not sabr_state: + sabr_state = SabrState(format_id=self._format_id) + + return sabr_state + + def _write_sabr_state(self): + sabr_state = SabrState(format_id=self._format_id) + + if not self._init_sequence: + sabr_state.init_segment = None + else: + sabr_state.init_segment = SabrStateInitSegment( + content_length=self._init_sequence.sequence.sequence_content_length, + ) + + sabr_state.sequences = [] + for sequence_file in self._sequence_files: + # Ignore partial sequences + if not sequence_file.sequence.first_segment or not sequence_file.sequence.last_segment: + continue + sabr_state.sequences.append(SabrStateSequence( + sequence_start_number=sequence_file.sequence.first_segment.sequence_number, + sequence_content_length=sequence_file.sequence.sequence_content_length, + first_segment=SabrStateSegment( + sequence_number=sequence_file.sequence.first_segment.sequence_number, + start_time_ms=sequence_file.sequence.first_segment.start_time_ms, + duration_ms=sequence_file.sequence.first_segment.duration_ms, + duration_estimated=sequence_file.sequence.first_segment.duration_estimated, + content_length=sequence_file.sequence.first_segment.content_length, + ), + last_segment=SabrStateSegment( + sequence_number=sequence_file.sequence.last_segment.sequence_number, + start_time_ms=sequence_file.sequence.last_segment.start_time_ms, + duration_ms=sequence_file.sequence.last_segment.duration_ms, + duration_estimated=sequence_file.sequence.last_segment.duration_estimated, + content_length=sequence_file.sequence.last_segment.content_length, + ), + )) + + self._sabr_state_file.update(sabr_state) + + def finish(self): + self._state['status'] = 'finished' + self.fd._hook_progress(self._state, self.info_dict) + + for sequence_file in self._sequence_files: + sequence_file.close() + + if self._init_sequence: + self._init_sequence.close() + + # Now merge all the sequences together + self.file.initialize_writer(resume=False) + + # Note: May not always be an init segment, e.g for live streams + if self._init_sequence: + self._init_sequence.read_into(self.file) + self._init_sequence.close() + + # TODO: handling of disjointed segments + previous_seq_number = None + for sequence_file in sorted( + (sf for sf in self._sequence_files if sf.sequence.first_segment), + key=lambda s: s.sequence.first_segment.sequence_number): + if previous_seq_number and previous_seq_number + 1 != sequence_file.sequence.first_segment.sequence_number: + self.fd.report_warning(f'Disjointed sequences found in SABR format {self.info_dict.get("format_id")}') + previous_seq_number = sequence_file.sequence.last_segment.sequence_number + sequence_file.read_into(self.file) + sequence_file.close() + + # Format temp file should have all the segments, rename it to the final name + self.file.close() + self.fd.try_rename(self.file.filename, self.fd.undo_temp_name(self.file.filename)) + + # Remove the state file + self._sabr_state_file.remove() + + # Remove sequence files + for sf in self._sequence_files: + sf.close() + sf.remove() + + if self._init_sequence: + self._init_sequence.close() + self._init_sequence.remove() + self.close() diff --git a/yt_dlp/extractor/youtube/_proto/__init__.py b/yt_dlp/extractor/youtube/_proto/__init__.py new file mode 100644 index 000000000..12de76d50 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/__init__.py @@ -0,0 +1,14 @@ +import dataclasses +import typing + + +def unknown_fields(obj: typing.Any, path=()) -> typing.Iterable[tuple[tuple[str, ...], dict[int, list]]]: + if not dataclasses.is_dataclass(obj): + return + + if unknown := getattr(obj, '_unknown', None): + yield path, unknown + + for field in dataclasses.fields(obj): + value = getattr(obj, field.name) + yield from unknown_fields(value, (*path, field.name)) diff --git a/yt_dlp/extractor/youtube/_proto/innertube/__init__.py b/yt_dlp/extractor/youtube/_proto/innertube/__init__.py new file mode 100644 index 000000000..df4f4d05d --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/innertube/__init__.py @@ -0,0 +1,5 @@ +from .client_info import ClientInfo, ClientName # noqa: F401 +from .compression_algorithm import CompressionAlgorithm # noqa: F401 +from .next_request_policy import NextRequestPolicy # noqa: F401 +from .range import Range # noqa: F401 +from .seek_source import SeekSource # noqa: F401 diff --git a/yt_dlp/extractor/youtube/_proto/innertube/client_info.py b/yt_dlp/extractor/youtube/_proto/innertube/client_info.py new file mode 100644 index 000000000..e6651ca30 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/innertube/client_info.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +class ClientName(protobug.Enum, strict=False): + UNKNOWN_INTERFACE = 0 + WEB = 1 + MWEB = 2 + ANDROID = 3 + IOS = 5 + TVHTML5 = 7 + TVLITE = 8 + TVANDROID = 10 + XBOX = 11 + CLIENTX = 12 + XBOXONEGUIDE = 13 + ANDROID_CREATOR = 14 + IOS_CREATOR = 15 + TVAPPLE = 16 + IOS_INSTANT = 17 + ANDROID_KIDS = 18 + IOS_KIDS = 19 + ANDROID_INSTANT = 20 + ANDROID_MUSIC = 21 + IOS_TABLOID = 22 + ANDROID_TV = 23 + ANDROID_GAMING = 24 + IOS_GAMING = 25 + IOS_MUSIC = 26 + MWEB_TIER_2 = 27 + ANDROID_VR = 28 + ANDROID_UNPLUGGED = 29 + ANDROID_TESTSUITE = 30 + WEB_MUSIC_ANALYTICS = 31 + WEB_GAMING = 32 + IOS_UNPLUGGED = 33 + ANDROID_WITNESS = 34 + IOS_WITNESS = 35 + ANDROID_SPORTS = 36 + IOS_SPORTS = 37 + ANDROID_LITE = 38 + IOS_EMBEDDED_PLAYER = 39 + IOS_DIRECTOR = 40 + WEB_UNPLUGGED = 41 + WEB_EXPERIMENTS = 42 + TVHTML5_CAST = 43 + WEB_EMBEDDED_PLAYER = 56 + TVHTML5_AUDIO = 57 + TV_UNPLUGGED_CAST = 58 + TVHTML5_KIDS = 59 + WEB_HEROES = 60 + WEB_MUSIC = 61 + WEB_CREATOR = 62 + TV_UNPLUGGED_ANDROID = 63 + IOS_LIVE_CREATION_EXTENSION = 64 + TVHTML5_UNPLUGGED = 65 + IOS_MESSAGES_EXTENSION = 66 + WEB_REMIX = 67 + IOS_UPTIME = 68 + WEB_UNPLUGGED_ONBOARDING = 69 + WEB_UNPLUGGED_OPS = 70 + WEB_UNPLUGGED_PUBLIC = 71 + TVHTML5_VR = 72 + WEB_LIVE_STREAMING = 73 + ANDROID_TV_KIDS = 74 + TVHTML5_SIMPLY = 75 + WEB_KIDS = 76 + MUSIC_INTEGRATIONS = 77 + TVHTML5_YONGLE = 80 + GOOGLE_ASSISTANT = 84 + TVHTML5_SIMPLY_EMBEDDED_PLAYER = 85 + WEB_MUSIC_EMBEDDED_PLAYER = 86 + WEB_INTERNAL_ANALYTICS = 87 + WEB_PARENT_TOOLS = 88 + GOOGLE_MEDIA_ACTIONS = 89 + WEB_PHONE_VERIFICATION = 90 + ANDROID_PRODUCER = 91 + IOS_PRODUCER = 92 + TVHTML5_FOR_KIDS = 93 + GOOGLE_LIST_RECS = 94 + MEDIA_CONNECT_FRONTEND = 95 + WEB_EFFECT_MAKER = 98 + WEB_SHOPPING_EXTENSION = 99 + WEB_PLAYABLES_PORTAL = 100 + VISIONOS = 101 + WEB_LIVE_APPS = 102 + WEB_MUSIC_INTEGRATIONS = 103 + ANDROID_MUSIC_AOSP = 104 + + +@protobug.message +class ClientInfo: + hl: protobug.String | None = protobug.field(1, default=None) + gl: protobug.String | None = protobug.field(2, default=None) + remote_host: protobug.String | None = protobug.field(4, default=None) + + device_make: protobug.String | None = protobug.field(12, default=None) + device_model: protobug.String | None = protobug.field(13, default=None) + visitor_data: protobug.String | None = protobug.field(14, default=None) + user_agent: protobug.String | None = protobug.field(15, default=None) + client_name: ClientName | None = protobug.field(16, default=None) + client_version: protobug.String | None = protobug.field(17, default=None) + os_name: protobug.String | None = protobug.field(18, default=None) + os_version: protobug.String | None = protobug.field(19, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/innertube/compression_algorithm.py b/yt_dlp/extractor/youtube/_proto/innertube/compression_algorithm.py new file mode 100644 index 000000000..c1b039069 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/innertube/compression_algorithm.py @@ -0,0 +1,7 @@ +from yt_dlp.dependencies import protobug + + +class CompressionAlgorithm(protobug.Enum, strict=False): + COMPRESSION_ALGORITHM_UNKNOWN = 0 + COMPRESSION_ALGORITHM_NONE = 1 + COMPRESSION_ALGORITHM_GZIP = 2 diff --git a/yt_dlp/extractor/youtube/_proto/innertube/next_request_policy.py b/yt_dlp/extractor/youtube/_proto/innertube/next_request_policy.py new file mode 100644 index 000000000..c670add97 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/innertube/next_request_policy.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class NextRequestPolicy: + target_audio_readahead_ms: protobug.Int32 | None = protobug.field(1, default=None) + target_video_readahead_ms: protobug.Int32 | None = protobug.field(2, default=None) + max_time_since_last_request_ms: protobug.Int32 | None = protobug.field(3, default=None) + backoff_time_ms: protobug.Int32 | None = protobug.field(4, default=None) + min_audio_readahead_ms: protobug.Int32 | None = protobug.field(5, default=None) + min_video_readahead_ms: protobug.Int32 | None = protobug.field(6, default=None) + playback_cookie: protobug.Bytes | None = protobug.field(7, default=None) + video_id: protobug.String | None = protobug.field(8, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/innertube/range.py b/yt_dlp/extractor/youtube/_proto/innertube/range.py new file mode 100644 index 000000000..8d21eb076 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/innertube/range.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class Range: + start: protobug.Int64 | None = protobug.field(1, default=None) + end: protobug.Int64 | None = protobug.field(2, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/innertube/seek_source.py b/yt_dlp/extractor/youtube/_proto/innertube/seek_source.py new file mode 100644 index 000000000..ba6ddbdf6 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/innertube/seek_source.py @@ -0,0 +1,14 @@ +from yt_dlp.dependencies import protobug + + +class SeekSource(protobug.Enum, strict=False): + SEEK_SOURCE_UNKNOWN = 0 + SEEK_SOURCE_SABR_PARTIAL_CHUNK = 9 + SEEK_SOURCE_SABR_SEEK_TO_HEAD = 10 + SEEK_SOURCE_SABR_LIVE_DVR_USER_SEEK = 11 + SEEK_SOURCE_SABR_SEEK_TO_DVR_LOWER_BOUND = 12 + SEEK_SOURCE_SABR_SEEK_TO_DVR_UPPER_BOUND = 13 + SEEK_SOURCE_SABR_ACCURATE_SEEK = 17 + SEEK_SOURCE_SABR_INGESTION_WALL_TIME_SEEK = 29 + SEEK_SOURCE_SABR_SEEK_TO_CLOSEST_KEYFRAME = 59 + SEEK_SOURCE_SABR_RELOAD_PLAYER_RESPONSE_TOKEN_SEEK = 106 diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/__init__.py b/yt_dlp/extractor/youtube/_proto/videostreaming/__init__.py new file mode 100644 index 000000000..157477769 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/__init__.py @@ -0,0 +1,16 @@ +from .buffered_range import BufferedRange # noqa: F401 +from .client_abr_state import ClientAbrState # noqa: F401 +from .format_id import FormatId # noqa: F401 +from .format_initialization_metadata import FormatInitializationMetadata # noqa: F401 +from .live_metadata import LiveMetadata # noqa: F401 +from .media_header import MediaHeader # noqa: F401 +from .reload_player_response import ReloadPlayerResponse # noqa: F401 +from .sabr_context_sending_policy import SabrContextSendingPolicy # noqa: F401 +from .sabr_context_update import SabrContextUpdate # noqa: F401 +from .sabr_error import SabrError # noqa: F401 +from .sabr_redirect import SabrRedirect # noqa: F401 +from .sabr_seek import SabrSeek # noqa: F401 +from .stream_protection_status import StreamProtectionStatus # noqa: F401 +from .streamer_context import SabrContext, StreamerContext # noqa: F401 +from .time_range import TimeRange # noqa: F401 +from .video_playback_abr_request import VideoPlaybackAbrRequest # noqa: F401 diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/buffered_range.py b/yt_dlp/extractor/youtube/_proto/videostreaming/buffered_range.py new file mode 100644 index 000000000..5f22fa541 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/buffered_range.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + +from .format_id import FormatId +from .time_range import TimeRange + + +@protobug.message +class BufferedRange: + format_id: FormatId | None = protobug.field(1, default=None) + start_time_ms: protobug.Int64 | None = protobug.field(2, default=None) + duration_ms: protobug.Int64 | None = protobug.field(3, default=None) + start_segment_index: protobug.Int32 | None = protobug.field(4, default=None) + end_segment_index: protobug.Int32 | None = protobug.field(5, default=None) + time_range: TimeRange | None = protobug.field(6, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/client_abr_state.py b/yt_dlp/extractor/youtube/_proto/videostreaming/client_abr_state.py new file mode 100644 index 000000000..d27f4667e --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/client_abr_state.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class ClientAbrState: + player_time_ms: protobug.Int64 | None = protobug.field(28, default=None) + enabled_track_types_bitfield: protobug.Int32 | None = protobug.field(40, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/format_id.py b/yt_dlp/extractor/youtube/_proto/videostreaming/format_id.py new file mode 100644 index 000000000..59afa8b28 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/format_id.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class FormatId: + itag: protobug.Int32 | None = protobug.field(1) + lmt: protobug.UInt64 | None = protobug.field(2, default=None) + xtags: protobug.String | None = protobug.field(3, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/format_initialization_metadata.py b/yt_dlp/extractor/youtube/_proto/videostreaming/format_initialization_metadata.py new file mode 100644 index 000000000..7786d885b --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/format_initialization_metadata.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + +from .format_id import FormatId +from ..innertube import Range + + +@protobug.message +class FormatInitializationMetadata: + video_id: protobug.String = protobug.field(1, default=None) + format_id: FormatId = protobug.field(2, default=None) + end_time_ms: protobug.Int32 | None = protobug.field(3, default=None) + total_segments: protobug.Int32 | None = protobug.field(4, default=None) + mime_type: protobug.String | None = protobug.field(5, default=None) + init_range: Range | None = protobug.field(6, default=None) + index_range: Range | None = protobug.field(7, default=None) + duration_ticks: protobug.Int32 | None = protobug.field(9, default=None) + duration_timescale: protobug.Int32 | None = protobug.field(10, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/live_metadata.py b/yt_dlp/extractor/youtube/_proto/videostreaming/live_metadata.py new file mode 100644 index 000000000..41019f2fa --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/live_metadata.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class LiveMetadata: + head_sequence_number: protobug.Int32 | None = protobug.field(3, default=None) + head_sequence_time_ms: protobug.Int64 | None = protobug.field(4, default=None) + wall_time_ms: protobug.Int64 | None = protobug.field(5, default=None) + video_id: protobug.String | None = protobug.field(6, default=None) + source: protobug.String | None = protobug.field(7, default=None) + + min_seekable_time_ticks: protobug.Int64 | None = protobug.field(12, default=None) + min_seekable_timescale: protobug.Int32 | None = protobug.field(13, default=None) + + max_seekable_time_ticks: protobug.Int64 | None = protobug.field(14, default=None) + max_seekable_timescale: protobug.Int32 | None = protobug.field(15, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/media_header.py b/yt_dlp/extractor/youtube/_proto/videostreaming/media_header.py new file mode 100644 index 000000000..61a19a107 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/media_header.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + +from .format_id import FormatId +from .time_range import TimeRange +from ..innertube import CompressionAlgorithm + + +@protobug.message +class MediaHeader: + header_id: protobug.UInt32 | None = protobug.field(1, default=None) + video_id: protobug.String | None = protobug.field(2, default=None) + itag: protobug.Int32 | None = protobug.field(3, default=None) + last_modified: protobug.UInt64 | None = protobug.field(4, default=None) + xtags: protobug.String | None = protobug.field(5, default=None) + start_data_range: protobug.Int32 | None = protobug.field(6, default=None) + compression: CompressionAlgorithm | None = protobug.field(7, default=None) + is_init_segment: protobug.Bool | None = protobug.field(8, default=None) + sequence_number: protobug.Int64 | None = protobug.field(9, default=None) + bitrate_bps: protobug.Int64 | None = protobug.field(10, default=None) + start_ms: protobug.Int32 | None = protobug.field(11, default=None) + duration_ms: protobug.Int32 | None = protobug.field(12, default=None) + format_id: FormatId | None = protobug.field(13, default=None) + content_length: protobug.Int64 | None = protobug.field(14, default=None) + time_range: TimeRange | None = protobug.field(15, default=None) + sequence_lmt: protobug.Int32 | None = protobug.field(16, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/reload_player_response.py b/yt_dlp/extractor/youtube/_proto/videostreaming/reload_player_response.py new file mode 100644 index 000000000..ac7f1ed5b --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/reload_player_response.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class ReloadPlaybackParams: + token: protobug.String | None = protobug.field(1, default=None) + + +@protobug.message +class ReloadPlayerResponse: + reload_playback_params: ReloadPlaybackParams | None = protobug.field(1, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_context_sending_policy.py b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_context_sending_policy.py new file mode 100644 index 000000000..99d961696 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_context_sending_policy.py @@ -0,0 +1,13 @@ +from yt_dlp.dependencies import protobug + + +@protobug.message +class SabrContextSendingPolicy: + # Start sending the SabrContextUpdates of this type + start_policy: list[protobug.Int32] = protobug.field(1, default_factory=list) + + # Stop sending the SabrContextUpdates of this type + stop_policy: list[protobug.Int32] = protobug.field(2, default_factory=list) + + # Stop and discard the SabrContextUpdates of this type + discard_policy: list[protobug.Int32] = protobug.field(3, default_factory=list) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_context_update.py b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_context_update.py new file mode 100644 index 000000000..7e5eb28f3 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_context_update.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class SabrContextUpdate: + + class SabrContextScope(protobug.Enum, strict=False): + SABR_CONTEXT_SCOPE_UNKNOWN = 0 + SABR_CONTEXT_SCOPE_PLAYBACK = 1 + SABR_CONTEXT_SCOPE_REQUEST = 2 + SABR_CONTEXT_SCOPE_WATCH_ENDPOINT = 3 + SABR_CONTEXT_SCOPE_CONTENT_ADS = 4 + + class SabrContextWritePolicy(protobug.Enum, strict=False): + SABR_CONTEXT_WRITE_POLICY_UNSPECIFIED = 0 + SABR_CONTEXT_WRITE_POLICY_OVERWRITE = 1 + SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING = 2 + + type: protobug.Int32 | None = protobug.field(1, default=None) + scope: SabrContextScope | None = protobug.field(2, default=None) + value: protobug.Bytes | None = protobug.field(3, default=None) + send_by_default: protobug.Bool | None = protobug.field(4, default=None) + write_policy: SabrContextWritePolicy | None = protobug.field(5, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_error.py b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_error.py new file mode 100644 index 000000000..3bf93fa21 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_error.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class Error: + status_code: protobug.Int32 | None = protobug.field(1, default=None) + type: protobug.Int32 | None = protobug.field(4, default=None) + + +@protobug.message +class SabrError: + type: protobug.String | None = protobug.field(1, default=None) + action: protobug.Int32 | None = protobug.field(2, default=None) + error: Error | None = protobug.field(3, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_redirect.py b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_redirect.py new file mode 100644 index 000000000..fe0a1a1b9 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_redirect.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class SabrRedirect: + redirect_url: protobug.String | None = protobug.field(1, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_seek.py b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_seek.py new file mode 100644 index 000000000..a820e2fe6 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/sabr_seek.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + +from ..innertube import SeekSource + + +@protobug.message +class SabrSeek: + seek_time_ticks: protobug.Int32 = protobug.field(1) + timescale: protobug.Int32 = protobug.field(2) + seek_source: SeekSource | None = protobug.field(3, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/stream_protection_status.py b/yt_dlp/extractor/youtube/_proto/videostreaming/stream_protection_status.py new file mode 100644 index 000000000..13e88fd32 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/stream_protection_status.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class StreamProtectionStatus: + + class Status(protobug.Enum, strict=False): + OK = 1 + ATTESTATION_PENDING = 2 + ATTESTATION_REQUIRED = 3 + + status: Status | None = protobug.field(1, default=None) + max_retries: protobug.Int32 | None = protobug.field(2, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/streamer_context.py b/yt_dlp/extractor/youtube/_proto/videostreaming/streamer_context.py new file mode 100644 index 000000000..443628f31 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/streamer_context.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + +from ..innertube import ClientInfo + + +@protobug.message +class SabrContext: + # Type and Value from a SabrContextUpdate + type: protobug.Int32 | None = protobug.field(1, default=None) + value: protobug.Bytes | None = protobug.field(2, default=None) + + +@protobug.message +class StreamerContext: + client_info: ClientInfo | None = protobug.field(1, default=None) + po_token: protobug.Bytes | None = protobug.field(2, default=None) + playback_cookie: protobug.Bytes | None = protobug.field(3, default=None) + sabr_contexts: list[SabrContext] = protobug.field(5, default_factory=list) + unsent_sabr_contexts: list[protobug.Int32] = protobug.field(6, default_factory=list) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/time_range.py b/yt_dlp/extractor/youtube/_proto/videostreaming/time_range.py new file mode 100644 index 000000000..bafea01b4 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/time_range.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + + +@protobug.message +class TimeRange: + start_ticks: protobug.Int64 | None = protobug.field(1, default=None) + duration_ticks: protobug.Int64 | None = protobug.field(2, default=None) + timescale: protobug.Int32 | None = protobug.field(3, default=None) diff --git a/yt_dlp/extractor/youtube/_proto/videostreaming/video_playback_abr_request.py b/yt_dlp/extractor/youtube/_proto/videostreaming/video_playback_abr_request.py new file mode 100644 index 000000000..caa46ba91 --- /dev/null +++ b/yt_dlp/extractor/youtube/_proto/videostreaming/video_playback_abr_request.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from yt_dlp.dependencies import protobug + +from .buffered_range import BufferedRange +from .client_abr_state import ClientAbrState +from .format_id import FormatId +from .streamer_context import StreamerContext + + +@protobug.message +class VideoPlaybackAbrRequest: + client_abr_state: ClientAbrState = protobug.field(1, default=None) + initialized_format_ids: list[FormatId] = protobug.field(2, default_factory=list) + buffered_ranges: list[BufferedRange] = protobug.field(3, default_factory=list) + player_time_ms: protobug.Int64 | None = protobug.field(4, default=None) + video_playback_ustreamer_config: protobug.Bytes | None = protobug.field(5, default=None) + + selected_audio_format_ids: list[FormatId] = protobug.field(16, default_factory=list) + selected_video_format_ids: list[FormatId] = protobug.field(17, default_factory=list) + selected_caption_format_ids: list[FormatId] = protobug.field(18, default_factory=list) + streamer_context: StreamerContext = protobug.field(19, default_factory=StreamerContext) diff --git a/yt_dlp/extractor/youtube/_streaming/__init__.py b/yt_dlp/extractor/youtube/_streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/exceptions.py b/yt_dlp/extractor/youtube/_streaming/sabr/exceptions.py new file mode 100644 index 000000000..7387db8e8 --- /dev/null +++ b/yt_dlp/extractor/youtube/_streaming/sabr/exceptions.py @@ -0,0 +1,26 @@ +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId +from yt_dlp.utils import YoutubeDLError + + +class SabrStreamConsumedError(YoutubeDLError): + pass + + +class SabrStreamError(YoutubeDLError): + pass + + +class MediaSegmentMismatchError(SabrStreamError): + def __init__(self, format_id: FormatId, expected_sequence_number: int, received_sequence_number: int): + super().__init__( + f'Segment sequence number mismatch for format {format_id}: ' + f'expected {expected_sequence_number}, received {received_sequence_number}') + self.expected_sequence_number = expected_sequence_number + self.received_sequence_number = received_sequence_number + + +class PoTokenError(SabrStreamError): + def __init__(self, missing=False): + super().__init__( + f'This stream requires a GVS PO Token to continue' + f'{" and the one provided is invalid" if not missing else ""}') diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/models.py b/yt_dlp/extractor/youtube/_streaming/sabr/models.py new file mode 100644 index 000000000..52edcc5ce --- /dev/null +++ b/yt_dlp/extractor/youtube/_streaming/sabr/models.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import dataclasses + +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId +from yt_dlp.extractor.youtube.pot._provider import IEContentProviderLogger + + +@dataclasses.dataclass +class Segment: + format_id: FormatId + is_init_segment: bool = False + duration_ms: int = 0 + start_ms: int = 0 + start_data_range: int = 0 + sequence_number: int = 0 + content_length: int | None = None + content_length_estimated: bool = False + initialized_format: InitializedFormat = None + # Whether duration_ms is an estimate + duration_estimated: bool = False + # Whether we should discard the segment data + discard: bool = False + # Whether the segment has already been consumed. + # `discard` should be set to True if this is the case. + consumed: bool = False + received_data_length: int = 0 + sequence_lmt: int | None = None + + +@dataclasses.dataclass +class ConsumedRange: + start_sequence_number: int + end_sequence_number: int + start_time_ms: int + duration_ms: int + + +@dataclasses.dataclass +class InitializedFormat: + format_id: FormatId + video_id: str + format_selector: FormatSelector | None = None + duration_ms: int = 0 + end_time_ms: int = 0 + mime_type: str = None + # Current segment in the sequence. Set to None to break the sequence and allow a seek. + current_segment: Segment | None = None + init_segment: Segment | None | bool = None + consumed_ranges: list[ConsumedRange] = dataclasses.field(default_factory=list) + total_segments: int = None + # Whether we should discard any data received for this format + discard: bool = False + sequence_lmt: int | None = None + + +SabrLogger = IEContentProviderLogger + + +@dataclasses.dataclass +class FormatSelector: + display_name: str + format_ids: list[FormatId] = dataclasses.field(default_factory=list) + discard_media: bool = False + + def match(self, format_id: FormatId = None, **kwargs) -> bool: + return format_id in self.format_ids + + +@dataclasses.dataclass +class AudioSelector(FormatSelector): + + def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool: + return ( + super().match(format_id, mime_type=mime_type, **kwargs) + or (not self.format_ids and mime_type and mime_type.lower().startswith('audio')) + ) + + +@dataclasses.dataclass +class VideoSelector(FormatSelector): + + def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool: + return ( + super().match(format_id, mime_type=mime_type, **kwargs) + or (not self.format_ids and mime_type and mime_type.lower().startswith('video')) + ) + + +@dataclasses.dataclass +class CaptionSelector(FormatSelector): + + def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool: + return ( + super().match(format_id, mime_type=mime_type, **kwargs) + or (not self.format_ids and mime_type and mime_type.lower().startswith('text')) + ) diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/part.py b/yt_dlp/extractor/youtube/_streaming/sabr/part.py new file mode 100644 index 000000000..8c7d8dc0c --- /dev/null +++ b/yt_dlp/extractor/youtube/_streaming/sabr/part.py @@ -0,0 +1,92 @@ +import dataclasses +import enum + +from yt_dlp.extractor.youtube._proto.videostreaming import FormatId + +from .models import FormatSelector + + +@dataclasses.dataclass +class SabrPart: + pass + + +@dataclasses.dataclass +class MediaSegmentInitSabrPart(SabrPart): + format_selector: FormatSelector + format_id: FormatId + sequence_number: int = None + is_init_segment: bool = False + total_segments: int = None + start_time_ms: int = None + player_time_ms: int = None + duration_ms: int = None + duration_estimated: bool = False + start_bytes: int = None + content_length: int = None + content_length_estimated: bool = False + + +@dataclasses.dataclass +class MediaSegmentDataSabrPart(SabrPart): + format_selector: FormatSelector + format_id: FormatId + sequence_number: int = None + is_init_segment: bool = False + total_segments: int = None + data: bytes = b'' + content_length: int = None + segment_start_bytes: int = None + + +@dataclasses.dataclass +class MediaSegmentEndSabrPart(SabrPart): + format_selector: FormatSelector + format_id: FormatId + sequence_number: int = None + is_init_segment: bool = False + total_segments: int = None + + +@dataclasses.dataclass +class FormatInitializedSabrPart(SabrPart): + format_id: FormatId + format_selector: FormatSelector + + +@dataclasses.dataclass +class PoTokenStatusSabrPart(SabrPart): + class PoTokenStatus(enum.Enum): + OK = enum.auto() # PO Token is provided and valid + MISSING = enum.auto() # PO Token is not provided, and is required. A PO Token should be provided ASAP + INVALID = enum.auto() # PO Token is provided, but is invalid. A new one should be generated ASAP + PENDING = enum.auto() # PO Token is provided, but probably only a cold start token. A full PO Token should be provided ASAP + NOT_REQUIRED = enum.auto() # PO Token is not provided, and is not required + PENDING_MISSING = enum.auto() # PO Token is not provided, but is pending. A full PO Token should be (probably) provided ASAP + + status: PoTokenStatus + + +@dataclasses.dataclass +class RefreshPlayerResponseSabrPart(SabrPart): + + class Reason(enum.Enum): + UNKNOWN = enum.auto() + SABR_URL_EXPIRY = enum.auto() + SABR_RELOAD_PLAYER_RESPONSE = enum.auto() + + reason: Reason + reload_playback_token: str = None + + +@dataclasses.dataclass +class MediaSeekSabrPart(SabrPart): + # Lets the consumer know the media sequence for a format may change + class Reason(enum.Enum): + UNKNOWN = enum.auto() + SERVER_SEEK = enum.auto() # SABR_SEEK from server + CONSUMED_SEEK = enum.auto() # Seeking as next fragment is already buffered + + reason: Reason + format_id: FormatId + format_selector: FormatSelector diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/processor.py b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py new file mode 100644 index 000000000..85e3bedde --- /dev/null +++ b/yt_dlp/extractor/youtube/_streaming/sabr/processor.py @@ -0,0 +1,671 @@ +from __future__ import annotations + +import base64 +import io +import math + +from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy +from yt_dlp.extractor.youtube._proto.videostreaming import ( + BufferedRange, + ClientAbrState, + FormatInitializationMetadata, + LiveMetadata, + MediaHeader, + SabrContext, + SabrContextSendingPolicy, + SabrContextUpdate, + SabrSeek, + StreamerContext, + StreamProtectionStatus, + TimeRange, + VideoPlaybackAbrRequest, +) + +from .exceptions import MediaSegmentMismatchError, SabrStreamError +from .models import ( + AudioSelector, + CaptionSelector, + ConsumedRange, + InitializedFormat, + SabrLogger, + Segment, + VideoSelector, +) +from .part import ( + FormatInitializedSabrPart, + MediaSeekSabrPart, + MediaSegmentDataSabrPart, + MediaSegmentEndSabrPart, + MediaSegmentInitSabrPart, + PoTokenStatusSabrPart, +) +from .utils import ticks_to_ms + + +class ProcessMediaEndResult: + def __init__(self, sabr_part: MediaSegmentEndSabrPart = None, is_new_segment: bool = False): + self.is_new_segment = is_new_segment + self.sabr_part = sabr_part + + +class ProcessMediaResult: + def __init__(self, sabr_part: MediaSegmentDataSabrPart = None): + self.sabr_part = sabr_part + + +class ProcessMediaHeaderResult: + def __init__(self, sabr_part: MediaSegmentInitSabrPart | None = None): + self.sabr_part = sabr_part + + +class ProcessLiveMetadataResult: + def __init__(self, seek_sabr_parts: list[MediaSeekSabrPart] | None = None): + self.seek_sabr_parts = seek_sabr_parts or [] + + +class ProcessStreamProtectionStatusResult: + def __init__(self, sabr_part: PoTokenStatusSabrPart | None = None): + self.sabr_part = sabr_part + + +class ProcessFormatInitializationMetadataResult: + def __init__(self, sabr_part: FormatInitializedSabrPart | None = None): + self.sabr_part = sabr_part + + +class ProcessSabrSeekResult: + def __init__(self, seek_sabr_parts: list[MediaSeekSabrPart] | None = None): + self.seek_sabr_parts = seek_sabr_parts or [] + + +class SabrProcessor: + """ + SABR Processor + + This handles core SABR protocol logic, independent of requests. + """ + + def __init__( + self, + logger: SabrLogger, + video_playback_ustreamer_config: str, + client_info: ClientInfo, + audio_selection: AudioSelector | None = None, + video_selection: VideoSelector | None = None, + caption_selection: CaptionSelector | None = None, + live_segment_target_duration_sec: int | None = None, + live_segment_target_duration_tolerance_ms: int | None = None, + start_time_ms: int | None = None, + po_token: str | None = None, + live_end_wait_sec: int | None = None, + live_end_segment_tolerance: int | None = None, + post_live: bool = False, + video_id: str | None = None, + ): + + self.logger = logger + + self.video_playback_ustreamer_config = video_playback_ustreamer_config + self.po_token = po_token + self.client_info = client_info + self.live_segment_target_duration_sec = live_segment_target_duration_sec or 5 + self.live_segment_target_duration_tolerance_ms = live_segment_target_duration_tolerance_ms or 100 + if self.live_segment_target_duration_tolerance_ms >= (self.live_segment_target_duration_sec * 1000) / 2: + raise ValueError( + 'live_segment_target_duration_tolerance_ms must be less than ' + 'half of live_segment_target_duration_sec in milliseconds', + ) + self.start_time_ms = start_time_ms or 0 + if self.start_time_ms < 0: + raise ValueError('start_time_ms must be greater than or equal to 0') + + self.live_end_wait_sec = live_end_wait_sec or max(10, 3 * self.live_segment_target_duration_sec) + self.live_end_segment_tolerance = live_end_segment_tolerance or 10 + self.post_live = post_live + self._is_live = False + self.video_id = video_id + + self._audio_format_selector = audio_selection + self._video_format_selector = video_selection + self._caption_format_selector = caption_selection + + # IMPORTANT: initialized formats is assumed to contain only ACTIVE formats + self.initialized_formats: dict[str, InitializedFormat] = {} + self.stream_protection_status: StreamProtectionStatus.Status | None = None + + self.partial_segments: dict[int, Segment] = {} + self.total_duration_ms = None + self.selected_audio_format_ids = [] + self.selected_video_format_ids = [] + self.selected_caption_format_ids = [] + self.next_request_policy: NextRequestPolicy | None = None + self.live_metadata: LiveMetadata | None = None + self.client_abr_state: ClientAbrState + self.sabr_contexts_to_send: set[int] = set() + self.sabr_context_updates: dict[int, SabrContextUpdate] = {} + self._initialize_cabr_state() + + @property + def is_live(self): + return bool( + self.live_metadata + or self._is_live, + ) + + @is_live.setter + def is_live(self, value: bool): + self._is_live = value + + def _initialize_cabr_state(self): + enabled_track_types_bitfield = 0 # Audio+Video + if not self._video_format_selector: + enabled_track_types_bitfield = 1 # Audio only + self._video_format_selector = VideoSelector(display_name='video_ignore', discard_media=True) + + if self._caption_format_selector: + # SABR does not support caption-only or audio+captions only - can only get audio+video with captions + # If audio or video is not selected, the tracks will be initialized but marked as buffered. + enabled_track_types_bitfield = 7 + + # SABR does not support video-only, so we need to discard the audio track received. + # We need a selector as the server sometimes does not like it + # if we haven't initialized an audio format (e.g. livestreams). + if not self._audio_format_selector: + self._audio_format_selector = AudioSelector(display_name='audio_ignore', discard_media=True) + + if not self._caption_format_selector: + self._caption_format_selector = CaptionSelector(display_name='caption_ignore', discard_media=True) + + self.selected_audio_format_ids = self._audio_format_selector.format_ids + self.selected_video_format_ids = self._video_format_selector.format_ids + self.selected_caption_format_ids = self._caption_format_selector.format_ids + + self.logger.debug(f'Starting playback at: {self.start_time_ms}ms') + self.client_abr_state = ClientAbrState( + player_time_ms=self.start_time_ms, + enabled_track_types_bitfield=enabled_track_types_bitfield) + + def match_format_selector(self, format_init_metadata): + for format_selector in (self._video_format_selector, self._audio_format_selector, self._caption_format_selector): + if not format_selector: + continue + if format_selector.match(format_id=format_init_metadata.format_id, mime_type=format_init_metadata.mime_type): + return format_selector + return None + + def process_media_header(self, media_header: MediaHeader) -> ProcessMediaHeaderResult: + if media_header.video_id and self.video_id and media_header.video_id != self.video_id: + raise SabrStreamError( + f'Received unexpected MediaHeader for video' + f' {media_header.video_id} (expecting {self.video_id})') + + if not media_header.format_id: + raise SabrStreamError(f'Format ID not found in MediaHeader (media_header={media_header})') + + # Guard. This should not happen, except if we don't clear partial segments + if media_header.header_id in self.partial_segments: + raise SabrStreamError(f'Header ID {media_header.header_id} already exists') + + result = ProcessMediaHeaderResult() + + initialized_format = self.initialized_formats.get(str(media_header.format_id)) + if not initialized_format: + self.logger.debug(f'Initialized format not found for {media_header.format_id}') + return result + + if media_header.compression: + # Unknown when this is used, but it is not supported currently + raise SabrStreamError(f'Compression not supported in MediaHeader (media_header={media_header})') + + sequence_number, is_init_segment = media_header.sequence_number, media_header.is_init_segment + if sequence_number is None and not media_header.is_init_segment: + raise SabrStreamError(f'Sequence number not found in MediaHeader (media_header={media_header})') + + initialized_format.sequence_lmt = media_header.sequence_lmt + + # Need to keep track of if we discard due to be consumed or not + # for processing down the line (MediaEnd) + consumed = False + discard = initialized_format.discard + + # Guard: Check if sequence number is within any existing consumed range + # The server should not send us any segments that are already consumed + # However, if retrying a request, we may get the same segment again + if not is_init_segment and any( + cr.start_sequence_number <= sequence_number <= cr.end_sequence_number + for cr in initialized_format.consumed_ranges + ): + self.logger.debug(f'{initialized_format.format_id} segment {sequence_number} already consumed, marking segment as consumed') + consumed = True + + # Validate that the segment is in order. + # Note: If the format is to be discarded, we do not care about the order + # and can expect uncommanded seeks as the consumer does not know about it. + # Note: previous segment should never be an init segment. + previous_segment = initialized_format.current_segment + if ( + previous_segment and not is_init_segment + and not previous_segment.discard and not discard and not consumed + and sequence_number != previous_segment.sequence_number + 1 + ): + # Bail out as the segment is not in order when it is expected to be + raise MediaSegmentMismatchError( + expected_sequence_number=previous_segment.sequence_number + 1, + received_sequence_number=sequence_number, + format_id=media_header.format_id) + + if initialized_format.init_segment and is_init_segment: + self.logger.debug( + f'Init segment {sequence_number} already seen for format {initialized_format.format_id}, marking segment as consumed') + consumed = True + + time_range = media_header.time_range + start_ms = media_header.start_ms or (time_range and ticks_to_ms(time_range.start_ticks, time_range.timescale)) or 0 + + # Calculate duration of this segment + # For videos, either duration_ms or time_range should be present + # For live streams, calculate segment duration based on live metadata target segment duration + actual_duration_ms = ( + media_header.duration_ms + or (time_range and ticks_to_ms(time_range.duration_ticks, time_range.timescale))) + + estimated_duration_ms = None + if self.is_live: + # Underestimate the duration of the segment slightly as + # the real duration may be slightly shorter than the target duration. + estimated_duration_ms = (self.live_segment_target_duration_sec * 1000) - self.live_segment_target_duration_tolerance_ms + elif is_init_segment: + estimated_duration_ms = 0 + + duration_ms = actual_duration_ms or estimated_duration_ms + + estimated_content_length = None + if self.is_live and media_header.content_length is None and media_header.bitrate_bps is not None: + estimated_content_length = math.ceil(media_header.bitrate_bps * (estimated_duration_ms / 1000)) + + # Guard: Bail out if we cannot determine the duration, which we need to progress. + if duration_ms is None: + raise SabrStreamError(f'Cannot determine duration of segment {sequence_number} (media_header={media_header})') + + segment = Segment( + format_id=media_header.format_id, + is_init_segment=is_init_segment, + duration_ms=duration_ms, + start_data_range=media_header.start_data_range, + sequence_number=sequence_number, + content_length=media_header.content_length or estimated_content_length, + content_length_estimated=estimated_content_length is not None, + start_ms=start_ms, + initialized_format=initialized_format, + duration_estimated=not actual_duration_ms, + discard=discard or consumed, + consumed=consumed, + sequence_lmt=media_header.sequence_lmt, + ) + + self.partial_segments[media_header.header_id] = segment + + if not segment.discard: + result.sabr_part = MediaSegmentInitSabrPart( + format_selector=segment.initialized_format.format_selector, + format_id=segment.format_id, + player_time_ms=self.client_abr_state.player_time_ms, + sequence_number=segment.sequence_number, + total_segments=segment.initialized_format.total_segments, + duration_ms=segment.duration_ms, + start_bytes=segment.start_data_range, + start_time_ms=segment.start_ms, + is_init_segment=segment.is_init_segment, + content_length=segment.content_length, + content_length_estimated=segment.content_length_estimated, + ) + + self.logger.trace( + f'Initialized Media Header {media_header.header_id} for sequence {sequence_number}. Segment: {segment}') + + return result + + def process_media(self, header_id: int, content_length: int, data: io.BufferedIOBase) -> ProcessMediaResult: + result = ProcessMediaResult() + segment = self.partial_segments.get(header_id) + if not segment: + self.logger.debug(f'Header ID {header_id} not found') + return result + + segment_start_bytes = segment.received_data_length + segment.received_data_length += content_length + + if not segment.discard: + result.sabr_part = MediaSegmentDataSabrPart( + format_selector=segment.initialized_format.format_selector, + format_id=segment.format_id, + sequence_number=segment.sequence_number, + is_init_segment=segment.is_init_segment, + total_segments=segment.initialized_format.total_segments, + data=data.read(), + content_length=content_length, + segment_start_bytes=segment_start_bytes, + ) + + return result + + def process_media_end(self, header_id: int) -> ProcessMediaEndResult: + result = ProcessMediaEndResult() + segment = self.partial_segments.pop(header_id, None) + if not segment: + # Should only happen due to server issue, + # or we have an uninitialized format (which itself should not happen) + self.logger.warning(f'Received a MediaEnd for an unknown or already finished header ID {header_id}') + return result + + self.logger.trace( + f'MediaEnd for {segment.format_id} (sequence {segment.sequence_number}, data length = {segment.received_data_length})') + + if segment.content_length is not None and segment.received_data_length != segment.content_length: + if segment.content_length_estimated: + self.logger.trace( + f'Content length for {segment.format_id} (sequence {segment.sequence_number}) was estimated, ' + f'estimated {segment.content_length} bytes, got {segment.received_data_length} bytes') + else: + raise SabrStreamError( + f'Content length mismatch for {segment.format_id} (sequence {segment.sequence_number}): ' + f'expected {segment.content_length} bytes, got {segment.received_data_length} bytes', + ) + + # Only count received segments as new segments if they are not discarded (consumed) + # or it was part of a format that was discarded (but not consumed). + # The latter can happen if the format is to be discarded but was not marked as fully consumed. + if not segment.discard or (segment.initialized_format.discard and not segment.consumed): + result.is_new_segment = True + + # Return the segment here instead of during MEDIA part(s) because: + # 1. We can validate that we received the correct data length + # 2. In the case of a retry during segment media, the partial data is not sent to the consumer + if not segment.discard: + # This needs to be yielded AFTER we have processed the segment + # So the consumer can see the updated consumed ranges and use them for e.g. syncing between concurrent streams + result.sabr_part = MediaSegmentEndSabrPart( + format_selector=segment.initialized_format.format_selector, + format_id=segment.format_id, + sequence_number=segment.sequence_number, + is_init_segment=segment.is_init_segment, + total_segments=segment.initialized_format.total_segments, + ) + else: + self.logger.trace(f'Discarding media for {segment.initialized_format.format_id}') + + if segment.is_init_segment: + segment.initialized_format.init_segment = segment + # Do not create a consumed range for init segments + return result + + if segment.initialized_format.current_segment and self.is_live: + previous_segment = segment.initialized_format.current_segment + self.logger.trace( + f'Previous segment {previous_segment.sequence_number} for format {segment.format_id} ' + f'estimated duration difference from this segment ({segment.sequence_number}): {segment.start_ms - (previous_segment.start_ms + previous_segment.duration_ms)}ms') + + segment.initialized_format.current_segment = segment + + # Try to find a consumed range for this segment in sequence + consumed_range = next( + (cr for cr in segment.initialized_format.consumed_ranges if cr.end_sequence_number == segment.sequence_number - 1), + None, + ) + + if not consumed_range and any( + cr.start_sequence_number <= segment.sequence_number <= cr.end_sequence_number + for cr in segment.initialized_format.consumed_ranges + ): + # Segment is already consumed, do not create a new consumed range. It was probably discarded. + # This can be expected to happen in the case of video-only, where we discard the audio track (and mark it as entirely buffered) + # We still want to create/update consumed range for discarded media IF it is not already consumed + self.logger.debug(f'{segment.format_id} segment {segment.sequence_number} already consumed, not creating or updating consumed range (discard={segment.discard})') + return result + + if not consumed_range: + # Create a new consumed range starting from this segment + segment.initialized_format.consumed_ranges.append(ConsumedRange( + start_time_ms=segment.start_ms, + duration_ms=segment.duration_ms, + start_sequence_number=segment.sequence_number, + end_sequence_number=segment.sequence_number, + )) + self.logger.debug(f'Created new consumed range for {segment.initialized_format.format_id} {segment.initialized_format.consumed_ranges[-1]}') + return result + + # Update the existing consumed range to include this segment + consumed_range.end_sequence_number = segment.sequence_number + consumed_range.duration_ms = (segment.start_ms - consumed_range.start_time_ms) + segment.duration_ms + + # TODO: Conduct a seek on consumed ranges + + return result + + def process_live_metadata(self, live_metadata: LiveMetadata) -> ProcessLiveMetadataResult: + self.live_metadata = live_metadata + if self.live_metadata.head_sequence_time_ms: + self.total_duration_ms = self.live_metadata.head_sequence_time_ms + + # If we have a head sequence number, we need to update the total sequences for each initialized format + # For livestreams, it is not available in the format initialization metadata + if self.live_metadata.head_sequence_number: + for izf in self.initialized_formats.values(): + izf.total_segments = self.live_metadata.head_sequence_number + + result = ProcessLiveMetadataResult() + + # If the current player time is less than the min dvr time, simulate a server seek to the min dvr time. + # The server SHOULD send us a SABR_SEEK part in this case, but it does not always happen (e.g. ANDROID_VR) + # The server SHOULD NOT send us segments before the min dvr time, so we should assume that the player time is correct. + min_seekable_time_ms = ticks_to_ms(self.live_metadata.min_seekable_time_ticks, self.live_metadata.min_seekable_timescale) + if min_seekable_time_ms is not None and self.client_abr_state.player_time_ms < min_seekable_time_ms: + self.logger.debug(f'Player time {self.client_abr_state.player_time_ms} is less than min seekable time {min_seekable_time_ms}, simulating server seek') + self.client_abr_state.player_time_ms = min_seekable_time_ms + + for izf in self.initialized_formats.values(): + izf.current_segment = None # Clear the current segment as we expect segments to no longer be in order. + result.seek_sabr_parts.append(MediaSeekSabrPart( + reason=MediaSeekSabrPart.Reason.SERVER_SEEK, + format_id=izf.format_id, + format_selector=izf.format_selector, + )) + + return result + + def process_stream_protection_status(self, stream_protection_status: StreamProtectionStatus) -> ProcessStreamProtectionStatusResult: + self.stream_protection_status = stream_protection_status.status + status = stream_protection_status.status + po_token = self.po_token + + if status == StreamProtectionStatus.Status.OK: + result_status = ( + PoTokenStatusSabrPart.PoTokenStatus.OK if po_token + else PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED + ) + elif status == StreamProtectionStatus.Status.ATTESTATION_PENDING: + result_status = ( + PoTokenStatusSabrPart.PoTokenStatus.PENDING if po_token + else PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING + ) + elif status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED: + result_status = ( + PoTokenStatusSabrPart.PoTokenStatus.INVALID if po_token + else PoTokenStatusSabrPart.PoTokenStatus.MISSING + ) + else: + result_status = None + + sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None + return ProcessStreamProtectionStatusResult(sabr_part) + + def process_format_initialization_metadata(self, format_init_metadata: FormatInitializationMetadata) -> ProcessFormatInitializationMetadataResult: + result = ProcessFormatInitializationMetadataResult() + if str(format_init_metadata.format_id) in self.initialized_formats: + self.logger.trace(f'Format {format_init_metadata.format_id} already initialized') + return result + + if format_init_metadata.video_id and self.video_id and format_init_metadata.video_id != self.video_id: + raise SabrStreamError( + f'Received unexpected Format Initialization Metadata for video' + f' {format_init_metadata.video_id} (expecting {self.video_id})') + + format_selector = self.match_format_selector(format_init_metadata) + if not format_selector: + # Should not happen. If we ignored the format the server may refuse to send us any more data + raise SabrStreamError(f'Received format {format_init_metadata.format_id} but it does not match any format selector') + + # Guard: Check if the format selector is already in use by another initialized format. + # This can happen when the server changes the format to use (e.g. changing quality). + # + # Changing a format will require adding some logic to handle inactive formats. + # Given we only provide one FormatId currently, and this should not occur in this case, + # we will mark this as not currently supported and bail. + for izf in self.initialized_formats.values(): + if izf.format_selector is format_selector: + raise SabrStreamError('Server changed format. Changing formats is not currently supported') + + duration_ms = ticks_to_ms(format_init_metadata.duration_timescale, format_init_metadata.duration_ticks) + + total_segments = format_init_metadata.total_segments + if not total_segments and self.live_metadata and self.live_metadata.head_sequence_number: + total_segments = self.live_metadata.head_sequence_number + + initialized_format = InitializedFormat( + format_id=format_init_metadata.format_id, + duration_ms=duration_ms, + end_time_ms=format_init_metadata.end_time_ms, + mime_type=format_init_metadata.mime_type, + video_id=format_init_metadata.video_id, + format_selector=format_selector, + total_segments=total_segments, + discard=format_selector.discard_media, + ) + self.total_duration_ms = max(self.total_duration_ms or 0, format_init_metadata.end_time_ms or 0, duration_ms or 0) + + if initialized_format.discard: + # Mark the entire format as buffered into oblivion if we plan to discard all media. + # This stops the server sending us any more data for this format. + # Note: Using JS_MAX_SAFE_INTEGER but could use any maximum value as long as the server accepts it. + initialized_format.consumed_ranges = [ConsumedRange( + start_time_ms=0, + duration_ms=(2**53) - 1, + start_sequence_number=0, + end_sequence_number=(2**53) - 1, + )] + + self.initialized_formats[str(format_init_metadata.format_id)] = initialized_format + self.logger.debug(f'Initialized Format: {initialized_format}') + + if not initialized_format.discard: + result.sabr_part = FormatInitializedSabrPart( + format_id=format_init_metadata.format_id, + format_selector=format_selector, + ) + + return ProcessFormatInitializationMetadataResult(sabr_part=result.sabr_part) + + def process_next_request_policy(self, next_request_policy: NextRequestPolicy): + self.next_request_policy = next_request_policy + self.logger.trace(f'Registered new NextRequestPolicy: {self.next_request_policy}') + + def process_sabr_seek(self, sabr_seek: SabrSeek) -> ProcessSabrSeekResult: + seek_to = ticks_to_ms(sabr_seek.seek_time_ticks, sabr_seek.timescale) + if seek_to is None: + raise SabrStreamError(f'Server sent a SabrSeek part that is missing required seek data: {sabr_seek}') + self.logger.debug(f'Seeking to {seek_to}ms') + self.client_abr_state.player_time_ms = seek_to + + result = ProcessSabrSeekResult() + + # Clear latest segment of each initialized format + # as we expect them to no longer be in order. + for initialized_format in self.initialized_formats.values(): + initialized_format.current_segment = None + result.seek_sabr_parts.append(MediaSeekSabrPart( + reason=MediaSeekSabrPart.Reason.SERVER_SEEK, + format_id=initialized_format.format_id, + format_selector=initialized_format.format_selector, + )) + return result + + def process_sabr_context_update(self, sabr_ctx_update: SabrContextUpdate): + if not (sabr_ctx_update.type and sabr_ctx_update.value and sabr_ctx_update.write_policy): + self.logger.warning('Received an invalid SabrContextUpdate, ignoring') + return + + if ( + sabr_ctx_update.write_policy == SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING + and sabr_ctx_update.type in self.sabr_context_updates + ): + self.logger.debug( + 'Received a SABR Context Update with write_policy=KEEP_EXISTING' + 'matching an existing SABR Context Update. Ignoring update') + return + + self.logger.warning( + 'Received a SABR Context Update. YouTube is likely trying to force ads on the client. ' + 'This may cause issues with playback.') + + self.sabr_context_updates[sabr_ctx_update.type] = sabr_ctx_update + if sabr_ctx_update.send_by_default is True: + self.sabr_contexts_to_send.add(sabr_ctx_update.type) + self.logger.debug(f'Registered SabrContextUpdate {sabr_ctx_update}') + + def process_sabr_context_sending_policy(self, sabr_ctx_sending_policy: SabrContextSendingPolicy): + for start_type in sabr_ctx_sending_policy.start_policy: + if start_type not in self.sabr_context_updates: + self.logger.debug(f'Server requested to enable SABR Context Update for type {start_type}') + self.sabr_contexts_to_send.add(start_type) + + for stop_type in sabr_ctx_sending_policy.stop_policy: + if stop_type in self.sabr_contexts_to_send: + self.logger.debug(f'Server requested to disable SABR Context Update for type {stop_type}') + self.sabr_contexts_to_send.remove(stop_type) + + for discard_type in sabr_ctx_sending_policy.discard_policy: + if discard_type in self.sabr_context_updates: + self.logger.debug(f'Server requested to discard SABR Context Update for type {discard_type}') + self.sabr_context_updates.pop(discard_type, None) + + +def build_vpabr_request(processor: SabrProcessor): + return VideoPlaybackAbrRequest( + client_abr_state=processor.client_abr_state, + selected_video_format_ids=processor.selected_video_format_ids, + selected_audio_format_ids=processor.selected_audio_format_ids, + selected_caption_format_ids=processor.selected_caption_format_ids, + initialized_format_ids=[ + initialized_format.format_id for initialized_format in processor.initialized_formats.values() + ], + video_playback_ustreamer_config=base64.urlsafe_b64decode(processor.video_playback_ustreamer_config), + streamer_context=StreamerContext( + po_token=base64.urlsafe_b64decode(processor.po_token) if processor.po_token is not None else None, + playback_cookie=processor.next_request_policy.playback_cookie if processor.next_request_policy is not None else None, + client_info=processor.client_info, + sabr_contexts=[ + SabrContext(context.type, context.value) + for context in processor.sabr_context_updates.values() + if context.type in processor.sabr_contexts_to_send + ], + unsent_sabr_contexts=[ + context_type for context_type in processor.sabr_contexts_to_send + if context_type not in processor.sabr_context_updates + ], + ), + buffered_ranges=[ + BufferedRange( + format_id=initialized_format.format_id, + start_segment_index=cr.start_sequence_number, + end_segment_index=cr.end_sequence_number, + start_time_ms=cr.start_time_ms, + duration_ms=cr.duration_ms, + time_range=TimeRange( + start_ticks=cr.start_time_ms, + duration_ticks=cr.duration_ms, + timescale=1000, + ), + ) for initialized_format in processor.initialized_formats.values() + for cr in initialized_format.consumed_ranges + ], + ) diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/stream.py b/yt_dlp/extractor/youtube/_streaming/sabr/stream.py new file mode 100644 index 000000000..5e1096a9a --- /dev/null +++ b/yt_dlp/extractor/youtube/_streaming/sabr/stream.py @@ -0,0 +1,813 @@ +from __future__ import annotations + +import base64 +import dataclasses +import datetime as dt +import math +import time +import typing +import urllib.parse + +from yt_dlp.dependencies import protobug +from yt_dlp.extractor.youtube._proto import unknown_fields +from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy +from yt_dlp.extractor.youtube._proto.videostreaming import ( + FormatInitializationMetadata, + LiveMetadata, + MediaHeader, + ReloadPlayerResponse, + SabrContextSendingPolicy, + SabrContextUpdate, + SabrError, + SabrRedirect, + SabrSeek, + StreamProtectionStatus, +) +from yt_dlp.networking import Request, Response +from yt_dlp.networking.exceptions import HTTPError, TransportError +from yt_dlp.utils import RetryManager, int_or_none, parse_qs, str_or_none, traverse_obj + +from .exceptions import MediaSegmentMismatchError, PoTokenError, SabrStreamConsumedError, SabrStreamError +from .models import AudioSelector, CaptionSelector, SabrLogger, VideoSelector +from .part import ( + MediaSeekSabrPart, + RefreshPlayerResponseSabrPart, +) +from .processor import SabrProcessor, build_vpabr_request +from .utils import broadcast_id_from_url, get_cr_chain, next_gvs_fallback_url +from ..ump import UMPDecoder, UMPPart, UMPPartId, read_varint + + +class SabrStream: + + """ + + A YouTube SABR (Server Adaptive Bit Rate) client implementation designed for downloading streams and videos. + + It presents an iterator (iter_parts) that yields the next available segments and other metadata. + + Parameters: + @param urlopen: A callable that takes a Request and returns a Response. Raises TransportError or HTTPError on failure. + @param logger: The logger. + @param server_abr_streaming_url: SABR streaming URL. + @param video_playback_ustreamer_config: The base64url encoded ustreamer config. + @param client_info: The Innertube client info. + @param audio_selection: The audio format selector to use for audio. + @param video_selection: The video format selector to use for video. + @param caption_selection: The caption format selector to use for captions. + @param live_segment_target_duration_sec: The target duration of live segments in seconds. + @param live_segment_target_duration_tolerance_ms: The tolerance to accept for estimated duration of live segment in milliseconds. + @param start_time_ms: The time in milliseconds to start playback from. + @param po_token: Initial GVS PO Token. + @param http_retries: The maximum number of times to retry a request before failing. + @param pot_retries: The maximum number of times to retry PO Token errors before failing. + @param host_fallback_threshold: The number of consecutive retries before falling back to the next GVS server. + @param max_empty_requests: The maximum number of consecutive requests with no new segments before giving up. + @param live_end_wait_sec: The number of seconds to wait after the last received segment before considering the live stream ended. + @param live_end_segment_tolerance: The number of segments before the live head segment at which the livestream is allowed to end. Defaults to 10. + @param post_live: Whether the live stream is in post-live mode. Used to determine how to handle the end of the stream. + @param video_id: The video ID of the YouTube video. Used for validating received data is for the correct video. + @param retry_sleep_func: A function to sleep between retries. If None, will not sleep between retries. + @param expiry_threshold_sec: The number of seconds before the GVS expiry to consider it expired. Defaults to 1 minute. + """ + + # Used for debugging + _IGNORED_PARTS = ( + UMPPartId.REQUEST_IDENTIFIER, + UMPPartId.REQUEST_CANCELLATION_POLICY, + UMPPartId.PLAYBACK_START_POLICY, + UMPPartId.ALLOWED_CACHED_FORMATS, + UMPPartId.PAUSE_BW_SAMPLING_HINT, + UMPPartId.START_BW_SAMPLING_HINT, + UMPPartId.REQUEST_PIPELINING, + UMPPartId.SELECTABLE_FORMATS, + UMPPartId.PREWARM_CONNECTION, + ) + + @dataclasses.dataclass + class _NoSegmentsTracker: + consecutive_requests: int = 0 + timestamp_started: float | None = None + live_head_segment_started: int | None = None + + def reset(self): + self.consecutive_requests = 0 + self.timestamp_started = None + self.live_head_segment_started = None + + def increment(self, live_head_segment=None): + if self.consecutive_requests == 0: + self.timestamp_started = time.time() + self.live_head_segment_started = live_head_segment + self.consecutive_requests += 1 + + def __init__( + self, + urlopen: typing.Callable[[Request], Response], + logger: SabrLogger, + server_abr_streaming_url: str, + video_playback_ustreamer_config: str, + client_info: ClientInfo, + audio_selection: AudioSelector | None = None, + video_selection: VideoSelector | None = None, + caption_selection: CaptionSelector | None = None, + live_segment_target_duration_sec: int | None = None, + live_segment_target_duration_tolerance_ms: int | None = None, + start_time_ms: int | None = None, + po_token: str | None = None, + http_retries: int | None = None, + pot_retries: int | None = None, + host_fallback_threshold: int | None = None, + max_empty_requests: int | None = None, + live_end_wait_sec: int | None = None, + live_end_segment_tolerance: int | None = None, + post_live: bool = False, + video_id: str | None = None, + retry_sleep_func: int | None = None, + expiry_threshold_sec: int | None = None, + ): + + self.logger = logger + self._urlopen = urlopen + + self.processor = SabrProcessor( + logger=logger, + video_playback_ustreamer_config=video_playback_ustreamer_config, + client_info=client_info, + audio_selection=audio_selection, + video_selection=video_selection, + caption_selection=caption_selection, + live_segment_target_duration_sec=live_segment_target_duration_sec, + live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms, + start_time_ms=start_time_ms, + po_token=po_token, + live_end_wait_sec=live_end_wait_sec, + live_end_segment_tolerance=live_end_segment_tolerance, + post_live=post_live, + video_id=video_id, + ) + self.url = server_abr_streaming_url + self.http_retries = http_retries or 10 + self.pot_retries = pot_retries or 5 + self.host_fallback_threshold = host_fallback_threshold or 8 + self.max_empty_requests = max_empty_requests or 3 + self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds + if self.expiry_threshold_sec <= 0: + raise ValueError('expiry_threshold_sec must be greater than 0') + if self.max_empty_requests <= 0: + raise ValueError('max_empty_requests must be greater than 0') + self.retry_sleep_func = retry_sleep_func + self._request_number = 0 + + # Whether we got any new (not consumed) segments in the request. + self._received_new_segments = False + self._no_new_segments_tracker = SabrStream._NoSegmentsTracker() + self._sps_retry_manager: typing.Generator | None = None + self._current_sps_retry = None + self._http_retry_manager: typing.Generator | None = None + self._current_http_retry = None + self._unknown_part_types = set() + + # Whether the current request is a result of a retry + self._is_retry = False + + self._consumed = False + self._sq_mismatch_backtrack_count = 0 + self._sq_mismatch_forward_count = 0 + + def close(self): + self._consumed = True + + def __iter__(self): + return self.iter_parts() + + @property + def url(self): + return self._url + + @url.setter + def url(self, url): + self.logger.debug(f'New URL: {url}') + if hasattr(self, '_url') and ((bn := broadcast_id_from_url(url)) != (bc := broadcast_id_from_url(self.url))): + raise SabrStreamError(f'Broadcast ID changed from {bc} to {bn}. The download will need to be restarted.') + self._url = url + if str_or_none(parse_qs(url).get('source', [None])[0]) == 'yt_live_broadcast': + self.processor.is_live = True + + def iter_parts(self): + if self._consumed: + raise SabrStreamConsumedError('SABR stream has already been consumed') + + self._http_retry_manager = None + self._sps_retry_manager = None + + def report_retry(err, count, retries, fatal=True): + if count >= self.host_fallback_threshold: + self._process_fallback_server() + RetryManager.report_retry( + err, count, retries, info=self.logger.info, + warn=lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'), + error=None if fatal else lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'), + sleep_func=self.retry_sleep_func, + ) + + def report_sps_retry(err, count, retries, fatal=True): + RetryManager.report_retry( + err, count, retries, info=self.logger.info, + warn=lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'), + error=None if fatal else lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'), + sleep_func=self.retry_sleep_func, + ) + + while not self._consumed: + if self._http_retry_manager is None: + self._http_retry_manager = iter(RetryManager(self.http_retries, report_retry)) + + if self._sps_retry_manager is None: + self._sps_retry_manager = iter(RetryManager(self.pot_retries, report_sps_retry)) + + self._current_http_retry = next(self._http_retry_manager) + self._current_sps_retry = next(self._sps_retry_manager) + + self._log_state() + + yield from self._process_expiry() + vpabr = build_vpabr_request(self.processor) + payload = protobug.dumps(vpabr) + self.logger.trace(f'Ustreamer Config: {self.processor.video_playback_ustreamer_config}') + self.logger.trace(f'Sending SABR request: {vpabr}') + + response = None + try: + response = self._urlopen( + Request( + url=self.url, + method='POST', + data=payload, + query={'rn': self._request_number}, + headers={ + 'content-type': 'application/x-protobuf', + 'accept-encoding': 'identity', + 'accept': 'application/vnd.yt-ump', + }, + ), + ) + self._request_number += 1 + except TransportError as e: + self._current_http_retry.error = e + except HTTPError as e: + # retry on 5xx errors only + if 500 <= e.status < 600: + self._current_http_retry.error = e + else: + raise SabrStreamError(f'HTTP Error: {e.status} - {e.reason}') + + if response: + try: + yield from self._parse_ump_response(response) + except TransportError as e: + self._current_http_retry.error = e + + if not response.closed: + response.close() + + self._validate_response_integrity() + self._process_sps_retry() + + if not self._current_http_retry.error: + self._http_retry_manager = None + + if not self._current_sps_retry.error: + self._sps_retry_manager = None + + retry_next_request = bool(self._current_http_retry.error or self._current_sps_retry.error) + + # We are expecting to stay in the same place for a retry + if not retry_next_request: + # Only increment request no segments number if we are not retrying + self._process_request_had_segments() + + # Calculate and apply the next playback time to skip to + yield from self._prepare_next_playback_time() + + # Request successfully processed, next request is not a retry + self._is_retry = False + else: + self._is_retry = True + + self._received_new_segments = False + + self._consumed = True + + def _process_sps_retry(self): + error = PoTokenError(missing=not self.processor.po_token) + + if self.processor.stream_protection_status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED: + # Always start retrying immediately on ATTESTATION_REQUIRED + self._current_sps_retry.error = error + return + + elif ( + self.processor.stream_protection_status == StreamProtectionStatus.Status.ATTESTATION_PENDING + and self._no_new_segments_tracker.consecutive_requests >= self.max_empty_requests + and (not self.processor.is_live or self.processor.stream_protection_status or ( + self.processor.live_metadata is not None + and self._no_new_segments_tracker.live_head_segment_started != self.processor.live_metadata.head_sequence_number) + ) + ): + # Sometimes YouTube sends no data on ATTESTATION_PENDING, so in this case we need to count retries to fail on a PO Token error. + # We only start counting retries after max_empty_requests in case of intermittent no data that we need to increase the player time on. + # For livestreams when we receive no new segments, this could be due the stream ending or ATTESTATION_PENDING. + # To differentiate the two, we check if the live head segment has changed during the time we start getting no new segments. + # xxx: not perfect detection, sometimes get a new segment we can never fetch (partial) + self._current_sps_retry.error = error + return + + def _process_request_had_segments(self): + if not self._received_new_segments: + self._no_new_segments_tracker.increment( + live_head_segment=self.processor.live_metadata.head_sequence_number if self.processor.live_metadata else None) + self.logger.trace(f'No new segments received in request {self._request_number}, count: {self._no_new_segments_tracker.consecutive_requests}') + else: + self._no_new_segments_tracker.reset() + + def _validate_response_integrity(self): + if not len(self.processor.partial_segments): + return + + msg = 'Received partial segments: ' + ', '.join( + f'{seg.format_id}: {seg.sequence_number}' + for seg in self.processor.partial_segments.values() + ) + if self.processor.is_live: + # In post live, sometimes we get a partial segment at the end of the stream that should be ignored. + # If this occurs mid-stream, other guards should prevent corruption. + if ( + self.processor.live_metadata + # TODO: generalize + and self.processor.client_abr_state.player_time_ms >= ( + self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance)) + ): + # Only log a warning if we are not near the head of a stream + self.logger.debug(msg) + else: + self.logger.warning(msg) + else: + # Should not happen for videos + self._current_http_retry.error = SabrStreamError(msg) + + self.processor.partial_segments.clear() + + def _prepare_next_playback_time(self): + # TODO: refactor and cleanup this massive function + wait_seconds = 0 + for izf in self.processor.initialized_formats.values(): + if not izf.current_segment: + continue + + # Guard: Check that the segment is not in multiple consumed ranges + # This should not happen, but if it does, we should bail + count = sum( + 1 for cr in izf.consumed_ranges + if cr.start_sequence_number <= izf.current_segment.sequence_number <= cr.end_sequence_number + ) + + if count > 1: + raise SabrStreamError(f'Segment {izf.current_segment.sequence_number} for format {izf.format_id} in {count} consumed ranges') + + # Check if there is two or more consumed ranges where the end lines up with the start of the other. + # This could happen in the case of concurrent playback. + # In this case, we should consider a seek as acceptable to the end of the other. + # Note: It is assumed a segment is only present in one consumed range - it should not be allowed in multiple (by process media header) + prev_consumed_range = next( + (cr for cr in izf.consumed_ranges if cr.end_sequence_number == izf.current_segment.sequence_number), + None, + ) + # TODO: move to processor MEDIA_END + if prev_consumed_range and len(get_cr_chain(prev_consumed_range, izf.consumed_ranges)) >= 2: + self.logger.debug(f'Found two or more consumed ranges that line up, allowing a seek for format {izf.format_id}') + izf.current_segment = None + yield MediaSeekSabrPart( + reason=MediaSeekSabrPart.Reason.CONSUMED_SEEK, + format_id=izf.format_id, + format_selector=izf.format_selector) + + enabled_initialized_formats = [izf for izf in self.processor.initialized_formats.values() if not izf.discard] + + # For each initialized format: + # 1. find the consumed format that matches player_time_ms. + # 2. find the current consumed range in sequence (in case multiple are joined together) + # For livestreams, we allow a tolerance for the segment duration as it is estimated. This tolerance should be less than the segment duration / 2. + + cr_tolerance_ms = 0 + if self.processor.is_live: + cr_tolerance_ms = self.processor.live_segment_target_duration_tolerance_ms + + current_consumed_ranges = [] + for izf in enabled_initialized_formats: + for cr in sorted(izf.consumed_ranges, key=lambda cr: cr.start_sequence_number): + if (cr.start_time_ms - cr_tolerance_ms) <= self.processor.client_abr_state.player_time_ms <= cr.start_time_ms + cr.duration_ms + (cr_tolerance_ms * 2): + chain = get_cr_chain(cr, izf.consumed_ranges) + current_consumed_ranges.append(chain[-1]) + # There should only be one chain for the current player_time_ms (including the tolerance) + break + + min_consumed_duration_ms = None + + # Get the lowest consumed range end time out of all current consumed ranges. + if current_consumed_ranges: + lowest_izf_consumed_range = min(current_consumed_ranges, key=lambda cr: cr.start_time_ms + cr.duration_ms) + min_consumed_duration_ms = lowest_izf_consumed_range.start_time_ms + lowest_izf_consumed_range.duration_ms + + if len(current_consumed_ranges) != len(enabled_initialized_formats) or min_consumed_duration_ms is None: + # Missing a consumed range for a format. + # In this case, consider player_time_ms to be our correct next time + # May happen in the case of: + # 1. A Format has not been initialized yet (can happen if response read fails) + # or + # 1. SABR_SEEK to time outside both formats consumed ranges + # 2. ONE of the formats returns data after the SABR_SEEK in that request + if min_consumed_duration_ms is None: + min_consumed_duration_ms = self.processor.client_abr_state.player_time_ms + else: + min_consumed_duration_ms = min(min_consumed_duration_ms, self.processor.client_abr_state.player_time_ms) + + # Usually provided by the server if there was no segments returned. + # We'll use this to calculate the time to wait for the next request (for live streams). + next_request_backoff_ms = (self.processor.next_request_policy and self.processor.next_request_policy.backoff_time_ms) or 0 + + request_player_time = self.processor.client_abr_state.player_time_ms + self.logger.trace(f'min consumed duration ms: {min_consumed_duration_ms}') + self.processor.client_abr_state.player_time_ms = min_consumed_duration_ms + + # Check if the latest segment is the last one of each format (if data is available) + if ( + not (self.processor.is_live and not self.processor.post_live) + and enabled_initialized_formats + and len(current_consumed_ranges) == len(enabled_initialized_formats) + and all( + ( + initialized_format.total_segments is not None + # consumed ranges are not guaranteed to be in order + and sorted( + initialized_format.consumed_ranges, + key=lambda cr: cr.end_sequence_number, + )[-1].end_sequence_number >= initialized_format.total_segments + ) + for initialized_format in enabled_initialized_formats + ) + ): + self.logger.debug('Reached last expected segment for all enabled formats, assuming end of media') + self._consumed = True + + # Check if we have exceeded the total duration of the media (if not live), + # or wait for the next segment (if live) + elif self.processor.total_duration_ms and (self.processor.client_abr_state.player_time_ms >= self.processor.total_duration_ms): + if self.processor.is_live: + self.logger.trace(f'setting player time ms ({self.processor.client_abr_state.player_time_ms}) to total duration ms ({self.processor.total_duration_ms})') + self.processor.client_abr_state.player_time_ms = self.processor.total_duration_ms + if ( + self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests + and not self._is_retry + and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec + ): + self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds, assuming end of live stream') + self._consumed = True + else: + wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec) + else: + self.logger.debug(f'End of media (player time ms {self.processor.client_abr_state.player_time_ms} >= total duration ms {self.processor.total_duration_ms})') + self._consumed = True + + # Handle receiving no new segments before end the end of the video/stream + # For videos, if exceeds max_empty_requests, this should not happen so we raise an error + # For livestreams, if we exceed max_empty_requests, and we don't have live_metadata, + # and have not received any data for a while, we can assume the stream has ended (as we don't know the head sequence number) + elif ( + # Determine if we are receiving no segments as the live stream has ended. + # Allow some tolerance the head segment may not be able to be received. + self.processor.is_live and not self.processor.post_live + and ( + getattr(self.processor.live_metadata, 'head_sequence_number', None) is None + or ( + enabled_initialized_formats + and len(current_consumed_ranges) == len(enabled_initialized_formats) + and all( + ( + initialized_format.total_segments is not None + and sorted( + initialized_format.consumed_ranges, + key=lambda cr: cr.end_sequence_number, + )[-1].end_sequence_number + >= initialized_format.total_segments - self.processor.live_end_segment_tolerance + ) + for initialized_format in enabled_initialized_formats + ) + ) + or self.processor.live_metadata.head_sequence_time_ms is None + or ( + # Sometimes we receive a partial segment at the end of the stream + # and the server seeks us to the end of the current segment. + # As our consumed range for this segment has an estimated end time, + # this may be slightly off what the server seeks to. + # This can cause the player time to be slightly outside the consumed range. + # + # Because of this, we should also check the player time against + # the head segment time using the estimated segment duration. + # xxx: consider also taking into account the max seekable timestamp + request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance) + ) + ) + ): + if ( + not self._is_retry # allow us to sleep on a retry + and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests + and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec + ): + self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds; assuming end of live stream') + self._consumed = True + elif self._no_new_segments_tracker.consecutive_requests >= 1: + # Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream + wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec) + elif ( + self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests + and not self._is_retry + ): + raise SabrStreamError('No new segments received in three consecutive requests') + + elif ( + not self.processor.is_live and next_request_backoff_ms + and self._no_new_segments_tracker.consecutive_requests >= 1 + and any(t in self.processor.sabr_contexts_to_send for t in self.processor.sabr_context_updates) + ): + wait_seconds = math.ceil(next_request_backoff_ms / 1000) + self.logger.info(f'The server is requiring yt-dlp to wait {wait_seconds} seconds before continuing due to ad enforcement') + + if wait_seconds: + self.logger.debug(f'Waiting {wait_seconds} seconds for next segment(s)') + time.sleep(wait_seconds) + + def _parse_ump_response(self, response): + self._unknown_part_types.clear() + ump = UMPDecoder(response) + for part in ump.iter_parts(): + if part.part_id == UMPPartId.MEDIA_HEADER: + yield from self._process_media_header(part) + elif part.part_id == UMPPartId.MEDIA: + yield from self._process_media(part) + elif part.part_id == UMPPartId.MEDIA_END: + yield from self._process_media_end(part) + elif part.part_id == UMPPartId.STREAM_PROTECTION_STATUS: + yield from self._process_stream_protection_status(part) + elif part.part_id == UMPPartId.SABR_REDIRECT: + self._process_sabr_redirect(part) + elif part.part_id == UMPPartId.FORMAT_INITIALIZATION_METADATA: + yield from self._process_format_initialization_metadata(part) + elif part.part_id == UMPPartId.NEXT_REQUEST_POLICY: + self._process_next_request_policy(part) + elif part.part_id == UMPPartId.LIVE_METADATA: + yield from self._process_live_metadata(part) + elif part.part_id == UMPPartId.SABR_SEEK: + yield from self._process_sabr_seek(part) + elif part.part_id == UMPPartId.SABR_ERROR: + self._process_sabr_error(part) + elif part.part_id == UMPPartId.SABR_CONTEXT_UPDATE: + self._process_sabr_context_update(part) + elif part.part_id == UMPPartId.SABR_CONTEXT_SENDING_POLICY: + self._process_sabr_context_sending_policy(part) + elif part.part_id == UMPPartId.RELOAD_PLAYER_RESPONSE: + yield from self._process_reload_player_response(part) + else: + if part.part_id not in self._IGNORED_PARTS: + self._unknown_part_types.add(part.part_id) + self._log_part(part, msg='Unhandled part type', data=part.data.read()) + + # Cancel request processing if we are going to retry + if self._current_sps_retry.error or self._current_http_retry.error: + self.logger.debug('Request processing cancelled') + return + + def _process_media_header(self, part: UMPPart): + media_header = protobug.load(part.data, MediaHeader) + self._log_part(part=part, protobug_obj=media_header) + + try: + result = self.processor.process_media_header(media_header) + if result.sabr_part: + yield result.sabr_part + except MediaSegmentMismatchError as e: + # For livestreams, the server may not know the exact segment for a given player time. + # For segments near stream head, it estimates using segment duration, which can cause off-by-one segment mismatches. + # If a segment is much longer or shorter than expected, the server may return a segment ahead or behind. + # In such cases, retry with an adjusted player time to resync. + if self.processor.is_live and e.received_sequence_number == e.expected_sequence_number - 1: + # The segment before the previous segment was possibly longer than expected. + # Move the player time forward to try to adjust for this. + self.processor.client_abr_state.player_time_ms += self.processor.live_segment_target_duration_tolerance_ms + self._sq_mismatch_forward_count += 1 + self._current_http_retry.error = e + return + elif self.processor.is_live and e.received_sequence_number == e.expected_sequence_number + 2: + # The previous segment was possibly shorter than expected + # Move the player time backwards to try to adjust for this. + self.processor.client_abr_state.player_time_ms = max(0, self.processor.client_abr_state.player_time_ms - self.processor.live_segment_target_duration_tolerance_ms) + self._sq_mismatch_backtrack_count += 1 + self._current_http_retry.error = e + return + raise e + + def _process_media(self, part: UMPPart): + header_id = read_varint(part.data) + content_length = part.size - part.data.tell() + result = self.processor.process_media(header_id, content_length, part.data) + if result.sabr_part: + yield result.sabr_part + + def _process_media_end(self, part: UMPPart): + header_id = read_varint(part.data) + self._log_part(part, msg=f'Header ID: {header_id}') + + result = self.processor.process_media_end(header_id) + if result.is_new_segment: + self._received_new_segments = True + + if result.sabr_part: + yield result.sabr_part + + def _process_live_metadata(self, part: UMPPart): + live_metadata = protobug.load(part.data, LiveMetadata) + self._log_part(part, protobug_obj=live_metadata) + yield from self.processor.process_live_metadata(live_metadata).seek_sabr_parts + + def _process_stream_protection_status(self, part: UMPPart): + sps = protobug.load(part.data, StreamProtectionStatus) + self._log_part(part, msg=f'Status: {StreamProtectionStatus.Status(sps.status).name}', protobug_obj=sps) + result = self.processor.process_stream_protection_status(sps) + if result.sabr_part: + yield result.sabr_part + + def _process_sabr_redirect(self, part: UMPPart): + sabr_redirect = protobug.load(part.data, SabrRedirect) + self._log_part(part, protobug_obj=sabr_redirect) + if not sabr_redirect.redirect_url: + self.logger.warning('Server requested to redirect to an invalid URL') + return + self.url = sabr_redirect.redirect_url + + def _process_format_initialization_metadata(self, part: UMPPart): + fmt_init_metadata = protobug.load(part.data, FormatInitializationMetadata) + self._log_part(part, protobug_obj=fmt_init_metadata) + result = self.processor.process_format_initialization_metadata(fmt_init_metadata) + if result.sabr_part: + yield result.sabr_part + + def _process_next_request_policy(self, part: UMPPart): + next_request_policy = protobug.load(part.data, NextRequestPolicy) + self._log_part(part, protobug_obj=next_request_policy) + self.processor.process_next_request_policy(next_request_policy) + + def _process_sabr_seek(self, part: UMPPart): + sabr_seek = protobug.load(part.data, SabrSeek) + self._log_part(part, protobug_obj=sabr_seek) + yield from self.processor.process_sabr_seek(sabr_seek).seek_sabr_parts + + def _process_sabr_error(self, part: UMPPart): + sabr_error = protobug.load(part.data, SabrError) + self._log_part(part, protobug_obj=sabr_error) + self._current_http_retry.error = SabrStreamError(f'SABR Protocol Error: {sabr_error}') + + def _process_sabr_context_update(self, part: UMPPart): + sabr_ctx_update = protobug.load(part.data, SabrContextUpdate) + self._log_part(part, protobug_obj=sabr_ctx_update) + self.processor.process_sabr_context_update(sabr_ctx_update) + + def _process_sabr_context_sending_policy(self, part: UMPPart): + sabr_ctx_sending_policy = protobug.load(part.data, SabrContextSendingPolicy) + self._log_part(part, protobug_obj=sabr_ctx_sending_policy) + self.processor.process_sabr_context_sending_policy(sabr_ctx_sending_policy) + + def _process_reload_player_response(self, part: UMPPart): + reload_player_response = protobug.load(part.data, ReloadPlayerResponse) + self._log_part(part, protobug_obj=reload_player_response) + yield RefreshPlayerResponseSabrPart( + reason=RefreshPlayerResponseSabrPart.Reason.SABR_RELOAD_PLAYER_RESPONSE, + reload_playback_token=reload_player_response.reload_playback_params.token, + ) + + def _process_fallback_server(self): + # Attempt to fall back to another GVS host in the case the current one fails + new_url = next_gvs_fallback_url(self.url) + if not new_url: + self.logger.debug('No more fallback hosts available') + + self.logger.warning(f'Falling back to host {urllib.parse.urlparse(new_url).netloc}') + self.url = new_url + + def _gvs_expiry(self): + return int_or_none(traverse_obj(parse_qs(self.url), ('expire', 0), get_all=False)) + + def _process_expiry(self): + expires_at = self._gvs_expiry() + + if not expires_at: + self.logger.warning( + 'No expiry timestamp found in SABR URL. Will not be able to refresh.', once=True) + return + + if expires_at - self.expiry_threshold_sec >= time.time(): + self.logger.trace(f'SABR url expires in {int(expires_at - time.time())} seconds') + return + + self.logger.debug( + f'Requesting player response refresh as SABR URL is due to expire in {self.expiry_threshold_sec} seconds') + yield RefreshPlayerResponseSabrPart(reason=RefreshPlayerResponseSabrPart.Reason.SABR_URL_EXPIRY) + + def _log_part(self, part: UMPPart, msg=None, protobug_obj=None, data=None): + if self.logger.log_level > self.logger.LogLevel.TRACE: + return + message = f'[{part.part_id.name}]: (Size {part.size})' + if protobug_obj: + message += f' Parsed: {protobug_obj}' + uf = list(unknown_fields(protobug_obj)) + if uf: + message += f' (Unknown fields: {uf})' + if msg: + message += f' {msg}' + if data: + message += f' Data: {base64.b64encode(data).decode("utf-8")}' + self.logger.trace(message.strip()) + + def _log_state(self): + # TODO: refactor + if self.logger.log_level > self.logger.LogLevel.DEBUG: + return + + if self.processor.is_live and self.processor.post_live: + live_message = f'post_live ({self.processor.live_segment_target_duration_sec}s)' + elif self.processor.is_live: + live_message = f'live ({self.processor.live_segment_target_duration_sec}s)' + else: + live_message = 'not_live' + + if self.processor.is_live: + live_message += ' bid:' + str_or_none(broadcast_id_from_url(self.url)) + + consumed_ranges_message = ( + ', '.join( + f'{izf.format_id.itag}:' + + ', '.join( + f'{cf.start_sequence_number}-{cf.end_sequence_number} ' + f'({cf.start_time_ms}-' + f'{cf.start_time_ms + cf.duration_ms})' + for cf in izf.consumed_ranges + ) + for izf in self.processor.initialized_formats.values() + ) or 'none' + ) + + izf_parts = [] + for izf in self.processor.initialized_formats.values(): + s = f'{izf.format_id.itag}' + if izf.discard: + s += 'd' + p = [] + if izf.total_segments: + p.append(f'{izf.total_segments}') + if izf.sequence_lmt is not None: + p.append(f'lmt={izf.sequence_lmt}') + if p: + s += ('(' + ','.join(p) + ')') + izf_parts.append(s) + + initialized_formats_message = ', '.join(izf_parts) or 'none' + + unknown_part_message = '' + if self._unknown_part_types: + unknown_part_message = 'unkpt:' + ', '.join(part_type.name for part_type in self._unknown_part_types) + + sabr_context_update_msg = '' + if self.processor.sabr_context_updates: + sabr_context_update_msg += 'cu:[' + ','.join( + f'{k}{"(n)" if k not in self.processor.sabr_contexts_to_send else ""}' + for k in self.processor.sabr_context_updates + ) + ']' + + self.logger.debug( + "[SABR State] " + f"v:{self.processor.video_id or 'unknown'} " + f"c:{self.processor.client_info.client_name.name} " + f"t:{self.processor.client_abr_state.player_time_ms} " + f"td:{self.processor.total_duration_ms if self.processor.total_duration_ms else 'n/a'} " + f"h:{urllib.parse.urlparse(self.url).netloc} " + f"exp:{dt.timedelta(seconds=int(self._gvs_expiry() - time.time())) if self._gvs_expiry() else 'n/a'} " + f"rn:{self._request_number} rnns:{self._no_new_segments_tracker.consecutive_requests} " + f"lnns:{self._no_new_segments_tracker.live_head_segment_started or 'n/a'} " + f"mmb:{self._sq_mismatch_backtrack_count} mmf:{self._sq_mismatch_forward_count} " + f"pot:{'Y' if self.processor.po_token else 'N'} " + f"sps:{self.processor.stream_protection_status.name if self.processor.stream_protection_status else 'n/a'} " + f"{live_message} " + f"if:[{initialized_formats_message}] " + f"cr:[{consumed_ranges_message}] " + f"{sabr_context_update_msg} " + f"{unknown_part_message}", + ) diff --git a/yt_dlp/extractor/youtube/_streaming/sabr/utils.py b/yt_dlp/extractor/youtube/_streaming/sabr/utils.py new file mode 100644 index 000000000..ee0358b16 --- /dev/null +++ b/yt_dlp/extractor/youtube/_streaming/sabr/utils.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import contextlib +import math +import urllib.parse + +from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange +from yt_dlp.utils import int_or_none, orderedSet, parse_qs, str_or_none, update_url_query + + +def get_cr_chain(start_consumed_range: ConsumedRange, consumed_ranges: list[ConsumedRange]) -> list[ConsumedRange]: + # TODO: unit test + # Return the continuous consumed range chain starting from the given consumed range + # Note: It is assumed a segment is only present in one consumed range - it should not be allowed in multiple (by process media header) + chain = [start_consumed_range] + for cr in sorted(consumed_ranges, key=lambda r: r.start_sequence_number): + if cr.start_sequence_number == chain[-1].end_sequence_number + 1: + chain.append(cr) + elif cr.start_sequence_number > chain[-1].end_sequence_number + 1: + break + return chain + + +def next_gvs_fallback_url(gvs_url): + # TODO: unit test + qs = parse_qs(gvs_url) + gvs_url_parsed = urllib.parse.urlparse(gvs_url) + fvip = int_or_none(qs.get('fvip', [None])[0]) + mvi = int_or_none(qs.get('mvi', [None])[0]) + mn = str_or_none(qs.get('mn', [None])[0], default='').split(',') + fallback_count = int_or_none(qs.get('fallback_count', ['0'])[0], default=0) + + hosts = [] + + def build_host(current_host, f, m): + rr = current_host.startswith('rr') + if f is None or m is None: + return None + return ('rr' if rr else 'r') + str(f) + '---' + m + '.googlevideo.com' + + original_host = build_host(gvs_url_parsed.netloc, mvi, mn[0]) + + # Order of fallback hosts: + # 1. Fallback host in url (mn[1] + fvip) + # 2. Fallback hosts brute forced (this usually contains the original host) + for mn_entry in reversed(mn): + for fvip_entry in orderedSet([fvip, 1, 2, 3, 4, 5]): + fallback_host = build_host(gvs_url_parsed.netloc, fvip_entry, mn_entry) + if fallback_host and fallback_host not in hosts: + hosts.append(fallback_host) + + if not hosts or len(hosts) == 1: + return None + + # if first fallback, anchor to start of list so we start with the known fallback hosts + # Sometimes we may get a SABR_REDIRECT after a fallback, which gives a new host with new fallbacks. + # In this case, the original host indicated by the url params would match the current host + current_host_index = -1 + if fallback_count > 0 and gvs_url_parsed.netloc != original_host: + with contextlib.suppress(ValueError): + current_host_index = hosts.index(gvs_url_parsed.netloc) + + def next_host(idx, h): + return h[(idx + 1) % len(h)] + + new_host = next_host(current_host_index + 1, hosts) + # If the current URL only has one fallback host, then the first fallback host is the same as the current host. + if new_host == gvs_url_parsed.netloc: + new_host = next_host(current_host_index + 2, hosts) + + # TODO: do not return new_host if it still matches the original host + return update_url_query( + gvs_url_parsed._replace(netloc=new_host).geturl(), {'fallback_count': fallback_count + 1}) + + +def ticks_to_ms(time_ticks: int, timescale: int): + if time_ticks is None or timescale is None: + return None + return math.ceil((time_ticks / timescale) * 1000) + + +def broadcast_id_from_url(url: str) -> str | None: + return str_or_none(parse_qs(url).get('id', [None])[0]) diff --git a/yt_dlp/extractor/youtube/_streaming/ump.py b/yt_dlp/extractor/youtube/_streaming/ump.py new file mode 100644 index 000000000..a980c6eea --- /dev/null +++ b/yt_dlp/extractor/youtube/_streaming/ump.py @@ -0,0 +1,116 @@ +import dataclasses +import enum +import io + + +class UMPPartId(enum.IntEnum): + UNKNOWN = -1 + ONESIE_HEADER = 10 + ONESIE_DATA = 11 + ONESIE_ENCRYPTED_MEDIA = 12 + MEDIA_HEADER = 20 + MEDIA = 21 + MEDIA_END = 22 + LIVE_METADATA = 31 + HOSTNAME_CHANGE_HINT = 32 + LIVE_METADATA_PROMISE = 33 + LIVE_METADATA_PROMISE_CANCELLATION = 34 + NEXT_REQUEST_POLICY = 35 + USTREAMER_VIDEO_AND_FORMAT_DATA = 36 + FORMAT_SELECTION_CONFIG = 37 + USTREAMER_SELECTED_MEDIA_STREAM = 38 + FORMAT_INITIALIZATION_METADATA = 42 + SABR_REDIRECT = 43 + SABR_ERROR = 44 + SABR_SEEK = 45 + RELOAD_PLAYER_RESPONSE = 46 + PLAYBACK_START_POLICY = 47 + ALLOWED_CACHED_FORMATS = 48 + START_BW_SAMPLING_HINT = 49 + PAUSE_BW_SAMPLING_HINT = 50 + SELECTABLE_FORMATS = 51 + REQUEST_IDENTIFIER = 52 + REQUEST_CANCELLATION_POLICY = 53 + ONESIE_PREFETCH_REJECTION = 54 + TIMELINE_CONTEXT = 55 + REQUEST_PIPELINING = 56 + SABR_CONTEXT_UPDATE = 57 + STREAM_PROTECTION_STATUS = 58 + SABR_CONTEXT_SENDING_POLICY = 59 + LAWNMOWER_POLICY = 60 + SABR_ACK = 61 + END_OF_TRACK = 62 + CACHE_LOAD_POLICY = 63 + LAWNMOWER_MESSAGING_POLICY = 64 + PREWARM_CONNECTION = 65 + PLAYBACK_DEBUG_INFO = 66 + SNACKBAR_MESSAGE = 67 + + @classmethod + def _missing_(cls, value): + return cls.UNKNOWN + + +@dataclasses.dataclass +class UMPPart: + part_id: UMPPartId + size: int + data: io.BufferedIOBase + + +class UMPDecoder: + def __init__(self, fp: io.BufferedIOBase): + self.fp = fp + + def iter_parts(self): + while not self.fp.closed: + part_type = read_varint(self.fp) + if part_type == -1 and not self.fp.closed: + self.fp.close() + + if self.fp.closed: + break + part_size = read_varint(self.fp) + if part_size == -1 and not self.fp.closed: + self.fp.close() + + if self.fp.closed: + raise EOFError('Unexpected EOF while reading part size') + + part_data = self.fp.read(part_size) + # In the future, we could allow streaming the part data. + # But we will need to ensure that each part is completely read before continuing. + yield UMPPart(UMPPartId(part_type), part_size, io.BytesIO(part_data)) + + +def read_varint(fp: io.BufferedIOBase) -> int: + # https://web.archive.org/web/20250430054327/https://github.com/gsuberland/UMP_Format/blob/main/UMP_Format.md + # https://web.archive.org/web/20250429151021/https://github.com/davidzeng0/innertube/blob/main/googlevideo/ump.md + byte = fp.read(1) + if not byte: + # Expected EOF + return -1 + + prefix = byte[0] + size = varint_size(prefix) + result = 0 + shift = 0 + + if size != 5: + shift = 8 - size + mask = (1 << shift) - 1 + result |= prefix & mask + + for _ in range(1, size): + next_byte = fp.read(1) + if not next_byte: + return -1 + byte_int = next_byte[0] + result |= byte_int << shift + shift += 8 + + return result + + +def varint_size(byte: int) -> int: + return 1 if byte < 128 else 2 if byte < 192 else 3 if byte < 224 else 4 if byte < 240 else 5 diff --git a/yt_dlp/extractor/youtube/_video.py b/yt_dlp/extractor/youtube/_video.py index 55ebdce1b..efb71356c 100644 --- a/yt_dlp/extractor/youtube/_video.py +++ b/yt_dlp/extractor/youtube/_video.py @@ -26,6 +26,7 @@ from .pot._director import initialize_pot_director from .pot.provider import PoTokenContext, PoTokenRequest from ..openload import PhantomJSwrapper +from ...dependencies import protobug from ...jsinterp import JSInterpreter from ...networking.exceptions import HTTPError from ...utils import ( @@ -72,6 +73,7 @@ STREAMING_DATA_CLIENT_NAME = '__yt_dlp_client' STREAMING_DATA_INITIAL_PO_TOKEN = '__yt_dlp_po_token' +STREAMING_DATA_FETCH_PO_TOKEN = '__yt_dlp_fetch_po_token' STREAMING_DATA_FETCH_SUBS_PO_TOKEN = '__yt_dlp_fetch_subs_po_token' STREAMING_DATA_INNERTUBE_CONTEXT = '__yt_dlp_innertube_context' @@ -1824,7 +1826,8 @@ def _real_initialize(self): def _prepare_live_from_start_formats(self, formats, video_id, live_start_time, url, webpage_url, smuggled_data, is_live): lock = threading.Lock() start_time = time.time() - formats = [f for f in formats if f.get('is_from_start')] + # TODO: only include dash formats + formats = [f for f in formats if f.get('is_from_start') and f.get('protocol') != 'sabr'] def refetch_manifest(format_id, delay): nonlocal formats, start_time, is_live @@ -1836,7 +1839,7 @@ def refetch_manifest(format_id, delay): microformats = traverse_obj( prs, (..., 'microformat', 'playerMicroformatRenderer'), expected_type=dict) - _, live_status, _, formats, _ = self._list_formats(video_id, microformats, video_details, prs, player_url) + _, live_status, formats, _ = self._list_formats(video_id, microformats, video_details, prs, player_url) is_live = live_status == 'is_live' start_time = time.time() @@ -2812,16 +2815,25 @@ def _get_checkok_params(): return {'contentCheckOk': True, 'racyCheckOk': True} @classmethod - def _generate_player_context(cls, sts=None): - context = { + def _generate_player_context(cls, sts=None, reload_playback_token=None): + content_playback_context = { 'html5Preference': 'HTML5_PREF_WANTS', + 'isInlinePlaybackNoAd': True, } if sts is not None: - context['signatureTimestamp'] = sts + content_playback_context['signatureTimestamp'] = sts + + playback_context = { + 'contentPlaybackContext': content_playback_context, + } + + if reload_playback_token: + playback_context['reloadPlaybackContext'] = { + 'reloadPlaybackParams': {'token': reload_playback_token}, + } + return { - 'playbackContext': { - 'contentPlaybackContext': context, - }, + 'playbackContext': playback_context, **cls._get_checkok_params(), } @@ -2863,7 +2875,7 @@ def _get_config_po_token(self, client: str, context: _PoTokenContext): def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, visitor_data=None, data_sync_id=None, session_index=None, player_url=None, video_id=None, webpage=None, - required=False, **kwargs): + required=False, bypass_cache=None, **kwargs): """ Fetch a PO Token for a given client and context. This function will validate required parameters for a given context and client. @@ -2879,6 +2891,7 @@ def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, @param video_id: video ID. @param webpage: video webpage. @param required: Whether the PO Token is required (i.e. try to fetch unless policy is "never"). + @param bypass_cache: Whether to bypass the cache. @param kwargs: Additional arguments to pass down. May be more added in the future. @return: The fetched PO Token. None if it could not be fetched. """ @@ -2900,7 +2913,7 @@ def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, return config_po_token = self._get_config_po_token(client, context) - if config_po_token: + if config_po_token and not bypass_cache: # GVS WebPO token is bound to data_sync_id / account Session ID when logged in. if player_url and context == _PoTokenContext.GVS and not data_sync_id and self.is_authenticated: self.report_warning( @@ -2927,6 +2940,7 @@ def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, player_url=player_url, video_id=video_id, video_webpage=webpage, + bypass_cache=bypass_cache, required=required, **kwargs, ) @@ -2984,7 +2998,7 @@ def _fetch_po_token(self, client, **kwargs): request_verify_tls=not self.get_param('nocheckcertificate'), request_source_address=self.get_param('source_address'), - bypass_cache=False, + bypass_cache=kwargs.get('bypass_cache', False), ) return self._pot_director.get_po_token(pot_request) @@ -3005,7 +3019,7 @@ def _is_agegated(player_response): def _is_unplayable(player_response): return traverse_obj(player_response, ('playabilityStatus', 'status')) == 'UNPLAYABLE' - def _extract_player_response(self, client, video_id, master_ytcfg, player_ytcfg, player_url, initial_pr, visitor_data, data_sync_id, po_token): + def _extract_player_response(self, client, video_id, master_ytcfg, player_ytcfg, player_url, initial_pr, visitor_data, data_sync_id, po_token, reload_playback_token): headers = self.generate_api_headers( ytcfg=player_ytcfg, default_client=client, @@ -3034,7 +3048,8 @@ def _extract_player_response(self, client, video_id, master_ytcfg, player_ytcfg, yt_query['serviceIntegrityDimensions'] = {'poToken': po_token} sts = self._extract_signature_timestamp(video_id, player_url, master_ytcfg, fatal=False) if player_url else None - yt_query.update(self._generate_player_context(sts)) + yt_query.update(self._generate_player_context(sts, reload_playback_token)) + return self._extract_response( item_id=video_id, ep='player', query=yt_query, ytcfg=player_ytcfg, headers=headers, fatal=True, @@ -3087,7 +3102,7 @@ def _invalid_player_response(self, pr, video_id): if (pr_id := traverse_obj(pr, ('videoDetails', 'videoId'))) != video_id: return pr_id - def _extract_player_responses(self, clients, video_id, webpage, master_ytcfg, smuggled_data): + def _extract_player_responses(self, clients, video_id, webpage, master_ytcfg, reload_playback_token): initial_pr = None if webpage: initial_pr = self._search_json( @@ -3136,7 +3151,7 @@ def append_client(*client_names): player_url = self._download_player_url(video_id) tried_iframe_fallback = True - pr = initial_pr if client == 'web' else None + pr = initial_pr if client == 'web' and not reload_playback_token else None visitor_data = visitor_data or self._extract_visitor_data(master_ytcfg, initial_pr, player_ytcfg) data_sync_id = data_sync_id or self._extract_data_sync_id(master_ytcfg, initial_pr, player_ytcfg) @@ -3156,8 +3171,13 @@ def append_client(*client_names): player_po_token = None if pr else self.fetch_po_token( context=_PoTokenContext.PLAYER, **fetch_po_token_args) - gvs_po_token = self.fetch_po_token( - context=_PoTokenContext.GVS, **fetch_po_token_args) + fetch_gvs_po_token_func = functools.partial( + self.fetch_po_token, + context=_PoTokenContext.GVS, + **fetch_po_token_args, + ) + + gvs_po_token = fetch_gvs_po_token_func() fetch_subs_po_token_func = functools.partial( self.fetch_po_token, @@ -3200,7 +3220,8 @@ def append_client(*client_names): initial_pr=initial_pr, visitor_data=visitor_data, data_sync_id=data_sync_id, - po_token=player_po_token) + po_token=player_po_token, + reload_playback_token=reload_playback_token) except ExtractorError as e: self.report_warning(e) continue @@ -3215,9 +3236,12 @@ def append_client(*client_names): sd[STREAMING_DATA_INITIAL_PO_TOKEN] = gvs_po_token sd[STREAMING_DATA_INNERTUBE_CONTEXT] = innertube_context sd[STREAMING_DATA_FETCH_SUBS_PO_TOKEN] = fetch_subs_po_token_func + sd[STREAMING_DATA_FETCH_PO_TOKEN] = fetch_gvs_po_token_func for f in traverse_obj(sd, (('formats', 'adaptiveFormats'), ..., {dict})): f[STREAMING_DATA_CLIENT_NAME] = client f[STREAMING_DATA_INITIAL_PO_TOKEN] = gvs_po_token + f[STREAMING_DATA_INNERTUBE_CONTEXT] = innertube_context + f[STREAMING_DATA_FETCH_PO_TOKEN] = fetch_gvs_po_token_func if deprioritize_pr: deprioritized_prs.append(pr) else: @@ -3295,12 +3319,54 @@ def _report_pot_subtitles_skipped(self, video_id, client_name, msg=None): else: self.report_warning(msg, only_once=True) - def _extract_formats_and_subtitles(self, streaming_data, video_id, player_url, live_status, duration): + def _process_n_param(self, gvs_url, video_id, player_url, proto='https'): + query = parse_qs(gvs_url) + if query.get('n'): + try: + decrypt_nsig = self._cached(self._decrypt_nsig, 'nsig', query['n'][0]) + return update_url_query(gvs_url, { + 'n': decrypt_nsig(query['n'][0], video_id, player_url), + }) + except ExtractorError as e: + if player_url: + self.report_warning( + f'nsig extraction failed: Some {proto} formats may be missing\n' + f' n = {query["n"][0]} ; player = {player_url}\n' + f' {bug_reports_message(before="")}', + video_id=video_id, only_once=True) + self.write_debug(e, only_once=True) + else: + self.report_warning( + f'Cannot decrypt nsig without player_url: Some {proto} formats may be missing', + video_id=video_id, only_once=True) + return None + return gvs_url + + def _reload_sabr_config(self, video_id, client_name, reload_playback_token): + # xxx: may also update client info? + url = 'https://www.youtube.com/watch?v=' + video_id + _, _, prs, player_url = self._download_player_responses(url, {}, video_id, url, reload_playback_token) + video_details = traverse_obj(prs, (..., 'videoDetails'), expected_type=dict) + microformats = traverse_obj( + prs, (..., 'microformat', 'playerMicroformatRenderer'), + expected_type=dict) + _, live_status, formats, _ = self._list_formats(video_id, microformats, video_details, prs, player_url) + + for f in formats: + if f.get('protocol') == 'sabr': + sabr_config = f['_sabr_config'] + if sabr_config['client_name'] == client_name: + return f['url'], sabr_config['video_playback_ustreamer_config'] + + raise ExtractorError('No SABR formats found', expected=True) + + def _extract_formats_and_subtitles(self, video_id, player_responses, player_url, live_status, duration): CHUNK_SIZE = 10 << 20 PREFERRED_LANG_VALUE = 10 original_language = None itags, stream_ids = collections.defaultdict(set), [] itag_qualities, res_qualities = {}, {0: None} + subtitles = {} q = qualities([ # Normally tiny is the smallest video-only formats. But # audio-only formats with unknown quality may get tagged as tiny @@ -3308,7 +3374,6 @@ def _extract_formats_and_subtitles(self, streaming_data, video_id, player_url, l 'audio_quality_ultralow', 'audio_quality_low', 'audio_quality_medium', 'audio_quality_high', # Audio only formats 'small', 'medium', 'large', 'hd720', 'hd1080', 'hd1440', 'hd2160', 'hd2880', 'highres', ]) - streaming_formats = traverse_obj(streaming_data, (..., ('formats', 'adaptiveFormats'), ...)) format_types = self._configuration_arg('formats') all_formats = 'duplicate' in format_types if self._configuration_arg('include_duplicate_formats'): @@ -3323,271 +3388,331 @@ def build_fragments(f): }), } for range_start in range(0, f['filesize'], CHUNK_SIZE)) - for fmt in streaming_formats: - client_name = fmt[STREAMING_DATA_CLIENT_NAME] - if fmt.get('targetDurationSec'): + for pr in player_responses: + streaming_data = traverse_obj(pr, 'streamingData') + if not streaming_data: continue + client_name = streaming_data.get(STREAMING_DATA_CLIENT_NAME) + po_token = streaming_data.get(STREAMING_DATA_INITIAL_PO_TOKEN) + fetch_po_token = streaming_data.get(STREAMING_DATA_FETCH_PO_TOKEN) + innertube_context = streaming_data.get(STREAMING_DATA_INNERTUBE_CONTEXT) + streaming_formats = traverse_obj(streaming_data, (('formats', 'adaptiveFormats'), ...)) - itag = str_or_none(fmt.get('itag')) - audio_track = fmt.get('audioTrack') or {} - stream_id = (itag, audio_track.get('id'), fmt.get('isDrc')) - if not all_formats: - if stream_id in stream_ids: - continue + def get_stream_id(fmt_stream): + return str_or_none(fmt_stream.get('itag')), traverse_obj(fmt_stream, 'audioTrack', 'id'), fmt_stream.get('isDrc') - quality = fmt.get('quality') - height = int_or_none(fmt.get('height')) - if quality == 'tiny' or not quality: - quality = fmt.get('audioQuality', '').lower() or quality - # The 3gp format (17) in android client has a quality of "small", - # but is actually worse than other formats - if itag == '17': - quality = 'tiny' - if quality: - if itag: - itag_qualities[itag] = quality - if height: - res_qualities[height] = quality + def process_format_stream(fmt_stream, proto): + nonlocal itag_qualities, res_qualities, original_language + itag = str_or_none(fmt_stream.get('itag')) + audio_track = fmt_stream.get('audioTrack') or {} + quality = fmt_stream.get('quality') + height = int_or_none(fmt_stream.get('height')) + if quality == 'tiny' or not quality: + quality = fmt_stream.get('audioQuality', '').lower() or quality + # The 3gp format (17) in android client has a quality of "small", + # but is actually worse than other formats + if itag == '17': + quality = 'tiny' + if quality: + if itag: + itag_qualities[itag] = quality + if height: + res_qualities[height] = quality - display_name = audio_track.get('displayName') or '' - is_original = 'original' in display_name.lower() - is_descriptive = 'descriptive' in display_name.lower() - is_default = audio_track.get('audioIsDefault') - language_code = audio_track.get('id', '').split('.')[0] - if language_code and (is_original or (is_default and not original_language)): - original_language = language_code + display_name = audio_track.get('displayName') or '' + is_original = 'original' in display_name.lower() + is_descriptive = 'descriptive' in display_name.lower() + is_default = audio_track.get('audioIsDefault') + language_code = audio_track.get('id', '').split('.')[0] + if language_code and (is_original or (is_default and not original_language)): + original_language = language_code - has_drm = bool(fmt.get('drmFamilies')) + has_drm = bool(fmt_stream.get('drmFamilies')) - # FORMAT_STREAM_TYPE_OTF(otf=1) requires downloading the init fragment - # (adding `&sq=0` to the URL) and parsing emsg box to determine the - # number of fragment that would subsequently requested with (`&sq=N`) - if fmt.get('type') == 'FORMAT_STREAM_TYPE_OTF' and not has_drm: - continue - - if has_drm: - msg = f'Some {client_name} client https formats have been skipped as they are DRM protected. ' - if client_name == 'tv': - msg += ( - f'{"Your account" if self.is_authenticated else "The current session"} may have ' - f'an experiment that applies DRM to all videos on the tv client. ' - f'See https://github.com/yt-dlp/yt-dlp/issues/12563 for more details.' - ) - self.report_warning(msg, video_id, only_once=True) - - fmt_url = fmt.get('url') - if not fmt_url: - sc = urllib.parse.parse_qs(fmt.get('signatureCipher')) - fmt_url = url_or_none(try_get(sc, lambda x: x['url'][0])) - encrypted_sig = try_get(sc, lambda x: x['s'][0]) - if not all((sc, fmt_url, player_url, encrypted_sig)): - msg = f'Some {client_name} client https formats have been skipped as they are missing a url. ' - if client_name == 'web': - msg += 'YouTube is forcing SABR streaming for this client. ' - else: + if has_drm: + msg = f'Some {client_name} client {proto} formats have been skipped as they are DRM protected. ' + if client_name == 'tv': msg += ( - f'YouTube may have enabled the SABR-only or Server-Side Ad Placement experiment for ' - f'{"your account" if self.is_authenticated else "the current session"}. ' + f'{"Your account" if self.is_authenticated else "The current session"} may have ' + f'an experiment that applies DRM to all videos on the tv client. ' + f'See https://github.com/yt-dlp/yt-dlp/issues/12563 for more details.' ) - msg += 'See https://github.com/yt-dlp/yt-dlp/issues/12482 for more details' self.report_warning(msg, video_id, only_once=True) - continue - try: - fmt_url += '&{}={}'.format( - traverse_obj(sc, ('sp', -1)) or 'signature', - self._decrypt_signature(encrypted_sig, video_id, player_url), - ) - except ExtractorError as e: + + tbr = float_or_none(fmt_stream.get('averageBitrate') or fmt_stream.get('bitrate'), 1000) + format_duration = traverse_obj(fmt_stream, ('approxDurationMs', {float_or_none(scale=1000)})) + # Some formats may have much smaller duration than others (possibly damaged during encoding) + # E.g. 2-nOtRESiUc Ref: https://github.com/yt-dlp/yt-dlp/issues/2823 + # Make sure to avoid false positives with small duration differences. + # E.g. __2ABJjxzNo, ySuUZEjARPY + is_damaged = try_call(lambda: format_duration < duration // 2) + if is_damaged: self.report_warning( - f'Signature extraction failed: Some formats may be missing\n' - f' player = {player_url}\n' - f' {bug_reports_message(before="")}', - video_id=video_id, only_once=True) - self.write_debug( - f'{video_id}: Signature extraction failure info:\n' - f' encrypted sig = {encrypted_sig}\n' - f' player = {player_url}') - self.write_debug(e, only_once=True) - continue + f'Some {client_name} client {proto} formats are possibly damaged. They will be deprioritized', video_id, only_once=True) - query = parse_qs(fmt_url) - if query.get('n'): - try: - decrypt_nsig = self._cached(self._decrypt_nsig, 'nsig', query['n'][0]) - fmt_url = update_url_query(fmt_url, { - 'n': decrypt_nsig(query['n'][0], video_id, player_url), - }) - except ExtractorError as e: - if player_url: - self.report_warning( - f'nsig extraction failed: Some formats may be missing\n' - f' n = {query["n"][0]} ; player = {player_url}\n' - f' {bug_reports_message(before="")}', - video_id=video_id, only_once=True) - self.write_debug(e, only_once=True) - else: - self.report_warning( - 'Cannot decrypt nsig without player_url: Some formats may be missing', - video_id=video_id, only_once=True) - continue + # Clients that require PO Token return videoplayback URLs that may return 403 + require_po_token = ( + not po_token + and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS'] + and itag not in ['18']) # these formats do not require PO Token - tbr = float_or_none(fmt.get('averageBitrate') or fmt.get('bitrate'), 1000) - format_duration = traverse_obj(fmt, ('approxDurationMs', {float_or_none(scale=1000)})) - # Some formats may have much smaller duration than others (possibly damaged during encoding) - # E.g. 2-nOtRESiUc Ref: https://github.com/yt-dlp/yt-dlp/issues/2823 - # Make sure to avoid false positives with small duration differences. - # E.g. __2ABJjxzNo, ySuUZEjARPY - is_damaged = try_call(lambda: format_duration < duration // 2) - if is_damaged: - self.report_warning( - 'Some formats are possibly damaged. They will be deprioritized', video_id, only_once=True) - - po_token = fmt.get(STREAMING_DATA_INITIAL_PO_TOKEN) - - if po_token: - fmt_url = update_url_query(fmt_url, {'pot': po_token}) - - # Clients that require PO Token return videoplayback URLs that may return 403 - require_po_token = ( - not po_token - and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS'] - and itag not in ['18']) # these formats do not require PO Token - - if require_po_token and 'missing_pot' not in self._configuration_arg('formats'): - self._report_pot_format_skipped(video_id, client_name, 'https') - continue - - name = fmt.get('qualityLabel') or quality.replace('audio_quality_', '') or '' - fps = int_or_none(fmt.get('fps')) or 0 - dct = { - 'asr': int_or_none(fmt.get('audioSampleRate')), - 'filesize': int_or_none(fmt.get('contentLength')), - 'format_id': f'{itag}{"-drc" if fmt.get("isDrc") else ""}', - 'format_note': join_nonempty( - join_nonempty(display_name, is_default and ' (default)', delim=''), - name, fmt.get('isDrc') and 'DRC', - try_get(fmt, lambda x: x['projectionType'].replace('RECTANGULAR', '').lower()), - try_get(fmt, lambda x: x['spatialAudioType'].replace('SPATIAL_AUDIO_TYPE_', '').lower()), - is_damaged and 'DAMAGED', require_po_token and 'MISSING POT', - (self.get_param('verbose') or all_formats) and short_client_name(client_name), - delim=', '), - # Format 22 is likely to be damaged. See https://github.com/yt-dlp/yt-dlp/issues/3372 - 'source_preference': (-5 if itag == '22' else -1) + (100 if 'Premium' in name else 0), - 'fps': fps if fps > 1 else None, # For some formats, fps is wrongly returned as 1 - 'audio_channels': fmt.get('audioChannels'), - 'height': height, - 'quality': q(quality) - bool(fmt.get('isDrc')) / 2, - 'has_drm': has_drm, - 'tbr': tbr, - 'filesize_approx': filesize_from_tbr(tbr, format_duration), - 'url': fmt_url, - 'width': int_or_none(fmt.get('width')), - 'language': join_nonempty(language_code, 'desc' if is_descriptive else '') or None, - 'language_preference': PREFERRED_LANG_VALUE if is_original else 5 if is_default else -10 if is_descriptive else -1, - # Strictly de-prioritize damaged and 3gp formats - 'preference': -10 if is_damaged else -2 if itag == '17' else None, - } - mime_mobj = re.match( - r'((?:[^/]+)/(?:[^;]+))(?:;\s*codecs="([^"]+)")?', fmt.get('mimeType') or '') - if mime_mobj: - dct['ext'] = mimetype2ext(mime_mobj.group(1)) - dct.update(parse_codecs(mime_mobj.group(2))) - if itag: - itags[itag].add(('https', dct.get('language'))) - stream_ids.append(stream_id) - single_stream = 'none' in (dct.get('acodec'), dct.get('vcodec')) - if single_stream and dct.get('ext'): - dct['container'] = dct['ext'] + '_dash' - - if (all_formats or 'dashy' in format_types) and dct['filesize']: - yield { - **dct, - 'format_id': f'{dct["format_id"]}-dashy' if all_formats else dct['format_id'], - 'protocol': 'http_dash_segments', - 'fragments': build_fragments(dct), - } - if all_formats or 'dashy' not in format_types: - dct['downloader_options'] = {'http_chunk_size': CHUNK_SIZE} - yield dct - - needs_live_processing = self._needs_live_processing(live_status, duration) - skip_bad_formats = 'incomplete' not in format_types - if self._configuration_arg('include_incomplete_formats'): - skip_bad_formats = False - self._downloader.deprecated_feature('[youtube] include_incomplete_formats extractor argument is deprecated. ' - 'Use formats=incomplete extractor argument instead') - - skip_manifests = set(self._configuration_arg('skip')) - if (not self.get_param('youtube_include_hls_manifest', True) - or needs_live_processing == 'is_live' # These will be filtered out by YoutubeDL anyway - or (needs_live_processing and skip_bad_formats)): - skip_manifests.add('hls') - - if not self.get_param('youtube_include_dash_manifest', True): - skip_manifests.add('dash') - if self._configuration_arg('include_live_dash'): - self._downloader.deprecated_feature('[youtube] include_live_dash extractor argument is deprecated. ' - 'Use formats=incomplete extractor argument instead') - elif skip_bad_formats and live_status == 'is_live' and needs_live_processing != 'is_live': - skip_manifests.add('dash') - - def process_manifest_format(f, proto, client_name, itag, po_token): - key = (proto, f.get('language')) - if not all_formats and key in itags[itag]: - return False - - if f.get('source_preference') is None: - f['source_preference'] = -1 - - # Clients that require PO Token return videoplayback URLs that may return 403 - # hls does not currently require PO Token - if ( - not po_token - and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS'] - and proto != 'hls' - ): - if 'missing_pot' not in self._configuration_arg('formats'): + if require_po_token and 'missing_pot' not in self._configuration_arg('formats'): self._report_pot_format_skipped(video_id, client_name, proto) + return None + + name = fmt_stream.get('qualityLabel') or quality.replace('audio_quality_', '') or '' + fps = int_or_none(fmt_stream.get('fps')) or 0 + dct = { + 'asr': int_or_none(fmt_stream.get('audioSampleRate')), + 'filesize': int_or_none(fmt_stream.get('contentLength')), + 'format_id': f'{itag}{"-drc" if fmt_stream.get("isDrc") else ""}', + 'format_note': join_nonempty( + join_nonempty(display_name, is_default and ' (default)', delim=''), + name, fmt_stream.get('isDrc') and 'DRC', + try_get(fmt_stream, lambda x: x['projectionType'].replace('RECTANGULAR', '').lower()), + try_get(fmt_stream, lambda x: x['spatialAudioType'].replace('SPATIAL_AUDIO_TYPE_', '').lower()), + is_damaged and 'DAMAGED', require_po_token and 'MISSING POT', + (self.get_param('verbose') or all_formats) and short_client_name(client_name), + delim=', '), + # Format 22 is likely to be damaged. See https://github.com/yt-dlp/yt-dlp/issues/3372 + 'source_preference': (-5 if itag == '22' else -1) + (100 if 'Premium' in name else 0), + 'fps': fps if fps > 1 else None, # For some formats, fps is wrongly returned as 1 + 'audio_channels': fmt_stream.get('audioChannels'), + 'height': height, + 'quality': q(quality) - bool(fmt_stream.get('isDrc')) / 2, + 'has_drm': has_drm, + 'tbr': tbr, + 'filesize_approx': filesize_from_tbr(tbr, format_duration), + 'width': int_or_none(fmt_stream.get('width')), + 'language': join_nonempty(language_code, 'desc' if is_descriptive else '') or None, + 'language_preference': PREFERRED_LANG_VALUE if is_original else 5 if is_default else -10 if is_descriptive else -1, + # Strictly de-prioritize damaged and 3gp formats + 'preference': -10 if is_damaged else -2 if itag == '17' else None, + } + mime_mobj = re.match( + r'((?:[^/]+)/(?:[^;]+))(?:;\s*codecs="([^"]+)")?', fmt_stream.get('mimeType') or '') + if mime_mobj: + dct['ext'] = mimetype2ext(mime_mobj.group(1)) + dct.update(parse_codecs(mime_mobj.group(2))) + + single_stream = 'none' in (dct.get('acodec'), dct.get('vcodec')) + if single_stream and dct.get('ext'): + dct['container'] = dct['ext'] + '_dash' + + return dct + + def process_sabr_formats_and_subtitles(): + proto = 'sabr' + server_abr_streaming_url = (self._process_n_param + (streaming_data.get('serverAbrStreamingUrl'), video_id, player_url, proto)) + video_playback_ustreamer_config = traverse_obj( + pr, ('playerConfig', 'mediaCommonConfig', 'mediaUstreamerRequestConfig', 'videoPlaybackUstreamerConfig')) + + if not server_abr_streaming_url or not video_playback_ustreamer_config: + return + + if protobug is None: + self.report_warning( + f'{video_id}: {client_name} client {proto} formats will be skipped as protobug is not installed.', + only_once=True) + return + + sabr_config = { + 'video_playback_ustreamer_config': video_playback_ustreamer_config, + 'po_token': po_token, + 'fetch_po_token_fn': fetch_po_token, + 'client_name': client_name, + 'client_info': traverse_obj(innertube_context, 'client'), + 'reload_config_fn': functools.partial(self._reload_sabr_config, video_id, client_name), + 'video_id': video_id, + 'live_status': live_status, + } + + for fmt_stream in streaming_formats: + stream_id = get_stream_id(fmt_stream) + if not all_formats: + if stream_id in stream_ids: + continue + + fmt = process_format_stream(fmt_stream, proto) + if not fmt: + continue + + caption_track = fmt_stream.get('captionTrack') + + fmt.update({ + 'is_from_start': live_status == 'is_live' and self.get_param('live_from_start'), + 'url': server_abr_streaming_url, + 'protocol': 'sabr', + }) + + fmt['_sabr_config'] = { + **sabr_config, + 'itag': stream_id[0], + 'xtags': fmt_stream.get('xtags'), + 'last_modified': fmt_stream.get('lastModified'), + 'target_duration_sec': fmt_stream.get('targetDurationSec'), + } + + single_stream = 'none' in (fmt.get('acodec'), fmt.get('vcodec')) + + nonlocal subtitles + if caption_track: + # TODO: proper live subtitle extraction + subtitles = self._merge_subtitles({str(stream_id[0]): [fmt]}, subtitles) + elif single_stream: + if stream_id[0]: + itags[stream_id[0]].add((proto, fmt.get('language'))) + stream_ids.append(stream_id) + yield fmt + + def process_https_formats(): + proto = 'https' + for fmt_stream in streaming_formats: + if fmt_stream.get('targetDurationSec'): + continue + + # FORMAT_STREAM_TYPE_OTF(otf=1) requires downloading the init fragment + # (adding `&sq=0` to the URL) and parsing emsg box to determine the + # number of fragment that would subsequently requested with (`&sq=N`) + if fmt_stream.get('type') == 'FORMAT_STREAM_TYPE_OTF' and not bool(fmt_stream.get('drmFamilies')): + continue + + stream_id = get_stream_id(fmt_stream) + if not all_formats: + if stream_id in stream_ids: + continue + + fmt = process_format_stream(fmt_stream, proto) + if not fmt: + continue + + fmt_url = fmt_stream.get('url') + if not fmt_url: + sc = urllib.parse.parse_qs(fmt_stream.get('signatureCipher')) + fmt_url = url_or_none(try_get(sc, lambda x: x['url'][0])) + encrypted_sig = try_get(sc, lambda x: x['s'][0]) + if not all((sc, fmt_url, player_url, encrypted_sig)): + continue + try: + fmt_url += '&{}={}'.format( + traverse_obj(sc, ('sp', -1)) or 'signature', + self._decrypt_signature(encrypted_sig, video_id, player_url), + ) + except ExtractorError as e: + self.report_warning( + f'Signature extraction failed: Some formats may be missing\n' + f' player = {player_url}\n' + f' {bug_reports_message(before="")}', + video_id=video_id, only_once=True) + self.write_debug( + f'{video_id}: Signature extraction failure info:\n' + f' encrypted sig = {encrypted_sig}\n' + f' player = {player_url}') + self.write_debug(e, only_once=True) + continue + + fmt_url = self._process_n_param(fmt_url, video_id, player_url, proto) + if not fmt_url: + continue + + if po_token: + fmt_url = update_url_query(fmt_url, {'pot': po_token}) + + fmt['url'] = fmt_url + + if stream_id[0]: + itags[stream_id[0]].add((proto, fmt.get('language'))) + stream_ids.append(stream_id) + + if (all_formats or 'dashy' in format_types) and fmt['filesize']: + yield { + **fmt, + 'format_id': f'{fmt["format_id"]}-dashy' if all_formats else fmt['format_id'], + 'protocol': 'http_dash_segments', + 'fragments': build_fragments(fmt), + } + if all_formats or 'dashy' not in format_types: + fmt['downloader_options'] = {'http_chunk_size': CHUNK_SIZE} + yield fmt + + yield from process_https_formats() + yield from process_sabr_formats_and_subtitles() + + needs_live_processing = self._needs_live_processing(live_status, duration) + skip_bad_formats = 'incomplete' not in format_types + if self._configuration_arg('include_incomplete_formats'): + skip_bad_formats = False + self._downloader.deprecated_feature('[youtube] include_incomplete_formats extractor argument is deprecated. ' + 'Use formats=incomplete extractor argument instead') + + skip_manifests = set(self._configuration_arg('skip')) + if (not self.get_param('youtube_include_hls_manifest', True) + or needs_live_processing == 'is_live' # These will be filtered out by YoutubeDL anyway + or (needs_live_processing and skip_bad_formats)): + skip_manifests.add('hls') + + if not self.get_param('youtube_include_dash_manifest', True): + skip_manifests.add('dash') + if self._configuration_arg('include_live_dash'): + self._downloader.deprecated_feature('[youtube] include_live_dash extractor argument is deprecated. ' + 'Use formats=incomplete extractor argument instead') + elif skip_bad_formats and live_status == 'is_live' and needs_live_processing != 'is_live': + skip_manifests.add('dash') + + def process_manifest_format(f, proto, client_name, itag, po_token): + key = (proto, f.get('language')) + if not all_formats and key in itags[itag]: return False - f['format_note'] = join_nonempty(f.get('format_note'), 'MISSING POT', delim=' ') - f['source_preference'] -= 20 - itags[itag].add(key) + if f.get('source_preference') is None: + f['source_preference'] = -1 - if itag and all_formats: - f['format_id'] = f'{itag}-{proto}' - elif any(p != proto for p, _ in itags[itag]): - f['format_id'] = f'{itag}-{proto}' - elif itag: - f['format_id'] = itag + # Clients that require PO Token return videoplayback URLs that may return 403 + # hls does not currently require PO Token + if ( + not po_token + and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS'] + and proto != 'hls' + ): + if 'missing_pot' not in self._configuration_arg('formats'): + self._report_pot_format_skipped(video_id, client_name, proto) + return False + f['format_note'] = join_nonempty(f.get('format_note'), 'MISSING POT', delim=' ') + f['source_preference'] -= 20 - if original_language and f.get('language') == original_language: - f['format_note'] = join_nonempty(f.get('format_note'), '(default)', delim=' ') - f['language_preference'] = PREFERRED_LANG_VALUE + itags[itag].add(key) - if itag in ('616', '235'): - f['format_note'] = join_nonempty(f.get('format_note'), 'Premium', delim=' ') - f['source_preference'] += 100 + if itag and all_formats: + f['format_id'] = f'{itag}-{proto}' + elif any(p != proto for p, _ in itags[itag]): + f['format_id'] = f'{itag}-{proto}' + elif itag: + f['format_id'] = itag - f['quality'] = q(itag_qualities.get(try_get(f, lambda f: f['format_id'].split('-')[0]), -1)) - if f['quality'] == -1 and f.get('height'): - f['quality'] = q(res_qualities[min(res_qualities, key=lambda x: abs(x - f['height']))]) - if self.get_param('verbose') or all_formats: - f['format_note'] = join_nonempty( - f.get('format_note'), short_client_name(client_name), delim=', ') - if f.get('fps') and f['fps'] <= 1: - del f['fps'] + if original_language and f.get('language') == original_language: + f['format_note'] = join_nonempty(f.get('format_note'), '(default)', delim=' ') + f['language_preference'] = PREFERRED_LANG_VALUE - if proto == 'hls' and f.get('has_drm'): - f['has_drm'] = 'maybe' - f['source_preference'] -= 5 - return True + if itag in ('616', '235'): + f['format_note'] = join_nonempty(f.get('format_note'), 'Premium', delim=' ') + f['source_preference'] += 100 - subtitles = {} - for sd in streaming_data: - client_name = sd[STREAMING_DATA_CLIENT_NAME] - po_token = sd.get(STREAMING_DATA_INITIAL_PO_TOKEN) - hls_manifest_url = 'hls' not in skip_manifests and sd.get('hlsManifestUrl') + f['quality'] = q(itag_qualities.get(try_get(f, lambda f: f['format_id'].split('-')[0]), -1)) + if f['quality'] == -1 and f.get('height'): + f['quality'] = q(res_qualities[min(res_qualities, key=lambda x: abs(x - f['height']))]) + if self.get_param('verbose') or all_formats: + f['format_note'] = join_nonempty( + f.get('format_note'), short_client_name(client_name), delim=', ') + if f.get('fps') and f['fps'] <= 1: + del f['fps'] + + if proto == 'hls' and f.get('has_drm'): + f['has_drm'] = 'maybe' + f['source_preference'] -= 5 + return True + + hls_manifest_url = 'hls' not in skip_manifests and streaming_data.get('hlsManifestUrl') if hls_manifest_url: if po_token: hls_manifest_url = hls_manifest_url.rstrip('/') + f'/pot/{po_token}' @@ -3602,7 +3727,7 @@ def process_manifest_format(f, proto, client_name, itag, po_token): r'/itag/(\d+)', f['url'], 'itag', default=None), po_token): yield f - dash_manifest_url = 'dash' not in skip_manifests and sd.get('dashManifestUrl') + dash_manifest_url = 'dash' not in skip_manifests and streaming_data.get('dashManifestUrl') if dash_manifest_url: if po_token: dash_manifest_url = dash_manifest_url.rstrip('/') + f'/pot/{po_token}' @@ -3659,11 +3784,14 @@ def _extract_storyboard(self, player_responses, duration): } for j in range(math.ceil(fragment_count))], } - def _download_player_responses(self, url, smuggled_data, video_id, webpage_url): + def _download_player_responses(self, url, smuggled_data, video_id, webpage_url, reload_playback_token=None): webpage = None if 'webpage' not in self._configuration_arg('player_skip'): query = {'bpctr': '9999999999', 'has_verified': '1'} - pp = self._configuration_arg('player_params', [None], casesense=True)[0] + pp = ( + self._configuration_arg('player_params', [None], casesense=True)[0] + or traverse_obj(INNERTUBE_CLIENTS, ('web', 'PLAYER_PARAMS', {str})) + ) if pp: query['pp'] = pp webpage = self._download_webpage_with_retries(webpage_url, video_id, query=query) @@ -3672,7 +3800,7 @@ def _download_player_responses(self, url, smuggled_data, video_id, webpage_url): player_responses, player_url = self._extract_player_responses( self._get_requested_clients(url, smuggled_data), - video_id, webpage, master_ytcfg, smuggled_data) + video_id, webpage, master_ytcfg, reload_playback_token) return webpage, master_ytcfg, player_responses, player_url @@ -3690,14 +3818,14 @@ def _list_formats(self, video_id, microformats, video_details, player_responses, else 'was_live' if live_content else 'not_live' if False in (is_live, live_content) else None) - streaming_data = traverse_obj(player_responses, (..., 'streamingData')) - *formats, subtitles = self._extract_formats_and_subtitles(streaming_data, video_id, player_url, live_status, duration) + *formats, subtitles = self._extract_formats_and_subtitles(video_id, player_responses, player_url, live_status, duration) + if all(f.get('has_drm') for f in formats): # If there are no formats that definitely don't have DRM, all have DRM for f in formats: f['has_drm'] = True - return live_broadcast_details, live_status, streaming_data, formats, subtitles + return live_broadcast_details, live_status, formats, subtitles def _real_extract(self, url): url, smuggled_data = unsmuggle_url(url, {}) @@ -3787,8 +3915,9 @@ def feed_entry(name): or int_or_none(get_first(microformats, 'lengthSeconds')) or parse_duration(search_meta('duration')) or None) - live_broadcast_details, live_status, streaming_data, formats, automatic_captions = \ + live_broadcast_details, live_status, formats, automatic_captions = \ self._list_formats(video_id, microformats, video_details, player_responses, player_url, duration) + streaming_data = traverse_obj(player_responses, (..., 'streamingData')) if live_status == 'post_live': self.write_debug(f'{video_id}: Video is in Post-Live Manifestless mode')