Xray-core/common/mux/client.go
yuhan6665 017f53b5fc
Add session context outbounds as slice (#3356)
* Add session context outbounds as slice

slice is needed for dialer proxy where two outbounds work on top of each other
There are two sets of target addr for example
It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy

* Fill outbound tag info

* Splice now checks capalibility from all outbounds

* Fix unit tests
2024-05-13 21:52:24 -04:00

394 lines
9.1 KiB
Go

package mux
import (
"context"
"io"
"sync"
"time"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/common/xudp"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/pipe"
)
type ClientManager struct {
Enabled bool // wheather mux is enabled from user config
Picker WorkerPicker
}
func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
for i := 0; i < 16; i++ {
worker, err := m.Picker.PickAvailable()
if err != nil {
return err
}
if worker.Dispatch(ctx, link) {
return nil
}
}
return newError("unable to find an available mux client").AtWarning()
}
type WorkerPicker interface {
PickAvailable() (*ClientWorker, error)
}
type IncrementalWorkerPicker struct {
Factory ClientWorkerFactory
access sync.Mutex
workers []*ClientWorker
cleanupTask *task.Periodic
}
func (p *IncrementalWorkerPicker) cleanupFunc() error {
p.access.Lock()
defer p.access.Unlock()
if len(p.workers) == 0 {
return newError("no worker")
}
p.cleanup()
return nil
}
func (p *IncrementalWorkerPicker) cleanup() {
var activeWorkers []*ClientWorker
for _, w := range p.workers {
if !w.Closed() {
activeWorkers = append(activeWorkers, w)
}
}
p.workers = activeWorkers
}
func (p *IncrementalWorkerPicker) findAvailable() int {
for idx, w := range p.workers {
if !w.IsFull() {
return idx
}
}
return -1
}
func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
p.access.Lock()
defer p.access.Unlock()
idx := p.findAvailable()
if idx >= 0 {
n := len(p.workers)
if n > 1 && idx != n-1 {
p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
}
return p.workers[idx], false, nil
}
p.cleanup()
worker, err := p.Factory.Create()
if err != nil {
return nil, false, err
}
p.workers = append(p.workers, worker)
if p.cleanupTask == nil {
p.cleanupTask = &task.Periodic{
Interval: time.Second * 30,
Execute: p.cleanupFunc,
}
}
return worker, true, nil
}
func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
worker, start, err := p.pickInternal()
if start {
common.Must(p.cleanupTask.Start())
}
return worker, err
}
type ClientWorkerFactory interface {
Create() (*ClientWorker, error)
}
type DialingWorkerFactory struct {
Proxy proxy.Outbound
Dialer internet.Dialer
Strategy ClientStrategy
}
func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
uplinkReader, upLinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
c, err := NewClientWorker(transport.Link{
Reader: downlinkReader,
Writer: upLinkWriter,
}, f.Strategy)
if err != nil {
return nil, err
}
go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
outbounds := []*session.Outbound{{
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
}}
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
errors.New("failed to handler mux client connection").Base(err).WriteToLog()
}
common.Must(c.Close())
cancel()
}(f.Proxy, f.Dialer, c.done)
return c, nil
}
type ClientStrategy struct {
MaxConcurrency uint32
MaxConnection uint32
}
type ClientWorker struct {
sessionManager *SessionManager
link transport.Link
done *done.Instance
strategy ClientStrategy
}
var (
muxCoolAddress = net.DomainAddress("v1.mux.cool")
muxCoolPort = net.Port(9527)
)
// NewClientWorker creates a new mux.Client.
func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) {
c := &ClientWorker{
sessionManager: NewSessionManager(),
link: stream,
done: done.New(),
strategy: s,
}
go c.fetchOutput()
go c.monitor()
return c, nil
}
func (m *ClientWorker) TotalConnections() uint32 {
return uint32(m.sessionManager.Count())
}
func (m *ClientWorker) ActiveConnections() uint32 {
return uint32(m.sessionManager.Size())
}
// Closed returns true if this Client is closed.
func (m *ClientWorker) Closed() bool {
return m.done.Done()
}
func (m *ClientWorker) monitor() {
timer := time.NewTicker(time.Second * 16)
defer timer.Stop()
for {
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Close(m.link.Writer)
common.Interrupt(m.link.Reader)
return
case <-timer.C:
size := m.sessionManager.Size()
if size == 0 && m.sessionManager.CloseIfNoSession() {
common.Must(m.done.Close())
}
}
}
}
func writeFirstPayload(reader buf.Reader, writer *Writer) error {
err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100)
if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
return writer.WriteMultiBuffer(buf.MultiBuffer{})
}
if err != nil {
return err
}
return nil
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
transferType := protocol.TransferTypeStream
if ob.Target.Network == net.Network_UDP {
transferType = protocol.TransferTypePacket
}
s.transferType = transferType
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()
newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
return
}
if err := buf.Copy(s.input, writer); err != nil {
newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
return
}
}
func (m *ClientWorker) IsClosing() bool {
sm := m.sessionManager
if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
return true
}
return false
}
func (m *ClientWorker) IsFull() bool {
if m.IsClosing() || m.Closed() {
return true
}
sm := m.sessionManager
if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
return true
}
return false
}
func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
if m.IsFull() || m.Closed() {
return false
}
sm := m.sessionManager
s := sm.Allocate()
if s == nil {
return false
}
s.input = link.Reader
s.output = link.Writer
go fetchInput(ctx, s, m.link.Writer)
return true
}
func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
s, found := m.sessionManager.Get(meta.SessionID)
if !found {
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
rr := s.NewReader(reader, &meta.Target)
err := buf.Copy(rr, s.output)
if err != nil && buf.IsWriteError(err) {
newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()
s.Close(false)
return buf.Copy(rr, buf.Discard)
}
return err
}
func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := m.sessionManager.Get(meta.SessionID); found {
s.Close(false)
}
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) fetchOutput() {
defer func() {
common.Must(m.done.Close())
}()
reader := &buf.BufferedReader{Reader: m.link.Reader}
var meta FrameMetadata
for {
err := meta.Unmarshal(reader)
if err != nil {
if errors.Cause(err) != io.EOF {
newError("failed to read metadata").Base(err).WriteToLog()
}
break
}
switch meta.SessionStatus {
case SessionStatusKeepAlive:
err = m.handleStatueKeepAlive(&meta, reader)
case SessionStatusEnd:
err = m.handleStatusEnd(&meta, reader)
case SessionStatusNew:
err = m.handleStatusNew(&meta, reader)
case SessionStatusKeep:
err = m.handleStatusKeep(&meta, reader)
default:
status := meta.SessionStatus
newError("unknown status: ", status).AtError().WriteToLog()
return
}
if err != nil {
newError("failed to process data").Base(err).WriteToLog()
return
}
}
}