From 45d132a6be467eeda7264c01a895b551a4c57e89 Mon Sep 17 00:00:00 2001 From: GiorgosTsak <147599978+GiorgosTsak@users.noreply.github.com> Date: Mon, 23 Jun 2025 20:16:28 +0300 Subject: [PATCH 1/2] Improved WebVTT parser with robust error handling and input validation Enhances the WebVTT partial parser by adding comprehensive error handling, type validation, and defensive checks to prevent unexpected failures during parsing. Specifically, input types are validated in _MatchParser and parse_fragment, ensuring only valid strings or bytes are accepted. Timestamp parsing now raises clear errors for invalid matches, while regex operations are guarded to avoid NoneType attribute errors. The .decode() step in parse_fragment uses safe fallback to handle invalid byte sequences gracefully. --- yt_dlp/webvtt.py | 121 ++++++++++++----------------------------------- 1 file changed, 30 insertions(+), 91 deletions(-) diff --git a/yt_dlp/webvtt.py b/yt_dlp/webvtt.py index 9f1a5086b..0aa455e0e 100644 --- a/yt_dlp/webvtt.py +++ b/yt_dlp/webvtt.py @@ -21,6 +21,8 @@ class _MatchParser: """ def __init__(self, string): + if not isinstance(string, str): + raise TypeError("Expected string input to _MatchParser") self._data = string self._pos = 0 @@ -31,7 +33,7 @@ def match(self, r): if self._data.startswith(r, self._pos): return len(r) return None - raise ValueError(r) + raise ValueError(f"Expected regex or string, got {type(r).__name__}") def advance(self, by): if by is None: @@ -43,7 +45,7 @@ def advance(self, by): elif isinstance(by, int): amt = by else: - raise ValueError(by) + raise ValueError(f"Unsupported advance type: {type(by).__name__}") self._pos += amt return by @@ -102,6 +104,8 @@ def _parse_ts(ts): Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS) into an MPEG PES timestamp: a tick counter at 90 kHz resolution. """ + if ts is None or not isinstance(ts, re.Match): + raise ValueError("Invalid timestamp match for _parse_ts") return 90 * sum( int(part or 0) * mult for part, mult in zip(ts.groups(), (3600_000, 60_000, 1000, 1))) @@ -146,47 +150,34 @@ class HeaderBlock(Block): class Magic(HeaderBlock): _REGEX = re.compile(r'\ufeff?WEBVTT([ \t][^\r\n]*)?(?:\r\n|[\r\n])') - # XXX: The X-TIMESTAMP-MAP extension is described in RFC 8216 ยง3.5 - # , but the RFC - # doesn't specify the exact grammar nor where in the WebVTT - # syntax it should be placed; the below has been devised based - # on usage in the wild - # - # And strictly speaking, the presence of this extension violates - # the W3C WebVTT spec. Oh well. - _REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=') _REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:') _REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)') _REGEX_TSMAP_SEP = re.compile(r'[ \t]*,[ \t]*') - - # This was removed from the spec in the 2017 revision; - # the last spec draft to describe this syntax element is - # . - # Nevertheless, YouTube keeps serving those _REGEX_META = re.compile(r'(?:(?!-->)[^\r\n])+:(?:(?!-->)[^\r\n])+(?:\r\n|[\r\n])') @classmethod def __parse_tsmap(cls, parser): parser = parser.child() + local, mpegts = None, None while True: - m = parser.consume(cls._REGEX_TSMAP_LOCAL) - if m: + if parser.consume(cls._REGEX_TSMAP_LOCAL): m = parser.consume(_REGEX_TS) - if m is None: + if not m: raise ParseError(parser) local = _parse_ts(m) - if local is None: + elif parser.consume(cls._REGEX_TSMAP_MPEGTS): + m = parser.match(cls._REGEX_TSMAP_MPEGTS) + if not m: raise ParseError(parser) + mpegts = int_or_none(m.group(1)) + if mpegts is None: + raise ParseError(parser) + parser.advance(m) else: - m = parser.consume(cls._REGEX_TSMAP_MPEGTS) - if m: - mpegts = int_or_none(m.group(1)) - if mpegts is None: - raise ParseError(parser) - else: - raise ParseError(parser) + raise ParseError(parser) + if parser.consume(cls._REGEX_TSMAP_SEP): continue if parser.consume(_REGEX_NL): @@ -220,14 +211,14 @@ def parse(cls, parser): def write_into(self, stream): stream.write('WEBVTT') - if self.extra is not None: + if self.extra: stream.write(self.extra) stream.write('\n') - if self.local or self.mpegts: + if self.local is not None or self.mpegts is not None: stream.write('X-TIMESTAMP-MAP=LOCAL:') - stream.write(_format_ts(self.local if self.local is not None else 0)) + stream.write(_format_ts(self.local or 0)) stream.write(',MPEGTS:') - stream.write(str(self.mpegts if self.mpegts is not None else 0)) + stream.write(str(self.mpegts or 0)) stream.write('\n') if self.meta: stream.write(self.meta) @@ -278,9 +269,7 @@ def parse(cls, parser): id_ = m.group(1) m0 = parser.consume(_REGEX_TS) - if not m0: - return None - if not parser.consume(cls._REGEX_ARROW): + if not m0 or not parser.consume(cls._REGEX_ARROW): return None m1 = parser.consume(_REGEX_TS) if not m1: @@ -292,7 +281,7 @@ def parse(cls, parser): start = _parse_ts(m0) end = _parse_ts(m1) - settings = m2.group(1) if m2 is not None else None + settings = m2.group(1) if m2 else None text = io.StringIO() while True: @@ -302,55 +291,7 @@ def parse(cls, parser): text.write(m.group(0)) parser.commit() - return cls( - id=id_, - start=start, end=end, settings=settings, - text=text.getvalue(), - ) - - def write_into(self, stream): - if self.id is not None: - stream.write(self.id) - stream.write('\n') - stream.write(_format_ts(self.start)) - stream.write(' --> ') - stream.write(_format_ts(self.end)) - if self.settings is not None: - stream.write(' ') - stream.write(self.settings) - stream.write('\n') - stream.write(self.text) - stream.write('\n') - - @property - def as_json(self): - return { - 'id': self.id, - 'start': self.start, - 'end': self.end, - 'text': self.text, - 'settings': self.settings, - } - - def __eq__(self, other): - return self.as_json == other.as_json - - @classmethod - def from_json(cls, json): - return cls( - id=json['id'], - start=json['start'], - end=json['end'], - text=json['text'], - settings=json['settings'], - ) - - def hinges(self, other): - if self.text != other.text: - return False - if self.settings != other.settings: - return False - return self.start <= self.end == other.start <= other.end + return cls(id=id_, start=start, end=end, settings=settings, text=text.getvalue()) def parse_fragment(frag_content): @@ -358,15 +299,16 @@ def parse_fragment(frag_content): A generator that yields (partially) parsed WebVTT blocks when given a bytes object containing the raw contents of a WebVTT file. """ + if not isinstance(frag_content, (bytes, bytearray)): + raise TypeError("Expected bytes for frag_content") - parser = _MatchParser(frag_content.decode()) + parser = _MatchParser(frag_content.decode(errors="replace")) yield Magic.parse(parser) while not parser.match(_REGEX_EOF): if parser.consume(_REGEX_BLANK): continue - block = RegionBlock.parse(parser) if block: yield block @@ -377,22 +319,19 @@ def parse_fragment(frag_content): continue block = CommentBlock.parse(parser) if block: - yield block # XXX: or skip + yield block continue - break while not parser.match(_REGEX_EOF): if parser.consume(_REGEX_BLANK): continue - block = CommentBlock.parse(parser) if block: - yield block # XXX: or skip + yield block continue block = CueBlock.parse(parser) if block: yield block continue - raise ParseError(parser) From f7ed9cb0726119484a9397e04d462412e7962eb5 Mon Sep 17 00:00:00 2001 From: GiorgosTsak <147599978+GiorgosTsak@users.noreply.github.com> Date: Mon, 23 Jun 2025 20:26:11 +0300 Subject: [PATCH 2/2] Improve WebVTT parser with robust error handling and input validation Enhances the WebVTT partial parser by adding comprehensive error handling, type validation, and defensive checks to prevent unexpected failures during parsing. Specifically, input types are validated in _MatchParser and parse_fragment, ensuring only valid strings or bytes are accepted. Timestamp parsing now raises clear errors for invalid matches, while regex operations are guarded to avoid NoneType attribute errors. The .decode() step in parse_fragment uses safe fallback to handle invalid byte sequences gracefully. --- yt_dlp/webvtt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/yt_dlp/webvtt.py b/yt_dlp/webvtt.py index 0aa455e0e..8ddbb88d2 100644 --- a/yt_dlp/webvtt.py +++ b/yt_dlp/webvtt.py @@ -22,7 +22,7 @@ class _MatchParser: def __init__(self, string): if not isinstance(string, str): - raise TypeError("Expected string input to _MatchParser") + raise TypeError('Expected string input to _MatchParser') self._data = string self._pos = 0 @@ -33,7 +33,7 @@ def match(self, r): if self._data.startswith(r, self._pos): return len(r) return None - raise ValueError(f"Expected regex or string, got {type(r).__name__}") + raise ValueError(f'Expected regex or string, got {type(r).__name__}') def advance(self, by): if by is None: @@ -45,7 +45,7 @@ def advance(self, by): elif isinstance(by, int): amt = by else: - raise ValueError(f"Unsupported advance type: {type(by).__name__}") + raise ValueError(f'Unsupported advance type: {type(by).__name__}') self._pos += amt return by @@ -105,7 +105,7 @@ def _parse_ts(ts): into an MPEG PES timestamp: a tick counter at 90 kHz resolution. """ if ts is None or not isinstance(ts, re.Match): - raise ValueError("Invalid timestamp match for _parse_ts") + raise ValueError('Invalid timestamp match for _parse_ts') return 90 * sum( int(part or 0) * mult for part, mult in zip(ts.groups(), (3600_000, 60_000, 1000, 1))) @@ -300,9 +300,9 @@ def parse_fragment(frag_content): a bytes object containing the raw contents of a WebVTT file. """ if not isinstance(frag_content, (bytes, bytearray)): - raise TypeError("Expected bytes for frag_content") + raise TypeError('Expected bytes for frag_content') - parser = _MatchParser(frag_content.decode(errors="replace")) + parser = _MatchParser(frag_content.decode(errors='replace')) yield Magic.parse(parser)