mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2026-01-07 07:21:18 +00:00
[networking] Add keep_header_casing extension (#11652)
Authored by: coletdjnz, Grub4K Co-authored-by: coletdjnz <coletdjnz@protonmail.com>
This commit is contained in:
@@ -296,6 +296,7 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
|
||||
extensions.pop('cookiejar', None)
|
||||
extensions.pop('timeout', None)
|
||||
extensions.pop('legacy_ssl', None)
|
||||
extensions.pop('keep_header_casing', None)
|
||||
|
||||
def _create_instance(self, cookiejar, legacy_ssl_support=None):
|
||||
session = RequestsSession()
|
||||
@@ -312,11 +313,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
|
||||
session.trust_env = False # no need, we already load proxies from env
|
||||
return session
|
||||
|
||||
def _send(self, request):
|
||||
|
||||
headers = self._merge_headers(request.headers)
|
||||
def _prepare_headers(self, _, headers):
|
||||
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
|
||||
|
||||
def _send(self, request):
|
||||
|
||||
headers = self._get_headers(request)
|
||||
max_redirects_exceeded = False
|
||||
|
||||
session = self._get_instance(
|
||||
|
||||
@@ -379,13 +379,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
|
||||
opener.addheaders = []
|
||||
return opener
|
||||
|
||||
def _send(self, request):
|
||||
headers = self._merge_headers(request.headers)
|
||||
def _prepare_headers(self, _, headers):
|
||||
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
|
||||
|
||||
def _send(self, request):
|
||||
headers = self._get_headers(request)
|
||||
urllib_req = urllib.request.Request(
|
||||
url=request.url,
|
||||
data=request.data,
|
||||
headers=dict(headers),
|
||||
headers=headers,
|
||||
method=request.method,
|
||||
)
|
||||
|
||||
|
||||
@@ -116,6 +116,7 @@ class WebsocketsRH(WebSocketRequestHandler):
|
||||
extensions.pop('timeout', None)
|
||||
extensions.pop('cookiejar', None)
|
||||
extensions.pop('legacy_ssl', None)
|
||||
extensions.pop('keep_header_casing', None)
|
||||
|
||||
def close(self):
|
||||
# Remove the logging handler that contains a reference to our logger
|
||||
@@ -123,15 +124,16 @@ class WebsocketsRH(WebSocketRequestHandler):
|
||||
for name, handler in self.__logging_handlers.items():
|
||||
logging.getLogger(name).removeHandler(handler)
|
||||
|
||||
def _send(self, request):
|
||||
timeout = self._calculate_timeout(request)
|
||||
headers = self._merge_headers(request.headers)
|
||||
def _prepare_headers(self, request, headers):
|
||||
if 'cookie' not in headers:
|
||||
cookiejar = self._get_cookiejar(request)
|
||||
cookie_header = cookiejar.get_cookie_header(request.url)
|
||||
if cookie_header:
|
||||
headers['cookie'] = cookie_header
|
||||
|
||||
def _send(self, request):
|
||||
timeout = self._calculate_timeout(request)
|
||||
headers = self._get_headers(request)
|
||||
wsuri = parse_uri(request.url)
|
||||
create_conn_kwargs = {
|
||||
'source_address': (self.source_address, 0) if self.source_address else None,
|
||||
|
||||
@@ -206,6 +206,7 @@ class RequestHandler(abc.ABC):
|
||||
- `cookiejar`: Cookiejar to use for this request.
|
||||
- `timeout`: socket timeout to use for this request.
|
||||
- `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support.
|
||||
- `keep_header_casing`: Keep the casing of headers when sending the request.
|
||||
To enable these, add extensions.pop('<extension>', None) to _check_extensions
|
||||
|
||||
Apart from the url protocol, proxies dict may contain the following keys:
|
||||
@@ -259,6 +260,23 @@ class RequestHandler(abc.ABC):
|
||||
def _merge_headers(self, request_headers):
|
||||
return HTTPHeaderDict(self.headers, request_headers)
|
||||
|
||||
def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
|
||||
"""Additional operations to prepare headers before building. To be extended by subclasses.
|
||||
@param request: Request object
|
||||
@param headers: Merged headers to prepare
|
||||
"""
|
||||
|
||||
def _get_headers(self, request: Request) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for external use.
|
||||
Subclasses may define a _prepare_headers method to modify headers after merge but before building.
|
||||
"""
|
||||
headers = self._merge_headers(request.headers)
|
||||
self._prepare_headers(request, headers)
|
||||
if request.extensions.get('keep_header_casing'):
|
||||
return headers.sensitive()
|
||||
return dict(headers)
|
||||
|
||||
def _calculate_timeout(self, request):
|
||||
return float(request.extensions.get('timeout') or self.timeout)
|
||||
|
||||
@@ -317,6 +335,7 @@ class RequestHandler(abc.ABC):
|
||||
assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
|
||||
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
|
||||
assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType))
|
||||
assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType))
|
||||
|
||||
def _validate(self, request):
|
||||
self._check_url_scheme(request)
|
||||
|
||||
@@ -5,11 +5,11 @@ from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .common import RequestHandler, register_preference
|
||||
from .common import RequestHandler, register_preference, Request
|
||||
from .exceptions import UnsupportedRequest
|
||||
from ..compat.types import NoneType
|
||||
from ..utils import classproperty, join_nonempty
|
||||
from ..utils.networking import std_headers
|
||||
from ..utils.networking import std_headers, HTTPHeaderDict
|
||||
|
||||
|
||||
@dataclass(order=True, frozen=True)
|
||||
@@ -123,7 +123,17 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
|
||||
"""Get the requested target for the request"""
|
||||
return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)
|
||||
|
||||
def _get_impersonate_headers(self, request):
|
||||
def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
|
||||
"""Additional operations to prepare headers before building. To be extended by subclasses.
|
||||
@param request: Request object
|
||||
@param headers: Merged headers to prepare
|
||||
"""
|
||||
|
||||
def _get_impersonate_headers(self, request: Request) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for external impersonation use.
|
||||
Subclasses may define a _prepare_impersonate_headers method to modify headers after merge but before building.
|
||||
"""
|
||||
headers = self._merge_headers(request.headers)
|
||||
if self._get_request_target(request) is not None:
|
||||
# remove all headers present in std_headers
|
||||
@@ -131,7 +141,11 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
|
||||
for k, v in std_headers.items():
|
||||
if headers.get(k) == v:
|
||||
headers.pop(k)
|
||||
return headers
|
||||
|
||||
self._prepare_impersonate_headers(request, headers)
|
||||
if request.extensions.get('keep_header_casing'):
|
||||
return headers.sensitive()
|
||||
return dict(headers)
|
||||
|
||||
|
||||
@register_preference(ImpersonateRequestHandler)
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import random
|
||||
import typing
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
from ._utils import remove_start
|
||||
if typing.TYPE_CHECKING:
|
||||
T = typing.TypeVar('T')
|
||||
|
||||
from ._utils import NO_DEFAULT, remove_start
|
||||
|
||||
|
||||
def random_user_agent():
|
||||
@@ -51,32 +58,141 @@ def random_user_agent():
|
||||
return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS)
|
||||
|
||||
|
||||
class HTTPHeaderDict(collections.UserDict, dict):
|
||||
class HTTPHeaderDict(dict):
|
||||
"""
|
||||
Store and access keys case-insensitively.
|
||||
The constructor can take multiple dicts, in which keys in the latter are prioritised.
|
||||
|
||||
Retains a case sensitive mapping of the headers, which can be accessed via `.sensitive()`.
|
||||
"""
|
||||
def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
|
||||
obj = dict.__new__(cls, *args, **kwargs)
|
||||
obj.__sensitive_map = {}
|
||||
return obj
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, /, *args, **kwargs):
|
||||
super().__init__()
|
||||
for dct in args:
|
||||
if dct is not None:
|
||||
self.update(dct)
|
||||
self.update(kwargs)
|
||||
self.__sensitive_map = {}
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('latin-1')
|
||||
super().__setitem__(key.title(), str(value).strip())
|
||||
for dct in filter(None, args):
|
||||
self.update(dct)
|
||||
if kwargs:
|
||||
self.update(kwargs)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def sensitive(self, /) -> dict[str, str]:
|
||||
return {
|
||||
self.__sensitive_map[key]: value
|
||||
for key, value in self.items()
|
||||
}
|
||||
|
||||
def __contains__(self, key: str, /) -> bool:
|
||||
return super().__contains__(key.title() if isinstance(key, str) else key)
|
||||
|
||||
def __delitem__(self, key: str, /) -> None:
|
||||
key = key.title()
|
||||
del self.__sensitive_map[key]
|
||||
super().__delitem__(key)
|
||||
|
||||
def __getitem__(self, key, /) -> str:
|
||||
return super().__getitem__(key.title())
|
||||
|
||||
def __delitem__(self, key):
|
||||
super().__delitem__(key.title())
|
||||
def __ior__(self, other, /):
|
||||
if isinstance(other, type(self)):
|
||||
other = other.sensitive()
|
||||
if isinstance(other, dict):
|
||||
self.update(other)
|
||||
return
|
||||
return NotImplemented
|
||||
|
||||
def __contains__(self, key):
|
||||
return super().__contains__(key.title() if isinstance(key, str) else key)
|
||||
def __or__(self, other, /) -> typing.Self:
|
||||
if isinstance(other, type(self)):
|
||||
other = other.sensitive()
|
||||
if isinstance(other, dict):
|
||||
return type(self)(self.sensitive(), other)
|
||||
return NotImplemented
|
||||
|
||||
def __ror__(self, other, /) -> typing.Self:
|
||||
if isinstance(other, type(self)):
|
||||
other = other.sensitive()
|
||||
if isinstance(other, dict):
|
||||
return type(self)(other, self.sensitive())
|
||||
return NotImplemented
|
||||
|
||||
def __setitem__(self, key: str, value, /) -> None:
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('latin-1')
|
||||
key_title = key.title()
|
||||
self.__sensitive_map[key_title] = key
|
||||
super().__setitem__(key_title, str(value).strip())
|
||||
|
||||
def clear(self, /) -> None:
|
||||
self.__sensitive_map.clear()
|
||||
super().clear()
|
||||
|
||||
def copy(self, /) -> typing.Self:
|
||||
return type(self)(self.sensitive())
|
||||
|
||||
@typing.overload
|
||||
def get(self, key: str, /) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
def get(self, key: str, /, default: T) -> str | T: ...
|
||||
|
||||
def get(self, key, /, default=NO_DEFAULT):
|
||||
key = key.title()
|
||||
if default is NO_DEFAULT:
|
||||
return super().get(key)
|
||||
return super().get(key, default)
|
||||
|
||||
@typing.overload
|
||||
def pop(self, key: str, /) -> str: ...
|
||||
|
||||
@typing.overload
|
||||
def pop(self, key: str, /, default: T) -> str | T: ...
|
||||
|
||||
def pop(self, key, /, default=NO_DEFAULT):
|
||||
key = key.title()
|
||||
if default is NO_DEFAULT:
|
||||
self.__sensitive_map.pop(key)
|
||||
return super().pop(key)
|
||||
self.__sensitive_map.pop(key, default)
|
||||
return super().pop(key, default)
|
||||
|
||||
def popitem(self) -> tuple[str, str]:
|
||||
self.__sensitive_map.popitem()
|
||||
return super().popitem()
|
||||
|
||||
@typing.overload
|
||||
def setdefault(self, key: str, /) -> str: ...
|
||||
|
||||
@typing.overload
|
||||
def setdefault(self, key: str, /, default) -> str: ...
|
||||
|
||||
def setdefault(self, key, /, default=None) -> str:
|
||||
key = key.title()
|
||||
if key in self.__sensitive_map:
|
||||
return super().__getitem__(key)
|
||||
|
||||
self[key] = default or ''
|
||||
return self[key]
|
||||
|
||||
def update(self, other, /, **kwargs) -> None:
|
||||
if isinstance(other, type(self)):
|
||||
other = other.sensitive()
|
||||
if isinstance(other, collections.abc.Mapping):
|
||||
for key, value in other.items():
|
||||
self[key] = value
|
||||
|
||||
elif hasattr(other, 'keys'):
|
||||
for key in other.keys(): # noqa: SIM118
|
||||
self[key] = other[key]
|
||||
|
||||
else:
|
||||
for key, value in other:
|
||||
self[key] = value
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self[key] = value
|
||||
|
||||
|
||||
std_headers = HTTPHeaderDict({
|
||||
|
||||
Reference in New Issue
Block a user