Disable VMess drain when not pure connection

This commit is contained in:
RPRX 2020-12-18 12:45:47 +00:00 committed by GitHub
parent ff9bb2d8df
commit f390047b37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 7 deletions

View File

@ -57,14 +57,14 @@ func TestRequestSerialization(t *testing.T) {
defer common.Close(userValidator) defer common.Close(userValidator)
server := NewServerSession(userValidator, sessionHistory) server := NewServerSession(userValidator, sessionHistory)
actualRequest, err := server.DecodeRequestHeader(buffer) actualRequest, err := server.DecodeRequestHeader(buffer, false)
common.Must(err) common.Must(err)
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
t.Error(r) t.Error(r)
} }
_, err = server.DecodeRequestHeader(buffer2) _, err = server.DecodeRequestHeader(buffer2, false)
// anti replay attack // anti replay attack
if err == nil { if err == nil {
t.Error("nil error") t.Error("nil error")
@ -107,7 +107,7 @@ func TestInvalidRequest(t *testing.T) {
defer common.Close(userValidator) defer common.Close(userValidator)
server := NewServerSession(userValidator, sessionHistory) server := NewServerSession(userValidator, sessionHistory)
_, err := server.DecodeRequestHeader(buffer) _, err := server.DecodeRequestHeader(buffer, false)
if err == nil { if err == nil {
t.Error("nil error") t.Error("nil error")
} }
@ -148,7 +148,7 @@ func TestMuxRequest(t *testing.T) {
defer common.Close(userValidator) defer common.Close(userValidator)
server := NewServerSession(userValidator, sessionHistory) server := NewServerSession(userValidator, sessionHistory)
actualRequest, err := server.DecodeRequestHeader(buffer) actualRequest, err := server.DecodeRequestHeader(buffer, false)
common.Must(err) common.Must(err)
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {

View File

@ -131,7 +131,7 @@ func parseSecurityType(b byte) protocol.SecurityType {
} }
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream. // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) { func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) {
buffer := buf.New() buffer := buf.New()
behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed())) behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed()))
BaseDrainSize := behaviorRand.Roll(3266) BaseDrainSize := behaviorRand.Roll(3266)
@ -143,7 +143,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
drainConnection := func(e error) error { drainConnection := func(e error) error {
// We read a deterministic generated length of data before closing the connection to offset padding read pattern // We read a deterministic generated length of data before closing the connection to offset padding read pattern
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
if readSizeRemain > 0 { if readSizeRemain > 0 && isDrain {
err := s.DrainConnN(reader, readSizeRemain) err := s.DrainConnN(reader, readSizeRemain)
if err != nil { if err != nil {
return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e) return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e)

View File

@ -220,9 +220,18 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
return newError("unable to set read deadline").Base(err).AtWarning() return newError("unable to set read deadline").Base(err).AtWarning()
} }
iConn := connection
if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
iConn = statConn.Connection
}
_, isDrain := iConn.(*net.TCPConn)
if !isDrain {
_, isDrain = iConn.(*net.UnixConn)
}
reader := &buf.BufferedReader{Reader: buf.NewReader(connection)} reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
svrSession := encoding.NewServerSession(h.clients, h.sessionHistory) svrSession := encoding.NewServerSession(h.clients, h.sessionHistory)
request, err := svrSession.DecodeRequestHeader(reader) request, err := svrSession.DecodeRequestHeader(reader, isDrain)
if err != nil { if err != nil {
if errors.Cause(err) != io.EOF { if errors.Cause(err) != io.EOF {
log.Record(&log.AccessMessage{ log.Record(&log.AccessMessage{