package kcp import ( "bytes" "context" "io" "net" "runtime" "sync" "sync/atomic" "time" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal/semaphore" ) var ( ErrIOTimeout = errors.New("Read/Write timeout") ErrClosedListener = errors.New("Listener closed.") ErrClosedConnection = errors.New("Connection closed.") ) // State of the connection type State int32 // Is returns true if current State is one of the candidates. func (s State) Is(states ...State) bool { for _, state := range states { if s == state { return true } } return false } const ( StateActive State = 0 // Connection is active StateReadyToClose State = 1 // Connection is closed locally StatePeerClosed State = 2 // Connection is closed on remote StateTerminating State = 3 // Connection is ready to be destroyed locally StatePeerTerminating State = 4 // Connection is ready to be destroyed on remote StateTerminated State = 5 // Connection is destroyed. ) func nowMillisec() int64 { now := time.Now() return now.Unix()*1000 + int64(now.Nanosecond()/1000000) } type RoundTripInfo struct { sync.RWMutex variation uint32 srtt uint32 rto uint32 minRtt uint32 updatedTimestamp uint32 } func (info *RoundTripInfo) UpdatePeerRTO(rto uint32, current uint32) { info.Lock() defer info.Unlock() if current-info.updatedTimestamp < 3000 { return } info.updatedTimestamp = current info.rto = rto } func (info *RoundTripInfo) Update(rtt uint32, current uint32) { if rtt > 0x7FFFFFFF { return } info.Lock() defer info.Unlock() // https://tools.ietf.org/html/rfc6298 if info.srtt == 0 { info.srtt = rtt info.variation = rtt / 2 } else { delta := rtt - info.srtt if info.srtt > rtt { delta = info.srtt - rtt } info.variation = (3*info.variation + delta) / 4 info.srtt = (7*info.srtt + rtt) / 8 if info.srtt < info.minRtt { info.srtt = info.minRtt } } var rto uint32 if info.minRtt < 4*info.variation { rto = info.srtt + 4*info.variation } else { rto = info.srtt + info.variation } if rto > 10000 { rto = 10000 } info.rto = rto * 5 / 4 info.updatedTimestamp = current } func (info *RoundTripInfo) Timeout() uint32 { info.RLock() defer info.RUnlock() return info.rto } func (info *RoundTripInfo) SmoothedTime() uint32 { info.RLock() defer info.RUnlock() return info.srtt } type Updater struct { interval int64 shouldContinue func() bool shouldTerminate func() bool updateFunc func() notifier *semaphore.Instance } func NewUpdater(interval uint32, shouldContinue func() bool, shouldTerminate func() bool, updateFunc func()) *Updater { u := &Updater{ interval: int64(time.Duration(interval) * time.Millisecond), shouldContinue: shouldContinue, shouldTerminate: shouldTerminate, updateFunc: updateFunc, notifier: semaphore.New(1), } return u } func (u *Updater) WakeUp() { select { case <-u.notifier.Wait(): go u.run() default: } } func (u *Updater) run() { defer u.notifier.Signal() if u.shouldTerminate() { return } ticker := time.NewTicker(u.Interval()) for u.shouldContinue() { u.updateFunc() <-ticker.C } ticker.Stop() } func (u *Updater) Interval() time.Duration { return time.Duration(atomic.LoadInt64(&u.interval)) } func (u *Updater) SetInterval(d time.Duration) { atomic.StoreInt64(&u.interval, int64(d)) } type ConnMetadata struct { LocalAddr net.Addr RemoteAddr net.Addr Conversation uint16 } // Connection is a KCP connection over UDP. type Connection struct { meta ConnMetadata closer io.Closer rd time.Time wd time.Time // write deadline since int64 dataInput *signal.Notifier dataOutput *signal.Notifier Config *Config state State stateBeginTime uint32 lastIncomingTime uint32 lastPingTime uint32 mss uint32 roundTrip *RoundTripInfo receivingWorker *ReceivingWorker sendingWorker *SendingWorker output SegmentWriter dataUpdater *Updater pingUpdater *Updater } // NewConnection create a new KCP connection between local and remote. func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection { errors.LogInfo(context.Background(), "#", meta.Conversation, " creating connection to ", meta.RemoteAddr) conn := &Connection{ meta: meta, closer: closer, since: nowMillisec(), dataInput: signal.NewNotifier(), dataOutput: signal.NewNotifier(), Config: config, output: NewRetryableWriter(NewSegmentWriter(writer)), mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead, roundTrip: &RoundTripInfo{ rto: 100, minRtt: config.GetTTIValue(), }, } conn.receivingWorker = NewReceivingWorker(conn) conn.sendingWorker = NewSendingWorker(conn) isTerminating := func() bool { return conn.State().Is(StateTerminating, StateTerminated) } isTerminated := func() bool { return conn.State() == StateTerminated } conn.dataUpdater = NewUpdater( config.GetTTIValue(), func() bool { return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary()) }, isTerminating, conn.updateTask) conn.pingUpdater = NewUpdater( 5000, // 5 seconds func() bool { return !isTerminated() }, isTerminated, conn.updateTask) conn.pingUpdater.WakeUp() return conn } func (c *Connection) Elapsed() uint32 { return uint32(nowMillisec() - c.since) } // ReadMultiBuffer implements buf.Reader. func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { if c == nil { return nil, io.EOF } for { if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { return nil, io.EOF } mb := c.receivingWorker.ReadMultiBuffer() if !mb.IsEmpty() { c.dataUpdater.WakeUp() return mb, nil } if c.State() == StatePeerTerminating { return nil, io.EOF } if err := c.waitForDataInput(); err != nil { return nil, err } } } func (c *Connection) waitForDataInput() error { for i := 0; i < 16; i++ { select { case <-c.dataInput.Wait(): return nil default: runtime.Gosched() } } duration := time.Second * 16 if !c.rd.IsZero() { duration = time.Until(c.rd) if duration < 0 { return ErrIOTimeout } } timeout := time.NewTimer(duration) defer timeout.Stop() select { case <-c.dataInput.Wait(): case <-timeout.C: if !c.rd.IsZero() && c.rd.Before(time.Now()) { return ErrIOTimeout } } return nil } // Read implements the Conn Read method. func (c *Connection) Read(b []byte) (int, error) { if c == nil { return 0, io.EOF } for { if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { return 0, io.EOF } nBytes := c.receivingWorker.Read(b) if nBytes > 0 { c.dataUpdater.WakeUp() return nBytes, nil } if err := c.waitForDataInput(); err != nil { return 0, err } } } func (c *Connection) waitForDataOutput() error { for i := 0; i < 16; i++ { select { case <-c.dataOutput.Wait(): return nil default: runtime.Gosched() } } duration := time.Second * 16 if !c.wd.IsZero() { duration = time.Until(c.wd) if duration < 0 { return ErrIOTimeout } } timeout := time.NewTimer(duration) defer timeout.Stop() select { case <-c.dataOutput.Wait(): case <-timeout.C: if !c.wd.IsZero() && c.wd.Before(time.Now()) { return ErrIOTimeout } } return nil } // Write implements io.Writer. func (c *Connection) Write(b []byte) (int, error) { reader := bytes.NewReader(b) if err := c.writeMultiBufferInternal(reader); err != nil { return 0, err } return len(b), nil } // WriteMultiBuffer implements buf.Writer. func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { reader := &buf.MultiBufferContainer{ MultiBuffer: mb, } defer reader.Close() return c.writeMultiBufferInternal(reader) } func (c *Connection) writeMultiBufferInternal(reader io.Reader) error { updatePending := false defer func() { if updatePending { c.dataUpdater.WakeUp() } }() var b *buf.Buffer defer b.Release() for { for { if c == nil || c.State() != StateActive { return io.ErrClosedPipe } if b == nil { b = buf.New() _, err := b.ReadFrom(io.LimitReader(reader, int64(c.mss))) if err != nil { return nil } } if !c.sendingWorker.Push(b) { break } updatePending = true b = nil } if updatePending { c.dataUpdater.WakeUp() updatePending = false } if err := c.waitForDataOutput(); err != nil { return err } } } func (c *Connection) SetState(state State) { current := c.Elapsed() atomic.StoreInt32((*int32)(&c.state), int32(state)) atomic.StoreUint32(&c.stateBeginTime, current) errors.LogDebug(context.Background(), "#", c.meta.Conversation, " entering state ", state, " at ", current) switch state { case StateReadyToClose: c.receivingWorker.CloseRead() case StatePeerClosed: c.sendingWorker.CloseWrite() case StateTerminating: c.receivingWorker.CloseRead() c.sendingWorker.CloseWrite() c.pingUpdater.SetInterval(time.Second) case StatePeerTerminating: c.sendingWorker.CloseWrite() c.pingUpdater.SetInterval(time.Second) case StateTerminated: c.receivingWorker.CloseRead() c.sendingWorker.CloseWrite() c.pingUpdater.SetInterval(time.Second) c.dataUpdater.WakeUp() c.pingUpdater.WakeUp() go c.Terminate() } } // Close closes the connection. func (c *Connection) Close() error { if c == nil { return ErrClosedConnection } c.dataInput.Signal() c.dataOutput.Signal() switch c.State() { case StateReadyToClose, StateTerminating, StateTerminated: return ErrClosedConnection case StateActive: c.SetState(StateReadyToClose) case StatePeerClosed: c.SetState(StateTerminating) case StatePeerTerminating: c.SetState(StateTerminated) } errors.LogInfo(context.Background(), "#", c.meta.Conversation, " closing connection to ", c.meta.RemoteAddr) return nil } // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. func (c *Connection) LocalAddr() net.Addr { if c == nil { return nil } return c.meta.LocalAddr } // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. func (c *Connection) RemoteAddr() net.Addr { if c == nil { return nil } return c.meta.RemoteAddr } // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. func (c *Connection) SetDeadline(t time.Time) error { if err := c.SetReadDeadline(t); err != nil { return err } return c.SetWriteDeadline(t) } // SetReadDeadline implements the Conn SetReadDeadline method. func (c *Connection) SetReadDeadline(t time.Time) error { if c == nil || c.State() != StateActive { return ErrClosedConnection } c.rd = t return nil } // SetWriteDeadline implements the Conn SetWriteDeadline method. func (c *Connection) SetWriteDeadline(t time.Time) error { if c == nil || c.State() != StateActive { return ErrClosedConnection } c.wd = t return nil } // kcp update, input loop func (c *Connection) updateTask() { c.flush() } func (c *Connection) Terminate() { if c == nil { return } errors.LogInfo(context.Background(), "#", c.meta.Conversation, " terminating connection to ", c.RemoteAddr()) // v.SetState(StateTerminated) c.dataInput.Signal() c.dataOutput.Signal() c.closer.Close() c.sendingWorker.Release() c.receivingWorker.Release() } func (c *Connection) HandleOption(opt SegmentOption) { if (opt & SegmentOptionClose) == SegmentOptionClose { c.OnPeerClosed() } } func (c *Connection) OnPeerClosed() { switch c.State() { case StateReadyToClose: c.SetState(StateTerminating) case StateActive: c.SetState(StatePeerClosed) } } // Input when you received a low level packet (eg. UDP packet), call it func (c *Connection) Input(segments []Segment) { current := c.Elapsed() atomic.StoreUint32(&c.lastIncomingTime, current) for _, seg := range segments { if seg.Conversation() != c.meta.Conversation { break } switch seg := seg.(type) { case *DataSegment: c.HandleOption(seg.Option) c.receivingWorker.ProcessSegment(seg) if c.receivingWorker.IsDataAvailable() { c.dataInput.Signal() } c.dataUpdater.WakeUp() case *AckSegment: c.HandleOption(seg.Option) c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout()) c.dataOutput.Signal() c.dataUpdater.WakeUp() case *CmdOnlySegment: c.HandleOption(seg.Option) if seg.Command() == CommandTerminate { switch c.State() { case StateActive, StatePeerClosed: c.SetState(StatePeerTerminating) case StateReadyToClose: c.SetState(StateTerminating) case StateTerminating: c.SetState(StateTerminated) } } if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate { c.dataInput.Signal() c.dataOutput.Signal() } c.sendingWorker.ProcessReceivingNext(seg.ReceivingNext) c.receivingWorker.ProcessSendingNext(seg.SendingNext) c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current) seg.Release() default: } } } func (c *Connection) flush() { current := c.Elapsed() if c.State() == StateTerminated { return } if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 { c.Close() } if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() { c.SetState(StateTerminating) } if c.State() == StateTerminating { errors.LogDebug(context.Background(), "#", c.meta.Conversation, " sending terminating cmd.") c.Ping(current, CommandTerminate) if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 { c.SetState(StateTerminated) } return } if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 { c.SetState(StateTerminating) } if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 { c.SetState(StateTerminating) } // flush acknowledges c.receivingWorker.Flush(current) c.sendingWorker.Flush(current) if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 { c.Ping(current, CommandPing) } } func (c *Connection) State() State { return State(atomic.LoadInt32((*int32)(&c.state))) } func (c *Connection) Ping(current uint32, cmd Command) { seg := NewCmdOnlySegment() seg.Conv = c.meta.Conversation seg.Cmd = cmd seg.ReceivingNext = c.receivingWorker.NextNumber() seg.SendingNext = c.sendingWorker.FirstUnacknowledged() seg.PeerRTO = c.roundTrip.Timeout() if c.State() == StateReadyToClose { seg.Option = SegmentOptionClose } c.output.Write(seg) atomic.StoreUint32(&c.lastPingTime, current) seg.Release() }