mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-30 01:08:33 +00:00
Add session context outbounds as slice (#3356)
* Add session context outbounds as slice slice is needed for dialer proxy where two outbounds work on top of each other There are two sets of target addr for example It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy * Fill outbound tag info * Splice now checks capalibility from all outbounds * Fix unit tests
This commit is contained in:
parent
0735053348
commit
017f53b5fc
42 changed files with 303 additions and 236 deletions
|
@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
|
|||
if !destination.IsValid() {
|
||||
panic("Dispatcher: Invalid destination.")
|
||||
}
|
||||
ob := session.OutboundFromContext(ctx)
|
||||
if ob == nil {
|
||||
ob = &session.Outbound{}
|
||||
ctx = session.ContextWithOutbound(ctx, ob)
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
if len(outbounds) == 0 {
|
||||
outbounds = []*session.Outbound{{}}
|
||||
ctx = session.ContextWithOutbounds(ctx, outbounds)
|
||||
}
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
ob.OriginalTarget = destination
|
||||
ob.Target = destination
|
||||
content := session.ContentFromContext(ctx)
|
||||
|
@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
|
|||
if !destination.IsValid() {
|
||||
return newError("Dispatcher: Invalid destination.")
|
||||
}
|
||||
ob := session.OutboundFromContext(ctx)
|
||||
if ob == nil {
|
||||
ob = &session.Outbound{}
|
||||
ctx = session.ContextWithOutbound(ctx, ob)
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
if len(outbounds) == 0 {
|
||||
outbounds = []*session.Outbound{{}}
|
||||
ctx = session.ContextWithOutbounds(ctx, outbounds)
|
||||
}
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
ob.OriginalTarget = destination
|
||||
ob.Target = destination
|
||||
content := session.ContentFromContext(ctx)
|
||||
|
@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
|
|||
return contentResult, contentErr
|
||||
}
|
||||
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
|
||||
ob := session.OutboundFromContext(ctx)
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
|
||||
proxied := hosts.LookupHosts(ob.Target.String())
|
||||
if proxied != nil {
|
||||
|
@ -425,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
|
|||
return
|
||||
}
|
||||
|
||||
ob.Tag = handler.Tag()
|
||||
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
|
||||
if tag := handler.Tag(); tag != "" {
|
||||
if inTag == "" {
|
||||
|
|
|
@ -26,11 +26,12 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
|
|||
return protocolSnifferWithMetadata{}, errNotInit
|
||||
}
|
||||
return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) {
|
||||
Target := session.OutboundFromContext(ctx).Target
|
||||
if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP {
|
||||
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address)
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP {
|
||||
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address)
|
||||
if domainFromFakeDNS != "" {
|
||||
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
|
||||
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
|
||||
return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +39,7 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
|
|||
if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil {
|
||||
ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt)
|
||||
if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok {
|
||||
inPool := fkr0.IsIPInIPPool(Target.Address)
|
||||
inPool := fkr0.IsIPInIPPool(ob.Target.Address)
|
||||
ipAddressInRangeValue.addressInRange = &inPool
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) {
|
|||
sid := session.NewID()
|
||||
ctx = session.ContextWithID(ctx, sid)
|
||||
|
||||
var outbound = &session.Outbound{}
|
||||
outbounds := []*session.Outbound{{}}
|
||||
if w.recvOrigDest {
|
||||
var dest net.Destination
|
||||
switch getTProxyType(w.stream) {
|
||||
|
@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) {
|
|||
dest = net.DestinationFromAddr(conn.LocalAddr())
|
||||
}
|
||||
if dest.IsValid() {
|
||||
outbound.Target = dest
|
||||
outbounds[0].Target = dest
|
||||
}
|
||||
}
|
||||
ctx = session.ContextWithOutbound(ctx, outbound)
|
||||
ctx = session.ContextWithOutbounds(ctx, outbounds)
|
||||
|
||||
if w.uplinkCounter != nil || w.downlinkCounter != nil {
|
||||
conn = &stat.CounterConnection{
|
||||
|
@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
|
|||
ctx = session.ContextWithID(ctx, sid)
|
||||
|
||||
if originalDest.IsValid() {
|
||||
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
|
||||
outbounds := []*session.Outbound{{
|
||||
Target: originalDest,
|
||||
})
|
||||
}}
|
||||
ctx = session.ContextWithOutbounds(ctx, outbounds)
|
||||
}
|
||||
ctx = session.ContextWithInbound(ctx, &session.Inbound{
|
||||
Source: source,
|
||||
|
|
|
@ -169,10 +169,11 @@ func (h *Handler) Tag() string {
|
|||
|
||||
// Dispatch implements proxy.Outbound.Dispatch.
|
||||
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
|
||||
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
|
||||
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address {
|
||||
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
|
||||
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
|
||||
}
|
||||
if h.mux != nil {
|
||||
test := func(err error) {
|
||||
|
@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
|||
common.Interrupt(link.Writer)
|
||||
}
|
||||
}
|
||||
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
|
||||
if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 {
|
||||
switch h.udp443 {
|
||||
case "reject":
|
||||
test(newError("XUDP rejected UDP/443 traffic").AtInfo())
|
||||
|
@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
|||
goto out
|
||||
}
|
||||
}
|
||||
if h.xudp != nil && outbound.Target.Network == net.Network_UDP {
|
||||
if h.xudp != nil && ob.Target.Network == net.Network_UDP {
|
||||
if !h.xudp.Enabled {
|
||||
goto out
|
||||
}
|
||||
|
@ -243,10 +244,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
|
|||
handler := h.outboundManager.GetHandler(tag)
|
||||
if handler != nil {
|
||||
newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
||||
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{
|
||||
Target: dest,
|
||||
})
|
||||
|
||||
Tag: tag,
|
||||
})) // add another outbound in session ctx
|
||||
opts := pipe.OptionsFromContext(ctx)
|
||||
uplinkReader, uplinkWriter := pipe.New(opts...)
|
||||
downlinkReader, downlinkWriter := pipe.New(opts...)
|
||||
|
@ -266,15 +268,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
|
|||
}
|
||||
|
||||
if h.senderSettings.Via != nil {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil {
|
||||
outbound = new(session.Outbound)
|
||||
ctx = session.ContextWithOutbound(ctx, outbound)
|
||||
}
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
if h.senderSettings.ViaCidr == "" {
|
||||
outbound.Gateway = h.senderSettings.Via.AsAddress()
|
||||
ob.Gateway = h.senderSettings.Via.AsAddress()
|
||||
} else { //Get a random address.
|
||||
outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
|
||||
ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -285,10 +284,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
|
|||
|
||||
conn, err := internet.Dial(ctx, dest, h.streamSettings)
|
||||
conn = h.getStatCouterConnection(conn)
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound != nil {
|
||||
outbound.Conn = conn
|
||||
}
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
ob.Conn = conn
|
||||
return conn, err
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/xtls/xray-core/app/stats"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/serial"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
core "github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/outbound"
|
||||
"github.com/xtls/xray-core/proxy/freedom"
|
||||
|
@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
|
|||
v, _ := core.New(config)
|
||||
v.AddFeature((outbound.Manager)(new(Manager)))
|
||||
ctx := context.WithValue(context.Background(), xrayKey, v)
|
||||
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
|
||||
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
|
||||
Tag: "tag",
|
||||
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
|
||||
|
@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) {
|
|||
v, _ := core.New(config)
|
||||
v.AddFeature((outbound.Manager)(new(Manager)))
|
||||
ctx := context.WithValue(context.Background(), xrayKey, v)
|
||||
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
|
||||
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
|
||||
Tag: "tag",
|
||||
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
|
||||
|
|
|
@ -62,12 +62,13 @@ func (p *Portal) Close() error {
|
|||
}
|
||||
|
||||
func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
|
||||
outboundMeta := session.OutboundFromContext(ctx)
|
||||
if outboundMeta == nil {
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds) - 1]
|
||||
if ob == nil {
|
||||
return newError("outbound metadata not found").AtError()
|
||||
}
|
||||
|
||||
if isDomain(outboundMeta.Target, p.domain) {
|
||||
if isDomain(ob.Target, p.domain) {
|
||||
muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{})
|
||||
if err != nil {
|
||||
return newError("failed to create mux client worker").Base(err).AtWarning()
|
||||
|
@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
|
|||
downlinkReader, downlinkWriter := pipe.New(opt...)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
|
||||
outbounds := []*session.Outbound{{
|
||||
Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
|
||||
})
|
||||
}}
|
||||
ctx = session.ContextWithOutbounds(ctx, outbounds)
|
||||
f := client.Dispatch(ctx, &transport.Link{
|
||||
Reader: uplinkReader,
|
||||
Writer: downlinkWriter,
|
||||
|
|
|
@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) {
|
|||
HandlerSelector: mockHs,
|
||||
}, nil))
|
||||
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
|
||||
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
|
||||
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
|
||||
}})
|
||||
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
|
||||
common.Must(err)
|
||||
if tag := route.GetOutboundTag(); tag != "test" {
|
||||
|
@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) {
|
|||
HandlerSelector: mockHs,
|
||||
}, nil))
|
||||
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
|
||||
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
|
||||
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
|
||||
}})
|
||||
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
|
||||
common.Must(err)
|
||||
if tag := route.GetOutboundTag(); tag != "test" {
|
||||
|
@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) {
|
|||
r := new(Router)
|
||||
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
|
||||
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
|
||||
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
|
||||
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
|
||||
}})
|
||||
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
|
||||
common.Must(err)
|
||||
if tag := route.GetOutboundTag(); tag != "test" {
|
||||
|
@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) {
|
|||
r := new(Router)
|
||||
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
|
||||
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
|
||||
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
|
||||
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
|
||||
}})
|
||||
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
|
||||
common.Must(err)
|
||||
if tag := route.GetOutboundTag(); tag != "test" {
|
||||
|
@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) {
|
|||
r := new(Router)
|
||||
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))
|
||||
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
|
||||
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
|
||||
Target: net.TCPDestination(net.LocalHostIP, 80),
|
||||
}})
|
||||
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
|
||||
common.Must(err)
|
||||
if tag := route.GetOutboundTag(); tag != "test" {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue