mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-07-07 22:08:38 +00:00
336 lines
16 KiB
Python
336 lines
16 KiB
Python
from __future__ import annotations
|
|
import collections
|
|
import itertools
|
|
|
|
from yt_dlp.networking.exceptions import TransportError, HTTPError
|
|
from yt_dlp.utils import traverse_obj, int_or_none, DownloadError, join_nonempty
|
|
from yt_dlp.downloader import FileDownloader
|
|
|
|
from ._writer import SabrFDFormatWriter
|
|
from ._logger import create_sabrfd_logger
|
|
|
|
from yt_dlp.extractor.youtube._streaming.sabr.part import (
|
|
MediaSegmentEndSabrPart,
|
|
MediaSegmentDataSabrPart,
|
|
MediaSegmentInitSabrPart,
|
|
PoTokenStatusSabrPart,
|
|
RefreshPlayerResponseSabrPart,
|
|
MediaSeekSabrPart,
|
|
FormatInitializedSabrPart,
|
|
)
|
|
from yt_dlp.extractor.youtube._streaming.sabr.stream import SabrStream
|
|
from yt_dlp.extractor.youtube._streaming.sabr.models import ConsumedRange, AudioSelector, VideoSelector, CaptionSelector
|
|
from yt_dlp.extractor.youtube._streaming.sabr.exceptions import SabrStreamError
|
|
from yt_dlp.extractor.youtube._proto.innertube import ClientInfo, ClientName
|
|
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
|
|
|
|
|
|
class SabrFD(FileDownloader):
|
|
|
|
@classmethod
|
|
def can_download(cls, info_dict):
|
|
return (
|
|
info_dict.get('requested_formats')
|
|
and all(
|
|
format_info.get('protocol') == 'sabr'
|
|
for format_info in info_dict['requested_formats']))
|
|
|
|
def _group_formats_by_client(self, filename, info_dict):
|
|
format_groups = collections.defaultdict(dict, {})
|
|
requested_formats = info_dict.get('requested_formats') or [info_dict]
|
|
|
|
for _idx, f in enumerate(requested_formats):
|
|
sabr_config = f.get('_sabr_config')
|
|
client_name = sabr_config.get('client_name')
|
|
client_info = sabr_config.get('client_info')
|
|
server_abr_streaming_url = f.get('url')
|
|
video_playback_ustreamer_config = sabr_config.get('video_playback_ustreamer_config')
|
|
|
|
if not video_playback_ustreamer_config:
|
|
raise DownloadError('Video playback ustreamer config not found')
|
|
|
|
sabr_format_group_config = format_groups.get(client_name)
|
|
|
|
if not sabr_format_group_config:
|
|
sabr_format_group_config = format_groups[client_name] = {
|
|
'server_abr_streaming_url': server_abr_streaming_url,
|
|
'video_playback_ustreamer_config': video_playback_ustreamer_config,
|
|
'formats': [],
|
|
'initial_po_token': sabr_config.get('po_token'),
|
|
'fetch_po_token_fn': fn if callable(fn := sabr_config.get('fetch_po_token_fn')) else None,
|
|
'reload_config_fn': fn if callable(fn := sabr_config.get('reload_config_fn')) else None,
|
|
'live_status': sabr_config.get('live_status'),
|
|
'video_id': sabr_config.get('video_id'),
|
|
'client_info': ClientInfo(
|
|
client_name=traverse_obj(client_info, ('clientName', {lambda x: ClientName[x]})),
|
|
client_version=traverse_obj(client_info, 'clientVersion'),
|
|
os_version=traverse_obj(client_info, 'osVersion'),
|
|
os_name=traverse_obj(client_info, 'osName'),
|
|
device_model=traverse_obj(client_info, 'deviceModel'),
|
|
device_make=traverse_obj(client_info, 'deviceMake'),
|
|
),
|
|
'target_duration_sec': sabr_config.get('target_duration_sec'),
|
|
# Number.MAX_SAFE_INTEGER
|
|
'start_time_ms': ((2**53) - 1) if info_dict.get('live_status') == 'is_live' and not f.get('is_from_start') else 0,
|
|
}
|
|
|
|
else:
|
|
if sabr_format_group_config['server_abr_streaming_url'] != server_abr_streaming_url:
|
|
raise DownloadError('Server ABR streaming URL mismatch')
|
|
|
|
if sabr_format_group_config['video_playback_ustreamer_config'] != video_playback_ustreamer_config:
|
|
raise DownloadError('Video playback ustreamer config mismatch')
|
|
|
|
itag = int_or_none(sabr_config.get('itag'))
|
|
sabr_format_group_config['formats'].append({
|
|
'display_name': f.get('format_id'),
|
|
'format_id': itag and FormatId(
|
|
itag=itag, lmt=int_or_none(sabr_config.get('last_modified')), xtags=sabr_config.get('xtags')),
|
|
'format_type': format_type(f),
|
|
'quality': sabr_config.get('quality'),
|
|
'height': sabr_config.get('height'),
|
|
'filename': f.get('filepath', filename),
|
|
'info_dict': f,
|
|
})
|
|
|
|
return format_groups
|
|
|
|
def real_download(self, filename, info_dict):
|
|
format_groups = self._group_formats_by_client(filename, info_dict)
|
|
|
|
is_test = self.params.get('test', False)
|
|
resume = self.params.get('continuedl', True)
|
|
|
|
for client_name, format_group in format_groups.items():
|
|
formats = format_group['formats']
|
|
audio_formats = (f for f in formats if f['format_type'] == 'audio')
|
|
video_formats = (f for f in formats if f['format_type'] == 'video')
|
|
caption_formats = (f for f in formats if f['format_type'] == 'caption')
|
|
for audio_format, video_format, caption_format in itertools.zip_longest(audio_formats, video_formats, caption_formats):
|
|
format_str = join_nonempty(*[
|
|
traverse_obj(audio_format, 'display_name'),
|
|
traverse_obj(video_format, 'display_name'),
|
|
traverse_obj(caption_format, 'display_name')], delim='+')
|
|
self.write_debug(f'Downloading formats: {format_str} ({client_name} client)')
|
|
self._download_sabr_stream(
|
|
info_dict=info_dict,
|
|
video_format=video_format,
|
|
audio_format=audio_format,
|
|
caption_format=caption_format,
|
|
resume=resume,
|
|
is_test=is_test,
|
|
server_abr_streaming_url=format_group['server_abr_streaming_url'],
|
|
video_playback_ustreamer_config=format_group['video_playback_ustreamer_config'],
|
|
initial_po_token=format_group['initial_po_token'],
|
|
fetch_po_token_fn=format_group['fetch_po_token_fn'],
|
|
reload_config_fn=format_group['reload_config_fn'],
|
|
client_info=format_group['client_info'],
|
|
start_time_ms=format_group['start_time_ms'],
|
|
target_duration_sec=format_group.get('target_duration_sec', None),
|
|
live_status=format_group.get('live_status'),
|
|
video_id=format_group.get('video_id'),
|
|
)
|
|
return True
|
|
|
|
def _download_sabr_stream(
|
|
self,
|
|
video_id: str,
|
|
info_dict: dict,
|
|
video_format: dict,
|
|
audio_format: dict,
|
|
caption_format: dict,
|
|
resume: bool,
|
|
is_test: bool,
|
|
server_abr_streaming_url: str,
|
|
video_playback_ustreamer_config: str,
|
|
initial_po_token: str,
|
|
fetch_po_token_fn: callable | None = None,
|
|
reload_config_fn: callable | None = None,
|
|
client_info: ClientInfo | None = None,
|
|
start_time_ms: int = 0,
|
|
target_duration_sec: int | None = None,
|
|
live_status: str | None = None,
|
|
):
|
|
|
|
writers = {}
|
|
audio_selector = None
|
|
video_selector = None
|
|
caption_selector = None
|
|
|
|
if audio_format:
|
|
audio_selector = AudioSelector(
|
|
display_name=audio_format['display_name'], format_ids=[audio_format['format_id']])
|
|
writers[audio_selector.display_name] = SabrFDFormatWriter(
|
|
self, audio_format.get('filename'),
|
|
audio_format['info_dict'], len(writers), resume=resume)
|
|
|
|
if video_format:
|
|
video_selector = VideoSelector(
|
|
display_name=video_format['display_name'], format_ids=[video_format['format_id']])
|
|
writers[video_selector.display_name] = SabrFDFormatWriter(
|
|
self, video_format.get('filename'),
|
|
video_format['info_dict'], len(writers), resume=resume)
|
|
|
|
if caption_format:
|
|
caption_selector = CaptionSelector(
|
|
display_name=caption_format['display_name'], format_ids=[caption_format['format_id']])
|
|
writers[caption_selector.display_name] = SabrFDFormatWriter(
|
|
self, caption_format.get('filename'),
|
|
caption_format['info_dict'], len(writers), resume=resume)
|
|
|
|
# Report the destination files before we start downloading instead of when we initialize the writers,
|
|
# as the formats may not all start at the same time (leading to messy output)
|
|
for writer in writers.values():
|
|
self.report_destination(writer.filename)
|
|
|
|
stream = SabrStream(
|
|
urlopen=self.ydl.urlopen,
|
|
logger=create_sabrfd_logger(self.ydl, prefix='sabr:stream'),
|
|
server_abr_streaming_url=server_abr_streaming_url,
|
|
video_playback_ustreamer_config=video_playback_ustreamer_config,
|
|
po_token=initial_po_token,
|
|
video_selection=video_selector,
|
|
audio_selection=audio_selector,
|
|
caption_selection=caption_selector,
|
|
start_time_ms=start_time_ms,
|
|
client_info=client_info,
|
|
live_segment_target_duration_sec=target_duration_sec,
|
|
post_live=live_status == 'post_live',
|
|
video_id=video_id,
|
|
retry_sleep_func=self.params.get('retry_sleep_functions', {}).get('http'),
|
|
)
|
|
|
|
self._prepare_multiline_status(len(writers) + 1)
|
|
|
|
try:
|
|
total_bytes = 0
|
|
for part in stream:
|
|
if is_test and total_bytes >= self._TEST_FILE_SIZE:
|
|
stream.close()
|
|
break
|
|
if isinstance(part, PoTokenStatusSabrPart):
|
|
if not fetch_po_token_fn:
|
|
self.report_warning(
|
|
'No fetch PO token function found - this can happen if you use --load-info-json.'
|
|
' The download will fail if a valid PO token is required.', only_once=True)
|
|
if part.status in (
|
|
part.PoTokenStatus.INVALID,
|
|
part.PoTokenStatus.PENDING,
|
|
):
|
|
# Fetch a PO token with bypass_cache=True
|
|
# (ensure we create a new one)
|
|
po_token = fetch_po_token_fn(bypass_cache=True)
|
|
if po_token:
|
|
stream.processor.po_token = po_token
|
|
elif part.status in (
|
|
part.PoTokenStatus.MISSING,
|
|
part.PoTokenStatus.PENDING_MISSING,
|
|
):
|
|
# Fetch a PO Token, bypass_cache=False
|
|
po_token = fetch_po_token_fn()
|
|
if po_token:
|
|
stream.processor.po_token = po_token
|
|
|
|
elif isinstance(part, FormatInitializedSabrPart):
|
|
writer = writers.get(part.format_selector.display_name)
|
|
if not writer:
|
|
self.report_warning(f'Unknown format selector: {part.format_selector}')
|
|
continue
|
|
|
|
writer.initialize_format(part.format_id)
|
|
initialized_format = stream.processor.initialized_formats[str(part.format_id)]
|
|
if writer.state.init_sequence:
|
|
initialized_format.init_segment = True
|
|
initialized_format.current_segment = None # allow a seek
|
|
|
|
# Build consumed ranges from the sequences
|
|
consumed_ranges = []
|
|
for sequence in writer.state.sequences:
|
|
consumed_ranges.append(ConsumedRange(
|
|
start_time_ms=sequence.first_segment.start_time_ms,
|
|
duration_ms=(sequence.last_segment.start_time_ms + sequence.last_segment.duration_ms) - sequence.first_segment.start_time_ms,
|
|
start_sequence_number=sequence.first_segment.sequence_number,
|
|
end_sequence_number=sequence.last_segment.sequence_number,
|
|
))
|
|
if consumed_ranges:
|
|
initialized_format.consumed_ranges = consumed_ranges
|
|
initialized_format.current_segment = None # allow a seek
|
|
self.to_screen(f'[download] Resuming download for format {part.format_selector.display_name}')
|
|
|
|
elif isinstance(part, MediaSegmentInitSabrPart):
|
|
writer = writers.get(part.format_selector.display_name)
|
|
if not writer:
|
|
self.report_warning(f'Unknown init format selector: {part.format_selector}')
|
|
continue
|
|
writer.initialize_segment(part)
|
|
|
|
elif isinstance(part, MediaSegmentDataSabrPart):
|
|
total_bytes += len(part.data) # TODO: not reliable
|
|
writer = writers.get(part.format_selector.display_name)
|
|
if not writer:
|
|
self.report_warning(f'Unknown data format selector: {part.format_selector}')
|
|
continue
|
|
writer.write_segment_data(part)
|
|
|
|
elif isinstance(part, MediaSegmentEndSabrPart):
|
|
writer = writers.get(part.format_selector.display_name)
|
|
if not writer:
|
|
self.report_warning(f'Unknown end format selector: {part.format_selector}')
|
|
continue
|
|
writer.end_segment(part)
|
|
|
|
elif isinstance(part, RefreshPlayerResponseSabrPart):
|
|
self.to_screen(f'Refreshing player response; Reason: {part.reason}')
|
|
# In-place refresh - not ideal but should work in most cases
|
|
# TODO: handle case where live stream changes to non-livestream on refresh?
|
|
# TODO: if live, allow a seek as for non-DVR streams the reload may be longer than the buffer duration
|
|
# TODO: handle po token function change
|
|
if not reload_config_fn:
|
|
raise self.report_warning(
|
|
'No reload config function found - cannot refresh SABR streaming URL.'
|
|
' The url will expire soon and the download will fail.')
|
|
try:
|
|
stream.url, stream.processor.video_playback_ustreamer_config = reload_config_fn(part.reload_playback_token)
|
|
except (TransportError, HTTPError) as e:
|
|
self.report_warning(f'Failed to refresh SABR streaming URL: {e}')
|
|
|
|
elif isinstance(part, MediaSeekSabrPart):
|
|
if (
|
|
not info_dict.get('is_live')
|
|
and live_status not in ('post_live', 'is_live')
|
|
and not stream.processor.is_live
|
|
and part.reason == MediaSeekSabrPart.Reason.SERVER_SEEK
|
|
):
|
|
raise DownloadError('Server tried to seek a video')
|
|
else:
|
|
self.to_screen(f'Unhandled part type: {part.__class__.__name__}')
|
|
|
|
for writer in writers.values():
|
|
writer.finish()
|
|
except SabrStreamError as e:
|
|
raise DownloadError(str(e)) from e
|
|
except KeyboardInterrupt:
|
|
if (
|
|
not info_dict.get('is_live')
|
|
and not live_status == 'is_live'
|
|
and not stream.processor.is_live
|
|
):
|
|
raise
|
|
self.to_screen('Interrupted by user')
|
|
for writer in writers.values():
|
|
writer.finish()
|
|
finally:
|
|
# TODO: for livestreams, since we cannot resume them, should we finish the writers?
|
|
for writer in writers.values():
|
|
writer.close()
|
|
|
|
|
|
def format_type(f):
|
|
if f.get('acodec') == 'none':
|
|
return 'video'
|
|
elif f.get('vcodec') == 'none':
|
|
return 'audio'
|
|
elif f.get('vcodec') is None and f.get('acodec') is None:
|
|
return 'caption'
|
|
return None
|