implementation of local address-bound dialer

This commit is contained in:
Vladislav Yarmak 2023-06-26 21:12:28 +03:00
parent 61221213c2
commit 36235d68fb
3 changed files with 143 additions and 0 deletions

1
go.mod
View File

@ -5,6 +5,7 @@ go 1.13
require (
github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/tg123/go-htpasswd v1.2.1
golang.org/x/crypto v0.7.0

4
go.sum
View File

@ -8,6 +8,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=

138
hintdialer.go Normal file
View File

@ -0,0 +1,138 @@
package main
import (
"context"
"errors"
"fmt"
"net"
"github.com/hashicorp/go-multierror"
)
var (
ErrNoSuitableAddress = errors.New("no suitable address")
ErrBadIPAddressLength = errors.New("bad IP address length")
ErrUnknownNetwork = errors.New("unknown network")
)
type BoundDialerContextKey struct{}
type BoundDialerDefaultSink interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type BoundDialer struct {
defaultDialer BoundDialerDefaultSink
defaultHints []net.IP
}
func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) *BoundDialer {
if defaultDialer == nil {
defaultDialer = &net.Dialer{}
}
return &BoundDialer{
defaultDialer: defaultDialer,
defaultHints: defaultHints,
}
}
func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
hints := d.defaultHints
if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil {
if hintsOverrideValue, ok := hintsOverride.([]net.IP); ok {
hints = hintsOverrideValue
}
}
if len(hints) == 0 {
return d.defaultDialer.DialContext(ctx, network, address)
}
var netBase string
switch network {
case "tcp", "tcp4", "tcp6":
netBase = "tcp"
case "udp", "udp4", "udp6":
netBase = "udp"
case "ip", "ip4", "ip6":
netBase = "ip"
default:
return d.defaultDialer.DialContext(ctx, network, address)
}
var resErr error
for _, lIP := range hints {
lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP)
if err != nil {
resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err))
continue
}
if network != netBase && network != restrictedNetwork {
continue
}
conn, err := (&net.Dialer{
LocalAddr: lAddr,
}).DialContext(ctx, restrictedNetwork, address)
if err != nil {
resErr = multierror.Append(resErr, fmt.Errorf("dial failed: %w", err))
} else {
return conn, nil
}
}
if resErr == nil {
resErr = ErrNoSuitableAddress
}
return nil, resErr
}
func (d *BoundDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) {
v6 := true
if ip4 := ip.To4(); len(ip4) == net.IPv4len {
ip = ip4
v6 = false
} else if len(ip) != net.IPv6len {
return nil, "", ErrBadIPAddressLength
}
var lAddr net.Addr
var lNetwork string
switch network {
case "tcp", "tcp4", "tcp6":
lAddr = &net.TCPAddr{
IP: ip,
}
if v6 {
lNetwork = "tcp6"
} else {
lNetwork = "tcp4"
}
case "udp", "udp4", "udp6":
lAddr = &net.UDPAddr{
IP: ip,
}
if v6 {
lNetwork = "udp6"
} else {
lNetwork = "udp4"
}
case "ip", "ip4", "ip6":
lAddr = &net.IPAddr{
IP: ip,
}
if v6 {
lNetwork = "ip6"
} else {
lNetwork = "ip4"
}
default:
return nil, "", ErrUnknownNetwork
}
return lAddr, lNetwork, nil
}