diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index f8884003..e6267d2f 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -24,15 +24,6 @@ type connEntry struct { link *transport.Link timer signal.ActivityUpdater cancel context.CancelFunc - closed bool -} - -func (c *connEntry) Close() error { - c.closed = true - c.cancel() - common.Interrupt(c.link.Reader) - common.Close(c.link.Writer) - return nil } type Dispatcher struct { @@ -54,7 +45,8 @@ func (v *Dispatcher) RemoveRay() { v.Lock() defer v.Unlock() if v.conn != nil { - v.conn.Close() + common.Interrupt(v.conn.link.Reader) + common.Close(v.conn.link.Writer) v.conn = nil } } @@ -64,32 +56,28 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (* defer v.Unlock() if v.conn != nil { - if v.conn.closed { - v.conn = nil - } else { - return v.conn, nil - } + return v.conn, nil } errors.LogInfo(ctx, "establishing new connection for ", dest) ctx, cancel := context.WithCancel(ctx) + removeRay := func() { + cancel() + v.RemoveRay() + } + timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) link, err := v.dispatcher.Dispatch(ctx, dest) if err != nil { - cancel() return nil, errors.New("failed to dispatch request to ", dest).Base(err) } entry := &connEntry{ link: link, - cancel: cancel, + timer: timer, + cancel: removeRay, } - entryClose := func() { - entry.Close() - } - - entry.timer = signal.CancelAfterInactivity(ctx, entryClose, time.Minute) v.conn = entry go handleInput(ctx, entry, dest, v.callback, v.callClose) return entry, nil @@ -108,7 +96,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, if outputStream != nil { if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { errors.LogInfoInner(ctx, err, "failed to write first UDP payload") - conn.Close() + conn.cancel() return } } @@ -116,7 +104,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) { defer func() { - conn.Close() + conn.cancel() if callClose != nil { callClose() } @@ -148,7 +136,6 @@ func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, cal Payload: b, Source: dest, }) - b.Release() } } }