1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-06-27 17:08:32 +00:00

Improved WebVTT parser with robust error handling and input validation

Enhances the WebVTT partial parser by adding comprehensive error handling, type validation, and defensive checks to prevent unexpected failures during parsing. Specifically, input types are validated in _MatchParser and parse_fragment, ensuring only valid strings or bytes are accepted. Timestamp parsing now raises clear errors for invalid matches, while regex operations are guarded to avoid NoneType attribute errors. The .decode() step in parse_fragment uses safe fallback to handle invalid byte sequences gracefully.
This commit is contained in:
GiorgosTsak 2025-06-23 20:16:28 +03:00 committed by GitHub
parent 73bf102116
commit 45d132a6be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,8 @@ class _MatchParser:
""" """
def __init__(self, string): def __init__(self, string):
if not isinstance(string, str):
raise TypeError("Expected string input to _MatchParser")
self._data = string self._data = string
self._pos = 0 self._pos = 0
@ -31,7 +33,7 @@ def match(self, r):
if self._data.startswith(r, self._pos): if self._data.startswith(r, self._pos):
return len(r) return len(r)
return None return None
raise ValueError(r) raise ValueError(f"Expected regex or string, got {type(r).__name__}")
def advance(self, by): def advance(self, by):
if by is None: if by is None:
@ -43,7 +45,7 @@ def advance(self, by):
elif isinstance(by, int): elif isinstance(by, int):
amt = by amt = by
else: else:
raise ValueError(by) raise ValueError(f"Unsupported advance type: {type(by).__name__}")
self._pos += amt self._pos += amt
return by return by
@ -102,6 +104,8 @@ def _parse_ts(ts):
Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS) Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS)
into an MPEG PES timestamp: a tick counter at 90 kHz resolution. into an MPEG PES timestamp: a tick counter at 90 kHz resolution.
""" """
if ts is None or not isinstance(ts, re.Match):
raise ValueError("Invalid timestamp match for _parse_ts")
return 90 * sum( return 90 * sum(
int(part or 0) * mult for part, mult in zip(ts.groups(), (3600_000, 60_000, 1000, 1))) int(part or 0) * mult for part, mult in zip(ts.groups(), (3600_000, 60_000, 1000, 1)))
@ -146,47 +150,34 @@ class HeaderBlock(Block):
class Magic(HeaderBlock): class Magic(HeaderBlock):
_REGEX = re.compile(r'\ufeff?WEBVTT([ \t][^\r\n]*)?(?:\r\n|[\r\n])') _REGEX = re.compile(r'\ufeff?WEBVTT([ \t][^\r\n]*)?(?:\r\n|[\r\n])')
# XXX: The X-TIMESTAMP-MAP extension is described in RFC 8216 §3.5
# <https://tools.ietf.org/html/rfc8216#section-3.5>, but the RFC
# doesn't specify the exact grammar nor where in the WebVTT
# syntax it should be placed; the below has been devised based
# on usage in the wild
#
# And strictly speaking, the presence of this extension violates
# the W3C WebVTT spec. Oh well.
_REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=') _REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=')
_REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:') _REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:')
_REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)') _REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)')
_REGEX_TSMAP_SEP = re.compile(r'[ \t]*,[ \t]*') _REGEX_TSMAP_SEP = re.compile(r'[ \t]*,[ \t]*')
# This was removed from the spec in the 2017 revision;
# the last spec draft to describe this syntax element is
# <https://www.w3.org/TR/2015/WD-webvtt1-20151208/#webvtt-metadata-header>.
# Nevertheless, YouTube keeps serving those
_REGEX_META = re.compile(r'(?:(?!-->)[^\r\n])+:(?:(?!-->)[^\r\n])+(?:\r\n|[\r\n])') _REGEX_META = re.compile(r'(?:(?!-->)[^\r\n])+:(?:(?!-->)[^\r\n])+(?:\r\n|[\r\n])')
@classmethod @classmethod
def __parse_tsmap(cls, parser): def __parse_tsmap(cls, parser):
parser = parser.child() parser = parser.child()
local, mpegts = None, None
while True: while True:
m = parser.consume(cls._REGEX_TSMAP_LOCAL) if parser.consume(cls._REGEX_TSMAP_LOCAL):
if m:
m = parser.consume(_REGEX_TS) m = parser.consume(_REGEX_TS)
if m is None: if not m:
raise ParseError(parser) raise ParseError(parser)
local = _parse_ts(m) local = _parse_ts(m)
if local is None: elif parser.consume(cls._REGEX_TSMAP_MPEGTS):
m = parser.match(cls._REGEX_TSMAP_MPEGTS)
if not m:
raise ParseError(parser) raise ParseError(parser)
mpegts = int_or_none(m.group(1))
if mpegts is None:
raise ParseError(parser)
parser.advance(m)
else: else:
m = parser.consume(cls._REGEX_TSMAP_MPEGTS) raise ParseError(parser)
if m:
mpegts = int_or_none(m.group(1))
if mpegts is None:
raise ParseError(parser)
else:
raise ParseError(parser)
if parser.consume(cls._REGEX_TSMAP_SEP): if parser.consume(cls._REGEX_TSMAP_SEP):
continue continue
if parser.consume(_REGEX_NL): if parser.consume(_REGEX_NL):
@ -220,14 +211,14 @@ def parse(cls, parser):
def write_into(self, stream): def write_into(self, stream):
stream.write('WEBVTT') stream.write('WEBVTT')
if self.extra is not None: if self.extra:
stream.write(self.extra) stream.write(self.extra)
stream.write('\n') stream.write('\n')
if self.local or self.mpegts: if self.local is not None or self.mpegts is not None:
stream.write('X-TIMESTAMP-MAP=LOCAL:') stream.write('X-TIMESTAMP-MAP=LOCAL:')
stream.write(_format_ts(self.local if self.local is not None else 0)) stream.write(_format_ts(self.local or 0))
stream.write(',MPEGTS:') stream.write(',MPEGTS:')
stream.write(str(self.mpegts if self.mpegts is not None else 0)) stream.write(str(self.mpegts or 0))
stream.write('\n') stream.write('\n')
if self.meta: if self.meta:
stream.write(self.meta) stream.write(self.meta)
@ -278,9 +269,7 @@ def parse(cls, parser):
id_ = m.group(1) id_ = m.group(1)
m0 = parser.consume(_REGEX_TS) m0 = parser.consume(_REGEX_TS)
if not m0: if not m0 or not parser.consume(cls._REGEX_ARROW):
return None
if not parser.consume(cls._REGEX_ARROW):
return None return None
m1 = parser.consume(_REGEX_TS) m1 = parser.consume(_REGEX_TS)
if not m1: if not m1:
@ -292,7 +281,7 @@ def parse(cls, parser):
start = _parse_ts(m0) start = _parse_ts(m0)
end = _parse_ts(m1) end = _parse_ts(m1)
settings = m2.group(1) if m2 is not None else None settings = m2.group(1) if m2 else None
text = io.StringIO() text = io.StringIO()
while True: while True:
@ -302,55 +291,7 @@ def parse(cls, parser):
text.write(m.group(0)) text.write(m.group(0))
parser.commit() parser.commit()
return cls( return cls(id=id_, start=start, end=end, settings=settings, text=text.getvalue())
id=id_,
start=start, end=end, settings=settings,
text=text.getvalue(),
)
def write_into(self, stream):
if self.id is not None:
stream.write(self.id)
stream.write('\n')
stream.write(_format_ts(self.start))
stream.write(' --> ')
stream.write(_format_ts(self.end))
if self.settings is not None:
stream.write(' ')
stream.write(self.settings)
stream.write('\n')
stream.write(self.text)
stream.write('\n')
@property
def as_json(self):
return {
'id': self.id,
'start': self.start,
'end': self.end,
'text': self.text,
'settings': self.settings,
}
def __eq__(self, other):
return self.as_json == other.as_json
@classmethod
def from_json(cls, json):
return cls(
id=json['id'],
start=json['start'],
end=json['end'],
text=json['text'],
settings=json['settings'],
)
def hinges(self, other):
if self.text != other.text:
return False
if self.settings != other.settings:
return False
return self.start <= self.end == other.start <= other.end
def parse_fragment(frag_content): def parse_fragment(frag_content):
@ -358,15 +299,16 @@ def parse_fragment(frag_content):
A generator that yields (partially) parsed WebVTT blocks when given A generator that yields (partially) parsed WebVTT blocks when given
a bytes object containing the raw contents of a WebVTT file. a bytes object containing the raw contents of a WebVTT file.
""" """
if not isinstance(frag_content, (bytes, bytearray)):
raise TypeError("Expected bytes for frag_content")
parser = _MatchParser(frag_content.decode()) parser = _MatchParser(frag_content.decode(errors="replace"))
yield Magic.parse(parser) yield Magic.parse(parser)
while not parser.match(_REGEX_EOF): while not parser.match(_REGEX_EOF):
if parser.consume(_REGEX_BLANK): if parser.consume(_REGEX_BLANK):
continue continue
block = RegionBlock.parse(parser) block = RegionBlock.parse(parser)
if block: if block:
yield block yield block
@ -377,22 +319,19 @@ def parse_fragment(frag_content):
continue continue
block = CommentBlock.parse(parser) block = CommentBlock.parse(parser)
if block: if block:
yield block # XXX: or skip yield block
continue continue
break break
while not parser.match(_REGEX_EOF): while not parser.match(_REGEX_EOF):
if parser.consume(_REGEX_BLANK): if parser.consume(_REGEX_BLANK):
continue continue
block = CommentBlock.parse(parser) block = CommentBlock.parse(parser)
if block: if block:
yield block # XXX: or skip yield block
continue continue
block = CueBlock.parse(parser) block = CueBlock.parse(parser)
if block: if block:
yield block yield block
continue continue
raise ParseError(parser) raise ParseError(parser)