add list-ciphers option
This commit is contained in:
parent
6dd28baf6b
commit
0725d27faf
309
auth.go
309
auth.go
|
@ -1,16 +1,16 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"bufio"
|
||||||
"net/http"
|
"crypto/subtle"
|
||||||
"net/url"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"strconv"
|
"net/http"
|
||||||
"encoding/base64"
|
"net/url"
|
||||||
"crypto/subtle"
|
"os"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"strconv"
|
||||||
"bufio"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const AUTH_REQUIRED_MSG = "Proxy authentication required.\n"
|
const AUTH_REQUIRED_MSG = "Proxy authentication required.\n"
|
||||||
|
@ -19,183 +19,182 @@ const AUTH_TRIGGERED_MSG = "Browser auth triggered!\n"
|
||||||
const EPOCH_EXPIRE = "Thu, 01 Jan 1970 00:00:01 GMT"
|
const EPOCH_EXPIRE = "Thu, 01 Jan 1970 00:00:01 GMT"
|
||||||
|
|
||||||
type Auth interface {
|
type Auth interface {
|
||||||
Validate(wr http.ResponseWriter, req *http.Request) bool
|
Validate(wr http.ResponseWriter, req *http.Request) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuth(paramstr string) (Auth, error) {
|
func NewAuth(paramstr string) (Auth, error) {
|
||||||
url, err := url.Parse(paramstr)
|
url, err := url.Parse(paramstr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch strings.ToLower(url.Scheme) {
|
switch strings.ToLower(url.Scheme) {
|
||||||
case "static":
|
case "static":
|
||||||
return NewStaticAuth(url)
|
return NewStaticAuth(url)
|
||||||
case "basicfile":
|
case "basicfile":
|
||||||
return NewBasicFileAuth(url)
|
return NewBasicFileAuth(url)
|
||||||
case "cert":
|
case "cert":
|
||||||
return CertAuth{}, nil
|
return CertAuth{}, nil
|
||||||
case "none":
|
case "none":
|
||||||
return NoAuth{}, nil
|
return NoAuth{}, nil
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("Unknown auth scheme")
|
return nil, errors.New("Unknown auth scheme")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStaticAuth(param_url *url.URL) (*BasicAuth, error) {
|
func NewStaticAuth(param_url *url.URL) (*BasicAuth, error) {
|
||||||
values, err := url.ParseQuery(param_url.RawQuery)
|
values, err := url.ParseQuery(param_url.RawQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
username := values.Get("username")
|
username := values.Get("username")
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return nil, errors.New("\"username\" parameter is missing from auth config URI")
|
return nil, errors.New("\"username\" parameter is missing from auth config URI")
|
||||||
}
|
}
|
||||||
password := values.Get("password")
|
password := values.Get("password")
|
||||||
if password == "" {
|
if password == "" {
|
||||||
return nil, errors.New("\"password\" parameter is missing from auth config URI")
|
return nil, errors.New("\"password\" parameter is missing from auth config URI")
|
||||||
}
|
}
|
||||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &BasicAuth{
|
return &BasicAuth{
|
||||||
users: map[string][]byte{
|
users: map[string][]byte{
|
||||||
username: hashedPassword,
|
username: hashedPassword,
|
||||||
},
|
},
|
||||||
hiddenDomain: strings.ToLower(values.Get("hidden_domain")),
|
hiddenDomain: strings.ToLower(values.Get("hidden_domain")),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain string) {
|
func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain string) {
|
||||||
if hidden_domain != "" &&
|
if hidden_domain != "" &&
|
||||||
(subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 &&
|
(subtle.ConstantTimeCompare([]byte(req.URL.Host), []byte(hidden_domain)) != 1 &&
|
||||||
subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) {
|
subtle.ConstantTimeCompare([]byte(req.Host), []byte(hidden_domain)) != 1) {
|
||||||
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
||||||
} else {
|
} else {
|
||||||
wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`)
|
wr.Header().Set("Proxy-Authenticate", `Basic realm="dumbproxy"`)
|
||||||
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG))))
|
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_REQUIRED_MSG))))
|
||||||
wr.WriteHeader(407)
|
wr.WriteHeader(407)
|
||||||
wr.Write([]byte(AUTH_REQUIRED_MSG))
|
wr.Write([]byte(AUTH_REQUIRED_MSG))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type BasicAuth struct {
|
type BasicAuth struct {
|
||||||
users map[string][]byte
|
users map[string][]byte
|
||||||
hiddenDomain string
|
hiddenDomain string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBasicFileAuth(param_url *url.URL) (*BasicAuth, error) {
|
func NewBasicFileAuth(param_url *url.URL) (*BasicAuth, error) {
|
||||||
values, err := url.ParseQuery(param_url.RawQuery)
|
values, err := url.ParseQuery(param_url.RawQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
filename := values.Get("path")
|
filename := values.Get("path")
|
||||||
if filename == "" {
|
if filename == "" {
|
||||||
return nil, errors.New("\"path\" parameter is missing from auth config URI")
|
return nil, errors.New("\"path\" parameter is missing from auth config URI")
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Open(filename)
|
f, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
scanner := bufio.NewScanner(f)
|
scanner := bufio.NewScanner(f)
|
||||||
users := make(map[string][]byte)
|
users := make(map[string][]byte)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
trimmed := strings.TrimSpace(line)
|
trimmed := strings.TrimSpace(line)
|
||||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
pair := strings.SplitN(line, ":", 2)
|
pair := strings.SplitN(line, ":", 2)
|
||||||
if len(pair) != 2 {
|
if len(pair) != 2 {
|
||||||
return nil, errors.New("Malformed login and password line")
|
return nil, errors.New("Malformed login and password line")
|
||||||
}
|
}
|
||||||
login := pair[0]
|
login := pair[0]
|
||||||
password := pair[1]
|
password := pair[1]
|
||||||
users[login] = []byte(password)
|
users[login] = []byte(password)
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(users) == 0 {
|
if len(users) == 0 {
|
||||||
return nil, errors.New("No password lines were read from file")
|
return nil, errors.New("No password lines were read from file")
|
||||||
}
|
}
|
||||||
return &BasicAuth{
|
return &BasicAuth{
|
||||||
users: users,
|
users: users,
|
||||||
hiddenDomain: strings.ToLower(values.Get("hidden_domain")),
|
hiddenDomain: strings.ToLower(values.Get("hidden_domain")),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *BasicAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
|
func (auth *BasicAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
|
||||||
hdr := req.Header.Get("Proxy-Authorization")
|
hdr := req.Header.Get("Proxy-Authorization")
|
||||||
if hdr == "" {
|
if hdr == "" {
|
||||||
requireBasicAuth(wr, req, auth.hiddenDomain)
|
requireBasicAuth(wr, req, auth.hiddenDomain)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
hdr_parts := strings.SplitN(hdr, " ", 2)
|
hdr_parts := strings.SplitN(hdr, " ", 2)
|
||||||
if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" {
|
if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" {
|
||||||
requireBasicAuth(wr, req, auth.hiddenDomain)
|
requireBasicAuth(wr, req, auth.hiddenDomain)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
token := hdr_parts[1]
|
token := hdr_parts[1]
|
||||||
data, err := base64.StdEncoding.DecodeString(token)
|
data, err := base64.StdEncoding.DecodeString(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requireBasicAuth(wr, req, auth.hiddenDomain)
|
requireBasicAuth(wr, req, auth.hiddenDomain)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
pair := strings.SplitN(string(data), ":", 2)
|
pair := strings.SplitN(string(data), ":", 2)
|
||||||
if len(pair) != 2 {
|
if len(pair) != 2 {
|
||||||
requireBasicAuth(wr, req, auth.hiddenDomain)
|
requireBasicAuth(wr, req, auth.hiddenDomain)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
login := pair[0]
|
login := pair[0]
|
||||||
password := pair[1]
|
password := pair[1]
|
||||||
|
|
||||||
hashedPassword, ok := auth.users[login]
|
hashedPassword, ok := auth.users[login]
|
||||||
if !ok {
|
if !ok {
|
||||||
requireBasicAuth(wr, req, auth.hiddenDomain)
|
requireBasicAuth(wr, req, auth.hiddenDomain)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if bcrypt.CompareHashAndPassword(hashedPassword, []byte(password)) == nil {
|
if bcrypt.CompareHashAndPassword(hashedPassword, []byte(password)) == nil {
|
||||||
if auth.hiddenDomain != "" &&
|
if auth.hiddenDomain != "" &&
|
||||||
(req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) {
|
(req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) {
|
||||||
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_TRIGGERED_MSG))))
|
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_TRIGGERED_MSG))))
|
||||||
wr.Header().Set("Pragma", "no-cache")
|
wr.Header().Set("Pragma", "no-cache")
|
||||||
wr.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
wr.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||||
wr.Header().Set("Expires", EPOCH_EXPIRE)
|
wr.Header().Set("Expires", EPOCH_EXPIRE)
|
||||||
wr.Header()["Date"] = nil
|
wr.Header()["Date"] = nil
|
||||||
wr.WriteHeader(http.StatusOK)
|
wr.WriteHeader(http.StatusOK)
|
||||||
wr.Write([]byte(AUTH_TRIGGERED_MSG))
|
wr.Write([]byte(AUTH_TRIGGERED_MSG))
|
||||||
return false
|
return false
|
||||||
} else {
|
} else {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
requireBasicAuth(wr, req, auth.hiddenDomain)
|
requireBasicAuth(wr, req, auth.hiddenDomain)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
type NoAuth struct {}
|
type NoAuth struct{}
|
||||||
|
|
||||||
func (_ NoAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
|
func (_ NoAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CertAuth struct{}
|
||||||
type CertAuth struct {}
|
|
||||||
|
|
||||||
func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
|
func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) bool {
|
||||||
if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 {
|
if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 {
|
||||||
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
||||||
return false
|
return false
|
||||||
} else {
|
} else {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
48
condlog.go
48
condlog.go
|
@ -1,58 +1,58 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"fmt"
|
||||||
"fmt"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CRITICAL = 50
|
CRITICAL = 50
|
||||||
ERROR = 40
|
ERROR = 40
|
||||||
WARNING = 30
|
WARNING = 30
|
||||||
INFO = 20
|
INFO = 20
|
||||||
DEBUG = 10
|
DEBUG = 10
|
||||||
NOTSET = 0
|
NOTSET = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
type CondLogger struct {
|
type CondLogger struct {
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
verbosity int
|
verbosity int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *CondLogger) Log(verb int, format string, v ...interface{}) error {
|
func (cl *CondLogger) Log(verb int, format string, v ...interface{}) error {
|
||||||
if verb >= cl.verbosity {
|
if verb >= cl.verbosity {
|
||||||
return cl.logger.Output(2, fmt.Sprintf(format, v...))
|
return cl.logger.Output(2, fmt.Sprintf(format, v...))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *CondLogger) log(verb int, format string, v ...interface{}) error {
|
func (cl *CondLogger) log(verb int, format string, v ...interface{}) error {
|
||||||
if verb >= cl.verbosity {
|
if verb >= cl.verbosity {
|
||||||
return cl.logger.Output(3, fmt.Sprintf(format, v...))
|
return cl.logger.Output(3, fmt.Sprintf(format, v...))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *CondLogger) Critical(s string, v ...interface{}) error {
|
func (cl *CondLogger) Critical(s string, v ...interface{}) error {
|
||||||
return cl.log(CRITICAL, "CRITICAL " + s, v...)
|
return cl.log(CRITICAL, "CRITICAL "+s, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *CondLogger) Error(s string, v ...interface{}) error {
|
func (cl *CondLogger) Error(s string, v ...interface{}) error {
|
||||||
return cl.log(ERROR, "ERROR " + s, v...)
|
return cl.log(ERROR, "ERROR "+s, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *CondLogger) Warning(s string, v ...interface{}) error {
|
func (cl *CondLogger) Warning(s string, v ...interface{}) error {
|
||||||
return cl.log(WARNING, "WARNING " + s, v...)
|
return cl.log(WARNING, "WARNING "+s, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *CondLogger) Info(s string, v ...interface{}) error {
|
func (cl *CondLogger) Info(s string, v ...interface{}) error {
|
||||||
return cl.log(INFO, "INFO " + s, v...)
|
return cl.log(INFO, "INFO "+s, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *CondLogger) Debug(s string, v ...interface{}) error {
|
func (cl *CondLogger) Debug(s string, v ...interface{}) error {
|
||||||
return cl.log(DEBUG, "DEBUG " + s, v...)
|
return cl.log(DEBUG, "DEBUG "+s, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCondLogger(logger *log.Logger, verbosity int) *CondLogger {
|
func NewCondLogger(logger *log.Logger, verbosity int) *CondLogger {
|
||||||
return &CondLogger{verbosity: verbosity, logger: logger}
|
return &CondLogger{verbosity: verbosity, logger: logger}
|
||||||
}
|
}
|
||||||
|
|
206
handler.go
206
handler.go
|
@ -1,133 +1,133 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"context"
|
"sync"
|
||||||
"sync"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxyHandler struct {
|
type ProxyHandler struct {
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
auth Auth
|
auth Auth
|
||||||
logger *CondLogger
|
logger *CondLogger
|
||||||
httptransport http.RoundTripper
|
httptransport http.RoundTripper
|
||||||
outbound map[string]string
|
outbound map[string]string
|
||||||
outboundMux sync.RWMutex
|
outboundMux sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler {
|
func NewProxyHandler(timeout time.Duration, auth Auth, logger *CondLogger) *ProxyHandler {
|
||||||
httptransport := &http.Transport{}
|
httptransport := &http.Transport{}
|
||||||
return &ProxyHandler{
|
return &ProxyHandler{
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
httptransport: httptransport,
|
httptransport: httptransport,
|
||||||
outbound: make(map[string]string),
|
outbound: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
|
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
|
||||||
ctx, _ := context.WithTimeout(req.Context(), s.timeout)
|
ctx, _ := context.WithTimeout(req.Context(), s.timeout)
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI)
|
conn, err := dialer.DialContext(ctx, "tcp", req.RequestURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("Can't satisfy CONNECT request: %v", err)
|
s.logger.Error("Can't satisfy CONNECT request: %v", err)
|
||||||
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
|
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
localAddr := conn.LocalAddr().String()
|
localAddr := conn.LocalAddr().String()
|
||||||
s.outboundMux.Lock()
|
s.outboundMux.Lock()
|
||||||
s.outbound[localAddr] = req.RemoteAddr
|
s.outbound[localAddr] = req.RemoteAddr
|
||||||
s.outboundMux.Unlock()
|
s.outboundMux.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.outboundMux.Lock()
|
s.outboundMux.Lock()
|
||||||
delete(s.outbound, localAddr)
|
delete(s.outbound, localAddr)
|
||||||
s.outboundMux.Unlock()
|
s.outboundMux.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
|
if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
|
||||||
// Upgrade client connection
|
// Upgrade client connection
|
||||||
localconn, _, err := hijack(wr)
|
localconn, _, err := hijack(wr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("Can't hijack client connection: %v", err)
|
s.logger.Error("Can't hijack client connection: %v", err)
|
||||||
http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
|
http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer localconn.Close()
|
defer localconn.Close()
|
||||||
|
|
||||||
// Inform client connection is built
|
// Inform client connection is built
|
||||||
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
|
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
|
||||||
|
|
||||||
proxy(req.Context(), localconn, conn)
|
proxy(req.Context(), localconn, conn)
|
||||||
} else if req.ProtoMajor == 2 {
|
} else if req.ProtoMajor == 2 {
|
||||||
wr.Header()["Date"] = nil
|
wr.Header()["Date"] = nil
|
||||||
wr.WriteHeader(http.StatusOK)
|
wr.WriteHeader(http.StatusOK)
|
||||||
flush(wr)
|
flush(wr)
|
||||||
proxyh2(req.Context(), req.Body, wr, conn)
|
proxyh2(req.Context(), req.Body, wr, conn)
|
||||||
} else {
|
} else {
|
||||||
s.logger.Error("Unsupported protocol version: %s", req.Proto)
|
s.logger.Error("Unsupported protocol version: %s", req.Proto)
|
||||||
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
|
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) {
|
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) {
|
||||||
req.RequestURI = ""
|
req.RequestURI = ""
|
||||||
if req.ProtoMajor == 2 {
|
if req.ProtoMajor == 2 {
|
||||||
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
|
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
|
||||||
req.URL.Host = req.Host
|
req.URL.Host = req.Host
|
||||||
}
|
}
|
||||||
resp, err := s.httptransport.RoundTrip(req)
|
resp, err := s.httptransport.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("HTTP fetch error: %v", err)
|
s.logger.Error("HTTP fetch error: %v", err)
|
||||||
http.Error(wr, "Server Error", http.StatusInternalServerError)
|
http.Error(wr, "Server Error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
|
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
|
||||||
delHopHeaders(resp.Header)
|
delHopHeaders(resp.Header)
|
||||||
copyHeader(wr.Header(), resp.Header)
|
copyHeader(wr.Header(), resp.Header)
|
||||||
wr.WriteHeader(resp.StatusCode)
|
wr.WriteHeader(resp.StatusCode)
|
||||||
flush(wr)
|
flush(wr)
|
||||||
copyBody(wr, resp.Body)
|
copyBody(wr, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {
|
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {
|
||||||
s.outboundMux.RLock()
|
s.outboundMux.RLock()
|
||||||
originator, found := s.outbound[req.RemoteAddr]
|
originator, found := s.outbound[req.RemoteAddr]
|
||||||
s.outboundMux.RUnlock()
|
s.outboundMux.RUnlock()
|
||||||
return originator, found
|
return originator, found
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||||
s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL)
|
s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL)
|
||||||
|
|
||||||
if originator, isLoopback := s.isLoopback(req) ; isLoopback {
|
if originator, isLoopback := s.isLoopback(req); isLoopback {
|
||||||
s.logger.Critical("Loopback tunnel detected: %s is an outbound " +
|
s.logger.Critical("Loopback tunnel detected: %s is an outbound "+
|
||||||
"address for another request from %s", req.RemoteAddr, originator)
|
"address for another request from %s", req.RemoteAddr, originator)
|
||||||
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
isConnect := strings.ToUpper(req.Method) == "CONNECT"
|
isConnect := strings.ToUpper(req.Method) == "CONNECT"
|
||||||
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
|
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
|
||||||
req.Host == "" && req.ProtoMajor == 2 {
|
req.Host == "" && req.ProtoMajor == 2 {
|
||||||
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !s.auth.Validate(wr, req) {
|
if !s.auth.Validate(wr, req) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delHopHeaders(req.Header)
|
delHopHeaders(req.Header)
|
||||||
if isConnect {
|
if isConnect {
|
||||||
s.HandleTunnel(wr, req)
|
s.HandleTunnel(wr, req)
|
||||||
} else {
|
} else {
|
||||||
s.HandleRequest(wr, req)
|
s.HandleRequest(wr, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
70
logwriter.go
70
logwriter.go
|
@ -1,57 +1,57 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"errors"
|
||||||
"errors"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MAX_LOG_QLEN = 128
|
const MAX_LOG_QLEN = 128
|
||||||
const QUEUE_SHUTDOWN_TIMEOUT = 500 * time.Millisecond
|
const QUEUE_SHUTDOWN_TIMEOUT = 500 * time.Millisecond
|
||||||
|
|
||||||
type LogWriter struct {
|
type LogWriter struct {
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
ch chan []byte
|
ch chan []byte
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lw *LogWriter) Write(p []byte) (int, error) {
|
func (lw *LogWriter) Write(p []byte) (int, error) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return 0, errors.New("Can't write nil byte slice")
|
return 0, errors.New("Can't write nil byte slice")
|
||||||
}
|
}
|
||||||
buf := make([]byte, len(p))
|
buf := make([]byte, len(p))
|
||||||
copy(buf, p)
|
copy(buf, p)
|
||||||
select {
|
select {
|
||||||
case lw.ch <- buf:
|
case lw.ch <- buf:
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("Writer queue overflow")
|
return 0, errors.New("Writer queue overflow")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLogWriter(writer io.Writer) *LogWriter {
|
func NewLogWriter(writer io.Writer) *LogWriter {
|
||||||
lw := &LogWriter{writer,
|
lw := &LogWriter{writer,
|
||||||
make(chan []byte, MAX_LOG_QLEN),
|
make(chan []byte, MAX_LOG_QLEN),
|
||||||
make(chan struct{})}
|
make(chan struct{})}
|
||||||
go lw.loop()
|
go lw.loop()
|
||||||
return lw
|
return lw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lw *LogWriter) loop() {
|
func (lw *LogWriter) loop() {
|
||||||
for p := range lw.ch {
|
for p := range lw.ch {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
lw.writer.Write(p)
|
lw.writer.Write(p)
|
||||||
}
|
}
|
||||||
lw.done <- struct{}{}
|
lw.done <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lw *LogWriter) Close() {
|
func (lw *LogWriter) Close() {
|
||||||
lw.ch <- nil
|
lw.ch <- nil
|
||||||
timer := time.After(QUEUE_SHUTDOWN_TIMEOUT)
|
timer := time.After(QUEUE_SHUTDOWN_TIMEOUT)
|
||||||
select {
|
select {
|
||||||
case <-timer:
|
case <-timer:
|
||||||
case <-lw.done:
|
case <-lw.done:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
147
main.go
147
main.go
|
@ -1,95 +1,108 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"crypto/tls"
|
||||||
"os"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"flag"
|
"log"
|
||||||
"time"
|
"net/http"
|
||||||
"net/http"
|
"os"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func perror(msg string) {
|
func perror(msg string) {
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
fmt.Fprintln(os.Stderr, msg)
|
fmt.Fprintln(os.Stderr, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func arg_fail(msg string) {
|
func arg_fail(msg string) {
|
||||||
perror(msg)
|
perror(msg)
|
||||||
perror("Usage:")
|
perror("Usage:")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
type CLIArgs struct {
|
type CLIArgs struct {
|
||||||
bind_address string
|
bind_address string
|
||||||
auth string
|
auth string
|
||||||
verbosity int
|
verbosity int
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cert, key, cafile string
|
cert, key, cafile string
|
||||||
|
list_suites bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func list_suites() {
|
||||||
|
for _, cipher := range tls.CipherSuites() {
|
||||||
|
fmt.Println(cipher.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func parse_args() CLIArgs {
|
func parse_args() CLIArgs {
|
||||||
var args CLIArgs
|
var args CLIArgs
|
||||||
flag.StringVar(&args.bind_address, "bind-address", ":8080", "HTTP proxy listen address")
|
flag.StringVar(&args.bind_address, "bind-address", ":8080", "HTTP proxy listen address")
|
||||||
flag.StringVar(&args.auth, "auth", "none://", "auth parameters")
|
flag.StringVar(&args.auth, "auth", "none://", "auth parameters")
|
||||||
flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity " +
|
flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity "+
|
||||||
"(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)")
|
"(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)")
|
||||||
flag.DurationVar(&args.timeout, "timeout", 10 * time.Second, "timeout for network operations")
|
flag.DurationVar(&args.timeout, "timeout", 10*time.Second, "timeout for network operations")
|
||||||
flag.StringVar(&args.cert, "cert", "", "enable TLS and use certificate")
|
flag.StringVar(&args.cert, "cert", "", "enable TLS and use certificate")
|
||||||
flag.StringVar(&args.key, "key", "", "key for TLS certificate")
|
flag.StringVar(&args.key, "key", "", "key for TLS certificate")
|
||||||
flag.StringVar(&args.cafile, "cafile", "", "CA file to authenticate clients with certificates")
|
flag.StringVar(&args.cafile, "cafile", "", "CA file to authenticate clients with certificates")
|
||||||
flag.Parse()
|
flag.BoolVar(&args.list_suites, "list-ciphers", false, "list ciphersuites")
|
||||||
return args
|
flag.Parse()
|
||||||
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
func run() int {
|
func run() int {
|
||||||
args := parse_args()
|
args := parse_args()
|
||||||
|
|
||||||
logWriter := NewLogWriter(os.Stderr)
|
if args.list_suites {
|
||||||
defer logWriter.Close()
|
list_suites()
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ",
|
logWriter := NewLogWriter(os.Stderr)
|
||||||
log.LstdFlags | log.Lshortfile),
|
defer logWriter.Close()
|
||||||
args.verbosity)
|
|
||||||
proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ",
|
|
||||||
log.LstdFlags | log.Lshortfile),
|
|
||||||
args.verbosity)
|
|
||||||
|
|
||||||
auth, err := NewAuth(args.auth)
|
mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ",
|
||||||
if err != nil {
|
log.LstdFlags|log.Lshortfile),
|
||||||
mainLogger.Critical("Failed to instantiate auth provider: %v", err)
|
args.verbosity)
|
||||||
return 3
|
proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ",
|
||||||
}
|
log.LstdFlags|log.Lshortfile),
|
||||||
|
args.verbosity)
|
||||||
|
|
||||||
server := http.Server{
|
auth, err := NewAuth(args.auth)
|
||||||
Addr: args.bind_address,
|
if err != nil {
|
||||||
Handler: NewProxyHandler(args.timeout, auth, proxyLogger),
|
mainLogger.Critical("Failed to instantiate auth provider: %v", err)
|
||||||
ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags | log.Lshortfile),
|
return 3
|
||||||
ReadTimeout: 0,
|
}
|
||||||
ReadHeaderTimeout: 0,
|
|
||||||
WriteTimeout: 0,
|
|
||||||
IdleTimeout: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
mainLogger.Info("Starting proxy server...")
|
server := http.Server{
|
||||||
if args.cert != "" {
|
Addr: args.bind_address,
|
||||||
cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile)
|
Handler: NewProxyHandler(args.timeout, auth, proxyLogger),
|
||||||
if err1 != nil {
|
ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile),
|
||||||
mainLogger.Critical("TLS config construction failed: %v", err)
|
ReadTimeout: 0,
|
||||||
return 3
|
ReadHeaderTimeout: 0,
|
||||||
}
|
WriteTimeout: 0,
|
||||||
server.TLSConfig = cfg
|
IdleTimeout: 0,
|
||||||
err = server.ListenAndServeTLS("", "")
|
}
|
||||||
} else {
|
|
||||||
err = server.ListenAndServe()
|
mainLogger.Info("Starting proxy server...")
|
||||||
}
|
if args.cert != "" {
|
||||||
mainLogger.Critical("Server terminated with a reason: %v", err)
|
cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile)
|
||||||
mainLogger.Info("Shutting down...")
|
if err1 != nil {
|
||||||
return 0
|
mainLogger.Critical("TLS config construction failed: %v", err)
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
server.TLSConfig = cfg
|
||||||
|
err = server.ListenAndServeTLS("", "")
|
||||||
|
} else {
|
||||||
|
err = server.ListenAndServe()
|
||||||
|
}
|
||||||
|
mainLogger.Critical("Server terminated with a reason: %v", err)
|
||||||
|
mainLogger.Info("Shutting down...")
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
os.Exit(run())
|
os.Exit(run())
|
||||||
}
|
}
|
||||||
|
|
228
utils.go
228
utils.go
|
@ -1,75 +1,75 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"bufio"
|
||||||
"net"
|
"context"
|
||||||
"sync"
|
"crypto/tls"
|
||||||
"io"
|
"crypto/x509"
|
||||||
"time"
|
"errors"
|
||||||
"errors"
|
"io"
|
||||||
"net/http"
|
"io/ioutil"
|
||||||
"bufio"
|
"net"
|
||||||
"crypto/tls"
|
"net/http"
|
||||||
"crypto/x509"
|
"sync"
|
||||||
"io/ioutil"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const COPY_BUF = 128 * 1024
|
const COPY_BUF = 128 * 1024
|
||||||
|
|
||||||
func proxy(ctx context.Context, left, right net.Conn) {
|
func proxy(ctx context.Context, left, right net.Conn) {
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
cpy := func (dst, src net.Conn) {
|
cpy := func(dst, src net.Conn) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
io.Copy(dst, src)
|
io.Copy(dst, src)
|
||||||
dst.Close()
|
dst.Close()
|
||||||
}
|
}
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go cpy(left, right)
|
go cpy(left, right)
|
||||||
go cpy(right, left)
|
go cpy(right, left)
|
||||||
groupdone := make(chan struct{}, 1)
|
groupdone := make(chan struct{}, 1)
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
groupdone <-struct{}{}
|
groupdone <- struct{}{}
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
left.Close()
|
left.Close()
|
||||||
right.Close()
|
right.Close()
|
||||||
case <-groupdone:
|
case <-groupdone:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
<-groupdone
|
<-groupdone
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
|
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
ltr := func (dst net.Conn, src io.Reader) {
|
ltr := func(dst net.Conn, src io.Reader) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
io.Copy(dst, src)
|
io.Copy(dst, src)
|
||||||
dst.Close()
|
dst.Close()
|
||||||
}
|
}
|
||||||
rtl := func (dst io.Writer, src io.Reader) {
|
rtl := func(dst io.Writer, src io.Reader) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
copyBody(dst, src)
|
copyBody(dst, src)
|
||||||
}
|
}
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go ltr(right, leftreader)
|
go ltr(right, leftreader)
|
||||||
go rtl(leftwriter, right)
|
go rtl(leftwriter, right)
|
||||||
groupdone := make(chan struct{}, 1)
|
groupdone := make(chan struct{}, 1)
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
groupdone <-struct{}{}
|
groupdone <- struct{}{}
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
leftreader.Close()
|
leftreader.Close()
|
||||||
right.Close()
|
right.Close()
|
||||||
case <-groupdone:
|
case <-groupdone:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
<-groupdone
|
<-groupdone
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||||
|
@ -79,7 +79,7 @@ var hopHeaders = []string{
|
||||||
"Keep-Alive",
|
"Keep-Alive",
|
||||||
"Proxy-Authenticate",
|
"Proxy-Authenticate",
|
||||||
"Proxy-Connection",
|
"Proxy-Connection",
|
||||||
"Proxy-Authorization",
|
"Proxy-Authorization",
|
||||||
"Te", // canonicalized version of "TE"
|
"Te", // canonicalized version of "TE"
|
||||||
"Trailers",
|
"Trailers",
|
||||||
"Transfer-Encoding",
|
"Transfer-Encoding",
|
||||||
|
@ -101,65 +101,65 @@ func delHopHeaders(header http.Header) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
|
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
|
||||||
hj, ok := hijackable.(http.Hijacker)
|
hj, ok := hijackable.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errors.New("Connection doesn't support hijacking")
|
return nil, nil, errors.New("Connection doesn't support hijacking")
|
||||||
}
|
}
|
||||||
conn, rw, err := hj.Hijack()
|
conn, rw, err := hj.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
var emptytime time.Time
|
var emptytime time.Time
|
||||||
err = conn.SetDeadline(emptytime)
|
err = conn.SetDeadline(emptytime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return conn, rw, nil
|
return conn, rw, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func flush(flusher interface{}) bool {
|
func flush(flusher interface{}) bool {
|
||||||
f, ok := flusher.(http.Flusher)
|
f, ok := flusher.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
f.Flush()
|
f.Flush()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyBody(wr io.Writer, body io.Reader) {
|
func copyBody(wr io.Writer, body io.Reader) {
|
||||||
buf := make([]byte, COPY_BUF)
|
buf := make([]byte, COPY_BUF)
|
||||||
for {
|
for {
|
||||||
bread, read_err := body.Read(buf)
|
bread, read_err := body.Read(buf)
|
||||||
var write_err error
|
var write_err error
|
||||||
if bread > 0 {
|
if bread > 0 {
|
||||||
_, write_err = wr.Write(buf[:bread])
|
_, write_err = wr.Write(buf[:bread])
|
||||||
flush(wr)
|
flush(wr)
|
||||||
}
|
}
|
||||||
if read_err != nil || write_err != nil {
|
if read_err != nil || write_err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeServerTLSConfig(certfile, keyfile, cafile string) (*tls.Config, error) {
|
func makeServerTLSConfig(certfile, keyfile, cafile string) (*tls.Config, error) {
|
||||||
var cfg tls.Config
|
var cfg tls.Config
|
||||||
cert, err := tls.LoadX509KeyPair(certfile, keyfile)
|
cert, err := tls.LoadX509KeyPair(certfile, keyfile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
cfg.Certificates = []tls.Certificate{cert}
|
cfg.Certificates = []tls.Certificate{cert}
|
||||||
if cafile != "" {
|
if cafile != "" {
|
||||||
roots := x509.NewCertPool()
|
roots := x509.NewCertPool()
|
||||||
certs, err := ioutil.ReadFile(cafile)
|
certs, err := ioutil.ReadFile(cafile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if ok := roots.AppendCertsFromPEM(certs); !ok {
|
if ok := roots.AppendCertsFromPEM(certs); !ok {
|
||||||
return nil, errors.New("Failed to load CA certificates")
|
return nil, errors.New("Failed to load CA certificates")
|
||||||
}
|
}
|
||||||
cfg.ClientCAs = roots
|
cfg.ClientCAs = roots
|
||||||
cfg.ClientAuth = tls.VerifyClientCertIfGiven
|
cfg.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
}
|
}
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue