diff --git a/README.md b/README.md index f9a33f9..0d6bf63 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,8 @@ Usage of /home/user/go/bin/dumbproxy: colon-separated list of enabled ciphers -disable-http2 disable HTTP2 + -ip-hints value + a comma-separated list of source addresses to use on dial attempts. Example: "10.0.0.1,fe80::2,0.0.0.0,::" -key string key for TLS certificate -list-ciphers @@ -206,6 +208,8 @@ Usage of /home/user/go/bin/dumbproxy: upstream proxy URL. Can be repeated multiple times to chain proxies. Examples: socks5h://127.0.0.1:9050; https://user:password@example.com:443 -timeout duration timeout for network operations (default 10s) + -user-ip-hints + allow IP hints to be specified by user in X-Src-IP-Hints header -verbosity int logging verbosity (10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical) (default 20) -version diff --git a/go.mod b/go.mod index 4f764fa..c21ddbe 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/tg123/go-htpasswd v1.2.1 golang.org/x/crypto v0.7.0 diff --git a/go.sum b/go.sum index cc4c260..c80efb9 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= diff --git a/handler.go b/handler.go index 936026b..f1ba95f 100644 --- a/handler.go +++ b/handler.go @@ -10,6 +10,8 @@ import ( "time" ) +const HintsHeaderName = "X-Src-IP-Hints" + type HandlerDialer interface { DialContext(ctx context.Context, net, address string) (net.Conn, error) } @@ -22,11 +24,14 @@ type ProxyHandler struct { httptransport http.RoundTripper outbound map[string]string outboundMux sync.RWMutex + userIPHints bool } -func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, logger *CondLogger) *ProxyHandler { +func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, + userIPHints bool, logger *CondLogger) *ProxyHandler { httptransport := &http.Transport{ - DialContext: dialer.DialContext, + DialContext: dialer.DialContext, + DisableKeepAlives: userIPHints, } return &ProxyHandler{ timeout: timeout, @@ -35,6 +40,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, log dialer: dialer, httptransport: httptransport, outbound: make(map[string]string), + userIPHints: userIPHints, } } @@ -133,6 +139,20 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { if !ok { return } + if s.userIPHints { + hintValues := req.Header.Values(HintsHeaderName) + if len(hintValues) > 0 { + req.Header.Del(HintsHeaderName) + if hintIPs, err := parseIPList(hintValues[0]); err != nil { + s.logger.Info("Request: %v %q %v %v %v -- bad IP hint header: %q", req.RemoteAddr, username, req.Proto, req.Method, req.URL, hintValues[0]) + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return + } else { + newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, hintIPs) + req = req.WithContext(newCtx) + } + } + } delHopHeaders(req.Header) if isConnect { s.HandleTunnel(wr, req) diff --git a/hintdialer.go b/hintdialer.go new file mode 100644 index 0000000..d1ded98 --- /dev/null +++ b/hintdialer.go @@ -0,0 +1,138 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" +) + +var ( + ErrNoSuitableAddress = errors.New("no suitable address") + ErrBadIPAddressLength = errors.New("bad IP address length") + ErrUnknownNetwork = errors.New("unknown network") +) + +type BoundDialerContextKey struct{} + +type BoundDialerDefaultSink interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +type BoundDialer struct { + defaultDialer BoundDialerDefaultSink + defaultHints []net.IP +} + +func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) *BoundDialer { + if defaultDialer == nil { + defaultDialer = &net.Dialer{} + } + return &BoundDialer{ + defaultDialer: defaultDialer, + defaultHints: defaultHints, + } +} + +func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + hints := d.defaultHints + if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil { + if hintsOverrideValue, ok := hintsOverride.([]net.IP); ok { + hints = hintsOverrideValue + } + } + + if len(hints) == 0 { + return d.defaultDialer.DialContext(ctx, network, address) + } + + var netBase string + switch network { + case "tcp", "tcp4", "tcp6": + netBase = "tcp" + case "udp", "udp4", "udp6": + netBase = "udp" + case "ip", "ip4", "ip6": + netBase = "ip" + default: + return d.defaultDialer.DialContext(ctx, network, address) + } + + var resErr error + for _, lIP := range hints { + lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP) + if err != nil { + resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err)) + continue + } + if network != netBase && network != restrictedNetwork { + continue + } + + conn, err := (&net.Dialer{ + LocalAddr: lAddr, + }).DialContext(ctx, restrictedNetwork, address) + if err != nil { + resErr = multierror.Append(resErr, fmt.Errorf("dial failed: %w", err)) + } else { + return conn, nil + } + } + + if resErr == nil { + resErr = ErrNoSuitableAddress + } + return nil, resErr +} + +func (d *BoundDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) { + v6 := true + if ip4 := ip.To4(); len(ip4) == net.IPv4len { + ip = ip4 + v6 = false + } else if len(ip) != net.IPv6len { + return nil, "", ErrBadIPAddressLength + } + + var lAddr net.Addr + var lNetwork string + switch network { + case "tcp", "tcp4", "tcp6": + lAddr = &net.TCPAddr{ + IP: ip, + } + if v6 { + lNetwork = "tcp6" + } else { + lNetwork = "tcp4" + } + case "udp", "udp4", "udp6": + lAddr = &net.UDPAddr{ + IP: ip, + } + if v6 { + lNetwork = "udp6" + } else { + lNetwork = "udp4" + } + case "ip", "ip4", "ip6": + lAddr = &net.IPAddr{ + IP: ip, + } + if v6 { + lNetwork = "ip6" + } else { + lNetwork = "ip4" + } + default: + return nil, "", ErrUnknownNetwork + } + + return lAddr, lNetwork, nil +} diff --git a/main.go b/main.go index 7d8b812..4084690 100644 --- a/main.go +++ b/main.go @@ -72,6 +72,8 @@ type CLIArgs struct { passwdCost int positionalArgs []string proxy []string + sourceIPHints []net.IP + userIPHints bool } func parse_args() CLIArgs { @@ -101,6 +103,15 @@ func parse_args() CLIArgs { args.proxy = append(args.proxy, p) return nil }) + flag.Func("ip-hints", "a comma-separated list of source addresses to use on dial attempts. Example: \"10.0.0.1,fe80::2,0.0.0.0,::\"", func(p string) error { + list, err := parseIPList(p) + if err != nil { + return err + } + args.sourceIPHints = list + return nil + }) + flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header") flag.Parse() args.positionalArgs = flag.Args() return args @@ -146,7 +157,7 @@ func run() int { } defer auth.Stop() - var dialer Dialer = new(net.Dialer) + var dialer Dialer = NewBoundDialer(new(net.Dialer), args.sourceIPHints) for _, proxyURL := range args.proxy { newDialer, err := proxyDialerFromURL(proxyURL, dialer) if err != nil { @@ -158,7 +169,7 @@ func run() int { server := http.Server{ Addr: args.bind_address, - Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), proxyLogger), + Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), args.userIPHints, proxyLogger), ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile), ReadTimeout: 0, ReadHeaderTimeout: 0, diff --git a/utils.go b/utils.go index 2070019..5cbfc89 100644 --- a/utils.go +++ b/utils.go @@ -369,3 +369,19 @@ func maybeWrapWithContextDialer(d Dialer) ContextDialer { } return wrappedDialer{d} } + +func parseIPList(list string) ([]net.IP, error) { + res := make([]net.IP, 0) + for _, elem := range strings.Split(list, ",") { + elem = strings.TrimSpace(elem) + if len(elem) == 0 { + continue + } + if parsed := net.ParseIP(elem); parsed == nil { + return nil, fmt.Errorf("unable to parse IP address %q", elem) + } else { + res = append(res, parsed) + } + } + return res, nil +}