diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go index c48b9f28..1273b0a2 100644 --- a/proxy/trojan/protocol.go +++ b/proxy/trojan/protocol.go @@ -146,26 +146,6 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { return nil } -// WriteMultiBufferWithMetadata writes udp packet with destination specified -func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error { - for { - mb2, b := buf.SplitFirst(mb) - mb = mb2 - if b == nil { - break - } - source := &dest - if b.UDP != nil { - source = b.UDP - } - if _, err := w.writePacket(b.Bytes(), *source); err != nil { - buf.ReleaseMulti(mb) - return err - } - } - return nil -} - func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { buffer := buf.StackNew() defer buffer.Release() @@ -259,12 +239,6 @@ func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) { return buf.MultiBuffer{b}, err } -// PacketPayload combines udp payload and destination -type PacketPayload struct { - Target net.Destination - Buffer buf.MultiBuffer -} - // PacketReader is UDP Connection Reader Wrapper for trojan protocol type PacketReader struct { io.Reader @@ -272,15 +246,6 @@ type PacketReader struct { // ReadMultiBuffer implements buf.Reader func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { - p, err := r.ReadMultiBufferWithMetadata() - if p != nil { - return p.Buffer, err - } - return nil, err -} - -// ReadMultiBufferWithMetadata reads udp packet with destination -func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) { addr, port, err := addrParser.ReadAddressPort(nil, r) if err != nil { return nil, newError("failed to read address and port").Base(err) @@ -321,7 +286,7 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) { remain -= int(n) } - return &PacketPayload{Target: dest, Buffer: mb}, nil + return mb, nil } func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn *xtls.Conn, rawConn syscall.RawConn, counter stats.Counter, sctx context.Context) error { diff --git a/proxy/trojan/protocol_test.go b/proxy/trojan/protocol_test.go index ea1cf7ac..038f45fd 100644 --- a/proxy/trojan/protocol_test.go +++ b/proxy/trojan/protocol_test.go @@ -71,21 +71,22 @@ func TestUDPRequest(t *testing.T) { common.Must(connReader.ParseHeader()) packetReader := &PacketReader{Reader: connReader} - p, err := packetReader.ReadMultiBufferWithMetadata() + mb, err := packetReader.ReadMultiBuffer() common.Must(err) - if p.Buffer.IsEmpty() { + if mb.IsEmpty() { t.Error("no request data") } - if r := cmp.Diff(p.Target, destination); r != "" { + mb2, b := buf.SplitFirst(mb) + defer buf.ReleaseMulti(mb2) + + dest := *b.UDP + if r := cmp.Diff(dest, destination); r != "" { t.Error("destination: ", r) } - mb, decoded := buf.SplitFirst(p.Buffer) - buf.ReleaseMulti(mb) - - if r := cmp.Diff(decoded.Bytes(), payload); r != "" { + if r := cmp.Diff(b.Bytes(), payload); r != "" { t.Error("data: ", r) } } diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index c73d4e2f..57e232b1 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -250,7 +250,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { - common.Must(clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source)) + udpPayload := packet.Payload + udpPayload.UDP = &packet.Source + common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload})) }) inbound := session.InboundFromContext(ctx) @@ -263,7 +265,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade case <-ctx.Done(): return nil default: - p, err := clientReader.ReadMultiBufferWithMetadata() + mb, err := clientReader.ReadMultiBuffer() if err != nil { if errors.Cause(err) != io.EOF { return newError("unexpected EOF").Base(err) @@ -271,21 +273,24 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade return nil } + mb2, b := buf.SplitFirst(mb) + destination := *b.UDP ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ From: inbound.Source, - To: p.Target, + To: destination, Status: log.AccessAccepted, Reason: "", Email: user.Email, }) - newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx)) + newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx)) if !buf.Cone || dest == nil { - dest = &p.Target + dest = &destination } - for _, b := range p.Buffer { - udpServer.Dispatch(ctx, *dest, b) + udpServer.Dispatch(ctx, *dest, b) // first packet + for _, payload := range mb2 { + udpServer.Dispatch(ctx, *dest, payload) } } }