This commit is contained in:
patterniha 2025-07-21 14:55:50 +08:00 committed by GitHub
commit fbf34393c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 27 additions and 16 deletions

View file

@ -276,6 +276,8 @@ func (m *ClientWorker) IsClosing() bool {
return false return false
} }
// IsFull returns true if this ClientWorker is unable to accept more connections.
// it might be because it is closing, or the number of connections has reached the limit.
func (m *ClientWorker) IsFull() bool { func (m *ClientWorker) IsFull() bool {
if m.IsClosing() || m.Closed() { if m.IsClosing() || m.Closed() {
return true return true
@ -289,12 +291,12 @@ func (m *ClientWorker) IsFull() bool {
} }
func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool { func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
if m.IsFull() || m.Closed() { if m.IsFull() {
return false return false
} }
sm := m.sessionManager sm := m.sessionManager
s := sm.Allocate() s := sm.Allocate(&m.strategy)
if s == nil { if s == nil {
return false return false
} }

View file

@ -201,11 +201,12 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
transferType: protocol.TransferTypePacket, transferType: protocol.TransferTypePacket,
XUDP: x, XUDP: x,
} }
go handle(ctx, x.Mux, w.link.Writer)
x.Status = Active x.Status = Active
if !w.sessionManager.Add(x.Mux) { if !w.sessionManager.Add(x.Mux) {
x.Mux.Close(false) x.Mux.Close(false)
return errors.New("failed to add new session")
} }
go handle(ctx, x.Mux, w.link.Writer)
return nil return nil
} }
@ -226,18 +227,23 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
if meta.Target.Network == net.Network_UDP { if meta.Target.Network == net.Network_UDP {
s.transferType = protocol.TransferTypePacket s.transferType = protocol.TransferTypePacket
} }
w.sessionManager.Add(s) if !w.sessionManager.Add(s) {
s.Close(false)
return errors.New("failed to add new session")
}
go handle(ctx, s, w.link.Writer) go handle(ctx, s, w.link.Writer)
if !meta.Option.Has(OptionData) { if !meta.Option.Has(OptionData) {
return nil return nil
} }
rr := s.NewReader(reader, &meta.Target) rr := s.NewReader(reader, &meta.Target)
if err := buf.Copy(rr, s.output); err != nil { err = buf.Copy(rr, s.output)
buf.Copy(rr, buf.Discard)
return s.Close(false) if err != nil && buf.IsWriteError(err) {
s.Close(false)
return buf.Copy(rr, buf.Discard)
} }
return nil return err
} }
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error { func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
@ -304,10 +310,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
} }
func (w *ServerWorker) run(ctx context.Context) { func (w *ServerWorker) run(ctx context.Context) {
input := w.link.Reader reader := &buf.BufferedReader{Reader: w.link.Reader}
reader := &buf.BufferedReader{Reader: input}
defer w.sessionManager.Close() defer w.sessionManager.Close()
defer common.Close(w.link.Writer)
defer common.Interrupt(w.link.Reader)
for { for {
select { select {
@ -318,7 +325,6 @@ func (w *ServerWorker) run(ctx context.Context) {
if err != nil { if err != nil {
if errors.Cause(err) != io.EOF { if errors.Cause(err) != io.EOF {
errors.LogInfoInner(ctx, err, "unexpected EOF") errors.LogInfoInner(ctx, err, "unexpected EOF")
common.Interrupt(input)
} }
return return
} }

View file

@ -50,11 +50,14 @@ func (m *SessionManager) Count() int {
return int(m.count) return int(m.count)
} }
func (m *SessionManager) Allocate() *Session { func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
MaxConcurrency := int(Strategy.MaxConcurrency)
MaxConnection := uint16(Strategy.MaxConnection)
if m.closed { if m.closed || (MaxConcurrency > 0 && len(m.sessions) >= MaxConcurrency) || (MaxConnection > 0 && m.count >= MaxConnection) {
return nil return nil
} }

View file

@ -9,7 +9,7 @@ import (
func TestSessionManagerAdd(t *testing.T) { func TestSessionManagerAdd(t *testing.T) {
m := NewSessionManager() m := NewSessionManager()
s := m.Allocate() s := m.Allocate(&ClientStrategy{})
if s.ID != 1 { if s.ID != 1 {
t.Error("id: ", s.ID) t.Error("id: ", s.ID)
} }
@ -17,7 +17,7 @@ func TestSessionManagerAdd(t *testing.T) {
t.Error("size: ", m.Size()) t.Error("size: ", m.Size())
} }
s = m.Allocate() s = m.Allocate(&ClientStrategy{})
if s.ID != 2 { if s.ID != 2 {
t.Error("id: ", s.ID) t.Error("id: ", s.ID)
} }
@ -39,7 +39,7 @@ func TestSessionManagerAdd(t *testing.T) {
func TestSessionManagerClose(t *testing.T) { func TestSessionManagerClose(t *testing.T) {
m := NewSessionManager() m := NewSessionManager()
s := m.Allocate() s := m.Allocate(&ClientStrategy{})
if m.CloseIfNoSession() { if m.CloseIfNoSession() {
t.Error("able to close") t.Error("able to close")