diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 3168b59a..8efe3996 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -76,7 +76,7 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err return errors.New("failed to create mux client worker").Base(err).AtWarning() } - worker, err := NewPortalWorker(muxClient, p.picker) + worker, err := NewPortalWorker(muxClient) if err != nil { return errors.New("failed to create portal worker").Base(err) } @@ -206,16 +206,15 @@ func (p *StaticMuxPicker) AddWorker(worker *PortalWorker) { } type PortalWorker struct { - client *mux.ClientWorker - control *task.Periodic - writer buf.Writer - reader buf.Reader - draining bool - counter uint32 - parentPicker *StaticMuxPicker + client *mux.ClientWorker + control *task.Periodic + writer buf.Writer + reader buf.Reader + draining bool + counter uint32 } -func NewPortalWorker(client *mux.ClientWorker, picker *StaticMuxPicker) (*PortalWorker, error) { +func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { opt := []pipe.Option{pipe.WithSizeLimit(16 * 1024)} uplinkReader, uplinkWriter := pipe.New(opt...) downlinkReader, downlinkWriter := pipe.New(opt...) @@ -233,10 +232,9 @@ func NewPortalWorker(client *mux.ClientWorker, picker *StaticMuxPicker) (*Portal return nil, errors.New("unable to dispatch control connection") } w := &PortalWorker{ - client: client, - reader: downlinkReader, - writer: uplinkWriter, - parentPicker: picker, + client: client, + reader: downlinkReader, + writer: uplinkWriter, } w.control = &task.Periodic{ Execute: w.heartbeat, @@ -251,18 +249,23 @@ func (w *PortalWorker) heartbeat() error { return errors.New("client worker stopped") } - if w.writer == nil { + if w.draining || w.writer == nil { return errors.New("already disposed") } msg := &Control{} msg.FillInRandom() - if w.draining || w.client.TotalConnections() > 256 { + if w.client.TotalConnections() > 256 { w.draining = true msg.State = Control_DRAIN - defer w.tryCloseHeartbeat() + defer func() { + w.client.GetTimer().Reset(time.Second * 16) + common.Close(w.writer) + common.Interrupt(w.reader) + w.writer = nil + }() } w.counter = (w.counter + 1) % 5 @@ -282,20 +285,3 @@ func (w *PortalWorker) IsFull() bool { func (w *PortalWorker) Closed() bool { return w.client.Closed() } - -func (w *PortalWorker) tryCloseHeartbeat() { - w.parentPicker.access.Lock() - closeable := false - for _, wo := range w.parentPicker.workers { - if wo != w && !wo.IsFull() && !wo.draining && wo.writer != nil { - closeable = true - break - } - } - w.parentPicker.access.Unlock() - if closeable { - common.Close(w.writer) - common.Interrupt(w.reader) - w.writer = nil - } -} diff --git a/common/mux/client.go b/common/mux/client.go index df74ef17..3ebcc182 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -173,6 +173,7 @@ type ClientWorker struct { sessionManager *SessionManager link transport.Link done *done.Instance + timer *time.Ticker strategy ClientStrategy } @@ -187,6 +188,7 @@ func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, er sessionManager: NewSessionManager(), link: stream, done: done.New(), + timer: time.NewTicker(time.Second * 16), strategy: s, } @@ -209,9 +211,12 @@ func (m *ClientWorker) Closed() bool { return m.done.Done() } +func (m *ClientWorker) GetTimer() *time.Ticker { + return m.timer +} + func (m *ClientWorker) monitor() { - timer := time.NewTicker(time.Second * 16) - defer timer.Stop() + defer m.timer.Stop() for { select { @@ -220,7 +225,7 @@ func (m *ClientWorker) monitor() { common.Close(m.link.Writer) common.Interrupt(m.link.Reader) return - case <-timer.C: + case <-m.timer.C: size := m.sessionManager.Size() if size == 0 && m.sessionManager.CloseIfNoSession() { common.Must(m.done.Close())