dynamic integration for bounded dialer by user request

This commit is contained in:
Vladislav Yarmak 2023-06-26 23:53:08 +03:00
parent f1dc40b0ce
commit 61ba22edb5
2 changed files with 25 additions and 3 deletions

View File

@ -10,6 +10,8 @@ import (
"time" "time"
) )
const HintsHeaderName = "X-Src-IP-Hints"
type HandlerDialer interface { type HandlerDialer interface {
DialContext(ctx context.Context, net, address string) (net.Conn, error) DialContext(ctx context.Context, net, address string) (net.Conn, error)
} }
@ -22,11 +24,14 @@ type ProxyHandler struct {
httptransport http.RoundTripper httptransport http.RoundTripper
outbound map[string]string outbound map[string]string
outboundMux sync.RWMutex outboundMux sync.RWMutex
userIPHints bool
} }
func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, logger *CondLogger) *ProxyHandler { func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer,
userIPHints bool, logger *CondLogger) *ProxyHandler {
httptransport := &http.Transport{ httptransport := &http.Transport{
DialContext: dialer.DialContext, DialContext: dialer.DialContext,
DisableKeepAlives: userIPHints,
} }
return &ProxyHandler{ return &ProxyHandler{
timeout: timeout, timeout: timeout,
@ -35,6 +40,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, log
dialer: dialer, dialer: dialer,
httptransport: httptransport, httptransport: httptransport,
outbound: make(map[string]string), outbound: make(map[string]string),
userIPHints: userIPHints,
} }
} }
@ -133,6 +139,20 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
if !ok { if !ok {
return return
} }
if s.userIPHints {
hintValues := req.Header.Values(HintsHeaderName)
if len(hintValues) > 0 {
req.Header.Del(HintsHeaderName)
if hintIPs, err := parseIPList(hintValues[0]); err != nil {
s.logger.Info("Request: %v %q %v %v %v -- bad IP hint header: %q", req.RemoteAddr, username, req.Proto, req.Method, req.URL, hintValues[0])
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return
} else {
newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, hintIPs)
req = req.WithContext(newCtx)
}
}
}
delHopHeaders(req.Header) delHopHeaders(req.Header)
if isConnect { if isConnect {
s.HandleTunnel(wr, req) s.HandleTunnel(wr, req)

View File

@ -73,6 +73,7 @@ type CLIArgs struct {
positionalArgs []string positionalArgs []string
proxy []string proxy []string
sourceIPHints []net.IP sourceIPHints []net.IP
userIPHints bool
} }
func parse_args() CLIArgs { func parse_args() CLIArgs {
@ -110,6 +111,7 @@ func parse_args() CLIArgs {
args.sourceIPHints = list args.sourceIPHints = list
return nil return nil
}) })
flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header")
flag.Parse() flag.Parse()
args.positionalArgs = flag.Args() args.positionalArgs = flag.Args()
return args return args
@ -167,7 +169,7 @@ func run() int {
server := http.Server{ server := http.Server{
Addr: args.bind_address, Addr: args.bind_address,
Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), proxyLogger), Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), args.userIPHints, proxyLogger),
ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile), ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile),
ReadTimeout: 0, ReadTimeout: 0,
ReadHeaderTimeout: 0, ReadHeaderTimeout: 0,