from __future__ import annotations import logging import re import threading import types import typing import h2.config # type: ignore[import-untyped] import h2.connection # type: ignore[import-untyped] import # type: ignore[import-untyped] from .._base_connection import _TYPE_BODY from .._collections import HTTPHeaderDict from ..connection import HTTPSConnection, _get_default_user_agent from ..exceptions import ConnectionError from ..response import BaseHTTPResponse orig_HTTPSConnection = HTTPSConnection T = typing.TypeVar("T") log = logging.getLogger(__name__) RE_IS_LEGAL_HEADER_NAME = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-z]+$") RE_IS_ILLEGAL_HEADER_VALUE = re.compile(rb"[\0\x00\x0a\x0d\r\n]|^[ \r\n\t]|[ \r\n\t]$") def _is_legal_header_name(name: bytes) -> bool: """ "An implementation that validates fields according to the definitions in Sections 5.1 and 5.5 of [HTTP] only needs an additional check that field names do not include uppercase characters." ( `http.client._is_legal_header_name` does not validate the field name according to the HTTP 1.1 spec, so we do that here, in addition to checking for uppercase characters. This does not allow for the `:` character in the header name, so should not be used to validate pseudo-headers. """ return bool(RE_IS_LEGAL_HEADER_NAME.match(name)) def _is_illegal_header_value(value: bytes) -> bool: """ "A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. A field value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB, 0x20 or 0x09)." ( """ return bool( class _LockedObject(typing.Generic[T]): """ A wrapper class that hides a specific object behind a lock. The goal here is to provide a simple way to protect access to an object that cannot safely be simultaneously accessed from multiple threads. The intended use of this class is simple: take hold of it with a context manager, which returns the protected object. """ __slots__ = ( "lock", "_obj", ) def __init__(self, obj: T): self.lock = threading.RLock() self._obj = obj def __enter__(self) -> T: self.lock.acquire() return self._obj def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: self.lock.release() class HTTP2Connection(HTTPSConnection): def __init__( self, host: str, port: int | None = None, **kwargs: typing.Any ) -> None: self._h2_conn = self._new_h2_conn() self._h2_stream: int | None = None self._headers: list[tuple[bytes, bytes]] = [] if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive: raise NotImplementedError("Proxies aren't supported with HTTP/2") super().__init__(host, port, **kwargs) if self._tunnel_host is not None: raise NotImplementedError("Tunneling isn't supported with HTTP/2") def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]: config = h2.config.H2Configuration(client_side=True) return _LockedObject(h2.connection.H2Connection(config=config)) def connect(self) -> None: super().connect() with self._h2_conn as conn: conn.initiate_connection() if data_to_send := conn.data_to_send(): self.sock.sendall(data_to_send) def putrequest( # type: ignore[override] self, method: str, url: str, **kwargs: typing.Any, ) -> None: """putrequest This deviates from the HTTPConnection method signature since we never need to override sending accept-encoding headers or the host header. """ if "skip_host" in kwargs: raise NotImplementedError("`skip_host` isn't supported") if "skip_accept_encoding" in kwargs: raise NotImplementedError("`skip_accept_encoding` isn't supported") self._request_url = url or "/" self._validate_path(url) # type: ignore[attr-defined] if ":" in authority = f"[{}]:{self.port or 443}" else: authority = f"{}:{self.port or 443}" self._headers.append((b":scheme", b"https")) self._headers.append((b":method", method.encode())) self._headers.append((b":authority", authority.encode())) self._headers.append((b":path", url.encode())) with self._h2_conn as conn: self._h2_stream = conn.get_next_available_stream_id() def putheader(self, header: str | bytes, *values: str | bytes) -> None: # TODO SKIPPABLE_HEADERS from urllib3 are ignored. header = header.encode() if isinstance(header, str) else header header = header.lower() # A lot of upstream code uses capitalized headers. if not _is_legal_header_name(header): raise ValueError(f"Illegal header name {str(header)}") for value in values: value = value.encode() if isinstance(value, str) else value if _is_illegal_header_value(value): raise ValueError(f"Illegal header value {str(value)}") self._headers.append((header, value)) def endheaders(self, message_body: typing.Any = None) -> None: # type: ignore[override] if self._h2_stream is None: raise ConnectionError("Must call `putrequest` first.") with self._h2_conn as conn: conn.send_headers( stream_id=self._h2_stream, headers=self._headers, end_stream=(message_body is None), ) if data_to_send := conn.data_to_send(): self.sock.sendall(data_to_send) self._headers = [] # Reset headers for the next request. def send(self, data: typing.Any) -> None: """Send data to the server. `data` can be: `str`, `bytes`, an iterable, or file-like objects that support a .read() method. """ if self._h2_stream is None: raise ConnectionError("Must call `putrequest` first.") with self._h2_conn as conn: if data_to_send := conn.data_to_send(): self.sock.sendall(data_to_send) if hasattr(data, "read"): # file-like objects while True: chunk = if not chunk: break if isinstance(chunk, str): chunk = chunk.encode() # pragma: no cover conn.send_data(self._h2_stream, chunk, end_stream=False) if data_to_send := conn.data_to_send(): self.sock.sendall(data_to_send) conn.end_stream(self._h2_stream) return if isinstance(data, str): # str -> bytes data = data.encode() try: if isinstance(data, bytes): conn.send_data(self._h2_stream, data, end_stream=True) if data_to_send := conn.data_to_send(): self.sock.sendall(data_to_send) else: for chunk in data: conn.send_data(self._h2_stream, chunk, end_stream=False) if data_to_send := conn.data_to_send(): self.sock.sendall(data_to_send) conn.end_stream(self._h2_stream) except TypeError: raise TypeError( "`data` should be str, bytes, iterable, or file. got %r" % type(data) ) def set_tunnel( self, host: str, port: int | None = None, headers: typing.Mapping[str, str] | None = None, scheme: str = "http", ) -> None: raise NotImplementedError( "HTTP/2 does not support setting up a tunnel through a proxy" ) def getresponse( # type: ignore[override] self, ) -> HTTP2Response: status = None data = bytearray() with self._h2_conn as conn: end_stream = False while not end_stream: # TODO: Arbitrary read value. if received_data := self.sock.recv(65535): events = conn.receive_data(received_data) for event in events: if isinstance(event, headers = HTTPHeaderDict() for header, value in event.headers: if header == b":status": status = int(value.decode()) else: headers.add( header.decode("ascii"), value.decode("ascii") ) elif isinstance(event, data += conn.acknowledge_received_data( event.flow_controlled_length, event.stream_id ) elif isinstance(event, end_stream = True if data_to_send := conn.data_to_send(): self.sock.sendall(data_to_send) assert status is not None return HTTP2Response( status=status, headers=headers, request_url=self._request_url, data=bytes(data), ) def request( # type: ignore[override] self, method: str, url: str, body: _TYPE_BODY | None = None, headers: typing.Mapping[str, str] | None = None, *, preload_content: bool = True, decode_content: bool = True, enforce_content_length: bool = True, **kwargs: typing.Any, ) -> None: """Send an HTTP/2 request""" if "chunked" in kwargs: # TODO this is often present from upstream. # raise NotImplementedError("`chunked` isn't supported with HTTP/2") pass if self.sock is not None: self.sock.settimeout(self.timeout) self.putrequest(method, url) headers = headers or {} for k, v in headers.items(): if k.lower() == "transfer-encoding" and v == "chunked": continue else: self.putheader(k, v) if b"user-agent" not in dict(self._headers): self.putheader(b"user-agent", _get_default_user_agent()) if body: self.endheaders(message_body=body) self.send(body) else: self.endheaders() def close(self) -> None: with self._h2_conn as conn: try: conn.close_connection() if data := conn.data_to_send(): self.sock.sendall(data) except Exception: pass # Reset all our HTTP/2 connection state. self._h2_conn = self._new_h2_conn() self._h2_stream = None self._headers = [] super().close() class HTTP2Response(BaseHTTPResponse): # TODO: This is a woefully incomplete response object, but works for non-streaming. def __init__( self, status: int, headers: HTTPHeaderDict, request_url: str, data: bytes, decode_content: bool = False, # TODO: support decoding ) -> None: super().__init__( status=status, headers=headers, # Following CPython, we map HTTP versions to major * 10 + minor integers version=20, version_string="HTTP/2", # No reason phrase in HTTP/2 reason=None, decode_content=decode_content, request_url=request_url, ) self._data = data self.length_remaining = 0 @property def data(self) -> bytes: return self._data def get_redirect_location(self) -> None: return None def close(self) -> None: pass