Parse big server hello properly

This commit is contained in:
yuhan6665 2022-11-21 22:37:22 -05:00
parent e5e9e58d66
commit d87758d46f
3 changed files with 55 additions and 32 deletions

View File

@ -247,7 +247,9 @@ func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, c
} }
// XtlsRead filter and read xtls protocol // XtlsRead filter and read xtls protocol
func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn, counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool) error { func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn,
counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool,
isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32) error {
err := func() error { err := func() error {
var ct stats.Counter var ct stats.Counter
filterUUID := true filterUUID := true
@ -306,7 +308,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
} }
} }
if *numberOfPacketToFilter > 0 { if *numberOfPacketToFilter > 0 {
XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, ctx) XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
} }
if ct != nil { if ct != nil {
ct.Add(int64(buffer.Len())) ct.Add(int64(buffer.Len()))
@ -328,7 +330,9 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
} }
// XtlsWrite filter and write xtls protocol // XtlsWrite filter and write xtls protocol
func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter, ctx context.Context, userUUID *[]byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool) error { func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter,
ctx context.Context, userUUID *[]byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool,
cipher *uint16, remainingServerHello *int32) error {
err := func() error { err := func() error {
var ct stats.Counter var ct stats.Counter
filterTlsApplicationData := true filterTlsApplicationData := true
@ -337,7 +341,7 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
buffer, err := reader.ReadMultiBuffer() buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() { if !buffer.IsEmpty() {
if *numberOfPacketToFilter > 0 { if *numberOfPacketToFilter > 0 {
XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, ctx) XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
} }
if filterTlsApplicationData && *isTLS { if filterTlsApplicationData && *isTLS {
buffer = ReshapeMultiBuffer(ctx, buffer) buffer = ReshapeMultiBuffer(ctx, buffer)
@ -399,40 +403,51 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
} }
// XtlsFilterTls filter and recognize tls 1.3 and other info // XtlsFilterTls filter and recognize tls 1.3 and other info
func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, ctx context.Context) { func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool,
cipher *uint16, remainingServerHello *int32, ctx context.Context) {
for _, b := range buffer { for _, b := range buffer {
*numberOfPacketToFilter-- *numberOfPacketToFilter--
if b.Len() >= 6 { if b.Len() >= 6 {
startsBytes := b.BytesTo(6) startsBytes := b.BytesTo(6)
if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == 0x02 { if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == 0x02 {
total := (int(startsBytes[3])<<8 | int(startsBytes[4])) + 5 *remainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5
if b.Len() >= 74 && total >= 74 {
if bytes.Contains(b.BytesTo(int32(total)), tls13SupportedVersions) {
sessionIdLen := int32(b.Byte(43))
cipherSuite := b.BytesRange(43 + sessionIdLen + 1, 43 + sessionIdLen + 3)
cipherNum := uint16(cipherSuite[0]) << 8 | uint16(cipherSuite[1])
v, ok := Tls13CipherSuiteDic[cipherNum]
if !ok {
v = "Unknown cipher!"
} else if (v != "TLS_AES_128_CCM_8_SHA256") {
*enableXtls = true
}
newError("XtlsFilterTls found tls 1.3! ", buffer.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx))
} else {
newError("XtlsFilterTls found tls 1.2! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
}
*isTLS12orAbove = true *isTLS12orAbove = true
*isTLS = true *isTLS = true
*numberOfPacketToFilter = 0 if b.Len() >= 79 && *remainingServerHello >= 79 {
return sessionIdLen := int32(b.Byte(43))
cipherSuite := b.BytesRange(43 + sessionIdLen + 1, 43 + sessionIdLen + 3)
*cipher = uint16(cipherSuite[0]) << 8 | uint16(cipherSuite[1])
} else { } else {
newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", total).WriteToLog(session.ExportIDToError(ctx)) newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx))
} }
} else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == 0x01 { } else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == 0x01 {
*isTLS = true *isTLS = true
newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
} }
} }
if *remainingServerHello > 0 {
end := *remainingServerHello
if end > b.Len() {
end = b.Len()
}
*remainingServerHello -= b.Len()
if bytes.Contains(b.BytesTo(end), tls13SupportedVersions) {
v, ok := Tls13CipherSuiteDic[*cipher]
if !ok {
v = "Old cipher: " + strconv.FormatUint(uint64(*cipher), 16)
} else if (v != "TLS_AES_128_CCM_8_SHA256") {
*enableXtls = true
}
newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx))
*numberOfPacketToFilter = 0
return
} else if *remainingServerHello <= 0 {
newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx))
*numberOfPacketToFilter = 0
return
}
newError("XtlsFilterTls inclusive server hello ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx))
}
if *numberOfPacketToFilter <= 0 { if *numberOfPacketToFilter <= 0 {
newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
} }

View File

@ -511,6 +511,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
enableXtls := false enableXtls := false
isTLS12orAbove := false isTLS12orAbove := false
isTLS := false isTLS := false
var cipher uint16 = 0
var remainingServerHello int32 = -1
numberOfPacketToFilter := 8 numberOfPacketToFilter := 8
postRequest := func() error { postRequest := func() error {
@ -529,7 +531,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
//TODO enable splice //TODO enable splice
ctx = session.ContextWithInbound(ctx, nil) ctx = session.ContextWithInbound(ctx, nil)
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(),
&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
err = encoding.ReadV(clientReader, serverWriter, timer, iConn.(*xtls.Conn), rawConn, counter, ctx) err = encoding.ReadV(clientReader, serverWriter, timer, iConn.(*xtls.Conn), rawConn, counter, ctx)
} }
@ -561,7 +564,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
return err1 // ... return err1 // ...
} }
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, ctx) encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
if isTLS { if isTLS {
multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer) multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
for i, b := range multiBuffer { for i, b := range multiBuffer {
@ -583,7 +586,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
if statConn != nil { if statConn != nil {
counter = statConn.WriteCounter counter = statConn.WriteCounter
} }
err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter,
&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

View File

@ -193,6 +193,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
enableXtls := false enableXtls := false
isTLS12orAbove := false isTLS12orAbove := false
isTLS := false isTLS := false
var cipher uint16 = 0
var remainingServerHello int32 = -1
numberOfPacketToFilter := 8 numberOfPacketToFilter := 8
if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 { if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 {
@ -220,7 +222,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return err1 // ... return err1 // ...
} }
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, ctx) encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
if isTLS { if isTLS {
for i, b := range multiBuffer { for i, b := range multiBuffer {
multiBuffer[i] = encoding.XtlsPadding(b, 0x00, &userUUID, ctx) multiBuffer[i] = encoding.XtlsPadding(b, 0x00, &userUUID, ctx)
@ -241,7 +243,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if statConn != nil { if statConn != nil {
counter = statConn.WriteCounter counter = statConn.WriteCounter
} }
err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter,
&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@ -277,7 +280,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
counter = statConn.ReadCounter counter = statConn.ReadCounter
} }
if requestAddons.Flow == vless.XRV { if requestAddons.Flow == vless.XRV {
err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(),
&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
} else { } else {
if requestAddons.Flow != vless.XRS { if requestAddons.Flow != vless.XRS {
ctx = session.ContextWithInbound(ctx, nil) ctx = session.ContextWithInbound(ctx, nil)