detect hostname the smart way

This commit is contained in:
b1ek 2024-07-26 17:17:02 +10:00
parent 506272e290
commit 3d5c53b566
2 changed files with 21 additions and 5 deletions

14
auth.go
View File

@ -84,15 +84,16 @@ func NewStaticAuth(param_url *url.URL, logger *CondLogger) (*BasicAuth, error) {
} }
func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain string) { func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain string) {
if IsAstraHost(req) && req.URL.Path == "/" {
SendIndex(wr, req)
return
}
if hidden_domain != "" && if hidden_domain != "" &&
(subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 && (subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 &&
subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) { subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
} else { } else {
if req.Host == "astra.blek.codes" && req.URL.Host == "astra.blek.codes" && req.URL.Path == "/" {
SendIndex(wr, req)
return
}
wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`) wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`)
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG)))) wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG))))
wr.WriteHeader(407) wr.WriteHeader(407)
@ -263,6 +264,11 @@ func (_ NoAuth) Stop() {}
type CertAuth struct{} type CertAuth struct{}
func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) { func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) {
if req.Host == "astra.blek.codes" && req.URL.Host == "astra.blek.codes" && req.URL.Path == "/" {
SendIndex(wr, req)
return "", false
}
if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 || len(req.TLS.VerifiedChains[0]) < 1 { if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 || len(req.TLS.VerifiedChains[0]) < 1 {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return "", false return "", false

View File

@ -10,8 +10,13 @@ import (
"time" "time"
) )
const astrahost = "astra.blek.codes"
const HintsHeaderName = "X-Src-IP-Hints" const HintsHeaderName = "X-Src-IP-Hints"
func IsAstraHost(req *http.Request) bool {
return req.Host == astrahost || req.URL.Host == astrahost
}
type HandlerDialer interface { type HandlerDialer interface {
DialContext(ctx context.Context, net, address string) (net.Conn, error) DialContext(ctx context.Context, net, address string) (net.Conn, error)
} }
@ -122,6 +127,11 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
if originator, isLoopback := s.isLoopback(req); isLoopback { if originator, isLoopback := s.isLoopback(req); isLoopback {
s.logger.Critical("Loopback tunnel detected: %s is an outbound "+ s.logger.Critical("Loopback tunnel detected: %s is an outbound "+
"address for another request from %s", req.RemoteAddr, originator) "address for another request from %s", req.RemoteAddr, originator)
if IsAstraHost(req) && req.URL.Path == "/" {
SendIndex(wr, req)
return
}
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return return
} }
@ -129,7 +139,7 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
isConnect := strings.ToUpper(req.Method) == "CONNECT" isConnect := strings.ToUpper(req.Method) == "CONNECT"
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
req.Host == "" && req.ProtoMajor == 2 { req.Host == "" && req.ProtoMajor == 2 {
if req.Host == "astra.blek.codes" && req.URL.Host == "astra.blek.codes" && req.URL.Path == "/" { if IsAstraHost(req) && req.URL.Path == "/" {
SendIndex(wr, req) SendIndex(wr, req)
return return
} }