diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index 8330d5d7..f9b7cba1 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -117,10 +117,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error { req, err := http.NewRequest("POST", url, payload) - req.ContentLength = contentLength if err != nil { return err } + req.ContentLength = contentLength req.Header = c.transportConfig.GetRequestHeader() if c.isH2 || c.isH3 { diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index d4579bc7..373e1613 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -314,14 +314,6 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err) } errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port) - - // h2cHandler can handle both plaintext HTTP/1.1 and h2c - h2cHandler := h2c.NewHandler(handler, &http2.Server{}) - l.server = http.Server{ - Handler: h2cHandler, - ReadHeaderTimeout: time.Second * 4, - MaxHeaderBytes: 8192, - } } // tcp/unix (h1/h2) @@ -332,7 +324,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet } } + // h2cHandler can handle both plaintext HTTP/1.1 and h2c + h2cHandler := h2c.NewHandler(handler, &http2.Server{}) l.listener = listener + l.server = http.Server{ + Handler: h2cHandler, + ReadHeaderTimeout: time.Second * 4, + MaxHeaderBytes: 8192, + } go func() { if err := l.server.Serve(l.listener); err != nil { diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index a3b609ab..5002e1a5 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -298,3 +298,65 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { t.Error("end: ", end, " start: ", start) } } + +func Test_listenSHAndDial_Unix(t *testing.T) { + tempDir := t.TempDir() + tempSocket := tempDir + "/server.sock" + + listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{ + ProtocolName: "splithttp", + ProtocolSettings: &Config{ + Path: "/sh", + }, + }, func(conn stat.Connection) { + go func(c stat.Connection) { + defer c.Close() + + var b [1024]byte + c.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err := c.Read(b[:]) + if err != nil { + return + } + + common.Must2(c.Write([]byte("Response"))) + }(conn) + }) + common.Must(err) + ctx := context.Background() + streamSettings := &internet.MemoryStreamConfig{ + ProtocolName: "splithttp", + ProtocolSettings: &Config{ + Host: "example.com", + Path: "sh", + }, + } + conn, err := Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings) + + common.Must(err) + _, err = conn.Write([]byte("Test connection 1")) + common.Must(err) + + var b [1024]byte + fmt.Println("test2") + n, _ := conn.Read(b[:]) + fmt.Println("string is", n) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + + common.Must(conn.Close()) + conn, err = Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings) + + common.Must(err) + _, err = conn.Write([]byte("Test connection 2")) + common.Must(err) + n, _ = conn.Read(b[:]) + common.Must(err) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + common.Must(conn.Close()) + + common.Must(listen.Close()) +}