1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2026-02-11 00:14:34 +00:00
Files
yt-dlp/yt_dlp/downloader/sabr/_writer.py
2025-06-21 11:15:25 +12:00

356 lines
14 KiB
Python

from __future__ import annotations
import dataclasses
from ._io import DiskFormatIOBackend
from ._file import SequenceFile, Sequence, Segment
from ._state import (
SabrStateSegment,
SabrStateSequence,
SabrStateInitSegment,
SabrState,
SabrStateFile,
)
from yt_dlp.extractor.youtube._streaming.sabr.part import (
MediaSegmentInitSabrPart,
MediaSegmentDataSabrPart,
MediaSegmentEndSabrPart,
)
from yt_dlp.utils import DownloadError
from yt_dlp.extractor.youtube._proto.videostreaming import FormatId
from yt_dlp.utils.progress import ProgressCalculator
INIT_SEGMENT_ID = 'i'
@dataclasses.dataclass
class SabrFormatState:
format_id: FormatId
init_sequence: Sequence | None = None
sequences: list[Sequence] = dataclasses.field(default_factory=list)
class SabrFDFormatWriter:
def __init__(self, fd, filename, infodict, progress_idx=0, resume=False):
self.fd = fd
self.info_dict = infodict
self.filename = filename
self.progress_idx = progress_idx
self.resume = resume
self._progress = None
self._downloaded_bytes = 0
self._state = {}
self._format_id = None
self.file = DiskFormatIOBackend(
fd=self.fd,
filename=self.fd.temp_name(filename),
)
self._sabr_state_file = SabrStateFile(format_filename=self.filename, fd=fd)
self._sequence_files: list[SequenceFile] = []
self._init_sequence: SequenceFile | None = None
@property
def state(self):
return SabrFormatState(
format_id=self._format_id,
init_sequence=self._init_sequence.sequence if self._init_sequence else None,
sequences=[sf.sequence for sf in self._sequence_files],
)
@property
def downloaded_bytes(self):
return (sum(
sequence.current_length for sequence in self._sequence_files)
+ (self._init_sequence.current_length if self._init_sequence else 0))
def initialize_format(self, format_id):
if self._format_id:
raise ValueError('Already initialized')
self._format_id = format_id
if not self.resume:
if self._sabr_state_file.exists:
self._sabr_state_file.remove()
return
document = self._load_sabr_state()
if document.init_segment:
init_segment = Segment(
segment_id=INIT_SEGMENT_ID,
content_length=document.init_segment.content_length,
is_init_segment=True,
)
try:
self._init_sequence = SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=True,
sequence=Sequence(
sequence_id=INIT_SEGMENT_ID,
sequence_content_length=init_segment.content_length,
first_segment=init_segment,
last_segment=init_segment,
))
except DownloadError as e:
self.fd.report_warning(f'Failed to resume init segment for format {self.info_dict.get("format_id")}: {e}')
for sabr_sequence in list(document.sequences):
try:
self._sequence_files.append(SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=True,
sequence=Sequence(
sequence_id=str(sabr_sequence.sequence_start_number),
sequence_content_length=sabr_sequence.sequence_content_length,
first_segment=Segment(
segment_id=str(sabr_sequence.first_segment.sequence_number),
sequence_number=sabr_sequence.first_segment.sequence_number,
content_length=sabr_sequence.first_segment.content_length,
start_time_ms=sabr_sequence.first_segment.start_time_ms,
duration_ms=sabr_sequence.first_segment.duration_ms,
is_init_segment=False,
),
last_segment=Segment(
segment_id=str(sabr_sequence.last_segment.sequence_number),
sequence_number=sabr_sequence.last_segment.sequence_number,
content_length=sabr_sequence.last_segment.content_length,
start_time_ms=sabr_sequence.last_segment.start_time_ms,
duration_ms=sabr_sequence.last_segment.duration_ms,
is_init_segment=False,
),
),
))
except DownloadError as e:
self.fd.report_warning(
f'Failed to resume sequence {sabr_sequence.sequence_start_number} '
f'for format {self.info_dict.get("format_id")}: {e}')
@property
def initialized(self):
return self._format_id is not None
def close(self):
if not self.file:
raise ValueError('Already closed')
for sequence in self._sequence_files:
sequence.close()
self._sequence_files.clear()
if self._init_sequence:
self._init_sequence.close()
self._init_sequence = None
self.file.close()
def _find_sequence_file(self, predicate):
match = None
for sequence in self._sequence_files:
if predicate(sequence):
if match is not None:
raise DownloadError('Multiple sequence files found for segment')
match = sequence
return match
def find_next_sequence_file(self, next_segment: Segment):
return self._find_sequence_file(lambda sequence: sequence.is_next_segment(next_segment))
def find_current_sequence_file(self, segment_id: str):
return self._find_sequence_file(lambda sequence: sequence.is_current_segment(segment_id))
def initialize_segment(self, part: MediaSegmentInitSabrPart):
if not self._progress:
self._progress = ProgressCalculator(part.start_bytes)
if not self._format_id:
raise ValueError('not initialized')
if part.is_init_segment:
if not self._init_sequence:
self._init_sequence = SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=False,
sequence=Sequence(
sequence_id=INIT_SEGMENT_ID,
))
self._init_sequence.initialize_segment(Segment(
segment_id=INIT_SEGMENT_ID,
content_length=part.content_length,
content_length_estimated=part.content_length_estimated,
is_init_segment=True,
))
return True
segment = Segment(
segment_id=str(part.sequence_number),
sequence_number=part.sequence_number,
start_time_ms=part.start_time_ms,
duration_ms=part.duration_ms,
duration_estimated=part.duration_estimated,
content_length=part.content_length,
content_length_estimated=part.content_length_estimated,
)
sequence_file = self.find_current_sequence_file(segment.segment_id) or self.find_next_sequence_file(segment)
if not sequence_file:
sequence_file = SequenceFile(
fd=self.fd,
format_filename=self.filename,
resume=False,
sequence=Sequence(sequence_id=str(part.sequence_number)),
)
self._sequence_files.append(sequence_file)
sequence_file.initialize_segment(segment)
return True
def write_segment_data(self, part: MediaSegmentDataSabrPart):
if part.is_init_segment:
sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID
else:
segment_id = str(part.sequence_number)
sequence_file = self.find_current_sequence_file(segment_id)
if not sequence_file:
raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?')
sequence_file.write_segment_data(part.data, segment_id)
# TODO: Handling of disjointed segments (e.g. when downloading segments out of order / concurrently)
self._progress.total = self.info_dict.get('filesize')
self._state = {
'status': 'downloading',
'downloaded_bytes': self.downloaded_bytes,
'total_bytes': self.info_dict.get('filesize'),
'filename': self.filename,
'eta': self._progress.eta.smooth,
'speed': self._progress.speed.smooth,
'elapsed': self._progress.elapsed,
'progress_idx': self.progress_idx,
'fragment_count': part.total_segments,
'fragment_index': part.sequence_number,
}
self._progress.update(self._state['downloaded_bytes'])
self.fd._hook_progress(self._state, self.info_dict)
def end_segment(self, part: MediaSegmentEndSabrPart):
if part.is_init_segment:
sequence_file, segment_id = self._init_sequence, INIT_SEGMENT_ID
else:
segment_id = str(part.sequence_number)
sequence_file = self.find_current_sequence_file(segment_id)
if not sequence_file:
raise DownloadError('Unable to find sequence file for segment. Was the segment initialized?')
sequence_file.end_segment(segment_id)
self._write_sabr_state()
def _load_sabr_state(self):
sabr_state = None
if self._sabr_state_file.exists:
try:
sabr_state = self._sabr_state_file.retrieve()
except Exception:
self.fd.report_warning(
f'Corrupted state file for format {self.info_dict.get("format_id")}, restarting download')
if sabr_state and sabr_state.format_id != self._format_id:
self.fd.report_warning(
f'Format ID mismatch in state file for {self.info_dict.get("format_id")}, restarting download')
sabr_state = None
if not sabr_state:
sabr_state = SabrState(format_id=self._format_id)
return sabr_state
def _write_sabr_state(self):
sabr_state = SabrState(format_id=self._format_id)
if not self._init_sequence:
sabr_state.init_segment = None
else:
sabr_state.init_segment = SabrStateInitSegment(
content_length=self._init_sequence.sequence.sequence_content_length,
)
sabr_state.sequences = []
for sequence_file in self._sequence_files:
# Ignore partial sequences
if not sequence_file.sequence.first_segment or not sequence_file.sequence.last_segment:
continue
sabr_state.sequences.append(SabrStateSequence(
sequence_start_number=sequence_file.sequence.first_segment.sequence_number,
sequence_content_length=sequence_file.sequence.sequence_content_length,
first_segment=SabrStateSegment(
sequence_number=sequence_file.sequence.first_segment.sequence_number,
start_time_ms=sequence_file.sequence.first_segment.start_time_ms,
duration_ms=sequence_file.sequence.first_segment.duration_ms,
duration_estimated=sequence_file.sequence.first_segment.duration_estimated,
content_length=sequence_file.sequence.first_segment.content_length,
),
last_segment=SabrStateSegment(
sequence_number=sequence_file.sequence.last_segment.sequence_number,
start_time_ms=sequence_file.sequence.last_segment.start_time_ms,
duration_ms=sequence_file.sequence.last_segment.duration_ms,
duration_estimated=sequence_file.sequence.last_segment.duration_estimated,
content_length=sequence_file.sequence.last_segment.content_length,
),
))
self._sabr_state_file.update(sabr_state)
def finish(self):
self._state['status'] = 'finished'
self.fd._hook_progress(self._state, self.info_dict)
for sequence_file in self._sequence_files:
sequence_file.close()
if self._init_sequence:
self._init_sequence.close()
# Now merge all the sequences together
self.file.initialize_writer(resume=False)
# Note: May not always be an init segment, e.g for live streams
if self._init_sequence:
self._init_sequence.read_into(self.file)
self._init_sequence.close()
# TODO: handling of disjointed segments
previous_seq_number = None
for sequence_file in sorted(
(sf for sf in self._sequence_files if sf.sequence.first_segment),
key=lambda s: s.sequence.first_segment.sequence_number):
if previous_seq_number and previous_seq_number + 1 != sequence_file.sequence.first_segment.sequence_number:
self.fd.report_warning(f'Disjointed sequences found in SABR format {self.info_dict.get("format_id")}')
previous_seq_number = sequence_file.sequence.last_segment.sequence_number
sequence_file.read_into(self.file)
sequence_file.close()
# Format temp file should have all the segments, rename it to the final name
self.file.close()
self.fd.try_rename(self.file.filename, self.fd.undo_temp_name(self.file.filename))
# Remove the state file
self._sabr_state_file.remove()
# Remove sequence files
for sf in self._sequence_files:
sf.close()
sf.remove()
if self._init_sequence:
self._init_sequence.close()
self._init_sequence.remove()
self.close()