prohibit CONNECT loops
This commit is contained in:
parent
528b1d4d1f
commit
b171a05588
31
handler.go
31
handler.go
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ProxyHandler struct {
|
||||
|
@ -14,6 +15,8 @@ type ProxyHandler struct {
|
|||
auth Auth
|
||||
logger *CondLogger
|
||||
httptransport http.RoundTripper
|
||||
outbound map[string]string
|
||||
outboundMux sync.RWMutex
|
||||
}
|
||||
|
||||
func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler {
|
||||
|
@ -23,6 +26,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *Prox
|
|||
auth: auth,
|
||||
logger: logger,
|
||||
httptransport: httptransport,
|
||||
outbound: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -35,7 +39,17 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
|
|||
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localAddr := conn.LocalAddr().String()
|
||||
s.outboundMux.Lock()
|
||||
s.outbound[localAddr] = req.RemoteAddr
|
||||
s.outboundMux.Unlock()
|
||||
defer func() {
|
||||
conn.Close()
|
||||
s.outboundMux.Lock()
|
||||
delete(s.outbound, localAddr)
|
||||
s.outboundMux.Unlock()
|
||||
}()
|
||||
|
||||
if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
|
||||
// Upgrade client connection
|
||||
|
@ -84,8 +98,23 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request)
|
|||
copyBody(wr, resp.Body)
|
||||
}
|
||||
|
||||
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {
|
||||
s.outboundMux.RLock()
|
||||
originator, found := s.outbound[req.RemoteAddr]
|
||||
s.outboundMux.RUnlock()
|
||||
return originator, found
|
||||
}
|
||||
|
||||
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL)
|
||||
|
||||
if originator, isLoopback := s.isLoopback(req) ; isLoopback {
|
||||
s.logger.Critical("Loopback tunnel detected: %s is an outbound " +
|
||||
"address for another request from %s", req.RemoteAddr, originator)
|
||||
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
isConnect := strings.ToUpper(req.Method) == "CONNECT"
|
||||
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
|
||||
req.Host == "" && req.ProtoMajor == 2 {
|
||||
|
|
Loading…
Reference in New Issue