XHTTP client: Move x_padding into Referer header (#4298)

""Breaking"": Update the server side first, then client
This commit is contained in:
rPDmYQ 2025-01-18 20:05:19 +08:00 committed by GitHub
parent 30cb22afb1
commit 14a6636a41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 182 additions and 59 deletions

View File

@ -5,6 +5,7 @@ import (
"context" "context"
_ "embed" _ "embed"
"encoding/base64" "encoding/base64"
"encoding/json"
"net/http" "net/http"
"time" "time"
@ -17,6 +18,12 @@ import (
//go:embed dialer.html //go:embed dialer.html
var webpage []byte var webpage []byte
type task struct {
Method string `json:"method"`
URL string `json:"url"`
Extra any `json:"extra,omitempty"`
}
var conns chan *websocket.Conn var conns chan *websocket.Conn
var upgrader = &websocket.Upgrader{ var upgrader = &websocket.Upgrader{
@ -55,23 +62,69 @@ func HasBrowserDialer() bool {
return conns != nil return conns != nil
} }
type webSocketExtra struct {
Protocol string `json:"protocol,omitempty"`
}
func DialWS(uri string, ed []byte) (*websocket.Conn, error) { func DialWS(uri string, ed []byte) (*websocket.Conn, error) {
data := []byte("WS " + uri) task := task{
Method: "WS",
URL: uri,
}
if ed != nil { if ed != nil {
data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...) task.Extra = webSocketExtra{
Protocol: base64.RawURLEncoding.EncodeToString(ed),
}
} }
return dialRaw(data) return dialTask(task)
} }
func DialGet(uri string) (*websocket.Conn, error) { type httpExtra struct {
data := []byte("GET " + uri) Referrer string `json:"referrer,omitempty"`
return dialRaw(data) Headers map[string]string `json:"headers,omitempty"`
} }
func DialPost(uri string, payload []byte) error { func httpExtraFromHeaders(headers http.Header) *httpExtra {
data := []byte("POST " + uri) if len(headers) == 0 {
conn, err := dialRaw(data) return nil
}
extra := httpExtra{}
if referrer := headers.Get("Referer"); referrer != "" {
extra.Referrer = referrer
headers.Del("Referer")
}
if len(headers) > 0 {
extra.Headers = make(map[string]string)
for header := range headers {
extra.Headers[header] = headers.Get(header)
}
}
return &extra
}
func DialGet(uri string, headers http.Header) (*websocket.Conn, error) {
task := task{
Method: "GET",
URL: uri,
Extra: httpExtraFromHeaders(headers),
}
return dialTask(task)
}
func DialPost(uri string, headers http.Header, payload []byte) error {
task := task{
Method: "POST",
URL: uri,
Extra: httpExtraFromHeaders(headers),
}
conn, err := dialTask(task)
if err != nil { if err != nil {
return err return err
} }
@ -90,7 +143,12 @@ func DialPost(uri string, payload []byte) error {
return nil return nil
} }
func dialRaw(data []byte) (*websocket.Conn, error) { func dialTask(task task) (*websocket.Conn, error) {
data, err := json.Marshal(task)
if err != nil {
return nil, err
}
var conn *websocket.Conn var conn *websocket.Conn
for { for {
conn = <-conns conn = <-conns
@ -100,7 +158,7 @@ func dialRaw(data []byte) (*websocket.Conn, error) {
break break
} }
} }
err := CheckOK(conn) err = CheckOK(conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -14,10 +14,28 @@
let upstreamGetCount = 0; let upstreamGetCount = 0;
let upstreamWsCount = 0; let upstreamWsCount = 0;
let upstreamPostCount = 0; let upstreamPostCount = 0;
function prepareRequestInit(extra) {
const requestInit = {};
if (extra.referrer) {
// note: we have to strip the protocol and host part.
// Browsers disallow that, and will reset the value to current page if attempted.
const referrer = URL.parse(extra.referrer);
requestInit.referrer = referrer.pathname + referrer.search + referrer.hash;
requestInit.referrerPolicy = "unsafe-url";
}
if (extra.headers) {
requestInit.headers = extra.headers;
}
return requestInit;
}
let check = function () { let check = function () {
if (clientIdleCount > 0) { if (clientIdleCount > 0) {
return; return;
}; }
clientIdleCount += 1; clientIdleCount += 1;
console.log("Prepare", url); console.log("Prepare", url);
let ws = new WebSocket(url); let ws = new WebSocket(url);
@ -29,12 +47,12 @@
// double-checking that this continues to work // double-checking that this continues to work
ws.onmessage = function (event) { ws.onmessage = function (event) {
clientIdleCount -= 1; clientIdleCount -= 1;
let [method, url, protocol] = event.data.split(" "); let task = JSON.parse(event.data);
switch (method) { switch (task.method) {
case "WS": { case "WS": {
upstreamWsCount += 1; upstreamWsCount += 1;
console.log("Dial WS", url, protocol); console.log("Dial WS", task.url, task.extra.protocol);
const wss = new WebSocket(url, protocol); const wss = new WebSocket(task.url, task.extra.protocol);
wss.binaryType = "arraybuffer"; wss.binaryType = "arraybuffer";
let opened = false; let opened = false;
ws.onmessage = function (event) { ws.onmessage = function (event) {
@ -60,10 +78,12 @@
wss.close() wss.close()
}; };
break; break;
}; }
case "GET": { case "GET": {
(async () => { (async () => {
console.log("Dial GET", url); const requestInit = prepareRequestInit(task.extra);
console.log("Dial GET", task.url);
ws.send("ok"); ws.send("ok");
const controller = new AbortController(); const controller = new AbortController();
@ -83,58 +103,62 @@
ws.onclose = (event) => { ws.onclose = (event) => {
try { try {
reader && reader.cancel(); reader && reader.cancel();
} catch(e) {}; } catch(e) {}
try { try {
controller.abort(); controller.abort();
} catch(e) {}; } catch(e) {}
}; };
try { try {
upstreamGetCount += 1; upstreamGetCount += 1;
const response = await fetch(url, {signal: controller.signal});
requestInit.signal = controller.signal;
const response = await fetch(task.url, requestInit);
const body = await response.body; const body = await response.body;
reader = body.getReader(); reader = body.getReader();
while (true) { while (true) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
ws.send(value); if (value) ws.send(value); // don't send back "undefined" string when received nothing
if (done) break; if (done) break;
}; }
} finally { } finally {
upstreamGetCount -= 1; upstreamGetCount -= 1;
console.log("Dial GET DONE, remaining: ", upstreamGetCount); console.log("Dial GET DONE, remaining: ", upstreamGetCount);
ws.close(); ws.close();
}; }
})(); })();
break; break;
}; }
case "POST": { case "POST": {
upstreamPostCount += 1; upstreamPostCount += 1;
console.log("Dial POST", url);
const requestInit = prepareRequestInit(task.extra);
requestInit.method = "POST";
console.log("Dial POST", task.url);
ws.send("ok"); ws.send("ok");
ws.onmessage = async (event) => { ws.onmessage = async (event) => {
try { try {
const response = await fetch( requestInit.body = event.data;
url, const response = await fetch(task.url, requestInit);
{method: "POST", body: event.data}
);
if (response.ok) { if (response.ok) {
ws.send("ok"); ws.send("ok");
} else { } else {
console.error("bad status code"); console.error("bad status code");
ws.send("fail"); ws.send("fail");
}; }
} finally { } finally {
upstreamPostCount -= 1; upstreamPostCount -= 1;
console.log("Dial POST DONE, remaining: ", upstreamPostCount); console.log("Dial POST DONE, remaining: ", upstreamPostCount);
ws.close(); ws.close();
}; }
}; };
break; break;
}; }
}; }
check(); check();
}; };

View File

@ -5,13 +5,15 @@ import (
"io" "io"
gonet "net" gonet "net"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/transport/internet/browser_dialer" "github.com/xtls/xray-core/transport/internet/browser_dialer"
"github.com/xtls/xray-core/transport/internet/websocket" "github.com/xtls/xray-core/transport/internet/websocket"
) )
// implements splithttp.DialerClient in terms of browser dialer // BrowserDialerClient implements splithttp.DialerClient in terms of browser dialer
// has no fields because everything is global state :O) type BrowserDialerClient struct {
type BrowserDialerClient struct{} transportConfig *Config
}
func (c *BrowserDialerClient) IsClosed() bool { func (c *BrowserDialerClient) IsClosed() bool {
panic("not implemented yet") panic("not implemented yet")
@ -19,10 +21,10 @@ func (c *BrowserDialerClient) IsClosed() bool {
func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (io.ReadCloser, gonet.Addr, gonet.Addr, error) { func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (io.ReadCloser, gonet.Addr, gonet.Addr, error) {
if body != nil { if body != nil {
panic("not implemented yet") return nil, nil, nil, errors.New("bidirectional streaming for browser dialer not implemented yet")
} }
conn, err := browser_dialer.DialGet(url) conn, err := browser_dialer.DialGet(url, c.transportConfig.GetRequestHeader())
dummyAddr := &gonet.IPAddr{} dummyAddr := &gonet.IPAddr{}
if err != nil { if err != nil {
return nil, dummyAddr, dummyAddr, err return nil, dummyAddr, dummyAddr, err
@ -37,7 +39,7 @@ func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, body i
return err return err
} }
err = browser_dialer.DialPost(url, bytes) err = browser_dialer.DialPost(url, c.transportConfig.GetRequestHeader(), bytes)
if err != nil { if err != nil {
return err return err
} }

View File

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"math/big" "math/big"
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
@ -11,6 +12,8 @@ import (
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
) )
const paddingQuery = "x_padding"
func (c *Config) GetNormalizedPath() string { func (c *Config) GetNormalizedPath() string {
pathAndQuery := strings.SplitN(c.Path, "?", 2) pathAndQuery := strings.SplitN(c.Path, "?", 2)
path := pathAndQuery[0] path := pathAndQuery[0]
@ -39,11 +42,6 @@ func (c *Config) GetNormalizedQuery() string {
} }
query += "x_version=" + core.Version() query += "x_version=" + core.Version()
paddingLen := c.GetNormalizedXPaddingBytes().rand()
if paddingLen > 0 {
query += "&x_padding=" + strings.Repeat("0", int(paddingLen))
}
return query return query
} }
@ -53,6 +51,28 @@ func (c *Config) GetRequestHeader() http.Header {
header.Add(k, v) header.Add(k, v)
} }
paddingLen := c.GetNormalizedXPaddingBytes().rand()
if paddingLen > 0 {
query, err := url.ParseQuery(c.GetNormalizedQuery())
if err != nil {
query = url.Values{}
}
// https://www.rfc-editor.org/rfc/rfc7541.html#appendix-B
// h2's HPACK Header Compression feature employs a huffman encoding using a static table.
// 'X' is assigned an 8 bit code, so HPACK compression won't change actual padding length on the wire.
// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2-2
// h3's similar QPACK feature uses the same huffman table.
query.Set(paddingQuery, strings.Repeat("X", int(paddingLen)))
referrer := url.URL{
Scheme: "https", // maybe http actually, but this part is not being checked
Host: c.Host,
Path: c.GetNormalizedPath(),
RawQuery: query.Encode(),
}
header.Set("Referer", referrer.String())
}
return header return header
} }
@ -63,7 +83,7 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter) {
writer.Header().Set("X-Version", core.Version()) writer.Header().Set("X-Version", core.Version())
paddingLen := c.GetNormalizedXPaddingBytes().rand() paddingLen := c.GetNormalizedXPaddingBytes().rand()
if paddingLen > 0 { if paddingLen > 0 {
writer.Header().Set("X-Padding", strings.Repeat("0", int(paddingLen))) writer.Header().Set("X-Padding", strings.Repeat("X", int(paddingLen)))
} }
} }

View File

@ -53,8 +53,8 @@ var (
func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) { func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) {
realityConfig := reality.ConfigFromStreamSettings(streamSettings) realityConfig := reality.ConfigFromStreamSettings(streamSettings)
if browser_dialer.HasBrowserDialer() && realityConfig != nil { if browser_dialer.HasBrowserDialer() && realityConfig == nil {
return &BrowserDialerClient{}, nil return &BrowserDialerClient{transportConfig: streamSettings.ProtocolSettings.(*Config)}, nil
} }
globalDialerAccess.Lock() globalDialerAccess.Lock()
@ -367,15 +367,18 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
}, },
} }
var err error
if mode == "stream-one" { if mode == "stream-one" {
requestURL.Path = transportConfiguration.GetNormalizedPath() requestURL.Path = transportConfiguration.GetNormalizedPath()
if xmuxClient != nil { if xmuxClient != nil {
xmuxClient.LeftRequests.Add(-1) xmuxClient.LeftRequests.Add(-1)
} }
conn.reader, conn.remoteAddr, conn.localAddr, _ = httpClient.OpenStream(ctx, requestURL.String(), reader, false) conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient.OpenStream(ctx, requestURL.String(), reader, false)
if err != nil { // browser dialer only
return nil, err
}
return stat.Connection(&conn), nil return stat.Connection(&conn), nil
} else { // stream-down } else { // stream-down
var err error
if xmuxClient2 != nil { if xmuxClient2 != nil {
xmuxClient2.LeftRequests.Add(-1) xmuxClient2.LeftRequests.Add(-1)
} }
@ -388,7 +391,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
if xmuxClient != nil { if xmuxClient != nil {
xmuxClient.LeftRequests.Add(-1) xmuxClient.LeftRequests.Add(-1)
} }
httpClient.OpenStream(ctx, requestURL.String(), reader, true) _, _, _, err = httpClient.OpenStream(ctx, requestURL.String(), reader, true)
if err != nil { // browser dialer only
return nil, err
}
return stat.Connection(&conn), nil return stat.Connection(&conn), nil
} }
@ -428,8 +434,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
// can reassign Path (potentially concurrently) // can reassign Path (potentially concurrently)
url := requestURL url := requestURL
url.Path += "/" + strconv.FormatInt(seq, 10) url.Path += "/" + strconv.FormatInt(seq, 10)
// reassign query to get different padding
url.RawQuery = transportConfiguration.GetNormalizedQuery()
seq += 1 seq += 1

View File

@ -7,6 +7,7 @@ import (
"io" "io"
gonet "net" gonet "net"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -110,9 +111,23 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
} }
validRange := h.config.GetNormalizedXPaddingBytes() validRange := h.config.GetNormalizedXPaddingBytes()
x_padding := int32(len(request.URL.Query().Get("x_padding"))) paddingLength := -1
if validRange.To > 0 && (x_padding < validRange.From || x_padding > validRange.To) {
errors.LogInfo(context.Background(), "invalid x_padding length:", x_padding) if referrerPadding := request.Header.Get("Referer"); referrerPadding != "" {
// Browser dialer cannot control the host part of referrer header, so only check the query
if referrerURL, err := url.Parse(referrerPadding); err == nil {
if query := referrerURL.Query(); query.Has(paddingQuery) {
paddingLength = len(query.Get(paddingQuery))
}
}
}
if paddingLength == -1 {
paddingLength = len(request.URL.Query().Get(paddingQuery))
}
if validRange.To > 0 && (int32(paddingLength) < validRange.From || int32(paddingLength) > validRange.To) {
errors.LogInfo(context.Background(), "invalid x_padding length:", int32(paddingLength))
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
return return
} }
@ -185,10 +200,10 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
return return
} }
payload, err := io.ReadAll(request.Body) payload, err := io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
if len(payload) > scMaxEachPostBytes { if len(payload) > scMaxEachPostBytes {
errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request had size ", len(payload), ". Adjust scMaxEachPostBytes on the server to be at least as large as client.") errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")
writer.WriteHeader(http.StatusRequestEntityTooLarge) writer.WriteHeader(http.StatusRequestEntityTooLarge)
return return
} }