diff --git a/handler.go b/handler.go index 936026b..f1ba95f 100644 --- a/handler.go +++ b/handler.go @@ -10,6 +10,8 @@ import ( "time" ) +const HintsHeaderName = "X-Src-IP-Hints" + type HandlerDialer interface { DialContext(ctx context.Context, net, address string) (net.Conn, error) } @@ -22,11 +24,14 @@ type ProxyHandler struct { httptransport http.RoundTripper outbound map[string]string outboundMux sync.RWMutex + userIPHints bool } -func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, logger *CondLogger) *ProxyHandler { +func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, + userIPHints bool, logger *CondLogger) *ProxyHandler { httptransport := &http.Transport{ - DialContext: dialer.DialContext, + DialContext: dialer.DialContext, + DisableKeepAlives: userIPHints, } return &ProxyHandler{ timeout: timeout, @@ -35,6 +40,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, log dialer: dialer, httptransport: httptransport, outbound: make(map[string]string), + userIPHints: userIPHints, } } @@ -133,6 +139,20 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { if !ok { return } + if s.userIPHints { + hintValues := req.Header.Values(HintsHeaderName) + if len(hintValues) > 0 { + req.Header.Del(HintsHeaderName) + if hintIPs, err := parseIPList(hintValues[0]); err != nil { + s.logger.Info("Request: %v %q %v %v %v -- bad IP hint header: %q", req.RemoteAddr, username, req.Proto, req.Method, req.URL, hintValues[0]) + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return + } else { + newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, hintIPs) + req = req.WithContext(newCtx) + } + } + } delHopHeaders(req.Header) if isConnect { s.HandleTunnel(wr, req) diff --git a/main.go b/main.go index 5a90672..feddd5f 100644 --- a/main.go +++ b/main.go @@ -73,6 +73,7 @@ type CLIArgs struct { positionalArgs []string proxy []string sourceIPHints []net.IP + userIPHints bool } func parse_args() CLIArgs { @@ -110,6 +111,7 @@ func parse_args() CLIArgs { args.sourceIPHints = list return nil }) + flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header") flag.Parse() args.positionalArgs = flag.Args() return args @@ -167,7 +169,7 @@ func run() int { server := http.Server{ Addr: args.bind_address, - Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), proxyLogger), + Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), args.userIPHints, proxyLogger), ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile), ReadTimeout: 0, ReadHeaderTimeout: 0,