add list-ciphers option

This commit is contained in:
Vladislav Yarmak 2021-02-26 09:09:55 +02:00
parent 6dd28baf6b
commit 0725d27faf
6 changed files with 510 additions and 498 deletions

309
auth.go
View File

@ -1,16 +1,16 @@
package main package main
import ( import (
"os" "bufio"
"net/http" "crypto/subtle"
"net/url" "encoding/base64"
"errors" "errors"
"strings" "golang.org/x/crypto/bcrypt"
"strconv" "net/http"
"encoding/base64" "net/url"
"crypto/subtle" "os"
"golang.org/x/crypto/bcrypt" "strconv"
"bufio" "strings"
) )
const AUTH_REQUIRED_MSG = "Proxy authentication required.\n" const AUTH_REQUIRED_MSG = "Proxy authentication required.\n"
@ -19,183 +19,182 @@ const AUTH_TRIGGERED_MSG = "Browser auth triggered!\n"
const EPOCH_EXPIRE = "Thu, 01 Jan 1970 00:00:01 GMT" const EPOCH_EXPIRE = "Thu, 01 Jan 1970 00:00:01 GMT"
type Auth interface { type Auth interface {
Validate(wr http.ResponseWriter, req *http.Request) bool Validate(wr http.ResponseWriter, req *http.Request) bool
} }
func NewAuth(paramstr string) (Auth, error) { func NewAuth(paramstr string) (Auth, error) {
url, err := url.Parse(paramstr) url, err := url.Parse(paramstr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch strings.ToLower(url.Scheme) { switch strings.ToLower(url.Scheme) {
case "static": case "static":
return NewStaticAuth(url) return NewStaticAuth(url)
case "basicfile": case "basicfile":
return NewBasicFileAuth(url) return NewBasicFileAuth(url)
case "cert": case "cert":
return CertAuth{}, nil return CertAuth{}, nil
case "none": case "none":
return NoAuth{}, nil return NoAuth{}, nil
default: default:
return nil, errors.New("Unknown auth scheme") return nil, errors.New("Unknown auth scheme")
} }
} }
func NewStaticAuth(param_url *url.URL) (*BasicAuth, error) { func NewStaticAuth(param_url *url.URL) (*BasicAuth, error) {
values, err := url.ParseQuery(param_url.RawQuery) values, err := url.ParseQuery(param_url.RawQuery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
username := values.Get("username") username := values.Get("username")
if username == "" { if username == "" {
return nil, errors.New("\"username\" parameter is missing from auth config URI") return nil, errors.New("\"username\" parameter is missing from auth config URI")
} }
password := values.Get("password") password := values.Get("password")
if password == "" { if password == "" {
return nil, errors.New("\"password\" parameter is missing from auth config URI") return nil, errors.New("\"password\" parameter is missing from auth config URI")
} }
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &BasicAuth{ return &BasicAuth{
users: map[string][]byte{ users: map[string][]byte{
username: hashedPassword, username: hashedPassword,
}, },
hiddenDomain: strings.ToLower(values.Get("hidden_domain")), hiddenDomain: strings.ToLower(values.Get("hidden_domain")),
}, nil }, nil
} }
func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain string) { func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain string) {
if hidden_domain != "" && if hidden_domain != "" &&
(subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 && (subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 &&
subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) { subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
} else { } else {
wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`) wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`)
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG)))) wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG))))
wr.WriteHeader(407) wr.WriteHeader(407)
wr.Write([]byte(AUTH_REQUIRED_MSG)) wr.Write([]byte(AUTH_REQUIRED_MSG))
} }
} }
type BasicAuth struct { type BasicAuth struct {
users map[string][]byte users map[string][]byte
hiddenDomain string hiddenDomain string
} }
func NewBasicFileAuth(param_url *url.URL) (*BasicAuth, error) { func NewBasicFileAuth(param_url *url.URL) (*BasicAuth, error) {
values, err := url.ParseQuery(param_url.RawQuery) values, err := url.ParseQuery(param_url.RawQuery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
filename := values.Get("path") filename := values.Get("path")
if filename == "" { if filename == "" {
return nil, errors.New("\"path\" parameter is missing from auth config URI") return nil, errors.New("\"path\" parameter is missing from auth config URI")
} }
f, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close() defer f.Close()
scanner := bufio.NewScanner(f) scanner := bufio.NewScanner(f)
users := make(map[string][]byte) users := make(map[string][]byte)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
trimmed := strings.TrimSpace(line) trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") { if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue continue
} }
pair := strings.SplitN(line, ":", 2) pair := strings.SplitN(line, ":", 2)
if len(pair) != 2 { if len(pair) != 2 {
return nil, errors.New("Malformed login and password line") return nil, errors.New("Malformed login and password line")
} }
login := pair[0] login := pair[0]
password := pair[1] password := pair[1]
users[login] = []byte(password) users[login] = []byte(password)
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return nil, err return nil, err
} }
if len(users) == 0 { if len(users) == 0 {
return nil, errors.New("No password lines were read from file") return nil, errors.New("No password lines were read from file")
} }
return &BasicAuth{ return &BasicAuth{
users: users, users: users,
hiddenDomain: strings.ToLower(values.Get("hidden_domain")), hiddenDomain: strings.ToLower(values.Get("hidden_domain")),
}, nil }, nil
} }
func (auth *BasicAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { func (auth *BasicAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
hdr := req.Header.Get("Proxy-Authorization") hdr := req.Header.Get("Proxy-Authorization")
if hdr == "" { if hdr == "" {
requireBasicAuth(wr, req, auth.hiddenDomain) requireBasicAuth(wr, req, auth.hiddenDomain)
return false return false
} }
hdr_parts := strings.SplitN(hdr, " ", 2) hdr_parts := strings.SplitN(hdr, " ", 2)
if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" { if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" {
requireBasicAuth(wr, req, auth.hiddenDomain) requireBasicAuth(wr, req, auth.hiddenDomain)
return false return false
} }
token := hdr_parts[1] token := hdr_parts[1]
data, err := base64.StdEncoding.DecodeString(token) data, err := base64.StdEncoding.DecodeString(token)
if err != nil { if err != nil {
requireBasicAuth(wr, req, auth.hiddenDomain) requireBasicAuth(wr, req, auth.hiddenDomain)
return false return false
} }
pair := strings.SplitN(string(data), ":", 2) pair := strings.SplitN(string(data), ":", 2)
if len(pair) != 2 { if len(pair) != 2 {
requireBasicAuth(wr, req, auth.hiddenDomain) requireBasicAuth(wr, req, auth.hiddenDomain)
return false return false
} }
login := pair[0] login := pair[0]
password := pair[1] password := pair[1]
hashedPassword, ok := auth.users[login] hashedPassword, ok := auth.users[login]
if !ok { if !ok {
requireBasicAuth(wr, req, auth.hiddenDomain) requireBasicAuth(wr, req, auth.hiddenDomain)
return false return false
} }
if bcrypt.CompareHashAndPassword(hashedPassword, []byte(password)) == nil { if bcrypt.CompareHashAndPassword(hashedPassword, []byte(password)) == nil {
if auth.hiddenDomain != "" && if auth.hiddenDomain != "" &&
(req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) { (req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) {
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_TRIGGERED_MSG)))) wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_TRIGGERED_MSG))))
wr.Header().Set("Pragma", "no-cache") wr.Header().Set("Pragma", "no-cache")
wr.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") wr.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
wr.Header().Set("Expires", EPOCH_EXPIRE) wr.Header().Set("Expires", EPOCH_EXPIRE)
wr.Header()["Date"] = nil wr.Header()["Date"] = nil
wr.WriteHeader(http.StatusOK) wr.WriteHeader(http.StatusOK)
wr.Write([]byte(AUTH_TRIGGERED_MSG)) wr.Write([]byte(AUTH_TRIGGERED_MSG))
return false return false
} else { } else {
return true return true
} }
} }
requireBasicAuth(wr, req, auth.hiddenDomain) requireBasicAuth(wr, req, auth.hiddenDomain)
return false return false
} }
type NoAuth struct {} type NoAuth struct{}
func (_ NoAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { func (_ NoAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
return true return true
} }
type CertAuth struct{}
type CertAuth struct {}
func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) bool { func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 { if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return false return false
} else { } else {
return true return true
} }
} }

View File

@ -1,58 +1,58 @@
package main package main
import ( import (
"log" "fmt"
"fmt" "log"
) )
const ( const (
CRITICAL = 50 CRITICAL = 50
ERROR = 40 ERROR = 40
WARNING = 30 WARNING = 30
INFO = 20 INFO = 20
DEBUG = 10 DEBUG = 10
NOTSET = 0 NOTSET = 0
) )
type CondLogger struct { type CondLogger struct {
logger *log.Logger logger *log.Logger
verbosity int verbosity int
} }
func (cl *CondLogger) Log(verb int, format string, v ...interface{}) error { func (cl *CondLogger) Log(verb int, format string, v ...interface{}) error {
if verb >= cl.verbosity { if verb >= cl.verbosity {
return cl.logger.Output(2, fmt.Sprintf(format, v...)) return cl.logger.Output(2, fmt.Sprintf(format, v...))
} }
return nil return nil
} }
func (cl *CondLogger) log(verb int, format string, v ...interface{}) error { func (cl *CondLogger) log(verb int, format string, v ...interface{}) error {
if verb >= cl.verbosity { if verb >= cl.verbosity {
return cl.logger.Output(3, fmt.Sprintf(format, v...)) return cl.logger.Output(3, fmt.Sprintf(format, v...))
} }
return nil return nil
} }
func (cl *CondLogger) Critical(s string, v ...interface{}) error { func (cl *CondLogger) Critical(s string, v ...interface{}) error {
return cl.log(CRITICAL, "CRITICAL " + s, v...) return cl.log(CRITICAL, "CRITICAL "+s, v...)
} }
func (cl *CondLogger) Error(s string, v ...interface{}) error { func (cl *CondLogger) Error(s string, v ...interface{}) error {
return cl.log(ERROR, "ERROR " + s, v...) return cl.log(ERROR, "ERROR "+s, v...)
} }
func (cl *CondLogger) Warning(s string, v ...interface{}) error { func (cl *CondLogger) Warning(s string, v ...interface{}) error {
return cl.log(WARNING, "WARNING " + s, v...) return cl.log(WARNING, "WARNING "+s, v...)
} }
func (cl *CondLogger) Info(s string, v ...interface{}) error { func (cl *CondLogger) Info(s string, v ...interface{}) error {
return cl.log(INFO, "INFO " + s, v...) return cl.log(INFO, "INFO "+s, v...)
} }
func (cl *CondLogger) Debug(s string, v ...interface{}) error { func (cl *CondLogger) Debug(s string, v ...interface{}) error {
return cl.log(DEBUG, "DEBUG " + s, v...) return cl.log(DEBUG, "DEBUG "+s, v...)
} }
func NewCondLogger(logger *log.Logger, verbosity int) *CondLogger { func NewCondLogger(logger *log.Logger, verbosity int) *CondLogger {
return &CondLogger{verbosity: verbosity, logger: logger} return &CondLogger{verbosity: verbosity, logger: logger}
} }

View File

@ -1,133 +1,133 @@
package main package main
import ( import (
"net" "context"
"fmt" "fmt"
"time" "net"
"net/http" "net/http"
"strings" "strings"
"context" "sync"
"sync" "time"
) )
type ProxyHandler struct { type ProxyHandler struct {
timeout time.Duration timeout time.Duration
auth Auth auth Auth
logger *CondLogger logger *CondLogger
httptransport http.RoundTripper httptransport http.RoundTripper
outbound map[string]string outbound map[string]string
outboundMux sync.RWMutex outboundMux sync.RWMutex
} }
func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler { func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler {
httptransport := &http.Transport{} httptransport := &http.Transport{}
return &ProxyHandler{ return &ProxyHandler{
timeout: timeout, timeout: timeout,
auth: auth, auth: auth,
logger: logger, logger: logger,
httptransport: httptransport, httptransport: httptransport,
outbound: make(map[string]string), outbound: make(map[string]string),
} }
} }
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
ctx, _ := context.WithTimeout(req.Context(), s.timeout) ctx, _ := context.WithTimeout(req.Context(), s.timeout)
dialer := net.Dialer{} dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI) conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI)
if err != nil { if err != nil {
s.logger.Error("Can't satisfy CONNECT request: %v", err) s.logger.Error("Can't satisfy CONNECT request: %v", err)
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway) http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
return return
} }
localAddr := conn.LocalAddr().String() localAddr := conn.LocalAddr().String()
s.outboundMux.Lock() s.outboundMux.Lock()
s.outbound[localAddr] = req.RemoteAddr s.outbound[localAddr] = req.RemoteAddr
s.outboundMux.Unlock() s.outboundMux.Unlock()
defer func() { defer func() {
conn.Close() conn.Close()
s.outboundMux.Lock() s.outboundMux.Lock()
delete(s.outbound, localAddr) delete(s.outbound, localAddr)
s.outboundMux.Unlock() s.outboundMux.Unlock()
}() }()
if req.ProtoMajor == 0 || req.ProtoMajor == 1 { if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
// Upgrade client connection // Upgrade client connection
localconn, _, err := hijack(wr) localconn, _, err := hijack(wr)
if err != nil { if err != nil {
s.logger.Error("Can't hijack client connection: %v", err) s.logger.Error("Can't hijack client connection: %v", err)
http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError) http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
return return
} }
defer localconn.Close() defer localconn.Close()
// Inform client connection is built // Inform client connection is built
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor) fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
proxy(req.Context(), localconn, conn) proxy(req.Context(), localconn, conn)
} else if req.ProtoMajor == 2 { } else if req.ProtoMajor == 2 {
wr.Header()["Date"] = nil wr.Header()["Date"] = nil
wr.WriteHeader(http.StatusOK) wr.WriteHeader(http.StatusOK)
flush(wr) flush(wr)
proxyh2(req.Context(), req.Body, wr, conn) proxyh2(req.Context(), req.Body, wr, conn)
} else { } else {
s.logger.Error("Unsupported protocol version: %s", req.Proto) s.logger.Error("Unsupported protocol version: %s", req.Proto)
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest) http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
return return
} }
} }
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) { func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) {
req.RequestURI = "" req.RequestURI = ""
if req.ProtoMajor == 2 { if req.ProtoMajor == 2 {
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
req.URL.Host = req.Host req.URL.Host = req.Host
} }
resp, err := s.httptransport.RoundTrip(req) resp, err := s.httptransport.RoundTrip(req)
if err != nil { if err != nil {
s.logger.Error("HTTP fetch error: %v", err) s.logger.Error("HTTP fetch error: %v", err)
http.Error(wr, "Server Error", http.StatusInternalServerError) http.Error(wr, "Server Error", http.StatusInternalServerError)
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status) s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
delHopHeaders(resp.Header) delHopHeaders(resp.Header)
copyHeader(wr.Header(), resp.Header) copyHeader(wr.Header(), resp.Header)
wr.WriteHeader(resp.StatusCode) wr.WriteHeader(resp.StatusCode)
flush(wr) flush(wr)
copyBody(wr, resp.Body) copyBody(wr, resp.Body)
} }
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) { func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {
s.outboundMux.RLock() s.outboundMux.RLock()
originator, found := s.outbound[req.RemoteAddr] originator, found := s.outbound[req.RemoteAddr]
s.outboundMux.RUnlock() s.outboundMux.RUnlock()
return originator, found return originator, found
} }
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL) s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL)
if originator, isLoopback := s.isLoopback(req) ; isLoopback { if originator, isLoopback := s.isLoopback(req); isLoopback {
s.logger.Critical("Loopback tunnel detected: %s is an outbound " + s.logger.Critical("Loopback tunnel detected: %s is an outbound "+
"address for another request from %s", req.RemoteAddr, originator) "address for another request from %s", req.RemoteAddr, originator)
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return return
} }
isConnect := strings.ToUpper(req.Method) == "CONNECT" isConnect := strings.ToUpper(req.Method) == "CONNECT"
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
req.Host == "" && req.ProtoMajor == 2 { req.Host == "" && req.ProtoMajor == 2 {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return return
} }
if !s.auth.Validate(wr, req) { if !s.auth.Validate(wr, req) {
return return
} }
delHopHeaders(req.Header) delHopHeaders(req.Header)
if isConnect { if isConnect {
s.HandleTunnel(wr, req) s.HandleTunnel(wr, req)
} else { } else {
s.HandleRequest(wr, req) s.HandleRequest(wr, req)
} }
} }

View File

@ -1,57 +1,57 @@
package main package main
import ( import (
"io" "errors"
"errors" "io"
"time" "time"
) )
const MAX_LOG_QLEN = 128 const MAX_LOG_QLEN = 128
const QUEUE_SHUTDOWN_TIMEOUT = 500 * time.Millisecond const QUEUE_SHUTDOWN_TIMEOUT = 500 * time.Millisecond
type LogWriter struct { type LogWriter struct {
writer io.Writer writer io.Writer
ch chan []byte ch chan []byte
done chan struct{} done chan struct{}
} }
func (lw *LogWriter) Write(p []byte) (int, error) { func (lw *LogWriter) Write(p []byte) (int, error) {
if p == nil { if p == nil {
return 0, errors.New("Can't write nil byte slice") return 0, errors.New("Can't write nil byte slice")
} }
buf := make([]byte, len(p)) buf := make([]byte, len(p))
copy(buf, p) copy(buf, p)
select { select {
case lw.ch <- buf: case lw.ch <- buf:
return len(p), nil return len(p), nil
default: default:
return 0, errors.New("Writer queue overflow") return 0, errors.New("Writer queue overflow")
} }
} }
func NewLogWriter(writer io.Writer) *LogWriter { func NewLogWriter(writer io.Writer) *LogWriter {
lw := &LogWriter{writer, lw := &LogWriter{writer,
make(chan []byte, MAX_LOG_QLEN), make(chan []byte, MAX_LOG_QLEN),
make(chan struct{})} make(chan struct{})}
go lw.loop() go lw.loop()
return lw return lw
} }
func (lw *LogWriter) loop() { func (lw *LogWriter) loop() {
for p := range lw.ch { for p := range lw.ch {
if p == nil { if p == nil {
break break
} }
lw.writer.Write(p) lw.writer.Write(p)
} }
lw.done <- struct{}{} lw.done <- struct{}{}
} }
func (lw *LogWriter) Close() { func (lw *LogWriter) Close() {
lw.ch <- nil lw.ch <- nil
timer := time.After(QUEUE_SHUTDOWN_TIMEOUT) timer := time.After(QUEUE_SHUTDOWN_TIMEOUT)
select { select {
case <-timer: case <-timer:
case <-lw.done: case <-lw.done:
} }
} }

147
main.go
View File

@ -1,95 +1,108 @@
package main package main
import ( import (
"log" "crypto/tls"
"os" "flag"
"fmt" "fmt"
"flag" "log"
"time" "net/http"
"net/http" "os"
"time"
) )
func perror(msg string) { func perror(msg string) {
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, msg) fmt.Fprintln(os.Stderr, msg)
} }
func arg_fail(msg string) { func arg_fail(msg string) {
perror(msg) perror(msg)
perror("Usage:") perror("Usage:")
flag.PrintDefaults() flag.PrintDefaults()
os.Exit(2) os.Exit(2)
} }
type CLIArgs struct { type CLIArgs struct {
bind_address string bind_address string
auth string auth string
verbosity int verbosity int
timeout time.Duration timeout time.Duration
cert, key, cafile string cert, key, cafile string
list_suites bool
} }
func list_suites() {
for _, cipher := range tls.CipherSuites() {
fmt.Println(cipher.Name)
}
}
func parse_args() CLIArgs { func parse_args() CLIArgs {
var args CLIArgs var args CLIArgs
flag.StringVar(&args.bind_address, "bind-address", ":8080", "HTTP proxy listen address") flag.StringVar(&args.bind_address, "bind-address", ":8080", "HTTP proxy listen address")
flag.StringVar(&args.auth, "auth", "none://", "auth parameters") flag.StringVar(&args.auth, "auth", "none://", "auth parameters")
flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity " + flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity "+
"(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)") "(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)")
flag.DurationVar(&args.timeout, "timeout", 10 * time.Second, "timeout for network operations") flag.DurationVar(&args.timeout, "timeout", 10*time.Second, "timeout for network operations")
flag.StringVar(&args.cert, "cert", "", "enable TLS and use certificate") flag.StringVar(&args.cert, "cert", "", "enable TLS and use certificate")
flag.StringVar(&args.key, "key", "", "key for TLS certificate") flag.StringVar(&args.key, "key", "", "key for TLS certificate")
flag.StringVar(&args.cafile, "cafile", "", "CA file to authenticate clients with certificates") flag.StringVar(&args.cafile, "cafile", "", "CA file to authenticate clients with certificates")
flag.Parse() flag.BoolVar(&args.list_suites, "list-ciphers", false, "list ciphersuites")
return args flag.Parse()
return args
} }
func run() int { func run() int {
args := parse_args() args := parse_args()
logWriter := NewLogWriter(os.Stderr) if args.list_suites {
defer logWriter.Close() list_suites()
return 0
}
mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ", logWriter := NewLogWriter(os.Stderr)
log.LstdFlags | log.Lshortfile), defer logWriter.Close()
args.verbosity)
proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ",
log.LstdFlags | log.Lshortfile),
args.verbosity)
auth, err := NewAuth(args.auth) mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ",
if err != nil { log.LstdFlags|log.Lshortfile),
mainLogger.Critical("Failed to instantiate auth provider: %v", err) args.verbosity)
return 3 proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ",
} log.LstdFlags|log.Lshortfile),
args.verbosity)
server := http.Server{ auth, err := NewAuth(args.auth)
Addr: args.bind_address, if err != nil {
Handler: NewProxyHandler(args.timeout, auth, proxyLogger), mainLogger.Critical("Failed to instantiate auth provider: %v", err)
ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags | log.Lshortfile), return 3
ReadTimeout: 0, }
ReadHeaderTimeout: 0,
WriteTimeout: 0,
IdleTimeout: 0,
}
mainLogger.Info("Starting proxy server...") server := http.Server{
if args.cert != "" { Addr: args.bind_address,
cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile) Handler: NewProxyHandler(args.timeout, auth, proxyLogger),
if err1 != nil { ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile),
mainLogger.Critical("TLS config construction failed: %v", err) ReadTimeout: 0,
return 3 ReadHeaderTimeout: 0,
} WriteTimeout: 0,
server.TLSConfig = cfg IdleTimeout: 0,
err = server.ListenAndServeTLS("", "") }
} else {
err = server.ListenAndServe() mainLogger.Info("Starting proxy server...")
} if args.cert != "" {
mainLogger.Critical("Server terminated with a reason: %v", err) cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile)
mainLogger.Info("Shutting down...") if err1 != nil {
return 0 mainLogger.Critical("TLS config construction failed: %v", err)
return 3
}
server.TLSConfig = cfg
err = server.ListenAndServeTLS("", "")
} else {
err = server.ListenAndServe()
}
mainLogger.Critical("Server terminated with a reason: %v", err)
mainLogger.Info("Shutting down...")
return 0
} }
func main() { func main() {
os.Exit(run()) os.Exit(run())
} }

228
utils.go
View File

@ -1,75 +1,75 @@
package main package main
import ( import (
"context" "bufio"
"net" "context"
"sync" "crypto/tls"
"io" "crypto/x509"
"time" "errors"
"errors" "io"
"net/http" "io/ioutil"
"bufio" "net"
"crypto/tls" "net/http"
"crypto/x509" "sync"
"io/ioutil" "time"
) )
const COPY_BUF = 128 * 1024 const COPY_BUF = 128 * 1024
func proxy(ctx context.Context, left, right net.Conn) { func proxy(ctx context.Context, left, right net.Conn) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
cpy := func (dst, src net.Conn) { cpy := func(dst, src net.Conn) {
defer wg.Done() defer wg.Done()
io.Copy(dst, src) io.Copy(dst, src)
dst.Close() dst.Close()
} }
wg.Add(2) wg.Add(2)
go cpy(left, right) go cpy(left, right)
go cpy(right, left) go cpy(right, left)
groupdone := make(chan struct{}, 1) groupdone := make(chan struct{}, 1)
go func() { go func() {
wg.Wait() wg.Wait()
groupdone <-struct{}{} groupdone <- struct{}{}
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
left.Close() left.Close()
right.Close() right.Close()
case <-groupdone: case <-groupdone:
return return
} }
<-groupdone <-groupdone
return return
} }
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
ltr := func (dst net.Conn, src io.Reader) { ltr := func(dst net.Conn, src io.Reader) {
defer wg.Done() defer wg.Done()
io.Copy(dst, src) io.Copy(dst, src)
dst.Close() dst.Close()
} }
rtl := func (dst io.Writer, src io.Reader) { rtl := func(dst io.Writer, src io.Reader) {
defer wg.Done() defer wg.Done()
copyBody(dst, src) copyBody(dst, src)
} }
wg.Add(2) wg.Add(2)
go ltr(right, leftreader) go ltr(right, leftreader)
go rtl(leftwriter, right) go rtl(leftwriter, right)
groupdone := make(chan struct{}, 1) groupdone := make(chan struct{}, 1)
go func() { go func() {
wg.Wait() wg.Wait()
groupdone <-struct{}{} groupdone <- struct{}{}
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
leftreader.Close() leftreader.Close()
right.Close() right.Close()
case <-groupdone: case <-groupdone:
return return
} }
<-groupdone <-groupdone
return return
} }
// Hop-by-hop headers. These are removed when sent to the backend. // Hop-by-hop headers. These are removed when sent to the backend.
@ -79,7 +79,7 @@ var hopHeaders = []string{
"Keep-Alive", "Keep-Alive",
"Proxy-Authenticate", "Proxy-Authenticate",
"Proxy-Connection", "Proxy-Connection",
"Proxy-Authorization", "Proxy-Authorization",
"Te", // canonicalized version of "TE" "Te", // canonicalized version of "TE"
"Trailers", "Trailers",
"Transfer-Encoding", "Transfer-Encoding",
@ -101,65 +101,65 @@ func delHopHeaders(header http.Header) {
} }
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := hijackable.(http.Hijacker) hj, ok := hijackable.(http.Hijacker)
if !ok { if !ok {
return nil, nil, errors.New("Connection doesn't support hijacking") return nil, nil, errors.New("Connection doesn't support hijacking")
} }
conn, rw, err := hj.Hijack() conn, rw, err := hj.Hijack()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var emptytime time.Time var emptytime time.Time
err = conn.SetDeadline(emptytime) err = conn.SetDeadline(emptytime)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, nil, err return nil, nil, err
} }
return conn, rw, nil return conn, rw, nil
} }
func flush(flusher interface{}) bool { func flush(flusher interface{}) bool {
f, ok := flusher.(http.Flusher) f, ok := flusher.(http.Flusher)
if !ok { if !ok {
return false return false
} }
f.Flush() f.Flush()
return true return true
} }
func copyBody(wr io.Writer, body io.Reader) { func copyBody(wr io.Writer, body io.Reader) {
buf := make([]byte, COPY_BUF) buf := make([]byte, COPY_BUF)
for { for {
bread, read_err := body.Read(buf) bread, read_err := body.Read(buf)
var write_err error var write_err error
if bread > 0 { if bread > 0 {
_, write_err = wr.Write(buf[:bread]) _, write_err = wr.Write(buf[:bread])
flush(wr) flush(wr)
} }
if read_err != nil || write_err != nil { if read_err != nil || write_err != nil {
break break
} }
} }
} }
func makeServerTLSConfig(certfile, keyfile, cafile string) (*tls.Config, error) { func makeServerTLSConfig(certfile, keyfile, cafile string) (*tls.Config, error) {
var cfg tls.Config var cfg tls.Config
cert, err := tls.LoadX509KeyPair(certfile, keyfile) cert, err := tls.LoadX509KeyPair(certfile, keyfile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cfg.Certificates = []tls.Certificate{cert} cfg.Certificates = []tls.Certificate{cert}
if cafile != "" { if cafile != "" {
roots := x509.NewCertPool() roots := x509.NewCertPool()
certs, err := ioutil.ReadFile(cafile) certs, err := ioutil.ReadFile(cafile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ok := roots.AppendCertsFromPEM(certs); !ok { if ok := roots.AppendCertsFromPEM(certs); !ok {
return nil, errors.New("Failed to load CA certificates") return nil, errors.New("Failed to load CA certificates")
} }
cfg.ClientCAs = roots cfg.ClientCAs = roots
cfg.ClientAuth = tls.VerifyClientCertIfGiven cfg.ClientAuth = tls.VerifyClientCertIfGiven
} }
return &cfg, nil return &cfg, nil
} }