This commit is contained in:
RPRX 2020-11-25 19:01:53 +08:00
parent 47d23e9972
commit c7f7c08ead
711 changed files with 82154 additions and 2 deletions

View file

@ -0,0 +1,58 @@
package antireplay
import (
"sync"
"time"
cuckoo "github.com/seiflotfy/cuckoofilter"
)
const replayFilterCapacity = 100000
// ReplayFilter check for replay attacks.
type ReplayFilter struct {
lock sync.Mutex
poolA *cuckoo.Filter
poolB *cuckoo.Filter
poolSwap bool
lastSwap int64
interval int64
}
// NewReplayFilter create a new filter with specifying the expiration time interval in seconds.
func NewReplayFilter(interval int64) *ReplayFilter {
filter := &ReplayFilter{}
filter.interval = interval
return filter
}
// Interval in second for expiration time for duplicate records.
func (filter *ReplayFilter) Interval() int64 {
return filter.interval
}
// Check determine if there are duplicate records.
func (filter *ReplayFilter) Check(sum []byte) bool {
filter.lock.Lock()
defer filter.lock.Unlock()
now := time.Now().Unix()
if filter.lastSwap == 0 {
filter.lastSwap = now
filter.poolA = cuckoo.NewFilter(replayFilterCapacity)
filter.poolB = cuckoo.NewFilter(replayFilterCapacity)
}
elapsed := now - filter.lastSwap
if elapsed >= filter.Interval() {
if filter.poolSwap {
filter.poolA.Reset()
} else {
filter.poolB.Reset()
}
filter.poolSwap = !filter.poolSwap
filter.lastSwap = now
}
return filter.poolA.InsertUnique(sum) && filter.poolB.InsertUnique(sum)
}

21
common/bitmask/byte.go Normal file
View file

@ -0,0 +1,21 @@
package bitmask
// Byte is a bitmask in byte.
type Byte byte
// Has returns true if this bitmask contains another bitmask.
func (b Byte) Has(bb Byte) bool {
return (b & bb) != 0
}
func (b *Byte) Set(bb Byte) {
*b |= bb
}
func (b *Byte) Clear(bb Byte) {
*b &= ^bb
}
func (b *Byte) Toggle(bb Byte) {
*b ^= bb
}

View file

@ -0,0 +1,36 @@
package bitmask_test
import (
"testing"
. "github.com/xtls/xray-core/v1/common/bitmask"
)
func TestBitmaskByte(t *testing.T) {
b := Byte(0)
b.Set(Byte(1))
if !b.Has(1) {
t.Fatal("expected ", b, " to contain 1, but actually not")
}
b.Set(Byte(2))
if !b.Has(2) {
t.Fatal("expected ", b, " to contain 2, but actually not")
}
if !b.Has(1) {
t.Fatal("expected ", b, " to contain 1, but actually not")
}
b.Clear(Byte(1))
if !b.Has(2) {
t.Fatal("expected ", b, " to contain 2, but actually not")
}
if b.Has(1) {
t.Fatal("expected ", b, " to not contain 1, but actually did")
}
b.Toggle(Byte(2))
if b.Has(2) {
t.Fatal("expected ", b, " to not contain 2, but actually did")
}
}

4
common/buf/buf.go Normal file
View file

@ -0,0 +1,4 @@
// Package buf provides a light-weight memory allocation mechanism.
package buf // import "github.com/xtls/xray-core/v1/common/buf"
//go:generate go run github.com/xtls/xray-core/v1/common/errors/errorgen

212
common/buf/buffer.go Normal file
View file

@ -0,0 +1,212 @@
package buf
import (
"io"
"github.com/xtls/xray-core/v1/common/bytespool"
)
const (
// Size of a regular buffer.
Size = 8192
)
var pool = bytespool.GetPool(Size)
// Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles
// the buffer into an internal buffer pool, in order to recreate a buffer more
// quickly.
type Buffer struct {
v []byte
start int32
end int32
}
// New creates a Buffer with 0 length and 2K capacity.
func New() *Buffer {
return &Buffer{
v: pool.Get().([]byte),
}
}
// StackNew creates a new Buffer object on stack.
// This method is for buffers that is released in the same function.
func StackNew() Buffer {
return Buffer{
v: pool.Get().([]byte),
}
}
// Release recycles the buffer into an internal buffer pool.
func (b *Buffer) Release() {
if b == nil || b.v == nil {
return
}
p := b.v
b.v = nil
b.Clear()
pool.Put(p)
}
// Clear clears the content of the buffer, results an empty buffer with
// Len() = 0.
func (b *Buffer) Clear() {
b.start = 0
b.end = 0
}
// Byte returns the bytes at index.
func (b *Buffer) Byte(index int32) byte {
return b.v[b.start+index]
}
// SetByte sets the byte value at index.
func (b *Buffer) SetByte(index int32, value byte) {
b.v[b.start+index] = value
}
// Bytes returns the content bytes of this Buffer.
func (b *Buffer) Bytes() []byte {
return b.v[b.start:b.end]
}
// Extend increases the buffer size by n bytes, and returns the extended part.
// It panics if result size is larger than buf.Size.
func (b *Buffer) Extend(n int32) []byte {
end := b.end + n
if end > int32(len(b.v)) {
panic("extending out of bound")
}
ext := b.v[b.end:end]
b.end = end
return ext
}
// BytesRange returns a slice of this buffer with given from and to boundary.
func (b *Buffer) BytesRange(from, to int32) []byte {
if from < 0 {
from += b.Len()
}
if to < 0 {
to += b.Len()
}
return b.v[b.start+from : b.start+to]
}
// BytesFrom returns a slice of this Buffer starting from the given position.
func (b *Buffer) BytesFrom(from int32) []byte {
if from < 0 {
from += b.Len()
}
return b.v[b.start+from : b.end]
}
// BytesTo returns a slice of this Buffer from start to the given position.
func (b *Buffer) BytesTo(to int32) []byte {
if to < 0 {
to += b.Len()
}
return b.v[b.start : b.start+to]
}
// Resize cuts the buffer at the given position.
func (b *Buffer) Resize(from, to int32) {
if from < 0 {
from += b.Len()
}
if to < 0 {
to += b.Len()
}
if to < from {
panic("Invalid slice")
}
b.end = b.start + to
b.start += from
}
// Advance cuts the buffer at the given position.
func (b *Buffer) Advance(from int32) {
if from < 0 {
from += b.Len()
}
b.start += from
}
// Len returns the length of the buffer content.
func (b *Buffer) Len() int32 {
if b == nil {
return 0
}
return b.end - b.start
}
// IsEmpty returns true if the buffer is empty.
func (b *Buffer) IsEmpty() bool {
return b.Len() == 0
}
// IsFull returns true if the buffer has no more room to grow.
func (b *Buffer) IsFull() bool {
return b != nil && b.end == int32(len(b.v))
}
// Write implements Write method in io.Writer.
func (b *Buffer) Write(data []byte) (int, error) {
nBytes := copy(b.v[b.end:], data)
b.end += int32(nBytes)
return nBytes, nil
}
// WriteByte writes a single byte into the buffer.
func (b *Buffer) WriteByte(v byte) error {
if b.IsFull() {
return newError("buffer full")
}
b.v[b.end] = v
b.end++
return nil
}
// WriteString implements io.StringWriter.
func (b *Buffer) WriteString(s string) (int, error) {
return b.Write([]byte(s))
}
// Read implements io.Reader.Read().
func (b *Buffer) Read(data []byte) (int, error) {
if b.Len() == 0 {
return 0, io.EOF
}
nBytes := copy(data, b.v[b.start:b.end])
if int32(nBytes) == b.Len() {
b.Clear()
} else {
b.start += int32(nBytes)
}
return nBytes, nil
}
// ReadFrom implements io.ReaderFrom.
func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) {
n, err := reader.Read(b.v[b.end:])
b.end += int32(n)
return int64(n), err
}
// ReadFullFrom reads exact size of bytes from given reader, or until error occurs.
func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) {
end := b.end + size
if end > int32(len(b.v)) {
v := end
return 0, newError("out of bound: ", v)
}
n, err := io.ReadFull(reader, b.v[b.end:end])
b.end += int32(n)
return int64(n), err
}
// String returns the string form of this Buffer.
func (b *Buffer) String() string {
return string(b.Bytes())
}

223
common/buf/buffer_test.go Normal file
View file

@ -0,0 +1,223 @@
package buf_test
import (
"bytes"
"crypto/rand"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
. "github.com/xtls/xray-core/v1/common/buf"
)
func TestBufferClear(t *testing.T) {
buffer := New()
defer buffer.Release()
payload := "Bytes"
buffer.Write([]byte(payload))
if diff := cmp.Diff(buffer.Bytes(), []byte(payload)); diff != "" {
t.Error(diff)
}
buffer.Clear()
if buffer.Len() != 0 {
t.Error("expect 0 length, but got ", buffer.Len())
}
}
func TestBufferIsEmpty(t *testing.T) {
buffer := New()
defer buffer.Release()
if buffer.IsEmpty() != true {
t.Error("expect empty buffer, but not")
}
}
func TestBufferString(t *testing.T) {
buffer := New()
defer buffer.Release()
const payload = "Test String"
common.Must2(buffer.WriteString(payload))
if buffer.String() != payload {
t.Error("expect buffer content as ", payload, " but actually ", buffer.String())
}
}
func TestBufferByte(t *testing.T) {
{
buffer := New()
common.Must(buffer.WriteByte('m'))
if buffer.String() != "m" {
t.Error("expect buffer content as ", "m", " but actually ", buffer.String())
}
buffer.Release()
}
{
buffer := StackNew()
common.Must(buffer.WriteByte('n'))
if buffer.String() != "n" {
t.Error("expect buffer content as ", "n", " but actually ", buffer.String())
}
buffer.Release()
}
{
buffer := StackNew()
common.Must2(buffer.WriteString("HELLOWORLD"))
if b := buffer.Byte(5); b != 'W' {
t.Error("unexpected byte ", b)
}
buffer.SetByte(5, 'M')
if buffer.String() != "HELLOMORLD" {
t.Error("expect buffer content as ", "n", " but actually ", buffer.String())
}
buffer.Release()
}
}
func TestBufferResize(t *testing.T) {
buffer := New()
defer buffer.Release()
const payload = "Test String"
common.Must2(buffer.WriteString(payload))
if buffer.String() != payload {
t.Error("expect buffer content as ", payload, " but actually ", buffer.String())
}
buffer.Resize(-6, -3)
if l := buffer.Len(); int(l) != 3 {
t.Error("len error ", l)
}
if s := buffer.String(); s != "Str" {
t.Error("unexpect buffer ", s)
}
buffer.Resize(int32(len(payload)), 200)
if l := buffer.Len(); int(l) != 200-len(payload) {
t.Error("len error ", l)
}
}
func TestBufferSlice(t *testing.T) {
{
b := New()
common.Must2(b.Write([]byte("abcd")))
bytes := b.BytesFrom(-2)
if diff := cmp.Diff(bytes, []byte{'c', 'd'}); diff != "" {
t.Error(diff)
}
}
{
b := New()
common.Must2(b.Write([]byte("abcd")))
bytes := b.BytesTo(-2)
if diff := cmp.Diff(bytes, []byte{'a', 'b'}); diff != "" {
t.Error(diff)
}
}
{
b := New()
common.Must2(b.Write([]byte("abcd")))
bytes := b.BytesRange(-3, -1)
if diff := cmp.Diff(bytes, []byte{'b', 'c'}); diff != "" {
t.Error(diff)
}
}
}
func TestBufferReadFullFrom(t *testing.T) {
payload := make([]byte, 1024)
common.Must2(rand.Read(payload))
reader := bytes.NewReader(payload)
b := New()
n, err := b.ReadFullFrom(reader, 1024)
common.Must(err)
if n != 1024 {
t.Error("expect reading 1024 bytes, but actually ", n)
}
if diff := cmp.Diff(payload, b.Bytes()); diff != "" {
t.Error(diff)
}
}
func BenchmarkNewBuffer(b *testing.B) {
for i := 0; i < b.N; i++ {
buffer := New()
buffer.Release()
}
}
func BenchmarkNewBufferStack(b *testing.B) {
for i := 0; i < b.N; i++ {
buffer := StackNew()
buffer.Release()
}
}
func BenchmarkWrite2(b *testing.B) {
buffer := New()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = buffer.Write([]byte{'a', 'b'})
buffer.Clear()
}
}
func BenchmarkWrite8(b *testing.B) {
buffer := New()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = buffer.Write([]byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'})
buffer.Clear()
}
}
func BenchmarkWrite32(b *testing.B) {
buffer := New()
payload := make([]byte, 32)
rand.Read(payload)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = buffer.Write(payload)
buffer.Clear()
}
}
func BenchmarkWriteByte2(b *testing.B) {
buffer := New()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = buffer.WriteByte('a')
_ = buffer.WriteByte('b')
buffer.Clear()
}
}
func BenchmarkWriteByte8(b *testing.B) {
buffer := New()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = buffer.WriteByte('a')
_ = buffer.WriteByte('b')
_ = buffer.WriteByte('c')
_ = buffer.WriteByte('d')
_ = buffer.WriteByte('e')
_ = buffer.WriteByte('f')
_ = buffer.WriteByte('g')
_ = buffer.WriteByte('h')
buffer.Clear()
}
}

123
common/buf/copy.go Normal file
View file

@ -0,0 +1,123 @@
package buf
import (
"io"
"time"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/signal"
)
type dataHandler func(MultiBuffer)
type copyHandler struct {
onData []dataHandler
}
// SizeCounter is for counting bytes copied by Copy().
type SizeCounter struct {
Size int64
}
// CopyOption is an option for copying data.
type CopyOption func(*copyHandler)
// UpdateActivity is a CopyOption to update activity on each data copy operation.
func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(MultiBuffer) {
timer.Update()
})
}
}
// CountSize is a CopyOption that sums the total size of data copied into the given SizeCounter.
func CountSize(sc *SizeCounter) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(b MultiBuffer) {
sc.Size += int64(b.Len())
})
}
}
type readError struct {
error
}
func (e readError) Error() string {
return e.error.Error()
}
func (e readError) Inner() error {
return e.error
}
// IsReadError returns true if the error in Copy() comes from reading.
func IsReadError(err error) bool {
_, ok := err.(readError)
return ok
}
type writeError struct {
error
}
func (e writeError) Error() string {
return e.error.Error()
}
func (e writeError) Inner() error {
return e.error
}
// IsWriteError returns true if the error in Copy() comes from writing.
func IsWriteError(err error) bool {
_, ok := err.(writeError)
return ok
}
func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
for {
buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
for _, handler := range handler.onData {
handler(buffer)
}
if werr := writer.WriteMultiBuffer(buffer); werr != nil {
return writeError{werr}
}
}
if err != nil {
return readError{err}
}
}
}
// Copy dumps all payload from reader to writer or stops when an error occurs. It returns nil when EOF.
func Copy(reader Reader, writer Writer, options ...CopyOption) error {
var handler copyHandler
for _, option := range options {
option(&handler)
}
err := copyInternal(reader, writer, &handler)
if err != nil && errors.Cause(err) != io.EOF {
return err
}
return nil
}
var ErrNotTimeoutReader = newError("not a TimeoutReader")
func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error {
timeoutReader, ok := reader.(TimeoutReader)
if !ok {
return ErrNotTimeoutReader
}
mb, err := timeoutReader.ReadMultiBufferTimeout(timeout)
if err != nil {
return err
}
return writer.WriteMultiBuffer(mb)
}

71
common/buf/copy_test.go Normal file
View file

@ -0,0 +1,71 @@
package buf_test
import (
"crypto/rand"
"io"
"testing"
"github.com/golang/mock/gomock"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/testing/mocks"
)
func TestReadError(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockReader := mocks.NewReader(mockCtl)
mockReader.EXPECT().Read(gomock.Any()).Return(0, errors.New("error"))
err := buf.Copy(buf.NewReader(mockReader), buf.Discard)
if err == nil {
t.Fatal("expected error, but nil")
}
if !buf.IsReadError(err) {
t.Error("expected to be ReadError, but not")
}
if err.Error() != "error" {
t.Fatal("unexpected error message: ", err.Error())
}
}
func TestWriteError(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockWriter := mocks.NewWriter(mockCtl)
mockWriter.EXPECT().Write(gomock.Any()).Return(0, errors.New("error"))
err := buf.Copy(buf.NewReader(rand.Reader), buf.NewWriter(mockWriter))
if err == nil {
t.Fatal("expected error, but nil")
}
if !buf.IsWriteError(err) {
t.Error("expected to be WriteError, but not")
}
if err.Error() != "error" {
t.Fatal("unexpected error message: ", err.Error())
}
}
type TestReader struct{}
func (TestReader) Read(b []byte) (int, error) {
return len(b), nil
}
func BenchmarkCopy(b *testing.B) {
reader := buf.NewReader(io.LimitReader(TestReader{}, 10240))
writer := buf.Discard
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = buf.Copy(reader, writer)
}
}

View file

@ -0,0 +1,39 @@
???8+?$??I????+??????+?+++?IO7$ZD88ZDMD8OZ$7II7+++++++++++++
+??7++???I????????+?+++?+I?IZI$OND7ODDDDDD7Z$IZI++++++++++++
???I????????????~,...~?++I?777$DD8O8DDD88O$O7$7I++++++++++?+
???????????????.,::~...,+?I77ZZD8ZDNDDDDD8ZZ7$$7+++++++?+?+?
??????????????.,,:~~~==,I?7$$ZOD8ODNDD8DDZ$87777++?+++?????+
?????????I?=...:~~~~=~=+I?$$ZODD88ND8N8DDOZOZ77?????++??????
???II?????.,,,:==~~===I?IIZ$O$88ODD8ODNDDDOO$7$??I?++?++++??
???I????+..,,~=+???+?????7OOZZ8O$$778DDDDDO87I$I++++++++????
I??????..,,:~=??????+=,~?ZZZ$$I??II$DDDDD8Z8I~,+=?II$777IIII
II???,.,,::~??I?I?....,,~==I?+===+?$ODN8DD$O=,......+?????II
I?I?..,,:~~????,...,,::::~~~~~~~~=+$88ODD88=~,,,.......IIIII
II,..,,:~~I?:..,,,::::~~~~~~~~~~~~~+IOZ87?~~~::::,,,,...=?II
I,...,:::....,:::::::~~~~~~~~~~~~~~~=++=~~~~~~~~~~~:::,,,?II
,,,,~....,,,::::::::::::::~~~~~~~~~~~~~~~~~~~~~~~~~~~::,,,??
:~:...,,,:::::::::::::::::::~~~~~~~~~~~~~~~~~~~~~~~~~~::,,II
:::::::::::::::::::~+++::::::~~~~~~~~~~~~~~~~~~~~~~~~::::,,7
::::::::::::::~IIII?????:::::::~~~~~~~~~~~~~~~~~~~~~::::::,I
:,,,,,,,:+ZIIIIIIIIIIIII:::~::~~~~~~~~~~~~~~~~~~~~=~::::::::
7I777IIZI7ZIIIIIIIIIIII7?:~~~~~~~~~~~~~~~~~~~~~~~~~=~:::::::
$$$77$7Z77$7I77IIII7III$$:~~~~~~~~~~~~~~~~~~~~~~~=II~:::::::
$$$8$Z7$7$Z777777777777Z7~:~~~~~~~~~~~~~~~~~~~~~~$777::::::,
ZOZOZOZZ$7$$ZZ$8DDDZ777$$=~~~~~~~~~~~~~~~~~~~~~~~$$$7~:::::,
OOZOOOZZOOZO$ZZZ$O$$$$7ZZ$~~~~~~~~~~~~~~~~~~~~~~~ZZ$ZZ:::::,
O88OOOOO8ODOZZZZZOOZ8OOOOO:~~~~~~~~~~~~~~~~~~~~~ZOZZZZ~:::::
8888O8OODZ8ZOZOZZOOZOOOOOZ:::~~~~~~~~~~~~~~~~~~~,Z$ZOOO:::::
Z88O88D8Z88ZZOOZZOZ$$Z$$OZ:::~~~~~~~~~~~~~~~~~~~,,ZOOOOO::::
888D88OODD8DNDNDNNDDDD88OI:::::~~~~~~~~~~~~~~~~~.,:8ZO8O::::
D8D88DO88ZOOZOO8DDDNOZ$$O8~::::~~~~~~~~~~~~~===~..,88O8OO:::
8OD8O8OODO$D8DO88DO8O8888O~~::~~::~~~~~~~~~~~===...:8OOOZ~::
:..................,~,..~,~~:~:~~~~~~~~~~~~~~===...,+.....~~
.........................~~~:~~~~~~~~~~~~~~~~~==:..,......:~
.Made with love.........,~~~~~~~~:~~~~::~~~~~~~==..,,......:
........................~~~~~~~~~~~~~~:~~~~~~~~===,.,......~
...................,,..~~~~~~~~~~~~~~~~~~~~~~~~~==~,,.......
..................,,::~~~~~~~~~~~~~~~~~~~~~~~~~====~.,....,.
....................:~~~~~~~~~~~~~~~~~~~~~~~~~~~~==~:......,
......................,~================,.==~~~=~===~,......
.Thank you for your support.....................:~=,,,,,,,..

View file

@ -0,0 +1,9 @@
package buf
import "github.com/xtls/xray-core/v1/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

116
common/buf/io.go Normal file
View file

@ -0,0 +1,116 @@
package buf
import (
"io"
"net"
"os"
"syscall"
"time"
)
// Reader extends io.Reader with MultiBuffer.
type Reader interface {
// ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
ReadMultiBuffer() (MultiBuffer, error)
}
// ErrReadTimeout is an error that happens with IO timeout.
var ErrReadTimeout = newError("IO timeout")
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
type TimeoutReader interface {
ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
}
// Writer extends io.Writer with MultiBuffer.
type Writer interface {
// WriteMultiBuffer writes a MultiBuffer into underlying writer.
WriteMultiBuffer(MultiBuffer) error
}
// WriteAllBytes ensures all bytes are written into the given writer.
func WriteAllBytes(writer io.Writer, payload []byte) error {
for len(payload) > 0 {
n, err := writer.Write(payload)
if err != nil {
return err
}
payload = payload[n:]
}
return nil
}
func isPacketReader(reader io.Reader) bool {
_, ok := reader.(net.PacketConn)
return ok
}
// NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader {
if mr, ok := reader.(Reader); ok {
return mr
}
if isPacketReader(reader) {
return &PacketReader{
Reader: reader,
}
}
_, isFile := reader.(*os.File)
if !isFile && useReadv {
if sc, ok := reader.(syscall.Conn); ok {
rawConn, err := sc.SyscallConn()
if err != nil {
newError("failed to get sysconn").Base(err).WriteToLog()
} else {
return NewReadVReader(reader, rawConn)
}
}
}
return &SingleReader{
Reader: reader,
}
}
// NewPacketReader creates a new PacketReader based on the given reader.
func NewPacketReader(reader io.Reader) Reader {
if mr, ok := reader.(Reader); ok {
return mr
}
return &PacketReader{
Reader: reader,
}
}
func isPacketWriter(writer io.Writer) bool {
if _, ok := writer.(net.PacketConn); ok {
return true
}
// If the writer doesn't implement syscall.Conn, it is probably not a TCP connection.
if _, ok := writer.(syscall.Conn); !ok {
return true
}
return false
}
// NewWriter creates a new Writer.
func NewWriter(writer io.Writer) Writer {
if mw, ok := writer.(Writer); ok {
return mw
}
if isPacketWriter(writer) {
return &SequentialWriter{
Writer: writer,
}
}
return &BufferToBytesWriter{
Writer: writer,
}
}

50
common/buf/io_test.go Normal file
View file

@ -0,0 +1,50 @@
package buf_test
import (
"crypto/tls"
"io"
"testing"
. "github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/testing/servers/tcp"
)
func TestWriterCreation(t *testing.T) {
tcpServer := tcp.Server{}
dest, err := tcpServer.Start()
if err != nil {
t.Fatal("failed to start tcp server: ", err)
}
defer tcpServer.Close()
conn, err := net.Dial("tcp", dest.NetAddr())
if err != nil {
t.Fatal("failed to dial a TCP connection: ", err)
}
defer conn.Close()
{
writer := NewWriter(conn)
if _, ok := writer.(*BufferToBytesWriter); !ok {
t.Fatal("writer is not a BufferToBytesWriter")
}
writer2 := NewWriter(writer.(io.Writer))
if writer2 != writer {
t.Fatal("writer is not reused")
}
}
tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
})
defer tlsConn.Close()
{
writer := NewWriter(tlsConn)
if _, ok := writer.(*SequentialWriter); !ok {
t.Fatal("writer is not a SequentialWriter")
}
}
}

297
common/buf/multi_buffer.go Normal file
View file

@ -0,0 +1,297 @@
package buf
import (
"io"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/serial"
)
// ReadAllToBytes reads all content from the reader into a byte array, until EOF.
func ReadAllToBytes(reader io.Reader) ([]byte, error) {
mb, err := ReadFrom(reader)
if err != nil {
return nil, err
}
if mb.Len() == 0 {
return nil, nil
}
b := make([]byte, mb.Len())
mb, _ = SplitBytes(mb, b)
ReleaseMulti(mb)
return b, nil
}
// MultiBuffer is a list of Buffers. The order of Buffer matters.
type MultiBuffer []*Buffer
// MergeMulti merges content from src to dest, and returns the new address of dest and src
func MergeMulti(dest MultiBuffer, src MultiBuffer) (MultiBuffer, MultiBuffer) {
dest = append(dest, src...)
for idx := range src {
src[idx] = nil
}
return dest, src[:0]
}
// MergeBytes merges the given bytes into MultiBuffer and return the new address of the merged MultiBuffer.
func MergeBytes(dest MultiBuffer, src []byte) MultiBuffer {
n := len(dest)
if n > 0 && !(dest)[n-1].IsFull() {
nBytes, _ := (dest)[n-1].Write(src)
src = src[nBytes:]
}
for len(src) > 0 {
b := New()
nBytes, _ := b.Write(src)
src = src[nBytes:]
dest = append(dest, b)
}
return dest
}
// ReleaseMulti release all content of the MultiBuffer, and returns an empty MultiBuffer.
func ReleaseMulti(mb MultiBuffer) MultiBuffer {
for i := range mb {
mb[i].Release()
mb[i] = nil
}
return mb[:0]
}
// Copy copied the beginning part of the MultiBuffer into the given byte array.
func (mb MultiBuffer) Copy(b []byte) int {
total := 0
for _, bb := range mb {
nBytes := copy(b[total:], bb.Bytes())
total += nBytes
if int32(nBytes) < bb.Len() {
break
}
}
return total
}
// ReadFrom reads all content from reader until EOF.
func ReadFrom(reader io.Reader) (MultiBuffer, error) {
mb := make(MultiBuffer, 0, 16)
for {
b := New()
_, err := b.ReadFullFrom(reader, Size)
if b.IsEmpty() {
b.Release()
} else {
mb = append(mb, b)
}
if err != nil {
if errors.Cause(err) == io.EOF || errors.Cause(err) == io.ErrUnexpectedEOF {
return mb, nil
}
return mb, err
}
}
}
// SplitBytes splits the given amount of bytes from the beginning of the MultiBuffer.
// It returns the new address of MultiBuffer leftover, and number of bytes written into the input byte slice.
func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int) {
totalBytes := 0
endIndex := -1
for i := range mb {
pBuffer := mb[i]
nBytes, _ := pBuffer.Read(b)
totalBytes += nBytes
b = b[nBytes:]
if !pBuffer.IsEmpty() {
endIndex = i
break
}
pBuffer.Release()
mb[i] = nil
}
if endIndex == -1 {
mb = mb[:0]
} else {
mb = mb[endIndex:]
}
return mb, totalBytes
}
// SplitFirstBytes splits the first buffer from MultiBuffer, and then copy its content into the given slice.
func SplitFirstBytes(mb MultiBuffer, p []byte) (MultiBuffer, int) {
mb, b := SplitFirst(mb)
if b == nil {
return mb, 0
}
n := copy(p, b.Bytes())
b.Release()
return mb, n
}
// Compact returns another MultiBuffer by merging all content of the given one together.
func Compact(mb MultiBuffer) MultiBuffer {
if len(mb) == 0 {
return mb
}
mb2 := make(MultiBuffer, 0, len(mb))
last := mb[0]
for i := 1; i < len(mb); i++ {
curr := mb[i]
if last.Len()+curr.Len() > Size {
mb2 = append(mb2, last)
last = curr
} else {
common.Must2(last.ReadFrom(curr))
curr.Release()
}
}
mb2 = append(mb2, last)
return mb2
}
// SplitFirst splits the first Buffer from the beginning of the MultiBuffer.
func SplitFirst(mb MultiBuffer) (MultiBuffer, *Buffer) {
if len(mb) == 0 {
return mb, nil
}
b := mb[0]
mb[0] = nil
mb = mb[1:]
return mb, b
}
// SplitSize splits the beginning of the MultiBuffer into another one, for at most size bytes.
func SplitSize(mb MultiBuffer, size int32) (MultiBuffer, MultiBuffer) {
if len(mb) == 0 {
return mb, nil
}
if mb[0].Len() > size {
b := New()
copy(b.Extend(size), mb[0].BytesTo(size))
mb[0].Advance(size)
return mb, MultiBuffer{b}
}
totalBytes := int32(0)
var r MultiBuffer
endIndex := -1
for i := range mb {
if totalBytes+mb[i].Len() > size {
endIndex = i
break
}
totalBytes += mb[i].Len()
r = append(r, mb[i])
mb[i] = nil
}
if endIndex == -1 {
// To reuse mb array
mb = mb[:0]
} else {
mb = mb[endIndex:]
}
return mb, r
}
// WriteMultiBuffer writes all buffers from the MultiBuffer to the Writer one by one, and return error if any, with leftover MultiBuffer.
func WriteMultiBuffer(writer io.Writer, mb MultiBuffer) (MultiBuffer, error) {
for {
mb2, b := SplitFirst(mb)
mb = mb2
if b == nil {
break
}
_, err := writer.Write(b.Bytes())
b.Release()
if err != nil {
return mb, err
}
}
return nil, nil
}
// Len returns the total number of bytes in the MultiBuffer.
func (mb MultiBuffer) Len() int32 {
if mb == nil {
return 0
}
size := int32(0)
for _, b := range mb {
size += b.Len()
}
return size
}
// IsEmpty return true if the MultiBuffer has no content.
func (mb MultiBuffer) IsEmpty() bool {
for _, b := range mb {
if !b.IsEmpty() {
return false
}
}
return true
}
// String returns the content of the MultiBuffer in string.
func (mb MultiBuffer) String() string {
v := make([]interface{}, len(mb))
for i, b := range mb {
v[i] = b
}
return serial.Concat(v...)
}
// MultiBufferContainer is a ReadWriteCloser wrapper over MultiBuffer.
type MultiBufferContainer struct {
MultiBuffer
}
// Read implements io.Reader.
func (c *MultiBufferContainer) Read(b []byte) (int, error) {
if c.MultiBuffer.IsEmpty() {
return 0, io.EOF
}
mb, nBytes := SplitBytes(c.MultiBuffer, b)
c.MultiBuffer = mb
return nBytes, nil
}
// ReadMultiBuffer implements Reader.
func (c *MultiBufferContainer) ReadMultiBuffer() (MultiBuffer, error) {
mb := c.MultiBuffer
c.MultiBuffer = nil
return mb, nil
}
// Write implements io.Writer.
func (c *MultiBufferContainer) Write(b []byte) (int, error) {
c.MultiBuffer = MergeBytes(c.MultiBuffer, b)
return len(b), nil
}
// WriteMultiBuffer implement Writer.
func (c *MultiBufferContainer) WriteMultiBuffer(b MultiBuffer) error {
mb, _ := MergeMulti(c.MultiBuffer, b)
c.MultiBuffer = mb
return nil
}
// Close implement io.Closer.
func (c *MultiBufferContainer) Close() error {
c.MultiBuffer = ReleaseMulti(c.MultiBuffer)
return nil
}

View file

@ -0,0 +1,190 @@
package buf_test
import (
"bytes"
"crypto/rand"
"io"
"io/ioutil"
"os"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
. "github.com/xtls/xray-core/v1/common/buf"
)
func TestMultiBufferRead(t *testing.T) {
b1 := New()
common.Must2(b1.WriteString("ab"))
b2 := New()
common.Must2(b2.WriteString("cd"))
mb := MultiBuffer{b1, b2}
bs := make([]byte, 32)
_, nBytes := SplitBytes(mb, bs)
if nBytes != 4 {
t.Error("expect 4 bytes split, but got ", nBytes)
}
if r := cmp.Diff(bs[:nBytes], []byte("abcd")); r != "" {
t.Error(r)
}
}
func TestMultiBufferAppend(t *testing.T) {
var mb MultiBuffer
b := New()
common.Must2(b.WriteString("ab"))
mb = append(mb, b)
if mb.Len() != 2 {
t.Error("expected length 2, but got ", mb.Len())
}
}
func TestMultiBufferSliceBySizeLarge(t *testing.T) {
lb := make([]byte, 8*1024)
common.Must2(io.ReadFull(rand.Reader, lb))
mb := MergeBytes(nil, lb)
mb, mb2 := SplitSize(mb, 1024)
if mb2.Len() != 1024 {
t.Error("expect length 1024, but got ", mb2.Len())
}
if mb.Len() != 7*1024 {
t.Error("expect length 7*1024, but got ", mb.Len())
}
mb, mb3 := SplitSize(mb, 7*1024)
if mb3.Len() != 7*1024 {
t.Error("expect length 7*1024, but got", mb.Len())
}
if !mb.IsEmpty() {
t.Error("expect empty buffer, but got ", mb.Len())
}
}
func TestMultiBufferSplitFirst(t *testing.T) {
b1 := New()
b1.WriteString("b1")
b2 := New()
b2.WriteString("b2")
b3 := New()
b3.WriteString("b3")
var mb MultiBuffer
mb = append(mb, b1, b2, b3)
mb, c1 := SplitFirst(mb)
if diff := cmp.Diff(b1.String(), c1.String()); diff != "" {
t.Error(diff)
}
mb, c2 := SplitFirst(mb)
if diff := cmp.Diff(b2.String(), c2.String()); diff != "" {
t.Error(diff)
}
mb, c3 := SplitFirst(mb)
if diff := cmp.Diff(b3.String(), c3.String()); diff != "" {
t.Error(diff)
}
if !mb.IsEmpty() {
t.Error("expect empty buffer, but got ", mb.String())
}
}
func TestMultiBufferReadAllToByte(t *testing.T) {
{
lb := make([]byte, 8*1024)
common.Must2(io.ReadFull(rand.Reader, lb))
rd := bytes.NewBuffer(lb)
b, err := ReadAllToBytes(rd)
common.Must(err)
if l := len(b); l != 8*1024 {
t.Error("unexpceted length from ReadAllToBytes", l)
}
}
{
const dat = "data/test_MultiBufferReadAllToByte.dat"
f, err := os.Open(dat)
common.Must(err)
buf2, err := ReadAllToBytes(f)
common.Must(err)
f.Close()
cnt, err := ioutil.ReadFile(dat)
common.Must(err)
if d := cmp.Diff(buf2, cnt); d != "" {
t.Error("fail to read from file: ", d)
}
}
}
func TestMultiBufferCopy(t *testing.T) {
lb := make([]byte, 8*1024)
common.Must2(io.ReadFull(rand.Reader, lb))
reader := bytes.NewBuffer(lb)
mb, err := ReadFrom(reader)
common.Must(err)
lbdst := make([]byte, 8*1024)
mb.Copy(lbdst)
if d := cmp.Diff(lb, lbdst); d != "" {
t.Error("unexpceted different from MultiBufferCopy ", d)
}
}
func TestSplitFirstBytes(t *testing.T) {
a := New()
common.Must2(a.WriteString("ab"))
b := New()
common.Must2(b.WriteString("bc"))
mb := MultiBuffer{a, b}
o := make([]byte, 2)
_, cnt := SplitFirstBytes(mb, o)
if cnt != 2 {
t.Error("unexpected cnt from SplitFirstBytes ", cnt)
}
if d := cmp.Diff(string(o), "ab"); d != "" {
t.Error("unexpected splited result from SplitFirstBytes ", d)
}
}
func TestCompact(t *testing.T) {
a := New()
common.Must2(a.WriteString("ab"))
b := New()
common.Must2(b.WriteString("bc"))
mb := MultiBuffer{a, b}
cmb := Compact(mb)
if w := cmb.String(); w != "abbc" {
t.Error("unexpected Compact result ", w)
}
}
func BenchmarkSplitBytes(b *testing.B) {
var mb MultiBuffer
raw := make([]byte, Size)
b.ResetTimer()
for i := 0; i < b.N; i++ {
buffer := StackNew()
buffer.Extend(Size)
mb = append(mb, &buffer)
mb, _ = SplitBytes(mb, raw)
}
}

174
common/buf/reader.go Normal file
View file

@ -0,0 +1,174 @@
package buf
import (
"io"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/errors"
)
func readOneUDP(r io.Reader) (*Buffer, error) {
b := New()
for i := 0; i < 64; i++ {
_, err := b.ReadFrom(r)
if !b.IsEmpty() {
return b, nil
}
if err != nil {
b.Release()
return nil, err
}
}
b.Release()
return nil, newError("Reader returns too many empty payloads.")
}
// ReadBuffer reads a Buffer from the given reader.
func ReadBuffer(r io.Reader) (*Buffer, error) {
b := New()
n, err := b.ReadFrom(r)
if n > 0 {
return b, err
}
b.Release()
return nil, err
}
// BufferedReader is a Reader that keeps its internal buffer.
type BufferedReader struct {
// Reader is the underlying reader to be read from
Reader Reader
// Buffer is the internal buffer to be read from first
Buffer MultiBuffer
// Spliter is a function to read bytes from MultiBuffer
Spliter func(MultiBuffer, []byte) (MultiBuffer, int)
}
// BufferedBytes returns the number of bytes that is cached in this reader.
func (r *BufferedReader) BufferedBytes() int32 {
return r.Buffer.Len()
}
// ReadByte implements io.ByteReader.
func (r *BufferedReader) ReadByte() (byte, error) {
var b [1]byte
_, err := r.Read(b[:])
return b[0], err
}
// Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
func (r *BufferedReader) Read(b []byte) (int, error) {
spliter := r.Spliter
if spliter == nil {
spliter = SplitBytes
}
if !r.Buffer.IsEmpty() {
buffer, nBytes := spliter(r.Buffer, b)
r.Buffer = buffer
if r.Buffer.IsEmpty() {
r.Buffer = nil
}
return nBytes, nil
}
mb, err := r.Reader.ReadMultiBuffer()
if err != nil {
return 0, err
}
mb, nBytes := spliter(mb, b)
if !mb.IsEmpty() {
r.Buffer = mb
}
return nBytes, nil
}
// ReadMultiBuffer implements Reader.
func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
if !r.Buffer.IsEmpty() {
mb := r.Buffer
r.Buffer = nil
return mb, nil
}
return r.Reader.ReadMultiBuffer()
}
// ReadAtMost returns a MultiBuffer with at most size.
func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) {
if r.Buffer.IsEmpty() {
mb, err := r.Reader.ReadMultiBuffer()
if mb.IsEmpty() && err != nil {
return nil, err
}
r.Buffer = mb
}
rb, mb := SplitSize(r.Buffer, size)
r.Buffer = rb
if r.Buffer.IsEmpty() {
r.Buffer = nil
}
return mb, nil
}
func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
mbWriter := NewWriter(writer)
var sc SizeCounter
if r.Buffer != nil {
sc.Size = int64(r.Buffer.Len())
if err := mbWriter.WriteMultiBuffer(r.Buffer); err != nil {
return 0, err
}
r.Buffer = nil
}
err := Copy(r.Reader, mbWriter, CountSize(&sc))
return sc.Size, err
}
// WriteTo implements io.WriterTo.
func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
nBytes, err := r.writeToInternal(writer)
if errors.Cause(err) == io.EOF {
return nBytes, nil
}
return nBytes, err
}
// Interrupt implements common.Interruptible.
func (r *BufferedReader) Interrupt() {
common.Interrupt(r.Reader)
}
// Close implements io.Closer.
func (r *BufferedReader) Close() error {
return common.Close(r.Reader)
}
// SingleReader is a Reader that read one Buffer every time.
type SingleReader struct {
io.Reader
}
// ReadMultiBuffer implements Reader.
func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) {
b, err := ReadBuffer(r.Reader)
return MultiBuffer{b}, err
}
// PacketReader is a Reader that read one Buffer every time.
type PacketReader struct {
io.Reader
}
// ReadMultiBuffer implements Reader.
func (r *PacketReader) ReadMultiBuffer() (MultiBuffer, error) {
b, err := readOneUDP(r.Reader)
if err != nil {
return nil, err
}
return MultiBuffer{b}, nil
}

131
common/buf/reader_test.go Normal file
View file

@ -0,0 +1,131 @@
package buf_test
import (
"bytes"
"io"
"strings"
"testing"
"github.com/xtls/xray-core/v1/common"
. "github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/transport/pipe"
)
func TestBytesReaderWriteTo(t *testing.T) {
pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
reader := &BufferedReader{Reader: pReader}
b1 := New()
b1.WriteString("abc")
b2 := New()
b2.WriteString("efg")
common.Must(pWriter.WriteMultiBuffer(MultiBuffer{b1, b2}))
pWriter.Close()
pReader2, pWriter2 := pipe.New(pipe.WithSizeLimit(1024))
writer := NewBufferedWriter(pWriter2)
writer.SetBuffered(false)
nBytes, err := io.Copy(writer, reader)
common.Must(err)
if nBytes != 6 {
t.Error("copy: ", nBytes)
}
mb, err := pReader2.ReadMultiBuffer()
common.Must(err)
if s := mb.String(); s != "abcefg" {
t.Error("content: ", s)
}
}
func TestBytesReaderMultiBuffer(t *testing.T) {
pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
reader := &BufferedReader{Reader: pReader}
b1 := New()
b1.WriteString("abc")
b2 := New()
b2.WriteString("efg")
common.Must(pWriter.WriteMultiBuffer(MultiBuffer{b1, b2}))
pWriter.Close()
mbReader := NewReader(reader)
mb, err := mbReader.ReadMultiBuffer()
common.Must(err)
if s := mb.String(); s != "abcefg" {
t.Error("content: ", s)
}
}
func TestReadByte(t *testing.T) {
sr := strings.NewReader("abcd")
reader := &BufferedReader{
Reader: NewReader(sr),
}
b, err := reader.ReadByte()
common.Must(err)
if b != 'a' {
t.Error("unexpected byte: ", b, " want a")
}
if reader.BufferedBytes() != 3 { // 3 bytes left in buffer
t.Error("unexpected buffered Bytes: ", reader.BufferedBytes())
}
nBytes, err := reader.WriteTo(DiscardBytes)
common.Must(err)
if nBytes != 3 {
t.Error("unexpect bytes written: ", nBytes)
}
}
func TestReadBuffer(t *testing.T) {
{
sr := strings.NewReader("abcd")
buf, err := ReadBuffer(sr)
common.Must(err)
if s := buf.String(); s != "abcd" {
t.Error("unexpected str: ", s, " want abcd")
}
buf.Release()
}
}
func TestReadAtMost(t *testing.T) {
sr := strings.NewReader("abcd")
reader := &BufferedReader{
Reader: NewReader(sr),
}
mb, err := reader.ReadAtMost(3)
common.Must(err)
if s := mb.String(); s != "abc" {
t.Error("unexpected read result: ", s)
}
nBytes, err := reader.WriteTo(DiscardBytes)
common.Must(err)
if nBytes != 1 {
t.Error("unexpect bytes written: ", nBytes)
}
}
func TestPacketReader_ReadMultiBuffer(t *testing.T) {
const alpha = "abcefg"
buf := bytes.NewBufferString(alpha)
reader := &PacketReader{buf}
mb, err := reader.ReadMultiBuffer()
common.Must(err)
if s := mb.String(); s != alpha {
t.Error("content: ", s)
}
}
func TestReaderInterface(t *testing.T) {
_ = (io.Reader)(new(ReadVReader))
_ = (Reader)(new(ReadVReader))
_ = (Reader)(new(BufferedReader))
_ = (io.Reader)(new(BufferedReader))
_ = (io.ByteReader)(new(BufferedReader))
_ = (io.WriterTo)(new(BufferedReader))
}

47
common/buf/readv_posix.go Normal file
View file

@ -0,0 +1,47 @@
// +build !windows
// +build !wasm
// +build !illumos
package buf
import (
"syscall"
"unsafe"
)
type posixReader struct {
iovecs []syscall.Iovec
}
func (r *posixReader) Init(bs []*Buffer) {
iovecs := r.iovecs
if iovecs == nil {
iovecs = make([]syscall.Iovec, 0, len(bs))
}
for idx, b := range bs {
iovecs = append(iovecs, syscall.Iovec{
Base: &(b.v[0]),
})
iovecs[idx].SetLen(int(Size))
}
r.iovecs = iovecs
}
func (r *posixReader) Read(fd uintptr) int32 {
n, _, e := syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&r.iovecs[0])), uintptr(len(r.iovecs)))
if e != 0 {
return -1
}
return int32(n)
}
func (r *posixReader) Clear() {
for idx := range r.iovecs {
r.iovecs[idx].Base = nil
}
r.iovecs = r.iovecs[:0]
}
func newMultiReader() multiReader {
return &posixReader{}
}

150
common/buf/readv_reader.go Normal file
View file

@ -0,0 +1,150 @@
// +build !wasm
package buf
import (
"io"
"runtime"
"syscall"
"github.com/xtls/xray-core/v1/common/platform"
)
type allocStrategy struct {
current uint32
}
func (s *allocStrategy) Current() uint32 {
return s.current
}
func (s *allocStrategy) Adjust(n uint32) {
if n >= s.current {
s.current *= 2
} else {
s.current = n
}
if s.current > 8 {
s.current = 8
}
if s.current == 0 {
s.current = 1
}
}
func (s *allocStrategy) Alloc() []*Buffer {
bs := make([]*Buffer, s.current)
for i := range bs {
bs[i] = New()
}
return bs
}
type multiReader interface {
Init([]*Buffer)
Read(fd uintptr) int32
Clear()
}
// ReadVReader is a Reader that uses readv(2) syscall to read data.
type ReadVReader struct {
io.Reader
rawConn syscall.RawConn
mr multiReader
alloc allocStrategy
}
// NewReadVReader creates a new ReadVReader.
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn) *ReadVReader {
return &ReadVReader{
Reader: reader,
rawConn: rawConn,
alloc: allocStrategy{
current: 1,
},
mr: newMultiReader(),
}
}
func (r *ReadVReader) readMulti() (MultiBuffer, error) {
bs := r.alloc.Alloc()
r.mr.Init(bs)
var nBytes int32
err := r.rawConn.Read(func(fd uintptr) bool {
n := r.mr.Read(fd)
if n < 0 {
return false
}
nBytes = n
return true
})
r.mr.Clear()
if err != nil {
ReleaseMulti(MultiBuffer(bs))
return nil, err
}
if nBytes == 0 {
ReleaseMulti(MultiBuffer(bs))
return nil, io.EOF
}
nBuf := 0
for nBuf < len(bs) {
if nBytes <= 0 {
break
}
end := nBytes
if end > Size {
end = Size
}
bs[nBuf].end = end
nBytes -= end
nBuf++
}
for i := nBuf; i < len(bs); i++ {
bs[i].Release()
bs[i] = nil
}
return MultiBuffer(bs[:nBuf]), nil
}
// ReadMultiBuffer implements Reader.
func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.alloc.Current() == 1 {
b, err := ReadBuffer(r.Reader)
if b.IsFull() {
r.alloc.Adjust(1)
}
return MultiBuffer{b}, err
}
mb, err := r.readMulti()
if err != nil {
return nil, err
}
r.alloc.Adjust(uint32(len(mb)))
return mb, nil
}
var useReadv = true
func init() {
const defaultFlagValue = "NOT_DEFINED_AT_ALL"
value := platform.NewEnvFlag("xray.buf.readv").GetValue(func() string { return defaultFlagValue })
switch value {
case defaultFlagValue, "auto":
if (runtime.GOARCH == "386" || runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x") && (runtime.GOOS == "linux" || runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
useReadv = true
}
case "enable":
useReadv = true
}
}

View file

@ -0,0 +1,14 @@
// +build wasm
package buf
import (
"io"
"syscall"
)
const useReadv = false
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn) Reader {
panic("not implemented")
}

72
common/buf/readv_test.go Normal file
View file

@ -0,0 +1,72 @@
// +build !wasm
package buf_test
import (
"crypto/rand"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
. "github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/testing/servers/tcp"
"golang.org/x/sync/errgroup"
)
func TestReadvReader(t *testing.T) {
tcpServer := &tcp.Server{
MsgProcessor: func(b []byte) []byte {
return b
},
}
dest, err := tcpServer.Start()
common.Must(err)
defer tcpServer.Close()
conn, err := net.Dial("tcp", dest.NetAddr())
common.Must(err)
defer conn.Close()
const size = 8192
data := make([]byte, 8192)
common.Must2(rand.Read(data))
var errg errgroup.Group
errg.Go(func() error {
writer := NewWriter(conn)
mb := MergeBytes(nil, data)
return writer.WriteMultiBuffer(mb)
})
defer func() {
if err := errg.Wait(); err != nil {
t.Error(err)
}
}()
rawConn, err := conn.(*net.TCPConn).SyscallConn()
common.Must(err)
reader := NewReadVReader(conn, rawConn)
var rmb MultiBuffer
for {
mb, err := reader.ReadMultiBuffer()
if err != nil {
t.Fatal("unexpected error: ", err)
}
rmb, _ = MergeMulti(rmb, mb)
if rmb.Len() == size {
break
}
}
rdata := make([]byte, size)
SplitBytes(rmb, rdata)
if r := cmp.Diff(data, rdata); r != "" {
t.Fatal(r)
}
}

36
common/buf/readv_unix.go Normal file
View file

@ -0,0 +1,36 @@
// +build illumos
package buf
import "golang.org/x/sys/unix"
type unixReader struct {
iovs [][]byte
}
func (r *unixReader) Init(bs []*Buffer) {
iovs := r.iovs
if iovs == nil {
iovs = make([][]byte, 0, len(bs))
}
for _, b := range bs {
iovs = append(iovs, b.v)
}
r.iovs = iovs
}
func (r *unixReader) Read(fd uintptr) int32 {
n, e := unix.Readv(int(fd), r.iovs)
if e != nil {
return -1
}
return int32(n)
}
func (r *unixReader) Clear() {
r.iovs = r.iovs[:0]
}
func newMultiReader() multiReader {
return &unixReader{}
}

View file

@ -0,0 +1,39 @@
package buf
import (
"syscall"
)
type windowsReader struct {
bufs []syscall.WSABuf
}
func (r *windowsReader) Init(bs []*Buffer) {
if r.bufs == nil {
r.bufs = make([]syscall.WSABuf, 0, len(bs))
}
for _, b := range bs {
r.bufs = append(r.bufs, syscall.WSABuf{Len: uint32(Size), Buf: &b.v[0]})
}
}
func (r *windowsReader) Clear() {
for idx := range r.bufs {
r.bufs[idx].Buf = nil
}
r.bufs = r.bufs[:0]
}
func (r *windowsReader) Read(fd uintptr) int32 {
var nBytes uint32
var flags uint32
err := syscall.WSARecv(syscall.Handle(fd), &r.bufs[0], uint32(len(r.bufs)), &nBytes, &flags, nil, nil)
if err != nil {
return -1
}
return int32(nBytes)
}
func newMultiReader() multiReader {
return new(windowsReader)
}

262
common/buf/writer.go Normal file
View file

@ -0,0 +1,262 @@
package buf
import (
"io"
"net"
"sync"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/errors"
)
// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
type BufferToBytesWriter struct {
io.Writer
cache [][]byte
}
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer ReleaseMulti(mb)
size := mb.Len()
if size == 0 {
return nil
}
if len(mb) == 1 {
return WriteAllBytes(w.Writer, mb[0].Bytes())
}
if cap(w.cache) < len(mb) {
w.cache = make([][]byte, 0, len(mb))
}
bs := w.cache
for _, b := range mb {
bs = append(bs, b.Bytes())
}
defer func() {
for idx := range bs {
bs[idx] = nil
}
}()
nb := net.Buffers(bs)
for size > 0 {
n, err := nb.WriteTo(w.Writer)
if err != nil {
return err
}
size -= int32(n)
}
return nil
}
// ReadFrom implements io.ReaderFrom.
func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
// BufferedWriter is a Writer with internal buffer.
type BufferedWriter struct {
sync.Mutex
writer Writer
buffer *Buffer
buffered bool
}
// NewBufferedWriter creates a new BufferedWriter.
func NewBufferedWriter(writer Writer) *BufferedWriter {
return &BufferedWriter{
writer: writer,
buffer: New(),
buffered: true,
}
}
// WriteByte implements io.ByteWriter.
func (w *BufferedWriter) WriteByte(c byte) error {
return common.Error2(w.Write([]byte{c}))
}
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
w.Lock()
defer w.Unlock()
if !w.buffered {
if writer, ok := w.writer.(io.Writer); ok {
return writer.Write(b)
}
}
totalBytes := 0
for len(b) > 0 {
if w.buffer == nil {
w.buffer = New()
}
nBytes, err := w.buffer.Write(b)
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
if !w.buffered || w.buffer.IsFull() {
if err := w.flushInternal(); err != nil {
return totalBytes, err
}
}
b = b[nBytes:]
}
return totalBytes, nil
}
// WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer.
func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
if b.IsEmpty() {
return nil
}
w.Lock()
defer w.Unlock()
if !w.buffered {
return w.writer.WriteMultiBuffer(b)
}
reader := MultiBufferContainer{
MultiBuffer: b,
}
defer reader.Close()
for !reader.MultiBuffer.IsEmpty() {
if w.buffer == nil {
w.buffer = New()
}
common.Must2(w.buffer.ReadFrom(&reader))
if w.buffer.IsFull() {
if err := w.flushInternal(); err != nil {
return err
}
}
}
return nil
}
// Flush flushes buffered content into underlying writer.
func (w *BufferedWriter) Flush() error {
w.Lock()
defer w.Unlock()
return w.flushInternal()
}
func (w *BufferedWriter) flushInternal() error {
if w.buffer.IsEmpty() {
return nil
}
b := w.buffer
w.buffer = nil
if writer, ok := w.writer.(io.Writer); ok {
err := WriteAllBytes(writer, b.Bytes())
b.Release()
return err
}
return w.writer.WriteMultiBuffer(MultiBuffer{b})
}
// SetBuffered sets whether the internal buffer is used. If set to false, Flush() will be called to clear the buffer.
func (w *BufferedWriter) SetBuffered(f bool) error {
w.Lock()
defer w.Unlock()
w.buffered = f
if !f {
return w.flushInternal()
}
return nil
}
// ReadFrom implements io.ReaderFrom.
func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
if err := w.SetBuffered(false); err != nil {
return 0, err
}
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
// Close implements io.Closable.
func (w *BufferedWriter) Close() error {
if err := w.Flush(); err != nil {
return err
}
return common.Close(w.writer)
}
// SequentialWriter is a Writer that writes MultiBuffer sequentially into the underlying io.Writer.
type SequentialWriter struct {
io.Writer
}
// WriteMultiBuffer implements Writer.
func (w *SequentialWriter) WriteMultiBuffer(mb MultiBuffer) error {
mb, err := WriteMultiBuffer(w.Writer, mb)
ReleaseMulti(mb)
return err
}
type noOpWriter byte
func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
ReleaseMulti(b)
return nil
}
func (noOpWriter) Write(b []byte) (int, error) {
return len(b), nil
}
func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
b := New()
defer b.Release()
totalBytes := int64(0)
for {
b.Clear()
_, err := b.ReadFrom(reader)
totalBytes += int64(b.Len())
if err != nil {
if errors.Cause(err) == io.EOF {
return totalBytes, nil
}
return totalBytes, err
}
}
}
var (
// Discard is a Writer that swallows all contents written in.
Discard Writer = noOpWriter(0)
// DiscardBytes is an io.Writer that swallows all contents written in.
DiscardBytes io.Writer = noOpWriter(0)
)

98
common/buf/writer_test.go Normal file
View file

@ -0,0 +1,98 @@
package buf_test
import (
"bufio"
"bytes"
"crypto/rand"
"io"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
. "github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/transport/pipe"
)
func TestWriter(t *testing.T) {
lb := New()
common.Must2(lb.ReadFrom(rand.Reader))
expectedBytes := append([]byte(nil), lb.Bytes()...)
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
writer := NewBufferedWriter(NewWriter(writeBuffer))
writer.SetBuffered(false)
common.Must(writer.WriteMultiBuffer(MultiBuffer{lb}))
common.Must(writer.Flush())
if r := cmp.Diff(expectedBytes, writeBuffer.Bytes()); r != "" {
t.Error(r)
}
}
func TestBytesWriterReadFrom(t *testing.T) {
const size = 50000
pReader, pWriter := pipe.New(pipe.WithSizeLimit(size))
reader := bufio.NewReader(io.LimitReader(rand.Reader, size))
writer := NewBufferedWriter(pWriter)
writer.SetBuffered(false)
nBytes, err := reader.WriteTo(writer)
if nBytes != size {
t.Fatal("unexpected size of bytes written: ", nBytes)
}
if err != nil {
t.Fatal("expect success, but actually error: ", err.Error())
}
mb, err := pReader.ReadMultiBuffer()
common.Must(err)
if mb.Len() != size {
t.Fatal("unexpected size read: ", mb.Len())
}
}
func TestDiscardBytes(t *testing.T) {
b := New()
common.Must2(b.ReadFullFrom(rand.Reader, Size))
nBytes, err := io.Copy(DiscardBytes, b)
common.Must(err)
if nBytes != Size {
t.Error("copy size: ", nBytes)
}
}
func TestDiscardBytesMultiBuffer(t *testing.T) {
const size = 10240*1024 + 1
buffer := bytes.NewBuffer(make([]byte, 0, size))
common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size)))
r := NewReader(buffer)
nBytes, err := io.Copy(DiscardBytes, &BufferedReader{Reader: r})
common.Must(err)
if nBytes != size {
t.Error("copy size: ", nBytes)
}
}
func TestWriterInterface(t *testing.T) {
{
var writer interface{} = (*BufferToBytesWriter)(nil)
switch writer.(type) {
case Writer, io.Writer, io.ReaderFrom:
default:
t.Error("BufferToBytesWriter is not Writer, io.Writer or io.ReaderFrom")
}
}
{
var writer interface{} = (*BufferedWriter)(nil)
switch writer.(type) {
case Writer, io.Writer, io.ReaderFrom, io.ByteWriter:
default:
t.Error("BufferedWriter is not Writer, io.Writer, io.ReaderFrom or io.ByteWriter")
}
}
}

72
common/bytespool/pool.go Normal file
View file

@ -0,0 +1,72 @@
package bytespool
import "sync"
func createAllocFunc(size int32) func() interface{} {
return func() interface{} {
return make([]byte, size)
}
}
// The following parameters controls the size of buffer pools.
// There are numPools pools. Starting from 2k size, the size of each pool is sizeMulti of the previous one.
// Package buf is guaranteed to not use buffers larger than the largest pool.
// Other packets may use larger buffers.
const (
numPools = 4
sizeMulti = 4
)
var (
pool [numPools]sync.Pool
poolSize [numPools]int32
)
func init() {
size := int32(2048)
for i := 0; i < numPools; i++ {
pool[i] = sync.Pool{
New: createAllocFunc(size),
}
poolSize[i] = size
size *= sizeMulti
}
}
// GetPool returns a sync.Pool that generates bytes array with at least the given size.
// It may return nil if no such pool exists.
//
// xray:api:stable
func GetPool(size int32) *sync.Pool {
for idx, ps := range poolSize {
if size <= ps {
return &pool[idx]
}
}
return nil
}
// Alloc returns a byte slice with at least the given size. Minimum size of returned slice is 2048.
//
// xray:api:stable
func Alloc(size int32) []byte {
pool := GetPool(size)
if pool != nil {
return pool.Get().([]byte)
}
return make([]byte, size)
}
// Free puts a byte slice into the internal pool.
//
// xray:api:stable
func Free(b []byte) {
size := int32(cap(b))
b = b[0:cap(b)]
for i := numPools - 1; i >= 0; i-- {
if size >= poolSize[i] {
pool[i].Put(b)
return
}
}
}

16
common/cmdarg/cmdarg.go Normal file
View file

@ -0,0 +1,16 @@
package cmdarg
import "strings"
// Arg is used by flag to accept multiple argument.
type Arg []string
func (c *Arg) String() string {
return strings.Join([]string(*c), " ")
}
// Set is the method flag package calls
func (c *Arg) Set(value string) error {
*c = append(*c, value)
return nil
}

158
common/common.go Normal file
View file

@ -0,0 +1,158 @@
// Package common contains common utilities that are shared among other packages.
// See each sub-package for detail.
package common
import (
"fmt"
"go/build"
"io/ioutil"
"os"
"path/filepath"
"strings"
"github.com/xtls/xray-core/v1/common/errors"
)
//go:generate go run github.com/xtls/xray-core/v1/common/errors/errorgen
var (
// ErrNoClue is for the situation that existing information is not enough to make a decision. For example, Router may return this error when there is no suitable route.
ErrNoClue = errors.New("not enough information for making a decision")
)
// Must panics if err is not nil.
func Must(err error) {
if err != nil {
panic(err)
}
}
// Must2 panics if the second parameter is not nil, otherwise returns the first parameter.
func Must2(v interface{}, err error) interface{} {
Must(err)
return v
}
// Error2 returns the err from the 2nd parameter.
func Error2(v interface{}, err error) error {
return err
}
// envFile returns the name of the Go environment configuration file.
// Copy from https://github.com/golang/go/blob/c4f2a9788a7be04daf931ac54382fbe2cb754938/src/cmd/go/internal/cfg/cfg.go#L150-L166
func envFile() (string, error) {
if file := os.Getenv("GOENV"); file != "" {
if file == "off" {
return "", fmt.Errorf("GOENV=off")
}
return file, nil
}
dir, err := os.UserConfigDir()
if err != nil {
return "", err
}
if dir == "" {
return "", fmt.Errorf("missing user-config dir")
}
return filepath.Join(dir, "go", "env"), nil
}
// GetRuntimeEnv returns the value of runtime environment variable,
// that is set by running following command: `go env -w key=value`.
func GetRuntimeEnv(key string) (string, error) {
file, err := envFile()
if err != nil {
return "", err
}
if file == "" {
return "", fmt.Errorf("missing runtime env file")
}
var data []byte
var runtimeEnv string
data, readErr := ioutil.ReadFile(file)
if readErr != nil {
return "", readErr
}
envStrings := strings.Split(string(data), "\n")
for _, envItem := range envStrings {
envItem = strings.TrimSuffix(envItem, "\r")
envKeyValue := strings.Split(envItem, "=")
if strings.EqualFold(strings.TrimSpace(envKeyValue[0]), key) {
runtimeEnv = strings.TrimSpace(envKeyValue[1])
}
}
return runtimeEnv, nil
}
// GetGOBIN returns GOBIN environment variable as a string. It will NOT be empty.
func GetGOBIN() string {
// The one set by user explicitly by `export GOBIN=/path` or `env GOBIN=/path command`
GOBIN := os.Getenv("GOBIN")
if GOBIN == "" {
var err error
// The one set by user by running `go env -w GOBIN=/path`
GOBIN, err = GetRuntimeEnv("GOBIN")
if err != nil {
// The default one that Golang uses
return filepath.Join(build.Default.GOPATH, "bin")
}
if GOBIN == "" {
return filepath.Join(build.Default.GOPATH, "bin")
}
return GOBIN
}
return GOBIN
}
// GetGOPATH returns GOPATH environment variable as a string. It will NOT be empty.
func GetGOPATH() string {
// The one set by user explicitly by `export GOPATH=/path` or `env GOPATH=/path command`
GOPATH := os.Getenv("GOPATH")
if GOPATH == "" {
var err error
// The one set by user by running `go env -w GOPATH=/path`
GOPATH, err = GetRuntimeEnv("GOPATH")
if err != nil {
// The default one that Golang uses
return build.Default.GOPATH
}
if GOPATH == "" {
return build.Default.GOPATH
}
return GOPATH
}
return GOPATH
}
// GetModuleName returns the value of module in `go.mod` file.
func GetModuleName(pathToProjectRoot string) (string, error) {
var moduleName string
loopPath := pathToProjectRoot
for {
if idx := strings.LastIndex(loopPath, string(filepath.Separator)); idx >= 0 {
gomodPath := filepath.Join(loopPath, "go.mod")
gomodBytes, err := ioutil.ReadFile(gomodPath)
if err != nil {
loopPath = loopPath[:idx]
continue
}
gomodContent := string(gomodBytes)
moduleIdx := strings.Index(gomodContent, "module ")
newLineIdx := strings.Index(gomodContent, "\n")
if moduleIdx >= 0 {
if newLineIdx >= 0 {
moduleName = strings.TrimSpace(gomodContent[moduleIdx+6 : newLineIdx])
moduleName = strings.TrimSuffix(moduleName, "\r")
} else {
moduleName = strings.TrimSpace(gomodContent[moduleIdx+6:])
}
return moduleName, nil
}
return "", fmt.Errorf("can not get module path in `%s`", gomodPath)
}
break
}
return moduleName, fmt.Errorf("no `go.mod` file in every parent directory of `%s`", pathToProjectRoot)
}

44
common/common_test.go Normal file
View file

@ -0,0 +1,44 @@
package common_test
import (
"errors"
"testing"
. "github.com/xtls/xray-core/v1/common"
)
func TestMust(t *testing.T) {
hasPanic := func(f func()) (ret bool) {
defer func() {
if r := recover(); r != nil {
ret = true
}
}()
f()
return false
}
testCases := []struct {
Input func()
Panic bool
}{
{
Panic: true,
Input: func() { Must(func() error { return errors.New("test error") }()) },
},
{
Panic: true,
Input: func() { Must2(func() (int, error) { return 0, errors.New("test error") }()) },
},
{
Panic: false,
Input: func() { Must(func() error { return nil }()) },
},
}
for idx, test := range testCases {
if hasPanic(test.Input) != test.Panic {
t.Error("test case #", idx, " expect panic ", test.Panic, " but actually not")
}
}
}

40
common/crypto/aes.go Normal file
View file

@ -0,0 +1,40 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"github.com/xtls/xray-core/v1/common"
)
// NewAesDecryptionStream creates a new AES encryption stream based on given key and IV.
// Caller must ensure the length of key and IV is either 16, 24 or 32 bytes.
func NewAesDecryptionStream(key []byte, iv []byte) cipher.Stream {
return NewAesStreamMethod(key, iv, cipher.NewCFBDecrypter)
}
// NewAesEncryptionStream creates a new AES description stream based on given key and IV.
// Caller must ensure the length of key and IV is either 16, 24 or 32 bytes.
func NewAesEncryptionStream(key []byte, iv []byte) cipher.Stream {
return NewAesStreamMethod(key, iv, cipher.NewCFBEncrypter)
}
func NewAesStreamMethod(key []byte, iv []byte, f func(cipher.Block, []byte) cipher.Stream) cipher.Stream {
aesBlock, err := aes.NewCipher(key)
common.Must(err)
return f(aesBlock, iv)
}
// NewAesCTRStream creates a stream cipher based on AES-CTR.
func NewAesCTRStream(key []byte, iv []byte) cipher.Stream {
return NewAesStreamMethod(key, iv, cipher.NewCTR)
}
// NewAesGcm creates a AEAD cipher based on AES-GCM.
func NewAesGcm(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
return aead
}

345
common/crypto/auth.go Normal file
View file

@ -0,0 +1,345 @@
package crypto
import (
"crypto/cipher"
"io"
"math/rand"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/bytespool"
"github.com/xtls/xray-core/v1/common/protocol"
)
type BytesGenerator func() []byte
func GenerateEmptyBytes() BytesGenerator {
var b [1]byte
return func() []byte {
return b[:0]
}
}
func GenerateStaticBytes(content []byte) BytesGenerator {
return func() []byte {
return content
}
}
func GenerateIncreasingNonce(nonce []byte) BytesGenerator {
c := append([]byte(nil), nonce...)
return func() []byte {
for i := range c {
c[i]++
if c[i] != 0 {
break
}
}
return c
}
}
func GenerateInitialAEADNonce() BytesGenerator {
return GenerateIncreasingNonce([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF})
}
type Authenticator interface {
NonceSize() int
Overhead() int
Open(dst, cipherText []byte) ([]byte, error)
Seal(dst, plainText []byte) ([]byte, error)
}
type AEADAuthenticator struct {
cipher.AEAD
NonceGenerator BytesGenerator
AdditionalDataGenerator BytesGenerator
}
func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
iv := v.NonceGenerator()
if len(iv) != v.AEAD.NonceSize() {
return nil, newError("invalid AEAD nonce size: ", len(iv))
}
var additionalData []byte
if v.AdditionalDataGenerator != nil {
additionalData = v.AdditionalDataGenerator()
}
return v.AEAD.Open(dst, iv, cipherText, additionalData)
}
func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
iv := v.NonceGenerator()
if len(iv) != v.AEAD.NonceSize() {
return nil, newError("invalid AEAD nonce size: ", len(iv))
}
var additionalData []byte
if v.AdditionalDataGenerator != nil {
additionalData = v.AdditionalDataGenerator()
}
return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
}
type AuthenticationReader struct {
auth Authenticator
reader *buf.BufferedReader
sizeParser ChunkSizeDecoder
sizeBytes []byte
transferType protocol.TransferType
padding PaddingLengthGenerator
size uint16
paddingLen uint16
hasSize bool
done bool
}
func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType, paddingLen PaddingLengthGenerator) *AuthenticationReader {
r := &AuthenticationReader{
auth: auth,
sizeParser: sizeParser,
transferType: transferType,
padding: paddingLen,
sizeBytes: make([]byte, sizeParser.SizeBytes()),
}
if breader, ok := reader.(*buf.BufferedReader); ok {
r.reader = breader
} else {
r.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)}
}
return r
}
func (r *AuthenticationReader) readSize() (uint16, uint16, error) {
if r.hasSize {
r.hasSize = false
return r.size, r.paddingLen, nil
}
if _, err := io.ReadFull(r.reader, r.sizeBytes); err != nil {
return 0, 0, err
}
var padding uint16
if r.padding != nil {
padding = r.padding.NextPaddingLen()
}
size, err := r.sizeParser.Decode(r.sizeBytes)
return size, padding, err
}
var errSoft = newError("waiting for more data")
func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) {
b := buf.New()
if _, err := b.ReadFullFrom(r.reader, size); err != nil {
b.Release()
return nil, err
}
size -= padding
rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size))
if err != nil {
b.Release()
return nil, err
}
b.Resize(0, int32(len(rb)))
return b, nil
}
func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) error {
if soft && r.reader.BufferedBytes() < r.sizeParser.SizeBytes() {
return errSoft
}
if r.done {
return io.EOF
}
size, padding, err := r.readSize()
if err != nil {
return err
}
if size == uint16(r.auth.Overhead())+padding {
r.done = true
return io.EOF
}
if soft && int32(size) > r.reader.BufferedBytes() {
r.size = size
r.paddingLen = padding
r.hasSize = true
return errSoft
}
if size <= buf.Size {
b, err := r.readBuffer(int32(size), int32(padding))
if err != nil {
return nil
}
*mb = append(*mb, b)
return nil
}
payload := bytespool.Alloc(int32(size))
defer bytespool.Free(payload)
if _, err := io.ReadFull(r.reader, payload[:size]); err != nil {
return err
}
size -= padding
rb, err := r.auth.Open(payload[:0], payload[:size])
if err != nil {
return err
}
*mb = buf.MergeBytes(*mb, rb)
return nil
}
func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
const readSize = 16
mb := make(buf.MultiBuffer, 0, readSize)
if err := r.readInternal(false, &mb); err != nil {
buf.ReleaseMulti(mb)
return nil, err
}
for i := 1; i < readSize; i++ {
err := r.readInternal(true, &mb)
if err == errSoft || err == io.EOF {
break
}
if err != nil {
buf.ReleaseMulti(mb)
return nil, err
}
}
return mb, nil
}
type AuthenticationWriter struct {
auth Authenticator
writer buf.Writer
sizeParser ChunkSizeEncoder
transferType protocol.TransferType
padding PaddingLengthGenerator
}
func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType, padding PaddingLengthGenerator) *AuthenticationWriter {
w := &AuthenticationWriter{
auth: auth,
writer: buf.NewWriter(writer),
sizeParser: sizeParser,
transferType: transferType,
}
if padding != nil {
w.padding = padding
}
return w
}
func (w *AuthenticationWriter) seal(b []byte) (*buf.Buffer, error) {
encryptedSize := int32(len(b) + w.auth.Overhead())
var paddingSize int32
if w.padding != nil {
paddingSize = int32(w.padding.NextPaddingLen())
}
sizeBytes := w.sizeParser.SizeBytes()
totalSize := sizeBytes + encryptedSize + paddingSize
if totalSize > buf.Size {
return nil, newError("size too large: ", totalSize)
}
eb := buf.New()
w.sizeParser.Encode(uint16(encryptedSize+paddingSize), eb.Extend(sizeBytes))
if _, err := w.auth.Seal(eb.Extend(encryptedSize)[:0], b); err != nil {
eb.Release()
return nil, err
}
if paddingSize > 0 {
// With size of the chunk and padding length encrypted, the content of padding doesn't matter much.
paddingBytes := eb.Extend(paddingSize)
common.Must2(rand.Read(paddingBytes))
}
return eb, nil
}
func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
defer buf.ReleaseMulti(mb)
var maxPadding int32
if w.padding != nil {
maxPadding = int32(w.padding.MaxPaddingLen())
}
payloadSize := buf.Size - int32(w.auth.Overhead()) - w.sizeParser.SizeBytes() - maxPadding
mb2Write := make(buf.MultiBuffer, 0, len(mb)+10)
temp := buf.New()
defer temp.Release()
rawBytes := temp.Extend(payloadSize)
for {
nb, nBytes := buf.SplitBytes(mb, rawBytes)
mb = nb
eb, err := w.seal(rawBytes[:nBytes])
if err != nil {
buf.ReleaseMulti(mb2Write)
return err
}
mb2Write = append(mb2Write, eb)
if mb.IsEmpty() {
break
}
}
return w.writer.WriteMultiBuffer(mb2Write)
}
func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
defer buf.ReleaseMulti(mb)
mb2Write := make(buf.MultiBuffer, 0, len(mb)+1)
for _, b := range mb {
if b.IsEmpty() {
continue
}
eb, err := w.seal(b.Bytes())
if err != nil {
continue
}
mb2Write = append(mb2Write, eb)
}
if mb2Write.IsEmpty() {
return nil
}
return w.writer.WriteMultiBuffer(mb2Write)
}
// WriteMultiBuffer implements buf.Writer.
func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if mb.IsEmpty() {
eb, err := w.seal([]byte{})
common.Must(err)
return w.writer.WriteMultiBuffer(buf.MultiBuffer{eb})
}
if w.transferType == protocol.TransferTypeStream {
return w.writeStream(mb)
}
return w.writePacket(mb)
}

143
common/crypto/auth_test.go Normal file
View file

@ -0,0 +1,143 @@
package crypto_test
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"io"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
. "github.com/xtls/xray-core/v1/common/crypto"
"github.com/xtls/xray-core/v1/common/protocol"
)
func TestAuthenticationReaderWriter(t *testing.T) {
key := make([]byte, 16)
rand.Read(key)
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
const payloadSize = 1024 * 80
rawPayload := make([]byte, payloadSize)
rand.Read(rawPayload)
payload := buf.MergeBytes(nil, rawPayload)
cache := bytes.NewBuffer(nil)
iv := make([]byte, 12)
rand.Read(iv)
writer := NewAuthenticationWriter(&AEADAuthenticator{
AEAD: aead,
NonceGenerator: GenerateStaticBytes(iv),
AdditionalDataGenerator: GenerateEmptyBytes(),
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil)
common.Must(writer.WriteMultiBuffer(payload))
if cache.Len() <= 1024*80 {
t.Error("cache len: ", cache.Len())
}
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead,
NonceGenerator: GenerateStaticBytes(iv),
AdditionalDataGenerator: GenerateEmptyBytes(),
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil)
var mb buf.MultiBuffer
for mb.Len() < payloadSize {
mb2, err := reader.ReadMultiBuffer()
common.Must(err)
mb, _ = buf.MergeMulti(mb, mb2)
}
if mb.Len() != payloadSize {
t.Error("mb len: ", mb.Len())
}
mbContent := make([]byte, payloadSize)
buf.SplitBytes(mb, mbContent)
if r := cmp.Diff(mbContent, rawPayload); r != "" {
t.Error(r)
}
_, err = reader.ReadMultiBuffer()
if err != io.EOF {
t.Error("error: ", err)
}
}
func TestAuthenticationReaderWriterPacket(t *testing.T) {
key := make([]byte, 16)
common.Must2(rand.Read(key))
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
cache := buf.New()
iv := make([]byte, 12)
rand.Read(iv)
writer := NewAuthenticationWriter(&AEADAuthenticator{
AEAD: aead,
NonceGenerator: GenerateStaticBytes(iv),
AdditionalDataGenerator: GenerateEmptyBytes(),
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil)
var payload buf.MultiBuffer
pb1 := buf.New()
pb1.Write([]byte("abcd"))
payload = append(payload, pb1)
pb2 := buf.New()
pb2.Write([]byte("efgh"))
payload = append(payload, pb2)
common.Must(writer.WriteMultiBuffer(payload))
if cache.Len() == 0 {
t.Error("cache len: ", cache.Len())
}
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead,
NonceGenerator: GenerateStaticBytes(iv),
AdditionalDataGenerator: GenerateEmptyBytes(),
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil)
mb, err := reader.ReadMultiBuffer()
common.Must(err)
mb, b1 := buf.SplitFirst(mb)
if b1.String() != "abcd" {
t.Error("b1: ", b1.String())
}
mb, b2 := buf.SplitFirst(mb)
if b2.String() != "efgh" {
t.Error("b2: ", b2.String())
}
if !mb.IsEmpty() {
t.Error("not empty")
}
_, err = reader.ReadMultiBuffer()
if err != io.EOF {
t.Error("error: ", err)
}
}

View file

@ -0,0 +1,50 @@
package crypto_test
import (
"crypto/cipher"
"testing"
. "github.com/xtls/xray-core/v1/common/crypto"
)
const benchSize = 1024 * 1024
func benchmarkStream(b *testing.B, c cipher.Stream) {
b.SetBytes(benchSize)
input := make([]byte, benchSize)
output := make([]byte, benchSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.XORKeyStream(output, input)
}
}
func BenchmarkChaCha20(b *testing.B) {
key := make([]byte, 32)
nonce := make([]byte, 8)
c := NewChaCha20Stream(key, nonce)
benchmarkStream(b, c)
}
func BenchmarkChaCha20IETF(b *testing.B) {
key := make([]byte, 32)
nonce := make([]byte, 12)
c := NewChaCha20Stream(key, nonce)
benchmarkStream(b, c)
}
func BenchmarkAESEncryption(b *testing.B) {
key := make([]byte, 32)
iv := make([]byte, 16)
c := NewAesEncryptionStream(key, iv)
benchmarkStream(b, c)
}
func BenchmarkAESDecryption(b *testing.B) {
key := make([]byte, 32)
iv := make([]byte, 16)
c := NewAesDecryptionStream(key, iv)
benchmarkStream(b, c)
}

13
common/crypto/chacha20.go Normal file
View file

@ -0,0 +1,13 @@
package crypto
import (
"crypto/cipher"
"github.com/xtls/xray-core/v1/common/crypto/internal"
)
// NewChaCha20Stream creates a new Chacha20 encryption/descryption stream based on give key and IV.
// Caller must ensure the length of key is 32 bytes, and length of IV is either 8 or 12 bytes.
func NewChaCha20Stream(key []byte, iv []byte) cipher.Stream {
return internal.NewChaCha20Stream(key, iv, 20)
}

View file

@ -0,0 +1,77 @@
package crypto_test
import (
"crypto/rand"
"encoding/hex"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
. "github.com/xtls/xray-core/v1/common/crypto"
)
func mustDecodeHex(s string) []byte {
b, err := hex.DecodeString(s)
common.Must(err)
return b
}
func TestChaCha20Stream(t *testing.T) {
var cases = []struct {
key []byte
iv []byte
output []byte
}{
{
key: mustDecodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
iv: mustDecodeHex("0000000000000000"),
output: mustDecodeHex("76b8e0ada0f13d90405d6ae55386bd28bdd219b8a08ded1aa836efcc8b770dc7" +
"da41597c5157488d7724e03fb8d84a376a43b8f41518a11cc387b669b2ee6586" +
"9f07e7be5551387a98ba977c732d080dcb0f29a048e3656912c6533e32ee7aed" +
"29b721769ce64e43d57133b074d839d531ed1f28510afb45ace10a1f4b794d6f"),
},
{
key: mustDecodeHex("5555555555555555555555555555555555555555555555555555555555555555"),
iv: mustDecodeHex("5555555555555555"),
output: mustDecodeHex("bea9411aa453c5434a5ae8c92862f564396855a9ea6e22d6d3b50ae1b3663311" +
"a4a3606c671d605ce16c3aece8e61ea145c59775017bee2fa6f88afc758069f7" +
"e0b8f676e644216f4d2a3422d7fa36c6c4931aca950e9da42788e6d0b6d1cd83" +
"8ef652e97b145b14871eae6c6804c7004db5ac2fce4c68c726d004b10fcaba86"),
},
{
key: mustDecodeHex("0000000000000000000000000000000000000000000000000000000000000000"),
iv: mustDecodeHex("000000000000000000000000"),
output: mustDecodeHex("76b8e0ada0f13d90405d6ae55386bd28bdd219b8a08ded1aa836efcc8b770dc7da41597c5157488d7724e03fb8d84a376a43b8f41518a11cc387b669b2ee6586"),
},
}
for _, c := range cases {
s := NewChaCha20Stream(c.key, c.iv)
input := make([]byte, len(c.output))
actualOutout := make([]byte, len(c.output))
s.XORKeyStream(actualOutout, input)
if r := cmp.Diff(c.output, actualOutout); r != "" {
t.Fatal(r)
}
}
}
func TestChaCha20Decoding(t *testing.T) {
key := make([]byte, 32)
common.Must2(rand.Read(key))
iv := make([]byte, 8)
common.Must2(rand.Read(iv))
stream := NewChaCha20Stream(key, iv)
payload := make([]byte, 1024)
common.Must2(rand.Read(payload))
x := make([]byte, len(payload))
stream.XORKeyStream(x, payload)
stream2 := NewChaCha20Stream(key, iv)
stream2.XORKeyStream(x, x)
if r := cmp.Diff(x, payload); r != "" {
t.Fatal(r)
}
}

160
common/crypto/chunk.go Normal file
View file

@ -0,0 +1,160 @@
package crypto
import (
"encoding/binary"
"io"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
)
// ChunkSizeDecoder is a utility class to decode size value from bytes.
type ChunkSizeDecoder interface {
SizeBytes() int32
Decode([]byte) (uint16, error)
}
// ChunkSizeEncoder is a utility class to encode size value into bytes.
type ChunkSizeEncoder interface {
SizeBytes() int32
Encode(uint16, []byte) []byte
}
type PaddingLengthGenerator interface {
MaxPaddingLen() uint16
NextPaddingLen() uint16
}
type PlainChunkSizeParser struct{}
func (PlainChunkSizeParser) SizeBytes() int32 {
return 2
}
func (PlainChunkSizeParser) Encode(size uint16, b []byte) []byte {
binary.BigEndian.PutUint16(b, size)
return b[:2]
}
func (PlainChunkSizeParser) Decode(b []byte) (uint16, error) {
return binary.BigEndian.Uint16(b), nil
}
type AEADChunkSizeParser struct {
Auth *AEADAuthenticator
}
func (p *AEADChunkSizeParser) SizeBytes() int32 {
return 2 + int32(p.Auth.Overhead())
}
func (p *AEADChunkSizeParser) Encode(size uint16, b []byte) []byte {
binary.BigEndian.PutUint16(b, size-uint16(p.Auth.Overhead()))
b, err := p.Auth.Seal(b[:0], b[:2])
common.Must(err)
return b
}
func (p *AEADChunkSizeParser) Decode(b []byte) (uint16, error) {
b, err := p.Auth.Open(b[:0], b)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint16(b) + uint16(p.Auth.Overhead()), nil
}
type ChunkStreamReader struct {
sizeDecoder ChunkSizeDecoder
reader *buf.BufferedReader
buffer []byte
leftOverSize int32
maxNumChunk uint32
numChunk uint32
}
func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader {
return NewChunkStreamReaderWithChunkCount(sizeDecoder, reader, 0)
}
func NewChunkStreamReaderWithChunkCount(sizeDecoder ChunkSizeDecoder, reader io.Reader, maxNumChunk uint32) *ChunkStreamReader {
r := &ChunkStreamReader{
sizeDecoder: sizeDecoder,
buffer: make([]byte, sizeDecoder.SizeBytes()),
maxNumChunk: maxNumChunk,
}
if breader, ok := reader.(*buf.BufferedReader); ok {
r.reader = breader
} else {
r.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)}
}
return r
}
func (r *ChunkStreamReader) readSize() (uint16, error) {
if _, err := io.ReadFull(r.reader, r.buffer); err != nil {
return 0, err
}
return r.sizeDecoder.Decode(r.buffer)
}
func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
size := r.leftOverSize
if size == 0 {
r.numChunk++
if r.maxNumChunk > 0 && r.numChunk > r.maxNumChunk {
return nil, io.EOF
}
nextSize, err := r.readSize()
if err != nil {
return nil, err
}
if nextSize == 0 {
return nil, io.EOF
}
size = int32(nextSize)
}
r.leftOverSize = size
mb, err := r.reader.ReadAtMost(size)
if !mb.IsEmpty() {
r.leftOverSize -= mb.Len()
return mb, nil
}
return nil, err
}
type ChunkStreamWriter struct {
sizeEncoder ChunkSizeEncoder
writer buf.Writer
}
func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *ChunkStreamWriter {
return &ChunkStreamWriter{
sizeEncoder: sizeEncoder,
writer: buf.NewWriter(writer),
}
}
func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
const sliceSize = 8192
mbLen := mb.Len()
mb2Write := make(buf.MultiBuffer, 0, mbLen/buf.Size+mbLen/sliceSize+2)
for {
mb2, slice := buf.SplitSize(mb, sliceSize)
mb = mb2
b := buf.New()
w.sizeEncoder.Encode(uint16(slice.Len()), b.Extend(w.sizeEncoder.SizeBytes()))
mb2Write = append(mb2Write, b)
mb2Write = append(mb2Write, slice...)
if mb.IsEmpty() {
break
}
}
return w.writer.WriteMultiBuffer(mb2Write)
}

View file

@ -0,0 +1,51 @@
package crypto_test
import (
"bytes"
"io"
"testing"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
. "github.com/xtls/xray-core/v1/common/crypto"
)
func TestChunkStreamIO(t *testing.T) {
cache := bytes.NewBuffer(make([]byte, 0, 8192))
writer := NewChunkStreamWriter(PlainChunkSizeParser{}, cache)
reader := NewChunkStreamReader(PlainChunkSizeParser{}, cache)
b := buf.New()
b.WriteString("abcd")
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{b}))
b = buf.New()
b.WriteString("efg")
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{b}))
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
if cache.Len() != 13 {
t.Fatalf("Cache length is %d, want 13", cache.Len())
}
mb, err := reader.ReadMultiBuffer()
common.Must(err)
if s := mb.String(); s != "abcd" {
t.Error("content: ", s)
}
mb, err = reader.ReadMultiBuffer()
common.Must(err)
if s := mb.String(); s != "efg" {
t.Error("content: ", s)
}
_, err = reader.ReadMultiBuffer()
if err != io.EOF {
t.Error("error: ", err)
}
}

4
common/crypto/crypto.go Normal file
View file

@ -0,0 +1,4 @@
// Package crypto provides common crypto libraries for Xray.
package crypto // import "github.com/xtls/xray-core/v1/common/crypto"
//go:generate go run github.com/xtls/xray-core/v1/common/errors/errorgen

View file

@ -0,0 +1,9 @@
package crypto
import "github.com/xtls/xray-core/v1/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

View file

@ -0,0 +1,80 @@
package internal
//go:generate go run chacha_core_gen.go
import (
"encoding/binary"
)
const (
wordSize = 4 // the size of ChaCha20's words
stateSize = 16 // the size of ChaCha20's state, in words
blockSize = stateSize * wordSize // the size of ChaCha20's block, in bytes
)
type ChaCha20Stream struct {
state [stateSize]uint32 // the state as an array of 16 32-bit words
block [blockSize]byte // the keystream as an array of 64 bytes
offset int // the offset of used bytes in block
rounds int
}
func NewChaCha20Stream(key []byte, nonce []byte, rounds int) *ChaCha20Stream {
s := new(ChaCha20Stream)
// the magic constants for 256-bit keys
s.state[0] = 0x61707865
s.state[1] = 0x3320646e
s.state[2] = 0x79622d32
s.state[3] = 0x6b206574
for i := 0; i < 8; i++ {
s.state[i+4] = binary.LittleEndian.Uint32(key[i*4 : i*4+4])
}
switch len(nonce) {
case 8:
s.state[14] = binary.LittleEndian.Uint32(nonce[0:])
s.state[15] = binary.LittleEndian.Uint32(nonce[4:])
case 12:
s.state[13] = binary.LittleEndian.Uint32(nonce[0:4])
s.state[14] = binary.LittleEndian.Uint32(nonce[4:8])
s.state[15] = binary.LittleEndian.Uint32(nonce[8:12])
default:
panic("bad nonce length")
}
s.rounds = rounds
ChaCha20Block(&s.state, s.block[:], s.rounds)
return s
}
func (s *ChaCha20Stream) XORKeyStream(dst, src []byte) {
// Stride over the input in 64-byte blocks, minus the amount of keystream
// previously used. This will produce best results when processing blocks
// of a size evenly divisible by 64.
i := 0
max := len(src)
for i < max {
gap := blockSize - s.offset
limit := i + gap
if limit > max {
limit = max
}
o := s.offset
for j := i; j < limit; j++ {
dst[j] = src[j] ^ s.block[o]
o++
}
i += gap
s.offset = o
if o == blockSize {
s.offset = 0
s.state[12]++
ChaCha20Block(&s.state, s.block[:], s.rounds)
}
}
}

View file

@ -0,0 +1,123 @@
package internal
import "encoding/binary"
func ChaCha20Block(s *[16]uint32, out []byte, rounds int) {
var x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 = s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7], s[8], s[9], s[10], s[11], s[12], s[13], s[14], s[15]
for i := 0; i < rounds; i += 2 {
var x uint32
x0 += x4
x = x12 ^ x0
x12 = (x << 16) | (x >> (32 - 16))
x8 += x12
x = x4 ^ x8
x4 = (x << 12) | (x >> (32 - 12))
x0 += x4
x = x12 ^ x0
x12 = (x << 8) | (x >> (32 - 8))
x8 += x12
x = x4 ^ x8
x4 = (x << 7) | (x >> (32 - 7))
x1 += x5
x = x13 ^ x1
x13 = (x << 16) | (x >> (32 - 16))
x9 += x13
x = x5 ^ x9
x5 = (x << 12) | (x >> (32 - 12))
x1 += x5
x = x13 ^ x1
x13 = (x << 8) | (x >> (32 - 8))
x9 += x13
x = x5 ^ x9
x5 = (x << 7) | (x >> (32 - 7))
x2 += x6
x = x14 ^ x2
x14 = (x << 16) | (x >> (32 - 16))
x10 += x14
x = x6 ^ x10
x6 = (x << 12) | (x >> (32 - 12))
x2 += x6
x = x14 ^ x2
x14 = (x << 8) | (x >> (32 - 8))
x10 += x14
x = x6 ^ x10
x6 = (x << 7) | (x >> (32 - 7))
x3 += x7
x = x15 ^ x3
x15 = (x << 16) | (x >> (32 - 16))
x11 += x15
x = x7 ^ x11
x7 = (x << 12) | (x >> (32 - 12))
x3 += x7
x = x15 ^ x3
x15 = (x << 8) | (x >> (32 - 8))
x11 += x15
x = x7 ^ x11
x7 = (x << 7) | (x >> (32 - 7))
x0 += x5
x = x15 ^ x0
x15 = (x << 16) | (x >> (32 - 16))
x10 += x15
x = x5 ^ x10
x5 = (x << 12) | (x >> (32 - 12))
x0 += x5
x = x15 ^ x0
x15 = (x << 8) | (x >> (32 - 8))
x10 += x15
x = x5 ^ x10
x5 = (x << 7) | (x >> (32 - 7))
x1 += x6
x = x12 ^ x1
x12 = (x << 16) | (x >> (32 - 16))
x11 += x12
x = x6 ^ x11
x6 = (x << 12) | (x >> (32 - 12))
x1 += x6
x = x12 ^ x1
x12 = (x << 8) | (x >> (32 - 8))
x11 += x12
x = x6 ^ x11
x6 = (x << 7) | (x >> (32 - 7))
x2 += x7
x = x13 ^ x2
x13 = (x << 16) | (x >> (32 - 16))
x8 += x13
x = x7 ^ x8
x7 = (x << 12) | (x >> (32 - 12))
x2 += x7
x = x13 ^ x2
x13 = (x << 8) | (x >> (32 - 8))
x8 += x13
x = x7 ^ x8
x7 = (x << 7) | (x >> (32 - 7))
x3 += x4
x = x14 ^ x3
x14 = (x << 16) | (x >> (32 - 16))
x9 += x14
x = x4 ^ x9
x4 = (x << 12) | (x >> (32 - 12))
x3 += x4
x = x14 ^ x3
x14 = (x << 8) | (x >> (32 - 8))
x9 += x14
x = x4 ^ x9
x4 = (x << 7) | (x >> (32 - 7))
}
binary.LittleEndian.PutUint32(out[0:4], s[0]+x0)
binary.LittleEndian.PutUint32(out[4:8], s[1]+x1)
binary.LittleEndian.PutUint32(out[8:12], s[2]+x2)
binary.LittleEndian.PutUint32(out[12:16], s[3]+x3)
binary.LittleEndian.PutUint32(out[16:20], s[4]+x4)
binary.LittleEndian.PutUint32(out[20:24], s[5]+x5)
binary.LittleEndian.PutUint32(out[24:28], s[6]+x6)
binary.LittleEndian.PutUint32(out[28:32], s[7]+x7)
binary.LittleEndian.PutUint32(out[32:36], s[8]+x8)
binary.LittleEndian.PutUint32(out[36:40], s[9]+x9)
binary.LittleEndian.PutUint32(out[40:44], s[10]+x10)
binary.LittleEndian.PutUint32(out[44:48], s[11]+x11)
binary.LittleEndian.PutUint32(out[48:52], s[12]+x12)
binary.LittleEndian.PutUint32(out[52:56], s[13]+x13)
binary.LittleEndian.PutUint32(out[56:60], s[14]+x14)
binary.LittleEndian.PutUint32(out[60:64], s[15]+x15)
}

View file

@ -0,0 +1,69 @@
// +build generate
package main
import (
"fmt"
"log"
"os"
)
func writeQuarterRound(file *os.File, a, b, c, d int) {
add := "x%d+=x%d\n"
xor := "x=x%d^x%d\n"
rotate := "x%d=(x << %d) | (x >> (32 - %d))\n"
fmt.Fprintf(file, add, a, b)
fmt.Fprintf(file, xor, d, a)
fmt.Fprintf(file, rotate, d, 16, 16)
fmt.Fprintf(file, add, c, d)
fmt.Fprintf(file, xor, b, c)
fmt.Fprintf(file, rotate, b, 12, 12)
fmt.Fprintf(file, add, a, b)
fmt.Fprintf(file, xor, d, a)
fmt.Fprintf(file, rotate, d, 8, 8)
fmt.Fprintf(file, add, c, d)
fmt.Fprintf(file, xor, b, c)
fmt.Fprintf(file, rotate, b, 7, 7)
}
func writeChacha20Block(file *os.File) {
fmt.Fprintln(file, `
func ChaCha20Block(s *[16]uint32, out []byte, rounds int) {
var x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15 = s[0],s[1],s[2],s[3],s[4],s[5],s[6],s[7],s[8],s[9],s[10],s[11],s[12],s[13],s[14],s[15]
for i := 0; i < rounds; i+=2 {
var x uint32
`)
writeQuarterRound(file, 0, 4, 8, 12)
writeQuarterRound(file, 1, 5, 9, 13)
writeQuarterRound(file, 2, 6, 10, 14)
writeQuarterRound(file, 3, 7, 11, 15)
writeQuarterRound(file, 0, 5, 10, 15)
writeQuarterRound(file, 1, 6, 11, 12)
writeQuarterRound(file, 2, 7, 8, 13)
writeQuarterRound(file, 3, 4, 9, 14)
fmt.Fprintln(file, "}")
for i := 0; i < 16; i++ {
fmt.Fprintf(file, "binary.LittleEndian.PutUint32(out[%d:%d], s[%d]+x%d)\n", i*4, i*4+4, i, i)
}
fmt.Fprintln(file, "}")
fmt.Fprintln(file)
}
func main() {
file, err := os.OpenFile("chacha_core.generated.go", os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
if err != nil {
log.Fatalf("Failed to generate chacha_core.go: %v", err)
}
defer file.Close()
fmt.Fprintln(file, "package internal")
fmt.Fprintln(file)
fmt.Fprintln(file, "import \"encoding/binary\"")
fmt.Fprintln(file)
writeChacha20Block(file)
}

66
common/crypto/io.go Normal file
View file

@ -0,0 +1,66 @@
package crypto
import (
"crypto/cipher"
"io"
"github.com/xtls/xray-core/v1/common/buf"
)
type CryptionReader struct {
stream cipher.Stream
reader io.Reader
}
func NewCryptionReader(stream cipher.Stream, reader io.Reader) *CryptionReader {
return &CryptionReader{
stream: stream,
reader: reader,
}
}
func (r *CryptionReader) Read(data []byte) (int, error) {
nBytes, err := r.reader.Read(data)
if nBytes > 0 {
r.stream.XORKeyStream(data[:nBytes], data[:nBytes])
}
return nBytes, err
}
var (
_ buf.Writer = (*CryptionWriter)(nil)
)
type CryptionWriter struct {
stream cipher.Stream
writer io.Writer
bufWriter buf.Writer
}
// NewCryptionWriter creates a new CryptionWriter.
func NewCryptionWriter(stream cipher.Stream, writer io.Writer) *CryptionWriter {
return &CryptionWriter{
stream: stream,
writer: writer,
bufWriter: buf.NewWriter(writer),
}
}
// Write implements io.Writer.Write().
func (w *CryptionWriter) Write(data []byte) (int, error) {
w.stream.XORKeyStream(data, data)
if err := buf.WriteAllBytes(w.writer, data); err != nil {
return 0, err
}
return len(data), nil
}
// WriteMultiBuffer implements buf.Writer.
func (w *CryptionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for _, b := range mb {
w.stream.XORKeyStream(b.Bytes(), b.Bytes())
}
return w.bufWriter.WriteMultiBuffer(mb)
}

52
common/dice/dice.go Normal file
View file

@ -0,0 +1,52 @@
// Package dice contains common functions to generate random number.
// It also initialize math/rand with the time in seconds at launch time.
package dice // import "github.com/xtls/xray-core/v1/common/dice"
import (
"math/rand"
"time"
)
// Roll returns a non-negative number between 0 (inclusive) and n (exclusive).
func Roll(n int) int {
if n == 1 {
return 0
}
return rand.Intn(n)
}
// Roll returns a non-negative number between 0 (inclusive) and n (exclusive).
func RollDeterministic(n int, seed int64) int {
if n == 1 {
return 0
}
return rand.New(rand.NewSource(seed)).Intn(n)
}
// RollUint16 returns a random uint16 value.
func RollUint16() uint16 {
return uint16(rand.Int63() >> 47)
}
func RollUint64() uint64 {
return rand.Uint64()
}
func NewDeterministicDice(seed int64) *DeterministicDice {
return &DeterministicDice{rand.New(rand.NewSource(seed))}
}
type DeterministicDice struct {
*rand.Rand
}
func (dd *DeterministicDice) Roll(n int) int {
if n == 1 {
return 0
}
return dd.Intn(n)
}
func init() {
rand.Seed(time.Now().Unix())
}

50
common/dice/dice_test.go Normal file
View file

@ -0,0 +1,50 @@
package dice_test
import (
"math/rand"
"testing"
. "github.com/xtls/xray-core/v1/common/dice"
)
func BenchmarkRoll1(b *testing.B) {
for i := 0; i < b.N; i++ {
Roll(1)
}
}
func BenchmarkRoll20(b *testing.B) {
for i := 0; i < b.N; i++ {
Roll(20)
}
}
func BenchmarkIntn1(b *testing.B) {
for i := 0; i < b.N; i++ {
rand.Intn(1)
}
}
func BenchmarkIntn20(b *testing.B) {
for i := 0; i < b.N; i++ {
rand.Intn(20)
}
}
func BenchmarkInt63(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = uint16(rand.Int63() >> 47)
}
}
func BenchmarkInt31(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = uint16(rand.Int31() >> 15)
}
}
func BenchmarkIntn(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = uint16(rand.Intn(65536))
}
}

View file

@ -0,0 +1,9 @@
package common
import "github.com/xtls/xray-core/v1/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

View file

@ -0,0 +1,45 @@
package main
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/xtls/xray-core/v1/common"
)
func main() {
pwd, err := os.Getwd()
if err != nil {
fmt.Println("can not get current working directory")
os.Exit(1)
}
pkg := filepath.Base(pwd)
if pkg == "xray-core" {
pkg = "core"
}
moduleName, gmnErr := common.GetModuleName(pwd)
if gmnErr != nil {
fmt.Println("can not get module path", gmnErr)
os.Exit(1)
}
file, err := os.OpenFile("errors.generated.go", os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
if err != nil {
log.Fatalf("Failed to generate errors.generated.go: %v", err)
os.Exit(1)
}
defer file.Close()
fmt.Fprintln(file, "package", pkg)
fmt.Fprintln(file, "")
fmt.Fprintln(file, "import \""+moduleName+"/common/errors\"")
fmt.Fprintln(file, "")
fmt.Fprintln(file, "type errPathObjHolder struct{}")
fmt.Fprintln(file, "")
fmt.Fprintln(file, "func newError(values ...interface{}) *errors.Error {")
fmt.Fprintln(file, " return errors.New(values...).WithPathObj(errPathObjHolder{})")
fmt.Fprintln(file, "}")
}

195
common/errors/errors.go Normal file
View file

@ -0,0 +1,195 @@
// Package errors is a drop-in replacement for Golang lib 'errors'.
package errors // import "github.com/xtls/xray-core/v1/common/errors"
import (
"os"
"reflect"
"strings"
"github.com/xtls/xray-core/v1/common/log"
"github.com/xtls/xray-core/v1/common/serial"
)
type hasInnerError interface {
// Inner returns the underlying error of this one.
Inner() error
}
type hasSeverity interface {
Severity() log.Severity
}
// Error is an error object with underlying error.
type Error struct {
pathObj interface{}
prefix []interface{}
message []interface{}
inner error
severity log.Severity
}
func (err *Error) WithPathObj(obj interface{}) *Error {
err.pathObj = obj
return err
}
func (err *Error) pkgPath() string {
if err.pathObj == nil {
return ""
}
return reflect.TypeOf(err.pathObj).PkgPath()
}
// Error implements error.Error().
func (err *Error) Error() string {
builder := strings.Builder{}
for _, prefix := range err.prefix {
builder.WriteByte('[')
builder.WriteString(serial.ToString(prefix))
builder.WriteString("] ")
}
path := err.pkgPath()
if len(path) > 0 {
builder.WriteString(path)
builder.WriteString(": ")
}
msg := serial.Concat(err.message...)
builder.WriteString(msg)
if err.inner != nil {
builder.WriteString(" > ")
builder.WriteString(err.inner.Error())
}
return builder.String()
}
// Inner implements hasInnerError.Inner()
func (err *Error) Inner() error {
if err.inner == nil {
return nil
}
return err.inner
}
func (err *Error) Base(e error) *Error {
err.inner = e
return err
}
func (err *Error) atSeverity(s log.Severity) *Error {
err.severity = s
return err
}
func (err *Error) Severity() log.Severity {
if err.inner == nil {
return err.severity
}
if s, ok := err.inner.(hasSeverity); ok {
as := s.Severity()
if as < err.severity {
return as
}
}
return err.severity
}
// AtDebug sets the severity to debug.
func (err *Error) AtDebug() *Error {
return err.atSeverity(log.Severity_Debug)
}
// AtInfo sets the severity to info.
func (err *Error) AtInfo() *Error {
return err.atSeverity(log.Severity_Info)
}
// AtWarning sets the severity to warning.
func (err *Error) AtWarning() *Error {
return err.atSeverity(log.Severity_Warning)
}
// AtError sets the severity to error.
func (err *Error) AtError() *Error {
return err.atSeverity(log.Severity_Error)
}
// String returns the string representation of this error.
func (err *Error) String() string {
return err.Error()
}
// WriteToLog writes current error into log.
func (err *Error) WriteToLog(opts ...ExportOption) {
var holder ExportOptionHolder
for _, opt := range opts {
opt(&holder)
}
if holder.SessionID > 0 {
err.prefix = append(err.prefix, holder.SessionID)
}
log.Record(&log.GeneralMessage{
Severity: GetSeverity(err),
Content: err,
})
}
type ExportOptionHolder struct {
SessionID uint32
}
type ExportOption func(*ExportOptionHolder)
// New returns a new error object with message formed from given arguments.
func New(msg ...interface{}) *Error {
return &Error{
message: msg,
severity: log.Severity_Info,
}
}
// Cause returns the root cause of this error.
func Cause(err error) error {
if err == nil {
return nil
}
L:
for {
switch inner := err.(type) {
case hasInnerError:
if inner.Inner() == nil {
break L
}
err = inner.Inner()
case *os.PathError:
if inner.Err == nil {
break L
}
err = inner.Err
case *os.SyscallError:
if inner.Err == nil {
break L
}
err = inner.Err
default:
break L
}
}
return err
}
// GetSeverity returns the actual severity of the error, including inner errors.
func GetSeverity(err error) log.Severity {
if s, ok := err.(hasSeverity); ok {
return s.Severity()
}
return log.Severity_Info
}

View file

@ -0,0 +1,62 @@
package errors_test
import (
"io"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
. "github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/log"
)
func TestError(t *testing.T) {
err := New("TestError")
if v := GetSeverity(err); v != log.Severity_Info {
t.Error("severity: ", v)
}
err = New("TestError2").Base(io.EOF)
if v := GetSeverity(err); v != log.Severity_Info {
t.Error("severity: ", v)
}
err = New("TestError3").Base(io.EOF).AtWarning()
if v := GetSeverity(err); v != log.Severity_Warning {
t.Error("severity: ", v)
}
err = New("TestError4").Base(io.EOF).AtWarning()
err = New("TestError5").Base(err)
if v := GetSeverity(err); v != log.Severity_Warning {
t.Error("severity: ", v)
}
if v := err.Error(); !strings.Contains(v, "EOF") {
t.Error("error: ", v)
}
}
type e struct{}
func TestErrorMessage(t *testing.T) {
data := []struct {
err error
msg string
}{
{
err: New("a").Base(New("b")).WithPathObj(e{}),
msg: "github.com/xtls/xray-core/v1/common/errors_test: a > b",
},
{
err: New("a").Base(New("b").WithPathObj(e{})),
msg: "a > github.com/xtls/xray-core/v1/common/errors_test: b",
},
}
for _, d := range data {
if diff := cmp.Diff(d.msg, d.err.Error()); diff != "" {
t.Error(diff)
}
}
}

View file

@ -0,0 +1,30 @@
package errors
import (
"strings"
)
type multiError []error
func (e multiError) Error() string {
var r strings.Builder
r.WriteString("multierr: ")
for _, err := range e {
r.WriteString(err.Error())
r.WriteString(" | ")
}
return r.String()
}
func Combine(maybeError ...error) error {
var errs multiError
for _, err := range maybeError {
if err != nil {
errs = append(errs, err)
}
}
if len(errs) == 0 {
return nil
}
return errs
}

68
common/interfaces.go Normal file
View file

@ -0,0 +1,68 @@
package common
import "github.com/xtls/xray-core/v1/common/errors"
// Closable is the interface for objects that can release its resources.
//
// xray:api:beta
type Closable interface {
// Close release all resources used by this object, including goroutines.
Close() error
}
// Interruptible is an interface for objects that can be stopped before its completion.
//
// xray:api:beta
type Interruptible interface {
Interrupt()
}
// Close closes the obj if it is a Closable.
//
// xray:api:beta
func Close(obj interface{}) error {
if c, ok := obj.(Closable); ok {
return c.Close()
}
return nil
}
// Interrupt calls Interrupt() if object implements Interruptible interface, or Close() if the object implements Closable interface.
//
// xray:api:beta
func Interrupt(obj interface{}) error {
if c, ok := obj.(Interruptible); ok {
c.Interrupt()
return nil
}
return Close(obj)
}
// Runnable is the interface for objects that can start to work and stop on demand.
type Runnable interface {
// Start starts the runnable object. Upon the method returning nil, the object begins to function properly.
Start() error
Closable
}
// HasType is the interface for objects that knows its type.
type HasType interface {
// Type returns the type of the object.
// Usually it returns (*Type)(nil) of the object.
Type() interface{}
}
// ChainedClosable is a Closable that consists of multiple Closable objects.
type ChainedClosable []Closable
// Close implements Closable.
func (cc ChainedClosable) Close() error {
var errs []error
for _, c := range cc {
if err := c.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Combine(errs...)
}

64
common/log/access.go Normal file
View file

@ -0,0 +1,64 @@
package log
import (
"context"
"strings"
"github.com/xtls/xray-core/v1/common/serial"
)
type logKey int
const (
accessMessageKey logKey = iota
)
type AccessStatus string
const (
AccessAccepted = AccessStatus("accepted")
AccessRejected = AccessStatus("rejected")
)
type AccessMessage struct {
From interface{}
To interface{}
Status AccessStatus
Reason interface{}
Email string
Detour string
}
func (m *AccessMessage) String() string {
builder := strings.Builder{}
builder.WriteString(serial.ToString(m.From))
builder.WriteByte(' ')
builder.WriteString(string(m.Status))
builder.WriteByte(' ')
builder.WriteString(serial.ToString(m.To))
builder.WriteByte(' ')
if len(m.Detour) > 0 {
builder.WriteByte('[')
builder.WriteString(m.Detour)
builder.WriteString("] ")
}
builder.WriteString(serial.ToString(m.Reason))
if len(m.Email) > 0 {
builder.WriteString("email:")
builder.WriteString(m.Email)
builder.WriteByte(' ')
}
return builder.String()
}
func ContextWithAccessMessage(ctx context.Context, accessMessage *AccessMessage) context.Context {
return context.WithValue(ctx, accessMessageKey, accessMessage)
}
func AccessMessageFromContext(ctx context.Context) *AccessMessage {
if accessMessage, ok := ctx.Value(accessMessageKey).(*AccessMessage); ok {
return accessMessage
}
return nil
}

66
common/log/log.go Normal file
View file

@ -0,0 +1,66 @@
package log // import "github.com/xtls/xray-core/v1/common/log"
import (
"sync"
"github.com/xtls/xray-core/v1/common/serial"
)
// Message is the interface for all log messages.
type Message interface {
String() string
}
// Handler is the interface for log handler.
type Handler interface {
Handle(msg Message)
}
// GeneralMessage is a general log message that can contain all kind of content.
type GeneralMessage struct {
Severity Severity
Content interface{}
}
// String implements Message.
func (m *GeneralMessage) String() string {
return serial.Concat("[", m.Severity, "] ", m.Content)
}
// Record writes a message into log stream.
func Record(msg Message) {
logHandler.Handle(msg)
}
var (
logHandler syncHandler
)
// RegisterHandler register a new handler as current log handler. Previous registered handler will be discarded.
func RegisterHandler(handler Handler) {
if handler == nil {
panic("Log handler is nil")
}
logHandler.Set(handler)
}
type syncHandler struct {
sync.RWMutex
Handler
}
func (h *syncHandler) Handle(msg Message) {
h.RLock()
defer h.RUnlock()
if h.Handler != nil {
h.Handler.Handle(msg)
}
}
func (h *syncHandler) Set(handler Handler) {
h.Lock()
defer h.Unlock()
h.Handler = handler
}

148
common/log/log.pb.go Normal file
View file

@ -0,0 +1,148 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.14.0
// source: common/log/log.proto
package log
import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type Severity int32
const (
Severity_Unknown Severity = 0
Severity_Error Severity = 1
Severity_Warning Severity = 2
Severity_Info Severity = 3
Severity_Debug Severity = 4
)
// Enum value maps for Severity.
var (
Severity_name = map[int32]string{
0: "Unknown",
1: "Error",
2: "Warning",
3: "Info",
4: "Debug",
}
Severity_value = map[string]int32{
"Unknown": 0,
"Error": 1,
"Warning": 2,
"Info": 3,
"Debug": 4,
}
)
func (x Severity) Enum() *Severity {
p := new(Severity)
*p = x
return p
}
func (x Severity) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Severity) Descriptor() protoreflect.EnumDescriptor {
return file_common_log_log_proto_enumTypes[0].Descriptor()
}
func (Severity) Type() protoreflect.EnumType {
return &file_common_log_log_proto_enumTypes[0]
}
func (x Severity) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use Severity.Descriptor instead.
func (Severity) EnumDescriptor() ([]byte, []int) {
return file_common_log_log_proto_rawDescGZIP(), []int{0}
}
var File_common_log_log_proto protoreflect.FileDescriptor
var file_common_log_log_proto_rawDesc = []byte{
0x0a, 0x14, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6c, 0x6f, 0x67, 0x2f, 0x6c, 0x6f, 0x67,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0f, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d,
0x6d, 0x6f, 0x6e, 0x2e, 0x6c, 0x6f, 0x67, 0x2a, 0x44, 0x0a, 0x08, 0x53, 0x65, 0x76, 0x65, 0x72,
0x69, 0x74, 0x79, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00,
0x12, 0x09, 0x0a, 0x05, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x0b, 0x0a, 0x07, 0x57,
0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x6e, 0x66, 0x6f,
0x10, 0x03, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x10, 0x04, 0x42, 0x52, 0x0a,
0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e,
0x2e, 0x6c, 0x6f, 0x67, 0x50, 0x01, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72,
0x65, 0x2f, 0x76, 0x31, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6c, 0x6f, 0x67, 0xaa,
0x02, 0x0f, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_common_log_log_proto_rawDescOnce sync.Once
file_common_log_log_proto_rawDescData = file_common_log_log_proto_rawDesc
)
func file_common_log_log_proto_rawDescGZIP() []byte {
file_common_log_log_proto_rawDescOnce.Do(func() {
file_common_log_log_proto_rawDescData = protoimpl.X.CompressGZIP(file_common_log_log_proto_rawDescData)
})
return file_common_log_log_proto_rawDescData
}
var file_common_log_log_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_common_log_log_proto_goTypes = []interface{}{
(Severity)(0), // 0: xray.common.log.Severity
}
var file_common_log_log_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_common_log_log_proto_init() }
func file_common_log_log_proto_init() {
if File_common_log_log_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_common_log_log_proto_rawDesc,
NumEnums: 1,
NumMessages: 0,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_log_log_proto_goTypes,
DependencyIndexes: file_common_log_log_proto_depIdxs,
EnumInfos: file_common_log_log_proto_enumTypes,
}.Build()
File_common_log_log_proto = out.File
file_common_log_log_proto_rawDesc = nil
file_common_log_log_proto_goTypes = nil
file_common_log_log_proto_depIdxs = nil
}

15
common/log/log.proto Normal file
View file

@ -0,0 +1,15 @@
syntax = "proto3";
package xray.common.log;
option csharp_namespace = "Xray.Common.Log";
option go_package = "github.com/xtls/xray-core/v1/common/log";
option java_package = "com.xray.common.log";
option java_multiple_files = true;
enum Severity {
Unknown = 0;
Error = 1;
Warning = 2;
Info = 3;
Debug = 4;
}

33
common/log/log_test.go Normal file
View file

@ -0,0 +1,33 @@
package log_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common/log"
"github.com/xtls/xray-core/v1/common/net"
)
type testLogger struct {
value string
}
func (l *testLogger) Handle(msg log.Message) {
l.value = msg.String()
}
func TestLogRecord(t *testing.T) {
var logger testLogger
log.RegisterHandler(&logger)
ip := "8.8.8.8"
log.Record(&log.GeneralMessage{
Severity: log.Severity_Error,
Content: net.ParseAddress(ip),
})
if diff := cmp.Diff("[Error] "+ip, logger.value); diff != "" {
t.Error(diff)
}
}

152
common/log/logger.go Normal file
View file

@ -0,0 +1,152 @@
package log
import (
"io"
"log"
"os"
"time"
"github.com/xtls/xray-core/v1/common/platform"
"github.com/xtls/xray-core/v1/common/signal/done"
"github.com/xtls/xray-core/v1/common/signal/semaphore"
)
// Writer is the interface for writing logs.
type Writer interface {
Write(string) error
io.Closer
}
// WriterCreator is a function to create LogWriters.
type WriterCreator func() Writer
type generalLogger struct {
creator WriterCreator
buffer chan Message
access *semaphore.Instance
done *done.Instance
}
// NewLogger returns a generic log handler that can handle all type of messages.
func NewLogger(logWriterCreator WriterCreator) Handler {
return &generalLogger{
creator: logWriterCreator,
buffer: make(chan Message, 16),
access: semaphore.New(1),
done: done.New(),
}
}
func (l *generalLogger) run() {
defer l.access.Signal()
dataWritten := false
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
logger := l.creator()
if logger == nil {
return
}
defer logger.Close()
for {
select {
case <-l.done.Wait():
return
case msg := <-l.buffer:
logger.Write(msg.String() + platform.LineSeparator())
dataWritten = true
case <-ticker.C:
if !dataWritten {
return
}
dataWritten = false
}
}
}
func (l *generalLogger) Handle(msg Message) {
select {
case l.buffer <- msg:
default:
}
select {
case <-l.access.Wait():
go l.run()
default:
}
}
func (l *generalLogger) Close() error {
return l.done.Close()
}
type consoleLogWriter struct {
logger *log.Logger
}
func (w *consoleLogWriter) Write(s string) error {
w.logger.Print(s)
return nil
}
func (w *consoleLogWriter) Close() error {
return nil
}
type fileLogWriter struct {
file *os.File
logger *log.Logger
}
func (w *fileLogWriter) Write(s string) error {
w.logger.Print(s)
return nil
}
func (w *fileLogWriter) Close() error {
return w.file.Close()
}
// CreateStdoutLogWriter returns a LogWriterCreator that creates LogWriter for stdout.
func CreateStdoutLogWriter() WriterCreator {
return func() Writer {
return &consoleLogWriter{
logger: log.New(os.Stdout, "", log.Ldate|log.Ltime),
}
}
}
// CreateStderrLogWriter returns a LogWriterCreator that creates LogWriter for stderr.
func CreateStderrLogWriter() WriterCreator {
return func() Writer {
return &consoleLogWriter{
logger: log.New(os.Stderr, "", log.Ldate|log.Ltime),
}
}
}
// CreateFileLogWriter returns a LogWriterCreator that creates LogWriter for the given file.
func CreateFileLogWriter(path string) (WriterCreator, error) {
file, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return nil, err
}
file.Close()
return func() Writer {
file, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return nil
}
return &fileLogWriter{
file: file,
logger: log.New(file, "", log.Ldate|log.Ltime),
}
}, nil
}
func init() {
RegisterHandler(NewLogger(CreateStdoutLogWriter()))
}

39
common/log/logger_test.go Normal file
View file

@ -0,0 +1,39 @@
package log_test
import (
"io/ioutil"
"os"
"strings"
"testing"
"time"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
. "github.com/xtls/xray-core/v1/common/log"
)
func TestFileLogger(t *testing.T) {
f, err := ioutil.TempFile("", "vtest")
common.Must(err)
path := f.Name()
common.Must(f.Close())
creator, err := CreateFileLogWriter(path)
common.Must(err)
handler := NewLogger(creator)
handler.Handle(&GeneralMessage{Content: "Test Log"})
time.Sleep(2 * time.Second)
common.Must(common.Close(handler))
f, err = os.Open(path)
common.Must(err)
defer f.Close()
b, err := buf.ReadAllToBytes(f)
common.Must(err)
if !strings.Contains(string(b), "Test Log") {
t.Fatal("Expect log text contains 'Test Log', but actually: ", string(b))
}
}

402
common/mux/client.go Normal file
View file

@ -0,0 +1,402 @@
package mux
import (
"context"
"io"
"sync"
"time"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/session"
"github.com/xtls/xray-core/v1/common/signal/done"
"github.com/xtls/xray-core/v1/common/task"
"github.com/xtls/xray-core/v1/proxy"
"github.com/xtls/xray-core/v1/transport"
"github.com/xtls/xray-core/v1/transport/internet"
"github.com/xtls/xray-core/v1/transport/pipe"
)
type ClientManager struct {
Enabled bool // wheather mux is enabled from user config
Picker WorkerPicker
}
func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
for i := 0; i < 16; i++ {
worker, err := m.Picker.PickAvailable()
if err != nil {
return err
}
if worker.Dispatch(ctx, link) {
return nil
}
}
return newError("unable to find an available mux client").AtWarning()
}
type WorkerPicker interface {
PickAvailable() (*ClientWorker, error)
}
type IncrementalWorkerPicker struct {
Factory ClientWorkerFactory
access sync.Mutex
workers []*ClientWorker
cleanupTask *task.Periodic
}
func (p *IncrementalWorkerPicker) cleanupFunc() error {
p.access.Lock()
defer p.access.Unlock()
if len(p.workers) == 0 {
return newError("no worker")
}
p.cleanup()
return nil
}
func (p *IncrementalWorkerPicker) cleanup() {
var activeWorkers []*ClientWorker
for _, w := range p.workers {
if !w.Closed() {
activeWorkers = append(activeWorkers, w)
}
}
p.workers = activeWorkers
}
func (p *IncrementalWorkerPicker) findAvailable() int {
for idx, w := range p.workers {
if !w.IsFull() {
return idx
}
}
return -1
}
func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
p.access.Lock()
defer p.access.Unlock()
idx := p.findAvailable()
if idx >= 0 {
n := len(p.workers)
if n > 1 && idx != n-1 {
p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
}
return p.workers[idx], false, nil
}
p.cleanup()
worker, err := p.Factory.Create()
if err != nil {
return nil, false, err
}
p.workers = append(p.workers, worker)
if p.cleanupTask == nil {
p.cleanupTask = &task.Periodic{
Interval: time.Second * 30,
Execute: p.cleanupFunc,
}
}
return worker, true, nil
}
func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
worker, start, err := p.pickInternal()
if start {
common.Must(p.cleanupTask.Start())
}
return worker, err
}
type ClientWorkerFactory interface {
Create() (*ClientWorker, error)
}
type DialingWorkerFactory struct {
Proxy proxy.Outbound
Dialer internet.Dialer
Strategy ClientStrategy
}
func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
uplinkReader, upLinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
c, err := NewClientWorker(transport.Link{
Reader: downlinkReader,
Writer: upLinkWriter,
}, f.Strategy)
if err != nil {
return nil, err
}
go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
})
ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
errors.New("failed to handler mux client connection").Base(err).WriteToLog()
}
common.Must(c.Close())
cancel()
}(f.Proxy, f.Dialer, c.done)
return c, nil
}
type ClientStrategy struct {
MaxConcurrency uint32
MaxConnection uint32
}
type ClientWorker struct {
sessionManager *SessionManager
link transport.Link
done *done.Instance
strategy ClientStrategy
}
var muxCoolAddress = net.DomainAddress("v1.mux.cool")
var muxCoolPort = net.Port(9527)
// NewClientWorker creates a new mux.Client.
func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) {
c := &ClientWorker{
sessionManager: NewSessionManager(),
link: stream,
done: done.New(),
strategy: s,
}
go c.fetchOutput()
go c.monitor()
return c, nil
}
func (m *ClientWorker) TotalConnections() uint32 {
return uint32(m.sessionManager.Count())
}
func (m *ClientWorker) ActiveConnections() uint32 {
return uint32(m.sessionManager.Size())
}
// Closed returns true if this Client is closed.
func (m *ClientWorker) Closed() bool {
return m.done.Done()
}
func (m *ClientWorker) monitor() {
timer := time.NewTicker(time.Second * 16)
defer timer.Stop()
for {
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Close(m.link.Writer)
common.Interrupt(m.link.Reader)
return
case <-timer.C:
size := m.sessionManager.Size()
if size == 0 && m.sessionManager.CloseIfNoSession() {
common.Must(m.done.Close())
}
}
}
}
func writeFirstPayload(reader buf.Reader, writer *Writer) error {
err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100)
if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
return writer.WriteMultiBuffer(buf.MultiBuffer{})
}
if err != nil {
return err
}
return nil
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
dest := session.OutboundFromContext(ctx).Target
transferType := protocol.TransferTypeStream
if dest.Network == net.Network_UDP {
transferType = protocol.TransferTypePacket
}
s.transferType = transferType
writer := NewWriter(s.ID, dest, output, transferType)
defer s.Close()
defer writer.Close()
newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
common.Interrupt(s.input)
return
}
if err := buf.Copy(s.input, writer); err != nil {
newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
common.Interrupt(s.input)
return
}
}
func (m *ClientWorker) IsClosing() bool {
sm := m.sessionManager
if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
return true
}
return false
}
func (m *ClientWorker) IsFull() bool {
if m.IsClosing() || m.Closed() {
return true
}
sm := m.sessionManager
if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
return true
}
return false
}
func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
if m.IsFull() || m.Closed() {
return false
}
sm := m.sessionManager
s := sm.Allocate()
if s == nil {
return false
}
s.input = link.Reader
s.output = link.Writer
go fetchInput(ctx, s, m.link.Writer)
return true
}
func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
s, found := m.sessionManager.Get(meta.SessionID)
if !found {
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
rr := s.NewReader(reader)
err := buf.Copy(rr, s.output)
if err != nil && buf.IsWriteError(err) {
newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
drainErr := buf.Copy(rr, buf.Discard)
common.Interrupt(s.input)
s.Close()
return drainErr
}
return err
}
func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := m.sessionManager.Get(meta.SessionID); found {
if meta.Option.Has(OptionError) {
common.Interrupt(s.input)
common.Interrupt(s.output)
}
s.Close()
}
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) fetchOutput() {
defer func() {
common.Must(m.done.Close())
}()
reader := &buf.BufferedReader{Reader: m.link.Reader}
var meta FrameMetadata
for {
err := meta.Unmarshal(reader)
if err != nil {
if errors.Cause(err) != io.EOF {
newError("failed to read metadata").Base(err).WriteToLog()
}
break
}
switch meta.SessionStatus {
case SessionStatusKeepAlive:
err = m.handleStatueKeepAlive(&meta, reader)
case SessionStatusEnd:
err = m.handleStatusEnd(&meta, reader)
case SessionStatusNew:
err = m.handleStatusNew(&meta, reader)
case SessionStatusKeep:
err = m.handleStatusKeep(&meta, reader)
default:
status := meta.SessionStatus
newError("unknown status: ", status).AtError().WriteToLog()
return
}
if err != nil {
newError("failed to process data").Base(err).WriteToLog()
return
}
}
}

116
common/mux/client_test.go Normal file
View file

@ -0,0 +1,116 @@
package mux_test
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/mux"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/session"
"github.com/xtls/xray-core/v1/testing/mocks"
"github.com/xtls/xray-core/v1/transport"
"github.com/xtls/xray-core/v1/transport/pipe"
)
func TestIncrementalPickerFailure(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockWorkerFactory := mocks.NewMuxClientWorkerFactory(mockCtl)
mockWorkerFactory.EXPECT().Create().Return(nil, errors.New("test"))
picker := mux.IncrementalWorkerPicker{
Factory: mockWorkerFactory,
}
_, err := picker.PickAvailable()
if err == nil {
t.Error("expected error, but nil")
}
}
func TestClientWorkerEOF(t *testing.T) {
reader, writer := pipe.New(pipe.WithoutSizeLimit())
common.Must(writer.Close())
worker, err := mux.NewClientWorker(transport.Link{Reader: reader, Writer: writer}, mux.ClientStrategy{})
common.Must(err)
time.Sleep(time.Millisecond * 500)
f := worker.Dispatch(context.Background(), nil)
if f {
t.Error("expected failed dispatching, but actually not")
}
}
func TestClientWorkerClose(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
r1, w1 := pipe.New(pipe.WithoutSizeLimit())
worker1, err := mux.NewClientWorker(transport.Link{
Reader: r1,
Writer: w1,
}, mux.ClientStrategy{
MaxConcurrency: 4,
MaxConnection: 4,
})
common.Must(err)
r2, w2 := pipe.New(pipe.WithoutSizeLimit())
worker2, err := mux.NewClientWorker(transport.Link{
Reader: r2,
Writer: w2,
}, mux.ClientStrategy{
MaxConcurrency: 4,
MaxConnection: 4,
})
common.Must(err)
factory := mocks.NewMuxClientWorkerFactory(mockCtl)
gomock.InOrder(
factory.EXPECT().Create().Return(worker1, nil),
factory.EXPECT().Create().Return(worker2, nil),
)
picker := &mux.IncrementalWorkerPicker{
Factory: factory,
}
manager := &mux.ClientManager{
Picker: picker,
}
tr1, tw1 := pipe.New(pipe.WithoutSizeLimit())
ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
common.Must(manager.Dispatch(ctx1, &transport.Link{
Reader: tr1,
Writer: tw1,
}))
defer tw1.Close()
common.Must(w1.Close())
time.Sleep(time.Millisecond * 500)
if !worker1.Closed() {
t.Error("worker1 is not finished")
}
tr2, tw2 := pipe.New(pipe.WithoutSizeLimit())
ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
common.Must(manager.Dispatch(ctx2, &transport.Link{
Reader: tr2,
Writer: tw2,
}))
defer tw2.Close()
common.Must(w2.Close())
}

View file

@ -0,0 +1,9 @@
package mux
import "github.com/xtls/xray-core/v1/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

145
common/mux/frame.go Normal file
View file

@ -0,0 +1,145 @@
package mux
import (
"encoding/binary"
"io"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/bitmask"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/serial"
)
type SessionStatus byte
const (
SessionStatusNew SessionStatus = 0x01
SessionStatusKeep SessionStatus = 0x02
SessionStatusEnd SessionStatus = 0x03
SessionStatusKeepAlive SessionStatus = 0x04
)
const (
OptionData bitmask.Byte = 0x01
OptionError bitmask.Byte = 0x02
)
type TargetNetwork byte
const (
TargetNetworkTCP TargetNetwork = 0x01
TargetNetworkUDP TargetNetwork = 0x02
)
var addrParser = protocol.NewAddressParser(
protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
protocol.PortThenAddress(),
)
/*
Frame format
2 bytes - length
2 bytes - session id
1 bytes - status
1 bytes - option
1 byte - network
2 bytes - port
n bytes - address
*/
type FrameMetadata struct {
Target net.Destination
SessionID uint16
Option bitmask.Byte
SessionStatus SessionStatus
}
func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
lenBytes := b.Extend(2)
len0 := b.Len()
sessionBytes := b.Extend(2)
binary.BigEndian.PutUint16(sessionBytes, f.SessionID)
common.Must(b.WriteByte(byte(f.SessionStatus)))
common.Must(b.WriteByte(byte(f.Option)))
if f.SessionStatus == SessionStatusNew {
switch f.Target.Network {
case net.Network_TCP:
common.Must(b.WriteByte(byte(TargetNetworkTCP)))
case net.Network_UDP:
common.Must(b.WriteByte(byte(TargetNetworkUDP)))
}
if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil {
return err
}
}
len1 := b.Len()
binary.BigEndian.PutUint16(lenBytes, uint16(len1-len0))
return nil
}
// Unmarshal reads FrameMetadata from the given reader.
func (f *FrameMetadata) Unmarshal(reader io.Reader) error {
metaLen, err := serial.ReadUint16(reader)
if err != nil {
return err
}
if metaLen > 512 {
return newError("invalid metalen ", metaLen).AtError()
}
b := buf.New()
defer b.Release()
if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil {
return err
}
return f.UnmarshalFromBuffer(b)
}
// UnmarshalFromBuffer reads a FrameMetadata from the given buffer.
// Visible for testing only.
func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error {
if b.Len() < 4 {
return newError("insufficient buffer: ", b.Len())
}
f.SessionID = binary.BigEndian.Uint16(b.BytesTo(2))
f.SessionStatus = SessionStatus(b.Byte(2))
f.Option = bitmask.Byte(b.Byte(3))
f.Target.Network = net.Network_Unknown
if f.SessionStatus == SessionStatusNew {
if b.Len() < 8 {
return newError("insufficient buffer: ", b.Len())
}
network := TargetNetwork(b.Byte(4))
b.Advance(5)
addr, port, err := addrParser.ReadAddressPort(nil, b)
if err != nil {
return newError("failed to parse address and port").Base(err)
}
switch network {
case TargetNetworkTCP:
f.Target = net.TCPDestination(addr, port)
case TargetNetworkUDP:
f.Target = net.UDPDestination(addr, port)
default:
return newError("unknown network type: ", network)
}
}
return nil
}

25
common/mux/frame_test.go Normal file
View file

@ -0,0 +1,25 @@
package mux_test
import (
"testing"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/mux"
"github.com/xtls/xray-core/v1/common/net"
)
func BenchmarkFrameWrite(b *testing.B) {
frame := mux.FrameMetadata{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), net.Port(80)),
SessionID: 1,
SessionStatus: mux.SessionStatusNew,
}
writer := buf.New()
defer writer.Release()
for i := 0; i < b.N; i++ {
common.Must(frame.WriteTo(writer))
writer.Clear()
}
}

3
common/mux/mux.go Normal file
View file

@ -0,0 +1,3 @@
package mux
//go:generate go run github.com/xtls/xray-core/v1/common/errors/errorgen

196
common/mux/mux_test.go Normal file
View file

@ -0,0 +1,196 @@
package mux_test
import (
"io"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
. "github.com/xtls/xray-core/v1/common/mux"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/transport/pipe"
)
func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
var mb buf.MultiBuffer
for {
b, err := reader.ReadMultiBuffer()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
mb = append(mb, b...)
}
return mb, nil
}
func TestReaderWriter(t *testing.T) {
pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
dest := net.TCPDestination(net.DomainAddress("example.com"), 80)
writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream)
dest2 := net.TCPDestination(net.LocalHostIP, 443)
writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream)
dest3 := net.TCPDestination(net.LocalHostIPv6, 18374)
writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream)
writePayload := func(writer *Writer, payload ...byte) error {
b := buf.New()
b.Write(payload)
return writer.WriteMultiBuffer(buf.MultiBuffer{b})
}
common.Must(writePayload(writer, 'a', 'b', 'c', 'd'))
common.Must(writePayload(writer2))
common.Must(writePayload(writer, 'e', 'f', 'g', 'h'))
common.Must(writePayload(writer3, 'x'))
writer.Close()
writer3.Close()
common.Must(writePayload(writer2, 'y'))
writer2.Close()
bytesReader := &buf.BufferedReader{Reader: pReader}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 1,
SessionStatus: SessionStatusNew,
Target: dest,
Option: OptionData,
}); r != "" {
t.Error("metadata: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "abcd" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionStatus: SessionStatusNew,
SessionID: 2,
Option: 0,
Target: dest2,
}); r != "" {
t.Error("meta: ", r)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 1,
SessionStatus: SessionStatusKeep,
Option: 1,
}); r != "" {
t.Error("meta: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "efgh" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 3,
SessionStatus: SessionStatusNew,
Option: 1,
Target: dest3,
}); r != "" {
t.Error("meta: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "x" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 1,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 3,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 2,
SessionStatus: SessionStatusKeep,
Option: 1,
}); r != "" {
t.Error("meta: ", r)
}
data, err := readAll(NewStreamReader(bytesReader))
common.Must(err)
if s := data.String(); s != "y" {
t.Error("data: ", s)
}
}
{
var meta FrameMetadata
common.Must(meta.Unmarshal(bytesReader))
if r := cmp.Diff(meta, FrameMetadata{
SessionID: 2,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
pWriter.Close()
{
var meta FrameMetadata
err := meta.Unmarshal(bytesReader)
if err == nil {
t.Error("nil error")
}
}
}

52
common/mux/reader.go Normal file
View file

@ -0,0 +1,52 @@
package mux
import (
"io"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/crypto"
"github.com/xtls/xray-core/v1/common/serial"
)
// PacketReader is an io.Reader that reads whole chunk of Mux frames every time.
type PacketReader struct {
reader io.Reader
eof bool
}
// NewPacketReader creates a new PacketReader.
func NewPacketReader(reader io.Reader) *PacketReader {
return &PacketReader{
reader: reader,
eof: false,
}
}
// ReadMultiBuffer implements buf.Reader.
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.eof {
return nil, io.EOF
}
size, err := serial.ReadUint16(r.reader)
if err != nil {
return nil, err
}
if size > buf.Size {
return nil, newError("packet size too large: ", size)
}
b := buf.New()
if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
b.Release()
return nil, err
}
r.eof = true
return buf.MultiBuffer{b}, nil
}
// NewStreamReader creates a new StreamReader.
func NewStreamReader(reader *buf.BufferedReader) buf.Reader {
return crypto.NewChunkStreamReaderWithChunkCount(crypto.PlainChunkSizeParser{}, reader, 1)
}

252
common/mux/server.go Normal file
View file

@ -0,0 +1,252 @@
package mux
import (
"context"
"io"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/errors"
"github.com/xtls/xray-core/v1/common/log"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/session"
"github.com/xtls/xray-core/v1/core"
"github.com/xtls/xray-core/v1/features/routing"
"github.com/xtls/xray-core/v1/transport"
"github.com/xtls/xray-core/v1/transport/pipe"
)
type Server struct {
dispatcher routing.Dispatcher
}
// NewServer creates a new mux.Server.
func NewServer(ctx context.Context) *Server {
s := &Server{}
core.RequireFeatures(ctx, func(d routing.Dispatcher) {
s.dispatcher = d
})
return s
}
// Type implements common.HasType.
func (s *Server) Type() interface{} {
return s.dispatcher.Type()
}
// Dispatch implements routing.Dispatcher
func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) {
if dest.Address != muxCoolAddress {
return s.dispatcher.Dispatch(ctx, dest)
}
opts := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
_, err := NewServerWorker(ctx, s.dispatcher, &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,
})
if err != nil {
return nil, err
}
return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil
}
// Start implements common.Runnable.
func (s *Server) Start() error {
return nil
}
// Close implements common.Closable.
func (s *Server) Close() error {
return nil
}
type ServerWorker struct {
dispatcher routing.Dispatcher
link *transport.Link
sessionManager *SessionManager
}
func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
worker := &ServerWorker{
dispatcher: d,
link: link,
sessionManager: NewSessionManager(),
}
go worker.run(ctx)
return worker, nil
}
func handle(ctx context.Context, s *Session, output buf.Writer) {
writer := NewResponseWriter(s.ID, output, s.transferType)
if err := buf.Copy(s.input, writer); err != nil {
newError("session ", s.ID, " ends.").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
}
writer.Close()
s.Close()
}
func (w *ServerWorker) ActiveConnections() uint32 {
return uint32(w.sessionManager.Size())
}
func (w *ServerWorker) Closed() bool {
return w.sessionManager.Closed()
}
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
newError("received request for ", meta.Target).WriteToLog(session.ExportIDToError(ctx))
{
msg := &log.AccessMessage{
To: meta.Target,
Status: log.AccessAccepted,
Reason: "",
}
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
msg.From = inbound.Source
msg.Email = inbound.User.Email
}
ctx = log.ContextWithAccessMessage(ctx, msg)
}
link, err := w.dispatcher.Dispatch(ctx, meta.Target)
if err != nil {
if meta.Option.Has(OptionData) {
buf.Copy(NewStreamReader(reader), buf.Discard)
}
return newError("failed to dispatch request.").Base(err)
}
s := &Session{
input: link.Reader,
output: link.Writer,
parent: w.sessionManager,
ID: meta.SessionID,
transferType: protocol.TransferTypeStream,
}
if meta.Target.Network == net.Network_UDP {
s.transferType = protocol.TransferTypePacket
}
w.sessionManager.Add(s)
go handle(ctx, s, w.link.Writer)
if !meta.Option.Has(OptionData) {
return nil
}
rr := s.NewReader(reader)
if err := buf.Copy(rr, s.output); err != nil {
buf.Copy(rr, buf.Discard)
common.Interrupt(s.input)
return s.Close()
}
return nil
}
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
s, found := w.sessionManager.Get(meta.SessionID)
if !found {
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, w.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
rr := s.NewReader(reader)
err := buf.Copy(rr, s.output)
if err != nil && buf.IsWriteError(err) {
newError("failed to write to downstream writer. closing session ", s.ID).Base(err).WriteToLog()
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, w.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
drainErr := buf.Copy(rr, buf.Discard)
common.Interrupt(s.input)
s.Close()
return drainErr
}
return err
}
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := w.sessionManager.Get(meta.SessionID); found {
if meta.Option.Has(OptionError) {
common.Interrupt(s.input)
common.Interrupt(s.output)
}
s.Close()
}
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedReader) error {
var meta FrameMetadata
err := meta.Unmarshal(reader)
if err != nil {
return newError("failed to read metadata").Base(err)
}
switch meta.SessionStatus {
case SessionStatusKeepAlive:
err = w.handleStatusKeepAlive(&meta, reader)
case SessionStatusEnd:
err = w.handleStatusEnd(&meta, reader)
case SessionStatusNew:
err = w.handleStatusNew(ctx, &meta, reader)
case SessionStatusKeep:
err = w.handleStatusKeep(&meta, reader)
default:
status := meta.SessionStatus
return newError("unknown status: ", status).AtError()
}
if err != nil {
return newError("failed to process data").Base(err)
}
return nil
}
func (w *ServerWorker) run(ctx context.Context) {
input := w.link.Reader
reader := &buf.BufferedReader{Reader: input}
defer w.sessionManager.Close()
for {
select {
case <-ctx.Done():
return
default:
err := w.handleFrame(ctx, reader)
if err != nil {
if errors.Cause(err) != io.EOF {
newError("unexpected EOF").Base(err).WriteToLog(session.ExportIDToError(ctx))
common.Interrupt(input)
}
return
}
}
}
}

160
common/mux/session.go Normal file
View file

@ -0,0 +1,160 @@
package mux
import (
"sync"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/protocol"
)
type SessionManager struct {
sync.RWMutex
sessions map[uint16]*Session
count uint16
closed bool
}
func NewSessionManager() *SessionManager {
return &SessionManager{
count: 0,
sessions: make(map[uint16]*Session, 16),
}
}
func (m *SessionManager) Closed() bool {
m.RLock()
defer m.RUnlock()
return m.closed
}
func (m *SessionManager) Size() int {
m.RLock()
defer m.RUnlock()
return len(m.sessions)
}
func (m *SessionManager) Count() int {
m.RLock()
defer m.RUnlock()
return int(m.count)
}
func (m *SessionManager) Allocate() *Session {
m.Lock()
defer m.Unlock()
if m.closed {
return nil
}
m.count++
s := &Session{
ID: m.count,
parent: m,
}
m.sessions[s.ID] = s
return s
}
func (m *SessionManager) Add(s *Session) {
m.Lock()
defer m.Unlock()
if m.closed {
return
}
m.count++
m.sessions[s.ID] = s
}
func (m *SessionManager) Remove(id uint16) {
m.Lock()
defer m.Unlock()
if m.closed {
return
}
delete(m.sessions, id)
if len(m.sessions) == 0 {
m.sessions = make(map[uint16]*Session, 16)
}
}
func (m *SessionManager) Get(id uint16) (*Session, bool) {
m.RLock()
defer m.RUnlock()
if m.closed {
return nil, false
}
s, found := m.sessions[id]
return s, found
}
func (m *SessionManager) CloseIfNoSession() bool {
m.Lock()
defer m.Unlock()
if m.closed {
return true
}
if len(m.sessions) != 0 {
return false
}
m.closed = true
return true
}
func (m *SessionManager) Close() error {
m.Lock()
defer m.Unlock()
if m.closed {
return nil
}
m.closed = true
for _, s := range m.sessions {
common.Close(s.input)
common.Close(s.output)
}
m.sessions = nil
return nil
}
// Session represents a client connection in a Mux connection.
type Session struct {
input buf.Reader
output buf.Writer
parent *SessionManager
ID uint16
transferType protocol.TransferType
}
// Close closes all resources associated with this session.
func (s *Session) Close() error {
common.Close(s.output)
common.Close(s.input)
s.parent.Remove(s.ID)
return nil
}
// NewReader creates a buf.Reader based on the transfer type of this Session.
func (s *Session) NewReader(reader *buf.BufferedReader) buf.Reader {
if s.transferType == protocol.TransferTypeStream {
return NewStreamReader(reader)
}
return NewPacketReader(reader)
}

View file

@ -0,0 +1,51 @@
package mux_test
import (
"testing"
. "github.com/xtls/xray-core/v1/common/mux"
)
func TestSessionManagerAdd(t *testing.T) {
m := NewSessionManager()
s := m.Allocate()
if s.ID != 1 {
t.Error("id: ", s.ID)
}
if m.Size() != 1 {
t.Error("size: ", m.Size())
}
s = m.Allocate()
if s.ID != 2 {
t.Error("id: ", s.ID)
}
if m.Size() != 2 {
t.Error("size: ", m.Size())
}
s = &Session{
ID: 4,
}
m.Add(s)
if s.ID != 4 {
t.Error("id: ", s.ID)
}
if m.Size() != 3 {
t.Error("size: ", m.Size())
}
}
func TestSessionManagerClose(t *testing.T) {
m := NewSessionManager()
s := m.Allocate()
if m.CloseIfNoSession() {
t.Error("able to close")
}
m.Remove(s.ID)
if !m.CloseIfNoSession() {
t.Error("not able to close")
}
}

126
common/mux/writer.go Normal file
View file

@ -0,0 +1,126 @@
package mux
import (
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/net"
"github.com/xtls/xray-core/v1/common/protocol"
"github.com/xtls/xray-core/v1/common/serial"
)
type Writer struct {
dest net.Destination
writer buf.Writer
id uint16
followup bool
hasError bool
transferType protocol.TransferType
}
func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType) *Writer {
return &Writer{
id: id,
dest: dest,
writer: writer,
followup: false,
transferType: transferType,
}
}
func NewResponseWriter(id uint16, writer buf.Writer, transferType protocol.TransferType) *Writer {
return &Writer{
id: id,
writer: writer,
followup: true,
transferType: transferType,
}
}
func (w *Writer) getNextFrameMeta() FrameMetadata {
meta := FrameMetadata{
SessionID: w.id,
Target: w.dest,
}
if w.followup {
meta.SessionStatus = SessionStatusKeep
} else {
w.followup = true
meta.SessionStatus = SessionStatusNew
}
return meta
}
func (w *Writer) writeMetaOnly() error {
meta := w.getNextFrameMeta()
b := buf.New()
if err := meta.WriteTo(b); err != nil {
return err
}
return w.writer.WriteMultiBuffer(buf.MultiBuffer{b})
}
func writeMetaWithFrame(writer buf.Writer, meta FrameMetadata, data buf.MultiBuffer) error {
frame := buf.New()
if err := meta.WriteTo(frame); err != nil {
return err
}
if _, err := serial.WriteUint16(frame, uint16(data.Len())); err != nil {
return err
}
mb2 := make(buf.MultiBuffer, 0, len(data)+1)
mb2 = append(mb2, frame)
mb2 = append(mb2, data...)
return writer.WriteMultiBuffer(mb2)
}
func (w *Writer) writeData(mb buf.MultiBuffer) error {
meta := w.getNextFrameMeta()
meta.Option.Set(OptionData)
return writeMetaWithFrame(w.writer, meta, mb)
}
// WriteMultiBuffer implements buf.Writer.
func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer buf.ReleaseMulti(mb)
if mb.IsEmpty() {
return w.writeMetaOnly()
}
for !mb.IsEmpty() {
var chunk buf.MultiBuffer
if w.transferType == protocol.TransferTypeStream {
mb, chunk = buf.SplitSize(mb, 8*1024)
} else {
mb2, b := buf.SplitFirst(mb)
mb = mb2
chunk = buf.MultiBuffer{b}
}
if err := w.writeData(chunk); err != nil {
return err
}
}
return nil
}
// Close implements common.Closable.
func (w *Writer) Close() error {
meta := FrameMetadata{
SessionID: w.id,
SessionStatus: SessionStatusEnd,
}
if w.hasError {
meta.Option.Set(OptionError)
}
frame := buf.New()
common.Must(meta.WriteTo(frame))
w.writer.WriteMultiBuffer(buf.MultiBuffer{frame})
return nil
}

211
common/net/address.go Normal file
View file

@ -0,0 +1,211 @@
package net
import (
"bytes"
"net"
"strings"
)
var (
// LocalHostIP is a constant value for localhost IP in IPv4.
LocalHostIP = IPAddress([]byte{127, 0, 0, 1})
// AnyIP is a constant value for any IP in IPv4.
AnyIP = IPAddress([]byte{0, 0, 0, 0})
// LocalHostDomain is a constant value for localhost domain.
LocalHostDomain = DomainAddress("localhost")
// LocalHostIPv6 is a constant value for localhost IP in IPv6.
LocalHostIPv6 = IPAddress([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
// AnyIPv6 is a constant value for any IP in IPv6.
AnyIPv6 = IPAddress([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
)
// AddressFamily is the type of address.
type AddressFamily byte
const (
// AddressFamilyIPv4 represents address as IPv4
AddressFamilyIPv4 = AddressFamily(0)
// AddressFamilyIPv6 represents address as IPv6
AddressFamilyIPv6 = AddressFamily(1)
// AddressFamilyDomain represents address as Domain
AddressFamilyDomain = AddressFamily(2)
)
// IsIPv4 returns true if current AddressFamily is IPv4.
func (af AddressFamily) IsIPv4() bool {
return af == AddressFamilyIPv4
}
// IsIPv6 returns true if current AddressFamily is IPv6.
func (af AddressFamily) IsIPv6() bool {
return af == AddressFamilyIPv6
}
// IsIP returns true if current AddressFamily is IPv6 or IPv4.
func (af AddressFamily) IsIP() bool {
return af == AddressFamilyIPv4 || af == AddressFamilyIPv6
}
// IsDomain returns true if current AddressFamily is Domain.
func (af AddressFamily) IsDomain() bool {
return af == AddressFamilyDomain
}
// Address represents a network address to be communicated with. It may be an IP address or domain
// address, not both. This interface doesn't resolve IP address for a given domain.
type Address interface {
IP() net.IP // IP of this Address
Domain() string // Domain of this Address
Family() AddressFamily
String() string // String representation of this Address
}
func isAlphaNum(c byte) bool {
return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
}
// ParseAddress parses a string into an Address. The return value will be an IPAddress when
// the string is in the form of IPv4 or IPv6 address, or a DomainAddress otherwise.
func ParseAddress(addr string) Address {
// Handle IPv6 address in form as "[2001:4860:0:2001::68]"
lenAddr := len(addr)
if lenAddr > 0 && addr[0] == '[' && addr[lenAddr-1] == ']' {
addr = addr[1 : lenAddr-1]
lenAddr -= 2
}
if lenAddr > 0 && (!isAlphaNum(addr[0]) || !isAlphaNum(addr[len(addr)-1])) {
addr = strings.TrimSpace(addr)
}
ip := net.ParseIP(addr)
if ip != nil {
return IPAddress(ip)
}
return DomainAddress(addr)
}
var bytes0 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// IPAddress creates an Address with given IP.
func IPAddress(ip []byte) Address {
switch len(ip) {
case net.IPv4len:
var addr ipv4Address = [4]byte{ip[0], ip[1], ip[2], ip[3]}
return addr
case net.IPv6len:
if bytes.Equal(ip[:10], bytes0) && ip[10] == 0xff && ip[11] == 0xff {
return IPAddress(ip[12:16])
}
var addr ipv6Address = [16]byte{
ip[0], ip[1], ip[2], ip[3],
ip[4], ip[5], ip[6], ip[7],
ip[8], ip[9], ip[10], ip[11],
ip[12], ip[13], ip[14], ip[15],
}
return addr
default:
newError("invalid IP format: ", ip).AtError().WriteToLog()
return nil
}
}
// DomainAddress creates an Address with given domain.
func DomainAddress(domain string) Address {
return domainAddress(domain)
}
type ipv4Address [4]byte
func (a ipv4Address) IP() net.IP {
return net.IP(a[:])
}
func (ipv4Address) Domain() string {
panic("Calling Domain() on an IPv4Address.")
}
func (ipv4Address) Family() AddressFamily {
return AddressFamilyIPv4
}
func (a ipv4Address) String() string {
return a.IP().String()
}
type ipv6Address [16]byte
func (a ipv6Address) IP() net.IP {
return net.IP(a[:])
}
func (ipv6Address) Domain() string {
panic("Calling Domain() on an IPv6Address.")
}
func (ipv6Address) Family() AddressFamily {
return AddressFamilyIPv6
}
func (a ipv6Address) String() string {
return "[" + a.IP().String() + "]"
}
type domainAddress string
func (domainAddress) IP() net.IP {
panic("Calling IP() on a DomainAddress.")
}
func (a domainAddress) Domain() string {
return string(a)
}
func (domainAddress) Family() AddressFamily {
return AddressFamilyDomain
}
func (a domainAddress) String() string {
return a.Domain()
}
// AsAddress translates IPOrDomain to Address.
func (d *IPOrDomain) AsAddress() Address {
if d == nil {
return nil
}
switch addr := d.Address.(type) {
case *IPOrDomain_Ip:
return IPAddress(addr.Ip)
case *IPOrDomain_Domain:
return DomainAddress(addr.Domain)
}
panic("Common|Net: Invalid address.")
}
// NewIPOrDomain translates Address to IPOrDomain
func NewIPOrDomain(addr Address) *IPOrDomain {
switch addr.Family() {
case AddressFamilyDomain:
return &IPOrDomain{
Address: &IPOrDomain_Domain{
Domain: addr.Domain(),
},
}
case AddressFamilyIPv4, AddressFamilyIPv6:
return &IPOrDomain{
Address: &IPOrDomain_Ip{
Ip: addr.IP(),
},
}
default:
panic("Unknown Address type.")
}
}

195
common/net/address.pb.go Normal file
View file

@ -0,0 +1,195 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.14.0
// source: common/net/address.proto
package net
import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
// Address of a network host. It may be either an IP address or a domain
// address.
type IPOrDomain struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Address:
// *IPOrDomain_Ip
// *IPOrDomain_Domain
Address isIPOrDomain_Address `protobuf_oneof:"address"`
}
func (x *IPOrDomain) Reset() {
*x = IPOrDomain{}
if protoimpl.UnsafeEnabled {
mi := &file_common_net_address_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *IPOrDomain) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*IPOrDomain) ProtoMessage() {}
func (x *IPOrDomain) ProtoReflect() protoreflect.Message {
mi := &file_common_net_address_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use IPOrDomain.ProtoReflect.Descriptor instead.
func (*IPOrDomain) Descriptor() ([]byte, []int) {
return file_common_net_address_proto_rawDescGZIP(), []int{0}
}
func (m *IPOrDomain) GetAddress() isIPOrDomain_Address {
if m != nil {
return m.Address
}
return nil
}
func (x *IPOrDomain) GetIp() []byte {
if x, ok := x.GetAddress().(*IPOrDomain_Ip); ok {
return x.Ip
}
return nil
}
func (x *IPOrDomain) GetDomain() string {
if x, ok := x.GetAddress().(*IPOrDomain_Domain); ok {
return x.Domain
}
return ""
}
type isIPOrDomain_Address interface {
isIPOrDomain_Address()
}
type IPOrDomain_Ip struct {
// IP address. Must by either 4 or 16 bytes.
Ip []byte `protobuf:"bytes,1,opt,name=ip,proto3,oneof"`
}
type IPOrDomain_Domain struct {
// Domain address.
Domain string `protobuf:"bytes,2,opt,name=domain,proto3,oneof"`
}
func (*IPOrDomain_Ip) isIPOrDomain_Address() {}
func (*IPOrDomain_Domain) isIPOrDomain_Address() {}
var File_common_net_address_proto protoreflect.FileDescriptor
var file_common_net_address_proto_rawDesc = []byte{
0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x61, 0x64, 0x64,
0x72, 0x65, 0x73, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0f, 0x78, 0x72, 0x61, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x22, 0x43, 0x0a, 0x0a, 0x49,
0x50, 0x4f, 0x72, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x02, 0x69, 0x70, 0x18,
0x01, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x02, 0x69, 0x70, 0x12, 0x18, 0x0a, 0x06, 0x64,
0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x06, 0x64,
0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x42, 0x09, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73,
0x42, 0x52, 0x0a, 0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d,
0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75,
0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d,
0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e,
0x65, 0x74, 0xaa, 0x02, 0x0f, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e,
0x2e, 0x4e, 0x65, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_common_net_address_proto_rawDescOnce sync.Once
file_common_net_address_proto_rawDescData = file_common_net_address_proto_rawDesc
)
func file_common_net_address_proto_rawDescGZIP() []byte {
file_common_net_address_proto_rawDescOnce.Do(func() {
file_common_net_address_proto_rawDescData = protoimpl.X.CompressGZIP(file_common_net_address_proto_rawDescData)
})
return file_common_net_address_proto_rawDescData
}
var file_common_net_address_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_common_net_address_proto_goTypes = []interface{}{
(*IPOrDomain)(nil), // 0: xray.common.net.IPOrDomain
}
var file_common_net_address_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_common_net_address_proto_init() }
func file_common_net_address_proto_init() {
if File_common_net_address_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_common_net_address_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*IPOrDomain); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_common_net_address_proto_msgTypes[0].OneofWrappers = []interface{}{
(*IPOrDomain_Ip)(nil),
(*IPOrDomain_Domain)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_common_net_address_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_net_address_proto_goTypes,
DependencyIndexes: file_common_net_address_proto_depIdxs,
MessageInfos: file_common_net_address_proto_msgTypes,
}.Build()
File_common_net_address_proto = out.File
file_common_net_address_proto_rawDesc = nil
file_common_net_address_proto_goTypes = nil
file_common_net_address_proto_depIdxs = nil
}

19
common/net/address.proto Normal file
View file

@ -0,0 +1,19 @@
syntax = "proto3";
package xray.common.net;
option csharp_namespace = "Xray.Common.Net";
option go_package = "github.com/xtls/xray-core/v1/common/net";
option java_package = "com.xray.common.net";
option java_multiple_files = true;
// Address of a network host. It may be either an IP address or a domain
// address.
message IPOrDomain {
oneof address {
// IP address. Must by either 4 or 16 bytes.
bytes ip = 1;
// Domain address.
string domain = 2;
}
}

194
common/net/address_test.go Normal file
View file

@ -0,0 +1,194 @@
package net_test
import (
"net"
"testing"
"github.com/google/go-cmp/cmp"
. "github.com/xtls/xray-core/v1/common/net"
)
func TestAddressProperty(t *testing.T) {
type addrProprty struct {
IP []byte
Domain string
Family AddressFamily
String string
}
testCases := []struct {
Input Address
Output addrProprty
}{
{
Input: IPAddress([]byte{byte(1), byte(2), byte(3), byte(4)}),
Output: addrProprty{
IP: []byte{byte(1), byte(2), byte(3), byte(4)},
Family: AddressFamilyIPv4,
String: "1.2.3.4",
},
},
{
Input: IPAddress([]byte{
byte(1), byte(2), byte(3), byte(4),
byte(1), byte(2), byte(3), byte(4),
byte(1), byte(2), byte(3), byte(4),
byte(1), byte(2), byte(3), byte(4),
}),
Output: addrProprty{
IP: []byte{
byte(1), byte(2), byte(3), byte(4),
byte(1), byte(2), byte(3), byte(4),
byte(1), byte(2), byte(3), byte(4),
byte(1), byte(2), byte(3), byte(4),
},
Family: AddressFamilyIPv6,
String: "[102:304:102:304:102:304:102:304]",
},
},
{
Input: IPAddress([]byte{
byte(0), byte(0), byte(0), byte(0),
byte(0), byte(0), byte(0), byte(0),
byte(0), byte(0), byte(255), byte(255),
byte(1), byte(2), byte(3), byte(4),
}),
Output: addrProprty{
IP: []byte{byte(1), byte(2), byte(3), byte(4)},
Family: AddressFamilyIPv4,
String: "1.2.3.4",
},
},
{
Input: DomainAddress("example.com"),
Output: addrProprty{
Domain: "example.com",
Family: AddressFamilyDomain,
String: "example.com",
},
},
{
Input: IPAddress(net.IPv4(1, 2, 3, 4)),
Output: addrProprty{
IP: []byte{byte(1), byte(2), byte(3), byte(4)},
Family: AddressFamilyIPv4,
String: "1.2.3.4",
},
},
{
Input: ParseAddress("[2001:4860:0:2001::68]"),
Output: addrProprty{
IP: []byte{0x20, 0x01, 0x48, 0x60, 0x00, 0x00, 0x20, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x68},
Family: AddressFamilyIPv6,
String: "[2001:4860:0:2001::68]",
},
},
{
Input: ParseAddress("::0"),
Output: addrProprty{
IP: AnyIPv6.IP(),
Family: AddressFamilyIPv6,
String: "[::]",
},
},
{
Input: ParseAddress("[::ffff:123.151.71.143]"),
Output: addrProprty{
IP: []byte{123, 151, 71, 143},
Family: AddressFamilyIPv4,
String: "123.151.71.143",
},
},
{
Input: NewIPOrDomain(ParseAddress("example.com")).AsAddress(),
Output: addrProprty{
Domain: "example.com",
Family: AddressFamilyDomain,
String: "example.com",
},
},
{
Input: NewIPOrDomain(ParseAddress("8.8.8.8")).AsAddress(),
Output: addrProprty{
IP: []byte{8, 8, 8, 8},
Family: AddressFamilyIPv4,
String: "8.8.8.8",
},
},
{
Input: NewIPOrDomain(ParseAddress("[2001:4860:0:2001::68]")).AsAddress(),
Output: addrProprty{
IP: []byte{0x20, 0x01, 0x48, 0x60, 0x00, 0x00, 0x20, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x68},
Family: AddressFamilyIPv6,
String: "[2001:4860:0:2001::68]",
},
},
}
for _, testCase := range testCases {
actual := addrProprty{
Family: testCase.Input.Family(),
String: testCase.Input.String(),
}
if testCase.Input.Family().IsIP() {
actual.IP = testCase.Input.IP()
} else {
actual.Domain = testCase.Input.Domain()
}
if r := cmp.Diff(actual, testCase.Output); r != "" {
t.Error("for input: ", testCase.Input, ":", r)
}
}
}
func TestInvalidAddressConvertion(t *testing.T) {
panics := func(f func()) (ret bool) {
defer func() {
if r := recover(); r != nil {
ret = true
}
}()
f()
return false
}
testCases := []func(){
func() { ParseAddress("8.8.8.8").Domain() },
func() { ParseAddress("2001:4860:0:2001::68").Domain() },
func() { ParseAddress("example.com").IP() },
}
for idx, testCase := range testCases {
if !panics(testCase) {
t.Error("case ", idx, " failed")
}
}
}
func BenchmarkParseAddressIPv4(b *testing.B) {
for i := 0; i < b.N; i++ {
addr := ParseAddress("8.8.8.8")
if addr.Family() != AddressFamilyIPv4 {
panic("not ipv4")
}
}
}
func BenchmarkParseAddressIPv6(b *testing.B) {
for i := 0; i < b.N; i++ {
addr := ParseAddress("2001:4860:0:2001::68")
if addr.Family() != AddressFamilyIPv6 {
panic("not ipv6")
}
}
}
func BenchmarkParseAddressDomain(b *testing.B) {
for i := 0; i < b.N; i++ {
addr := ParseAddress("example.com")
if addr.Family() != AddressFamilyDomain {
panic("not domain")
}
}
}

162
common/net/connection.go Normal file
View file

@ -0,0 +1,162 @@
// +build !confonly
package net
import (
"io"
"net"
"time"
"github.com/xtls/xray-core/v1/common"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/signal/done"
)
type ConnectionOption func(*connection)
func ConnectionLocalAddr(a net.Addr) ConnectionOption {
return func(c *connection) {
c.local = a
}
}
func ConnectionRemoteAddr(a net.Addr) ConnectionOption {
return func(c *connection) {
c.remote = a
}
}
func ConnectionInput(writer io.Writer) ConnectionOption {
return func(c *connection) {
c.writer = buf.NewWriter(writer)
}
}
func ConnectionInputMulti(writer buf.Writer) ConnectionOption {
return func(c *connection) {
c.writer = writer
}
}
func ConnectionOutput(reader io.Reader) ConnectionOption {
return func(c *connection) {
c.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)}
}
}
func ConnectionOutputMulti(reader buf.Reader) ConnectionOption {
return func(c *connection) {
c.reader = &buf.BufferedReader{Reader: reader}
}
}
func ConnectionOutputMultiUDP(reader buf.Reader) ConnectionOption {
return func(c *connection) {
c.reader = &buf.BufferedReader{
Reader: reader,
Spliter: buf.SplitFirstBytes,
}
}
}
func ConnectionOnClose(n io.Closer) ConnectionOption {
return func(c *connection) {
c.onClose = n
}
}
func NewConnection(opts ...ConnectionOption) net.Conn {
c := &connection{
done: done.New(),
local: &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
Port: 0,
},
remote: &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
Port: 0,
},
}
for _, opt := range opts {
opt(c)
}
return c
}
type connection struct {
reader *buf.BufferedReader
writer buf.Writer
done *done.Instance
onClose io.Closer
local Addr
remote Addr
}
func (c *connection) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
// ReadMultiBuffer implements buf.Reader.
func (c *connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
return c.reader.ReadMultiBuffer()
}
// Write implements net.Conn.Write().
func (c *connection) Write(b []byte) (int, error) {
if c.done.Done() {
return 0, io.ErrClosedPipe
}
l := len(b)
mb := make(buf.MultiBuffer, 0, l/buf.Size+1)
mb = buf.MergeBytes(mb, b)
return l, c.writer.WriteMultiBuffer(mb)
}
func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.done.Done() {
buf.ReleaseMulti(mb)
return io.ErrClosedPipe
}
return c.writer.WriteMultiBuffer(mb)
}
// Close implements net.Conn.Close().
func (c *connection) Close() error {
common.Must(c.done.Close())
common.Interrupt(c.reader)
common.Close(c.writer)
if c.onClose != nil {
return c.onClose.Close()
}
return nil
}
// LocalAddr implements net.Conn.LocalAddr().
func (c *connection) LocalAddr() net.Addr {
return c.local
}
// RemoteAddr implements net.Conn.RemoteAddr().
func (c *connection) RemoteAddr() net.Addr {
return c.remote
}
// SetDeadline implements net.Conn.SetDeadline().
func (c *connection) SetDeadline(t time.Time) error {
return nil
}
// SetReadDeadline implements net.Conn.SetReadDeadline().
func (c *connection) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline implements net.Conn.SetWriteDeadline().
func (c *connection) SetWriteDeadline(t time.Time) error {
return nil
}

126
common/net/destination.go Normal file
View file

@ -0,0 +1,126 @@
package net
import (
"net"
"strings"
)
// Destination represents a network destination including address and protocol (tcp / udp).
type Destination struct {
Address Address
Port Port
Network Network
}
// DestinationFromAddr generates a Destination from a net address.
func DestinationFromAddr(addr net.Addr) Destination {
switch addr := addr.(type) {
case *net.TCPAddr:
return TCPDestination(IPAddress(addr.IP), Port(addr.Port))
case *net.UDPAddr:
return UDPDestination(IPAddress(addr.IP), Port(addr.Port))
case *net.UnixAddr:
return UnixDestination(DomainAddress(addr.Name))
default:
panic("Net: Unknown address type.")
}
}
// ParseDestination converts a destination from its string presentation.
func ParseDestination(dest string) (Destination, error) {
d := Destination{
Address: AnyIP,
Port: Port(0),
}
if strings.HasPrefix(dest, "tcp:") {
d.Network = Network_TCP
dest = dest[4:]
} else if strings.HasPrefix(dest, "udp:") {
d.Network = Network_UDP
dest = dest[4:]
} else if strings.HasPrefix(dest, "unix:") {
d = UnixDestination(DomainAddress(dest[5:]))
return d, nil
}
hstr, pstr, err := SplitHostPort(dest)
if err != nil {
return d, err
}
if len(hstr) > 0 {
d.Address = ParseAddress(hstr)
}
if len(pstr) > 0 {
port, err := PortFromString(pstr)
if err != nil {
return d, err
}
d.Port = port
}
return d, nil
}
// TCPDestination creates a TCP destination with given address
func TCPDestination(address Address, port Port) Destination {
return Destination{
Network: Network_TCP,
Address: address,
Port: port,
}
}
// UDPDestination creates a UDP destination with given address
func UDPDestination(address Address, port Port) Destination {
return Destination{
Network: Network_UDP,
Address: address,
Port: port,
}
}
// UnixDestination creates a Unix destination with given address
func UnixDestination(address Address) Destination {
return Destination{
Network: Network_UNIX,
Address: address,
}
}
// NetAddr returns the network address in this Destination in string form.
func (d Destination) NetAddr() string {
addr := ""
if d.Network == Network_TCP || d.Network == Network_UDP {
addr = d.Address.String() + ":" + d.Port.String()
} else if d.Network == Network_UNIX {
addr = d.Address.String()
}
return addr
}
// String returns the strings form of this Destination.
func (d Destination) String() string {
prefix := "unknown:"
switch d.Network {
case Network_TCP:
prefix = "tcp:"
case Network_UDP:
prefix = "udp:"
case Network_UNIX:
prefix = "unix:"
}
return prefix + d.NetAddr()
}
// IsValid returns true if this Destination is valid.
func (d Destination) IsValid() bool {
return d.Network != Network_Unknown
}
// AsDestination converts current Endpoint into Destination.
func (p *Endpoint) AsDestination() Destination {
return Destination{
Network: p.Network,
Address: p.Address.AsAddress(),
Port: Port(p.Port),
}
}

View file

@ -0,0 +1,185 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.14.0
// source: common/net/destination.proto
package net
import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
// Endpoint of a network connection.
type Endpoint struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Network Network `protobuf:"varint,1,opt,name=network,proto3,enum=xray.common.net.Network" json:"network,omitempty"`
Address *IPOrDomain `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
Port uint32 `protobuf:"varint,3,opt,name=port,proto3" json:"port,omitempty"`
}
func (x *Endpoint) Reset() {
*x = Endpoint{}
if protoimpl.UnsafeEnabled {
mi := &file_common_net_destination_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Endpoint) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Endpoint) ProtoMessage() {}
func (x *Endpoint) ProtoReflect() protoreflect.Message {
mi := &file_common_net_destination_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Endpoint.ProtoReflect.Descriptor instead.
func (*Endpoint) Descriptor() ([]byte, []int) {
return file_common_net_destination_proto_rawDescGZIP(), []int{0}
}
func (x *Endpoint) GetNetwork() Network {
if x != nil {
return x.Network
}
return Network_Unknown
}
func (x *Endpoint) GetAddress() *IPOrDomain {
if x != nil {
return x.Address
}
return nil
}
func (x *Endpoint) GetPort() uint32 {
if x != nil {
return x.Port
}
return 0
}
var File_common_net_destination_proto protoreflect.FileDescriptor
var file_common_net_destination_proto_rawDesc = []byte{
0x0a, 0x1c, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x64, 0x65, 0x73,
0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0f,
0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x1a,
0x18, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x6e, 0x65, 0x74, 0x77,
0x6f, 0x72, 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x18, 0x63, 0x6f, 0x6d, 0x6d, 0x6f,
0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0x89, 0x01, 0x0a, 0x08, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
0x12, 0x32, 0x0a, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28,
0x0e, 0x32, 0x18, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e,
0x6e, 0x65, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x07, 0x6e, 0x65, 0x74,
0x77, 0x6f, 0x72, 0x6b, 0x12, 0x35, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d,
0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x49, 0x50, 0x4f, 0x72, 0x44, 0x6f, 0x6d, 0x61,
0x69, 0x6e, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x70,
0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x42,
0x52, 0x0a, 0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d,
0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63,
0x6f, 0x72, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65,
0x74, 0xaa, 0x02, 0x0f, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e,
0x4e, 0x65, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_common_net_destination_proto_rawDescOnce sync.Once
file_common_net_destination_proto_rawDescData = file_common_net_destination_proto_rawDesc
)
func file_common_net_destination_proto_rawDescGZIP() []byte {
file_common_net_destination_proto_rawDescOnce.Do(func() {
file_common_net_destination_proto_rawDescData = protoimpl.X.CompressGZIP(file_common_net_destination_proto_rawDescData)
})
return file_common_net_destination_proto_rawDescData
}
var file_common_net_destination_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_common_net_destination_proto_goTypes = []interface{}{
(*Endpoint)(nil), // 0: xray.common.net.Endpoint
(Network)(0), // 1: xray.common.net.Network
(*IPOrDomain)(nil), // 2: xray.common.net.IPOrDomain
}
var file_common_net_destination_proto_depIdxs = []int32{
1, // 0: xray.common.net.Endpoint.network:type_name -> xray.common.net.Network
2, // 1: xray.common.net.Endpoint.address:type_name -> xray.common.net.IPOrDomain
2, // [2:2] is the sub-list for method output_type
2, // [2:2] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_common_net_destination_proto_init() }
func file_common_net_destination_proto_init() {
if File_common_net_destination_proto != nil {
return
}
file_common_net_network_proto_init()
file_common_net_address_proto_init()
if !protoimpl.UnsafeEnabled {
file_common_net_destination_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Endpoint); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_common_net_destination_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_net_destination_proto_goTypes,
DependencyIndexes: file_common_net_destination_proto_depIdxs,
MessageInfos: file_common_net_destination_proto_msgTypes,
}.Build()
File_common_net_destination_proto = out.File
file_common_net_destination_proto_rawDesc = nil
file_common_net_destination_proto_goTypes = nil
file_common_net_destination_proto_depIdxs = nil
}

View file

@ -0,0 +1,17 @@
syntax = "proto3";
package xray.common.net;
option csharp_namespace = "Xray.Common.Net";
option go_package = "github.com/xtls/xray-core/v1/common/net";
option java_package = "com.xray.common.net";
option java_multiple_files = true;
import "common/net/network.proto";
import "common/net/address.proto";
// Endpoint of a network connection.
message Endpoint {
Network network = 1;
IPOrDomain address = 2;
uint32 port = 3;
}

View file

@ -0,0 +1,111 @@
package net_test
import (
"testing"
"github.com/google/go-cmp/cmp"
. "github.com/xtls/xray-core/v1/common/net"
)
func TestDestinationProperty(t *testing.T) {
testCases := []struct {
Input Destination
Network Network
String string
NetString string
}{
{
Input: TCPDestination(IPAddress([]byte{1, 2, 3, 4}), 80),
Network: Network_TCP,
String: "tcp:1.2.3.4:80",
NetString: "1.2.3.4:80",
},
{
Input: UDPDestination(IPAddress([]byte{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88}), 53),
Network: Network_UDP,
String: "udp:[2001:4860:4860::8888]:53",
NetString: "[2001:4860:4860::8888]:53",
},
{
Input: UnixDestination(DomainAddress("/tmp/test.sock")),
Network: Network_UNIX,
String: "unix:/tmp/test.sock",
NetString: "/tmp/test.sock",
},
}
for _, testCase := range testCases {
dest := testCase.Input
if r := cmp.Diff(dest.Network, testCase.Network); r != "" {
t.Error("unexpected Network in ", dest.String(), ": ", r)
}
if r := cmp.Diff(dest.String(), testCase.String); r != "" {
t.Error(r)
}
if r := cmp.Diff(dest.NetAddr(), testCase.NetString); r != "" {
t.Error(r)
}
}
}
func TestDestinationParse(t *testing.T) {
cases := []struct {
Input string
Output Destination
Error bool
}{
{
Input: "tcp:127.0.0.1:80",
Output: TCPDestination(LocalHostIP, Port(80)),
},
{
Input: "udp:8.8.8.8:53",
Output: UDPDestination(IPAddress([]byte{8, 8, 8, 8}), Port(53)),
},
{
Input: "unix:/tmp/test.sock",
Output: UnixDestination(DomainAddress("/tmp/test.sock")),
},
{
Input: "8.8.8.8:53",
Output: Destination{
Address: IPAddress([]byte{8, 8, 8, 8}),
Port: Port(53),
},
},
{
Input: ":53",
Output: Destination{
Address: AnyIP,
Port: Port(53),
},
},
{
Input: "8.8.8.8",
Error: true,
},
{
Input: "8.8.8.8:http",
Error: true,
},
{
Input: "/tmp/test.sock",
Error: true,
},
}
for _, testcase := range cases {
d, err := ParseDestination(testcase.Input)
if !testcase.Error {
if err != nil {
t.Error("for test case: ", testcase.Input, " expected no error, but got ", err)
}
if d != testcase.Output {
t.Error("for test case: ", testcase.Input, " expected output: ", testcase.Output.String(), " but got ", d.String())
}
} else if err == nil {
t.Error("for test case: ", testcase.Input, " expected error, but got nil")
}
}
}

View file

@ -0,0 +1,9 @@
package net
import "github.com/xtls/xray-core/v1/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

4
common/net/net.go Normal file
View file

@ -0,0 +1,4 @@
// Package net is a drop-in replacement to Golang's net package, with some more functionalities.
package net // import "github.com/xtls/xray-core/v1/common/net"
//go:generate go run github.com/xtls/xray-core/v1/common/errors/errorgen

24
common/net/network.go Normal file
View file

@ -0,0 +1,24 @@
package net
func (n Network) SystemString() string {
switch n {
case Network_TCP:
return "tcp"
case Network_UDP:
return "udp"
case Network_UNIX:
return "unix"
default:
return "unknown"
}
}
// HasNetwork returns true if the network list has a certain network.
func HasNetwork(list []Network, network Network) bool {
for _, value := range list {
if value == network {
return true
}
}
return false
}

219
common/net/network.pb.go Normal file
View file

@ -0,0 +1,219 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.14.0
// source: common/net/network.proto
package net
import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type Network int32
const (
Network_Unknown Network = 0
// Deprecated: Do not use.
Network_RawTCP Network = 1
Network_TCP Network = 2
Network_UDP Network = 3
Network_UNIX Network = 4
)
// Enum value maps for Network.
var (
Network_name = map[int32]string{
0: "Unknown",
1: "RawTCP",
2: "TCP",
3: "UDP",
4: "UNIX",
}
Network_value = map[string]int32{
"Unknown": 0,
"RawTCP": 1,
"TCP": 2,
"UDP": 3,
"UNIX": 4,
}
)
func (x Network) Enum() *Network {
p := new(Network)
*p = x
return p
}
func (x Network) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Network) Descriptor() protoreflect.EnumDescriptor {
return file_common_net_network_proto_enumTypes[0].Descriptor()
}
func (Network) Type() protoreflect.EnumType {
return &file_common_net_network_proto_enumTypes[0]
}
func (x Network) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use Network.Descriptor instead.
func (Network) EnumDescriptor() ([]byte, []int) {
return file_common_net_network_proto_rawDescGZIP(), []int{0}
}
// NetworkList is a list of Networks.
type NetworkList struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Network []Network `protobuf:"varint,1,rep,packed,name=network,proto3,enum=xray.common.net.Network" json:"network,omitempty"`
}
func (x *NetworkList) Reset() {
*x = NetworkList{}
if protoimpl.UnsafeEnabled {
mi := &file_common_net_network_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *NetworkList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*NetworkList) ProtoMessage() {}
func (x *NetworkList) ProtoReflect() protoreflect.Message {
mi := &file_common_net_network_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use NetworkList.ProtoReflect.Descriptor instead.
func (*NetworkList) Descriptor() ([]byte, []int) {
return file_common_net_network_proto_rawDescGZIP(), []int{0}
}
func (x *NetworkList) GetNetwork() []Network {
if x != nil {
return x.Network
}
return nil
}
var File_common_net_network_proto protoreflect.FileDescriptor
var file_common_net_network_proto_rawDesc = []byte{
0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x6e, 0x65, 0x74,
0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0f, 0x78, 0x72, 0x61, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x22, 0x41, 0x0a, 0x0b, 0x4e,
0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x32, 0x0a, 0x07, 0x6e, 0x65,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x78, 0x72,
0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x4e, 0x65,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2a, 0x42,
0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x6e, 0x6b,
0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x06, 0x52, 0x61, 0x77, 0x54, 0x43, 0x50,
0x10, 0x01, 0x1a, 0x02, 0x08, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12,
0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x55, 0x4e, 0x49, 0x58,
0x10, 0x04, 0x42, 0x52, 0x0a, 0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63,
0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x27, 0x67, 0x69, 0x74,
0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61,
0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e,
0x2f, 0x6e, 0x65, 0x74, 0xaa, 0x02, 0x0f, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x43, 0x6f, 0x6d, 0x6d,
0x6f, 0x6e, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_common_net_network_proto_rawDescOnce sync.Once
file_common_net_network_proto_rawDescData = file_common_net_network_proto_rawDesc
)
func file_common_net_network_proto_rawDescGZIP() []byte {
file_common_net_network_proto_rawDescOnce.Do(func() {
file_common_net_network_proto_rawDescData = protoimpl.X.CompressGZIP(file_common_net_network_proto_rawDescData)
})
return file_common_net_network_proto_rawDescData
}
var file_common_net_network_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_common_net_network_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_common_net_network_proto_goTypes = []interface{}{
(Network)(0), // 0: xray.common.net.Network
(*NetworkList)(nil), // 1: xray.common.net.NetworkList
}
var file_common_net_network_proto_depIdxs = []int32{
0, // 0: xray.common.net.NetworkList.network:type_name -> xray.common.net.Network
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_common_net_network_proto_init() }
func file_common_net_network_proto_init() {
if File_common_net_network_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_common_net_network_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NetworkList); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_common_net_network_proto_rawDesc,
NumEnums: 1,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_net_network_proto_goTypes,
DependencyIndexes: file_common_net_network_proto_depIdxs,
EnumInfos: file_common_net_network_proto_enumTypes,
MessageInfos: file_common_net_network_proto_msgTypes,
}.Build()
File_common_net_network_proto = out.File
file_common_net_network_proto_rawDesc = nil
file_common_net_network_proto_goTypes = nil
file_common_net_network_proto_depIdxs = nil
}

19
common/net/network.proto Normal file
View file

@ -0,0 +1,19 @@
syntax = "proto3";
package xray.common.net;
option csharp_namespace = "Xray.Common.Net";
option go_package = "github.com/xtls/xray-core/v1/common/net";
option java_package = "com.xray.common.net";
option java_multiple_files = true;
enum Network {
Unknown = 0;
RawTCP = 1 [deprecated = true];
TCP = 2;
UDP = 3;
UNIX = 4;
}
// NetworkList is a list of Networks.
message NetworkList { repeated Network network = 1; }

95
common/net/port.go Normal file
View file

@ -0,0 +1,95 @@
package net
import (
"encoding/binary"
"strconv"
)
// Port represents a network port in TCP and UDP protocol.
type Port uint16
// PortFromBytes converts a byte array to a Port, assuming bytes are in big endian order.
// @unsafe Caller must ensure that the byte array has at least 2 elements.
func PortFromBytes(port []byte) Port {
return Port(binary.BigEndian.Uint16(port))
}
// PortFromInt converts an integer to a Port.
// @error when the integer is not positive or larger then 65535
func PortFromInt(val uint32) (Port, error) {
if val > 65535 {
return Port(0), newError("invalid port range: ", val)
}
return Port(val), nil
}
// PortFromString converts a string to a Port.
// @error when the string is not an integer or the integral value is a not a valid Port.
func PortFromString(s string) (Port, error) {
val, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return Port(0), newError("invalid port range: ", s)
}
return PortFromInt(uint32(val))
}
// Value return the corresponding uint16 value of a Port.
func (p Port) Value() uint16 {
return uint16(p)
}
// String returns the string presentation of a Port.
func (p Port) String() string {
return strconv.Itoa(int(p))
}
// FromPort returns the beginning port of this PortRange.
func (p *PortRange) FromPort() Port {
return Port(p.From)
}
// ToPort returns the end port of this PortRange.
func (p *PortRange) ToPort() Port {
return Port(p.To)
}
// Contains returns true if the given port is within the range of a PortRange.
func (p *PortRange) Contains(port Port) bool {
return p.FromPort() <= port && port <= p.ToPort()
}
// SinglePortRange returns a PortRange contains a single port.
func SinglePortRange(p Port) *PortRange {
return &PortRange{
From: uint32(p),
To: uint32(p),
}
}
type MemoryPortRange struct {
From Port
To Port
}
func (r MemoryPortRange) Contains(port Port) bool {
return r.From <= port && port <= r.To
}
type MemoryPortList []MemoryPortRange
func PortListFromProto(l *PortList) MemoryPortList {
mpl := make(MemoryPortList, 0, len(l.Range))
for _, r := range l.Range {
mpl = append(mpl, MemoryPortRange{From: Port(r.From), To: Port(r.To)})
}
return mpl
}
func (mpl MemoryPortList) Contains(port Port) bool {
for _, pr := range mpl {
if pr.Contains(port) {
return true
}
}
return false
}

230
common/net/port.pb.go Normal file
View file

@ -0,0 +1,230 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.14.0
// source: common/net/port.proto
package net
import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
// PortRange represents a range of ports.
type PortRange struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// The port that this range starts from.
From uint32 `protobuf:"varint,1,opt,name=From,proto3" json:"From,omitempty"`
// The port that this range ends with (inclusive).
To uint32 `protobuf:"varint,2,opt,name=To,proto3" json:"To,omitempty"`
}
func (x *PortRange) Reset() {
*x = PortRange{}
if protoimpl.UnsafeEnabled {
mi := &file_common_net_port_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PortRange) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PortRange) ProtoMessage() {}
func (x *PortRange) ProtoReflect() protoreflect.Message {
mi := &file_common_net_port_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PortRange.ProtoReflect.Descriptor instead.
func (*PortRange) Descriptor() ([]byte, []int) {
return file_common_net_port_proto_rawDescGZIP(), []int{0}
}
func (x *PortRange) GetFrom() uint32 {
if x != nil {
return x.From
}
return 0
}
func (x *PortRange) GetTo() uint32 {
if x != nil {
return x.To
}
return 0
}
// PortList is a list of ports.
type PortList struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Range []*PortRange `protobuf:"bytes,1,rep,name=range,proto3" json:"range,omitempty"`
}
func (x *PortList) Reset() {
*x = PortList{}
if protoimpl.UnsafeEnabled {
mi := &file_common_net_port_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PortList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PortList) ProtoMessage() {}
func (x *PortList) ProtoReflect() protoreflect.Message {
mi := &file_common_net_port_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PortList.ProtoReflect.Descriptor instead.
func (*PortList) Descriptor() ([]byte, []int) {
return file_common_net_port_proto_rawDescGZIP(), []int{1}
}
func (x *PortList) GetRange() []*PortRange {
if x != nil {
return x.Range
}
return nil
}
var File_common_net_port_proto protoreflect.FileDescriptor
var file_common_net_port_proto_rawDesc = []byte{
0x0a, 0x15, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x70, 0x6f, 0x72,
0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0f, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f,
0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x22, 0x2f, 0x0a, 0x09, 0x50, 0x6f, 0x72, 0x74,
0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x46, 0x72, 0x6f, 0x6d, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0d, 0x52, 0x04, 0x46, 0x72, 0x6f, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x54, 0x6f, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x02, 0x54, 0x6f, 0x22, 0x3c, 0x0a, 0x08, 0x50, 0x6f, 0x72,
0x74, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x30, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x01,
0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d,
0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x52, 0x61, 0x6e, 0x67, 0x65,
0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x42, 0x52, 0x0a, 0x13, 0x63, 0x6f, 0x6d, 0x2e, 0x78,
0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x50, 0x01,
0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c,
0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x63,
0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0xaa, 0x02, 0x0f, 0x58, 0x72, 0x61, 0x79,
0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
}
var (
file_common_net_port_proto_rawDescOnce sync.Once
file_common_net_port_proto_rawDescData = file_common_net_port_proto_rawDesc
)
func file_common_net_port_proto_rawDescGZIP() []byte {
file_common_net_port_proto_rawDescOnce.Do(func() {
file_common_net_port_proto_rawDescData = protoimpl.X.CompressGZIP(file_common_net_port_proto_rawDescData)
})
return file_common_net_port_proto_rawDescData
}
var file_common_net_port_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_common_net_port_proto_goTypes = []interface{}{
(*PortRange)(nil), // 0: xray.common.net.PortRange
(*PortList)(nil), // 1: xray.common.net.PortList
}
var file_common_net_port_proto_depIdxs = []int32{
0, // 0: xray.common.net.PortList.range:type_name -> xray.common.net.PortRange
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_common_net_port_proto_init() }
func file_common_net_port_proto_init() {
if File_common_net_port_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_common_net_port_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PortRange); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_common_net_port_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PortList); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_common_net_port_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_net_port_proto_goTypes,
DependencyIndexes: file_common_net_port_proto_depIdxs,
MessageInfos: file_common_net_port_proto_msgTypes,
}.Build()
File_common_net_port_proto = out.File
file_common_net_port_proto_rawDesc = nil
file_common_net_port_proto_goTypes = nil
file_common_net_port_proto_depIdxs = nil
}

20
common/net/port.proto Normal file
View file

@ -0,0 +1,20 @@
syntax = "proto3";
package xray.common.net;
option csharp_namespace = "Xray.Common.Net";
option go_package = "github.com/xtls/xray-core/v1/common/net";
option java_package = "com.xray.common.net";
option java_multiple_files = true;
// PortRange represents a range of ports.
message PortRange {
// The port that this range starts from.
uint32 From = 1;
// The port that this range ends with (inclusive).
uint32 To = 2;
}
// PortList is a list of ports.
message PortList {
repeated PortRange range = 1;
}

18
common/net/port_test.go Normal file
View file

@ -0,0 +1,18 @@
package net_test
import (
"testing"
. "github.com/xtls/xray-core/v1/common/net"
)
func TestPortRangeContains(t *testing.T) {
portRange := &PortRange{
From: 53,
To: 53,
}
if !portRange.Contains(Port(53)) {
t.Error("expected port range containing 53, but actually not")
}
}

61
common/net/system.go Normal file
View file

@ -0,0 +1,61 @@
package net
import "net"
// DialTCP is an alias of net.DialTCP.
var DialTCP = net.DialTCP
var DialUDP = net.DialUDP
var DialUnix = net.DialUnix
var Dial = net.Dial
type ListenConfig = net.ListenConfig
var Listen = net.Listen
var ListenTCP = net.ListenTCP
var ListenUDP = net.ListenUDP
var ListenUnix = net.ListenUnix
var LookupIP = net.LookupIP
var FileConn = net.FileConn
// ParseIP is an alias of net.ParseIP
var ParseIP = net.ParseIP
var SplitHostPort = net.SplitHostPort
var CIDRMask = net.CIDRMask
type Addr = net.Addr
type Conn = net.Conn
type PacketConn = net.PacketConn
type TCPAddr = net.TCPAddr
type TCPConn = net.TCPConn
type UDPAddr = net.UDPAddr
type UDPConn = net.UDPConn
type UnixAddr = net.UnixAddr
type UnixConn = net.UnixConn
// IP is an alias for net.IP.
type IP = net.IP
type IPMask = net.IPMask
type IPNet = net.IPNet
const IPv4len = net.IPv4len
const IPv6len = net.IPv6len
type Error = net.Error
type AddrError = net.AddrError
type Dialer = net.Dialer
type Listener = net.Listener
type TCPListener = net.TCPListener
type UnixListener = net.UnixListener
var ResolveUnixAddr = net.ResolveUnixAddr
var ResolveUDPAddr = net.ResolveUDPAddr
type Resolver = net.Resolver

30
common/peer/latency.go Normal file
View file

@ -0,0 +1,30 @@
package peer
import (
"sync"
)
type Latency interface {
Value() uint64
}
type HasLatency interface {
ConnectionLatency() Latency
HandshakeLatency() Latency
}
type AverageLatency struct {
access sync.Mutex
value uint64
}
func (al *AverageLatency) Update(newValue uint64) {
al.access.Lock()
defer al.access.Unlock()
al.value = (al.value + newValue*2) / 3
}
func (al *AverageLatency) Value() uint64 {
return al.value
}

1
common/peer/peer.go Normal file
View file

@ -0,0 +1 @@
package peer

View file

@ -0,0 +1,9 @@
// +build !windows
package ctlcmd
import "syscall"
func getSysProcAttr() *syscall.SysProcAttr {
return nil
}

View file

@ -0,0 +1,11 @@
// +build windows
package ctlcmd
import "syscall"
func getSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
HideWindow: true,
}
}

View file

@ -0,0 +1,50 @@
package ctlcmd
import (
"io"
"os"
"os/exec"
"strings"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/platform"
)
//go:generate go run github.com/xtls/xray-core/v1/common/errors/errorgen
func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
xctl := platform.GetToolLocation("xctl")
if _, err := os.Stat(xctl); err != nil {
return nil, newError("xctl doesn't exist").Base(err)
}
var errBuffer buf.MultiBufferContainer
var outBuffer buf.MultiBufferContainer
cmd := exec.Command(xctl, args...)
cmd.Stderr = &errBuffer
cmd.Stdout = &outBuffer
cmd.SysProcAttr = getSysProcAttr()
if input != nil {
cmd.Stdin = input
}
if err := cmd.Start(); err != nil {
return nil, newError("failed to start xctl").Base(err)
}
if err := cmd.Wait(); err != nil {
msg := "failed to execute xctl"
if errBuffer.Len() > 0 {
msg += ": \n" + strings.TrimSpace(errBuffer.MultiBuffer.String())
}
return nil, newError(msg).Base(err)
}
// log stderr, info message
if !errBuffer.IsEmpty() {
newError("<xctl message> \n", strings.TrimSpace(errBuffer.MultiBuffer.String())).AtInfo().WriteToLog()
}
return outBuffer.MultiBuffer, nil
}

View file

@ -0,0 +1,9 @@
package ctlcmd
import "github.com/xtls/xray-core/v1/common/errors"
type errPathObjHolder struct{}
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).WithPathObj(errPathObjHolder{})
}

View file

@ -0,0 +1,44 @@
package filesystem
import (
"io"
"os"
"github.com/xtls/xray-core/v1/common/buf"
"github.com/xtls/xray-core/v1/common/platform"
)
type FileReaderFunc func(path string) (io.ReadCloser, error)
var NewFileReader FileReaderFunc = func(path string) (io.ReadCloser, error) {
return os.Open(path)
}
func ReadFile(path string) ([]byte, error) {
reader, err := NewFileReader(path)
if err != nil {
return nil, err
}
defer reader.Close()
return buf.ReadAllToBytes(reader)
}
func ReadAsset(file string) ([]byte, error) {
return ReadFile(platform.GetAssetLocation(file))
}
func CopyFile(dst string, src string) error {
bytes, err := ReadFile(src)
if err != nil {
return err
}
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer f.Close()
_, err = f.Write(bytes)
return err
}

44
common/platform/others.go Normal file
View file

@ -0,0 +1,44 @@
// +build !windows
package platform
import (
"os"
"path/filepath"
)
func ExpandEnv(s string) string {
return os.ExpandEnv(s)
}
func LineSeparator() string {
return "\n"
}
func GetToolLocation(file string) string {
const name = "xray.location.tool"
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(toolPath, file)
}
// GetAssetLocation search for `file` in certain locations
func GetAssetLocation(file string) string {
const name = "xray.location.asset"
assetPath := NewEnvFlag(name).GetValue(getExecutableDir)
defPath := filepath.Join(assetPath, file)
for _, p := range []string{
defPath,
filepath.Join("/usr/local/share/xray/", file),
filepath.Join("/usr/share/xray/", file),
} {
if _, err := os.Stat(p); os.IsNotExist(err) {
continue
}
// asset found
return p
}
// asset not found, let the caller throw out the error
return defPath
}

View file

@ -0,0 +1,86 @@
package platform // import "github.com/xtls/xray-core/v1/common/platform"
import (
"os"
"path/filepath"
"strconv"
"strings"
)
type EnvFlag struct {
Name string
AltName string
}
func NewEnvFlag(name string) EnvFlag {
return EnvFlag{
Name: name,
AltName: NormalizeEnvName(name),
}
}
func (f EnvFlag) GetValue(defaultValue func() string) string {
if v, found := os.LookupEnv(f.Name); found {
return v
}
if len(f.AltName) > 0 {
if v, found := os.LookupEnv(f.AltName); found {
return v
}
}
return defaultValue()
}
func (f EnvFlag) GetValueAsInt(defaultValue int) int {
useDefaultValue := false
s := f.GetValue(func() string {
useDefaultValue = true
return ""
})
if useDefaultValue {
return defaultValue
}
v, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return defaultValue
}
return int(v)
}
func NormalizeEnvName(name string) string {
return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(name)), ".", "_")
}
func getExecutableDir() string {
exec, err := os.Executable()
if err != nil {
return ""
}
return filepath.Dir(exec)
}
func getExecutableSubDir(dir string) func() string {
return func() string {
return filepath.Join(getExecutableDir(), dir)
}
}
func GetPluginDirectory() string {
const name = "xray.location.plugin"
pluginDir := NewEnvFlag(name).GetValue(getExecutableSubDir("plugins"))
return pluginDir
}
func GetConfigurationPath() string {
const name = "xray.location.config"
configPath := NewEnvFlag(name).GetValue(getExecutableDir)
return filepath.Join(configPath, "config.json")
}
// GetConfDirPath reads "xray.location.confdir"
func GetConfDirPath() string {
const name = "xray.location.confdir"
configPath := NewEnvFlag(name).GetValue(func() string { return "" })
return configPath
}

View file

@ -0,0 +1,65 @@
package platform_test
import (
"os"
"path/filepath"
"runtime"
"testing"
"github.com/xtls/xray-core/v1/common"
. "github.com/xtls/xray-core/v1/common/platform"
)
func TestNormalizeEnvName(t *testing.T) {
cases := []struct {
input string
output string
}{
{
input: "a",
output: "A",
},
{
input: "a.a",
output: "A_A",
},
{
input: "A.A.B",
output: "A_A_B",
},
}
for _, test := range cases {
if v := NormalizeEnvName(test.input); v != test.output {
t.Error("unexpected output: ", v, " want ", test.output)
}
}
}
func TestEnvFlag(t *testing.T) {
if v := (EnvFlag{
Name: "xxxxx.y",
}.GetValueAsInt(10)); v != 10 {
t.Error("env value: ", v)
}
}
func TestGetAssetLocation(t *testing.T) {
exec, err := os.Executable()
common.Must(err)
loc := GetAssetLocation("t")
if filepath.Dir(loc) != filepath.Dir(exec) {
t.Error("asset dir: ", loc, " not in ", exec)
}
os.Setenv("xray.location.asset", "/xray")
if runtime.GOOS == "windows" {
if v := GetAssetLocation("t"); v != "\\xray\\t" {
t.Error("asset loc: ", v)
}
} else {
if v := GetAssetLocation("t"); v != "/xray/t" {
t.Error("asset loc: ", v)
}
}
}

View file

@ -0,0 +1,27 @@
// +build windows
package platform
import "path/filepath"
func ExpandEnv(s string) string {
// TODO
return s
}
func LineSeparator() string {
return "\r\n"
}
func GetToolLocation(file string) string {
const name = "xray.location.tool"
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(toolPath, file+".exe")
}
// GetAssetLocation search for `file` in the excutable dir
func GetAssetLocation(file string) string {
const name = "xray.location.asset"
assetPath := NewEnvFlag(name).GetValue(getExecutableDir)
return filepath.Join(assetPath, file)
}

View file

@ -0,0 +1,11 @@
package protocol
// Account is a user identity used for authentication.
type Account interface {
Equals(Account) bool
}
// AsAccount is an object can be converted into account.
type AsAccount interface {
AsAccount() (Account, error)
}

Some files were not shown because too many files have changed in this diff Show more