diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 818c5718..3168b59a 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -76,7 +76,7 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err return errors.New("failed to create mux client worker").Base(err).AtWarning() } - worker, err := NewPortalWorker(muxClient) + worker, err := NewPortalWorker(muxClient, p.picker) if err != nil { return errors.New("failed to create portal worker").Base(err) } @@ -170,7 +170,7 @@ func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) { if w.draining { continue } - if w.client.Closed() { + if w.IsFull() { continue } if w.client.ActiveConnections() < minConn { @@ -206,14 +206,16 @@ func (p *StaticMuxPicker) AddWorker(worker *PortalWorker) { } type PortalWorker struct { - client *mux.ClientWorker - control *task.Periodic - writer buf.Writer - reader buf.Reader - draining bool + client *mux.ClientWorker + control *task.Periodic + writer buf.Writer + reader buf.Reader + draining bool + counter uint32 + parentPicker *StaticMuxPicker } -func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { +func NewPortalWorker(client *mux.ClientWorker, picker *StaticMuxPicker) (*PortalWorker, error) { opt := []pipe.Option{pipe.WithSizeLimit(16 * 1024)} uplinkReader, uplinkWriter := pipe.New(opt...) downlinkReader, downlinkWriter := pipe.New(opt...) @@ -231,9 +233,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { return nil, errors.New("unable to dispatch control connection") } w := &PortalWorker{ - client: client, - reader: downlinkReader, - writer: uplinkWriter, + client: client, + reader: downlinkReader, + writer: uplinkWriter, + parentPicker: picker, } w.control = &task.Periodic{ Execute: w.heartbeat, @@ -244,32 +247,32 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { } func (w *PortalWorker) heartbeat() error { - if w.client.Closed() { + if w.Closed() { return errors.New("client worker stopped") } - if w.draining || w.writer == nil { + if w.writer == nil { return errors.New("already disposed") } msg := &Control{} msg.FillInRandom() - if w.client.TotalConnections() > 256 { + if w.draining || w.client.TotalConnections() > 256 { w.draining = true msg.State = Control_DRAIN - defer func() { - common.Close(w.writer) - common.Interrupt(w.reader) - w.writer = nil - }() + defer w.tryCloseHeartbeat() } - b, err := proto.Marshal(msg) - common.Must(err) - mb := buf.MergeBytes(nil, b) - return w.writer.WriteMultiBuffer(mb) + w.counter = (w.counter + 1) % 5 + if w.draining || w.counter == 1 { + b, err := proto.Marshal(msg) + common.Must(err) + mb := buf.MergeBytes(nil, b) + return w.writer.WriteMultiBuffer(mb) + } + return nil } func (w *PortalWorker) IsFull() bool { @@ -279,3 +282,20 @@ func (w *PortalWorker) IsFull() bool { func (w *PortalWorker) Closed() bool { return w.client.Closed() } + +func (w *PortalWorker) tryCloseHeartbeat() { + w.parentPicker.access.Lock() + closeable := false + for _, wo := range w.parentPicker.workers { + if wo != w && !wo.IsFull() && !wo.draining && wo.writer != nil { + closeable = true + break + } + } + w.parentPicker.access.Unlock() + if closeable { + common.Close(w.writer) + common.Interrupt(w.reader) + w.writer = nil + } +}