Add session context outbounds as slice (#3356)

* Add session context outbounds as slice

slice is needed for dialer proxy where two outbounds work on top of each other
There are two sets of target addr for example
It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy

* Fill outbound tag info

* Splice now checks capalibility from all outbounds

* Fix unit tests
This commit is contained in:
yuhan6665 2024-05-13 21:52:24 -04:00 committed by GitHub
parent 0735053348
commit 017f53b5fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 303 additions and 236 deletions

View File

@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
if !destination.IsValid() {
panic("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
return contentResult, contentErr
}
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
ob := session.OutboundFromContext(ctx)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String())
if proxied != nil {
@ -425,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
return
}
ob.Tag = handler.Tag()
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
if tag := handler.Tag(); tag != "" {
if inTag == "" {

View File

@ -26,11 +26,12 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
return protocolSnifferWithMetadata{}, errNotInit
}
return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) {
Target := session.OutboundFromContext(ctx).Target
if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address)
if domainFromFakeDNS != "" {
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil
}
}
@ -38,7 +39,7 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil {
ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt)
if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok {
inPool := fkr0.IsIPInIPPool(Target.Address)
inPool := fkr0.IsIPInIPPool(ob.Target.Address)
ipAddressInRangeValue.addressInRange = &inPool
}
}

View File

@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) {
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)
var outbound = &session.Outbound{}
outbounds := []*session.Outbound{{}}
if w.recvOrigDest {
var dest net.Destination
switch getTProxyType(w.stream) {
@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) {
dest = net.DestinationFromAddr(conn.LocalAddr())
}
if dest.IsValid() {
outbound.Target = dest
outbounds[0].Target = dest
}
}
ctx = session.ContextWithOutbound(ctx, outbound)
ctx = session.ContextWithOutbounds(ctx, outbounds)
if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &stat.CounterConnection{
@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
ctx = session.ContextWithID(ctx, sid)
if originalDest.IsValid() {
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: originalDest,
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ctx = session.ContextWithInbound(ctx, &session.Inbound{
Source: source,

View File

@ -169,10 +169,11 @@ func (h *Handler) Tag() string {
// Dispatch implements proxy.Outbound.Dispatch.
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
}
if h.mux != nil {
test := func(err error) {
@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
common.Interrupt(link.Writer)
}
}
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 {
switch h.udp443 {
case "reject":
test(newError("XUDP rejected UDP/443 traffic").AtInfo())
@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
goto out
}
}
if h.xudp != nil && outbound.Target.Network == net.Network_UDP {
if h.xudp != nil && ob.Target.Network == net.Network_UDP {
if !h.xudp.Enabled {
goto out
}
@ -243,10 +244,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
handler := h.outboundManager.GetHandler(tag)
if handler != nil {
newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx))
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := session.OutboundsFromContext(ctx)
ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{
Target: dest,
})
Tag: tag,
})) // add another outbound in session ctx
opts := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
@ -266,15 +268,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
}
if h.senderSettings.Via != nil {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
outbound = new(session.Outbound)
ctx = session.ContextWithOutbound(ctx, outbound)
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if h.senderSettings.ViaCidr == "" {
outbound.Gateway = h.senderSettings.Via.AsAddress()
ob.Gateway = h.senderSettings.Via.AsAddress()
} else { //Get a random address.
outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
}
}
}
@ -285,10 +284,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
conn, err := internet.Dial(ctx, dest, h.streamSettings)
conn = h.getStatCouterConnection(conn)
outbound := session.OutboundFromContext(ctx)
if outbound != nil {
outbound.Conn = conn
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
ob.Conn = conn
return conn, err
}

View File

@ -14,6 +14,7 @@ import (
"github.com/xtls/xray-core/app/stats"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session"
core "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/proxy/freedom"
@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),

View File

@ -62,12 +62,13 @@ func (p *Portal) Close() error {
}
func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
outboundMeta := session.OutboundFromContext(ctx)
if outboundMeta == nil {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob == nil {
return newError("outbound metadata not found").AtError()
}
if isDomain(outboundMeta.Target, p.domain) {
if isDomain(ob.Target, p.domain) {
muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{})
if err != nil {
return newError("failed to create mux client worker").Base(err).AtWarning()
@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
downlinkReader, downlinkWriter := pipe.New(opt...)
ctx := context.Background()
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
f := client.Dispatch(ctx, &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,

View File

@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) {
HandlerSelector: mockHs,
}, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) {
HandlerSelector: mockHs,
}, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.LocalHostIP, 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {

View File

@ -148,9 +148,10 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
}
go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
outbounds := []*session.Outbound{{
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
})
}}
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
@ -242,17 +243,18 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
dest := session.OutboundFromContext(ctx).Target
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
transferType := protocol.TransferTypeStream
if dest.Network == net.Network_UDP {
if ob.Target.Network == net.Network_UDP {
transferType = protocol.TransferTypePacket
}
s.transferType = transferType
writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx))
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()
newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true

View File

@ -86,9 +86,9 @@ func TestClientWorkerClose(t *testing.T) {
}
tr1, tw1 := pipe.New(pipe.WithoutSizeLimit())
ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{
ctx1 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
}})
common.Must(manager.Dispatch(ctx1, &transport.Link{
Reader: tr1,
Writer: tw1,
@ -103,9 +103,9 @@ func TestClientWorkerClose(t *testing.T) {
}
tr2, tw2 := pipe.New(pipe.WithoutSizeLimit())
ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{
ctx2 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
}})
common.Must(manager.Dispatch(ctx2, &transport.Link{
Reader: tr2,
Writer: tw2,

View File

@ -51,13 +51,13 @@ func InboundFromContext(ctx context.Context) *Inbound {
return nil
}
func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context {
return context.WithValue(ctx, outboundSessionKey, outbound)
func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Context {
return context.WithValue(ctx, outboundSessionKey, outbounds)
}
func OutboundFromContext(ctx context.Context) *Outbound {
if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok {
return outbound
func OutboundsFromContext(ctx context.Context) []*Outbound {
if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok {
return outbounds
}
return nil
}

View File

@ -50,18 +50,11 @@ type Inbound struct {
Conn net.Conn
// Timer of the inbound buf copier. May be nil.
Timer *signal.ActivityTimer
// CanSpliceCopy is a property for this connection, set by both inbound and outbound
// CanSpliceCopy is a property for this connection
// 1 = can, 2 = after processing protocol info should be able to, 3 = cannot
CanSpliceCopy int
}
func(i *Inbound) SetCanSpliceCopy(canSpliceCopy int) int {
if canSpliceCopy > i.CanSpliceCopy {
i.CanSpliceCopy = canSpliceCopy
}
return i.CanSpliceCopy
}
// Outbound is the metadata of an outbound connection.
type Outbound struct {
// Target address of the outbound connection.
@ -70,10 +63,15 @@ type Outbound struct {
RouteTarget net.Destination
// Gateway address
Gateway net.Address
// Tag of the outbound proxy that handles the connection.
Tag string
// Name of the outbound proxy that handles the connection.
Name string
// Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings
Conn net.Conn
// CanSpliceCopy is a property for this connection
// 1 = can, 2 = after processing protocol info should be able to, 3 = cannot
CanSpliceCopy int
}
// SniffingRequest controls the behavior of content sniffing.

View File

@ -43,9 +43,14 @@ func NewOutboundDialer(outbound proxy.Outbound, dialer internet.Dialer) *XrayOut
}
func (d *XrayOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
Target: ToDestination(destination, ToNetwork(network)),
})
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.Target = ToDestination(destination, ToNetwork(network))
opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
uplinkReader, uplinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)

View File

@ -124,9 +124,11 @@ func (ctx *Context) GetSkipDNSResolve() bool {
// AsRoutingContext creates a context from context.context with session info.
func AsRoutingContext(ctx context.Context) routing.Context {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
return &Context{
Inbound: session.InboundFromContext(ctx),
Outbound: session.OutboundFromContext(ctx),
Outbound: ob,
Content: session.ContentFromContext(ctx),
}
}

View File

@ -31,10 +31,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound != nil {
outbound.Name = "blackhole"
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
ob.Name = "blackhole"
nBytes := h.response.WriteTo(link.Writer)
if nBytes > 0 {

View File

@ -7,13 +7,15 @@ import (
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/proxy/blackhole"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/pipe"
)
func TestBlackholeHTTPResponse(t *testing.T) {
handler, err := blackhole.New(context.Background(), &blackhole.Config{
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{}})
handler, err := blackhole.New(ctx, &blackhole.Config{
Response: serial.ToTypedMessage(&blackhole.HTTPResponse{}),
})
common.Must(err)
@ -32,7 +34,7 @@ func TestBlackholeHTTPResponse(t *testing.T) {
Reader: reader,
Writer: writer,
}
common.Must(handler.Process(context.Background(), &link, nil))
common.Must(handler.Process(ctx, &link, nil))
common.Must(rerr)
if mb.IsEmpty() {
t.Error("expect http response, but nothing")

View File

@ -96,15 +96,16 @@ func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.
// Process implements proxy.Outbound.
func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("invalid outbound")
}
outbound.Name = "dns"
ob.Name = "dns"
srcNetwork := outbound.Target.Network
srcNetwork := ob.Target.Network
dest := outbound.Target
dest := ob.Target
if h.server.Network != net.Network_Unknown {
dest.Network = h.server.Network
}

View File

@ -86,10 +86,15 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
destinationOverridden := false
if d.config.FollowRedirect {
if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
dest = outbound.Target
destinationOverridden = true
} else if handshake, ok := conn.(hasHandshakeAddressContext); ok {
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) > 0 {
ob := outbounds[len(outbounds) - 1]
if ob.Target.IsValid() {
dest = ob.Target
destinationOverridden = true
}
}
if handshake, ok := conn.(hasHandshakeAddressContext); ok && !destinationOverridden {
addr := handshake.HandshakeAddressContext(ctx)
if addr != nil {
dest.Address = addr
@ -103,7 +108,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
inbound := session.InboundFromContext(ctx)
inbound.Name = "dokodemo-door"
inbound.SetCanSpliceCopy(1)
inbound.CanSpliceCopy = 1
inbound.User = &protocol.MemoryUser{
Level: d.config.UserLevel,
}

View File

@ -106,16 +106,16 @@ func isValidAddress(addr *net.IPOrDomain) bool {
// Process implements proxy.Outbound.
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified.")
}
outbound.Name = "freedom"
ob.Name = "freedom"
ob.CanSpliceCopy = 1
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(1)
}
destination := outbound.Target
destination := ob.Target
UDPOverride := net.UDPDestination(nil, 0)
if h.config.DestinationOverride != nil {
server := h.config.DestinationOverride.Server

View File

@ -69,16 +69,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
// Process implements proxy.Outbound.Process. We first create a socket tunnel via HTTP CONNECT method, then redirect all inbound traffic to that tunnel.
func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified.")
}
outbound.Name = "http"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(2)
}
target := outbound.Target
ob.Name = "http"
ob.CanSpliceCopy = 2
target := ob.Target
targetAddr := target.NetAddr()
if target.Network == net.Network_UDP {
@ -175,9 +173,10 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error)
}
inbound := session.InboundFromContext(ctx)
outbound := session.OutboundFromContext(ctx)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if inbound == nil || outbound == nil {
if inbound == nil || ob == nil {
return nil, newError("missing inbound or outbound metadata from context")
}
@ -186,7 +185,7 @@ func fillRequestHeader(ctx context.Context, header []*Header) ([]*Header, error)
Target net.Destination
}{
Source: inbound.Source,
Target: outbound.Target,
Target: ob.Target,
}
filled := make([]*Header, len(header))

View File

@ -85,7 +85,7 @@ type readerOnly struct {
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "http"
inbound.SetCanSpliceCopy(2)
inbound.CanSpliceCopy = 2
inbound.User = &protocol.MemoryUser{
Level: s.config.UserLevel,
}

View File

@ -22,12 +22,13 @@ type Loopback struct {
}
func (l *Loopback) Process(ctx context.Context, link *transport.Link, _ internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified.")
}
outbound.Name = "loopback"
destination := outbound.Target
ob.Name = "loopback"
destination := ob.Target
newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx))

View File

@ -474,45 +474,73 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
readerConn, readCounter, _ := UnwrapRawConn(readerConn)
writerConn, _, writeCounter := UnwrapRawConn(writerConn)
reader := buf.NewReader(readerConn)
if inbound := session.InboundFromContext(ctx); inbound != nil {
if tc, ok := writerConn.(*net.TCPConn); ok && readerConn != nil && writerConn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
for inbound.CanSpliceCopy != 3 {
if inbound.CanSpliceCopy == 1 {
newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx))
statWriter, _ := writer.(*dispatcher.SizeStatWriter)
//runtime.Gosched() // necessary
time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice
w, err := tc.ReadFrom(readerConn)
if readCounter != nil {
readCounter.Add(w) // outbound stats
}
if writeCounter != nil {
writeCounter.Add(w) // inbound stats
}
if statWriter != nil {
statWriter.Counter.Add(w) // user stats
}
if err != nil && errors.Cause(err) != io.EOF {
return err
}
return nil
}
buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
if readCounter != nil {
readCounter.Add(int64(buffer.Len()))
}
timer.Update()
if werr := writer.WriteMultiBuffer(buffer); werr != nil {
return werr
}
}
if err != nil {
return err
}
}
if runtime.GOOS != "linux" && runtime.GOOS != "android" {
return readV(ctx, reader, writer, timer, readCounter)
}
tc, ok := writerConn.(*net.TCPConn)
if !ok || readerConn == nil || writerConn == nil {
return readV(ctx, reader, writer, timer, readCounter)
}
inbound := session.InboundFromContext(ctx)
if inbound == nil || inbound.CanSpliceCopy == 3 {
return readV(ctx, reader, writer, timer, readCounter)
}
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
return readV(ctx, reader, writer, timer, readCounter)
}
for _, ob := range outbounds {
if ob.CanSpliceCopy == 3 {
return readV(ctx, reader, writer, timer, readCounter)
}
}
for {
inbound := session.InboundFromContext(ctx)
outbounds := session.OutboundsFromContext(ctx)
var splice = inbound.CanSpliceCopy == 1
for _, ob := range outbounds {
if ob.CanSpliceCopy != 1 {
splice = false
}
}
if splice {
newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx))
statWriter, _ := writer.(*dispatcher.SizeStatWriter)
//runtime.Gosched() // necessary
time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice
w, err := tc.ReadFrom(readerConn)
if readCounter != nil {
readCounter.Add(w) // outbound stats
}
if writeCounter != nil {
writeCounter.Add(w) // inbound stats
}
if statWriter != nil {
statWriter.Counter.Add(w) // user stats
}
if err != nil && errors.Cause(err) != io.EOF {
return err
}
return nil
}
buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
if readCounter != nil {
readCounter.Add(int64(buffer.Len()))
}
timer.Update()
if werr := writer.WriteMultiBuffer(buffer); werr != nil {
return werr
}
}
if err != nil {
return err
}
}
}
func readV(ctx context.Context, reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, readCounter stats.Counter) error {
newError("CopyRawConn readv").WriteToLog(session.ExportIDToError(ctx))
if err := buf.Copy(reader, writer, buf.UpdateActivity(timer), buf.AddToStatCounter(readCounter)); err != nil {
return newError("failed to process response").Base(err)

View File

@ -49,16 +49,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
// Process implements OutboundHandler.Process().
func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "shadowsocks"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
destination := outbound.Target
ob.Name = "shadowsocks"
ob.CanSpliceCopy = 3
destination := ob.Target
network := destination.Network
var server *protocol.ServerSpec

View File

@ -73,7 +73,7 @@ func (s *Server) Network() []net.Network {
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks"
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
switch network {
case net.Network_TCP:

View File

@ -66,7 +66,7 @@ func (i *Inbound) Network() []net.Network {
func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks-2022"
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
var metadata M.Metadata
if inbound.Source.IsValid() {

View File

@ -155,7 +155,7 @@ func (i *MultiUserInbound) Network() []net.Network {
func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks-2022-multi"
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
var metadata M.Metadata
if inbound.Source.IsValid() {

View File

@ -87,7 +87,7 @@ func (i *RelayInbound) Network() []net.Network {
func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "shadowsocks-2022-relay"
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
var metadata M.Metadata
if inbound.Source.IsValid() {

View File

@ -65,15 +65,16 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inboundConn = inbound.Conn
inbound.SetCanSpliceCopy(3)
}
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "shadowsocks-2022"
destination := outbound.Target
ob.Name = "shadowsocks-2022"
ob.CanSpliceCopy = 3
destination := ob.Target
network := destination.Network
newError("tunneling request to ", destination, " via ", o.server.NetAddr()).WriteToLog(session.ExportIDToError(ctx))

View File

@ -57,17 +57,15 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
// Process implements proxy.Outbound.Process.
func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified.")
}
outbound.Name = "socks"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(2)
}
ob.Name = "socks"
ob.CanSpliceCopy = 2
// Destination of the inner request.
destination := outbound.Target
destination := ob.Target
// Outbound server.
var server *protocol.ServerSpec

View File

@ -65,7 +65,7 @@ func (s *Server) Network() []net.Network {
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "socks"
inbound.SetCanSpliceCopy(2)
inbound.CanSpliceCopy = 2
inbound.User = &protocol.MemoryUser{
Level: s.config.UserLevel,
}

View File

@ -50,16 +50,14 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
// Process implements OutboundHandler.Process().
func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "trojan"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
destination := outbound.Target
ob.Name = "trojan"
ob.CanSpliceCopy = 3
destination := ob.Target
network := destination.Network
var server *protocol.ServerSpec

View File

@ -215,7 +215,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
inbound := session.InboundFromContext(ctx)
inbound.Name = "trojan"
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
inbound.User = user
sessionPolicy = s.policyManager.ForLevel(user.Level)

View File

@ -174,15 +174,18 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
}
// XtlsRead filter and read xtls protocol
func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ctx context.Context) error {
func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error {
err := func() error {
for {
if trafficState.ReaderSwitchToDirectCopy {
var writerConn net.Conn
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && ob != nil {
writerConn = inbound.Conn
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter
inbound.CanSpliceCopy = 1
}
if ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change
ob.CanSpliceCopy = 1
}
}
return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer)
@ -219,14 +222,19 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
}
// XtlsWrite filter and write xtls protocol
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ctx context.Context) error {
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error {
err := func() error {
var ct stats.Counter
for {
buffer, err := reader.ReadMultiBuffer()
if trafficState.WriterSwitchToDirectCopy {
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter
if inbound := session.InboundFromContext(ctx); inbound != nil && ob != nil {
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
if ob.CanSpliceCopy == 2 {
ob.CanSpliceCopy = 1
}
}
rawConn, _, writerCounter := proxy.UnwrapRawConn(conn)
writer = buf.NewWriter(rawConn)

View File

@ -449,7 +449,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
switch requestAddons.Flow {
case vless.XRV:
if account.Flow == requestAddons.Flow {
inbound.SetCanSpliceCopy(2)
inbound.CanSpliceCopy = 2
switch request.Command {
case protocol.RequestCommandUDP:
return newError(requestAddons.Flow + " doesn't support UDP").AtWarning()
@ -479,7 +479,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning()
}
case "":
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) {
return newError(account.ID.String() + " is not able to use \"\". Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning()
}
@ -523,7 +523,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
if requestAddons.Flow == vless.XRV {
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1)
err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, ctx1)
err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1)
} else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -560,7 +560,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
var err error
if requestAddons.Flow == vless.XRV {
err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ctx)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, ob, ctx)
} else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

View File

@ -70,12 +70,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Process implements proxy.Outbound.Process().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified").AtError()
}
outbound.Name = "vless"
inbound := session.InboundFromContext(ctx)
ob.Name = "vless"
var rec *protocol.ServerSpec
var conn stat.Connection
@ -96,7 +96,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if statConn, ok := iConn.(*stat.CounterConnection); ok {
iConn = statConn.Connection
}
target := outbound.Target
target := ob.Target
newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx))
command := protocol.RequestCommandTCP
@ -130,9 +130,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
requestAddons.Flow = requestAddons.Flow[:16]
fallthrough
case vless.XRV:
if inbound != nil {
inbound.SetCanSpliceCopy(2)
}
ob.CanSpliceCopy = 2
switch request.Command {
case protocol.RequestCommandUDP:
if !allowUDP443 && request.Port == 443 {
@ -161,9 +159,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset))
}
default:
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
ob.CanSpliceCopy = 3
}
var newCtx context.Context
@ -238,8 +234,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning()
}
}
ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice
err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ctx1)
ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1)
} else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -277,7 +273,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}
if requestAddons.Flow == vless.XRV {
err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ctx)
err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx)
} else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

View File

@ -257,7 +257,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
inbound := session.InboundFromContext(ctx)
inbound.Name = "vmess"
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
inbound.User = request.User
sessionPolicy = h.policyManager.ForLevel(request.User.Level)

View File

@ -60,15 +60,13 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Process implements proxy.Outbound.Process().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified").AtError()
}
outbound.Name = "vmess"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
ob.Name = "vmess"
ob.CanSpliceCopy = 3
var rec *protocol.ServerSpec
var conn stat.Connection
@ -87,7 +85,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}
defer conn.Close()
target := outbound.Target
target := ob.Target
newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx))
command := protocol.RequestCommandTCP

View File

@ -127,22 +127,20 @@ func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
// Process implements OutboundHandler.Dispatch().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if !ob.Target.IsValid() {
return newError("target not specified")
}
outbound.Name = "wireguard"
inbound := session.InboundFromContext(ctx)
if inbound != nil {
inbound.SetCanSpliceCopy(3)
}
ob.Name = "wireguard"
ob.CanSpliceCopy = 3
if err := h.processWireGuard(dialer); err != nil {
return err
}
// Destination of the inner request.
destination := outbound.Target
destination := ob.Target
command := protocol.RequestCommandTCP
if destination.Network == net.Network_UDP {
command = protocol.RequestCommandUDP

View File

@ -79,13 +79,15 @@ func (*Server) Network() []net.Network {
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
inbound := session.InboundFromContext(ctx)
inbound.Name = "wireguard"
inbound.SetCanSpliceCopy(3)
inbound.CanSpliceCopy = 3
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
s.info = routingInfo{
ctx: core.ToBackgroundDetachedContext(ctx),
dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
outboundTag: session.OutboundFromContext(ctx),
outboundTag: ob,
contentTag: session.ContentFromContext(ctx),
}
@ -145,7 +147,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
}
if s.info.outboundTag != nil {
ctx = session.ContextWithOutbound(ctx, s.info.outboundTag)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag})
}
if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag)

View File

@ -112,7 +112,12 @@ func canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig
func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn {
newError("redirecting request " + dst.String() + " to " + obt).WriteToLog(session.ExportIDToError(ctx))
h := obm.GetHandler(obt)
ctx = session.ContextWithOutbound(ctx, &session.Outbound{Target: dst, Gateway: nil})
outbounds := session.OutboundsFromContext(ctx)
ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{
Target: dst,
Gateway: nil,
Tag: obt,
})) // add another outbound in session ctx
if h != nil {
ur, uw := pipe.New(pipe.OptionsFromContext(ctx)...)
dr, dw := pipe.New(pipe.OptionsFromContext(ctx)...)
@ -131,8 +136,10 @@ func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn {
// DialSystem calls system dialer to create a network connection.
func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
var src net.Address
if outbound := session.OutboundFromContext(ctx); outbound != nil {
src = outbound.Gateway
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) > 0 {
ob := outbounds[len(outbounds) - 1]
src = ob.Gateway
}
if sockopt == nil {
return effectiveSystemDialer.Dial(ctx, src, dest, sockopt)

View File

@ -118,7 +118,7 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
address := net.ParseAddress(rawHost)
gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))
gctx = session.ContextWithOutbounds(gctx, session.OutboundsFromContext(ctx))
gctx = session.ContextWithTimeoutOnly(gctx, true)
c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)

View File

@ -68,7 +68,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
address := net.ParseAddress(rawHost)
hctx = session.ContextWithID(hctx, session.IDFromContext(ctx))
hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx))
hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
hctx = session.ContextWithTimeoutOnly(hctx, true)
pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)