diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index a8d9c6f5..48b90b63 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -31,6 +31,7 @@ type Dispatcher struct { conns map[net.Destination]*connEntry dispatcher routing.Dispatcher callback ResponseCallback + callClose func() error } func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { @@ -79,7 +80,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (* cancel: removeRay, } v.conns[dest] = entry - go handleInput(ctx, entry, dest, v.callback) + go handleInput(ctx, entry, dest, v.callback, v.callClose) return entry, nil } @@ -102,8 +103,13 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, } } -func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) { - defer conn.cancel() +func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) { + defer func() { + conn.cancel() + if callClose != nil { + callClose() + } + }() input := conn.link.Reader timer := conn.timer @@ -144,7 +150,12 @@ func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.Pac done: done.New(), } - d := NewDispatcher(dispatcher, c.callback) + d := &Dispatcher{ + conns: make(map[net.Destination]*connEntry), + dispatcher: dispatcher, + callback: c.callback, + callClose: c.Close, + } c.dispatcher = d return c, nil } @@ -162,16 +173,22 @@ func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { } func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { + var packet *udp.Packet +s: select { case <-c.done.Wait(): - return 0, nil, io.EOF - case packet := <-c.cache: - n := copy(p, packet.Payload.Bytes()) - return n, &net.UDPAddr{ - IP: packet.Source.Address.IP(), - Port: int(packet.Source.Port), - }, nil + select { + case packet = <-c.cache: + break s + default: + return 0, nil, io.EOF + } + case packet = <-c.cache: } + return copy(p, packet.Payload.Bytes()), &net.UDPAddr{ + IP: packet.Source.Address.IP(), + Port: int(packet.Source.Port), + }, nil } func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {