From fb0e517158801ef49def1bdb706708b8a33f6f6e Mon Sep 17 00:00:00 2001 From: RPRX <63339210+rprx@users.noreply.github.com> Date: Fri, 8 Jan 2021 06:00:51 +0000 Subject: [PATCH] Adjust Trojan & Socks handleUDPPayload --- proxy/socks/server.go | 5 +++-- proxy/trojan/server.go | 29 +++++++++++++++++++---------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/proxy/socks/server.go b/proxy/socks/server.go index f9013da6..dd142a72 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -218,7 +218,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, conn.Write(udpMessage.Bytes()) }) - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { + inbound := session.InboundFromContext(ctx) + if inbound != nil && inbound.Source.IsValid() { newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx)) } @@ -249,7 +250,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, currentPacketCtx := ctx newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { + if inbound != nil && inbound.Source.IsValid() { currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ From: inbound.Source, To: destination, diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 57e232b1..da6361ff 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -251,7 +251,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { udpPayload := packet.Payload - udpPayload.UDP = &packet.Source + if udpPayload.UDP == nil { + udpPayload.UDP = &packet.Source + } common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload})) }) @@ -274,23 +276,30 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade } mb2, b := buf.SplitFirst(mb) + if b == nil { + continue + } destination := *b.UDP - ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ - From: inbound.Source, - To: destination, - Status: log.AccessAccepted, - Reason: "", - Email: user.Email, - }) + + currentPacketCtx := ctx + if inbound.Source.IsValid() { + currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ + From: inbound.Source, + To: destination, + Status: log.AccessAccepted, + Reason: "", + Email: user.Email, + }) + } newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx)) if !buf.Cone || dest == nil { dest = &destination } - udpServer.Dispatch(ctx, *dest, b) // first packet + udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet for _, payload := range mb2 { - udpServer.Dispatch(ctx, *dest, payload) + udpServer.Dispatch(currentPacketCtx, *dest, payload) } } }