diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 2fa43ec5..0365e02b 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -2,6 +2,7 @@ package buf import ( "io" + "net" "github.com/xtls/xray-core/common/bytespool" ) @@ -20,6 +21,7 @@ type Buffer struct { v []byte start int32 end int32 + UDP *net.UDPAddr } // New creates a Buffer with 0 length and 2K capacity. diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index edb31d18..4a1be462 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -17,6 +17,7 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" ) @@ -148,7 +149,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if destination.Network == net.Network_TCP { writer = buf.NewWriter(conn) } else { - writer = &buf.SequentialWriter{Writer: conn} + writer = NewPacketWriter(conn) } if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil { @@ -165,7 +166,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if destination.Network == net.Network_TCP { reader = buf.NewReader(conn) } else { - reader = buf.NewPacketReader(conn) + reader = NewPacketReader(conn) } if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil { return newError("failed to process response").Base(err) @@ -180,3 +181,93 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return nil } + +func NewPacketReader(conn net.Conn) buf.Reader { + iConn := conn + statConn, ok := iConn.(*internet.StatCouterConnection) + if ok { + iConn = statConn.Connection + } + var counter stats.Counter + if statConn != nil { + counter = statConn.ReadCounter + } + if c, ok := iConn.(*internet.PacketConnWrapper); ok { + return &PacketReader{ + PacketConnWrapper: c, + Counter: counter, + } + } + return &buf.PacketReader{Reader: conn} +} + +type PacketReader struct { + *internet.PacketConnWrapper + stats.Counter +} + +func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + b := buf.New() + b.Resize(0, buf.Size) + n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes()) + if err != nil { + b.Release() + return nil, err + } + b.Resize(0, int32(n)) + b.UDP = d.(*net.UDPAddr) + if r.Counter != nil { + r.Counter.Add(int64(n)) + } + return buf.MultiBuffer{b}, nil +} + +func NewPacketWriter(conn net.Conn) buf.Writer { + iConn := conn + statConn, ok := iConn.(*internet.StatCouterConnection) + if ok { + iConn = statConn.Connection + } + var counter stats.Counter + if statConn != nil { + counter = statConn.WriteCounter + } + if c, ok := iConn.(*internet.PacketConnWrapper); ok { + return &PacketWriter{ + PacketConnWrapper: c, + Counter: counter, + } + } + return &buf.SequentialWriter{Writer: conn} +} + +type PacketWriter struct { + *internet.PacketConnWrapper + stats.Counter +} + +func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + for { + mb2, b := buf.SplitFirst(mb) + mb = mb2 + if b == nil { + break + } + var n int + var err error + if b.UDP != nil { + n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), b.UDP) + } else { + n, err = w.PacketConnWrapper.Write(b.Bytes()) + } + b.Release() + if err != nil { + buf.ReleaseMulti(mb) + return err + } + if w.Counter != nil { + w.Counter.Add(int64(n)) + } + } + return nil +} diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 06766be1..374bf8d4 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -134,14 +134,15 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } if request.Command == protocol.RequestCommandUDP { - writer := &buf.SequentialWriter{Writer: &UDPWriter{ - Writer: conn, - Request: request, - }} requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + writer := &UDPWriter{ + Writer: conn, + Request: request, + } + if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil { return newError("failed to transport all UDP request").Base(err) } diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 12ed9b3f..9cb74fc3 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -230,11 +230,15 @@ func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { buffer.Release() return nil, err } - _, payload, err := DecodeUDPPacket(v.User, buffer) + u, payload, err := DecodeUDPPacket(v.User, buffer) if err != nil { buffer.Release() return nil, err } + payload.UDP = &net.UDPAddr{ + IP: u.Address.IP(), + Port: int(u.Port), + } return buf.MultiBuffer{payload}, nil } @@ -243,13 +247,36 @@ type UDPWriter struct { Request *protocol.RequestHeader } -// Write implements io.Writer. -func (w *UDPWriter) Write(payload []byte) (int, error) { - packet, err := EncodeUDPPacket(w.Request, payload) - if err != nil { - return 0, err +func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + for { + mb2, b := buf.SplitFirst(mb) + mb = mb2 + if b == nil { + break + } + var packet *buf.Buffer + var err error + if b.UDP != nil { + request := &protocol.RequestHeader{ + User: w.Request.User, + Address: net.IPAddress(b.UDP.IP), + Port: net.Port(b.UDP.Port), + } + packet, err = EncodeUDPPacket(request, b.Bytes()) + } else { + packet, err = EncodeUDPPacket(w.Request, b.Bytes()) + } + b.Release() + if err != nil { + buf.ReleaseMulti(mb) + return err + } + _, err = w.Writer.Write(packet.Bytes()) + packet.Release() + if err != nil { + buf.ReleaseMulti(mb) + return err + } } - _, err = w.Writer.Write(packet.Bytes()) - packet.Release() - return len(payload), err + return nil } diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index 5429a8b9..9654e9df 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -145,7 +145,7 @@ func TestUDPReaderWriter(t *testing.T) { cache := buf.New() defer cache.Release() - writer := &buf.SequentialWriter{Writer: &UDPWriter{ + writer := &UDPWriter{ Writer: cache, Request: &protocol.RequestHeader{ Version: Version, @@ -153,7 +153,7 @@ func TestUDPReaderWriter(t *testing.T) { Port: 123, User: user, }, - }} + } reader := &UDPReader{ Reader: cache, diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 07316020..090d5c56 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -77,6 +77,15 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection } payload := packet.Payload + + if payload.UDP != nil { + request = &protocol.RequestHeader{ + User: request.User, + Address: net.IPAddress(payload.UDP.IP), + Port: net.Port(payload.UDP.Port), + } + } + data, err := EncodeUDPPacket(request, payload.Bytes()) payload.Release() if err != nil { @@ -94,6 +103,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection } inbound.User = s.user + var dest net.Destination + reader := buf.NewPacketReader(conn) for { mpayload, err := reader.ReadMultiBuffer() @@ -118,17 +129,25 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection } currentPacketCtx := ctx - dest := request.Destination() if inbound.Source.IsValid() { currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ From: inbound.Source, - To: dest, + To: request.Destination(), Status: log.AccessAccepted, Reason: "", Email: request.User.Email, }) } - newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(currentPacketCtx)) + newError("tunnelling request to ", request.Destination()).WriteToLog(session.ExportIDToError(currentPacketCtx)) + + data.UDP = &net.UDPAddr{ + IP: request.Address.IP(), + Port: int(request.Port), + } + + if dest.Network == 0 { + dest = request.Destination() // JUST FOLLOW THE FIREST PACKET + } currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) udpServer.Dispatch(currentPacketCtx, dest, data) diff --git a/proxy/socks/server.go b/proxy/socks/server.go index fa75f1d6..620d5dae 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -196,6 +196,15 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, if request == nil { return } + + if payload.UDP != nil { + request = &protocol.RequestHeader{ + User: request.User, + Address: net.IPAddress(payload.UDP.IP), + Port: net.Port(payload.UDP.Port), + } + } + udpMessage, err := EncodeUDPPacket(request, payload.Bytes()) payload.Release() @@ -211,6 +220,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx)) } + var dest net.Destination + reader := buf.NewPacketReader(conn) for { mpayload, err := reader.ReadMultiBuffer() @@ -242,8 +253,17 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, }) } + payload.UDP = &net.UDPAddr{ + IP: request.Address.IP(), + Port: int(request.Port), + } + + if dest.Network == 0 { + dest = request.Destination() // JUST FOLLOW THE FIREST PACKET + } + currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) - udpServer.Dispatch(currentPacketCtx, request.Destination(), payload) + udpServer.Dispatch(currentPacketCtx, dest, payload) } } } diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go index 07fe188b..c5643805 100644 --- a/proxy/trojan/protocol.go +++ b/proxy/trojan/protocol.go @@ -128,31 +128,43 @@ type PacketWriter struct { // WriteMultiBuffer implements buf.Writer func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { - b := make([]byte, maxLength) - for !mb.IsEmpty() { - var length int - mb, length = buf.SplitBytes(mb, b) - if _, err := w.writePacket(b[:length], w.Target); err != nil { + for { + mb2, b := buf.SplitFirst(mb) + mb = mb2 + if b == nil { + break + } + target := w.Target + if b.UDP != nil { + target.Address = net.IPAddress(b.UDP.IP) + target.Port = net.Port(b.UDP.Port) + } + if _, err := w.writePacket(b.Bytes(), target); err != nil { buf.ReleaseMulti(mb) return err } } - return nil } // WriteMultiBufferWithMetadata writes udp packet with destination specified func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error { - b := make([]byte, maxLength) - for !mb.IsEmpty() { - var length int - mb, length = buf.SplitBytes(mb, b) - if _, err := w.writePacket(b[:length], dest); err != nil { + for { + mb2, b := buf.SplitFirst(mb) + mb = mb2 + if b == nil { + break + } + source := dest + if b.UDP != nil { + source.Address = net.IPAddress(b.UDP.IP) + source.Port = net.Port(b.UDP.Port) + } + if _, err := w.writePacket(b.Bytes(), source); err != nil { buf.ReleaseMulti(mb) return err } } - return nil } @@ -300,6 +312,10 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) { } b := buf.New() + b.UDP = &net.UDPAddr{ + IP: addr.IP(), + Port: int(port.Value()), + } mb = append(mb, b) n, err := b.ReadFullFrom(r, int32(length)) if err != nil { diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index b43af83b..60ff7622 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -256,6 +256,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade inbound := session.InboundFromContext(ctx) user := inbound.User + var dest net.Destination + for { select { case <-ctx.Done(): @@ -278,8 +280,12 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade }) newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx)) + if dest.Network == 0 { + dest = p.Target // JUST FOLLOW THE FIREST PACKET + } + for _, b := range p.Buffer { - udpServer.Dispatch(ctx, p.Target, b) + udpServer.Dispatch(ctx, dest, b) } } } diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index ef925971..52fe5e76 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -60,7 +60,7 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne if err != nil { return nil, err } - return &packetConnWrapper{ + return &PacketConnWrapper{ conn: packetConn, dest: destAddr, }, nil @@ -98,41 +98,49 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr()) } -type packetConnWrapper struct { +type PacketConnWrapper struct { conn net.PacketConn dest net.Addr } -func (c *packetConnWrapper) Close() error { +func (c *PacketConnWrapper) Close() error { return c.conn.Close() } -func (c *packetConnWrapper) LocalAddr() net.Addr { +func (c *PacketConnWrapper) LocalAddr() net.Addr { return c.conn.LocalAddr() } -func (c *packetConnWrapper) RemoteAddr() net.Addr { +func (c *PacketConnWrapper) RemoteAddr() net.Addr { return c.dest } -func (c *packetConnWrapper) Write(p []byte) (int, error) { +func (c *PacketConnWrapper) Write(p []byte) (int, error) { return c.conn.WriteTo(p, c.dest) } -func (c *packetConnWrapper) Read(p []byte) (int, error) { +func (c *PacketConnWrapper) Read(p []byte) (int, error) { n, _, err := c.conn.ReadFrom(p) return n, err } -func (c *packetConnWrapper) SetDeadline(t time.Time) error { +func (c *PacketConnWrapper) WriteTo(p []byte, d net.Addr) (int, error) { + return c.conn.WriteTo(p, d) +} + +func (c *PacketConnWrapper) ReadFrom(p []byte) (int, net.Addr, error) { + return c.conn.ReadFrom(p) +} + +func (c *PacketConnWrapper) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } -func (c *packetConnWrapper) SetReadDeadline(t time.Time) error { +func (c *PacketConnWrapper) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *packetConnWrapper) SetWriteDeadline(t time.Time) error { +func (c *PacketConnWrapper) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }