From 7b7084f825a7d74938061b223c79f21801acb7ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A7=8B=E3=81=AE=E3=81=8B=E3=81=88=E3=81=A7?= Date: Sun, 18 Apr 2021 13:21:17 +0800 Subject: [PATCH] Refactor: A faster DomainMatcher implementation (#348) Co-authored-by: DarthVader <61409963+darsvador@users.noreply.github.com> --- app/dispatcher/default.go | 2 +- app/dns/server.go | 1 - app/dns/server_test.go | 4 +- app/router/condition.go | 22 +- app/router/condition_test.go | 96 ++++++- app/router/config.go | 21 +- app/router/config.pb.go | 81 +++--- app/router/config.proto | 2 + common/protocol/tls/sniff.go | 2 +- common/strmatcher/ac_automaton_matcher.go | 243 +++++++++++++++++ common/strmatcher/benchmark_test.go | 13 + common/strmatcher/mph_matcher.go | 301 ++++++++++++++++++++++ common/strmatcher/strmatcher.go | 1 + common/strmatcher/strmatcher_test.go | 169 ++++++++++++ infra/conf/router.go | 6 + 15 files changed, 912 insertions(+), 52 deletions(-) create mode 100644 common/strmatcher/ac_automaton_matcher.go create mode 100644 common/strmatcher/mph_matcher.go diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 50ed6134..e7679751 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -179,7 +179,7 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran func shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { domain := result.Domain() for _, d := range request.ExcludeForDomain { - if domain == d { + if strings.ToLower(domain) == d { return false } } diff --git a/app/dns/server.go b/app/dns/server.go index 0090c998..2004c60f 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -348,7 +348,6 @@ func (s *Server) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) if domain == "" { return nil, newError("empty domain name") } - domain = strings.ToLower(domain) // normalize the FQDN form query if strings.HasSuffix(domain, ".") { diff --git a/app/dns/server_test.go b/app/dns/server_test.go index c2b984ad..feb382fa 100644 --- a/app/dns/server_test.go +++ b/app/dns/server_test.go @@ -101,8 +101,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { rr, _ := dns.NewRR("localhost-b. IN A 127.0.0.4") ans.Answer = append(ans.Answer, rr) - case q.Name == "mijia\\ cloud." && q.Qtype == dns.TypeA: - rr, _ := dns.NewRR("mijia\\ cloud. IN A 127.0.0.1") + case q.Name == "Mijia\\ Cloud." && q.Qtype == dns.TypeA: + rr, _ := dns.NewRR("Mijia\\ Cloud. IN A 127.0.0.1") ans.Answer = append(ans.Answer, rr) } } diff --git a/app/router/condition.go b/app/router/condition.go index 8c5d1de8..fd421560 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -66,6 +66,24 @@ type DomainMatcher struct { matchers strmatcher.IndexMatcher } +func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) { + g := strmatcher.NewMphMatcherGroup() + for _, d := range domains { + matcherType, f := matcherTypeMap[d.Type] + if !f { + return nil, newError("unsupported domain type", d.Type) + } + _, err := g.AddPattern(d.Value, matcherType) + if err != nil { + return nil, err + } + } + g.Build() + return &DomainMatcher{ + matchers: g, + }, nil +} + func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) { g := new(strmatcher.MatcherGroup) for _, d := range domains { @@ -82,7 +100,7 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) { } func (m *DomainMatcher) ApplyDomain(domain string) bool { - return len(m.matchers.Match(domain)) > 0 + return len(m.matchers.Match(strings.ToLower(domain))) > 0 } // Apply implements Condition. @@ -91,7 +109,7 @@ func (m *DomainMatcher) Apply(ctx routing.Context) bool { if len(domain) == 0 { return false } - return m.ApplyDomain(strings.ToLower(domain)) + return m.ApplyDomain(domain) } type MultiGeoIPMatcher struct { diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 3f7185e6..ef473a24 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -359,6 +359,9 @@ func TestChinaSites(t *testing.T) { matcher, err := NewDomainMatcher(domains) common.Must(err) + acMatcher, err := NewMphMatcherGroup(domains) + common.Must(err) + type TestCase struct { Domain string Output bool @@ -387,9 +390,96 @@ func TestChinaSites(t *testing.T) { } for _, testCase := range testCases { - r := matcher.ApplyDomain(testCase.Domain) - if r != testCase.Output { - t.Error("expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r) + r1 := matcher.ApplyDomain(testCase.Domain) + r2 := acMatcher.ApplyDomain(testCase.Domain) + if r1 != testCase.Output { + t.Error("DomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r1) + } else if r2 != testCase.Output { + t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r2) + } + } +} + +func BenchmarkMphDomainMatcher(b *testing.B) { + domains, err := loadGeoSite("CN") + common.Must(err) + + matcher, err := NewMphMatcherGroup(domains) + common.Must(err) + + type TestCase struct { + Domain string + Output bool + } + testCases := []TestCase{ + { + Domain: "163.com", + Output: true, + }, + { + Domain: "163.com", + Output: true, + }, + { + Domain: "164.com", + Output: false, + }, + { + Domain: "164.com", + Output: false, + }, + } + + for i := 0; i < 1024; i++ { + testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false}) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, testCase := range testCases { + _ = matcher.ApplyDomain(testCase.Domain) + } + } +} + +func BenchmarkDomainMatcher(b *testing.B) { + domains, err := loadGeoSite("CN") + common.Must(err) + + matcher, err := NewDomainMatcher(domains) + common.Must(err) + + type TestCase struct { + Domain string + Output bool + } + testCases := []TestCase{ + { + Domain: "163.com", + Output: true, + }, + { + Domain: "163.com", + Output: true, + }, + { + Domain: "164.com", + Output: false, + }, + { + Domain: "164.com", + Output: false, + }, + } + + for i := 0; i < 1024; i++ { + testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false}) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, testCase := range testCases { + _ = matcher.ApplyDomain(testCase.Domain) } } } diff --git a/app/router/config.go b/app/router/config.go index 07167a53..f7ce0911 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -67,11 +67,24 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { conds := NewConditionChan() if len(rr.Domain) > 0 { - matcher, err := NewDomainMatcher(rr.Domain) - if err != nil { - return nil, newError("failed to build domain condition").Base(err) + switch rr.DomainMatcher { + case "linear": + matcher, err := NewDomainMatcher(rr.Domain) + if err != nil { + return nil, newError("failed to build domain condition").Base(err) + } + conds.Add(matcher) + case "mph", "hybrid": + fallthrough + default: + matcher, err := NewMphMatcherGroup(rr.Domain) + if err != nil { + return nil, newError("failed to build domain condition with MphDomainMatcher").Base(err) + } + newError("MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)").AtDebug().WriteToLog() + conds.Add(matcher) } - conds.Add(matcher) + } if len(rr.UserEmail) > 0 { diff --git a/app/router/config.pb.go b/app/router/config.pb.go index 69abd370..5938781b 100644 --- a/app/router/config.pb.go +++ b/app/router/config.pb.go @@ -1,13 +1,12 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.25.0 -// protoc v3.14.0 +// protoc-gen-go v1.26.0 +// protoc v3.15.8 // source: app/router/config.proto package router import ( - proto "github.com/golang/protobuf/proto" net "github.com/xtls/xray-core/common/net" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" @@ -22,10 +21,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - // Type of domain value. type Domain_Type int32 @@ -515,6 +510,7 @@ type RoutingRule struct { InboundTag []string `protobuf:"bytes,8,rep,name=inbound_tag,json=inboundTag,proto3" json:"inbound_tag,omitempty"` Protocol []string `protobuf:"bytes,9,rep,name=protocol,proto3" json:"protocol,omitempty"` Attributes string `protobuf:"bytes,15,opt,name=attributes,proto3" json:"attributes,omitempty"` + DomainMatcher string `protobuf:"bytes,17,opt,name=domain_matcher,json=domainMatcher,proto3" json:"domain_matcher,omitempty"` } func (x *RoutingRule) Reset() { @@ -672,6 +668,13 @@ func (x *RoutingRule) GetAttributes() string { return "" } +func (x *RoutingRule) GetDomainMatcher() string { + if x != nil { + return x.DomainMatcher + } + return "" +} + type isRoutingRule_TargetTag interface { isRoutingRule_TargetTag() } @@ -946,7 +949,7 @@ var file_app_router_config_proto_rawDesc = []byte{ 0x74, 0x12, 0x2e, 0x0a, 0x05, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x47, 0x65, 0x6f, 0x53, 0x69, 0x74, 0x65, 0x52, 0x05, 0x65, 0x6e, 0x74, 0x72, - 0x79, 0x22, 0x8e, 0x06, 0x0a, 0x0b, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, + 0x79, 0x22, 0xb5, 0x06, 0x0a, 0x0b, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x12, 0x0a, 0x03, 0x74, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x03, 0x74, 0x61, 0x67, 0x12, 0x25, 0x0a, 0x0d, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x69, 0x6e, 0x67, 0x5f, 0x74, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, @@ -994,36 +997,38 @@ var file_app_router_config_proto_rawDesc = []byte{ 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x09, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x73, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x74, 0x74, 0x72, 0x69, 0x62, - 0x75, 0x74, 0x65, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x5f, 0x74, - 0x61, 0x67, 0x22, 0x4e, 0x0a, 0x0d, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x69, 0x6e, 0x67, 0x52, - 0x75, 0x6c, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x74, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x74, 0x61, 0x67, 0x12, 0x2b, 0x0a, 0x11, 0x6f, 0x75, 0x74, 0x62, 0x6f, 0x75, 0x6e, - 0x64, 0x5f, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x10, 0x6f, 0x75, 0x74, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, - 0x6f, 0x72, 0x22, 0x9b, 0x02, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x4f, 0x0a, - 0x0f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x26, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, - 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x52, 0x0e, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x30, - 0x0a, 0x04, 0x72, 0x75, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x78, - 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x52, - 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x04, 0x72, 0x75, 0x6c, 0x65, - 0x12, 0x45, 0x0a, 0x0e, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x69, 0x6e, 0x67, 0x5f, 0x72, 0x75, - 0x6c, 0x65, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, - 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x42, 0x61, 0x6c, 0x61, 0x6e, - 0x63, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, - 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x22, 0x47, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x08, 0x0a, 0x04, 0x41, 0x73, 0x49, - 0x73, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x55, 0x73, 0x65, 0x49, 0x70, 0x10, 0x01, 0x12, 0x10, - 0x0a, 0x0c, 0x49, 0x70, 0x49, 0x66, 0x4e, 0x6f, 0x6e, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x02, - 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x70, 0x4f, 0x6e, 0x44, 0x65, 0x6d, 0x61, 0x6e, 0x64, 0x10, 0x03, - 0x42, 0x4f, 0x0a, 0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, - 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x50, 0x01, 0x5a, 0x24, 0x67, 0x69, 0x74, 0x68, 0x75, - 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, - 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0xaa, - 0x02, 0x0f, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x75, 0x74, 0x65, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x6d, + 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x18, 0x11, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x42, 0x0c, 0x0a, 0x0a, 0x74, + 0x61, 0x72, 0x67, 0x65, 0x74, 0x5f, 0x74, 0x61, 0x67, 0x22, 0x4e, 0x0a, 0x0d, 0x42, 0x61, 0x6c, + 0x61, 0x6e, 0x63, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x74, 0x61, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x74, 0x61, 0x67, 0x12, 0x2b, 0x0a, 0x11, + 0x6f, 0x75, 0x74, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x5f, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x6f, + 0x72, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x6f, 0x75, 0x74, 0x62, 0x6f, 0x75, 0x6e, + 0x64, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x22, 0x9b, 0x02, 0x0a, 0x06, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x4f, 0x0a, 0x0f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x73, + 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x26, 0x2e, + 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, + 0x61, 0x74, 0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, + 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x30, 0x0a, 0x04, 0x72, 0x75, 0x6c, 0x65, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, + 0x65, 0x52, 0x04, 0x72, 0x75, 0x6c, 0x65, 0x12, 0x45, 0x0a, 0x0e, 0x62, 0x61, 0x6c, 0x61, 0x6e, + 0x63, 0x69, 0x6e, 0x67, 0x5f, 0x72, 0x75, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1e, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x72, 0x2e, 0x42, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, + 0x0d, 0x62, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x22, 0x47, + 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, + 0x12, 0x08, 0x0a, 0x04, 0x41, 0x73, 0x49, 0x73, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x55, 0x73, + 0x65, 0x49, 0x70, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x49, 0x70, 0x49, 0x66, 0x4e, 0x6f, 0x6e, + 0x4d, 0x61, 0x74, 0x63, 0x68, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x70, 0x4f, 0x6e, 0x44, + 0x65, 0x6d, 0x61, 0x6e, 0x64, 0x10, 0x03, 0x42, 0x4f, 0x0a, 0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x78, + 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x50, 0x01, + 0x5a, 0x24, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, + 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, + 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0xaa, 0x02, 0x0f, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x41, 0x70, + 0x70, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/app/router/config.proto b/app/router/config.proto index c5e655ad..8877c657 100644 --- a/app/router/config.proto +++ b/app/router/config.proto @@ -119,6 +119,8 @@ message RoutingRule { repeated string protocol = 9; string attributes = 15; + + string domain_matcher = 17; } message BalancingRule { diff --git a/common/protocol/tls/sniff.go b/common/protocol/tls/sniff.go index f3806a05..a8fec15f 100644 --- a/common/protocol/tls/sniff.go +++ b/common/protocol/tls/sniff.go @@ -102,7 +102,7 @@ func ReadClientHello(data []byte, h *SniffHeader) error { return errNotClientHello } if nameType == 0 { - serverName := strings.ToLower(string(d[:nameLen])) + serverName := string(d[:nameLen]) // An SNI value may not include a // trailing dot. See // https://tools.ietf.org/html/rfc6066#section-3. diff --git a/common/strmatcher/ac_automaton_matcher.go b/common/strmatcher/ac_automaton_matcher.go new file mode 100644 index 00000000..e21364ec --- /dev/null +++ b/common/strmatcher/ac_automaton_matcher.go @@ -0,0 +1,243 @@ +package strmatcher + +import ( + "container/list" +) + +const validCharCount = 53 + +type MatchType struct { + matchType Type + exist bool +} + +const ( + TrieEdge bool = true + FailEdge bool = false +) + +type Edge struct { + edgeType bool + nextNode int +} + +type ACAutomaton struct { + trie [][validCharCount]Edge + fail []int + exists []MatchType + count int +} + +func newNode() [validCharCount]Edge { + var s [validCharCount]Edge + for i := range s { + s[i] = Edge{ + edgeType: FailEdge, + nextNode: 0, + } + } + return s +} + +var char2Index = []int{ + 'A': 0, + 'a': 0, + 'B': 1, + 'b': 1, + 'C': 2, + 'c': 2, + 'D': 3, + 'd': 3, + 'E': 4, + 'e': 4, + 'F': 5, + 'f': 5, + 'G': 6, + 'g': 6, + 'H': 7, + 'h': 7, + 'I': 8, + 'i': 8, + 'J': 9, + 'j': 9, + 'K': 10, + 'k': 10, + 'L': 11, + 'l': 11, + 'M': 12, + 'm': 12, + 'N': 13, + 'n': 13, + 'O': 14, + 'o': 14, + 'P': 15, + 'p': 15, + 'Q': 16, + 'q': 16, + 'R': 17, + 'r': 17, + 'S': 18, + 's': 18, + 'T': 19, + 't': 19, + 'U': 20, + 'u': 20, + 'V': 21, + 'v': 21, + 'W': 22, + 'w': 22, + 'X': 23, + 'x': 23, + 'Y': 24, + 'y': 24, + 'Z': 25, + 'z': 25, + '!': 26, + '$': 27, + '&': 28, + '\'': 29, + '(': 30, + ')': 31, + '*': 32, + '+': 33, + ',': 34, + ';': 35, + '=': 36, + ':': 37, + '%': 38, + '-': 39, + '.': 40, + '_': 41, + '~': 42, + '0': 43, + '1': 44, + '2': 45, + '3': 46, + '4': 47, + '5': 48, + '6': 49, + '7': 50, + '8': 51, + '9': 52, +} + +func NewACAutomaton() *ACAutomaton { + var ac = new(ACAutomaton) + ac.trie = append(ac.trie, newNode()) + ac.fail = append(ac.fail, 0) + ac.exists = append(ac.exists, MatchType{ + matchType: Full, + exist: false, + }) + return ac +} + +func (ac *ACAutomaton) Add(domain string, t Type) { + var node = 0 + for i := len(domain) - 1; i >= 0; i-- { + var idx = char2Index[domain[i]] + if ac.trie[node][idx].nextNode == 0 { + ac.count++ + if len(ac.trie) < ac.count+1 { + ac.trie = append(ac.trie, newNode()) + ac.fail = append(ac.fail, 0) + ac.exists = append(ac.exists, MatchType{ + matchType: Full, + exist: false, + }) + } + ac.trie[node][idx] = Edge{ + edgeType: TrieEdge, + nextNode: ac.count, + } + } + node = ac.trie[node][idx].nextNode + } + ac.exists[node] = MatchType{ + matchType: t, + exist: true, + } + switch t { + case Domain: + ac.exists[node] = MatchType{ + matchType: Full, + exist: true, + } + var idx = char2Index['.'] + if ac.trie[node][idx].nextNode == 0 { + ac.count++ + if len(ac.trie) < ac.count+1 { + ac.trie = append(ac.trie, newNode()) + ac.fail = append(ac.fail, 0) + ac.exists = append(ac.exists, MatchType{ + matchType: Full, + exist: false, + }) + } + ac.trie[node][idx] = Edge{ + edgeType: TrieEdge, + nextNode: ac.count, + } + } + node = ac.trie[node][idx].nextNode + ac.exists[node] = MatchType{ + matchType: t, + exist: true, + } + default: + break + } +} + +func (ac *ACAutomaton) Build() { + var queue = list.New() + for i := 0; i < validCharCount; i++ { + if ac.trie[0][i].nextNode != 0 { + queue.PushBack(ac.trie[0][i]) + } + } + for { + var front = queue.Front() + if front == nil { + break + } else { + var node = front.Value.(Edge).nextNode + queue.Remove(front) + for i := 0; i < validCharCount; i++ { + if ac.trie[node][i].nextNode != 0 { + ac.fail[ac.trie[node][i].nextNode] = ac.trie[ac.fail[node]][i].nextNode + queue.PushBack(ac.trie[node][i]) + } else { + ac.trie[node][i] = Edge{ + edgeType: FailEdge, + nextNode: ac.trie[ac.fail[node]][i].nextNode, + } + } + } + } + } +} + +func (ac *ACAutomaton) Match(s string) bool { + var node = 0 + var fullMatch = true + // 1. the match string is all through trie edge. FULL MATCH or DOMAIN + // 2. the match string is through a fail edge. NOT FULL MATCH + // 2.1 Through a fail edge, but there exists a valid node. SUBSTR + for i := len(s) - 1; i >= 0; i-- { + var idx = char2Index[s[i]] + fullMatch = fullMatch && ac.trie[node][idx].edgeType + node = ac.trie[node][idx].nextNode + switch ac.exists[node].matchType { + case Substr: + return true + case Domain: + if fullMatch { + return true + } + default: + break + } + } + return fullMatch && ac.exists[node].exist +} diff --git a/common/strmatcher/benchmark_test.go b/common/strmatcher/benchmark_test.go index 3e70ca04..9d09140c 100644 --- a/common/strmatcher/benchmark_test.go +++ b/common/strmatcher/benchmark_test.go @@ -8,6 +8,19 @@ import ( . "github.com/xtls/xray-core/common/strmatcher" ) +func BenchmarkACAutomaton(b *testing.B) { + ac := NewACAutomaton() + for i := 1; i <= 1024; i++ { + ac.Add(strconv.Itoa(i)+".v2ray.com", Domain) + } + ac.Build() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ac.Match("0.v2ray.com") + } +} + func BenchmarkDomainMatcherGroup(b *testing.B) { g := new(DomainMatcherGroup) diff --git a/common/strmatcher/mph_matcher.go b/common/strmatcher/mph_matcher.go new file mode 100644 index 00000000..21b98e76 --- /dev/null +++ b/common/strmatcher/mph_matcher.go @@ -0,0 +1,301 @@ +package strmatcher + +import ( + "math/bits" + "regexp" + "sort" + "strings" + "unsafe" +) + +// PrimeRK is the prime base used in Rabin-Karp algorithm. +const PrimeRK = 16777619 + +// calculate the rolling murmurHash of given string +func RollingHash(s string) uint32 { + h := uint32(0) + for i := len(s) - 1; i >= 0; i-- { + h = h*PrimeRK + uint32(s[i]) + } + return h +} + +// A MphMatcherGroup is divided into three parts: +// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table; +// 2. `substr` patterns are matched by ac automaton; +// 3. `regex` patterns are matched with the regex library. +type MphMatcherGroup struct { + ac *ACAutomaton + otherMatchers []matcherEntry + rules []string + level0 []uint32 + level0Mask int + level1 []uint32 + level1Mask int + count uint32 + ruleMap *map[string]uint32 +} + +func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) { + h := RollingHash(pattern) + switch t { + case Domain: + (*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.') + fallthrough + case Full: + (*g.ruleMap)[pattern] = h + default: + } +} + +func NewMphMatcherGroup() *MphMatcherGroup { + return &MphMatcherGroup{ + ac: nil, + otherMatchers: nil, + rules: nil, + level0: nil, + level0Mask: 0, + level1: nil, + level1Mask: 0, + count: 1, + ruleMap: &map[string]uint32{}, + } +} + +// AddPattern adds a pattern to MphMatcherGroup +func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { + switch t { + case Substr: + if g.ac == nil { + g.ac = NewACAutomaton() + } + g.ac.Add(pattern, t) + case Full, Domain: + pattern = strings.ToLower(pattern) + g.AddFullOrDomainPattern(pattern, t) + case Regex: + r, err := regexp.Compile(pattern) + if err != nil { + return 0, err + } + g.otherMatchers = append(g.otherMatchers, matcherEntry{ + m: ®exMatcher{pattern: r}, + id: g.count, + }) + default: + panic("Unknown type") + } + return g.count, nil +} + +// Build builds a minimal perfect hash table and ac automaton from insert rules +func (g *MphMatcherGroup) Build() { + if g.ac != nil { + g.ac.Build() + } + keyLen := len(*g.ruleMap) + if keyLen == 0 { + keyLen = 1 + (*g.ruleMap)["empty___"] = RollingHash("empty___") + } + g.level0 = make([]uint32, nextPow2(keyLen/4)) + g.level0Mask = len(g.level0) - 1 + g.level1 = make([]uint32, nextPow2(keyLen)) + g.level1Mask = len(g.level1) - 1 + var sparseBuckets = make([][]int, len(g.level0)) + var ruleIdx int + for rule, hash := range *g.ruleMap { + n := int(hash) & g.level0Mask + g.rules = append(g.rules, rule) + sparseBuckets[n] = append(sparseBuckets[n], ruleIdx) + ruleIdx++ + } + g.ruleMap = nil + var buckets []indexBucket + for n, vals := range sparseBuckets { + if len(vals) > 0 { + buckets = append(buckets, indexBucket{n, vals}) + } + } + sort.Sort(bySize(buckets)) + + occ := make([]bool, len(g.level1)) + var tmpOcc []int + for _, bucket := range buckets { + var seed = uint32(0) + for { + findSeed := true + tmpOcc = tmpOcc[:0] + for _, i := range bucket.vals { + n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask + if occ[n] { + for _, n := range tmpOcc { + occ[n] = false + } + seed++ + findSeed = false + break + } + occ[n] = true + tmpOcc = append(tmpOcc, n) + g.level1[n] = uint32(i) + } + if findSeed { + g.level0[bucket.n] = seed + break + } + } + } +} + +func nextPow2(v int) int { + if v <= 1 { + return 1 + } + const MaxUInt = ^uint(0) + n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1 + return int(n) +} + +// Lookup searches for s in t and returns its index and whether it was found. +func (g *MphMatcherGroup) Lookup(h uint32, s string) bool { + i0 := int(h) & g.level0Mask + seed := g.level0[i0] + i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask + n := g.level1[i1] + return s == g.rules[int(n)] +} + +// Match implements IndexMatcher.Match. +func (g *MphMatcherGroup) Match(pattern string) []uint32 { + result := []uint32{} + hash := uint32(0) + for i := len(pattern) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(pattern[i]) + if pattern[i] == '.' { + if g.Lookup(hash, pattern[i:]) { + result = append(result, 1) + return result + } + } + } + if g.Lookup(hash, pattern) { + result = append(result, 1) + return result + } + if g.ac != nil && g.ac.Match(pattern) { + result = append(result, 1) + return result + } + for _, e := range g.otherMatchers { + if e.m.Match(pattern) { + result = append(result, e.id) + return result + } + } + return nil +} + +type indexBucket struct { + n int + vals []int +} + +type bySize []indexBucket + +func (s bySize) Len() int { return len(s) } +func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) } +func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stringStruct struct { + str unsafe.Pointer + len int +} + +func strhashFallback(a unsafe.Pointer, h uintptr) uintptr { + x := (*stringStruct)(a) + return memhashFallback(x.str, h, uintptr(x.len)) +} + +const ( + // Constants for multiplication: four random odd 64-bit numbers. + m1 = 16877499708836156737 + m2 = 2820277070424839065 + m3 = 9497967016996688599 + m4 = 15839092249703872147 +) + +var hashkey = [4]uintptr{1, 1, 1, 1} + +func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr { + h := uint64(seed + s*hashkey[0]) +tail: + switch { + case s == 0: + case s < 4: + h ^= uint64(*(*byte)(p)) + h ^= uint64(*(*byte)(add(p, s>>1))) << 8 + h ^= uint64(*(*byte)(add(p, s-1))) << 16 + h = rotl31(h*m1) * m2 + case s <= 8: + h ^= uint64(readUnaligned32(p)) + h ^= uint64(readUnaligned32(add(p, s-4))) << 32 + h = rotl31(h*m1) * m2 + case s <= 16: + h ^= readUnaligned64(p) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, s-8)) + h = rotl31(h*m1) * m2 + case s <= 32: + h ^= readUnaligned64(p) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, 8)) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, s-16)) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, s-8)) + h = rotl31(h*m1) * m2 + default: + v1 := h + v2 := uint64(seed * hashkey[1]) + v3 := uint64(seed * hashkey[2]) + v4 := uint64(seed * hashkey[3]) + for s >= 32 { + v1 ^= readUnaligned64(p) + v1 = rotl31(v1*m1) * m2 + p = add(p, 8) + v2 ^= readUnaligned64(p) + v2 = rotl31(v2*m2) * m3 + p = add(p, 8) + v3 ^= readUnaligned64(p) + v3 = rotl31(v3*m3) * m4 + p = add(p, 8) + v4 ^= readUnaligned64(p) + v4 = rotl31(v4*m4) * m1 + p = add(p, 8) + s -= 32 + } + h = v1 ^ v2 ^ v3 ^ v4 + goto tail + } + + h ^= h >> 29 + h *= m3 + h ^= h >> 32 + return uintptr(h) +} +func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { + return unsafe.Pointer(uintptr(p) + x) +} +func readUnaligned32(p unsafe.Pointer) uint32 { + q := (*[4]byte)(p) + return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24 +} + +func rotl31(x uint64) uint64 { + return (x << 31) | (x >> (64 - 31)) +} +func readUnaligned64(p unsafe.Pointer) uint64 { + q := (*[8]byte)(p) + return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56 +} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index 9728047d..294e6e73 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -27,6 +27,7 @@ const ( // New creates a new Matcher based on the given pattern. func (t Type) New(pattern string) (Matcher, error) { + // 1. regex matching is case-sensitive switch t { case Full: return fullMatcher(pattern), nil diff --git a/common/strmatcher/strmatcher_test.go b/common/strmatcher/strmatcher_test.go index 87d6c5b3..ec87ce5c 100644 --- a/common/strmatcher/strmatcher_test.go +++ b/common/strmatcher/strmatcher_test.go @@ -91,3 +91,172 @@ func TestMatcherGroup(t *testing.T) { } } } + +func TestACAutomaton(t *testing.T) { + cases1 := []struct { + pattern string + mType Type + input string + output bool + }{ + { + pattern: "xtls.github.io", + mType: Domain, + input: "www.xtls.github.io", + output: true, + }, + { + pattern: "xtls.github.io", + mType: Domain, + input: "xtls.github.io", + output: true, + }, + { + pattern: "xtls.github.io", + mType: Domain, + input: "www.xtis.github.io", + output: false, + }, + { + pattern: "xtls.github.io", + mType: Domain, + input: "tls.github.io", + output: false, + }, + { + pattern: "xtls.github.io", + mType: Domain, + input: "xxtls.github.io", + output: false, + }, + { + pattern: "xtls.github.io", + mType: Full, + input: "xtls.github.io", + output: true, + }, + { + pattern: "xtls.github.io", + mType: Full, + input: "xxtls.github.io", + output: false, + }, + } + for _, test := range cases1 { + var ac = NewACAutomaton() + ac.Add(test.pattern, test.mType) + ac.Build() + if m := ac.Match(test.input); m != test.output { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + { + cases2Input := []struct { + pattern string + mType Type + }{ + { + pattern: "163.com", + mType: Domain, + }, + { + pattern: "m.126.com", + mType: Full, + }, + { + pattern: "3.com", + mType: Full, + }, + { + pattern: "google.com", + mType: Substr, + }, + { + pattern: "vgoogle.com", + mType: Substr, + }, + } + var ac = NewACAutomaton() + for _, test := range cases2Input { + ac.Add(test.pattern, test.mType) + } + ac.Build() + cases2Output := []struct { + pattern string + res bool + }{ + { + pattern: "126.com", + res: false, + }, + { + pattern: "m.163.com", + res: true, + }, + { + pattern: "mm163.com", + res: false, + }, + { + pattern: "m.126.com", + res: true, + }, + { + pattern: "163.com", + res: true, + }, + { + pattern: "63.com", + res: false, + }, + { + pattern: "oogle.com", + res: false, + }, + { + pattern: "vvgoogle.com", + res: true, + }, + } + for _, test := range cases2Output { + if m := ac.Match(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } + + { + cases3Input := []struct { + pattern string + mType Type + }{ + { + pattern: "video.google.com", + mType: Domain, + }, + { + pattern: "gle.com", + mType: Domain, + }, + } + var ac = NewACAutomaton() + for _, test := range cases3Input { + ac.Add(test.pattern, test.mType) + } + ac.Build() + cases3Output := []struct { + pattern string + res bool + }{ + { + pattern: "google.com", + res: false, + }, + } + for _, test := range cases3Output { + if m := ac.Match(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } +} diff --git a/infra/conf/router.go b/infra/conf/router.go index a93da6b1..abed22c6 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -98,6 +98,8 @@ type RouterRule struct { Type string `json:"type"` OutboundTag string `json:"outboundTag"` BalancerTag string `json:"balancerTag"` + + DomainMatcher string `json:"domainMatcher"` } func ParseIP(s string) (*router.CIDR, error) { @@ -491,6 +493,10 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { return nil, newError("neither outboundTag nor balancerTag is specified in routing rule") } + if rawFieldRule.DomainMatcher != "" { + rule.DomainMatcher = rawFieldRule.DomainMatcher + } + if rawFieldRule.Domain != nil { for _, domain := range *rawFieldRule.Domain { rules, err := parseDomainRule(domain)