mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-01 19:26:39 +00:00
WireGuard: Improve config error handling; Prevent panic in case of errors during server initialization (#4566)
https://github.com/XTLS/Xray-core/pull/4566#issuecomment-2764779273
This commit is contained in:
parent
52a2c63682
commit
17207fc5e4
@ -67,7 +67,7 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
||||
var err error
|
||||
config.SecretKey, err = ParseWireGuardKey(c.SecretKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("invalid WireGuard secret key: %w", err)
|
||||
}
|
||||
|
||||
if c.Address == nil {
|
||||
@ -126,6 +126,10 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
||||
func ParseWireGuardKey(str string) (string, error) {
|
||||
var err error
|
||||
|
||||
if str == "" {
|
||||
return "", errors.New("key must not be empty")
|
||||
}
|
||||
|
||||
if len(str)%2 == 0 {
|
||||
_, err = hex.DecodeString(str)
|
||||
if err == nil {
|
||||
|
@ -241,14 +241,14 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) {
|
||||
}
|
||||
rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to load inbound detour config.").Base(err)
|
||||
return nil, errors.New("failed to load inbound detour config for protocol ", c.Protocol).Base(err)
|
||||
}
|
||||
if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok {
|
||||
receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect
|
||||
}
|
||||
ts, err := rawConfig.(Buildable).Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build inbound handler for protocol ", c.Protocol).Base(err)
|
||||
}
|
||||
|
||||
return &core.InboundHandlerConfig{
|
||||
@ -303,7 +303,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
||||
if c.StreamSetting != nil {
|
||||
ss, err := c.StreamSetting.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build stream settings for outbound detour").Base(err)
|
||||
}
|
||||
senderSettings.StreamSettings = ss
|
||||
}
|
||||
@ -311,7 +311,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
||||
if c.ProxySettings != nil {
|
||||
ps, err := c.ProxySettings.Build()
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid outbound detour proxy settings.").Base(err)
|
||||
return nil, errors.New("invalid outbound detour proxy settings").Base(err)
|
||||
}
|
||||
if ps.TransportLayerProxy {
|
||||
if senderSettings.StreamSettings != nil {
|
||||
@ -331,7 +331,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
||||
if c.MuxSettings != nil {
|
||||
ms, err := c.MuxSettings.Build()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to build Mux config.").Base(err)
|
||||
return nil, errors.New("failed to build Mux config").Base(err)
|
||||
}
|
||||
senderSettings.MultiplexSettings = ms
|
||||
}
|
||||
@ -342,11 +342,11 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
||||
}
|
||||
rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to parse to outbound detour config.").Base(err)
|
||||
return nil, errors.New("failed to load outbound detour config for protocol ", c.Protocol).Base(err)
|
||||
}
|
||||
ts, err := rawConfig.(Buildable).Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build outbound handler for protocol ", c.Protocol).Base(err)
|
||||
}
|
||||
|
||||
return &core.OutboundHandlerConfig{
|
||||
@ -490,7 +490,7 @@ func (c *Config) Override(o *Config, fn string) {
|
||||
// Build implements Buildable.
|
||||
func (c *Config) Build() (*core.Config, error) {
|
||||
if err := PostProcessConfigureFile(c); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to post-process configuration file").Base(err)
|
||||
}
|
||||
|
||||
config := &core.Config{
|
||||
@ -504,21 +504,21 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.API != nil {
|
||||
apiConf, err := c.API.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build API configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(apiConf))
|
||||
}
|
||||
if c.Metrics != nil {
|
||||
metricsConf, err := c.Metrics.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build metrics configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(metricsConf))
|
||||
}
|
||||
if c.Stats != nil {
|
||||
statsConf, err := c.Stats.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build stats configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(statsConf))
|
||||
}
|
||||
@ -536,7 +536,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.RouterConfig != nil {
|
||||
routerConfig, err := c.RouterConfig.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build routing configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(routerConfig))
|
||||
}
|
||||
@ -544,7 +544,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.DNSConfig != nil {
|
||||
dnsApp, err := c.DNSConfig.Build()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to parse DNS config").Base(err)
|
||||
return nil, errors.New("failed to build DNS configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(dnsApp))
|
||||
}
|
||||
@ -552,7 +552,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.Policy != nil {
|
||||
pc, err := c.Policy.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build policy configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(pc))
|
||||
}
|
||||
@ -560,7 +560,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.Reverse != nil {
|
||||
r, err := c.Reverse.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build reverse configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(r))
|
||||
}
|
||||
@ -568,7 +568,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.FakeDNS != nil {
|
||||
r, err := c.FakeDNS.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build fake DNS configuration").Base(err)
|
||||
}
|
||||
config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...)
|
||||
}
|
||||
@ -576,7 +576,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.Observatory != nil {
|
||||
r, err := c.Observatory.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build observatory configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(r))
|
||||
}
|
||||
@ -584,7 +584,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
if c.BurstObservatory != nil {
|
||||
r, err := c.BurstObservatory.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build burst observatory configuration").Base(err)
|
||||
}
|
||||
config.App = append(config.App, serial.ToTypedMessage(r))
|
||||
}
|
||||
@ -602,7 +602,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
for _, rawInboundConfig := range inbounds {
|
||||
ic, err := rawInboundConfig.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build inbound config with tag ", rawInboundConfig.Tag).Base(err)
|
||||
}
|
||||
config.Inbound = append(config.Inbound, ic)
|
||||
}
|
||||
@ -616,7 +616,7 @@ func (c *Config) Build() (*core.Config, error) {
|
||||
for _, rawOutboundConfig := range outbounds {
|
||||
oc, err := rawOutboundConfig.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("failed to build outbound config with tag ", rawOutboundConfig.Tag).Base(err)
|
||||
}
|
||||
config.Outbound = append(config.Outbound, oc)
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@ -33,6 +34,7 @@ type netTun struct {
|
||||
incomingPacket chan *buffer.View
|
||||
mtu int
|
||||
hasV4, hasV6 bool
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type Net netTun
|
||||
@ -174,18 +176,15 @@ func (tun *netTun) Flush() error {
|
||||
|
||||
// Close implements tun.Device
|
||||
func (tun *netTun) Close() error {
|
||||
tun.stack.RemoveNIC(1)
|
||||
tun.closeOnce.Do(func() {
|
||||
tun.stack.RemoveNIC(1)
|
||||
|
||||
if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
|
||||
tun.ep.Close()
|
||||
tun.ep.Close()
|
||||
|
||||
if tun.incomingPacket != nil {
|
||||
close(tun.incomingPacket)
|
||||
}
|
||||
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
52
proxy/wireguard/server_test.go
Normal file
52
proxy/wireguard/server_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package wireguard_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/proxy/wireguard"
|
||||
)
|
||||
|
||||
// TestWireGuardServerInitializationError verifies that an error during TUN initialization
|
||||
// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead.
|
||||
func TestWireGuardServerInitializationError(t *testing.T) {
|
||||
// Create a minimal core instance with default features
|
||||
config := &core.Config{}
|
||||
instance, err := core.New(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create core instance: %v", err)
|
||||
}
|
||||
// Set the Xray instance in the context
|
||||
ctx := context.WithValue(context.Background(), core.XrayKey(1), instance)
|
||||
|
||||
// Define the server configuration with an empty SecretKey to trigger error
|
||||
conf := &wireguard.DeviceConfig{
|
||||
IsClient: false,
|
||||
Endpoint: []string{"10.0.0.1/32"},
|
||||
Mtu: 1420,
|
||||
SecretKey: "", // Empty SecretKey to trigger error
|
||||
Peers: []*wireguard.PeerConfig{
|
||||
{
|
||||
PublicKey: "some_public_key",
|
||||
AllowedIps: []string{"10.0.0.2/32"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Use defer to catch any panic and fail the test explicitly
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("TUN initialization panicked: %v", r)
|
||||
debug.PrintStack()
|
||||
}
|
||||
}()
|
||||
|
||||
// Attempt to initialize the WireGuard server
|
||||
_, err = wireguard.NewServer(ctx, conf)
|
||||
|
||||
// Check that an error is returned
|
||||
assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice")
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user