diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index ab44a1d5..9a91480f 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -241,7 +241,9 @@ func (h *Handler) DestIpAddress() net.IP { // Dial implements internet.Dialer. func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) { if h.senderSettings != nil { + if h.senderSettings.ProxySettings.HasTag() { + tag := h.senderSettings.ProxySettings.Tag handler := h.outboundManager.GetHandler(tag) if handler != nil { @@ -270,22 +272,40 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti } if h.senderSettings.Via != nil { + outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] - if h.senderSettings.ViaCidr == "" { - if h.senderSettings.Via.AsAddress().Family().IsDomain() && h.senderSettings.Via.AsAddress().Domain() == "origin" { - if inbound := session.InboundFromContext(ctx); inbound != nil { - origin, _, err := net.SplitHostPort(inbound.Conn.LocalAddr().String()) - if err == nil { - ob.Gateway = net.ParseAddress(origin) - } - } - } else { - ob.Gateway = h.senderSettings.Via.AsAddress() - } - } else { //Get a random address. - ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr) + addr := h.senderSettings.Via.AsAddress() + var domain string + if addr.Family().IsDomain() { + domain = addr.Domain() } + switch { + case h.senderSettings.ViaCidr != "": + ob.Gateway = ParseRandomIP(addr, h.senderSettings.ViaCidr) + + case domain == "origin": + + if inbound := session.InboundFromContext(ctx); inbound != nil { + origin, _, err := net.SplitHostPort(inbound.Conn.LocalAddr().String()) + if err == nil { + ob.Gateway = net.ParseAddress(origin) + } + + } + case domain == "srcip": + if inbound := session.InboundFromContext(ctx); inbound != nil { + srcip, _, err := net.SplitHostPort(inbound.Conn.RemoteAddr().String()) + if err == nil { + ob.Gateway = net.ParseAddress(srcip) + } + } + //case addr.Family().IsDomain(): + default: + ob.Gateway = addr + + } + } } @@ -329,20 +349,21 @@ func (h *Handler) Close() error { return nil } -func ParseRandomIPv6(address net.Address, prefix string) net.Address { - _, network, _ := gonet.ParseCIDR(address.IP().String() + "/" + prefix) +func ParseRandomIP(addr net.Address, prefix string) net.Address { - maskSize, totalBits := network.Mask.Size() - subnetSize := big.NewInt(1).Lsh(big.NewInt(1), uint(totalBits-maskSize)) + _, ipnet, _ := gonet.ParseCIDR(addr.IP().String() + "/" + prefix) - // random - randomBigInt, _ := rand.Int(rand.Reader, subnetSize) + ones, bits := ipnet.Mask.Size() + subnetSize := new(big.Int).Lsh(big.NewInt(1), uint(bits-ones)) - startIPBigInt := big.NewInt(0).SetBytes(network.IP.To16()) - randomIPBigInt := big.NewInt(0).Add(startIPBigInt, randomBigInt) + rnd, _ := rand.Int(rand.Reader, subnetSize) - randomIPBytes := randomIPBigInt.Bytes() - randomIPBytes = append(make([]byte, 16-len(randomIPBytes)), randomIPBytes...) + startInt := new(big.Int).SetBytes(ipnet.IP) + rndInt := new(big.Int).Add(startInt, rnd) - return net.ParseAddress(gonet.IP(randomIPBytes).String()) + rndBytes := rndInt.Bytes() + padded := make([]byte, len(ipnet.IP)) + copy(padded[len(padded)-len(rndBytes):], rndBytes) + + return net.ParseAddress(gonet.IP(padded).String()) }