Add WebSocket 0-RTT support (#375)

This commit is contained in:
RPRX 2021-03-14 07:10:10 +00:00 committed by GitHub
parent 9adce5a6c4
commit 60b06877bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 122 additions and 14 deletions

View file

@ -2,9 +2,12 @@ package websocket
import (
"context"
"encoding/base64"
"io"
"time"
"github.com/gorilla/websocket"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
@ -15,10 +18,21 @@ import (
// Dial dials a WebSocket connection to the given destination.
func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
conn, err := dialWebsocket(ctx, dest, streamSettings)
if err != nil {
return nil, newError("failed to dial WebSocket").Base(err)
var conn net.Conn
if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
ctx, cancel := context.WithCancel(ctx)
conn = &delayDialConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
dest: dest,
streamSettings: streamSettings,
}
} else {
var err error
if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
return nil, newError("failed to dial WebSocket").Base(err)
}
}
return internet.Connection(conn), nil
}
@ -27,7 +41,7 @@ func init() {
common.Must(internet.RegisterTransportDialer(protocolName, Dial))
}
func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
wsSettings := streamSettings.ProtocolSettings.(*Config)
dialer := &websocket.Dialer{
@ -52,7 +66,12 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
}
uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
header := wsSettings.GetRequestHeader()
if ed != nil {
header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed))
}
conn, resp, err := dialer.Dial(uri, header)
if err != nil {
var reason string
if resp != nil {
@ -61,5 +80,60 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
}
return newConnection(conn, conn.RemoteAddr()), nil
return newConnection(conn, conn.RemoteAddr(), nil), nil
}
type delayDialConn struct {
net.Conn
closed bool
dialed chan bool
cancel context.CancelFunc
ctx context.Context
dest net.Destination
streamSettings *internet.MemoryStreamConfig
}
func (d *delayDialConn) Write(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
ed := b
if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
ed = nil
}
var err error
if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
d.Close()
return 0, newError("failed to dial WebSocket").Base(err)
}
d.dialed <- true
if ed != nil {
return len(ed), nil
}
}
return d.Conn.Write(b)
}
func (d *delayDialConn) Read(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
select {
case <-d.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-d.dialed:
}
}
return d.Conn.Read(b)
}
func (d *delayDialConn) Close() error {
d.closed = true
d.cancel()
if d.Conn == nil {
return nil
}
return d.Conn.Close()
}