prohibit CONNECT loops

This commit is contained in:
Vladislav Yarmak 2020-05-28 02:52:16 +03:00
parent 528b1d4d1f
commit b171a05588
1 changed files with 30 additions and 1 deletions

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"context" "context"
"sync"
) )
type ProxyHandler struct { type ProxyHandler struct {
@ -14,6 +15,8 @@ type ProxyHandler struct {
auth Auth auth Auth
logger *CondLogger logger *CondLogger
httptransport http.RoundTripper httptransport http.RoundTripper
outbound map[string]string
outboundMux sync.RWMutex
} }
func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler { func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler {
@ -23,6 +26,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *Prox
auth: auth, auth: auth,
logger: logger, logger: logger,
httptransport: httptransport, httptransport: httptransport,
outbound: make(map[string]string),
} }
} }
@ -35,7 +39,17 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway) http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
return return
} }
defer conn.Close()
localAddr := conn.LocalAddr().String()
s.outboundMux.Lock()
s.outbound[localAddr] = req.RemoteAddr
s.outboundMux.Unlock()
defer func() {
conn.Close()
s.outboundMux.Lock()
delete(s.outbound, localAddr)
s.outboundMux.Unlock()
}()
if req.ProtoMajor == 0 || req.ProtoMajor == 1 { if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
// Upgrade client connection // Upgrade client connection
@ -84,8 +98,23 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request)
copyBody(wr, resp.Body) copyBody(wr, resp.Body)
} }
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {
s.outboundMux.RLock()
originator, found := s.outbound[req.RemoteAddr]
s.outboundMux.RUnlock()
return originator, found
}
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL) s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL)
if originator, isLoopback := s.isLoopback(req) ; isLoopback {
s.logger.Critical("Loopback tunnel detected: %s is an outbound " +
"address for another request from %s", req.RemoteAddr, originator)
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return
}
isConnect := strings.ToUpper(req.Method) == "CONNECT" isConnect := strings.ToUpper(req.Method) == "CONNECT"
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
req.Host == "" && req.ProtoMajor == 2 { req.Host == "" && req.ProtoMajor == 2 {