Add DispatchLink

This commit is contained in:
世界 2021-09-28 09:42:57 +08:00
parent 625cf7361a
commit 50e576081e
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
5 changed files with 90 additions and 2 deletions

View File

@ -271,6 +271,67 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
return inbound, nil return inbound, nil
} }
// DispatchLink implements routing.Dispatcher.
func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.")
}
ob := &session.Outbound{
Target: destination,
}
ctx = session.ContextWithOutbound(ctx, ob)
content := session.ContentFromContext(ctx)
if content == nil {
content = new(session.Content)
ctx = session.ContextWithContent(ctx, content)
}
sniffingRequest := content.SniffingRequest
switch {
case !sniffingRequest.Enabled:
go d.routedDispatch(ctx, outbound, destination)
case destination.Network != net.Network_TCP:
// Only metadata sniff will be used for non tcp connection
result, err := sniffer(ctx, nil, true)
if err == nil {
content.Protocol = result.Protocol()
if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
ob.RouteTarget = destination
} else {
ob.Target = destination
}
}
}
go d.routedDispatch(ctx, outbound, destination)
default:
go func() {
cReader := &cachedReader{
reader: outbound.Reader.(*pipe.Reader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly)
if err == nil {
content.Protocol = result.Protocol()
}
if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
ob.RouteTarget = destination
} else {
ob.Target = destination
}
}
d.routedDispatch(ctx, outbound, destination)
}()
}
return nil
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) { func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) {
payload := buf.New() payload := buf.New()
defer payload.Release() defer payload.Release()

View File

@ -147,7 +147,7 @@ func (w *BridgeWorker) Connections() uint32 {
return w.worker.ActiveConnections() return w.worker.ActiveConnections()
} }
func (w *BridgeWorker) handleInternalConn(link transport.Link) { func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
go func() { go func() {
reader := link.Reader reader := link.Reader
for { for {
@ -181,7 +181,7 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra
uplinkReader, uplinkWriter := pipe.New(opt...) uplinkReader, uplinkWriter := pipe.New(opt...)
downlinkReader, downlinkWriter := pipe.New(opt...) downlinkReader, downlinkWriter := pipe.New(opt...)
w.handleInternalConn(transport.Link{ w.handleInternalConn(&transport.Link{
Reader: downlinkReader, Reader: downlinkReader,
Writer: uplinkWriter, Writer: uplinkWriter,
}) })
@ -191,3 +191,16 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra
Writer: downlinkWriter, Writer: downlinkWriter,
}, nil }, nil
} }
func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
if !isInternalDomain(dest) {
ctx = session.ContextWithInbound(ctx, &session.Inbound{
Tag: w.tag,
})
return w.dispatcher.DispatchLink(ctx, dest, link)
}
w.handleInternalConn(link)
return nil
}

View File

@ -56,6 +56,15 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*transport
return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil
} }
// DispatchLink implements routing.Dispatcher
func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
if dest.Address != muxCoolAddress {
return s.dispatcher.DispatchLink(ctx, dest, link)
}
_, err := NewServerWorker(ctx, s.dispatcher, link)
return err
}
// Start implements common.Runnable. // Start implements common.Runnable.
func (s *Server) Start() error { func (s *Server) Start() error {
return nil return nil

View File

@ -17,6 +17,7 @@ type Dispatcher interface {
// Dispatch returns a Ray for transporting data for the given request. // Dispatch returns a Ray for transporting data for the given request.
Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error)
DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error
} }
// DispatcherType returns the type of Dispatcher interface. Can be used to implement common.HasType. // DispatcherType returns the type of Dispatcher interface. Can be used to implement common.HasType.

View File

@ -24,6 +24,10 @@ func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*t
return d.OnDispatch(ctx, dest) return d.OnDispatch(ctx, dest)
} }
func (d *TestDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
return nil
}
func (d *TestDispatcher) Start() error { func (d *TestDispatcher) Start() error {
return nil return nil
} }