mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-04-03 04:06:37 +00:00
WireGuard: Improve config error handling; Prevent panic in case of errors during server initialization (#4566)
https://github.com/XTLS/Xray-core/pull/4566#issuecomment-2764779273
This commit is contained in:
parent
52a2c63682
commit
17207fc5e4
@ -67,7 +67,7 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
|||||||
var err error
|
var err error
|
||||||
config.SecretKey, err = ParseWireGuardKey(c.SecretKey)
|
config.SecretKey, err = ParseWireGuardKey(c.SecretKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("invalid WireGuard secret key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Address == nil {
|
if c.Address == nil {
|
||||||
@ -126,6 +126,10 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
|||||||
func ParseWireGuardKey(str string) (string, error) {
|
func ParseWireGuardKey(str string) (string, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
if str == "" {
|
||||||
|
return "", errors.New("key must not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
if len(str)%2 == 0 {
|
if len(str)%2 == 0 {
|
||||||
_, err = hex.DecodeString(str)
|
_, err = hex.DecodeString(str)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -241,14 +241,14 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) {
|
|||||||
}
|
}
|
||||||
rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol)
|
rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("failed to load inbound detour config.").Base(err)
|
return nil, errors.New("failed to load inbound detour config for protocol ", c.Protocol).Base(err)
|
||||||
}
|
}
|
||||||
if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok {
|
if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok {
|
||||||
receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect
|
receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect
|
||||||
}
|
}
|
||||||
ts, err := rawConfig.(Buildable).Build()
|
ts, err := rawConfig.(Buildable).Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build inbound handler for protocol ", c.Protocol).Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &core.InboundHandlerConfig{
|
return &core.InboundHandlerConfig{
|
||||||
@ -303,7 +303,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
|||||||
if c.StreamSetting != nil {
|
if c.StreamSetting != nil {
|
||||||
ss, err := c.StreamSetting.Build()
|
ss, err := c.StreamSetting.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build stream settings for outbound detour").Base(err)
|
||||||
}
|
}
|
||||||
senderSettings.StreamSettings = ss
|
senderSettings.StreamSettings = ss
|
||||||
}
|
}
|
||||||
@ -311,7 +311,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
|||||||
if c.ProxySettings != nil {
|
if c.ProxySettings != nil {
|
||||||
ps, err := c.ProxySettings.Build()
|
ps, err := c.ProxySettings.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("invalid outbound detour proxy settings.").Base(err)
|
return nil, errors.New("invalid outbound detour proxy settings").Base(err)
|
||||||
}
|
}
|
||||||
if ps.TransportLayerProxy {
|
if ps.TransportLayerProxy {
|
||||||
if senderSettings.StreamSettings != nil {
|
if senderSettings.StreamSettings != nil {
|
||||||
@ -331,7 +331,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
|||||||
if c.MuxSettings != nil {
|
if c.MuxSettings != nil {
|
||||||
ms, err := c.MuxSettings.Build()
|
ms, err := c.MuxSettings.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("failed to build Mux config.").Base(err)
|
return nil, errors.New("failed to build Mux config").Base(err)
|
||||||
}
|
}
|
||||||
senderSettings.MultiplexSettings = ms
|
senderSettings.MultiplexSettings = ms
|
||||||
}
|
}
|
||||||
@ -342,11 +342,11 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
|||||||
}
|
}
|
||||||
rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol)
|
rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("failed to parse to outbound detour config.").Base(err)
|
return nil, errors.New("failed to load outbound detour config for protocol ", c.Protocol).Base(err)
|
||||||
}
|
}
|
||||||
ts, err := rawConfig.(Buildable).Build()
|
ts, err := rawConfig.(Buildable).Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build outbound handler for protocol ", c.Protocol).Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &core.OutboundHandlerConfig{
|
return &core.OutboundHandlerConfig{
|
||||||
@ -490,7 +490,7 @@ func (c *Config) Override(o *Config, fn string) {
|
|||||||
// Build implements Buildable.
|
// Build implements Buildable.
|
||||||
func (c *Config) Build() (*core.Config, error) {
|
func (c *Config) Build() (*core.Config, error) {
|
||||||
if err := PostProcessConfigureFile(c); err != nil {
|
if err := PostProcessConfigureFile(c); err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to post-process configuration file").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &core.Config{
|
config := &core.Config{
|
||||||
@ -504,21 +504,21 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.API != nil {
|
if c.API != nil {
|
||||||
apiConf, err := c.API.Build()
|
apiConf, err := c.API.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build API configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(apiConf))
|
config.App = append(config.App, serial.ToTypedMessage(apiConf))
|
||||||
}
|
}
|
||||||
if c.Metrics != nil {
|
if c.Metrics != nil {
|
||||||
metricsConf, err := c.Metrics.Build()
|
metricsConf, err := c.Metrics.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build metrics configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(metricsConf))
|
config.App = append(config.App, serial.ToTypedMessage(metricsConf))
|
||||||
}
|
}
|
||||||
if c.Stats != nil {
|
if c.Stats != nil {
|
||||||
statsConf, err := c.Stats.Build()
|
statsConf, err := c.Stats.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build stats configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(statsConf))
|
config.App = append(config.App, serial.ToTypedMessage(statsConf))
|
||||||
}
|
}
|
||||||
@ -536,7 +536,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.RouterConfig != nil {
|
if c.RouterConfig != nil {
|
||||||
routerConfig, err := c.RouterConfig.Build()
|
routerConfig, err := c.RouterConfig.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build routing configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(routerConfig))
|
config.App = append(config.App, serial.ToTypedMessage(routerConfig))
|
||||||
}
|
}
|
||||||
@ -544,7 +544,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.DNSConfig != nil {
|
if c.DNSConfig != nil {
|
||||||
dnsApp, err := c.DNSConfig.Build()
|
dnsApp, err := c.DNSConfig.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("failed to parse DNS config").Base(err)
|
return nil, errors.New("failed to build DNS configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(dnsApp))
|
config.App = append(config.App, serial.ToTypedMessage(dnsApp))
|
||||||
}
|
}
|
||||||
@ -552,7 +552,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.Policy != nil {
|
if c.Policy != nil {
|
||||||
pc, err := c.Policy.Build()
|
pc, err := c.Policy.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build policy configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(pc))
|
config.App = append(config.App, serial.ToTypedMessage(pc))
|
||||||
}
|
}
|
||||||
@ -560,7 +560,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.Reverse != nil {
|
if c.Reverse != nil {
|
||||||
r, err := c.Reverse.Build()
|
r, err := c.Reverse.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build reverse configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(r))
|
config.App = append(config.App, serial.ToTypedMessage(r))
|
||||||
}
|
}
|
||||||
@ -568,7 +568,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.FakeDNS != nil {
|
if c.FakeDNS != nil {
|
||||||
r, err := c.FakeDNS.Build()
|
r, err := c.FakeDNS.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build fake DNS configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...)
|
config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...)
|
||||||
}
|
}
|
||||||
@ -576,7 +576,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.Observatory != nil {
|
if c.Observatory != nil {
|
||||||
r, err := c.Observatory.Build()
|
r, err := c.Observatory.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build observatory configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(r))
|
config.App = append(config.App, serial.ToTypedMessage(r))
|
||||||
}
|
}
|
||||||
@ -584,7 +584,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
if c.BurstObservatory != nil {
|
if c.BurstObservatory != nil {
|
||||||
r, err := c.BurstObservatory.Build()
|
r, err := c.BurstObservatory.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build burst observatory configuration").Base(err)
|
||||||
}
|
}
|
||||||
config.App = append(config.App, serial.ToTypedMessage(r))
|
config.App = append(config.App, serial.ToTypedMessage(r))
|
||||||
}
|
}
|
||||||
@ -602,7 +602,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
for _, rawInboundConfig := range inbounds {
|
for _, rawInboundConfig := range inbounds {
|
||||||
ic, err := rawInboundConfig.Build()
|
ic, err := rawInboundConfig.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build inbound config with tag ", rawInboundConfig.Tag).Base(err)
|
||||||
}
|
}
|
||||||
config.Inbound = append(config.Inbound, ic)
|
config.Inbound = append(config.Inbound, ic)
|
||||||
}
|
}
|
||||||
@ -616,7 +616,7 @@ func (c *Config) Build() (*core.Config, error) {
|
|||||||
for _, rawOutboundConfig := range outbounds {
|
for _, rawOutboundConfig := range outbounds {
|
||||||
oc, err := rawOutboundConfig.Build()
|
oc, err := rawOutboundConfig.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.New("failed to build outbound config with tag ", rawOutboundConfig.Tag).Base(err)
|
||||||
}
|
}
|
||||||
config.Outbound = append(config.Outbound, oc)
|
config.Outbound = append(config.Outbound, oc)
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@ -33,6 +34,7 @@ type netTun struct {
|
|||||||
incomingPacket chan *buffer.View
|
incomingPacket chan *buffer.View
|
||||||
mtu int
|
mtu int
|
||||||
hasV4, hasV6 bool
|
hasV4, hasV6 bool
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
type Net netTun
|
type Net netTun
|
||||||
@ -174,18 +176,15 @@ func (tun *netTun) Flush() error {
|
|||||||
|
|
||||||
// Close implements tun.Device
|
// Close implements tun.Device
|
||||||
func (tun *netTun) Close() error {
|
func (tun *netTun) Close() error {
|
||||||
tun.stack.RemoveNIC(1)
|
tun.closeOnce.Do(func() {
|
||||||
|
tun.stack.RemoveNIC(1)
|
||||||
|
|
||||||
if tun.events != nil {
|
|
||||||
close(tun.events)
|
close(tun.events)
|
||||||
}
|
|
||||||
|
|
||||||
tun.ep.Close()
|
tun.ep.Close()
|
||||||
|
|
||||||
if tun.incomingPacket != nil {
|
|
||||||
close(tun.incomingPacket)
|
close(tun.incomingPacket)
|
||||||
}
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
52
proxy/wireguard/server_test.go
Normal file
52
proxy/wireguard/server_test.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package wireguard_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"runtime/debug"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/xtls/xray-core/core"
|
||||||
|
"github.com/xtls/xray-core/proxy/wireguard"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestWireGuardServerInitializationError verifies that an error during TUN initialization
|
||||||
|
// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead.
|
||||||
|
func TestWireGuardServerInitializationError(t *testing.T) {
|
||||||
|
// Create a minimal core instance with default features
|
||||||
|
config := &core.Config{}
|
||||||
|
instance, err := core.New(config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create core instance: %v", err)
|
||||||
|
}
|
||||||
|
// Set the Xray instance in the context
|
||||||
|
ctx := context.WithValue(context.Background(), core.XrayKey(1), instance)
|
||||||
|
|
||||||
|
// Define the server configuration with an empty SecretKey to trigger error
|
||||||
|
conf := &wireguard.DeviceConfig{
|
||||||
|
IsClient: false,
|
||||||
|
Endpoint: []string{"10.0.0.1/32"},
|
||||||
|
Mtu: 1420,
|
||||||
|
SecretKey: "", // Empty SecretKey to trigger error
|
||||||
|
Peers: []*wireguard.PeerConfig{
|
||||||
|
{
|
||||||
|
PublicKey: "some_public_key",
|
||||||
|
AllowedIps: []string{"10.0.0.2/32"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use defer to catch any panic and fail the test explicitly
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("TUN initialization panicked: %v", r)
|
||||||
|
debug.PrintStack()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Attempt to initialize the WireGuard server
|
||||||
|
_, err = wireguard.NewServer(ctx, conf)
|
||||||
|
|
||||||
|
// Check that an error is returned
|
||||||
|
assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice")
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user