PoC: on-demand TLS

Implements "on-demand TLS" as I call it, which means obtaining TLS certificates on-the-fly during TLS handshakes if a certificate for the requested hostname is not already available. Only the first request for a new hostname will experience higher latency; subsequent requests will get the new certificates right out of memory.

Code still needs lots of cleanup but the feature is basically working.
This commit is contained in:
Matthew Holt 2016-01-13 00:32:46 -07:00
parent b4cab78bec
commit 47079c3d24
4 changed files with 218 additions and 45 deletions

View file

@ -191,8 +191,9 @@ func startServers(groupings bindingGroup) error {
if err != nil {
return err
}
s.HTTP2 = HTTP2 // TODO: This setting is temporary
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running
s.HTTP2 = HTTP2 // TODO: This setting is temporary
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running
s.SNICallback = letsencrypt.GetCertificateDuringHandshake // TLS on demand -- awesome!
var ln server.ListenerFile
if IsRestart() {

View file

@ -0,0 +1,99 @@
package letsencrypt
import (
"crypto/tls"
"errors"
"strings"
"sync"
"github.com/mholt/caddy/server"
)
// GetCertificateDuringHandshake is a function that gets a certificate during a TLS handshake.
// It first checks an in-memory cache in case the cert was requested before, then tries to load
// a certificate in the storage folder from disk. If it can't find an existing certificate, it
// will try to obtain one using ACME, which will then be stored on disk and cached in memory.
//
// This function is safe for use by multiple concurrent goroutines.
func GetCertificateDuringHandshake(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Utility function to help us load a cert from disk and put it in the cache if successful
loadCertFromDisk := func(domain string) *tls.Certificate {
cert, err := tls.LoadX509KeyPair(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
if err == nil {
certCacheMu.Lock()
if len(certCache) < 10000 { // limit size of cache to prevent a ridiculous, unusual kind of attack
certCache[domain] = &cert
}
certCacheMu.Unlock()
return &cert
}
return nil
}
// First check our in-memory cache to see if we've already loaded it
certCacheMu.RLock()
cert := server.GetCertificateFromCache(clientHello, certCache)
certCacheMu.RUnlock()
if cert != nil {
return cert, nil
}
// Then check to see if we already have one on disk; if we do, add it to cache and use it
name := strings.ToLower(clientHello.ServerName)
cert = loadCertFromDisk(name)
if cert != nil {
return cert, nil
}
// Only option left is to get one from LE, but the name has to qualify first
if !HostQualifies(name) {
return nil, nil
}
// By this point, we need to obtain one from the CA. We must protect this process
// from happening concurrently, so synchronize.
obtainCertWaitGroupsMutex.Lock()
wg, ok := obtainCertWaitGroups[name]
if ok {
// lucky us -- another goroutine is already obtaining the certificate.
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitGroupsMutex.Unlock()
wg.Wait()
return GetCertificateDuringHandshake(clientHello)
}
// looks like it's up to us to do all the work and obtain the cert
wg = new(sync.WaitGroup)
wg.Add(1)
obtainCertWaitGroups[name] = wg
obtainCertWaitGroupsMutex.Unlock()
// Unblock waiters and delete waitgroup when we return
defer func() {
obtainCertWaitGroupsMutex.Lock()
wg.Done()
delete(obtainCertWaitGroups, name)
obtainCertWaitGroupsMutex.Unlock()
}()
// obtain cert
client, err := newClientPort(DefaultEmail, AlternatePort)
if err != nil {
return nil, errors.New("error creating client: " + err.Error())
}
err = clientObtain(client, []string{name}, false)
if err != nil {
return nil, err
}
// load certificate into memory and return it
return loadCertFromDisk(name), nil
}
// obtainCertWaitGroups is used to coordinate obtaining certs for each hostname.
var obtainCertWaitGroups = make(map[string]*sync.WaitGroup)
var obtainCertWaitGroupsMutex sync.Mutex
// certCache stores certificates that have been obtained in memory.
var certCache = make(map[string]*tls.Certificate)
var certCacheMu sync.RWMutex

View file

@ -6,6 +6,7 @@ package letsencrypt
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
@ -82,6 +83,13 @@ func Activate(configs []server.Config) ([]server.Config, error) {
// keep certificates renewed and OCSP stapling updated
go maintainAssets(configs, stopChan)
// TODO - experimental dynamic TLS!
for i := range configs {
if configs[i].Host == "" && configs[i].Port == "443" {
configs[i].TLS.Enabled = true
}
}
return configs, nil
}
@ -127,41 +135,9 @@ func ObtainCerts(configs []server.Config, altPort string) error {
continue
}
Obtain:
certificate, failures := client.ObtainCertificate([]string{cfg.Host}, true, nil)
if len(failures) == 0 {
// Success - immediately save the certificate resource
err := saveCertResource(certificate)
if err != nil {
return errors.New("error saving assets for " + cfg.Host + ": " + err.Error())
}
} else {
// Error - either try to fix it or report them it to the user and abort
var errMsg string // we'll combine all the failures into a single error message
var promptedForAgreement bool // only prompt user for agreement at most once
for errDomain, obtainErr := range failures {
// TODO: Double-check, will obtainErr ever be nil?
if tosErr, ok := obtainErr.(acme.TOSError); ok {
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && altPort == "" { // don't prompt if server is already running
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || altPort != "" {
err := client.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
goto Obtain
}
}
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
}
return errors.New(errMsg)
err := clientObtain(client, []string{cfg.Host}, altPort == "")
if err != nil {
return err
}
}
}
@ -447,6 +423,49 @@ func redirPlaintextHost(cfg server.Config) server.Config {
}
}
// clientObtain uses client to obtain a single certificate for domains in names. If
// the user is present to provide an email address, pass in true for allowPrompt,
// otherwise pass in false. If err == nil, the certificate (and key) will be saved
// to disk in the storage folder.
func clientObtain(client *acme.Client, names []string, allowPrompt bool) error {
certificate, failures := client.ObtainCertificate(names, true, nil)
if len(failures) > 0 {
// Error - either try to fix it or report them it to the user and abort
var errMsg string // we'll combine all the failures into a single error message
var promptedForAgreement bool // only prompt user for agreement at most once
for errDomain, obtainErr := range failures {
// TODO: Double-check, will obtainErr ever be nil?
if tosErr, ok := obtainErr.(acme.TOSError); ok {
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && allowPrompt { // don't prompt if server is already running
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || !allowPrompt {
err := client.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
return clientObtain(client, names, allowPrompt)
}
}
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
}
return errors.New(errMsg)
}
// Success - immediately save the certificate resource
err := saveCertResource(certificate)
if err != nil {
return fmt.Errorf("error saving assets for %v: %v", names, err)
}
return nil
}
// Revoke revokes the certificate for host via ACME protocol.
func Revoke(host string) error {
if !existingCertAndKey(host) {

View file

@ -13,6 +13,7 @@ import (
"net/http"
"os"
"runtime"
"strings"
"sync"
"time"
)
@ -33,6 +34,7 @@ type Server struct {
startChan chan struct{} // used to block until server is finished starting
connTimeout time.Duration // the maximum duration of a graceful shutdown
ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request
SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
// ListenerFile represents a listener.
@ -206,17 +208,39 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
// Here we diverge from the stdlib a bit by loading multiple certs/key pairs
// then we map the server names to their certs
var err error
config.Certificates = make([]tls.Certificate, len(tlsConfigs))
for i, tlsConfig := range tlsConfigs {
config.Certificates[i], err = tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
config.Certificates[i].OCSPStaple = tlsConfig.OCSPStaple
for _, tlsConfig := range tlsConfigs {
if tlsConfig.Certificate == "" || tlsConfig.Key == "" {
continue
}
cert, err := tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
if err != nil {
defer close(s.startChan)
return err
return fmt.Errorf("loading certificate and key pair: %v", err)
}
cert.OCSPStaple = tlsConfig.OCSPStaple
config.Certificates = append(config.Certificates, cert)
}
if len(config.Certificates) > 0 {
config.BuildNameToCertificate()
}
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
// TODO: When Caddy starts, if it is to issue certs dynamically, we need
// terms agreement and an email address. make sure this is enforced at server
// start if the Caddyfile enables dynamic certificate issuance!
// Check NameToCertificate like the std lib does in "getCertificate" (unexported, bah)
cert := GetCertificateFromCache(clientHello, config.NameToCertificate)
if cert != nil {
return cert, nil
}
if s.SNICallback != nil {
return s.SNICallback(clientHello)
}
return nil, nil
}
config.BuildNameToCertificate()
// Customize our TLS configuration
config.MinVersion = tlsConfigs[0].ProtocolMinVersion
@ -225,7 +249,7 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
config.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites
// TLS client authentication, if user enabled it
err = setupClientAuth(tlsConfigs, config)
err := setupClientAuth(tlsConfigs, config)
if err != nil {
defer close(s.startChan)
return err
@ -242,6 +266,36 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
return s.Server.Serve(ln)
}
// Borrowed from the Go standard library, crypto/tls pacakge, common.go.
// It has been modified to fit this program.
// Original license:
//
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
func GetCertificateFromCache(clientHello *tls.ClientHelloInfo, cache map[string]*tls.Certificate) *tls.Certificate {
name := strings.ToLower(clientHello.ServerName)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
// exact match? great! use it
if cert, ok := cache[name]; ok {
return cert
}
// try replacing labels in the name with wildcards until we get a match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := cache[candidate]; ok {
return cert
}
}
return nil
}
// Stop stops the server. It blocks until the server is
// totally stopped. On POSIX systems, it will wait for
// connections to close (up to a max timeout of a few