allow use $lAddr in ip hint template
This commit is contained in:
parent
88898cb94f
commit
fd303bda12
38
handler.go
38
handler.go
|
@ -31,7 +31,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer,
|
||||||
userIPHints bool, logger *CondLogger) *ProxyHandler {
|
userIPHints bool, logger *CondLogger) *ProxyHandler {
|
||||||
httptransport := &http.Transport{
|
httptransport := &http.Transport{
|
||||||
DialContext: dialer.DialContext,
|
DialContext: dialer.DialContext,
|
||||||
DisableKeepAlives: userIPHints,
|
DisableKeepAlives: true,
|
||||||
}
|
}
|
||||||
return &ProxyHandler{
|
return &ProxyHandler{
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
|
@ -134,25 +134,26 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
username, ok := s.auth.Validate(wr, req)
|
username, ok := s.auth.Validate(wr, req)
|
||||||
s.logger.Info("Request: %v %q %v %v %v", req.RemoteAddr, username, req.Proto, req.Method, req.URL)
|
localAddr := getLocalAddr(req.Context())
|
||||||
|
s.logger.Info("Request: %v => %v %q %v %v %v", req.RemoteAddr, localAddr, username, req.Proto, req.Method, req.URL)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ipHints *string
|
||||||
if s.userIPHints {
|
if s.userIPHints {
|
||||||
hintValues := req.Header.Values(HintsHeaderName)
|
hintValues := req.Header.Values(HintsHeaderName)
|
||||||
if len(hintValues) > 0 {
|
if len(hintValues) > 0 {
|
||||||
req.Header.Del(HintsHeaderName)
|
req.Header.Del(HintsHeaderName)
|
||||||
if hintIPs, err := parseIPList(hintValues[0]); err != nil {
|
ipHints = &hintValues[0]
|
||||||
s.logger.Info("Request: %v %q %v %v %v -- bad IP hint header: %q", req.RemoteAddr, username, req.Proto, req.Method, req.URL, hintValues[0])
|
}
|
||||||
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
}
|
||||||
return
|
newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, BoundDialerContextValue{
|
||||||
} else {
|
Hints: ipHints,
|
||||||
newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, hintIPs)
|
LocalAddr: trimAddrPort(localAddr),
|
||||||
|
})
|
||||||
req = req.WithContext(newCtx)
|
req = req.WithContext(newCtx)
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delHopHeaders(req.Header)
|
delHopHeaders(req.Header)
|
||||||
if isConnect {
|
if isConnect {
|
||||||
s.HandleTunnel(wr, req)
|
s.HandleTunnel(wr, req)
|
||||||
|
@ -160,3 +161,18 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||||
s.HandleRequest(wr, req)
|
s.HandleRequest(wr, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func trimAddrPort(addrPort string) string {
|
||||||
|
res, _, err := net.SplitHostPort(addrPort)
|
||||||
|
if err != nil {
|
||||||
|
return addrPort
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLocalAddr(ctx context.Context) string {
|
||||||
|
if addr, ok := ctx.Value(http.LocalAddrContextKey).(net.Addr); ok {
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
return "<request context is missing address>"
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
)
|
)
|
||||||
|
@ -17,16 +18,21 @@ var (
|
||||||
|
|
||||||
type BoundDialerContextKey struct{}
|
type BoundDialerContextKey struct{}
|
||||||
|
|
||||||
|
type BoundDialerContextValue struct {
|
||||||
|
Hints *string
|
||||||
|
LocalAddr string
|
||||||
|
}
|
||||||
|
|
||||||
type BoundDialerDefaultSink interface {
|
type BoundDialerDefaultSink interface {
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type BoundDialer struct {
|
type BoundDialer struct {
|
||||||
defaultDialer BoundDialerDefaultSink
|
defaultDialer BoundDialerDefaultSink
|
||||||
defaultHints []net.IP
|
defaultHints string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP) *BoundDialer {
|
func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints string) *BoundDialer {
|
||||||
if defaultDialer == nil {
|
if defaultDialer == nil {
|
||||||
defaultDialer = &net.Dialer{}
|
defaultDialer = &net.Dialer{}
|
||||||
}
|
}
|
||||||
|
@ -38,13 +44,22 @@ func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints []net.IP)
|
||||||
|
|
||||||
func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
hints := d.defaultHints
|
hints := d.defaultHints
|
||||||
|
lAddr := ""
|
||||||
if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil {
|
if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil {
|
||||||
if hintsOverrideValue, ok := hintsOverride.([]net.IP); ok {
|
if hintsOverrideValue, ok := hintsOverride.(BoundDialerContextValue); ok {
|
||||||
hints = hintsOverrideValue
|
if hintsOverrideValue.Hints != nil {
|
||||||
|
hints = *hintsOverrideValue.Hints
|
||||||
|
}
|
||||||
|
lAddr = hintsOverrideValue.LocalAddr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(hints) == 0 {
|
parsedHints, err := parseHints(hints, lAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dial failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parsedHints) == 0 {
|
||||||
return d.defaultDialer.DialContext(ctx, network, address)
|
return d.defaultDialer.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +76,7 @@ func (d *BoundDialer) DialContext(ctx context.Context, network, address string)
|
||||||
}
|
}
|
||||||
|
|
||||||
var resErr error
|
var resErr error
|
||||||
for _, lIP := range hints {
|
for _, lIP := range parsedHints {
|
||||||
lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP)
|
lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err))
|
resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err))
|
||||||
|
@ -136,3 +151,19 @@ func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) {
|
||||||
|
|
||||||
return lAddr, lNetwork, nil
|
return lAddr, lNetwork, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseHints(hints, lAddr string) ([]net.IP, error) {
|
||||||
|
hints = os.Expand(hints, func(key string) string {
|
||||||
|
switch key {
|
||||||
|
case "lAddr":
|
||||||
|
return lAddr
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("<bad key:%q>", key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
res, err := parseIPList(hints)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse source IP hints %q: %w", hints, err)
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
11
main.go
11
main.go
|
@ -72,7 +72,7 @@ type CLIArgs struct {
|
||||||
passwdCost int
|
passwdCost int
|
||||||
positionalArgs []string
|
positionalArgs []string
|
||||||
proxy []string
|
proxy []string
|
||||||
sourceIPHints []net.IP
|
sourceIPHints string
|
||||||
userIPHints bool
|
userIPHints bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,14 +103,7 @@ func parse_args() CLIArgs {
|
||||||
args.proxy = append(args.proxy, p)
|
args.proxy = append(args.proxy, p)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
flag.Func("ip-hints", "a comma-separated list of source addresses to use on dial attempts. Example: \"10.0.0.1,fe80::2,0.0.0.0,::\"", func(p string) error {
|
flag.StringVar(&args.sourceIPHints, "ip-hints", "", "a comma-separated list of source addresses to use on dial attempts. Example: \"10.0.0.1,fe80::2,0.0.0.0,::\"")
|
||||||
list, err := parseIPList(p)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
args.sourceIPHints = list
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header")
|
flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
args.positionalArgs = flag.Args()
|
args.positionalArgs = flag.Args()
|
||||||
|
|
Loading…
Reference in New Issue