diff --git a/handler.go b/handler.go index f1ba95f..cbca628 100644 --- a/handler.go +++ b/handler.go @@ -31,7 +31,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, userIPHints bool, logger *CondLogger) *ProxyHandler { httptransport := &http.Transport{ DialContext: dialer.DialContext, - DisableKeepAlives: userIPHints, + DisableKeepAlives: true, } return &ProxyHandler{ timeout: timeout, @@ -134,25 +134,26 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { } username, ok := s.auth.Validate(wr, req) - s.logger.Info("Request: %v %q %v %v %v", req.RemoteAddr, username, req.Proto, req.Method, req.URL) + localAddr := getLocalAddr(req.Context()) + s.logger.Info("Request: %v => %v %q %v %v %v", req.RemoteAddr, localAddr, username, req.Proto, req.Method, req.URL) if !ok { return } + + var ipHints *string if s.userIPHints { hintValues := req.Header.Values(HintsHeaderName) if len(hintValues) > 0 { req.Header.Del(HintsHeaderName) - if hintIPs, err := parseIPList(hintValues[0]); err != nil { - s.logger.Info("Request: %v %q %v %v %v -- bad IP hint header: %q", req.RemoteAddr, username, req.Proto, req.Method, req.URL, hintValues[0]) - http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) - return - } else { - newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, hintIPs) - req = req.WithContext(newCtx) - } + ipHints = &hintValues[0] } } + newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, BoundDialerContextValue{ + Hints: ipHints, + LocalAddr: trimAddrPort(localAddr), + }) + req = req.WithContext(newCtx) delHopHeaders(req.Header) if isConnect { s.HandleTunnel(wr, req) @@ -160,3 +161,18 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { s.HandleRequest(wr, req) } } + +func trimAddrPort(addrPort string) string { + res, _, err := net.SplitHostPort(addrPort) + if err != nil { + return addrPort + } + return res +} + +func getLocalAddr(ctx context.Context) string { + if addr, ok := ctx.Value(http.LocalAddrContextKey).(net.Addr); ok { + return addr.String() + } + return "" +} diff --git a/hintdialer.go b/hintdialer.go index d1ded98..e76efce 100644 --- a/hintdialer.go +++ b/hintdialer.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "os" "github.com/hashicorp/go-multierror" ) @@ -17,16 +18,21 @@ var ( type BoundDialerContextKey struct{} +type BoundDialerContextValue struct { + Hints *string + LocalAddr string +} + type BoundDialerDefaultSink interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } type BoundDialer struct { defaultDialer BoundDialerDefaultSink - defaultHints []net.IP + defaultHints string } -func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) *BoundDialer { +func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints string) *BoundDialer { if defaultDialer == nil { defaultDialer = &net.Dialer{} } @@ -38,13 +44,22 @@ func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { hints := d.defaultHints + lAddr := "" if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil { - if hintsOverrideValue, ok := hintsOverride.([]net.IP); ok { - hints = hintsOverrideValue + if hintsOverrideValue, ok := hintsOverride.(BoundDialerContextValue); ok { + if hintsOverrideValue.Hints != nil { + hints = *hintsOverrideValue.Hints + } + lAddr = hintsOverrideValue.LocalAddr } } - if len(hints) == 0 { + parsedHints, err := parseHints(hints, lAddr) + if err != nil { + return nil, fmt.Errorf("dial failed: %w", err) + } + + if len(parsedHints) == 0 { return d.defaultDialer.DialContext(ctx, network, address) } @@ -61,7 +76,7 @@ func (d *BoundDialer) DialContext(ctx context.Context, network, address string) } var resErr error - for _, lIP := range hints { + for _, lIP := range parsedHints { lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP) if err != nil { resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err)) @@ -136,3 +151,19 @@ func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) { return lAddr, lNetwork, nil } + +func parseHints(hints, lAddr string) ([]net.IP, error) { + hints = os.Expand(hints, func(key string) string { + switch key { + case "lAddr": + return lAddr + default: + return fmt.Sprintf("", key) + } + }) + res, err := parseIPList(hints) + if err != nil { + return nil, fmt.Errorf("unable to parse source IP hints %q: %w", hints, err) + } + return res, nil +} diff --git a/main.go b/main.go index 4084690..0b4a925 100644 --- a/main.go +++ b/main.go @@ -72,7 +72,7 @@ type CLIArgs struct { passwdCost int positionalArgs []string proxy []string - sourceIPHints []net.IP + sourceIPHints string userIPHints bool } @@ -103,14 +103,7 @@ func parse_args() CLIArgs { args.proxy = append(args.proxy, p) return nil }) - flag.Func("ip-hints", "a comma-separated list of source addresses to use on dial attempts. Example: \"10.0.0.1,fe80::2,0.0.0.0,::\"", func(p string) error { - list, err := parseIPList(p) - if err != nil { - return err - } - args.sourceIPHints = list - return nil - }) + flag.StringVar(&args.sourceIPHints, "ip-hints", "", "a comma-separated list of source addresses to use on dial attempts. Example: \"10.0.0.1,fe80::2,0.0.0.0,::\"") flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header") flag.Parse() args.positionalArgs = flag.Args()