diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..3ae0bf7 --- /dev/null +++ b/auth.go @@ -0,0 +1,88 @@ +package main + +import ( + "net/http" + "net/url" + "errors" + "strings" + "strconv" + "encoding/base64" + "crypto/subtle" +) + +const AUTH_REQUIRED_MSG = "Proxy authentication required." + +type Auth interface { + Validate(wr http.ResponseWriter, req *http.Request) bool +} + +func NewAuth(paramstr string) (Auth, error) { + url, err := url.Parse(paramstr) + if err != nil { + return nil, err + } + + switch strings.ToLower(url.Scheme) { + case "static": + auth, err := NewStaticAuth(url) + return auth, err + case "none": + return NoAuth{}, nil + default: + return nil, errors.New("Unknown auth scheme") + } +} + +type StaticAuth string + +func NewStaticAuth(param_url *url.URL) (StaticAuth, error) { + values, err := url.ParseQuery(param_url.RawQuery) + if err != nil { + return StaticAuth(""), err + } + username := values.Get("username") + if username == "" { + return StaticAuth(""), errors.New("\"username\" parameter is missing from auth config URI") + } + password := values.Get("password") + if password == "" { + return StaticAuth(""), errors.New("\"password\" parameter is missing from auth config URI") + } + return StaticAuth(base64.StdEncoding.EncodeToString( + []byte(username + ":" + password))), nil +} + +func requireBasicAuth(wr http.ResponseWriter) { + wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`) + wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG)))) + wr.WriteHeader(407) + wr.Write([]byte(AUTH_REQUIRED_MSG)) +} + +func (auth StaticAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { + hdr := req.Header.Get("Proxy-Authorization") + if hdr == "" { + requireBasicAuth(wr) + return false + } + hdr_parts := strings.SplitN(hdr, " ", 2) + if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" { + requireBasicAuth(wr) + return false + } + token := hdr_parts[1] + ok := (subtle.ConstantTimeCompare([]byte(token), []byte(auth)) == 1) + if ok { + return true + } else { + requireBasicAuth(wr) + return false + } +} + +type NoAuth struct {} + +func (_ NoAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { + return true +} + diff --git a/handler.go b/handler.go index 8ff18e2..0889ebf 100644 --- a/handler.go +++ b/handler.go @@ -12,14 +12,16 @@ import ( type ProxyHandler struct { timeout time.Duration + auth Auth logger *CondLogger httptransport http.RoundTripper } -func NewProxyHandler(timeout time.Duration, logger *CondLogger) *ProxyHandler { +func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler { httptransport := &http.Transport{} return &ProxyHandler{ timeout: timeout, + auth: auth, logger: logger, httptransport: httptransport, } @@ -65,11 +67,15 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) delHopHeaders(resp.Header) copyHeader(wr.Header(), resp.Header) wr.WriteHeader(resp.StatusCode) + flush(wr) io.Copy(wr, resp.Body) } func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { s.logger.Info("Request: %v %v %v", req.RemoteAddr, req.Method, req.URL) + if !s.auth.Validate(wr, req) { + return + } delHopHeaders(req.Header) if strings.ToUpper(req.Method) == "CONNECT" { s.HandleTunnel(wr, req) diff --git a/main.go b/main.go index 7c289a4..93f1758 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ func arg_fail(msg string) { type CLIArgs struct { bind_address string + auth string verbosity int timeout time.Duration } @@ -31,6 +32,7 @@ type CLIArgs struct { func parse_args() CLIArgs { var args CLIArgs flag.StringVar(&args.bind_address, "bind-address", ":8080", "HTTP proxy listen address") + flag.StringVar(&args.auth, "auth", "none://", "auth parameters") flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity " + "(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)") flag.DurationVar(&args.timeout, "timeout", 10 * time.Second, "timeout for network operations") @@ -51,8 +53,13 @@ func run() int { log.LstdFlags | log.Lshortfile), args.verbosity) mainLogger.Info("Starting proxy server...") - handler := NewProxyHandler(args.timeout, proxyLogger) - err := http.ListenAndServe(args.bind_address, handler) + auth, err := NewAuth(args.auth) + if err != nil { + mainLogger.Critical("Failed to instantiate auth provider: %v", err) + return 3 + } + handler := NewProxyHandler(args.timeout, auth, proxyLogger) + err = http.ListenAndServe(args.bind_address, handler) mainLogger.Critical("Server terminated with a reason: %v", err) mainLogger.Info("Shutting down...") return 0 diff --git a/utils.go b/utils.go index 072f80c..b1501fa 100644 --- a/utils.go +++ b/utils.go @@ -82,3 +82,12 @@ func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { } return conn, rw, nil } + +func flush(flusher interface{}) bool { + f, ok := flusher.(http.Flusher) + if !ok { + return false + } + f.Flush() + return true +}