diff --git a/handler.go b/handler.go index 62ab0c5..936026b 100644 --- a/handler.go +++ b/handler.go @@ -10,21 +10,29 @@ import ( "time" ) +type HandlerDialer interface { + DialContext(ctx context.Context, net, address string) (net.Conn, error) +} + type ProxyHandler struct { timeout time.Duration auth Auth logger *CondLogger + dialer HandlerDialer httptransport http.RoundTripper outbound map[string]string outboundMux sync.RWMutex } -func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler { - httptransport := &http.Transport{} +func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, logger *CondLogger) *ProxyHandler { + httptransport := &http.Transport{ + DialContext: dialer.DialContext, + } return &ProxyHandler{ timeout: timeout, auth: auth, logger: logger, + dialer: dialer, httptransport: httptransport, outbound: make(map[string]string), } @@ -32,8 +40,7 @@ func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *Prox func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { ctx, _ := context.WithTimeout(req.Context(), s.timeout) - dialer := net.Dialer{} - conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI) + conn, err := s.dialer.DialContext(ctx, "tcp", req.RequestURI) if err != nil { s.logger.Error("Can't satisfy CONNECT request: %v", err) http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway) diff --git a/main.go b/main.go index acd0484..407500c 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "net" "net/http" "os" "path/filepath" @@ -141,7 +142,7 @@ func run() int { server := http.Server{ Addr: args.bind_address, - Handler: NewProxyHandler(args.timeout, auth, proxyLogger), + Handler: NewProxyHandler(args.timeout, auth, new(net.Dialer), proxyLogger), ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile), ReadTimeout: 0, ReadHeaderTimeout: 0,