2023-06-26 20:12:28 +02:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
2023-09-06 16:34:07 +02:00
|
|
|
"os"
|
2023-06-26 20:12:28 +02:00
|
|
|
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
ErrNoSuitableAddress = errors.New("no suitable address")
|
|
|
|
ErrBadIPAddressLength = errors.New("bad IP address length")
|
|
|
|
ErrUnknownNetwork = errors.New("unknown network")
|
|
|
|
)
|
|
|
|
|
|
|
|
type BoundDialerContextKey struct{}
|
|
|
|
|
2023-09-06 16:34:07 +02:00
|
|
|
type BoundDialerContextValue struct {
|
|
|
|
Hints *string
|
|
|
|
LocalAddr string
|
|
|
|
}
|
|
|
|
|
2023-06-26 20:12:28 +02:00
|
|
|
type BoundDialerDefaultSink interface {
|
|
|
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
type BoundDialer struct {
|
|
|
|
defaultDialer BoundDialerDefaultSink
|
2023-09-06 16:34:07 +02:00
|
|
|
defaultHints string
|
2023-06-26 20:12:28 +02:00
|
|
|
}
|
|
|
|
|
2023-09-06 16:34:07 +02:00
|
|
|
func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints string) *BoundDialer {
|
2023-06-26 20:12:28 +02:00
|
|
|
if defaultDialer == nil {
|
|
|
|
defaultDialer = &net.Dialer{}
|
|
|
|
}
|
|
|
|
return &BoundDialer{
|
|
|
|
defaultDialer: defaultDialer,
|
|
|
|
defaultHints: defaultHints,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
hints := d.defaultHints
|
2023-09-06 16:34:07 +02:00
|
|
|
lAddr := ""
|
2023-06-26 20:12:28 +02:00
|
|
|
if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil {
|
2023-09-06 16:34:07 +02:00
|
|
|
if hintsOverrideValue, ok := hintsOverride.(BoundDialerContextValue); ok {
|
|
|
|
if hintsOverrideValue.Hints != nil {
|
|
|
|
hints = *hintsOverrideValue.Hints
|
|
|
|
}
|
|
|
|
lAddr = hintsOverrideValue.LocalAddr
|
2023-06-26 20:12:28 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-06 16:34:07 +02:00
|
|
|
parsedHints, err := parseHints(hints, lAddr)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("dial failed: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(parsedHints) == 0 {
|
2023-06-26 20:12:28 +02:00
|
|
|
return d.defaultDialer.DialContext(ctx, network, address)
|
|
|
|
}
|
|
|
|
|
|
|
|
var netBase string
|
|
|
|
switch network {
|
|
|
|
case "tcp", "tcp4", "tcp6":
|
|
|
|
netBase = "tcp"
|
|
|
|
case "udp", "udp4", "udp6":
|
|
|
|
netBase = "udp"
|
|
|
|
case "ip", "ip4", "ip6":
|
|
|
|
netBase = "ip"
|
|
|
|
default:
|
|
|
|
return d.defaultDialer.DialContext(ctx, network, address)
|
|
|
|
}
|
|
|
|
|
|
|
|
var resErr error
|
2023-09-06 16:34:07 +02:00
|
|
|
for _, lIP := range parsedHints {
|
2023-06-26 20:12:28 +02:00
|
|
|
lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP)
|
|
|
|
if err != nil {
|
|
|
|
resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err))
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if network != netBase && network != restrictedNetwork {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
conn, err := (&net.Dialer{
|
|
|
|
LocalAddr: lAddr,
|
|
|
|
}).DialContext(ctx, restrictedNetwork, address)
|
|
|
|
if err != nil {
|
|
|
|
resErr = multierror.Append(resErr, fmt.Errorf("dial failed: %w", err))
|
|
|
|
} else {
|
|
|
|
return conn, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if resErr == nil {
|
|
|
|
resErr = ErrNoSuitableAddress
|
|
|
|
}
|
|
|
|
return nil, resErr
|
|
|
|
}
|
|
|
|
|
|
|
|
func (d *BoundDialer) Dial(network, address string) (net.Conn, error) {
|
|
|
|
return d.DialContext(context.Background(), network, address)
|
|
|
|
}
|
|
|
|
|
|
|
|
func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) {
|
|
|
|
v6 := true
|
|
|
|
if ip4 := ip.To4(); len(ip4) == net.IPv4len {
|
|
|
|
ip = ip4
|
|
|
|
v6 = false
|
|
|
|
} else if len(ip) != net.IPv6len {
|
|
|
|
return nil, "", ErrBadIPAddressLength
|
|
|
|
}
|
|
|
|
|
|
|
|
var lAddr net.Addr
|
|
|
|
var lNetwork string
|
|
|
|
switch network {
|
|
|
|
case "tcp", "tcp4", "tcp6":
|
|
|
|
lAddr = &net.TCPAddr{
|
|
|
|
IP: ip,
|
|
|
|
}
|
|
|
|
if v6 {
|
|
|
|
lNetwork = "tcp6"
|
|
|
|
} else {
|
|
|
|
lNetwork = "tcp4"
|
|
|
|
}
|
|
|
|
case "udp", "udp4", "udp6":
|
|
|
|
lAddr = &net.UDPAddr{
|
|
|
|
IP: ip,
|
|
|
|
}
|
|
|
|
if v6 {
|
|
|
|
lNetwork = "udp6"
|
|
|
|
} else {
|
|
|
|
lNetwork = "udp4"
|
|
|
|
}
|
|
|
|
case "ip", "ip4", "ip6":
|
|
|
|
lAddr = &net.IPAddr{
|
|
|
|
IP: ip,
|
|
|
|
}
|
|
|
|
if v6 {
|
|
|
|
lNetwork = "ip6"
|
|
|
|
} else {
|
|
|
|
lNetwork = "ip4"
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
return nil, "", ErrUnknownNetwork
|
|
|
|
}
|
|
|
|
|
|
|
|
return lAddr, lNetwork, nil
|
|
|
|
}
|
2023-09-06 16:34:07 +02:00
|
|
|
|
|
|
|
func parseHints(hints, lAddr string) ([]net.IP, error) {
|
|
|
|
hints = os.Expand(hints, func(key string) string {
|
|
|
|
switch key {
|
|
|
|
case "lAddr":
|
|
|
|
return lAddr
|
|
|
|
default:
|
|
|
|
return fmt.Sprintf("<bad key:%q>", key)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
res, err := parseIPList(hints)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("unable to parse source IP hints %q: %w", hints, err)
|
|
|
|
}
|
|
|
|
return res, nil
|
|
|
|
}
|