From da5a28a088091b86ac5b70ca732fc11cdb4c43fe Mon Sep 17 00:00:00 2001 From: dyhkwong <50692134+dyhkwong@users.noreply.github.com> Date: Mon, 15 Jan 2024 23:33:15 +0800 Subject: [PATCH] Fix #2654 (#2941) * fix udp dispatcher * fix test --- transport/internet/udp/dispatcher.go | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index 32c8c8ac..c29d4b13 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -28,7 +28,7 @@ type connEntry struct { type Dispatcher struct { sync.RWMutex - conns map[net.Destination]*connEntry + conn *connEntry dispatcher routing.Dispatcher callback ResponseCallback callClose func() error @@ -36,19 +36,18 @@ type Dispatcher struct { func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { return &Dispatcher{ - conns: make(map[net.Destination]*connEntry), dispatcher: dispatcher, callback: callback, } } -func (v *Dispatcher) RemoveRay(dest net.Destination) { +func (v *Dispatcher) RemoveRay() { v.Lock() defer v.Unlock() - if conn, found := v.conns[dest]; found { - common.Close(conn.link.Reader) - common.Close(conn.link.Writer) - delete(v.conns, dest) + if v.conn != nil { + common.Close(v.conn.link.Reader) + common.Close(v.conn.link.Writer) + v.conn = nil } } @@ -56,8 +55,8 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (* v.Lock() defer v.Unlock() - if entry, found := v.conns[dest]; found { - return entry, nil + if v.conn != nil { + return v.conn, nil } newError("establishing new connection for ", dest).WriteToLog() @@ -65,7 +64,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (* ctx, cancel := context.WithCancel(ctx) removeRay := func() { cancel() - v.RemoveRay(dest) + v.RemoveRay() } timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) @@ -79,7 +78,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (* timer: timer, cancel: removeRay, } - v.conns[dest] = entry + v.conn = entry go handleInput(ctx, entry, dest, v.callback, v.callClose) return entry, nil } @@ -130,6 +129,9 @@ func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, cal } timer.Update() for _, b := range mb { + if b.UDP != nil { + dest = *b.UDP + } callback(ctx, &udp.Packet{ Payload: b, Source: dest, @@ -153,7 +155,6 @@ func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.Pac } d := &Dispatcher{ - conns: make(map[net.Destination]*connEntry), dispatcher: dispatcher, callback: c.callback, callClose: c.Close, @@ -199,7 +200,9 @@ func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { n := copy(raw, p) buffer.Resize(0, int32(n)) - c.dispatcher.Dispatch(c.ctx, net.DestinationFromAddr(addr), buffer) + destination := net.DestinationFromAddr(addr) + buffer.UDP = &destination + c.dispatcher.Dispatch(c.ctx, destination, buffer) return n, nil }