diff --git a/common/mux/client.go b/common/mux/client.go index df74ef17..9df339e0 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -334,9 +334,13 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere rr := s.NewReader(reader, &meta.Target) err := buf.Copy(rr, s.output) + if err != nil { + s.Close(false) + } if err != nil && buf.IsWriteError(err) { errors.LogInfoInner(context.Background(), err, "failed to write to downstream. closing session ", s.ID) - s.Close(false) + closingWriter := NewResponseWriter(s.ID, m.link.Writer, protocol.TransferTypeStream) + closingWriter.Close() return buf.Copy(rr, buf.Discard) } diff --git a/common/mux/server.go b/common/mux/server.go index 8fc0ae09..2e009e2c 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -201,11 +201,12 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, transferType: protocol.TransferTypePacket, XUDP: x, } - go handle(ctx, x.Mux, w.link.Writer) - x.Status = Active if !w.sessionManager.Add(x.Mux) { x.Mux.Close(false) + return errors.New("failed to add session.") } + x.Status = Active + go handle(ctx, x.Mux, w.link.Writer) return nil } @@ -226,18 +227,26 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, if meta.Target.Network == net.Network_UDP { s.transferType = protocol.TransferTypePacket } - w.sessionManager.Add(s) + if !w.sessionManager.Add(s) { + s.Close(false) + return errors.New("failed to add session.") + } go handle(ctx, s, w.link.Writer) if !meta.Option.Has(OptionData) { return nil } rr := s.NewReader(reader, &meta.Target) - if err := buf.Copy(rr, s.output); err != nil { - buf.Copy(rr, buf.Discard) - return s.Close(false) + err = buf.Copy(rr, s.output) + if err != nil { + s.Close(false) } - return nil + if err != nil && buf.IsWriteError(err) { + closingWriter := NewResponseWriter(s.ID, w.link.Writer, protocol.TransferTypeStream) + closingWriter.Close() + return buf.Copy(rr, buf.Discard) + } + return err } func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error { @@ -256,10 +265,13 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere rr := s.NewReader(reader, &meta.Target) err := buf.Copy(rr, s.output) - + if err != nil { + s.Close(false) + } if err != nil && buf.IsWriteError(err) { errors.LogInfoInner(context.Background(), err, "failed to write to downstream writer. closing session ", s.ID) - s.Close(false) + closingWriter := NewResponseWriter(s.ID, w.link.Writer, protocol.TransferTypeStream) + closingWriter.Close() return buf.Copy(rr, buf.Discard) } @@ -304,10 +316,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead } func (w *ServerWorker) run(ctx context.Context) { - input := w.link.Reader - reader := &buf.BufferedReader{Reader: input} + reader := &buf.BufferedReader{Reader: w.link.Reader} defer w.sessionManager.Close() + defer common.Close(w.link.Writer) + defer common.Interrupt(w.link.Reader) for { select { @@ -318,7 +331,6 @@ func (w *ServerWorker) run(ctx context.Context) { if err != nil { if errors.Cause(err) != io.EOF { errors.LogInfoInner(ctx, err, "unexpected EOF") - common.Interrupt(input) } return }