1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-06-27 17:08:32 +00:00

[ie/youtube] SABR Downloader

This commit is contained in:
coletdjnz 2025-06-21 11:15:25 +12:00
parent 73bf102116
commit c898d5f738
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
48 changed files with 4069 additions and 274 deletions

View File

@ -208,7 +208,7 @@ jobs:
python3.9 -m pip install -U pip wheel 'setuptools>=71.0.2' python3.9 -m pip install -U pip wheel 'setuptools>=71.0.2'
# XXX: Keep this in sync with pyproject.toml (it can't be accessed at this stage) and exclude secretstorage # XXX: Keep this in sync with pyproject.toml (it can't be accessed at this stage) and exclude secretstorage
python3.9 -m pip install -U Pyinstaller mutagen pycryptodomex brotli certifi cffi \ python3.9 -m pip install -U Pyinstaller mutagen pycryptodomex brotli certifi cffi \
'requests>=2.32.2,<3' 'urllib3>=1.26.17,<3' 'websockets>=13.0' 'requests>=2.32.2,<3' 'urllib3>=1.26.17,<3' 'websockets>=13.0' 'protobug==0.3.0'
run: | run: |
cd repo cd repo

View File

@ -212,6 +212,7 @@ ### Metadata
### Misc ### Misc
* [**protobug**](https://github.com/yt-dlp/protobug)\* - Protobuf library, for serializing and deserializing protobuf data. Licensed under [Unlicense](https://github.com/yt-dlp/protobug/blob/main/LICENSE)
* [**pycryptodomex**](https://github.com/Legrandin/pycryptodome)\* - For decrypting AES-128 HLS streams and various other data. Licensed under [BSD-2-Clause](https://github.com/Legrandin/pycryptodome/blob/master/LICENSE.rst) * [**pycryptodomex**](https://github.com/Legrandin/pycryptodome)\* - For decrypting AES-128 HLS streams and various other data. Licensed under [BSD-2-Clause](https://github.com/Legrandin/pycryptodome/blob/master/LICENSE.rst)
* [**phantomjs**](https://github.com/ariya/phantomjs) - Used in extractors where javascript needs to be run. Licensed under [BSD-3-Clause](https://github.com/ariya/phantomjs/blob/master/LICENSE.BSD) * [**phantomjs**](https://github.com/ariya/phantomjs) - Used in extractors where javascript needs to be run. Licensed under [BSD-3-Clause](https://github.com/ariya/phantomjs/blob/master/LICENSE.BSD)
* [**secretstorage**](https://github.com/mitya57/secretstorage)\* - For `--cookies-from-browser` to access the **Gnome** keyring while decrypting cookies of **Chromium**-based browsers on **Linux**. Licensed under [BSD-3-Clause](https://github.com/mitya57/secretstorage/blob/master/LICENSE) * [**secretstorage**](https://github.com/mitya57/secretstorage)\* - For `--cookies-from-browser` to access the **Gnome** keyring while decrypting cookies of **Chromium**-based browsers on **Linux**. Licensed under [BSD-3-Clause](https://github.com/mitya57/secretstorage/blob/master/LICENSE)
@ -1808,6 +1809,7 @@ #### youtube
* `innertube_host`: Innertube API host to use for all API requests; e.g. `studio.youtube.com`, `youtubei.googleapis.com`. Note that cookies exported from one subdomain will not work on others * `innertube_host`: Innertube API host to use for all API requests; e.g. `studio.youtube.com`, `youtubei.googleapis.com`. Note that cookies exported from one subdomain will not work on others
* `innertube_key`: Innertube API key to use for all API requests. By default, no API key is used * `innertube_key`: Innertube API key to use for all API requests. By default, no API key is used
* `raise_incomplete_data`: `Incomplete Data Received` raises an error instead of reporting a warning * `raise_incomplete_data`: `Incomplete Data Received` raises an error instead of reporting a warning
* `sabr_log_level`: Set the log level of the SABR downloader. One of `TRACE`, `DEBUG` or `INFO` (default)
* `data_sync_id`: Overrides the account Data Sync ID used in Innertube API requests. This may be needed if you are using an account with `youtube:player_skip=webpage,configs` or `youtubetab:skip=webpage` * `data_sync_id`: Overrides the account Data Sync ID used in Innertube API requests. This may be needed if you are using an account with `youtube:player_skip=webpage,configs` or `youtubetab:skip=webpage`
* `visitor_data`: Overrides the Visitor Data used in Innertube API requests. This should be used with `player_skip=webpage,configs` and without cookies. Note: this may have adverse effects if used improperly. If a session from a browser is wanted, you should pass cookies instead (which contain the Visitor ID) * `visitor_data`: Overrides the Visitor Data used in Innertube API requests. This should be used with `player_skip=webpage,configs` and without cookies. Note: this may have adverse effects if used improperly. If a session from a browser is wanted, you should pass cookies instead (which contain the Visitor ID)
* `po_token`: Proof of Origin (PO) Token(s) to use. Comma seperated list of PO Tokens in the format `CLIENT.CONTEXT+PO_TOKEN`, e.g. `youtube:po_token=web.gvs+XXX,web.player=XXX,web_safari.gvs+YYY`. Context can be any of `gvs` (Google Video Server URLs), `player` (Innertube player request) or `subs` (Subtitles) * `po_token`: Proof of Origin (PO) Token(s) to use. Comma seperated list of PO Tokens in the format `CLIENT.CONTEXT+PO_TOKEN`, e.g. `youtube:po_token=web.gvs+XXX,web.player=XXX,web_safari.gvs+YYY`. Context can be any of `gvs` (Google Video Server URLs), `player` (Innertube player request) or `subs` (Subtitles)

View File

@ -53,6 +53,7 @@ default = [
"requests>=2.32.2,<3", "requests>=2.32.2,<3",
"urllib3>=1.26.17,<3", "urllib3>=1.26.17,<3",
"websockets>=13.0", "websockets>=13.0",
"protobug==0.3.0",
] ]
curl-cffi = [ curl-cffi = [
"curl-cffi>=0.5.10,!=0.6.*,!=0.7.*,!=0.8.*,!=0.9.*,<0.11; implementation_name=='cpython'", "curl-cffi>=0.5.10,!=0.6.*,!=0.7.*,!=0.8.*,!=0.9.*,<0.11; implementation_name=='cpython'",

102
test/test_sabr/test_ump.py Normal file
View File

@ -0,0 +1,102 @@
import io
import pytest
from yt_dlp.extractor.youtube._streaming.ump import varint_size, read_varint, UMPDecoder, UMPPartId
@pytest.mark.parametrize('data, expected', [
(0x01, 1),
(0x4F, 1),
(0x80, 2),
(0xBF, 2),
(0xC0, 3),
(0xDF, 3),
(0xE0, 4),
(0xEF, 4),
(0xF0, 5),
(0xFF, 5),
])
def test_varint_size(data, expected):
assert varint_size(data) == expected
@pytest.mark.parametrize('data, expected', [
# 1 byte long varint
(b'\x01', 1),
(b'\x4F', 79),
# 2 byte long varint
(b'\x80\x01', 64),
(b'\x8A\x7F', 8138),
(b'\xBF\x7F', 8191),
# 3 byte long varint
(b'\xC0\x80\x01', 12288),
(b'\xDF\x7F\xFF', 2093055),
# 4 byte long varint
(b'\xE0\x80\x80\x01', 1574912),
(b'\xEF\x7F\xFF\xFF', 268433407),
# 5 byte long varint
(b'\xF0\x80\x80\x80\x01', 25198720),
(b'\xFF\x7F\xFF\xFF\xFF', 4294967167),
],
)
def test_readvarint(data, expected):
assert read_varint(io.BytesIO(data)) == expected
class TestUMPDecoder:
EXAMPLE_PART_DATA = [
{
# Part 1: Part type of 20, part size of 127
'part_type_bytes': b'\x14',
'part_size_bytes': b'\x7F',
'part_data_bytes': b'\x01' * 127,
'part_id': UMPPartId.MEDIA_HEADER,
'part_size': 127,
},
# Part 2, Part type of 4294967295, part size of 0
{
'part_type_bytes': b'\xFF\xFF\xFF\xFF\xFF',
'part_size_bytes': b'\x00',
'part_data_bytes': b'',
'part_id': UMPPartId.UNKNOWN,
'part_size': 0,
},
# Part 3: Part type of 21, part size of 1574912
{
'part_type_bytes': b'\x15',
'part_size_bytes': b'\xE0\x80\x80\x01',
'part_data_bytes': b'\x01' * 1574912,
'part_id': UMPPartId.MEDIA,
'part_size': 1574912,
},
]
COMBINED_PART_DATA = b''.join(part['part_type_bytes'] + part['part_size_bytes'] + part['part_data_bytes'] for part in EXAMPLE_PART_DATA)
def test_iter_parts(self):
# Create a mock file-like object
mock_file = io.BytesIO(self.COMBINED_PART_DATA)
# Create an instance of UMPDecoder with the mock file
decoder = UMPDecoder(mock_file)
# Iterate over the parts and check the values
for idx, part in enumerate(decoder.iter_parts()):
assert part.part_id == self.EXAMPLE_PART_DATA[idx]['part_id']
assert part.size == self.EXAMPLE_PART_DATA[idx]['part_size']
assert part.data.read() == self.EXAMPLE_PART_DATA[idx]['part_data_bytes']
assert mock_file.closed
def test_unexpected_eof(self):
# Unexpected bytes at the end of the file
mock_file = io.BytesIO(self.COMBINED_PART_DATA + b'\x00')
decoder = UMPDecoder(mock_file)
# Iterate over the parts and check the values
with pytest.raises(EOFError, match='Unexpected EOF while reading part size'):
for idx, part in enumerate(decoder.iter_parts()):
assert part.part_id == self.EXAMPLE_PART_DATA[idx]['part_id']
part.data.read()
assert mock_file.closed

View File

@ -0,0 +1,23 @@
import pytest
from yt_dlp.extractor.youtube._streaming.sabr.utils import ticks_to_ms, broadcast_id_from_url
@pytest.mark.parametrize(
'ticks, timescale, expected_ms',
[
(1000, 1000, 1000),
(5000, 10000, 500),
(234234, 44100, 5312),
(1, 1, 1000),
(None, 1000, None),
(1000, None, None),
(None, None, None),
],
)
def test_ticks_to_ms(ticks, timescale, expected_ms):
assert ticks_to_ms(ticks, timescale) == expected_ms
def test_broadcast_id_from_url():
assert broadcast_id_from_url('https://example.com/path?other=param&id=example.1~243&other2=param2') == 'example.1~243'
assert broadcast_id_from_url('https://example.com/path?other=param&other2=param2') is None

View File

@ -25,7 +25,7 @@ def get_hidden_imports():
for module in ('websockets', 'requests', 'urllib3'): for module in ('websockets', 'requests', 'urllib3'):
yield from collect_submodules(module) yield from collect_submodules(module)
# These are auto-detected, but explicitly add them just in case # These are auto-detected, but explicitly add them just in case
yield from ('mutagen', 'brotli', 'certifi', 'secretstorage', 'curl_cffi') yield from ('mutagen', 'brotli', 'certifi', 'secretstorage', 'curl_cffi', 'protobug')
hiddenimports = list(get_hidden_imports()) hiddenimports = list(get_hidden_imports())

View File

@ -79,6 +79,11 @@
except ImportError: except ImportError:
curl_cffi = None curl_cffi = None
try:
import protobug
except ImportError:
protobug = None
from . import Cryptodome from . import Cryptodome
all_dependencies = {k: v for k, v in globals().items() if not k.startswith('_')} all_dependencies = {k: v for k, v in globals().items() if not k.startswith('_')}

View File

@ -15,6 +15,10 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N
and not (to_stdout and len(protocols) > 1) and not (to_stdout and len(protocols) > 1)
and set(protocols) == {'http_dash_segments_generator'}): and set(protocols) == {'http_dash_segments_generator'}):
return DashSegmentsFD return DashSegmentsFD
elif SabrFD is not None and set(downloaders) == {SabrFD} and SabrFD.can_download(info_copy):
# NOTE: there may be one or more SABR downloaders for this info_dict,
# as SABR can download multiple formats at once.
return SabrFD
elif len(downloaders) == 1: elif len(downloaders) == 1:
return downloaders[0] return downloaders[0]
return None return None
@ -36,6 +40,7 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N
from .websocket import WebSocketFragmentFD from .websocket import WebSocketFragmentFD
from .youtube_live_chat import YoutubeLiveChatFD from .youtube_live_chat import YoutubeLiveChatFD
from .bunnycdn import BunnyCdnFD from .bunnycdn import BunnyCdnFD
from .sabr import SabrFD
PROTOCOL_MAP = { PROTOCOL_MAP = {
'rtmp': RtmpFD, 'rtmp': RtmpFD,
@ -56,6 +61,7 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N
'youtube_live_chat': YoutubeLiveChatFD, 'youtube_live_chat': YoutubeLiveChatFD,
'youtube_live_chat_replay': YoutubeLiveChatFD, 'youtube_live_chat_replay': YoutubeLiveChatFD,
'bunnycdn': BunnyCdnFD, 'bunnycdn': BunnyCdnFD,
'sabr': SabrFD,
} }

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
from yt_dlp.utils import DownloadError
from yt_dlp.downloader import FileDownloader
if not protobug:
class SabrFD(FileDownloader):
@classmethod
def can_download(cls, info_dict):
is_sabr = (
info_dict.get('requested_formats')
and all(
format_info.get('protocol') == 'sabr'
for format_info in info_dict['requested_formats']))
if is_sabr:
raise DownloadError('SABRFD requires protobug to be installed')
return is_sabr
else:
from ._fd import SabrFD # noqa: F401

View File

@ -0,0 +1,330 @@
from __future__ import annotations
import collections
import itertools
from yt_dlp.networking.exceptions import TransportError, HTTPError
from yt_dlp.utils import traverse_obj, int_or_none, DownloadError, join_nonempty
from yt_dlp.downloader import FileDownloader
from ._writer import SabrFDFormatWriter
from ._logger import create_sabrfd_logger
from yt_dlp.extractor.youtube._streaming.sabr.part import (
MediaSegmentEndSabrPart,
MediaSegmentDataSabrPart,
MediaSegmentInitSabrPart,
PoTokenStatusSabrPart,
RefreshPlayerResponseSabrPart,
MediaSeekSabrPart,
FormatInitializedSabrPart,
)
from yt_dlp.extractor.youtube._streaming.sabr.stream import SabrStream
from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange, AudioSelector, VideoSelector, CaptionSelector
from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, ClientName
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
class SabrFD(FileDownloader):
@classmethod
def can_download(cls, info_dict):
return (
info_dict.get('requested_formats')
and all(
format_info.get('protocol') == 'sabr'
for format_info in info_dict['requested_formats']))
def _group_formats_by_client(self, filename, info_dict):
format_groups = collections.defaultdict(dict, {})
requested_formats = info_dict.get('requested_formats') or [info_dict]
for _idx, f in enumerate(requested_formats):
sabr_config = f.get('_sabr_config')
client_name = sabr_config.get('client_name')
client_info = sabr_config.get('client_info')
server_abr_streaming_url = f.get('url')
video_playback_ustreamer_config = sabr_config.get('video_playback_ustreamer_config')
if not video_playback_ustreamer_config:
raise DownloadError('Video playback ustreamer config not found')
sabr_format_group_config = format_groups.get(client_name)
if not sabr_format_group_config:
sabr_format_group_config = format_groups[client_name] = {
'server_abr_streaming_url': server_abr_streaming_url,
'video_playback_ustreamer_config': video_playback_ustreamer_config,
'formats': [],
'initial_po_token': sabr_config.get('po_token'),
'fetch_po_token_fn': fn if callable(fn := sabr_config.get('fetch_po_token_fn')) else None,
'reload_config_fn': fn if callable(fn := sabr_config.get('reload_config_fn')) else None,
'live_status': sabr_config.get('live_status'),
'video_id': sabr_config.get('video_id'),
'client_info': ClientInfo(
client_name=traverse_obj(client_info, ('clientName', {lambda x: ClientName[x]})),
client_version=traverse_obj(client_info, 'clientVersion'),
os_version=traverse_obj(client_info, 'osVersion'),
os_name=traverse_obj(client_info, 'osName'),
device_model=traverse_obj(client_info, 'deviceModel'),
device_make=traverse_obj(client_info, 'deviceMake'),
),
'target_duration_sec': sabr_config.get('target_duration_sec'),
# Number.MAX_SAFE_INTEGER
'start_time_ms': ((2**53) - 1) if info_dict.get('live_status') == 'is_live' and not f.get('is_from_start') else 0,
}
else:
if sabr_format_group_config['server_abr_streaming_url'] != server_abr_streaming_url:
raise DownloadError('Server ABR streaming URL mismatch')
if sabr_format_group_config['video_playback_ustreamer_config'] != video_playback_ustreamer_config:
raise DownloadError('Video playback ustreamer config mismatch')
itag = int_or_none(sabr_config.get('itag'))
sabr_format_group_config['formats'].append({
'display_name': f.get('format_id'),
'format_id': itag and FormatId(
itag=itag, lmt=int_or_none(sabr_config.get('last_modified')), xtags=sabr_config.get('xtags')),
'format_type': format_type(f),
'quality': sabr_config.get('quality'),
'height': sabr_config.get('height'),
'filename': f.get('filepath', filename),
'info_dict': f,
})
return format_groups
def real_download(self, filename, info_dict):
format_groups = self._group_formats_by_client(filename, info_dict)
is_test = self.params.get('test', False)
resume = self.params.get('continuedl', True)
for client_name, format_group in format_groups.items():
formats = format_group['formats']
audio_formats = (f for f in formats if f['format_type'] == 'audio')
video_formats = (f for f in formats if f['format_type'] == 'video')
caption_formats = (f for f in formats if f['format_type'] == 'caption')
for audio_format, video_format, caption_format in itertools.zip_longest(audio_formats, video_formats, caption_formats):
format_str = join_nonempty(*[
traverse_obj(audio_format, 'display_name'),
traverse_obj(video_format, 'display_name'),
traverse_obj(caption_format, 'display_name')], delim='+')
self.write_debug(f'Downloading formats: {format_str} ({client_name} client)')
self._download_sabr_stream(
info_dict=info_dict,
video_format=video_format,
audio_format=audio_format,
caption_format=caption_format,
resume=resume,
is_test=is_test,
server_abr_streaming_url=format_group['server_abr_streaming_url'],
video_playback_ustreamer_config=format_group['video_playback_ustreamer_config'],
initial_po_token=format_group['initial_po_token'],
fetch_po_token_fn=format_group['fetch_po_token_fn'],
reload_config_fn=format_group['reload_config_fn'],
client_info=format_group['client_info'],
start_time_ms=format_group['start_time_ms'],
target_duration_sec=format_group.get('target_duration_sec', None),
live_status=format_group.get('live_status'),
video_id=format_group.get('video_id'),
)
return True
def _download_sabr_stream(
self,
video_id: str,
info_dict: dict,
video_format: dict,
audio_format: dict,
caption_format: dict,
resume: bool,
is_test: bool,
server_abr_streaming_url: str,
video_playback_ustreamer_config: str,
initial_po_token: str,
fetch_po_token_fn: callable | None = None,
reload_config_fn: callable | None = None,
client_info: ClientInfo | None = None,
start_time_ms: int = 0,
target_duration_sec: int | None = None,
live_status: str | None = None,
):
writers = {}
audio_selector = None
video_selector = None
caption_selector = None
if audio_format:
audio_selector = AudioSelector(
display_name=audio_format['display_name'], format_ids=[audio_format['format_id']])
writers[audio_selector.display_name] = SabrFDFormatWriter(
self, audio_format.get('filename'),
audio_format['info_dict'], len(writers), resume=resume)
if video_format:
video_selector = VideoSelector(
display_name=video_format['display_name'], format_ids=[video_format['format_id']])
writers[video_selector.display_name] = SabrFDFormatWriter(
self, video_format.get('filename'),
video_format['info_dict'], len(writers), resume=resume)
if caption_format:
caption_selector = CaptionSelector(
display_name=caption_format['display_name'], format_ids=[caption_format['format_id']])
writers[caption_selector.display_name] = SabrFDFormatWriter(
self, caption_format.get('filename'),
caption_format['info_dict'], len(writers), resume=resume)
stream = SabrStream(
urlopen=self.ydl.urlopen,
logger=create_sabrfd_logger(self.ydl, prefix='sabr:stream'),
server_abr_streaming_url=server_abr_streaming_url,
video_playback_ustreamer_config=video_playback_ustreamer_config,
po_token=initial_po_token,
video_selection=video_selector,
audio_selection=audio_selector,
caption_selection=caption_selector,
start_time_ms=start_time_ms,
client_info=client_info,
live_segment_target_duration_sec=target_duration_sec,
post_live=live_status == 'post_live',
video_id=video_id,
retry_sleep_func=self.params.get('retry_sleep_functions', {}).get('http'),
)
self._prepare_multiline_status(len(writers) + 1)
try:
total_bytes = 0
for part in stream:
if is_test and total_bytes >= self._TEST_FILE_SIZE:
stream.close()
break
if isinstance(part, PoTokenStatusSabrPart):
if not fetch_po_token_fn:
self.report_warning(
'No fetch PO token function found - this can happen if you use --load-info-json.'
' The download will fail if a valid PO token is required.', only_once=True)
if part.status in (
part.PoTokenStatus.INVALID,
part.PoTokenStatus.PENDING,
):
# Fetch a PO token with bypass_cache=True
# (ensure we create a new one)
po_token = fetch_po_token_fn(bypass_cache=True)
if po_token:
stream.processor.po_token = po_token
elif part.status in (
part.PoTokenStatus.MISSING,
part.PoTokenStatus.PENDING_MISSING,
):
# Fetch a PO Token, bypass_cache=False
po_token = fetch_po_token_fn()
if po_token:
stream.processor.po_token = po_token
elif isinstance(part, FormatInitializedSabrPart):
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown format selector: {part.format_selector}')
continue
writer.initialize_format(part.format_id)
initialized_format = stream.processor.initialized_formats[str(part.format_id)]
if writer.state.init_sequence:
initialized_format.init_segment = True
initialized_format.current_segment = None # allow a seek
# Build consumed ranges from the sequences
consumed_ranges = []
for sequence in writer.state.sequences:
consumed_ranges.append(ConsumedRange(
start_time_ms=sequence.first_segment.start_time_ms,
duration_ms=(sequence.last_segment.start_time_ms + sequence.last_segment.duration_ms) - sequence.first_segment.start_time_ms,
start_sequence_number=sequence.first_segment.sequence_number,
end_sequence_number=sequence.last_segment.sequence_number,
))
if consumed_ranges:
initialized_format.consumed_ranges = consumed_ranges
initialized_format.current_segment = None # allow a seek
self.to_screen(f'[download] Resuming download for format {part.format_selector.display_name}')
elif isinstance(part, MediaSegmentInitSabrPart):
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown init format selector: {part.format_selector}')
continue
writer.initialize_segment(part)
elif isinstance(part, MediaSegmentDataSabrPart):
total_bytes += len(part.data) # TODO: not reliable
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown data format selector: {part.format_selector}')
continue
writer.write_segment_data(part)
elif isinstance(part, MediaSegmentEndSabrPart):
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown end format selector: {part.format_selector}')
continue
writer.end_segment(part)
elif isinstance(part, RefreshPlayerResponseSabrPart):
self.to_screen(f'Refreshing player response; Reason: {part.reason}')
# In-place refresh - not ideal but should work in most cases
# TODO: handle case where live stream changes to non-livestream on refresh?
# TODO: if live, allow a seek as for non-DVR streams the reload may be longer than the buffer duration
# TODO: handle po token function change
if not reload_config_fn:
raise self.report_warning(
'No reload config function found - cannot refresh SABR streaming URL.'
' The url will expire soon and the download will fail.')
try:
stream.url, stream.processor.video_playback_ustreamer_config = reload_config_fn(part.reload_playback_token)
except (TransportError, HTTPError) as e:
self.report_warning(f'Failed to refresh SABR streaming URL: {e}')
elif isinstance(part, MediaSeekSabrPart):
if (
not info_dict.get('is_live')
and live_status not in ('post_live', 'is_live')
and not stream.processor.is_live
and part.reason == MediaSeekSabrPart.Reason.SERVER_SEEK
):
raise DownloadError('Server tried to seek a video')
else:
self.to_screen(f'Unhandled part type: {part.__class__.__name__}')
for writer in writers.values():
writer.finish()
except SabrStreamError as e:
raise DownloadError(str(e)) from e
except KeyboardInterrupt:
if (
not info_dict.get('is_live')
and not live_status == 'is_live'
and not stream.processor.is_live
):
raise
self.to_screen('Interrupted by user')
for writer in writers.values():
writer.finish()
finally:
# TODO: for livestreams, since we cannot resume them, should we finish the writers?
for writer in writers.values():
writer.close()
def format_type(f):
if f.get('acodec') == 'none':
return 'video'
elif f.get('vcodec') == 'none':
return 'audio'
elif f.get('vcodec') is None and f.get('acodec') is None:
return 'caption'
return None

View File

@ -0,0 +1,196 @@
from __future__ import annotations
import dataclasses
from yt_dlp.utils import DownloadError
from ._io import DiskFormatIOBackend, MemoryFormatIOBackend
@dataclasses.dataclass
class Segment:
segment_id: str
content_length: int | None = None
content_length_estimated: bool = False
sequence_number: int | None = None
start_time_ms: int | None = None
duration_ms: int | None = None
duration_estimated: bool = False
is_init_segment: bool = False
@dataclasses.dataclass
class Sequence:
sequence_id: str
# The segments may not have a start byte range, so to keep it simple we will track
# length of the sequence. We can infer from this and the segment's content_length where they should end and begin.
sequence_content_length: int = 0
first_segment: Segment | None = None
last_segment: Segment | None = None
class SequenceFile:
def __init__(self, fd, format_filename, sequence: Sequence, resume=False):
self.fd = fd
self.format_filename = format_filename
self.sequence = sequence
self.file = DiskFormatIOBackend(
fd=self.fd,
filename=self.format_filename + f'.sq{self.sequence_id}.sabr.part',
)
self.current_segment: SegmentFile | None = None
self.resume = resume
sequence_file_exists = self.file.exists()
if not resume and sequence_file_exists:
self.file.remove()
elif not self.sequence.last_segment and sequence_file_exists:
self.file.remove()
if self.sequence.last_segment and not sequence_file_exists:
raise DownloadError(f'Cannot find existing sequence {self.sequence_id} file')
if self.sequence.last_segment and not self.file.validate_length(self.sequence.sequence_content_length):
self.file.remove()
raise DownloadError(f'Existing sequence {self.sequence_id} file is not valid; removing')
@property
def sequence_id(self):
return self.sequence.sequence_id
@property
def current_length(self):
total = self.sequence.sequence_content_length
if self.current_segment:
total += self.current_segment.current_length
return total
def is_next_segment(self, segment: Segment):
if self.current_segment:
return False
latest_segment = self.sequence.last_segment or self.sequence.first_segment
if not latest_segment:
return True
if segment.is_init_segment and latest_segment.is_init_segment:
# Only one segment allowed for init segments
return False
return segment.sequence_number == latest_segment.sequence_number + 1
def is_current_segment(self, segment_id: str):
if not self.current_segment:
return False
return self.current_segment.segment_id == segment_id
def initialize_segment(self, segment: Segment):
if self.current_segment and not self.is_current_segment(segment.segment_id):
raise ValueError('Cannot reinitialize a segment that does not match the current segment')
if not self.current_segment and not self.is_next_segment(segment):
raise ValueError('Cannot initialize a segment that does not match the next segment')
self.current_segment = SegmentFile(
fd=self.fd,
format_filename=self.format_filename,
segment=segment,
)
def write_segment_data(self, data, segment_id: str):
if not self.is_current_segment(segment_id):
raise ValueError('Cannot write to a segment that does not match the current segment')
self.current_segment.write(data)
def end_segment(self, segment_id):
if not self.is_current_segment(segment_id):
raise ValueError('Cannot end a segment that does not exist')
self.current_segment.finish_write()
if (
self.current_segment.segment.content_length
and not self.current_segment.segment.content_length_estimated
and self.current_segment.current_length != self.current_segment.segment.content_length
):
raise DownloadError(
f'Filesize mismatch for segment {self.current_segment.segment_id}: '
f'Expected {self.current_segment.segment.content_length} bytes, got {self.current_segment.current_length} bytes')
self.current_segment.segment.content_length = self.current_segment.current_length
self.current_segment.segment.content_length_estimated = False
if not self.sequence.first_segment:
self.sequence.first_segment = self.current_segment.segment
self.sequence.last_segment = self.current_segment.segment
self.sequence.sequence_content_length += self.current_segment.current_length
if not self.file.mode:
self.file.initialize_writer(self.resume)
self.current_segment.read_into(self.file)
self.current_segment.remove()
self.current_segment = None
def read_into(self, backend):
self.file.initialize_reader()
self.file.read_into(backend)
self.file.close()
def remove(self):
self.close()
self.file.remove()
def close(self):
self.file.close()
class SegmentFile:
def __init__(self, fd, format_filename, segment: Segment, memory_file_limit=2 * 1024 * 1024):
self.fd = fd
self.format_filename = format_filename
self.segment: Segment = segment
self.current_length = 0
filename = format_filename + f'.sg{segment.sequence_number}.sabr.part'
# Store the segment in memory if it is small enough
if segment.content_length and segment.content_length <= memory_file_limit:
self.file = MemoryFormatIOBackend(
fd=self.fd,
filename=filename,
)
else:
self.file = DiskFormatIOBackend(
fd=self.fd,
filename=filename,
)
# Never resume a segment
exists = self.file.exists()
if exists:
self.file.remove()
@property
def segment_id(self):
return self.segment.segment_id
def write(self, data):
if not self.file.mode:
self.file.initialize_writer(resume=False)
self.current_length += self.file.write(data)
def read_into(self, file):
self.file.initialize_reader()
self.file.read_into(file)
self.file.close()
def remove(self):
self.close()
self.file.remove()
def finish_write(self):
self.close()
def close(self):
self.file.close()

View File

@ -0,0 +1,162 @@
from __future__ import annotations
import abc
import io
import os
import shutil
import typing
class FormatIOBackend(abc.ABC):
def __init__(self, fd, filename, buffer=1024 * 1024):
self.fd = fd
self.filename = filename
self.write_buffer = buffer
self._fp = None
self._fp_mode = None
@property
def writer(self):
if self._fp is None or self._fp_mode != 'write':
return None
return self._fp
@property
def reader(self):
if self._fp is None or self._fp_mode != 'read':
return None
return self._fp
def initialize_writer(self, resume=False):
if self._fp is not None:
raise ValueError('Backend already initialized')
self._fp = self._create_writer(resume)
self._fp_mode = 'write'
@abc.abstractmethod
def _create_writer(self, resume=False) -> typing.IO:
pass
def initialize_reader(self):
if self._fp is not None:
raise ValueError('Backend already initialized')
self._fp = self._create_reader()
self._fp_mode = 'read'
@abc.abstractmethod
def _create_reader(self) -> typing.IO:
pass
def close(self):
if self._fp and not self._fp.closed:
self._fp.flush()
self._fp.close()
self._fp = None
self._fp_mode = None
@abc.abstractmethod
def validate_length(self, expected_length):
pass
def remove(self):
self.close()
self._remove()
@abc.abstractmethod
def _remove(self):
pass
@abc.abstractmethod
def exists(self):
pass
@property
def mode(self):
if self._fp is None:
return None
return self._fp_mode
def write(self, data: io.BufferedIOBase | bytes):
if not self.writer:
raise ValueError('Backend writer not initialized')
if isinstance(data, bytes):
bytes_written = self.writer.write(data)
elif isinstance(data, io.BufferedIOBase):
bytes_written = self.writer.tell()
shutil.copyfileobj(data, self.writer, length=self.write_buffer)
bytes_written = self.writer.tell() - bytes_written
else:
raise TypeError('Data must be bytes or a BufferedIOBase object')
self.writer.flush()
return bytes_written
def read_into(self, backend):
if not backend.writer:
raise ValueError('Backend writer not initialized')
if not self.reader:
raise ValueError('Backend reader not initialized')
shutil.copyfileobj(self.reader, backend.writer, length=self.write_buffer)
backend.writer.flush()
class DiskFormatIOBackend(FormatIOBackend):
def _create_writer(self, resume=False) -> typing.IO:
if resume and self.exists():
write_fp, self.filename = self.fd.sanitize_open(self.filename, 'ab')
else:
write_fp, self.filename = self.fd.sanitize_open(self.filename, 'wb')
return write_fp
def _create_reader(self) -> typing.IO:
read_fp, self.filename = self.fd.sanitize_open(self.filename, 'rb')
return read_fp
def validate_length(self, expected_length):
return os.path.getsize(self.filename) == expected_length
def _remove(self):
self.fd.try_remove(self.filename)
def exists(self):
return os.path.isfile(self.filename)
class MemoryFormatIOBackend(FormatIOBackend):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._memory_store = io.BytesIO()
def _create_writer(self, resume=False) -> typing.IO:
class NonClosingBufferedWriter(io.BufferedWriter):
def close(self):
self.flush()
# Do not close the underlying buffer
if resume and self.exists():
self._memory_store.seek(0, io.SEEK_END)
else:
self._memory_store.seek(0)
self._memory_store.truncate(0)
return NonClosingBufferedWriter(self._memory_store)
def _create_reader(self) -> typing.IO:
class NonClosingBufferedReader(io.BufferedReader):
def close(self):
self.flush()
# Seek to the beginning of the buffer
self._memory_store.seek(0)
return NonClosingBufferedReader(self._memory_store)
def validate_length(self, expected_length):
return self._memory_store.getbuffer().nbytes != expected_length
def _remove(self):
self._memory_store = io.BytesIO()
def exists(self):
return self._memory_store.getbuffer().nbytes > 0

View File

@ -0,0 +1,46 @@
from __future__ import annotations
from yt_dlp.utils import format_field, traverse_obj
from yt_dlp.extractor.youtube._streaming.sabr.models import SabrLogger
from yt_dlp.utils._utils import _YDLLogger
# TODO: create a logger that logs to a file rather than the console.
# Might be useful for debugging SABR issues from users.
class SabrFDLogger(SabrLogger):
def __init__(self, ydl, prefix, log_level: SabrLogger.LogLevel | None = None):
self._ydl_logger = _YDLLogger(ydl)
self.prefix = prefix
self.log_level = log_level if log_level is not None else self.LogLevel.INFO
def _format_msg(self, message: str):
prefixstr = format_field(self.prefix, None, '[%s] ')
return f'{prefixstr}{message}'
def trace(self, message: str):
if self.log_level <= self.LogLevel.TRACE:
self._ydl_logger.debug(self._format_msg('TRACE: ' + message))
def debug(self, message: str):
if self.log_level <= self.LogLevel.DEBUG:
self._ydl_logger.debug(self._format_msg(message))
def info(self, message: str):
if self.log_level <= self.LogLevel.INFO:
self._ydl_logger.info(self._format_msg(message))
def warning(self, message: str, *, once=False):
if self.log_level <= self.LogLevel.WARNING:
self._ydl_logger.warning(self._format_msg(message), once=once)
def error(self, message: str):
if self.log_level <= self.LogLevel.ERROR:
self._ydl_logger.error(self._format_msg(message), is_error=False)
def create_sabrfd_logger(ydl, prefix):
return SabrFDLogger(
ydl, prefix=prefix,
log_level=SabrFDLogger.LogLevel(traverse_obj(
ydl.params, ('extractor_args', 'youtube', 'sabr_log_level', 0, {str}), get_all=False)))

View File

@ -0,0 +1,77 @@
from __future__ import annotations
import contextlib
import os
import tempfile
from yt_dlp.dependencies import protobug
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
@protobug.message
class SabrStateSegment:
sequence_number: protobug.Int32 = protobug.field(1)
start_time_ms: protobug.Int64 = protobug.field(2)
duration_ms: protobug.Int64 = protobug.field(3)
duration_estimated: protobug.Bool = protobug.field(4)
content_length: protobug.Int64 = protobug.field(5)
@protobug.message
class SabrStateSequence:
sequence_start_number: protobug.Int32 = protobug.field(1)
sequence_content_length: protobug.Int64 = protobug.field(2)
first_segment: SabrStateSegment = protobug.field(3)
last_segment: SabrStateSegment = protobug.field(4)
@protobug.message
class SabrStateInitSegment:
content_length: protobug.Int64 = protobug.field(2)
@protobug.message
class SabrState:
format_id: FormatId = protobug.field(1)
init_segment: SabrStateInitSegment | None = protobug.field(2, default=None)
sequences: list[SabrStateSequence] = protobug.field(3, default_factory=list)
class SabrStateFile:
def __init__(self, format_filename, fd):
self.filename = format_filename + '.sabr.state'
self.fd = fd
@property
def exists(self):
return os.path.isfile(self.filename)
def retrieve(self):
stream, self.filename = self.fd.sanitize_open(self.filename, 'rb')
try:
return self.deserialize(stream.read())
finally:
stream.close()
def update(self, sabr_document):
# Attempt to write progress document somewhat atomically to avoid corruption
tf = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(self.filename))
try:
with open(tf.name, 'wb') as f:
f.write(self.serialize(sabr_document))
f.flush()
os.fsync(f.fileno())
os.replace(tf.name, self.filename)
finally:
if os.path.exists(tf.name):
with contextlib.suppress(FileNotFoundError, OSError):
os.unlink(tf.name)
def serialize(self, sabr_document):
return protobug.dumps(sabr_document)
def deserialize(self, data):
return protobug.loads(data, SabrState)
def remove(self):
self.fd.try_remove(self.filename)

View File

@ -0,0 +1,355 @@
from __future__ import annotations
import dataclasses
from ._io import DiskFormatIOBackend
from ._file import SequenceFile, Sequence, Segment
from ._state import (
SabrStateSegment,
SabrStateSequence,
SabrStateInitSegment,
SabrState,
SabrStateFile,
)
from yt_dlp.extractor.youtube._streaming.sabr.part import (
MediaSegmentInitSabrPart,
MediaSegmentDataSabrPart,
MediaSegmentEndSabrPart,
)
from yt_dlp.utils import DownloadError
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
from yt_dlp.utils.progress import ProgressCalculator
INIT_SEGMENT_ID = 'i'
@dataclasses.dataclass
class SabrFormatState:
format_id: FormatId
init_sequence: Sequence | None = None
sequences: list[Sequence] = dataclasses.field(default_factory=list)
class SabrFDFormatWriter:
def __init__(self, fd, filename, infodict, progress_idx=0, resume=False):
self.fd = fd
self.info_dict = infodict
self.filename = filename
self.progress_idx = progress_idx
self.resume = resume
self._progress = None
self._downloaded_bytes = 0
self._state = {}
self._format_id = None
self.file = DiskFormatIOBackend(
fd=self.fd,
filename=self.fd.temp_name(filename),
)
self._sabr_state_file = SabrStateFile(format_filename=self.filename, fd=fd)
self._sequence_files: list[SequenceFile] = []
self._init_sequence: SequenceFile | None = None
@property
def state(self):
return SabrFormatState(
format_id=self._format_id,
init_sequence=self._init_sequence.sequence if self._init_sequence else None,
sequences=[sf.sequence for sf in self._sequence_files],
)
@property
def downloaded_bytes(self):
return (sum(
sequence.current_length for sequence in self._sequence_files)
+ (self._init_sequence.current_length if self._init_sequence else 0))
def initialize_format(self, format_id):
if self._format_id:
raise ValueError('Already initialized')
self._format_id = format_id
if not self.resume:
if self._sabr_state_file.exists:
self._sabr_state_file.remove()
return
document = self._load_sabr_state()
if document.init_segment:
init_segment = Segment(
segment_id=INIT_SEGMENT_ID,
content_length=document.init_segment.content_length,
is_init_segment=True,
)
try:
self._init_sequence = SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=True,
sequence=Sequence(
sequence_id=INIT_SEGMENT_ID,
sequence_content_length=init_segment.content_length,
first_segment=init_segment,
last_segment=init_segment,
))
except DownloadError as e:
self.fd.report_warning(f'Failed to resume init segment for format {self.info_dict.get("format_id")}: {e}')
for sabr_sequence in list(document.sequences):
try:
self._sequence_files.append(SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=True,
sequence=Sequence(
sequence_id=str(sabr_sequence.sequence_start_number),
sequence_content_length=sabr_sequence.sequence_content_length,
first_segment=Segment(
segment_id=str(sabr_sequence.first_segment.sequence_number),
sequence_number=sabr_sequence.first_segment.sequence_number,
content_length=sabr_sequence.first_segment.content_length,
start_time_ms=sabr_sequence.first_segment.start_time_ms,
duration_ms=sabr_sequence.first_segment.duration_ms,
is_init_segment=False,
),
last_segment=Segment(
segment_id=str(sabr_sequence.last_segment.sequence_number),
sequence_number=sabr_sequence.last_segment.sequence_number,
content_length=sabr_sequence.last_segment.content_length,
start_time_ms=sabr_sequence.last_segment.start_time_ms,
duration_ms=sabr_sequence.last_segment.duration_ms,
is_init_segment=False,
),
),
))
except DownloadError as e:
self.fd.report_warning(
f'Failed to resume sequence {sabr_sequence.sequence_start_number} '
f'for format {self.info_dict.get("format_id")}: {e}')
@property
def initialized(self):
return self._format_id is not None
def close(self):
if not self.file:
raise ValueError('Already closed')
for sequence in self._sequence_files:
sequence.close()
self._sequence_files.clear()
if self._init_sequence:
self._init_sequence.close()
self._init_sequence = None
self.file.close()
def _find_sequence_file(self, predicate):
match = None
for sequence in self._sequence_files:
if predicate(sequence):
if match is not None:
raise DownloadError('Multiple sequence files found for segment')
match = sequence
return match
def find_next_sequence_file(self, next_segment: Segment):
return self._find_sequence_file(lambda sequence: sequence.is_next_segment(next_segment))
def find_current_sequence_file(self, segment_id: str):
return self._find_sequence_file(lambda sequence: sequence.is_current_segment(segment_id))
def initialize_segment(self, part: MediaSegmentInitSabrPart):
if not self._progress:
self._progress = ProgressCalculator(part.start_bytes)
if not self._format_id:
raise ValueError('not initialized')
if part.is_init_segment:
if not self._init_sequence:
self._init_sequence = SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=False,
sequence=Sequence(
sequence_id=INIT_SEGMENT_ID,
))
self._init_sequence.initialize_segment(Segment(
segment_id=INIT_SEGMENT_ID,
content_length=part.content_length,
content_length_estimated=part.content_length_estimated,
is_init_segment=True,
))
return True
segment = Segment(
segment_id=str(part.sequence_number),
sequence_number=part.sequence_number,
start_time_ms=part.start_time_ms,
duration_ms=part.duration_ms,
duration_estimated=part.duration_estimated,
content_length=part.content_length,
content_length_estimated=part.content_length_estimated,
)
sequence_file = self.find_current_sequence_file(segment.segment_id) or self.find_next_sequence_file(segment)
if not sequence_file:
sequence_file = SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=False,
sequence=Sequence(sequence_id=str(part.sequence_number)),
)
self._sequence_files.append(sequence_file)
sequence_file.initialize_segment(segment)
return True
def write_segment_data(self, part: MediaSegmentDataSabrPart):
if part.is_init_segment:
sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID
else:
segment_id = str(part.sequence_number)
sequence_file = self.find_current_sequence_file(segment_id)
if not sequence_file:
raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?')
sequence_file.write_segment_data(part.data, segment_id)
# TODO: Handling of disjointed segments (e.g. when downloading segments out of order / concurrently)
self._progress.total = self.info_dict.get('filesize')
self._state = {
'status': 'downloading',
'downloaded_bytes': self.downloaded_bytes,
'total_bytes': self.info_dict.get('filesize'),
'filename': self.filename,
'eta': self._progress.eta.smooth,
'speed': self._progress.speed.smooth,
'elapsed': self._progress.elapsed,
'progress_idx': self.progress_idx,
'fragment_count': part.total_segments,
'fragment_index': part.sequence_number,
}
self._progress.update(self._state['downloaded_bytes'])
self.fd._hook_progress(self._state, self.info_dict)
def end_segment(self, part: MediaSegmentEndSabrPart):
if part.is_init_segment:
sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID
else:
segment_id = str(part.sequence_number)
sequence_file = self.find_current_sequence_file(segment_id)
if not sequence_file:
raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?')
sequence_file.end_segment(segment_id)
self._write_sabr_state()
def _load_sabr_state(self):
sabr_state = None
if self._sabr_state_file.exists:
try:
sabr_state = self._sabr_state_file.retrieve()
except Exception:
self.fd.report_warning(
f'Corrupted state file for format {self.info_dict.get("format_id")}, restarting download')
if sabr_state and sabr_state.format_id != self._format_id:
self.fd.report_warning(
f'Format ID mismatch in state file for {self.info_dict.get("format_id")}, restarting download')
sabr_state = None
if not sabr_state:
sabr_state = SabrState(format_id=self._format_id)
return sabr_state
def _write_sabr_state(self):
sabr_state = SabrState(format_id=self._format_id)
if not self._init_sequence:
sabr_state.init_segment = None
else:
sabr_state.init_segment = SabrStateInitSegment(
content_length=self._init_sequence.sequence.sequence_content_length,
)
sabr_state.sequences = []
for sequence_file in self._sequence_files:
# Ignore partial sequences
if not sequence_file.sequence.first_segment or not sequence_file.sequence.last_segment:
continue
sabr_state.sequences.append(SabrStateSequence(
sequence_start_number=sequence_file.sequence.first_segment.sequence_number,
sequence_content_length=sequence_file.sequence.sequence_content_length,
first_segment=SabrStateSegment(
sequence_number=sequence_file.sequence.first_segment.sequence_number,
start_time_ms=sequence_file.sequence.first_segment.start_time_ms,
duration_ms=sequence_file.sequence.first_segment.duration_ms,
duration_estimated=sequence_file.sequence.first_segment.duration_estimated,
content_length=sequence_file.sequence.first_segment.content_length,
),
last_segment=SabrStateSegment(
sequence_number=sequence_file.sequence.last_segment.sequence_number,
start_time_ms=sequence_file.sequence.last_segment.start_time_ms,
duration_ms=sequence_file.sequence.last_segment.duration_ms,
duration_estimated=sequence_file.sequence.last_segment.duration_estimated,
content_length=sequence_file.sequence.last_segment.content_length,
),
))
self._sabr_state_file.update(sabr_state)
def finish(self):
self._state['status'] = 'finished'
self.fd._hook_progress(self._state, self.info_dict)
for sequence_file in self._sequence_files:
sequence_file.close()
if self._init_sequence:
self._init_sequence.close()
# Now merge all the sequences together
self.file.initialize_writer(resume=False)
# Note: May not always be an init segment, e.g for live streams
if self._init_sequence:
self._init_sequence.read_into(self.file)
self._init_sequence.close()
# TODO: handling of disjointed segments
previous_seq_number = None
for sequence_file in sorted(
(sf for sf in self._sequence_files if sf.sequence.first_segment),
key=lambda s: s.sequence.first_segment.sequence_number):
if previous_seq_number and previous_seq_number + 1 != sequence_file.sequence.first_segment.sequence_number:
self.fd.report_warning(f'Disjointed sequences found in SABR format {self.info_dict.get("format_id")}')
previous_seq_number = sequence_file.sequence.last_segment.sequence_number
sequence_file.read_into(self.file)
sequence_file.close()
# Format temp file should have all the segments, rename it to the final name
self.file.close()
self.fd.try_rename(self.file.filename, self.fd.undo_temp_name(self.file.filename))
# Remove the state file
self._sabr_state_file.remove()
# Remove sequence files
for sf in self._sequence_files:
sf.close()
sf.remove()
if self._init_sequence:
self._init_sequence.close()
self._init_sequence.remove()
self.close()

View File

@ -0,0 +1,14 @@
import dataclasses
import typing
def unknown_fields(obj: typing.Any, path=()) -> typing.Iterable[tuple[tuple[str, ...], dict[int, list]]]:
if not dataclasses.is_dataclass(obj):
return
if unknown := getattr(obj, '_unknown', None):
yield path, unknown
for field in dataclasses.fields(obj):
value = getattr(obj, field.name)
yield from unknown_fields(value, (*path, field.name))

View File

@ -0,0 +1,5 @@
from .client_info import ClientInfo, ClientName # noqa: F401
from .compression_algorithm import CompressionAlgorithm # noqa: F401
from .next_request_policy import NextRequestPolicy # noqa: F401
from .range import Range # noqa: F401
from .seek_source import SeekSource # noqa: F401

View File

@ -0,0 +1,105 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
class ClientName(protobug.Enum, strict=False):
UNKNOWN_INTERFACE = 0
WEB = 1
MWEB = 2
ANDROID = 3
IOS = 5
TVHTML5 = 7
TVLITE = 8
TVANDROID = 10
XBOX = 11
CLIENTX = 12
XBOXONEGUIDE = 13
ANDROID_CREATOR = 14
IOS_CREATOR = 15
TVAPPLE = 16
IOS_INSTANT = 17
ANDROID_KIDS = 18
IOS_KIDS = 19
ANDROID_INSTANT = 20
ANDROID_MUSIC = 21
IOS_TABLOID = 22
ANDROID_TV = 23
ANDROID_GAMING = 24
IOS_GAMING = 25
IOS_MUSIC = 26
MWEB_TIER_2 = 27
ANDROID_VR = 28
ANDROID_UNPLUGGED = 29
ANDROID_TESTSUITE = 30
WEB_MUSIC_ANALYTICS = 31
WEB_GAMING = 32
IOS_UNPLUGGED = 33
ANDROID_WITNESS = 34
IOS_WITNESS = 35
ANDROID_SPORTS = 36
IOS_SPORTS = 37
ANDROID_LITE = 38
IOS_EMBEDDED_PLAYER = 39
IOS_DIRECTOR = 40
WEB_UNPLUGGED = 41
WEB_EXPERIMENTS = 42
TVHTML5_CAST = 43
WEB_EMBEDDED_PLAYER = 56
TVHTML5_AUDIO = 57
TV_UNPLUGGED_CAST = 58
TVHTML5_KIDS = 59
WEB_HEROES = 60
WEB_MUSIC = 61
WEB_CREATOR = 62
TV_UNPLUGGED_ANDROID = 63
IOS_LIVE_CREATION_EXTENSION = 64
TVHTML5_UNPLUGGED = 65
IOS_MESSAGES_EXTENSION = 66
WEB_REMIX = 67
IOS_UPTIME = 68
WEB_UNPLUGGED_ONBOARDING = 69
WEB_UNPLUGGED_OPS = 70
WEB_UNPLUGGED_PUBLIC = 71
TVHTML5_VR = 72
WEB_LIVE_STREAMING = 73
ANDROID_TV_KIDS = 74
TVHTML5_SIMPLY = 75
WEB_KIDS = 76
MUSIC_INTEGRATIONS = 77
TVHTML5_YONGLE = 80
GOOGLE_ASSISTANT = 84
TVHTML5_SIMPLY_EMBEDDED_PLAYER = 85
WEB_MUSIC_EMBEDDED_PLAYER = 86
WEB_INTERNAL_ANALYTICS = 87
WEB_PARENT_TOOLS = 88
GOOGLE_MEDIA_ACTIONS = 89
WEB_PHONE_VERIFICATION = 90
ANDROID_PRODUCER = 91
IOS_PRODUCER = 92
TVHTML5_FOR_KIDS = 93
GOOGLE_LIST_RECS = 94
MEDIA_CONNECT_FRONTEND = 95
WEB_EFFECT_MAKER = 98
WEB_SHOPPING_EXTENSION = 99
WEB_PLAYABLES_PORTAL = 100
VISIONOS = 101
WEB_LIVE_APPS = 102
WEB_MUSIC_INTEGRATIONS = 103
ANDROID_MUSIC_AOSP = 104
@protobug.message
class ClientInfo:
hl: protobug.String | None = protobug.field(1, default=None)
gl: protobug.String | None = protobug.field(2, default=None)
remote_host: protobug.String | None = protobug.field(4, default=None)
device_make: protobug.String | None = protobug.field(12, default=None)
device_model: protobug.String | None = protobug.field(13, default=None)
visitor_data: protobug.String | None = protobug.field(14, default=None)
user_agent: protobug.String | None = protobug.field(15, default=None)
client_name: ClientName | None = protobug.field(16, default=None)
client_version: protobug.String | None = protobug.field(17, default=None)
os_name: protobug.String | None = protobug.field(18, default=None)
os_version: protobug.String | None = protobug.field(19, default=None)

View File

@ -0,0 +1,7 @@
from yt_dlp.dependencies import protobug
class CompressionAlgorithm(protobug.Enum, strict=False):
COMPRESSION_ALGORITHM_UNKNOWN = 0
COMPRESSION_ALGORITHM_NONE = 1
COMPRESSION_ALGORITHM_GZIP = 2

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class NextRequestPolicy:
target_audio_readahead_ms: protobug.Int32 | None = protobug.field(1, default=None)
target_video_readahead_ms: protobug.Int32 | None = protobug.field(2, default=None)
max_time_since_last_request_ms: protobug.Int32 | None = protobug.field(3, default=None)
backoff_time_ms: protobug.Int32 | None = protobug.field(4, default=None)
min_audio_readahead_ms: protobug.Int32 | None = protobug.field(5, default=None)
min_video_readahead_ms: protobug.Int32 | None = protobug.field(6, default=None)
playback_cookie: protobug.Bytes | None = protobug.field(7, default=None)
video_id: protobug.String | None = protobug.field(8, default=None)

View File

@ -0,0 +1,9 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class Range:
start: protobug.Int64 | None = protobug.field(1, default=None)
end: protobug.Int64 | None = protobug.field(2, default=None)

View File

@ -0,0 +1,14 @@
from yt_dlp.dependencies import protobug
class SeekSource(protobug.Enum, strict=False):
SEEK_SOURCE_UNKNOWN = 0
SEEK_SOURCE_SABR_PARTIAL_CHUNK = 9
SEEK_SOURCE_SABR_SEEK_TO_HEAD = 10
SEEK_SOURCE_SABR_LIVE_DVR_USER_SEEK = 11
SEEK_SOURCE_SABR_SEEK_TO_DVR_LOWER_BOUND = 12
SEEK_SOURCE_SABR_SEEK_TO_DVR_UPPER_BOUND = 13
SEEK_SOURCE_SABR_ACCURATE_SEEK = 17
SEEK_SOURCE_SABR_INGESTION_WALL_TIME_SEEK = 29
SEEK_SOURCE_SABR_SEEK_TO_CLOSEST_KEYFRAME = 59
SEEK_SOURCE_SABR_RELOAD_PLAYER_RESPONSE_TOKEN_SEEK = 106

View File

@ -0,0 +1,16 @@
from .buffered_range import BufferedRange # noqa: F401
from .client_abr_state import ClientAbrState # noqa: F401
from .format_id import FormatId # noqa: F401
from .format_initialization_metadata import FormatInitializationMetadata # noqa: F401
from .live_metadata import LiveMetadata # noqa: F401
from .media_header import MediaHeader # noqa: F401
from .reload_player_response import ReloadPlayerResponse # noqa: F401
from .sabr_context_sending_policy import SabrContextSendingPolicy # noqa: F401
from .sabr_context_update import SabrContextUpdate # noqa: F401
from .sabr_error import SabrError # noqa: F401
from .sabr_redirect import SabrRedirect # noqa: F401
from .sabr_seek import SabrSeek # noqa: F401
from .stream_protection_status import StreamProtectionStatus # noqa: F401
from .streamer_context import SabrContext, StreamerContext # noqa: F401
from .time_range import TimeRange # noqa: F401
from .video_playback_abr_request import VideoPlaybackAbrRequest # noqa: F401

View File

@ -0,0 +1,16 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
from .format_id import FormatId
from .time_range import TimeRange
@protobug.message
class BufferedRange:
format_id: FormatId | None = protobug.field(1, default=None)
start_time_ms: protobug.Int64 | None = protobug.field(2, default=None)
duration_ms: protobug.Int64 | None = protobug.field(3, default=None)
start_segment_index: protobug.Int32 | None = protobug.field(4, default=None)
end_segment_index: protobug.Int32 | None = protobug.field(5, default=None)
time_range: TimeRange | None = protobug.field(6, default=None)

View File

@ -0,0 +1,9 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class ClientAbrState:
player_time_ms: protobug.Int64 | None = protobug.field(28, default=None)
enabled_track_types_bitfield: protobug.Int32 | None = protobug.field(40, default=None)

View File

@ -0,0 +1,10 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class FormatId:
itag: protobug.Int32 | None = protobug.field(1)
lmt: protobug.UInt64 | None = protobug.field(2, default=None)
xtags: protobug.String | None = protobug.field(3, default=None)

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
from .format_id import FormatId
from ..innertube import Range
@protobug.message
class FormatInitializationMetadata:
video_id: protobug.String = protobug.field(1, default=None)
format_id: FormatId = protobug.field(2, default=None)
end_time_ms: protobug.Int32 | None = protobug.field(3, default=None)
total_segments: protobug.Int32 | None = protobug.field(4, default=None)
mime_type: protobug.String | None = protobug.field(5, default=None)
init_range: Range | None = protobug.field(6, default=None)
index_range: Range | None = protobug.field(7, default=None)
duration_ticks: protobug.Int32 | None = protobug.field(9, default=None)
duration_timescale: protobug.Int32 | None = protobug.field(10, default=None)

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class LiveMetadata:
head_sequence_number: protobug.Int32 | None = protobug.field(3, default=None)
head_sequence_time_ms: protobug.Int64 | None = protobug.field(4, default=None)
wall_time_ms: protobug.Int64 | None = protobug.field(5, default=None)
video_id: protobug.String | None = protobug.field(6, default=None)
source: protobug.String | None = protobug.field(7, default=None)
min_seekable_time_ticks: protobug.Int64 | None = protobug.field(12, default=None)
min_seekable_timescale: protobug.Int32 | None = protobug.field(13, default=None)
max_seekable_time_ticks: protobug.Int64 | None = protobug.field(14, default=None)
max_seekable_timescale: protobug.Int32 | None = protobug.field(15, default=None)

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
from .format_id import FormatId
from .time_range import TimeRange
from ..innertube import CompressionAlgorithm
@protobug.message
class MediaHeader:
header_id: protobug.UInt32 | None = protobug.field(1, default=None)
video_id: protobug.String | None = protobug.field(2, default=None)
itag: protobug.Int32 | None = protobug.field(3, default=None)
last_modified: protobug.UInt64 | None = protobug.field(4, default=None)
xtags: protobug.String | None = protobug.field(5, default=None)
start_data_range: protobug.Int32 | None = protobug.field(6, default=None)
compression: CompressionAlgorithm | None = protobug.field(7, default=None)
is_init_segment: protobug.Bool | None = protobug.field(8, default=None)
sequence_number: protobug.Int64 | None = protobug.field(9, default=None)
bitrate_bps: protobug.Int64 | None = protobug.field(10, default=None)
start_ms: protobug.Int32 | None = protobug.field(11, default=None)
duration_ms: protobug.Int32 | None = protobug.field(12, default=None)
format_id: FormatId | None = protobug.field(13, default=None)
content_length: protobug.Int64 | None = protobug.field(14, default=None)
time_range: TimeRange | None = protobug.field(15, default=None)
sequence_lmt: protobug.Int32 | None = protobug.field(16, default=None)

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class ReloadPlaybackParams:
token: protobug.String | None = protobug.field(1, default=None)
@protobug.message
class ReloadPlayerResponse:
reload_playback_params: ReloadPlaybackParams | None = protobug.field(1, default=None)

View File

@ -0,0 +1,13 @@
from yt_dlp.dependencies import protobug
@protobug.message
class SabrContextSendingPolicy:
# Start sending the SabrContextUpdates of this type
start_policy: list[protobug.Int32] = protobug.field(1, default_factory=list)
# Stop sending the SabrContextUpdates of this type
stop_policy: list[protobug.Int32] = protobug.field(2, default_factory=list)
# Stop and discard the SabrContextUpdates of this type
discard_policy: list[protobug.Int32] = protobug.field(3, default_factory=list)

View File

@ -0,0 +1,25 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class SabrContextUpdate:
class SabrContextScope(protobug.Enum, strict=False):
SABR_CONTEXT_SCOPE_UNKNOWN = 0
SABR_CONTEXT_SCOPE_PLAYBACK = 1
SABR_CONTEXT_SCOPE_REQUEST = 2
SABR_CONTEXT_SCOPE_WATCH_ENDPOINT = 3
SABR_CONTEXT_SCOPE_CONTENT_ADS = 4
class SabrContextWritePolicy(protobug.Enum, strict=False):
SABR_CONTEXT_WRITE_POLICY_UNSPECIFIED = 0
SABR_CONTEXT_WRITE_POLICY_OVERWRITE = 1
SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING = 2
type: protobug.Int32 | None = protobug.field(1, default=None)
scope: SabrContextScope | None = protobug.field(2, default=None)
value: protobug.Bytes | None = protobug.field(3, default=None)
send_by_default: protobug.Bool | None = protobug.field(4, default=None)
write_policy: SabrContextWritePolicy | None = protobug.field(5, default=None)

View File

@ -0,0 +1,16 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class Error:
status_code: protobug.Int32 | None = protobug.field(1, default=None)
type: protobug.Int32 | None = protobug.field(4, default=None)
@protobug.message
class SabrError:
type: protobug.String | None = protobug.field(1, default=None)
action: protobug.Int32 | None = protobug.field(2, default=None)
error: Error | None = protobug.field(3, default=None)

View File

@ -0,0 +1,8 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class SabrRedirect:
redirect_url: protobug.String | None = protobug.field(1, default=None)

View File

@ -0,0 +1,12 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
from ..innertube import SeekSource
@protobug.message
class SabrSeek:
seek_time_ticks: protobug.Int32 = protobug.field(1)
timescale: protobug.Int32 = protobug.field(2)
seek_source: SeekSource | None = protobug.field(3, default=None)

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class StreamProtectionStatus:
class Status(protobug.Enum, strict=False):
OK = 1
ATTESTATION_PENDING = 2
ATTESTATION_REQUIRED = 3
status: Status | None = protobug.field(1, default=None)
max_retries: protobug.Int32 | None = protobug.field(2, default=None)

View File

@ -0,0 +1,21 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
from ..innertube import ClientInfo
@protobug.message
class SabrContext:
# Type and Value from a SabrContextUpdate
type: protobug.Int32 | None = protobug.field(1, default=None)
value: protobug.Bytes | None = protobug.field(2, default=None)
@protobug.message
class StreamerContext:
client_info: ClientInfo | None = protobug.field(1, default=None)
po_token: protobug.Bytes | None = protobug.field(2, default=None)
playback_cookie: protobug.Bytes | None = protobug.field(3, default=None)
sabr_contexts: list[SabrContext] = protobug.field(5, default_factory=list)
unsent_sabr_contexts: list[protobug.Int32] = protobug.field(6, default_factory=list)

View File

@ -0,0 +1,10 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
@protobug.message
class TimeRange:
start_ticks: protobug.Int64 | None = protobug.field(1, default=None)
duration_ticks: protobug.Int64 | None = protobug.field(2, default=None)
timescale: protobug.Int32 | None = protobug.field(3, default=None)

View File

@ -0,0 +1,22 @@
from __future__ import annotations
from yt_dlp.dependencies import protobug
from .buffered_range import BufferedRange
from .client_abr_state import ClientAbrState
from .format_id import FormatId
from .streamer_context import StreamerContext
@protobug.message
class VideoPlaybackAbrRequest:
client_abr_state: ClientAbrState = protobug.field(1, default=None)
initialized_format_ids: list[FormatId] = protobug.field(2, default_factory=list)
buffered_ranges: list[BufferedRange] = protobug.field(3, default_factory=list)
player_time_ms: protobug.Int64 | None = protobug.field(4, default=None)
video_playback_ustreamer_config: protobug.Bytes | None = protobug.field(5, default=None)
selected_audio_format_ids: list[FormatId] = protobug.field(16, default_factory=list)
selected_video_format_ids: list[FormatId] = protobug.field(17, default_factory=list)
selected_caption_format_ids: list[FormatId] = protobug.field(18, default_factory=list)
streamer_context: StreamerContext = protobug.field(19, default_factory=StreamerContext)

View File

@ -0,0 +1,26 @@
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
from yt_dlp.utils import YoutubeDLError
class SabrStreamConsumedError(YoutubeDLError):
pass
class SabrStreamError(YoutubeDLError):
pass
class MediaSegmentMismatchError(SabrStreamError):
def __init__(self, format_id: FormatId, expected_sequence_number: int, received_sequence_number: int):
super().__init__(
f'Segment sequence number mismatch for format {format_id}: '
f'expected {expected_sequence_number}, received {received_sequence_number}')
self.expected_sequence_number = expected_sequence_number
self.received_sequence_number = received_sequence_number
class PoTokenError(SabrStreamError):
def __init__(self, missing=False):
super().__init__(
f'This stream requires a GVS PO Token to continue'
f'{" and the one provided is invalid" if not missing else ""}')

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import dataclasses
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
from yt_dlp.extractor.youtube.pot._provider import IEContentProviderLogger
@dataclasses.dataclass
class Segment:
format_id: FormatId
is_init_segment: bool = False
duration_ms: int = 0
start_ms: int = 0
start_data_range: int = 0
sequence_number: int = 0
content_length: int | None = None
content_length_estimated: bool = False
initialized_format: InitializedFormat = None
# Whether duration_ms is an estimate
duration_estimated: bool = False
# Whether we should discard the segment data
discard: bool = False
# Whether the segment has already been consumed.
# `discard` should be set to True if this is the case.
consumed: bool = False
received_data_length: int = 0
sequence_lmt: int | None = None
@dataclasses.dataclass
class ConsumedRange:
start_sequence_number: int
end_sequence_number: int
start_time_ms: int
duration_ms: int
@dataclasses.dataclass
class InitializedFormat:
format_id: FormatId
video_id: str
format_selector: FormatSelector | None = None
duration_ms: int = 0
end_time_ms: int = 0
mime_type: str = None
# Current segment in the sequence. Set to None to break the sequence and allow a seek.
current_segment: Segment | None = None
init_segment: Segment | None | bool = None
consumed_ranges: list[ConsumedRange] = dataclasses.field(default_factory=list)
total_segments: int = None
# Whether we should discard any data received for this format
discard: bool = False
sequence_lmt: int | None = None
SabrLogger = IEContentProviderLogger
@dataclasses.dataclass
class FormatSelector:
display_name: str
format_ids: list[FormatId] = dataclasses.field(default_factory=list)
discard_media: bool = False
def match(self, format_id: FormatId = None, **kwargs) -> bool:
return format_id in self.format_ids
@dataclasses.dataclass
class AudioSelector(FormatSelector):
def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
return (
super().match(format_id, mime_type=mime_type, **kwargs)
or (not self.format_ids and mime_type and mime_type.lower().startswith('audio'))
)
@dataclasses.dataclass
class VideoSelector(FormatSelector):
def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
return (
super().match(format_id, mime_type=mime_type, **kwargs)
or (not self.format_ids and mime_type and mime_type.lower().startswith('video'))
)
@dataclasses.dataclass
class CaptionSelector(FormatSelector):
def match(self, format_id: FormatId = None, mime_type: str | None = None, **kwargs) -> bool:
return (
super().match(format_id, mime_type=mime_type, **kwargs)
or (not self.format_ids and mime_type and mime_type.lower().startswith('text'))
)

View File

@ -0,0 +1,92 @@
import dataclasses
import enum
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
from .models import FormatSelector
@dataclasses.dataclass
class SabrPart:
pass
@dataclasses.dataclass
class MediaSegmentInitSabrPart(SabrPart):
format_selector: FormatSelector
format_id: FormatId
sequence_number: int = None
is_init_segment: bool = False
total_segments: int = None
start_time_ms: int = None
player_time_ms: int = None
duration_ms: int = None
duration_estimated: bool = False
start_bytes: int = None
content_length: int = None
content_length_estimated: bool = False
@dataclasses.dataclass
class MediaSegmentDataSabrPart(SabrPart):
format_selector: FormatSelector
format_id: FormatId
sequence_number: int = None
is_init_segment: bool = False
total_segments: int = None
data: bytes = b''
content_length: int = None
segment_start_bytes: int = None
@dataclasses.dataclass
class MediaSegmentEndSabrPart(SabrPart):
format_selector: FormatSelector
format_id: FormatId
sequence_number: int = None
is_init_segment: bool = False
total_segments: int = None
@dataclasses.dataclass
class FormatInitializedSabrPart(SabrPart):
format_id: FormatId
format_selector: FormatSelector
@dataclasses.dataclass
class PoTokenStatusSabrPart(SabrPart):
class PoTokenStatus(enum.Enum):
OK = enum.auto() # PO Token is provided and valid
MISSING = enum.auto() # PO Token is not provided, and is required. A PO Token should be provided ASAP
INVALID = enum.auto() # PO Token is provided, but is invalid. A new one should be generated ASAP
PENDING = enum.auto() # PO Token is provided, but probably only a cold start token. A full PO Token should be provided ASAP
NOT_REQUIRED = enum.auto() # PO Token is not provided, and is not required
PENDING_MISSING = enum.auto() # PO Token is not provided, but is pending. A full PO Token should be (probably) provided ASAP
status: PoTokenStatus
@dataclasses.dataclass
class RefreshPlayerResponseSabrPart(SabrPart):
class Reason(enum.Enum):
UNKNOWN = enum.auto()
SABR_URL_EXPIRY = enum.auto()
SABR_RELOAD_PLAYER_RESPONSE = enum.auto()
reason: Reason
reload_playback_token: str = None
@dataclasses.dataclass
class MediaSeekSabrPart(SabrPart):
# Lets the consumer know the media sequence for a format may change
class Reason(enum.Enum):
UNKNOWN = enum.auto()
SERVER_SEEK = enum.auto() # SABR_SEEK from server
CONSUMED_SEEK = enum.auto() # Seeking as next fragment is already buffered
reason: Reason
format_id: FormatId
format_selector: FormatSelector

View File

@ -0,0 +1,671 @@
from __future__ import annotations
import base64
import io
import math
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
from yt_dlp.extractor.youtube._proto.videostreaming import (
BufferedRange,
ClientAbrState,
FormatInitializationMetadata,
LiveMetadata,
MediaHeader,
SabrContext,
SabrContextSendingPolicy,
SabrContextUpdate,
SabrSeek,
StreamerContext,
StreamProtectionStatus,
TimeRange,
VideoPlaybackAbrRequest,
)
from .exceptions import MediaSegmentMismatchError, SabrStreamError
from .models import (
AudioSelector,
CaptionSelector,
ConsumedRange,
InitializedFormat,
SabrLogger,
Segment,
VideoSelector,
)
from .part import (
FormatInitializedSabrPart,
MediaSeekSabrPart,
MediaSegmentDataSabrPart,
MediaSegmentEndSabrPart,
MediaSegmentInitSabrPart,
PoTokenStatusSabrPart,
)
from .utils import ticks_to_ms
class ProcessMediaEndResult:
def __init__(self, sabr_part: MediaSegmentEndSabrPart = None, is_new_segment: bool = False):
self.is_new_segment = is_new_segment
self.sabr_part = sabr_part
class ProcessMediaResult:
def __init__(self, sabr_part: MediaSegmentDataSabrPart = None):
self.sabr_part = sabr_part
class ProcessMediaHeaderResult:
def __init__(self, sabr_part: MediaSegmentInitSabrPart | None = None):
self.sabr_part = sabr_part
class ProcessLiveMetadataResult:
def __init__(self, seek_sabr_parts: list[MediaSeekSabrPart] | None = None):
self.seek_sabr_parts = seek_sabr_parts or []
class ProcessStreamProtectionStatusResult:
def __init__(self, sabr_part: PoTokenStatusSabrPart | None = None):
self.sabr_part = sabr_part
class ProcessFormatInitializationMetadataResult:
def __init__(self, sabr_part: FormatInitializedSabrPart | None = None):
self.sabr_part = sabr_part
class ProcessSabrSeekResult:
def __init__(self, seek_sabr_parts: list[MediaSeekSabrPart] | None = None):
self.seek_sabr_parts = seek_sabr_parts or []
class SabrProcessor:
"""
SABR Processor
This handles core SABR protocol logic, independent of requests.
"""
def __init__(
self,
logger: SabrLogger,
video_playback_ustreamer_config: str,
client_info: ClientInfo,
audio_selection: AudioSelector | None = None,
video_selection: VideoSelector | None = None,
caption_selection: CaptionSelector | None = None,
live_segment_target_duration_sec: int | None = None,
live_segment_target_duration_tolerance_ms: int | None = None,
start_time_ms: int | None = None,
po_token: str | None = None,
live_end_wait_sec: int | None = None,
live_end_segment_tolerance: int | None = None,
post_live: bool = False,
video_id: str | None = None,
):
self.logger = logger
self.video_playback_ustreamer_config = video_playback_ustreamer_config
self.po_token = po_token
self.client_info = client_info
self.live_segment_target_duration_sec = live_segment_target_duration_sec or 5
self.live_segment_target_duration_tolerance_ms = live_segment_target_duration_tolerance_ms or 100
if self.live_segment_target_duration_tolerance_ms >= (self.live_segment_target_duration_sec * 1000) / 2:
raise ValueError(
'live_segment_target_duration_tolerance_ms must be less than '
'half of live_segment_target_duration_sec in milliseconds',
)
self.start_time_ms = start_time_ms or 0
if self.start_time_ms < 0:
raise ValueError('start_time_ms must be greater than or equal to 0')
self.live_end_wait_sec = live_end_wait_sec or max(10, 3 * self.live_segment_target_duration_sec)
self.live_end_segment_tolerance = live_end_segment_tolerance or 10
self.post_live = post_live
self._is_live = False
self.video_id = video_id
self._audio_format_selector = audio_selection
self._video_format_selector = video_selection
self._caption_format_selector = caption_selection
# IMPORTANT: initialized formats is assumed to contain only ACTIVE formats
self.initialized_formats: dict[str, InitializedFormat] = {}
self.stream_protection_status: StreamProtectionStatus.Status | None = None
self.partial_segments: dict[int, Segment] = {}
self.total_duration_ms = None
self.selected_audio_format_ids = []
self.selected_video_format_ids = []
self.selected_caption_format_ids = []
self.next_request_policy: NextRequestPolicy | None = None
self.live_metadata: LiveMetadata | None = None
self.client_abr_state: ClientAbrState
self.sabr_contexts_to_send: set[int] = set()
self.sabr_context_updates: dict[int, SabrContextUpdate] = {}
self._initialize_cabr_state()
@property
def is_live(self):
return bool(
self.live_metadata
or self._is_live,
)
@is_live.setter
def is_live(self, value: bool):
self._is_live = value
def _initialize_cabr_state(self):
enabled_track_types_bitfield = 0 # Audio+Video
if not self._video_format_selector:
enabled_track_types_bitfield = 1 # Audio only
self._video_format_selector = VideoSelector(display_name='video_ignore', discard_media=True)
if self._caption_format_selector:
# SABR does not support caption-only or audio+captions only - can only get audio+video with captions
# If audio or video is not selected, the tracks will be initialized but marked as buffered.
enabled_track_types_bitfield = 7
# SABR does not support video-only, so we need to discard the audio track received.
# We need a selector as the server sometimes does not like it
# if we haven't initialized an audio format (e.g. livestreams).
if not self._audio_format_selector:
self._audio_format_selector = AudioSelector(display_name='audio_ignore', discard_media=True)
if not self._caption_format_selector:
self._caption_format_selector = CaptionSelector(display_name='caption_ignore', discard_media=True)
self.selected_audio_format_ids = self._audio_format_selector.format_ids
self.selected_video_format_ids = self._video_format_selector.format_ids
self.selected_caption_format_ids = self._caption_format_selector.format_ids
self.logger.debug(f'Starting playback at: {self.start_time_ms}ms')
self.client_abr_state = ClientAbrState(
player_time_ms=self.start_time_ms,
enabled_track_types_bitfield=enabled_track_types_bitfield)
def match_format_selector(self, format_init_metadata):
for format_selector in (self._video_format_selector, self._audio_format_selector, self._caption_format_selector):
if not format_selector:
continue
if format_selector.match(format_id=format_init_metadata.format_id, mime_type=format_init_metadata.mime_type):
return format_selector
return None
def process_media_header(self, media_header: MediaHeader) -> ProcessMediaHeaderResult:
if media_header.video_id and self.video_id and media_header.video_id != self.video_id:
raise SabrStreamError(
f'Received unexpected MediaHeader for video'
f' {media_header.video_id} (expecting {self.video_id})')
if not media_header.format_id:
raise SabrStreamError(f'Format ID not found in MediaHeader (media_header={media_header})')
# Guard. This should not happen, except if we don't clear partial segments
if media_header.header_id in self.partial_segments:
raise SabrStreamError(f'Header ID {media_header.header_id} already exists')
result = ProcessMediaHeaderResult()
initialized_format = self.initialized_formats.get(str(media_header.format_id))
if not initialized_format:
self.logger.debug(f'Initialized format not found for {media_header.format_id}')
return result
if media_header.compression:
# Unknown when this is used, but it is not supported currently
raise SabrStreamError(f'Compression not supported in MediaHeader (media_header={media_header})')
sequence_number, is_init_segment = media_header.sequence_number, media_header.is_init_segment
if sequence_number is None and not media_header.is_init_segment:
raise SabrStreamError(f'Sequence number not found in MediaHeader (media_header={media_header})')
initialized_format.sequence_lmt = media_header.sequence_lmt
# Need to keep track of if we discard due to be consumed or not
# for processing down the line (MediaEnd)
consumed = False
discard = initialized_format.discard
# Guard: Check if sequence number is within any existing consumed range
# The server should not send us any segments that are already consumed
# However, if retrying a request, we may get the same segment again
if not is_init_segment and any(
cr.start_sequence_number <= sequence_number <= cr.end_sequence_number
for cr in initialized_format.consumed_ranges
):
self.logger.debug(f'{initialized_format.format_id} segment {sequence_number} already consumed, marking segment as consumed')
consumed = True
# Validate that the segment is in order.
# Note: If the format is to be discarded, we do not care about the order
# and can expect uncommanded seeks as the consumer does not know about it.
# Note: previous segment should never be an init segment.
previous_segment = initialized_format.current_segment
if (
previous_segment and not is_init_segment
and not previous_segment.discard and not discard and not consumed
and sequence_number != previous_segment.sequence_number + 1
):
# Bail out as the segment is not in order when it is expected to be
raise MediaSegmentMismatchError(
expected_sequence_number=previous_segment.sequence_number + 1,
received_sequence_number=sequence_number,
format_id=media_header.format_id)
if initialized_format.init_segment and is_init_segment:
self.logger.debug(
f'Init segment {sequence_number} already seen for format {initialized_format.format_id}, marking segment as consumed')
consumed = True
time_range = media_header.time_range
start_ms = media_header.start_ms or (time_range and ticks_to_ms(time_range.start_ticks, time_range.timescale)) or 0
# Calculate duration of this segment
# For videos, either duration_ms or time_range should be present
# For live streams, calculate segment duration based on live metadata target segment duration
actual_duration_ms = (
media_header.duration_ms
or (time_range and ticks_to_ms(time_range.duration_ticks, time_range.timescale)))
estimated_duration_ms = None
if self.is_live:
# Underestimate the duration of the segment slightly as
# the real duration may be slightly shorter than the target duration.
estimated_duration_ms = (self.live_segment_target_duration_sec * 1000) - self.live_segment_target_duration_tolerance_ms
elif is_init_segment:
estimated_duration_ms = 0
duration_ms = actual_duration_ms or estimated_duration_ms
estimated_content_length = None
if self.is_live and media_header.content_length is None and media_header.bitrate_bps is not None:
estimated_content_length = math.ceil(media_header.bitrate_bps * (estimated_duration_ms / 1000))
# Guard: Bail out if we cannot determine the duration, which we need to progress.
if duration_ms is None:
raise SabrStreamError(f'Cannot determine duration of segment {sequence_number} (media_header={media_header})')
segment = Segment(
format_id=media_header.format_id,
is_init_segment=is_init_segment,
duration_ms=duration_ms,
start_data_range=media_header.start_data_range,
sequence_number=sequence_number,
content_length=media_header.content_length or estimated_content_length,
content_length_estimated=estimated_content_length is not None,
start_ms=start_ms,
initialized_format=initialized_format,
duration_estimated=not actual_duration_ms,
discard=discard or consumed,
consumed=consumed,
sequence_lmt=media_header.sequence_lmt,
)
self.partial_segments[media_header.header_id] = segment
if not segment.discard:
result.sabr_part = MediaSegmentInitSabrPart(
format_selector=segment.initialized_format.format_selector,
format_id=segment.format_id,
player_time_ms=self.client_abr_state.player_time_ms,
sequence_number=segment.sequence_number,
total_segments=segment.initialized_format.total_segments,
duration_ms=segment.duration_ms,
start_bytes=segment.start_data_range,
start_time_ms=segment.start_ms,
is_init_segment=segment.is_init_segment,
content_length=segment.content_length,
content_length_estimated=segment.content_length_estimated,
)
self.logger.trace(
f'Initialized Media Header {media_header.header_id} for sequence {sequence_number}. Segment: {segment}')
return result
def process_media(self, header_id: int, content_length: int, data: io.BufferedIOBase) -> ProcessMediaResult:
result = ProcessMediaResult()
segment = self.partial_segments.get(header_id)
if not segment:
self.logger.debug(f'Header ID {header_id} not found')
return result
segment_start_bytes = segment.received_data_length
segment.received_data_length += content_length
if not segment.discard:
result.sabr_part = MediaSegmentDataSabrPart(
format_selector=segment.initialized_format.format_selector,
format_id=segment.format_id,
sequence_number=segment.sequence_number,
is_init_segment=segment.is_init_segment,
total_segments=segment.initialized_format.total_segments,
data=data.read(),
content_length=content_length,
segment_start_bytes=segment_start_bytes,
)
return result
def process_media_end(self, header_id: int) -> ProcessMediaEndResult:
result = ProcessMediaEndResult()
segment = self.partial_segments.pop(header_id, None)
if not segment:
# Should only happen due to server issue,
# or we have an uninitialized format (which itself should not happen)
self.logger.warning(f'Received a MediaEnd for an unknown or already finished header ID {header_id}')
return result
self.logger.trace(
f'MediaEnd for {segment.format_id} (sequence {segment.sequence_number}, data length = {segment.received_data_length})')
if segment.content_length is not None and segment.received_data_length != segment.content_length:
if segment.content_length_estimated:
self.logger.trace(
f'Content length for {segment.format_id} (sequence {segment.sequence_number}) was estimated, '
f'estimated {segment.content_length} bytes, got {segment.received_data_length} bytes')
else:
raise SabrStreamError(
f'Content length mismatch for {segment.format_id} (sequence {segment.sequence_number}): '
f'expected {segment.content_length} bytes, got {segment.received_data_length} bytes',
)
# Only count received segments as new segments if they are not discarded (consumed)
# or it was part of a format that was discarded (but not consumed).
# The latter can happen if the format is to be discarded but was not marked as fully consumed.
if not segment.discard or (segment.initialized_format.discard and not segment.consumed):
result.is_new_segment = True
# Return the segment here instead of during MEDIA part(s) because:
# 1. We can validate that we received the correct data length
# 2. In the case of a retry during segment media, the partial data is not sent to the consumer
if not segment.discard:
# This needs to be yielded AFTER we have processed the segment
# So the consumer can see the updated consumed ranges and use them for e.g. syncing between concurrent streams
result.sabr_part = MediaSegmentEndSabrPart(
format_selector=segment.initialized_format.format_selector,
format_id=segment.format_id,
sequence_number=segment.sequence_number,
is_init_segment=segment.is_init_segment,
total_segments=segment.initialized_format.total_segments,
)
else:
self.logger.trace(f'Discarding media for {segment.initialized_format.format_id}')
if segment.is_init_segment:
segment.initialized_format.init_segment = segment
# Do not create a consumed range for init segments
return result
if segment.initialized_format.current_segment and self.is_live:
previous_segment = segment.initialized_format.current_segment
self.logger.trace(
f'Previous segment {previous_segment.sequence_number} for format {segment.format_id} '
f'estimated duration difference from this segment ({segment.sequence_number}): {segment.start_ms - (previous_segment.start_ms + previous_segment.duration_ms)}ms')
segment.initialized_format.current_segment = segment
# Try to find a consumed range for this segment in sequence
consumed_range = next(
(cr for cr in segment.initialized_format.consumed_ranges if cr.end_sequence_number == segment.sequence_number - 1),
None,
)
if not consumed_range and any(
cr.start_sequence_number <= segment.sequence_number <= cr.end_sequence_number
for cr in segment.initialized_format.consumed_ranges
):
# Segment is already consumed, do not create a new consumed range. It was probably discarded.
# This can be expected to happen in the case of video-only, where we discard the audio track (and mark it as entirely buffered)
# We still want to create/update consumed range for discarded media IF it is not already consumed
self.logger.debug(f'{segment.format_id} segment {segment.sequence_number} already consumed, not creating or updating consumed range (discard={segment.discard})')
return result
if not consumed_range:
# Create a new consumed range starting from this segment
segment.initialized_format.consumed_ranges.append(ConsumedRange(
start_time_ms=segment.start_ms,
duration_ms=segment.duration_ms,
start_sequence_number=segment.sequence_number,
end_sequence_number=segment.sequence_number,
))
self.logger.debug(f'Created new consumed range for {segment.initialized_format.format_id} {segment.initialized_format.consumed_ranges[-1]}')
return result
# Update the existing consumed range to include this segment
consumed_range.end_sequence_number = segment.sequence_number
consumed_range.duration_ms = (segment.start_ms - consumed_range.start_time_ms) + segment.duration_ms
# TODO: Conduct a seek on consumed ranges
return result
def process_live_metadata(self, live_metadata: LiveMetadata) -> ProcessLiveMetadataResult:
self.live_metadata = live_metadata
if self.live_metadata.head_sequence_time_ms:
self.total_duration_ms = self.live_metadata.head_sequence_time_ms
# If we have a head sequence number, we need to update the total sequences for each initialized format
# For livestreams, it is not available in the format initialization metadata
if self.live_metadata.head_sequence_number:
for izf in self.initialized_formats.values():
izf.total_segments = self.live_metadata.head_sequence_number
result = ProcessLiveMetadataResult()
# If the current player time is less than the min dvr time, simulate a server seek to the min dvr time.
# The server SHOULD send us a SABR_SEEK part in this case, but it does not always happen (e.g. ANDROID_VR)
# The server SHOULD NOT send us segments before the min dvr time, so we should assume that the player time is correct.
min_seekable_time_ms = ticks_to_ms(self.live_metadata.min_seekable_time_ticks, self.live_metadata.min_seekable_timescale)
if min_seekable_time_ms is not None and self.client_abr_state.player_time_ms < min_seekable_time_ms:
self.logger.debug(f'Player time {self.client_abr_state.player_time_ms} is less than min seekable time {min_seekable_time_ms}, simulating server seek')
self.client_abr_state.player_time_ms = min_seekable_time_ms
for izf in self.initialized_formats.values():
izf.current_segment = None # Clear the current segment as we expect segments to no longer be in order.
result.seek_sabr_parts.append(MediaSeekSabrPart(
reason=MediaSeekSabrPart.Reason.SERVER_SEEK,
format_id=izf.format_id,
format_selector=izf.format_selector,
))
return result
def process_stream_protection_status(self, stream_protection_status: StreamProtectionStatus) -> ProcessStreamProtectionStatusResult:
self.stream_protection_status = stream_protection_status.status
status = stream_protection_status.status
po_token = self.po_token
if status == StreamProtectionStatus.Status.OK:
result_status = (
PoTokenStatusSabrPart.PoTokenStatus.OK if po_token
else PoTokenStatusSabrPart.PoTokenStatus.NOT_REQUIRED
)
elif status == StreamProtectionStatus.Status.ATTESTATION_PENDING:
result_status = (
PoTokenStatusSabrPart.PoTokenStatus.PENDING if po_token
else PoTokenStatusSabrPart.PoTokenStatus.PENDING_MISSING
)
elif status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED:
result_status = (
PoTokenStatusSabrPart.PoTokenStatus.INVALID if po_token
else PoTokenStatusSabrPart.PoTokenStatus.MISSING
)
else:
result_status = None
sabr_part = PoTokenStatusSabrPart(status=result_status) if result_status is not None else None
return ProcessStreamProtectionStatusResult(sabr_part)
def process_format_initialization_metadata(self, format_init_metadata: FormatInitializationMetadata) -> ProcessFormatInitializationMetadataResult:
result = ProcessFormatInitializationMetadataResult()
if str(format_init_metadata.format_id) in self.initialized_formats:
self.logger.trace(f'Format {format_init_metadata.format_id} already initialized')
return result
if format_init_metadata.video_id and self.video_id and format_init_metadata.video_id != self.video_id:
raise SabrStreamError(
f'Received unexpected Format Initialization Metadata for video'
f' {format_init_metadata.video_id} (expecting {self.video_id})')
format_selector = self.match_format_selector(format_init_metadata)
if not format_selector:
# Should not happen. If we ignored the format the server may refuse to send us any more data
raise SabrStreamError(f'Received format {format_init_metadata.format_id} but it does not match any format selector')
# Guard: Check if the format selector is already in use by another initialized format.
# This can happen when the server changes the format to use (e.g. changing quality).
#
# Changing a format will require adding some logic to handle inactive formats.
# Given we only provide one FormatId currently, and this should not occur in this case,
# we will mark this as not currently supported and bail.
for izf in self.initialized_formats.values():
if izf.format_selector is format_selector:
raise SabrStreamError('Server changed format. Changing formats is not currently supported')
duration_ms = ticks_to_ms(format_init_metadata.duration_timescale, format_init_metadata.duration_ticks)
total_segments = format_init_metadata.total_segments
if not total_segments and self.live_metadata and self.live_metadata.head_sequence_number:
total_segments = self.live_metadata.head_sequence_number
initialized_format = InitializedFormat(
format_id=format_init_metadata.format_id,
duration_ms=duration_ms,
end_time_ms=format_init_metadata.end_time_ms,
mime_type=format_init_metadata.mime_type,
video_id=format_init_metadata.video_id,
format_selector=format_selector,
total_segments=total_segments,
discard=format_selector.discard_media,
)
self.total_duration_ms = max(self.total_duration_ms or 0, format_init_metadata.end_time_ms or 0, duration_ms or 0)
if initialized_format.discard:
# Mark the entire format as buffered into oblivion if we plan to discard all media.
# This stops the server sending us any more data for this format.
# Note: Using JS_MAX_SAFE_INTEGER but could use any maximum value as long as the server accepts it.
initialized_format.consumed_ranges = [ConsumedRange(
start_time_ms=0,
duration_ms=(2**53) - 1,
start_sequence_number=0,
end_sequence_number=(2**53) - 1,
)]
self.initialized_formats[str(format_init_metadata.format_id)] = initialized_format
self.logger.debug(f'Initialized Format: {initialized_format}')
if not initialized_format.discard:
result.sabr_part = FormatInitializedSabrPart(
format_id=format_init_metadata.format_id,
format_selector=format_selector,
)
return ProcessFormatInitializationMetadataResult(sabr_part=result.sabr_part)
def process_next_request_policy(self, next_request_policy: NextRequestPolicy):
self.next_request_policy = next_request_policy
self.logger.trace(f'Registered new NextRequestPolicy: {self.next_request_policy}')
def process_sabr_seek(self, sabr_seek: SabrSeek) -> ProcessSabrSeekResult:
seek_to = ticks_to_ms(sabr_seek.seek_time_ticks, sabr_seek.timescale)
if seek_to is None:
raise SabrStreamError(f'Server sent a SabrSeek part that is missing required seek data: {sabr_seek}')
self.logger.debug(f'Seeking to {seek_to}ms')
self.client_abr_state.player_time_ms = seek_to
result = ProcessSabrSeekResult()
# Clear latest segment of each initialized format
# as we expect them to no longer be in order.
for initialized_format in self.initialized_formats.values():
initialized_format.current_segment = None
result.seek_sabr_parts.append(MediaSeekSabrPart(
reason=MediaSeekSabrPart.Reason.SERVER_SEEK,
format_id=initialized_format.format_id,
format_selector=initialized_format.format_selector,
))
return result
def process_sabr_context_update(self, sabr_ctx_update: SabrContextUpdate):
if not (sabr_ctx_update.type and sabr_ctx_update.value and sabr_ctx_update.write_policy):
self.logger.warning('Received an invalid SabrContextUpdate, ignoring')
return
if (
sabr_ctx_update.write_policy == SabrContextUpdate.SabrContextWritePolicy.SABR_CONTEXT_WRITE_POLICY_KEEP_EXISTING
and sabr_ctx_update.type in self.sabr_context_updates
):
self.logger.debug(
'Received a SABR Context Update with write_policy=KEEP_EXISTING'
'matching an existing SABR Context Update. Ignoring update')
return
self.logger.warning(
'Received a SABR Context Update. YouTube is likely trying to force ads on the client. '
'This may cause issues with playback.')
self.sabr_context_updates[sabr_ctx_update.type] = sabr_ctx_update
if sabr_ctx_update.send_by_default is True:
self.sabr_contexts_to_send.add(sabr_ctx_update.type)
self.logger.debug(f'Registered SabrContextUpdate {sabr_ctx_update}')
def process_sabr_context_sending_policy(self, sabr_ctx_sending_policy: SabrContextSendingPolicy):
for start_type in sabr_ctx_sending_policy.start_policy:
if start_type not in self.sabr_context_updates:
self.logger.debug(f'Server requested to enable SABR Context Update for type {start_type}')
self.sabr_contexts_to_send.add(start_type)
for stop_type in sabr_ctx_sending_policy.stop_policy:
if stop_type in self.sabr_contexts_to_send:
self.logger.debug(f'Server requested to disable SABR Context Update for type {stop_type}')
self.sabr_contexts_to_send.remove(stop_type)
for discard_type in sabr_ctx_sending_policy.discard_policy:
if discard_type in self.sabr_context_updates:
self.logger.debug(f'Server requested to discard SABR Context Update for type {discard_type}')
self.sabr_context_updates.pop(discard_type, None)
def build_vpabr_request(processor: SabrProcessor):
return VideoPlaybackAbrRequest(
client_abr_state=processor.client_abr_state,
selected_video_format_ids=processor.selected_video_format_ids,
selected_audio_format_ids=processor.selected_audio_format_ids,
selected_caption_format_ids=processor.selected_caption_format_ids,
initialized_format_ids=[
initialized_format.format_id for initialized_format in processor.initialized_formats.values()
],
video_playback_ustreamer_config=base64.urlsafe_b64decode(processor.video_playback_ustreamer_config),
streamer_context=StreamerContext(
po_token=base64.urlsafe_b64decode(processor.po_token) if processor.po_token is not None else None,
playback_cookie=processor.next_request_policy.playback_cookie if processor.next_request_policy is not None else None,
client_info=processor.client_info,
sabr_contexts=[
SabrContext(context.type, context.value)
for context in processor.sabr_context_updates.values()
if context.type in processor.sabr_contexts_to_send
],
unsent_sabr_contexts=[
context_type for context_type in processor.sabr_contexts_to_send
if context_type not in processor.sabr_context_updates
],
),
buffered_ranges=[
BufferedRange(
format_id=initialized_format.format_id,
start_segment_index=cr.start_sequence_number,
end_segment_index=cr.end_sequence_number,
start_time_ms=cr.start_time_ms,
duration_ms=cr.duration_ms,
time_range=TimeRange(
start_ticks=cr.start_time_ms,
duration_ticks=cr.duration_ms,
timescale=1000,
),
) for initialized_format in processor.initialized_formats.values()
for cr in initialized_format.consumed_ranges
],
)

View File

@ -0,0 +1,813 @@
from __future__ import annotations
import base64
import dataclasses
import datetime as dt
import math
import time
import typing
import urllib.parse
from yt_dlp.dependencies import protobug
from yt_dlp.extractor.youtube._proto import unknown_fields
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, NextRequestPolicy
from yt_dlp.extractor.youtube._proto.videostreaming import (
FormatInitializationMetadata,
LiveMetadata,
MediaHeader,
ReloadPlayerResponse,
SabrContextSendingPolicy,
SabrContextUpdate,
SabrError,
SabrRedirect,
SabrSeek,
StreamProtectionStatus,
)
from yt_dlp.networking import Request, Response
from yt_dlp.networking.exceptions import HTTPError, TransportError
from yt_dlp.utils import RetryManager, int_or_none, parse_qs, str_or_none, traverse_obj
from .exceptions import MediaSegmentMismatchError, PoTokenError, SabrStreamConsumedError, SabrStreamError
from .models import AudioSelector, CaptionSelector, SabrLogger, VideoSelector
from .part import (
MediaSeekSabrPart,
RefreshPlayerResponseSabrPart,
)
from .processor import SabrProcessor, build_vpabr_request
from .utils import broadcast_id_from_url, get_cr_chain, next_gvs_fallback_url
from ..ump import UMPDecoder, UMPPart, UMPPartId, read_varint
class SabrStream:
"""
A YouTube SABR (Server Adaptive Bit Rate) client implementation designed for downloading streams and videos.
It presents an iterator (iter_parts) that yields the next available segments and other metadata.
Parameters:
@param urlopen: A callable that takes a Request and returns a Response. Raises TransportError or HTTPError on failure.
@param logger: The logger.
@param server_abr_streaming_url: SABR streaming URL.
@param video_playback_ustreamer_config: The base64url encoded ustreamer config.
@param client_info: The Innertube client info.
@param audio_selection: The audio format selector to use for audio.
@param video_selection: The video format selector to use for video.
@param caption_selection: The caption format selector to use for captions.
@param live_segment_target_duration_sec: The target duration of live segments in seconds.
@param live_segment_target_duration_tolerance_ms: The tolerance to accept for estimated duration of live segment in milliseconds.
@param start_time_ms: The time in milliseconds to start playback from.
@param po_token: Initial GVS PO Token.
@param http_retries: The maximum number of times to retry a request before failing.
@param pot_retries: The maximum number of times to retry PO Token errors before failing.
@param host_fallback_threshold: The number of consecutive retries before falling back to the next GVS server.
@param max_empty_requests: The maximum number of consecutive requests with no new segments before giving up.
@param live_end_wait_sec: The number of seconds to wait after the last received segment before considering the live stream ended.
@param live_end_segment_tolerance: The number of segments before the live head segment at which the livestream is allowed to end. Defaults to 10.
@param post_live: Whether the live stream is in post-live mode. Used to determine how to handle the end of the stream.
@param video_id: The video ID of the YouTube video. Used for validating received data is for the correct video.
@param retry_sleep_func: A function to sleep between retries. If None, will not sleep between retries.
@param expiry_threshold_sec: The number of seconds before the GVS expiry to consider it expired. Defaults to 1 minute.
"""
# Used for debugging
_IGNORED_PARTS = (
UMPPartId.REQUEST_IDENTIFIER,
UMPPartId.REQUEST_CANCELLATION_POLICY,
UMPPartId.PLAYBACK_START_POLICY,
UMPPartId.ALLOWED_CACHED_FORMATS,
UMPPartId.PAUSE_BW_SAMPLING_HINT,
UMPPartId.START_BW_SAMPLING_HINT,
UMPPartId.REQUEST_PIPELINING,
UMPPartId.SELECTABLE_FORMATS,
UMPPartId.PREWARM_CONNECTION,
)
@dataclasses.dataclass
class _NoSegmentsTracker:
consecutive_requests: int = 0
timestamp_started: float | None = None
live_head_segment_started: int | None = None
def reset(self):
self.consecutive_requests = 0
self.timestamp_started = None
self.live_head_segment_started = None
def increment(self, live_head_segment=None):
if self.consecutive_requests == 0:
self.timestamp_started = time.time()
self.live_head_segment_started = live_head_segment
self.consecutive_requests += 1
def __init__(
self,
urlopen: typing.Callable[[Request], Response],
logger: SabrLogger,
server_abr_streaming_url: str,
video_playback_ustreamer_config: str,
client_info: ClientInfo,
audio_selection: AudioSelector | None = None,
video_selection: VideoSelector | None = None,
caption_selection: CaptionSelector | None = None,
live_segment_target_duration_sec: int | None = None,
live_segment_target_duration_tolerance_ms: int | None = None,
start_time_ms: int | None = None,
po_token: str | None = None,
http_retries: int | None = None,
pot_retries: int | None = None,
host_fallback_threshold: int | None = None,
max_empty_requests: int | None = None,
live_end_wait_sec: int | None = None,
live_end_segment_tolerance: int | None = None,
post_live: bool = False,
video_id: str | None = None,
retry_sleep_func: int | None = None,
expiry_threshold_sec: int | None = None,
):
self.logger = logger
self._urlopen = urlopen
self.processor = SabrProcessor(
logger=logger,
video_playback_ustreamer_config=video_playback_ustreamer_config,
client_info=client_info,
audio_selection=audio_selection,
video_selection=video_selection,
caption_selection=caption_selection,
live_segment_target_duration_sec=live_segment_target_duration_sec,
live_segment_target_duration_tolerance_ms=live_segment_target_duration_tolerance_ms,
start_time_ms=start_time_ms,
po_token=po_token,
live_end_wait_sec=live_end_wait_sec,
live_end_segment_tolerance=live_end_segment_tolerance,
post_live=post_live,
video_id=video_id,
)
self.url = server_abr_streaming_url
self.http_retries = http_retries or 10
self.pot_retries = pot_retries or 5
self.host_fallback_threshold = host_fallback_threshold or 8
self.max_empty_requests = max_empty_requests or 3
self.expiry_threshold_sec = expiry_threshold_sec or 60 # 60 seconds
if self.expiry_threshold_sec <= 0:
raise ValueError('expiry_threshold_sec must be greater than 0')
if self.max_empty_requests <= 0:
raise ValueError('max_empty_requests must be greater than 0')
self.retry_sleep_func = retry_sleep_func
self._request_number = 0
# Whether we got any new (not consumed) segments in the request.
self._received_new_segments = False
self._no_new_segments_tracker = SabrStream._NoSegmentsTracker()
self._sps_retry_manager: typing.Generator | None = None
self._current_sps_retry = None
self._http_retry_manager: typing.Generator | None = None
self._current_http_retry = None
self._unknown_part_types = set()
# Whether the current request is a result of a retry
self._is_retry = False
self._consumed = False
self._sq_mismatch_backtrack_count = 0
self._sq_mismatch_forward_count = 0
def close(self):
self._consumed = True
def __iter__(self):
return self.iter_parts()
@property
def url(self):
return self._url
@url.setter
def url(self, url):
self.logger.debug(f'New URL: {url}')
if hasattr(self, '_url') and ((bn := broadcast_id_from_url(url)) != (bc := broadcast_id_from_url(self.url))):
raise SabrStreamError(f'Broadcast ID changed from {bc} to {bn}. The download will need to be restarted.')
self._url = url
if str_or_none(parse_qs(url).get('source', [None])[0]) == 'yt_live_broadcast':
self.processor.is_live = True
def iter_parts(self):
if self._consumed:
raise SabrStreamConsumedError('SABR stream has already been consumed')
self._http_retry_manager = None
self._sps_retry_manager = None
def report_retry(err, count, retries, fatal=True):
if count >= self.host_fallback_threshold:
self._process_fallback_server()
RetryManager.report_retry(
err, count, retries, info=self.logger.info,
warn=lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
error=None if fatal else lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
sleep_func=self.retry_sleep_func,
)
def report_sps_retry(err, count, retries, fatal=True):
RetryManager.report_retry(
err, count, retries, info=self.logger.info,
warn=lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
error=None if fatal else lambda msg: self.logger.warning(f'[sabr] Got error: {msg}'),
sleep_func=self.retry_sleep_func,
)
while not self._consumed:
if self._http_retry_manager is None:
self._http_retry_manager = iter(RetryManager(self.http_retries, report_retry))
if self._sps_retry_manager is None:
self._sps_retry_manager = iter(RetryManager(self.pot_retries, report_sps_retry))
self._current_http_retry = next(self._http_retry_manager)
self._current_sps_retry = next(self._sps_retry_manager)
self._log_state()
yield from self._process_expiry()
vpabr = build_vpabr_request(self.processor)
payload = protobug.dumps(vpabr)
self.logger.trace(f'Ustreamer Config: {self.processor.video_playback_ustreamer_config}')
self.logger.trace(f'Sending SABR request: {vpabr}')
response = None
try:
response = self._urlopen(
Request(
url=self.url,
method='POST',
data=payload,
query={'rn': self._request_number},
headers={
'content-type': 'application/x-protobuf',
'accept-encoding': 'identity',
'accept': 'application/vnd.yt-ump',
},
),
)
self._request_number += 1
except TransportError as e:
self._current_http_retry.error = e
except HTTPError as e:
# retry on 5xx errors only
if 500 <= e.status < 600:
self._current_http_retry.error = e
else:
raise SabrStreamError(f'HTTP Error: {e.status} - {e.reason}')
if response:
try:
yield from self._parse_ump_response(response)
except TransportError as e:
self._current_http_retry.error = e
if not response.closed:
response.close()
self._validate_response_integrity()
self._process_sps_retry()
if not self._current_http_retry.error:
self._http_retry_manager = None
if not self._current_sps_retry.error:
self._sps_retry_manager = None
retry_next_request = bool(self._current_http_retry.error or self._current_sps_retry.error)
# We are expecting to stay in the same place for a retry
if not retry_next_request:
# Only increment request no segments number if we are not retrying
self._process_request_had_segments()
# Calculate and apply the next playback time to skip to
yield from self._prepare_next_playback_time()
# Request successfully processed, next request is not a retry
self._is_retry = False
else:
self._is_retry = True
self._received_new_segments = False
self._consumed = True
def _process_sps_retry(self):
error = PoTokenError(missing=not self.processor.po_token)
if self.processor.stream_protection_status == StreamProtectionStatus.Status.ATTESTATION_REQUIRED:
# Always start retrying immediately on ATTESTATION_REQUIRED
self._current_sps_retry.error = error
return
elif (
self.processor.stream_protection_status == StreamProtectionStatus.Status.ATTESTATION_PENDING
and self._no_new_segments_tracker.consecutive_requests >= self.max_empty_requests
and (not self.processor.is_live or self.processor.stream_protection_status or (
self.processor.live_metadata is not None
and self._no_new_segments_tracker.live_head_segment_started != self.processor.live_metadata.head_sequence_number)
)
):
# Sometimes YouTube sends no data on ATTESTATION_PENDING, so in this case we need to count retries to fail on a PO Token error.
# We only start counting retries after max_empty_requests in case of intermittent no data that we need to increase the player time on.
# For livestreams when we receive no new segments, this could be due the stream ending or ATTESTATION_PENDING.
# To differentiate the two, we check if the live head segment has changed during the time we start getting no new segments.
# xxx: not perfect detection, sometimes get a new segment we can never fetch (partial)
self._current_sps_retry.error = error
return
def _process_request_had_segments(self):
if not self._received_new_segments:
self._no_new_segments_tracker.increment(
live_head_segment=self.processor.live_metadata.head_sequence_number if self.processor.live_metadata else None)
self.logger.trace(f'No new segments received in request {self._request_number}, count: {self._no_new_segments_tracker.consecutive_requests}')
else:
self._no_new_segments_tracker.reset()
def _validate_response_integrity(self):
if not len(self.processor.partial_segments):
return
msg = 'Received partial segments: ' + ', '.join(
f'{seg.format_id}: {seg.sequence_number}'
for seg in self.processor.partial_segments.values()
)
if self.processor.is_live:
# In post live, sometimes we get a partial segment at the end of the stream that should be ignored.
# If this occurs mid-stream, other guards should prevent corruption.
if (
self.processor.live_metadata
# TODO: generalize
and self.processor.client_abr_state.player_time_ms >= (
self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance))
):
# Only log a warning if we are not near the head of a stream
self.logger.debug(msg)
else:
self.logger.warning(msg)
else:
# Should not happen for videos
self._current_http_retry.error = SabrStreamError(msg)
self.processor.partial_segments.clear()
def _prepare_next_playback_time(self):
# TODO: refactor and cleanup this massive function
wait_seconds = 0
for izf in self.processor.initialized_formats.values():
if not izf.current_segment:
continue
# Guard: Check that the segment is not in multiple consumed ranges
# This should not happen, but if it does, we should bail
count = sum(
1 for cr in izf.consumed_ranges
if cr.start_sequence_number <= izf.current_segment.sequence_number <= cr.end_sequence_number
)
if count > 1:
raise SabrStreamError(f'Segment {izf.current_segment.sequence_number} for format {izf.format_id} in {count} consumed ranges')
# Check if there is two or more consumed ranges where the end lines up with the start of the other.
# This could happen in the case of concurrent playback.
# In this case, we should consider a seek as acceptable to the end of the other.
# Note: It is assumed a segment is only present in one consumed range - it should not be allowed in multiple (by process media header)
prev_consumed_range = next(
(cr for cr in izf.consumed_ranges if cr.end_sequence_number == izf.current_segment.sequence_number),
None,
)
# TODO: move to processor MEDIA_END
if prev_consumed_range and len(get_cr_chain(prev_consumed_range, izf.consumed_ranges)) >= 2:
self.logger.debug(f'Found two or more consumed ranges that line up, allowing a seek for format {izf.format_id}')
izf.current_segment = None
yield MediaSeekSabrPart(
reason=MediaSeekSabrPart.Reason.CONSUMED_SEEK,
format_id=izf.format_id,
format_selector=izf.format_selector)
enabled_initialized_formats = [izf for izf in self.processor.initialized_formats.values() if not izf.discard]
# For each initialized format:
# 1. find the consumed format that matches player_time_ms.
# 2. find the current consumed range in sequence (in case multiple are joined together)
# For livestreams, we allow a tolerance for the segment duration as it is estimated. This tolerance should be less than the segment duration / 2.
cr_tolerance_ms = 0
if self.processor.is_live:
cr_tolerance_ms = self.processor.live_segment_target_duration_tolerance_ms
current_consumed_ranges = []
for izf in enabled_initialized_formats:
for cr in sorted(izf.consumed_ranges, key=lambda cr: cr.start_sequence_number):
if (cr.start_time_ms - cr_tolerance_ms) <= self.processor.client_abr_state.player_time_ms <= cr.start_time_ms + cr.duration_ms + (cr_tolerance_ms * 2):
chain = get_cr_chain(cr, izf.consumed_ranges)
current_consumed_ranges.append(chain[-1])
# There should only be one chain for the current player_time_ms (including the tolerance)
break
min_consumed_duration_ms = None
# Get the lowest consumed range end time out of all current consumed ranges.
if current_consumed_ranges:
lowest_izf_consumed_range = min(current_consumed_ranges, key=lambda cr: cr.start_time_ms + cr.duration_ms)
min_consumed_duration_ms = lowest_izf_consumed_range.start_time_ms + lowest_izf_consumed_range.duration_ms
if len(current_consumed_ranges) != len(enabled_initialized_formats) or min_consumed_duration_ms is None:
# Missing a consumed range for a format.
# In this case, consider player_time_ms to be our correct next time
# May happen in the case of:
# 1. A Format has not been initialized yet (can happen if response read fails)
# or
# 1. SABR_SEEK to time outside both formats consumed ranges
# 2. ONE of the formats returns data after the SABR_SEEK in that request
if min_consumed_duration_ms is None:
min_consumed_duration_ms = self.processor.client_abr_state.player_time_ms
else:
min_consumed_duration_ms = min(min_consumed_duration_ms, self.processor.client_abr_state.player_time_ms)
# Usually provided by the server if there was no segments returned.
# We'll use this to calculate the time to wait for the next request (for live streams).
next_request_backoff_ms = (self.processor.next_request_policy and self.processor.next_request_policy.backoff_time_ms) or 0
request_player_time = self.processor.client_abr_state.player_time_ms
self.logger.trace(f'min consumed duration ms: {min_consumed_duration_ms}')
self.processor.client_abr_state.player_time_ms = min_consumed_duration_ms
# Check if the latest segment is the last one of each format (if data is available)
if (
not (self.processor.is_live and not self.processor.post_live)
and enabled_initialized_formats
and len(current_consumed_ranges) == len(enabled_initialized_formats)
and all(
(
initialized_format.total_segments is not None
# consumed ranges are not guaranteed to be in order
and sorted(
initialized_format.consumed_ranges,
key=lambda cr: cr.end_sequence_number,
)[-1].end_sequence_number >= initialized_format.total_segments
)
for initialized_format in enabled_initialized_formats
)
):
self.logger.debug('Reached last expected segment for all enabled formats, assuming end of media')
self._consumed = True
# Check if we have exceeded the total duration of the media (if not live),
# or wait for the next segment (if live)
elif self.processor.total_duration_ms and (self.processor.client_abr_state.player_time_ms >= self.processor.total_duration_ms):
if self.processor.is_live:
self.logger.trace(f'setting player time ms ({self.processor.client_abr_state.player_time_ms}) to total duration ms ({self.processor.total_duration_ms})')
self.processor.client_abr_state.player_time_ms = self.processor.total_duration_ms
if (
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
and not self._is_retry
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
):
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds, assuming end of live stream')
self._consumed = True
else:
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
else:
self.logger.debug(f'End of media (player time ms {self.processor.client_abr_state.player_time_ms} >= total duration ms {self.processor.total_duration_ms})')
self._consumed = True
# Handle receiving no new segments before end the end of the video/stream
# For videos, if exceeds max_empty_requests, this should not happen so we raise an error
# For livestreams, if we exceed max_empty_requests, and we don't have live_metadata,
# and have not received any data for a while, we can assume the stream has ended (as we don't know the head sequence number)
elif (
# Determine if we are receiving no segments as the live stream has ended.
# Allow some tolerance the head segment may not be able to be received.
self.processor.is_live and not self.processor.post_live
and (
getattr(self.processor.live_metadata, 'head_sequence_number', None) is None
or (
enabled_initialized_formats
and len(current_consumed_ranges) == len(enabled_initialized_formats)
and all(
(
initialized_format.total_segments is not None
and sorted(
initialized_format.consumed_ranges,
key=lambda cr: cr.end_sequence_number,
)[-1].end_sequence_number
>= initialized_format.total_segments - self.processor.live_end_segment_tolerance
)
for initialized_format in enabled_initialized_formats
)
)
or self.processor.live_metadata.head_sequence_time_ms is None
or (
# Sometimes we receive a partial segment at the end of the stream
# and the server seeks us to the end of the current segment.
# As our consumed range for this segment has an estimated end time,
# this may be slightly off what the server seeks to.
# This can cause the player time to be slightly outside the consumed range.
#
# Because of this, we should also check the player time against
# the head segment time using the estimated segment duration.
# xxx: consider also taking into account the max seekable timestamp
request_player_time >= self.processor.live_metadata.head_sequence_time_ms - (self.processor.live_segment_target_duration_sec * 1000 * self.processor.live_end_segment_tolerance)
)
)
):
if (
not self._is_retry # allow us to sleep on a retry
and self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
and self._no_new_segments_tracker.timestamp_started < time.time() + self.processor.live_end_wait_sec
):
self.logger.debug(f'No new segments received for at least {self.processor.live_end_wait_sec} seconds; assuming end of live stream')
self._consumed = True
elif self._no_new_segments_tracker.consecutive_requests >= 1:
# Sometimes we can't get the head segment - rather tend to sit behind the head segment for the duration of the livestream
wait_seconds = max(next_request_backoff_ms / 1000, self.processor.live_segment_target_duration_sec)
elif (
self._no_new_segments_tracker.consecutive_requests > self.max_empty_requests
and not self._is_retry
):
raise SabrStreamError('No new segments received in three consecutive requests')
elif (
not self.processor.is_live and next_request_backoff_ms
and self._no_new_segments_tracker.consecutive_requests >= 1
and any(t in self.processor.sabr_contexts_to_send for t in self.processor.sabr_context_updates)
):
wait_seconds = math.ceil(next_request_backoff_ms / 1000)
self.logger.info(f'The server is requiring yt-dlp to wait {wait_seconds} seconds before continuing due to ad enforcement')
if wait_seconds:
self.logger.debug(f'Waiting {wait_seconds} seconds for next segment(s)')
time.sleep(wait_seconds)
def _parse_ump_response(self, response):
self._unknown_part_types.clear()
ump = UMPDecoder(response)
for part in ump.iter_parts():
if part.part_id == UMPPartId.MEDIA_HEADER:
yield from self._process_media_header(part)
elif part.part_id == UMPPartId.MEDIA:
yield from self._process_media(part)
elif part.part_id == UMPPartId.MEDIA_END:
yield from self._process_media_end(part)
elif part.part_id == UMPPartId.STREAM_PROTECTION_STATUS:
yield from self._process_stream_protection_status(part)
elif part.part_id == UMPPartId.SABR_REDIRECT:
self._process_sabr_redirect(part)
elif part.part_id == UMPPartId.FORMAT_INITIALIZATION_METADATA:
yield from self._process_format_initialization_metadata(part)
elif part.part_id == UMPPartId.NEXT_REQUEST_POLICY:
self._process_next_request_policy(part)
elif part.part_id == UMPPartId.LIVE_METADATA:
yield from self._process_live_metadata(part)
elif part.part_id == UMPPartId.SABR_SEEK:
yield from self._process_sabr_seek(part)
elif part.part_id == UMPPartId.SABR_ERROR:
self._process_sabr_error(part)
elif part.part_id == UMPPartId.SABR_CONTEXT_UPDATE:
self._process_sabr_context_update(part)
elif part.part_id == UMPPartId.SABR_CONTEXT_SENDING_POLICY:
self._process_sabr_context_sending_policy(part)
elif part.part_id == UMPPartId.RELOAD_PLAYER_RESPONSE:
yield from self._process_reload_player_response(part)
else:
if part.part_id not in self._IGNORED_PARTS:
self._unknown_part_types.add(part.part_id)
self._log_part(part, msg='Unhandled part type', data=part.data.read())
# Cancel request processing if we are going to retry
if self._current_sps_retry.error or self._current_http_retry.error:
self.logger.debug('Request processing cancelled')
return
def _process_media_header(self, part: UMPPart):
media_header = protobug.load(part.data, MediaHeader)
self._log_part(part=part, protobug_obj=media_header)
try:
result = self.processor.process_media_header(media_header)
if result.sabr_part:
yield result.sabr_part
except MediaSegmentMismatchError as e:
# For livestreams, the server may not know the exact segment for a given player time.
# For segments near stream head, it estimates using segment duration, which can cause off-by-one segment mismatches.
# If a segment is much longer or shorter than expected, the server may return a segment ahead or behind.
# In such cases, retry with an adjusted player time to resync.
if self.processor.is_live and e.received_sequence_number == e.expected_sequence_number - 1:
# The segment before the previous segment was possibly longer than expected.
# Move the player time forward to try to adjust for this.
self.processor.client_abr_state.player_time_ms += self.processor.live_segment_target_duration_tolerance_ms
self._sq_mismatch_forward_count += 1
self._current_http_retry.error = e
return
elif self.processor.is_live and e.received_sequence_number == e.expected_sequence_number + 2:
# The previous segment was possibly shorter than expected
# Move the player time backwards to try to adjust for this.
self.processor.client_abr_state.player_time_ms = max(0, self.processor.client_abr_state.player_time_ms - self.processor.live_segment_target_duration_tolerance_ms)
self._sq_mismatch_backtrack_count += 1
self._current_http_retry.error = e
return
raise e
def _process_media(self, part: UMPPart):
header_id = read_varint(part.data)
content_length = part.size - part.data.tell()
result = self.processor.process_media(header_id, content_length, part.data)
if result.sabr_part:
yield result.sabr_part
def _process_media_end(self, part: UMPPart):
header_id = read_varint(part.data)
self._log_part(part, msg=f'Header ID: {header_id}')
result = self.processor.process_media_end(header_id)
if result.is_new_segment:
self._received_new_segments = True
if result.sabr_part:
yield result.sabr_part
def _process_live_metadata(self, part: UMPPart):
live_metadata = protobug.load(part.data, LiveMetadata)
self._log_part(part, protobug_obj=live_metadata)
yield from self.processor.process_live_metadata(live_metadata).seek_sabr_parts
def _process_stream_protection_status(self, part: UMPPart):
sps = protobug.load(part.data, StreamProtectionStatus)
self._log_part(part, msg=f'Status: {StreamProtectionStatus.Status(sps.status).name}', protobug_obj=sps)
result = self.processor.process_stream_protection_status(sps)
if result.sabr_part:
yield result.sabr_part
def _process_sabr_redirect(self, part: UMPPart):
sabr_redirect = protobug.load(part.data, SabrRedirect)
self._log_part(part, protobug_obj=sabr_redirect)
if not sabr_redirect.redirect_url:
self.logger.warning('Server requested to redirect to an invalid URL')
return
self.url = sabr_redirect.redirect_url
def _process_format_initialization_metadata(self, part: UMPPart):
fmt_init_metadata = protobug.load(part.data, FormatInitializationMetadata)
self._log_part(part, protobug_obj=fmt_init_metadata)
result = self.processor.process_format_initialization_metadata(fmt_init_metadata)
if result.sabr_part:
yield result.sabr_part
def _process_next_request_policy(self, part: UMPPart):
next_request_policy = protobug.load(part.data, NextRequestPolicy)
self._log_part(part, protobug_obj=next_request_policy)
self.processor.process_next_request_policy(next_request_policy)
def _process_sabr_seek(self, part: UMPPart):
sabr_seek = protobug.load(part.data, SabrSeek)
self._log_part(part, protobug_obj=sabr_seek)
yield from self.processor.process_sabr_seek(sabr_seek).seek_sabr_parts
def _process_sabr_error(self, part: UMPPart):
sabr_error = protobug.load(part.data, SabrError)
self._log_part(part, protobug_obj=sabr_error)
self._current_http_retry.error = SabrStreamError(f'SABR Protocol Error: {sabr_error}')
def _process_sabr_context_update(self, part: UMPPart):
sabr_ctx_update = protobug.load(part.data, SabrContextUpdate)
self._log_part(part, protobug_obj=sabr_ctx_update)
self.processor.process_sabr_context_update(sabr_ctx_update)
def _process_sabr_context_sending_policy(self, part: UMPPart):
sabr_ctx_sending_policy = protobug.load(part.data, SabrContextSendingPolicy)
self._log_part(part, protobug_obj=sabr_ctx_sending_policy)
self.processor.process_sabr_context_sending_policy(sabr_ctx_sending_policy)
def _process_reload_player_response(self, part: UMPPart):
reload_player_response = protobug.load(part.data, ReloadPlayerResponse)
self._log_part(part, protobug_obj=reload_player_response)
yield RefreshPlayerResponseSabrPart(
reason=RefreshPlayerResponseSabrPart.Reason.SABR_RELOAD_PLAYER_RESPONSE,
reload_playback_token=reload_player_response.reload_playback_params.token,
)
def _process_fallback_server(self):
# Attempt to fall back to another GVS host in the case the current one fails
new_url = next_gvs_fallback_url(self.url)
if not new_url:
self.logger.debug('No more fallback hosts available')
self.logger.warning(f'Falling back to host {urllib.parse.urlparse(new_url).netloc}')
self.url = new_url
def _gvs_expiry(self):
return int_or_none(traverse_obj(parse_qs(self.url), ('expire', 0), get_all=False))
def _process_expiry(self):
expires_at = self._gvs_expiry()
if not expires_at:
self.logger.warning(
'No expiry timestamp found in SABR URL. Will not be able to refresh.', once=True)
return
if expires_at - self.expiry_threshold_sec >= time.time():
self.logger.trace(f'SABR url expires in {int(expires_at - time.time())} seconds')
return
self.logger.debug(
f'Requesting player response refresh as SABR URL is due to expire in {self.expiry_threshold_sec} seconds')
yield RefreshPlayerResponseSabrPart(reason=RefreshPlayerResponseSabrPart.Reason.SABR_URL_EXPIRY)
def _log_part(self, part: UMPPart, msg=None, protobug_obj=None, data=None):
if self.logger.log_level > self.logger.LogLevel.TRACE:
return
message = f'[{part.part_id.name}]: (Size {part.size})'
if protobug_obj:
message += f' Parsed: {protobug_obj}'
uf = list(unknown_fields(protobug_obj))
if uf:
message += f' (Unknown fields: {uf})'
if msg:
message += f' {msg}'
if data:
message += f' Data: {base64.b64encode(data).decode("utf-8")}'
self.logger.trace(message.strip())
def _log_state(self):
# TODO: refactor
if self.logger.log_level > self.logger.LogLevel.DEBUG:
return
if self.processor.is_live and self.processor.post_live:
live_message = f'post_live ({self.processor.live_segment_target_duration_sec}s)'
elif self.processor.is_live:
live_message = f'live ({self.processor.live_segment_target_duration_sec}s)'
else:
live_message = 'not_live'
if self.processor.is_live:
live_message += ' bid:' + str_or_none(broadcast_id_from_url(self.url))
consumed_ranges_message = (
', '.join(
f'{izf.format_id.itag}:'
+ ', '.join(
f'{cf.start_sequence_number}-{cf.end_sequence_number} '
f'({cf.start_time_ms}-'
f'{cf.start_time_ms + cf.duration_ms})'
for cf in izf.consumed_ranges
)
for izf in self.processor.initialized_formats.values()
) or 'none'
)
izf_parts = []
for izf in self.processor.initialized_formats.values():
s = f'{izf.format_id.itag}'
if izf.discard:
s += 'd'
p = []
if izf.total_segments:
p.append(f'{izf.total_segments}')
if izf.sequence_lmt is not None:
p.append(f'lmt={izf.sequence_lmt}')
if p:
s += ('(' + ','.join(p) + ')')
izf_parts.append(s)
initialized_formats_message = ', '.join(izf_parts) or 'none'
unknown_part_message = ''
if self._unknown_part_types:
unknown_part_message = 'unkpt:' + ', '.join(part_type.name for part_type in self._unknown_part_types)
sabr_context_update_msg = ''
if self.processor.sabr_context_updates:
sabr_context_update_msg += 'cu:[' + ','.join(
f'{k}{"(n)" if k not in self.processor.sabr_contexts_to_send else ""}'
for k in self.processor.sabr_context_updates
) + ']'
self.logger.debug(
"[SABR State] "
f"v:{self.processor.video_id or 'unknown'} "
f"c:{self.processor.client_info.client_name.name} "
f"t:{self.processor.client_abr_state.player_time_ms} "
f"td:{self.processor.total_duration_ms if self.processor.total_duration_ms else 'n/a'} "
f"h:{urllib.parse.urlparse(self.url).netloc} "
f"exp:{dt.timedelta(seconds=int(self._gvs_expiry() - time.time())) if self._gvs_expiry() else 'n/a'} "
f"rn:{self._request_number} rnns:{self._no_new_segments_tracker.consecutive_requests} "
f"lnns:{self._no_new_segments_tracker.live_head_segment_started or 'n/a'} "
f"mmb:{self._sq_mismatch_backtrack_count} mmf:{self._sq_mismatch_forward_count} "
f"pot:{'Y' if self.processor.po_token else 'N'} "
f"sps:{self.processor.stream_protection_status.name if self.processor.stream_protection_status else 'n/a'} "
f"{live_message} "
f"if:[{initialized_formats_message}] "
f"cr:[{consumed_ranges_message}] "
f"{sabr_context_update_msg} "
f"{unknown_part_message}",
)

View File

@ -0,0 +1,83 @@
from __future__ import annotations
import contextlib
import math
import urllib.parse
from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange
from yt_dlp.utils import int_or_none, orderedSet, parse_qs, str_or_none, update_url_query
def get_cr_chain(start_consumed_range: ConsumedRange, consumed_ranges: list[ConsumedRange]) -> list[ConsumedRange]:
# TODO: unit test
# Return the continuous consumed range chain starting from the given consumed range
# Note: It is assumed a segment is only present in one consumed range - it should not be allowed in multiple (by process media header)
chain = [start_consumed_range]
for cr in sorted(consumed_ranges, key=lambda r: r.start_sequence_number):
if cr.start_sequence_number == chain[-1].end_sequence_number + 1:
chain.append(cr)
elif cr.start_sequence_number > chain[-1].end_sequence_number + 1:
break
return chain
def next_gvs_fallback_url(gvs_url):
# TODO: unit test
qs = parse_qs(gvs_url)
gvs_url_parsed = urllib.parse.urlparse(gvs_url)
fvip = int_or_none(qs.get('fvip', [None])[0])
mvi = int_or_none(qs.get('mvi', [None])[0])
mn = str_or_none(qs.get('mn', [None])[0], default='').split(',')
fallback_count = int_or_none(qs.get('fallback_count', ['0'])[0], default=0)
hosts = []
def build_host(current_host, f, m):
rr = current_host.startswith('rr')
if f is None or m is None:
return None
return ('rr' if rr else 'r') + str(f) + '---' + m + '.googlevideo.com'
original_host = build_host(gvs_url_parsed.netloc, mvi, mn[0])
# Order of fallback hosts:
# 1. Fallback host in url (mn[1] + fvip)
# 2. Fallback hosts brute forced (this usually contains the original host)
for mn_entry in reversed(mn):
for fvip_entry in orderedSet([fvip, 1, 2, 3, 4, 5]):
fallback_host = build_host(gvs_url_parsed.netloc, fvip_entry, mn_entry)
if fallback_host and fallback_host not in hosts:
hosts.append(fallback_host)
if not hosts or len(hosts) == 1:
return None
# if first fallback, anchor to start of list so we start with the known fallback hosts
# Sometimes we may get a SABR_REDIRECT after a fallback, which gives a new host with new fallbacks.
# In this case, the original host indicated by the url params would match the current host
current_host_index = -1
if fallback_count > 0 and gvs_url_parsed.netloc != original_host:
with contextlib.suppress(ValueError):
current_host_index = hosts.index(gvs_url_parsed.netloc)
def next_host(idx, h):
return h[(idx + 1) % len(h)]
new_host = next_host(current_host_index + 1, hosts)
# If the current URL only has one fallback host, then the first fallback host is the same as the current host.
if new_host == gvs_url_parsed.netloc:
new_host = next_host(current_host_index + 2, hosts)
# TODO: do not return new_host if it still matches the original host
return update_url_query(
gvs_url_parsed._replace(netloc=new_host).geturl(), {'fallback_count': fallback_count + 1})
def ticks_to_ms(time_ticks: int, timescale: int):
if time_ticks is None or timescale is None:
return None
return math.ceil((time_ticks / timescale) * 1000)
def broadcast_id_from_url(url: str) -> str | None:
return str_or_none(parse_qs(url).get('id', [None])[0])

View File

@ -0,0 +1,116 @@
import dataclasses
import enum
import io
class UMPPartId(enum.IntEnum):
UNKNOWN = -1
ONESIE_HEADER = 10
ONESIE_DATA = 11
ONESIE_ENCRYPTED_MEDIA = 12
MEDIA_HEADER = 20
MEDIA = 21
MEDIA_END = 22
LIVE_METADATA = 31
HOSTNAME_CHANGE_HINT = 32
LIVE_METADATA_PROMISE = 33
LIVE_METADATA_PROMISE_CANCELLATION = 34
NEXT_REQUEST_POLICY = 35
USTREAMER_VIDEO_AND_FORMAT_DATA = 36
FORMAT_SELECTION_CONFIG = 37
USTREAMER_SELECTED_MEDIA_STREAM = 38
FORMAT_INITIALIZATION_METADATA = 42
SABR_REDIRECT = 43
SABR_ERROR = 44
SABR_SEEK = 45
RELOAD_PLAYER_RESPONSE = 46
PLAYBACK_START_POLICY = 47
ALLOWED_CACHED_FORMATS = 48
START_BW_SAMPLING_HINT = 49
PAUSE_BW_SAMPLING_HINT = 50
SELECTABLE_FORMATS = 51
REQUEST_IDENTIFIER = 52
REQUEST_CANCELLATION_POLICY = 53
ONESIE_PREFETCH_REJECTION = 54
TIMELINE_CONTEXT = 55
REQUEST_PIPELINING = 56
SABR_CONTEXT_UPDATE = 57
STREAM_PROTECTION_STATUS = 58
SABR_CONTEXT_SENDING_POLICY = 59
LAWNMOWER_POLICY = 60
SABR_ACK = 61
END_OF_TRACK = 62
CACHE_LOAD_POLICY = 63
LAWNMOWER_MESSAGING_POLICY = 64
PREWARM_CONNECTION = 65
PLAYBACK_DEBUG_INFO = 66
SNACKBAR_MESSAGE = 67
@classmethod
def _missing_(cls, value):
return cls.UNKNOWN
@dataclasses.dataclass
class UMPPart:
part_id: UMPPartId
size: int
data: io.BufferedIOBase
class UMPDecoder:
def __init__(self, fp: io.BufferedIOBase):
self.fp = fp
def iter_parts(self):
while not self.fp.closed:
part_type = read_varint(self.fp)
if part_type == -1 and not self.fp.closed:
self.fp.close()
if self.fp.closed:
break
part_size = read_varint(self.fp)
if part_size == -1 and not self.fp.closed:
self.fp.close()
if self.fp.closed:
raise EOFError('Unexpected EOF while reading part size')
part_data = self.fp.read(part_size)
# In the future, we could allow streaming the part data.
# But we will need to ensure that each part is completely read before continuing.
yield UMPPart(UMPPartId(part_type), part_size, io.BytesIO(part_data))
def read_varint(fp: io.BufferedIOBase) -> int:
# https://web.archive.org/web/20250430054327/https://github.com/gsuberland/UMP_Format/blob/main/UMP_Format.md
# https://web.archive.org/web/20250429151021/https://github.com/davidzeng0/innertube/blob/main/googlevideo/ump.md
byte = fp.read(1)
if not byte:
# Expected EOF
return -1
prefix = byte[0]
size = varint_size(prefix)
result = 0
shift = 0
if size != 5:
shift = 8 - size
mask = (1 << shift) - 1
result |= prefix & mask
for _ in range(1, size):
next_byte = fp.read(1)
if not next_byte:
return -1
byte_int = next_byte[0]
result |= byte_int << shift
shift += 8
return result
def varint_size(byte: int) -> int:
return 1 if byte < 128 else 2 if byte < 192 else 3 if byte < 224 else 4 if byte < 240 else 5

View File

@ -26,6 +26,7 @@
from .pot._director import initialize_pot_director from .pot._director import initialize_pot_director
from .pot.provider import PoTokenContext, PoTokenRequest from .pot.provider import PoTokenContext, PoTokenRequest
from ..openload import PhantomJSwrapper from ..openload import PhantomJSwrapper
from ...dependencies import protobug
from ...jsinterp import JSInterpreter from ...jsinterp import JSInterpreter
from ...networking.exceptions import HTTPError from ...networking.exceptions import HTTPError
from ...utils import ( from ...utils import (
@ -72,6 +73,7 @@
STREAMING_DATA_CLIENT_NAME = '__yt_dlp_client' STREAMING_DATA_CLIENT_NAME = '__yt_dlp_client'
STREAMING_DATA_INITIAL_PO_TOKEN = '__yt_dlp_po_token' STREAMING_DATA_INITIAL_PO_TOKEN = '__yt_dlp_po_token'
STREAMING_DATA_FETCH_PO_TOKEN = '__yt_dlp_fetch_po_token'
STREAMING_DATA_FETCH_SUBS_PO_TOKEN = '__yt_dlp_fetch_subs_po_token' STREAMING_DATA_FETCH_SUBS_PO_TOKEN = '__yt_dlp_fetch_subs_po_token'
STREAMING_DATA_INNERTUBE_CONTEXT = '__yt_dlp_innertube_context' STREAMING_DATA_INNERTUBE_CONTEXT = '__yt_dlp_innertube_context'
@ -1824,7 +1826,8 @@ def _real_initialize(self):
def _prepare_live_from_start_formats(self, formats, video_id, live_start_time, url, webpage_url, smuggled_data, is_live): def _prepare_live_from_start_formats(self, formats, video_id, live_start_time, url, webpage_url, smuggled_data, is_live):
lock = threading.Lock() lock = threading.Lock()
start_time = time.time() start_time = time.time()
formats = [f for f in formats if f.get('is_from_start')] # TODO: only include dash formats
formats = [f for f in formats if f.get('is_from_start') and f.get('protocol') != 'sabr']
def refetch_manifest(format_id, delay): def refetch_manifest(format_id, delay):
nonlocal formats, start_time, is_live nonlocal formats, start_time, is_live
@ -1836,7 +1839,7 @@ def refetch_manifest(format_id, delay):
microformats = traverse_obj( microformats = traverse_obj(
prs, (..., 'microformat', 'playerMicroformatRenderer'), prs, (..., 'microformat', 'playerMicroformatRenderer'),
expected_type=dict) expected_type=dict)
_, live_status, _, formats, _ = self._list_formats(video_id, microformats, video_details, prs, player_url) _, live_status, formats, _ = self._list_formats(video_id, microformats, video_details, prs, player_url)
is_live = live_status == 'is_live' is_live = live_status == 'is_live'
start_time = time.time() start_time = time.time()
@ -2812,16 +2815,25 @@ def _get_checkok_params():
return {'contentCheckOk': True, 'racyCheckOk': True} return {'contentCheckOk': True, 'racyCheckOk': True}
@classmethod @classmethod
def _generate_player_context(cls, sts=None): def _generate_player_context(cls, sts=None, reload_playback_token=None):
context = { content_playback_context = {
'html5Preference': 'HTML5_PREF_WANTS', 'html5Preference': 'HTML5_PREF_WANTS',
'isInlinePlaybackNoAd': True,
} }
if sts is not None: if sts is not None:
context['signatureTimestamp'] = sts content_playback_context['signatureTimestamp'] = sts
playback_context = {
'contentPlaybackContext': content_playback_context,
}
if reload_playback_token:
playback_context['reloadPlaybackContext'] = {
'reloadPlaybackParams': {'token': reload_playback_token},
}
return { return {
'playbackContext': { 'playbackContext': playback_context,
'contentPlaybackContext': context,
},
**cls._get_checkok_params(), **cls._get_checkok_params(),
} }
@ -2863,7 +2875,7 @@ def _get_config_po_token(self, client: str, context: _PoTokenContext):
def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, visitor_data=None, def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, visitor_data=None,
data_sync_id=None, session_index=None, player_url=None, video_id=None, webpage=None, data_sync_id=None, session_index=None, player_url=None, video_id=None, webpage=None,
required=False, **kwargs): required=False, bypass_cache=None, **kwargs):
""" """
Fetch a PO Token for a given client and context. This function will validate required parameters for a given context and client. Fetch a PO Token for a given client and context. This function will validate required parameters for a given context and client.
@ -2879,6 +2891,7 @@ def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None,
@param video_id: video ID. @param video_id: video ID.
@param webpage: video webpage. @param webpage: video webpage.
@param required: Whether the PO Token is required (i.e. try to fetch unless policy is "never"). @param required: Whether the PO Token is required (i.e. try to fetch unless policy is "never").
@param bypass_cache: Whether to bypass the cache.
@param kwargs: Additional arguments to pass down. May be more added in the future. @param kwargs: Additional arguments to pass down. May be more added in the future.
@return: The fetched PO Token. None if it could not be fetched. @return: The fetched PO Token. None if it could not be fetched.
""" """
@ -2900,7 +2913,7 @@ def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None,
return return
config_po_token = self._get_config_po_token(client, context) config_po_token = self._get_config_po_token(client, context)
if config_po_token: if config_po_token and not bypass_cache:
# GVS WebPO token is bound to data_sync_id / account Session ID when logged in. # GVS WebPO token is bound to data_sync_id / account Session ID when logged in.
if player_url and context == _PoTokenContext.GVS and not data_sync_id and self.is_authenticated: if player_url and context == _PoTokenContext.GVS and not data_sync_id and self.is_authenticated:
self.report_warning( self.report_warning(
@ -2927,6 +2940,7 @@ def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None,
player_url=player_url, player_url=player_url,
video_id=video_id, video_id=video_id,
video_webpage=webpage, video_webpage=webpage,
bypass_cache=bypass_cache,
required=required, required=required,
**kwargs, **kwargs,
) )
@ -2984,7 +2998,7 @@ def _fetch_po_token(self, client, **kwargs):
request_verify_tls=not self.get_param('nocheckcertificate'), request_verify_tls=not self.get_param('nocheckcertificate'),
request_source_address=self.get_param('source_address'), request_source_address=self.get_param('source_address'),
bypass_cache=False, bypass_cache=kwargs.get('bypass_cache', False),
) )
return self._pot_director.get_po_token(pot_request) return self._pot_director.get_po_token(pot_request)
@ -3005,7 +3019,7 @@ def _is_agegated(player_response):
def _is_unplayable(player_response): def _is_unplayable(player_response):
return traverse_obj(player_response, ('playabilityStatus', 'status')) == 'UNPLAYABLE' return traverse_obj(player_response, ('playabilityStatus', 'status')) == 'UNPLAYABLE'
def _extract_player_response(self, client, video_id, master_ytcfg, player_ytcfg, player_url, initial_pr, visitor_data, data_sync_id, po_token): def _extract_player_response(self, client, video_id, master_ytcfg, player_ytcfg, player_url, initial_pr, visitor_data, data_sync_id, po_token, reload_playback_token):
headers = self.generate_api_headers( headers = self.generate_api_headers(
ytcfg=player_ytcfg, ytcfg=player_ytcfg,
default_client=client, default_client=client,
@ -3034,7 +3048,8 @@ def _extract_player_response(self, client, video_id, master_ytcfg, player_ytcfg,
yt_query['serviceIntegrityDimensions'] = {'poToken': po_token} yt_query['serviceIntegrityDimensions'] = {'poToken': po_token}
sts = self._extract_signature_timestamp(video_id, player_url, master_ytcfg, fatal=False) if player_url else None sts = self._extract_signature_timestamp(video_id, player_url, master_ytcfg, fatal=False) if player_url else None
yt_query.update(self._generate_player_context(sts)) yt_query.update(self._generate_player_context(sts, reload_playback_token))
return self._extract_response( return self._extract_response(
item_id=video_id, ep='player', query=yt_query, item_id=video_id, ep='player', query=yt_query,
ytcfg=player_ytcfg, headers=headers, fatal=True, ytcfg=player_ytcfg, headers=headers, fatal=True,
@ -3087,7 +3102,7 @@ def _invalid_player_response(self, pr, video_id):
if (pr_id := traverse_obj(pr, ('videoDetails', 'videoId'))) != video_id: if (pr_id := traverse_obj(pr, ('videoDetails', 'videoId'))) != video_id:
return pr_id return pr_id
def _extract_player_responses(self, clients, video_id, webpage, master_ytcfg, smuggled_data): def _extract_player_responses(self, clients, video_id, webpage, master_ytcfg, reload_playback_token):
initial_pr = None initial_pr = None
if webpage: if webpage:
initial_pr = self._search_json( initial_pr = self._search_json(
@ -3136,7 +3151,7 @@ def append_client(*client_names):
player_url = self._download_player_url(video_id) player_url = self._download_player_url(video_id)
tried_iframe_fallback = True tried_iframe_fallback = True
pr = initial_pr if client == 'web' else None pr = initial_pr if client == 'web' and not reload_playback_token else None
visitor_data = visitor_data or self._extract_visitor_data(master_ytcfg, initial_pr, player_ytcfg) visitor_data = visitor_data or self._extract_visitor_data(master_ytcfg, initial_pr, player_ytcfg)
data_sync_id = data_sync_id or self._extract_data_sync_id(master_ytcfg, initial_pr, player_ytcfg) data_sync_id = data_sync_id or self._extract_data_sync_id(master_ytcfg, initial_pr, player_ytcfg)
@ -3156,8 +3171,13 @@ def append_client(*client_names):
player_po_token = None if pr else self.fetch_po_token( player_po_token = None if pr else self.fetch_po_token(
context=_PoTokenContext.PLAYER, **fetch_po_token_args) context=_PoTokenContext.PLAYER, **fetch_po_token_args)
gvs_po_token = self.fetch_po_token( fetch_gvs_po_token_func = functools.partial(
context=_PoTokenContext.GVS, **fetch_po_token_args) self.fetch_po_token,
context=_PoTokenContext.GVS,
**fetch_po_token_args,
)
gvs_po_token = fetch_gvs_po_token_func()
fetch_subs_po_token_func = functools.partial( fetch_subs_po_token_func = functools.partial(
self.fetch_po_token, self.fetch_po_token,
@ -3200,7 +3220,8 @@ def append_client(*client_names):
initial_pr=initial_pr, initial_pr=initial_pr,
visitor_data=visitor_data, visitor_data=visitor_data,
data_sync_id=data_sync_id, data_sync_id=data_sync_id,
po_token=player_po_token) po_token=player_po_token,
reload_playback_token=reload_playback_token)
except ExtractorError as e: except ExtractorError as e:
self.report_warning(e) self.report_warning(e)
continue continue
@ -3215,9 +3236,12 @@ def append_client(*client_names):
sd[STREAMING_DATA_INITIAL_PO_TOKEN] = gvs_po_token sd[STREAMING_DATA_INITIAL_PO_TOKEN] = gvs_po_token
sd[STREAMING_DATA_INNERTUBE_CONTEXT] = innertube_context sd[STREAMING_DATA_INNERTUBE_CONTEXT] = innertube_context
sd[STREAMING_DATA_FETCH_SUBS_PO_TOKEN] = fetch_subs_po_token_func sd[STREAMING_DATA_FETCH_SUBS_PO_TOKEN] = fetch_subs_po_token_func
sd[STREAMING_DATA_FETCH_PO_TOKEN] = fetch_gvs_po_token_func
for f in traverse_obj(sd, (('formats', 'adaptiveFormats'), ..., {dict})): for f in traverse_obj(sd, (('formats', 'adaptiveFormats'), ..., {dict})):
f[STREAMING_DATA_CLIENT_NAME] = client f[STREAMING_DATA_CLIENT_NAME] = client
f[STREAMING_DATA_INITIAL_PO_TOKEN] = gvs_po_token f[STREAMING_DATA_INITIAL_PO_TOKEN] = gvs_po_token
f[STREAMING_DATA_INNERTUBE_CONTEXT] = innertube_context
f[STREAMING_DATA_FETCH_PO_TOKEN] = fetch_gvs_po_token_func
if deprioritize_pr: if deprioritize_pr:
deprioritized_prs.append(pr) deprioritized_prs.append(pr)
else: else:
@ -3295,12 +3319,54 @@ def _report_pot_subtitles_skipped(self, video_id, client_name, msg=None):
else: else:
self.report_warning(msg, only_once=True) self.report_warning(msg, only_once=True)
def _extract_formats_and_subtitles(self, streaming_data, video_id, player_url, live_status, duration): def _process_n_param(self, gvs_url, video_id, player_url, proto='https'):
query = parse_qs(gvs_url)
if query.get('n'):
try:
decrypt_nsig = self._cached(self._decrypt_nsig, 'nsig', query['n'][0])
return update_url_query(gvs_url, {
'n': decrypt_nsig(query['n'][0], video_id, player_url),
})
except ExtractorError as e:
if player_url:
self.report_warning(
f'nsig extraction failed: Some {proto} formats may be missing\n'
f' n = {query["n"][0]} ; player = {player_url}\n'
f' {bug_reports_message(before="")}',
video_id=video_id, only_once=True)
self.write_debug(e, only_once=True)
else:
self.report_warning(
f'Cannot decrypt nsig without player_url: Some {proto} formats may be missing',
video_id=video_id, only_once=True)
return None
return gvs_url
def _reload_sabr_config(self, video_id, client_name, reload_playback_token):
# xxx: may also update client info?
url = 'https://www.youtube.com/watch?v=' + video_id
_, _, prs, player_url = self._download_player_responses(url, {}, video_id, url, reload_playback_token)
video_details = traverse_obj(prs, (..., 'videoDetails'), expected_type=dict)
microformats = traverse_obj(
prs, (..., 'microformat', 'playerMicroformatRenderer'),
expected_type=dict)
_, live_status, formats, _ = self._list_formats(video_id, microformats, video_details, prs, player_url)
for f in formats:
if f.get('protocol') == 'sabr':
sabr_config = f['_sabr_config']
if sabr_config['client_name'] == client_name:
return f['url'], sabr_config['video_playback_ustreamer_config']
raise ExtractorError('No SABR formats found', expected=True)
def _extract_formats_and_subtitles(self, video_id, player_responses, player_url, live_status, duration):
CHUNK_SIZE = 10 << 20 CHUNK_SIZE = 10 << 20
PREFERRED_LANG_VALUE = 10 PREFERRED_LANG_VALUE = 10
original_language = None original_language = None
itags, stream_ids = collections.defaultdict(set), [] itags, stream_ids = collections.defaultdict(set), []
itag_qualities, res_qualities = {}, {0: None} itag_qualities, res_qualities = {}, {0: None}
subtitles = {}
q = qualities([ q = qualities([
# Normally tiny is the smallest video-only formats. But # Normally tiny is the smallest video-only formats. But
# audio-only formats with unknown quality may get tagged as tiny # audio-only formats with unknown quality may get tagged as tiny
@ -3308,7 +3374,6 @@ def _extract_formats_and_subtitles(self, streaming_data, video_id, player_url, l
'audio_quality_ultralow', 'audio_quality_low', 'audio_quality_medium', 'audio_quality_high', # Audio only formats 'audio_quality_ultralow', 'audio_quality_low', 'audio_quality_medium', 'audio_quality_high', # Audio only formats
'small', 'medium', 'large', 'hd720', 'hd1080', 'hd1440', 'hd2160', 'hd2880', 'highres', 'small', 'medium', 'large', 'hd720', 'hd1080', 'hd1440', 'hd2160', 'hd2880', 'highres',
]) ])
streaming_formats = traverse_obj(streaming_data, (..., ('formats', 'adaptiveFormats'), ...))
format_types = self._configuration_arg('formats') format_types = self._configuration_arg('formats')
all_formats = 'duplicate' in format_types all_formats = 'duplicate' in format_types
if self._configuration_arg('include_duplicate_formats'): if self._configuration_arg('include_duplicate_formats'):
@ -3323,271 +3388,331 @@ def build_fragments(f):
}), }),
} for range_start in range(0, f['filesize'], CHUNK_SIZE)) } for range_start in range(0, f['filesize'], CHUNK_SIZE))
for fmt in streaming_formats: for pr in player_responses:
client_name = fmt[STREAMING_DATA_CLIENT_NAME] streaming_data = traverse_obj(pr, 'streamingData')
if fmt.get('targetDurationSec'): if not streaming_data:
continue continue
client_name = streaming_data.get(STREAMING_DATA_CLIENT_NAME)
po_token = streaming_data.get(STREAMING_DATA_INITIAL_PO_TOKEN)
fetch_po_token = streaming_data.get(STREAMING_DATA_FETCH_PO_TOKEN)
innertube_context = streaming_data.get(STREAMING_DATA_INNERTUBE_CONTEXT)
streaming_formats = traverse_obj(streaming_data, (('formats', 'adaptiveFormats'), ...))
itag = str_or_none(fmt.get('itag')) def get_stream_id(fmt_stream):
audio_track = fmt.get('audioTrack') or {} return str_or_none(fmt_stream.get('itag')), traverse_obj(fmt_stream, 'audioTrack', 'id'), fmt_stream.get('isDrc')
stream_id = (itag, audio_track.get('id'), fmt.get('isDrc'))
if not all_formats:
if stream_id in stream_ids:
continue
quality = fmt.get('quality') def process_format_stream(fmt_stream, proto):
height = int_or_none(fmt.get('height')) nonlocal itag_qualities, res_qualities, original_language
if quality == 'tiny' or not quality: itag = str_or_none(fmt_stream.get('itag'))
quality = fmt.get('audioQuality', '').lower() or quality audio_track = fmt_stream.get('audioTrack') or {}
# The 3gp format (17) in android client has a quality of "small", quality = fmt_stream.get('quality')
# but is actually worse than other formats height = int_or_none(fmt_stream.get('height'))
if itag == '17': if quality == 'tiny' or not quality:
quality = 'tiny' quality = fmt_stream.get('audioQuality', '').lower() or quality
if quality: # The 3gp format (17) in android client has a quality of "small",
if itag: # but is actually worse than other formats
itag_qualities[itag] = quality if itag == '17':
if height: quality = 'tiny'
res_qualities[height] = quality if quality:
if itag:
itag_qualities[itag] = quality
if height:
res_qualities[height] = quality
display_name = audio_track.get('displayName') or '' display_name = audio_track.get('displayName') or ''
is_original = 'original' in display_name.lower() is_original = 'original' in display_name.lower()
is_descriptive = 'descriptive' in display_name.lower() is_descriptive = 'descriptive' in display_name.lower()
is_default = audio_track.get('audioIsDefault') is_default = audio_track.get('audioIsDefault')
language_code = audio_track.get('id', '').split('.')[0] language_code = audio_track.get('id', '').split('.')[0]
if language_code and (is_original or (is_default and not original_language)): if language_code and (is_original or (is_default and not original_language)):
original_language = language_code original_language = language_code
has_drm = bool(fmt.get('drmFamilies')) has_drm = bool(fmt_stream.get('drmFamilies'))
# FORMAT_STREAM_TYPE_OTF(otf=1) requires downloading the init fragment if has_drm:
# (adding `&sq=0` to the URL) and parsing emsg box to determine the msg = f'Some {client_name} client {proto} formats have been skipped as they are DRM protected. '
# number of fragment that would subsequently requested with (`&sq=N`) if client_name == 'tv':
if fmt.get('type') == 'FORMAT_STREAM_TYPE_OTF' and not has_drm:
continue
if has_drm:
msg = f'Some {client_name} client https formats have been skipped as they are DRM protected. '
if client_name == 'tv':
msg += (
f'{"Your account" if self.is_authenticated else "The current session"} may have '
f'an experiment that applies DRM to all videos on the tv client. '
f'See https://github.com/yt-dlp/yt-dlp/issues/12563 for more details.'
)
self.report_warning(msg, video_id, only_once=True)
fmt_url = fmt.get('url')
if not fmt_url:
sc = urllib.parse.parse_qs(fmt.get('signatureCipher'))
fmt_url = url_or_none(try_get(sc, lambda x: x['url'][0]))
encrypted_sig = try_get(sc, lambda x: x['s'][0])
if not all((sc, fmt_url, player_url, encrypted_sig)):
msg = f'Some {client_name} client https formats have been skipped as they are missing a url. '
if client_name == 'web':
msg += 'YouTube is forcing SABR streaming for this client. '
else:
msg += ( msg += (
f'YouTube may have enabled the SABR-only or Server-Side Ad Placement experiment for ' f'{"Your account" if self.is_authenticated else "The current session"} may have '
f'{"your account" if self.is_authenticated else "the current session"}. ' f'an experiment that applies DRM to all videos on the tv client. '
f'See https://github.com/yt-dlp/yt-dlp/issues/12563 for more details.'
) )
msg += 'See https://github.com/yt-dlp/yt-dlp/issues/12482 for more details'
self.report_warning(msg, video_id, only_once=True) self.report_warning(msg, video_id, only_once=True)
continue
try: tbr = float_or_none(fmt_stream.get('averageBitrate') or fmt_stream.get('bitrate'), 1000)
fmt_url += '&{}={}'.format( format_duration = traverse_obj(fmt_stream, ('approxDurationMs', {float_or_none(scale=1000)}))
traverse_obj(sc, ('sp', -1)) or 'signature', # Some formats may have much smaller duration than others (possibly damaged during encoding)
self._decrypt_signature(encrypted_sig, video_id, player_url), # E.g. 2-nOtRESiUc Ref: https://github.com/yt-dlp/yt-dlp/issues/2823
) # Make sure to avoid false positives with small duration differences.
except ExtractorError as e: # E.g. __2ABJjxzNo, ySuUZEjARPY
is_damaged = try_call(lambda: format_duration < duration // 2)
if is_damaged:
self.report_warning( self.report_warning(
f'Signature extraction failed: Some formats may be missing\n' f'Some {client_name} client {proto} formats are possibly damaged. They will be deprioritized', video_id, only_once=True)
f' player = {player_url}\n'
f' {bug_reports_message(before="")}',
video_id=video_id, only_once=True)
self.write_debug(
f'{video_id}: Signature extraction failure info:\n'
f' encrypted sig = {encrypted_sig}\n'
f' player = {player_url}')
self.write_debug(e, only_once=True)
continue
query = parse_qs(fmt_url) # Clients that require PO Token return videoplayback URLs that may return 403
if query.get('n'): require_po_token = (
try: not po_token
decrypt_nsig = self._cached(self._decrypt_nsig, 'nsig', query['n'][0]) and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS']
fmt_url = update_url_query(fmt_url, { and itag not in ['18']) # these formats do not require PO Token
'n': decrypt_nsig(query['n'][0], video_id, player_url),
})
except ExtractorError as e:
if player_url:
self.report_warning(
f'nsig extraction failed: Some formats may be missing\n'
f' n = {query["n"][0]} ; player = {player_url}\n'
f' {bug_reports_message(before="")}',
video_id=video_id, only_once=True)
self.write_debug(e, only_once=True)
else:
self.report_warning(
'Cannot decrypt nsig without player_url: Some formats may be missing',
video_id=video_id, only_once=True)
continue
tbr = float_or_none(fmt.get('averageBitrate') or fmt.get('bitrate'), 1000) if require_po_token and 'missing_pot' not in self._configuration_arg('formats'):
format_duration = traverse_obj(fmt, ('approxDurationMs', {float_or_none(scale=1000)}))
# Some formats may have much smaller duration than others (possibly damaged during encoding)
# E.g. 2-nOtRESiUc Ref: https://github.com/yt-dlp/yt-dlp/issues/2823
# Make sure to avoid false positives with small duration differences.
# E.g. __2ABJjxzNo, ySuUZEjARPY
is_damaged = try_call(lambda: format_duration < duration // 2)
if is_damaged:
self.report_warning(
'Some formats are possibly damaged. They will be deprioritized', video_id, only_once=True)
po_token = fmt.get(STREAMING_DATA_INITIAL_PO_TOKEN)
if po_token:
fmt_url = update_url_query(fmt_url, {'pot': po_token})
# Clients that require PO Token return videoplayback URLs that may return 403
require_po_token = (
not po_token
and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS']
and itag not in ['18']) # these formats do not require PO Token
if require_po_token and 'missing_pot' not in self._configuration_arg('formats'):
self._report_pot_format_skipped(video_id, client_name, 'https')
continue
name = fmt.get('qualityLabel') or quality.replace('audio_quality_', '') or ''
fps = int_or_none(fmt.get('fps')) or 0
dct = {
'asr': int_or_none(fmt.get('audioSampleRate')),
'filesize': int_or_none(fmt.get('contentLength')),
'format_id': f'{itag}{"-drc" if fmt.get("isDrc") else ""}',
'format_note': join_nonempty(
join_nonempty(display_name, is_default and ' (default)', delim=''),
name, fmt.get('isDrc') and 'DRC',
try_get(fmt, lambda x: x['projectionType'].replace('RECTANGULAR', '').lower()),
try_get(fmt, lambda x: x['spatialAudioType'].replace('SPATIAL_AUDIO_TYPE_', '').lower()),
is_damaged and 'DAMAGED', require_po_token and 'MISSING POT',
(self.get_param('verbose') or all_formats) and short_client_name(client_name),
delim=', '),
# Format 22 is likely to be damaged. See https://github.com/yt-dlp/yt-dlp/issues/3372
'source_preference': (-5 if itag == '22' else -1) + (100 if 'Premium' in name else 0),
'fps': fps if fps > 1 else None, # For some formats, fps is wrongly returned as 1
'audio_channels': fmt.get('audioChannels'),
'height': height,
'quality': q(quality) - bool(fmt.get('isDrc')) / 2,
'has_drm': has_drm,
'tbr': tbr,
'filesize_approx': filesize_from_tbr(tbr, format_duration),
'url': fmt_url,
'width': int_or_none(fmt.get('width')),
'language': join_nonempty(language_code, 'desc' if is_descriptive else '') or None,
'language_preference': PREFERRED_LANG_VALUE if is_original else 5 if is_default else -10 if is_descriptive else -1,
# Strictly de-prioritize damaged and 3gp formats
'preference': -10 if is_damaged else -2 if itag == '17' else None,
}
mime_mobj = re.match(
r'((?:[^/]+)/(?:[^;]+))(?:;\s*codecs="([^"]+)")?', fmt.get('mimeType') or '')
if mime_mobj:
dct['ext'] = mimetype2ext(mime_mobj.group(1))
dct.update(parse_codecs(mime_mobj.group(2)))
if itag:
itags[itag].add(('https', dct.get('language')))
stream_ids.append(stream_id)
single_stream = 'none' in (dct.get('acodec'), dct.get('vcodec'))
if single_stream and dct.get('ext'):
dct['container'] = dct['ext'] + '_dash'
if (all_formats or 'dashy' in format_types) and dct['filesize']:
yield {
**dct,
'format_id': f'{dct["format_id"]}-dashy' if all_formats else dct['format_id'],
'protocol': 'http_dash_segments',
'fragments': build_fragments(dct),
}
if all_formats or 'dashy' not in format_types:
dct['downloader_options'] = {'http_chunk_size': CHUNK_SIZE}
yield dct
needs_live_processing = self._needs_live_processing(live_status, duration)
skip_bad_formats = 'incomplete' not in format_types
if self._configuration_arg('include_incomplete_formats'):
skip_bad_formats = False
self._downloader.deprecated_feature('[youtube] include_incomplete_formats extractor argument is deprecated. '
'Use formats=incomplete extractor argument instead')
skip_manifests = set(self._configuration_arg('skip'))
if (not self.get_param('youtube_include_hls_manifest', True)
or needs_live_processing == 'is_live' # These will be filtered out by YoutubeDL anyway
or (needs_live_processing and skip_bad_formats)):
skip_manifests.add('hls')
if not self.get_param('youtube_include_dash_manifest', True):
skip_manifests.add('dash')
if self._configuration_arg('include_live_dash'):
self._downloader.deprecated_feature('[youtube] include_live_dash extractor argument is deprecated. '
'Use formats=incomplete extractor argument instead')
elif skip_bad_formats and live_status == 'is_live' and needs_live_processing != 'is_live':
skip_manifests.add('dash')
def process_manifest_format(f, proto, client_name, itag, po_token):
key = (proto, f.get('language'))
if not all_formats and key in itags[itag]:
return False
if f.get('source_preference') is None:
f['source_preference'] = -1
# Clients that require PO Token return videoplayback URLs that may return 403
# hls does not currently require PO Token
if (
not po_token
and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS']
and proto != 'hls'
):
if 'missing_pot' not in self._configuration_arg('formats'):
self._report_pot_format_skipped(video_id, client_name, proto) self._report_pot_format_skipped(video_id, client_name, proto)
return None
name = fmt_stream.get('qualityLabel') or quality.replace('audio_quality_', '') or ''
fps = int_or_none(fmt_stream.get('fps')) or 0
dct = {
'asr': int_or_none(fmt_stream.get('audioSampleRate')),
'filesize': int_or_none(fmt_stream.get('contentLength')),
'format_id': f'{itag}{"-drc" if fmt_stream.get("isDrc") else ""}',
'format_note': join_nonempty(
join_nonempty(display_name, is_default and ' (default)', delim=''),
name, fmt_stream.get('isDrc') and 'DRC',
try_get(fmt_stream, lambda x: x['projectionType'].replace('RECTANGULAR', '').lower()),
try_get(fmt_stream, lambda x: x['spatialAudioType'].replace('SPATIAL_AUDIO_TYPE_', '').lower()),
is_damaged and 'DAMAGED', require_po_token and 'MISSING POT',
(self.get_param('verbose') or all_formats) and short_client_name(client_name),
delim=', '),
# Format 22 is likely to be damaged. See https://github.com/yt-dlp/yt-dlp/issues/3372
'source_preference': (-5 if itag == '22' else -1) + (100 if 'Premium' in name else 0),
'fps': fps if fps > 1 else None, # For some formats, fps is wrongly returned as 1
'audio_channels': fmt_stream.get('audioChannels'),
'height': height,
'quality': q(quality) - bool(fmt_stream.get('isDrc')) / 2,
'has_drm': has_drm,
'tbr': tbr,
'filesize_approx': filesize_from_tbr(tbr, format_duration),
'width': int_or_none(fmt_stream.get('width')),
'language': join_nonempty(language_code, 'desc' if is_descriptive else '') or None,
'language_preference': PREFERRED_LANG_VALUE if is_original else 5 if is_default else -10 if is_descriptive else -1,
# Strictly de-prioritize damaged and 3gp formats
'preference': -10 if is_damaged else -2 if itag == '17' else None,
}
mime_mobj = re.match(
r'((?:[^/]+)/(?:[^;]+))(?:;\s*codecs="([^"]+)")?', fmt_stream.get('mimeType') or '')
if mime_mobj:
dct['ext'] = mimetype2ext(mime_mobj.group(1))
dct.update(parse_codecs(mime_mobj.group(2)))
single_stream = 'none' in (dct.get('acodec'), dct.get('vcodec'))
if single_stream and dct.get('ext'):
dct['container'] = dct['ext'] + '_dash'
return dct
def process_sabr_formats_and_subtitles():
proto = 'sabr'
server_abr_streaming_url = (self._process_n_param
(streaming_data.get('serverAbrStreamingUrl'), video_id, player_url, proto))
video_playback_ustreamer_config = traverse_obj(
pr, ('playerConfig', 'mediaCommonConfig', 'mediaUstreamerRequestConfig', 'videoPlaybackUstreamerConfig'))
if not server_abr_streaming_url or not video_playback_ustreamer_config:
return
if protobug is None:
self.report_warning(
f'{video_id}: {client_name} client {proto} formats will be skipped as protobug is not installed.',
only_once=True)
return
sabr_config = {
'video_playback_ustreamer_config': video_playback_ustreamer_config,
'po_token': po_token,
'fetch_po_token_fn': fetch_po_token,
'client_name': client_name,
'client_info': traverse_obj(innertube_context, 'client'),
'reload_config_fn': functools.partial(self._reload_sabr_config, video_id, client_name),
'video_id': video_id,
'live_status': live_status,
}
for fmt_stream in streaming_formats:
stream_id = get_stream_id(fmt_stream)
if not all_formats:
if stream_id in stream_ids:
continue
fmt = process_format_stream(fmt_stream, proto)
if not fmt:
continue
caption_track = fmt_stream.get('captionTrack')
fmt.update({
'is_from_start': live_status == 'is_live' and self.get_param('live_from_start'),
'url': server_abr_streaming_url,
'protocol': 'sabr',
})
fmt['_sabr_config'] = {
**sabr_config,
'itag': stream_id[0],
'xtags': fmt_stream.get('xtags'),
'last_modified': fmt_stream.get('lastModified'),
'target_duration_sec': fmt_stream.get('targetDurationSec'),
}
single_stream = 'none' in (fmt.get('acodec'), fmt.get('vcodec'))
nonlocal subtitles
if caption_track:
# TODO: proper live subtitle extraction
subtitles = self._merge_subtitles({str(stream_id[0]): [fmt]}, subtitles)
elif single_stream:
if stream_id[0]:
itags[stream_id[0]].add((proto, fmt.get('language')))
stream_ids.append(stream_id)
yield fmt
def process_https_formats():
proto = 'https'
for fmt_stream in streaming_formats:
if fmt_stream.get('targetDurationSec'):
continue
# FORMAT_STREAM_TYPE_OTF(otf=1) requires downloading the init fragment
# (adding `&sq=0` to the URL) and parsing emsg box to determine the
# number of fragment that would subsequently requested with (`&sq=N`)
if fmt_stream.get('type') == 'FORMAT_STREAM_TYPE_OTF' and not bool(fmt_stream.get('drmFamilies')):
continue
stream_id = get_stream_id(fmt_stream)
if not all_formats:
if stream_id in stream_ids:
continue
fmt = process_format_stream(fmt_stream, proto)
if not fmt:
continue
fmt_url = fmt_stream.get('url')
if not fmt_url:
sc = urllib.parse.parse_qs(fmt_stream.get('signatureCipher'))
fmt_url = url_or_none(try_get(sc, lambda x: x['url'][0]))
encrypted_sig = try_get(sc, lambda x: x['s'][0])
if not all((sc, fmt_url, player_url, encrypted_sig)):
continue
try:
fmt_url += '&{}={}'.format(
traverse_obj(sc, ('sp', -1)) or 'signature',
self._decrypt_signature(encrypted_sig, video_id, player_url),
)
except ExtractorError as e:
self.report_warning(
f'Signature extraction failed: Some formats may be missing\n'
f' player = {player_url}\n'
f' {bug_reports_message(before="")}',
video_id=video_id, only_once=True)
self.write_debug(
f'{video_id}: Signature extraction failure info:\n'
f' encrypted sig = {encrypted_sig}\n'
f' player = {player_url}')
self.write_debug(e, only_once=True)
continue
fmt_url = self._process_n_param(fmt_url, video_id, player_url, proto)
if not fmt_url:
continue
if po_token:
fmt_url = update_url_query(fmt_url, {'pot': po_token})
fmt['url'] = fmt_url
if stream_id[0]:
itags[stream_id[0]].add((proto, fmt.get('language')))
stream_ids.append(stream_id)
if (all_formats or 'dashy' in format_types) and fmt['filesize']:
yield {
**fmt,
'format_id': f'{fmt["format_id"]}-dashy' if all_formats else fmt['format_id'],
'protocol': 'http_dash_segments',
'fragments': build_fragments(fmt),
}
if all_formats or 'dashy' not in format_types:
fmt['downloader_options'] = {'http_chunk_size': CHUNK_SIZE}
yield fmt
yield from process_https_formats()
yield from process_sabr_formats_and_subtitles()
needs_live_processing = self._needs_live_processing(live_status, duration)
skip_bad_formats = 'incomplete' not in format_types
if self._configuration_arg('include_incomplete_formats'):
skip_bad_formats = False
self._downloader.deprecated_feature('[youtube] include_incomplete_formats extractor argument is deprecated. '
'Use formats=incomplete extractor argument instead')
skip_manifests = set(self._configuration_arg('skip'))
if (not self.get_param('youtube_include_hls_manifest', True)
or needs_live_processing == 'is_live' # These will be filtered out by YoutubeDL anyway
or (needs_live_processing and skip_bad_formats)):
skip_manifests.add('hls')
if not self.get_param('youtube_include_dash_manifest', True):
skip_manifests.add('dash')
if self._configuration_arg('include_live_dash'):
self._downloader.deprecated_feature('[youtube] include_live_dash extractor argument is deprecated. '
'Use formats=incomplete extractor argument instead')
elif skip_bad_formats and live_status == 'is_live' and needs_live_processing != 'is_live':
skip_manifests.add('dash')
def process_manifest_format(f, proto, client_name, itag, po_token):
key = (proto, f.get('language'))
if not all_formats and key in itags[itag]:
return False return False
f['format_note'] = join_nonempty(f.get('format_note'), 'MISSING POT', delim=' ')
f['source_preference'] -= 20
itags[itag].add(key) if f.get('source_preference') is None:
f['source_preference'] = -1
if itag and all_formats: # Clients that require PO Token return videoplayback URLs that may return 403
f['format_id'] = f'{itag}-{proto}' # hls does not currently require PO Token
elif any(p != proto for p, _ in itags[itag]): if (
f['format_id'] = f'{itag}-{proto}' not po_token
elif itag: and _PoTokenContext.GVS in self._get_default_ytcfg(client_name)['PO_TOKEN_REQUIRED_CONTEXTS']
f['format_id'] = itag and proto != 'hls'
):
if 'missing_pot' not in self._configuration_arg('formats'):
self._report_pot_format_skipped(video_id, client_name, proto)
return False
f['format_note'] = join_nonempty(f.get('format_note'), 'MISSING POT', delim=' ')
f['source_preference'] -= 20
if original_language and f.get('language') == original_language: itags[itag].add(key)
f['format_note'] = join_nonempty(f.get('format_note'), '(default)', delim=' ')
f['language_preference'] = PREFERRED_LANG_VALUE
if itag in ('616', '235'): if itag and all_formats:
f['format_note'] = join_nonempty(f.get('format_note'), 'Premium', delim=' ') f['format_id'] = f'{itag}-{proto}'
f['source_preference'] += 100 elif any(p != proto for p, _ in itags[itag]):
f['format_id'] = f'{itag}-{proto}'
elif itag:
f['format_id'] = itag
f['quality'] = q(itag_qualities.get(try_get(f, lambda f: f['format_id'].split('-')[0]), -1)) if original_language and f.get('language') == original_language:
if f['quality'] == -1 and f.get('height'): f['format_note'] = join_nonempty(f.get('format_note'), '(default)', delim=' ')
f['quality'] = q(res_qualities[min(res_qualities, key=lambda x: abs(x - f['height']))]) f['language_preference'] = PREFERRED_LANG_VALUE
if self.get_param('verbose') or all_formats:
f['format_note'] = join_nonempty(
f.get('format_note'), short_client_name(client_name), delim=', ')
if f.get('fps') and f['fps'] <= 1:
del f['fps']
if proto == 'hls' and f.get('has_drm'): if itag in ('616', '235'):
f['has_drm'] = 'maybe' f['format_note'] = join_nonempty(f.get('format_note'), 'Premium', delim=' ')
f['source_preference'] -= 5 f['source_preference'] += 100
return True
subtitles = {} f['quality'] = q(itag_qualities.get(try_get(f, lambda f: f['format_id'].split('-')[0]), -1))
for sd in streaming_data: if f['quality'] == -1 and f.get('height'):
client_name = sd[STREAMING_DATA_CLIENT_NAME] f['quality'] = q(res_qualities[min(res_qualities, key=lambda x: abs(x - f['height']))])
po_token = sd.get(STREAMING_DATA_INITIAL_PO_TOKEN) if self.get_param('verbose') or all_formats:
hls_manifest_url = 'hls' not in skip_manifests and sd.get('hlsManifestUrl') f['format_note'] = join_nonempty(
f.get('format_note'), short_client_name(client_name), delim=', ')
if f.get('fps') and f['fps'] <= 1:
del f['fps']
if proto == 'hls' and f.get('has_drm'):
f['has_drm'] = 'maybe'
f['source_preference'] -= 5
return True
hls_manifest_url = 'hls' not in skip_manifests and streaming_data.get('hlsManifestUrl')
if hls_manifest_url: if hls_manifest_url:
if po_token: if po_token:
hls_manifest_url = hls_manifest_url.rstrip('/') + f'/pot/{po_token}' hls_manifest_url = hls_manifest_url.rstrip('/') + f'/pot/{po_token}'
@ -3602,7 +3727,7 @@ def process_manifest_format(f, proto, client_name, itag, po_token):
r'/itag/(\d+)', f['url'], 'itag', default=None), po_token): r'/itag/(\d+)', f['url'], 'itag', default=None), po_token):
yield f yield f
dash_manifest_url = 'dash' not in skip_manifests and sd.get('dashManifestUrl') dash_manifest_url = 'dash' not in skip_manifests and streaming_data.get('dashManifestUrl')
if dash_manifest_url: if dash_manifest_url:
if po_token: if po_token:
dash_manifest_url = dash_manifest_url.rstrip('/') + f'/pot/{po_token}' dash_manifest_url = dash_manifest_url.rstrip('/') + f'/pot/{po_token}'
@ -3659,11 +3784,14 @@ def _extract_storyboard(self, player_responses, duration):
} for j in range(math.ceil(fragment_count))], } for j in range(math.ceil(fragment_count))],
} }
def _download_player_responses(self, url, smuggled_data, video_id, webpage_url): def _download_player_responses(self, url, smuggled_data, video_id, webpage_url, reload_playback_token=None):
webpage = None webpage = None
if 'webpage' not in self._configuration_arg('player_skip'): if 'webpage' not in self._configuration_arg('player_skip'):
query = {'bpctr': '9999999999', 'has_verified': '1'} query = {'bpctr': '9999999999', 'has_verified': '1'}
pp = self._configuration_arg('player_params', [None], casesense=True)[0] pp = (
self._configuration_arg('player_params', [None], casesense=True)[0]
or traverse_obj(INNERTUBE_CLIENTS, ('web', 'PLAYER_PARAMS', {str}))
)
if pp: if pp:
query['pp'] = pp query['pp'] = pp
webpage = self._download_webpage_with_retries(webpage_url, video_id, query=query) webpage = self._download_webpage_with_retries(webpage_url, video_id, query=query)
@ -3672,7 +3800,7 @@ def _download_player_responses(self, url, smuggled_data, video_id, webpage_url):
player_responses, player_url = self._extract_player_responses( player_responses, player_url = self._extract_player_responses(
self._get_requested_clients(url, smuggled_data), self._get_requested_clients(url, smuggled_data),
video_id, webpage, master_ytcfg, smuggled_data) video_id, webpage, master_ytcfg, reload_playback_token)
return webpage, master_ytcfg, player_responses, player_url return webpage, master_ytcfg, player_responses, player_url
@ -3690,14 +3818,14 @@ def _list_formats(self, video_id, microformats, video_details, player_responses,
else 'was_live' if live_content else 'was_live' if live_content
else 'not_live' if False in (is_live, live_content) else 'not_live' if False in (is_live, live_content)
else None) else None)
streaming_data = traverse_obj(player_responses, (..., 'streamingData')) *formats, subtitles = self._extract_formats_and_subtitles(video_id, player_responses, player_url, live_status, duration)
*formats, subtitles = self._extract_formats_and_subtitles(streaming_data, video_id, player_url, live_status, duration)
if all(f.get('has_drm') for f in formats): if all(f.get('has_drm') for f in formats):
# If there are no formats that definitely don't have DRM, all have DRM # If there are no formats that definitely don't have DRM, all have DRM
for f in formats: for f in formats:
f['has_drm'] = True f['has_drm'] = True
return live_broadcast_details, live_status, streaming_data, formats, subtitles return live_broadcast_details, live_status, formats, subtitles
def _real_extract(self, url): def _real_extract(self, url):
url, smuggled_data = unsmuggle_url(url, {}) url, smuggled_data = unsmuggle_url(url, {})
@ -3787,8 +3915,9 @@ def feed_entry(name):
or int_or_none(get_first(microformats, 'lengthSeconds')) or int_or_none(get_first(microformats, 'lengthSeconds'))
or parse_duration(search_meta('duration')) or None) or parse_duration(search_meta('duration')) or None)
live_broadcast_details, live_status, streaming_data, formats, automatic_captions = \ live_broadcast_details, live_status, formats, automatic_captions = \
self._list_formats(video_id, microformats, video_details, player_responses, player_url, duration) self._list_formats(video_id, microformats, video_details, player_responses, player_url, duration)
streaming_data = traverse_obj(player_responses, (..., 'streamingData'))
if live_status == 'post_live': if live_status == 'post_live':
self.write_debug(f'{video_id}: Video is in Post-Live Manifestless mode') self.write_debug(f'{video_id}: Video is in Post-Live Manifestless mode')