Xray-core/proxy/vless/inbound/inbound.go
yuhan6665 009037b1e2 Add separate uplink/downlink flag for direct copy
For each connection, xtls need 4 flags for uplink/downlink reader/writer to decide when it switch to direct copy.
In the past, there were only one for read and one for write.
If service has xtls inbound and xtls outbound, the two flags may be corrupted by signal from different directions.
2025-01-26 12:18:16 -05:00

603 lines
19 KiB
Go

package inbound
import (
"bytes"
"context"
gotls "crypto/tls"
"io"
"reflect"
"strconv"
"strings"
"time"
"unsafe"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/retry"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
feature_inbound "github.com/xtls/xray-core/features/inbound"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/vless"
"github.com/xtls/xray-core/proxy/vless/encoding"
"github.com/xtls/xray-core/transport/internet/reality"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tls"
)
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
var dc dns.Client
if err := core.RequireFeatures(ctx, func(d dns.Client) error {
dc = d
return nil
}); err != nil {
return nil, err
}
c := config.(*Config)
validator := new(vless.MemoryValidator)
for _, user := range c.Clients {
u, err := user.ToMemoryUser()
if err != nil {
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
}
if err := validator.Add(u); err != nil {
return nil, errors.New("failed to initiate user").Base(err).AtError()
}
}
return New(ctx, c, dc, validator)
}))
}
// Handler is an inbound connection handler that handles messages in VLess protocol.
type Handler struct {
inboundHandlerManager feature_inbound.Manager
policyManager policy.Manager
validator vless.Validator
dns dns.Client
fallbacks map[string]map[string]map[string]*Fallback // or nil
// regexps map[string]*regexp.Regexp // or nil
}
// New creates a new VLess inbound handler.
func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
v := core.MustFromContext(ctx)
handler := &Handler{
inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
dns: dc,
validator: validator,
}
if config.Fallbacks != nil {
handler.fallbacks = make(map[string]map[string]map[string]*Fallback)
// handler.regexps = make(map[string]*regexp.Regexp)
for _, fb := range config.Fallbacks {
if handler.fallbacks[fb.Name] == nil {
handler.fallbacks[fb.Name] = make(map[string]map[string]*Fallback)
}
if handler.fallbacks[fb.Name][fb.Alpn] == nil {
handler.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback)
}
handler.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb
/*
if fb.Path != "" {
if r, err := regexp.Compile(fb.Path); err != nil {
return nil, errors.New("invalid path regexp").Base(err).AtError()
} else {
handler.regexps[fb.Path] = r
}
}
*/
}
if handler.fallbacks[""] != nil {
for name, apfb := range handler.fallbacks {
if name != "" {
for alpn := range handler.fallbacks[""] {
if apfb[alpn] == nil {
apfb[alpn] = make(map[string]*Fallback)
}
}
}
}
}
for _, apfb := range handler.fallbacks {
if apfb[""] != nil {
for alpn, pfb := range apfb {
if alpn != "" { // && alpn != "h2" {
for path, fb := range apfb[""] {
if pfb[path] == nil {
pfb[path] = fb
}
}
}
}
}
}
if handler.fallbacks[""] != nil {
for name, apfb := range handler.fallbacks {
if name != "" {
for alpn, pfb := range handler.fallbacks[""] {
for path, fb := range pfb {
if apfb[alpn][path] == nil {
apfb[alpn][path] = fb
}
}
}
}
}
}
}
return handler, nil
}
func isMuxAndNotXUDP(request *protocol.RequestHeader, first *buf.Buffer) bool {
if request.Command != protocol.RequestCommandMux {
return false
}
if first.Len() < 7 {
return true
}
firstBytes := first.Bytes()
return !(firstBytes[2] == 0 && // ID high
firstBytes[3] == 0 && // ID low
firstBytes[6] == 2) // Network type: UDP
}
// Close implements common.Closable.Close().
func (h *Handler) Close() error {
return errors.Combine(common.Close(h.validator))
}
// AddUser implements proxy.UserManager.AddUser().
func (h *Handler) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
return h.validator.Add(u)
}
// RemoveUser implements proxy.UserManager.RemoveUser().
func (h *Handler) RemoveUser(ctx context.Context, e string) error {
return h.validator.Del(e)
}
// GetUser implements proxy.UserManager.GetUser().
func (h *Handler) GetUser(ctx context.Context, email string) *protocol.MemoryUser {
return h.validator.GetByEmail(email)
}
// GetUsers implements proxy.UserManager.GetUsers().
func (h *Handler) GetUsers(ctx context.Context) []*protocol.MemoryUser {
return h.validator.GetAll()
}
// GetUsersCount implements proxy.UserManager.GetUsersCount().
func (h *Handler) GetUsersCount(context.Context) int64 {
return h.validator.GetCount()
}
// Network implements proxy.Inbound.Network().
func (*Handler) Network() []net.Network {
return []net.Network{net.Network_TCP, net.Network_UNIX}
}
// Process implements proxy.Inbound.Process().
func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
iConn := connection
if statConn, ok := iConn.(*stat.CounterConnection); ok {
iConn = statConn.Connection
}
sessionPolicy := h.policyManager.ForLevel(0)
if err := connection.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
return errors.New("unable to set read deadline").Base(err).AtWarning()
}
first := buf.FromBytes(make([]byte, buf.Size))
first.Clear()
firstLen, _ := first.ReadFrom(connection)
errors.LogInfo(ctx, "firstLen = ", firstLen)
reader := &buf.BufferedReader{
Reader: buf.NewReader(connection),
Buffer: buf.MultiBuffer{first},
}
var request *protocol.RequestHeader
var requestAddons *encoding.Addons
var err error
napfb := h.fallbacks
isfb := napfb != nil
if isfb && firstLen < 18 {
err = errors.New("fallback directly")
} else {
request, requestAddons, isfb, err = encoding.DecodeRequestHeader(isfb, first, reader, h.validator)
}
if err != nil {
if isfb {
if err := connection.SetReadDeadline(time.Time{}); err != nil {
errors.LogWarningInner(ctx, err, "unable to set back read deadline")
}
errors.LogInfoInner(ctx, err, "fallback starts")
name := ""
alpn := ""
if tlsConn, ok := iConn.(*tls.Conn); ok {
cs := tlsConn.ConnectionState()
name = cs.ServerName
alpn = cs.NegotiatedProtocol
errors.LogInfo(ctx, "realName = "+name)
errors.LogInfo(ctx, "realAlpn = "+alpn)
} else if realityConn, ok := iConn.(*reality.Conn); ok {
cs := realityConn.ConnectionState()
name = cs.ServerName
alpn = cs.NegotiatedProtocol
errors.LogInfo(ctx, "realName = "+name)
errors.LogInfo(ctx, "realAlpn = "+alpn)
}
name = strings.ToLower(name)
alpn = strings.ToLower(alpn)
if len(napfb) > 1 || napfb[""] == nil {
if name != "" && napfb[name] == nil {
match := ""
for n := range napfb {
if n != "" && strings.Contains(name, n) && len(n) > len(match) {
match = n
}
}
name = match
}
}
if napfb[name] == nil {
name = ""
}
apfb := napfb[name]
if apfb == nil {
return errors.New(`failed to find the default "name" config`).AtWarning()
}
if apfb[alpn] == nil {
alpn = ""
}
pfb := apfb[alpn]
if pfb == nil {
return errors.New(`failed to find the default "alpn" config`).AtWarning()
}
path := ""
if len(pfb) > 1 || pfb[""] == nil {
/*
if lines := bytes.Split(firstBytes, []byte{'\r', '\n'}); len(lines) > 1 {
if s := bytes.Split(lines[0], []byte{' '}); len(s) == 3 {
if len(s[0]) < 8 && len(s[1]) > 0 && len(s[2]) == 8 {
errors.New("realPath = " + string(s[1])).AtInfo().WriteToLog(sid)
for _, fb := range pfb {
if fb.Path != "" && h.regexps[fb.Path].Match(s[1]) {
path = fb.Path
break
}
}
}
}
}
*/
if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
firstBytes := first.Bytes()
for i := 4; i <= 8; i++ { // 5 -> 9
if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
search := len(firstBytes)
if search > 64 {
search = 64 // up to about 60
}
for j := i + 1; j < search; j++ {
k := firstBytes[j]
if k == '\r' || k == '\n' { // avoid logging \r or \n
break
}
if k == '?' || k == ' ' {
path = string(firstBytes[i:j])
errors.LogInfo(ctx, "realPath = "+path)
if pfb[path] == nil {
path = ""
}
break
}
}
break
}
}
}
}
fb := pfb[path]
if fb == nil {
return errors.New(`failed to find the default "path" config`).AtWarning()
}
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
var conn net.Conn
if err := retry.ExponentialBackoff(5, 100).On(func() error {
var dialer net.Dialer
conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
if err != nil {
return err
}
return nil
}); err != nil {
return errors.New("failed to dial to " + fb.Dest).Base(err).AtWarning()
}
defer conn.Close()
serverReader := buf.NewReader(conn)
serverWriter := buf.NewWriter(conn)
postRequest := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
if fb.Xver != 0 {
ipType := 4
remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
if err != nil {
ipType = 0
}
localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
if err != nil {
ipType = 0
}
if ipType == 4 {
for i := 0; i < len(remoteAddr); i++ {
if remoteAddr[i] == ':' {
ipType = 6
break
}
}
}
pro := buf.New()
defer pro.Release()
switch fb.Xver {
case 1:
if ipType == 0 {
pro.Write([]byte("PROXY UNKNOWN\r\n"))
break
}
if ipType == 4 {
pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))
} else {
pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))
}
case 2:
pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A")) // signature
if ipType == 0 {
pro.Write([]byte("\x20\x00\x00\x00")) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes
break
}
if ipType == 4 {
pro.Write([]byte("\x21\x11\x00\x0C")) // v2 + PROXY + AF_INET + STREAM + 12 bytes
pro.Write(net.ParseIP(remoteAddr).To4())
pro.Write(net.ParseIP(localAddr).To4())
} else {
pro.Write([]byte("\x21\x21\x00\x24")) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes
pro.Write(net.ParseIP(remoteAddr).To16())
pro.Write(net.ParseIP(localAddr).To16())
}
p1, _ := strconv.ParseUint(remotePort, 10, 16)
p2, _ := strconv.ParseUint(localPort, 10, 16)
pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)})
}
if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
return errors.New("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
}
}
if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to fallback request payload").Base(err).AtInfo()
}
return nil
}
writer := buf.NewWriter(connection)
getResponse := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to deliver response payload").Base(err).AtInfo()
}
return nil
}
if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
common.Interrupt(serverReader)
common.Interrupt(serverWriter)
return errors.New("fallback ends").Base(err).AtInfo()
}
return nil
}
if errors.Cause(err) != io.EOF {
log.Record(&log.AccessMessage{
From: connection.RemoteAddr(),
To: "",
Status: log.AccessRejected,
Reason: err,
})
err = errors.New("invalid request from ", connection.RemoteAddr()).Base(err).AtInfo()
}
return err
}
if err := connection.SetReadDeadline(time.Time{}); err != nil {
errors.LogWarningInner(ctx, err, "unable to set back read deadline")
}
errors.LogInfo(ctx, "received request for ", request.Destination())
inbound := session.InboundFromContext(ctx)
if inbound == nil {
panic("no inbound metadata")
}
inbound.Name = "vless"
inbound.User = request.User
account := request.User.Account.(*vless.MemoryAccount)
responseAddons := &encoding.Addons{
// Flow: requestAddons.Flow,
}
var input *bytes.Reader
var rawInput *bytes.Buffer
switch requestAddons.Flow {
case vless.XRV:
if account.Flow == requestAddons.Flow {
inbound.CanSpliceCopy = 2
switch request.Command {
case protocol.RequestCommandUDP:
return errors.New(requestAddons.Flow + " doesn't support UDP").AtWarning()
case protocol.RequestCommandMux:
fallthrough // we will break Mux connections that contain TCP requests
case protocol.RequestCommandTCP:
var t reflect.Type
var p uintptr
if tlsConn, ok := iConn.(*tls.Conn); ok {
if tlsConn.ConnectionState().Version != gotls.VersionTLS13 {
return errors.New(`failed to use `+requestAddons.Flow+`, found outer tls version `, tlsConn.ConnectionState().Version).AtWarning()
}
t = reflect.TypeOf(tlsConn.Conn).Elem()
p = uintptr(unsafe.Pointer(tlsConn.Conn))
} else if realityConn, ok := iConn.(*reality.Conn); ok {
t = reflect.TypeOf(realityConn.Conn).Elem()
p = uintptr(unsafe.Pointer(realityConn.Conn))
} else {
return errors.New("XTLS only supports TLS and REALITY directly for now.").AtWarning()
}
i, _ := t.FieldByName("input")
r, _ := t.FieldByName("rawInput")
input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset))
rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset))
}
} else {
return errors.New("account " + account.ID.String() + " is not able to use the flow " + requestAddons.Flow).AtWarning()
}
case "":
inbound.CanSpliceCopy = 3
if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) {
return errors.New("account " + account.ID.String() + " is rejected since the client flow is empty. Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning()
}
default:
return errors.New("unknown request flow " + requestAddons.Flow).AtWarning()
}
if request.Command != protocol.RequestCommandMux {
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: connection.RemoteAddr(),
To: request.Destination(),
Status: log.AccessAccepted,
Reason: "",
Email: request.User.Email,
})
} else if account.Flow == vless.XRV {
ctx = session.ContextWithAllowedNetwork(ctx, net.Network_UDP)
}
sessionPolicy = h.policyManager.ForLevel(request.User.Level)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
inbound.Timer = timer
ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
link, err := dispatcher.Dispatch(ctx, request.Destination())
if err != nil {
return errors.New("failed to dispatch request to ", request.Destination()).Base(err).AtWarning()
}
serverReader := link.Reader // .(*pipe.Reader)
serverWriter := link.Writer // .(*pipe.Writer)
trafficState := proxy.NewTrafficState(account.ID.Bytes())
postRequest := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
// default: clientReader := reader
clientReader := encoding.DecodeBodyAddons(reader, request, requestAddons)
var err error
if requestAddons.Flow == vless.XRV {
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
clientReader = proxy.NewVisionReader(clientReader, trafficState, true, ctx1)
err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, true, ctx1)
} else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
}
if err != nil {
return errors.New("failed to transfer request payload").Base(err).AtInfo()
}
return nil
}
getResponse := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
bufferWriter := buf.NewBufferedWriter(buf.NewWriter(connection))
if err := encoding.EncodeResponseHeader(bufferWriter, request, responseAddons); err != nil {
return errors.New("failed to encode response header").Base(err).AtWarning()
}
// default: clientWriter := bufferWriter
clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, false, ctx)
multiBuffer, err1 := serverReader.ReadMultiBuffer()
if err1 != nil {
return err1 // ...
}
if err := clientWriter.WriteMultiBuffer(multiBuffer); err != nil {
return err // ...
}
// Flush; bufferWriter.WriteMultiBuffer now is bufferWriter.writer.WriteMultiBuffer
if err := bufferWriter.SetBuffered(false); err != nil {
return errors.New("failed to write A response payload").Base(err).AtWarning()
}
var err error
if requestAddons.Flow == vless.XRV {
err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, false, ctx)
} else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))
}
if err != nil {
return errors.New("failed to transfer response payload").Base(err).AtInfo()
}
// Indicates the end of response payload.
switch responseAddons.Flow {
default:
}
return nil
}
if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), getResponse); err != nil {
common.Interrupt(serverReader)
common.Interrupt(serverWriter)
return errors.New("connection ends").Base(err).AtInfo()
}
return nil
}