diff --git a/go.mod b/go.mod index 4f764fa..c21ddbe 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/tg123/go-htpasswd v1.2.1 golang.org/x/crypto v0.7.0 diff --git a/go.sum b/go.sum index cc4c260..c80efb9 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= diff --git a/hintdialer.go b/hintdialer.go new file mode 100644 index 0000000..d1ded98 --- /dev/null +++ b/hintdialer.go @@ -0,0 +1,138 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" +) + +var ( + ErrNoSuitableAddress = errors.New("no suitable address") + ErrBadIPAddressLength = errors.New("bad IP address length") + ErrUnknownNetwork = errors.New("unknown network") +) + +type BoundDialerContextKey struct{} + +type BoundDialerDefaultSink interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +type BoundDialer struct { + defaultDialer BoundDialerDefaultSink + defaultHints []net.IP +} + +func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) *BoundDialer { + if defaultDialer == nil { + defaultDialer = &net.Dialer{} + } + return &BoundDialer{ + defaultDialer: defaultDialer, + defaultHints: defaultHints, + } +} + +func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + hints := d.defaultHints + if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil { + if hintsOverrideValue, ok := hintsOverride.([]net.IP); ok { + hints = hintsOverrideValue + } + } + + if len(hints) == 0 { + return d.defaultDialer.DialContext(ctx, network, address) + } + + var netBase string + switch network { + case "tcp", "tcp4", "tcp6": + netBase = "tcp" + case "udp", "udp4", "udp6": + netBase = "udp" + case "ip", "ip4", "ip6": + netBase = "ip" + default: + return d.defaultDialer.DialContext(ctx, network, address) + } + + var resErr error + for _, lIP := range hints { + lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP) + if err != nil { + resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err)) + continue + } + if network != netBase && network != restrictedNetwork { + continue + } + + conn, err := (&net.Dialer{ + LocalAddr: lAddr, + }).DialContext(ctx, restrictedNetwork, address) + if err != nil { + resErr = multierror.Append(resErr, fmt.Errorf("dial failed: %w", err)) + } else { + return conn, nil + } + } + + if resErr == nil { + resErr = ErrNoSuitableAddress + } + return nil, resErr +} + +func (d *BoundDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) { + v6 := true + if ip4 := ip.To4(); len(ip4) == net.IPv4len { + ip = ip4 + v6 = false + } else if len(ip) != net.IPv6len { + return nil, "", ErrBadIPAddressLength + } + + var lAddr net.Addr + var lNetwork string + switch network { + case "tcp", "tcp4", "tcp6": + lAddr = &net.TCPAddr{ + IP: ip, + } + if v6 { + lNetwork = "tcp6" + } else { + lNetwork = "tcp4" + } + case "udp", "udp4", "udp6": + lAddr = &net.UDPAddr{ + IP: ip, + } + if v6 { + lNetwork = "udp6" + } else { + lNetwork = "udp4" + } + case "ip", "ip4", "ip6": + lAddr = &net.IPAddr{ + IP: ip, + } + if v6 { + lNetwork = "ip6" + } else { + lNetwork = "ip4" + } + default: + return nil, "", ErrUnknownNetwork + } + + return lAddr, lNetwork, nil +}