Add session context outbounds as slice (#3356)

* Add session context outbounds as slice

slice is needed for dialer proxy where two outbounds work on top of each other
There are two sets of target addr for example
It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy

* Fill outbound tag info

* Splice now checks capalibility from all outbounds

* Fix unit tests
This commit is contained in:
yuhan6665 2024-05-13 21:52:24 -04:00 committed by GitHub
parent 0735053348
commit 017f53b5fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 303 additions and 236 deletions

View file

@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
if !destination.IsValid() {
panic("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
return contentResult, contentErr
}
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
ob := session.OutboundFromContext(ctx)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String())
if proxied != nil {
@ -425,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
return
}
ob.Tag = handler.Tag()
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
if tag := handler.Tag(); tag != "" {
if inTag == "" {

View file

@ -26,11 +26,12 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
return protocolSnifferWithMetadata{}, errNotInit
}
return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) {
Target := session.OutboundFromContext(ctx).Target
if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address)
if domainFromFakeDNS != "" {
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil
}
}
@ -38,7 +39,7 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil {
ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt)
if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok {
inPool := fkr0.IsIPInIPPool(Target.Address)
inPool := fkr0.IsIPInIPPool(ob.Target.Address)
ipAddressInRangeValue.addressInRange = &inPool
}
}

View file

@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) {
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)
var outbound = &session.Outbound{}
outbounds := []*session.Outbound{{}}
if w.recvOrigDest {
var dest net.Destination
switch getTProxyType(w.stream) {
@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) {
dest = net.DestinationFromAddr(conn.LocalAddr())
}
if dest.IsValid() {
outbound.Target = dest
outbounds[0].Target = dest
}
}
ctx = session.ContextWithOutbound(ctx, outbound)
ctx = session.ContextWithOutbounds(ctx, outbounds)
if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &stat.CounterConnection{
@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
ctx = session.ContextWithID(ctx, sid)
if originalDest.IsValid() {
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: originalDest,
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ctx = session.ContextWithInbound(ctx, &session.Inbound{
Source: source,

View file

@ -169,10 +169,11 @@ func (h *Handler) Tag() string {
// Dispatch implements proxy.Outbound.Dispatch.
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
}
if h.mux != nil {
test := func(err error) {
@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
common.Interrupt(link.Writer)
}
}
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 {
switch h.udp443 {
case "reject":
test(newError("XUDP rejected UDP/443 traffic").AtInfo())
@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
goto out
}
}
if h.xudp != nil && outbound.Target.Network == net.Network_UDP {
if h.xudp != nil && ob.Target.Network == net.Network_UDP {
if !h.xudp.Enabled {
goto out
}
@ -243,10 +244,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
handler := h.outboundManager.GetHandler(tag)
if handler != nil {
newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx))
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := session.OutboundsFromContext(ctx)
ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{
Target: dest,
})
Tag: tag,
})) // add another outbound in session ctx
opts := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
@ -266,15 +268,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
}
if h.senderSettings.Via != nil {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
outbound = new(session.Outbound)
ctx = session.ContextWithOutbound(ctx, outbound)
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if h.senderSettings.ViaCidr == "" {
outbound.Gateway = h.senderSettings.Via.AsAddress()
ob.Gateway = h.senderSettings.Via.AsAddress()
} else { //Get a random address.
outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
}
}
}
@ -285,10 +284,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
conn, err := internet.Dial(ctx, dest, h.streamSettings)
conn = h.getStatCouterConnection(conn)
outbound := session.OutboundFromContext(ctx)
if outbound != nil {
outbound.Conn = conn
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
ob.Conn = conn
return conn, err
}

View file

@ -14,6 +14,7 @@ import (
"github.com/xtls/xray-core/app/stats"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session"
core "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/proxy/freedom"
@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),

View file

@ -62,12 +62,13 @@ func (p *Portal) Close() error {
}
func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
outboundMeta := session.OutboundFromContext(ctx)
if outboundMeta == nil {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob == nil {
return newError("outbound metadata not found").AtError()
}
if isDomain(outboundMeta.Target, p.domain) {
if isDomain(ob.Target, p.domain) {
muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{})
if err != nil {
return newError("failed to create mux client worker").Base(err).AtWarning()
@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
downlinkReader, downlinkWriter := pipe.New(opt...)
ctx := context.Background()
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
f := client.Dispatch(ctx, &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,

View file

@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) {
HandlerSelector: mockHs,
}, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) {
HandlerSelector: mockHs,
}, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.LocalHostIP, 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {