implementation of local address-bound dialer
This commit is contained in:
parent
61221213c2
commit
36235d68fb
1
go.mod
1
go.mod
|
@ -5,6 +5,7 @@ go 1.13
|
|||
require (
|
||||
github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/tg123/go-htpasswd v1.2.1
|
||||
golang.org/x/crypto v0.7.0
|
||||
|
|
4
go.sum
4
go.sum
|
@ -8,6 +8,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
|||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoSuitableAddress = errors.New("no suitable address")
|
||||
ErrBadIPAddressLength = errors.New("bad IP address length")
|
||||
ErrUnknownNetwork = errors.New("unknown network")
|
||||
)
|
||||
|
||||
type BoundDialerContextKey struct{}
|
||||
|
||||
type BoundDialerDefaultSink interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type BoundDialer struct {
|
||||
defaultDialer BoundDialerDefaultSink
|
||||
defaultHints []net.IP
|
||||
}
|
||||
|
||||
func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) *BoundDialer {
|
||||
if defaultDialer == nil {
|
||||
defaultDialer = &net.Dialer{}
|
||||
}
|
||||
return &BoundDialer{
|
||||
defaultDialer: defaultDialer,
|
||||
defaultHints: defaultHints,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
hints := d.defaultHints
|
||||
if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil {
|
||||
if hintsOverrideValue, ok := hintsOverride.([]net.IP); ok {
|
||||
hints = hintsOverrideValue
|
||||
}
|
||||
}
|
||||
|
||||
if len(hints) == 0 {
|
||||
return d.defaultDialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
var netBase string
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
netBase = "tcp"
|
||||
case "udp", "udp4", "udp6":
|
||||
netBase = "udp"
|
||||
case "ip", "ip4", "ip6":
|
||||
netBase = "ip"
|
||||
default:
|
||||
return d.defaultDialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
var resErr error
|
||||
for _, lIP := range hints {
|
||||
lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP)
|
||||
if err != nil {
|
||||
resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err))
|
||||
continue
|
||||
}
|
||||
if network != netBase && network != restrictedNetwork {
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := (&net.Dialer{
|
||||
LocalAddr: lAddr,
|
||||
}).DialContext(ctx, restrictedNetwork, address)
|
||||
if err != nil {
|
||||
resErr = multierror.Append(resErr, fmt.Errorf("dial failed: %w", err))
|
||||
} else {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
if resErr == nil {
|
||||
resErr = ErrNoSuitableAddress
|
||||
}
|
||||
return nil, resErr
|
||||
}
|
||||
|
||||
func (d *BoundDialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) {
|
||||
v6 := true
|
||||
if ip4 := ip.To4(); len(ip4) == net.IPv4len {
|
||||
ip = ip4
|
||||
v6 = false
|
||||
} else if len(ip) != net.IPv6len {
|
||||
return nil, "", ErrBadIPAddressLength
|
||||
}
|
||||
|
||||
var lAddr net.Addr
|
||||
var lNetwork string
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
lAddr = &net.TCPAddr{
|
||||
IP: ip,
|
||||
}
|
||||
if v6 {
|
||||
lNetwork = "tcp6"
|
||||
} else {
|
||||
lNetwork = "tcp4"
|
||||
}
|
||||
case "udp", "udp4", "udp6":
|
||||
lAddr = &net.UDPAddr{
|
||||
IP: ip,
|
||||
}
|
||||
if v6 {
|
||||
lNetwork = "udp6"
|
||||
} else {
|
||||
lNetwork = "udp4"
|
||||
}
|
||||
case "ip", "ip4", "ip6":
|
||||
lAddr = &net.IPAddr{
|
||||
IP: ip,
|
||||
}
|
||||
if v6 {
|
||||
lNetwork = "ip6"
|
||||
} else {
|
||||
lNetwork = "ip4"
|
||||
}
|
||||
default:
|
||||
return nil, "", ErrUnknownNetwork
|
||||
}
|
||||
|
||||
return lAddr, lNetwork, nil
|
||||
}
|
Loading…
Reference in New Issue