diff --git a/common/mux/client.go b/common/mux/client.go index df74ef17..6c9e90bc 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -276,6 +276,8 @@ func (m *ClientWorker) IsClosing() bool { return false } +// IsFull returns true if this ClientWorker is unable to accept more connections. +// it might be because it is closing, or the number of connections has reached the limit. func (m *ClientWorker) IsFull() bool { if m.IsClosing() || m.Closed() { return true @@ -289,12 +291,12 @@ func (m *ClientWorker) IsFull() bool { } func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool { - if m.IsFull() || m.Closed() { + if m.IsFull() { return false } sm := m.sessionManager - s := sm.Allocate() + s := sm.Allocate(&m.strategy) if s == nil { return false } diff --git a/common/mux/server.go b/common/mux/server.go index 8fc0ae09..30dcf06e 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -201,11 +201,12 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, transferType: protocol.TransferTypePacket, XUDP: x, } - go handle(ctx, x.Mux, w.link.Writer) x.Status = Active if !w.sessionManager.Add(x.Mux) { x.Mux.Close(false) + return errors.New("failed to add new session") } + go handle(ctx, x.Mux, w.link.Writer) return nil } @@ -226,18 +227,23 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, if meta.Target.Network == net.Network_UDP { s.transferType = protocol.TransferTypePacket } - w.sessionManager.Add(s) + if !w.sessionManager.Add(s) { + s.Close(false) + return errors.New("failed to add new session") + } go handle(ctx, s, w.link.Writer) if !meta.Option.Has(OptionData) { return nil } rr := s.NewReader(reader, &meta.Target) - if err := buf.Copy(rr, s.output); err != nil { - buf.Copy(rr, buf.Discard) - return s.Close(false) + err = buf.Copy(rr, s.output) + + if err != nil && buf.IsWriteError(err) { + s.Close(false) + return buf.Copy(rr, buf.Discard) } - return nil + return err } func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error { @@ -304,10 +310,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead } func (w *ServerWorker) run(ctx context.Context) { - input := w.link.Reader - reader := &buf.BufferedReader{Reader: input} + reader := &buf.BufferedReader{Reader: w.link.Reader} defer w.sessionManager.Close() + defer common.Close(w.link.Writer) + defer common.Interrupt(w.link.Reader) for { select { @@ -318,7 +325,6 @@ func (w *ServerWorker) run(ctx context.Context) { if err != nil { if errors.Cause(err) != io.EOF { errors.LogInfoInner(ctx, err, "unexpected EOF") - common.Interrupt(input) } return } diff --git a/common/mux/session.go b/common/mux/session.go index 5e4b69ca..8bcb01bb 100644 --- a/common/mux/session.go +++ b/common/mux/session.go @@ -50,11 +50,14 @@ func (m *SessionManager) Count() int { return int(m.count) } -func (m *SessionManager) Allocate() *Session { +func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session { m.Lock() defer m.Unlock() + + MaxConcurrency := int(Strategy.MaxConcurrency) + MaxConnection := uint16(Strategy.MaxConnection) - if m.closed { + if m.closed || (MaxConcurrency > 0 && len(m.sessions) >= MaxConcurrency) || (MaxConnection > 0 && m.count >= MaxConnection) { return nil } diff --git a/common/mux/session_test.go b/common/mux/session_test.go index d81ad8c4..a8491a9c 100644 --- a/common/mux/session_test.go +++ b/common/mux/session_test.go @@ -9,7 +9,7 @@ import ( func TestSessionManagerAdd(t *testing.T) { m := NewSessionManager() - s := m.Allocate() + s := m.Allocate(&ClientStrategy{}) if s.ID != 1 { t.Error("id: ", s.ID) } @@ -17,7 +17,7 @@ func TestSessionManagerAdd(t *testing.T) { t.Error("size: ", m.Size()) } - s = m.Allocate() + s = m.Allocate(&ClientStrategy{}) if s.ID != 2 { t.Error("id: ", s.ID) } @@ -39,7 +39,7 @@ func TestSessionManagerAdd(t *testing.T) { func TestSessionManagerClose(t *testing.T) { m := NewSessionManager() - s := m.Allocate() + s := m.Allocate(&ClientStrategy{}) if m.CloseIfNoSession() { t.Error("able to close")