allow use $lAddr in ip hint template

This commit is contained in:
Vladislav Yarmak 2023-09-06 17:34:07 +03:00
parent 88898cb94f
commit fd303bda12
3 changed files with 65 additions and 25 deletions

View File

@ -31,7 +31,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer,
userIPHints bool, logger *CondLogger) *ProxyHandler { userIPHints bool, logger *CondLogger) *ProxyHandler {
httptransport := &http.Transport{ httptransport := &http.Transport{
DialContext: dialer.DialContext, DialContext: dialer.DialContext,
DisableKeepAlives: userIPHints, DisableKeepAlives: true,
} }
return &ProxyHandler{ return &ProxyHandler{
timeout: timeout, timeout: timeout,
@ -134,25 +134,26 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
} }
username, ok := s.auth.Validate(wr, req) username, ok := s.auth.Validate(wr, req)
s.logger.Info("Request: %v %q %v %v %v", req.RemoteAddr, username, req.Proto, req.Method, req.URL) localAddr := getLocalAddr(req.Context())
s.logger.Info("Request: %v => %v %q %v %v %v", req.RemoteAddr, localAddr, username, req.Proto, req.Method, req.URL)
if !ok { if !ok {
return return
} }
var ipHints *string
if s.userIPHints { if s.userIPHints {
hintValues := req.Header.Values(HintsHeaderName) hintValues := req.Header.Values(HintsHeaderName)
if len(hintValues) > 0 { if len(hintValues) > 0 {
req.Header.Del(HintsHeaderName) req.Header.Del(HintsHeaderName)
if hintIPs, err := parseIPList(hintValues[0]); err != nil { ipHints = &hintValues[0]
s.logger.Info("Request: %v %q %v %v %v -- bad IP hint header: %q", req.RemoteAddr, username, req.Proto, req.Method, req.URL, hintValues[0]) }
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) }
return newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, BoundDialerContextValue{
} else { Hints: ipHints,
newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, hintIPs) LocalAddr: trimAddrPort(localAddr),
})
req = req.WithContext(newCtx) req = req.WithContext(newCtx)
}
}
}
delHopHeaders(req.Header) delHopHeaders(req.Header)
if isConnect { if isConnect {
s.HandleTunnel(wr, req) s.HandleTunnel(wr, req)
@ -160,3 +161,18 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.HandleRequest(wr, req) s.HandleRequest(wr, req)
} }
} }
func trimAddrPort(addrPort string) string {
res, _, err := net.SplitHostPort(addrPort)
if err != nil {
return addrPort
}
return res
}
func getLocalAddr(ctx context.Context) string {
if addr, ok := ctx.Value(http.LocalAddrContextKey).(net.Addr); ok {
return addr.String()
}
return "<request context is missing address>"
}

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
) )
@ -17,16 +18,21 @@ var (
type BoundDialerContextKey struct{} type BoundDialerContextKey struct{}
type BoundDialerContextValue struct {
Hints *string
LocalAddr string
}
type BoundDialerDefaultSink interface { type BoundDialerDefaultSink interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
type BoundDialer struct { type BoundDialer struct {
defaultDialer BoundDialerDefaultSink defaultDialer BoundDialerDefaultSink
defaultHints []net.IP defaultHints string
} }
func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) *BoundDialer { func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints string) *BoundDialer {
if defaultDialer == nil { if defaultDialer == nil {
defaultDialer = &net.Dialer{} defaultDialer = &net.Dialer{}
} }
@ -38,13 +44,22 @@ func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP)
func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
hints := d.defaultHints hints := d.defaultHints
lAddr := ""
if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil { if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil {
if hintsOverrideValue, ok := hintsOverride.([]net.IP); ok { if hintsOverrideValue, ok := hintsOverride.(BoundDialerContextValue); ok {
hints = hintsOverrideValue if hintsOverrideValue.Hints != nil {
hints = *hintsOverrideValue.Hints
}
lAddr = hintsOverrideValue.LocalAddr
} }
} }
if len(hints) == 0 { parsedHints, err := parseHints(hints, lAddr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
if len(parsedHints) == 0 {
return d.defaultDialer.DialContext(ctx, network, address) return d.defaultDialer.DialContext(ctx, network, address)
} }
@ -61,7 +76,7 @@ func (d *BoundDialer) DialContext(ctx context.Context, network, address string)
} }
var resErr error var resErr error
for _, lIP := range hints { for _, lIP := range parsedHints {
lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP) lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP)
if err != nil { if err != nil {
resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err)) resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err))
@ -136,3 +151,19 @@ func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) {
return lAddr, lNetwork, nil return lAddr, lNetwork, nil
} }
func parseHints(hints, lAddr string) ([]net.IP, error) {
hints = os.Expand(hints, func(key string) string {
switch key {
case "lAddr":
return lAddr
default:
return fmt.Sprintf("<bad key:%q>", key)
}
})
res, err := parseIPList(hints)
if err != nil {
return nil, fmt.Errorf("unable to parse source IP hints %q: %w", hints, err)
}
return res, nil
}

11
main.go
View File

@ -72,7 +72,7 @@ type CLIArgs struct {
passwdCost int passwdCost int
positionalArgs []string positionalArgs []string
proxy []string proxy []string
sourceIPHints []net.IP sourceIPHints string
userIPHints bool userIPHints bool
} }
@ -103,14 +103,7 @@ func parse_args() CLIArgs {
args.proxy = append(args.proxy, p) args.proxy = append(args.proxy, p)
return nil return nil
}) })
flag.Func("ip-hints", "a comma-separated list of source addresses to use on dial attempts. Example: \"10.0.0.1,fe80::2,0.0.0.0,::\"", func(p string) error { flag.StringVar(&args.sourceIPHints, "ip-hints", "", "a comma-separated list of source addresses to use on dial attempts. Example: \"10.0.0.1,fe80::2,0.0.0.0,::\"")
list, err := parseIPList(p)
if err != nil {
return err
}
args.sourceIPHints = list
return nil
})
flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header") flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header")
flag.Parse() flag.Parse()
args.positionalArgs = flag.Args() args.positionalArgs = flag.Args()