From b171a05588dae479533b2b0245559c171db63d83 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Thu, 28 May 2020 02:52:16 +0300 Subject: [PATCH] prohibit CONNECT loops --- handler.go | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/handler.go b/handler.go index 5d31414..3c6260d 100644 --- a/handler.go +++ b/handler.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" "context" + "sync" ) type ProxyHandler struct { @@ -14,6 +15,8 @@ type ProxyHandler struct { auth Auth logger *CondLogger httptransport http.RoundTripper + outbound map[string]string + outboundMux sync.RWMutex } func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler { @@ -23,6 +26,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *Prox auth: auth, logger: logger, httptransport: httptransport, + outbound: make(map[string]string), } } @@ -35,7 +39,17 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway) return } - defer conn.Close() + + localAddr := conn.LocalAddr().String() + s.outboundMux.Lock() + s.outbound[localAddr] = req.RemoteAddr + s.outboundMux.Unlock() + defer func() { + conn.Close() + s.outboundMux.Lock() + delete(s.outbound, localAddr) + s.outboundMux.Unlock() + }() if req.ProtoMajor == 0 || req.ProtoMajor == 1 { // Upgrade client connection @@ -84,8 +98,23 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) copyBody(wr, resp.Body) } +func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) { + s.outboundMux.RLock() + originator, found := s.outbound[req.RemoteAddr] + s.outboundMux.RUnlock() + return originator, found +} + func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL) + + if originator, isLoopback := s.isLoopback(req) ; isLoopback { + s.logger.Critical("Loopback tunnel detected: %s is an outbound " + + "address for another request from %s", req.RemoteAddr, originator) + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return + } + isConnect := strings.ToUpper(req.Method) == "CONNECT" if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || req.Host == "" && req.ProtoMajor == 2 {