diff --git a/handler.go b/handler.go index ff0a446..5a5fe8d 100644 --- a/handler.go +++ b/handler.go @@ -38,19 +38,30 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { defer conn.Close() - // Upgrade client connection - localconn, _, err := hijack(wr) - if err != nil { - s.logger.Error("Can't hijack client connection: %v", err) - http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError) + if req.ProtoMajor == 0 || req.ProtoMajor == 1 { + // Upgrade client connection + localconn, _, err := hijack(wr) + if err != nil { + s.logger.Error("Can't hijack client connection: %v", err) + http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError) + return + } + defer localconn.Close() + + // Inform client connection is built + fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor) + + proxy(req.Context(), localconn, conn) + } else if req.ProtoMajor == 2 { + wr.Header()["Date"] = nil + wr.WriteHeader(http.StatusOK) + flush(wr) + proxyh2(req.Context(), req.Body, wr, conn) + } else { + s.logger.Error("Unsupported protocol version: %s", req.Proto) + http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest) return } - defer localconn.Close() - - // Inform client connection is built - fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor) - - proxy(req.Context(), localconn, conn) } func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) { @@ -71,7 +82,7 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) } func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { - s.logger.Info("Request: %v %v %v", req.RemoteAddr, req.Method, req.URL) + s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL) if !s.auth.Validate(wr, req) { return } diff --git a/main.go b/main.go index ed9885a..8faaa0b 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "flag" "time" "net/http" - "crypto/tls" ) func perror(msg string) { @@ -67,7 +66,6 @@ func run() int { Addr: args.bind_address, Handler: NewProxyHandler(args.timeout, auth, proxyLogger), ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags | log.Lshortfile), - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // No HTTP/2 ReadTimeout: 0, ReadHeaderTimeout: 0, WriteTimeout: 0, diff --git a/utils.go b/utils.go index 919663c..d2b0e02 100644 --- a/utils.go +++ b/utils.go @@ -42,6 +42,36 @@ func proxy(ctx context.Context, left, right net.Conn) { return } +func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { + wg := sync.WaitGroup{} + ltr := func (dst net.Conn, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + rtl := func (dst io.Writer, src io.Reader) { + defer wg.Done() + copyBody(dst, src) + } + wg.Add(2) + go ltr(right, leftreader) + go rtl(leftwriter, right) + groupdone := make(chan struct{}, 1) + go func() { + wg.Wait() + groupdone <-struct{}{} + }() + select { + case <-ctx.Done(): + leftreader.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return +} + // Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html var hopHeaders = []string{