This commit is contained in:
RPRX 2020-11-25 19:01:53 +08:00
parent 47d23e9972
commit c7f7c08ead
711 changed files with 82154 additions and 2 deletions

402
common/mux/client.go Normal file
View file

@ -0,0 +1,402 @@
package mux
import (
"context"
"io"
"sync"
"time"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/session"
"github.com/xtls/xray-core/v1/common/signal/done"
"github.com/xtls/xray-core/v1/common/task"
"github.com/xtls/xray-core/v1/proxy"
"github.com/xtls/xray-core/v1/transport"
"github.com/xtls/xray-core/v1/transport/internet"
"github.com/xtls/xray-core/v1/transport/pipe"
)
type ClientManager struct {
Enabled bool // wheather mux is enabled from user config
Picker WorkerPicker
}
func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
for i := 0; i < 16; i++ {
worker, err := m.Picker.PickAvailable()
if err != nil {
return err
}
if worker.Dispatch(ctx, link) {
return nil
}
}
return newError("unable to find an available mux client").AtWarning()
}
type WorkerPicker interface {
PickAvailable() (*ClientWorker, error)
}
type IncrementalWorkerPicker struct {
Factory ClientWorkerFactory
access sync.Mutex
workers []*ClientWorker
cleanupTask *task.Periodic
}
func (p *IncrementalWorkerPicker) cleanupFunc() error {
p.access.Lock()
defer p.access.Unlock()
if len(p.workers) == 0 {
return newError("no worker")
}
p.cleanup()
return nil
}
func (p *IncrementalWorkerPicker) cleanup() {
var activeWorkers []*ClientWorker
for _, w := range p.workers {
if !w.Closed() {
activeWorkers = append(activeWorkers, w)
}
}
p.workers = activeWorkers
}
func (p *IncrementalWorkerPicker) findAvailable() int {
for idx, w := range p.workers {
if !w.IsFull() {
return idx
}
}
return -1
}
func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
p.access.Lock()
defer p.access.Unlock()
idx := p.findAvailable()
if idx >= 0 {
n := len(p.workers)
if n > 1 && idx != n-1 {
p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
}
return p.workers[idx], false, nil
}
p.cleanup()
worker, err := p.Factory.Create()
if err != nil {
return nil, false, err
}
p.workers = append(p.workers, worker)
if p.cleanupTask == nil {
p.cleanupTask = &task.Periodic{
Interval: time.Second * 30,
Execute: p.cleanupFunc,
}
}
return worker, true, nil
}
func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
worker, start, err := p.pickInternal()
if start {
common.Must(p.cleanupTask.Start())
}
return worker, err
}
type ClientWorkerFactory interface {
Create() (*ClientWorker, error)
}
type DialingWorkerFactory struct {
Proxy proxy.Outbound
Dialer internet.Dialer
Strategy ClientStrategy
}
func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
uplinkReader, upLinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
c, err := NewClientWorker(transport.Link{
Reader: downlinkReader,
Writer: upLinkWriter,
}, f.Strategy)
if err != nil {
return nil, err
}
go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
})
ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
errors.New("failed to handler mux client connection").Base(err).WriteToLog()
}
common.Must(c.Close())
cancel()
}(f.Proxy, f.Dialer, c.done)
return c, nil
}
type ClientStrategy struct {
MaxConcurrency uint32
MaxConnection uint32
}
type ClientWorker struct {
sessionManager *SessionManager
link transport.Link
done *done.Instance
strategy ClientStrategy
}
var muxCoolAddress = net.DomainAddress("v1.mux.cool")
var muxCoolPort = net.Port(9527)
// NewClientWorker creates a new mux.Client.
func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) {
c := &ClientWorker{
sessionManager: NewSessionManager(),
link: stream,
done: done.New(),
strategy: s,
}
go c.fetchOutput()
go c.monitor()
return c, nil
}
func (m *ClientWorker) TotalConnections() uint32 {
return uint32(m.sessionManager.Count())
}
func (m *ClientWorker) ActiveConnections() uint32 {
return uint32(m.sessionManager.Size())
}
// Closed returns true if this Client is closed.
func (m *ClientWorker) Closed() bool {
return m.done.Done()
}
func (m *ClientWorker) monitor() {
timer := time.NewTicker(time.Second * 16)
defer timer.Stop()
for {
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Close(m.link.Writer)
common.Interrupt(m.link.Reader)
return
case <-timer.C:
size := m.sessionManager.Size()
if size == 0 && m.sessionManager.CloseIfNoSession() {
common.Must(m.done.Close())
}
}
}
}
func writeFirstPayload(reader buf.Reader, writer *Writer) error {
err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100)
if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
return writer.WriteMultiBuffer(buf.MultiBuffer{})
}
if err != nil {
return err
}
return nil
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
dest := session.OutboundFromContext(ctx).Target
transferType := protocol.TransferTypeStream
if dest.Network == net.Network_UDP {
transferType = protocol.TransferTypePacket
}
s.transferType = transferType
writer := NewWriter(s.ID, dest, output, transferType)
defer s.Close()
defer writer.Close()
newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
common.Interrupt(s.input)
return
}
if err := buf.Copy(s.input, writer); err != nil {
newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
common.Interrupt(s.input)
return
}
}
func (m *ClientWorker) IsClosing() bool {
sm := m.sessionManager
if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
return true
}
return false
}
func (m *ClientWorker) IsFull() bool {
if m.IsClosing() || m.Closed() {
return true
}
sm := m.sessionManager
if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
return true
}
return false
}
func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
if m.IsFull() || m.Closed() {
return false
}
sm := m.sessionManager
s := sm.Allocate()
if s == nil {
return false
}
s.input = link.Reader
s.output = link.Writer
go fetchInput(ctx, s, m.link.Writer)
return true
}
func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
s, found := m.sessionManager.Get(meta.SessionID)
if !found {
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
rr := s.NewReader(reader)
err := buf.Copy(rr, s.output)
if err != nil && buf.IsWriteError(err) {
newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
drainErr := buf.Copy(rr, buf.Discard)
common.Interrupt(s.input)
s.Close()
return drainErr
}
return err
}
func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := m.sessionManager.Get(meta.SessionID); found {
if meta.Option.Has(OptionError) {
common.Interrupt(s.input)
common.Interrupt(s.output)
}
s.Close()
}
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) fetchOutput() {
defer func() {
common.Must(m.done.Close())
}()
reader := &buf.BufferedReader{Reader: m.link.Reader}
var meta FrameMetadata
for {
err := meta.Unmarshal(reader)
if err != nil {
if errors.Cause(err) != io.EOF {
newError("failed to read metadata").Base(err).WriteToLog()
}
break
}
switch meta.SessionStatus {
case SessionStatusKeepAlive:
err = m.handleStatueKeepAlive(&meta, reader)
case SessionStatusEnd:
err = m.handleStatusEnd(&meta, reader)
case SessionStatusNew:
err = m.handleStatusNew(&meta, reader)
case SessionStatusKeep:
err = m.handleStatusKeep(&meta, reader)
default:
status := meta.SessionStatus
newError("unknown status: ", status).AtError().WriteToLog()
return
}
if err != nil {
newError("failed to process data").Base(err).WriteToLog()
return
}
}
}

116
common/mux/client_test.go Normal file
View file

@ -0,0 +1,116 @@
package mux_test
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/mux"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/session"
"github.com/xtls/xray-core/v1/testing/mocks"
"github.com/xtls/xray-core/v1/transport"
"github.com/xtls/xray-core/v1/transport/pipe"
)
func TestIncrementalPickerFailure(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockWorkerFactory := mocks.NewMuxClientWorkerFactory(mockCtl)
mockWorkerFactory.EXPECT().Create().Return(nil, errors.New("test"))
picker := mux.IncrementalWorkerPicker{
Factory: mockWorkerFactory,
}
_, err := picker.PickAvailable()
if err == nil {
t.Error("expected error, but nil")
}
}
func TestClientWorkerEOF(t *testing.T) {
reader, writer := pipe.New(pipe.WithoutSizeLimit())
common.Must(writer.Close())
worker, err := mux.NewClientWorker(transport.Link{Reader: reader, Writer: writer}, mux.ClientStrategy{})
common.Must(err)
time.Sleep(time.Millisecond * 500)
f := worker.Dispatch(context.Background(), nil)
if f {
t.Error("expected failed dispatching, but actually not")
}
}
func TestClientWorkerClose(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
r1, w1 := pipe.New(pipe.WithoutSizeLimit())
worker1, err := mux.NewClientWorker(transport.Link{
Reader: r1,
Writer: w1,
}, mux.ClientStrategy{
MaxConcurrency: 4,
MaxConnection: 4,
})
common.Must(err)
r2, w2 := pipe.New(pipe.WithoutSizeLimit())
worker2, err := mux.NewClientWorker(transport.Link{
Reader: r2,
Writer: w2,
}, mux.ClientStrategy{
MaxConcurrency: 4,
MaxConnection: 4,
})
common.Must(err)
factory := mocks.NewMuxClientWorkerFactory(mockCtl)
gomock.InOrder(
factory.EXPECT().Create().Return(worker1, nil),
factory.EXPECT().Create().Return(worker2, nil),
)
picker := &mux.IncrementalWorkerPicker{
Factory: factory,
}
manager := &mux.ClientManager{
Picker: picker,
}
tr1, tw1 := pipe.New(pipe.WithoutSizeLimit())
ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
common.Must(manager.Dispatch(ctx1, &transport.Link{
Reader: tr1,
Writer: tw1,
}))
defer tw1.Close()
common.Must(w1.Close())
time.Sleep(time.Millisecond * 500)
if !worker1.Closed() {
t.Error("worker1 is not finished")
}
tr2, tw2 := pipe.New(pipe.WithoutSizeLimit())
ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
common.Must(manager.Dispatch(ctx2, &transport.Link{
Reader: tr2,
Writer: tw2,
}))
defer tw2.Close()
common.Must(w2.Close())
}

View file

@ -0,0 +1,9 @@
package mux
import "github.com/xtls/xray-core/v1/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

145
common/mux/frame.go Normal file
View file

@ -0,0 +1,145 @@
package mux
import (
"encoding/binary"
"io"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/bitmask"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/serial"
)
type SessionStatus byte
const (
SessionStatusNew SessionStatus = 0x01
SessionStatusKeep SessionStatus = 0x02
SessionStatusEnd SessionStatus = 0x03
SessionStatusKeepAlive SessionStatus = 0x04
)
const (
OptionData bitmask.Byte = 0x01
OptionError bitmask.Byte = 0x02
)
type TargetNetwork byte
const (
TargetNetworkTCP TargetNetwork = 0x01
TargetNetworkUDP TargetNetwork = 0x02
)
var addrParser = protocol.NewAddressParser(
protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
protocol.PortThenAddress(),
)
/*
Frame format
2 bytes - length
2 bytes - session id
1 bytes - status
1 bytes - option
1 byte - network
2 bytes - port
n bytes - address
*/
type FrameMetadata struct {
Target net.Destination
SessionID uint16
Option bitmask.Byte
SessionStatus SessionStatus
}
func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
lenBytes := b.Extend(2)
len0 := b.Len()
sessionBytes := b.Extend(2)
binary.BigEndian.PutUint16(sessionBytes, f.SessionID)
common.Must(b.WriteByte(byte(f.SessionStatus)))
common.Must(b.WriteByte(byte(f.Option)))
if f.SessionStatus == SessionStatusNew {
switch f.Target.Network {
case net.Network_TCP:
common.Must(b.WriteByte(byte(TargetNetworkTCP)))
case net.Network_UDP:
common.Must(b.WriteByte(byte(TargetNetworkUDP)))
}
if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil {
return err
}
}
len1 := b.Len()
binary.BigEndian.PutUint16(lenBytes, uint16(len1-len0))
return nil
}
// Unmarshal reads FrameMetadata from the given reader.
func (f *FrameMetadata) Unmarshal(reader io.Reader) error {
metaLen, err := serial.ReadUint16(reader)
if err != nil {
return err
}
if metaLen > 512 {
return newError("invalid metalen ", metaLen).AtError()
}
b := buf.New()
defer b.Release()
if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil {
return err
}
return f.UnmarshalFromBuffer(b)
}
// UnmarshalFromBuffer reads a FrameMetadata from the given buffer.
// Visible for testing only.
func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error {
if b.Len() < 4 {
return newError("insufficient buffer: ", b.Len())
}
f.SessionID = binary.BigEndian.Uint16(b.BytesTo(2))
f.SessionStatus = SessionStatus(b.Byte(2))
f.Option = bitmask.Byte(b.Byte(3))
f.Target.Network = net.Network_Unknown
if f.SessionStatus == SessionStatusNew {
if b.Len() < 8 {
return newError("insufficient buffer: ", b.Len())
}
network := TargetNetwork(b.Byte(4))
b.Advance(5)
addr, port, err := addrParser.ReadAddressPort(nil, b)
if err != nil {
return newError("failed to parse address and port").Base(err)
}
switch network {
case TargetNetworkTCP:
f.Target = net.TCPDestination(addr, port)
case TargetNetworkUDP:
f.Target = net.UDPDestination(addr, port)
default:
return newError("unknown network type: ", network)
}
}
return nil
}

25
common/mux/frame_test.go Normal file
View file

@ -0,0 +1,25 @@
package mux_test
import (
"testing"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/mux"
"github.com/xtls/xray-core/v1/common/net"
)
func BenchmarkFrameWrite(b *testing.B) {
frame := mux.FrameMetadata{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), net.Port(80)),
SessionID: 1,
SessionStatus: mux.SessionStatusNew,
}
writer := buf.New()
defer writer.Release()
for i := 0; i < b.N; i++ {
common.Must(frame.WriteTo(writer))
writer.Clear()
}
}

3
common/mux/mux.go Normal file
View file

@ -0,0 +1,3 @@
package mux
//go:generate go run github.com/xtls/xray-core/v1/common/errors/errorgen

196
common/mux/mux_test.go Normal file
View file

@ -0,0 +1,196 @@
package mux_test
import (
"io"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
. "github.com/xtls/xray-core/v1/common/mux"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/transport/pipe"
)
func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
var mb buf.MultiBuffer
for {
b, err := reader.ReadMultiBuffer()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
mb = append(mb, b...)
}
return mb, nil
}
func TestReaderWriter(t *testing.T) {
pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
dest := net.TCPDestination(net.DomainAddress("example.com"), 80)
writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream)
dest2 := net.TCPDestination(net.LocalHostIP, 443)
writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream)
dest3 := net.TCPDestination(net.LocalHostIPv6, 18374)
writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream)
writePayload := func(writer *Writer, payload ...byte) error {
b := buf.New()
b.Write(payload)
return writer.WriteMultiBuffer(buf.MultiBuffer{b})
}
common.Must(writePayload(writer, 'a', 'b', 'c', 'd'))
common.Must(writePayload(writer2))
common.Must(writePayload(writer, 'e', 'f', 'g', 'h'))
common.Must(writePayload(writer3, 'x'))
writer.Close()
writer3.Close()
common.Must(writePayload(writer2, 'y'))
writer2.Close()
bytesReader := &buf.BufferedReader{Reader: pReader}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 1,
SessionStatus: SessionStatusNew,
Target: dest,
Option: OptionData,
}); r != "" {
t.Error("metadata: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "abcd" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionStatus: SessionStatusNew,
SessionID: 2,
Option: 0,
Target: dest2,
}); r != "" {
t.Error("meta: ", r)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 1,
SessionStatus: SessionStatusKeep,
Option: 1,
}); r != "" {
t.Error("meta: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "efgh" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 3,
SessionStatus: SessionStatusNew,
Option: 1,
Target: dest3,
}); r != "" {
t.Error("meta: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "x" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 1,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 3,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 2,
SessionStatus: SessionStatusKeep,
Option: 1,
}); r != "" {
t.Error("meta: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "y" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 2,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
pWriter.Close()
{
var meta FrameMetadata
err := meta.Unmarshal(bytesReader)
if err == nil {
t.Error("nil error")
}
}
}

52
common/mux/reader.go Normal file
View file

@ -0,0 +1,52 @@
package mux
import (
"io"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/crypto"
"github.com/xtls/xray-core/v1/common/serial"
)
// PacketReader is an io.Reader that reads whole chunk of Mux frames every time.
type PacketReader struct {
reader io.Reader
eof bool
}
// NewPacketReader creates a new PacketReader.
func NewPacketReader(reader io.Reader) *PacketReader {
return &PacketReader{
reader: reader,
eof: false,
}
}
// ReadMultiBuffer implements buf.Reader.
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.eof {
return nil, io.EOF
}
size, err := serial.ReadUint16(r.reader)
if err != nil {
return nil, err
}
if size > buf.Size {
return nil, newError("packet size too large: ", size)
}
b := buf.New()
if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
b.Release()
return nil, err
}
r.eof = true
return buf.MultiBuffer{b}, nil
}
// NewStreamReader creates a new StreamReader.
func NewStreamReader(reader *buf.BufferedReader) buf.Reader {
return crypto.NewChunkStreamReaderWithChunkCount(crypto.PlainChunkSizeParser{}, reader, 1)
}

252
common/mux/server.go Normal file
View file

@ -0,0 +1,252 @@
package mux
import (
"context"
"io"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/log"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/session"
"github.com/xtls/xray-core/v1/core"
"github.com/xtls/xray-core/v1/features/routing"
"github.com/xtls/xray-core/v1/transport"
"github.com/xtls/xray-core/v1/transport/pipe"
)
type Server struct {
dispatcher routing.Dispatcher
}
// NewServer creates a new mux.Server.
func NewServer(ctx context.Context) *Server {
s := &Server{}
core.RequireFeatures(ctx, func(d routing.Dispatcher) {
s.dispatcher = d
})
return s
}
// Type implements common.HasType.
func (s *Server) Type() interface{} {
return s.dispatcher.Type()
}
// Dispatch implements routing.Dispatcher
func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) {
if dest.Address != muxCoolAddress {
return s.dispatcher.Dispatch(ctx, dest)
}
opts := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
_, err := NewServerWorker(ctx, s.dispatcher, &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,
})
if err != nil {
return nil, err
}
return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil
}
// Start implements common.Runnable.
func (s *Server) Start() error {
return nil
}
// Close implements common.Closable.
func (s *Server) Close() error {
return nil
}
type ServerWorker struct {
dispatcher routing.Dispatcher
link *transport.Link
sessionManager *SessionManager
}
func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
worker := &ServerWorker{
dispatcher: d,
link: link,
sessionManager: NewSessionManager(),
}
go worker.run(ctx)
return worker, nil
}
func handle(ctx context.Context, s *Session, output buf.Writer) {
writer := NewResponseWriter(s.ID, output, s.transferType)
if err := buf.Copy(s.input, writer); err != nil {
newError("session ", s.ID, " ends.").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
}
writer.Close()
s.Close()
}
func (w *ServerWorker) ActiveConnections() uint32 {
return uint32(w.sessionManager.Size())
}
func (w *ServerWorker) Closed() bool {
return w.sessionManager.Closed()
}
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
newError("received request for ", meta.Target).WriteToLog(session.ExportIDToError(ctx))
{
msg := &log.AccessMessage{
To: meta.Target,
Status: log.AccessAccepted,
Reason: "",
}
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
msg.From = inbound.Source
msg.Email = inbound.User.Email
}
ctx = log.ContextWithAccessMessage(ctx, msg)
}
link, err := w.dispatcher.Dispatch(ctx, meta.Target)
if err != nil {
if meta.Option.Has(OptionData) {
buf.Copy(NewStreamReader(reader), buf.Discard)
}
return newError("failed to dispatch request.").Base(err)
}
s := &Session{
input: link.Reader,
output: link.Writer,
parent: w.sessionManager,
ID: meta.SessionID,
transferType: protocol.TransferTypeStream,
}
if meta.Target.Network == net.Network_UDP {
s.transferType = protocol.TransferTypePacket
}
w.sessionManager.Add(s)
go handle(ctx, s, w.link.Writer)
if !meta.Option.Has(OptionData) {
return nil
}
rr := s.NewReader(reader)
if err := buf.Copy(rr, s.output); err != nil {
buf.Copy(rr, buf.Discard)
common.Interrupt(s.input)
return s.Close()
}
return nil
}
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
s, found := w.sessionManager.Get(meta.SessionID)
if !found {
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, w.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
rr := s.NewReader(reader)
err := buf.Copy(rr, s.output)
if err != nil && buf.IsWriteError(err) {
newError("failed to write to downstream writer. closing session ", s.ID).Base(err).WriteToLog()
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, w.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
drainErr := buf.Copy(rr, buf.Discard)
common.Interrupt(s.input)
s.Close()
return drainErr
}
return err
}
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := w.sessionManager.Get(meta.SessionID); found {
if meta.Option.Has(OptionError) {
common.Interrupt(s.input)
common.Interrupt(s.output)
}
s.Close()
}
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedReader) error {
var meta FrameMetadata
err := meta.Unmarshal(reader)
if err != nil {
return newError("failed to read metadata").Base(err)
}
switch meta.SessionStatus {
case SessionStatusKeepAlive:
err = w.handleStatusKeepAlive(&meta, reader)
case SessionStatusEnd:
err = w.handleStatusEnd(&meta, reader)
case SessionStatusNew:
err = w.handleStatusNew(ctx, &meta, reader)
case SessionStatusKeep:
err = w.handleStatusKeep(&meta, reader)
default:
status := meta.SessionStatus
return newError("unknown status: ", status).AtError()
}
if err != nil {
return newError("failed to process data").Base(err)
}
return nil
}
func (w *ServerWorker) run(ctx context.Context) {
input := w.link.Reader
reader := &buf.BufferedReader{Reader: input}
defer w.sessionManager.Close()
for {
select {
case <-ctx.Done():
return
default:
err := w.handleFrame(ctx, reader)
if err != nil {
if errors.Cause(err) != io.EOF {
newError("unexpected EOF").Base(err).WriteToLog(session.ExportIDToError(ctx))
common.Interrupt(input)
}
return
}
}
}
}

160
common/mux/session.go Normal file
View file

@ -0,0 +1,160 @@
package mux
import (
"sync"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/protocol"
)
type SessionManager struct {
sync.RWMutex
sessions map[uint16]*Session
count uint16
closed bool
}
func NewSessionManager() *SessionManager {
return &SessionManager{
count: 0,
sessions: make(map[uint16]*Session, 16),
}
}
func (m *SessionManager) Closed() bool {
m.RLock()
defer m.RUnlock()
return m.closed
}
func (m *SessionManager) Size() int {
m.RLock()
defer m.RUnlock()
return len(m.sessions)
}
func (m *SessionManager) Count() int {
m.RLock()
defer m.RUnlock()
return int(m.count)
}
func (m *SessionManager) Allocate() *Session {
m.Lock()
defer m.Unlock()
if m.closed {
return nil
}
m.count++
s := &Session{
ID: m.count,
parent: m,
}
m.sessions[s.ID] = s
return s
}
func (m *SessionManager) Add(s *Session) {
m.Lock()
defer m.Unlock()
if m.closed {
return
}
m.count++
m.sessions[s.ID] = s
}
func (m *SessionManager) Remove(id uint16) {
m.Lock()
defer m.Unlock()
if m.closed {
return
}
delete(m.sessions, id)
if len(m.sessions) == 0 {
m.sessions = make(map[uint16]*Session, 16)
}
}
func (m *SessionManager) Get(id uint16) (*Session, bool) {
m.RLock()
defer m.RUnlock()
if m.closed {
return nil, false
}
s, found := m.sessions[id]
return s, found
}
func (m *SessionManager) CloseIfNoSession() bool {
m.Lock()
defer m.Unlock()
if m.closed {
return true
}
if len(m.sessions) != 0 {
return false
}
m.closed = true
return true
}
func (m *SessionManager) Close() error {
m.Lock()
defer m.Unlock()
if m.closed {
return nil
}
m.closed = true
for _, s := range m.sessions {
common.Close(s.input)
common.Close(s.output)
}
m.sessions = nil
return nil
}
// Session represents a client connection in a Mux connection.
type Session struct {
input buf.Reader
output buf.Writer
parent *SessionManager
ID uint16
transferType protocol.TransferType
}
// Close closes all resources associated with this session.
func (s *Session) Close() error {
common.Close(s.output)
common.Close(s.input)
s.parent.Remove(s.ID)
return nil
}
// NewReader creates a buf.Reader based on the transfer type of this Session.
func (s *Session) NewReader(reader *buf.BufferedReader) buf.Reader {
if s.transferType == protocol.TransferTypeStream {
return NewStreamReader(reader)
}
return NewPacketReader(reader)
}

View file

@ -0,0 +1,51 @@
package mux_test
import (
"testing"
. "github.com/xtls/xray-core/v1/common/mux"
)
func TestSessionManagerAdd(t *testing.T) {
m := NewSessionManager()
s := m.Allocate()
if s.ID != 1 {
t.Error("id: ", s.ID)
}
if m.Size() != 1 {
t.Error("size: ", m.Size())
}
s = m.Allocate()
if s.ID != 2 {
t.Error("id: ", s.ID)
}
if m.Size() != 2 {
t.Error("size: ", m.Size())
}
s = &Session{
ID: 4,
}
m.Add(s)
if s.ID != 4 {
t.Error("id: ", s.ID)
}
if m.Size() != 3 {
t.Error("size: ", m.Size())
}
}
func TestSessionManagerClose(t *testing.T) {
m := NewSessionManager()
s := m.Allocate()
if m.CloseIfNoSession() {
t.Error("able to close")
}
m.Remove(s.ID)
if !m.CloseIfNoSession() {
t.Error("not able to close")
}
}

126
common/mux/writer.go Normal file
View file

@ -0,0 +1,126 @@
package mux
import (
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/serial"
)
type Writer struct {
dest net.Destination
writer buf.Writer
id uint16
followup bool
hasError bool
transferType protocol.TransferType
}
func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType) *Writer {
return &Writer{
id: id,
dest: dest,
writer: writer,
followup: false,
transferType: transferType,
}
}
func NewResponseWriter(id uint16, writer buf.Writer, transferType protocol.TransferType) *Writer {
return &Writer{
id: id,
writer: writer,
followup: true,
transferType: transferType,
}
}
func (w *Writer) getNextFrameMeta() FrameMetadata {
meta := FrameMetadata{
SessionID: w.id,
Target: w.dest,
}
if w.followup {
meta.SessionStatus = SessionStatusKeep
} else {
w.followup = true
meta.SessionStatus = SessionStatusNew
}
return meta
}
func (w *Writer) writeMetaOnly() error {
meta := w.getNextFrameMeta()
b := buf.New()
if err := meta.WriteTo(b); err != nil {
return err
}
return w.writer.WriteMultiBuffer(buf.MultiBuffer{b})
}
func writeMetaWithFrame(writer buf.Writer, meta FrameMetadata, data buf.MultiBuffer) error {
frame := buf.New()
if err := meta.WriteTo(frame); err != nil {
return err
}
if _, err := serial.WriteUint16(frame, uint16(data.Len())); err != nil {
return err
}
mb2 := make(buf.MultiBuffer, 0, len(data)+1)
mb2 = append(mb2, frame)
mb2 = append(mb2, data...)
return writer.WriteMultiBuffer(mb2)
}
func (w *Writer) writeData(mb buf.MultiBuffer) error {
meta := w.getNextFrameMeta()
meta.Option.Set(OptionData)
return writeMetaWithFrame(w.writer, meta, mb)
}
// WriteMultiBuffer implements buf.Writer.
func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer buf.ReleaseMulti(mb)
if mb.IsEmpty() {
return w.writeMetaOnly()
}
for !mb.IsEmpty() {
var chunk buf.MultiBuffer
if w.transferType == protocol.TransferTypeStream {
mb, chunk = buf.SplitSize(mb, 8*1024)
} else {
mb2, b := buf.SplitFirst(mb)
mb = mb2
chunk = buf.MultiBuffer{b}
}
if err := w.writeData(chunk); err != nil {
return err
}
}
return nil
}
// Close implements common.Closable.
func (w *Writer) Close() error {
meta := FrameMetadata{
SessionID: w.id,
SessionStatus: SessionStatusEnd,
}
if w.hasError {
meta.Option.Set(OptionError)
}
frame := buf.New()
common.Must(meta.WriteTo(frame))
w.writer.WriteMultiBuffer(buf.MultiBuffer{frame})
return nil
}