From 3dd3bf94d4e2519eb549803c3be08fe9368b715c Mon Sep 17 00:00:00 2001 From: mmmray <142015632+mmmray@users.noreply.github.com> Date: Sun, 25 Aug 2024 21:02:01 +0200 Subject: [PATCH] Fix data leak between mux.cool connections (#3718) Fix #116 --- common/mux/server.go | 5 +- common/mux/server_test.go | 124 ++++++++++++++++++++++++++++++++++++++ common/session/context.go | 16 +++++ 3 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 common/mux/server_test.go diff --git a/common/mux/server.go b/common/mux/server.go index 5a4e9974..480175ba 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -118,6 +118,9 @@ func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.Bu } func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error { + // deep-clone outbounds because it is going to be mutated concurrently + // (Target and OriginalTarget) + ctx = session.ContextCloneOutbounds(ctx) errors.LogInfo(ctx, "received request for ", meta.Target) { msg := &log.AccessMessage{ @@ -170,7 +173,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, b.Release() mb = nil } - errors.LogInfoInner(ctx, err,"XUDP hit ", meta.GlobalID) + errors.LogInfoInner(ctx, err, "XUDP hit ", meta.GlobalID) } if mb != nil { ctx = session.ContextWithTimeoutOnly(ctx, true) diff --git a/common/mux/server_test.go b/common/mux/server_test.go new file mode 100644 index 00000000..4158bf46 --- /dev/null +++ b/common/mux/server_test.go @@ -0,0 +1,124 @@ +package mux_test + +import ( + "context" + "testing" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/mux" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/pipe" +) + +func newLinkPair() (*transport.Link, *transport.Link) { + opt := pipe.WithoutSizeLimit() + uplinkReader, uplinkWriter := pipe.New(opt) + downlinkReader, downlinkWriter := pipe.New(opt) + + uplink := &transport.Link{ + Reader: uplinkReader, + Writer: downlinkWriter, + } + + downlink := &transport.Link{ + Reader: downlinkReader, + Writer: uplinkWriter, + } + + return uplink, downlink +} + +type TestDispatcher struct { + OnDispatch func(ctx context.Context, dest net.Destination) (*transport.Link, error) +} + +func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) { + return d.OnDispatch(ctx, dest) +} + +func (d *TestDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { + return nil +} + +func (d *TestDispatcher) Start() error { + return nil +} + +func (d *TestDispatcher) Close() error { + return nil +} + +func (*TestDispatcher) Type() interface{} { + return routing.DispatcherType() +} + +func TestRegressionOutboundLeak(t *testing.T) { + originalOutbounds := []*session.Outbound{{}} + serverCtx := session.ContextWithOutbounds(context.Background(), originalOutbounds) + + websiteUplink, websiteDownlink := newLinkPair() + + dispatcher := TestDispatcher{ + OnDispatch: func(ctx context.Context, dest net.Destination) (*transport.Link, error) { + // emulate what DefaultRouter.Dispatch does, and mutate something on the context + ob := session.OutboundsFromContext(ctx)[0] + ob.Target = dest + return websiteDownlink, nil + }, + } + + muxServerUplink, muxServerDownlink := newLinkPair() + _, err := mux.NewServerWorker(serverCtx, &dispatcher, muxServerUplink) + common.Must(err) + + client, err := mux.NewClientWorker(*muxServerDownlink, mux.ClientStrategy{}) + common.Must(err) + + clientCtx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{ + Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80), + }}) + + muxClientUplink, muxClientDownlink := newLinkPair() + + ok := client.Dispatch(clientCtx, muxClientUplink) + if !ok { + t.Error("failed to dispatch") + } + + { + b := buf.FromBytes([]byte("hello")) + common.Must(muxClientDownlink.Writer.WriteMultiBuffer(buf.MultiBuffer{b})) + } + + resMb, err := websiteUplink.Reader.ReadMultiBuffer() + common.Must(err) + res := resMb.String() + if res != "hello" { + t.Error("upload: ", res) + } + + { + b := buf.FromBytes([]byte("world")) + common.Must(websiteUplink.Writer.WriteMultiBuffer(buf.MultiBuffer{b})) + } + + resMb, err = muxClientDownlink.Reader.ReadMultiBuffer() + common.Must(err) + res = resMb.String() + if res != "world" { + t.Error("download: ", res) + } + + outbounds := session.OutboundsFromContext(serverCtx) + if outbounds[0] != originalOutbounds[0] { + t.Error("outbound got reassigned: ", outbounds[0]) + } + + if outbounds[0].Target.Address != nil { + t.Error("outbound target got leaked: ", outbounds[0].Target.String()) + } +} diff --git a/common/session/context.go b/common/session/context.go index 3fed0151..b7af69cc 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -40,6 +40,22 @@ func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Co return context.WithValue(ctx, outboundSessionKey, outbounds) } +func ContextCloneOutbounds(ctx context.Context) context.Context { + outbounds := OutboundsFromContext(ctx) + newOutbounds := make([]*Outbound, len(outbounds)) + for i, ob := range outbounds { + if ob == nil { + continue + } + + // copy outbound by value + v := *ob + newOutbounds[i] = &v + } + + return ContextWithOutbounds(ctx, newOutbounds) +} + func OutboundsFromContext(ctx context.Context) []*Outbound { if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok { return outbounds