mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-24 11:15:49 +03:00
Merge pull request #757 from mholt/extend-tls-client-auth
Extend tls client auth
This commit is contained in:
commit
ddf4b1fd3b
4 changed files with 99 additions and 29 deletions
|
@ -83,10 +83,30 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
|
||||||
c.TLS.Ciphers = append(c.TLS.Ciphers, value)
|
c.TLS.Ciphers = append(c.TLS.Ciphers, value)
|
||||||
}
|
}
|
||||||
case "clients":
|
case "clients":
|
||||||
c.TLS.ClientCerts = c.RemainingArgs()
|
clientCertList := c.RemainingArgs()
|
||||||
if len(c.TLS.ClientCerts) == 0 {
|
if len(clientCertList) == 0 {
|
||||||
return nil, c.ArgErr()
|
return nil, c.ArgErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
listStart, mustProvideCA := 1, true
|
||||||
|
switch clientCertList[0] {
|
||||||
|
case "request":
|
||||||
|
c.TLS.ClientAuth = tls.RequestClientCert
|
||||||
|
mustProvideCA = false
|
||||||
|
case "require":
|
||||||
|
c.TLS.ClientAuth = tls.RequireAnyClientCert
|
||||||
|
mustProvideCA = false
|
||||||
|
case "verify_if_given":
|
||||||
|
c.TLS.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
|
default:
|
||||||
|
c.TLS.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
listStart = 0
|
||||||
|
}
|
||||||
|
if mustProvideCA && len(clientCertList) <= listStart {
|
||||||
|
return nil, c.ArgErr()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.TLS.ClientCerts = clientCertList[listStart:]
|
||||||
case "load":
|
case "load":
|
||||||
c.Args(&loadDir)
|
c.Args(&loadDir)
|
||||||
c.TLS.Manual = true
|
c.TLS.Manual = true
|
||||||
|
|
|
@ -189,34 +189,69 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupParseWithClientAuth(t *testing.T) {
|
func TestSetupParseWithClientAuth(t *testing.T) {
|
||||||
|
// Test missing client cert file
|
||||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||||
clients client_ca.crt client2_ca.crt
|
clients
|
||||||
}`
|
}`
|
||||||
c := setup.NewTestController(params)
|
c := setup.NewTestController(params)
|
||||||
_, err := Setup(c)
|
_, err := Setup(c)
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Expected no errors, got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if count := len(c.TLS.ClientCerts); count != 2 {
|
|
||||||
t.Fatalf("Expected two client certs, had %d", count)
|
|
||||||
}
|
|
||||||
if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" {
|
|
||||||
t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual)
|
|
||||||
}
|
|
||||||
if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" {
|
|
||||||
t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test missing client cert file
|
|
||||||
params = `tls ` + certFile + ` ` + keyFile + ` {
|
|
||||||
clients
|
|
||||||
}`
|
|
||||||
c = setup.NewTestController(params)
|
|
||||||
_, err = Setup(c)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error, but no error returned")
|
t.Errorf("Expected an error, but no error returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"}
|
||||||
|
for caseNumber, caseData := range []struct {
|
||||||
|
params string
|
||||||
|
clientAuthType tls.ClientAuthType
|
||||||
|
expectedErr bool
|
||||||
|
expectedCAs []string
|
||||||
|
}{
|
||||||
|
{"", tls.NoClientCert, false, noCAs},
|
||||||
|
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||||
|
clients client_ca.crt client2_ca.crt
|
||||||
|
}`, tls.RequireAndVerifyClientCert, false, twoCAs},
|
||||||
|
// now come modifier
|
||||||
|
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||||
|
clients request
|
||||||
|
}`, tls.RequestClientCert, false, noCAs},
|
||||||
|
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||||
|
clients require
|
||||||
|
}`, tls.RequireAnyClientCert, false, noCAs},
|
||||||
|
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||||
|
clients verify_if_given client_ca.crt client2_ca.crt
|
||||||
|
}`, tls.VerifyClientCertIfGiven, false, twoCAs},
|
||||||
|
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||||
|
clients verify_if_given
|
||||||
|
}`, tls.VerifyClientCertIfGiven, true, noCAs},
|
||||||
|
} {
|
||||||
|
c := setup.NewTestController(caseData.params)
|
||||||
|
_, err := Setup(c)
|
||||||
|
if caseData.expectedErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("In case %d: Expected an error, got: %v", caseNumber, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if caseData.clientAuthType != c.TLS.ClientAuth {
|
||||||
|
t.Errorf("In case %d: Expected TLS client auth type %v, got: %v",
|
||||||
|
caseNumber, caseData.clientAuthType, c.TLS.ClientAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := len(c.TLS.ClientCerts); count < len(caseData.expectedCAs) {
|
||||||
|
t.Fatalf("In case %d: Expected %d client certs, had %d", caseNumber, len(caseData.expectedCAs), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, expected := range caseData.expectedCAs {
|
||||||
|
if actual := c.TLS.ClientCerts[idx]; actual != expected {
|
||||||
|
t.Errorf("In case %d: Expected %dth client cert file to be '%s', but was '%s'",
|
||||||
|
caseNumber, idx, expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupParseWithKeyType(t *testing.T) {
|
func TestSetupParseWithKeyType(t *testing.T) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/mholt/caddy/middleware"
|
"github.com/mholt/caddy/middleware"
|
||||||
|
@ -75,4 +76,5 @@ type TLSConfig struct {
|
||||||
ProtocolMaxVersion uint16
|
ProtocolMaxVersion uint16
|
||||||
PreferServerCipherSuites bool
|
PreferServerCipherSuites bool
|
||||||
ClientCerts []string
|
ClientCerts []string
|
||||||
|
ClientAuth tls.ClientAuthType
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -332,6 +333,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use URL.RawPath If you need the original, "raw" URL.Path in your middleware.
|
||||||
|
// Collapse any ./ ../ /// madness here instead of doing that in every plugin.
|
||||||
|
if r.URL.Path != "/" {
|
||||||
|
path := filepath.Clean(r.URL.Path)
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
r.URL.Path = path
|
||||||
|
}
|
||||||
|
|
||||||
// Execute the optional request callback if it exists and it's not disabled
|
// Execute the optional request callback if it exists and it's not disabled
|
||||||
if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) {
|
if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) {
|
||||||
return
|
return
|
||||||
|
@ -368,17 +379,19 @@ func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) {
|
||||||
// setupClientAuth sets up TLS client authentication only if
|
// setupClientAuth sets up TLS client authentication only if
|
||||||
// any of the TLS configs specified at least one cert file.
|
// any of the TLS configs specified at least one cert file.
|
||||||
func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
|
func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
|
||||||
var clientAuth bool
|
whatClientAuth := tls.NoClientCert
|
||||||
for _, cfg := range tlsConfigs {
|
for _, cfg := range tlsConfigs {
|
||||||
if len(cfg.ClientCerts) > 0 {
|
if whatClientAuth < cfg.ClientAuth { // Use the most restrictive.
|
||||||
clientAuth = true
|
whatClientAuth = cfg.ClientAuth
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if clientAuth {
|
if whatClientAuth != tls.NoClientCert {
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
for _, cfg := range tlsConfigs {
|
for _, cfg := range tlsConfigs {
|
||||||
|
if len(cfg.ClientCerts) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, caFile := range cfg.ClientCerts {
|
for _, caFile := range cfg.ClientCerts {
|
||||||
caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect
|
caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -390,7 +403,7 @@ func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
config.ClientCAs = pool
|
config.ClientCAs = pool
|
||||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
config.ClientAuth = whatClientAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Reference in a new issue