From 6526e74d49eec5a4bb9a5e448a271d17f262d64b Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Thu, 2 Mar 2023 14:50:26 +0000 Subject: [PATCH] Add WaitReadCloser to make H2 real 0-RTT --- transport/internet/http/dialer.go | 66 ++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index 25ede63f..75adc249 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -3,6 +3,7 @@ package http import ( "context" gotls "crypto/tls" + "io" "net/http" "net/url" "sync" @@ -166,23 +167,70 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me // Disable any compression method from server. request.Header.Set("Accept-Encoding", "identity") - response, err := client.Do(request) - if err != nil { - return nil, newError("failed to dial to ", dest).Base(err).AtWarning() - } - if response.StatusCode != 200 { - return nil, newError("unexpected status", response.StatusCode).AtWarning() - } + wrc := &WaitReadCloser{Wait: make(chan struct{})} + go func() { + response, err := client.Do(request) + if err != nil { + newError("failed to dial to ", dest).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) + wrc.Close() + return + } + if response.StatusCode != 200 { + newError("unexpected status", response.StatusCode).AtWarning().WriteToLog(session.ExportIDToError(ctx)) + wrc.Close() + return + } + wrc.Set(response.Body) + }() bwriter := buf.NewBufferedWriter(pwriter) common.Must(bwriter.SetBuffered(false)) return cnc.NewConnection( - cnc.ConnectionOutput(response.Body), + cnc.ConnectionOutput(wrc), cnc.ConnectionInput(bwriter), - cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, response.Body}), + cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, wrc}), ), nil } func init() { common.Must(internet.RegisterTransportDialer(protocolName, Dial)) } + +type WaitReadCloser struct { + Wait chan struct{} + io.ReadCloser +} + +func (w *WaitReadCloser) Set(rc io.ReadCloser) { + w.ReadCloser = rc + defer func() { + if err := recover(); err != nil { + rc.Close() + } + }() + close(w.Wait) +} + +func (w *WaitReadCloser) Read(b []byte) (int, error) { + if w.ReadCloser == nil { + if <-w.Wait; w.ReadCloser == nil { + return 0, io.ErrClosedPipe + } + } + return w.ReadCloser.Read(b) +} + +func (w *WaitReadCloser) Close() error { + if w.ReadCloser != nil { + return w.ReadCloser.Close() + } + defer func() { + if err := recover(); err != nil { + if w.ReadCloser != nil { + w.ReadCloser.Close() + } + } + }() + close(w.Wait) + return nil +}