mirror of
https://github.com/caddyserver/caddy.git
synced 2025-02-05 08:38:26 +03:00
tls: Refactor internals related to TLS configurations (#1466)
* tls: Refactor TLS config innards with a few minor syntax changes muststaple -> must_staple "http2 off" -> "alpn" with list of ALPN values * Fix typo * Fix QUIC handler * Inline struct field assignments
This commit is contained in:
parent
4b877eebc4
commit
73794f2a2c
13 changed files with 316 additions and 301 deletions
|
@ -79,7 +79,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
|
|||
cfg.TLS.Enabled = true
|
||||
cfg.Addr.Scheme = "https"
|
||||
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
|
||||
_, err := caddytls.CacheManagedCertificate(cfg.Addr.Host, cfg.TLS)
|
||||
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -35,6 +35,11 @@ type tlsHandler struct {
|
|||
// Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17):
|
||||
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
|
||||
func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if h.listener == nil {
|
||||
h.next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
h.listener.helloInfosMu.RLock()
|
||||
info := h.listener.helloInfos[r.RemoteAddr]
|
||||
h.listener.helloInfosMu.RUnlock()
|
||||
|
@ -78,63 +83,62 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
h.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// clientHelloConn reads the ClientHello
|
||||
// and stores it in the attached listener.
|
||||
type clientHelloConn struct {
|
||||
net.Conn
|
||||
readHello bool
|
||||
listener *tlsHelloListener
|
||||
readHello bool // whether ClientHello has been read
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
// Read reads from c.Conn (by letting the standard library
|
||||
// do the reading off the wire), with the exception of
|
||||
// getting a copy of the ClientHello so it can parse it.
|
||||
func (c *clientHelloConn) Read(b []byte) (n int, err error) {
|
||||
if !c.readHello {
|
||||
// Read the header bytes.
|
||||
hdr := make([]byte, 5)
|
||||
n, err := io.ReadFull(c.Conn, hdr)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Get the length of the ClientHello message and read it as well.
|
||||
length := uint16(hdr[3])<<8 | uint16(hdr[4])
|
||||
hello := make([]byte, int(length))
|
||||
n, err = io.ReadFull(c.Conn, hello)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Parse the ClientHello and store it in the map.
|
||||
rawParsed := parseRawClientHello(hello)
|
||||
c.listener.helloInfosMu.Lock()
|
||||
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
|
||||
c.listener.helloInfosMu.Unlock()
|
||||
|
||||
// Since we buffered the header and ClientHello, pretend we were
|
||||
// never here by lining up the buffered values to be read with a
|
||||
// custom connection type, followed by the rest of the actual
|
||||
// underlying connection.
|
||||
mr := io.MultiReader(bytes.NewReader(hdr), bytes.NewReader(hello), c.Conn)
|
||||
mc := multiConn{Conn: c.Conn, reader: mr}
|
||||
|
||||
c.Conn = mc
|
||||
|
||||
c.readHello = true
|
||||
// if we've already read the ClientHello, pass thru
|
||||
if c.readHello {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
// multiConn is a net.Conn that reads from the
|
||||
// given reader instead of the wire directly. This
|
||||
// is useful when some of the connection has already
|
||||
// been read (like the TLS Client Hello) and the
|
||||
// reader is a io.MultiReader that starts with
|
||||
// the contents of the buffer.
|
||||
type multiConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
// we let the standard lib read off the wire for us, and
|
||||
// tee that into our buffer so we can read the ClientHello
|
||||
tee := io.TeeReader(c.Conn, c.buf)
|
||||
n, err = tee.Read(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if c.buf.Len() < 5 {
|
||||
return // need to read more bytes for header
|
||||
}
|
||||
|
||||
// Read reads from mc.reader.
|
||||
func (mc multiConn) Read(b []byte) (n int, err error) {
|
||||
return mc.reader.Read(b)
|
||||
// read the header bytes
|
||||
hdr := make([]byte, 5)
|
||||
_, err = io.ReadFull(c.buf, hdr)
|
||||
if err != nil {
|
||||
return // this would be highly unusual and sad
|
||||
}
|
||||
|
||||
// get length of the ClientHello message and read it
|
||||
length := int(uint16(hdr[3])<<8 | uint16(hdr[4]))
|
||||
if c.buf.Len() < length {
|
||||
return // need to read more bytes
|
||||
}
|
||||
hello := make([]byte, length)
|
||||
_, err = io.ReadFull(c.buf, hello)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.buf = nil // buffer no longer needed
|
||||
|
||||
// parse the ClientHello and store it in the map
|
||||
rawParsed := parseRawClientHello(hello)
|
||||
c.listener.helloInfosMu.Lock()
|
||||
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
|
||||
c.listener.helloInfosMu.Unlock()
|
||||
|
||||
c.readHello = true
|
||||
return
|
||||
}
|
||||
|
||||
// parseRawClientHello parses data which contains the raw
|
||||
|
@ -279,7 +283,7 @@ func (l *tlsHelloListener) Accept() (net.Conn, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
helloConn := &clientHelloConn{Conn: conn, listener: l}
|
||||
helloConn := &clientHelloConn{Conn: conn, listener: l, buf: new(bytes.Buffer)}
|
||||
return tls.Server(helloConn, l.config), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ func TestHeuristicFunctions(t *testing.T) {
|
|||
// clientHello pairs a User-Agent string to its ClientHello message.
|
||||
type clientHello struct {
|
||||
userAgent string
|
||||
helloHex string
|
||||
helloHex string // do NOT include the header, just the ClientHello message
|
||||
}
|
||||
|
||||
// clientHellos groups samples of true (real) ClientHellos by the
|
||||
|
@ -158,7 +158,12 @@ func TestHeuristicFunctions(t *testing.T) {
|
|||
},
|
||||
{
|
||||
// IE 11 on Windows 7, this connection was intercepted by Blue Coat
|
||||
helloHex: "010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100",
|
||||
helloHex: `010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100`,
|
||||
},
|
||||
{
|
||||
// Firefox 51.0.1 being intercepted by burp 1.7.17
|
||||
userAgent: "(TODO)",
|
||||
helloHex: `010000d8030358a92f4daca95acc2f6a10a9c50d736135eae39406d3090238464540d482677600003ac023c027003cc025c02900670040c009c013002fc004c00e00330032c02bc02f009cc02dc031009e00a2c008c012000ac003c00d0016001300ff01000075000a0034003200170001000300130015000600070009000a0018000b000c0019000d000e000f001000110002001200040005001400080016000b00020100000d00180016060306010503050104030401040202030201020201010000001700150000126a61677561722e6b796877616e612e6f7267`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -31,40 +31,47 @@ type Server struct {
|
|||
connTimeout time.Duration // max time to wait for a connection before force stop
|
||||
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
|
||||
vhosts *vhostTrie
|
||||
tlsConfig caddytls.ConfigGroup
|
||||
}
|
||||
|
||||
// ensure it satisfies the interface
|
||||
var _ caddy.GracefulServer = new(Server)
|
||||
|
||||
var defaultALPN = []string{"h2", "http/1.1"}
|
||||
|
||||
// makeTLSConfig extracts TLS settings from each site config to
|
||||
// build a tls.Config usable in Caddy HTTP servers. The returned
|
||||
// config will be nil if TLS is disabled for these sites.
|
||||
func makeTLSConfig(group []*SiteConfig) (*tls.Config, error) {
|
||||
var tlsConfigs []*caddytls.Config
|
||||
for i := range group {
|
||||
if HTTP2 && len(group[i].TLS.ALPN) == 0 {
|
||||
// if no application-level protocol was configured up to now,
|
||||
// default to HTTP/2, then HTTP/1.1 if necessary
|
||||
group[i].TLS.ALPN = defaultALPN
|
||||
}
|
||||
tlsConfigs = append(tlsConfigs, group[i].TLS)
|
||||
}
|
||||
return caddytls.MakeTLSConfig(tlsConfigs)
|
||||
}
|
||||
|
||||
// NewServer creates a new Server instance that will listen on addr
|
||||
// and will serve the sites configured in group.
|
||||
func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
||||
s := &Server{
|
||||
Server: makeHTTPServer(addr, group),
|
||||
Server: makeHTTPServerWithTimeouts(addr, group),
|
||||
vhosts: newVHostTrie(),
|
||||
sites: group,
|
||||
connTimeout: GracefulTimeout,
|
||||
}
|
||||
|
||||
s.Server.Handler = s // this is weird, but whatever
|
||||
tlsh := &tlsHandler{next: s.Server.Handler}
|
||||
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
|
||||
// when a connection closes or is hijacked, delete its entry
|
||||
// in the map, because we are done with it.
|
||||
if tlsh.listener != nil {
|
||||
if cs == http.StateHijacked || cs == http.StateClosed {
|
||||
tlsh.listener.helloInfosMu.Lock()
|
||||
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
|
||||
tlsh.listener.helloInfosMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Disable HTTP/2 if desired
|
||||
if !HTTP2 {
|
||||
s.Server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||
// extract TLS settings from each site config to build
|
||||
// a tls.Config, which will not be nil if TLS is enabled
|
||||
tlsConfig, err := makeTLSConfig(group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Server.TLSConfig = tlsConfig
|
||||
|
||||
// Enable QUIC if desired
|
||||
if QUIC {
|
||||
|
@ -72,41 +79,36 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
|||
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
|
||||
}
|
||||
|
||||
// Set up TLS configuration
|
||||
tlsConfigs := make(caddytls.ConfigGroup)
|
||||
var allConfigs []*caddytls.Config
|
||||
|
||||
for _, site := range group {
|
||||
|
||||
if err := site.TLS.Build(tlsConfigs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfigs[site.TLS.Hostname] = site.TLS
|
||||
allConfigs = append(allConfigs, site.TLS)
|
||||
}
|
||||
|
||||
// Check if configs are valid
|
||||
if err := caddytls.CheckConfigs(allConfigs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.tlsConfig = tlsConfigs
|
||||
|
||||
if caddytls.HasTLSEnabled(allConfigs) {
|
||||
s.Server.TLSConfig = &tls.Config{
|
||||
GetConfigForClient: s.tlsConfig.GetConfigForClient,
|
||||
GetCertificate: s.tlsConfig.GetCertificate,
|
||||
}
|
||||
}
|
||||
|
||||
// As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2"
|
||||
if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 {
|
||||
s.Server.TLSConfig.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
// if TLS is enabled, make sure we prepare the Server accordingly
|
||||
if s.Server.TLSConfig != nil {
|
||||
s.Server.Handler = tlsh
|
||||
// wrap the HTTP handler with a handler that does MITM detection
|
||||
tlsh := &tlsHandler{next: s.Server.Handler}
|
||||
s.Server.Handler = tlsh // this needs to be the "outer" handler when Serve() is called, for type assertion
|
||||
|
||||
// when Serve() creates the TLS listener later, that listener should
|
||||
// be adding a reference the ClientHello info to a map; this callback
|
||||
// will be sure to clear out that entry when the connection closes.
|
||||
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
|
||||
// when a connection closes or is hijacked, delete its entry
|
||||
// in the map, because we are done with it.
|
||||
if tlsh.listener != nil {
|
||||
if cs == http.StateHijacked || cs == http.StateClosed {
|
||||
tlsh.listener.helloInfosMu.Lock()
|
||||
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
|
||||
tlsh.listener.helloInfosMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// As of Go 1.7, if the Server's TLSConfig is not nil, HTTP/2 is enabled only
|
||||
// if TLSConfig.NextProtos includes the string "h2"
|
||||
if HTTP2 && len(s.Server.TLSConfig.NextProtos) == 0 {
|
||||
// some experimenting shows that this NextProtos must have at least
|
||||
// one value that overlaps with the NextProtos of any other tls.Config
|
||||
// that is returned from GetConfigForClient; if there is no overlap,
|
||||
// the connection will fail (as of Go 1.8, Feb. 2017).
|
||||
s.Server.TLSConfig.NextProtos = defaultALPN
|
||||
}
|
||||
}
|
||||
|
||||
// Compile custom middleware for every site (enables virtual hosting)
|
||||
|
@ -122,6 +124,61 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
// makeHTTPServerWithTimeouts makes an http.Server from the group of
|
||||
// configs in a way that configures timeouts (or, if not set, it uses
|
||||
// the default timeouts) by combining the configuration of each
|
||||
// SiteConfig in the group. (Timeouts are important for mitigating
|
||||
// slowloris attacks.)
|
||||
func makeHTTPServerWithTimeouts(addr string, group []*SiteConfig) *http.Server {
|
||||
// find the minimum duration configured for each timeout
|
||||
var min Timeouts
|
||||
for _, cfg := range group {
|
||||
if cfg.Timeouts.ReadTimeoutSet &&
|
||||
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
|
||||
min.ReadTimeoutSet = true
|
||||
min.ReadTimeout = cfg.Timeouts.ReadTimeout
|
||||
}
|
||||
if cfg.Timeouts.ReadHeaderTimeoutSet &&
|
||||
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
|
||||
min.ReadHeaderTimeoutSet = true
|
||||
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
|
||||
}
|
||||
if cfg.Timeouts.WriteTimeoutSet &&
|
||||
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
|
||||
min.WriteTimeoutSet = true
|
||||
min.WriteTimeout = cfg.Timeouts.WriteTimeout
|
||||
}
|
||||
if cfg.Timeouts.IdleTimeoutSet &&
|
||||
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
|
||||
min.IdleTimeoutSet = true
|
||||
min.IdleTimeout = cfg.Timeouts.IdleTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// for the values that were not set, use defaults
|
||||
if !min.ReadTimeoutSet {
|
||||
min.ReadTimeout = defaultTimeouts.ReadTimeout
|
||||
}
|
||||
if !min.ReadHeaderTimeoutSet {
|
||||
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
|
||||
}
|
||||
if !min.WriteTimeoutSet {
|
||||
min.WriteTimeout = defaultTimeouts.WriteTimeout
|
||||
}
|
||||
if !min.IdleTimeoutSet {
|
||||
min.IdleTimeout = defaultTimeouts.IdleTimeout
|
||||
}
|
||||
|
||||
// set the final values on the server and return it
|
||||
return &http.Server{
|
||||
Addr: addr,
|
||||
ReadTimeout: min.ReadTimeout,
|
||||
ReadHeaderTimeout: min.ReadHeaderTimeout,
|
||||
WriteTimeout: min.WriteTimeout,
|
||||
IdleTimeout: min.IdleTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s.quicServer.SetQuicHeaders(w.Header())
|
||||
|
@ -390,62 +447,6 @@ var defaultTimeouts = Timeouts{
|
|||
IdleTimeout: 2 * time.Minute,
|
||||
}
|
||||
|
||||
// makeHTTPServer makes an http.Server from the group of configs
|
||||
// in a way that configures timeouts (or, if not set, it uses the
|
||||
// default timeouts) and other http.Server properties by combining
|
||||
// the configuration of each SiteConfig in the group. (Timeouts
|
||||
// are important for mitigating slowloris attacks.)
|
||||
func makeHTTPServer(addr string, group []*SiteConfig) *http.Server {
|
||||
s := &http.Server{Addr: addr}
|
||||
|
||||
// find the minimum duration configured for each timeout
|
||||
var min Timeouts
|
||||
for _, cfg := range group {
|
||||
if cfg.Timeouts.ReadTimeoutSet &&
|
||||
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
|
||||
min.ReadTimeoutSet = true
|
||||
min.ReadTimeout = cfg.Timeouts.ReadTimeout
|
||||
}
|
||||
if cfg.Timeouts.ReadHeaderTimeoutSet &&
|
||||
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
|
||||
min.ReadHeaderTimeoutSet = true
|
||||
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
|
||||
}
|
||||
if cfg.Timeouts.WriteTimeoutSet &&
|
||||
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
|
||||
min.WriteTimeoutSet = true
|
||||
min.WriteTimeout = cfg.Timeouts.WriteTimeout
|
||||
}
|
||||
if cfg.Timeouts.IdleTimeoutSet &&
|
||||
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
|
||||
min.IdleTimeoutSet = true
|
||||
min.IdleTimeout = cfg.Timeouts.IdleTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// for the values that were not set, use defaults
|
||||
if !min.ReadTimeoutSet {
|
||||
min.ReadTimeout = defaultTimeouts.ReadTimeout
|
||||
}
|
||||
if !min.ReadHeaderTimeoutSet {
|
||||
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
|
||||
}
|
||||
if !min.WriteTimeoutSet {
|
||||
min.WriteTimeout = defaultTimeouts.WriteTimeout
|
||||
}
|
||||
if !min.IdleTimeoutSet {
|
||||
min.IdleTimeout = defaultTimeouts.IdleTimeout
|
||||
}
|
||||
|
||||
// set the final values on the server
|
||||
s.ReadTimeout = min.ReadTimeout
|
||||
s.ReadHeaderTimeout = min.ReadHeaderTimeout
|
||||
s.WriteTimeout = min.WriteTimeout
|
||||
s.IdleTimeout = min.IdleTimeout
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
|
|
|
@ -92,7 +92,7 @@ func TestMakeHTTPServer(t *testing.T) {
|
|||
},
|
||||
},
|
||||
} {
|
||||
actual := makeHTTPServer("127.0.0.1:9005", tc.group)
|
||||
actual := makeHTTPServerWithTimeouts("127.0.0.1:9005", tc.group)
|
||||
|
||||
if got, want := actual.Addr, "127.0.0.1:9005"; got != want {
|
||||
t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got)
|
||||
|
|
|
@ -89,8 +89,8 @@ func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
|||
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
|
||||
// (meaning that it was obtained or loaded during a TLS handshake).
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func CacheManagedCertificate(domain string, cfg *Config) (Certificate, error) {
|
||||
// This method is safe for concurrent use.
|
||||
func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
|
||||
storage, err := cfg.StorageFor(cfg.CAUrl)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
|
|
|
@ -109,11 +109,11 @@ type Config struct {
|
|||
// Add the must staple TLS extension to the CSR generated by lego/acme
|
||||
MustStaple bool
|
||||
|
||||
// Disables HTTP2 completely
|
||||
DisableHTTP2 bool
|
||||
// The list of protocols to choose from for Application Layer
|
||||
// Protocol Negotiation (ALPN).
|
||||
ALPN []string
|
||||
|
||||
// Holds final tls.Config
|
||||
tlsConfig *tls.Config
|
||||
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
|
||||
}
|
||||
|
||||
// OnDemandState contains some state relevant for providing
|
||||
|
@ -223,33 +223,20 @@ func (c *Config) StorageFor(caURL string) (Storage, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (cfg *Config) Build(group ConfigGroup) error {
|
||||
config, err := cfg.build()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
cfg.tlsConfig = config
|
||||
cfg.tlsConfig.GetCertificate = group.GetCertificate
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (cfg *Config) build() (*tls.Config, error) {
|
||||
config := new(tls.Config)
|
||||
|
||||
// buildStandardTLSConfig converts cfg (*caddytls.Config) to a *tls.Config
|
||||
// and stores it in cfg so it can be used in servers. If TLS is disabled,
|
||||
// no tls.Config is created.
|
||||
func (cfg *Config) buildStandardTLSConfig() error {
|
||||
if !cfg.Enabled {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
config := new(tls.Config)
|
||||
|
||||
ciphersAdded := make(map[uint16]struct{})
|
||||
curvesAdded := make(map[tls.CurveID]struct{})
|
||||
|
||||
// Add cipher suites
|
||||
// add cipher suites
|
||||
for _, ciph := range cfg.Ciphers {
|
||||
if _, ok := ciphersAdded[ciph]; !ok {
|
||||
ciphersAdded[ciph] = struct{}{}
|
||||
|
@ -259,7 +246,7 @@ func (cfg *Config) build() (*tls.Config, error) {
|
|||
|
||||
config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
|
||||
|
||||
// Union curves
|
||||
// add curve preferences
|
||||
for _, curv := range cfg.CurvePreferences {
|
||||
if _, ok := curvesAdded[curv]; !ok {
|
||||
curvesAdded[curv] = struct{}{}
|
||||
|
@ -270,8 +257,10 @@ func (cfg *Config) build() (*tls.Config, error) {
|
|||
config.MinVersion = cfg.ProtocolMinVersion
|
||||
config.MaxVersion = cfg.ProtocolMaxVersion
|
||||
config.ClientAuth = cfg.ClientAuth
|
||||
config.NextProtos = cfg.ALPN
|
||||
config.GetCertificate = cfg.GetCertificate
|
||||
|
||||
// Set up client authentication if enabled
|
||||
// set up client authentication if enabled
|
||||
if config.ClientAuth != tls.NoClientCert {
|
||||
pool := x509.NewCertPool()
|
||||
clientCertsAdded := make(map[string]struct{})
|
||||
|
@ -286,45 +275,51 @@ func (cfg *Config) build() (*tls.Config, error) {
|
|||
// Any client with a certificate from this CA will be allowed to connect
|
||||
caCrt, err := ioutil.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if !pool.AppendCertsFromPEM(caCrt) {
|
||||
return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
|
||||
return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
|
||||
}
|
||||
}
|
||||
|
||||
config.ClientCAs = pool
|
||||
}
|
||||
|
||||
// Default cipher suites
|
||||
// default cipher suites
|
||||
if len(config.CipherSuites) == 0 {
|
||||
config.CipherSuites = defaultCiphers
|
||||
}
|
||||
|
||||
// For security, ensure TLS_FALLBACK_SCSV is always included first
|
||||
// for security, ensure TLS_FALLBACK_SCSV is always included first
|
||||
if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV {
|
||||
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
|
||||
}
|
||||
|
||||
if cfg.DisableHTTP2 {
|
||||
config.NextProtos = []string{}
|
||||
} else {
|
||||
config.NextProtos = []string{"h2"}
|
||||
}
|
||||
// store the resulting new tls.Config
|
||||
cfg.tlsConfig = config
|
||||
|
||||
return config, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckConfigs checks if multiple TLS configs does not collide with each other
|
||||
func CheckConfigs(configs []*Config) error {
|
||||
// MakeTLSConfig makes a tls.Config from configs. The returned
|
||||
// tls.Config is programmed to load the matching caddytls.Config
|
||||
// based on the hostname in SNI, but that's all.
|
||||
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
||||
if len(configs) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for i, cfg := range configs {
|
||||
configMap := make(configGroup)
|
||||
|
||||
// Can't serve TLS and not-TLS on same port
|
||||
for i, cfg := range configs {
|
||||
if cfg == nil {
|
||||
// avoid nil pointer dereference below this loop
|
||||
configs[i] = new(Config)
|
||||
continue
|
||||
}
|
||||
|
||||
// can't serve TLS and non-TLS on same port
|
||||
if i > 0 && cfg.Enabled != configs[i-1].Enabled {
|
||||
thisConfProto, lastConfProto := "not TLS", "not TLS"
|
||||
if cfg.Enabled {
|
||||
|
@ -333,26 +328,33 @@ func CheckConfigs(configs []*Config) error {
|
|||
if configs[i-1].Enabled {
|
||||
lastConfProto = "TLS"
|
||||
}
|
||||
return fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
|
||||
return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
|
||||
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
|
||||
}
|
||||
|
||||
if !cfg.Enabled {
|
||||
continue
|
||||
// convert each caddytls.Config into a tls.Config
|
||||
if err := cfg.buildStandardTLSConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Key this config by its hostname (overwriting
|
||||
// configs with the same hostname pattern); during
|
||||
// TLS handshakes, configs are loaded based on
|
||||
// the hostname pattern, according to client's SNI.
|
||||
configMap[cfg.Hostname] = cfg
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func HasTLSEnabled(configs []*Config) bool {
|
||||
for _, config := range configs {
|
||||
if config.Enabled {
|
||||
return true
|
||||
}
|
||||
// Is TLS disabled? By now, we know that all
|
||||
// configs agree whether it is or not, so we
|
||||
// can just look at the first one. If so,
|
||||
// we're done here.
|
||||
if len(configs) == 0 || !configs[0].Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return false
|
||||
return &tls.Config{
|
||||
GetConfigForClient: configMap.GetConfigForClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConfigGetter gets a Config keyed by key.
|
||||
|
|
|
@ -8,50 +8,50 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestMakeTLSConfigProtocolVersions(t *testing.T) {
|
||||
func TestConvertTLSConfigProtocolVersions(t *testing.T) {
|
||||
// same min and max protocol versions
|
||||
config := Config{
|
||||
config := &Config{
|
||||
Enabled: true,
|
||||
ProtocolMinVersion: tls.VersionTLS12,
|
||||
ProtocolMaxVersion: tls.VersionTLS12,
|
||||
}
|
||||
result, err := config.build()
|
||||
err := config.buildStandardTLSConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect an error, but got %v", err)
|
||||
}
|
||||
if got, want := result.MinVersion, uint16(tls.VersionTLS12); got != want {
|
||||
if got, want := config.tlsConfig.MinVersion, uint16(tls.VersionTLS12); got != want {
|
||||
t.Errorf("Expected min version to be %x, got %x", want, got)
|
||||
}
|
||||
if got, want := result.MaxVersion, uint16(tls.VersionTLS12); got != want {
|
||||
if got, want := config.tlsConfig.MaxVersion, uint16(tls.VersionTLS12); got != want {
|
||||
t.Errorf("Expected max version to be %x, got %x", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) {
|
||||
func TestConvertTLSConfigPreferServerCipherSuites(t *testing.T) {
|
||||
// prefer server cipher suites
|
||||
config := Config{Enabled: true, PreferServerCipherSuites: true}
|
||||
result, err := config.build()
|
||||
err := config.buildStandardTLSConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect an error, but got %v", err)
|
||||
}
|
||||
if got, want := result.PreferServerCipherSuites, true; got != want {
|
||||
if got, want := config.tlsConfig.PreferServerCipherSuites, true; got != want {
|
||||
t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) {
|
||||
func TestMakeTLSConfigTLSEnabledDisabledError(t *testing.T) {
|
||||
// verify handling when Enabled is true and false
|
||||
configs := []*Config{
|
||||
{Enabled: true},
|
||||
{Enabled: false},
|
||||
}
|
||||
err := CheckConfigs(configs)
|
||||
_, err := MakeTLSConfig(configs)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected an error, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeTLSConfigCipherSuites(t *testing.T) {
|
||||
func TestConvertTLSConfigCipherSuites(t *testing.T) {
|
||||
// ensure cipher suites are unioned and
|
||||
// that TLS_FALLBACK_SCSV is prepended
|
||||
configs := []*Config{
|
||||
|
@ -67,10 +67,13 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, config := range configs {
|
||||
cfg, _ := config.build()
|
||||
|
||||
if !reflect.DeepEqual(cfg.CipherSuites, expectedCiphers[i]) {
|
||||
t.Errorf("Expected ciphers %v but got %v", expectedCiphers[i], cfg.CipherSuites)
|
||||
err := config.buildStandardTLSConfig()
|
||||
if err != nil {
|
||||
t.Errorf("Test %d: Expected no error, got: %v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(config.tlsConfig.CipherSuites, expectedCiphers[i]) {
|
||||
t.Errorf("Test %d: Expected ciphers %v but got %v",
|
||||
i, expectedCiphers[i], config.tlsConfig.CipherSuites)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -13,18 +13,19 @@ import (
|
|||
|
||||
// configGroup is a type that keys configs by their hostname
|
||||
// (hostnames can have wildcard characters; use the getConfig
|
||||
// method to get a config by matching its hostname). Its
|
||||
// GetCertificate function can be used with tls.Config.
|
||||
type ConfigGroup map[string]*Config
|
||||
// method to get a config by matching its hostname).
|
||||
type configGroup map[string]*Config
|
||||
|
||||
// getConfig gets the config by the first key match for name.
|
||||
// In other words, "sub.foo.bar" will get the config for "*.foo.bar"
|
||||
// if that is the closest match. This function MAY return nil
|
||||
// if no match is found.
|
||||
// if that is the closest match. If no match is found, the first
|
||||
// (random) config will be loaded, which will defer any TLS alerts
|
||||
// to the certificate validation (this may or may not be ideal;
|
||||
// let's talk about it if this becomes problematic).
|
||||
//
|
||||
// This function follows nearly the same logic to lookup
|
||||
// a hostname as the getCertificate function uses.
|
||||
func (cg ConfigGroup) getConfig(name string) *Config {
|
||||
func (cg configGroup) getConfig(name string) *Config {
|
||||
name = strings.ToLower(name)
|
||||
|
||||
// exact match? great, let's use it
|
||||
|
@ -42,14 +43,36 @@ func (cg ConfigGroup) getConfig(name string) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
// as last resort, try a config that serves all names
|
||||
// as a fallback, try a config that serves all names
|
||||
if config, ok := cg[""]; ok {
|
||||
return config
|
||||
}
|
||||
|
||||
// as a last resort, use a random config
|
||||
// (even if the config isn't for that hostname,
|
||||
// it should help us serve clients without SNI
|
||||
// or at least defer TLS alerts to the cert)
|
||||
for _, config := range cg {
|
||||
return config
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfigForClient gets a TLS configuration satisfying clientHello.
|
||||
// In getting the configuration, it abides the rules and settings
|
||||
// defined in the Config that matches clientHello.ServerName. If no
|
||||
// tls.Config is set on the matching Config, a nil value is returned.
|
||||
//
|
||||
// This method is safe for use as a tls.Config.GetConfigForClient callback.
|
||||
func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
config := cg.getConfig(clientHello.ServerName)
|
||||
if config != nil {
|
||||
return config.tlsConfig, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetCertificate gets a certificate to satisfy clientHello. In getting
|
||||
// the certificate, it abides the rules and settings defined in the
|
||||
// Config that matches clientHello.ServerName. It first checks the in-
|
||||
|
@ -58,27 +81,11 @@ func (cg ConfigGroup) getConfig(name string) *Config {
|
|||
// via ACME.
|
||||
//
|
||||
// This method is safe for use as a tls.Config.GetCertificate callback.
|
||||
func (cg ConfigGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
|
||||
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := cfg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
|
||||
return &cert.Certificate, err
|
||||
}
|
||||
|
||||
// GetConfigForClient gets a TLS configuration satisfying clientHello. In getting
|
||||
// the configuration, it abides the rules and settings defined in the
|
||||
// Config that matches clientHello.ServerName.
|
||||
//
|
||||
// This method is safe for use as a tls.Config.GetConfigForClient callback.
|
||||
func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
|
||||
config := cg.getConfig(clientHello.ServerName)
|
||||
|
||||
if config != nil {
|
||||
return config.tlsConfig, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getCertDuringHandshake will get a certificate for name. It first tries
|
||||
// the in-memory cache. If no certificate for name is in the cache, the
|
||||
// config most closely corresponding to name will be loaded. If that config
|
||||
|
@ -90,21 +97,20 @@ func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls
|
|||
// certificate is available.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||
// First check our in-memory cache to see if we've already loaded it
|
||||
cert, matched, defaulted := getCertificate(name)
|
||||
if matched {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// Get the relevant TLS config for this name. If OnDemand is enabled,
|
||||
// then we might be able to load or obtain a needed certificate.
|
||||
cfg := cg.getConfig(name)
|
||||
if cfg != nil && cfg.OnDemand && loadIfNecessary {
|
||||
// If OnDemand is enabled, then we might be able to load or
|
||||
// obtain a needed certificate
|
||||
if cfg.OnDemand && loadIfNecessary {
|
||||
// Then check to see if we have one on disk
|
||||
loadedCert, err := CacheManagedCertificate(name, cfg)
|
||||
loadedCert, err := cfg.CacheManagedCertificate(name)
|
||||
if err == nil {
|
||||
loadedCert, err = cg.handshakeMaintenance(name, loadedCert)
|
||||
loadedCert, err = cfg.handshakeMaintenance(name, loadedCert)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
|
||||
}
|
||||
|
@ -116,7 +122,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
|
|||
name = strings.ToLower(name)
|
||||
|
||||
// Make sure aren't over any applicable limits
|
||||
err := cg.checkLimitsForObtainingNewCerts(name, cfg)
|
||||
err := cfg.checkLimitsForObtainingNewCerts(name)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
|
@ -127,7 +133,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
|
|||
}
|
||||
|
||||
// Obtain certificate from the CA
|
||||
return cg.obtainOnDemandCertificate(name, cfg)
|
||||
return cfg.obtainOnDemandCertificate(name)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,7 +149,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
|
|||
// now according to mitigating factors we keep track of and preferences the
|
||||
// user has set. If a non-nil error is returned, do not issue a new certificate
|
||||
// for name.
|
||||
func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error {
|
||||
func (cfg *Config) checkLimitsForObtainingNewCerts(name string) error {
|
||||
// User can set hard limit for number of certs for the process to issue
|
||||
if cfg.OnDemandState.MaxObtain > 0 &&
|
||||
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
|
||||
|
@ -167,7 +173,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
|
|||
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
|
||||
}
|
||||
|
||||
// 👍Good to go
|
||||
// Good to go 👍
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -176,7 +182,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
|
|||
// name, it will wait and use what the other goroutine obtained.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) {
|
||||
func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
|
||||
// We must protect this process from happening concurrently, so synchronize.
|
||||
obtainCertWaitChansMu.Lock()
|
||||
wait, ok := obtainCertWaitChans[name]
|
||||
|
@ -185,7 +191,7 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
|
|||
// wait for it to finish obtaining the cert and then we'll use it.
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
<-wait
|
||||
return cg.getCertDuringHandshake(name, true, false)
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// looks like it's up to us to do all the work and obtain the cert.
|
||||
|
@ -228,19 +234,19 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
|
|||
lastIssueTimeMu.Unlock()
|
||||
|
||||
// certificate is already on disk; now just start over to load it and serve it
|
||||
return cg.getCertDuringHandshake(name, true, false)
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// handshakeMaintenance performs a check on cert for expiration and OCSP
|
||||
// validity.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
|
||||
func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
|
||||
// Check cert expiration
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBefore {
|
||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
||||
return cg.renewDynamicCertificate(name, cert.Config)
|
||||
return cfg.renewDynamicCertificate(name)
|
||||
}
|
||||
|
||||
// Check OCSP staple validity
|
||||
|
@ -268,7 +274,7 @@ func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certi
|
|||
// usable. name should already be lower-cased before calling this function.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) {
|
||||
func (cfg *Config) renewDynamicCertificate(name string) (Certificate, error) {
|
||||
obtainCertWaitChansMu.Lock()
|
||||
wait, ok := obtainCertWaitChans[name]
|
||||
if ok {
|
||||
|
@ -276,7 +282,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
|
|||
// wait for it to finish, then we'll use the new one.
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
<-wait
|
||||
return cg.getCertDuringHandshake(name, true, false)
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// looks like it's up to us to do all the work and renew the cert
|
||||
|
@ -300,7 +306,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
|
|||
return Certificate{}, err
|
||||
}
|
||||
|
||||
return cg.getCertDuringHandshake(name, true, false)
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
func TestGetCertificate(t *testing.T) {
|
||||
defer func() { certCache = make(map[string]Certificate) }()
|
||||
|
||||
cg := make(ConfigGroup)
|
||||
cfg := new(Config)
|
||||
|
||||
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
||||
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
||||
|
@ -17,10 +17,10 @@ func TestGetCertificate(t *testing.T) {
|
|||
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
|
||||
|
||||
// When cache is empty
|
||||
if cert, err := cg.GetCertificate(hello); err == nil {
|
||||
if cert, err := cfg.GetCertificate(hello); err == nil {
|
||||
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
|
||||
}
|
||||
if cert, err := cg.GetCertificate(helloNoSNI); err == nil {
|
||||
if cert, err := cfg.GetCertificate(helloNoSNI); err == nil {
|
||||
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
||||
}
|
||||
|
||||
|
@ -28,12 +28,12 @@ func TestGetCertificate(t *testing.T) {
|
|||
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
||||
certCache[""] = defaultCert
|
||||
certCache["example.com"] = defaultCert
|
||||
if cert, err := cg.GetCertificate(hello); err != nil {
|
||||
if cert, err := cfg.GetCertificate(hello); err != nil {
|
||||
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
||||
}
|
||||
if cert, err := cg.GetCertificate(helloNoSNI); err != nil {
|
||||
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
|
||||
|
@ -41,14 +41,14 @@ func TestGetCertificate(t *testing.T) {
|
|||
|
||||
// When retrieving wildcard certificate
|
||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
|
||||
if cert, err := cg.GetCertificate(helloSub); err != nil {
|
||||
if cert, err := cfg.GetCertificate(helloSub); err != nil {
|
||||
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
||||
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
||||
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
||||
}
|
||||
|
||||
// When no certificate matches, the default is returned
|
||||
if cert, err := cg.GetCertificate(helloNoMatch); err != nil {
|
||||
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil {
|
||||
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Expected default cert with no matches, got: %v", cert)
|
||||
|
|
|
@ -152,7 +152,7 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
|
|||
delete(certCache, "")
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
_, err := CacheManagedCertificate(cert.Names[0], cert.Config)
|
||||
_, err := cert.Config.CacheManagedCertificate(cert.Names[0])
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
return err // operator is present, so report error immediately
|
||||
|
|
|
@ -164,21 +164,15 @@ func setupTLS(c *caddy.Controller) error {
|
|||
return c.Errf("Unsupported Storage provider '%s'", args[0])
|
||||
}
|
||||
config.StorageProvider = args[0]
|
||||
|
||||
case "http2":
|
||||
case "alpn":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 1 {
|
||||
if len(args) == 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case "off":
|
||||
config.DisableHTTP2 = true
|
||||
default:
|
||||
c.ArgErr()
|
||||
for _, arg := range args {
|
||||
config.ALPN = append(config.ALPN, arg)
|
||||
}
|
||||
|
||||
case "muststaple":
|
||||
case "must_staple":
|
||||
config.MustStaple = true
|
||||
default:
|
||||
return c.Errf("Unknown keyword '%s'", c.Val())
|
||||
|
|
|
@ -91,8 +91,8 @@ func TestSetupParseBasic(t *testing.T) {
|
|||
t.Error("Expected PreferServerCipherSuites = true, but was false")
|
||||
}
|
||||
|
||||
if cfg.DisableHTTP2 {
|
||||
t.Error("Expected HTTP2 to be enabled by default")
|
||||
if len(cfg.ALPN) != 0 {
|
||||
t.Error("Expected ALPN empty by default")
|
||||
}
|
||||
|
||||
// Ensure curve count is correct
|
||||
|
@ -121,8 +121,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
|||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
protocols tls1.0 tls1.2
|
||||
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
|
||||
muststaple
|
||||
http2 off
|
||||
must_staple
|
||||
alpn http/1.1
|
||||
}`
|
||||
cfg := new(Config)
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
|
@ -149,8 +149,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
|||
t.Error("Expected must staple to be true")
|
||||
}
|
||||
|
||||
if !cfg.DisableHTTP2 {
|
||||
t.Error("Expected HTTP2 to be disabled")
|
||||
if len(cfg.ALPN) != 1 || cfg.ALPN[0] != "http/1.1" {
|
||||
t.Errorf("Expected ALPN to contain only 'http/1.1' but got: %v", cfg.ALPN)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue