diff --git a/.gitignore b/.gitignore index 58bb2f3..4f16c1f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ # vendor/ bin/ *.snap +passwd.txt diff --git a/auth.go b/auth.go index aeacc61..8aa686c 100644 --- a/auth.go +++ b/auth.go @@ -37,12 +37,7 @@ func NewAuth(paramstr string) (Auth, error) { } } -type StaticAuth struct { - token string - hiddenDomain string -} - -func NewStaticAuth(param_url *url.URL) (*StaticAuth, error) { +func NewStaticAuth(param_url *url.URL) (*BasicAuth, error) { values, err := url.ParseQuery(param_url.RawQuery) if err != nil { return nil, err @@ -55,8 +50,14 @@ func NewStaticAuth(param_url *url.URL) (*StaticAuth, error) { if password == "" { return nil, errors.New("\"password\" parameter is missing from auth config URI") } - return &StaticAuth{ - token: base64.StdEncoding.EncodeToString([]byte(username + ":" + password)), + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) + if err != nil { + return nil, err + } + return &BasicAuth{ + users: map[string][]byte{ + username: hashedPassword, + }, hiddenDomain: strings.ToLower(values.Get("hidden_domain")), }, nil } @@ -74,33 +75,6 @@ func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain s } } -func (auth *StaticAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { - hdr := req.Header.Get("Proxy-Authorization") - if hdr == "" { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } - hdr_parts := strings.SplitN(hdr, " ", 2) - if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } - token := hdr_parts[1] - ok := (subtle.ConstantTimeCompare([]byte(token), []byte(auth.token)) == 1) - if ok { - if auth.hiddenDomain != "" && - (req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) { - http.Error(wr, "Browser auth triggered!", http.StatusGone) - return false - } else { - return true - } - } else { - requireBasicAuth(wr, req, auth.hiddenDomain) - return false - } -} - type BasicAuth struct { users map[string][]byte hiddenDomain string @@ -190,6 +164,7 @@ func (auth *BasicAuth) Validate(wr http.ResponseWriter, req *http.Request) bool return true } } + requireBasicAuth(wr, req, auth.hiddenDomain) return false } diff --git a/handler.go b/handler.go index 04ec71f..8391dd9 100644 --- a/handler.go +++ b/handler.go @@ -37,7 +37,6 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { } defer conn.Close() - if req.ProtoMajor == 0 || req.ProtoMajor == 1 { // Upgrade client connection localconn, _, err := hijack(wr) @@ -87,8 +86,9 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL) - if ((req.URL.Host == "" || req.URL.Scheme == "") && req.ProtoMajor < 2) || - (req.Host == "" && req.ProtoMajor == 2) { + isConnect := strings.ToUpper(req.Method) == "CONNECT" + if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || + req.Host == "" && req.ProtoMajor == 2 { http.Error(wr, "Bad Request", http.StatusBadRequest) return } @@ -96,7 +96,7 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { return } delHopHeaders(req.Header) - if strings.ToUpper(req.Method) == "CONNECT" { + if isConnect { s.HandleTunnel(wr, req) } else { s.HandleRequest(wr, req)