diff --git a/app/reverse/config.go b/app/reverse/config.go index 517b6170..8ce38c9c 100644 --- a/app/reverse/config.go +++ b/app/reverse/config.go @@ -9,6 +9,7 @@ import ( func (c *Control) FillInRandom() { randomLength := dice.Roll(64) + randomLength++ c.Random = make([]byte, randomLength) io.ReadFull(rand.Reader, c.Random) } diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 818c5718..8efe3996 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -170,7 +170,7 @@ func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) { if w.draining { continue } - if w.client.Closed() { + if w.IsFull() { continue } if w.client.ActiveConnections() < minConn { @@ -211,6 +211,7 @@ type PortalWorker struct { writer buf.Writer reader buf.Reader draining bool + counter uint32 } func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { @@ -244,7 +245,7 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { } func (w *PortalWorker) heartbeat() error { - if w.client.Closed() { + if w.Closed() { return errors.New("client worker stopped") } @@ -260,16 +261,21 @@ func (w *PortalWorker) heartbeat() error { msg.State = Control_DRAIN defer func() { + w.client.GetTimer().Reset(time.Second * 16) common.Close(w.writer) common.Interrupt(w.reader) w.writer = nil }() } - b, err := proto.Marshal(msg) - common.Must(err) - mb := buf.MergeBytes(nil, b) - return w.writer.WriteMultiBuffer(mb) + w.counter = (w.counter + 1) % 5 + if w.draining || w.counter == 1 { + b, err := proto.Marshal(msg) + common.Must(err) + mb := buf.MergeBytes(nil, b) + return w.writer.WriteMultiBuffer(mb) + } + return nil } func (w *PortalWorker) IsFull() bool { diff --git a/common/mux/client.go b/common/mux/client.go index df74ef17..3ebcc182 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -173,6 +173,7 @@ type ClientWorker struct { sessionManager *SessionManager link transport.Link done *done.Instance + timer *time.Ticker strategy ClientStrategy } @@ -187,6 +188,7 @@ func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, er sessionManager: NewSessionManager(), link: stream, done: done.New(), + timer: time.NewTicker(time.Second * 16), strategy: s, } @@ -209,9 +211,12 @@ func (m *ClientWorker) Closed() bool { return m.done.Done() } +func (m *ClientWorker) GetTimer() *time.Ticker { + return m.timer +} + func (m *ClientWorker) monitor() { - timer := time.NewTicker(time.Second * 16) - defer timer.Stop() + defer m.timer.Stop() for { select { @@ -220,7 +225,7 @@ func (m *ClientWorker) monitor() { common.Close(m.link.Writer) common.Interrupt(m.link.Reader) return - case <-timer.C: + case <-m.timer.C: size := m.sessionManager.Size() if size == 0 && m.sessionManager.CloseIfNoSession() { common.Must(m.done.Close())