1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-06-27 17:08:32 +00:00

Improved WebVTT parser with robust error handling and input validation

Enhances the WebVTT partial parser by adding comprehensive error handling, type validation, and defensive checks to prevent unexpected failures during parsing. Specifically, input types are validated in _MatchParser and parse_fragment, ensuring only valid strings or bytes are accepted. Timestamp parsing now raises clear errors for invalid matches, while regex operations are guarded to avoid NoneType attribute errors. The .decode() step in parse_fragment uses safe fallback to handle invalid byte sequences gracefully.
This commit is contained in:
GiorgosTsak 2025-06-23 20:16:28 +03:00 committed by GitHub
parent 73bf102116
commit 45d132a6be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,8 @@ class _MatchParser:
"""
def __init__(self, string):
if not isinstance(string, str):
raise TypeError("Expected string input to _MatchParser")
self._data = string
self._pos = 0
@ -31,7 +33,7 @@ def match(self, r):
if self._data.startswith(r, self._pos):
return len(r)
return None
raise ValueError(r)
raise ValueError(f"Expected regex or string, got {type(r).__name__}")
def advance(self, by):
if by is None:
@ -43,7 +45,7 @@ def advance(self, by):
elif isinstance(by, int):
amt = by
else:
raise ValueError(by)
raise ValueError(f"Unsupported advance type: {type(by).__name__}")
self._pos += amt
return by
@ -102,6 +104,8 @@ def _parse_ts(ts):
Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS)
into an MPEG PES timestamp: a tick counter at 90 kHz resolution.
"""
if ts is None or not isinstance(ts, re.Match):
raise ValueError("Invalid timestamp match for _parse_ts")
return 90 * sum(
int(part or 0) * mult for part, mult in zip(ts.groups(), (3600_000, 60_000, 1000, 1)))
@ -146,47 +150,34 @@ class HeaderBlock(Block):
class Magic(HeaderBlock):
_REGEX = re.compile(r'\ufeff?WEBVTT([ \t][^\r\n]*)?(?:\r\n|[\r\n])')
# XXX: The X-TIMESTAMP-MAP extension is described in RFC 8216 §3.5
# <https://tools.ietf.org/html/rfc8216#section-3.5>, but the RFC
# doesn't specify the exact grammar nor where in the WebVTT
# syntax it should be placed; the below has been devised based
# on usage in the wild
#
# And strictly speaking, the presence of this extension violates
# the W3C WebVTT spec. Oh well.
_REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=')
_REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:')
_REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)')
_REGEX_TSMAP_SEP = re.compile(r'[ \t]*,[ \t]*')
# This was removed from the spec in the 2017 revision;
# the last spec draft to describe this syntax element is
# <https://www.w3.org/TR/2015/WD-webvtt1-20151208/#webvtt-metadata-header>.
# Nevertheless, YouTube keeps serving those
_REGEX_META = re.compile(r'(?:(?!-->)[^\r\n])+:(?:(?!-->)[^\r\n])+(?:\r\n|[\r\n])')
@classmethod
def __parse_tsmap(cls, parser):
parser = parser.child()
local, mpegts = None, None
while True:
m = parser.consume(cls._REGEX_TSMAP_LOCAL)
if m:
if parser.consume(cls._REGEX_TSMAP_LOCAL):
m = parser.consume(_REGEX_TS)
if m is None:
if not m:
raise ParseError(parser)
local = _parse_ts(m)
if local is None:
elif parser.consume(cls._REGEX_TSMAP_MPEGTS):
m = parser.match(cls._REGEX_TSMAP_MPEGTS)
if not m:
raise ParseError(parser)
mpegts = int_or_none(m.group(1))
if mpegts is None:
raise ParseError(parser)
parser.advance(m)
else:
m = parser.consume(cls._REGEX_TSMAP_MPEGTS)
if m:
mpegts = int_or_none(m.group(1))
if mpegts is None:
raise ParseError(parser)
else:
raise ParseError(parser)
raise ParseError(parser)
if parser.consume(cls._REGEX_TSMAP_SEP):
continue
if parser.consume(_REGEX_NL):
@ -220,14 +211,14 @@ def parse(cls, parser):
def write_into(self, stream):
stream.write('WEBVTT')
if self.extra is not None:
if self.extra:
stream.write(self.extra)
stream.write('\n')
if self.local or self.mpegts:
if self.local is not None or self.mpegts is not None:
stream.write('X-TIMESTAMP-MAP=LOCAL:')
stream.write(_format_ts(self.local if self.local is not None else 0))
stream.write(_format_ts(self.local or 0))
stream.write(',MPEGTS:')
stream.write(str(self.mpegts if self.mpegts is not None else 0))
stream.write(str(self.mpegts or 0))
stream.write('\n')
if self.meta:
stream.write(self.meta)
@ -278,9 +269,7 @@ def parse(cls, parser):
id_ = m.group(1)
m0 = parser.consume(_REGEX_TS)
if not m0:
return None
if not parser.consume(cls._REGEX_ARROW):
if not m0 or not parser.consume(cls._REGEX_ARROW):
return None
m1 = parser.consume(_REGEX_TS)
if not m1:
@ -292,7 +281,7 @@ def parse(cls, parser):
start = _parse_ts(m0)
end = _parse_ts(m1)
settings = m2.group(1) if m2 is not None else None
settings = m2.group(1) if m2 else None
text = io.StringIO()
while True:
@ -302,55 +291,7 @@ def parse(cls, parser):
text.write(m.group(0))
parser.commit()
return cls(
id=id_,
start=start, end=end, settings=settings,
text=text.getvalue(),
)
def write_into(self, stream):
if self.id is not None:
stream.write(self.id)
stream.write('\n')
stream.write(_format_ts(self.start))
stream.write(' --> ')
stream.write(_format_ts(self.end))
if self.settings is not None:
stream.write(' ')
stream.write(self.settings)
stream.write('\n')
stream.write(self.text)
stream.write('\n')
@property
def as_json(self):
return {
'id': self.id,
'start': self.start,
'end': self.end,
'text': self.text,
'settings': self.settings,
}
def __eq__(self, other):
return self.as_json == other.as_json
@classmethod
def from_json(cls, json):
return cls(
id=json['id'],
start=json['start'],
end=json['end'],
text=json['text'],
settings=json['settings'],
)
def hinges(self, other):
if self.text != other.text:
return False
if self.settings != other.settings:
return False
return self.start <= self.end == other.start <= other.end
return cls(id=id_, start=start, end=end, settings=settings, text=text.getvalue())
def parse_fragment(frag_content):
@ -358,15 +299,16 @@ def parse_fragment(frag_content):
A generator that yields (partially) parsed WebVTT blocks when given
a bytes object containing the raw contents of a WebVTT file.
"""
if not isinstance(frag_content, (bytes, bytearray)):
raise TypeError("Expected bytes for frag_content")
parser = _MatchParser(frag_content.decode())
parser = _MatchParser(frag_content.decode(errors="replace"))
yield Magic.parse(parser)
while not parser.match(_REGEX_EOF):
if parser.consume(_REGEX_BLANK):
continue
block = RegionBlock.parse(parser)
if block:
yield block
@ -377,22 +319,19 @@ def parse_fragment(frag_content):
continue
block = CommentBlock.parse(parser)
if block:
yield block # XXX: or skip
yield block
continue
break
while not parser.match(_REGEX_EOF):
if parser.consume(_REGEX_BLANK):
continue
block = CommentBlock.parse(parser)
if block:
yield block # XXX: or skip
yield block
continue
block = CueBlock.parse(parser)
if block:
yield block
continue
raise ParseError(parser)