diff --git a/yt_dlp/webvtt.py b/yt_dlp/webvtt.py index 9f1a5086b..0aa455e0e 100644 --- a/yt_dlp/webvtt.py +++ b/yt_dlp/webvtt.py @@ -21,6 +21,8 @@ class _MatchParser: """ def __init__(self, string): + if not isinstance(string, str): + raise TypeError("Expected string input to _MatchParser") self._data = string self._pos = 0 @@ -31,7 +33,7 @@ def match(self, r): if self._data.startswith(r, self._pos): return len(r) return None - raise ValueError(r) + raise ValueError(f"Expected regex or string, got {type(r).__name__}") def advance(self, by): if by is None: @@ -43,7 +45,7 @@ def advance(self, by): elif isinstance(by, int): amt = by else: - raise ValueError(by) + raise ValueError(f"Unsupported advance type: {type(by).__name__}") self._pos += amt return by @@ -102,6 +104,8 @@ def _parse_ts(ts): Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS) into an MPEG PES timestamp: a tick counter at 90 kHz resolution. """ + if ts is None or not isinstance(ts, re.Match): + raise ValueError("Invalid timestamp match for _parse_ts") return 90 * sum( int(part or 0) * mult for part, mult in zip(ts.groups(), (3600_000, 60_000, 1000, 1))) @@ -146,47 +150,34 @@ class HeaderBlock(Block): class Magic(HeaderBlock): _REGEX = re.compile(r'\ufeff?WEBVTT([ \t][^\r\n]*)?(?:\r\n|[\r\n])') - # XXX: The X-TIMESTAMP-MAP extension is described in RFC 8216 ยง3.5 - # , but the RFC - # doesn't specify the exact grammar nor where in the WebVTT - # syntax it should be placed; the below has been devised based - # on usage in the wild - # - # And strictly speaking, the presence of this extension violates - # the W3C WebVTT spec. Oh well. - _REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=') _REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:') _REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)') _REGEX_TSMAP_SEP = re.compile(r'[ \t]*,[ \t]*') - - # This was removed from the spec in the 2017 revision; - # the last spec draft to describe this syntax element is - # . - # Nevertheless, YouTube keeps serving those _REGEX_META = re.compile(r'(?:(?!-->)[^\r\n])+:(?:(?!-->)[^\r\n])+(?:\r\n|[\r\n])') @classmethod def __parse_tsmap(cls, parser): parser = parser.child() + local, mpegts = None, None while True: - m = parser.consume(cls._REGEX_TSMAP_LOCAL) - if m: + if parser.consume(cls._REGEX_TSMAP_LOCAL): m = parser.consume(_REGEX_TS) - if m is None: + if not m: raise ParseError(parser) local = _parse_ts(m) - if local is None: + elif parser.consume(cls._REGEX_TSMAP_MPEGTS): + m = parser.match(cls._REGEX_TSMAP_MPEGTS) + if not m: raise ParseError(parser) + mpegts = int_or_none(m.group(1)) + if mpegts is None: + raise ParseError(parser) + parser.advance(m) else: - m = parser.consume(cls._REGEX_TSMAP_MPEGTS) - if m: - mpegts = int_or_none(m.group(1)) - if mpegts is None: - raise ParseError(parser) - else: - raise ParseError(parser) + raise ParseError(parser) + if parser.consume(cls._REGEX_TSMAP_SEP): continue if parser.consume(_REGEX_NL): @@ -220,14 +211,14 @@ def parse(cls, parser): def write_into(self, stream): stream.write('WEBVTT') - if self.extra is not None: + if self.extra: stream.write(self.extra) stream.write('\n') - if self.local or self.mpegts: + if self.local is not None or self.mpegts is not None: stream.write('X-TIMESTAMP-MAP=LOCAL:') - stream.write(_format_ts(self.local if self.local is not None else 0)) + stream.write(_format_ts(self.local or 0)) stream.write(',MPEGTS:') - stream.write(str(self.mpegts if self.mpegts is not None else 0)) + stream.write(str(self.mpegts or 0)) stream.write('\n') if self.meta: stream.write(self.meta) @@ -278,9 +269,7 @@ def parse(cls, parser): id_ = m.group(1) m0 = parser.consume(_REGEX_TS) - if not m0: - return None - if not parser.consume(cls._REGEX_ARROW): + if not m0 or not parser.consume(cls._REGEX_ARROW): return None m1 = parser.consume(_REGEX_TS) if not m1: @@ -292,7 +281,7 @@ def parse(cls, parser): start = _parse_ts(m0) end = _parse_ts(m1) - settings = m2.group(1) if m2 is not None else None + settings = m2.group(1) if m2 else None text = io.StringIO() while True: @@ -302,55 +291,7 @@ def parse(cls, parser): text.write(m.group(0)) parser.commit() - return cls( - id=id_, - start=start, end=end, settings=settings, - text=text.getvalue(), - ) - - def write_into(self, stream): - if self.id is not None: - stream.write(self.id) - stream.write('\n') - stream.write(_format_ts(self.start)) - stream.write(' --> ') - stream.write(_format_ts(self.end)) - if self.settings is not None: - stream.write(' ') - stream.write(self.settings) - stream.write('\n') - stream.write(self.text) - stream.write('\n') - - @property - def as_json(self): - return { - 'id': self.id, - 'start': self.start, - 'end': self.end, - 'text': self.text, - 'settings': self.settings, - } - - def __eq__(self, other): - return self.as_json == other.as_json - - @classmethod - def from_json(cls, json): - return cls( - id=json['id'], - start=json['start'], - end=json['end'], - text=json['text'], - settings=json['settings'], - ) - - def hinges(self, other): - if self.text != other.text: - return False - if self.settings != other.settings: - return False - return self.start <= self.end == other.start <= other.end + return cls(id=id_, start=start, end=end, settings=settings, text=text.getvalue()) def parse_fragment(frag_content): @@ -358,15 +299,16 @@ def parse_fragment(frag_content): A generator that yields (partially) parsed WebVTT blocks when given a bytes object containing the raw contents of a WebVTT file. """ + if not isinstance(frag_content, (bytes, bytearray)): + raise TypeError("Expected bytes for frag_content") - parser = _MatchParser(frag_content.decode()) + parser = _MatchParser(frag_content.decode(errors="replace")) yield Magic.parse(parser) while not parser.match(_REGEX_EOF): if parser.consume(_REGEX_BLANK): continue - block = RegionBlock.parse(parser) if block: yield block @@ -377,22 +319,19 @@ def parse_fragment(frag_content): continue block = CommentBlock.parse(parser) if block: - yield block # XXX: or skip + yield block continue - break while not parser.match(_REGEX_EOF): if parser.consume(_REGEX_BLANK): continue - block = CommentBlock.parse(parser) if block: - yield block # XXX: or skip + yield block continue block = CueBlock.parse(parser) if block: yield block continue - raise ParseError(parser)