1
0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-12-31 03:51:20 +00:00

[networking] Fix various socks proxy bugs (#8065)

- Fixed support for IPv6 socks proxies
- Fixed support for IPv6 over socks5
- Fixed --source-address not being obeyed for socks4 and socks5
- Fixed socks4a when the destination address is an IPv4 address

Closes https://github.com/yt-dlp/yt-dlp/issues/7959
Fixes https://github.com/ytdl-org/youtube-dl/issues/15368

Authored by: coletdjnz
Co-authored-by: Simon Sawicki <accounts@grub4k.xyz>
Co-authored-by: bashonly <bashonly@bashonly.com>
This commit is contained in:
coletdjnz
2023-09-18 07:33:26 +00:00
committed by GitHub
parent 81f46ac573
commit 20fbbd9249
4 changed files with 110 additions and 84 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib
import functools
import socket
import ssl
import sys
import typing
@@ -206,3 +207,59 @@ def wrap_request_errors(func):
e.handler = self
raise
return wrapper
def _socket_connect(ip_addr, timeout, source_address):
af, socktype, proto, canonname, sa = ip_addr
sock = socket.socket(af, socktype, proto)
try:
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except socket.error:
sock.close()
raise
def create_connection(
address,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None,
*,
_create_socket_func=_socket_connect
):
# Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
# This filters the addresses based on the given source_address.
# Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
host, port = address
ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
if not ip_addrs:
raise socket.error('getaddrinfo returns an empty list')
if source_address is not None:
af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
if not ip_addrs:
raise OSError(
f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
f'Can\'t use "{source_address[0]}" as source address')
err = None
for ip_addr in ip_addrs:
try:
sock = _create_socket_func(ip_addr, timeout, source_address)
# Explicitly break __traceback__ reference cycle
# https://bugs.python.org/issue36820
err = None
return sock
except socket.error as e:
err = e
try:
raise err
finally:
# Explicitly break __traceback__ reference cycle
# https://bugs.python.org/issue36820
err = None

View File

@@ -23,6 +23,7 @@ from urllib.request import (
from ._helper import (
InstanceStoreMixin,
add_accept_encoding_header,
create_connection,
get_redirect_method,
make_socks_proxy_opts,
select_proxy,
@@ -54,44 +55,10 @@ if brotli:
def _create_http_connection(http_class, source_address, *args, **kwargs):
hc = http_class(*args, **kwargs)
if hasattr(hc, '_create_connection'):
hc._create_connection = create_connection
if source_address is not None:
# This is to workaround _create_connection() from socket where it will try all
# address data from getaddrinfo() including IPv6. This filters the result from
# getaddrinfo() based on the source_address value.
# This is based on the cpython socket.create_connection() function.
# https://github.com/python/cpython/blob/master/Lib/socket.py#L691
def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
host, port = address
err = None
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6
ip_addrs = [addr for addr in addrs if addr[0] == af]
if addrs and not ip_addrs:
ip_version = 'v4' if af == socket.AF_INET else 'v6'
raise OSError(
"No remote IP%s addresses available for connect, can't use '%s' as source address"
% (ip_version, source_address[0]))
for res in ip_addrs:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
sock.bind(source_address)
sock.connect(sa)
err = None # Explicitly break reference cycle
return sock
except OSError as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise OSError('getaddrinfo returns an empty list')
if hasattr(hc, '_create_connection'):
hc._create_connection = _create_connection
hc.source_address = (source_address, 0)
return hc
@@ -220,13 +187,28 @@ def make_socks_conn_class(base_class, socks_proxy):
proxy_args = make_socks_proxy_opts(socks_proxy)
class SocksConnection(base_class):
def connect(self):
self.sock = sockssocket()
self.sock.setproxy(**proxy_args)
if type(self.timeout) in (int, float): # noqa: E721
self.sock.settimeout(self.timeout)
self.sock.connect((self.host, self.port))
_create_connection = create_connection
def connect(self):
def sock_socket_connect(ip_addr, timeout, source_address):
af, socktype, proto, canonname, sa = ip_addr
sock = sockssocket(af, socktype, proto)
try:
connect_proxy_args = proxy_args.copy()
connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
sock.setproxy(**connect_proxy_args)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect((self.host, self.port))
return sock
except socket.error:
sock.close()
raise
self.sock = create_connection(
(proxy_args['addr'], proxy_args['port']), timeout=self.timeout,
source_address=self.source_address, _create_socket_func=sock_socket_connect)
if isinstance(self, http.client.HTTPSConnection):
self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)

View File

@@ -134,26 +134,31 @@ class sockssocket(socket.socket):
self.close()
raise InvalidVersionError(expected_version, got_version)
def _resolve_address(self, destaddr, default, use_remote_dns):
try:
return socket.inet_aton(destaddr)
except OSError:
if use_remote_dns and self._proxy.remote_dns:
return default
else:
return socket.inet_aton(socket.gethostbyname(destaddr))
def _resolve_address(self, destaddr, default, use_remote_dns, family=None):
for f in (family,) if family else (socket.AF_INET, socket.AF_INET6):
try:
return f, socket.inet_pton(f, destaddr)
except OSError:
continue
if use_remote_dns and self._proxy.remote_dns:
return 0, default
else:
res = socket.getaddrinfo(destaddr, None, family=family or 0)
f, _, _, _, ipaddr = res[0]
return f, socket.inet_pton(f, ipaddr[0])
def _setup_socks4(self, address, is_4a=False):
destaddr, port = address
ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a)
_, ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a, family=socket.AF_INET)
packet = struct.pack('!BBH', SOCKS4_VERSION, Socks4Command.CMD_CONNECT, port) + ipaddr
username = (self._proxy.username or '').encode()
packet += username + b'\x00'
if is_4a and self._proxy.remote_dns:
if is_4a and self._proxy.remote_dns and ipaddr == SOCKS4_DEFAULT_DSTIP:
packet += destaddr.encode() + b'\x00'
self.sendall(packet)
@@ -210,7 +215,7 @@ class sockssocket(socket.socket):
def _setup_socks5(self, address):
destaddr, port = address
ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True)
family, ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True)
self._socks5_auth()
@@ -220,8 +225,10 @@ class sockssocket(socket.socket):
destaddr = destaddr.encode()
packet += struct.pack('!B', Socks5AddressType.ATYP_DOMAINNAME)
packet += self._len_and_data(destaddr)
else:
elif family == socket.AF_INET:
packet += struct.pack('!B', Socks5AddressType.ATYP_IPV4) + ipaddr
elif family == socket.AF_INET6:
packet += struct.pack('!B', Socks5AddressType.ATYP_IPV6) + ipaddr
packet += struct.pack('!H', port)
self.sendall(packet)