diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index b8131b8f..26019bbe 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin if !destination.IsValid() { panic("Dispatcher: Invalid destination.") } - ob := session.OutboundFromContext(ctx) - if ob == nil { - ob = &session.Outbound{} - ctx = session.ContextWithOutbound(ctx, ob) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) } + ob := outbounds[len(outbounds) - 1] ob.OriginalTarget = destination ob.Target = destination content := session.ContentFromContext(ctx) @@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De if !destination.IsValid() { return newError("Dispatcher: Invalid destination.") } - ob := session.OutboundFromContext(ctx) - if ob == nil { - ob = &session.Outbound{} - ctx = session.ContextWithOutbound(ctx, ob) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) } + ob := outbounds[len(outbounds) - 1] ob.OriginalTarget = destination ob.Target = destination content := session.ContentFromContext(ctx) @@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw return contentResult, contentErr } func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { - ob := session.OutboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { proxied := hosts.LookupHosts(ob.Target.String()) if proxied != nil { @@ -425,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. return } + ob.Tag = handler.Tag() if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { if tag := handler.Tag(); tag != "" { if inTag == "" { diff --git a/app/dispatcher/fakednssniffer.go b/app/dispatcher/fakednssniffer.go index ad879daf..8d0804de 100644 --- a/app/dispatcher/fakednssniffer.go +++ b/app/dispatcher/fakednssniffer.go @@ -26,11 +26,12 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error) return protocolSnifferWithMetadata{}, errNotInit } return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) { - Target := session.OutboundFromContext(ctx).Target - if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP { - domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP { + domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address) if domainFromFakeDNS != "" { - newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx)) + newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx)) return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil } } @@ -38,7 +39,7 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error) if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil { ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt) if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok { - inPool := fkr0.IsIPInIPPool(Target.Address) + inPool := fkr0.IsIPInIPPool(ob.Target.Address) ipAddressInRangeValue.addressInRange = &inPool } } diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 1fe86655..9a6499f1 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) { sid := session.NewID() ctx = session.ContextWithID(ctx, sid) - var outbound = &session.Outbound{} + outbounds := []*session.Outbound{{}} if w.recvOrigDest { var dest net.Destination switch getTProxyType(w.stream) { @@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) { dest = net.DestinationFromAddr(conn.LocalAddr()) } if dest.IsValid() { - outbound.Target = dest + outbounds[0].Target = dest } } - ctx = session.ContextWithOutbound(ctx, outbound) + ctx = session.ContextWithOutbounds(ctx, outbounds) if w.uplinkCounter != nil || w.downlinkCounter != nil { conn = &stat.CounterConnection{ @@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest ctx = session.ContextWithID(ctx, sid) if originalDest.IsValid() { - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := []*session.Outbound{{ Target: originalDest, - }) + }} + ctx = session.ContextWithOutbounds(ctx, outbounds) } ctx = session.ContextWithInbound(ctx, &session.Inbound{ Source: source, diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 792ac249..4262c76a 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -169,10 +169,11 @@ func (h *Handler) Tag() string { // Dispatch implements proxy.Outbound.Dispatch. func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { - outbound := session.OutboundFromContext(ctx) - if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address { - link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} - link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address { + link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} + link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} } if h.mux != nil { test := func(err error) { @@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { common.Interrupt(link.Writer) } } - if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 { + if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 { switch h.udp443 { case "reject": test(newError("XUDP rejected UDP/443 traffic").AtInfo()) @@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { goto out } } - if h.xudp != nil && outbound.Target.Network == net.Network_UDP { + if h.xudp != nil && ob.Target.Network == net.Network_UDP { if !h.xudp.Enabled { goto out } @@ -243,10 +244,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti handler := h.outboundManager.GetHandler(tag) if handler != nil { newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx)) - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := session.OutboundsFromContext(ctx) + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{ Target: dest, - }) - + Tag: tag, + })) // add another outbound in session ctx opts := pipe.OptionsFromContext(ctx) uplinkReader, uplinkWriter := pipe.New(opts...) downlinkReader, downlinkWriter := pipe.New(opts...) @@ -266,15 +268,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti } if h.senderSettings.Via != nil { - outbound := session.OutboundFromContext(ctx) - if outbound == nil { - outbound = new(session.Outbound) - ctx = session.ContextWithOutbound(ctx, outbound) - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] if h.senderSettings.ViaCidr == "" { - outbound.Gateway = h.senderSettings.Via.AsAddress() + ob.Gateway = h.senderSettings.Via.AsAddress() } else { //Get a random address. - outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr) + ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr) } } } @@ -285,10 +284,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti conn, err := internet.Dial(ctx, dest, h.streamSettings) conn = h.getStatCouterConnection(conn) - outbound := session.OutboundFromContext(ctx) - if outbound != nil { - outbound.Conn = conn - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + ob.Conn = conn return conn, err } diff --git a/app/proxyman/outbound/handler_test.go b/app/proxyman/outbound/handler_test.go index e5b67308..3f7ef28e 100644 --- a/app/proxyman/outbound/handler_test.go +++ b/app/proxyman/outbound/handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/xtls/xray-core/app/stats" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" core "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/proxy/freedom" @@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) { v, _ := core.New(config) v.AddFeature((outbound.Manager)(new(Manager))) ctx := context.WithValue(context.Background(), xrayKey, v) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}}) h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{ Tag: "tag", ProxySettings: serial.ToTypedMessage(&freedom.Config{}), @@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) { v, _ := core.New(config) v.AddFeature((outbound.Manager)(new(Manager))) ctx := context.WithValue(context.Background(), xrayKey, v) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}}) h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{ Tag: "tag", ProxySettings: serial.ToTypedMessage(&freedom.Config{}), diff --git a/app/reverse/portal.go b/app/reverse/portal.go index fb0b6930..456de550 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -62,12 +62,13 @@ func (p *Portal) Close() error { } func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error { - outboundMeta := session.OutboundFromContext(ctx) - if outboundMeta == nil { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if ob == nil { return newError("outbound metadata not found").AtError() } - if isDomain(outboundMeta.Target, p.domain) { + if isDomain(ob.Target, p.domain) { muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{}) if err != nil { return newError("failed to create mux client worker").Base(err).AtWarning() @@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { downlinkReader, downlinkWriter := pipe.New(opt...) ctx := context.Background() - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + outbounds := []*session.Outbound{{ Target: net.UDPDestination(net.DomainAddress(internalDomain), 0), - }) + }} + ctx = session.ContextWithOutbounds(ctx, outbounds) f := client.Dispatch(ctx, &transport.Link{ Reader: uplinkReader, Writer: downlinkWriter, diff --git a/app/router/router_test.go b/app/router/router_test.go index 4c6bfc63..2c33aae1 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) { HandlerSelector: mockHs, }, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) { HandlerSelector: mockHs, }, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("example.com"), 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { @@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) { r := new(Router) common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil)) - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.LocalHostIP, 80), + }}) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag := route.GetOutboundTag(); tag != "test" { diff --git a/common/mux/client.go b/common/mux/client.go index 88621be0..2537f02b 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -148,9 +148,10 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { } go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ + outbounds := []*session.Outbound{{ Target: net.TCPDestination(muxCoolAddress, muxCoolPort), - }) + }} + ctx := session.ContextWithOutbounds(context.Background(), outbounds) ctx, cancel := context.WithCancel(ctx) if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { @@ -242,17 +243,18 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error { } func fetchInput(ctx context.Context, s *Session, output buf.Writer) { - dest := session.OutboundFromContext(ctx).Target + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] transferType := protocol.TransferTypeStream - if dest.Network == net.Network_UDP { + if ob.Target.Network == net.Network_UDP { transferType = protocol.TransferTypePacket } s.transferType = transferType - writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx)) + writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) defer s.Close(false) defer writer.Close() - newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) + newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx)) if err := writeFirstPayload(s.input, writer); err != nil { newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) writer.hasError = true diff --git a/common/mux/client_test.go b/common/mux/client_test.go index 7837a86e..9626e2a2 100644 --- a/common/mux/client_test.go +++ b/common/mux/client_test.go @@ -86,9 +86,9 @@ func TestClientWorkerClose(t *testing.T) { } tr1, tw1 := pipe.New(pipe.WithoutSizeLimit()) - ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{ + ctx1 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80), - }) + }}) common.Must(manager.Dispatch(ctx1, &transport.Link{ Reader: tr1, Writer: tw1, @@ -103,9 +103,9 @@ func TestClientWorkerClose(t *testing.T) { } tr2, tw2 := pipe.New(pipe.WithoutSizeLimit()) - ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{ + ctx2 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80), - }) + }}) common.Must(manager.Dispatch(ctx2, &transport.Link{ Reader: tr2, Writer: tw2, diff --git a/common/session/context.go b/common/session/context.go index 87586169..fc37bd72 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -51,13 +51,13 @@ func InboundFromContext(ctx context.Context) *Inbound { return nil } -func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context { - return context.WithValue(ctx, outboundSessionKey, outbound) +func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Context { + return context.WithValue(ctx, outboundSessionKey, outbounds) } -func OutboundFromContext(ctx context.Context) *Outbound { - if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok { - return outbound +func OutboundsFromContext(ctx context.Context) []*Outbound { + if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok { + return outbounds } return nil } diff --git a/common/session/session.go b/common/session/session.go index 38ffa7bd..d8ab1ec4 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -50,18 +50,11 @@ type Inbound struct { Conn net.Conn // Timer of the inbound buf copier. May be nil. Timer *signal.ActivityTimer - // CanSpliceCopy is a property for this connection, set by both inbound and outbound + // CanSpliceCopy is a property for this connection // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot CanSpliceCopy int } -func(i *Inbound) SetCanSpliceCopy(canSpliceCopy int) int { - if canSpliceCopy > i.CanSpliceCopy { - i.CanSpliceCopy = canSpliceCopy - } - return i.CanSpliceCopy -} - // Outbound is the metadata of an outbound connection. type Outbound struct { // Target address of the outbound connection. @@ -70,10 +63,15 @@ type Outbound struct { RouteTarget net.Destination // Gateway address Gateway net.Address + // Tag of the outbound proxy that handles the connection. + Tag string // Name of the outbound proxy that handles the connection. Name string // Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings Conn net.Conn + // CanSpliceCopy is a property for this connection + // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot + CanSpliceCopy int } // SniffingRequest controls the behavior of content sniffing. diff --git a/common/singbridge/dialer.go b/common/singbridge/dialer.go index 896c97fe..6be83036 100644 --- a/common/singbridge/dialer.go +++ b/common/singbridge/dialer.go @@ -43,9 +43,14 @@ func NewOutboundDialer(outbound proxy.Outbound, dialer internet.Dialer) *XrayOut } func (d *XrayOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ - Target: ToDestination(destination, ToNetwork(network)), - }) + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + outbounds = []*session.Outbound{{}} + ctx = session.ContextWithOutbounds(ctx, outbounds) + } + ob := outbounds[len(outbounds) - 1] + ob.Target = ToDestination(destination, ToNetwork(network)) + opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} uplinkReader, uplinkWriter := pipe.New(opts...) downlinkReader, downlinkWriter := pipe.New(opts...) diff --git a/features/routing/session/context.go b/features/routing/session/context.go index c900219d..3c9764b3 100644 --- a/features/routing/session/context.go +++ b/features/routing/session/context.go @@ -124,9 +124,11 @@ func (ctx *Context) GetSkipDNSResolve() bool { // AsRoutingContext creates a context from context.context with session info. func AsRoutingContext(ctx context.Context) routing.Context { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] return &Context{ Inbound: session.InboundFromContext(ctx), - Outbound: session.OutboundFromContext(ctx), + Outbound: ob, Content: session.ContentFromContext(ctx), } } diff --git a/proxy/blackhole/blackhole.go b/proxy/blackhole/blackhole.go index 4b819417..23c9c291 100644 --- a/proxy/blackhole/blackhole.go +++ b/proxy/blackhole/blackhole.go @@ -31,10 +31,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound != nil { - outbound.Name = "blackhole" - } + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + ob.Name = "blackhole" nBytes := h.response.WriteTo(link.Writer) if nBytes > 0 { diff --git a/proxy/blackhole/blackhole_test.go b/proxy/blackhole/blackhole_test.go index 8e487e0c..6a9cb8e8 100644 --- a/proxy/blackhole/blackhole_test.go +++ b/proxy/blackhole/blackhole_test.go @@ -7,13 +7,15 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/proxy/blackhole" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/pipe" ) func TestBlackholeHTTPResponse(t *testing.T) { - handler, err := blackhole.New(context.Background(), &blackhole.Config{ + ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{}}) + handler, err := blackhole.New(ctx, &blackhole.Config{ Response: serial.ToTypedMessage(&blackhole.HTTPResponse{}), }) common.Must(err) @@ -32,7 +34,7 @@ func TestBlackholeHTTPResponse(t *testing.T) { Reader: reader, Writer: writer, } - common.Must(handler.Process(context.Background(), &link, nil)) + common.Must(handler.Process(ctx, &link, nil)) common.Must(rerr) if mb.IsEmpty() { t.Error("expect http response, but nothing") diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 2cf21a42..86790f76 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -96,15 +96,16 @@ func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage. // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("invalid outbound") } - outbound.Name = "dns" + ob.Name = "dns" - srcNetwork := outbound.Target.Network + srcNetwork := ob.Target.Network - dest := outbound.Target + dest := ob.Target if h.server.Network != net.Network_Unknown { dest.Network = h.server.Network } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 1c59fe62..5a07df5c 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -86,10 +86,15 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st destinationOverridden := false if d.config.FollowRedirect { - if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() { - dest = outbound.Target - destinationOverridden = true - } else if handshake, ok := conn.(hasHandshakeAddressContext); ok { + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) > 0 { + ob := outbounds[len(outbounds) - 1] + if ob.Target.IsValid() { + dest = ob.Target + destinationOverridden = true + } + } + if handshake, ok := conn.(hasHandshakeAddressContext); ok && !destinationOverridden { addr := handshake.HandshakeAddressContext(ctx) if addr != nil { dest.Address = addr @@ -103,7 +108,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st inbound := session.InboundFromContext(ctx) inbound.Name = "dokodemo-door" - inbound.SetCanSpliceCopy(1) + inbound.CanSpliceCopy = 1 inbound.User = &protocol.MemoryUser{ Level: d.config.UserLevel, } diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 0176929c..9e6afc9d 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -106,16 +106,16 @@ func isValidAddress(addr *net.IPOrDomain) bool { // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "freedom" + ob.Name = "freedom" + ob.CanSpliceCopy = 1 inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(1) - } - destination := outbound.Target + + destination := ob.Target UDPOverride := net.UDPDestination(nil, 0) if h.config.DestinationOverride != nil { server := h.config.DestinationOverride.Server diff --git a/proxy/http/client.go b/proxy/http/client.go index 72060c4d..80a0328a 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -69,16 +69,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements proxy.Outbound.Process. We first create a socket tunnel via HTTP CONNECT method, then redirect all inbound traffic to that tunnel. func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "http" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } - target := outbound.Target + ob.Name = "http" + ob.CanSpliceCopy = 2 + target := ob.Target targetAddr := target.NetAddr() if target.Network == net.Network_UDP { @@ -175,9 +173,10 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error) } inbound := session.InboundFromContext(ctx) - outbound := session.OutboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] - if inbound == nil || outbound == nil { + if inbound == nil || ob == nil { return nil, newError("missing inbound or outbound metadata from context") } @@ -186,7 +185,7 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error) Target net.Destination }{ Source: inbound.Source, - Target: outbound.Target, + Target: ob.Target, } filled := make([]*Header, len(header)) diff --git a/proxy/http/server.go b/proxy/http/server.go index 511d9b08..a7df317d 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -85,7 +85,7 @@ type readerOnly struct { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "http" - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/loopback/loopback.go b/proxy/loopback/loopback.go index 30c39bd9..f3be5a95 100644 --- a/proxy/loopback/loopback.go +++ b/proxy/loopback/loopback.go @@ -22,12 +22,13 @@ type Loopback struct { } func (l *Loopback) Process(ctx context.Context, link *transport.Link, _ internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "loopback" - destination := outbound.Target + ob.Name = "loopback" + destination := ob.Target newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/proxy.go b/proxy/proxy.go index 6a5a1798..2507d029 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -474,45 +474,73 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net readerConn, readCounter, _ := UnwrapRawConn(readerConn) writerConn, _, writeCounter := UnwrapRawConn(writerConn) reader := buf.NewReader(readerConn) - if inbound := session.InboundFromContext(ctx); inbound != nil { - if tc, ok := writerConn.(*net.TCPConn); ok && readerConn != nil && writerConn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - for inbound.CanSpliceCopy != 3 { - if inbound.CanSpliceCopy == 1 { - newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx)) - statWriter, _ := writer.(*dispatcher.SizeStatWriter) - //runtime.Gosched() // necessary - time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice - w, err := tc.ReadFrom(readerConn) - if readCounter != nil { - readCounter.Add(w) // outbound stats - } - if writeCounter != nil { - writeCounter.Add(w) // inbound stats - } - if statWriter != nil { - statWriter.Counter.Add(w) // user stats - } - if err != nil && errors.Cause(err) != io.EOF { - return err - } - return nil - } - buffer, err := reader.ReadMultiBuffer() - if !buffer.IsEmpty() { - if readCounter != nil { - readCounter.Add(int64(buffer.Len())) - } - timer.Update() - if werr := writer.WriteMultiBuffer(buffer); werr != nil { - return werr - } - } - if err != nil { - return err - } - } + if runtime.GOOS != "linux" && runtime.GOOS != "android" { + return readV(ctx, reader, writer, timer, readCounter) + } + tc, ok := writerConn.(*net.TCPConn) + if !ok || readerConn == nil || writerConn == nil { + return readV(ctx, reader, writer, timer, readCounter) + } + inbound := session.InboundFromContext(ctx) + if inbound == nil || inbound.CanSpliceCopy == 3 { + return readV(ctx, reader, writer, timer, readCounter) + } + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) == 0 { + return readV(ctx, reader, writer, timer, readCounter) + } + for _, ob := range outbounds { + if ob.CanSpliceCopy == 3 { + return readV(ctx, reader, writer, timer, readCounter) } } + + for { + inbound := session.InboundFromContext(ctx) + outbounds := session.OutboundsFromContext(ctx) + var splice = inbound.CanSpliceCopy == 1 + for _, ob := range outbounds { + if ob.CanSpliceCopy != 1 { + splice = false + } + } + if splice { + newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx)) + statWriter, _ := writer.(*dispatcher.SizeStatWriter) + //runtime.Gosched() // necessary + time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice + w, err := tc.ReadFrom(readerConn) + if readCounter != nil { + readCounter.Add(w) // outbound stats + } + if writeCounter != nil { + writeCounter.Add(w) // inbound stats + } + if statWriter != nil { + statWriter.Counter.Add(w) // user stats + } + if err != nil && errors.Cause(err) != io.EOF { + return err + } + return nil + } + buffer, err := reader.ReadMultiBuffer() + if !buffer.IsEmpty() { + if readCounter != nil { + readCounter.Add(int64(buffer.Len())) + } + timer.Update() + if werr := writer.WriteMultiBuffer(buffer); werr != nil { + return werr + } + } + if err != nil { + return err + } + } +} + +func readV(ctx context.Context, reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, readCounter stats.Counter) error { newError("CopyRawConn readv").WriteToLog(session.ExportIDToError(ctx)) if err := buf.Copy(reader, writer, buf.UpdateActivity(timer), buf.AddToStatCounter(readCounter)); err != nil { return newError("failed to process response").Base(err) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 57d8f81c..8ebe7631 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -49,16 +49,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements OutboundHandler.Process(). func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "shadowsocks" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } - destination := outbound.Target + ob.Name = "shadowsocks" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network var server *protocol.ServerSpec diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 2975ba70..8253506a 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -73,7 +73,7 @@ func (s *Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 switch network { case net.Network_TCP: diff --git a/proxy/shadowsocks_2022/inbound.go b/proxy/shadowsocks_2022/inbound.go index 00314c90..f1eb76a5 100644 --- a/proxy/shadowsocks_2022/inbound.go +++ b/proxy/shadowsocks_2022/inbound.go @@ -66,7 +66,7 @@ func (i *Inbound) Network() []net.Network { func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_multi.go b/proxy/shadowsocks_2022/inbound_multi.go index df837894..f80ec6d1 100644 --- a/proxy/shadowsocks_2022/inbound_multi.go +++ b/proxy/shadowsocks_2022/inbound_multi.go @@ -155,7 +155,7 @@ func (i *MultiUserInbound) Network() []net.Network { func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-multi" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_relay.go b/proxy/shadowsocks_2022/inbound_relay.go index 7317f8dd..1c4b8248 100644 --- a/proxy/shadowsocks_2022/inbound_relay.go +++ b/proxy/shadowsocks_2022/inbound_relay.go @@ -87,7 +87,7 @@ func (i *RelayInbound) Network() []net.Network { func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-relay" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/outbound.go b/proxy/shadowsocks_2022/outbound.go index bc1eb556..cac9a91b 100644 --- a/proxy/shadowsocks_2022/outbound.go +++ b/proxy/shadowsocks_2022/outbound.go @@ -65,15 +65,16 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int inbound := session.InboundFromContext(ctx) if inbound != nil { inboundConn = inbound.Conn - inbound.SetCanSpliceCopy(3) } - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "shadowsocks-2022" - destination := outbound.Target + ob.Name = "shadowsocks-2022" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network newError("tunneling request to ", destination, " via ", o.server.NetAddr()).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 82591be4..b283eb65 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -57,17 +57,15 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements proxy.Outbound.Process. func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified.") } - outbound.Name = "socks" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } + ob.Name = "socks" + ob.CanSpliceCopy = 2 // Destination of the inner request. - destination := outbound.Target + destination := ob.Target // Outbound server. var server *protocol.ServerSpec diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 2f789757..0109d5b4 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -65,7 +65,7 @@ func (s *Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "socks" - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index d6b95fc0..3a4d838a 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -50,16 +50,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements OutboundHandler.Process(). func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "trojan" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } - destination := outbound.Target + ob.Name = "trojan" + ob.CanSpliceCopy = 3 + destination := ob.Target network := destination.Network var server *protocol.ServerSpec diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 5c3fcd91..bc52c2b1 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -215,7 +215,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound := session.InboundFromContext(ctx) inbound.Name = "trojan" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 inbound.User = user sessionPolicy = s.policyManager.ForLevel(user.Level) diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 5956389a..2976be74 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -174,15 +174,18 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ctx context.Context) error { +func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { err := func() error { for { if trafficState.ReaderSwitchToDirectCopy { var writerConn net.Conn - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && ob != nil { writerConn = inbound.Conn if inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + inbound.CanSpliceCopy = 1 + } + if ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change + ob.CanSpliceCopy = 1 } } return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer) @@ -219,14 +222,19 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ctx context.Context) error { +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error { err := func() error { var ct stats.Counter for { buffer, err := reader.ReadMultiBuffer() if trafficState.WriterSwitchToDirectCopy { - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + if inbound := session.InboundFromContext(ctx); inbound != nil && ob != nil { + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 + } + if ob.CanSpliceCopy == 2 { + ob.CanSpliceCopy = 1 + } } rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) writer = buf.NewWriter(rawConn) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 0ffa61d2..7d2dd507 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -449,7 +449,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s switch requestAddons.Flow { case vless.XRV: if account.Flow == requestAddons.Flow { - inbound.SetCanSpliceCopy(2) + inbound.CanSpliceCopy = 2 switch request.Command { case protocol.RequestCommandUDP: return newError(requestAddons.Flow + " doesn't support UDP").AtWarning() @@ -479,7 +479,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning() } case "": - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) { return newError(account.ID.String() + " is not able to use \"\". Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning() } @@ -523,7 +523,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if requestAddons.Flow == vless.XRV { ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1) - err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, ctx1) + err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -560,7 +560,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ctx) + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ob, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index a9368813..bf98253b 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -70,12 +70,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified").AtError() } - outbound.Name = "vless" - inbound := session.InboundFromContext(ctx) + ob.Name = "vless" var rec *protocol.ServerSpec var conn stat.Connection @@ -96,7 +96,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if statConn, ok := iConn.(*stat.CounterConnection); ok { iConn = statConn.Connection } - target := outbound.Target + target := ob.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx)) command := protocol.RequestCommandTCP @@ -130,9 +130,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte requestAddons.Flow = requestAddons.Flow[:16] fallthrough case vless.XRV: - if inbound != nil { - inbound.SetCanSpliceCopy(2) - } + ob.CanSpliceCopy = 2 switch request.Command { case protocol.RequestCommandUDP: if !allowUDP443 && request.Port == 443 { @@ -161,9 +159,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) } default: - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.CanSpliceCopy = 3 } var newCtx context.Context @@ -238,8 +234,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning() } } - ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice - err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ctx1) + ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice + err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -277,7 +273,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ctx) + err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 679ea5da..f5340f20 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -257,7 +257,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s inbound := session.InboundFromContext(ctx) inbound.Name = "vmess" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 inbound.User = request.User sessionPolicy = h.policyManager.ForLevel(request.User.Level) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index c3c55d95..8f102dbb 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -60,15 +60,13 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified").AtError() } - outbound.Name = "vmess" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.Name = "vmess" + ob.CanSpliceCopy = 3 var rec *protocol.ServerSpec var conn stat.Connection @@ -87,7 +85,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } defer conn.Close() - target := outbound.Target + target := ob.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx)) command := protocol.RequestCommandTCP diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index 4136525e..00a6fa51 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -127,22 +127,20 @@ func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] + if !ob.Target.IsValid() { return newError("target not specified") } - outbound.Name = "wireguard" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } + ob.Name = "wireguard" + ob.CanSpliceCopy = 3 if err := h.processWireGuard(dialer); err != nil { return err } // Destination of the inner request. - destination := outbound.Target + destination := ob.Target command := protocol.RequestCommandTCP if destination.Network == net.Network_UDP { command = protocol.RequestCommandUDP diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index bdb4e801..3d3b584c 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -79,13 +79,15 @@ func (*Server) Network() []net.Network { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "wireguard" - inbound.SetCanSpliceCopy(3) + inbound.CanSpliceCopy = 3 + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds) - 1] s.info = routingInfo{ ctx: core.ToBackgroundDetachedContext(ctx), dispatcher: dispatcher, inboundTag: session.InboundFromContext(ctx), - outboundTag: session.OutboundFromContext(ctx), + outboundTag: ob, contentTag: session.ContentFromContext(ctx), } @@ -145,7 +147,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { ctx = session.ContextWithInbound(ctx, s.info.inboundTag) } if s.info.outboundTag != nil { - ctx = session.ContextWithOutbound(ctx, s.info.outboundTag) + ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag}) } if s.info.contentTag != nil { ctx = session.ContextWithContent(ctx, s.info.contentTag) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 3d5d046f..ffa868a3 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -112,7 +112,12 @@ func canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { newError("redirecting request " + dst.String() + " to " + obt).WriteToLog(session.ExportIDToError(ctx)) h := obm.GetHandler(obt) - ctx = session.ContextWithOutbound(ctx, &session.Outbound{Target: dst, Gateway: nil}) + outbounds := session.OutboundsFromContext(ctx) + ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{ + Target: dst, + Gateway: nil, + Tag: obt, + })) // add another outbound in session ctx if h != nil { ur, uw := pipe.New(pipe.OptionsFromContext(ctx)...) dr, dw := pipe.New(pipe.OptionsFromContext(ctx)...) @@ -131,8 +136,10 @@ func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { // DialSystem calls system dialer to create a network connection. func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { var src net.Address - if outbound := session.OutboundFromContext(ctx); outbound != nil { - src = outbound.Gateway + outbounds := session.OutboundsFromContext(ctx) + if len(outbounds) > 0 { + ob := outbounds[len(outbounds) - 1] + src = ob.Gateway } if sockopt == nil { return effectiveSystemDialer.Dial(ctx, src, dest, sockopt) diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 5d5789b4..a4b03ced 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -118,7 +118,7 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in address := net.ParseAddress(rawHost) gctx = session.ContextWithID(gctx, session.IDFromContext(ctx)) - gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx)) + gctx = session.ContextWithOutbounds(gctx, session.OutboundsFromContext(ctx)) gctx = session.ContextWithTimeoutOnly(gctx, true) c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt) diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index acccd0b7..0148658c 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -68,7 +68,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in address := net.ParseAddress(rawHost) hctx = session.ContextWithID(hctx, session.IDFromContext(ctx)) - hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx)) + hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx)) hctx = session.ContextWithTimeoutOnly(hctx, true) pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)