diff --git a/infra/conf/wireguard.go b/infra/conf/wireguard.go index 9952101a..34ce7215 100644 --- a/infra/conf/wireguard.go +++ b/infra/conf/wireguard.go @@ -67,7 +67,7 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { var err error config.SecretKey, err = ParseWireGuardKey(c.SecretKey) if err != nil { - return nil, err + return nil, errors.New("invalid WireGuard secret key: %w", err) } if c.Address == nil { @@ -126,6 +126,10 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { func ParseWireGuardKey(str string) (string, error) { var err error + if str == "" { + return "", errors.New("key must not be empty") + } + if len(str)%2 == 0 { _, err = hex.DecodeString(str) if err == nil { diff --git a/infra/conf/xray.go b/infra/conf/xray.go index a9cc88bc..4b084b56 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -241,14 +241,14 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) { } rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol) if err != nil { - return nil, errors.New("failed to load inbound detour config.").Base(err) + return nil, errors.New("failed to load inbound detour config for protocol ", c.Protocol).Base(err) } if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok { receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect } ts, err := rawConfig.(Buildable).Build() if err != nil { - return nil, err + return nil, errors.New("failed to build inbound handler for protocol ", c.Protocol).Base(err) } return &core.InboundHandlerConfig{ @@ -303,7 +303,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { if c.StreamSetting != nil { ss, err := c.StreamSetting.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build stream settings for outbound detour").Base(err) } senderSettings.StreamSettings = ss } @@ -311,7 +311,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { if c.ProxySettings != nil { ps, err := c.ProxySettings.Build() if err != nil { - return nil, errors.New("invalid outbound detour proxy settings.").Base(err) + return nil, errors.New("invalid outbound detour proxy settings").Base(err) } if ps.TransportLayerProxy { if senderSettings.StreamSettings != nil { @@ -331,7 +331,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { if c.MuxSettings != nil { ms, err := c.MuxSettings.Build() if err != nil { - return nil, errors.New("failed to build Mux config.").Base(err) + return nil, errors.New("failed to build Mux config").Base(err) } senderSettings.MultiplexSettings = ms } @@ -342,11 +342,11 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { } rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol) if err != nil { - return nil, errors.New("failed to parse to outbound detour config.").Base(err) + return nil, errors.New("failed to load outbound detour config for protocol ", c.Protocol).Base(err) } ts, err := rawConfig.(Buildable).Build() if err != nil { - return nil, err + return nil, errors.New("failed to build outbound handler for protocol ", c.Protocol).Base(err) } return &core.OutboundHandlerConfig{ @@ -490,7 +490,7 @@ func (c *Config) Override(o *Config, fn string) { // Build implements Buildable. func (c *Config) Build() (*core.Config, error) { if err := PostProcessConfigureFile(c); err != nil { - return nil, err + return nil, errors.New("failed to post-process configuration file").Base(err) } config := &core.Config{ @@ -504,21 +504,21 @@ func (c *Config) Build() (*core.Config, error) { if c.API != nil { apiConf, err := c.API.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build API configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(apiConf)) } if c.Metrics != nil { metricsConf, err := c.Metrics.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build metrics configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(metricsConf)) } if c.Stats != nil { statsConf, err := c.Stats.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build stats configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(statsConf)) } @@ -536,7 +536,7 @@ func (c *Config) Build() (*core.Config, error) { if c.RouterConfig != nil { routerConfig, err := c.RouterConfig.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build routing configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(routerConfig)) } @@ -544,7 +544,7 @@ func (c *Config) Build() (*core.Config, error) { if c.DNSConfig != nil { dnsApp, err := c.DNSConfig.Build() if err != nil { - return nil, errors.New("failed to parse DNS config").Base(err) + return nil, errors.New("failed to build DNS configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(dnsApp)) } @@ -552,7 +552,7 @@ func (c *Config) Build() (*core.Config, error) { if c.Policy != nil { pc, err := c.Policy.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build policy configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(pc)) } @@ -560,7 +560,7 @@ func (c *Config) Build() (*core.Config, error) { if c.Reverse != nil { r, err := c.Reverse.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build reverse configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(r)) } @@ -568,7 +568,7 @@ func (c *Config) Build() (*core.Config, error) { if c.FakeDNS != nil { r, err := c.FakeDNS.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build fake DNS configuration").Base(err) } config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...) } @@ -576,7 +576,7 @@ func (c *Config) Build() (*core.Config, error) { if c.Observatory != nil { r, err := c.Observatory.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build observatory configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(r)) } @@ -584,7 +584,7 @@ func (c *Config) Build() (*core.Config, error) { if c.BurstObservatory != nil { r, err := c.BurstObservatory.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build burst observatory configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(r)) } @@ -602,7 +602,7 @@ func (c *Config) Build() (*core.Config, error) { for _, rawInboundConfig := range inbounds { ic, err := rawInboundConfig.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build inbound config with tag ", rawInboundConfig.Tag).Base(err) } config.Inbound = append(config.Inbound, ic) } @@ -616,7 +616,7 @@ func (c *Config) Build() (*core.Config, error) { for _, rawOutboundConfig := range outbounds { oc, err := rawOutboundConfig.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build outbound config with tag ", rawOutboundConfig.Tag).Base(err) } config.Outbound = append(config.Outbound, oc) } diff --git a/proxy/wireguard/gvisortun/tun.go b/proxy/wireguard/gvisortun/tun.go index 65677c48..2f9aa33c 100644 --- a/proxy/wireguard/gvisortun/tun.go +++ b/proxy/wireguard/gvisortun/tun.go @@ -10,6 +10,7 @@ import ( "fmt" "net/netip" "os" + "sync" "syscall" "golang.zx2c4.com/wireguard/tun" @@ -33,6 +34,7 @@ type netTun struct { incomingPacket chan *buffer.View mtu int hasV4, hasV6 bool + closeOnce sync.Once } type Net netTun @@ -174,18 +176,15 @@ func (tun *netTun) Flush() error { // Close implements tun.Device func (tun *netTun) Close() error { - tun.stack.RemoveNIC(1) + tun.closeOnce.Do(func() { + tun.stack.RemoveNIC(1) - if tun.events != nil { close(tun.events) - } - tun.ep.Close() + tun.ep.Close() - if tun.incomingPacket != nil { close(tun.incomingPacket) - } - + }) return nil } diff --git a/proxy/wireguard/server_test.go b/proxy/wireguard/server_test.go new file mode 100644 index 00000000..057b508e --- /dev/null +++ b/proxy/wireguard/server_test.go @@ -0,0 +1,52 @@ +package wireguard_test + +import ( + "context" + "github.com/stretchr/testify/assert" + "runtime/debug" + "testing" + + "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/proxy/wireguard" +) + +// TestWireGuardServerInitializationError verifies that an error during TUN initialization +// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead. +func TestWireGuardServerInitializationError(t *testing.T) { + // Create a minimal core instance with default features + config := &core.Config{} + instance, err := core.New(config) + if err != nil { + t.Fatalf("Failed to create core instance: %v", err) + } + // Set the Xray instance in the context + ctx := context.WithValue(context.Background(), core.XrayKey(1), instance) + + // Define the server configuration with an empty SecretKey to trigger error + conf := &wireguard.DeviceConfig{ + IsClient: false, + Endpoint: []string{"10.0.0.1/32"}, + Mtu: 1420, + SecretKey: "", // Empty SecretKey to trigger error + Peers: []*wireguard.PeerConfig{ + { + PublicKey: "some_public_key", + AllowedIps: []string{"10.0.0.2/32"}, + }, + }, + } + + // Use defer to catch any panic and fail the test explicitly + defer func() { + if r := recover(); r != nil { + t.Errorf("TUN initialization panicked: %v", r) + debug.PrintStack() + } + }() + + // Attempt to initialize the WireGuard server + _, err = wireguard.NewServer(ctx, conf) + + // Check that an error is returned + assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice") +}