package udp import ( "context" "io" "sync" "time" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/udp" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport" ) type ResponseCallback func(ctx context.Context, packet *udp.Packet) type connEntry struct { link *transport.Link timer signal.ActivityUpdater cancel context.CancelFunc } type Dispatcher struct { sync.RWMutex conns map[net.Destination]*connEntry dispatcher routing.Dispatcher callback ResponseCallback } func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { return &Dispatcher{ conns: make(map[net.Destination]*connEntry), dispatcher: dispatcher, callback: callback, } } func (v *Dispatcher) RemoveRay(dest net.Destination) { v.Lock() defer v.Unlock() if conn, found := v.conns[dest]; found { common.Close(conn.link.Reader) common.Close(conn.link.Writer) delete(v.conns, dest) } } func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry { v.Lock() defer v.Unlock() if entry, found := v.conns[dest]; found { return entry } newError("establishing new connection for ", dest).WriteToLog() ctx, cancel := context.WithCancel(ctx) removeRay := func() { cancel() v.RemoveRay(dest) } timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) link, _ := v.dispatcher.Dispatch(ctx, dest) entry := &connEntry{ link: link, timer: timer, cancel: removeRay, } v.conns[dest] = entry go handleInput(ctx, entry, dest, v.callback) return entry } func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) { // TODO: Add user to destString newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx)) conn := v.getInboundRay(ctx, destination) outputStream := conn.link.Writer if outputStream != nil { if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) conn.cancel() return } } } func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) { defer conn.cancel() input := conn.link.Reader timer := conn.timer for { select { case <-ctx.Done(): return default: } mb, err := input.ReadMultiBuffer() if err != nil { newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx)) return } timer.Update() for _, b := range mb { callback(ctx, &udp.Packet{ Payload: b, Source: dest, }) } } } type dispatcherConn struct { dispatcher *Dispatcher cache chan *udp.Packet done *done.Instance } func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) { c := &dispatcherConn{ cache: make(chan *udp.Packet, 16), done: done.New(), } d := NewDispatcher(dispatcher, c.callback) c.dispatcher = d return c, nil } func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { select { case <-c.done.Wait(): packet.Payload.Release() return case c.cache <- packet: default: packet.Payload.Release() return } } func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { select { case <-c.done.Wait(): return 0, nil, io.EOF case packet := <-c.cache: n := copy(p, packet.Payload.Bytes()) return n, &net.UDPAddr{ IP: packet.Source.Address.IP(), Port: int(packet.Source.Port), }, nil } } func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { buffer := buf.New() raw := buffer.Extend(buf.Size) n := copy(raw, p) buffer.Resize(0, int32(n)) ctx := context.Background() c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer) return n, nil } func (c *dispatcherConn) Close() error { return c.done.Close() } func (c *dispatcherConn) LocalAddr() net.Addr { return &net.UDPAddr{ IP: []byte{0, 0, 0, 0}, Port: 0, } } func (c *dispatcherConn) SetDeadline(t time.Time) error { return nil } func (c *dispatcherConn) SetReadDeadline(t time.Time) error { return nil } func (c *dispatcherConn) SetWriteDeadline(t time.Time) error { return nil }