From 9d89944967329ade25e497b77b49b3c74d3da508 Mon Sep 17 00:00:00 2001 From: patterniha <71074308+patterniha@users.noreply.github.com> Date: Thu, 12 Jun 2025 18:22:09 +0330 Subject: [PATCH] Freedom: fix UDP reply address --- proxy/freedom/freedom.go | 83 +++++++++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 26 deletions(-) diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 0399c3d1..f31ab06e 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "io" + "sync" "time" "github.com/pires/go-proxyproto" @@ -238,7 +239,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if destination.Network == net.Network_TCP { reader = buf.NewReader(conn) } else { - reader = NewPacketReader(conn, UDPOverride) + reader = NewPacketReader(ctx, conn, UDPOverride) } if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil { return errors.New("failed to process response").Base(err) @@ -273,7 +274,7 @@ func isTLSConn(conn stat.Connection) bool { return false } -func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader { +func NewPacketReader(ctx context.Context, conn net.Conn, UDPOverride net.Destination) buf.Reader { iConn := conn statConn, ok := iConn.(*stat.CounterConnection) if ok { @@ -283,10 +284,17 @@ func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader { if statConn != nil { counter = statConn.ReadCounter } - if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 { + if c, ok := iConn.(*internet.PacketConnWrapper); ok { + isAddrChanged := false + outbounds := session.OutboundsFromContext(ctx) + targetAddr := outbounds[len(outbounds)-1].Target.Address + if UDPOverride.Address != nil || UDPOverride.Port != 0 || targetAddr.Family().IsDomain() { + isAddrChanged = true + } return &PacketReader{ PacketConnWrapper: c, Counter: counter, + IsAddrChanged: isAddrChanged, } } return &buf.PacketReader{Reader: conn} @@ -295,6 +303,7 @@ func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader { type PacketReader struct { *internet.PacketConnWrapper stats.Counter + IsAddrChanged bool } func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { @@ -306,10 +315,12 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { return nil, err } b.Resize(0, int32(n)) - b.UDP = &net.Destination{ - Address: net.IPAddress(d.(*net.UDPAddr).IP), - Port: net.Port(d.(*net.UDPAddr).Port), - Network: net.Network_UDP, + if !r.IsAddrChanged { + b.UDP = &net.Destination{ + Address: net.IPAddress(d.(*net.UDPAddr).IP), + Port: net.Port(d.(*net.UDPAddr).Port), + Network: net.Network_UDP, + } } if r.Counter != nil { r.Counter.Add(int64(n)) @@ -335,7 +346,7 @@ func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride resolvedUDPAddr := make(map[string]net.Address) if targetAddr.Family().IsDomain() { RemoteAddress, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) - resolvedUDPAddr[targetAddr.String()] = net.ParseAddress(RemoteAddress) + resolvedUDPAddr[targetAddr.Domain()] = net.ParseAddress(RemoteAddress) } return &PacketWriter{ PacketConnWrapper: c, @@ -362,6 +373,43 @@ type PacketWriter struct { // Resulting in these packets being sent to many different IPs randomly // So, cache and keep the resolve result resolvedUDPAddr map[string]net.Address + sync.Mutex +} + +func (w *PacketWriter) getDestAddr(dest *net.Destination) net.Addr { + if w.UDPOverride.Address != nil { + dest.Address = w.UDPOverride.Address + } + if w.UDPOverride.Port != 0 { + dest.Port = w.UDPOverride.Port + } + + if dest.Address.Family().IsDomain() { + w.Lock() + defer w.Unlock() + ip := w.resolvedUDPAddr[dest.Address.Domain()] + if ip != nil { + dest.Address = ip + return dest.RawNetAddr() + } + if w.Handler.config.hasStrategy() { + ip := w.Handler.resolveIP(w.Context, dest.Address.Domain(), nil) + if ip != nil { + w.resolvedUDPAddr[dest.Address.Domain()] = ip + dest.Address = ip + return dest.RawNetAddr() + } + if w.Handler.config.forceIP() { + return nil + } + } + destAddr, _ := net.ResolveUDPAddr("udp", dest.NetAddr()) + if destAddr != nil { + w.resolvedUDPAddr[dest.Address.Domain()] = net.IPAddress(destAddr.IP) + } + return destAddr + } + return dest.RawNetAddr() } func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { @@ -374,24 +422,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { var n int var err error if b.UDP != nil { - if w.UDPOverride.Address != nil { - b.UDP.Address = w.UDPOverride.Address - } - if w.UDPOverride.Port != 0 { - b.UDP.Port = w.UDPOverride.Port - } - if w.Handler.config.hasStrategy() && b.UDP.Address.Family().IsDomain() { - if ip := w.resolvedUDPAddr[b.UDP.Address.Domain()]; ip != nil { - b.UDP.Address = ip - } else { - ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil) - if ip != nil { - b.UDP.Address = ip - w.resolvedUDPAddr[b.UDP.Address.Domain()] = ip - } - } - } - destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr()) + destAddr := w.getDestAddr(b.UDP) if destAddr == nil { b.Release() continue