diff --git a/caddy/caddy.go b/caddy/caddy.go index 734e984d..600abe66 100644 --- a/caddy/caddy.go +++ b/caddy/caddy.go @@ -191,8 +191,9 @@ func startServers(groupings bindingGroup) error { if err != nil { return err } - s.HTTP2 = HTTP2 // TODO: This setting is temporary - s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running + s.HTTP2 = HTTP2 // TODO: This setting is temporary + s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running + s.SNICallback = letsencrypt.GetCertificateDuringHandshake // TLS on demand -- awesome! var ln server.ListenerFile if IsRestart() { diff --git a/caddy/letsencrypt/handshake.go b/caddy/letsencrypt/handshake.go new file mode 100644 index 00000000..690eb076 --- /dev/null +++ b/caddy/letsencrypt/handshake.go @@ -0,0 +1,99 @@ +package letsencrypt + +import ( + "crypto/tls" + "errors" + "strings" + "sync" + + "github.com/mholt/caddy/server" +) + +// GetCertificateDuringHandshake is a function that gets a certificate during a TLS handshake. +// It first checks an in-memory cache in case the cert was requested before, then tries to load +// a certificate in the storage folder from disk. If it can't find an existing certificate, it +// will try to obtain one using ACME, which will then be stored on disk and cached in memory. +// +// This function is safe for use by multiple concurrent goroutines. +func GetCertificateDuringHandshake(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + // Utility function to help us load a cert from disk and put it in the cache if successful + loadCertFromDisk := func(domain string) *tls.Certificate { + cert, err := tls.LoadX509KeyPair(storage.SiteCertFile(domain), storage.SiteKeyFile(domain)) + if err == nil { + certCacheMu.Lock() + if len(certCache) < 10000 { // limit size of cache to prevent a ridiculous, unusual kind of attack + certCache[domain] = &cert + } + certCacheMu.Unlock() + return &cert + } + return nil + } + + // First check our in-memory cache to see if we've already loaded it + certCacheMu.RLock() + cert := server.GetCertificateFromCache(clientHello, certCache) + certCacheMu.RUnlock() + if cert != nil { + return cert, nil + } + + // Then check to see if we already have one on disk; if we do, add it to cache and use it + name := strings.ToLower(clientHello.ServerName) + cert = loadCertFromDisk(name) + if cert != nil { + return cert, nil + } + + // Only option left is to get one from LE, but the name has to qualify first + if !HostQualifies(name) { + return nil, nil + } + + // By this point, we need to obtain one from the CA. We must protect this process + // from happening concurrently, so synchronize. + obtainCertWaitGroupsMutex.Lock() + wg, ok := obtainCertWaitGroups[name] + if ok { + // lucky us -- another goroutine is already obtaining the certificate. + // wait for it to finish obtaining the cert and then we'll use it. + obtainCertWaitGroupsMutex.Unlock() + wg.Wait() + return GetCertificateDuringHandshake(clientHello) + } + + // looks like it's up to us to do all the work and obtain the cert + wg = new(sync.WaitGroup) + wg.Add(1) + obtainCertWaitGroups[name] = wg + obtainCertWaitGroupsMutex.Unlock() + + // Unblock waiters and delete waitgroup when we return + defer func() { + obtainCertWaitGroupsMutex.Lock() + wg.Done() + delete(obtainCertWaitGroups, name) + obtainCertWaitGroupsMutex.Unlock() + }() + + // obtain cert + client, err := newClientPort(DefaultEmail, AlternatePort) + if err != nil { + return nil, errors.New("error creating client: " + err.Error()) + } + err = clientObtain(client, []string{name}, false) + if err != nil { + return nil, err + } + + // load certificate into memory and return it + return loadCertFromDisk(name), nil +} + +// obtainCertWaitGroups is used to coordinate obtaining certs for each hostname. +var obtainCertWaitGroups = make(map[string]*sync.WaitGroup) +var obtainCertWaitGroupsMutex sync.Mutex + +// certCache stores certificates that have been obtained in memory. +var certCache = make(map[string]*tls.Certificate) +var certCacheMu sync.RWMutex diff --git a/caddy/letsencrypt/letsencrypt.go b/caddy/letsencrypt/letsencrypt.go index a2965a10..e721a7b1 100644 --- a/caddy/letsencrypt/letsencrypt.go +++ b/caddy/letsencrypt/letsencrypt.go @@ -6,6 +6,7 @@ package letsencrypt import ( "encoding/json" "errors" + "fmt" "io/ioutil" "net" "net/http" @@ -82,6 +83,13 @@ func Activate(configs []server.Config) ([]server.Config, error) { // keep certificates renewed and OCSP stapling updated go maintainAssets(configs, stopChan) + // TODO - experimental dynamic TLS! + for i := range configs { + if configs[i].Host == "" && configs[i].Port == "443" { + configs[i].TLS.Enabled = true + } + } + return configs, nil } @@ -127,41 +135,9 @@ func ObtainCerts(configs []server.Config, altPort string) error { continue } - Obtain: - certificate, failures := client.ObtainCertificate([]string{cfg.Host}, true, nil) - if len(failures) == 0 { - // Success - immediately save the certificate resource - err := saveCertResource(certificate) - if err != nil { - return errors.New("error saving assets for " + cfg.Host + ": " + err.Error()) - } - } else { - // Error - either try to fix it or report them it to the user and abort - var errMsg string // we'll combine all the failures into a single error message - var promptedForAgreement bool // only prompt user for agreement at most once - - for errDomain, obtainErr := range failures { - // TODO: Double-check, will obtainErr ever be nil? - if tosErr, ok := obtainErr.(acme.TOSError); ok { - // Terms of Service agreement error; we can probably deal with this - if !Agreed && !promptedForAgreement && altPort == "" { // don't prompt if server is already running - Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL - promptedForAgreement = true - } - if Agreed || altPort != "" { - err := client.AgreeToTOS() - if err != nil { - return errors.New("error agreeing to updated terms: " + err.Error()) - } - goto Obtain - } - } - - // If user did not agree or it was any other kind of error, just append to the list of errors - errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n" - } - - return errors.New(errMsg) + err := clientObtain(client, []string{cfg.Host}, altPort == "") + if err != nil { + return err } } } @@ -447,6 +423,49 @@ func redirPlaintextHost(cfg server.Config) server.Config { } } +// clientObtain uses client to obtain a single certificate for domains in names. If +// the user is present to provide an email address, pass in true for allowPrompt, +// otherwise pass in false. If err == nil, the certificate (and key) will be saved +// to disk in the storage folder. +func clientObtain(client *acme.Client, names []string, allowPrompt bool) error { + certificate, failures := client.ObtainCertificate(names, true, nil) + if len(failures) > 0 { + // Error - either try to fix it or report them it to the user and abort + var errMsg string // we'll combine all the failures into a single error message + var promptedForAgreement bool // only prompt user for agreement at most once + + for errDomain, obtainErr := range failures { + // TODO: Double-check, will obtainErr ever be nil? + if tosErr, ok := obtainErr.(acme.TOSError); ok { + // Terms of Service agreement error; we can probably deal with this + if !Agreed && !promptedForAgreement && allowPrompt { // don't prompt if server is already running + Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL + promptedForAgreement = true + } + if Agreed || !allowPrompt { + err := client.AgreeToTOS() + if err != nil { + return errors.New("error agreeing to updated terms: " + err.Error()) + } + return clientObtain(client, names, allowPrompt) + } + } + + // If user did not agree or it was any other kind of error, just append to the list of errors + errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n" + } + return errors.New(errMsg) + } + + // Success - immediately save the certificate resource + err := saveCertResource(certificate) + if err != nil { + return fmt.Errorf("error saving assets for %v: %v", names, err) + } + + return nil +} + // Revoke revokes the certificate for host via ACME protocol. func Revoke(host string) error { if !existingCertAndKey(host) { diff --git a/server/server.go b/server/server.go index 5794c167..293092c6 100644 --- a/server/server.go +++ b/server/server.go @@ -13,6 +13,7 @@ import ( "net/http" "os" "runtime" + "strings" "sync" "time" ) @@ -33,6 +34,7 @@ type Server struct { startChan chan struct{} // used to block until server is finished starting connTimeout time.Duration // the maximum duration of a graceful shutdown ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request + SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) } // ListenerFile represents a listener. @@ -206,17 +208,39 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { // Here we diverge from the stdlib a bit by loading multiple certs/key pairs // then we map the server names to their certs - var err error - config.Certificates = make([]tls.Certificate, len(tlsConfigs)) - for i, tlsConfig := range tlsConfigs { - config.Certificates[i], err = tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key) - config.Certificates[i].OCSPStaple = tlsConfig.OCSPStaple + for _, tlsConfig := range tlsConfigs { + if tlsConfig.Certificate == "" || tlsConfig.Key == "" { + continue + } + cert, err := tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key) if err != nil { defer close(s.startChan) - return err + return fmt.Errorf("loading certificate and key pair: %v", err) } + cert.OCSPStaple = tlsConfig.OCSPStaple + config.Certificates = append(config.Certificates, cert) + } + if len(config.Certificates) > 0 { + config.BuildNameToCertificate() + } + + config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + // TODO: When Caddy starts, if it is to issue certs dynamically, we need + // terms agreement and an email address. make sure this is enforced at server + // start if the Caddyfile enables dynamic certificate issuance! + + // Check NameToCertificate like the std lib does in "getCertificate" (unexported, bah) + cert := GetCertificateFromCache(clientHello, config.NameToCertificate) + if cert != nil { + return cert, nil + } + + if s.SNICallback != nil { + return s.SNICallback(clientHello) + } + + return nil, nil } - config.BuildNameToCertificate() // Customize our TLS configuration config.MinVersion = tlsConfigs[0].ProtocolMinVersion @@ -225,7 +249,7 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { config.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites // TLS client authentication, if user enabled it - err = setupClientAuth(tlsConfigs, config) + err := setupClientAuth(tlsConfigs, config) if err != nil { defer close(s.startChan) return err @@ -242,6 +266,36 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { return s.Server.Serve(ln) } +// Borrowed from the Go standard library, crypto/tls pacakge, common.go. +// It has been modified to fit this program. +// Original license: +// +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +func GetCertificateFromCache(clientHello *tls.ClientHelloInfo, cache map[string]*tls.Certificate) *tls.Certificate { + name := strings.ToLower(clientHello.ServerName) + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + + // exact match? great! use it + if cert, ok := cache[name]; ok { + return cert + } + + // try replacing labels in the name with wildcards until we get a match. + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok := cache[candidate]; ok { + return cert + } + } + return nil +} + // Stop stops the server. It blocks until the server is // totally stopped. On POSIX systems, it will wait for // connections to close (up to a max timeout of a few