1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-07-07 22:08:38 +00:00
yt-dlp/yt_dlp/downloader/sabr/_fd.py
2025-06-24 19:52:48 +12:00

336 lines
16 KiB
Python

from __future__ import annotations
import collections
import itertools
from yt_dlp.networking.exceptions import TransportError, HTTPError
from yt_dlp.utils import traverse_obj, int_or_none, DownloadError, join_nonempty
from yt_dlp.downloader import FileDownloader
from ._writer import SabrFDFormatWriter
from ._logger import create_sabrfd_logger
from yt_dlp.extractor.youtube._streaming.sabr.part import (
MediaSegmentEndSabrPart,
MediaSegmentDataSabrPart,
MediaSegmentInitSabrPart,
PoTokenStatusSabrPart,
RefreshPlayerResponseSabrPart,
MediaSeekSabrPart,
FormatInitializedSabrPart,
)
from yt_dlp.extractor.youtube._streaming.sabr.stream import SabrStream
from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange, AudioSelector, VideoSelector, CaptionSelector
from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, ClientName
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
class SabrFD(FileDownloader):
@classmethod
def can_download(cls, info_dict):
return (
info_dict.get('requested_formats')
and all(
format_info.get('protocol') == 'sabr'
for format_info in info_dict['requested_formats']))
def _group_formats_by_client(self, filename, info_dict):
format_groups = collections.defaultdict(dict, {})
requested_formats = info_dict.get('requested_formats') or [info_dict]
for _idx, f in enumerate(requested_formats):
sabr_config = f.get('_sabr_config')
client_name = sabr_config.get('client_name')
client_info = sabr_config.get('client_info')
server_abr_streaming_url = f.get('url')
video_playback_ustreamer_config = sabr_config.get('video_playback_ustreamer_config')
if not video_playback_ustreamer_config:
raise DownloadError('Video playback ustreamer config not found')
sabr_format_group_config = format_groups.get(client_name)
if not sabr_format_group_config:
sabr_format_group_config = format_groups[client_name] = {
'server_abr_streaming_url': server_abr_streaming_url,
'video_playback_ustreamer_config': video_playback_ustreamer_config,
'formats': [],
'initial_po_token': sabr_config.get('po_token'),
'fetch_po_token_fn': fn if callable(fn := sabr_config.get('fetch_po_token_fn')) else None,
'reload_config_fn': fn if callable(fn := sabr_config.get('reload_config_fn')) else None,
'live_status': sabr_config.get('live_status'),
'video_id': sabr_config.get('video_id'),
'client_info': ClientInfo(
client_name=traverse_obj(client_info, ('clientName', {lambda x: ClientName[x]})),
client_version=traverse_obj(client_info, 'clientVersion'),
os_version=traverse_obj(client_info, 'osVersion'),
os_name=traverse_obj(client_info, 'osName'),
device_model=traverse_obj(client_info, 'deviceModel'),
device_make=traverse_obj(client_info, 'deviceMake'),
),
'target_duration_sec': sabr_config.get('target_duration_sec'),
# Number.MAX_SAFE_INTEGER
'start_time_ms': ((2**53) - 1) if info_dict.get('live_status') == 'is_live' and not f.get('is_from_start') else 0,
}
else:
if sabr_format_group_config['server_abr_streaming_url'] != server_abr_streaming_url:
raise DownloadError('Server ABR streaming URL mismatch')
if sabr_format_group_config['video_playback_ustreamer_config'] != video_playback_ustreamer_config:
raise DownloadError('Video playback ustreamer config mismatch')
itag = int_or_none(sabr_config.get('itag'))
sabr_format_group_config['formats'].append({
'display_name': f.get('format_id'),
'format_id': itag and FormatId(
itag=itag, lmt=int_or_none(sabr_config.get('last_modified')), xtags=sabr_config.get('xtags')),
'format_type': format_type(f),
'quality': sabr_config.get('quality'),
'height': sabr_config.get('height'),
'filename': f.get('filepath', filename),
'info_dict': f,
})
return format_groups
def real_download(self, filename, info_dict):
format_groups = self._group_formats_by_client(filename, info_dict)
is_test = self.params.get('test', False)
resume = self.params.get('continuedl', True)
for client_name, format_group in format_groups.items():
formats = format_group['formats']
audio_formats = (f for f in formats if f['format_type'] == 'audio')
video_formats = (f for f in formats if f['format_type'] == 'video')
caption_formats = (f for f in formats if f['format_type'] == 'caption')
for audio_format, video_format, caption_format in itertools.zip_longest(audio_formats, video_formats, caption_formats):
format_str = join_nonempty(*[
traverse_obj(audio_format, 'display_name'),
traverse_obj(video_format, 'display_name'),
traverse_obj(caption_format, 'display_name')], delim='+')
self.write_debug(f'Downloading formats: {format_str} ({client_name} client)')
self._download_sabr_stream(
info_dict=info_dict,
video_format=video_format,
audio_format=audio_format,
caption_format=caption_format,
resume=resume,
is_test=is_test,
server_abr_streaming_url=format_group['server_abr_streaming_url'],
video_playback_ustreamer_config=format_group['video_playback_ustreamer_config'],
initial_po_token=format_group['initial_po_token'],
fetch_po_token_fn=format_group['fetch_po_token_fn'],
reload_config_fn=format_group['reload_config_fn'],
client_info=format_group['client_info'],
start_time_ms=format_group['start_time_ms'],
target_duration_sec=format_group.get('target_duration_sec', None),
live_status=format_group.get('live_status'),
video_id=format_group.get('video_id'),
)
return True
def _download_sabr_stream(
self,
video_id: str,
info_dict: dict,
video_format: dict,
audio_format: dict,
caption_format: dict,
resume: bool,
is_test: bool,
server_abr_streaming_url: str,
video_playback_ustreamer_config: str,
initial_po_token: str,
fetch_po_token_fn: callable | None = None,
reload_config_fn: callable | None = None,
client_info: ClientInfo | None = None,
start_time_ms: int = 0,
target_duration_sec: int | None = None,
live_status: str | None = None,
):
writers = {}
audio_selector = None
video_selector = None
caption_selector = None
if audio_format:
audio_selector = AudioSelector(
display_name=audio_format['display_name'], format_ids=[audio_format['format_id']])
writers[audio_selector.display_name] = SabrFDFormatWriter(
self, audio_format.get('filename'),
audio_format['info_dict'], len(writers), resume=resume)
if video_format:
video_selector = VideoSelector(
display_name=video_format['display_name'], format_ids=[video_format['format_id']])
writers[video_selector.display_name] = SabrFDFormatWriter(
self, video_format.get('filename'),
video_format['info_dict'], len(writers), resume=resume)
if caption_format:
caption_selector = CaptionSelector(
display_name=caption_format['display_name'], format_ids=[caption_format['format_id']])
writers[caption_selector.display_name] = SabrFDFormatWriter(
self, caption_format.get('filename'),
caption_format['info_dict'], len(writers), resume=resume)
# Report the destination files before we start downloading instead of when we initialize the writers,
# as the formats may not all start at the same time (leading to messy output)
for writer in writers.values():
self.report_destination(writer.filename)
stream = SabrStream(
urlopen=self.ydl.urlopen,
logger=create_sabrfd_logger(self.ydl, prefix='sabr:stream'),
server_abr_streaming_url=server_abr_streaming_url,
video_playback_ustreamer_config=video_playback_ustreamer_config,
po_token=initial_po_token,
video_selection=video_selector,
audio_selection=audio_selector,
caption_selection=caption_selector,
start_time_ms=start_time_ms,
client_info=client_info,
live_segment_target_duration_sec=target_duration_sec,
post_live=live_status == 'post_live',
video_id=video_id,
retry_sleep_func=self.params.get('retry_sleep_functions', {}).get('http'),
)
self._prepare_multiline_status(len(writers) + 1)
try:
total_bytes = 0
for part in stream:
if is_test and total_bytes >= self._TEST_FILE_SIZE:
stream.close()
break
if isinstance(part, PoTokenStatusSabrPart):
if not fetch_po_token_fn:
self.report_warning(
'No fetch PO token function found - this can happen if you use --load-info-json.'
' The download will fail if a valid PO token is required.', only_once=True)
if part.status in (
part.PoTokenStatus.INVALID,
part.PoTokenStatus.PENDING,
):
# Fetch a PO token with bypass_cache=True
# (ensure we create a new one)
po_token = fetch_po_token_fn(bypass_cache=True)
if po_token:
stream.processor.po_token = po_token
elif part.status in (
part.PoTokenStatus.MISSING,
part.PoTokenStatus.PENDING_MISSING,
):
# Fetch a PO Token, bypass_cache=False
po_token = fetch_po_token_fn()
if po_token:
stream.processor.po_token = po_token
elif isinstance(part, FormatInitializedSabrPart):
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown format selector: {part.format_selector}')
continue
writer.initialize_format(part.format_id)
initialized_format = stream.processor.initialized_formats[str(part.format_id)]
if writer.state.init_sequence:
initialized_format.init_segment = True
initialized_format.current_segment = None # allow a seek
# Build consumed ranges from the sequences
consumed_ranges = []
for sequence in writer.state.sequences:
consumed_ranges.append(ConsumedRange(
start_time_ms=sequence.first_segment.start_time_ms,
duration_ms=(sequence.last_segment.start_time_ms + sequence.last_segment.duration_ms) - sequence.first_segment.start_time_ms,
start_sequence_number=sequence.first_segment.sequence_number,
end_sequence_number=sequence.last_segment.sequence_number,
))
if consumed_ranges:
initialized_format.consumed_ranges = consumed_ranges
initialized_format.current_segment = None # allow a seek
self.to_screen(f'[download] Resuming download for format {part.format_selector.display_name}')
elif isinstance(part, MediaSegmentInitSabrPart):
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown init format selector: {part.format_selector}')
continue
writer.initialize_segment(part)
elif isinstance(part, MediaSegmentDataSabrPart):
total_bytes += len(part.data) # TODO: not reliable
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown data format selector: {part.format_selector}')
continue
writer.write_segment_data(part)
elif isinstance(part, MediaSegmentEndSabrPart):
writer = writers.get(part.format_selector.display_name)
if not writer:
self.report_warning(f'Unknown end format selector: {part.format_selector}')
continue
writer.end_segment(part)
elif isinstance(part, RefreshPlayerResponseSabrPart):
self.to_screen(f'Refreshing player response; Reason: {part.reason}')
# In-place refresh - not ideal but should work in most cases
# TODO: handle case where live stream changes to non-livestream on refresh?
# TODO: if live, allow a seek as for non-DVR streams the reload may be longer than the buffer duration
# TODO: handle po token function change
if not reload_config_fn:
raise self.report_warning(
'No reload config function found - cannot refresh SABR streaming URL.'
' The url will expire soon and the download will fail.')
try:
stream.url, stream.processor.video_playback_ustreamer_config = reload_config_fn(part.reload_playback_token)
except (TransportError, HTTPError) as e:
self.report_warning(f'Failed to refresh SABR streaming URL: {e}')
elif isinstance(part, MediaSeekSabrPart):
if (
not info_dict.get('is_live')
and live_status not in ('post_live', 'is_live')
and not stream.processor.is_live
and part.reason == MediaSeekSabrPart.Reason.SERVER_SEEK
):
raise DownloadError('Server tried to seek a video')
else:
self.to_screen(f'Unhandled part type: {part.__class__.__name__}')
for writer in writers.values():
writer.finish()
except SabrStreamError as e:
raise DownloadError(str(e)) from e
except KeyboardInterrupt:
if (
not info_dict.get('is_live')
and not live_status == 'is_live'
and not stream.processor.is_live
):
raise
self.to_screen('Interrupted by user')
for writer in writers.values():
writer.finish()
finally:
# TODO: for livestreams, since we cannot resume them, should we finish the writers?
for writer in writers.values():
writer.close()
def format_type(f):
if f.get('acodec') == 'none':
return 'video'
elif f.get('vcodec') == 'none':
return 'audio'
elif f.get('vcodec') is None and f.get('acodec') is None:
return 'caption'
return None