dynamic integration for bounded dialer by user request
This commit is contained in:
parent
f1dc40b0ce
commit
61ba22edb5
22
handler.go
22
handler.go
|
@ -10,6 +10,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const HintsHeaderName = "X-Src-IP-Hints"
|
||||||
|
|
||||||
type HandlerDialer interface {
|
type HandlerDialer interface {
|
||||||
DialContext(ctx context.Context, net, address string) (net.Conn, error)
|
DialContext(ctx context.Context, net, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
@ -22,11 +24,14 @@ type ProxyHandler struct {
|
||||||
httptransport http.RoundTripper
|
httptransport http.RoundTripper
|
||||||
outbound map[string]string
|
outbound map[string]string
|
||||||
outboundMux sync.RWMutex
|
outboundMux sync.RWMutex
|
||||||
|
userIPHints bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, logger *CondLogger) *ProxyHandler {
|
func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer,
|
||||||
|
userIPHints bool, logger *CondLogger) *ProxyHandler {
|
||||||
httptransport := &http.Transport{
|
httptransport := &http.Transport{
|
||||||
DialContext: dialer.DialContext,
|
DialContext: dialer.DialContext,
|
||||||
|
DisableKeepAlives: userIPHints,
|
||||||
}
|
}
|
||||||
return &ProxyHandler{
|
return &ProxyHandler{
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
|
@ -35,6 +40,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, log
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
httptransport: httptransport,
|
httptransport: httptransport,
|
||||||
outbound: make(map[string]string),
|
outbound: make(map[string]string),
|
||||||
|
userIPHints: userIPHints,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,6 +139,20 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.userIPHints {
|
||||||
|
hintValues := req.Header.Values(HintsHeaderName)
|
||||||
|
if len(hintValues) > 0 {
|
||||||
|
req.Header.Del(HintsHeaderName)
|
||||||
|
if hintIPs, err := parseIPList(hintValues[0]); err != nil {
|
||||||
|
s.logger.Info("Request: %v %q %v %v %v -- bad IP hint header: %q", req.RemoteAddr, username, req.Proto, req.Method, req.URL, hintValues[0])
|
||||||
|
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, hintIPs)
|
||||||
|
req = req.WithContext(newCtx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
delHopHeaders(req.Header)
|
delHopHeaders(req.Header)
|
||||||
if isConnect {
|
if isConnect {
|
||||||
s.HandleTunnel(wr, req)
|
s.HandleTunnel(wr, req)
|
||||||
|
|
4
main.go
4
main.go
|
@ -73,6 +73,7 @@ type CLIArgs struct {
|
||||||
positionalArgs []string
|
positionalArgs []string
|
||||||
proxy []string
|
proxy []string
|
||||||
sourceIPHints []net.IP
|
sourceIPHints []net.IP
|
||||||
|
userIPHints bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func parse_args() CLIArgs {
|
func parse_args() CLIArgs {
|
||||||
|
@ -110,6 +111,7 @@ func parse_args() CLIArgs {
|
||||||
args.sourceIPHints = list
|
args.sourceIPHints = list
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
args.positionalArgs = flag.Args()
|
args.positionalArgs = flag.Args()
|
||||||
return args
|
return args
|
||||||
|
@ -167,7 +169,7 @@ func run() int {
|
||||||
|
|
||||||
server := http.Server{
|
server := http.Server{
|
||||||
Addr: args.bind_address,
|
Addr: args.bind_address,
|
||||||
Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), proxyLogger),
|
Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), args.userIPHints, proxyLogger),
|
||||||
ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile),
|
ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile),
|
||||||
ReadTimeout: 0,
|
ReadTimeout: 0,
|
||||||
ReadHeaderTimeout: 0,
|
ReadHeaderTimeout: 0,
|
||||||
|
|
Loading…
Reference in New Issue