diff --git a/auth.go b/auth.go index 52166fc..3c8eb15 100644 --- a/auth.go +++ b/auth.go @@ -1,16 +1,16 @@ package main import ( - "os" - "net/http" - "net/url" - "errors" - "strings" - "strconv" - "encoding/base64" - "crypto/subtle" - "golang.org/x/crypto/bcrypt" - "bufio" + "bufio" + "crypto/subtle" + "encoding/base64" + "errors" + "golang.org/x/crypto/bcrypt" + "net/http" + "net/url" + "os" + "strconv" + "strings" ) const AUTH_REQUIRED_MSG = "Proxy authentication required.\n" @@ -19,183 +19,182 @@ const AUTH_TRIGGERED_MSG = "Browser auth triggered!\n" const EPOCH_EXPIRE = "Thu, 01 Jan 1970 00:00:01 GMT" type Auth interface { - Validate(wr http.ResponseWriter, req *http.Request) bool + Validate(wr http.ResponseWriter, req *http.Request) bool } func NewAuth(paramstr string) (Auth, error) { - url, err := url.Parse(paramstr) - if err != nil { - return nil, err - } + url, err := url.Parse(paramstr) + if err != nil { + return nil, err + } - switch strings.ToLower(url.Scheme) { - case "static": - return NewStaticAuth(url) - case "basicfile": - return NewBasicFileAuth(url) - case "cert": - return CertAuth{}, nil - case "none": - return NoAuth{}, nil - default: - return nil, errors.New("Unknown auth scheme") - } + switch strings.ToLower(url.Scheme) { + case "static": + return NewStaticAuth(url) + case "basicfile": + return NewBasicFileAuth(url) + case "cert": + return CertAuth{}, nil + case "none": + return NoAuth{}, nil + default: + return nil, errors.New("Unknown auth scheme") + } } func NewStaticAuth(param_url *url.URL) (*BasicAuth, error) { - values, err := url.ParseQuery(param_url.RawQuery) - if err != nil { - return nil, err - } - username := values.Get("username") - if username == "" { - return nil, errors.New("\"username\" parameter is missing from auth config URI") - } - password := values.Get("password") - if password == "" { - return nil, errors.New("\"password\" parameter is missing from auth config URI") - } - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) - if err != nil { - return nil, err - } - return &BasicAuth{ - users: map[string][]byte{ - username: hashedPassword, - }, - hiddenDomain: strings.ToLower(values.Get("hidden_domain")), - }, nil + values, err := url.ParseQuery(param_url.RawQuery) + if err != nil { + return nil, err + } + username := values.Get("username") + if username == "" { + return nil, errors.New("\"username\" parameter is missing from auth config URI") + } + password := values.Get("password") + if password == "" { + return nil, errors.New("\"password\" parameter is missing from auth config URI") + } + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) + if err != nil { + return nil, err + } + return &BasicAuth{ + users: map[string][]byte{ + username: hashedPassword, + }, + hiddenDomain: strings.ToLower(values.Get("hidden_domain")), + }, nil } func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain string) { - if hidden_domain != "" && - (subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 && - subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) { - http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) - } else { - wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`) - wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG)))) - wr.WriteHeader(407) - wr.Write([]byte(AUTH_REQUIRED_MSG)) - } + if hidden_domain != "" && + (subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 && + subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) { + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + } else { + wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`) + wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG)))) + wr.WriteHeader(407) + wr.Write([]byte(AUTH_REQUIRED_MSG)) + } } type BasicAuth struct { - users map[string][]byte - hiddenDomain string + users map[string][]byte + hiddenDomain string } func NewBasicFileAuth(param_url *url.URL) (*BasicAuth, error) { - values, err := url.ParseQuery(param_url.RawQuery) - if err != nil { - return nil, err - } - filename := values.Get("path") - if filename == "" { - return nil, errors.New("\"path\" parameter is missing from auth config URI") - } + values, err := url.ParseQuery(param_url.RawQuery) + if err != nil { + return nil, err + } + filename := values.Get("path") + if filename == "" { + return nil, errors.New("\"path\" parameter is missing from auth config URI") + } - f, err := os.Open(filename) - if err != nil { - return nil, err - } - defer f.Close() + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() - scanner := bufio.NewScanner(f) - users := make(map[string][]byte) - for scanner.Scan() { - line := scanner.Text() - trimmed := strings.TrimSpace(line) - if trimmed == "" || strings.HasPrefix(trimmed, "#") { - continue - } - pair := strings.SplitN(line, ":", 2) - if len(pair) != 2 { - return nil, errors.New("Malformed login and password line") - } - login := pair[0] - password := pair[1] - users[login] = []byte(password) - } - if err := scanner.Err(); err != nil { - return nil, err - } - if len(users) == 0 { - return nil, errors.New("No password lines were read from file") - } - return &BasicAuth{ - users: users, - hiddenDomain: strings.ToLower(values.Get("hidden_domain")), - }, nil + scanner := bufio.NewScanner(f) + users := make(map[string][]byte) + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + pair := strings.SplitN(line, ":", 2) + if len(pair) != 2 { + return nil, errors.New("Malformed login and password line") + } + login := pair[0] + password := pair[1] + users[login] = []byte(password) + } + if err := scanner.Err(); err != nil { + return nil, err + } + if len(users) == 0 { + return nil, errors.New("No password lines were read from file") + } + return &BasicAuth{ + users: users, + hiddenDomain: strings.ToLower(values.Get("hidden_domain")), + }, nil } func (auth *BasicAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { - hdr := req.Header.Get("Proxy-Authorization") - if hdr == "" { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } - hdr_parts := strings.SplitN(hdr, " ", 2) - if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } + hdr := req.Header.Get("Proxy-Authorization") + if hdr == "" { + requireBasicAuth(wr, req, auth.hiddenDomain) + return false + } + hdr_parts := strings.SplitN(hdr, " ", 2) + if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" { + requireBasicAuth(wr, req, auth.hiddenDomain) + return false + } - token := hdr_parts[1] - data, err := base64.StdEncoding.DecodeString(token) - if err != nil { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } + token := hdr_parts[1] + data, err := base64.StdEncoding.DecodeString(token) + if err != nil { + requireBasicAuth(wr, req, auth.hiddenDomain) + return false + } - pair := strings.SplitN(string(data), ":", 2) - if len(pair) != 2 { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } + pair := strings.SplitN(string(data), ":", 2) + if len(pair) != 2 { + requireBasicAuth(wr, req, auth.hiddenDomain) + return false + } - login := pair[0] - password := pair[1] + login := pair[0] + password := pair[1] - hashedPassword, ok := auth.users[login] - if !ok { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } + hashedPassword, ok := auth.users[login] + if !ok { + requireBasicAuth(wr, req, auth.hiddenDomain) + return false + } - if bcrypt.CompareHashAndPassword(hashedPassword, []byte(password)) == nil { - if auth.hiddenDomain != "" && - (req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) { - wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_TRIGGERED_MSG)))) - wr.Header().Set("Pragma", "no-cache") - wr.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") - wr.Header().Set("Expires", EPOCH_EXPIRE) - wr.Header()["Date"] = nil - wr.WriteHeader(http.StatusOK) - wr.Write([]byte(AUTH_TRIGGERED_MSG)) - return false - } else { - return true - } - } - requireBasicAuth(wr, req, auth.hiddenDomain) - return false + if bcrypt.CompareHashAndPassword(hashedPassword, []byte(password)) == nil { + if auth.hiddenDomain != "" && + (req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) { + wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_TRIGGERED_MSG)))) + wr.Header().Set("Pragma", "no-cache") + wr.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + wr.Header().Set("Expires", EPOCH_EXPIRE) + wr.Header()["Date"] = nil + wr.WriteHeader(http.StatusOK) + wr.Write([]byte(AUTH_TRIGGERED_MSG)) + return false + } else { + return true + } + } + requireBasicAuth(wr, req, auth.hiddenDomain) + return false } -type NoAuth struct {} +type NoAuth struct{} func (_ NoAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { - return true + return true } - -type CertAuth struct {} +type CertAuth struct{} func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { - if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 { - http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) - return false - } else { - return true - } + if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 { + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return false + } else { + return true + } } diff --git a/condlog.go b/condlog.go index cb76d2b..96a18f3 100644 --- a/condlog.go +++ b/condlog.go @@ -1,58 +1,58 @@ package main import ( - "log" - "fmt" + "fmt" + "log" ) const ( - CRITICAL = 50 - ERROR = 40 - WARNING = 30 - INFO = 20 - DEBUG = 10 - NOTSET = 0 + CRITICAL = 50 + ERROR = 40 + WARNING = 30 + INFO = 20 + DEBUG = 10 + NOTSET = 0 ) type CondLogger struct { - logger *log.Logger - verbosity int + logger *log.Logger + verbosity int } func (cl *CondLogger) Log(verb int, format string, v ...interface{}) error { - if verb >= cl.verbosity { - return cl.logger.Output(2, fmt.Sprintf(format, v...)) - } - return nil + if verb >= cl.verbosity { + return cl.logger.Output(2, fmt.Sprintf(format, v...)) + } + return nil } func (cl *CondLogger) log(verb int, format string, v ...interface{}) error { - if verb >= cl.verbosity { - return cl.logger.Output(3, fmt.Sprintf(format, v...)) - } - return nil + if verb >= cl.verbosity { + return cl.logger.Output(3, fmt.Sprintf(format, v...)) + } + return nil } func (cl *CondLogger) Critical(s string, v ...interface{}) error { - return cl.log(CRITICAL, "CRITICAL " + s, v...) + return cl.log(CRITICAL, "CRITICAL "+s, v...) } func (cl *CondLogger) Error(s string, v ...interface{}) error { - return cl.log(ERROR, "ERROR " + s, v...) + return cl.log(ERROR, "ERROR "+s, v...) } func (cl *CondLogger) Warning(s string, v ...interface{}) error { - return cl.log(WARNING, "WARNING " + s, v...) + return cl.log(WARNING, "WARNING "+s, v...) } func (cl *CondLogger) Info(s string, v ...interface{}) error { - return cl.log(INFO, "INFO " + s, v...) + return cl.log(INFO, "INFO "+s, v...) } func (cl *CondLogger) Debug(s string, v ...interface{}) error { - return cl.log(DEBUG, "DEBUG " + s, v...) + return cl.log(DEBUG, "DEBUG "+s, v...) } func NewCondLogger(logger *log.Logger, verbosity int) *CondLogger { - return &CondLogger{verbosity: verbosity, logger: logger} + return &CondLogger{verbosity: verbosity, logger: logger} } diff --git a/handler.go b/handler.go index 3c6260d..eb3fcaf 100644 --- a/handler.go +++ b/handler.go @@ -1,133 +1,133 @@ package main import ( - "net" - "fmt" - "time" - "net/http" - "strings" - "context" - "sync" + "context" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" ) type ProxyHandler struct { - timeout time.Duration - auth Auth - logger *CondLogger - httptransport http.RoundTripper - outbound map[string]string - outboundMux sync.RWMutex + timeout time.Duration + auth Auth + logger *CondLogger + httptransport http.RoundTripper + outbound map[string]string + outboundMux sync.RWMutex } func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler { httptransport := &http.Transport{} - return &ProxyHandler{ - timeout: timeout, - auth: auth, - logger: logger, - httptransport: httptransport, - outbound: make(map[string]string), - } + return &ProxyHandler{ + timeout: timeout, + auth: auth, + logger: logger, + httptransport: httptransport, + outbound: make(map[string]string), + } } func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { - ctx, _ := context.WithTimeout(req.Context(), s.timeout) - dialer := net.Dialer{} - conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI) - if err != nil { - s.logger.Error("Can't satisfy CONNECT request: %v", err) - http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway) - return - } + ctx, _ := context.WithTimeout(req.Context(), s.timeout) + dialer := net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI) + if err != nil { + s.logger.Error("Can't satisfy CONNECT request: %v", err) + http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway) + return + } - localAddr := conn.LocalAddr().String() - s.outboundMux.Lock() - s.outbound[localAddr] = req.RemoteAddr - s.outboundMux.Unlock() - defer func() { - conn.Close() - s.outboundMux.Lock() - delete(s.outbound, localAddr) - s.outboundMux.Unlock() - }() + localAddr := conn.LocalAddr().String() + s.outboundMux.Lock() + s.outbound[localAddr] = req.RemoteAddr + s.outboundMux.Unlock() + defer func() { + conn.Close() + s.outboundMux.Lock() + delete(s.outbound, localAddr) + s.outboundMux.Unlock() + }() - if req.ProtoMajor == 0 || req.ProtoMajor == 1 { - // Upgrade client connection - localconn, _, err := hijack(wr) - if err != nil { - s.logger.Error("Can't hijack client connection: %v", err) - http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError) - return - } - defer localconn.Close() + if req.ProtoMajor == 0 || req.ProtoMajor == 1 { + // Upgrade client connection + localconn, _, err := hijack(wr) + if err != nil { + s.logger.Error("Can't hijack client connection: %v", err) + http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError) + return + } + defer localconn.Close() - // Inform client connection is built - fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor) + // Inform client connection is built + fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor) - proxy(req.Context(), localconn, conn) - } else if req.ProtoMajor == 2 { - wr.Header()["Date"] = nil - wr.WriteHeader(http.StatusOK) - flush(wr) - proxyh2(req.Context(), req.Body, wr, conn) - } else { - s.logger.Error("Unsupported protocol version: %s", req.Proto) - http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest) - return - } + proxy(req.Context(), localconn, conn) + } else if req.ProtoMajor == 2 { + wr.Header()["Date"] = nil + wr.WriteHeader(http.StatusOK) + flush(wr) + proxyh2(req.Context(), req.Body, wr, conn) + } else { + s.logger.Error("Unsupported protocol version: %s", req.Proto) + http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest) + return + } } func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) { - req.RequestURI = "" - if req.ProtoMajor == 2 { - req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http - req.URL.Host = req.Host - } - resp, err := s.httptransport.RoundTrip(req) - if err != nil { - s.logger.Error("HTTP fetch error: %v", err) - http.Error(wr, "Server Error", http.StatusInternalServerError) - return - } - defer resp.Body.Close() - s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status) - delHopHeaders(resp.Header) - copyHeader(wr.Header(), resp.Header) - wr.WriteHeader(resp.StatusCode) - flush(wr) - copyBody(wr, resp.Body) + req.RequestURI = "" + if req.ProtoMajor == 2 { + req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http + req.URL.Host = req.Host + } + resp, err := s.httptransport.RoundTrip(req) + if err != nil { + s.logger.Error("HTTP fetch error: %v", err) + http.Error(wr, "Server Error", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status) + delHopHeaders(resp.Header) + copyHeader(wr.Header(), resp.Header) + wr.WriteHeader(resp.StatusCode) + flush(wr) + copyBody(wr, resp.Body) } func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) { - s.outboundMux.RLock() - originator, found := s.outbound[req.RemoteAddr] - s.outboundMux.RUnlock() - return originator, found + s.outboundMux.RLock() + originator, found := s.outbound[req.RemoteAddr] + s.outboundMux.RUnlock() + return originator, found } func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL) - if originator, isLoopback := s.isLoopback(req) ; isLoopback { - s.logger.Critical("Loopback tunnel detected: %s is an outbound " + - "address for another request from %s", req.RemoteAddr, originator) - http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) - return - } + if originator, isLoopback := s.isLoopback(req); isLoopback { + s.logger.Critical("Loopback tunnel detected: %s is an outbound "+ + "address for another request from %s", req.RemoteAddr, originator) + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return + } - isConnect := strings.ToUpper(req.Method) == "CONNECT" - if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || - req.Host == "" && req.ProtoMajor == 2 { - http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) - return - } - if !s.auth.Validate(wr, req) { - return - } - delHopHeaders(req.Header) - if isConnect { - s.HandleTunnel(wr, req) - } else { - s.HandleRequest(wr, req) - } + isConnect := strings.ToUpper(req.Method) == "CONNECT" + if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || + req.Host == "" && req.ProtoMajor == 2 { + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return + } + if !s.auth.Validate(wr, req) { + return + } + delHopHeaders(req.Header) + if isConnect { + s.HandleTunnel(wr, req) + } else { + s.HandleRequest(wr, req) + } } diff --git a/logwriter.go b/logwriter.go index a880db0..657c2f3 100644 --- a/logwriter.go +++ b/logwriter.go @@ -1,57 +1,57 @@ package main import ( - "io" - "errors" - "time" + "errors" + "io" + "time" ) const MAX_LOG_QLEN = 128 const QUEUE_SHUTDOWN_TIMEOUT = 500 * time.Millisecond type LogWriter struct { - writer io.Writer - ch chan []byte - done chan struct{} + writer io.Writer + ch chan []byte + done chan struct{} } func (lw *LogWriter) Write(p []byte) (int, error) { - if p == nil { - return 0, errors.New("Can't write nil byte slice") - } - buf := make([]byte, len(p)) - copy(buf, p) - select { - case lw.ch <- buf: - return len(p), nil - default: - return 0, errors.New("Writer queue overflow") - } + if p == nil { + return 0, errors.New("Can't write nil byte slice") + } + buf := make([]byte, len(p)) + copy(buf, p) + select { + case lw.ch <- buf: + return len(p), nil + default: + return 0, errors.New("Writer queue overflow") + } } func NewLogWriter(writer io.Writer) *LogWriter { - lw := &LogWriter{writer, - make(chan []byte, MAX_LOG_QLEN), - make(chan struct{})} - go lw.loop() - return lw + lw := &LogWriter{writer, + make(chan []byte, MAX_LOG_QLEN), + make(chan struct{})} + go lw.loop() + return lw } func (lw *LogWriter) loop() { - for p := range lw.ch { - if p == nil { - break - } - lw.writer.Write(p) - } - lw.done <- struct{}{} + for p := range lw.ch { + if p == nil { + break + } + lw.writer.Write(p) + } + lw.done <- struct{}{} } func (lw *LogWriter) Close() { - lw.ch <- nil - timer := time.After(QUEUE_SHUTDOWN_TIMEOUT) - select { - case <-timer: - case <-lw.done: - } + lw.ch <- nil + timer := time.After(QUEUE_SHUTDOWN_TIMEOUT) + select { + case <-timer: + case <-lw.done: + } } diff --git a/main.go b/main.go index 17b4d03..30fb007 100644 --- a/main.go +++ b/main.go @@ -1,95 +1,108 @@ package main import ( - "log" - "os" - "fmt" - "flag" - "time" - "net/http" + "crypto/tls" + "flag" + "fmt" + "log" + "net/http" + "os" + "time" ) func perror(msg string) { - fmt.Fprintln(os.Stderr, "") - fmt.Fprintln(os.Stderr, msg) + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, msg) } func arg_fail(msg string) { - perror(msg) - perror("Usage:") - flag.PrintDefaults() - os.Exit(2) + perror(msg) + perror("Usage:") + flag.PrintDefaults() + os.Exit(2) } type CLIArgs struct { - bind_address string - auth string - verbosity int - timeout time.Duration - cert, key, cafile string + bind_address string + auth string + verbosity int + timeout time.Duration + cert, key, cafile string + list_suites bool } +func list_suites() { + for _, cipher := range tls.CipherSuites() { + fmt.Println(cipher.Name) + } +} func parse_args() CLIArgs { - var args CLIArgs - flag.StringVar(&args.bind_address, "bind-address", ":8080", "HTTP proxy listen address") - flag.StringVar(&args.auth, "auth", "none://", "auth parameters") - flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity " + - "(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)") - flag.DurationVar(&args.timeout, "timeout", 10 * time.Second, "timeout for network operations") - flag.StringVar(&args.cert, "cert", "", "enable TLS and use certificate") - flag.StringVar(&args.key, "key", "", "key for TLS certificate") - flag.StringVar(&args.cafile, "cafile", "", "CA file to authenticate clients with certificates") - flag.Parse() - return args + var args CLIArgs + flag.StringVar(&args.bind_address, "bind-address", ":8080", "HTTP proxy listen address") + flag.StringVar(&args.auth, "auth", "none://", "auth parameters") + flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity "+ + "(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)") + flag.DurationVar(&args.timeout, "timeout", 10*time.Second, "timeout for network operations") + flag.StringVar(&args.cert, "cert", "", "enable TLS and use certificate") + flag.StringVar(&args.key, "key", "", "key for TLS certificate") + flag.StringVar(&args.cafile, "cafile", "", "CA file to authenticate clients with certificates") + flag.BoolVar(&args.list_suites, "list-ciphers", false, "list ciphersuites") + flag.Parse() + return args } func run() int { - args := parse_args() + args := parse_args() - logWriter := NewLogWriter(os.Stderr) - defer logWriter.Close() + if args.list_suites { + list_suites() + return 0 + } - mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ", - log.LstdFlags | log.Lshortfile), - args.verbosity) - proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ", - log.LstdFlags | log.Lshortfile), - args.verbosity) + logWriter := NewLogWriter(os.Stderr) + defer logWriter.Close() - auth, err := NewAuth(args.auth) - if err != nil { - mainLogger.Critical("Failed to instantiate auth provider: %v", err) - return 3 - } + mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ", + log.LstdFlags|log.Lshortfile), + args.verbosity) + proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ", + log.LstdFlags|log.Lshortfile), + args.verbosity) - server := http.Server{ - Addr: args.bind_address, - Handler: NewProxyHandler(args.timeout, auth, proxyLogger), - ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags | log.Lshortfile), - ReadTimeout: 0, - ReadHeaderTimeout: 0, - WriteTimeout: 0, - IdleTimeout: 0, - } + auth, err := NewAuth(args.auth) + if err != nil { + mainLogger.Critical("Failed to instantiate auth provider: %v", err) + return 3 + } - mainLogger.Info("Starting proxy server...") - if args.cert != "" { - cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile) - if err1 != nil { - mainLogger.Critical("TLS config construction failed: %v", err) - return 3 - } - server.TLSConfig = cfg - err = server.ListenAndServeTLS("", "") - } else { - err = server.ListenAndServe() - } - mainLogger.Critical("Server terminated with a reason: %v", err) - mainLogger.Info("Shutting down...") - return 0 + server := http.Server{ + Addr: args.bind_address, + Handler: NewProxyHandler(args.timeout, auth, proxyLogger), + ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile), + ReadTimeout: 0, + ReadHeaderTimeout: 0, + WriteTimeout: 0, + IdleTimeout: 0, + } + + mainLogger.Info("Starting proxy server...") + if args.cert != "" { + cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile) + if err1 != nil { + mainLogger.Critical("TLS config construction failed: %v", err) + return 3 + } + server.TLSConfig = cfg + err = server.ListenAndServeTLS("", "") + } else { + err = server.ListenAndServe() + } + mainLogger.Critical("Server terminated with a reason: %v", err) + mainLogger.Info("Shutting down...") + return 0 } func main() { - os.Exit(run()) + os.Exit(run()) } diff --git a/utils.go b/utils.go index 0a92dfb..9541b8d 100644 --- a/utils.go +++ b/utils.go @@ -1,75 +1,75 @@ package main import ( - "context" - "net" - "sync" - "io" - "time" - "errors" - "net/http" - "bufio" - "crypto/tls" - "crypto/x509" - "io/ioutil" + "bufio" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "sync" + "time" ) const COPY_BUF = 128 * 1024 func proxy(ctx context.Context, left, right net.Conn) { - wg := sync.WaitGroup{} - cpy := func (dst, src net.Conn) { - defer wg.Done() - io.Copy(dst, src) - dst.Close() - } - wg.Add(2) - go cpy(left, right) - go cpy(right, left) - groupdone := make(chan struct{}, 1) - go func() { - wg.Wait() - groupdone <-struct{}{} - }() - select { - case <-ctx.Done(): - left.Close() - right.Close() - case <-groupdone: - return - } - <-groupdone - return + wg := sync.WaitGroup{} + cpy := func(dst, src net.Conn) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + wg.Add(2) + go cpy(left, right) + go cpy(right, left) + groupdone := make(chan struct{}, 1) + go func() { + wg.Wait() + groupdone <- struct{}{} + }() + select { + case <-ctx.Done(): + left.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return } func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { - wg := sync.WaitGroup{} - ltr := func (dst net.Conn, src io.Reader) { - defer wg.Done() - io.Copy(dst, src) - dst.Close() - } - rtl := func (dst io.Writer, src io.Reader) { - defer wg.Done() - copyBody(dst, src) - } - wg.Add(2) - go ltr(right, leftreader) - go rtl(leftwriter, right) - groupdone := make(chan struct{}, 1) - go func() { - wg.Wait() - groupdone <-struct{}{} - }() - select { - case <-ctx.Done(): - leftreader.Close() - right.Close() - case <-groupdone: - return - } - <-groupdone - return + wg := sync.WaitGroup{} + ltr := func(dst net.Conn, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + rtl := func(dst io.Writer, src io.Reader) { + defer wg.Done() + copyBody(dst, src) + } + wg.Add(2) + go ltr(right, leftreader) + go rtl(leftwriter, right) + groupdone := make(chan struct{}, 1) + go func() { + wg.Wait() + groupdone <- struct{}{} + }() + select { + case <-ctx.Done(): + leftreader.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return } // Hop-by-hop headers. These are removed when sent to the backend. @@ -79,7 +79,7 @@ var hopHeaders = []string{ "Keep-Alive", "Proxy-Authenticate", "Proxy-Connection", - "Proxy-Authorization", + "Proxy-Authorization", "Te", // canonicalized version of "TE" "Trailers", "Transfer-Encoding", @@ -101,65 +101,65 @@ func delHopHeaders(header http.Header) { } func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { - hj, ok := hijackable.(http.Hijacker) - if !ok { - return nil, nil, errors.New("Connection doesn't support hijacking") - } - conn, rw, err := hj.Hijack() - if err != nil { - return nil, nil, err - } - var emptytime time.Time - err = conn.SetDeadline(emptytime) - if err != nil { - conn.Close() - return nil, nil, err - } - return conn, rw, nil + hj, ok := hijackable.(http.Hijacker) + if !ok { + return nil, nil, errors.New("Connection doesn't support hijacking") + } + conn, rw, err := hj.Hijack() + if err != nil { + return nil, nil, err + } + var emptytime time.Time + err = conn.SetDeadline(emptytime) + if err != nil { + conn.Close() + return nil, nil, err + } + return conn, rw, nil } func flush(flusher interface{}) bool { - f, ok := flusher.(http.Flusher) - if !ok { - return false - } - f.Flush() - return true + f, ok := flusher.(http.Flusher) + if !ok { + return false + } + f.Flush() + return true } func copyBody(wr io.Writer, body io.Reader) { - buf := make([]byte, COPY_BUF) - for { - bread, read_err := body.Read(buf) - var write_err error - if bread > 0 { - _, write_err = wr.Write(buf[:bread]) - flush(wr) - } - if read_err != nil || write_err != nil { - break - } - } + buf := make([]byte, COPY_BUF) + for { + bread, read_err := body.Read(buf) + var write_err error + if bread > 0 { + _, write_err = wr.Write(buf[:bread]) + flush(wr) + } + if read_err != nil || write_err != nil { + break + } + } } func makeServerTLSConfig(certfile, keyfile, cafile string) (*tls.Config, error) { - var cfg tls.Config - cert, err := tls.LoadX509KeyPair(certfile, keyfile) - if err != nil { - return nil, err - } - cfg.Certificates = []tls.Certificate{cert} - if cafile != "" { - roots := x509.NewCertPool() - certs, err := ioutil.ReadFile(cafile) - if err != nil { - return nil, err - } - if ok := roots.AppendCertsFromPEM(certs); !ok { - return nil, errors.New("Failed to load CA certificates") - } - cfg.ClientCAs = roots - cfg.ClientAuth = tls.VerifyClientCertIfGiven - } - return &cfg, nil + var cfg tls.Config + cert, err := tls.LoadX509KeyPair(certfile, keyfile) + if err != nil { + return nil, err + } + cfg.Certificates = []tls.Certificate{cert} + if cafile != "" { + roots := x509.NewCertPool() + certs, err := ioutil.ReadFile(cafile) + if err != nil { + return nil, err + } + if ok := roots.AppendCertsFromPEM(certs); !ok { + return nil, errors.New("Failed to load CA certificates") + } + cfg.ClientCAs = roots + cfg.ClientAuth = tls.VerifyClientCertIfGiven + } + return &cfg, nil }