diff --git a/ISSUE_TEMPLATE b/ISSUE_TEMPLATE index b49ec37cc..44e2caa6b 100644 --- a/ISSUE_TEMPLATE +++ b/ISSUE_TEMPLATE @@ -1,4 +1,4 @@ -(Are you asking for help with Caddy? Please use our forum instead: https://forum.caddyserver.com. If you are filing a bug report, please answer the following questions. If your issue is not a bug report, you do not need to use this template. Either way, please consider donating if we've helped you. Thanks!) +(Are you asking for help with using Caddy? Please use our forum instead: https://forum.caddyserver.com. If you are filing a bug report, please answer the following questions. If your issue is not a bug report, you do not need to use this template. Either way, please consider donating if we've helped you. Thanks!) #### 1. What version of Caddy are you running (`caddy -version`)? diff --git a/caddy/assets/path.go b/assets.go similarity index 57% rename from caddy/assets/path.go rename to assets.go index 46b883b1c..e353af8d3 100644 --- a/caddy/assets/path.go +++ b/assets.go @@ -1,4 +1,4 @@ -package assets +package caddy import ( "os" @@ -6,10 +6,15 @@ import ( "runtime" ) -// Path returns the path to the folder -// where the application may store data. This -// currently resolves to ~/.caddy -func Path() string { +// AssetsPath returns the path to the folder +// where the application may store data. If +// CADDYPATH env variable is set, that value +// is used. Otherwise, the path is the result +// of evaluating "$HOME/.caddy". +func AssetsPath() string { + if caddyPath := os.Getenv("CADDYPATH"); caddyPath != "" { + return caddyPath + } return filepath.Join(userHomeDir(), ".caddy") } diff --git a/assets_test.go b/assets_test.go new file mode 100644 index 000000000..193361048 --- /dev/null +++ b/assets_test.go @@ -0,0 +1,19 @@ +package caddy + +import ( + "os" + "strings" + "testing" +) + +func TestAssetsPath(t *testing.T) { + if actual := AssetsPath(); !strings.HasSuffix(actual, ".caddy") { + t.Errorf("Expected path to be a .caddy folder, got: %v", actual) + } + + os.Setenv("CADDYPATH", "testpath") + if actual, expected := AssetsPath(), "testpath"; actual != expected { + t.Errorf("Expected path to be %v, got: %v", expected, actual) + } + os.Setenv("CADDYPATH", "") +} diff --git a/caddy.go b/caddy.go new file mode 100644 index 000000000..69897cb33 --- /dev/null +++ b/caddy.go @@ -0,0 +1,752 @@ +package caddy + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/mholt/caddy/caddyfile" +) + +// Configurable application parameters +var ( + // AppName is the name of the application. + AppName string + + // AppVersion is the version of the application. + AppVersion string + + // Quiet mode will not show any informative output on initialization. + Quiet bool + + // PidFile is the path to the pidfile to create. + PidFile string + + // GracefulTimeout is the maximum duration of a graceful shutdown. + GracefulTimeout time.Duration + + // isUpgrade will be set to true if this process + // was started as part of an upgrade, where a parent + // Caddy process started this one. + isUpgrade bool +) + +// Instance contains the state of servers created as a result of +// calling Start and can be used to access or control those servers. +type Instance struct { + // serverType is the name of the instance's server type + serverType string + + // caddyfileInput is the input configuration text used for this process + caddyfileInput Input + + // wg is used to wait for all servers to shut down + wg sync.WaitGroup + + // servers is the list of servers with their listeners... + servers []serverListener + + // these are callbacks to execute when certain events happen + onStartup []func() error + onRestart []func() error + onShutdown []func() error +} + +// Stop stops all servers contained in i. It does NOT +// execute shutdown callbacks. +func (i *Instance) Stop() error { + // stop the servers + for _, s := range i.servers { + if gs, ok := s.server.(GracefulServer); ok { + if err := gs.Stop(); err != nil { + log.Printf("[ERROR] Stopping %s: %v", gs.Address(), err) + } + } + } + + // splice instance list to delete this one + for j, other := range instances { + if other == i { + instances = append(instances[:j], instances[j+1:]...) + break + } + } + + return nil +} + +// shutdownCallbacks executes all the shutdown callbacks of i. +// An error returned from one does not stop execution of the rest. +// All the errors will be returned. +func (i *Instance) shutdownCallbacks() []error { + var errs []error + for _, shutdownFunc := range i.onShutdown { + err := shutdownFunc() + if err != nil { + errs = append(errs, err) + } + } + return errs +} + +// Restart replaces the servers in i with new servers created from +// executing the newCaddyfile. Upon success, it returns the new +// instance to replace i. Upon failure, i will not be replaced. +func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) { + log.Println("[INFO] Reloading") + + // run restart callbacks + for _, fn := range i.onRestart { + err := fn() + if err != nil { + return i, err + } + } + + if newCaddyfile == nil { + newCaddyfile = i.caddyfileInput + } + + // Add file descriptors of all the sockets that are capable of it + restartFds := make(map[string]restartPair) + for _, s := range i.servers { + gs, srvOk := s.server.(GracefulServer) + ln, lnOk := s.listener.(Listener) + if srvOk && lnOk { + restartFds[gs.Address()] = restartPair{server: gs, listener: ln} + } + } + + // create new instance; if the restart fails, it is simply discarded + newInst := &Instance{serverType: newCaddyfile.ServerType()} + + // attempt to start new instance + err := startWithListenerFds(newCaddyfile, newInst, restartFds) + if err != nil { + return i, err + } + + // success! bump the old instance out so it will be garbage-collected + instancesMu.Lock() + for j, other := range instances { + if other == i { + instances = append(instances[:j], instances[j+1:]...) + break + } + } + instancesMu.Unlock() + + log.Println("[INFO] Reloading complete") + + return newInst, nil +} + +// SaveServer adds s and its associated listener ln to the +// internally-kept list of servers that is running. For +// saved servers, graceful restarts will be provided. +func (i *Instance) SaveServer(s Server, ln net.Listener) { + i.servers = append(i.servers, serverListener{server: s, listener: ln}) +} + +// HasListenerWithAddress returns whether this package is +// tracking a server using a listener with the address +// addr. +func HasListenerWithAddress(addr string) bool { + instancesMu.Lock() + defer instancesMu.Unlock() + for _, inst := range instances { + for _, sln := range inst.servers { + if listenerAddrEqual(sln.listener, addr) { + return true + } + } + } + return false +} + +// listenerAddrEqual compares a listener's address with +// addr. Extra care is taken to match addresses with an +// empty hostname portion, as listeners tend to report +// [::]:80, for example, when the matching address that +// created the listener might be simply :80. +func listenerAddrEqual(ln net.Listener, addr string) bool { + lnAddr := ln.Addr().String() + hostname, port, err := net.SplitHostPort(addr) + if err != nil || hostname != "" { + return lnAddr == addr + } + if lnAddr == net.JoinHostPort("::", port) { + return true + } + if lnAddr == net.JoinHostPort("0.0.0.0", port) { + return true + } + return false +} + +/* +// TODO: We should be able to support UDP servers... I'm considering this pattern. + +type UDPListener struct { + *net.UDPConn +} + +func (u UDPListener) Accept() (net.Conn, error) { + return u.UDPConn, nil +} + +func (u UDPListener) Close() error { + return u.UDPConn.Close() +} + +func (u UDPListener) Addr() net.Addr { + return u.UDPConn.LocalAddr() +} + +var _ net.Listener = UDPListener{} +*/ + +// Server is a type that can listen and serve. A Server +// must associate with exactly zero or one listeners. +type Server interface { + // Listen starts listening by creating a new listener + // and returning it. It does not start accepting + // connections. + Listen() (net.Listener, error) + + // Serve starts serving using the provided listener. + // Serve must start the server loop nearly immediately, + // or at least not return any errors before the server + // loop begins. Serve blocks indefinitely, or in other + // words, until the server is stopped. + Serve(net.Listener) error +} + +// Stopper is a type that can stop serving. The stop +// does not necessarily have to be graceful. +type Stopper interface { + // Stop stops the server. It blocks until the + // server is completely stopped. + Stop() error +} + +// GracefulServer is a Server and Stopper, the stopping +// of which is graceful (whatever that means for the kind +// of server being implemented). It must be able to return +// the address it is configured to listen on so that its +// listener can be paired with it upon graceful restarts. +// The net.Listener that a GracefulServer creates must +// implement the Listener interface for restarts to be +// graceful (assuming the listener is for TCP). +type GracefulServer interface { + Server + Stopper + + // Address returns the address the server should + // listen on; it is used to pair the server to + // its listener during a graceful/zero-downtime + // restart. Thus when implementing this method, + // you must not access a listener to get the + // address; you must store the address the + // server is to serve on some other way. + Address() string +} + +// Listener is a net.Listener with an underlying file descriptor. +// A server's listener should implement this interface if it is +// to support zero-downtime reloads. +type Listener interface { + net.Listener + File() (*os.File, error) +} + +// AfterStartup is an interface that can be implemented +// by a server type that wants to run some code after all +// servers for the same Instance have started. +type AfterStartup interface { + OnStartupComplete() +} + +// LoadCaddyfile loads a Caddyfile by calling the plugged in +// Caddyfile loader methods. An error is returned if more than +// one loader returns a non-nil Caddyfile input. If no loaders +// load a Caddyfile, the default loader is used. If no default +// loader is registered or it returns nil, the server type's +// default Caddyfile is loaded. If the server type does not +// specify any default Caddyfile value, then an empty Caddyfile +// is returned. Consequently, this function never returns a nil +// value as long as there are no errors. +func LoadCaddyfile(serverType string) (Input, error) { + // Ask plugged-in loaders for a Caddyfile + cdyfile, err := loadCaddyfileInput(serverType) + if err != nil { + return nil, err + } + + // Otherwise revert to default + if cdyfile == nil { + cdyfile = DefaultInput(serverType) + } + + // Still nil? Geez. + if cdyfile == nil { + cdyfile = CaddyfileInput{ServerTypeName: serverType} + } + + return cdyfile, nil +} + +// Wait blocks until all of i's servers have stopped. +func (i *Instance) Wait() { + i.wg.Wait() +} + +// CaddyfileFromPipe loads the Caddyfile input from f if f is +// not interactive input. f is assumed to be a pipe or stream, +// such as os.Stdin. If f is not a pipe, no error is returned +// but the Input value will be nil. An error is only returned +// if there was an error reading the pipe, even if the length +// of what was read is 0. +func CaddyfileFromPipe(f *os.File) (Input, error) { + fi, err := f.Stat() + if err == nil && fi.Mode()&os.ModeCharDevice == 0 { + // Note that a non-nil error is not a problem. Windows + // will not create a stdin if there is no pipe, which + // produces an error when calling Stat(). But Unix will + // make one either way, which is why we also check that + // bitmask. + // NOTE: Reading from stdin after this fails (e.g. for the let's encrypt email address) (OS X) + confBody, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + return CaddyfileInput{ + Contents: confBody, + Filepath: f.Name(), + }, nil + } + + // not having input from the pipe is not itself an error, + // just means no input to return. + return nil, nil +} + +// Caddyfile returns the Caddyfile used to create i. +func (i *Instance) Caddyfile() Input { + return i.caddyfileInput +} + +// Start starts Caddy with the given Caddyfile. +// +// This function blocks until all the servers are listening. +func Start(cdyfile Input) (*Instance, error) { + writePidFile() + inst := &Instance{serverType: cdyfile.ServerType()} + return inst, startWithListenerFds(cdyfile, inst, nil) +} + +func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartPair) error { + if cdyfile == nil { + cdyfile = CaddyfileInput{} + } + + stypeName := cdyfile.ServerType() + + stype, err := getServerType(stypeName) + if err != nil { + return err + } + + inst.caddyfileInput = cdyfile + + sblocks, err := loadServerBlocks(stypeName, path.Base(cdyfile.Path()), bytes.NewReader(cdyfile.Body())) + if err != nil { + return err + } + + ctx := stype.NewContext() + + sblocks, err = ctx.InspectServerBlocks(cdyfile.Path(), sblocks) + if err != nil { + return err + } + + err = executeDirectives(inst, cdyfile.Path(), stype.Directives, sblocks) + if err != nil { + return err + } + + slist, err := ctx.MakeServers() + if err != nil { + return err + } + + if restartFds == nil { + // run startup callbacks since this is not a restart + for _, startupFunc := range inst.onStartup { + err := startupFunc() + if err != nil { + return err + } + } + } + + err = startServers(slist, inst, restartFds) + if err != nil { + return err + } + + instancesMu.Lock() + instances = append(instances, inst) + instancesMu.Unlock() + + // run any AfterStartup callbacks if this is not + // part of a restart; then show file descriptor notice + if restartFds == nil { + for _, srvln := range inst.servers { + if srv, ok := srvln.server.(AfterStartup); ok { + srv.OnStartupComplete() + } + } + if !Quiet { + for _, srvln := range inst.servers { + if !IsLocalhost(srvln.listener.Addr().String()) { + checkFdlimit() + break + } + } + } + } + + return nil +} + +func executeDirectives(inst *Instance, filename string, + directives []string, sblocks []caddyfile.ServerBlock) error { + + // map of server block ID to map of directive name to whatever. + storages := make(map[int]map[string]interface{}) + + // It is crucial that directives are executed in the proper order. + // We loop with the directives on the outer loop so we execute + // a directive for all server blocks before going to the next directive. + // This is important mainly due to the parsing callbacks (below). + for _, dir := range directives { + for i, sb := range sblocks { + var once sync.Once + if _, ok := storages[i]; !ok { + storages[i] = make(map[string]interface{}) + } + + for j, key := range sb.Keys { + // Execute directive if it is in the server block + if tokens, ok := sb.Tokens[dir]; ok { + controller := &Controller{ + instance: inst, + Key: key, + Dispenser: caddyfile.NewDispenserTokens(filename, tokens), + OncePerServerBlock: func(f func() error) error { + var err error + once.Do(func() { + err = f() + }) + return err + }, + ServerBlockIndex: i, + ServerBlockKeyIndex: j, + ServerBlockKeys: sb.Keys, + ServerBlockStorage: storages[i][dir], + } + + setup, err := DirectiveAction(inst.serverType, dir) + if err != nil { + return err + } + + err = setup(controller) + if err != nil { + return err + } + + storages[i][dir] = controller.ServerBlockStorage // persist for this server block + } + } + } + + // See if there are any callbacks to execute after this directive + if allCallbacks, ok := parsingCallbacks[inst.serverType]; ok { + callbacks := allCallbacks[dir] + for _, callback := range callbacks { + if err := callback(); err != nil { + return err + } + } + } + } + + return nil +} + +func startServers(serverList []Server, inst *Instance, restartFds map[string]restartPair) error { + errChan := make(chan error, len(serverList)) + + for _, s := range serverList { + var ln net.Listener + var err error + + // If this is a reload and s is a GracefulServer, + // reuse the listener for a graceful restart. + if gs, ok := s.(GracefulServer); ok && restartFds != nil { + addr := gs.Address() + if old, ok := restartFds[addr]; ok { + file, err := old.listener.File() + if err != nil { + return err + } + ln, err = net.FileListener(file) + if err != nil { + return err + } + file.Close() + delete(restartFds, addr) + } + } + + if ln == nil { + ln, err = s.Listen() + if err != nil { + return err + } + } + + inst.wg.Add(1) + go func(s Server, ln net.Listener, inst *Instance) { + defer inst.wg.Done() + errChan <- s.Serve(ln) + }(s, ln, inst) + + inst.servers = append(inst.servers, serverListener{server: s, listener: ln}) + } + + // Close the remaining (unused) file descriptors to free up resources + // and stop old servers that aren't used anymore + for key, old := range restartFds { + if err := old.server.Stop(); err != nil { + log.Printf("[ERROR] Stopping %s: %v", old.server.Address(), err) + } + delete(restartFds, key) + } + + // Log errors that may be returned from Serve() calls, + // these errors should only be occurring in the server loop. + go func() { + for err := range errChan { + if err == nil { + continue + } + if strings.Contains(err.Error(), "use of closed network connection") { + // this error is normal when closing the listener + continue + } + log.Println(err) + } + }() + + return nil +} + +func getServerType(serverType string) (ServerType, error) { + stype, ok := serverTypes[serverType] + if ok { + return stype, nil + } + if serverType == "" { + if len(serverTypes) == 1 { + for _, stype := range serverTypes { + return stype, nil + } + } + return ServerType{}, fmt.Errorf("multiple server types available; must choose one") + } + if len(serverTypes) == 0 { + return ServerType{}, fmt.Errorf("no server types plugged in") + } + return ServerType{}, fmt.Errorf("unknown server type '%s'", serverType) +} + +func loadServerBlocks(serverType, filename string, input io.Reader) ([]caddyfile.ServerBlock, error) { + validDirectives := ValidDirectives(serverType) + serverBlocks, err := caddyfile.ServerBlocks(filename, input, validDirectives) + if err != nil { + return nil, err + } + if len(serverBlocks) == 0 && serverTypes[serverType].DefaultInput != nil { + newInput := serverTypes[serverType].DefaultInput() + serverBlocks, err = caddyfile.ServerBlocks(newInput.Path(), + bytes.NewReader(newInput.Body()), validDirectives) + if err != nil { + return nil, err + } + } + return serverBlocks, nil +} + +// Stop stops ALL servers. It blocks until they are all stopped. +// It does NOT execute shutdown callbacks, and it deletes all +// instances after stopping is completed. Do not re-use any +// references to old instances after calling Stop. +func Stop() error { + instancesMu.Lock() + for _, inst := range instances { + if err := inst.Stop(); err != nil { + log.Printf("[ERROR] Stopping %s: %v", inst.serverType, err) + } + } + instances = []*Instance{} + instancesMu.Unlock() + return nil +} + +// IsLocalhost returns true if the hostname of addr looks +// explicitly like a common local hostname. addr must only +// be a host or a host:port combination. +func IsLocalhost(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr // happens if the addr is just a hostname + } + return host == "localhost" || + host == "::1" || + strings.HasPrefix(host, "127.") +} + +// checkFdlimit issues a warning if the OS limit for +// max file descriptors is below a recommended minimum. +func checkFdlimit() { + const min = 8192 + + // Warn if ulimit is too low for production sites + if runtime.GOOS == "linux" || runtime.GOOS == "darwin" { + out, err := exec.Command("sh", "-c", "ulimit -n").Output() // use sh because ulimit isn't in Linux $PATH + if err == nil { + lim, err := strconv.Atoi(string(bytes.TrimSpace(out))) + if err == nil && lim < min { + fmt.Printf("WARNING: File descriptor limit %d is too low for production servers. "+ + "At least %d is recommended. Fix with \"ulimit -n %d\".\n", lim, min, min) + } + } + } +} + +// Upgrade re-launches the process, preserving the listeners +// for a graceful restart. It does NOT load new configuration; +// it only starts the process anew with a fresh binary. +// +// TODO: This is not yet implemented +func Upgrade() error { + return fmt.Errorf("not implemented") + // TODO: have child process set isUpgrade = true +} + +// IsUpgrade returns true if this process is part of an upgrade +// where a parent caddy process spawned this one to ugprade +// the binary. +func IsUpgrade() bool { + return isUpgrade +} + +// CaddyfileInput represents a Caddyfile as input +// and is simply a convenient way to implement +// the Input interface. +type CaddyfileInput struct { + Filepath string + Contents []byte + ServerTypeName string +} + +// Body returns c.Contents. +func (c CaddyfileInput) Body() []byte { return c.Contents } + +// Path returns c.Filepath. +func (c CaddyfileInput) Path() string { return c.Filepath } + +// ServerType returns c.ServerType. +func (c CaddyfileInput) ServerType() string { return c.ServerTypeName } + +// Input represents a Caddyfile; its contents and file path +// (which should include the file name at the end of the path). +// If path does not apply (e.g. piped input) you may use +// any understandable value. The path is mainly used for logging, +// error messages, and debugging. +type Input interface { + // Gets the Caddyfile contents + Body() []byte + + // Gets the path to the origin file + Path() string + + // The type of server this input is intended for + ServerType() string +} + +// DefaultInput returns the default Caddyfile input +// to use when it is otherwise empty or missing. +// It uses the default host and port (depends on +// host, e.g. localhost is 2015, otherwise 443) and +// root. +func DefaultInput(serverType string) Input { + if _, ok := serverTypes[serverType]; !ok { + return nil + } + if serverTypes[serverType].DefaultInput == nil { + return nil + } + return serverTypes[serverType].DefaultInput() +} + +// IsLoopback returns true if host looks explicitly like a loopback address. +func IsLoopback(host string) bool { + return host == "localhost" || + host == "::1" || + strings.HasPrefix(host, "127.") +} + +// writePidFile writes the process ID to the file at PidFile. +// It does nothing if PidFile is not set. +func writePidFile() error { + if PidFile == "" { + return nil + } + pid := []byte(strconv.Itoa(os.Getpid()) + "\n") + return ioutil.WriteFile(PidFile, pid, 0644) +} + +type restartPair struct { + server GracefulServer + listener Listener +} + +var ( + // instances is the list of running Instances. + instances []*Instance + + // instancesMu protects instances. + instancesMu sync.Mutex +) + +const ( + // DefaultConfigFile is the name of the configuration file that is loaded + // by default if no other file is specified. + DefaultConfigFile = "Caddyfile" +) diff --git a/caddy/assets/path_test.go b/caddy/assets/path_test.go deleted file mode 100644 index 374f813af..000000000 --- a/caddy/assets/path_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package assets - -import ( - "strings" - "testing" -) - -func TestPath(t *testing.T) { - if actual := Path(); !strings.HasSuffix(actual, ".caddy") { - t.Errorf("Expected path to be a .caddy folder, got: %v", actual) - } -} diff --git a/build.bash b/caddy/build.bash similarity index 94% rename from build.bash rename to caddy/build.bash index ec8a67d56..03c553451 100755 --- a/build.bash +++ b/caddy/build.bash @@ -7,19 +7,18 @@ # $ ./build.bash [output_filename] [git_repo] # # Outputs compiled program in current directory. -# Default file name is 'ecaddy'. # Default git repo is current directory. # Builds always take place from current directory. set -euo pipefail : ${output_filename:="${1:-}"} -: ${output_filename:="ecaddy"} +: ${output_filename:="caddy"} : ${git_repo:="${2:-}"} : ${git_repo:="."} -pkg=main +pkg=github.com/mholt/caddy/caddy/caddymain ldflags=() # Timestamp of build diff --git a/caddy/caddy.go b/caddy/caddy.go deleted file mode 100644 index 1484e1127..000000000 --- a/caddy/caddy.go +++ /dev/null @@ -1,346 +0,0 @@ -// Package caddy implements the Caddy web server as a service -// in your own Go programs. -// -// To use this package, follow a few simple steps: -// -// 1. Set the AppName and AppVersion variables. -// 2. Call LoadCaddyfile() to get the Caddyfile. -// You should pass in your own Caddyfile loader. -// 3. Call caddy.Start() to start Caddy, caddy.Stop() -// to stop it, or caddy.Restart() to restart it. -// -// You should use caddy.Wait() to wait for all Caddy servers -// to quit before your process exits. -package caddy - -import ( - "bytes" - "errors" - "fmt" - "io/ioutil" - "log" - "net" - "os" - "path" - "strings" - "sync" - "time" - - "github.com/mholt/caddy/caddy/https" - "github.com/mholt/caddy/server" -) - -// Configurable application parameters -var ( - // AppName is the name of the application. - AppName string - - // AppVersion is the version of the application. - AppVersion string - - // Quiet when set to true, will not show any informative output on initialization. - Quiet bool - - // HTTP2 indicates whether HTTP2 is enabled or not. - HTTP2 bool - - // PidFile is the path to the pidfile to create. - PidFile string - - // GracefulTimeout is the maximum duration of a graceful shutdown. - GracefulTimeout time.Duration -) - -var ( - // caddyfile is the input configuration text used for this process - caddyfile Input - - // caddyfileMu protects caddyfile during changes - caddyfileMu sync.Mutex - - // servers is a list of all the currently-listening servers - servers []*server.Server - - // serversMu protects the servers slice during changes - serversMu sync.Mutex - - // wg is used to wait for all servers to shut down - wg sync.WaitGroup - - // restartFds keeps the servers' sockets for graceful in-process restart - restartFds = make(map[string]*os.File) - - // startedBefore should be set to true if caddy has been started - // at least once (does not indicate whether currently running). - startedBefore bool -) - -const ( - // DefaultHost is the default host. - DefaultHost = "" - // DefaultPort is the default port. - DefaultPort = "2015" - // DefaultRoot is the default root folder. - DefaultRoot = "." -) - -// Start starts Caddy with the given Caddyfile. If cdyfile -// is nil, the LoadCaddyfile function will be called to get -// one. -// -// This function blocks until all the servers are listening. -func Start(cdyfile Input) (err error) { - // Input must never be nil; try to load something - if cdyfile == nil { - cdyfile, err = LoadCaddyfile(nil) - if err != nil { - return err - } - } - - caddyfileMu.Lock() - caddyfile = cdyfile - caddyfileMu.Unlock() - - // load the server configs (activates Let's Encrypt) - configs, err := loadConfigs(path.Base(cdyfile.Path()), bytes.NewReader(cdyfile.Body())) - if err != nil { - return err - } - - // group virtualhosts by address - groupings, err := arrangeBindings(configs) - if err != nil { - return err - } - - // Start each server with its one or more configurations - err = startServers(groupings) - if err != nil { - return err - } - - showInitializationOutput(groupings) - - startedBefore = true - - return nil -} - -// showInitializationOutput just outputs some basic information about -// what is being served to stdout, as well as any applicable, non-essential -// warnings for the user. -func showInitializationOutput(groupings bindingGroup) { - // Show initialization output - if !Quiet && !IsRestart() { - var checkedFdLimit bool - for _, group := range groupings { - for _, conf := range group.Configs { - // Print address of site - fmt.Println(conf.Address()) - - // Note if non-localhost site resolves to loopback interface - if group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) { - fmt.Printf("Notice: %s is only accessible on this machine (%s)\n", - conf.Host, group.BindAddr.IP.String()) - } - if !checkedFdLimit && !group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) { - checkFdlimit() - checkedFdLimit = true - } - } - } - } -} - -// startServers starts all the servers in groupings, -// taking into account whether or not this process is -// from a graceful restart or not. It blocks until -// the servers are listening. -func startServers(groupings bindingGroup) error { - var startupWg sync.WaitGroup - errChan := make(chan error, len(groupings)) // must be buffered to allow Serve functions below to return if stopped later - - for _, group := range groupings { - s, err := server.New(group.BindAddr.String(), group.Configs, GracefulTimeout) - if err != nil { - return err - } - s.HTTP2 = HTTP2 - s.ReqCallback = https.RequestCallback // ensures we can solve ACME challenges while running - if s.OnDemandTLS { - s.TLSConfig.GetCertificate = https.GetOrObtainCertificate // TLS on demand -- awesome! - } else { - s.TLSConfig.GetCertificate = https.GetCertificate - } - - var ln server.ListenerFile - if len(restartFds) > 0 { - // Reuse the listeners for in-process restart - if file, ok := restartFds[s.Addr]; ok { - fln, err := net.FileListener(file) - if err != nil { - return err - } - - ln, ok = fln.(server.ListenerFile) - if !ok { - return errors.New("listener for " + s.Addr + " was not a ListenerFile") - } - - file.Close() - delete(restartFds, s.Addr) - } - } - - wg.Add(1) - go func(s *server.Server, ln server.ListenerFile) { - defer wg.Done() - - // run startup functions that should only execute when - // the original parent process is starting. - if !startedBefore { - err := s.RunFirstStartupFuncs() - if err != nil { - errChan <- err - return - } - } - - // start the server - if ln != nil { - errChan <- s.Serve(ln) - } else { - errChan <- s.ListenAndServe() - } - }(s, ln) - - startupWg.Add(1) - go func(s *server.Server) { - defer startupWg.Done() - s.WaitUntilStarted() - }(s) - - serversMu.Lock() - servers = append(servers, s) - serversMu.Unlock() - } - - // Close the remaining (unused) file descriptors to free up resources - if len(restartFds) > 0 { - for key, file := range restartFds { - file.Close() - delete(restartFds, key) - } - } - - // Wait for all servers to finish starting - startupWg.Wait() - - // Return the first error, if any - select { - case err := <-errChan: - // "use of closed network connection" is normal if it was a graceful shutdown - if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - return err - } - default: - } - - return nil -} - -// Stop stops all servers. It blocks until they are all stopped. -// It does NOT execute shutdown callbacks that may have been -// configured by middleware (they must be executed separately). -func Stop() error { - https.Deactivate() - - serversMu.Lock() - for _, s := range servers { - if err := s.Stop(); err != nil { - log.Printf("[ERROR] Stopping %s: %v", s.Addr, err) - } - } - servers = []*server.Server{} // don't reuse servers - serversMu.Unlock() - - return nil -} - -// Wait blocks until all servers are stopped. -func Wait() { - wg.Wait() -} - -// LoadCaddyfile loads a Caddyfile by calling the user's loader function, -// and if that returns nil, then this function resorts to the default -// configuration. Thus, if there are no other errors, this function -// always returns at least the default Caddyfile. -func LoadCaddyfile(loader func() (Input, error)) (cdyfile Input, err error) { - // Try user's loader - if cdyfile == nil && loader != nil { - cdyfile, err = loader() - } - - // Otherwise revert to default - if cdyfile == nil { - cdyfile = DefaultInput() - } - - return -} - -// CaddyfileFromPipe loads the Caddyfile input from f if f is -// not interactive input. f is assumed to be a pipe or stream, -// such as os.Stdin. If f is not a pipe, no error is returned -// but the Input value will be nil. An error is only returned -// if there was an error reading the pipe, even if the length -// of what was read is 0. -func CaddyfileFromPipe(f *os.File) (Input, error) { - fi, err := f.Stat() - if err == nil && fi.Mode()&os.ModeCharDevice == 0 { - // Note that a non-nil error is not a problem. Windows - // will not create a stdin if there is no pipe, which - // produces an error when calling Stat(). But Unix will - // make one either way, which is why we also check that - // bitmask. - // BUG: Reading from stdin after this fails (e.g. for the let's encrypt email address) (OS X) - confBody, err := ioutil.ReadAll(f) - if err != nil { - return nil, err - } - return CaddyfileInput{ - Contents: confBody, - Filepath: f.Name(), - }, nil - } - - // not having input from the pipe is not itself an error, - // just means no input to return. - return nil, nil -} - -// Caddyfile returns the current Caddyfile -func Caddyfile() Input { - caddyfileMu.Lock() - defer caddyfileMu.Unlock() - return caddyfile -} - -// Input represents a Caddyfile; its contents and file path -// (which should include the file name at the end of the path). -// If path does not apply (e.g. piped input) you may use -// any understandable value. The path is mainly used for logging, -// error messages, and debugging. -type Input interface { - // Gets the Caddyfile contents - Body() []byte - - // Gets the path to the origin file - Path() string - - // IsFile returns true if the original input was a file on the file system - // that could be loaded again later if requested. - IsFile() bool -} diff --git a/caddy/caddy_test.go b/caddy/caddy_test.go deleted file mode 100644 index be40075dc..000000000 --- a/caddy/caddy_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package caddy - -import ( - "net/http" - "testing" - "time" -) - -func TestCaddyStartStop(t *testing.T) { - caddyfile := "localhost:1984" - - for i := 0; i < 2; i++ { - err := Start(CaddyfileInput{Contents: []byte(caddyfile)}) - if err != nil { - t.Fatalf("Error starting, iteration %d: %v", i, err) - } - - client := http.Client{ - Timeout: time.Duration(2 * time.Second), - } - resp, err := client.Get("http://localhost:1984") - if err != nil { - t.Fatalf("Expected GET request to succeed (iteration %d), but it failed: %v", i, err) - } - resp.Body.Close() - - err = Stop() - if err != nil { - t.Fatalf("Error stopping, iteration %d: %v", i, err) - } - } -} diff --git a/main.go b/caddy/caddymain/run.go similarity index 51% rename from main.go rename to caddy/caddymain/run.go index abb1b3f39..e113330a5 100644 --- a/main.go +++ b/caddy/caddymain/run.go @@ -1,4 +1,4 @@ -package main +package caddymain import ( "errors" @@ -7,46 +7,55 @@ import ( "io/ioutil" "log" "os" + "path/filepath" "runtime" "strconv" "strings" - "time" - "github.com/mholt/caddy/caddy" - "github.com/mholt/caddy/caddy/https" - "github.com/xenolf/lego/acme" "gopkg.in/natefinch/lumberjack.v2" + + "github.com/xenolf/lego/acme" + + "github.com/mholt/caddy" + // plug in the HTTP server type + _ "github.com/mholt/caddy/caddyhttp" + + "github.com/mholt/caddy/caddytls" + // This is where other plugins get plugged in (imported) ) func init() { caddy.TrapSignals() setVersion() - flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement") - flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server") - flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+caddy.DefaultConfigFile+")") + + flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement") + // TODO: Change from staging to v01 + flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-staging.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory") + flag.StringVar(&conf, "conf", "", "Caddyfile to load (default \""+caddy.DefaultConfigFile+"\")") flag.StringVar(&cpu, "cpu", "100%", "CPU cap") - flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address") - flag.DurationVar(&caddy.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown") - flag.StringVar(&caddy.Host, "host", caddy.DefaultHost, "Default host") - flag.BoolVar(&caddy.HTTP2, "http2", true, "Use HTTP/2") + flag.BoolVar(&plugins, "plugins", false, "List supported plugins") // TODO: change to plugins + flag.StringVar(&caddytls.DefaultEmail, "email", "", "Default ACME CA account email address") flag.StringVar(&logfile, "log", "", "Process log file") flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file") - flag.StringVar(&caddy.Port, "port", caddy.DefaultPort, "Default port") - flag.BoolVar(&caddy.Quiet, "quiet", false, "Quiet mode (no initialization output)") + flag.BoolVar(&caddy.Quiet, "quiet", false, "Quiet mode (no initialization output)") // TODO flag.StringVar(&revoke, "revoke", "", "Hostname for which to revoke the certificate") - flag.StringVar(&caddy.Root, "root", caddy.DefaultRoot, "Root path to default site") + flag.StringVar(&serverType, "type", "http", "Type of server to run") flag.BoolVar(&version, "version", false, "Show version") - flag.BoolVar(&directives, "directives", false, "List supported directives") + + caddy.RegisterCaddyfileLoader("flag", caddy.LoaderFunc(confLoader)) + caddy.SetDefaultCaddyfileLoader("default", caddy.LoaderFunc(defaultLoader)) } -func main() { - flag.Parse() // called here in main() to allow other packages to set flags in their inits +// Run is Caddy's main() function. +func Run() { + flag.Parse() + moveStorage() // TODO: This is temporary for the 0.9 release, or until most users upgrade to 0.9+ caddy.AppName = appName caddy.AppVersion = appVersion acme.UserAgent = appName + "/" + appVersion - // set up process log before anything bad happens + // Set up process log before anything bad happens switch logfile { case "stdout": log.SetOutput(os.Stdout) @@ -63,8 +72,9 @@ func main() { }) } + // Check for one-time actions if revoke != "" { - err := https.Revoke(revoke) + err := caddytls.Revoke(revoke) if err != nil { log.Fatal(err) } @@ -78,10 +88,8 @@ func main() { } os.Exit(0) } - if directives { - for _, d := range caddy.Directives() { - fmt.Println(d) - } + if plugins { + fmt.Println(caddy.DescribePlugins()) os.Exit(0) } @@ -92,77 +100,124 @@ func main() { } // Get Caddyfile input - caddyfile, err := caddy.LoadCaddyfile(loadCaddyfile) + caddyfile, err := caddy.LoadCaddyfile(serverType) if err != nil { mustLogFatal(err) } // Start your engines - err = caddy.Start(caddyfile) + instance, err := caddy.Start(caddyfile) if err != nil { mustLogFatal(err) } // Twiddle your thumbs - caddy.Wait() + instance.Wait() } -// mustLogFatal just wraps log.Fatal() in a way that ensures the +// mustLogFatal wraps log.Fatal() in a way that ensures the // output is always printed to stderr so the user can see it // if the user is still there, even if the process log was not -// enabled. If this process is a restart, however, and the user -// might not be there anymore, this just logs to the process log -// and exits. +// enabled. If this process is an upgrade, however, and the user +// might not be there anymore, this just logs to the process +// log and exits. func mustLogFatal(args ...interface{}) { - if !caddy.IsRestart() { + if !caddy.IsUpgrade() { log.SetOutput(os.Stderr) } log.Fatal(args...) } -func loadCaddyfile() (caddy.Input, error) { - // Try -conf flag - if conf != "" { - if conf == "stdin" { - return caddy.CaddyfileFromPipe(os.Stdin) - } - - contents, err := ioutil.ReadFile(conf) - if err != nil { - return nil, err - } - - return caddy.CaddyfileInput{ - Contents: contents, - Filepath: conf, - RealFile: true, - }, nil +// confLoader loads the Caddyfile using the -conf flag. +func confLoader(serverType string) (caddy.Input, error) { + if conf == "" { + return nil, nil } - // command line args - if flag.NArg() > 0 { - confBody := caddy.Host + ":" + caddy.Port + "\n" + strings.Join(flag.Args(), "\n") - return caddy.CaddyfileInput{ - Contents: []byte(confBody), - Filepath: "args", - }, nil + if conf == "stdin" { + return caddy.CaddyfileFromPipe(os.Stdin) } - // Caddyfile in cwd + contents, err := ioutil.ReadFile(conf) + if err != nil { + return nil, err + } + return caddy.CaddyfileInput{ + Contents: contents, + Filepath: conf, + ServerTypeName: serverType, + }, nil +} + +// defaultLoader loads the Caddyfile from the current working directory. +func defaultLoader(serverType string) (caddy.Input, error) { contents, err := ioutil.ReadFile(caddy.DefaultConfigFile) if err != nil { if os.IsNotExist(err) { - return caddy.DefaultInput(), nil + return nil, nil } return nil, err } return caddy.CaddyfileInput{ - Contents: contents, - Filepath: caddy.DefaultConfigFile, - RealFile: true, + Contents: contents, + Filepath: caddy.DefaultConfigFile, + ServerTypeName: serverType, }, nil } +// moveStorage moves the old certificate storage location by +// renaming the "letsencrypt" folder to the hostname of the +// CA URL. This is TEMPORARY until most users have upgraded to 0.9+. +func moveStorage() { + oldPath := filepath.Join(caddy.AssetsPath(), "letsencrypt") + _, err := os.Stat(oldPath) + if os.IsNotExist(err) { + return + } + newPath, err := caddytls.StorageFor(caddytls.DefaultCAUrl) + if err != nil { + log.Fatalf("[ERROR] Unable to get new path for certificate storage: %v", err) + } + err = os.MkdirAll(string(newPath), 0700) + if err != nil { + log.Fatalf("[ERROR] Unable to make new certificate storage path: %v", err) + } + err = os.Rename(oldPath, string(newPath)) + if err != nil { + log.Fatalf("[ERROR] Unable to migrate certificate storage: %v", err) + } + // convert mixed case folder and file names to lowercase + filepath.Walk(string(newPath), func(path string, info os.FileInfo, err error) error { + // must be careful to only lowercase the base of the path, not the whole thing!! + base := filepath.Base(path) + if lowerBase := strings.ToLower(base); base != lowerBase { + lowerPath := filepath.Join(filepath.Dir(path), lowerBase) + err = os.Rename(path, lowerPath) + if err != nil { + log.Fatalf("[ERROR] Unable to lower-case: %v", err) + } + } + return nil + }) +} + +// setVersion figures out the version information +// based on variables set by -ldflags. +func setVersion() { + // A development build is one that's not at a tag or has uncommitted changes + devBuild = gitTag == "" || gitShortStat != "" + + // Only set the appVersion if -ldflags was used + if gitNearestTag != "" || gitTag != "" { + if devBuild && gitNearestTag != "" { + appVersion = fmt.Sprintf("%s (+%s %s)", + strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate) + } else if gitTag != "" { + appVersion = strings.TrimPrefix(gitTag, "v") + } + } +} + // setCPU parses string cpu and sets GOMAXPROCS // according to its value. It accepts either // a number (e.g. 3) or a percent (e.g. 50%). @@ -198,33 +253,17 @@ func setCPU(cpu string) error { return nil } -// setVersion figures out the version information based on -// variables set by -ldflags. -func setVersion() { - // A development build is one that's not at a tag or has uncommitted changes - devBuild = gitTag == "" || gitShortStat != "" - - // Only set the appVersion if -ldflags was used - if gitNearestTag != "" || gitTag != "" { - if devBuild && gitNearestTag != "" { - appVersion = fmt.Sprintf("%s (+%s %s)", - strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate) - } else if gitTag != "" { - appVersion = strings.TrimPrefix(gitTag, "v") - } - } -} - const appName = "Caddy" // Flags that control program flow or startup var ( + serverType string conf string cpu string logfile string revoke string version bool - directives bool + plugins bool ) // Build information obtained with the help of -ldflags diff --git a/caddy/config.go b/caddy/config.go deleted file mode 100644 index c8ea6b4da..000000000 --- a/caddy/config.go +++ /dev/null @@ -1,348 +0,0 @@ -package caddy - -import ( - "bytes" - "fmt" - "io" - "log" - "net" - "sync" - - "github.com/mholt/caddy/caddy/https" - "github.com/mholt/caddy/caddy/parse" - "github.com/mholt/caddy/caddy/setup" - "github.com/mholt/caddy/server" -) - -const ( - // DefaultConfigFile is the name of the configuration file that is loaded - // by default if no other file is specified. - DefaultConfigFile = "Caddyfile" -) - -// loadConfigsUpToIncludingTLS loads the configs from input with name filename and returns them, -// the parsed server blocks, the index of the last directive it processed, and an error (if any). -func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Config, []parse.ServerBlock, int, error) { - var configs []server.Config - - // Each server block represents similar hosts/addresses, since they - // were grouped together in the Caddyfile. - serverBlocks, err := parse.ServerBlocks(filename, input, true) - if err != nil { - return nil, nil, 0, err - } - if len(serverBlocks) == 0 { - newInput := DefaultInput() - serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true) - if err != nil { - return nil, nil, 0, err - } - } - - var lastDirectiveIndex int // we set up directives in two parts; this stores where we left off - - // Iterate each server block and make a config for each one, - // executing the directives that were parsed in order up to the tls - // directive; this is because we must activate Let's Encrypt. - for i, sb := range serverBlocks { - onces := makeOnces() - storages := makeStorages() - - for j, addr := range sb.Addresses { - config := server.Config{ - Host: addr.Host, - Port: addr.Port, - Scheme: addr.Scheme, - Root: Root, - ConfigFile: filename, - AppName: AppName, - AppVersion: AppVersion, - } - - // It is crucial that directives are executed in the proper order. - for k, dir := range directiveOrder { - // Execute directive if it is in the server block - if tokens, ok := sb.Tokens[dir.name]; ok { - // Each setup function gets a controller, from which setup functions - // get access to the config, tokens, and other state information useful - // to set up its own host only. - controller := &setup.Controller{ - Config: &config, - Dispenser: parse.NewDispenserTokens(filename, tokens), - OncePerServerBlock: func(f func() error) error { - var err error - onces[dir.name].Do(func() { - err = f() - }) - return err - }, - ServerBlockIndex: i, - ServerBlockHostIndex: j, - ServerBlockHosts: sb.HostList(), - ServerBlockStorage: storages[dir.name], - } - // execute setup function and append middleware handler, if any - midware, err := dir.setup(controller) - if err != nil { - return nil, nil, lastDirectiveIndex, err - } - if midware != nil { - config.Middleware = append(config.Middleware, midware) - } - storages[dir.name] = controller.ServerBlockStorage // persist for this server block - } - - // Stop after TLS setup, since we need to activate Let's Encrypt before continuing; - // it makes some changes to the configs that middlewares might want to know about. - if dir.name == "tls" { - lastDirectiveIndex = k - break - } - } - - configs = append(configs, config) - } - } - - return configs, serverBlocks, lastDirectiveIndex, nil -} - -// loadConfigs reads input (named filename) and parses it, returning the -// server configurations in the order they appeared in the input. As part -// of this, it activates Let's Encrypt for the configs that are produced. -// Thus, the returned configs are already optimally configured for HTTPS. -func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { - configs, serverBlocks, lastDirectiveIndex, err := loadConfigsUpToIncludingTLS(filename, input) - if err != nil { - return nil, err - } - - // Now we have all the configs, but they have only been set up to the - // point of tls. We need to activate Let's Encrypt before setting up - // the rest of the middlewares so they have correct information regarding - // TLS configuration, if necessary. (this only appends, so our iterations - // over server blocks below shouldn't be affected) - if !IsRestart() && !Quiet { - fmt.Print("Activating privacy features...") - } - configs, err = https.Activate(configs) - if err != nil { - return nil, err - } else if !IsRestart() && !Quiet { - fmt.Println(" done.") - } - - // Finish setting up the rest of the directives, now that TLS is - // optimally configured. These loops are similar to above except - // we don't iterate all the directives from the beginning and we - // don't create new configs. - configIndex := -1 - for i, sb := range serverBlocks { - onces := makeOnces() - storages := makeStorages() - - for j := range sb.Addresses { - configIndex++ - - for k := lastDirectiveIndex + 1; k < len(directiveOrder); k++ { - dir := directiveOrder[k] - - if tokens, ok := sb.Tokens[dir.name]; ok { - controller := &setup.Controller{ - Config: &configs[configIndex], - Dispenser: parse.NewDispenserTokens(filename, tokens), - OncePerServerBlock: func(f func() error) error { - var err error - onces[dir.name].Do(func() { - err = f() - }) - return err - }, - ServerBlockIndex: i, - ServerBlockHostIndex: j, - ServerBlockHosts: sb.HostList(), - ServerBlockStorage: storages[dir.name], - } - midware, err := dir.setup(controller) - if err != nil { - return nil, err - } - if midware != nil { - configs[configIndex].Middleware = append(configs[configIndex].Middleware, midware) - } - storages[dir.name] = controller.ServerBlockStorage // persist for this server block - } - } - } - } - - return configs, nil -} - -// makeOnces makes a map of directive name to sync.Once -// instance. This is intended to be called once per server -// block when setting up configs so that Setup functions -// for each directive can perform a task just once per -// server block, even if there are multiple hosts on the block. -// -// We need one Once per directive, otherwise the first -// directive to use it would exclude other directives from -// using it at all, which would be a bug. -func makeOnces() map[string]*sync.Once { - onces := make(map[string]*sync.Once) - for _, dir := range directiveOrder { - onces[dir.name] = new(sync.Once) - } - return onces -} - -// makeStorages makes a map of directive name to interface{} -// so that directives' setup functions can persist state -// between different hosts on the same server block during the -// setup phase. -func makeStorages() map[string]interface{} { - storages := make(map[string]interface{}) - for _, dir := range directiveOrder { - storages[dir.name] = nil - } - return storages -} - -// arrangeBindings groups configurations by their bind address. For example, -// a server that should listen on localhost and another on 127.0.0.1 will -// be grouped into the same address: 127.0.0.1. It will return an error -// if an address is malformed or a TLS listener is configured on the -// same address as a plaintext HTTP listener. The return value is a map of -// bind address to list of configs that would become VirtualHosts on that -// server. Use the keys of the returned map to create listeners, and use -// the associated values to set up the virtualhosts. -func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) { - var groupings bindingGroup - - // Group configs by bind address - for _, conf := range allConfigs { - // use default port if none is specified - if conf.Port == "" { - conf.Port = Port - } - - bindAddr, warnErr, fatalErr := resolveAddr(conf) - if fatalErr != nil { - return groupings, fatalErr - } - if warnErr != nil { - log.Printf("[WARNING] Resolving bind address for %s: %v", conf.Address(), warnErr) - } - - // Make sure to compare the string representation of the address, - // not the pointer, since a new *TCPAddr is created each time. - var existing bool - for i := 0; i < len(groupings); i++ { - if groupings[i].BindAddr.String() == bindAddr.String() { - groupings[i].Configs = append(groupings[i].Configs, conf) - existing = true - break - } - } - if !existing { - groupings = append(groupings, bindingMapping{ - BindAddr: bindAddr, - Configs: []server.Config{conf}, - }) - } - } - - // Don't allow HTTP and HTTPS to be served on the same address - for _, group := range groupings { - isTLS := group.Configs[0].TLS.Enabled - for _, config := range group.Configs { - if config.TLS.Enabled != isTLS { - thisConfigProto, otherConfigProto := "HTTP", "HTTP" - if config.TLS.Enabled { - thisConfigProto = "HTTPS" - } - if group.Configs[0].TLS.Enabled { - otherConfigProto = "HTTPS" - } - return groupings, fmt.Errorf("configuration error: Cannot multiplex %s (%s) and %s (%s) on same address", - group.Configs[0].Address(), otherConfigProto, config.Address(), thisConfigProto) - } - } - } - - return groupings, nil -} - -// resolveAddr determines the address (host and port) that a config will -// bind to. The returned address, resolvAddr, should be used to bind the -// listener or group the config with other configs using the same address. -// The first error, if not nil, is just a warning and should be reported -// but execution may continue. The second error, if not nil, is a real -// problem and the server should not be started. -// -// This function does not handle edge cases like port "http" or "https" if -// they are not known to the system. It does, however, serve on the wildcard -// host if resolving the address of the specific hostname fails. -func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) { - resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.BindHost, conf.Port)) - if warnErr != nil { - // the hostname probably couldn't be resolved, just bind to wildcard then - resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("", conf.Port)) - if fatalErr != nil { - return - } - } - - return -} - -// validDirective returns true if d is a valid -// directive; false otherwise. -func validDirective(d string) bool { - for _, dir := range directiveOrder { - if dir.name == d { - return true - } - } - return false -} - -// DefaultInput returns the default Caddyfile input -// to use when it is otherwise empty or missing. -// It uses the default host and port (depends on -// host, e.g. localhost is 2015, otherwise 443) and -// root. -func DefaultInput() CaddyfileInput { - port := Port - if https.HostQualifies(Host) && port == DefaultPort { - port = "443" - } - return CaddyfileInput{ - Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)), - } -} - -// These defaults are configurable through the command line -var ( - // Root is the site root - Root = DefaultRoot - - // Host is the site host - Host = DefaultHost - - // Port is the site port - Port = DefaultPort -) - -// bindingMapping maps a network address to configurations -// that will bind to it. The order of the configs is important. -type bindingMapping struct { - BindAddr *net.TCPAddr - Configs []server.Config -} - -// bindingGroup maps network addresses to their configurations. -// Preserving the order of the groupings is important -// (related to graceful shutdown and restart) -// so this is a slice, not a literal map. -type bindingGroup []bindingMapping diff --git a/caddy/config_test.go b/caddy/config_test.go deleted file mode 100644 index f5f0db6c2..000000000 --- a/caddy/config_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package caddy - -import ( - "reflect" - "sync" - "testing" - - "github.com/mholt/caddy/server" -) - -func TestDefaultInput(t *testing.T) { - if actual, expected := string(DefaultInput().Body()), ":2015\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - // next few tests simulate user providing -host and/or -port flags - - Host = "not-localhost.com" - if actual, expected := string(DefaultInput().Body()), "not-localhost.com:443\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = "[::1]" - if actual, expected := string(DefaultInput().Body()), "[::1]:2015\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = "127.0.1.1" - if actual, expected := string(DefaultInput().Body()), "127.0.1.1:2015\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = "not-localhost.com" - Port = "1234" - if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = DefaultHost - Port = "1234" - if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } -} - -func TestResolveAddr(t *testing.T) { - // NOTE: If tests fail due to comparing to string "127.0.0.1", - // it's possible that system env resolves with IPv6, or ::1. - // If that happens, maybe we should use actualAddr.IP.IsLoopback() - // for the assertion, rather than a direct string comparison. - - // NOTE: Tests with {Host: "", Port: ""} and {Host: "localhost", Port: ""} - // will not behave the same cross-platform, so they have been omitted. - - for i, test := range []struct { - config server.Config - shouldWarnErr bool - shouldFatalErr bool - expectedIP string - expectedPort int - }{ - {server.Config{Host: "127.0.0.1", Port: "1234"}, false, false, "", 1234}, - {server.Config{Host: "localhost", Port: "80"}, false, false, "", 80}, - {server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "", 1234}, - {server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80}, - {server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443}, - {server.Config{BindHost: "", Port: "1234"}, false, false, "", 1234}, - {server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0}, - {server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "", 1234}, - } { - actualAddr, warnErr, fatalErr := resolveAddr(test.config) - - if test.shouldFatalErr && fatalErr == nil { - t.Errorf("Test %d: Expected error, but there wasn't any", i) - } - if !test.shouldFatalErr && fatalErr != nil { - t.Errorf("Test %d: Expected no error, but there was one: %v", i, fatalErr) - } - if fatalErr != nil { - continue - } - - if test.shouldWarnErr && warnErr == nil { - t.Errorf("Test %d: Expected warning, but there wasn't any", i) - } - if !test.shouldWarnErr && warnErr != nil { - t.Errorf("Test %d: Expected no warning, but there was one: %v", i, warnErr) - } - - if actual, expected := actualAddr.IP.String(), test.expectedIP; actual != expected { - t.Errorf("Test %d: IP was %s but expected %s", i, actual, expected) - } - if actual, expected := actualAddr.Port, test.expectedPort; actual != expected { - t.Errorf("Test %d: Port was %d but expected %d", i, actual, expected) - } - } -} - -func TestMakeOnces(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - onces := makeOnces() - if len(onces) != len(directives) { - t.Errorf("onces had len %d , expected %d", len(onces), len(directives)) - } - expected := map[string]*sync.Once{ - "dummy": new(sync.Once), - "dummy2": new(sync.Once), - } - if !reflect.DeepEqual(onces, expected) { - t.Errorf("onces was %v, expected %v", onces, expected) - } -} - -func TestMakeStorages(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - storages := makeStorages() - if len(storages) != len(directives) { - t.Errorf("storages had len %d , expected %d", len(storages), len(directives)) - } - expected := map[string]interface{}{ - "dummy": nil, - "dummy2": nil, - } - if !reflect.DeepEqual(storages, expected) { - t.Errorf("storages was %v, expected %v", storages, expected) - } -} - -func TestValidDirective(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - for i, test := range []struct { - directive string - valid bool - }{ - {"dummy", true}, - {"dummy2", true}, - {"dummy3", false}, - } { - if actual, expected := validDirective(test.directive), test.valid; actual != expected { - t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected) - } - } -} diff --git a/caddy/directives.go b/caddy/directives.go deleted file mode 100644 index 66e123a2d..000000000 --- a/caddy/directives.go +++ /dev/null @@ -1,109 +0,0 @@ -package caddy - -import ( - "github.com/mholt/caddy/caddy/https" - "github.com/mholt/caddy/caddy/parse" - "github.com/mholt/caddy/caddy/setup" - "github.com/mholt/caddy/middleware" -) - -func init() { - // The parse package must know which directives - // are valid, but it must not import the setup - // or config package. To solve this problem, we - // fill up this map in our init function here. - // The parse package does not need to know the - // ordering of the directives. - for _, dir := range directiveOrder { - parse.ValidDirectives[dir.name] = struct{}{} - } -} - -// Directives are registered in the order they should be -// executed. Middleware (directives that inject a handler) -// are executed in the order A-B-C-*-C-B-A, assuming -// they all call the Next handler in the chain. -// -// Ordering is VERY important. Every middleware will -// feel the effects of all other middleware below -// (after) them during a request, but they must not -// care what middleware above them are doing. -// -// For example, log needs to know the status code and -// exactly how many bytes were written to the client, -// which every other middleware can affect, so it gets -// registered first. The errors middleware does not -// care if gzip or log modifies its response, so it -// gets registered below them. Gzip, on the other hand, -// DOES care what errors does to the response since it -// must compress every output to the client, even error -// pages, so it must be registered before the errors -// middleware and any others that would write to the -// response. -var directiveOrder = []directive{ - // Essential directives that initialize vital configuration settings - {"root", setup.Root}, - {"bind", setup.BindHost}, - {"tls", https.Setup}, - - // Other directives that don't create HTTP handlers - {"startup", setup.Startup}, - {"shutdown", setup.Shutdown}, - - // Directives that inject handlers (middleware) - {"log", setup.Log}, - {"gzip", setup.Gzip}, - {"errors", setup.Errors}, - {"header", setup.Headers}, - {"rewrite", setup.Rewrite}, - {"redir", setup.Redir}, - {"ext", setup.Ext}, - {"mime", setup.Mime}, - {"basicauth", setup.BasicAuth}, - {"internal", setup.Internal}, - {"pprof", setup.PProf}, - {"expvar", setup.ExpVar}, - {"proxy", setup.Proxy}, - {"fastcgi", setup.FastCGI}, - {"websocket", setup.WebSocket}, - {"markdown", setup.Markdown}, - {"templates", setup.Templates}, - {"browse", setup.Browse}, -} - -// Directives returns the list of directives in order of priority. -func Directives() []string { - directives := make([]string, len(directiveOrder)) - for i, d := range directiveOrder { - directives[i] = d.name - } - return directives -} - -// RegisterDirective adds the given directive to caddy's list of directives. -// Pass the name of a directive you want it to be placed after, -// otherwise it will be placed at the bottom of the stack. -func RegisterDirective(name string, setup SetupFunc, after string) { - dir := directive{name: name, setup: setup} - idx := len(directiveOrder) - for i := range directiveOrder { - if directiveOrder[i].name == after { - idx = i + 1 - break - } - } - newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...) - directiveOrder = newDirectives - parse.ValidDirectives[name] = struct{}{} -} - -// directive ties together a directive name with its setup function. -type directive struct { - name string - setup SetupFunc -} - -// SetupFunc takes a controller and may optionally return a middleware. -// If the resulting middleware is not nil, it will be chained into -// the HTTP handlers in the order specified in this package. -type SetupFunc func(c *setup.Controller) (middleware.Middleware, error) diff --git a/caddy/directives_test.go b/caddy/directives_test.go deleted file mode 100644 index e37411f1c..000000000 --- a/caddy/directives_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package caddy - -import ( - "reflect" - "testing" -) - -func TestRegister(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - RegisterDirective("foo", nil, "dummy") - if len(directiveOrder) != 3 { - t.Fatal("Should have 3 directives now") - } - getNames := func() (s []string) { - for _, d := range directiveOrder { - s = append(s, d.name) - } - return s - } - if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) { - t.Fatalf("directive order doesn't match: %s", getNames()) - } - RegisterDirective("bar", nil, "ASDASD") - if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) { - t.Fatalf("directive order doesn't match: %s", getNames()) - } -} diff --git a/caddy/helpers.go b/caddy/helpers.go deleted file mode 100644 index 2338fff0f..000000000 --- a/caddy/helpers.go +++ /dev/null @@ -1,64 +0,0 @@ -package caddy - -import ( - "bytes" - "fmt" - "io/ioutil" - "os" - "os/exec" - "runtime" - "strconv" - "strings" -) - -// isLocalhost returns true if host looks explicitly like a localhost address. -func isLocalhost(host string) bool { - return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.") -} - -// checkFdlimit issues a warning if the OS max file descriptors is below a recommended minimum. -func checkFdlimit() { - const min = 4096 - - // Warn if ulimit is too low for production sites - if runtime.GOOS == "linux" || runtime.GOOS == "darwin" { - out, err := exec.Command("sh", "-c", "ulimit -n").Output() // use sh because ulimit isn't in Linux $PATH - if err == nil { - // Note that an error here need not be reported - lim, err := strconv.Atoi(string(bytes.TrimSpace(out))) - if err == nil && lim < min { - fmt.Printf("Warning: File descriptor limit %d is too low for production sites. At least %d is recommended. Set with \"ulimit -n %d\".\n", lim, min, min) - } - } - } -} - -// IsRestart returns whether this process is, according -// to env variables, a fork as part of a graceful restart. -func IsRestart() bool { - return startedBefore -} - -// writePidFile writes the process ID to the file at PidFile, if specified. -func writePidFile() error { - pid := []byte(strconv.Itoa(os.Getpid()) + "\n") - return ioutil.WriteFile(PidFile, pid, 0644) -} - -// CaddyfileInput represents a Caddyfile as input -// and is simply a convenient way to implement -// the Input interface. -type CaddyfileInput struct { - Filepath string - Contents []byte - RealFile bool -} - -// Body returns c.Contents. -func (c CaddyfileInput) Body() []byte { return c.Contents } - -// Path returns c.Filepath. -func (c CaddyfileInput) Path() string { return c.Filepath } - -// IsFile returns true if the original input was a real file on the file system. -func (c CaddyfileInput) IsFile() bool { return c.RealFile } diff --git a/caddy/https/crypto.go b/caddy/https/crypto.go deleted file mode 100644 index 7971bda36..000000000 --- a/caddy/https/crypto.go +++ /dev/null @@ -1,57 +0,0 @@ -package https - -import ( - "crypto" - "crypto/ecdsa" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "io/ioutil" - "os" -) - -// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file. -func loadPrivateKey(file string) (crypto.PrivateKey, error) { - keyBytes, err := ioutil.ReadFile(file) - if err != nil { - return nil, err - } - keyBlock, _ := pem.Decode(keyBytes) - - switch keyBlock.Type { - case "RSA PRIVATE KEY": - return x509.ParsePKCS1PrivateKey(keyBlock.Bytes) - case "EC PRIVATE KEY": - return x509.ParseECPrivateKey(keyBlock.Bytes) - } - - return nil, errors.New("unknown private key type") -} - -// savePrivateKey saves a PEM-encoded ECC/RSA private key to file. -func savePrivateKey(key crypto.PrivateKey, file string) error { - var pemType string - var keyBytes []byte - switch key := key.(type) { - case *ecdsa.PrivateKey: - var err error - pemType = "EC" - keyBytes, err = x509.MarshalECPrivateKey(key) - if err != nil { - return err - } - case *rsa.PrivateKey: - pemType = "RSA" - keyBytes = x509.MarshalPKCS1PrivateKey(key) - } - - pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes} - keyOut, err := os.Create(file) - if err != nil { - return err - } - keyOut.Chmod(0600) - defer keyOut.Close() - return pem.Encode(keyOut, &pemKey) -} diff --git a/caddy/https/handler.go b/caddy/https/handler.go deleted file mode 100644 index f3139f54e..000000000 --- a/caddy/https/handler.go +++ /dev/null @@ -1,42 +0,0 @@ -package https - -import ( - "crypto/tls" - "log" - "net/http" - "net/http/httputil" - "net/url" - "strings" -) - -const challengeBasePath = "/.well-known/acme-challenge" - -// RequestCallback proxies challenge requests to ACME client if the -// request path starts with challengeBasePath. It returns true if it -// handled the request and no more needs to be done; it returns false -// if this call was a no-op and the request still needs handling. -func RequestCallback(w http.ResponseWriter, r *http.Request) bool { - if strings.HasPrefix(r.URL.Path, challengeBasePath) { - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - - upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - log.Printf("[ERROR] ACME proxy handler: %v", err) - return true - } - - proxy := httputil.NewSingleHostReverseProxy(upstream) - proxy.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs - } - proxy.ServeHTTP(w, r) - - return true - } - - return false -} diff --git a/caddy/https/https.go b/caddy/https/https.go deleted file mode 100644 index f9214f149..000000000 --- a/caddy/https/https.go +++ /dev/null @@ -1,411 +0,0 @@ -// Package https facilitates the management of TLS assets and integrates -// Let's Encrypt functionality into Caddy with first-class support for -// creating and renewing certificates automatically. It is designed to -// configure sites for HTTPS by default. -package https - -import ( - "encoding/json" - "errors" - "io/ioutil" - "net" - "net/http" - "os" - "strings" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/redirect" - "github.com/mholt/caddy/server" - "github.com/xenolf/lego/acme" -) - -// Activate sets up TLS for each server config in configs -// as needed; this consists of acquiring and maintaining -// certificates and keys for qualifying configs and enabling -// OCSP stapling for all TLS-enabled configs. -// -// This function may prompt the user to provide an email -// address if none is available through other means. It -// prefers the email address specified in the config, but -// if that is not available it will check the command line -// argument. If absent, it will use the most recent email -// address from last time. If there isn't one, the user -// will be prompted and shown SA link. -// -// Also note that calling this function activates asset -// management automatically, which keeps certificates -// renewed and OCSP stapling updated. -// -// Activate returns the updated list of configs, since -// some may have been appended, for example, to redirect -// plaintext HTTP requests to their HTTPS counterpart. -// This function only appends; it does not splice. -func Activate(configs []server.Config) ([]server.Config, error) { - // just in case previous caller forgot... - Deactivate() - - // pre-screen each config and earmark the ones that qualify for managed TLS - MarkQualified(configs) - - // place certificates and keys on disk - err := ObtainCerts(configs, true, false) - if err != nil { - return configs, err - } - - // update TLS configurations - err = EnableTLS(configs, true) - if err != nil { - return configs, err - } - - // set up redirects - configs = MakePlaintextRedirects(configs) - - // renew all relevant certificates that need renewal. this is important - // to do right away for a couple reasons, mainly because each restart, - // the renewal ticker is reset, so if restarts happen more often than - // the ticker interval, renewals would never happen. but doing - // it right away at start guarantees that renewals aren't missed. - err = renewManagedCertificates(true) - if err != nil { - return configs, err - } - - // keep certificates renewed and OCSP stapling updated - go maintainAssets(stopChan) - - return configs, nil -} - -// Deactivate cleans up long-term, in-memory resources -// allocated by calling Activate(). Essentially, it stops -// the asset maintainer from running, meaning that certificates -// will not be renewed, OCSP staples will not be updated, etc. -func Deactivate() (err error) { - defer func() { - if rec := recover(); rec != nil { - err = errors.New("already deactivated") - } - }() - close(stopChan) - stopChan = make(chan struct{}) - return -} - -// MarkQualified scans each config and, if it qualifies for managed -// TLS, it sets the Managed field of the TLSConfig to true. -func MarkQualified(configs []server.Config) { - for i := 0; i < len(configs); i++ { - if ConfigQualifies(configs[i]) { - configs[i].TLS.Managed = true - } - } -} - -// ObtainCerts obtains certificates for all these configs as long as a -// certificate does not already exist on disk. It does not modify the -// configs at all; it only obtains and stores certificates and keys to -// the disk. If allowPrompts is true, the user may be shown a prompt. -// If proxyACME is true, the ACME challenges will be proxied to our alt port. -func ObtainCerts(configs []server.Config, allowPrompts, proxyACME bool) error { - // We group configs by email so we don't make the same clients over and - // over. This has the potential to prompt the user for an email, but we - // prevent that by assuming that if we already have a listener that can - // proxy ACME challenge requests, then the server is already running and - // the operator is no longer present. - groupedConfigs := groupConfigsByEmail(configs, allowPrompts) - - for email, group := range groupedConfigs { - // Wait as long as we can before creating the client, because it - // may not be needed, for example, if we already have what we - // need on disk. Creating a client involves the network and - // potentially prompting the user, etc., so only do if necessary. - var client *ACMEClient - - for _, cfg := range group { - if !HostQualifies(cfg.Host) || existingCertAndKey(cfg.Host) { - continue - } - - // Now we definitely do need a client - if client == nil { - var err error - client, err = NewACMEClient(email, allowPrompts) - if err != nil { - return errors.New("error creating client: " + err.Error()) - } - } - - // c.Configure assumes that allowPrompts == !proxyACME, - // but that's not always true. For example, a restart where - // the user isn't present and we're not listening on port 80. - // TODO: This could probably be refactored better. - if proxyACME { - client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, AlternatePort)) - client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, AlternatePort)) - client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) - } else { - client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, "")) - client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, "")) - client.ExcludeChallenges([]acme.Challenge{acme.DNS01}) - } - - err := client.Obtain([]string{cfg.Host}) - if err != nil { - return err - } - } - } - - return nil -} - -// groupConfigsByEmail groups configs by the email address to be used by an -// ACME client. It only groups configs that have TLS enabled and that are -// marked as Managed. If userPresent is true, the operator MAY be prompted -// for an email address. -func groupConfigsByEmail(configs []server.Config, userPresent bool) map[string][]server.Config { - initMap := make(map[string][]server.Config) - for _, cfg := range configs { - if !cfg.TLS.Managed { - continue - } - leEmail := getEmail(cfg, userPresent) - initMap[leEmail] = append(initMap[leEmail], cfg) - } - return initMap -} - -// EnableTLS configures each config to use TLS according to default settings. -// It will only change configs that are marked as managed, and assumes that -// certificates and keys are already on disk. If loadCertificates is true, -// the certificates will be loaded from disk into the cache for this process -// to use. If false, TLS will still be enabled and configured with default -// settings, but no certificates will be parsed loaded into the cache, and -// the returned error value will always be nil. -func EnableTLS(configs []server.Config, loadCertificates bool) error { - for i := 0; i < len(configs); i++ { - if !configs[i].TLS.Managed { - continue - } - configs[i].TLS.Enabled = true - if loadCertificates && HostQualifies(configs[i].Host) { - _, err := cacheManagedCertificate(configs[i].Host, false) - if err != nil { - return err - } - } - setDefaultTLSParams(&configs[i]) - } - return nil -} - -// hostHasOtherPort returns true if there is another config in the list with the same -// hostname that has port otherPort, or false otherwise. All the configs are checked -// against the hostname of allConfigs[thisConfigIdx]. -func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort string) bool { - for i, otherCfg := range allConfigs { - if i == thisConfigIdx { - continue // has to be a config OTHER than the one we're comparing against - } - if otherCfg.Host == allConfigs[thisConfigIdx].Host && otherCfg.Port == otherPort { - return true - } - } - return false -} - -// MakePlaintextRedirects sets up redirects from port 80 to the relevant HTTPS -// hosts. You must pass in all configs, not just configs that qualify, since -// we must know whether the same host already exists on port 80, and those would -// not be in a list of configs that qualify for automatic HTTPS. This function will -// only set up redirects for configs that qualify. It returns the updated list of -// all configs. -func MakePlaintextRedirects(allConfigs []server.Config) []server.Config { - for i, cfg := range allConfigs { - if cfg.TLS.Managed && - !hostHasOtherPort(allConfigs, i, "80") && - (cfg.Port == "443" || !hostHasOtherPort(allConfigs, i, "443")) { - allConfigs = append(allConfigs, redirPlaintextHost(cfg)) - } - } - return allConfigs -} - -// ConfigQualifies returns true if cfg qualifies for -// fully managed TLS (but not on-demand TLS, which is -// not considered here). It does NOT check to see if a -// cert and key already exist for the config. If the -// config does qualify, you should set cfg.TLS.Managed -// to true and check that instead, because the process of -// setting up the config may make it look like it -// doesn't qualify even though it originally did. -func ConfigQualifies(cfg server.Config) bool { - return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key - - // user can force-disable automatic HTTPS for this host - cfg.Scheme != "http" && - cfg.Port != "80" && - cfg.TLS.LetsEncryptEmail != "off" && - - // we get can't certs for some kinds of hostnames, but - // on-demand TLS allows empty hostnames at startup - (HostQualifies(cfg.Host) || cfg.TLS.OnDemand) -} - -// HostQualifies returns true if the hostname alone -// appears eligible for automatic HTTPS. For example, -// localhost, empty hostname, and IP addresses are -// not eligible because we cannot obtain certificates -// for those names. -func HostQualifies(hostname string) bool { - return hostname != "localhost" && // localhost is ineligible - - // hostname must not be empty - strings.TrimSpace(hostname) != "" && - - // cannot be an IP address, see - // https://community.letsencrypt.org/t/certificate-for-static-ip/84/2?u=mholt - // (also trim [] from either end, since that special case can sneak through - // for IPv6 addresses using the -host flag and with empty/no Caddyfile) - net.ParseIP(strings.Trim(hostname, "[]")) == nil -} - -// existingCertAndKey returns true if the host has a certificate -// and private key in storage already, false otherwise. -func existingCertAndKey(host string) bool { - _, err := os.Stat(storage.SiteCertFile(host)) - if err != nil { - return false - } - _, err = os.Stat(storage.SiteKeyFile(host)) - if err != nil { - return false - } - return true -} - -// saveCertResource saves the certificate resource to disk. This -// includes the certificate file itself, the private key, and the -// metadata file. -func saveCertResource(cert acme.CertificateResource) error { - err := os.MkdirAll(storage.Site(cert.Domain), 0700) - if err != nil { - return err - } - - // Save cert - err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600) - if err != nil { - return err - } - - // Save private key - err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600) - if err != nil { - return err - } - - // Save cert metadata - jsonBytes, err := json.MarshalIndent(&cert, "", "\t") - if err != nil { - return err - } - err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600) - if err != nil { - return err - } - - return nil -} - -// redirPlaintextHost returns a new plaintext HTTP configuration for -// a virtualHost that simply redirects to cfg, which is assumed to -// be the HTTPS configuration. The returned configuration is set -// to listen on port 80. -func redirPlaintextHost(cfg server.Config) server.Config { - toURL := "https://{host}" // serve any host, since cfg.Host could be empty - if cfg.Port != "443" && cfg.Port != "80" { - toURL += ":" + cfg.Port - } - - redirMidware := func(next middleware.Handler) middleware.Handler { - return redirect.Redirect{Next: next, Rules: []redirect.Rule{ - { - FromScheme: "http", - FromPath: "/", - To: toURL + "{uri}", - Code: http.StatusMovedPermanently, - }, - }} - } - - return server.Config{ - Host: cfg.Host, - BindHost: cfg.BindHost, - Port: "80", - Middleware: []middleware.Middleware{redirMidware}, - } -} - -// Revoke revokes the certificate for host via ACME protocol. -func Revoke(host string) error { - if !existingCertAndKey(host) { - return errors.New("no certificate and key for " + host) - } - - email := getEmail(server.Config{Host: host}, true) - if email == "" { - return errors.New("email is required to revoke") - } - - client, err := NewACMEClient(email, true) - if err != nil { - return err - } - - certFile := storage.SiteCertFile(host) - certBytes, err := ioutil.ReadFile(certFile) - if err != nil { - return err - } - - err = client.RevokeCertificate(certBytes) - if err != nil { - return err - } - - err = os.Remove(certFile) - if err != nil { - return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error()) - } - - return nil -} - -var ( - // DefaultEmail represents the Let's Encrypt account email to use if none provided - DefaultEmail string - - // Agreed indicates whether user has agreed to the Let's Encrypt SA - Agreed bool - - // CAUrl represents the base URL to the CA's ACME endpoint - CAUrl string -) - -// AlternatePort is the port on which the acme client will open a -// listener and solve the CA's challenges. If this alternate port -// is used instead of the default port (80 or 443), then the -// default port for the challenge must be forwarded to this one. -const AlternatePort = "5033" - -// KeyType is the type to use for new keys. -// This shouldn't need to change except for in tests; -// the size can be drastically reduced for speed. -var KeyType = acme.RSA2048 - -// stopChan is used to signal the maintenance goroutine -// to terminate. -var stopChan chan struct{} diff --git a/caddy/https/https_test.go b/caddy/https/https_test.go deleted file mode 100644 index 0f118f095..000000000 --- a/caddy/https/https_test.go +++ /dev/null @@ -1,332 +0,0 @@ -package https - -import ( - "io/ioutil" - "net/http" - "os" - "testing" - - "github.com/mholt/caddy/middleware/redirect" - "github.com/mholt/caddy/server" - "github.com/xenolf/lego/acme" -) - -func TestHostQualifies(t *testing.T) { - for i, test := range []struct { - host string - expect bool - }{ - {"localhost", false}, - {"127.0.0.1", false}, - {"127.0.1.5", false}, - {"::1", false}, - {"[::1]", false}, - {"[::]", false}, - {"::", false}, - {"", false}, - {" ", false}, - {"0.0.0.0", false}, - {"192.168.1.3", false}, - {"10.0.2.1", false}, - {"169.112.53.4", false}, - {"foobar.com", true}, - {"sub.foobar.com", true}, - } { - if HostQualifies(test.host) && !test.expect { - t.Errorf("Test %d: Expected '%s' to NOT qualify, but it did", i, test.host) - } - if !HostQualifies(test.host) && test.expect { - t.Errorf("Test %d: Expected '%s' to qualify, but it did NOT", i, test.host) - } - } -} - -func TestConfigQualifies(t *testing.T) { - for i, test := range []struct { - cfg server.Config - expect bool - }{ - {server.Config{Host: ""}, false}, - {server.Config{Host: "localhost"}, false}, - {server.Config{Host: "123.44.3.21"}, false}, - {server.Config{Host: "example.com"}, true}, - {server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false}, - {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false}, - {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true}, - {server.Config{Host: "example.com", Scheme: "http"}, false}, - {server.Config{Host: "example.com", Port: "80"}, false}, - {server.Config{Host: "example.com", Port: "1234"}, true}, - {server.Config{Host: "example.com", Scheme: "https"}, true}, - {server.Config{Host: "example.com", Port: "80", Scheme: "https"}, false}, - } { - if test.expect && !ConfigQualifies(test.cfg) { - t.Errorf("Test %d: Expected config to qualify, but it did NOT: %#v", i, test.cfg) - } - if !test.expect && ConfigQualifies(test.cfg) { - t.Errorf("Test %d: Expected config to NOT qualify, but it did: %#v", i, test.cfg) - } - } -} - -func TestRedirPlaintextHost(t *testing.T) { - cfg := redirPlaintextHost(server.Config{ - Host: "example.com", - BindHost: "93.184.216.34", - Port: "1234", - }) - - // Check host and port - if actual, expected := cfg.Host, "example.com"; actual != expected { - t.Errorf("Expected redir config to have host %s but got %s", expected, actual) - } - if actual, expected := cfg.BindHost, "93.184.216.34"; actual != expected { - t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual) - } - if actual, expected := cfg.Port, "80"; actual != expected { - t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual) - } - - // Make sure redirect handler is set up properly - if cfg.Middleware == nil || len(cfg.Middleware) != 1 { - t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware) - } - - handler, ok := cfg.Middleware[0](nil).(redirect.Redirect) - if !ok { - t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler) - } - if len(handler.Rules) != 1 { - t.Fatalf("Expected one redirect rule, got: %#v", handler.Rules) - } - - // Check redirect rule for correctness - if actual, expected := handler.Rules[0].FromScheme, "http"; actual != expected { - t.Errorf("Expected redirect rule to be from scheme '%s' but is actually from '%s'", expected, actual) - } - if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected { - t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual) - } - if actual, expected := handler.Rules[0].To, "https://{host}:1234{uri}"; actual != expected { - t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) - } - if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected { - t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual) - } - - // browsers can infer a default port from scheme, so make sure the port - // doesn't get added in explicitly for default ports like 443 for https. - cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"}) - handler, _ = cfg.Middleware[0](nil).(redirect.Redirect) - if actual, expected := handler.Rules[0].To, "https://{host}{uri}"; actual != expected { - t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) - } -} - -func TestSaveCertResource(t *testing.T) { - storage = Storage("./le_test_save") - defer func() { - err := os.RemoveAll(string(storage)) - if err != nil { - t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) - } - }() - - domain := "example.com" - certContents := "certificate" - keyContents := "private key" - metaContents := `{ - "domain": "example.com", - "certUrl": "https://example.com/cert", - "certStableUrl": "https://example.com/cert/stable" -}` - - cert := acme.CertificateResource{ - Domain: domain, - CertURL: "https://example.com/cert", - CertStableURL: "https://example.com/cert/stable", - PrivateKey: []byte(keyContents), - Certificate: []byte(certContents), - } - - err := saveCertResource(cert) - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain)) - if err != nil { - t.Errorf("Expected no error reading certificate file, got: %v", err) - } - if string(certFile) != certContents { - t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile)) - } - - keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain)) - if err != nil { - t.Errorf("Expected no error reading private key file, got: %v", err) - } - if string(keyFile) != keyContents { - t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile)) - } - - metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain)) - if err != nil { - t.Errorf("Expected no error reading meta file, got: %v", err) - } - if string(metaFile) != metaContents { - t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile)) - } -} - -func TestExistingCertAndKey(t *testing.T) { - storage = Storage("./le_test_existing") - defer func() { - err := os.RemoveAll(string(storage)) - if err != nil { - t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) - } - }() - - domain := "example.com" - - if existingCertAndKey(domain) { - t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain) - } - - err := saveCertResource(acme.CertificateResource{ - Domain: domain, - PrivateKey: []byte("key"), - Certificate: []byte("cert"), - }) - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if !existingCertAndKey(domain) { - t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain) - } -} - -func TestHostHasOtherPort(t *testing.T) { - configs := []server.Config{ - {Host: "example.com", Port: "80"}, - {Host: "sub1.example.com", Port: "80"}, - {Host: "sub1.example.com", Port: "443"}, - } - - if hostHasOtherPort(configs, 0, "80") { - t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`) - } - if hostHasOtherPort(configs, 0, "443") { - t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`) - } - if !hostHasOtherPort(configs, 1, "443") { - t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`) - } -} - -func TestMakePlaintextRedirects(t *testing.T) { - configs := []server.Config{ - // Happy path = standard redirect from 80 to 443 - {Host: "example.com", TLS: server.TLSConfig{Managed: true}}, - - // Host on port 80 already defined; don't change it (no redirect) - {Host: "sub1.example.com", Port: "80", Scheme: "http"}, - {Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}}, - - // Redirect from port 80 to port 5000 in this case - {Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}}, - - // Can redirect from 80 to either 443 or 5001, but choose 443 - {Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}}, - {Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}}, - } - - result := MakePlaintextRedirects(configs) - expectedRedirCount := 3 - - if len(result) != len(configs)+expectedRedirCount { - t.Errorf("Expected %d redirect(s) to be added, but got %d", - expectedRedirCount, len(result)-len(configs)) - } -} - -func TestEnableTLS(t *testing.T) { - configs := []server.Config{ - {Host: "example.com", TLS: server.TLSConfig{Managed: true}}, - {}, // not managed - no changes! - } - - EnableTLS(configs, false) - - if !configs[0].TLS.Enabled { - t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false") - } - if configs[1].TLS.Enabled { - t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true") - } -} - -func TestGroupConfigsByEmail(t *testing.T) { - if groupConfigsByEmail([]server.Config{}, false) == nil { - t.Errorf("With empty input, returned map was nil, but expected non-nil map") - } - - configs := []server.Config{ - {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, - {Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, - {Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, - {Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, - {Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, - {Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed - } - DefaultEmail = "test@example.com" - - groups := groupConfigsByEmail(configs, true) - - if groups == nil { - t.Fatalf("Returned map was nil, but expected values") - } - - if len(groups) != 2 { - t.Errorf("Expected 2 groups, got %d: %#v", len(groups), groups) - } - if len(groups["foo@bar"]) != 2 { - t.Errorf("Expected 2 configs for foo@bar, got %d: %#v", len(groups["foobar"]), groups["foobar"]) - } - if len(groups[DefaultEmail]) != 3 { - t.Errorf("Expected 3 configs for %s, got %d: %#v", DefaultEmail, len(groups["foobar"]), groups["foobar"]) - } -} - -func TestMarkQualified(t *testing.T) { - // TODO: TestConfigQualifies and this test share the same config list... - configs := []server.Config{ - {Host: ""}, - {Host: "localhost"}, - {Host: "123.44.3.21"}, - {Host: "example.com"}, - {Host: "example.com", TLS: server.TLSConfig{Manual: true}}, - {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, - {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, - {Host: "example.com", Scheme: "http"}, - {Host: "example.com", Port: "80"}, - {Host: "example.com", Port: "1234"}, - {Host: "example.com", Scheme: "https"}, - {Host: "example.com", Port: "80", Scheme: "https"}, - } - expectedManagedCount := 4 - - MarkQualified(configs) - - count := 0 - for _, cfg := range configs { - if cfg.TLS.Managed { - count++ - } - } - - if count != expectedManagedCount { - t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) - } -} diff --git a/caddy/https/setup.go b/caddy/https/setup.go deleted file mode 100644 index eebfc62da..000000000 --- a/caddy/https/setup.go +++ /dev/null @@ -1,355 +0,0 @@ -package https - -import ( - "bytes" - "crypto/tls" - "encoding/pem" - "io/ioutil" - "log" - "os" - "path/filepath" - "strconv" - "strings" - - "github.com/mholt/caddy/caddy/setup" - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/server" - "github.com/xenolf/lego/acme" -) - -// Setup sets up the TLS configuration and installs certificates that -// are specified by the user in the config file. All the automatic HTTPS -// stuff comes later outside of this function. -func Setup(c *setup.Controller) (middleware.Middleware, error) { - if c.Port == "80" || c.Scheme == "http" { - c.TLS.Enabled = false - log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address()) - return nil, nil - } - c.TLS.Enabled = true - - for c.Next() { - var certificateFile, keyFile, loadDir, maxCerts string - - args := c.RemainingArgs() - switch len(args) { - case 1: - c.TLS.LetsEncryptEmail = args[0] - - // user can force-disable managed TLS this way - if c.TLS.LetsEncryptEmail == "off" { - c.TLS.Enabled = false - return nil, nil - } - case 2: - certificateFile = args[0] - keyFile = args[1] - c.TLS.Manual = true - } - - // Optional block with extra parameters - var hadBlock bool - for c.NextBlock() { - hadBlock = true - switch c.Val() { - case "key_type": - arg := c.RemainingArgs() - value, ok := supportedKeyTypes[strings.ToUpper(arg[0])] - if !ok { - return nil, c.Errf("Wrong KeyType name or KeyType not supported '%s'", c.Val()) - } - KeyType = value - case "protocols": - args := c.RemainingArgs() - if len(args) != 2 { - return nil, c.ArgErr() - } - value, ok := supportedProtocols[strings.ToLower(args[0])] - if !ok { - return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val()) - } - c.TLS.ProtocolMinVersion = value - value, ok = supportedProtocols[strings.ToLower(args[1])] - if !ok { - return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val()) - } - c.TLS.ProtocolMaxVersion = value - case "ciphers": - for c.NextArg() { - value, ok := supportedCiphersMap[strings.ToUpper(c.Val())] - if !ok { - return nil, c.Errf("Wrong cipher name or cipher not supported '%s'", c.Val()) - } - c.TLS.Ciphers = append(c.TLS.Ciphers, value) - } - case "clients": - clientCertList := c.RemainingArgs() - if len(clientCertList) == 0 { - return nil, c.ArgErr() - } - - listStart, mustProvideCA := 1, true - switch clientCertList[0] { - case "request": - c.TLS.ClientAuth = tls.RequestClientCert - mustProvideCA = false - case "require": - c.TLS.ClientAuth = tls.RequireAnyClientCert - mustProvideCA = false - case "verify_if_given": - c.TLS.ClientAuth = tls.VerifyClientCertIfGiven - default: - c.TLS.ClientAuth = tls.RequireAndVerifyClientCert - listStart = 0 - } - if mustProvideCA && len(clientCertList) <= listStart { - return nil, c.ArgErr() - } - - c.TLS.ClientCerts = clientCertList[listStart:] - case "load": - c.Args(&loadDir) - c.TLS.Manual = true - case "max_certs": - c.Args(&maxCerts) - c.TLS.OnDemand = true - default: - return nil, c.Errf("Unknown keyword '%s'", c.Val()) - } - } - - // tls requires at least one argument if a block is not opened - if len(args) == 0 && !hadBlock { - return nil, c.ArgErr() - } - - // set certificate limit if on-demand TLS is enabled - if maxCerts != "" { - maxCertsNum, err := strconv.Atoi(maxCerts) - if err != nil || maxCertsNum < 1 { - return nil, c.Err("max_certs must be a positive integer") - } - if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost... - onDemandMaxIssue = int32(maxCertsNum) - } - } - - // don't try to load certificates unless we're supposed to - if !c.TLS.Enabled || !c.TLS.Manual { - continue - } - - // load a single certificate and key, if specified - if certificateFile != "" && keyFile != "" { - err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) - if err != nil { - return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err) - } - log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile) - } - - // load a directory of certificates, if specified - if loadDir != "" { - err := loadCertsInDir(c, loadDir) - if err != nil { - return nil, err - } - } - } - - setDefaultTLSParams(c.Config) - - return nil, nil -} - -// loadCertsInDir loads all the certificates/keys in dir, as long as -// the file ends with .pem. This method of loading certificates is -// modeled after haproxy, which expects the certificate and key to -// be bundled into the same file: -// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt -// -// This function may write to the log as it walks the directory tree. -func loadCertsInDir(c *setup.Controller, dir string) error { - return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { - if err != nil { - log.Printf("[WARNING] Unable to traverse into %s; skipping", path) - return nil - } - if info.IsDir() { - return nil - } - if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") { - certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer) - var foundKey bool // use only the first key in the file - - bundle, err := ioutil.ReadFile(path) - if err != nil { - return err - } - - for { - // Decode next block so we can see what type it is - var derBlock *pem.Block - derBlock, bundle = pem.Decode(bundle) - if derBlock == nil { - break - } - - if derBlock.Type == "CERTIFICATE" { - // Re-encode certificate as PEM, appending to certificate chain - pem.Encode(certBuilder, derBlock) - } else if derBlock.Type == "EC PARAMETERS" { - // EC keys generated from openssl can be composed of two blocks: - // parameters and key (parameter block should come first) - if !foundKey { - // Encode parameters - pem.Encode(keyBuilder, derBlock) - - // Key must immediately follow - derBlock, bundle = pem.Decode(bundle) - if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" { - return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path) - } - pem.Encode(keyBuilder, derBlock) - foundKey = true - } - } else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") { - // RSA key - if !foundKey { - pem.Encode(keyBuilder, derBlock) - foundKey = true - } - } else { - return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type) - } - } - - certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes() - if len(certPEMBytes) == 0 { - return c.Errf("%s: failed to parse PEM data", path) - } - if len(keyPEMBytes) == 0 { - return c.Errf("%s: no private key block found", path) - } - - err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) - if err != nil { - return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err) - } - log.Printf("[INFO] Successfully loaded TLS assets from %s", path) - } - return nil - }) -} - -// setDefaultTLSParams sets the default TLS cipher suites, protocol versions, -// and server preferences of a server.Config if they were not previously set -// (it does not overwrite; only fills in missing values). It will also set the -// port to 443 if not already set, TLS is enabled, TLS is manual, and the host -// does not equal localhost. -func setDefaultTLSParams(c *server.Config) { - // If no ciphers provided, use default list - if len(c.TLS.Ciphers) == 0 { - c.TLS.Ciphers = defaultCiphers - } - - // Not a cipher suite, but still important for mitigating protocol downgrade attacks - // (prepend since having it at end breaks http2 due to non-h2-approved suites before it) - c.TLS.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, c.TLS.Ciphers...) - - // Set default protocol min and max versions - must balance compatibility and security - if c.TLS.ProtocolMinVersion == 0 { - c.TLS.ProtocolMinVersion = tls.VersionTLS10 - } - if c.TLS.ProtocolMaxVersion == 0 { - c.TLS.ProtocolMaxVersion = tls.VersionTLS12 - } - - // Prefer server cipher suites - c.TLS.PreferServerCipherSuites = true - - // Default TLS port is 443; only use if port is not manually specified, - // TLS is enabled, and the host is not localhost - if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" { - c.Port = "443" - } -} - -// Map of supported key types -var supportedKeyTypes = map[string]acme.KeyType{ - "P384": acme.EC384, - "P256": acme.EC256, - "RSA8192": acme.RSA8192, - "RSA4096": acme.RSA4096, - "RSA2048": acme.RSA2048, -} - -// Map of supported protocols. -// SSLv3 will be not supported in future release. -// HTTP/2 only supports TLS 1.2 and higher. -var supportedProtocols = map[string]uint16{ - "ssl3.0": tls.VersionSSL30, - "tls1.0": tls.VersionTLS10, - "tls1.1": tls.VersionTLS11, - "tls1.2": tls.VersionTLS12, -} - -// Map of supported ciphers, used only for parsing config. -// -// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites, -// including all but two of the suites below (the two GCM suites). -// See https://http2.github.io/http2-spec/#BadCipherSuites -// -// TLS_FALLBACK_SCSV is not in this list because we manually ensure -// it is always added (even though it is not technically a cipher suite). -// -// This map, like any map, is NOT ORDERED. Do not range over this map. -var supportedCiphersMap = map[string]uint16{ - "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - "ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - "ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - "ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - "RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, - "RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, - "ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, -} - -// List of supported cipher suites in descending order of preference. -// Ordering is very important! Getting the wrong order will break -// mainstream clients, especially with HTTP/2. -// -// Note that TLS_FALLBACK_SCSV is not in this list since it is always -// added manually. -var supportedCiphers = []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, -} - -// List of all the ciphers we want to use by default -var defaultCiphers = []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, -} diff --git a/caddy/main.go b/caddy/main.go new file mode 100644 index 000000000..4559be03a --- /dev/null +++ b/caddy/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/mholt/caddy/caddy/caddymain" + +func main() { + caddymain.Run() +} diff --git a/caddy/parse/parse.go b/caddy/parse/parse.go deleted file mode 100644 index faef36c28..000000000 --- a/caddy/parse/parse.go +++ /dev/null @@ -1,32 +0,0 @@ -// Package parse provides facilities for parsing configuration files. -package parse - -import "io" - -// ServerBlocks parses the input just enough to organize tokens, -// in order, by server block. No further parsing is performed. -// If checkDirectives is true, only valid directives will be allowed -// otherwise we consider it a parse error. Server blocks are returned -// in the order in which they appear. -func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) { - p := parser{Dispenser: NewDispenser(filename, input)} - p.checkDirectives = checkDirectives - blocks, err := p.parseAll() - return blocks, err -} - -// allTokens lexes the entire input, but does not parse it. -// It returns all the tokens from the input, unstructured -// and in order. -func allTokens(input io.Reader) (tokens []token) { - l := new(lexer) - l.load(input) - for l.next() { - tokens = append(tokens, l.token) - } - return -} - -// ValidDirectives is a set of directives that are valid (unordered). Populated -// by config package's init function. -var ValidDirectives = make(map[string]struct{}) diff --git a/caddy/parse/parse_test.go b/caddy/parse/parse_test.go deleted file mode 100644 index 48746300f..000000000 --- a/caddy/parse/parse_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package parse - -import ( - "strings" - "testing" -) - -func TestAllTokens(t *testing.T) { - input := strings.NewReader("a b c\nd e") - expected := []string{"a", "b", "c", "d", "e"} - tokens := allTokens(input) - - if len(tokens) != len(expected) { - t.Fatalf("Expected %d tokens, got %d", len(expected), len(tokens)) - } - - for i, val := range expected { - if tokens[i].text != val { - t.Errorf("Token %d should be '%s' but was '%s'", i, val, tokens[i].text) - } - } -} diff --git a/caddy/parse/parsing_test.go b/caddy/parse/parsing_test.go deleted file mode 100644 index db7fd3e1b..000000000 --- a/caddy/parse/parsing_test.go +++ /dev/null @@ -1,480 +0,0 @@ -package parse - -import ( - "os" - "strings" - "testing" -) - -func TestStandardAddress(t *testing.T) { - for i, test := range []struct { - input string - scheme, host, port string - shouldErr bool - }{ - {`localhost`, "", "localhost", "", false}, - {`LOCALHOST`, "", "localhost", "", false}, - {`localhost:1234`, "", "localhost", "1234", false}, - {`LOCALHOST:1234`, "", "localhost", "1234", false}, - {`localhost:`, "", "localhost", "", false}, - {`0.0.0.0`, "", "0.0.0.0", "", false}, - {`127.0.0.1:1234`, "", "127.0.0.1", "1234", false}, - {`:1234`, "", "", "1234", false}, - {`[::1]`, "", "::1", "", false}, - {`[::1]:1234`, "", "::1", "1234", false}, - {`:`, "", "", "", false}, - {`localhost:http`, "http", "localhost", "80", false}, - {`localhost:https`, "https", "localhost", "443", false}, - {`:http`, "http", "", "80", false}, - {`:https`, "https", "", "443", false}, - {`http://localhost:https`, "", "", "", true}, // conflict - {`http://localhost:http`, "", "", "", true}, // repeated scheme - {`http://localhost:443`, "", "", "", true}, // not conventional - {`https://localhost:80`, "", "", "", true}, // not conventional - {`http://localhost`, "http", "localhost", "80", false}, - {`https://localhost`, "https", "localhost", "443", false}, - {`http://127.0.0.1`, "http", "127.0.0.1", "80", false}, - {`https://127.0.0.1`, "https", "127.0.0.1", "443", false}, - {`http://[::1]`, "http", "::1", "80", false}, - {`http://localhost:1234`, "http", "localhost", "1234", false}, - {`http://LOCALHOST:1234`, "http", "localhost", "1234", false}, - {`https://127.0.0.1:1234`, "https", "127.0.0.1", "1234", false}, - {`http://[::1]:1234`, "http", "::1", "1234", false}, - {``, "", "", "", false}, - {`::1`, "", "::1", "", true}, - {`localhost::`, "", "localhost::", "", true}, - {`#$%@`, "", "#$%@", "", true}, - } { - actual, err := standardAddress(test.input) - - if err != nil && !test.shouldErr { - t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err) - } - if err == nil && test.shouldErr { - t.Errorf("Test %d (%s): Expected error, but had none", i, test.input) - } - - if actual.Scheme != test.scheme { - t.Errorf("Test %d (%s): Expected scheme '%s', got '%s'", i, test.input, test.scheme, actual.Scheme) - } - if actual.Host != test.host { - t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host) - } - if actual.Port != test.port { - t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port) - } - } -} - -func TestParseOneAndImport(t *testing.T) { - setupParseTests() - - testParseOne := func(input string) (ServerBlock, error) { - p := testParser(input) - p.Next() // parseOne doesn't call Next() to start, so we must - err := p.parseOne() - return p.block, err - } - - for i, test := range []struct { - input string - shouldErr bool - addresses []address - tokens map[string]int // map of directive name to number of tokens expected - }{ - {`localhost`, false, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{}}, - - {`localhost - dir1`, false, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{ - "dir1": 1, - }}, - - {`localhost:1234 - dir1 foo bar`, false, []address{ - {"localhost:1234", "", "localhost", "1234"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost { - dir1 - }`, false, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{ - "dir1": 1, - }}, - - {`localhost:1234 { - dir1 foo bar - dir2 - }`, false, []address{ - {"localhost:1234", "", "localhost", "1234"}, - }, map[string]int{ - "dir1": 3, - "dir2": 1, - }}, - - {`http://localhost https://localhost - dir1 foo bar`, false, []address{ - {"http://localhost", "http", "localhost", "80"}, - {"https://localhost", "https", "localhost", "443"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`http://localhost https://localhost { - dir1 foo bar - }`, false, []address{ - {"http://localhost", "http", "localhost", "80"}, - {"https://localhost", "https", "localhost", "443"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`http://localhost, https://localhost { - dir1 foo bar - }`, false, []address{ - {"http://localhost", "http", "localhost", "80"}, - {"https://localhost", "https", "localhost", "443"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`http://localhost, { - }`, true, []address{ - {"http://localhost", "http", "localhost", "80"}, - }, map[string]int{}}, - - {`host1:80, http://host2.com - dir1 foo bar - dir2 baz`, false, []address{ - {"host1:80", "", "host1", "80"}, - {"http://host2.com", "http", "host2.com", "80"}, - }, map[string]int{ - "dir1": 3, - "dir2": 2, - }}, - - {`http://host1.com, - http://host2.com, - https://host3.com`, false, []address{ - {"http://host1.com", "http", "host1.com", "80"}, - {"http://host2.com", "http", "host2.com", "80"}, - {"https://host3.com", "https", "host3.com", "443"}, - }, map[string]int{}}, - - {`http://host1.com:1234, https://host2.com - dir1 foo { - bar baz - } - dir2`, false, []address{ - {"http://host1.com:1234", "http", "host1.com", "1234"}, - {"https://host2.com", "https", "host2.com", "443"}, - }, map[string]int{ - "dir1": 6, - "dir2": 1, - }}, - - {`127.0.0.1 - dir1 { - bar baz - } - dir2 { - foo bar - }`, false, []address{ - {"127.0.0.1", "", "127.0.0.1", ""}, - }, map[string]int{ - "dir1": 5, - "dir2": 5, - }}, - - {`127.0.0.1 - unknown_directive`, true, []address{ - {"127.0.0.1", "", "127.0.0.1", ""}, - }, map[string]int{}}, - - {`localhost - dir1 { - foo`, true, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost - dir1 { - }`, false, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost - dir1 { - } }`, true, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost - dir1 { - nested { - foo - } - } - dir2 foo bar`, false, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{ - "dir1": 7, - "dir2": 3, - }}, - - {``, false, []address{}, map[string]int{}}, - - {`localhost - dir1 arg1 - import import_test1.txt`, false, []address{ - {"localhost", "", "localhost", ""}, - }, map[string]int{ - "dir1": 2, - "dir2": 3, - "dir3": 1, - }}, - - {`import import_test2.txt`, false, []address{ - {"host1", "", "host1", ""}, - }, map[string]int{ - "dir1": 1, - "dir2": 2, - }}, - - {`import import_test1.txt import_test2.txt`, true, []address{}, map[string]int{}}, - - {`import not_found.txt`, true, []address{}, map[string]int{}}, - - {`""`, false, []address{}, map[string]int{}}, - - {``, false, []address{}, map[string]int{}}, - } { - result, err := testParseOne(test.input) - - if test.shouldErr && err == nil { - t.Errorf("Test %d: Expected an error, but didn't get one", i) - } - if !test.shouldErr && err != nil { - t.Errorf("Test %d: Expected no error, but got: %v", i, err) - } - - if len(result.Addresses) != len(test.addresses) { - t.Errorf("Test %d: Expected %d addresses, got %d", - i, len(test.addresses), len(result.Addresses)) - continue - } - for j, addr := range result.Addresses { - if addr.Host != test.addresses[j].Host { - t.Errorf("Test %d, address %d: Expected host to be '%s', but was '%s'", - i, j, test.addresses[j].Host, addr.Host) - } - if addr.Port != test.addresses[j].Port { - t.Errorf("Test %d, address %d: Expected port to be '%s', but was '%s'", - i, j, test.addresses[j].Port, addr.Port) - } - } - - if len(result.Tokens) != len(test.tokens) { - t.Errorf("Test %d: Expected %d directives, had %d", - i, len(test.tokens), len(result.Tokens)) - continue - } - for directive, tokens := range result.Tokens { - if len(tokens) != test.tokens[directive] { - t.Errorf("Test %d, directive '%s': Expected %d tokens, counted %d", - i, directive, test.tokens[directive], len(tokens)) - continue - } - } - } -} - -func TestParseAll(t *testing.T) { - setupParseTests() - - for i, test := range []struct { - input string - shouldErr bool - addresses [][]address // addresses per server block, in order - }{ - {`localhost`, false, [][]address{ - {{"localhost", "", "localhost", ""}}, - }}, - - {`localhost:1234`, false, [][]address{ - {{"localhost:1234", "", "localhost", "1234"}}, - }}, - - {`localhost:1234 { - } - localhost:2015 { - }`, false, [][]address{ - {{"localhost:1234", "", "localhost", "1234"}}, - {{"localhost:2015", "", "localhost", "2015"}}, - }}, - - {`localhost:1234, http://host2`, false, [][]address{ - {{"localhost:1234", "", "localhost", "1234"}, {"http://host2", "http", "host2", "80"}}, - }}, - - {`localhost:1234, http://host2,`, true, [][]address{}}, - - {`http://host1.com, http://host2.com { - } - https://host3.com, https://host4.com { - }`, false, [][]address{ - {{"http://host1.com", "http", "host1.com", "80"}, {"http://host2.com", "http", "host2.com", "80"}}, - {{"https://host3.com", "https", "host3.com", "443"}, {"https://host4.com", "https", "host4.com", "443"}}, - }}, - - {`import import_glob*.txt`, false, [][]address{ - {{"glob0.host0", "", "glob0.host0", ""}}, - {{"glob0.host1", "", "glob0.host1", ""}}, - {{"glob1.host0", "", "glob1.host0", ""}}, - {{"glob2.host0", "", "glob2.host0", ""}}, - }}, - } { - p := testParser(test.input) - blocks, err := p.parseAll() - - if test.shouldErr && err == nil { - t.Errorf("Test %d: Expected an error, but didn't get one", i) - } - if !test.shouldErr && err != nil { - t.Errorf("Test %d: Expected no error, but got: %v", i, err) - } - - if len(blocks) != len(test.addresses) { - t.Errorf("Test %d: Expected %d server blocks, got %d", - i, len(test.addresses), len(blocks)) - continue - } - for j, block := range blocks { - if len(block.Addresses) != len(test.addresses[j]) { - t.Errorf("Test %d: Expected %d addresses in block %d, got %d", - i, len(test.addresses[j]), j, len(block.Addresses)) - continue - } - for k, addr := range block.Addresses { - if addr.Host != test.addresses[j][k].Host { - t.Errorf("Test %d, block %d, address %d: Expected host to be '%s', but was '%s'", - i, j, k, test.addresses[j][k].Host, addr.Host) - } - if addr.Port != test.addresses[j][k].Port { - t.Errorf("Test %d, block %d, address %d: Expected port to be '%s', but was '%s'", - i, j, k, test.addresses[j][k].Port, addr.Port) - } - } - } - } -} - -func TestEnvironmentReplacement(t *testing.T) { - setupParseTests() - - os.Setenv("PORT", "8080") - os.Setenv("ADDRESS", "servername.com") - os.Setenv("FOOBAR", "foobar") - - // basic test; unix-style env vars - p := testParser(`{$ADDRESS}`) - blocks, _ := p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - - // multiple vars per token - p = testParser(`{$ADDRESS}:{$PORT}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // windows-style var and unix style in same token - p = testParser(`{%ADDRESS%}:{$PORT}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // reverse order - p = testParser(`{$ADDRESS}:{%PORT%}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // env var in server block body as argument - p = testParser(":{%PORT%}\ndir1 {$FOOBAR}") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Tokens["dir1"][1].text, "foobar"; expected != actual { - t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) - } - - // combined windows env vars in argument - p = testParser(":{%PORT%}\ndir1 {%ADDRESS%}/{%FOOBAR%}") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Tokens["dir1"][1].text, "servername.com/foobar"; expected != actual { - t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) - } - - // malformed env var (windows) - p = testParser(":1234\ndir1 {%ADDRESS}") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Tokens["dir1"][1].text, "{%ADDRESS}"; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - - // malformed (non-existent) env var (unix) - p = testParser(`:{$PORT$}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Port, ""; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // in quoted field - p = testParser(":1234\ndir1 \"Test {$FOOBAR} test\"") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Tokens["dir1"][1].text, "Test foobar test"; expected != actual { - t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) - } -} - -func setupParseTests() { - // Set up some bogus directives for testing - ValidDirectives = map[string]struct{}{ - "dir1": {}, - "dir2": {}, - "dir3": {}, - } -} - -func testParser(input string) parser { - buf := strings.NewReader(input) - p := parser{Dispenser: NewDispenser("Test", buf), checkDirectives: true} - return p -} diff --git a/caddy/restart.go b/caddy/restart.go deleted file mode 100644 index afd1b79e8..000000000 --- a/caddy/restart.go +++ /dev/null @@ -1,92 +0,0 @@ -// +build !windows - -package caddy - -import ( - "bytes" - "errors" - "log" - "net" - "path/filepath" - - "github.com/mholt/caddy/caddy/https" -) - -// Restart restarts the entire application; gracefully with zero -// downtime if on a POSIX-compatible system, or forcefully if on -// Windows but with imperceptibly-short downtime. -// -// The behavior can be controlled by the RestartMode variable, -// where "inproc" will restart forcefully in process same as -// Windows on a POSIX-compatible system. -// -// The restarted application will use newCaddyfile as its input -// configuration. If newCaddyfile is nil, the current (existing) -// Caddyfile configuration will be used. -// -// Note: The process must exist in the same place on the disk in -// order for this to work. Thus, multiple graceful restarts don't -// work if executing with `go run`, since the binary is cleaned up -// when `go run` sees the initial parent process exit. -func Restart(newCaddyfile Input) error { - log.Println("[INFO] Restarting") - - if newCaddyfile == nil { - caddyfileMu.Lock() - newCaddyfile = caddyfile - caddyfileMu.Unlock() - } - - // Get certificates for any new hosts in the new Caddyfile without causing downtime - err := getCertsForNewCaddyfile(newCaddyfile) - if err != nil { - return errors.New("TLS preload: " + err.Error()) - } - - // Add file descriptors of all the sockets for new instance - serversMu.Lock() - for _, s := range servers { - restartFds[s.Addr] = s.ListenerFd() - } - serversMu.Unlock() - - return restartInProc(newCaddyfile) -} - -func getCertsForNewCaddyfile(newCaddyfile Input) error { - // parse the new caddyfile only up to (and including) TLS - // so we can know what we need to get certs for. - configs, _, _, err := loadConfigsUpToIncludingTLS(filepath.Base(newCaddyfile.Path()), bytes.NewReader(newCaddyfile.Body())) - if err != nil { - return errors.New("loading Caddyfile: " + err.Error()) - } - - // first mark the configs that are qualified for managed TLS - https.MarkQualified(configs) - - // since we group by bind address to obtain certs, we must call - // EnableTLS to make sure the port is set properly first - // (can ignore error since we aren't actually using the certs) - https.EnableTLS(configs, false) - - // find out if we can let the acme package start its own challenge listener - // on port 80 - var proxyACME bool - serversMu.Lock() - for _, s := range servers { - _, port, _ := net.SplitHostPort(s.Addr) - if port == "80" { - proxyACME = true - break - } - } - serversMu.Unlock() - - // place certs on the disk - err = https.ObtainCerts(configs, false, proxyACME) - if err != nil { - return errors.New("obtaining certs: " + err.Error()) - } - - return nil -} diff --git a/caddy/restart_windows.go b/caddy/restart_windows.go deleted file mode 100644 index d860e9131..000000000 --- a/caddy/restart_windows.go +++ /dev/null @@ -1,17 +0,0 @@ -package caddy - -import "log" - -// Restart restarts Caddy forcefully using newCaddyfile, -// or, if nil, the current/existing Caddyfile is reused. -func Restart(newCaddyfile Input) error { - log.Println("[INFO] Restarting") - - if newCaddyfile == nil { - caddyfileMu.Lock() - newCaddyfile = caddyfile - caddyfileMu.Unlock() - } - - return restartInProc(newCaddyfile) -} diff --git a/caddy/restartinproc.go b/caddy/restartinproc.go deleted file mode 100644 index 677857a14..000000000 --- a/caddy/restartinproc.go +++ /dev/null @@ -1,28 +0,0 @@ -package caddy - -import "log" - -// restartInProc restarts Caddy forcefully in process using newCaddyfile. -func restartInProc(newCaddyfile Input) error { - wg.Add(1) // barrier so Wait() doesn't unblock - defer wg.Done() - - err := Stop() - if err != nil { - return err - } - - caddyfileMu.Lock() - oldCaddyfile := caddyfile - caddyfileMu.Unlock() - - err = Start(newCaddyfile) - if err != nil { - // revert to old Caddyfile - if oldErr := Start(oldCaddyfile); oldErr != nil { - log.Printf("[ERROR] Restart: in-process restart failed and cannot revert to old Caddyfile: %v", oldErr) - } - } - - return err -} diff --git a/caddy/setup/basicauth.go b/caddy/setup/basicauth.go deleted file mode 100644 index bc57d1c6e..000000000 --- a/caddy/setup/basicauth.go +++ /dev/null @@ -1,72 +0,0 @@ -package setup - -import ( - "strings" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/basicauth" -) - -// BasicAuth configures a new BasicAuth middleware instance. -func BasicAuth(c *Controller) (middleware.Middleware, error) { - root := c.Root - - rules, err := basicAuthParse(c) - if err != nil { - return nil, err - } - - basic := basicauth.BasicAuth{Rules: rules} - - return func(next middleware.Handler) middleware.Handler { - basic.Next = next - basic.SiteRoot = root - return basic - }, nil -} - -func basicAuthParse(c *Controller) ([]basicauth.Rule, error) { - var rules []basicauth.Rule - - var err error - for c.Next() { - var rule basicauth.Rule - - args := c.RemainingArgs() - - switch len(args) { - case 2: - rule.Username = args[0] - if rule.Password, err = passwordMatcher(rule.Username, args[1], c.Root); err != nil { - return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err) - } - - for c.NextBlock() { - rule.Resources = append(rule.Resources, c.Val()) - if c.NextArg() { - return rules, c.Errf("Expecting only one resource per line (extra '%s')", c.Val()) - } - } - case 3: - rule.Resources = append(rule.Resources, args[0]) - rule.Username = args[1] - if rule.Password, err = passwordMatcher(rule.Username, args[2], c.Root); err != nil { - return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err) - } - default: - return rules, c.ArgErr() - } - - rules = append(rules, rule) - } - - return rules, nil -} - -func passwordMatcher(username, passw, siteRoot string) (basicauth.PasswordMatcher, error) { - if !strings.HasPrefix(passw, "htpasswd=") { - return basicauth.PlainMatcher(passw), nil - } - - return basicauth.GetHtpasswdMatcher(passw[9:], username, siteRoot) -} diff --git a/caddy/setup/basicauth_test.go b/caddy/setup/basicauth_test.go deleted file mode 100644 index 186a3e97e..000000000 --- a/caddy/setup/basicauth_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package setup - -import ( - "fmt" - "io/ioutil" - "os" - "strings" - "testing" - - "github.com/mholt/caddy/middleware/basicauth" -) - -func TestBasicAuth(t *testing.T) { - c := NewTestController(`basicauth user pwd`) - - mid, err := BasicAuth(c) - if err != nil { - t.Errorf("Expected no errors, but got: %v", err) - } - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(basicauth.BasicAuth) - if !ok { - t.Fatalf("Expected handler to be type BasicAuth, got: %#v", handler) - } - - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } -} - -func TestBasicAuthParse(t *testing.T) { - htpasswdPasswd := "IedFOuGmTpT8" - htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww= -md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61` - - var skipHtpassword bool - htfh, err := ioutil.TempFile(".", "basicauth-") - if err != nil { - t.Logf("Error creating temp file (%v), will skip htpassword test", err) - skipHtpassword = true - } else { - if _, err = htfh.Write([]byte(htpasswdFile)); err != nil { - t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err) - } - htfh.Close() - defer os.Remove(htfh.Name()) - } - - tests := []struct { - input string - shouldErr bool - password string - expected []basicauth.Rule - }{ - {`basicauth user pwd`, false, "pwd", []basicauth.Rule{ - {Username: "user"}, - }}, - {`basicauth user pwd { - }`, false, "pwd", []basicauth.Rule{ - {Username: "user"}, - }}, - {`basicauth user pwd { - /resource1 - /resource2 - }`, false, "pwd", []basicauth.Rule{ - {Username: "user", Resources: []string{"/resource1", "/resource2"}}, - }}, - {`basicauth /resource user pwd`, false, "pwd", []basicauth.Rule{ - {Username: "user", Resources: []string{"/resource"}}, - }}, - {`basicauth /res1 user1 pwd1 - basicauth /res2 user2 pwd2`, false, "pwd", []basicauth.Rule{ - {Username: "user1", Resources: []string{"/res1"}}, - {Username: "user2", Resources: []string{"/res2"}}, - }}, - {`basicauth user`, true, "", []basicauth.Rule{}}, - {`basicauth`, true, "", []basicauth.Rule{}}, - {`basicauth /resource user pwd asdf`, true, "", []basicauth.Rule{}}, - - {`basicauth sha1 htpasswd=` + htfh.Name(), false, htpasswdPasswd, []basicauth.Rule{ - {Username: "sha1"}, - }}, - } - - for i, test := range tests { - c := NewTestController(test.input) - actual, err := basicAuthParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - - if len(actual) != len(test.expected) { - t.Fatalf("Test %d expected %d rules, but got %d", - i, len(test.expected), len(actual)) - } - - for j, expectedRule := range test.expected { - actualRule := actual[j] - - if actualRule.Username != expectedRule.Username { - t.Errorf("Test %d, rule %d: Expected username '%s', got '%s'", - i, j, expectedRule.Username, actualRule.Username) - } - - if strings.Contains(test.input, "htpasswd=") && skipHtpassword { - continue - } - pwd := test.password - if len(actual) > 1 { - pwd = fmt.Sprintf("%s%d", pwd, j+1) - } - if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") { - t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'", - i, j, test.password, actualRule.Password("")) - } - - expectedRes := fmt.Sprintf("%v", expectedRule.Resources) - actualRes := fmt.Sprintf("%v", actualRule.Resources) - if actualRes != expectedRes { - t.Errorf("Test %d, rule %d: Expected resource list %s, but got %s", - i, j, expectedRes, actualRes) - } - } - } -} diff --git a/caddy/setup/bindhost.go b/caddy/setup/bindhost.go deleted file mode 100644 index 363163dcb..000000000 --- a/caddy/setup/bindhost.go +++ /dev/null @@ -1,13 +0,0 @@ -package setup - -import "github.com/mholt/caddy/middleware" - -// BindHost sets the host to bind the listener to. -func BindHost(c *Controller) (middleware.Middleware, error) { - for c.Next() { - if !c.Args(&c.BindHost) { - return nil, c.ArgErr() - } - } - return nil, nil -} diff --git a/caddy/setup/browse.go b/caddy/setup/browse.go deleted file mode 100644 index fdb667227..000000000 --- a/caddy/setup/browse.go +++ /dev/null @@ -1,416 +0,0 @@ -package setup - -import ( - "fmt" - "io/ioutil" - "net/http" - "text/template" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/browse" -) - -// Browse configures a new Browse middleware instance. -func Browse(c *Controller) (middleware.Middleware, error) { - configs, err := browseParse(c) - if err != nil { - return nil, err - } - - browse := browse.Browse{ - Configs: configs, - IgnoreIndexes: false, - } - - return func(next middleware.Handler) middleware.Handler { - browse.Next = next - return browse - }, nil -} - -func browseParse(c *Controller) ([]browse.Config, error) { - var configs []browse.Config - - appendCfg := func(bc browse.Config) error { - for _, c := range configs { - if c.PathScope == bc.PathScope { - return fmt.Errorf("duplicate browsing config for %s", c.PathScope) - } - } - configs = append(configs, bc) - return nil - } - - for c.Next() { - var bc browse.Config - - // First argument is directory to allow browsing; default is site root - if c.NextArg() { - bc.PathScope = c.Val() - } else { - bc.PathScope = "/" - } - bc.Root = http.Dir(c.Root) - theRoot, err := bc.Root.Open("/") // catch a missing path early - if err != nil { - return configs, err - } - defer theRoot.Close() - _, err = theRoot.Readdir(-1) - if err != nil { - return configs, err - } - - // Second argument would be the template file to use - var tplText string - if c.NextArg() { - tplBytes, err := ioutil.ReadFile(c.Val()) - if err != nil { - return configs, err - } - tplText = string(tplBytes) - } else { - tplText = defaultTemplate - } - - // Build the template - tpl, err := template.New("listing").Parse(tplText) - if err != nil { - return configs, err - } - bc.Template = tpl - - // Save configuration - err = appendCfg(bc) - if err != nil { - return configs, err - } - } - - return configs, nil -} - -// The default template to use when serving up directory listings -const defaultTemplate = ` - - - {{.Name}} - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-

- {{range $url, $name := .BreadcrumbMap}}{{$name}}{{if ne $url "/"}}/{{end}}{{end}} -

-
-
-
-
- {{.NumDirs}} director{{if eq 1 .NumDirs}}y{{else}}ies{{end}} - {{.NumFiles}} file{{if ne 1 .NumFiles}}s{{end}} - {{- if ne 0 .ItemsLimitedTo}} - (of which only {{.ItemsLimitedTo}} are displayed) - {{- end}} -
-
-
- - - - - - - - - - {{- if .CanGoUp}} - - - - - - {{- end}} - {{- range .Items}} - - - {{- if .IsDir}} - - {{- else}} - - {{- end}} - - - {{- end}} - -
- {{- if and (eq .Sort "name") (ne .Order "desc")}} - Name - {{- else if and (eq .Sort "name") (ne .Order "asc")}} - Name - {{- else}} - Name - {{- end}} - - {{- if and (eq .Sort "size") (ne .Order "desc")}} - Size - {{- else if and (eq .Sort "size") (ne .Order "asc")}} - Size - {{- else}} - Size - {{- end}} - - {{- if and (eq .Sort "time") (ne .Order "desc")}} - Modified - {{- else if and (eq .Sort "time") (ne .Order "asc")}} - Modified - {{- else}} - Modified - {{- end}} -
- - Go up - -
- - {{- if .IsDir}} - - {{- else}} - - {{- end}} - {{.Name}} - - {{.HumanSize}}
-
-
- - - -` diff --git a/caddy/setup/browse_test.go b/caddy/setup/browse_test.go deleted file mode 100644 index 443e008bb..000000000 --- a/caddy/setup/browse_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package setup - -import ( - "io/ioutil" - "os" - "path/filepath" - "strconv" - "testing" - "time" - - "github.com/mholt/caddy/middleware/browse" -) - -func TestBrowse(t *testing.T) { - - tempDirPath, err := getTempDirPath() - if err != nil { - t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) - } - nonExistantDirPath := filepath.Join(tempDirPath, strconv.Itoa(int(time.Now().UnixNano()))) - - tempTemplate, err := ioutil.TempFile(".", "tempTemplate") - if err != nil { - t.Fatalf("BeforeTest: Failed to create a temporary file in the working directory! Error was: %v", err) - } - defer os.Remove(tempTemplate.Name()) - - tempTemplatePath := filepath.Join(".", tempTemplate.Name()) - - for i, test := range []struct { - input string - expectedPathScope []string - shouldErr bool - }{ - // test case #0 tests handling of multiple pathscopes - {"browse " + tempDirPath + "\n browse .", []string{tempDirPath, "."}, false}, - - // test case #1 tests instantiation of browse.Config with default values - {"browse /", []string{"/"}, false}, - - // test case #2 tests detectaction of custom template - {"browse . " + tempTemplatePath, []string{"."}, false}, - - // test case #3 tests detection of non-existent template - {"browse . " + nonExistantDirPath, nil, true}, - - // test case #4 tests detection of duplicate pathscopes - {"browse " + tempDirPath + "\n browse " + tempDirPath, nil, true}, - } { - - recievedFunc, err := Browse(NewTestController(test.input)) - if err != nil && !test.shouldErr { - t.Errorf("Test case #%d recieved an error of %v", i, err) - } - if test.expectedPathScope == nil { - continue - } - recievedConfigs := recievedFunc(nil).(browse.Browse).Configs - for j, config := range recievedConfigs { - if config.PathScope != test.expectedPathScope[j] { - t.Errorf("Test case #%d expected a pathscope of %v, but got %v", i, test.expectedPathScope, config.PathScope) - } - } - } -} diff --git a/caddy/setup/controller.go b/caddy/setup/controller.go deleted file mode 100644 index e31207263..000000000 --- a/caddy/setup/controller.go +++ /dev/null @@ -1,83 +0,0 @@ -package setup - -import ( - "fmt" - "net/http" - "strings" - - "github.com/mholt/caddy/caddy/parse" - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/server" -) - -// Controller is given to the setup function of middlewares which -// gives them access to be able to read tokens and set config. Each -// virtualhost gets their own server config and dispenser. -type Controller struct { - *server.Config - parse.Dispenser - - // OncePerServerBlock is a function that executes f - // exactly once per server block, no matter how many - // hosts are associated with it. If it is the first - // time, the function f is executed immediately - // (not deferred) and may return an error which is - // returned by OncePerServerBlock. - OncePerServerBlock func(f func() error) error - - // ServerBlockIndex is the 0-based index of the - // server block as it appeared in the input. - ServerBlockIndex int - - // ServerBlockHostIndex is the 0-based index of this - // host as it appeared in the input at the head of the - // server block. - ServerBlockHostIndex int - - // ServerBlockHosts is a list of hosts that are - // associated with this server block. All these - // hosts, consequently, share the same tokens. - ServerBlockHosts []string - - // ServerBlockStorage is used by a directive's - // setup function to persist state between all - // the hosts on a server block. - ServerBlockStorage interface{} -} - -// NewTestController creates a new *Controller for -// the input specified, with a filename of "Testfile". -// The Config is bare, consisting only of a Root of cwd. -// -// Used primarily for testing but needs to be exported so -// add-ons can use this as a convenience. Does not initialize -// the server-block-related fields. -func NewTestController(input string) *Controller { - return &Controller{ - Config: &server.Config{ - Root: ".", - }, - Dispenser: parse.NewDispenser("Testfile", strings.NewReader(input)), - OncePerServerBlock: func(f func() error) error { - return f() - }, - } -} - -// EmptyNext is a no-op function that can be passed into -// middleware.Middleware functions so that the assignment -// to the Next field of the Handler can be tested. -// -// Used primarily for testing but needs to be exported so -// add-ons can use this as a convenience. -var EmptyNext = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - return 0, nil -}) - -// SameNext does a pointer comparison between next1 and next2. -// -// Used primarily for testing but needs to be exported so -// add-ons can use this as a convenience. -func SameNext(next1, next2 middleware.Handler) bool { - return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2) -} diff --git a/caddy/setup/errors.go b/caddy/setup/errors.go deleted file mode 100644 index b4c0ab697..000000000 --- a/caddy/setup/errors.go +++ /dev/null @@ -1,148 +0,0 @@ -package setup - -import ( - "io" - "log" - "os" - "path/filepath" - "strconv" - - "github.com/hashicorp/go-syslog" - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/errors" -) - -// Errors configures a new errors middleware instance. -func Errors(c *Controller) (middleware.Middleware, error) { - handler, err := errorsParse(c) - if err != nil { - return nil, err - } - - // Open the log file for writing when the server starts - c.Startup = append(c.Startup, func() error { - var err error - var writer io.Writer - - switch handler.LogFile { - case "visible": - handler.Debug = true - case "stdout": - writer = os.Stdout - case "stderr": - writer = os.Stderr - case "syslog": - writer, err = gsyslog.NewLogger(gsyslog.LOG_ERR, "LOCAL0", "caddy") - if err != nil { - return err - } - default: - if handler.LogFile == "" { - writer = os.Stderr // default - break - } - - var file *os.File - file, err = os.OpenFile(handler.LogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - return err - } - if handler.LogRoller != nil { - file.Close() - - handler.LogRoller.Filename = handler.LogFile - - writer = handler.LogRoller.GetLogWriter() - } else { - writer = file - } - } - - handler.Log = log.New(writer, "", 0) - return nil - }) - - return func(next middleware.Handler) middleware.Handler { - handler.Next = next - return handler - }, nil -} - -func errorsParse(c *Controller) (*errors.ErrorHandler, error) { - // Very important that we make a pointer because the Startup - // function that opens the log file must have access to the - // same instance of the handler, not a copy. - handler := &errors.ErrorHandler{ErrorPages: make(map[int]string)} - - optionalBlock := func() (bool, error) { - var hadBlock bool - - for c.NextBlock() { - hadBlock = true - - what := c.Val() - if !c.NextArg() { - return hadBlock, c.ArgErr() - } - where := c.Val() - - if what == "log" { - if where == "visible" { - handler.Debug = true - } else { - handler.LogFile = where - if c.NextArg() { - if c.Val() == "{" { - c.IncrNest() - logRoller, err := parseRoller(c) - if err != nil { - return hadBlock, err - } - handler.LogRoller = logRoller - } - } - } - } else { - // Error page; ensure it exists - where = filepath.Join(c.Root, where) - f, err := os.Open(where) - if err != nil { - log.Printf("[WARNING] Unable to open error page '%s': %v", where, err) - } - f.Close() - - whatInt, err := strconv.Atoi(what) - if err != nil { - return hadBlock, c.Err("Expecting a numeric status code, got '" + what + "'") - } - handler.ErrorPages[whatInt] = where - } - } - return hadBlock, nil - } - - for c.Next() { - // weird hack to avoid having the handler values overwritten. - if c.Val() == "}" { - continue - } - // Configuration may be in a block - hadBlock, err := optionalBlock() - if err != nil { - return handler, err - } - - // Otherwise, the only argument would be an error log file name or 'visible' - if !hadBlock { - if c.NextArg() { - if c.Val() == "visible" { - handler.Debug = true - } else { - handler.LogFile = c.Val() - } - } - } - } - - return handler, nil -} diff --git a/caddy/setup/errors_test.go b/caddy/setup/errors_test.go deleted file mode 100644 index ace04624d..000000000 --- a/caddy/setup/errors_test.go +++ /dev/null @@ -1,158 +0,0 @@ -package setup - -import ( - "testing" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/errors" -) - -func TestErrors(t *testing.T) { - c := NewTestController(`errors`) - mid, err := Errors(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(*errors.ErrorHandler) - - if !ok { - t.Fatalf("Expected handler to be type ErrorHandler, got: %#v", handler) - } - - if myHandler.LogFile != "" { - t.Errorf("Expected '%s' as the default LogFile", "") - } - if myHandler.LogRoller != nil { - t.Errorf("Expected LogRoller to be nil, got: %v", *myHandler.LogRoller) - } - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } - - // Test Startup function - if len(c.Startup) == 0 { - t.Fatal("Expected 1 startup function, had 0") - } - c.Startup[0]() - if myHandler.Log == nil { - t.Error("Expected Log to be non-nil after startup because Debug is not enabled") - } -} - -func TestErrorsParse(t *testing.T) { - tests := []struct { - inputErrorsRules string - shouldErr bool - expectedErrorHandler errors.ErrorHandler - }{ - {`errors`, false, errors.ErrorHandler{ - LogFile: "", - }}, - {`errors errors.txt`, false, errors.ErrorHandler{ - LogFile: "errors.txt", - }}, - {`errors visible`, false, errors.ErrorHandler{ - LogFile: "", - Debug: true, - }}, - {`errors { log visible }`, false, errors.ErrorHandler{ - LogFile: "", - Debug: true, - }}, - {`errors { log errors.txt - 404 404.html - 500 500.html -}`, false, errors.ErrorHandler{ - LogFile: "errors.txt", - ErrorPages: map[int]string{ - 404: "404.html", - 500: "500.html", - }, - }}, - {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, errors.ErrorHandler{ - LogFile: "errors.txt", - LogRoller: &middleware.LogRoller{ - MaxSize: 2, - MaxAge: 10, - MaxBackups: 3, - LocalTime: true, - }, - }}, - {`errors { log errors.txt { - size 3 - age 11 - keep 5 - } - 404 404.html - 503 503.html -}`, false, errors.ErrorHandler{ - LogFile: "errors.txt", - ErrorPages: map[int]string{ - 404: "404.html", - 503: "503.html", - }, - LogRoller: &middleware.LogRoller{ - MaxSize: 3, - MaxAge: 11, - MaxBackups: 5, - LocalTime: true, - }, - }}, - } - for i, test := range tests { - c := NewTestController(test.inputErrorsRules) - actualErrorsRule, err := errorsParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - if actualErrorsRule.LogFile != test.expectedErrorHandler.LogFile { - t.Errorf("Test %d expected LogFile to be %s, but got %s", - i, test.expectedErrorHandler.LogFile, actualErrorsRule.LogFile) - } - if actualErrorsRule.Debug != test.expectedErrorHandler.Debug { - t.Errorf("Test %d expected Debug to be %v, but got %v", - i, test.expectedErrorHandler.Debug, actualErrorsRule.Debug) - } - if actualErrorsRule.LogRoller != nil && test.expectedErrorHandler.LogRoller == nil || actualErrorsRule.LogRoller == nil && test.expectedErrorHandler.LogRoller != nil { - t.Fatalf("Test %d expected LogRoller to be %v, but got %v", - i, test.expectedErrorHandler.LogRoller, actualErrorsRule.LogRoller) - } - if len(actualErrorsRule.ErrorPages) != len(test.expectedErrorHandler.ErrorPages) { - t.Fatalf("Test %d expected %d no of Error pages, but got %d ", - i, len(test.expectedErrorHandler.ErrorPages), len(actualErrorsRule.ErrorPages)) - } - if actualErrorsRule.LogRoller != nil && test.expectedErrorHandler.LogRoller != nil { - if actualErrorsRule.LogRoller.Filename != test.expectedErrorHandler.LogRoller.Filename { - t.Fatalf("Test %d expected LogRoller Filename to be %s, but got %s", - i, test.expectedErrorHandler.LogRoller.Filename, actualErrorsRule.LogRoller.Filename) - } - if actualErrorsRule.LogRoller.MaxAge != test.expectedErrorHandler.LogRoller.MaxAge { - t.Fatalf("Test %d expected LogRoller MaxAge to be %d, but got %d", - i, test.expectedErrorHandler.LogRoller.MaxAge, actualErrorsRule.LogRoller.MaxAge) - } - if actualErrorsRule.LogRoller.MaxBackups != test.expectedErrorHandler.LogRoller.MaxBackups { - t.Fatalf("Test %d expected LogRoller MaxBackups to be %d, but got %d", - i, test.expectedErrorHandler.LogRoller.MaxBackups, actualErrorsRule.LogRoller.MaxBackups) - } - if actualErrorsRule.LogRoller.MaxSize != test.expectedErrorHandler.LogRoller.MaxSize { - t.Fatalf("Test %d expected LogRoller MaxSize to be %d, but got %d", - i, test.expectedErrorHandler.LogRoller.MaxSize, actualErrorsRule.LogRoller.MaxSize) - } - if actualErrorsRule.LogRoller.LocalTime != test.expectedErrorHandler.LogRoller.LocalTime { - t.Fatalf("Test %d expected LogRoller LocalTime to be %t, but got %t", - i, test.expectedErrorHandler.LogRoller.LocalTime, actualErrorsRule.LogRoller.LocalTime) - } - } - } - -} diff --git a/caddy/setup/expvar.go b/caddy/setup/expvar.go deleted file mode 100644 index 4d9c353de..000000000 --- a/caddy/setup/expvar.go +++ /dev/null @@ -1,60 +0,0 @@ -package setup - -import ( - stdexpvar "expvar" - "runtime" - "sync" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/expvar" -) - -// ExpVar configures a new ExpVar middleware instance. -func ExpVar(c *Controller) (middleware.Middleware, error) { - resource, err := expVarParse(c) - if err != nil { - return nil, err - } - - // publish any extra information/metrics we may want to capture - publishExtraVars() - - expvar := expvar.ExpVar{Resource: resource} - - return func(next middleware.Handler) middleware.Handler { - expvar.Next = next - return expvar - }, nil -} - -func expVarParse(c *Controller) (expvar.Resource, error) { - var resource expvar.Resource - var err error - - for c.Next() { - args := c.RemainingArgs() - switch len(args) { - case 0: - resource = expvar.Resource(defaultExpvarPath) - case 1: - resource = expvar.Resource(args[0]) - default: - return resource, c.ArgErr() - } - } - - return resource, err -} - -func publishExtraVars() { - // By using sync.Once instead of an init() function, we don't clutter - // the app's expvar export unnecessarily, or risk colliding with it. - publishOnce.Do(func() { - stdexpvar.Publish("Goroutines", stdexpvar.Func(func() interface{} { - return runtime.NumGoroutine() - })) - }) -} - -var publishOnce sync.Once // publishing variables should only be done once -var defaultExpvarPath = "/debug/vars" diff --git a/caddy/setup/expvar_test.go b/caddy/setup/expvar_test.go deleted file mode 100644 index 5fb018ce9..000000000 --- a/caddy/setup/expvar_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package setup - -import ( - "testing" - - "github.com/mholt/caddy/middleware/expvar" -) - -func TestExpvar(t *testing.T) { - c := NewTestController(`expvar`) - mid, err := ExpVar(c) - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - c = NewTestController(`expvar /d/v`) - mid, err = ExpVar(c) - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(expvar.ExpVar) - if !ok { - t.Fatalf("Expected handler to be type ExpVar, got: %#v", handler) - } - if myHandler.Resource != "/d/v" { - t.Errorf("Expected /d/v as expvar resource") - } - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } -} diff --git a/caddy/setup/ext.go b/caddy/setup/ext.go deleted file mode 100644 index bfd4cba55..000000000 --- a/caddy/setup/ext.go +++ /dev/null @@ -1,55 +0,0 @@ -package setup - -import ( - "os" - "path/filepath" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/extensions" -) - -// Ext configures a new instance of 'extensions' middleware for clean URLs. -func Ext(c *Controller) (middleware.Middleware, error) { - root := c.Root - - exts, err := extParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return extensions.Ext{ - Next: next, - Extensions: exts, - Root: root, - } - }, nil -} - -// extParse sets up an instance of extension middleware -// from a middleware controller and returns a list of extensions. -func extParse(c *Controller) ([]string, error) { - var exts []string - - for c.Next() { - // At least one extension is required - if !c.NextArg() { - return exts, c.ArgErr() - } - exts = append(exts, c.Val()) - - // Tack on any other extensions that may have been listed - exts = append(exts, c.RemainingArgs()...) - } - - return exts, nil -} - -// resourceExists returns true if the file specified at -// root + path exists; false otherwise. -func resourceExists(root, path string) bool { - _, err := os.Stat(filepath.Join(root, path)) - // technically we should use os.IsNotExist(err) - // but we don't handle any other kinds of errors anyway - return err == nil -} diff --git a/caddy/setup/ext_test.go b/caddy/setup/ext_test.go deleted file mode 100644 index 24e3cf947..000000000 --- a/caddy/setup/ext_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package setup - -import ( - "testing" - - "github.com/mholt/caddy/middleware/extensions" -) - -func TestExt(t *testing.T) { - c := NewTestController(`ext .html .htm .php`) - - mid, err := Ext(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(extensions.Ext) - - if !ok { - t.Fatalf("Expected handler to be type Ext, got: %#v", handler) - } - - if myHandler.Extensions[0] != ".html" { - t.Errorf("Expected .html in the list of Extensions") - } - if myHandler.Extensions[1] != ".htm" { - t.Errorf("Expected .htm in the list of Extensions") - } - if myHandler.Extensions[2] != ".php" { - t.Errorf("Expected .php in the list of Extensions") - } - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } - -} - -func TestExtParse(t *testing.T) { - tests := []struct { - inputExts string - shouldErr bool - expectedExts []string - }{ - {`ext .html .htm .php`, false, []string{".html", ".htm", ".php"}}, - {`ext .php .html .xml`, false, []string{".php", ".html", ".xml"}}, - {`ext .txt .php .xml`, false, []string{".txt", ".php", ".xml"}}, - } - for i, test := range tests { - c := NewTestController(test.inputExts) - actualExts, err := extParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - - if len(actualExts) != len(test.expectedExts) { - t.Fatalf("Test %d expected %d rules, but got %d", - i, len(test.expectedExts), len(actualExts)) - } - for j, actualExt := range actualExts { - if actualExt != test.expectedExts[j] { - t.Fatalf("Test %d expected %dth extension to be %s , but got %s", - i, j, test.expectedExts[j], actualExt) - } - } - } - -} diff --git a/caddy/setup/fastcgi.go b/caddy/setup/fastcgi.go deleted file mode 100644 index d1e53d151..000000000 --- a/caddy/setup/fastcgi.go +++ /dev/null @@ -1,116 +0,0 @@ -package setup - -import ( - "errors" - "net/http" - "path/filepath" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/fastcgi" -) - -// FastCGI configures a new FastCGI middleware instance. -func FastCGI(c *Controller) (middleware.Middleware, error) { - absRoot, err := filepath.Abs(c.Root) - if err != nil { - return nil, err - } - - rules, err := fastcgiParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return fastcgi.Handler{ - Next: next, - Rules: rules, - Root: c.Root, - AbsRoot: absRoot, - FileSys: http.Dir(c.Root), - SoftwareName: c.AppName, - SoftwareVersion: c.AppVersion, - ServerName: c.Host, - ServerPort: c.Port, - } - }, nil -} - -func fastcgiParse(c *Controller) ([]fastcgi.Rule, error) { - var rules []fastcgi.Rule - - for c.Next() { - var rule fastcgi.Rule - - args := c.RemainingArgs() - - switch len(args) { - case 0: - return rules, c.ArgErr() - case 1: - rule.Path = "/" - rule.Address = args[0] - case 2: - rule.Path = args[0] - rule.Address = args[1] - case 3: - rule.Path = args[0] - rule.Address = args[1] - err := fastcgiPreset(args[2], &rule) - if err != nil { - return rules, c.Err("Invalid fastcgi rule preset '" + args[2] + "'") - } - } - - for c.NextBlock() { - switch c.Val() { - case "ext": - if !c.NextArg() { - return rules, c.ArgErr() - } - rule.Ext = c.Val() - case "split": - if !c.NextArg() { - return rules, c.ArgErr() - } - rule.SplitPath = c.Val() - case "index": - args := c.RemainingArgs() - if len(args) == 0 { - return rules, c.ArgErr() - } - rule.IndexFiles = args - case "env": - envArgs := c.RemainingArgs() - if len(envArgs) < 2 { - return rules, c.ArgErr() - } - rule.EnvVars = append(rule.EnvVars, [2]string{envArgs[0], envArgs[1]}) - case "except": - ignoredPaths := c.RemainingArgs() - if len(ignoredPaths) == 0 { - return rules, c.ArgErr() - } - rule.IgnoredSubPaths = ignoredPaths - } - } - - rules = append(rules, rule) - } - - return rules, nil -} - -// fastcgiPreset configures rule according to name. It returns an error if -// name is not a recognized preset name. -func fastcgiPreset(name string, rule *fastcgi.Rule) error { - switch name { - case "php": - rule.Ext = ".php" - rule.SplitPath = ".php" - rule.IndexFiles = []string{"index.php"} - default: - return errors.New(name + " is not a valid preset name") - } - return nil -} diff --git a/caddy/setup/fastcgi_test.go b/caddy/setup/fastcgi_test.go deleted file mode 100644 index 366446dee..000000000 --- a/caddy/setup/fastcgi_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package setup - -import ( - "fmt" - "testing" - - "github.com/mholt/caddy/middleware/fastcgi" -) - -func TestFastCGI(t *testing.T) { - - c := NewTestController(`fastcgi / 127.0.0.1:9000`) - - mid, err := FastCGI(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(fastcgi.Handler) - - if !ok { - t.Fatalf("Expected handler to be type , got: %#v", handler) - } - - if myHandler.Rules[0].Path != "/" { - t.Errorf("Expected / as the Path") - } - if myHandler.Rules[0].Address != "127.0.0.1:9000" { - t.Errorf("Expected 127.0.0.1:9000 as the Address") - } - -} - -func TestFastcgiParse(t *testing.T) { - tests := []struct { - inputFastcgiConfig string - shouldErr bool - expectedFastcgiConfig []fastcgi.Rule - }{ - - {`fastcgi /blog 127.0.0.1:9000 php`, - false, []fastcgi.Rule{{ - Path: "/blog", - Address: "127.0.0.1:9000", - Ext: ".php", - SplitPath: ".php", - IndexFiles: []string{"index.php"}, - }}}, - {`fastcgi / 127.0.0.1:9001 { - split .html - }`, - false, []fastcgi.Rule{{ - Path: "/", - Address: "127.0.0.1:9001", - Ext: "", - SplitPath: ".html", - IndexFiles: []string{}, - }}}, - {`fastcgi / 127.0.0.1:9001 { - split .html - except /admin /user - }`, - false, []fastcgi.Rule{{ - Path: "/", - Address: "127.0.0.1:9001", - Ext: "", - SplitPath: ".html", - IndexFiles: []string{}, - IgnoredSubPaths: []string{"/admin", "/user"}, - }}}, - } - for i, test := range tests { - c := NewTestController(test.inputFastcgiConfig) - actualFastcgiConfigs, err := fastcgiParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - if len(actualFastcgiConfigs) != len(test.expectedFastcgiConfig) { - t.Fatalf("Test %d expected %d no of FastCGI configs, but got %d ", - i, len(test.expectedFastcgiConfig), len(actualFastcgiConfigs)) - } - for j, actualFastcgiConfig := range actualFastcgiConfigs { - - if actualFastcgiConfig.Path != test.expectedFastcgiConfig[j].Path { - t.Errorf("Test %d expected %dth FastCGI Path to be %s , but got %s", - i, j, test.expectedFastcgiConfig[j].Path, actualFastcgiConfig.Path) - } - - if actualFastcgiConfig.Address != test.expectedFastcgiConfig[j].Address { - t.Errorf("Test %d expected %dth FastCGI Address to be %s , but got %s", - i, j, test.expectedFastcgiConfig[j].Address, actualFastcgiConfig.Address) - } - - if actualFastcgiConfig.Ext != test.expectedFastcgiConfig[j].Ext { - t.Errorf("Test %d expected %dth FastCGI Ext to be %s , but got %s", - i, j, test.expectedFastcgiConfig[j].Ext, actualFastcgiConfig.Ext) - } - - if actualFastcgiConfig.SplitPath != test.expectedFastcgiConfig[j].SplitPath { - t.Errorf("Test %d expected %dth FastCGI SplitPath to be %s , but got %s", - i, j, test.expectedFastcgiConfig[j].SplitPath, actualFastcgiConfig.SplitPath) - } - - if fmt.Sprint(actualFastcgiConfig.IndexFiles) != fmt.Sprint(test.expectedFastcgiConfig[j].IndexFiles) { - t.Errorf("Test %d expected %dth FastCGI IndexFiles to be %s , but got %s", - i, j, test.expectedFastcgiConfig[j].IndexFiles, actualFastcgiConfig.IndexFiles) - } - - if fmt.Sprint(actualFastcgiConfig.IgnoredSubPaths) != fmt.Sprint(test.expectedFastcgiConfig[j].IgnoredSubPaths) { - t.Errorf("Test %d expected %dth FastCGI IgnoredSubPaths to be %s , but got %s", - i, j, test.expectedFastcgiConfig[j].IgnoredSubPaths, actualFastcgiConfig.IgnoredSubPaths) - } - } - } - -} diff --git a/caddy/setup/headers.go b/caddy/setup/headers.go deleted file mode 100644 index 553f20b18..000000000 --- a/caddy/setup/headers.go +++ /dev/null @@ -1,84 +0,0 @@ -package setup - -import ( - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/headers" -) - -// Headers configures a new Headers middleware instance. -func Headers(c *Controller) (middleware.Middleware, error) { - rules, err := headersParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return headers.Headers{Next: next, Rules: rules} - }, nil -} - -func headersParse(c *Controller) ([]headers.Rule, error) { - var rules []headers.Rule - - for c.NextLine() { - var head headers.Rule - var isNewPattern bool - - if !c.NextArg() { - return rules, c.ArgErr() - } - pattern := c.Val() - - // See if we already have a definition for this Path pattern... - for _, h := range rules { - if h.Path == pattern { - head = h - break - } - } - - // ...otherwise, this is a new pattern - if head.Path == "" { - head.Path = pattern - isNewPattern = true - } - - for c.NextBlock() { - // A block of headers was opened... - - h := headers.Header{Name: c.Val()} - - if c.NextArg() { - h.Value = c.Val() - } - - head.Headers = append(head.Headers, h) - } - if c.NextArg() { - // ... or single header was defined as an argument instead. - - h := headers.Header{Name: c.Val()} - - h.Value = c.Val() - - if c.NextArg() { - h.Value = c.Val() - } - - head.Headers = append(head.Headers, h) - } - - if isNewPattern { - rules = append(rules, head) - } else { - for i := 0; i < len(rules); i++ { - if rules[i].Path == pattern { - rules[i] = head - break - } - } - } - } - - return rules, nil -} diff --git a/caddy/setup/headers_test.go b/caddy/setup/headers_test.go deleted file mode 100644 index 7b111cb42..000000000 --- a/caddy/setup/headers_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package setup - -import ( - "fmt" - "testing" - - "github.com/mholt/caddy/middleware/headers" -) - -func TestHeaders(t *testing.T) { - c := NewTestController(`header / Foo Bar`) - - mid, err := Headers(c) - if err != nil { - t.Errorf("Expected no errors, but got: %v", err) - } - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(headers.Headers) - if !ok { - t.Fatalf("Expected handler to be type Headers, got: %#v", handler) - } - - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } -} - -func TestHeadersParse(t *testing.T) { - tests := []struct { - input string - shouldErr bool - expected []headers.Rule - }{ - {`header /foo Foo "Bar Baz"`, - false, []headers.Rule{ - {Path: "/foo", Headers: []headers.Header{ - {Name: "Foo", Value: "Bar Baz"}, - }}, - }}, - {`header /bar { Foo "Bar Baz" Baz Qux }`, - false, []headers.Rule{ - {Path: "/bar", Headers: []headers.Header{ - {Name: "Foo", Value: "Bar Baz"}, - {Name: "Baz", Value: "Qux"}, - }}, - }}, - } - - for i, test := range tests { - c := NewTestController(test.input) - actual, err := headersParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - - if len(actual) != len(test.expected) { - t.Fatalf("Test %d expected %d rules, but got %d", - i, len(test.expected), len(actual)) - } - - for j, expectedRule := range test.expected { - actualRule := actual[j] - - if actualRule.Path != expectedRule.Path { - t.Errorf("Test %d, rule %d: Expected path %s, but got %s", - i, j, expectedRule.Path, actualRule.Path) - } - - expectedHeaders := fmt.Sprintf("%v", expectedRule.Headers) - actualHeaders := fmt.Sprintf("%v", actualRule.Headers) - - if actualHeaders != expectedHeaders { - t.Errorf("Test %d, rule %d: Expected headers %s, but got %s", - i, j, expectedHeaders, actualHeaders) - } - } - } -} diff --git a/caddy/setup/internal.go b/caddy/setup/internal.go deleted file mode 100644 index e83863b80..000000000 --- a/caddy/setup/internal.go +++ /dev/null @@ -1,31 +0,0 @@ -package setup - -import ( - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/inner" -) - -// Internal configures a new Internal middleware instance. -func Internal(c *Controller) (middleware.Middleware, error) { - paths, err := internalParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return inner.Internal{Next: next, Paths: paths} - }, nil -} - -func internalParse(c *Controller) ([]string, error) { - var paths []string - - for c.Next() { - if !c.NextArg() { - return paths, c.ArgErr() - } - paths = append(paths, c.Val()) - } - - return paths, nil -} diff --git a/caddy/setup/internal_test.go b/caddy/setup/internal_test.go deleted file mode 100644 index f4d0ed8b9..000000000 --- a/caddy/setup/internal_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package setup - -import ( - "testing" - - "github.com/mholt/caddy/middleware/inner" -) - -func TestInternal(t *testing.T) { - c := NewTestController(`internal /internal`) - - mid, err := Internal(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(inner.Internal) - - if !ok { - t.Fatalf("Expected handler to be type Internal, got: %#v", handler) - } - - if myHandler.Paths[0] != "/internal" { - t.Errorf("Expected internal in the list of internal Paths") - } - - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } - -} - -func TestInternalParse(t *testing.T) { - tests := []struct { - inputInternalPaths string - shouldErr bool - expectedInternalPaths []string - }{ - {`internal /internal`, false, []string{"/internal"}}, - - {`internal /internal1 - internal /internal2`, false, []string{"/internal1", "/internal2"}}, - } - for i, test := range tests { - c := NewTestController(test.inputInternalPaths) - actualInternalPaths, err := internalParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - - if len(actualInternalPaths) != len(test.expectedInternalPaths) { - t.Fatalf("Test %d expected %d InternalPaths, but got %d", - i, len(test.expectedInternalPaths), len(actualInternalPaths)) - } - for j, actualInternalPath := range actualInternalPaths { - if actualInternalPath != test.expectedInternalPaths[j] { - t.Fatalf("Test %d expected %dth Internal Path to be %s , but got %s", - i, j, test.expectedInternalPaths[j], actualInternalPath) - } - } - } - -} diff --git a/caddy/setup/markdown.go b/caddy/setup/markdown.go deleted file mode 100644 index fdc91991a..000000000 --- a/caddy/setup/markdown.go +++ /dev/null @@ -1,127 +0,0 @@ -package setup - -import ( - "net/http" - "path/filepath" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/markdown" - "github.com/russross/blackfriday" -) - -// Markdown configures a new Markdown middleware instance. -func Markdown(c *Controller) (middleware.Middleware, error) { - mdconfigs, err := markdownParse(c) - if err != nil { - return nil, err - } - - md := markdown.Markdown{ - Root: c.Root, - FileSys: http.Dir(c.Root), - Configs: mdconfigs, - IndexFiles: []string{"index.md"}, - } - - return func(next middleware.Handler) middleware.Handler { - md.Next = next - return md - }, nil -} - -func markdownParse(c *Controller) ([]*markdown.Config, error) { - var mdconfigs []*markdown.Config - - for c.Next() { - md := &markdown.Config{ - Renderer: blackfriday.HtmlRenderer(0, "", ""), - Extensions: make(map[string]struct{}), - Template: markdown.GetDefaultTemplate(), - } - - // Get the path scope - args := c.RemainingArgs() - switch len(args) { - case 0: - md.PathScope = "/" - case 1: - md.PathScope = args[0] - default: - return mdconfigs, c.ArgErr() - } - - // Load any other configuration parameters - for c.NextBlock() { - if err := loadParams(c, md); err != nil { - return mdconfigs, err - } - } - - // If no extensions were specified, assume some defaults - if len(md.Extensions) == 0 { - md.Extensions[".md"] = struct{}{} - md.Extensions[".markdown"] = struct{}{} - md.Extensions[".mdown"] = struct{}{} - } - - mdconfigs = append(mdconfigs, md) - } - - return mdconfigs, nil -} - -func loadParams(c *Controller, mdc *markdown.Config) error { - switch c.Val() { - case "ext": - for _, ext := range c.RemainingArgs() { - mdc.Extensions[ext] = struct{}{} - } - return nil - case "css": - if !c.NextArg() { - return c.ArgErr() - } - mdc.Styles = append(mdc.Styles, c.Val()) - return nil - case "js": - if !c.NextArg() { - return c.ArgErr() - } - mdc.Scripts = append(mdc.Scripts, c.Val()) - return nil - case "template": - tArgs := c.RemainingArgs() - switch len(tArgs) { - default: - return c.ArgErr() - case 1: - fpath := filepath.ToSlash(filepath.Clean(c.Root + string(filepath.Separator) + tArgs[0])) - - if err := markdown.SetTemplate(mdc.Template, "", fpath); err != nil { - c.Errf("default template parse error: %v", err) - } - return nil - case 2: - fpath := filepath.ToSlash(filepath.Clean(c.Root + string(filepath.Separator) + tArgs[1])) - - if err := markdown.SetTemplate(mdc.Template, tArgs[0], fpath); err != nil { - c.Errf("template parse error: %v", err) - } - return nil - } - case "templatedir": - if !c.NextArg() { - return c.ArgErr() - } - _, err := mdc.Template.ParseGlob(c.Val()) - if err != nil { - c.Errf("template load error: %v", err) - } - if c.NextArg() { - return c.ArgErr() - } - return nil - default: - return c.Err("Expected valid markdown configuration property") - } -} diff --git a/caddy/setup/markdown_test.go b/caddy/setup/markdown_test.go deleted file mode 100644 index fee9a3326..000000000 --- a/caddy/setup/markdown_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package setup - -import ( - "bytes" - "fmt" - "net/http" - "testing" - "text/template" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/markdown" -) - -func TestMarkdown(t *testing.T) { - - c := NewTestController(`markdown /blog`) - - mid, err := Markdown(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(markdown.Markdown) - - if !ok { - t.Fatalf("Expected handler to be type Markdown, got: %#v", handler) - } - - if myHandler.Configs[0].PathScope != "/blog" { - t.Errorf("Expected /blog as the Path Scope") - } - if len(myHandler.Configs[0].Extensions) != 3 { - t.Error("Expected 3 markdown extensions") - } - for _, key := range []string{".md", ".markdown", ".mdown"} { - if ext, ok := myHandler.Configs[0].Extensions[key]; !ok { - t.Errorf("Expected extensions to contain %v", ext) - } - } -} - -func TestMarkdownParse(t *testing.T) { - tests := []struct { - inputMarkdownConfig string - shouldErr bool - expectedMarkdownConfig []markdown.Config - }{ - - {`markdown /blog { - ext .md .txt - css /resources/css/blog.css - js /resources/js/blog.js -}`, false, []markdown.Config{{ - PathScope: "/blog", - Extensions: map[string]struct{}{ - ".md": {}, - ".txt": {}, - }, - Styles: []string{"/resources/css/blog.css"}, - Scripts: []string{"/resources/js/blog.js"}, - Template: markdown.GetDefaultTemplate(), - }}}, - {`markdown /blog { - ext .md - template tpl_with_include.html -}`, false, []markdown.Config{{ - PathScope: "/blog", - Extensions: map[string]struct{}{ - ".md": {}, - }, - Template: markdown.GetDefaultTemplate(), - }}}, - } - // Setup the extra template - tmpl := tests[1].expectedMarkdownConfig[0].Template - markdown.SetTemplate(tmpl, "", "./testdata/tpl_with_include.html") - - for i, test := range tests { - c := NewTestController(test.inputMarkdownConfig) - c.Root = "./testdata" - actualMarkdownConfigs, err := markdownParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - if len(actualMarkdownConfigs) != len(test.expectedMarkdownConfig) { - t.Fatalf("Test %d expected %d no of WebSocket configs, but got %d ", - i, len(test.expectedMarkdownConfig), len(actualMarkdownConfigs)) - } - for j, actualMarkdownConfig := range actualMarkdownConfigs { - - if actualMarkdownConfig.PathScope != test.expectedMarkdownConfig[j].PathScope { - t.Errorf("Test %d expected %dth Markdown PathScope to be %s , but got %s", - i, j, test.expectedMarkdownConfig[j].PathScope, actualMarkdownConfig.PathScope) - } - - if fmt.Sprint(actualMarkdownConfig.Styles) != fmt.Sprint(test.expectedMarkdownConfig[j].Styles) { - t.Errorf("Test %d expected %dth Markdown Config Styles to be %s , but got %s", - i, j, fmt.Sprint(test.expectedMarkdownConfig[j].Styles), fmt.Sprint(actualMarkdownConfig.Styles)) - } - if fmt.Sprint(actualMarkdownConfig.Scripts) != fmt.Sprint(test.expectedMarkdownConfig[j].Scripts) { - t.Errorf("Test %d expected %dth Markdown Config Scripts to be %s , but got %s", - i, j, fmt.Sprint(test.expectedMarkdownConfig[j].Scripts), fmt.Sprint(actualMarkdownConfig.Scripts)) - } - if ok, tx, ty := equalTemplates(actualMarkdownConfig.Template, test.expectedMarkdownConfig[j].Template); !ok { - t.Errorf("Test %d the %dth Markdown Config Templates did not match, expected %s to be %s", i, j, tx, ty) - } - } - } -} - -func equalTemplates(i, j *template.Template) (bool, string, string) { - // Just in case :) - if i == j { - return true, "", "" - } - - // We can't do much here, templates can't really be compared. However, - // we can execute the templates and compare their outputs to be reasonably - // sure that they're the same. - - // This is exceedingly ugly. - ctx := middleware.Context{ - Root: http.Dir("./testdata"), - } - - md := markdown.Data{ - Context: ctx, - Doc: make(map[string]string), - DocFlags: make(map[string]bool), - Styles: []string{"style1"}, - Scripts: []string{"js1"}, - } - md.Doc["title"] = "some title" - md.Doc["body"] = "some body" - - bufi := new(bytes.Buffer) - bufj := new(bytes.Buffer) - - if err := i.Execute(bufi, md); err != nil { - return false, fmt.Sprintf("%v", err), "" - } - if err := j.Execute(bufj, md); err != nil { - return false, "", fmt.Sprintf("%v", err) - } - - return bytes.Equal(bufi.Bytes(), bufj.Bytes()), string(bufi.Bytes()), string(bufj.Bytes()) -} diff --git a/caddy/setup/mime.go b/caddy/setup/mime.go deleted file mode 100644 index 59667dc36..000000000 --- a/caddy/setup/mime.go +++ /dev/null @@ -1,65 +0,0 @@ -package setup - -import ( - "fmt" - "strings" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/mime" -) - -// Mime configures a new mime middleware instance. -func Mime(c *Controller) (middleware.Middleware, error) { - configs, err := mimeParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return mime.Mime{Next: next, Configs: configs} - }, nil -} - -func mimeParse(c *Controller) (mime.Config, error) { - configs := mime.Config{} - - for c.Next() { - // At least one extension is required - - args := c.RemainingArgs() - switch len(args) { - case 2: - if err := validateExt(configs, args[0]); err != nil { - return configs, err - } - configs[args[0]] = args[1] - case 1: - return configs, c.ArgErr() - case 0: - for c.NextBlock() { - ext := c.Val() - if err := validateExt(configs, ext); err != nil { - return configs, err - } - if !c.NextArg() { - return configs, c.ArgErr() - } - configs[ext] = c.Val() - } - } - - } - - return configs, nil -} - -// validateExt checks for valid file name extension. -func validateExt(configs mime.Config, ext string) error { - if !strings.HasPrefix(ext, ".") { - return fmt.Errorf(`mime: invalid extension "%v" (must start with dot)`, ext) - } - if _, ok := configs[ext]; ok { - return fmt.Errorf(`mime: duplicate extension "%v" found`, ext) - } - return nil -} diff --git a/caddy/setup/mime_test.go b/caddy/setup/mime_test.go deleted file mode 100644 index 7b11f3d57..000000000 --- a/caddy/setup/mime_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package setup - -import ( - "testing" - - "github.com/mholt/caddy/middleware/mime" -) - -func TestMime(t *testing.T) { - - c := NewTestController(`mime .txt text/plain`) - - mid, err := Mime(c) - if err != nil { - t.Errorf("Expected no errors, but got: %v", err) - } - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(mime.Mime) - if !ok { - t.Fatalf("Expected handler to be type Mime, got: %#v", handler) - } - - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } - - tests := []struct { - input string - shouldErr bool - }{ - {`mime {`, true}, - {`mime {}`, true}, - {`mime a b`, true}, - {`mime a {`, true}, - {`mime { txt f } `, true}, - {`mime { html } `, true}, - {`mime { - .html text/html - .txt text/plain - } `, false}, - {`mime { - .foo text/foo - .bar text/bar - .foo text/foobar - } `, true}, - {`mime { .html text/html } `, false}, - {`mime { .html - } `, true}, - {`mime .txt text/plain`, false}, - } - for i, test := range tests { - c := NewTestController(test.input) - m, err := mimeParse(c) - if test.shouldErr && err == nil { - t.Errorf("Test %v: Expected error but found nil %v", i, m) - } else if !test.shouldErr && err != nil { - t.Errorf("Test %v: Expected no error but found error: %v", i, err) - } - } -} diff --git a/caddy/setup/pprof.go b/caddy/setup/pprof.go deleted file mode 100644 index 010485026..000000000 --- a/caddy/setup/pprof.go +++ /dev/null @@ -1,27 +0,0 @@ -package setup - -import ( - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/pprof" -) - -//PProf returns a new instance of a pprof handler. It accepts no arguments or options. -func PProf(c *Controller) (middleware.Middleware, error) { - found := false - for c.Next() { - if found { - return nil, c.Err("pprof can only be specified once") - } - if len(c.RemainingArgs()) != 0 { - return nil, c.ArgErr() - } - if c.NextBlock() { - return nil, c.ArgErr() - } - found = true - } - - return func(next middleware.Handler) middleware.Handler { - return &pprof.Handler{Next: next, Mux: pprof.NewMux()} - }, nil -} diff --git a/caddy/setup/pprof_test.go b/caddy/setup/pprof_test.go deleted file mode 100644 index ac9375af7..000000000 --- a/caddy/setup/pprof_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package setup - -import "testing" - -func TestPProf(t *testing.T) { - tests := []struct { - input string - shouldErr bool - }{ - {`pprof`, false}, - {`pprof {}`, true}, - {`pprof /foo`, true}, - {`pprof { - a b - }`, true}, - {`pprof - pprof`, true}, - } - for i, test := range tests { - c := NewTestController(test.input) - _, err := PProf(c) - if test.shouldErr && err == nil { - t.Errorf("Test %v: Expected error but found nil", i) - } else if !test.shouldErr && err != nil { - t.Errorf("Test %v: Expected no error but found error: %v", i, err) - } - } -} diff --git a/caddy/setup/proxy.go b/caddy/setup/proxy.go deleted file mode 100644 index 3011cb0e4..000000000 --- a/caddy/setup/proxy.go +++ /dev/null @@ -1,17 +0,0 @@ -package setup - -import ( - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/proxy" -) - -// Proxy configures a new Proxy middleware instance. -func Proxy(c *Controller) (middleware.Middleware, error) { - upstreams, err := proxy.NewStaticUpstreams(c.Dispenser) - if err != nil { - return nil, err - } - return func(next middleware.Handler) middleware.Handler { - return proxy.Proxy{Next: next, Upstreams: upstreams} - }, nil -} diff --git a/caddy/setup/proxy_test.go b/caddy/setup/proxy_test.go deleted file mode 100644 index 3d6d04a09..000000000 --- a/caddy/setup/proxy_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package setup - -import ( - "reflect" - "testing" - - "github.com/mholt/caddy/middleware/proxy" -) - -func TestUpstream(t *testing.T) { - for i, test := range []struct { - input string - shouldErr bool - expectedHosts map[string]struct{} - }{ - // test #0 test usual to destination still works normally - { - "proxy / localhost:80", - false, - map[string]struct{}{ - "http://localhost:80": {}, - }, - }, - - // test #1 test usual to destination with port range - { - "proxy / localhost:8080-8082", - false, - map[string]struct{}{ - "http://localhost:8080": {}, - "http://localhost:8081": {}, - "http://localhost:8082": {}, - }, - }, - - // test #2 test upstream directive - { - "proxy / {\n upstream localhost:8080\n}", - false, - map[string]struct{}{ - "http://localhost:8080": {}, - }, - }, - - // test #3 test upstream directive with port range - { - "proxy / {\n upstream localhost:8080-8081\n}", - false, - map[string]struct{}{ - "http://localhost:8080": {}, - "http://localhost:8081": {}, - }, - }, - - // test #4 test to destination with upstream directive - { - "proxy / localhost:8080 {\n upstream localhost:8081-8082\n}", - false, - map[string]struct{}{ - "http://localhost:8080": {}, - "http://localhost:8081": {}, - "http://localhost:8082": {}, - }, - }, - - // test #5 test with unix sockets - { - "proxy / localhost:8080 {\n upstream unix:/var/foo\n}", - false, - map[string]struct{}{ - "http://localhost:8080": {}, - "unix:/var/foo": {}, - }, - }, - - // test #6 test fail on malformed port range - { - "proxy / localhost:8090-8080", - true, - nil, - }, - - // test #7 test fail on malformed port range 2 - { - "proxy / {\n upstream localhost:80-A\n}", - true, - nil, - }, - - // test #8 test upstreams without ports work correctly - { - "proxy / http://localhost {\n upstream testendpoint\n}", - false, - map[string]struct{}{ - "http://localhost": {}, - "http://testendpoint": {}, - }, - }, - - // test #9 test several upstream directives - { - "proxy / localhost:8080 {\n upstream localhost:8081-8082\n upstream localhost:8083-8085\n}", - false, - map[string]struct{}{ - "http://localhost:8080": {}, - "http://localhost:8081": {}, - "http://localhost:8082": {}, - "http://localhost:8083": {}, - "http://localhost:8084": {}, - "http://localhost:8085": {}, - }, - }, - } { - receivedFunc, err := Proxy(NewTestController(test.input)) - if err != nil && !test.shouldErr { - t.Errorf("Test case #%d received an error of %v", i, err) - } else if test.shouldErr { - continue - } - - upstreams := receivedFunc(nil).(proxy.Proxy).Upstreams - for _, upstream := range upstreams { - val := reflect.ValueOf(upstream).Elem() - hosts := val.FieldByName("Hosts").Interface().(proxy.HostPool) - if len(hosts) != len(test.expectedHosts) { - t.Errorf("Test case #%d expected %d hosts but received %d", i, len(test.expectedHosts), len(hosts)) - } else { - for _, host := range hosts { - if _, found := test.expectedHosts[host.Name]; !found { - t.Errorf("Test case #%d has an unexpected host %s", i, host.Name) - } - } - } - } - } -} diff --git a/caddy/setup/redir.go b/caddy/setup/redir.go deleted file mode 100644 index 63488f4ab..000000000 --- a/caddy/setup/redir.go +++ /dev/null @@ -1,173 +0,0 @@ -package setup - -import ( - "net/http" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/redirect" -) - -// Redir configures a new Redirect middleware instance. -func Redir(c *Controller) (middleware.Middleware, error) { - rules, err := redirParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return redirect.Redirect{Next: next, Rules: rules} - }, nil -} - -func redirParse(c *Controller) ([]redirect.Rule, error) { - var redirects []redirect.Rule - - // setRedirCode sets the redirect code for rule if it can, or returns an error - setRedirCode := func(code string, rule *redirect.Rule) error { - if code == "meta" { - rule.Meta = true - } else if codeNumber, ok := httpRedirs[code]; ok { - rule.Code = codeNumber - } else { - return c.Errf("Invalid redirect code '%v'", code) - } - return nil - } - - // checkAndSaveRule checks the rule for validity (except the redir code) - // and saves it if it's valid, or returns an error. - checkAndSaveRule := func(rule redirect.Rule) error { - if rule.FromPath == rule.To { - return c.Err("'from' and 'to' values of redirect rule cannot be the same") - } - - for _, otherRule := range redirects { - if otherRule.FromPath == rule.FromPath { - return c.Errf("rule with duplicate 'from' value: %s -> %s", otherRule.FromPath, otherRule.To) - } - } - - redirects = append(redirects, rule) - return nil - } - - for c.Next() { - args := c.RemainingArgs() - - var hadOptionalBlock bool - for c.NextBlock() { - hadOptionalBlock = true - - var rule redirect.Rule - - if c.Config.TLS.Enabled { - rule.FromScheme = "https" - } else { - rule.FromScheme = "http" - } - - // Set initial redirect code - // BUG: If the code is specified for a whole block and that code is invalid, - // the line number will appear on the first line inside the block, even if that - // line overwrites the block-level code with a valid redirect code. The program - // still functions correctly, but the line number in the error reporting is - // misleading to the user. - if len(args) == 1 { - err := setRedirCode(args[0], &rule) - if err != nil { - return redirects, err - } - } else { - rule.Code = http.StatusMovedPermanently // default code - } - - // RemainingArgs only gets the values after the current token, but in our - // case we want to include the current token to get an accurate count. - insideArgs := append([]string{c.Val()}, c.RemainingArgs()...) - - switch len(insideArgs) { - case 1: - // To specified (catch-all redirect) - // Not sure why user is doing this in a table, as it causes all other redirects to be ignored. - // As such, this feature remains undocumented. - rule.FromPath = "/" - rule.To = insideArgs[0] - case 2: - // From and To specified - rule.FromPath = insideArgs[0] - rule.To = insideArgs[1] - case 3: - // From, To, and Code specified - rule.FromPath = insideArgs[0] - rule.To = insideArgs[1] - err := setRedirCode(insideArgs[2], &rule) - if err != nil { - return redirects, err - } - default: - return redirects, c.ArgErr() - } - - err := checkAndSaveRule(rule) - if err != nil { - return redirects, err - } - } - - if !hadOptionalBlock { - var rule redirect.Rule - - if c.Config.TLS.Enabled { - rule.FromScheme = "https" - } else { - rule.FromScheme = "http" - } - - rule.Code = http.StatusMovedPermanently // default - - switch len(args) { - case 1: - // To specified (catch-all redirect) - rule.FromPath = "/" - rule.To = args[0] - case 2: - // To and Code specified (catch-all redirect) - rule.FromPath = "/" - rule.To = args[0] - err := setRedirCode(args[1], &rule) - if err != nil { - return redirects, err - } - case 3: - // From, To, and Code specified - rule.FromPath = args[0] - rule.To = args[1] - err := setRedirCode(args[2], &rule) - if err != nil { - return redirects, err - } - default: - return redirects, c.ArgErr() - } - - err := checkAndSaveRule(rule) - if err != nil { - return redirects, err - } - } - } - - return redirects, nil -} - -// httpRedirs is a list of supported HTTP redirect codes. -var httpRedirs = map[string]int{ - "300": http.StatusMultipleChoices, - "301": http.StatusMovedPermanently, - "302": http.StatusFound, // (NOT CORRECT for "Temporary Redirect", see 307) - "303": http.StatusSeeOther, - "304": http.StatusNotModified, - "305": http.StatusUseProxy, - "307": http.StatusTemporaryRedirect, - "308": 308, // Permanent Redirect -} diff --git a/caddy/setup/redir_test.go b/caddy/setup/redir_test.go deleted file mode 100644 index 0285784fa..000000000 --- a/caddy/setup/redir_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package setup - -import ( - "testing" - - "github.com/mholt/caddy/middleware/redirect" -) - -func TestRedir(t *testing.T) { - - for j, test := range []struct { - input string - shouldErr bool - expectedRules []redirect.Rule - }{ - // test case #0 tests the recognition of a valid HTTP status code defined outside of block statement - {"redir 300 {\n/ /foo\n}", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 300}}}, - - // test case #1 tests the recognition of an invalid HTTP status code defined outside of block statement - {"redir 9000 {\n/ /foo\n}", true, []redirect.Rule{{}}}, - - // test case #2 tests the detection of a valid HTTP status code outside of a block statement being overriden by an invalid HTTP status code inside statement of a block statement - {"redir 300 {\n/ /foo 9000\n}", true, []redirect.Rule{{}}}, - - // test case #3 tests the detection of an invalid HTTP status code outside of a block statement being overriden by a valid HTTP status code inside statement of a block statement - {"redir 9000 {\n/ /foo 300\n}", true, []redirect.Rule{{}}}, - - // test case #4 tests the recognition of a TO redirection in a block statement.The HTTP status code is set to the default of 301 - MovedPermanently - {"redir 302 {\n/foo\n}", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 302}}}, - - // test case #5 tests the recognition of a TO and From redirection in a block statement - {"redir {\n/bar /foo 303\n}", false, []redirect.Rule{{FromPath: "/bar", To: "/foo", Code: 303}}}, - - // test case #6 tests the recognition of a TO redirection in a non-block statement. The HTTP status code is set to the default of 301 - MovedPermanently - {"redir /foo", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 301}}}, - - // test case #7 tests the recognition of a TO and From redirection in a non-block statement - {"redir /bar /foo 303", false, []redirect.Rule{{FromPath: "/bar", To: "/foo", Code: 303}}}, - - // test case #8 tests the recognition of multiple redirections - {"redir {\n / /foo 304 \n} \n redir {\n /bar /foobar 305 \n}", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 304}, {FromPath: "/bar", To: "/foobar", Code: 305}}}, - - // test case #9 tests the detection of duplicate redirections - {"redir {\n /bar /foo 304 \n} redir {\n /bar /foo 304 \n}", true, []redirect.Rule{{}}}, - } { - recievedFunc, err := Redir(NewTestController(test.input)) - if err != nil && !test.shouldErr { - t.Errorf("Test case #%d recieved an error of %v", j, err) - } else if test.shouldErr { - continue - } - recievedRules := recievedFunc(nil).(redirect.Redirect).Rules - - for i, recievedRule := range recievedRules { - if recievedRule.FromPath != test.expectedRules[i].FromPath { - t.Errorf("Test case #%d.%d expected a from path of %s, but recieved a from path of %s", j, i, test.expectedRules[i].FromPath, recievedRule.FromPath) - } - if recievedRule.To != test.expectedRules[i].To { - t.Errorf("Test case #%d.%d expected a TO path of %s, but recieved a TO path of %s", j, i, test.expectedRules[i].To, recievedRule.To) - } - if recievedRule.Code != test.expectedRules[i].Code { - t.Errorf("Test case #%d.%d expected a HTTP status code of %d, but recieved a code of %d", j, i, test.expectedRules[i].Code, recievedRule.Code) - } - } - } - -} diff --git a/caddy/setup/rewrite.go b/caddy/setup/rewrite.go deleted file mode 100644 index b270c93dd..000000000 --- a/caddy/setup/rewrite.go +++ /dev/null @@ -1,109 +0,0 @@ -package setup - -import ( - "net/http" - "strconv" - "strings" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/rewrite" -) - -// Rewrite configures a new Rewrite middleware instance. -func Rewrite(c *Controller) (middleware.Middleware, error) { - rewrites, err := rewriteParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return rewrite.Rewrite{ - Next: next, - FileSys: http.Dir(c.Root), - Rules: rewrites, - } - }, nil -} - -func rewriteParse(c *Controller) ([]rewrite.Rule, error) { - var simpleRules []rewrite.Rule - var regexpRules []rewrite.Rule - - for c.Next() { - var rule rewrite.Rule - var err error - var base = "/" - var pattern, to string - var status int - var ext []string - - args := c.RemainingArgs() - - var ifs []rewrite.If - - switch len(args) { - case 1: - base = args[0] - fallthrough - case 0: - for c.NextBlock() { - switch c.Val() { - case "r", "regexp": - if !c.NextArg() { - return nil, c.ArgErr() - } - pattern = c.Val() - case "to": - args1 := c.RemainingArgs() - if len(args1) == 0 { - return nil, c.ArgErr() - } - to = strings.Join(args1, " ") - case "ext": - args1 := c.RemainingArgs() - if len(args1) == 0 { - return nil, c.ArgErr() - } - ext = args1 - case "if": - args1 := c.RemainingArgs() - if len(args1) != 3 { - return nil, c.ArgErr() - } - ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2]) - if err != nil { - return nil, err - } - ifs = append(ifs, ifCond) - case "status": - if !c.NextArg() { - return nil, c.ArgErr() - } - status, _ = strconv.Atoi(c.Val()) - if status < 200 || (status > 299 && status < 400) || status > 499 { - return nil, c.Err("status must be 2xx or 4xx") - } - default: - return nil, c.ArgErr() - } - } - // ensure to or status is specified - if to == "" && status == 0 { - return nil, c.ArgErr() - } - if rule, err = rewrite.NewComplexRule(base, pattern, to, status, ext, ifs); err != nil { - return nil, err - } - regexpRules = append(regexpRules, rule) - - // the only unhandled case is 2 and above - default: - rule = rewrite.NewSimpleRule(args[0], strings.Join(args[1:], " ")) - simpleRules = append(simpleRules, rule) - } - - } - - // put simple rules in front to avoid regexp computation for them - return append(simpleRules, regexpRules...), nil -} diff --git a/caddy/setup/rewrite_test.go b/caddy/setup/rewrite_test.go deleted file mode 100644 index d252ed904..000000000 --- a/caddy/setup/rewrite_test.go +++ /dev/null @@ -1,241 +0,0 @@ -package setup - -import ( - "fmt" - "regexp" - "testing" - - "github.com/mholt/caddy/middleware/rewrite" -) - -func TestRewrite(t *testing.T) { - c := NewTestController(`rewrite /from /to`) - - mid, err := Rewrite(c) - if err != nil { - t.Errorf("Expected no errors, but got: %v", err) - } - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(rewrite.Rewrite) - if !ok { - t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler) - } - - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } - - if len(myHandler.Rules) != 1 { - t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules)) - } -} - -func TestRewriteParse(t *testing.T) { - simpleTests := []struct { - input string - shouldErr bool - expected []rewrite.Rule - }{ - {`rewrite /from /to`, false, []rewrite.Rule{ - rewrite.SimpleRule{From: "/from", To: "/to"}, - }}, - {`rewrite /from /to - rewrite a b`, false, []rewrite.Rule{ - rewrite.SimpleRule{From: "/from", To: "/to"}, - rewrite.SimpleRule{From: "a", To: "b"}, - }}, - {`rewrite a`, true, []rewrite.Rule{}}, - {`rewrite`, true, []rewrite.Rule{}}, - {`rewrite a b c`, false, []rewrite.Rule{ - rewrite.SimpleRule{From: "a", To: "b c"}, - }}, - } - - for i, test := range simpleTests { - c := NewTestController(test.input) - actual, err := rewriteParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } else if err != nil && test.shouldErr { - continue - } - - if len(actual) != len(test.expected) { - t.Fatalf("Test %d expected %d rules, but got %d", - i, len(test.expected), len(actual)) - } - - for j, e := range test.expected { - actualRule := actual[j].(rewrite.SimpleRule) - expectedRule := e.(rewrite.SimpleRule) - - if actualRule.From != expectedRule.From { - t.Errorf("Test %d, rule %d: Expected From=%s, got %s", - i, j, expectedRule.From, actualRule.From) - } - - if actualRule.To != expectedRule.To { - t.Errorf("Test %d, rule %d: Expected To=%s, got %s", - i, j, expectedRule.To, actualRule.To) - } - } - } - - regexpTests := []struct { - input string - shouldErr bool - expected []rewrite.Rule - }{ - {`rewrite { - r .* - to /to /index.php? - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")}, - }}, - {`rewrite { - regexp .* - to /to - ext / html txt - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, - }}, - {`rewrite /path { - r rr - to /dest - } - rewrite / { - regexp [a-z]+ - to /to /to2 - } - `, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, - &rewrite.ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")}, - }}, - {`rewrite { - r .* - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite /`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - to /to - if {path} is a - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{{A: "{path}", Operator: "is", B: "a"}}}, - }}, - {`rewrite { - status 500 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - status 400 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", Status: 400}, - }}, - {`rewrite { - to /to - status 400 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Status: 400}, - }}, - {`rewrite { - status 399 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - status 200 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", Status: 200}, - }}, - {`rewrite { - to /to - status 200 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Status: 200}, - }}, - {`rewrite { - status 199 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - status 0 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - to /to - status 0 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - } - - for i, test := range regexpTests { - c := NewTestController(test.input) - actual, err := rewriteParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } else if err != nil && test.shouldErr { - continue - } - - if len(actual) != len(test.expected) { - t.Fatalf("Test %d expected %d rules, but got %d", - i, len(test.expected), len(actual)) - } - - for j, e := range test.expected { - actualRule := actual[j].(*rewrite.ComplexRule) - expectedRule := e.(*rewrite.ComplexRule) - - if actualRule.Base != expectedRule.Base { - t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", - i, j, expectedRule.Base, actualRule.Base) - } - - if actualRule.To != expectedRule.To { - t.Errorf("Test %d, rule %d: Expected To=%s, got %s", - i, j, expectedRule.To, actualRule.To) - } - - if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) { - t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v", - i, j, expectedRule.To, actualRule.To) - } - - if actualRule.Regexp != nil { - if actualRule.String() != expectedRule.String() { - t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", - i, j, expectedRule.String(), actualRule.String()) - } - } - - if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) { - t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", - i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs)) - } - - } - } - -} diff --git a/caddy/setup/root.go b/caddy/setup/root.go deleted file mode 100644 index 5100f6961..000000000 --- a/caddy/setup/root.go +++ /dev/null @@ -1,32 +0,0 @@ -package setup - -import ( - "log" - "os" - - "github.com/mholt/caddy/middleware" -) - -// Root sets up the root file path of the server. -func Root(c *Controller) (middleware.Middleware, error) { - for c.Next() { - if !c.NextArg() { - return nil, c.ArgErr() - } - c.Root = c.Val() - } - - // Check if root path exists - _, err := os.Stat(c.Root) - if err != nil { - if os.IsNotExist(err) { - // Allow this, because the folder might appear later. - // But make sure the user knows! - log.Printf("[WARNING] Root path does not exist: %s", c.Root) - } else { - return nil, c.Errf("Unable to access root path '%s': %v", c.Root, err) - } - } - - return nil, nil -} diff --git a/caddy/setup/templates.go b/caddy/setup/templates.go deleted file mode 100644 index f8d7e98bd..000000000 --- a/caddy/setup/templates.go +++ /dev/null @@ -1,90 +0,0 @@ -package setup - -import ( - "net/http" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/templates" -) - -// Templates configures a new Templates middleware instance. -func Templates(c *Controller) (middleware.Middleware, error) { - rules, err := templatesParse(c) - if err != nil { - return nil, err - } - - tmpls := templates.Templates{ - Rules: rules, - Root: c.Root, - FileSys: http.Dir(c.Root), - } - - return func(next middleware.Handler) middleware.Handler { - tmpls.Next = next - return tmpls - }, nil -} - -func templatesParse(c *Controller) ([]templates.Rule, error) { - var rules []templates.Rule - - for c.Next() { - var rule templates.Rule - - rule.Path = defaultTemplatePath - rule.Extensions = defaultTemplateExtensions - - args := c.RemainingArgs() - - switch len(args) { - case 0: - // Optional block - for c.NextBlock() { - switch c.Val() { - case "path": - args := c.RemainingArgs() - if len(args) != 1 { - return nil, c.ArgErr() - } - rule.Path = args[0] - - case "ext": - args := c.RemainingArgs() - if len(args) == 0 { - return nil, c.ArgErr() - } - rule.Extensions = args - - case "between": - args := c.RemainingArgs() - if len(args) != 2 { - return nil, c.ArgErr() - } - rule.Delims[0] = args[0] - rule.Delims[1] = args[1] - } - } - default: - // First argument would be the path - rule.Path = args[0] - - // Any remaining arguments are extensions - rule.Extensions = args[1:] - if len(rule.Extensions) == 0 { - rule.Extensions = defaultTemplateExtensions - } - } - - for _, ext := range rule.Extensions { - rule.IndexFiles = append(rule.IndexFiles, "index"+ext) - } - - rules = append(rules, rule) - } - return rules, nil -} - -const defaultTemplatePath = "/" - -var defaultTemplateExtensions = []string{".html", ".htm", ".tmpl", ".tpl", ".txt"} diff --git a/caddy/setup/templates_test.go b/caddy/setup/templates_test.go deleted file mode 100644 index b1cfb29ce..000000000 --- a/caddy/setup/templates_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package setup - -import ( - "fmt" - "testing" - - "github.com/mholt/caddy/middleware/templates" -) - -func TestTemplates(t *testing.T) { - - c := NewTestController(`templates`) - - mid, err := Templates(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(templates.Templates) - - if !ok { - t.Fatalf("Expected handler to be type Templates, got: %#v", handler) - } - - if myHandler.Rules[0].Path != defaultTemplatePath { - t.Errorf("Expected / as the default Path") - } - if fmt.Sprint(myHandler.Rules[0].Extensions) != fmt.Sprint(defaultTemplateExtensions) { - t.Errorf("Expected %v to be the Default Extensions", defaultTemplateExtensions) - } - var indexFiles []string - for _, extension := range defaultTemplateExtensions { - indexFiles = append(indexFiles, "index"+extension) - } - if fmt.Sprint(myHandler.Rules[0].IndexFiles) != fmt.Sprint(indexFiles) { - t.Errorf("Expected %v to be the Default Index files", indexFiles) - } - if myHandler.Rules[0].Delims != [2]string{} { - t.Errorf("Expected %v to be the Default Delims", [2]string{}) - } -} - -func TestTemplatesParse(t *testing.T) { - tests := []struct { - inputTemplateConfig string - shouldErr bool - expectedTemplateConfig []templates.Rule - }{ - {`templates /api1`, false, []templates.Rule{{ - Path: "/api1", - Extensions: defaultTemplateExtensions, - Delims: [2]string{}, - }}}, - {`templates /api2 .txt .htm`, false, []templates.Rule{{ - Path: "/api2", - Extensions: []string{".txt", ".htm"}, - Delims: [2]string{}, - }}}, - - {`templates /api3 .htm .html - templates /api4 .txt .tpl `, false, []templates.Rule{{ - Path: "/api3", - Extensions: []string{".htm", ".html"}, - Delims: [2]string{}, - }, { - Path: "/api4", - Extensions: []string{".txt", ".tpl"}, - Delims: [2]string{}, - }}}, - {`templates { - path /api5 - ext .html - between {% %} - }`, false, []templates.Rule{{ - Path: "/api5", - Extensions: []string{".html"}, - Delims: [2]string{"{%", "%}"}, - }}}, - } - for i, test := range tests { - c := NewTestController(test.inputTemplateConfig) - actualTemplateConfigs, err := templatesParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - if len(actualTemplateConfigs) != len(test.expectedTemplateConfig) { - t.Fatalf("Test %d expected %d no of Template configs, but got %d ", - i, len(test.expectedTemplateConfig), len(actualTemplateConfigs)) - } - for j, actualTemplateConfig := range actualTemplateConfigs { - - if actualTemplateConfig.Path != test.expectedTemplateConfig[j].Path { - t.Errorf("Test %d expected %dth Template Config Path to be %s , but got %s", - i, j, test.expectedTemplateConfig[j].Path, actualTemplateConfig.Path) - } - - if fmt.Sprint(actualTemplateConfig.Extensions) != fmt.Sprint(test.expectedTemplateConfig[j].Extensions) { - t.Errorf("Expected %v to be the Extensions , but got %v instead", test.expectedTemplateConfig[j].Extensions, actualTemplateConfig.Extensions) - } - } - } - -} diff --git a/caddy/setup/testdata/blog/first_post.md b/caddy/setup/testdata/blog/first_post.md deleted file mode 100644 index f26583b75..000000000 --- a/caddy/setup/testdata/blog/first_post.md +++ /dev/null @@ -1 +0,0 @@ -# Test h1 diff --git a/caddy/setup/testdata/header.html b/caddy/setup/testdata/header.html deleted file mode 100644 index 9c96e0e37..000000000 --- a/caddy/setup/testdata/header.html +++ /dev/null @@ -1 +0,0 @@ -

Header title

diff --git a/caddy/setup/testdata/tpl_with_include.html b/caddy/setup/testdata/tpl_with_include.html deleted file mode 100644 index 95eeae0c8..000000000 --- a/caddy/setup/testdata/tpl_with_include.html +++ /dev/null @@ -1,10 +0,0 @@ - - - -{{.Doc.title}} - - -{{.Include "header.html"}} -{{.Doc.body}} - - diff --git a/caddy/setup/websocket.go b/caddy/setup/websocket.go deleted file mode 100644 index 17617c406..000000000 --- a/caddy/setup/websocket.go +++ /dev/null @@ -1,87 +0,0 @@ -package setup - -import ( - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/websocket" -) - -// WebSocket configures a new WebSocket middleware instance. -func WebSocket(c *Controller) (middleware.Middleware, error) { - - websocks, err := webSocketParse(c) - if err != nil { - return nil, err - } - websocket.GatewayInterface = c.AppName + "-CGI/1.1" - websocket.ServerSoftware = c.AppName + "/" + c.AppVersion - - return func(next middleware.Handler) middleware.Handler { - return websocket.WebSocket{Next: next, Sockets: websocks} - }, nil -} - -func webSocketParse(c *Controller) ([]websocket.Config, error) { - var websocks []websocket.Config - var respawn bool - - optionalBlock := func() (hadBlock bool, err error) { - for c.NextBlock() { - hadBlock = true - if c.Val() == "respawn" { - respawn = true - } else { - return true, c.Err("Expected websocket configuration parameter in block") - } - } - return - } - - for c.Next() { - var val, path, command string - - // Path or command; not sure which yet - if !c.NextArg() { - return nil, c.ArgErr() - } - val = c.Val() - - // Extra configuration may be in a block - hadBlock, err := optionalBlock() - if err != nil { - return nil, err - } - - if !hadBlock { - // The next argument on this line will be the command or an open curly brace - if c.NextArg() { - path = val - command = c.Val() - } else { - path = "/" - command = val - } - - // Okay, check again for optional block - _, err = optionalBlock() - if err != nil { - return nil, err - } - } - - // Split command into the actual command and its arguments - cmd, args, err := middleware.SplitCommandAndArgs(command) - if err != nil { - return nil, err - } - - websocks = append(websocks, websocket.Config{ - Path: path, - Command: cmd, - Arguments: args, - Respawn: respawn, // TODO: This isn't used currently - }) - } - - return websocks, nil - -} diff --git a/caddy/setup/websocket_test.go b/caddy/setup/websocket_test.go deleted file mode 100644 index ae3513602..000000000 --- a/caddy/setup/websocket_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package setup - -import ( - "testing" - - "github.com/mholt/caddy/middleware/websocket" -) - -func TestWebSocket(t *testing.T) { - - c := NewTestController(`websocket cat`) - - mid, err := WebSocket(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(websocket.WebSocket) - - if !ok { - t.Fatalf("Expected handler to be type WebSocket, got: %#v", handler) - } - - if myHandler.Sockets[0].Path != "/" { - t.Errorf("Expected / as the default Path") - } - if myHandler.Sockets[0].Command != "cat" { - t.Errorf("Expected %s as the command", "cat") - } - -} -func TestWebSocketParse(t *testing.T) { - tests := []struct { - inputWebSocketConfig string - shouldErr bool - expectedWebSocketConfig []websocket.Config - }{ - {`websocket /api1 cat`, false, []websocket.Config{{ - Path: "/api1", - Command: "cat", - }}}, - - {`websocket /api3 cat - websocket /api4 cat `, false, []websocket.Config{{ - Path: "/api3", - Command: "cat", - }, { - Path: "/api4", - Command: "cat", - }}}, - - {`websocket /api5 "cmd arg1 arg2 arg3"`, false, []websocket.Config{{ - Path: "/api5", - Command: "cmd", - Arguments: []string{"arg1", "arg2", "arg3"}, - }}}, - - // accept respawn - {`websocket /api6 cat { - respawn - }`, false, []websocket.Config{{ - Path: "/api6", - Command: "cat", - }}}, - - // invalid configuration - {`websocket /api7 cat { - invalid - }`, true, []websocket.Config{}}, - } - for i, test := range tests { - c := NewTestController(test.inputWebSocketConfig) - actualWebSocketConfigs, err := webSocketParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - if len(actualWebSocketConfigs) != len(test.expectedWebSocketConfig) { - t.Fatalf("Test %d expected %d no of WebSocket configs, but got %d ", - i, len(test.expectedWebSocketConfig), len(actualWebSocketConfigs)) - } - for j, actualWebSocketConfig := range actualWebSocketConfigs { - - if actualWebSocketConfig.Path != test.expectedWebSocketConfig[j].Path { - t.Errorf("Test %d expected %dth WebSocket Config Path to be %s , but got %s", - i, j, test.expectedWebSocketConfig[j].Path, actualWebSocketConfig.Path) - } - - if actualWebSocketConfig.Command != test.expectedWebSocketConfig[j].Command { - t.Errorf("Test %d expected %dth WebSocket Config Command to be %s , but got %s", - i, j, test.expectedWebSocketConfig[j].Command, actualWebSocketConfig.Command) - } - - } - } - -} diff --git a/caddy/parse/dispenser.go b/caddyfile/dispenser.go similarity index 99% rename from caddy/parse/dispenser.go rename to caddyfile/dispenser.go index 08aa6e76d..7fa169d45 100644 --- a/caddy/parse/dispenser.go +++ b/caddyfile/dispenser.go @@ -1,4 +1,4 @@ -package parse +package caddyfile import ( "errors" diff --git a/caddy/parse/dispenser_test.go b/caddyfile/dispenser_test.go similarity index 99% rename from caddy/parse/dispenser_test.go rename to caddyfile/dispenser_test.go index 20a7ddcac..313e273b0 100644 --- a/caddy/parse/dispenser_test.go +++ b/caddyfile/dispenser_test.go @@ -1,4 +1,4 @@ -package parse +package caddyfile import ( "reflect" diff --git a/caddy/caddyfile/json.go b/caddyfile/json.go similarity index 80% rename from caddy/caddyfile/json.go rename to caddyfile/json.go index e1213c27d..52c7b90fd 100644 --- a/caddy/caddyfile/json.go +++ b/caddyfile/json.go @@ -4,31 +4,26 @@ import ( "bytes" "encoding/json" "fmt" - "net" "sort" "strconv" "strings" - - "github.com/mholt/caddy/caddy/parse" ) const filename = "Caddyfile" // ToJSON converts caddyfile to its JSON representation. func ToJSON(caddyfile []byte) ([]byte, error) { - var j Caddyfile + var j EncodedCaddyfile - serverBlocks, err := parse.ServerBlocks(filename, bytes.NewReader(caddyfile), false) + serverBlocks, err := ServerBlocks(filename, bytes.NewReader(caddyfile), nil) if err != nil { return nil, err } for _, sb := range serverBlocks { - block := ServerBlock{Body: [][]interface{}{}} - - // Fill up host list - for _, host := range sb.HostList() { - block.Hosts = append(block.Hosts, standardizeScheme(host)) + block := EncodedServerBlock{ + Keys: sb.Keys, + Body: [][]interface{}{}, } // Extract directives deterministically by sorting them @@ -40,7 +35,7 @@ func ToJSON(caddyfile []byte) ([]byte, error) { // Convert each directive's tokens into our JSON structure for _, dir := range directives { - disp := parse.NewDispenserTokens(filename, sb.Tokens[dir]) + disp := NewDispenserTokens(filename, sb.Tokens[dir]) for disp.Next() { block.Body = append(block.Body, constructLine(&disp)) } @@ -62,7 +57,7 @@ func ToJSON(caddyfile []byte) ([]byte, error) { // but only one line at a time, to be used at the top-level of // a server block only (where the first token on each line is a // directive) - not to be used at any other nesting level. -func constructLine(d *parse.Dispenser) []interface{} { +func constructLine(d *Dispenser) []interface{} { var args []interface{} args = append(args, d.Val()) @@ -81,7 +76,7 @@ func constructLine(d *parse.Dispenser) []interface{} { // constructBlock recursively processes tokens into a // JSON-encodable structure. To be used in a directive's // block. Goes to end of block. -func constructBlock(d *parse.Dispenser) [][]interface{} { +func constructBlock(d *Dispenser) [][]interface{} { block := [][]interface{}{} for d.Next() { @@ -96,7 +91,7 @@ func constructBlock(d *parse.Dispenser) [][]interface{} { // FromJSON converts JSON-encoded jsonBytes to Caddyfile text func FromJSON(jsonBytes []byte) ([]byte, error) { - var j Caddyfile + var j EncodedCaddyfile var result string err := json.Unmarshal(jsonBytes, &j) @@ -108,11 +103,12 @@ func FromJSON(jsonBytes []byte) ([]byte, error) { if sbPos > 0 { result += "\n\n" } - for i, host := range sb.Hosts { + for i, key := range sb.Keys { if i > 0 { result += ", " } - result += standardizeScheme(host) + //result += standardizeScheme(key) + result += key } result += jsonToText(sb.Body, 1) } @@ -164,6 +160,8 @@ func jsonToText(scope interface{}, depth int) string { return result } +// TODO: Will this function come in handy somewhere else? +/* // standardizeScheme turns an address like host:https into https://host, // or "host:" into "host". func standardizeScheme(addr string) string { @@ -174,12 +172,13 @@ func standardizeScheme(addr string) string { } return strings.TrimSuffix(addr, ":") } +*/ -// Caddyfile encapsulates a slice of ServerBlocks. -type Caddyfile []ServerBlock +// EncodedCaddyfile encapsulates a slice of EncodedServerBlocks. +type EncodedCaddyfile []EncodedServerBlock -// ServerBlock represents a server block. -type ServerBlock struct { - Hosts []string `json:"hosts"` - Body [][]interface{} `json:"body"` +// EncodedServerBlock represents a server block ripe for encoding. +type EncodedServerBlock struct { + Keys []string `json:"keys"` + Body [][]interface{} `json:"body"` } diff --git a/caddy/caddyfile/json_test.go b/caddyfile/json_test.go similarity index 60% rename from caddy/caddyfile/json_test.go rename to caddyfile/json_test.go index 2e44ae2a2..97d553c33 100644 --- a/caddy/caddyfile/json_test.go +++ b/caddyfile/json_test.go @@ -9,7 +9,7 @@ var tests = []struct { caddyfile: `foo { root /bar }`, - json: `[{"hosts":["foo"],"body":[["root","/bar"]]}]`, + json: `[{"keys":["foo"],"body":[["root","/bar"]]}]`, }, { // 1 caddyfile: `host1, host2 { @@ -17,7 +17,7 @@ var tests = []struct { def } }`, - json: `[{"hosts":["host1","host2"],"body":[["dir",[["def"]]]]}]`, + json: `[{"keys":["host1","host2"],"body":[["dir",[["def"]]]]}]`, }, { // 2 caddyfile: `host1, host2 { @@ -26,58 +26,58 @@ var tests = []struct { jkl } }`, - json: `[{"hosts":["host1","host2"],"body":[["dir","abc",[["def","ghi"],["jkl"]]]]}]`, + json: `[{"keys":["host1","host2"],"body":[["dir","abc",[["def","ghi"],["jkl"]]]]}]`, }, { // 3 caddyfile: `host1:1234, host2:5678 { dir abc { } }`, - json: `[{"hosts":["host1:1234","host2:5678"],"body":[["dir","abc",[]]]}]`, + json: `[{"keys":["host1:1234","host2:5678"],"body":[["dir","abc",[]]]}]`, }, { // 4 caddyfile: `host { foo "bar baz" }`, - json: `[{"hosts":["host"],"body":[["foo","bar baz"]]}]`, + json: `[{"keys":["host"],"body":[["foo","bar baz"]]}]`, }, { // 5 caddyfile: `host, host:80 { foo "bar \"baz\"" }`, - json: `[{"hosts":["host","host:80"],"body":[["foo","bar \"baz\""]]}]`, + json: `[{"keys":["host","host:80"],"body":[["foo","bar \"baz\""]]}]`, }, { // 6 caddyfile: `host { foo "bar baz" }`, - json: `[{"hosts":["host"],"body":[["foo","bar\nbaz"]]}]`, + json: `[{"keys":["host"],"body":[["foo","bar\nbaz"]]}]`, }, { // 7 caddyfile: `host { dir 123 4.56 true }`, - json: `[{"hosts":["host"],"body":[["dir","123","4.56","true"]]}]`, // NOTE: I guess we assume numbers and booleans should be encoded as strings...? + json: `[{"keys":["host"],"body":[["dir","123","4.56","true"]]}]`, // NOTE: I guess we assume numbers and booleans should be encoded as strings...? }, { // 8 caddyfile: `http://host, https://host { }`, - json: `[{"hosts":["http://host","https://host"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency + json: `[{"keys":["http://host","https://host"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency }, { // 9 caddyfile: `host { dir1 a b dir2 c d }`, - json: `[{"hosts":["host"],"body":[["dir1","a","b"],["dir2","c","d"]]}]`, + json: `[{"keys":["host"],"body":[["dir1","a","b"],["dir2","c","d"]]}]`, }, { // 10 caddyfile: `host { dir a b dir c d }`, - json: `[{"hosts":["host"],"body":[["dir","a","b"],["dir","c","d"]]}]`, + json: `[{"keys":["host"],"body":[["dir","a","b"],["dir","c","d"]]}]`, }, { // 11 caddyfile: `host { @@ -87,7 +87,7 @@ baz" d } }`, - json: `[{"hosts":["host"],"body":[["dir1","a","b"],["dir2",[["c"],["d"]]]]}]`, + json: `[{"keys":["host"],"body":[["dir1","a","b"],["dir2",[["c"],["d"]]]]}]`, }, { // 12 caddyfile: `host1 { @@ -97,7 +97,7 @@ baz" host2 { dir2 }`, - json: `[{"hosts":["host1"],"body":[["dir1"]]},{"hosts":["host2"],"body":[["dir2"]]}]`, + json: `[{"keys":["host1"],"body":[["dir1"]]},{"keys":["host2"],"body":[["dir2"]]}]`, }, } @@ -125,17 +125,19 @@ func TestFromJSON(t *testing.T) { } } +// TODO: Will these tests come in handy somewhere else? +/* func TestStandardizeAddress(t *testing.T) { // host:https should be converted to https://host output, err := ToJSON([]byte(`host:https`)) if err != nil { t.Fatal(err) } - if expected, actual := `[{"hosts":["https://host"],"body":[]}]`, string(output); expected != actual { + if expected, actual := `[{"keys":["https://host"],"body":[]}]`, string(output); expected != actual { t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) } - output, err = FromJSON([]byte(`[{"hosts":["https://host"],"body":[]}]`)) + output, err = FromJSON([]byte(`[{"keys":["https://host"],"body":[]}]`)) if err != nil { t.Fatal(err) } @@ -148,10 +150,10 @@ func TestStandardizeAddress(t *testing.T) { if err != nil { t.Fatal(err) } - if expected, actual := `[{"hosts":["host"],"body":[]}]`, string(output); expected != actual { + if expected, actual := `[{"keys":["host"],"body":[]}]`, string(output); expected != actual { t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) } - output, err = FromJSON([]byte(`[{"hosts":["host:"],"body":[]}]`)) + output, err = FromJSON([]byte(`[{"keys":["host:"],"body":[]}]`)) if err != nil { t.Fatal(err) } @@ -159,3 +161,4 @@ func TestStandardizeAddress(t *testing.T) { t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) } } +*/ diff --git a/caddy/parse/lexer.go b/caddyfile/lexer.go similarity index 99% rename from caddy/parse/lexer.go rename to caddyfile/lexer.go index d2939eba2..537c263aa 100644 --- a/caddy/parse/lexer.go +++ b/caddyfile/lexer.go @@ -1,4 +1,4 @@ -package parse +package caddyfile import ( "bufio" diff --git a/caddy/parse/lexer_test.go b/caddyfile/lexer_test.go similarity index 99% rename from caddy/parse/lexer_test.go rename to caddyfile/lexer_test.go index f12c7e7dc..0e30974d6 100644 --- a/caddy/parse/lexer_test.go +++ b/caddyfile/lexer_test.go @@ -1,4 +1,4 @@ -package parse +package caddyfile import ( "strings" diff --git a/caddy/parse/parsing.go b/caddyfile/parse.go similarity index 74% rename from caddy/parse/parsing.go rename to caddyfile/parse.go index 3d4a383cd..cce4cfbf0 100644 --- a/caddy/parse/parsing.go +++ b/caddyfile/parse.go @@ -1,18 +1,41 @@ -package parse +package caddyfile import ( - "fmt" - "net" + "io" "os" "path/filepath" "strings" ) +// ServerBlocks parses the input just enough to group tokens, +// in order, by server block. No further parsing is performed. +// Server blocks are returned in the order in which they appear. +// Directives that do not appear in validDirectives will cause +// an error. If you do not want to check for valid directives, +// pass in nil instead. +func ServerBlocks(filename string, input io.Reader, validDirectives []string) ([]ServerBlock, error) { + p := parser{Dispenser: NewDispenser(filename, input), validDirectives: validDirectives} + blocks, err := p.parseAll() + return blocks, err +} + +// allTokens lexes the entire input, but does not parse it. +// It returns all the tokens from the input, unstructured +// and in order. +func allTokens(input io.Reader) (tokens []token) { + l := new(lexer) + l.load(input) + for l.next() { + tokens = append(tokens, l.token) + } + return +} + type parser struct { Dispenser block ServerBlock // current server block being parsed + validDirectives []string // a directive must be valid or it's an error eof bool // if we encounter a valid EOF in a hard place - checkDirectives bool // if true, directives must be known } func (p *parser) parseAll() ([]ServerBlock, error) { @@ -23,7 +46,7 @@ func (p *parser) parseAll() ([]ServerBlock, error) { if err != nil { return blocks, err } - if len(p.block.Addresses) > 0 { + if len(p.block.Keys) > 0 { blocks = append(blocks, p.block) } } @@ -89,7 +112,7 @@ func (p *parser) addresses() error { break } - if tkn != "" { // empty token possible if user typed "" in Caddyfile + if tkn != "" { // empty token possible if user typed "" // Trailing comma indicates another address will follow, which // may possibly be on the next line if tkn[len(tkn)-1] == ',' { @@ -99,13 +122,7 @@ func (p *parser) addresses() error { expectingAnother = false // but we may still see another one on this line } - // Parse and save this address - addr, err := standardAddress(tkn) - if err != nil { - return err - } - - p.block.Addresses = append(p.block.Addresses, addr) + p.block.Keys = append(p.block.Keys, tkn) } // Advance token and possibly break out of loop or return error @@ -253,10 +270,9 @@ func (p *parser) directive() error { dir := p.Val() nesting := 0 - if p.checkDirectives { - if _, ok := ValidDirectives[dir]; !ok { - return p.Errf("Unknown directive '%s'", dir) - } + // TODO: More helpful error message ("did you mean..." or "maybe you need to install its server type") + if !p.validDirective(dir) { + return p.Errf("Unknown directive '%s'", dir) } // The directive itself is appended as a relevant token @@ -305,63 +321,17 @@ func (p *parser) closeCurlyBrace() error { return nil } -// standardAddress parses an address string into a structured format with separate -// scheme, host, and port portions, as well as the original input string. -func standardAddress(str string) (address, error) { - var scheme string - var err error - - // first check for scheme and strip it off - input := str - if strings.HasPrefix(str, "https://") { - scheme = "https" - str = str[8:] - } else if strings.HasPrefix(str, "http://") { - scheme = "http" - str = str[7:] +// validDirective returns true if dir is in p.validDirectives. +func (p *parser) validDirective(dir string) bool { + if p.validDirectives == nil { + return true } - - // separate host and port - host, port, err := net.SplitHostPort(str) - if err != nil { - host, port, err = net.SplitHostPort(str + ":") - if err != nil { - host = str + for _, d := range p.validDirectives { + if d == dir { + return true } } - - // "The host subcomponent is case-insensitive." (RFC 3986) - host = strings.ToLower(host) - - // see if we can set port based off scheme - if port == "" { - if scheme == "http" { - port = "80" - } else if scheme == "https" { - port = "443" - } - } - - // repeated or conflicting scheme is confusing, so error - if scheme != "" && (port == "http" || port == "https") { - return address{}, fmt.Errorf("[%s] scheme specified twice in address", input) - } - - // error if scheme and port combination violate convention - if (scheme == "http" && port == "443") || (scheme == "https" && port == "80") { - return address{}, fmt.Errorf("[%s] scheme and port violate convention", input) - } - - // standardize http and https ports to their respective port numbers - if port == "http" { - scheme = "http" - port = "80" - } else if port == "https" { - scheme = "https" - port = "443" - } - - return address{Original: input, Scheme: scheme, Host: host, Port: port}, err + return false } // replaceEnvVars replaces environment variables that appear in the token @@ -390,26 +360,10 @@ func replaceEnvReferences(s, refStart, refEnd string) string { } type ( - // ServerBlock associates tokens with a list of addresses - // and groups tokens by directive name. + // ServerBlock associates any number of keys (usually addresses + // of some sort) with tokens (grouped by directive name). ServerBlock struct { - Addresses []address - Tokens map[string][]token - } - - address struct { - Original, Scheme, Host, Port string + Keys []string + Tokens map[string][]token } ) - -// HostList converts the list of addresses that are -// associated with this server block into a slice of -// strings, where each address is as it was originally -// read from the input. -func (sb ServerBlock) HostList() []string { - sbHosts := make([]string, len(sb.Addresses)) - for j, addr := range sb.Addresses { - sbHosts[j] = addr.Original - } - return sbHosts -} diff --git a/caddyfile/parse_test.go b/caddyfile/parse_test.go new file mode 100644 index 000000000..6b7ad47bb --- /dev/null +++ b/caddyfile/parse_test.go @@ -0,0 +1,410 @@ +package caddyfile + +import ( + "os" + "strings" + "testing" +) + +func TestAllTokens(t *testing.T) { + input := strings.NewReader("a b c\nd e") + expected := []string{"a", "b", "c", "d", "e"} + tokens := allTokens(input) + + if len(tokens) != len(expected) { + t.Fatalf("Expected %d tokens, got %d", len(expected), len(tokens)) + } + + for i, val := range expected { + if tokens[i].text != val { + t.Errorf("Token %d should be '%s' but was '%s'", i, val, tokens[i].text) + } + } +} + +func TestParseOneAndImport(t *testing.T) { + setupParseTests() + + testParseOne := func(input string) (ServerBlock, error) { + p := testParser(input) + p.Next() // parseOne doesn't call Next() to start, so we must + err := p.parseOne() + return p.block, err + } + + for i, test := range []struct { + input string + shouldErr bool + keys []string + tokens map[string]int // map of directive name to number of tokens expected + }{ + {`localhost`, false, []string{ + "localhost", + }, map[string]int{}}, + + {`localhost + dir1`, false, []string{ + "localhost", + }, map[string]int{ + "dir1": 1, + }}, + + {`localhost:1234 + dir1 foo bar`, false, []string{ + "localhost:1234", + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost { + dir1 + }`, false, []string{ + "localhost", + }, map[string]int{ + "dir1": 1, + }}, + + {`localhost:1234 { + dir1 foo bar + dir2 + }`, false, []string{ + "localhost:1234", + }, map[string]int{ + "dir1": 3, + "dir2": 1, + }}, + + {`http://localhost https://localhost + dir1 foo bar`, false, []string{ + "http://localhost", + "https://localhost", + }, map[string]int{ + "dir1": 3, + }}, + + {`http://localhost https://localhost { + dir1 foo bar + }`, false, []string{ + "http://localhost", + "https://localhost", + }, map[string]int{ + "dir1": 3, + }}, + + {`http://localhost, https://localhost { + dir1 foo bar + }`, false, []string{ + "http://localhost", + "https://localhost", + }, map[string]int{ + "dir1": 3, + }}, + + {`http://localhost, { + }`, true, []string{ + "http://localhost", + }, map[string]int{}}, + + {`host1:80, http://host2.com + dir1 foo bar + dir2 baz`, false, []string{ + "host1:80", + "http://host2.com", + }, map[string]int{ + "dir1": 3, + "dir2": 2, + }}, + + {`http://host1.com, + http://host2.com, + https://host3.com`, false, []string{ + "http://host1.com", + "http://host2.com", + "https://host3.com", + }, map[string]int{}}, + + {`http://host1.com:1234, https://host2.com + dir1 foo { + bar baz + } + dir2`, false, []string{ + "http://host1.com:1234", + "https://host2.com", + }, map[string]int{ + "dir1": 6, + "dir2": 1, + }}, + + {`127.0.0.1 + dir1 { + bar baz + } + dir2 { + foo bar + }`, false, []string{ + "127.0.0.1", + }, map[string]int{ + "dir1": 5, + "dir2": 5, + }}, + + {`localhost + dir1 { + foo`, true, []string{ + "localhost", + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost + dir1 { + }`, false, []string{ + "localhost", + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost + dir1 { + } }`, true, []string{ + "localhost", + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost + dir1 { + nested { + foo + } + } + dir2 foo bar`, false, []string{ + "localhost", + }, map[string]int{ + "dir1": 7, + "dir2": 3, + }}, + + {``, false, []string{}, map[string]int{}}, + + {`localhost + dir1 arg1 + import testdata/import_test1.txt`, false, []string{ + "localhost", + }, map[string]int{ + "dir1": 2, + "dir2": 3, + "dir3": 1, + }}, + + {`import testdata/import_test2.txt`, false, []string{ + "host1", + }, map[string]int{ + "dir1": 1, + "dir2": 2, + }}, + + {`import testdata/import_test1.txt testdata/import_test2.txt`, true, []string{}, map[string]int{}}, + + {`import testdata/not_found.txt`, true, []string{}, map[string]int{}}, + + {`""`, false, []string{}, map[string]int{}}, + + {``, false, []string{}, map[string]int{}}, + } { + result, err := testParseOne(test.input) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected an error, but didn't get one", i) + } + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error, but got: %v", i, err) + } + + if len(result.Keys) != len(test.keys) { + t.Errorf("Test %d: Expected %d keys, got %d", + i, len(test.keys), len(result.Keys)) + continue + } + for j, addr := range result.Keys { + if addr != test.keys[j] { + t.Errorf("Test %d, key %d: Expected '%s', but was '%s'", + i, j, test.keys[j], addr) + } + } + + if len(result.Tokens) != len(test.tokens) { + t.Errorf("Test %d: Expected %d directives, had %d", + i, len(test.tokens), len(result.Tokens)) + continue + } + for directive, tokens := range result.Tokens { + if len(tokens) != test.tokens[directive] { + t.Errorf("Test %d, directive '%s': Expected %d tokens, counted %d", + i, directive, test.tokens[directive], len(tokens)) + continue + } + } + } +} + +func TestParseAll(t *testing.T) { + setupParseTests() + + for i, test := range []struct { + input string + shouldErr bool + keys [][]string // keys per server block, in order + }{ + {`localhost`, false, [][]string{ + {"localhost"}, + }}, + + {`localhost:1234`, false, [][]string{ + {"localhost:1234"}, + }}, + + {`localhost:1234 { + } + localhost:2015 { + }`, false, [][]string{ + {"localhost:1234"}, + {"localhost:2015"}, + }}, + + {`localhost:1234, http://host2`, false, [][]string{ + {"localhost:1234", "http://host2"}, + }}, + + {`localhost:1234, http://host2,`, true, [][]string{}}, + + {`http://host1.com, http://host2.com { + } + https://host3.com, https://host4.com { + }`, false, [][]string{ + {"http://host1.com", "http://host2.com"}, + {"https://host3.com", "https://host4.com"}, + }}, + + {`import testdata/import_glob*.txt`, false, [][]string{ + {"glob0.host0"}, + {"glob0.host1"}, + {"glob1.host0"}, + {"glob2.host0"}, + }}, + } { + p := testParser(test.input) + blocks, err := p.parseAll() + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected an error, but didn't get one", i) + } + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error, but got: %v", i, err) + } + + if len(blocks) != len(test.keys) { + t.Errorf("Test %d: Expected %d server blocks, got %d", + i, len(test.keys), len(blocks)) + continue + } + for j, block := range blocks { + if len(block.Keys) != len(test.keys[j]) { + t.Errorf("Test %d: Expected %d keys in block %d, got %d", + i, len(test.keys[j]), j, len(block.Keys)) + continue + } + for k, addr := range block.Keys { + if addr != test.keys[j][k] { + t.Errorf("Test %d, block %d, key %d: Expected '%s', but got '%s'", + i, j, k, test.keys[j][k], addr) + } + } + } + } +} + +func TestEnvironmentReplacement(t *testing.T) { + setupParseTests() + + os.Setenv("PORT", "8080") + os.Setenv("ADDRESS", "servername.com") + os.Setenv("FOOBAR", "foobar") + + // basic test; unix-style env vars + p := testParser(`{$ADDRESS}`) + blocks, _ := p.parseAll() + if actual, expected := blocks[0].Keys[0], "servername.com"; expected != actual { + t.Errorf("Expected key to be '%s' but was '%s'", expected, actual) + } + + // multiple vars per token + p = testParser(`{$ADDRESS}:{$PORT}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Keys[0], "servername.com:8080"; expected != actual { + t.Errorf("Expected key to be '%s' but was '%s'", expected, actual) + } + + // windows-style var and unix style in same token + p = testParser(`{%ADDRESS%}:{$PORT}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Keys[0], "servername.com:8080"; expected != actual { + t.Errorf("Expected key to be '%s' but was '%s'", expected, actual) + } + + // reverse order + p = testParser(`{$ADDRESS}:{%PORT%}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Keys[0], "servername.com:8080"; expected != actual { + t.Errorf("Expected key to be '%s' but was '%s'", expected, actual) + } + + // env var in server block body as argument + p = testParser(":{%PORT%}\ndir1 {$FOOBAR}") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Keys[0], ":8080"; expected != actual { + t.Errorf("Expected key to be '%s' but was '%s'", expected, actual) + } + if actual, expected := blocks[0].Tokens["dir1"][1].text, "foobar"; expected != actual { + t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) + } + + // combined windows env vars in argument + p = testParser(":{%PORT%}\ndir1 {%ADDRESS%}/{%FOOBAR%}") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Tokens["dir1"][1].text, "servername.com/foobar"; expected != actual { + t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) + } + + // malformed env var (windows) + p = testParser(":1234\ndir1 {%ADDRESS}") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Tokens["dir1"][1].text, "{%ADDRESS}"; expected != actual { + t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) + } + + // malformed (non-existent) env var (unix) + p = testParser(`:{$PORT$}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Keys[0], ":"; expected != actual { + t.Errorf("Expected key to be '%s' but was '%s'", expected, actual) + } + + // in quoted field + p = testParser(":1234\ndir1 \"Test {$FOOBAR} test\"") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Tokens["dir1"][1].text, "Test foobar test"; expected != actual { + t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) + } +} + +func setupParseTests() { + // Set up some bogus directives for testing + //directives = []string{"dir1", "dir2", "dir3"} +} + +func testParser(input string) parser { + buf := strings.NewReader(input) + p := parser{Dispenser: NewDispenser("Test", buf)} + return p +} diff --git a/caddy/parse/import_glob0.txt b/caddyfile/testdata/import_glob0.txt similarity index 100% rename from caddy/parse/import_glob0.txt rename to caddyfile/testdata/import_glob0.txt diff --git a/caddy/parse/import_glob1.txt b/caddyfile/testdata/import_glob1.txt similarity index 100% rename from caddy/parse/import_glob1.txt rename to caddyfile/testdata/import_glob1.txt diff --git a/caddy/parse/import_glob2.txt b/caddyfile/testdata/import_glob2.txt similarity index 100% rename from caddy/parse/import_glob2.txt rename to caddyfile/testdata/import_glob2.txt diff --git a/caddy/parse/import_test1.txt b/caddyfile/testdata/import_test1.txt similarity index 100% rename from caddy/parse/import_test1.txt rename to caddyfile/testdata/import_test1.txt diff --git a/caddy/parse/import_test2.txt b/caddyfile/testdata/import_test2.txt similarity index 100% rename from caddy/parse/import_test2.txt rename to caddyfile/testdata/import_test2.txt diff --git a/caddyhttp/bind/bind.go b/caddyhttp/bind/bind.go new file mode 100644 index 000000000..fd60f8d1e --- /dev/null +++ b/caddyhttp/bind/bind.go @@ -0,0 +1,25 @@ +package bind + +import ( + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "bind", + ServerType: "http", + Action: setupBind, + }) +} + +func setupBind(c *caddy.Controller) error { + config := httpserver.GetConfig(c.Key) + for c.Next() { + if !c.Args(&config.ListenHost) { + return c.ArgErr() + } + config.TLS.ListenHost = config.ListenHost // necessary for ACME challenges, see issue #309 + } + return nil +} diff --git a/caddyhttp/bind/bind_test.go b/caddyhttp/bind/bind_test.go new file mode 100644 index 000000000..330d5427d --- /dev/null +++ b/caddyhttp/bind/bind_test.go @@ -0,0 +1,23 @@ +package bind + +import ( + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetupBind(t *testing.T) { + err := setupBind(caddy.NewTestController(`bind 1.2.3.4`)) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + cfg := httpserver.GetConfig("") + if got, want := cfg.ListenHost, "1.2.3.4"; got != want { + t.Errorf("Expected the config's ListenHost to be %s, was %s", want, got) + } + if got, want := cfg.TLS.ListenHost, "1.2.3.4"; got != want { + t.Errorf("Expected the TLS config's ListenHost to be %s, was %s", want, got) + } +} diff --git a/caddyhttp/caddyhttp.go b/caddyhttp/caddyhttp.go new file mode 100644 index 000000000..5c804b648 --- /dev/null +++ b/caddyhttp/caddyhttp.go @@ -0,0 +1,13 @@ +package caddyhttp + +import ( + // plug in the server + _ "github.com/mholt/caddy/caddyhttp/httpserver" + + // plug in the standard directives + _ "github.com/mholt/caddy/caddyhttp/bind" + _ "github.com/mholt/caddy/caddyhttp/gzip" + _ "github.com/mholt/caddy/caddyhttp/log" + _ "github.com/mholt/caddy/caddyhttp/root" + _ "github.com/mholt/caddy/startupshutdown" +) diff --git a/middleware/gzip/gzip.go b/caddyhttp/gzip/gzip.go similarity index 93% rename from middleware/gzip/gzip.go rename to caddyhttp/gzip/gzip.go index 4ef658556..dfee6c2e9 100644 --- a/middleware/gzip/gzip.go +++ b/caddyhttp/gzip/gzip.go @@ -1,4 +1,4 @@ -// Package gzip provides a simple middleware layer that performs +// Package gzip provides a middleware layer that performs // gzip compression on the response. package gzip @@ -12,15 +12,24 @@ import ( "net/http" "strings" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" ) +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "gzip", + ServerType: "http", + Action: setup, + }) +} + // Gzip is a middleware type which gzips HTTP responses. It is // imperative that any handler which writes to a gzipped response // specifies the Content-Type, otherwise some clients will assume // application/x-gzip and try to download a file. type Gzip struct { - Next middleware.Handler + Next httpserver.Handler Configs []Config } @@ -36,7 +45,6 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { return g.Next.ServeHTTP(w, r) } - outer: for _, c := range g.Configs { @@ -79,9 +87,7 @@ outer: // to send something back before gzipWriter gets closed at // the return of this method! if status >= 400 { - gz.Header().Set("Content-Type", "text/plain") // very necessary - gz.WriteHeader(status) - fmt.Fprintf(gz, "%d %s", status, http.StatusText(status)) + httpserver.DefaultErrorFunc(w, r, status) return 0, err } return status, err diff --git a/middleware/gzip/gzip_test.go b/caddyhttp/gzip/gzip_test.go similarity index 94% rename from middleware/gzip/gzip_test.go rename to caddyhttp/gzip/gzip_test.go index b39dd1af8..738dff679 100644 --- a/middleware/gzip/gzip_test.go +++ b/caddyhttp/gzip/gzip_test.go @@ -8,11 +8,10 @@ import ( "strings" "testing" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy/caddyhttp/httpserver" ) func TestGzipHandler(t *testing.T) { - pathFilter := PathFilter{make(Set)} badPaths := []string{"/bad", "/nogzip", "/nongzip"} for _, p := range badPaths { @@ -80,9 +79,8 @@ func TestGzipHandler(t *testing.T) { } } -func nextFunc(shouldGzip bool) middleware.Handler { - return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - +func nextFunc(shouldGzip bool) httpserver.Handler { + return httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { // write a relatively large text file b, err := ioutil.ReadFile("testdata/test.txt") if err != nil { diff --git a/middleware/gzip/request_filter.go b/caddyhttp/gzip/requestfilter.go similarity index 91% rename from middleware/gzip/request_filter.go rename to caddyhttp/gzip/requestfilter.go index 10f25c59b..804232a9d 100644 --- a/middleware/gzip/request_filter.go +++ b/caddyhttp/gzip/requestfilter.go @@ -4,7 +4,7 @@ import ( "net/http" "path" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy/caddyhttp/httpserver" ) // RequestFilter determines if a request should be gzipped. @@ -15,7 +15,8 @@ type RequestFilter interface { } // defaultExtensions is the list of default extensions for which to enable gzipping. -var defaultExtensions = []string{"", ".txt", ".htm", ".html", ".css", ".php", ".js", ".json", ".md", ".xml", ".svg"} +var defaultExtensions = []string{"", ".txt", ".htm", ".html", ".css", ".php", ".js", ".json", + ".md", ".mdown", ".xml", ".svg", ".go", ".cgi", ".py", ".pl", ".aspx", ".asp"} // DefaultExtFilter creates an ExtFilter with default extensions. func DefaultExtFilter() ExtFilter { @@ -54,7 +55,7 @@ type PathFilter struct { // is found and true otherwise. func (p PathFilter) ShouldCompress(r *http.Request) bool { return !p.IgnoredPaths.ContainsFunc(func(value string) bool { - return middleware.Path(r.URL.Path).Matches(value) + return httpserver.Path(r.URL.Path).Matches(value) }) } diff --git a/middleware/gzip/request_filter_test.go b/caddyhttp/gzip/requestfilter_test.go similarity index 100% rename from middleware/gzip/request_filter_test.go rename to caddyhttp/gzip/requestfilter_test.go diff --git a/middleware/gzip/response_filter.go b/caddyhttp/gzip/responsefilter.go similarity index 100% rename from middleware/gzip/response_filter.go rename to caddyhttp/gzip/responsefilter.go diff --git a/middleware/gzip/response_filter_test.go b/caddyhttp/gzip/responsefilter_test.go similarity index 94% rename from middleware/gzip/response_filter_test.go rename to caddyhttp/gzip/responsefilter_test.go index 2878336c3..a34f58cd3 100644 --- a/middleware/gzip/response_filter_test.go +++ b/caddyhttp/gzip/responsefilter_test.go @@ -7,7 +7,7 @@ import ( "net/http/httptest" "testing" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy/caddyhttp/httpserver" ) func TestLengthFilter(t *testing.T) { @@ -61,7 +61,7 @@ func TestResponseFilterWriter(t *testing.T) { }} for i, ts := range tests { - server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + server.Next = httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { w.Header().Set("Content-Length", fmt.Sprint(len(ts.body))) w.Write([]byte(ts.body)) return 200, nil diff --git a/caddy/setup/gzip.go b/caddyhttp/gzip/setup.go similarity index 69% rename from caddy/setup/gzip.go rename to caddyhttp/gzip/setup.go index 7d09fe01e..824ac2141 100644 --- a/caddy/setup/gzip.go +++ b/caddyhttp/gzip/setup.go @@ -1,38 +1,40 @@ -package setup +package gzip import ( "fmt" "strconv" "strings" - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/gzip" + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" ) -// Gzip configures a new gzip middleware instance. -func Gzip(c *Controller) (middleware.Middleware, error) { +// setup configures a new gzip middleware instance. +func setup(c *caddy.Controller) error { configs, err := gzipParse(c) if err != nil { - return nil, err + return err } - return func(next middleware.Handler) middleware.Handler { - return gzip.Gzip{Next: next, Configs: configs} - }, nil + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Gzip{Next: next, Configs: configs} + }) + + return nil } -func gzipParse(c *Controller) ([]gzip.Config, error) { - var configs []gzip.Config +func gzipParse(c *caddy.Controller) ([]Config, error) { + var configs []Config for c.Next() { - config := gzip.Config{} + config := Config{} // Request Filters - pathFilter := gzip.PathFilter{IgnoredPaths: make(gzip.Set)} - extFilter := gzip.ExtFilter{Exts: make(gzip.Set)} + pathFilter := PathFilter{IgnoredPaths: make(Set)} + extFilter := ExtFilter{Exts: make(Set)} // Response Filters - lengthFilter := gzip.LengthFilter(0) + lengthFilter := LengthFilter(0) // No extra args expected if len(c.RemainingArgs()) > 0 { @@ -47,7 +49,7 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { return configs, c.ArgErr() } for _, e := range exts { - if !strings.HasPrefix(e, ".") && e != gzip.ExtWildCard && e != "" { + if !strings.HasPrefix(e, ".") && e != ExtWildCard && e != "" { return configs, fmt.Errorf(`gzip: invalid extension "%v" (must start with dot)`, e) } extFilter.Exts.Add(e) @@ -82,18 +84,18 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { } else if length == 0 { return configs, fmt.Errorf(`gzip: min_length must be greater than 0`) } - lengthFilter = gzip.LengthFilter(length) + lengthFilter = LengthFilter(length) default: return configs, c.ArgErr() } } // Request Filters - config.RequestFilters = []gzip.RequestFilter{} + config.RequestFilters = []RequestFilter{} // If ignored paths are specified, put in front to filter with path first if len(pathFilter.IgnoredPaths) > 0 { - config.RequestFilters = []gzip.RequestFilter{pathFilter} + config.RequestFilters = []RequestFilter{pathFilter} } // Then, if extensions are specified, use those to filter. @@ -101,7 +103,7 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { if len(extFilter.Exts) > 0 { config.RequestFilters = append(config.RequestFilters, extFilter) } else { - config.RequestFilters = append(config.RequestFilters, gzip.DefaultExtFilter()) + config.RequestFilters = append(config.RequestFilters, DefaultExtFilter()) } // Response Filters diff --git a/caddy/setup/gzip_test.go b/caddyhttp/gzip/setup_test.go similarity index 76% rename from caddy/setup/gzip_test.go rename to caddyhttp/gzip/setup_test.go index 4c24ab0ab..a71c9b19f 100644 --- a/caddy/setup/gzip_test.go +++ b/caddyhttp/gzip/setup_test.go @@ -1,29 +1,29 @@ -package setup +package gzip import ( "testing" - "github.com/mholt/caddy/middleware/gzip" + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" ) -func TestGzip(t *testing.T) { - c := NewTestController(`gzip`) - - mid, err := Gzip(c) +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`gzip`)) if err != nil { t.Errorf("Expected no errors, but got: %v", err) } - if mid == nil { + mids := httpserver.GetConfig("").Middleware() + if mids == nil { t.Fatal("Expected middleware, was nil instead") } - handler := mid(EmptyNext) - myHandler, ok := handler.(gzip.Gzip) + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(Gzip) if !ok { t.Fatalf("Expected handler to be type Gzip, got: %#v", handler) } - if !SameNext(myHandler.Next, EmptyNext) { + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { t.Error("'Next' field of handler was not set properly") } @@ -90,8 +90,7 @@ func TestGzip(t *testing.T) { `, false}, } for i, test := range tests { - c := NewTestController(test.input) - _, err := gzipParse(c) + _, err := gzipParse(caddy.NewTestController(test.input)) if test.shouldErr && err == nil { t.Errorf("Test %v: Expected error but found nil", i) } else if !test.shouldErr && err != nil { diff --git a/middleware/gzip/testdata/test.txt b/caddyhttp/gzip/testdata/test.txt similarity index 100% rename from middleware/gzip/testdata/test.txt rename to caddyhttp/gzip/testdata/test.txt diff --git a/server/graceful.go b/caddyhttp/httpserver/graceful.go similarity index 68% rename from server/graceful.go rename to caddyhttp/httpserver/graceful.go index 5057d039b..f11a6c9aa 100644 --- a/server/graceful.go +++ b/caddyhttp/httpserver/graceful.go @@ -1,4 +1,4 @@ -package server +package httpserver import ( "net" @@ -6,16 +6,20 @@ import ( "syscall" ) +// TODO: Should this be a generic graceful listener available in its own package or something? +// Also, passing in a WaitGroup is a little awkward. Why can't this listener just keep +// the waitgroup internal to itself? + // newGracefulListener returns a gracefulListener that wraps l and // uses wg (stored in the host server) to count connections. -func newGracefulListener(l ListenerFile, wg *sync.WaitGroup) *gracefulListener { - gl := &gracefulListener{ListenerFile: l, stop: make(chan error), httpWg: wg} +func newGracefulListener(l net.Listener, wg *sync.WaitGroup) *gracefulListener { + gl := &gracefulListener{Listener: l, stop: make(chan error), connWg: wg} go func() { <-gl.stop gl.Lock() gl.stopped = true gl.Unlock() - gl.stop <- gl.ListenerFile.Close() + gl.stop <- gl.Listener.Close() }() return gl } @@ -24,21 +28,21 @@ func newGracefulListener(l ListenerFile, wg *sync.WaitGroup) *gracefulListener { // count the number of connections on it. Its // methods mainly wrap net.Listener to be graceful. type gracefulListener struct { - ListenerFile + net.Listener stop chan error stopped bool sync.Mutex // protects the stopped flag - httpWg *sync.WaitGroup // pointer to the host's wg used for counting connections + connWg *sync.WaitGroup // pointer to the host's wg used for counting connections } // Accept accepts a connection. func (gl *gracefulListener) Accept() (c net.Conn, err error) { - c, err = gl.ListenerFile.Accept() + c, err = gl.Listener.Accept() if err != nil { return } - c = gracefulConn{Conn: c, httpWg: gl.httpWg} - gl.httpWg.Add(1) + c = gracefulConn{Conn: c, connWg: gl.connWg} + gl.connWg.Add(1) return } @@ -60,7 +64,7 @@ func (gl *gracefulListener) Close() error { // a graceful shutdown. type gracefulConn struct { net.Conn - httpWg *sync.WaitGroup // pointer to the host server's connection waitgroup + connWg *sync.WaitGroup // pointer to the host server's connection waitgroup } // Close closes c's underlying connection while updating the wg count. @@ -71,6 +75,6 @@ func (c gracefulConn) Close() error { } // close can fail on http2 connections (as of Oct. 2015, before http2 in std lib) // so don't decrement count unless close succeeds - c.httpWg.Done() + c.connWg.Done() return nil } diff --git a/caddyhttp/httpserver/https.go b/caddyhttp/httpserver/https.go new file mode 100644 index 000000000..b93a85439 --- /dev/null +++ b/caddyhttp/httpserver/https.go @@ -0,0 +1,154 @@ +package httpserver + +import ( + "net" + "net/http" + + "github.com/mholt/caddy/caddytls" +) + +func activateHTTPS() error { + // TODO: Is this loop a bug? Should we scope this method to just a single context? (restarts...?) + for _, ctx := range contexts { + // pre-screen each config and earmark the ones that qualify for managed TLS + markQualifiedForAutoHTTPS(ctx.siteConfigs) + + // place certificates and keys on disk + for _, c := range ctx.siteConfigs { + err := c.TLS.ObtainCert(true) + if err != nil { + return err + } + } + + // update TLS configurations + err := enableAutoHTTPS(ctx.siteConfigs, true) + if err != nil { + return err + } + + // set up redirects + ctx.siteConfigs = makePlaintextRedirects(ctx.siteConfigs) + } + + // renew all relevant certificates that need renewal. this is important + // to do right away so we guarantee that renewals aren't missed, and + // also the user can respond to any potential errors that occur. + err := caddytls.RenewManagedCertificates(true) + if err != nil { + return err + } + + return nil +} + +// markQualifiedForAutoHTTPS scans each config and, if it +// qualifies for managed TLS, it sets the Managed field of +// the TLS config to true. +func markQualifiedForAutoHTTPS(configs []*SiteConfig) { + for _, cfg := range configs { + if caddytls.QualifiesForManagedTLS(cfg) && cfg.Addr.Scheme != "http" { + cfg.TLS.Managed = true + } + } +} + +// enableAutoHTTPS configures each config to use TLS according to default settings. +// It will only change configs that are marked as managed, and assumes that +// certificates and keys are already on disk. If loadCertificates is true, +// the certificates will be loaded from disk into the cache for this process +// to use. If false, TLS will still be enabled and configured with default +// settings, but no certificates will be parsed loaded into the cache, and +// the returned error value will always be nil. +func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error { + for _, cfg := range configs { + if cfg == nil || cfg.TLS == nil || !cfg.TLS.Managed { + continue + } + cfg.TLS.Enabled = true + cfg.Addr.Scheme = "https" + if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) { + _, err := caddytls.CacheManagedCertificate(cfg.Addr.Host, cfg.TLS) + if err != nil { + return err + } + } + + // Make sure any config values not explicitly set are set to default + caddytls.SetDefaultTLSParams(cfg.TLS) + + // Set default port of 443 if not explicitly set + if cfg.Addr.Port == "" && + cfg.TLS.Enabled && + (!cfg.TLS.Manual || cfg.TLS.OnDemand) && + cfg.Addr.Host != "localhost" { + cfg.Addr.Port = "443" + } + } + return nil +} + +// makePlaintextRedirects sets up redirects from port 80 to the relevant HTTPS +// hosts. You must pass in all configs, not just configs that qualify, since +// we must know whether the same host already exists on port 80, and those would +// not be in a list of configs that qualify for automatic HTTPS. This function will +// only set up redirects for configs that qualify. It returns the updated list of +// all configs. +func makePlaintextRedirects(allConfigs []*SiteConfig) []*SiteConfig { + for i, cfg := range allConfigs { + if cfg.TLS.Managed && + !hostHasOtherPort(allConfigs, i, "80") && + (cfg.Addr.Port == "443" || !hostHasOtherPort(allConfigs, i, "443")) { + allConfigs = append(allConfigs, redirPlaintextHost(cfg)) + } + } + return allConfigs +} + +// hostHasOtherPort returns true if there is another config in the list with the same +// hostname that has port otherPort, or false otherwise. All the configs are checked +// against the hostname of allConfigs[thisConfigIdx]. +func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort string) bool { + for i, otherCfg := range allConfigs { + if i == thisConfigIdx { + continue // has to be a config OTHER than the one we're comparing against + } + if otherCfg.Addr.Host == allConfigs[thisConfigIdx].Addr.Host && + otherCfg.Addr.Port == otherPort { + return true + } + } + return false +} + +// redirPlaintextHost returns a new plaintext HTTP configuration for +// a virtualHost that simply redirects to cfg, which is assumed to +// be the HTTPS configuration. The returned configuration is set +// to listen on port 80. The TLS field of cfg must not be nil. +func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { + redirPort := cfg.Addr.Port + if redirPort == "443" { + // default port is redundant + redirPort = "" + } + redirMiddleware := func(next Handler) Handler { + return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + toURL := "https://" + r.Host + if redirPort != "" { + toURL += ":" + redirPort + } + toURL += r.URL.RequestURI() + http.Redirect(w, r, toURL, http.StatusMovedPermanently) + return 0, nil + }) + } + host := cfg.Addr.Host + port := "80" + addr := net.JoinHostPort(host, port) + return &SiteConfig{ + Addr: Address{Original: addr, Host: host, Port: port}, + ListenHost: cfg.ListenHost, + middleware: []Middleware{redirMiddleware}, + TLS: &caddytls.Config{AltHTTPPort: cfg.TLS.AltHTTPPort}, + } +} diff --git a/caddyhttp/httpserver/https_test.go b/caddyhttp/httpserver/https_test.go new file mode 100644 index 000000000..04a4db1e4 --- /dev/null +++ b/caddyhttp/httpserver/https_test.go @@ -0,0 +1,178 @@ +package httpserver + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/mholt/caddy/caddytls" +) + +func TestRedirPlaintextHost(t *testing.T) { + cfg := redirPlaintextHost(&SiteConfig{ + Addr: Address{ + Host: "example.com", + Port: "1234", + }, + ListenHost: "93.184.216.34", + TLS: new(caddytls.Config), + }) + + // Check host and port + if actual, expected := cfg.Addr.Host, "example.com"; actual != expected { + t.Errorf("Expected redir config to have host %s but got %s", expected, actual) + } + if actual, expected := cfg.ListenHost, "93.184.216.34"; actual != expected { + t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual) + } + if actual, expected := cfg.Addr.Port, "80"; actual != expected { + t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual) + } + + // Make sure redirect handler is set up properly + if cfg.middleware == nil || len(cfg.middleware) != 1 { + t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.middleware) + } + + handler := cfg.middleware[0](nil) + + // Check redirect for correctness + rec := httptest.NewRecorder() + req, err := http.NewRequest("GET", "http://foo/bar?q=1", nil) + if err != nil { + t.Fatal(err) + } + status, err := handler.ServeHTTP(rec, req) + if status != 0 { + t.Errorf("Expected status return to be 0, but was %d", status) + } + if err != nil { + t.Errorf("Expected returned error to be nil, but was %v", err) + } + if rec.Code != http.StatusMovedPermanently { + t.Errorf("Expected status %d but got %d", http.StatusMovedPermanently, rec.Code) + } + if got, want := rec.Header().Get("Location"), "https://foo:1234/bar?q=1"; got != want { + t.Errorf("Expected Location: '%s' but got '%s'", want, got) + } + + // browsers can infer a default port from scheme, so make sure the port + // doesn't get added in explicitly for default ports like 443 for https. + cfg = redirPlaintextHost(&SiteConfig{Addr: Address{Host: "example.com", Port: "443"}, TLS: new(caddytls.Config)}) + handler = cfg.middleware[0](nil) + + rec = httptest.NewRecorder() + req, err = http.NewRequest("GET", "http://foo/bar?q=1", nil) + if err != nil { + t.Fatal(err) + } + status, err = handler.ServeHTTP(rec, req) + if status != 0 { + t.Errorf("Expected status return to be 0, but was %d", status) + } + if err != nil { + t.Errorf("Expected returned error to be nil, but was %v", err) + } + if rec.Code != http.StatusMovedPermanently { + t.Errorf("Expected status %d but got %d", http.StatusMovedPermanently, rec.Code) + } + if got, want := rec.Header().Get("Location"), "https://foo/bar?q=1"; got != want { + t.Errorf("Expected Location: '%s' but got '%s'", want, got) + } +} + +func TestHostHasOtherPort(t *testing.T) { + configs := []*SiteConfig{ + {Addr: Address{Host: "example.com", Port: "80"}}, + {Addr: Address{Host: "sub1.example.com", Port: "80"}}, + {Addr: Address{Host: "sub1.example.com", Port: "443"}}, + } + + if hostHasOtherPort(configs, 0, "80") { + t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`) + } + if hostHasOtherPort(configs, 0, "443") { + t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`) + } + if !hostHasOtherPort(configs, 1, "443") { + t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`) + } +} + +func TestMakePlaintextRedirects(t *testing.T) { + configs := []*SiteConfig{ + // Happy path = standard redirect from 80 to 443 + {Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true}}, + + // Host on port 80 already defined; don't change it (no redirect) + {Addr: Address{Host: "sub1.example.com", Port: "80", Scheme: "http"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "sub1.example.com"}, TLS: &caddytls.Config{Managed: true}}, + + // Redirect from port 80 to port 5000 in this case + {Addr: Address{Host: "sub2.example.com", Port: "5000"}, TLS: &caddytls.Config{Managed: true}}, + + // Can redirect from 80 to either 443 or 5001, but choose 443 + {Addr: Address{Host: "sub3.example.com", Port: "443"}, TLS: &caddytls.Config{Managed: true}}, + {Addr: Address{Host: "sub3.example.com", Port: "5001", Scheme: "https"}, TLS: &caddytls.Config{Managed: true}}, + } + + result := makePlaintextRedirects(configs) + expectedRedirCount := 3 + + if len(result) != len(configs)+expectedRedirCount { + t.Errorf("Expected %d redirect(s) to be added, but got %d", + expectedRedirCount, len(result)-len(configs)) + } +} + +func TestEnableAutoHTTPS(t *testing.T) { + configs := []*SiteConfig{ + {Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true}}, + {}, // not managed - no changes! + } + + enableAutoHTTPS(configs, false) + + if !configs[0].TLS.Enabled { + t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false") + } + if configs[0].Addr.Scheme != "https" { + t.Errorf("Expected config 0 to have Addr.Scheme == \"https\", but it was \"%s\"", + configs[0].Addr.Scheme) + } + if configs[1].TLS != nil && configs[1].TLS.Enabled { + t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true") + } +} + +func TestMarkQualifiedForAutoHTTPS(t *testing.T) { + // TODO: caddytls.TestQualifiesForManagedTLS and this test share nearly the same config list... + configs := []*SiteConfig{ + {Addr: Address{Host: ""}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "localhost"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "123.44.3.21"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "example.com"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Manual: true}}, + {Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{ACMEEmail: "off"}}, + {Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{ACMEEmail: "foo@bar.com"}}, + {Addr: Address{Host: "example.com", Scheme: "http"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "example.com", Port: "80"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "example.com", Port: "1234"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "example.com", Scheme: "https"}, TLS: new(caddytls.Config)}, + {Addr: Address{Host: "example.com", Port: "80", Scheme: "https"}, TLS: new(caddytls.Config)}, + } + expectedManagedCount := 4 + + markQualifiedForAutoHTTPS(configs) + + count := 0 + for _, cfg := range configs { + if cfg.TLS.Managed { + count++ + } + } + + if count != expectedManagedCount { + t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) + } +} diff --git a/middleware/middleware.go b/caddyhttp/httpserver/middleware.go similarity index 68% rename from middleware/middleware.go rename to caddyhttp/httpserver/middleware.go index d91044ebe..e5e70de42 100644 --- a/middleware/middleware.go +++ b/caddyhttp/httpserver/middleware.go @@ -1,12 +1,18 @@ -// Package middleware provides some types and functions common among middleware. -package middleware +package httpserver import ( + "fmt" "net/http" + "os" "path" + "strings" "time" ) +func init() { + initCaseSettings() +} + type ( // Middleware is the middle layer which represents the traditional // idea of middleware: it chains one Handler to the next by being @@ -96,8 +102,55 @@ func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) { w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat)) } +// CaseSensitivePath determines if paths should be case sensitive. +// This is configurable via CASE_SENSITIVE_PATH environment variable. +var CaseSensitivePath = true + +const caseSensitivePathEnv = "CASE_SENSITIVE_PATH" + +// initCaseSettings loads case sensitivity config from environment variable. +// +// This could have been in init, but init cannot be called from tests. +func initCaseSettings() { + switch os.Getenv(caseSensitivePathEnv) { + case "0", "false": + CaseSensitivePath = false + default: + CaseSensitivePath = true + } +} + +// Path represents a URI path. +type Path string + +// Matches checks to see if other matches p. +// +// Path matching will probably not always be a direct +// comparison; this method assures that paths can be +// easily and consistently matched. +func (p Path) Matches(other string) bool { + if CaseSensitivePath { + return strings.HasPrefix(string(p), other) + } + return strings.HasPrefix(strings.ToLower(string(p)), strings.ToLower(other)) +} + // currentTime, as it is defined here, returns time.Now(). // It's defined as a variable for mocking time in tests. -var currentTime = func() time.Time { - return time.Now() +var currentTime = func() time.Time { return time.Now() } + +// EmptyNext is a no-op function that can be passed into +// Middleware functions so that the assignment to the +// Next field of the Handler can be tested. +// +// Used primarily for testing but needs to be exported so +// plugins can use this as a convenience. +var EmptyNext = HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return 0, nil }) + +// SameNext does a pointer comparison between next1 and next2. +// +// Used primarily for testing but needs to be exported so +// plugins can use this as a convenience. +func SameNext(next1, next2 Handler) bool { + return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2) } diff --git a/middleware/path_test.go b/caddyhttp/httpserver/middleware_test.go similarity index 98% rename from middleware/path_test.go rename to caddyhttp/httpserver/middleware_test.go index eb054b1e4..2f75c8bb9 100644 --- a/middleware/path_test.go +++ b/caddyhttp/httpserver/middleware_test.go @@ -1,4 +1,4 @@ -package middleware +package httpserver import ( "os" diff --git a/caddyhttp/httpserver/plugin.go b/caddyhttp/httpserver/plugin.go new file mode 100644 index 000000000..5e6cb0e78 --- /dev/null +++ b/caddyhttp/httpserver/plugin.go @@ -0,0 +1,367 @@ +package httpserver + +import ( + "flag" + "fmt" + "log" + "net" + "net/url" + "strings" + "time" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyfile" + "github.com/mholt/caddy/caddytls" +) + +const serverType = "http" + +func init() { + flag.StringVar(&Host, "host", DefaultHost, "Default host") + flag.StringVar(&Port, "port", DefaultPort, "Default port") + flag.StringVar(&Root, "root", DefaultRoot, "Root path of default site") + flag.DurationVar(&GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown") // TODO + flag.BoolVar(&HTTP2, "http2", true, "Use HTTP/2") + flag.BoolVar(&QUIC, "quic", false, "Use experimental QUIC") + + caddy.RegisterServerType(serverType, caddy.ServerType{ + Directives: directives, + DefaultInput: func() caddy.Input { + if Port == DefaultPort && Host != "" { + // by leaving the port blank in this case we give auto HTTPS + // a chance to set the port to 443 for us + return caddy.CaddyfileInput{ + Contents: []byte(fmt.Sprintf("%s\nroot %s", Host, Root)), + ServerTypeName: serverType, + } + } + return caddy.CaddyfileInput{ + Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, Port, Root)), + ServerTypeName: serverType, + } + }, + NewContext: newContext, + }) + caddy.RegisterCaddyfileLoader("short", caddy.LoaderFunc(shortCaddyfileLoader)) + caddy.RegisterParsingCallback(serverType, "tls", activateHTTPS) + caddytls.RegisterConfigGetter(serverType, func(key string) *caddytls.Config { return GetConfig(key).TLS }) +} + +var contexts []*httpContext + +func newContext() caddy.Context { + context := &httpContext{keysToSiteConfigs: make(map[string]*SiteConfig)} + contexts = append(contexts, context) + return context +} + +type httpContext struct { + // keysToSiteConfigs maps an address at the top of a + // server block (a "key") to its SiteConfig. Not all + // SiteConfigs will be represented here, only ones + // that appeared in the Caddyfile. + keysToSiteConfigs map[string]*SiteConfig + + // siteConfigs is the master list of all site configs. + siteConfigs []*SiteConfig +} + +// InspectServerBlocks make sure that everything checks out before +// executing directives and otherwise prepares the directives to +// be parsed and executed. +func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) { + // TODO: Here we can inspect the server blocks + // and make changes to them, like adding a directive + // that must always be present (e.g. 'errors discard`?) - + // totally optional; server types need not register this + // function. + + // For each address in each server block, make a new config + for _, sb := range serverBlocks { + for _, key := range sb.Keys { + key = strings.ToLower(key) + if _, dup := h.keysToSiteConfigs[key]; dup { + return serverBlocks, fmt.Errorf("duplicate site address: %s", key) + } + addr, err := standardizeAddress(key) + if err != nil { + return serverBlocks, err + } + // Save the config to our master list, and key it for lookups + cfg := &SiteConfig{ + Addr: addr, + TLS: &caddytls.Config{Hostname: addr.Host}, + HiddenFiles: []string{sourceFile}, + } + h.siteConfigs = append(h.siteConfigs, cfg) + h.keysToSiteConfigs[key] = cfg + } + } + + return serverBlocks, nil +} + +// MakeServers uses the newly-created siteConfigs to +// create and return a list of server instances. +func (h *httpContext) MakeServers() ([]caddy.Server, error) { + // make sure TLS is disabled for explicitly-HTTP sites + // (necessary when HTTP address shares a block containing tls) + for _, cfg := range h.siteConfigs { + if cfg.TLS.Enabled && (cfg.Addr.Port == "80" || cfg.Addr.Scheme == "http") { + cfg.TLS.Enabled = false + log.Printf("[WARNING] TLS disabled for %s", cfg.Addr) + } + } + + // we must map (group) each config to a bind address + groups, err := groupSiteConfigsByListenAddr(h.siteConfigs) + if err != nil { + return nil, err + } + + // then we create a server for each group + var servers []caddy.Server + for addr, group := range groups { + s, err := NewServer(addr, group) + if err != nil { + return nil, err + } + servers = append(servers, s) + } + + return servers, nil +} + +// GetConfig gets a SiteConfig that is keyed by addrKey. +// It creates an empty one in the latest context if +// the key does not exist in any context, so it +// will never return nil. If no contexts exist (which +// should never happen except in tests), it creates a +// new context in which to put it. +func GetConfig(addrKey string) *SiteConfig { + for _, context := range contexts { + if cfg, ok := context.keysToSiteConfigs[addrKey]; ok { + return cfg + } + } + if len(contexts) == 0 { + // this shouldn't happen except in tests + newContext() + } + cfg := new(SiteConfig) + cfg.TLS = new(caddytls.Config) + defaultCtx := contexts[len(contexts)-1] + defaultCtx.siteConfigs = append(defaultCtx.siteConfigs, cfg) + defaultCtx.keysToSiteConfigs[addrKey] = cfg + return cfg +} + +// shortCaddyfileLoader loads a Caddyfile if positional arguments are +// detected, or, in other words, if un-named arguments are provided to +// the program. A "short Caddyfile" is one in which each argument +// is a line of the Caddyfile. The default host and port are prepended +// according to the Host and Port values. +func shortCaddyfileLoader(serverType string) (caddy.Input, error) { + if flag.NArg() > 0 && serverType == "http" { + confBody := fmt.Sprintf("%s:%s\n%s", Host, Port, strings.Join(flag.Args(), "\n")) + return caddy.CaddyfileInput{ + Contents: []byte(confBody), + Filepath: "args", + ServerTypeName: serverType, + }, nil + } + return nil, nil +} + +// groupSiteConfigsByListenAddr groups site configs by their listen +// (bind) address, so sites that use the same listener can be served +// on the same server instance. The return value maps the listen +// address (what you pass into net.Listen) to the list of site configs. +// This function does NOT vet the configs to ensure they are compatible. +func groupSiteConfigsByListenAddr(configs []*SiteConfig) (map[string][]*SiteConfig, error) { + groups := make(map[string][]*SiteConfig) + + for _, conf := range configs { + if caddy.IsLoopback(conf.Addr.Host) && conf.ListenHost == "" { + // special case: one would not expect a site served + // at loopback to be connected to from the outside. + conf.ListenHost = conf.Addr.Host + } + if conf.Addr.Port == "" { + conf.Addr.Port = Port + } + addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.ListenHost, conf.Addr.Port)) + if err != nil { + return nil, err + } + addrstr := addr.String() + groups[addrstr] = append(groups[addrstr], conf) + } + + return groups, nil +} + +// AddMiddleware adds a middleware to a site's middleware stack. +func (sc *SiteConfig) AddMiddleware(m Middleware) { + sc.middleware = append(sc.middleware, m) +} + +// Address represents a site address. It contains +// the original input value, and the component +// parts of an address. +type Address struct { + Original, Scheme, Host, Port, Path string +} + +// String returns a human-friendly print of the address. +func (a Address) String() string { + scheme := a.Scheme + if scheme == "" { + if a.Port == "80" { + scheme = "http" + } else if a.Port == "443" { + scheme = "https" + } + } + s := scheme + if s != "" { + s += "://" + } + s += a.Host + if (scheme == "https" && a.Port != "443") || + (scheme == "http" && a.Port != "80") { + s += ":" + a.Port + } + if a.Path != "" { + s += "/" + a.Path + } + return s +} + +// VHost returns a sensible concatenation of Host:Port/Path from a. +// It's basically the a.Original but without the scheme. +func (a Address) VHost() string { + if idx := strings.Index(a.Original, "://"); idx > -1 { + return a.Original[idx+3:] + } + return a.Original +} + +// standardizeAddress parses an address string into a structured format with separate +// scheme, host, and port portions, as well as the original input string. +func standardizeAddress(str string) (Address, error) { + input := str + + // Split input into components (prepend with // to assert host by default) + if !strings.Contains(str, "//") { + str = "//" + str + } + u, err := url.Parse(str) + if err != nil { + return Address{}, err + } + + // separate host and port + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + host, port, err = net.SplitHostPort(u.Host + ":") + if err != nil { + host = u.Host + } + } + + // see if we can set port based off scheme + if port == "" { + if u.Scheme == "http" { + port = "80" + } else if u.Scheme == "https" { + port = "443" + } + } + + // repeated or conflicting scheme is confusing, so error + if u.Scheme != "" && (port == "http" || port == "https") { + return Address{}, fmt.Errorf("[%s] scheme specified twice in address", input) + } + + // error if scheme and port combination violate convention + if (u.Scheme == "http" && port == "443") || (u.Scheme == "https" && port == "80") { + return Address{}, fmt.Errorf("[%s] scheme and port violate convention", input) + } + + // standardize http and https ports to their respective port numbers + if port == "http" { + u.Scheme = "http" + port = "80" + } else if port == "https" { + u.Scheme = "https" + port = "443" + } + + return Address{Original: input, Scheme: u.Scheme, Host: host, Port: port, Path: u.Path}, err +} + +// directives is the list of all directives known to exist for the +// http server type, including non-standard (3rd-party) directives. +// The ordering of this list is important. +var directives = []string{ + // primitive actions that set up the basics of each config + "root", + "bind", + "tls", + + // these don't inject middleware handlers + "startup", + "shutdown", + + // these add middleware to the stack + "log", + "gzip", + "errors", + "header", + "rewrite", + "redir", + "ext", + "mime", + "basicauth", + "internal", + "pprof", + "expvar", + "proxy", + "fastcgi", + "websocket", + "markdown", + "templates", + "browse", +} + +const ( + // DefaultHost is the default host. + DefaultHost = "" + // DefaultPort is the default port. + DefaultPort = "2015" + // DefaultRoot is the default root folder. + DefaultRoot = "." +) + +// These "soft defaults" are configurable by +// command line flags, etc. +var ( + // Root is the site root + Root = DefaultRoot + + // Host is the site host + Host = DefaultHost + + // Port is the site port + Port = DefaultPort + + // GracefulTimeout is the maximum duration of a graceful shutdown. + GracefulTimeout time.Duration + + // HTTP2 indicates whether HTTP2 is enabled or not. + HTTP2 bool + + // QUIC indicates whether QUIC is enabled or not. + QUIC bool +) diff --git a/caddyhttp/httpserver/plugin_test.go b/caddyhttp/httpserver/plugin_test.go new file mode 100644 index 000000000..353e0948c --- /dev/null +++ b/caddyhttp/httpserver/plugin_test.go @@ -0,0 +1,92 @@ +package httpserver + +import "testing" + +func TestStandardizeAddress(t *testing.T) { + for i, test := range []struct { + input string + scheme, host, port, path string + shouldErr bool + }{ + {`localhost`, "", "localhost", "", "", false}, + {`localhost:1234`, "", "localhost", "1234", "", false}, + {`localhost:`, "", "localhost", "", "", false}, + {`0.0.0.0`, "", "0.0.0.0", "", "", false}, + {`127.0.0.1:1234`, "", "127.0.0.1", "1234", "", false}, + {`:1234`, "", "", "1234", "", false}, + {`[::1]`, "", "::1", "", "", false}, + {`[::1]:1234`, "", "::1", "1234", "", false}, + {`:`, "", "", "", "", false}, + {`localhost:http`, "http", "localhost", "80", "", false}, + {`localhost:https`, "https", "localhost", "443", "", false}, + {`:http`, "http", "", "80", "", false}, + {`:https`, "https", "", "443", "", false}, + {`http://localhost:https`, "", "", "", "", true}, // conflict + {`http://localhost:http`, "", "", "", "", true}, // repeated scheme + {`http://localhost:443`, "", "", "", "", true}, // not conventional + {`https://localhost:80`, "", "", "", "", true}, // not conventional + {`http://localhost`, "http", "localhost", "80", "", false}, + {`https://localhost`, "https", "localhost", "443", "", false}, + {`http://127.0.0.1`, "http", "127.0.0.1", "80", "", false}, + {`https://127.0.0.1`, "https", "127.0.0.1", "443", "", false}, + {`http://[::1]`, "http", "::1", "80", "", false}, + {`http://localhost:1234`, "http", "localhost", "1234", "", false}, + {`https://127.0.0.1:1234`, "https", "127.0.0.1", "1234", "", false}, + {`http://[::1]:1234`, "http", "::1", "1234", "", false}, + {``, "", "", "", "", false}, + {`::1`, "", "::1", "", "", true}, + {`localhost::`, "", "localhost::", "", "", true}, + {`#$%@`, "", "", "", "", true}, + {`host/path`, "", "host", "", "/path", false}, + {`http://host/`, "http", "host", "80", "/", false}, + {`//asdf`, "", "asdf", "", "", false}, + {`:1234/asdf`, "", "", "1234", "/asdf", false}, + {`http://host/path`, "http", "host", "80", "/path", false}, + {`https://host:443/path/foo`, "https", "host", "443", "/path/foo", false}, + {`host:80/path`, "", "host", "80", "/path", false}, + {`host:https/path`, "https", "host", "443", "/path", false}, + } { + actual, err := standardizeAddress(test.input) + + if err != nil && !test.shouldErr { + t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err) + } + if err == nil && test.shouldErr { + t.Errorf("Test %d (%s): Expected error, but had none", i, test.input) + } + + if !test.shouldErr && actual.Original != test.input { + t.Errorf("Test %d (%s): Expected original '%s', got '%s'", i, test.input, test.input, actual.Original) + } + if actual.Scheme != test.scheme { + t.Errorf("Test %d (%s): Expected scheme '%s', got '%s'", i, test.input, test.scheme, actual.Scheme) + } + if actual.Host != test.host { + t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host) + } + if actual.Port != test.port { + t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port) + } + if actual.Path != test.path { + t.Errorf("Test %d (%s): Expected path '%s', got '%s'", i, test.input, test.path, actual.Path) + } + } +} + +func TestAddressVHost(t *testing.T) { + for i, test := range []struct { + addr Address + expected string + }{ + {Address{Original: "host:1234"}, "host:1234"}, + {Address{Original: "host:1234/foo"}, "host:1234/foo"}, + {Address{Original: "host/foo"}, "host/foo"}, + {Address{Original: "http://host/foo"}, "host/foo"}, + {Address{Original: "https://host/foo"}, "host/foo"}, + } { + actual := test.addr.VHost() + if actual != test.expected { + t.Errorf("Test %d: expected '%s' but got '%s'", i, test.expected, actual) + } + } +} diff --git a/middleware/recorder.go b/caddyhttp/httpserver/recorder.go similarity index 99% rename from middleware/recorder.go rename to caddyhttp/httpserver/recorder.go index 50f4811cf..5788ab44b 100644 --- a/middleware/recorder.go +++ b/caddyhttp/httpserver/recorder.go @@ -1,4 +1,4 @@ -package middleware +package httpserver import ( "bufio" diff --git a/middleware/recorder_test.go b/caddyhttp/httpserver/recorder_test.go similarity index 98% rename from middleware/recorder_test.go rename to caddyhttp/httpserver/recorder_test.go index ed6c6abdd..0772d669f 100644 --- a/middleware/recorder_test.go +++ b/caddyhttp/httpserver/recorder_test.go @@ -1,4 +1,4 @@ -package middleware +package httpserver import ( "net/http" diff --git a/middleware/replacer.go b/caddyhttp/httpserver/replacer.go similarity index 99% rename from middleware/replacer.go rename to caddyhttp/httpserver/replacer.go index 6748f6060..e0299b8bf 100644 --- a/middleware/replacer.go +++ b/caddyhttp/httpserver/replacer.go @@ -1,4 +1,4 @@ -package middleware +package httpserver import ( "net" diff --git a/middleware/replacer_test.go b/caddyhttp/httpserver/replacer_test.go similarity index 99% rename from middleware/replacer_test.go rename to caddyhttp/httpserver/replacer_test.go index f5d50b047..466e06239 100644 --- a/middleware/replacer_test.go +++ b/caddyhttp/httpserver/replacer_test.go @@ -1,4 +1,4 @@ -package middleware +package httpserver import ( "net/http" diff --git a/caddy/setup/roller.go b/caddyhttp/httpserver/roller.go similarity index 51% rename from caddy/setup/roller.go rename to caddyhttp/httpserver/roller.go index fedc52c58..b82646361 100644 --- a/caddy/setup/roller.go +++ b/caddyhttp/httpserver/roller.go @@ -1,12 +1,36 @@ -package setup +package httpserver import ( + "io" "strconv" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy" + + "gopkg.in/natefinch/lumberjack.v2" ) -func parseRoller(c *Controller) (*middleware.LogRoller, error) { +// LogRoller implements a type that provides a rolling logger. +type LogRoller struct { + Filename string + MaxSize int + MaxAge int + MaxBackups int + LocalTime bool +} + +// GetLogWriter returns an io.Writer that writes to a rolling logger. +func (l LogRoller) GetLogWriter() io.Writer { + return &lumberjack.Logger{ + Filename: l.Filename, + MaxSize: l.MaxSize, + MaxAge: l.MaxAge, + MaxBackups: l.MaxBackups, + LocalTime: l.LocalTime, + } +} + +// ParseRoller parses roller contents out of c. +func ParseRoller(c *caddy.Controller) (*LogRoller, error) { var size, age, keep int // This is kind of a hack to support nested blocks: // As we are already in a block: either log or errors, @@ -31,7 +55,7 @@ func parseRoller(c *Controller) (*middleware.LogRoller, error) { return nil, err } } - return &middleware.LogRoller{ + return &LogRoller{ MaxSize: size, MaxAge: age, MaxBackups: keep, diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go new file mode 100644 index 000000000..e1270ec64 --- /dev/null +++ b/caddyhttp/httpserver/server.go @@ -0,0 +1,378 @@ +// Package httpserver implements an HTTP server on top of Caddy. +package httpserver + +import ( + "crypto/tls" + "fmt" + "log" + "net" + "net/http" + "os" + "path" + "runtime" + "strings" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/h2quic" + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/staticfiles" + "github.com/mholt/caddy/caddytls" +) + +// Server is the HTTP server implementation. +type Server struct { + Server *http.Server + quicServer *h2quic.Server + listener net.Listener + listenerMu sync.Mutex + sites []*SiteConfig + connTimeout time.Duration // max time to wait for a connection before force stop + connWg sync.WaitGroup // one increment per connection + tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine + vhosts *vhostTrie +} + +// ensure it satisfies the interface +var _ caddy.GracefulServer = new(Server) + +// NewServer creates a new Server instance that will listen on addr +// and will serve the sites configured in group. +func NewServer(addr string, group []*SiteConfig) (*Server, error) { + s := &Server{ + Server: &http.Server{ + Addr: addr, + // TODO: Make these values configurable? + // ReadTimeout: 2 * time.Minute, + // WriteTimeout: 2 * time.Minute, + // MaxHeaderBytes: 1 << 16, + }, + vhosts: newVHostTrie(), + sites: group, + connTimeout: GracefulTimeout, + } + s.Server.Handler = s // this is weird, but whatever + + // Disable HTTP/2 if desired + if !HTTP2 { + s.Server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + } + + // Enable QUIC if desired + if QUIC { + s.quicServer = &h2quic.Server{Server: s.Server} + } + + // We have to bound our wg with one increment + // to prevent a "race condition" that is hard-coded + // into sync.WaitGroup.Wait() - basically, an add + // with a positive delta must be guaranteed to + // occur before Wait() is called on the wg. + // In a way, this kind of acts as a safety barrier. + s.connWg.Add(1) + + // Set up TLS configuration + var tlsConfigs []*caddytls.Config + var err error + for _, site := range group { + tlsConfigs = append(tlsConfigs, site.TLS) + } + s.Server.TLSConfig, err = caddytls.MakeTLSConfig(tlsConfigs) + if err != nil { + return nil, err + } + + // Compile custom middleware for every site (enables virtual hosting) + for _, site := range group { + stack := Handler(staticfiles.FileServer{Root: http.Dir(site.Root), Hide: site.HiddenFiles}) + for i := len(site.middleware) - 1; i >= 0; i-- { + stack = site.middleware[i](stack) + } + site.middlewareChain = stack + s.vhosts.Insert(site.Addr.VHost(), site) + } + + return s, nil +} + +// Listen creates an active listener for s that can be +// used to serve requests. +func (s *Server) Listen() (net.Listener, error) { + if s.Server == nil { + return nil, fmt.Errorf("Server field is nil") + } + + ln, err := net.Listen("tcp", s.Server.Addr) + if err != nil { + var succeeded bool + if runtime.GOOS == "windows" { + // Windows has been known to keep sockets open even after closing the listeners. + // Tests reveal this error case easily because they call Start() then Stop() + // in succession. TODO: Better way to handle this? And why limit this to Windows? + for i := 0; i < 20; i++ { + time.Sleep(100 * time.Millisecond) + ln, err = net.Listen("tcp", s.Server.Addr) + if err == nil { + succeeded = true + break + } + } + } + if !succeeded { + return nil, err + } + } + + // Very important to return a concrete caddy.Listener + // implementation for graceful restarts. + return ln.(*net.TCPListener), nil +} + +// Serve serves requests on ln. It blocks until ln is closed. +func (s *Server) Serve(ln net.Listener) error { + if tcpLn, ok := ln.(*net.TCPListener); ok { + ln = tcpKeepAliveListener{TCPListener: tcpLn} + } + + ln = newGracefulListener(ln, &s.connWg) + + s.listenerMu.Lock() + s.listener = ln + s.listenerMu.Unlock() + + if s.Server.TLSConfig != nil { + // Create TLS listener - note that we do not replace s.listener + // with this TLS listener; tls.listener is unexported and does + // not implement the File() method we need for graceful restarts + // on POSIX systems. + // TODO: Is this ^ still relevant anymore? Maybe we can now that it's a net.Listener... + ln = tls.NewListener(ln, s.Server.TLSConfig) + + // Rotate TLS session ticket keys + s.tlsGovChan = caddytls.RotateSessionTicketKeys(s.Server.TLSConfig) + } + + if QUIC { + go func() { + err := s.quicServer.ListenAndServe() + if err != nil { + log.Printf("[ERROR] listening for QUIC connections: %v", err) + } + }() + } + + err := s.Server.Serve(ln) + if QUIC { + s.quicServer.Close() + } + return err +} + +// ServeHTTP is the entry point of all HTTP requests. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer func() { + // We absolutely need to be sure we stay alive up here, + // even though, in theory, the errors middleware does this. + if rec := recover(); rec != nil { + log.Printf("[PANIC] %v", rec) + DefaultErrorFunc(w, r, http.StatusInternalServerError) + } + }() + + w.Header().Set("Server", "Caddy") + + sanitizePath(r) + + status, _ := s.serveHTTP(w, r) + + // Fallback error response in case error handling wasn't chained in + if status >= 400 { + DefaultErrorFunc(w, r, status) + } +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + // strip out the port because it's not used in virtual + // hosting; the port is irrelevant because each listener + // is on a different port. + hostname, _, err := net.SplitHostPort(r.Host) + if err != nil { + hostname = r.Host + } + + // look up the virtualhost; if no match, serve error + vhost, pathPrefix := s.vhosts.Match(hostname + r.URL.Path) + + if vhost == nil { + // check for ACME challenge even if vhost is nil; + // could be a new host coming online soon + if caddytls.HTTPChallengeHandler(w, r, caddytls.DefaultHTTPAlternatePort) { + return 0, nil + } + // otherwise, log the error and write a message to the client + remoteHost, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + remoteHost = r.RemoteAddr + } + writeTextResponse(w, http.StatusNotFound, "No such site at "+s.Server.Addr) + log.Printf("[INFO] %s - No such site at %s (Remote: %s, Referer: %s)", + hostname, s.Server.Addr, remoteHost, r.Header.Get("Referer")) + return 0, nil + } + + // we still check for ACME challenge if the vhost exists, + // because we must apply its HTTP challenge config settings + if s.proxyHTTPChallenge(vhost, w, r) { + return 0, nil + } + + // trim the path portion of the site address from the beginning of + // the URL path, so a request to example.com/foo/blog on the site + // defined as example.com/foo appears as /blog instead of /foo/blog. + if pathPrefix != "/" { + r.URL.Path = strings.TrimPrefix(r.URL.Path, pathPrefix) + if !strings.HasPrefix(r.URL.Path, "/") { + r.URL.Path = "/" + r.URL.Path + } + } + + return vhost.middlewareChain.ServeHTTP(w, r) +} + +// proxyHTTPChallenge solves the ACME HTTP challenge if r is the HTTP +// request for the challenge. If it is, and if the request has been +// fulfilled (response written), true is returned; false otherwise. +// If you don't have a vhost, just call the challenge handler directly. +func (s *Server) proxyHTTPChallenge(vhost *SiteConfig, w http.ResponseWriter, r *http.Request) bool { + if vhost.Addr.Port != caddytls.HTTPChallengePort { + return false + } + if vhost.TLS != nil && vhost.TLS.Manual { + return false + } + altPort := caddytls.DefaultHTTPAlternatePort + if vhost.TLS != nil && vhost.TLS.AltHTTPPort != "" { + altPort = vhost.TLS.AltHTTPPort + } + return caddytls.HTTPChallengeHandler(w, r, altPort) +} + +// Address returns the address s was assigned to listen on. +func (s *Server) Address() string { + return s.Server.Addr +} + +// Stop stops s gracefully (or forcefully after timeout) and +// closes its listener. +func (s *Server) Stop() (err error) { + s.Server.SetKeepAlivesEnabled(false) + + if runtime.GOOS != "windows" { + // force connections to close after timeout + done := make(chan struct{}) + go func() { + s.connWg.Done() // decrement our initial increment used as a barrier + s.connWg.Wait() + close(done) + }() + + // Wait for remaining connections to finish or + // force them all to close after timeout + select { + case <-time.After(s.connTimeout): + case <-done: + } + } + + // Close the listener now; this stops the server without delay + s.listenerMu.Lock() + if s.listener != nil { + err = s.listener.Close() + } + s.listenerMu.Unlock() + + // Closing this signals any TLS governor goroutines to exit + if s.tlsGovChan != nil { + close(s.tlsGovChan) + } + + return +} + +// sanitizePath collapses any ./ ../ /// madness +// which helps prevent path traversal attacks. +// Note to middleware: use URL.RawPath If you need +// the "original" URL.Path value. +func sanitizePath(r *http.Request) { + if r.URL.Path == "/" { + return + } + cleanedPath := path.Clean(r.URL.Path) + if cleanedPath == "." { + r.URL.Path = "/" + } else { + if !strings.HasPrefix(cleanedPath, "/") { + cleanedPath = "/" + cleanedPath + } + if strings.HasSuffix(r.URL.Path, "/") && !strings.HasSuffix(cleanedPath, "/") { + cleanedPath = cleanedPath + "/" + } + r.URL.Path = cleanedPath + } +} + +// OnStartupComplete lists the sites served by this server +// and any relevant information, assuming caddy.Quiet == false. +func (s *Server) OnStartupComplete() { + if caddy.Quiet { + return + } + for _, site := range s.sites { + output := site.Addr.String() + if caddy.IsLocalhost(s.Address()) && !caddy.IsLocalhost(site.Addr.Host) { + output += " (only accessible on this machine)" + } + fmt.Println(output) + } +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +// +// Borrowed from the Go standard library. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +// Accept accepts the connection with a keep-alive enabled. +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// File implements caddy.Listener; it returns the underlying file of the listener. +func (ln tcpKeepAliveListener) File() (*os.File, error) { + return ln.TCPListener.File() +} + +// DefaultErrorFunc responds to an HTTP request with a simple description +// of the specified HTTP status code. +func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) { + writeTextResponse(w, status, fmt.Sprintf("%d %s", status, http.StatusText(status))) +} + +// writeTextResponse writes body with code status to w. The body will +// be interpreted as plain text. +func writeTextResponse(w http.ResponseWriter, status int, body string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(status) + w.Write([]byte(body)) +} diff --git a/caddyhttp/httpserver/server_test.go b/caddyhttp/httpserver/server_test.go new file mode 100644 index 000000000..d8e53c100 --- /dev/null +++ b/caddyhttp/httpserver/server_test.go @@ -0,0 +1,15 @@ +package httpserver + +import ( + "net/http" + "testing" +) + +func TestAddress(t *testing.T) { + addr := "127.0.0.1:9005" + srv := &Server{Server: &http.Server{Addr: addr}} + + if got, want := srv.Address(), addr; got != want { + t.Errorf("Expected '%s' but got '%s'", want, got) + } +} diff --git a/caddyhttp/httpserver/siteconfig.go b/caddyhttp/httpserver/siteconfig.go new file mode 100644 index 000000000..dffe74901 --- /dev/null +++ b/caddyhttp/httpserver/siteconfig.go @@ -0,0 +1,53 @@ +package httpserver + +import "github.com/mholt/caddy/caddytls" + +// SiteConfig contains information about a site +// (also known as a virtual host). +type SiteConfig struct { + // The address of the site + Addr Address + + // The hostname to bind listener to; + // defaults to Addr.Host + ListenHost string + + // TLS configuration + TLS *caddytls.Config + + // Uncompiled middleware stack + middleware []Middleware + + // Compiled middleware stack + middlewareChain Handler + + // Directory from which to serve files + Root string + + // A list of files to hide (for example, the + // source Caddyfile). TODO: Enforcing this + // should be centralized, for example, a + // standardized way of loading files from disk + // for a request. + HiddenFiles []string +} + +// TLSConfig returns s.TLS. +func (s SiteConfig) TLSConfig() *caddytls.Config { + return s.TLS +} + +// Host returns s.Addr.Host. +func (s SiteConfig) Host() string { + return s.Addr.Host +} + +// Port returns s.Addr.Port. +func (s SiteConfig) Port() string { + return s.Addr.Port +} + +// Middleware returns s.middleware (useful for tests). +func (s SiteConfig) Middleware() []Middleware { + return s.middleware +} diff --git a/caddyhttp/httpserver/vhosttrie.go b/caddyhttp/httpserver/vhosttrie.go new file mode 100644 index 000000000..558255783 --- /dev/null +++ b/caddyhttp/httpserver/vhosttrie.go @@ -0,0 +1,139 @@ +package httpserver + +import ( + "net" + "strings" +) + +// vhostTrie facilitates virtual hosting. It matches +// requests first by hostname (with support for +// wildcards as TLS certificates support them), then +// by longest matching path. +type vhostTrie struct { + edges map[string]*vhostTrie + site *SiteConfig // also known as a virtual host + path string // the path portion of the key for this node +} + +// newVHostTrie returns a new vhostTrie. +func newVHostTrie() *vhostTrie { + return &vhostTrie{edges: make(map[string]*vhostTrie)} +} + +// Insert adds stack to t keyed by key. The key should be +// a valid "host/path" combination (or just host). +func (t *vhostTrie) Insert(key string, site *SiteConfig) { + host, path := t.splitHostPath(key) + if _, ok := t.edges[host]; !ok { + t.edges[host] = newVHostTrie() + } + t.edges[host].insertPath(path, path, site) +} + +// insertPath expects t to be a host node (not a root node), +// and inserts site into the t according to remainingPath. +func (t *vhostTrie) insertPath(remainingPath, originalPath string, site *SiteConfig) { + if remainingPath == "" { + t.site = site + t.path = originalPath + return + } + ch := string(remainingPath[0]) + if _, ok := t.edges[ch]; !ok { + t.edges[ch] = newVHostTrie() + } + t.edges[ch].insertPath(remainingPath[1:], originalPath, site) +} + +// Match returns the virtual host (site) in v with +// the closest match to key. If there was a match, +// it returns the SiteConfig and the path portion of +// the key used to make the match. The matched path +// would be a prefix of the path portion of the +// key, if not the whole path portion of the key. +// If there is no match, nil and empty string will +// be returned. +// +// A typical key will be in the form "host" or "host/path". +func (t *vhostTrie) Match(key string) (*SiteConfig, string) { + host, path := t.splitHostPath(key) + // try the given host, then, if no match, try wildcard hosts + branch := t.matchHost(host) + if branch == nil { + branch = t.matchHost("0.0.0.0") + } + if branch == nil { + branch = t.matchHost("") + } + if branch == nil { + return nil, "" + } + node := branch.matchPath(path) + if node == nil { + return nil, "" + } + return node.site, node.path +} + +// matchHost returns the vhostTrie matching host. The matching +// algorithm is the same as used to match certificates to host +// with SNI during TLS handshakes. In other words, it supports, +// to some degree, the use of wildcard (*) characters. +func (t *vhostTrie) matchHost(host string) *vhostTrie { + // try exact match + if subtree, ok := t.edges[host]; ok { + return subtree + } + + // then try replacing labels in the host + // with wildcards until we get a match + labels := strings.Split(host, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if subtree, ok := t.edges[candidate]; ok { + return subtree + } + } + + return nil +} + +// matchPath traverses t until it finds the longest key matching +// remainingPath, and returns its node. +func (t *vhostTrie) matchPath(remainingPath string) *vhostTrie { + var longestMatch *vhostTrie + for len(remainingPath) > 0 { + ch := string(remainingPath[0]) + next, ok := t.edges[ch] + if !ok { + break + } + if next.site != nil { + longestMatch = next + } + t = next + remainingPath = remainingPath[1:] + } + return longestMatch +} + +// splitHostPath separates host from path in key. +func (t *vhostTrie) splitHostPath(key string) (host, path string) { + parts := strings.SplitN(key, "/", 2) + host, path = strings.ToLower(parts[0]), "/" + if len(parts) > 1 { + path += parts[1] + } + // strip out the port (if present) from the host, since + // each port has its own socket, and each socket has its + // own listener, and each listener has its own server + // instance, and each server instance has its own vhosts. + // removing the port is a simple way to standardize so + // when requests come in, we can be sure to get a match. + hostname, _, err := net.SplitHostPort(host) + if err == nil { + host = hostname + } + return +} diff --git a/caddyhttp/httpserver/vhosttrie_test.go b/caddyhttp/httpserver/vhosttrie_test.go new file mode 100644 index 000000000..95ef1fba5 --- /dev/null +++ b/caddyhttp/httpserver/vhosttrie_test.go @@ -0,0 +1,141 @@ +package httpserver + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestVHostTrie(t *testing.T) { + trie := newVHostTrie() + populateTestTrie(trie, []string{ + "example", + "example.com", + "*.example.com", + "example.com/foo", + "example.com/foo/bar", + "*.example.com/test", + }) + assertTestTrie(t, trie, []vhostTrieTest{ + {"not-in-trie.com", false, "", "/"}, + {"example", true, "example", "/"}, + {"example.com", true, "example.com", "/"}, + {"example.com/test", true, "example.com", "/"}, + {"example.com/foo", true, "example.com/foo", "/foo"}, + {"example.com/foo/", true, "example.com/foo", "/foo"}, + {"EXAMPLE.COM/foo", true, "example.com/foo", "/foo"}, + {"EXAMPLE.COM/Foo", true, "example.com", "/"}, + {"example.com/foo/bar", true, "example.com/foo/bar", "/foo/bar"}, + {"example.com/foo/bar/baz", true, "example.com/foo/bar", "/foo/bar"}, + {"example.com/foo/other", true, "example.com/foo", "/foo"}, + {"foo.example.com", true, "*.example.com", "/"}, + {"foo.example.com/else", true, "*.example.com", "/"}, + }, false) +} + +func TestVHostTrieWildcard1(t *testing.T) { + trie := newVHostTrie() + populateTestTrie(trie, []string{ + "example.com", + "", + }) + assertTestTrie(t, trie, []vhostTrieTest{ + {"not-in-trie.com", true, "", "/"}, + {"example.com", true, "example.com", "/"}, + {"example.com/foo", true, "example.com", "/"}, + {"not-in-trie.com/asdf", true, "", "/"}, + }, true) +} + +func TestVHostTrieWildcard2(t *testing.T) { + trie := newVHostTrie() + populateTestTrie(trie, []string{ + "0.0.0.0/asdf", + }) + assertTestTrie(t, trie, []vhostTrieTest{ + {"example.com/asdf/foo", true, "0.0.0.0/asdf", "/asdf"}, + {"example.com/foo", false, "", "/"}, + {"host/asdf", true, "0.0.0.0/asdf", "/asdf"}, + }, true) +} + +func TestVHostTrieWildcard3(t *testing.T) { + trie := newVHostTrie() + populateTestTrie(trie, []string{ + "*/foo", + }) + assertTestTrie(t, trie, []vhostTrieTest{ + {"example.com/foo", true, "*/foo", "/foo"}, + {"example.com", false, "", "/"}, + }, true) +} + +func TestVHostTriePort(t *testing.T) { + // Make sure port is stripped out + trie := newVHostTrie() + populateTestTrie(trie, []string{ + "example.com:1234", + }) + assertTestTrie(t, trie, []vhostTrieTest{ + {"example.com/foo", true, "example.com:1234", "/"}, + }, true) +} + +func populateTestTrie(trie *vhostTrie, keys []string) { + for _, key := range keys { + // we wrap this in a func, passing in the key, otherwise the + // handler always writes the last key to the response, even + // if the handler is actually from one of the earlier keys. + func(key string) { + site := &SiteConfig{ + middlewareChain: HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + w.Write([]byte(key)) + return 0, nil + }), + } + trie.Insert(key, site) + }(key) + } +} + +type vhostTrieTest struct { + query string + expectMatch bool + expectedKey string + matchedPrefix string // the path portion of a key that is expected to be matched +} + +func assertTestTrie(t *testing.T, trie *vhostTrie, tests []vhostTrieTest, hasWildcardHosts bool) { + for i, test := range tests { + site, pathPrefix := trie.Match(test.query) + + if !test.expectMatch { + if site != nil { + // If not expecting a value, then just make sure we didn't get one + t.Errorf("Test %d: Expected no matches, but got %v", i, site) + } + continue + } + + // Otherwise, we must assert we got a value + if site == nil { + t.Errorf("Test %d: Expected non-nil return value, but got: %v", i, site) + continue + } + + // And it must be the correct value + resp := httptest.NewRecorder() + site.middlewareChain.ServeHTTP(resp, nil) + actualHandlerKey := resp.Body.String() + if actualHandlerKey != test.expectedKey { + t.Errorf("Test %d: Expected match '%s' but matched '%s'", + i, test.expectedKey, actualHandlerKey) + } + + // The path prefix must also be correct + if test.matchedPrefix != pathPrefix { + t.Errorf("Test %d: Expected matched path prefix to be '%s', got '%s'", + i, test.matchedPrefix, pathPrefix) + } + } +} diff --git a/middleware/log/log.go b/caddyhttp/log/log.go similarity index 81% rename from middleware/log/log.go rename to caddyhttp/log/log.go index 4e0c2e29b..1f0b5f0bc 100644 --- a/middleware/log/log.go +++ b/caddyhttp/log/log.go @@ -6,25 +6,34 @@ import ( "log" "net/http" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" ) +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "log", + ServerType: "http", + Action: setup, + }) +} + // Logger is a basic request logging middleware. type Logger struct { - Next middleware.Handler + Next httpserver.Handler Rules []Rule ErrorFunc func(http.ResponseWriter, *http.Request, int) // failover error handler } func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, rule := range l.Rules { - if middleware.Path(r.URL.Path).Matches(rule.PathScope) { + if httpserver.Path(r.URL.Path).Matches(rule.PathScope) { // Record the response - responseRecorder := middleware.NewResponseRecorder(w) + responseRecorder := httpserver.NewResponseRecorder(w) // Attach the Replacer we'll use so that other middlewares can // set their own placeholders if they want to. - rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue) + rep := httpserver.NewReplacer(r, responseRecorder, CommonLogEmptyValue) responseRecorder.Replacer = rep // Bon voyage, request! @@ -58,7 +67,7 @@ type Rule struct { OutputFile string Format string Log *log.Logger - Roller *middleware.LogRoller + Roller *httpserver.LogRoller } const ( diff --git a/middleware/log/log_test.go b/caddyhttp/log/log_test.go similarity index 92% rename from middleware/log/log_test.go rename to caddyhttp/log/log_test.go index 0ce12b0ca..af48f4424 100644 --- a/middleware/log/log_test.go +++ b/caddyhttp/log/log_test.go @@ -8,13 +8,13 @@ import ( "strings" "testing" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy/caddyhttp/httpserver" ) type erroringMiddleware struct{} func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - if rr, ok := w.(*middleware.ResponseRecorder); ok { + if rr, ok := w.(*httpserver.ResponseRecorder); ok { rr.Replacer.Set("testval", "foobar") } return http.StatusNotFound, nil diff --git a/caddy/setup/log.go b/caddyhttp/log/setup.go similarity index 66% rename from caddy/setup/log.go rename to caddyhttp/log/setup.go index 8bb4788a1..9aa3d9a49 100644 --- a/caddy/setup/log.go +++ b/caddyhttp/log/setup.go @@ -1,4 +1,4 @@ -package setup +package log import ( "io" @@ -6,20 +6,19 @@ import ( "os" "github.com/hashicorp/go-syslog" - "github.com/mholt/caddy/middleware" - caddylog "github.com/mholt/caddy/middleware/log" - "github.com/mholt/caddy/server" + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" ) -// Log sets up the logging middleware. -func Log(c *Controller) (middleware.Middleware, error) { +// setup sets up the logging middleware. +func setup(c *caddy.Controller) error { rules, err := logParse(c) if err != nil { - return nil, err + return err } // Open the log files for writing when the server starts - c.Startup = append(c.Startup, func() error { + c.OnStartup(func() error { for i := 0; i < len(rules); i++ { var err error var writer io.Writer @@ -54,24 +53,26 @@ func Log(c *Controller) (middleware.Middleware, error) { return nil }) - return func(next middleware.Handler) middleware.Handler { - return caddylog.Logger{Next: next, Rules: rules, ErrorFunc: server.DefaultErrorFunc} - }, nil + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Logger{Next: next, Rules: rules, ErrorFunc: httpserver.DefaultErrorFunc} + }) + + return nil } -func logParse(c *Controller) ([]caddylog.Rule, error) { - var rules []caddylog.Rule +func logParse(c *caddy.Controller) ([]Rule, error) { + var rules []Rule for c.Next() { args := c.RemainingArgs() - var logRoller *middleware.LogRoller + var logRoller *httpserver.LogRoller if c.NextBlock() { if c.Val() == "rotate" { if c.NextArg() { if c.Val() == "{" { var err error - logRoller, err = parseRoller(c) + logRoller, err = httpserver.ParseRoller(c) if err != nil { return nil, err } @@ -87,37 +88,37 @@ func logParse(c *Controller) ([]caddylog.Rule, error) { } if len(args) == 0 { // Nothing specified; use defaults - rules = append(rules, caddylog.Rule{ + rules = append(rules, Rule{ PathScope: "/", - OutputFile: caddylog.DefaultLogFilename, - Format: caddylog.DefaultLogFormat, + OutputFile: DefaultLogFilename, + Format: DefaultLogFormat, Roller: logRoller, }) } else if len(args) == 1 { // Only an output file specified - rules = append(rules, caddylog.Rule{ + rules = append(rules, Rule{ PathScope: "/", OutputFile: args[0], - Format: caddylog.DefaultLogFormat, + Format: DefaultLogFormat, Roller: logRoller, }) } else { // Path scope, output file, and maybe a format specified - format := caddylog.DefaultLogFormat + format := DefaultLogFormat if len(args) > 2 { switch args[2] { case "{common}": - format = caddylog.CommonLogFormat + format = CommonLogFormat case "{combined}": - format = caddylog.CombinedLogFormat + format = CombinedLogFormat default: format = args[2] } } - rules = append(rules, caddylog.Rule{ + rules = append(rules, Rule{ PathScope: args[0], OutputFile: args[1], Format: format, diff --git a/caddy/setup/log_test.go b/caddyhttp/log/setup_test.go similarity index 73% rename from caddy/setup/log_test.go rename to caddyhttp/log/setup_test.go index ae7a96e31..436002ac3 100644 --- a/caddy/setup/log_test.go +++ b/caddyhttp/log/setup_test.go @@ -1,28 +1,27 @@ -package setup +package log import ( "testing" - "github.com/mholt/caddy/middleware" - caddylog "github.com/mholt/caddy/middleware/log" + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" ) -func TestLog(t *testing.T) { - - c := NewTestController(`log`) - - mid, err := Log(c) +func TestSetup(t *testing.T) { + cfg := httpserver.GetConfig("") + err := setup(caddy.NewTestController(`log`)) if err != nil { t.Errorf("Expected no errors, got: %v", err) } - if mid == nil { + mids := cfg.Middleware() + if mids == nil { t.Fatal("Expected middleware, was nil instead") } - handler := mid(EmptyNext) - myHandler, ok := handler.(caddylog.Logger) + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(Logger) if !ok { t.Fatalf("Expected handler to be type Logger, got: %#v", handler) @@ -31,16 +30,16 @@ func TestLog(t *testing.T) { if myHandler.Rules[0].PathScope != "/" { t.Errorf("Expected / as the default PathScope") } - if myHandler.Rules[0].OutputFile != caddylog.DefaultLogFilename { - t.Errorf("Expected %s as the default OutputFile", caddylog.DefaultLogFilename) + if myHandler.Rules[0].OutputFile != DefaultLogFilename { + t.Errorf("Expected %s as the default OutputFile", DefaultLogFilename) } - if myHandler.Rules[0].Format != caddylog.DefaultLogFormat { - t.Errorf("Expected %s as the default Log Format", caddylog.DefaultLogFormat) + if myHandler.Rules[0].Format != DefaultLogFormat { + t.Errorf("Expected %s as the default Log Format", DefaultLogFormat) } if myHandler.Rules[0].Roller != nil { t.Errorf("Expected Roller to be nil, got: %v", *myHandler.Rules[0].Roller) } - if !SameNext(myHandler.Next, EmptyNext) { + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { t.Error("'Next' field of handler was not set properly") } @@ -50,50 +49,50 @@ func TestLogParse(t *testing.T) { tests := []struct { inputLogRules string shouldErr bool - expectedLogRules []caddylog.Rule + expectedLogRules []Rule }{ - {`log`, false, []caddylog.Rule{{ + {`log`, false, []Rule{{ PathScope: "/", - OutputFile: caddylog.DefaultLogFilename, - Format: caddylog.DefaultLogFormat, + OutputFile: DefaultLogFilename, + Format: DefaultLogFormat, }}}, - {`log log.txt`, false, []caddylog.Rule{{ + {`log log.txt`, false, []Rule{{ PathScope: "/", OutputFile: "log.txt", - Format: caddylog.DefaultLogFormat, + Format: DefaultLogFormat, }}}, - {`log /api log.txt`, false, []caddylog.Rule{{ + {`log /api log.txt`, false, []Rule{{ PathScope: "/api", OutputFile: "log.txt", - Format: caddylog.DefaultLogFormat, + Format: DefaultLogFormat, }}}, - {`log /serve stdout`, false, []caddylog.Rule{{ + {`log /serve stdout`, false, []Rule{{ PathScope: "/serve", OutputFile: "stdout", - Format: caddylog.DefaultLogFormat, + Format: DefaultLogFormat, }}}, - {`log /myapi log.txt {common}`, false, []caddylog.Rule{{ + {`log /myapi log.txt {common}`, false, []Rule{{ PathScope: "/myapi", OutputFile: "log.txt", - Format: caddylog.CommonLogFormat, + Format: CommonLogFormat, }}}, - {`log /test accesslog.txt {combined}`, false, []caddylog.Rule{{ + {`log /test accesslog.txt {combined}`, false, []Rule{{ PathScope: "/test", OutputFile: "accesslog.txt", - Format: caddylog.CombinedLogFormat, + Format: CombinedLogFormat, }}}, {`log /api1 log.txt - log /api2 accesslog.txt {combined}`, false, []caddylog.Rule{{ + log /api2 accesslog.txt {combined}`, false, []Rule{{ PathScope: "/api1", OutputFile: "log.txt", - Format: caddylog.DefaultLogFormat, + Format: DefaultLogFormat, }, { PathScope: "/api2", OutputFile: "accesslog.txt", - Format: caddylog.CombinedLogFormat, + Format: CombinedLogFormat, }}}, {`log /api3 stdout {host} - log /api4 log.txt {when}`, false, []caddylog.Rule{{ + log /api4 log.txt {when}`, false, []Rule{{ PathScope: "/api3", OutputFile: "stdout", Format: "{host}", @@ -102,11 +101,11 @@ func TestLogParse(t *testing.T) { OutputFile: "log.txt", Format: "{when}", }}}, - {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []caddylog.Rule{{ + {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []Rule{{ PathScope: "/", OutputFile: "access.log", - Format: caddylog.DefaultLogFormat, - Roller: &middleware.LogRoller{ + Format: DefaultLogFormat, + Roller: &httpserver.LogRoller{ MaxSize: 2, MaxAge: 10, MaxBackups: 3, @@ -115,7 +114,7 @@ func TestLogParse(t *testing.T) { }}}, } for i, test := range tests { - c := NewTestController(test.inputLogRules) + c := caddy.NewTestController(test.inputLogRules) actualLogRules, err := logParse(c) if err == nil && test.shouldErr { diff --git a/caddyhttp/root/root.go b/caddyhttp/root/root.go new file mode 100644 index 000000000..b4e485d1f --- /dev/null +++ b/caddyhttp/root/root.go @@ -0,0 +1,42 @@ +package root + +import ( + "log" + "os" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "root", + ServerType: "http", + Action: setupRoot, + }) +} + +func setupRoot(c *caddy.Controller) error { + config := httpserver.GetConfig(c.Key) + + for c.Next() { + if !c.NextArg() { + return c.ArgErr() + } + config.Root = c.Val() + } + + // Check if root path exists + _, err := os.Stat(config.Root) + if err != nil { + if os.IsNotExist(err) { + // Allow this, because the folder might appear later. + // But make sure the user knows! + log.Printf("[WARNING] Root path does not exist: %s", config.Root) + } else { + return c.Errf("Unable to access root path '%s': %v", config.Root, err) + } + } + + return nil +} diff --git a/caddy/setup/root_test.go b/caddyhttp/root/root_test.go similarity index 83% rename from caddy/setup/root_test.go rename to caddyhttp/root/root_test.go index 8b38e6d04..20b2c7a9b 100644 --- a/caddy/setup/root_test.go +++ b/caddyhttp/root/root_test.go @@ -1,4 +1,4 @@ -package setup +package root import ( "fmt" @@ -7,9 +7,13 @@ import ( "path/filepath" "strings" "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" ) func TestRoot(t *testing.T) { + cfg := httpserver.GetConfig("") // Predefined error substrings parseErrContent := "Parse error:" @@ -61,8 +65,8 @@ func TestRoot(t *testing.T) { } for i, test := range tests { - c := NewTestController(test.input) - mid, err := Root(c) + c := caddy.NewTestController(test.input) + err := setupRoot(c) if test.shouldErr && err == nil { t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) @@ -78,14 +82,9 @@ func TestRoot(t *testing.T) { } } - // the Root method always returns a nil middleware - if mid != nil { - t.Errorf("Middware, returned from Root() was not nil: %v", mid) - } - - // check c.Root only if we are in a positive test. - if !test.shouldErr && test.expectedRoot != c.Root { - t.Errorf("Root not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedRoot, c.Root) + // check root only if we are in a positive test. + if !test.shouldErr && test.expectedRoot != cfg.Root { + t.Errorf("Root not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedRoot, cfg.Root) } } } @@ -93,16 +92,13 @@ func TestRoot(t *testing.T) { // getTempDirPath returnes the path to the system temp directory. If it does not exists - an error is returned. func getTempDirPath() (string, error) { tempDir := os.TempDir() - _, err := os.Stat(tempDir) if err != nil { return "", err } - return tempDir, nil } func getInaccessiblePath(file string) string { - // null byte in filename is not allowed on Windows AND unix - return filepath.Join("C:", "file\x00name") + return filepath.Join("C:", "file\x00name") // null byte in filename is not allowed on Windows AND unix } diff --git a/middleware/fileserver.go b/caddyhttp/staticfiles/fileserver.go similarity index 68% rename from middleware/fileserver.go rename to caddyhttp/staticfiles/fileserver.go index b1c3d66d5..a2b874f8b 100644 --- a/middleware/fileserver.go +++ b/caddyhttp/staticfiles/fileserver.go @@ -1,4 +1,4 @@ -package middleware +package staticfiles import ( "fmt" @@ -7,59 +7,51 @@ import ( "os" "path" "path/filepath" + "runtime" "strconv" "strings" ) -// This file contains a standard way for Caddy middleware -// to load files from the file system given a request -// URI and path to site root. Other middleware that load -// files should use these facilities. - // FileServer implements a production-ready file server // and is the 'default' handler for all requests to Caddy. -// It simply loads and serves the URI requested. If Caddy is -// run without any extra configuration/directives, this is the -// only middleware handler that runs. It is not in its own -// folder like most other middleware handlers because it does -// not require a directive. It is a special case. -// -// FileServer is adapted from the one in net/http by -// the Go authors. Significant modifications have been made. +// It simply loads and serves the URI requested. FileServer +// is adapted from the one in net/http by the Go authors. +// Significant modifications have been made. // // Original license: // // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -func FileServer(root http.FileSystem, hide []string) Handler { - return &fileHandler{root: root, hide: hide} +type FileServer struct { + // Jailed disk access + Root http.FileSystem + + // List of files to treat as "Not Found" + Hide []string } -type fileHandler struct { - root http.FileSystem - hide []string // list of files to treat as "Not Found" -} - -func (fh *fileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - // r.URL.Path has already been cleaned in caddy/server by path.Clean(). +// ServeHTTP serves static files for r according to fs's configuration. +func (fs FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + // r.URL.Path has already been cleaned by Caddy. if r.URL.Path == "" { r.URL.Path = "/" } - return fh.serveFile(w, r, r.URL.Path) + return fs.serveFile(w, r, r.URL.Path) } // serveFile writes the specified file to the HTTP response. // name is '/'-separated, not filepath.Separator. -func (fh *fileHandler) serveFile(w http.ResponseWriter, r *http.Request, name string) (int, error) { +func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request, name string) (int, error) { // Prevent absolute path access on Windows. // TODO remove when stdlib http.Dir fixes this. - if runtimeGoos == "windows" { + if runtime.GOOS == "windows" { if filepath.IsAbs(name[1:]) { return http.StatusNotFound, nil } } - f, err := fh.root.Open(name) + + f, err := fs.Root.Open(name) if err != nil { if os.IsNotExist(err) { return http.StatusNotFound, nil @@ -104,7 +96,7 @@ func (fh *fileHandler) serveFile(w http.ResponseWriter, r *http.Request, name st if d.IsDir() { for _, indexPage := range IndexPages { index := strings.TrimSuffix(name, "/") + "/" + indexPage - ff, err := fh.root.Open(index) + ff, err := fs.Root.Open(index) if err == nil { // this defer does not leak fds because previous iterations // of the loop must have had an err, so nothing to close @@ -126,12 +118,11 @@ func (fh *fileHandler) serveFile(w http.ResponseWriter, r *http.Request, name st return http.StatusNotFound, nil } - // If file is on hide list. - if fh.isHidden(d) { + if fs.isHidden(d) { return http.StatusNotFound, nil } - // Add ETag header + // Experimental ETag header e := fmt.Sprintf(`W/"%x-%x"`, d.ModTime().Unix(), d.Size()) w.Header().Set("ETag", e) @@ -143,12 +134,11 @@ func (fh *fileHandler) serveFile(w http.ResponseWriter, r *http.Request, name st } // isHidden checks if file with FileInfo d is on hide list. -func (fh fileHandler) isHidden(d os.FileInfo) bool { +func (fs FileServer) isHidden(d os.FileInfo) bool { // If the file is supposed to be hidden, return a 404 - // (TODO: If the slice gets large, a set may be faster) - for _, hiddenPath := range fh.hide { + for _, hiddenPath := range fs.Hide { // Check if the served file is exactly the hidden file. - if hFile, err := fh.root.Open(hiddenPath); err == nil { + if hFile, err := fs.Root.Open(hiddenPath); err == nil { fs, _ := hFile.Stat() hFile.Close() if os.SameFile(d, fs) { @@ -160,8 +150,8 @@ func (fh fileHandler) isHidden(d os.FileInfo) bool { } // redirect is taken from http.localRedirect of the std lib. It -// sends an HTTP redirect to the client but will preserve the -// query string for the new path. +// sends an HTTP permanent redirect to the client but will +// preserve the query string for the new path. func redirect(w http.ResponseWriter, r *http.Request, newPath string) { if q := r.URL.RawQuery; q != "" { newPath += "?" + q diff --git a/middleware/fileserver_test.go b/caddyhttp/staticfiles/fileserver_test.go similarity index 90% rename from middleware/fileserver_test.go rename to caddyhttp/staticfiles/fileserver_test.go index 40d369a8b..6c77cec1a 100644 --- a/middleware/fileserver_test.go +++ b/caddyhttp/staticfiles/fileserver_test.go @@ -1,4 +1,4 @@ -package middleware +package staticfiles import ( "errors" @@ -44,7 +44,10 @@ func TestServeHTTP(t *testing.T) { beforeServeHTTPTest(t) defer afterServeHTTPTest(t) - fileserver := FileServer(http.Dir(testWebRoot), []string{"dir/hidden.html"}) + fileserver := FileServer{ + Root: http.Dir(testWebRoot), + Hide: []string{"dir/hidden.html"}, + } movedPermanently := "Moved Permanently" @@ -169,22 +172,22 @@ func TestServeHTTP(t *testing.T) { // check if error matches expectations if err != nil { - t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err) + t.Errorf("Test %d: Serving file at %s failed. Error was: %v", i, test.url, err) } // check status code if test.expectedStatus != status { - t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + t.Errorf("Test %d: Expected status %d, found %d", i, test.expectedStatus, status) } // check etag if test.expectedEtag != etag { - t.Errorf(getTestPrefix(i)+"Expected Etag header %d, found %d", test.expectedEtag, etag) + t.Errorf("Test %d: Expected Etag header %s, found %s", i, test.expectedEtag, etag) } // check body content if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) { - t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String()) + t.Errorf("Test %d: Expected body to contain %q, found %q", i, test.expectedBodyContent, responseRecorder.Body.String()) } } @@ -302,7 +305,7 @@ func TestServeHTTPFailingFS(t *testing.T) { for i, test := range tests { // initialize a file server with the failing FileSystem - fileserver := FileServer(failingFS{err: test.fsErr}, nil) + fileserver := FileServer{Root: failingFS{err: test.fsErr}} // prepare the request and response request, err := http.NewRequest("GET", "https://foo/", nil) @@ -315,12 +318,12 @@ func TestServeHTTPFailingFS(t *testing.T) { // check the status if status != test.expectedStatus { - t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + t.Errorf("Test %d: Expected status %d, found %d", i, test.expectedStatus, status) } // check the error if actualErr != test.expectedErr { - t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + t.Errorf("Test %d: Expected err %v, found %v", i, test.expectedErr, actualErr) } // check the headers - a special case for server under load @@ -328,7 +331,7 @@ func TestServeHTTPFailingFS(t *testing.T) { for expectedKey, expectedVal := range test.expectedHeaders { actualVal := responseRecorder.Header().Get(expectedKey) if expectedVal != actualVal { - t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal) + t.Errorf("Test %d: Expected header %s: %s, found %s", i, expectedKey, expectedVal, actualVal) } } } @@ -362,7 +365,7 @@ func TestServeHTTPFailingStat(t *testing.T) { for i, test := range tests { // initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will - fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil) + fileserver := FileServer{Root: failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}} // prepare the request and response request, err := http.NewRequest("GET", "https://foo/", nil) @@ -375,12 +378,12 @@ func TestServeHTTPFailingStat(t *testing.T) { // check the status if status != test.expectedStatus { - t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + t.Errorf("Test %d: Expected status %d, found %d", i, test.expectedStatus, status) } // check the error if actualErr != test.expectedErr { - t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + t.Errorf("Test %d: Expected err %v, found %v", i, test.expectedErr, actualErr) } } } diff --git a/caddy/https/certificates.go b/caddytls/certificates.go similarity index 87% rename from caddy/https/certificates.go rename to caddytls/certificates.go index 0dc3db523..b91180ba5 100644 --- a/caddy/https/certificates.go +++ b/caddytls/certificates.go @@ -1,4 +1,4 @@ -package https +package caddytls import ( "crypto/tls" @@ -33,21 +33,12 @@ type Certificate struct { // NotAfter is when the certificate expires. NotAfter time.Time - // Managed certificates are certificates that Caddy is managing, - // as opposed to the user specifying a certificate and key file - // or directory and managing the certificate resources themselves. - Managed bool - - // OnDemand certificates are obtained or loaded on-demand during TLS - // handshakes (as opposed to preloaded certificates, which are loaded - // at startup). If OnDemand is true, Managed must necessarily be true. - // OnDemand certificates are maintained in the background just like - // preloaded ones, however, if an OnDemand certificate fails to renew, - // it is removed from the in-memory cache. - OnDemand bool - // OCSP contains the certificate's parsed OCSP response. OCSP *ocsp.Response + + // Config is the configuration with which the certificate was + // loaded or obtained and with which it should be maintained. + Config *Config } // getCertificate gets a certificate that matches name (a server name) @@ -95,18 +86,21 @@ func getCertificate(name string) (cert Certificate, matched, defaulted bool) { return } -// cacheManagedCertificate loads the certificate for domain into the -// cache, flagging it as Managed and, if onDemand is true, as OnDemand +// CacheManagedCertificate loads the certificate for domain into the +// cache, flagging it as Managed and, if onDemand is true, as "OnDemand" // (meaning that it was obtained or loaded during a TLS handshake). // // This function is safe for concurrent use. -func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) { +func CacheManagedCertificate(domain string, cfg *Config) (Certificate, error) { + storage, err := StorageFor(cfg.CAUrl) + if err != nil { + return Certificate{}, err + } cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain)) if err != nil { return cert, err } - cert.Managed = true - cert.OnDemand = onDemand + cert.Config = cfg cacheCertificate(cert) return cert, nil } @@ -213,7 +207,7 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { func cacheCertificate(cert Certificate) { certCacheMu.Lock() if _, ok := certCache[""]; !ok { - // use as default + // use as default - must be *appended* to list, or bad things happen! cert.Names = append(cert.Names, "") certCache[""] = cert } @@ -232,3 +226,12 @@ func cacheCertificate(cert Certificate) { } certCacheMu.Unlock() } + +// uncacheCertificate deletes name's certificate from the +// cache. If name is not a key in the certificate cache, +// this function does nothing. +func uncacheCertificate(name string) { + certCacheMu.Lock() + delete(certCache, name) + certCacheMu.Unlock() +} diff --git a/caddy/https/certificates_test.go b/caddytls/certificates_test.go similarity index 99% rename from caddy/https/certificates_test.go rename to caddytls/certificates_test.go index dbfb4efc1..02f46cf1e 100644 --- a/caddy/https/certificates_test.go +++ b/caddytls/certificates_test.go @@ -1,4 +1,4 @@ -package https +package caddytls import "testing" diff --git a/caddy/https/client.go b/caddytls/client.go similarity index 55% rename from caddy/https/client.go rename to caddytls/client.go index 762e58aa1..c0e093322 100644 --- a/caddy/https/client.go +++ b/caddytls/client.go @@ -1,15 +1,19 @@ -package https +package caddytls import ( "encoding/json" "errors" "fmt" "io/ioutil" + "log" "net" + "net/url" + "os" + "strings" "sync" "time" - "github.com/mholt/caddy/server" + "github.com/mholt/caddy" "github.com/xenolf/lego/acme" ) @@ -19,22 +23,51 @@ var acmeMu sync.Mutex // ACMEClient is an acme.Client with custom state attached. type ACMEClient struct { *acme.Client - AllowPrompts bool // if false, we assume AlternatePort must be used + AllowPrompts bool + config *Config } -// NewACMEClient creates a new ACMEClient given an email and whether -// prompting the user is allowed. Clients should not be kept and -// re-used over long periods of time, but immediate re-use is more -// efficient than re-creating on every iteration. -var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) { - // Look up or create the LE user account - leUser, err := getUser(email) +// newACMEClient creates a new ACMEClient given an email and whether +// prompting the user is allowed. It's a variable so we can mock in tests. +var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) { + storage, err := StorageFor(config.CAUrl) if err != nil { return nil, err } + // Look up or create the LE user account + leUser, err := getUser(storage, config.ACMEEmail) + if err != nil { + return nil, err + } + + // ensure key type is set + keyType := DefaultKeyType + if config.KeyType != "" { + keyType = config.KeyType + } + + // ensure CA URL (directory endpoint) is set + caURL := DefaultCAUrl + if config.CAUrl != "" { + caURL = config.CAUrl + } + + // ensure endpoint is secure (assume HTTPS if scheme is missing) + if !strings.Contains(caURL, "://") { + caURL = "https://" + caURL + } + u, err := url.Parse(caURL) + if u.Scheme != "https" && + u.Host != "localhost" && + u.Host != "[::1]" && + !strings.HasPrefix(u.Host, "127.") && + !strings.HasPrefix(u.Host, "10.") { + return nil, fmt.Errorf("%s: insecure CA URL (HTTPS required)", caURL) + } + // The client facilitates our communication with the CA server. - client, err := acme.NewClient(CAUrl, &leUser, KeyType) + client, err := acme.NewClient(caURL, &leUser, keyType) if err != nil { return nil, err } @@ -59,52 +92,57 @@ var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) { err = client.AgreeToTOS() if err != nil { - saveUser(leUser) // Might as well try, right? + saveUser(storage, leUser) // Might as well try, right? return nil, errors.New("error agreeing to terms: " + err.Error()) } // save user to the file system - err = saveUser(leUser) + err = saveUser(storage, leUser) if err != nil { return nil, errors.New("could not save user: " + err.Error()) } } - return &ACMEClient{ - Client: client, - AllowPrompts: allowPrompts, - }, nil -} + c := &ACMEClient{Client: client, AllowPrompts: allowPrompts, config: config} -// NewACMEClientGetEmail creates a new ACMEClient and gets an email -// address at the same time (a server config is required, since it -// may contain an email address in it). -func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) { - return NewACMEClient(getEmail(config, allowPrompts), allowPrompts) -} + if config.DNSProvider == "" { + // Use HTTP and TLS-SNI challenges by default -// Configure configures c according to bindHost, which is the host (not -// whole address) to bind the listener to in solving the http and tls-sni -// challenges. -func (c *ACMEClient) Configure(bindHost string) { - // If we allow prompts, operator must be present. In our case, - // that is synonymous with saying the server is not already - // started. So if the user is still there, we don't use - // AlternatePort because we don't need to proxy the challenges. - // Conversely, if the operator is not there, the server has - // already started and we need to proxy the challenge. - if c.AllowPrompts { - // Operator is present; server is not already listening - c.SetHTTPAddress(net.JoinHostPort(bindHost, "")) - c.SetTLSAddress(net.JoinHostPort(bindHost, "")) - //c.ExcludeChallenges([]acme.Challenge{acme.DNS01}) + // See if HTTP challenge needs to be proxied + if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, HTTPChallengePort)) { + altPort := config.AltHTTPPort + if altPort == "" { + altPort = DefaultHTTPAlternatePort + } + c.SetHTTPAddress(net.JoinHostPort(config.ListenHost, altPort)) + } + + // See if TLS challenge needs to be handled by our own facilities + if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, TLSSNIChallengePort)) { + c.SetChallengeProvider(acme.TLSSNI01, tlsSniSolver{}) + } } else { - // Operator is not present; server is started, so proxy challenges - c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort)) - c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort)) - //c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) + // Otherwise, DNS challenge it is + + // Load provider constructor function + provFn, ok := dnsProviders[config.DNSProvider] + if !ok { + return nil, errors.New("unknown DNS provider by name '" + config.DNSProvider + "'") + } + + // we could pass credentials to create the provider, but for now + // we just let the solver package get them from the environment + prov, err := provFn() + if err != nil { + return nil, err + } + + // Use the DNS challenge exclusively + c.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01}) + c.SetChallengeProvider(acme.DNS01, prov) } - c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS... + + return c, nil } // Obtain obtains a single certificate for names. It stores the certificate @@ -121,7 +159,9 @@ Attempts: var promptedForAgreement bool // only prompt user for agreement at most once for errDomain, obtainErr := range failures { - // TODO: Double-check, will obtainErr ever be nil? + if obtainErr == nil { + continue + } if tosErr, ok := obtainErr.(acme.TOSError); ok { // Terms of Service agreement error; we can probably deal with this if !Agreed && !promptedForAgreement && c.AllowPrompts { @@ -144,7 +184,11 @@ Attempts: } // Success - immediately save the certificate resource - err := saveCertResource(certificate) + storage, err := StorageFor(c.config.CAUrl) + if err != nil { + return err + } + err = saveCertResource(storage, certificate) if err != nil { return fmt.Errorf("error saving assets for %v: %v", names, err) } @@ -163,6 +207,12 @@ Attempts: // // Anyway, this function is safe for concurrent use. func (c *ACMEClient) Renew(name string) error { + // Get access to ACME storage + storage, err := StorageFor(c.config.CAUrl) + if err != nil { + return err + } + // Prepare for renewal (load PEM cert, key, and meta) certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name)) if err != nil { @@ -204,12 +254,45 @@ func (c *ACMEClient) Renew(name string) error { } // For any other kind of error, wait 10s and try again. - time.Sleep(10 * time.Second) + wait := 10 * time.Second + log.Printf("[ERROR] Renewing: %v; trying again in %s", err, wait) + time.Sleep(wait) } if !success { return errors.New("too many renewal attempts; last error: " + err.Error()) } - return saveCertResource(newCertMeta) + return saveCertResource(storage, newCertMeta) +} + +// Revoke revokes the certificate for name and deltes +// it from storage. +func (c *ACMEClient) Revoke(name string) error { + storage, err := StorageFor(c.config.CAUrl) + if err != nil { + return err + } + + if !existingCertAndKey(storage, name) { + return errors.New("no certificate and key for " + name) + } + + certFile := storage.SiteCertFile(name) + certBytes, err := ioutil.ReadFile(certFile) + if err != nil { + return err + } + + err = c.Client.RevokeCertificate(certBytes) + if err != nil { + return err + } + + err = os.Remove(certFile) + if err != nil { + return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error()) + } + + return nil } diff --git a/caddytls/client_test.go b/caddytls/client_test.go new file mode 100644 index 000000000..bd9cbbc81 --- /dev/null +++ b/caddytls/client_test.go @@ -0,0 +1,3 @@ +package caddytls + +// TODO diff --git a/caddytls/config.go b/caddytls/config.go new file mode 100644 index 000000000..550da2101 --- /dev/null +++ b/caddytls/config.go @@ -0,0 +1,437 @@ +package caddytls + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "time" + + "github.com/xenolf/lego/acme" +) + +// Config describes how TLS should be configured and used. +type Config struct { + // The hostname or class of hostnames this config is + // designated for; can contain wildcard characters + // according to RFC 6125 §6.4.3 - this field MUST + // NOT be empty in order for things to work smoothly + Hostname string + + // Whether TLS is enabled + Enabled bool + + // Minimum and maximum protocol versions to allow + ProtocolMinVersion uint16 + ProtocolMaxVersion uint16 + + // The list of cipher suites; first should be + // TLS_FALLBACK_SCSV to prevent degrade attacks + Ciphers []uint16 + + // Whether to prefer server cipher suites + PreferServerCipherSuites bool + + // Client authentication policy + ClientAuth tls.ClientAuthType + + // List of client CA certificates to allow, if + // client authentication is enabled + ClientCerts []string + + // Manual means user provides own certs and keys + Manual bool + + // Managed means config qualifies for implicit, + // automatic, managed TLS; as opposed to the user + // providing and managing the certificate manually + Managed bool + + // OnDemand means the class of hostnames this + // config applies to may obtain and manage + // certificates at handshake-time (as opposed + // to pre-loaded at startup); OnDemand certs + // will be managed the same way as preloaded + // ones, however, if an OnDemand cert fails to + // renew, it is removed from the in-memory + // cache; if this is true, Managed must + // necessarily be true + OnDemand bool + + // SelfSigned means that this hostname is + // served with a self-signed certificate + // that we generated in memory for convenience + SelfSigned bool + + // The endpoint of the directory for the ACME + // CA we are to use + CAUrl string + + // The host (ONLY the host, not port) to listen + //on if necessary to start a a listener to solve + // an ACME challenge + ListenHost string + + // The alternate port (ONLY port, not host) + // to use for the ACME HTTP challenge; this + // port will be used if we proxy challenges + // coming in on port 80 to this alternate port + AltHTTPPort string + + // The string identifier of the DNS provider + // to use when solving the ACME DNS challenge + DNSProvider string + + // The email address to use when creating or + // using an ACME account (fun fact: if this + // is set to "off" then this config will not + // qualify for managed TLS) + ACMEEmail string + + // The type of key to use when generating + // certificates + KeyType acme.KeyType +} + +// ObtainCert obtains a certificate for c.Hostname, as long as a certificate +// does not already exist in storage on disk. It only obtains and stores +// certificates (and their keys) to disk, it does not load them into memory. +// If allowPrompts is true, the user may be shown a prompt. If proxyACME is +// true, the relevant ACME challenges will be proxied to the alternate port. +func (c *Config) ObtainCert(allowPrompts bool) error { + return c.obtainCertName(c.Hostname, allowPrompts) +} + +func (c *Config) obtainCertName(name string, allowPrompts bool) error { + storage, err := StorageFor(c.CAUrl) + if err != nil { + return err + } + + if !c.Managed || !HostQualifies(name) || existingCertAndKey(storage, name) { + return nil + } + + if c.ACMEEmail == "" { + c.ACMEEmail = getEmail(storage, allowPrompts) + } + + client, err := newACMEClient(c, allowPrompts) + if err != nil { + return err + } + + return client.Obtain([]string{name}) +} + +// RenewCert renews the certificate for c.Hostname. +func (c *Config) RenewCert(allowPrompts bool) error { + return c.renewCertName(c.Hostname, allowPrompts) +} + +func (c *Config) renewCertName(name string, allowPrompts bool) error { + storage, err := StorageFor(c.CAUrl) + if err != nil { + return err + } + + // Prepare for renewal (load PEM cert, key, and meta) + certBytes, err := ioutil.ReadFile(storage.SiteCertFile(c.Hostname)) + if err != nil { + return err + } + keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(c.Hostname)) + if err != nil { + return err + } + metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(c.Hostname)) + if err != nil { + return err + } + var certMeta acme.CertificateResource + err = json.Unmarshal(metaBytes, &certMeta) + certMeta.Certificate = certBytes + certMeta.PrivateKey = keyBytes + + client, err := newACMEClient(c, allowPrompts) + if err != nil { + return err + } + + // Perform renewal and retry if necessary, but not too many times. + var newCertMeta acme.CertificateResource + var success bool + for attempts := 0; attempts < 2; attempts++ { + acmeMu.Lock() + newCertMeta, err = client.RenewCertificate(certMeta, true) + acmeMu.Unlock() + if err == nil { + success = true + break + } + + // If the legal terms were updated and need to be + // agreed to again, we can handle that. + if _, ok := err.(acme.TOSError); ok { + err := client.AgreeToTOS() + if err != nil { + return err + } + continue + } + + // For any other kind of error, wait 10s and try again. + time.Sleep(10 * time.Second) + } + + if !success { + return errors.New("too many renewal attempts; last error: " + err.Error()) + } + + return saveCertResource(storage, newCertMeta) +} + +// MakeTLSConfig reduces configs into a single tls.Config. +// If TLS is to be disabled, a nil tls.Config will be returned. +func MakeTLSConfig(configs []*Config) (*tls.Config, error) { + if configs == nil || len(configs) == 0 { + return nil, nil + } + + config := new(tls.Config) + ciphersAdded := make(map[uint16]struct{}) + configMap := make(configGroup) + + for i, cfg := range configs { + if cfg == nil { + // avoid nil pointer dereference below + configs[i] = new(Config) + continue + } + + // Key this config by its hostname; this + // overwrites configs with the same hostname + configMap[cfg.Hostname] = cfg + + // Can't serve TLS and not-TLS on same port + if i > 0 && cfg.Enabled != configs[i-1].Enabled { + thisConfProto, lastConfProto := "not TLS", "not TLS" + if cfg.Enabled { + thisConfProto = "TLS" + } + if configs[i-1].Enabled { + lastConfProto = "TLS" + } + return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener", + configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto) + } + + // Union cipher suites + for _, ciph := range cfg.Ciphers { + if _, ok := ciphersAdded[ciph]; !ok { + ciphersAdded[ciph] = struct{}{} + config.CipherSuites = append(config.CipherSuites, ciph) + } + } + + // Can't resolve conflicting PreferServerCipherSuites settings + if i > 0 && cfg.PreferServerCipherSuites != configs[i-1].PreferServerCipherSuites { + return nil, fmt.Errorf("cannot both use PreferServerCipherSuites and not use it") + } + + // Go with the widest range of protocol versions + if cfg.ProtocolMinVersion < config.MinVersion { + config.MinVersion = cfg.ProtocolMinVersion + } + if cfg.ProtocolMaxVersion < config.MaxVersion { + config.MaxVersion = cfg.ProtocolMaxVersion + } + + // Go with the strictest ClientAuth type + if cfg.ClientAuth > config.ClientAuth { + config.ClientAuth = cfg.ClientAuth + } + } + + // Is TLS disabled? If so, we're done here. + // By now, we know that all configs agree + // whether it is or not, so we can just look + // at the first one. + if len(configs) == 0 || !configs[0].Enabled { + return nil, nil + } + + // Default cipher suites + if len(config.CipherSuites) == 0 { + config.CipherSuites = defaultCiphers + } + + // For security, ensure TLS_FALLBACK_SCSV is always included + if config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV { + config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...) + } + + // Set up client authentication if enabled + if config.ClientAuth != tls.NoClientCert { + pool := x509.NewCertPool() + clientCertsAdded := make(map[string]struct{}) + for _, cfg := range configs { + for _, caFile := range cfg.ClientCerts { + // don't add cert to pool more than once + if _, ok := clientCertsAdded[caFile]; ok { + continue + } + clientCertsAdded[caFile] = struct{}{} + + // Any client with a certificate from this CA will be allowed to connect + caCrt, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + + if !pool.AppendCertsFromPEM(caCrt) { + return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile) + } + } + } + config.ClientCAs = pool + } + + // Associate the GetCertificate callback, or almost nothing we just did will work + config.GetCertificate = configMap.GetCertificate + + return config, nil +} + +// ConfigGetter gets a Config keyed by key. +type ConfigGetter func(key string) *Config + +var configGetters = make(map[string]ConfigGetter) + +// RegisterConfigGetter registers fn as the way to get a +// Config for server type serverType. +func RegisterConfigGetter(serverType string, fn ConfigGetter) { + configGetters[serverType] = fn +} + +// SetDefaultTLSParams sets the default TLS cipher suites, protocol versions, +// and server preferences of a server.Config if they were not previously set +// (it does not overwrite; only fills in missing values). +func SetDefaultTLSParams(config *Config) { + // If no ciphers provided, use default list + if len(config.Ciphers) == 0 { + config.Ciphers = defaultCiphers + } + + // Not a cipher suite, but still important for mitigating protocol downgrade attacks + // (prepend since having it at end breaks http2 due to non-h2-approved suites before it) + config.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.Ciphers...) + + // Set default protocol min and max versions - must balance compatibility and security + if config.ProtocolMinVersion == 0 { + config.ProtocolMinVersion = tls.VersionTLS11 + } + if config.ProtocolMaxVersion == 0 { + config.ProtocolMaxVersion = tls.VersionTLS12 + } + + // Prefer server cipher suites + config.PreferServerCipherSuites = true +} + +// Map of supported key types +var supportedKeyTypes = map[string]acme.KeyType{ + "P384": acme.EC384, + "P256": acme.EC256, + "RSA8192": acme.RSA8192, + "RSA4096": acme.RSA4096, + "RSA2048": acme.RSA2048, +} + +// Map of supported protocols. +// HTTP/2 only supports TLS 1.2 and higher. +var supportedProtocols = map[string]uint16{ + "tls1.0": tls.VersionTLS10, + "tls1.1": tls.VersionTLS11, + "tls1.2": tls.VersionTLS12, +} + +// Map of supported ciphers, used only for parsing config. +// +// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites, +// including all but four of the suites below (the four GCM suites). +// See https://http2.github.io/http2-spec/#BadCipherSuites +// +// TLS_FALLBACK_SCSV is not in this list because we manually ensure +// it is always added (even though it is not technically a cipher suite). +// +// This map, like any map, is NOT ORDERED. Do not range over this map. +var supportedCiphersMap = map[string]uint16{ + "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + "ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + "RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, + "RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, + "ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, +} + +// List of supported cipher suites in descending order of preference. +// Ordering is very important! Getting the wrong order will break +// mainstream clients, especially with HTTP/2. +// +// Note that TLS_FALLBACK_SCSV is not in this list since it is always +// added manually. +var supportedCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, +} + +// List of all the ciphers we want to use by default +var defaultCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, +} + +const ( + // HTTPChallengePort is the officially designated port for + // the HTTP challenge. + HTTPChallengePort = "80" + + // TLSSNIChallengePort is the officially designated port for + // the TLS-SNI challenge. + TLSSNIChallengePort = "443" + + // DefaultHTTPAlternatePort is the port on which the ACME + // client will open a listener and solve the HTTP challenge. + // If this alternate port is used instead of the default + // port, then whatever is listening on the default port must + // be capable of proxying or forwarding the request to this + // alternate port. + DefaultHTTPAlternatePort = "5033" +) diff --git a/caddytls/crypto.go b/caddytls/crypto.go new file mode 100644 index 000000000..243b37f5d --- /dev/null +++ b/caddytls/crypto.go @@ -0,0 +1,258 @@ +package caddytls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "io" + "io/ioutil" + "math/big" + "net" + "os" + "time" + + "github.com/xenolf/lego/acme" +) + +// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file. +func loadPrivateKey(file string) (crypto.PrivateKey, error) { + keyBytes, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + keyBlock, _ := pem.Decode(keyBytes) + + switch keyBlock.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(keyBlock.Bytes) + } + + return nil, errors.New("unknown private key type") +} + +// savePrivateKey saves a PEM-encoded ECC/RSA private key to file. +func savePrivateKey(key crypto.PrivateKey, file string) error { + var pemType string + var keyBytes []byte + switch key := key.(type) { + case *ecdsa.PrivateKey: + var err error + pemType = "EC" + keyBytes, err = x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + case *rsa.PrivateKey: + pemType = "RSA" + keyBytes = x509.MarshalPKCS1PrivateKey(key) + } + + pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes} + keyOut, err := os.Create(file) + if err != nil { + return err + } + keyOut.Chmod(0600) + defer keyOut.Close() + return pem.Encode(keyOut, &pemKey) +} + +// stapleOCSP staples OCSP information to cert for hostname name. +// If you have it handy, you should pass in the PEM-encoded certificate +// bundle; otherwise the DER-encoded cert will have to be PEM-encoded. +// If you don't have the PEM blocks handy, just pass in nil. +// +// Errors here are not necessarily fatal, it could just be that the +// certificate doesn't have an issuer URL. +func stapleOCSP(cert *Certificate, pemBundle []byte) error { + if pemBundle == nil { + // The function in the acme package that gets OCSP requires a PEM-encoded cert + bundle := new(bytes.Buffer) + for _, derBytes := range cert.Certificate.Certificate { + pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + } + pemBundle = bundle.Bytes() + } + + ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle) + if err != nil { + return err + } + + cert.Certificate.OCSPStaple = ocspBytes + cert.OCSP = ocspResp + + return nil +} + +// makeSelfSignedCert makes a self-signed certificate according +// to the parameters in config. It then caches the certificate +// in our cache. +func makeSelfSignedCert(config *Config) error { + // start by generating private key + var privKey interface{} + var err error + switch config.KeyType { + case "", acme.EC256: + privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + case acme.EC384: + privKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + case acme.RSA2048: + privKey, err = rsa.GenerateKey(rand.Reader, 2048) + case acme.RSA4096: + privKey, err = rsa.GenerateKey(rand.Reader, 4096) + case acme.RSA8192: + privKey, err = rsa.GenerateKey(rand.Reader, 8192) + default: + return fmt.Errorf("cannot generate private key; unknown key type %v", config.KeyType) + } + if err != nil { + return fmt.Errorf("failed to generate private key: %v", err) + } + + // create certificate structure with proper values + notBefore := time.Now() + notAfter := notBefore.Add(24 * time.Hour * 7) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return fmt.Errorf("failed to generate serial number: %v", err) + } + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"Caddy Self-Signed"}}, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + if ip := net.ParseIP(config.Hostname); ip != nil { + cert.IPAddresses = append(cert.IPAddresses, ip) + } else { + cert.DNSNames = append(cert.DNSNames, config.Hostname) + } + + publicKey := func(privKey interface{}) interface{} { + switch k := privKey.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return errors.New("unknown key type") + } + } + derBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, publicKey(privKey), privKey) + if err != nil { + return fmt.Errorf("could not create certificate: %v", err) + } + + cacheCertificate(Certificate{ + Certificate: tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: privKey, + Leaf: cert, + }, + Names: cert.DNSNames, + NotAfter: cert.NotAfter, + Config: config, + }) + + return nil +} + +// RotateSessionTicketKeys rotates the TLS session ticket keys +// on cfg every TicketRotateInterval. It spawns a new goroutine so +// this function does NOT block. It returns a channel you should +// close when you are ready to stop the key rotation, like when the +// server using cfg is no longer running. +func RotateSessionTicketKeys(cfg *tls.Config) chan struct{} { + ch := make(chan struct{}) + ticker := time.NewTicker(TicketRotateInterval) + go runTLSTicketKeyRotation(cfg, ticker, ch) + return ch +} + +// Functions that may be swapped out for testing +var ( + runTLSTicketKeyRotation = standaloneTLSTicketKeyRotation + setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { return keys } +) + +// standaloneTLSTicketKeyRotation governs over the array of TLS ticket keys used to de/crypt TLS tickets. +// It periodically sets a new ticket key as the first one, used to encrypt (and decrypt), +// pushing any old ticket keys to the back, where they are considered for decryption only. +// +// Lack of entropy for the very first ticket key results in the feature being disabled (as does Go), +// later lack of entropy temporarily disables ticket key rotation. +// Old ticket keys are still phased out, though. +// +// Stops the ticker when returning. +func standaloneTLSTicketKeyRotation(c *tls.Config, ticker *time.Ticker, exitChan chan struct{}) { + defer ticker.Stop() + + // The entire page should be marked as sticky, but Go cannot do that + // without resorting to syscall#Mlock. And, we don't have madvise (for NODUMP), too. ☹ + keys := make([][32]byte, 1, NumTickets) + + rng := c.Rand + if rng == nil { + rng = rand.Reader + } + if _, err := io.ReadFull(rng, keys[0][:]); err != nil { + c.SessionTicketsDisabled = true // bail if we don't have the entropy for the first one + return + } + c.SessionTicketKey = keys[0] // SetSessionTicketKeys doesn't set a 'tls.keysAlreadySet' + c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) + + for { + select { + case _, isOpen := <-exitChan: + if !isOpen { + return + } + case <-ticker.C: + rng = c.Rand // could've changed since the start + if rng == nil { + rng = rand.Reader + } + var newTicketKey [32]byte + _, err := io.ReadFull(rng, newTicketKey[:]) + + if len(keys) < NumTickets { + keys = append(keys, keys[0]) // manipulates the internal length + } + for idx := len(keys) - 1; idx >= 1; idx-- { + keys[idx] = keys[idx-1] // yes, this makes copies + } + + if err == nil { + keys[0] = newTicketKey + } + // pushes the last key out, doesn't matter that we don't have a new one + c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) + } + } +} + +const ( + // NumTickets is how many tickets to hold and consider + // to decrypt TLS sessions. + NumTickets = 4 + + // TicketRotateInterval is how often to generate + // new ticket for TLS PFS encryption + TicketRotateInterval = 10 * time.Hour +) diff --git a/caddy/https/crypto_test.go b/caddytls/crypto_test.go similarity index 60% rename from caddy/https/crypto_test.go rename to caddytls/crypto_test.go index efa45c65a..3eca43ae2 100644 --- a/caddy/https/crypto_test.go +++ b/caddytls/crypto_test.go @@ -1,4 +1,4 @@ -package https +package caddytls import ( "bytes" @@ -7,11 +7,12 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" - "errors" "os" "runtime" "testing" + "time" ) func TestSaveAndLoadRSAPrivateKey(t *testing.T) { @@ -96,25 +97,70 @@ func TestSaveAndLoadECCPrivateKey(t *testing.T) { // PrivateKeysSame compares the bytes of a and b and returns true if they are the same. func PrivateKeysSame(a, b crypto.PrivateKey) bool { - var abytes, bbytes []byte - var err error - - if abytes, err = PrivateKeyBytes(a); err != nil { - return false - } - if bbytes, err = PrivateKeyBytes(b); err != nil { - return false - } - return bytes.Equal(abytes, bbytes) + return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b)) } // PrivateKeyBytes returns the bytes of DER-encoded key. -func PrivateKeyBytes(key crypto.PrivateKey) ([]byte, error) { +func PrivateKeyBytes(key crypto.PrivateKey) []byte { + var keyBytes []byte switch key := key.(type) { case *rsa.PrivateKey: - return x509.MarshalPKCS1PrivateKey(key), nil + keyBytes = x509.MarshalPKCS1PrivateKey(key) case *ecdsa.PrivateKey: - return x509.MarshalECPrivateKey(key) + keyBytes, _ = x509.MarshalECPrivateKey(key) + } + return keyBytes +} + +func TestStandaloneTLSTicketKeyRotation(t *testing.T) { + tlsGovChan := make(chan struct{}) + defer close(tlsGovChan) + callSync := make(chan bool, 1) + defer close(callSync) + + oldHook := setSessionTicketKeysTestHook + defer func() { + setSessionTicketKeysTestHook = oldHook + }() + var keysInUse [][32]byte + setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { + keysInUse = keys + callSync <- true + return keys + } + + c := new(tls.Config) + timer := time.NewTicker(time.Millisecond * 1) + + go standaloneTLSTicketKeyRotation(c, timer, tlsGovChan) + + rounds := 0 + var lastTicketKey [32]byte + for { + select { + case <-callSync: + if lastTicketKey == keysInUse[0] { + close(tlsGovChan) + t.Errorf("The same TLS ticket key has been used again (not rotated): %x.", lastTicketKey) + return + } + lastTicketKey = keysInUse[0] + rounds++ + if rounds <= NumTickets && len(keysInUse) != rounds { + close(tlsGovChan) + t.Errorf("Expected TLS ticket keys in use: %d; Got instead: %d.", rounds, len(keysInUse)) + return + } + if c.SessionTicketsDisabled == true { + t.Error("Session tickets have been disabled unexpectedly.") + return + } + if rounds >= NumTickets+1 { + return + } + case <-time.After(time.Second * 1): + t.Errorf("Timeout after %d rounds.", rounds) + return + } } - return nil, errors.New("Unknown private key type") } diff --git a/caddy/https/handshake.go b/caddytls/handshake.go similarity index 64% rename from caddy/https/handshake.go rename to caddytls/handshake.go index fc6ef809e..f389dd7ae 100644 --- a/caddy/https/handshake.go +++ b/caddytls/handshake.go @@ -1,9 +1,7 @@ -package https +package caddytls import ( - "bytes" "crypto/tls" - "encoding/pem" "errors" "fmt" "log" @@ -11,67 +9,98 @@ import ( "sync" "sync/atomic" "time" - - "github.com/mholt/caddy/server" - "github.com/xenolf/lego/acme" ) -// GetCertificate gets a certificate to satisfy clientHello as long as -// the certificate is already cached in memory. It will not be loaded -// from disk or obtained from the CA during the handshake. +// configGroup is a type that keys configs by their hostname +// (hostnames can have wildcard characters; use the getConfig +// method to get a config by matching its hostname). Its +// GetCertificate function can be used with tls.Config. +type configGroup map[string]*Config + +// getConfig gets the config by the first key match for name. +// In other words, "sub.foo.bar" will get the config for "*.foo.bar" +// if that is the closest match. This function MAY return nil +// if no match is found. // -// This function is safe for use as a tls.Config.GetCertificate callback. -func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := getCertDuringHandshake(clientHello.ServerName, false, false) - return &cert.Certificate, err +// This function follows nearly the same logic to lookup +// a hostname as the getCertificate function uses. +func (cg configGroup) getConfig(name string) *Config { + name = strings.ToLower(name) + + // exact match? great, let's use it + if config, ok := cg[name]; ok { + return config + } + + // try replacing labels in the name with wildcards until we get a match + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if config, ok := cg[candidate]; ok { + return config + } + } + + // as last resort, try a config that serves all names + if config, ok := cg[""]; ok { + return config + } + + return nil } -// GetOrObtainCertificate will get a certificate to satisfy clientHello, even -// if that means obtaining a new certificate from a CA during the handshake. -// It first checks the in-memory cache, then accesses disk, then accesses the -// network if it must. An obtained certificate will be stored on disk and -// cached in memory. +// GetCertificate gets a certificate to satisfy clientHello. In getting +// the certificate, it abides the rules and settings defined in the +// Config that matches clientHello.ServerName. It first checks the in- +// memory cache, then, if the config enables "OnDemand", it accessses +// disk, then accesses the network if it must obtain a new certificate +// via ACME. // -// This function is safe for use as a tls.Config.GetCertificate callback. -func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := getCertDuringHandshake(clientHello.ServerName, true, true) +// This method is safe for use as a tls.Config.GetCertificate callback. +func (cg configGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := cg.getCertDuringHandshake(clientHello.ServerName, true, true) return &cert.Certificate, err } // getCertDuringHandshake will get a certificate for name. It first tries -// the in-memory cache. If no certificate for name is in the cache and if -// loadIfNecessary == true, it goes to disk to load it into the cache and -// serve it. If it's not on disk and if obtainIfNecessary == true, the -// certificate will be obtained from the CA, cached, and served. If -// obtainIfNecessary is true, then loadIfNecessary must also be set to true. -// An error will be returned if and only if no certificate is available. +// the in-memory cache. If no certificate for name is in the cache, the +// config most closely corresponding to name will be loaded. If that config +// allows it (OnDemand==true) and if loadIfNecessary == true, it goes to disk +// to load it into the cache and serve it. If it's not on disk and if +// obtainIfNecessary == true, the certificate will be obtained from the CA, +// cached, and served. If obtainIfNecessary is true, then loadIfNecessary +// must also be set to true. An error will be returned if and only if no +// certificate is available. // // This function is safe for concurrent use. -func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { +func (cg configGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { // First check our in-memory cache to see if we've already loaded it cert, matched, defaulted := getCertificate(name) if matched { return cert, nil } - if loadIfNecessary { + // Get the relevant TLS config for this name. If OnDemand is enabled, + // then we might be able to load or obtain a needed certificate. + cfg := cg.getConfig(name) + if cfg != nil && cfg.OnDemand && loadIfNecessary { // Then check to see if we have one on disk - loadedCert, err := cacheManagedCertificate(name, true) + loadedCert, err := CacheManagedCertificate(name, cfg) if err == nil { - loadedCert, err = handshakeMaintenance(name, loadedCert) + loadedCert, err = cg.handshakeMaintenance(name, loadedCert) if err != nil { log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) } return loadedCert, nil } - if obtainIfNecessary { // By this point, we need to ask the CA for a certificate name = strings.ToLower(name) // Make sure aren't over any applicable limits - err := checkLimitsForObtainingNewCerts(name) + err := cg.checkLimitsForObtainingNewCerts(name) if err != nil { return Certificate{}, err } @@ -82,22 +111,23 @@ func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool } // Obtain certificate from the CA - return obtainOnDemandCertificate(name) + return cg.obtainOnDemandCertificate(name, cfg) } } + // Fall back to the default certificate if there is one if defaulted { return cert, nil } - return Certificate{}, errors.New("no certificate for " + name) + return Certificate{}, fmt.Errorf("no certificate available for %s", name) } // checkLimitsForObtainingNewCerts checks to see if name can be issued right // now according to mitigating factors we keep track of and preferences the // user has set. If a non-nil error is returned, do not issue a new certificate // for name. -func checkLimitsForObtainingNewCerts(name string) error { +func (cg configGroup) checkLimitsForObtainingNewCerts(name string) error { // User can set hard limit for number of certs for the process to issue if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue { return fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue) @@ -129,7 +159,7 @@ func checkLimitsForObtainingNewCerts(name string) error { // name, it will wait and use what the other goroutine obtained. // // This function is safe for use by multiple concurrent goroutines. -func obtainOnDemandCertificate(name string) (Certificate, error) { +func (cg configGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) { // We must protect this process from happening concurrently, so synchronize. obtainCertWaitChansMu.Lock() wait, ok := obtainCertWaitChans[name] @@ -138,7 +168,7 @@ func obtainOnDemandCertificate(name string) (Certificate, error) { // wait for it to finish obtaining the cert and then we'll use it. obtainCertWaitChansMu.Unlock() <-wait - return getCertDuringHandshake(name, true, false) + return cg.getCertDuringHandshake(name, true, false) } // looks like it's up to us to do all the work and obtain the cert @@ -156,14 +186,7 @@ func obtainOnDemandCertificate(name string) (Certificate, error) { log.Printf("[INFO] Obtaining new certificate for %s", name) - // obtain cert - client, err := NewACMEClientGetEmail(server.Config{}, false) - if err != nil { - return Certificate{}, errors.New("error creating client: " + err.Error()) - } - client.Configure("") // TODO: which BindHost? - err = client.Obtain([]string{name}) - if err != nil { + if err := cfg.obtainCertName(name, false); err != nil { // Failed to solve challenge, so don't allow another on-demand // issue for this name to be attempted for a little while. failedIssuanceMu.Lock() @@ -185,19 +208,19 @@ func obtainOnDemandCertificate(name string) (Certificate, error) { lastIssueTimeMu.Unlock() // The certificate is already on disk; now just start over to load it and serve it - return getCertDuringHandshake(name, true, false) + return cg.getCertDuringHandshake(name, true, false) } // handshakeMaintenance performs a check on cert for expiration and OCSP // validity. // // This function is safe for use by multiple concurrent goroutines. -func handshakeMaintenance(name string, cert Certificate) (Certificate, error) { +func (cg configGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) { // Check cert expiration timeLeft := cert.NotAfter.Sub(time.Now().UTC()) - if timeLeft < renewDurationBefore { + if timeLeft < RenewDurationBefore { log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) - return renewDynamicCertificate(name) + return cg.renewDynamicCertificate(name, cert.Config) } // Check OCSP staple validity @@ -219,13 +242,13 @@ func handshakeMaintenance(name string, cert Certificate) (Certificate, error) { return cert, nil } -// renewDynamicCertificate renews currentCert using the clientHello. It returns the +// renewDynamicCertificate renews the certificate for name using cfg. It returns the // certificate to use and an error, if any. currentCert may be returned even if an // error occurs, since we perform renewals before they expire and it may still be // usable. name should already be lower-cased before calling this function. // // This function is safe for use by multiple concurrent goroutines. -func renewDynamicCertificate(name string) (Certificate, error) { +func (cg configGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) { obtainCertWaitChansMu.Lock() wait, ok := obtainCertWaitChans[name] if ok { @@ -233,7 +256,7 @@ func renewDynamicCertificate(name string) (Certificate, error) { // wait for it to finish, then we'll use the new one. obtainCertWaitChansMu.Unlock() <-wait - return getCertDuringHandshake(name, true, false) + return cg.getCertDuringHandshake(name, true, false) } // looks like it's up to us to do all the work and renew the cert @@ -251,45 +274,12 @@ func renewDynamicCertificate(name string) (Certificate, error) { log.Printf("[INFO] Renewing certificate for %s", name) - client, err := NewACMEClientGetEmail(server.Config{}, false) - if err != nil { - return Certificate{}, err - } - client.Configure("") // TODO: Bind address of relevant listener, yuck - err = client.Renew(name) + err := cfg.renewCertName(name, false) if err != nil { return Certificate{}, err } - return getCertDuringHandshake(name, true, false) -} - -// stapleOCSP staples OCSP information to cert for hostname name. -// If you have it handy, you should pass in the PEM-encoded certificate -// bundle; otherwise the DER-encoded cert will have to be PEM-encoded. -// If you don't have the PEM blocks handy, just pass in nil. -// -// Errors here are not necessarily fatal, it could just be that the -// certificate doesn't have an issuer URL. -func stapleOCSP(cert *Certificate, pemBundle []byte) error { - if pemBundle == nil { - // The function in the acme package that gets OCSP requires a PEM-encoded cert - bundle := new(bytes.Buffer) - for _, derBytes := range cert.Certificate.Certificate { - pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - } - pemBundle = bundle.Bytes() - } - - ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle) - if err != nil { - return err - } - - cert.Certificate.OCSPStaple = ocspBytes - cert.OCSP = ocspResp - - return nil + return cg.getCertDuringHandshake(name, true, false) } // obtainCertWaitChans is used to coordinate obtaining certs for each hostname. @@ -318,3 +308,5 @@ var failedIssuanceMu sync.RWMutex // If this value is recent, do not make any on-demand certificate requests. var lastIssueTime time.Time var lastIssueTimeMu sync.Mutex + +var errNoCert = errors.New("no certificate available") diff --git a/caddy/https/handshake_test.go b/caddytls/handshake_test.go similarity index 83% rename from caddy/https/handshake_test.go rename to caddytls/handshake_test.go index cf70eb17d..6abfb767f 100644 --- a/caddy/https/handshake_test.go +++ b/caddytls/handshake_test.go @@ -1,4 +1,4 @@ -package https +package caddytls import ( "crypto/tls" @@ -9,16 +9,18 @@ import ( func TestGetCertificate(t *testing.T) { defer func() { certCache = make(map[string]Certificate) }() + cg := make(configGroup) + hello := &tls.ClientHelloInfo{ServerName: "example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} helloNoSNI := &tls.ClientHelloInfo{} helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} // When cache is empty - if cert, err := GetCertificate(hello); err == nil { + if cert, err := cg.GetCertificate(hello); err == nil { t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert) } - if cert, err := GetCertificate(helloNoSNI); err == nil { + if cert, err := cg.GetCertificate(helloNoSNI); err == nil { t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) } @@ -26,12 +28,12 @@ func TestGetCertificate(t *testing.T) { defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} certCache[""] = defaultCert certCache["example.com"] = defaultCert - if cert, err := GetCertificate(hello); err != nil { + if cert, err := cg.GetCertificate(hello); err != nil { t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) } else if cert.Leaf.DNSNames[0] != "example.com" { t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) } - if cert, err := GetCertificate(helloNoSNI); err != nil { + if cert, err := cg.GetCertificate(helloNoSNI); err != nil { t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) } else if cert.Leaf.DNSNames[0] != "example.com" { t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert) @@ -39,14 +41,14 @@ func TestGetCertificate(t *testing.T) { // When retrieving wildcard certificate certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} - if cert, err := GetCertificate(helloSub); err != nil { + if cert, err := cg.GetCertificate(helloSub); err != nil { t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) } else if cert.Leaf.DNSNames[0] != "*.example.com" { t.Errorf("Got wrong certificate, expected wildcard: %v", cert) } // When no certificate matches, the default is returned - if cert, err := GetCertificate(helloNoMatch); err != nil { + if cert, err := cg.GetCertificate(helloNoMatch); err != nil { t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) } else if cert.Leaf.DNSNames[0] != "example.com" { t.Errorf("Expected default cert with no matches, got: %v", cert) diff --git a/caddytls/httphandler.go b/caddytls/httphandler.go new file mode 100644 index 000000000..8115f9450 --- /dev/null +++ b/caddytls/httphandler.go @@ -0,0 +1,42 @@ +package caddytls + +import ( + "crypto/tls" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" +) + +const challengeBasePath = "/.well-known/acme-challenge" + +// HTTPChallengeHandler proxies challenge requests to ACME client if the +// request path starts with challengeBasePath. It returns true if it +// handled the request and no more needs to be done; it returns false +// if this call was a no-op and the request still needs handling. +func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, altPort string) bool { + if !strings.HasPrefix(r.URL.Path, challengeBasePath) { + return false + } + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + upstream, err := url.Parse(scheme + "://localhost:" + altPort) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Printf("[ERROR] ACME proxy handler: %v", err) + return true + } + + proxy := httputil.NewSingleHostReverseProxy(upstream) + proxy.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs + } + proxy.ServeHTTP(w, r) + + return true +} diff --git a/caddy/https/handler_test.go b/caddytls/httphandler_test.go similarity index 76% rename from caddy/https/handler_test.go rename to caddytls/httphandler_test.go index 016799ffb..fc04e8eeb 100644 --- a/caddy/https/handler_test.go +++ b/caddytls/httphandler_test.go @@ -1,4 +1,4 @@ -package https +package caddytls import ( "net" @@ -7,7 +7,7 @@ import ( "testing" ) -func TestRequestCallbackNoOp(t *testing.T) { +func TestHTTPChallengeHandlerNoOp(t *testing.T) { // try base paths that aren't handled by this handler for _, url := range []string{ "http://localhost/", @@ -21,13 +21,13 @@ func TestRequestCallbackNoOp(t *testing.T) { t.Fatalf("Could not craft request, got error: %v", err) } rw := httptest.NewRecorder() - if RequestCallback(rw, req) { + if HTTPChallengeHandler(rw, req, DefaultHTTPAlternatePort) { t.Errorf("Got true with this URL, but shouldn't have: %s", url) } } } -func TestRequestCallbackSuccess(t *testing.T) { +func TestHTTPChallengeHandlerSuccess(t *testing.T) { expectedPath := challengeBasePath + "/asdf" // Set up fake acme handler backend to make sure proxying succeeds @@ -40,7 +40,7 @@ func TestRequestCallbackSuccess(t *testing.T) { })) // Custom listener that uses the port we expect - ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort) + ln, err := net.Listen("tcp", "127.0.0.1:"+DefaultHTTPAlternatePort) if err != nil { t.Fatalf("Unable to start test server listener: %v", err) } @@ -49,13 +49,13 @@ func TestRequestCallbackSuccess(t *testing.T) { // Start our engines and run the test ts.Start() defer ts.Close() - req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil) + req, err := http.NewRequest("GET", "http://127.0.0.1:"+DefaultHTTPAlternatePort+expectedPath, nil) if err != nil { t.Fatalf("Could not craft request, got error: %v", err) } rw := httptest.NewRecorder() - RequestCallback(rw, req) + HTTPChallengeHandler(rw, req, DefaultHTTPAlternatePort) if !proxySuccess { t.Fatal("Expected request to be proxied, but it wasn't") diff --git a/caddy/https/maintain.go b/caddytls/maintain.go similarity index 74% rename from caddy/https/maintain.go rename to caddytls/maintain.go index a0fb0557b..96514ac24 100644 --- a/caddy/https/maintain.go +++ b/caddytls/maintain.go @@ -1,31 +1,39 @@ -package https +package caddytls import ( "log" "time" - "github.com/mholt/caddy/server" - "golang.org/x/crypto/ocsp" ) +func init() { + // maintain assets while this package is imported, which is + // always. we don't ever stop it, since we need it running. + go maintainAssets(make(chan struct{})) +} + const ( // RenewInterval is how often to check certificates for renewal. RenewInterval = 12 * time.Hour // OCSPInterval is how often to check if OCSP stapling needs updating. OCSPInterval = 1 * time.Hour + + // RenewDurationBefore is how long before expiration to renew certificates. + RenewDurationBefore = (24 * time.Hour) * 30 ) // maintainAssets is a permanently-blocking function // that loops indefinitely and, on a regular schedule, checks // certificates for expiration and initiates a renewal of certs // that are expiring soon. It also updates OCSP stapling and -// performs other maintenance of assets. +// performs other maintenance of assets. It should only be +// called once per process. // // You must pass in the channel which you'll close when // maintenance should stop, to allow this goroutine to clean up -// after itself and unblock. +// after itself and unblock. (Not that you HAVE to stop it...) func maintainAssets(stopChan chan struct{}) { renewalTicker := time.NewTicker(RenewInterval) ocspTicker := time.NewTicker(OCSPInterval) @@ -34,11 +42,11 @@ func maintainAssets(stopChan chan struct{}) { select { case <-renewalTicker.C: log.Println("[INFO] Scanning for expiring certificates") - renewManagedCertificates(false) + RenewManagedCertificates(false) log.Println("[INFO] Done checking certificates") case <-ocspTicker.C: log.Println("[INFO] Scanning for stale OCSP staples") - updateOCSPStaples() + UpdateOCSPStaples() log.Println("[INFO] Done checking OCSP staples") case <-stopChan: renewalTicker.Stop() @@ -49,20 +57,20 @@ func maintainAssets(stopChan chan struct{}) { } } -func renewManagedCertificates(allowPrompts bool) (err error) { +// RenewManagedCertificates renews managed certificates. +func RenewManagedCertificates(allowPrompts bool) (err error) { var renewed, deleted []Certificate - var client *ACMEClient visitedNames := make(map[string]struct{}) certCacheMu.RLock() for name, cert := range certCache { - if !cert.Managed { + if !cert.Config.Managed || cert.Config.SelfSigned { continue } // the list of names on this cert should never be empty... if cert.Names == nil || len(cert.Names) == 0 { - log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names) + log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", name, cert.Names) deleted = append(deleted, cert) continue } @@ -75,21 +83,21 @@ func renewManagedCertificates(allowPrompts bool) (err error) { visitedNames[name] = struct{}{} } + // if its time is up or ending soon, we need to try to renew it timeLeft := cert.NotAfter.Sub(time.Now().UTC()) - if timeLeft < renewDurationBefore { + if timeLeft < RenewDurationBefore { log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) - if client == nil { - client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts) - if err != nil { - return err - } - client.Configure("") // TODO: Bind address of relevant listener, yuck + if cert.Config == nil { + log.Printf("[ERROR] %s: No associated TLS config; unable to renew", name) + continue } - err := client.Renew(cert.Names[0]) // managed certs better have only one name + // this works well because managed certs are only associated with one name per config + err := cert.Config.RenewCert(allowPrompts) + if err != nil { - if client.AllowPrompts && timeLeft < 0 { + if allowPrompts && timeLeft < 0 { // Certificate renewal failed, the operator is present, and the certificate // is already expired; we should stop immediately and return the error. Note // that we used to do this any time a renewal failed at startup. However, @@ -100,7 +108,7 @@ func renewManagedCertificates(allowPrompts bool) (err error) { return err } log.Printf("[ERROR] %v", err) - if cert.OnDemand { + if cert.Config.OnDemand { deleted = append(deleted, cert) } } else { @@ -113,20 +121,21 @@ func renewManagedCertificates(allowPrompts bool) (err error) { // Apply changes to the cache for _, cert := range renewed { if cert.Names[len(cert.Names)-1] == "" { - // Special case: This is the default certificate, so we must - // ensure it gets updated as well, otherwise the renewal - // routine will find it and think it still needs to be renewed, - // even though we already renewed it... + // Special case: This is the default certificate. We must + // flush it out of the cache so that we no longer point to + // the old, un-renewed certificate. Otherwise it will be + // renewed on every scan, which is too often. When we cache + // this certificate in a moment, it will be the default again. certCacheMu.Lock() delete(certCache, "") certCacheMu.Unlock() } - _, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand) + _, err := CacheManagedCertificate(cert.Names[0], cert.Config) if err != nil { - if client.AllowPrompts { + if allowPrompts { return err // operator is present, so report error immediately } - log.Printf("[ERROR] Caching renewed certificate: %v", err) + log.Printf("[ERROR] %v", err) } } for _, cert := range deleted { @@ -140,7 +149,9 @@ func renewManagedCertificates(allowPrompts bool) (err error) { return nil } -func updateOCSPStaples() { +// UpdateOCSPStaples updates the OCSP stapling in all +// eligible, cached certificates. +func UpdateOCSPStaples() { // Create a temporary place to store updates // until we release the potentially long-lived // read lock and use a short-lived write lock. @@ -186,7 +197,7 @@ func updateOCSPStaples() { err := stapleOCSP(&cert, nil) if err != nil { if cert.OCSP != nil { - // if it was no staple before, that's fine, otherwise we should log the error + // if there was no staple before, that's fine; otherwise we should log the error log.Printf("[ERROR] Checking OCSP for %v: %v", cert.Names, err) } continue @@ -215,6 +226,3 @@ func updateOCSPStaples() { } certCacheMu.Unlock() } - -// renewDurationBefore is how long before expiration to renew certificates. -const renewDurationBefore = (24 * time.Hour) * 30 diff --git a/caddytls/setup.go b/caddytls/setup.go new file mode 100644 index 000000000..f9cfc9847 --- /dev/null +++ b/caddytls/setup.go @@ -0,0 +1,278 @@ +package caddytls + +import ( + "bytes" + "crypto/tls" + "encoding/pem" + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "tls", + Action: setupTLS, + }) +} + +// setupTLS sets up the TLS configuration and installs certificates that +// are specified by the user in the config file. All the automatic HTTPS +// stuff comes later outside of this function. +func setupTLS(c *caddy.Controller) error { + configGetter, ok := configGetters[c.ServerType()] + if !ok { + return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType()) + } + config := configGetter(c.Key) + if config == nil { + return fmt.Errorf("no caddytls.Config to set up for %s", c.Key) + } + + config.Enabled = true + + for c.Next() { + var certificateFile, keyFile, loadDir, maxCerts string + + args := c.RemainingArgs() + switch len(args) { + case 1: + // even if the email is one of the special values below, + // it is still necessary for future analysis that we store + // that value in the ACMEEmail field. + config.ACMEEmail = args[0] + + // user can force-disable managed TLS this way + if args[0] == "off" { + config.Enabled = false + return nil + } + + // user might want a temporary, in-memory, self-signed cert + if args[0] == "self_signed" { + config.SelfSigned = true + } + case 2: + certificateFile = args[0] + keyFile = args[1] + config.Manual = true + } + + // Optional block with extra parameters + var hadBlock bool + for c.NextBlock() { + hadBlock = true + switch c.Val() { + case "key_type": + arg := c.RemainingArgs() + value, ok := supportedKeyTypes[strings.ToUpper(arg[0])] + if !ok { + return c.Errf("Wrong key type name or key type not supported: '%s'", c.Val()) + } + config.KeyType = value + case "protocols": + args := c.RemainingArgs() + if len(args) != 2 { + return c.ArgErr() + } + value, ok := supportedProtocols[strings.ToLower(args[0])] + if !ok { + return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0]) + } + config.ProtocolMinVersion = value + value, ok = supportedProtocols[strings.ToLower(args[1])] + if !ok { + return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[1]) + } + config.ProtocolMaxVersion = value + case "ciphers": + for c.NextArg() { + value, ok := supportedCiphersMap[strings.ToUpper(c.Val())] + if !ok { + return c.Errf("Wrong cipher name or cipher not supported: '%s'", c.Val()) + } + config.Ciphers = append(config.Ciphers, value) + } + case "clients": + clientCertList := c.RemainingArgs() + if len(clientCertList) == 0 { + return c.ArgErr() + } + + listStart, mustProvideCA := 1, true + switch clientCertList[0] { + case "request": + config.ClientAuth = tls.RequestClientCert + mustProvideCA = false + case "require": + config.ClientAuth = tls.RequireAnyClientCert + mustProvideCA = false + case "verify_if_given": + config.ClientAuth = tls.VerifyClientCertIfGiven + default: + config.ClientAuth = tls.RequireAndVerifyClientCert + listStart = 0 + } + if mustProvideCA && len(clientCertList) <= listStart { + return c.ArgErr() + } + + config.ClientCerts = clientCertList[listStart:] + case "load": + c.Args(&loadDir) + config.Manual = true + case "max_certs": + c.Args(&maxCerts) + config.OnDemand = true + case "dns": + args := c.RemainingArgs() + if len(args) != 1 { + return c.ArgErr() + } + dnsProvName := args[0] + if _, ok := dnsProviders[dnsProvName]; !ok { + return c.Errf("Unsupported DNS provider '%s'", args[0]) + } + config.DNSProvider = args[0] + default: + return c.Errf("Unknown keyword '%s'", c.Val()) + } + } + + // tls requires at least one argument if a block is not opened + if len(args) == 0 && !hadBlock { + return c.ArgErr() + } + + // set certificate limit if on-demand TLS is enabled + if maxCerts != "" { + maxCertsNum, err := strconv.Atoi(maxCerts) + if err != nil || maxCertsNum < 1 { + return c.Err("max_certs must be a positive integer") + } + if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost... + onDemandMaxIssue = int32(maxCertsNum) + } + } + + // don't try to load certificates unless we're supposed to + if !config.Enabled || !config.Manual { + continue + } + + // load a single certificate and key, if specified + if certificateFile != "" && keyFile != "" { + err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) + if err != nil { + return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err) + } + log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile) + } + + // load a directory of certificates, if specified + if loadDir != "" { + err := loadCertsInDir(c, loadDir) + if err != nil { + return err + } + } + } + + SetDefaultTLSParams(config) + + // generate self-signed cert if needed + if config.SelfSigned { + err := makeSelfSignedCert(config) + if err != nil { + return fmt.Errorf("self-signed: %v", err) + } + } + + return nil +} + +// loadCertsInDir loads all the certificates/keys in dir, as long as +// the file ends with .pem. This method of loading certificates is +// modeled after haproxy, which expects the certificate and key to +// be bundled into the same file: +// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt +// +// This function may write to the log as it walks the directory tree. +func loadCertsInDir(c *caddy.Controller, dir string) error { + return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + log.Printf("[WARNING] Unable to traverse into %s; skipping", path) + return nil + } + if info.IsDir() { + return nil + } + if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") { + certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer) + var foundKey bool // use only the first key in the file + + bundle, err := ioutil.ReadFile(path) + if err != nil { + return err + } + + for { + // Decode next block so we can see what type it is + var derBlock *pem.Block + derBlock, bundle = pem.Decode(bundle) + if derBlock == nil { + break + } + + if derBlock.Type == "CERTIFICATE" { + // Re-encode certificate as PEM, appending to certificate chain + pem.Encode(certBuilder, derBlock) + } else if derBlock.Type == "EC PARAMETERS" { + // EC keys generated from openssl can be composed of two blocks: + // parameters and key (parameter block should come first) + if !foundKey { + // Encode parameters + pem.Encode(keyBuilder, derBlock) + + // Key must immediately follow + derBlock, bundle = pem.Decode(bundle) + if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" { + return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path) + } + pem.Encode(keyBuilder, derBlock) + foundKey = true + } + } else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") { + // RSA key + if !foundKey { + pem.Encode(keyBuilder, derBlock) + foundKey = true + } + } else { + return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type) + } + } + + certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes() + if len(certPEMBytes) == 0 { + return c.Errf("%s: failed to parse PEM data", path) + } + if len(keyPEMBytes) == 0 { + return c.Errf("%s: no private key block found", path) + } + + err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) + if err != nil { + return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err) + } + log.Printf("[INFO] Successfully loaded TLS assets from %s", path) + } + return nil + }) +} diff --git a/caddy/https/setup_test.go b/caddytls/setup_test.go similarity index 70% rename from caddy/https/setup_test.go rename to caddytls/setup_test.go index 59a772c45..ceb1c2a4a 100644 --- a/caddy/https/setup_test.go +++ b/caddytls/setup_test.go @@ -1,4 +1,4 @@ -package https +package caddytls import ( "crypto/tls" @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/mholt/caddy/caddy/setup" + "github.com/mholt/caddy" "github.com/xenolf/lego/acme" ) @@ -32,32 +32,29 @@ func TestMain(m *testing.M) { } func TestSetupParseBasic(t *testing.T) { - c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``) + cfg := new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c := caddy.NewTestController(`tls ` + certFile + ` ` + keyFile + ``) - _, err := Setup(c) + err := setupTLS(c) if err != nil { t.Errorf("Expected no errors, got: %v", err) } // Basic checks - if !c.TLS.Manual { + if !cfg.Manual { t.Error("Expected TLS Manual=true, but was false") } - if !c.TLS.Enabled { + if !cfg.Enabled { t.Error("Expected TLS Enabled=true, but was false") } // Security defaults - if c.TLS.ProtocolMinVersion != tls.VersionTLS10 { - t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion) + if cfg.ProtocolMinVersion != tls.VersionTLS11 { + t.Errorf("Expected 'tls1.1 (0x0302)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion) } - if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 { - t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", c.TLS.ProtocolMaxVersion) - } - - // KeyType default - if KeyType != acme.RSA2048 { - t.Errorf("Expected '2048' as KeyType, got %#v", KeyType) + if cfg.ProtocolMaxVersion != tls.VersionTLS12 { + t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", cfg.ProtocolMaxVersion) } // Cipher checks @@ -76,27 +73,27 @@ func TestSetupParseBasic(t *testing.T) { } // Ensure count is correct (plus one for TLS_FALLBACK_SCSV) - if len(c.TLS.Ciphers) != len(expectedCiphers) { + if len(cfg.Ciphers) != len(expectedCiphers) { t.Errorf("Expected %v Ciphers (including TLS_FALLBACK_SCSV), got %v", - len(expectedCiphers), len(c.TLS.Ciphers)) + len(expectedCiphers), len(cfg.Ciphers)) } // Ensure ordering is correct - for i, actual := range c.TLS.Ciphers { + for i, actual := range cfg.Ciphers { if actual != expectedCiphers[i] { t.Errorf("Expected cipher in position %d to be %0x, got %0x", i, expectedCiphers[i], actual) } } - if !c.TLS.PreferServerCipherSuites { + if !cfg.PreferServerCipherSuites { t.Error("Expected PreferServerCipherSuites = true, but was false") } } func TestSetupParseIncompleteParams(t *testing.T) { // Using tls without args is an error because it's unnecessary. - c := setup.NewTestController(`tls`) - _, err := Setup(c) + c := caddy.NewTestController(`tls`) + err := setupTLS(c) if err == nil { t.Error("Expected an error, but didn't get one") } @@ -104,26 +101,28 @@ func TestSetupParseIncompleteParams(t *testing.T) { func TestSetupParseWithOptionalParams(t *testing.T) { params := `tls ` + certFile + ` ` + keyFile + ` { - protocols ssl3.0 tls1.2 + protocols tls1.0 tls1.2 ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384 }` - c := setup.NewTestController(params) + cfg := new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c := caddy.NewTestController(params) - _, err := Setup(c) + err := setupTLS(c) if err != nil { t.Errorf("Expected no errors, got: %v", err) } - if c.TLS.ProtocolMinVersion != tls.VersionSSL30 { - t.Errorf("Expected 'ssl3.0 (0x0300)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion) + if cfg.ProtocolMinVersion != tls.VersionTLS10 { + t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion) } - if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 { - t.Errorf("Expected 'tls1.2 (0x0302)' as ProtocolMaxVersion, got %#v", c.TLS.ProtocolMaxVersion) + if cfg.ProtocolMaxVersion != tls.VersionTLS12 { + t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %#v", cfg.ProtocolMaxVersion) } - if len(c.TLS.Ciphers)-1 != 3 { - t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) + if len(cfg.Ciphers)-1 != 3 { + t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(cfg.Ciphers)-1) } } @@ -131,38 +130,28 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) { params := `tls { ciphers RSA-3DES-EDE-CBC-SHA }` - c := setup.NewTestController(params) + cfg := new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c := caddy.NewTestController(params) - _, err := Setup(c) + err := setupTLS(c) if err != nil { t.Errorf("Expected no errors, got: %v", err) } - if len(c.TLS.Ciphers)-1 != 1 { - t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) + if len(cfg.Ciphers)-1 != 1 { + t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(cfg.Ciphers)-1) } } -// TODO: If we allow this... but probably not a good idea. -// func TestSetupDisableHTTPRedirect(t *testing.T) { -// c := NewTestController(`tls { -// allow_http -// }`) -// _, err := TLS(c) -// if err != nil { -// t.Errorf("Expected no error, but got %v", err) -// } -// if !c.TLS.DisableHTTPRedir { -// t.Error("Expected HTTP redirect to be disabled, but it wasn't") -// } -// } - func TestSetupParseWithWrongOptionalParams(t *testing.T) { // Test protocols wrong params params := `tls ` + certFile + ` ` + keyFile + ` { protocols ssl tls }` - c := setup.NewTestController(params) - _, err := Setup(c) + cfg := new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c := caddy.NewTestController(params) + err := setupTLS(c) if err == nil { t.Errorf("Expected errors, but no error returned") } @@ -171,8 +160,10 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { params = `tls ` + certFile + ` ` + keyFile + ` { ciphers not-valid-cipher }` - c = setup.NewTestController(params) - _, err = Setup(c) + cfg = new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c = caddy.NewTestController(params) + err = setupTLS(c) if err == nil { t.Errorf("Expected errors, but no error returned") } @@ -181,8 +172,10 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { params = `tls { key_type ab123 }` - c = setup.NewTestController(params) - _, err = Setup(c) + cfg = new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c = caddy.NewTestController(params) + err = setupTLS(c) if err == nil { t.Errorf("Expected errors, but no error returned") } @@ -193,8 +186,10 @@ func TestSetupParseWithClientAuth(t *testing.T) { params := `tls ` + certFile + ` ` + keyFile + ` { clients }` - c := setup.NewTestController(params) - _, err := Setup(c) + cfg := new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c := caddy.NewTestController(params) + err := setupTLS(c) if err == nil { t.Errorf("Expected an error, but no error returned") } @@ -224,8 +219,10 @@ func TestSetupParseWithClientAuth(t *testing.T) { clients verify_if_given }`, tls.VerifyClientCertIfGiven, true, noCAs}, } { - c := setup.NewTestController(caseData.params) - _, err := Setup(c) + cfg := new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c := caddy.NewTestController(caseData.params) + err := setupTLS(c) if caseData.expectedErr { if err == nil { t.Errorf("In case %d: Expected an error, got: %v", caseNumber, err) @@ -236,17 +233,17 @@ func TestSetupParseWithClientAuth(t *testing.T) { t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err) } - if caseData.clientAuthType != c.TLS.ClientAuth { + if caseData.clientAuthType != cfg.ClientAuth { t.Errorf("In case %d: Expected TLS client auth type %v, got: %v", - caseNumber, caseData.clientAuthType, c.TLS.ClientAuth) + caseNumber, caseData.clientAuthType, cfg.ClientAuth) } - if count := len(c.TLS.ClientCerts); count < len(caseData.expectedCAs) { + if count := len(cfg.ClientCerts); count < len(caseData.expectedCAs) { t.Fatalf("In case %d: Expected %d client certs, had %d", caseNumber, len(caseData.expectedCAs), count) } for idx, expected := range caseData.expectedCAs { - if actual := c.TLS.ClientCerts[idx]; actual != expected { + if actual := cfg.ClientCerts[idx]; actual != expected { t.Errorf("In case %d: Expected %dth client cert file to be '%s', but was '%s'", caseNumber, idx, expected, actual) } @@ -258,15 +255,17 @@ func TestSetupParseWithKeyType(t *testing.T) { params := `tls { key_type p384 }` - c := setup.NewTestController(params) + cfg := new(Config) + RegisterConfigGetter("", func(key string) *Config { return cfg }) + c := caddy.NewTestController(params) - _, err := Setup(c) + err := setupTLS(c) if err != nil { t.Errorf("Expected no errors, got: %v", err) } - if KeyType != acme.EC384 { - t.Errorf("Expected 'P384' as KeyType, got %#v", KeyType) + if cfg.KeyType != acme.EC384 { + t.Errorf("Expected 'P384' as KeyType, got %#v", cfg.KeyType) } } diff --git a/caddy/https/storage.go b/caddytls/storage.go similarity index 62% rename from caddy/https/storage.go rename to caddytls/storage.go index 5d487837f..1a00a9de7 100644 --- a/caddy/https/storage.go +++ b/caddytls/storage.go @@ -1,19 +1,48 @@ -package https +package caddytls import ( + "fmt" + "net/url" "path/filepath" "strings" - "github.com/mholt/caddy/caddy/assets" + "github.com/mholt/caddy" ) -// storage is used to get file paths in a consistent, -// cross-platform way for persisting Let's Encrypt assets -// on the file system. -var storage = Storage(filepath.Join(assets.Path(), "letsencrypt")) +// StorageFor gets the storage value associated with the +// caURL, which should be unique for every different +// ACME CA. +func StorageFor(caURL string) (Storage, error) { + if caURL == "" { + caURL = DefaultCAUrl + } + if caURL == "" { + return "", fmt.Errorf("cannot create storage without CA URL") + } + caURL = strings.ToLower(caURL) + + // scheme required or host will be parsed as path (as of Go 1.6) + if !strings.Contains(caURL, "://") { + caURL = "https://" + caURL + } + + u, err := url.Parse(caURL) + if err != nil { + return "", fmt.Errorf("%s: unable to parse CA URL: %v", caURL, err) + } + + if u.Host == "" { + return "", fmt.Errorf("%s: no host in CA URL", caURL) + } + + return Storage(filepath.Join(storageBasePath, u.Host)), nil +} // Storage is a root directory and facilitates -// forming file paths derived from it. +// forming file paths derived from it. It is used +// to get file paths in a consistent, cross- +// platform way for persisting ACME assets. +// on the file system. type Storage string // Sites gets the directory that stores site certificate and keys. @@ -23,21 +52,25 @@ func (s Storage) Sites() string { // Site returns the path to the folder containing assets for domain. func (s Storage) Site(domain string) string { + domain = strings.ToLower(domain) return filepath.Join(s.Sites(), domain) } // SiteCertFile returns the path to the certificate file for domain. func (s Storage) SiteCertFile(domain string) string { + domain = strings.ToLower(domain) return filepath.Join(s.Site(domain), domain+".crt") } // SiteKeyFile returns the path to domain's private key file. func (s Storage) SiteKeyFile(domain string) string { + domain = strings.ToLower(domain) return filepath.Join(s.Site(domain), domain+".key") } // SiteMetaFile returns the path to the domain's asset metadata file. func (s Storage) SiteMetaFile(domain string) string { + domain = strings.ToLower(domain) return filepath.Join(s.Site(domain), domain+".json") } @@ -51,6 +84,7 @@ func (s Storage) User(email string) string { if email == "" { email = emptyEmail } + email = strings.ToLower(email) return filepath.Join(s.Users(), email) } @@ -60,6 +94,7 @@ func (s Storage) UserRegFile(email string) string { if email == "" { email = emptyEmail } + email = strings.ToLower(email) fileName := emailUsername(email) if fileName == "" { fileName = "registration" @@ -73,6 +108,7 @@ func (s Storage) UserKeyFile(email string) string { if email == "" { email = emptyEmail } + email = strings.ToLower(email) fileName := emailUsername(email) if fileName == "" { fileName = "private" @@ -92,3 +128,7 @@ func emailUsername(email string) string { } return email[:at] } + +// storageBasePath is the root path in which all TLS/ACME assets are +// stored. Do not change this value during the lifetime of the program. +var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme") diff --git a/caddy/https/storage_test.go b/caddytls/storage_test.go similarity index 62% rename from caddy/https/storage_test.go rename to caddytls/storage_test.go index 85c2220eb..e9175af96 100644 --- a/caddy/https/storage_test.go +++ b/caddytls/storage_test.go @@ -1,35 +1,82 @@ -package https +package caddytls import ( "path/filepath" "testing" ) +func TestStorageFor(t *testing.T) { + // first try without DefaultCAUrl set + DefaultCAUrl = "" + _, err := StorageFor("") + if err == nil { + t.Errorf("Without a default CA, expected error, but didn't get one") + } + st, err := StorageFor("https://example.com/foo") + if err != nil { + t.Errorf("Without a default CA but given input, expected no error, but got: %v", err) + } + if string(st) != filepath.Join(storageBasePath, "example.com") { + t.Errorf("Without a default CA but given input, expected '%s' not '%s'", "example.com", st) + } + + // try with the DefaultCAUrl set + DefaultCAUrl = "https://defaultCA/directory" + for i, test := range []struct { + input, expect string + shouldErr bool + }{ + {"https://acme-staging.api.letsencrypt.org/directory", "acme-staging.api.letsencrypt.org", false}, + {"https://foo/boo?bar=q", "foo", false}, + {"http://foo", "foo", false}, + {"", "defaultca", false}, + {"https://FooBar/asdf", "foobar", false}, + {"noscheme/path", "noscheme", false}, + {"/nohost", "", true}, + {"https:///nohost", "", true}, + {"FooBar", "foobar", false}, + } { + st, err := StorageFor(test.input) + if err == nil && test.shouldErr { + t.Errorf("Test %d: Expected an error, but didn't get one", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d: Expected no errors, but got: %v", i, err) + } + want := filepath.Join(storageBasePath, test.expect) + if test.shouldErr { + want = "" + } + if string(st) != want { + t.Errorf("Test %d: Expected '%s' but got '%s'", i, want, string(st)) + } + } +} + func TestStorage(t *testing.T) { - storage = Storage("./le_test") + storage := Storage("./le_test") if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected { t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("Test.com"); actual != expected { t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("Test.com"); actual != expected { t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual) } if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected { t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("TEST.COM"); actual != expected { t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual) } if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected { t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("Me@example.com"); actual != expected { t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual) } - if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected { + if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("ME@EXAMPLE.COM"); actual != expected { t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual) } if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected { diff --git a/caddytls/tls.go b/caddytls/tls.go new file mode 100644 index 000000000..601556a11 --- /dev/null +++ b/caddytls/tls.go @@ -0,0 +1,187 @@ +// Package caddytls facilitates the management of TLS assets and integrates +// Let's Encrypt functionality into Caddy with first-class support for +// creating and renewing certificates automatically. +package caddytls + +import ( + "encoding/json" + "io/ioutil" + "net" + "os" + "strings" + + "github.com/xenolf/lego/acme" +) + +// HostQualifies returns true if the hostname alone +// appears eligible for automatic HTTPS. For example, +// localhost, empty hostname, and IP addresses are +// not eligible because we cannot obtain certificates +// for those names. +func HostQualifies(hostname string) bool { + return hostname != "localhost" && // localhost is ineligible + + // hostname must not be empty + strings.TrimSpace(hostname) != "" && + + // must not contain wildcard (*) characters (until CA supports it) + !strings.Contains(hostname, "*") && + + // must not start or end with a dot + !strings.HasPrefix(hostname, ".") && + !strings.HasSuffix(hostname, ".") && + + // cannot be an IP address, see + // https://community.letsencrypt.org/t/certificate-for-static-ip/84/2?u=mholt + net.ParseIP(hostname) == nil +} + +// existingCertAndKey returns true if the hostname has +// a certificate and private key in storage already under +// the storage provided, otherwise it returns false. +func existingCertAndKey(storage Storage, hostname string) bool { + _, err := os.Stat(storage.SiteCertFile(hostname)) + if err != nil { + return false + } + _, err = os.Stat(storage.SiteKeyFile(hostname)) + if err != nil { + return false + } + return true +} + +// saveCertResource saves the certificate resource to disk. This +// includes the certificate file itself, the private key, and the +// metadata file. +func saveCertResource(storage Storage, cert acme.CertificateResource) error { + err := os.MkdirAll(storage.Site(cert.Domain), 0700) + if err != nil { + return err + } + + // Save cert + err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600) + if err != nil { + return err + } + + // Save private key + err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600) + if err != nil { + return err + } + + // Save cert metadata + jsonBytes, err := json.MarshalIndent(&cert, "", "\t") + if err != nil { + return err + } + err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600) + if err != nil { + return err + } + + return nil +} + +// Revoke revokes the certificate for host via ACME protocol. +// It assumes the certificate was obtained from the +// CA at DefaultCAUrl. +func Revoke(host string) error { + client, err := newACMEClient(new(Config), true) + if err != nil { + return err + } + return client.Revoke(host) +} + +// tlsSniSolver is a type that can solve tls-sni challenges using +// an existing listener and our custom, in-memory certificate cache. +type tlsSniSolver struct{} + +// Present adds the challenge certificate to the cache. +func (s tlsSniSolver) Present(domain, token, keyAuth string) error { + cert, err := acme.TLSSNI01ChallengeCert(keyAuth) + if err != nil { + return err + } + cacheCertificate(Certificate{ + Certificate: cert, + Names: []string{domain}, + }) + return nil +} + +// CleanUp removes the challenge certificate from the cache. +func (s tlsSniSolver) CleanUp(domain, token, keyAuth string) error { + uncacheCertificate(domain) + return nil +} + +// ConfigHolder is any type that has a Config; it presumably is +// connected to a hostname and port on which it is serving. +type ConfigHolder interface { + TLSConfig() *Config + Host() string + Port() string +} + +// QualifiesForManagedTLS returns true if c qualifies for +// for managed TLS (but not on-demand TLS specifically). +// It does NOT check to see if a cert and key already exist +// for the config. If the return value is true, you should +// be OK to set c.TLSConfig().Managed to true; then you should +// check that value in the future instead, because the process +// of setting up the config may make it look like it doesn't +// qualify even though it originally did. +func QualifiesForManagedTLS(c ConfigHolder) bool { + if c == nil { + return false + } + tlsConfig := c.TLSConfig() + if tlsConfig == nil { + return false + } + + return (!tlsConfig.Manual || tlsConfig.OnDemand) && // user might provide own cert and key + + // if self-signed, we've already generated one to use + !tlsConfig.SelfSigned && + + // user can force-disable managed TLS + c.Port() != "80" && + tlsConfig.ACMEEmail != "off" && + + // we get can't certs for some kinds of hostnames, but + // on-demand TLS allows empty hostnames at startup + (HostQualifies(c.Host()) || tlsConfig.OnDemand) +} + +// DNSProviderConstructor is a function that takes credentials and +// returns a type that can solve the ACME DNS challenges. +type DNSProviderConstructor func(credentials ...string) (acme.ChallengeProvider, error) + +// dnsProviders is the list of DNS providers that have been plugged in. +var dnsProviders = make(map[string]DNSProviderConstructor) + +// RegisterDNSProvider registers provider by name for solving the ACME DNS challenge. +func RegisterDNSProvider(name string, provider DNSProviderConstructor) { + dnsProviders[name] = provider +} + +var ( + // DefaultEmail represents the Let's Encrypt account email to use if none provided. + DefaultEmail string + + // Agreed indicates whether user has agreed to the Let's Encrypt SA. + Agreed bool + + // DefaultCAUrl is the default URL to the CA's ACME directory endpoint. + // It's very important to set this unless you set it in every Config. + DefaultCAUrl string + + // DefaultKeyType is used as the type of key for new certificates + // when no other key type is specified. + DefaultKeyType = acme.RSA2048 +) diff --git a/caddytls/tls_test.go b/caddytls/tls_test.go new file mode 100644 index 000000000..c46e24947 --- /dev/null +++ b/caddytls/tls_test.go @@ -0,0 +1,165 @@ +package caddytls + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/xenolf/lego/acme" +) + +func TestHostQualifies(t *testing.T) { + for i, test := range []struct { + host string + expect bool + }{ + {"example.com", true}, + {"sub.example.com", true}, + {"Sub.Example.COM", true}, + {"127.0.0.1", false}, + {"127.0.1.5", false}, + {"69.123.43.94", false}, + {"::1", false}, + {"::", false}, + {"0.0.0.0", false}, + {"", false}, + {" ", false}, + {"*.example.com", false}, + {".com", false}, + {"example.com.", false}, + {"localhost", false}, + {"local", true}, + {"devsite", true}, + {"192.168.1.3", false}, + {"10.0.2.1", false}, + {"169.112.53.4", false}, + } { + actual := HostQualifies(test.host) + if actual != test.expect { + t.Errorf("Test %d: Expected HostQualifies(%s)=%v, but got %v", + i, test.host, test.expect, actual) + } + } +} + +type holder struct { + host, port string + cfg *Config +} + +func (h holder) TLSConfig() *Config { return h.cfg } +func (h holder) Host() string { return h.host } +func (h holder) Port() string { return h.port } + +func TestQualifiesForManagedTLS(t *testing.T) { + for i, test := range []struct { + cfg ConfigHolder + expect bool + }{ + {holder{host: ""}, false}, + {holder{host: "localhost"}, false}, + {holder{host: "123.44.3.21"}, false}, + {holder{host: "example.com"}, false}, + {holder{host: "", cfg: new(Config)}, false}, + {holder{host: "localhost", cfg: new(Config)}, false}, + {holder{host: "123.44.3.21", cfg: new(Config)}, false}, + {holder{host: "example.com", cfg: new(Config)}, true}, + {holder{host: "*.example.com", cfg: new(Config)}, false}, + {holder{host: "example.com", cfg: &Config{Manual: true}}, false}, + {holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false}, + {holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true}, + {holder{host: "example.com", port: "80"}, false}, + {holder{host: "example.com", port: "1234", cfg: new(Config)}, true}, + {holder{host: "example.com", port: "443", cfg: new(Config)}, true}, + {holder{host: "example.com", port: "80"}, false}, + } { + if got, want := QualifiesForManagedTLS(test.cfg), test.expect; got != want { + t.Errorf("Test %d: Expected %v but got %v", i, want, got) + } + } +} + +func TestSaveCertResource(t *testing.T) { + storage := Storage("./le_test_save") + defer func() { + err := os.RemoveAll(string(storage)) + if err != nil { + t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) + } + }() + + domain := "example.com" + certContents := "certificate" + keyContents := "private key" + metaContents := `{ + "domain": "example.com", + "certUrl": "https://example.com/cert", + "certStableUrl": "https://example.com/cert/stable" +}` + + cert := acme.CertificateResource{ + Domain: domain, + CertURL: "https://example.com/cert", + CertStableURL: "https://example.com/cert/stable", + PrivateKey: []byte(keyContents), + Certificate: []byte(certContents), + } + + err := saveCertResource(storage, cert) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain)) + if err != nil { + t.Errorf("Expected no error reading certificate file, got: %v", err) + } + if string(certFile) != certContents { + t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile)) + } + + keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain)) + if err != nil { + t.Errorf("Expected no error reading private key file, got: %v", err) + } + if string(keyFile) != keyContents { + t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile)) + } + + metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain)) + if err != nil { + t.Errorf("Expected no error reading meta file, got: %v", err) + } + if string(metaFile) != metaContents { + t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile)) + } +} + +func TestExistingCertAndKey(t *testing.T) { + storage := Storage("./le_test_existing") + defer func() { + err := os.RemoveAll(string(storage)) + if err != nil { + t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) + } + }() + + domain := "example.com" + + if existingCertAndKey(storage, domain) { + t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain) + } + + err := saveCertResource(storage, acme.CertificateResource{ + Domain: domain, + PrivateKey: []byte("key"), + Certificate: []byte("cert"), + }) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if !existingCertAndKey(storage, domain) { + t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain) + } +} diff --git a/caddy/https/user.go b/caddytls/user.go similarity index 80% rename from caddy/https/user.go rename to caddytls/user.go index a7e6e5f62..d10680b91 100644 --- a/caddy/https/user.go +++ b/caddytls/user.go @@ -1,4 +1,4 @@ -package https +package caddytls import ( "bufio" @@ -14,7 +14,6 @@ import ( "os" "strings" - "github.com/mholt/caddy/server" "github.com/xenolf/lego/acme" ) @@ -40,11 +39,77 @@ func (u User) GetPrivateKey() crypto.PrivateKey { return u.key } -// getUser loads the user with the given email from disk. -// If the user does not exist, it will create a new one, -// but it does NOT save new users to the disk or register -// them via ACME. It does NOT prompt the user. -func getUser(email string) (User, error) { +// newUser creates a new User for the given email address +// with a new private key. This function does NOT save the +// user to disk or register it via ACME. If you want to use +// a user account that might already exist, call getUser +// instead. It does NOT prompt the user. +func newUser(email string) (User, error) { + user := User{Email: email} + privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return user, errors.New("error generating private key: " + err.Error()) + } + user.key = privateKey + return user, nil +} + +// getEmail does everything it can to obtain an email +// address from the user within the scope of storage +// to use for ACME TLS. If it cannot get an email +// address, it returns empty string. (It will warn the +// user of the consequences of an empty email.) This +// function MAY prompt the user for input. If userPresent +// is false, the operator will NOT be prompted and an +// empty email may be returned. +func getEmail(storage Storage, userPresent bool) string { + // First try memory (command line flag or typed by user previously) + leEmail := DefaultEmail + if leEmail == "" { + // Then try to get most recent user email + userDirs, err := ioutil.ReadDir(storage.Users()) + if err == nil { + var mostRecent os.FileInfo + for _, dir := range userDirs { + if !dir.IsDir() { + continue + } + if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) { + leEmail = dir.Name() + DefaultEmail = leEmail // save for next time + mostRecent = dir + } + } + } + } + if leEmail == "" && userPresent { + // Alas, we must bother the user and ask for an email address; + // if they proceed they also agree to the SA. + reader := bufio.NewReader(stdin) + fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.") + fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") + fmt.Println(" " + saURL) // TODO: Show current SA link + fmt.Println("Please enter your email address so you can recover your account if needed.") + fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.") + fmt.Print("Email address: ") + var err error + leEmail, err = reader.ReadString('\n') + if err != nil { + return "" + } + leEmail = strings.TrimSpace(leEmail) + DefaultEmail = leEmail + Agreed = true + } + return strings.ToLower(leEmail) +} + +// getUser loads the user with the given email from disk +// using the provided storage. If the user does not exist, +// it will create a new one, but it does NOT save new +// users to the disk or register them via ACME. It does +// NOT prompt the user. +func getUser(storage Storage, email string) (User, error) { var user User // open user file @@ -75,8 +140,10 @@ func getUser(email string) (User, error) { // saveUser persists a user's key and account registration // to the file system. It does NOT register the user via ACME -// or prompt the user. -func saveUser(user User) error { +// or prompt the user. You must also pass in the storage +// wherein the user should be saved. It should be the storage +// for the CA with which user has an account. +func saveUser(storage Storage, user User) error { // make user account folder err := os.MkdirAll(storage.User(user.Email), 0700) if err != nil { @@ -98,73 +165,6 @@ func saveUser(user User) error { return ioutil.WriteFile(storage.UserRegFile(user.Email), jsonBytes, 0600) } -// newUser creates a new User for the given email address -// with a new private key. This function does NOT save the -// user to disk or register it via ACME. If you want to use -// a user account that might already exist, call getUser -// instead. It does NOT prompt the user. -func newUser(email string) (User, error) { - user := User{Email: email} - privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - return user, errors.New("error generating private key: " + err.Error()) - } - user.key = privateKey - return user, nil -} - -// getEmail does everything it can to obtain an email -// address from the user to use for TLS for cfg. If it -// cannot get an email address, it returns empty string. -// (It will warn the user of the consequences of an -// empty email.) This function MAY prompt the user for -// input. If userPresent is false, the operator will -// NOT be prompted and an empty email may be returned. -func getEmail(cfg server.Config, userPresent bool) string { - // First try the tls directive from the Caddyfile - leEmail := cfg.TLS.LetsEncryptEmail - if leEmail == "" { - // Then try memory (command line flag or typed by user previously) - leEmail = DefaultEmail - } - if leEmail == "" { - // Then try to get most recent user email ~/.caddy/users file - userDirs, err := ioutil.ReadDir(storage.Users()) - if err == nil { - var mostRecent os.FileInfo - for _, dir := range userDirs { - if !dir.IsDir() { - continue - } - if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) { - leEmail = dir.Name() - DefaultEmail = leEmail // save for next time - } - } - } - } - if leEmail == "" && userPresent { - // Alas, we must bother the user and ask for an email address; - // if they proceed they also agree to the SA. - reader := bufio.NewReader(stdin) - fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.") - fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") - fmt.Println(" " + saURL) // TODO: Show current SA link - fmt.Println("Please enter your email address so you can recover your account if needed.") - fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.") - fmt.Print("Email address: ") - var err error - leEmail, err = reader.ReadString('\n') - if err != nil { - return "" - } - leEmail = strings.TrimSpace(leEmail) - DefaultEmail = leEmail - Agreed = true - } - return leEmail -} - // promptUserAgreement prompts the user to agree to the agreement // at agreementURL via stdin. If the agreement has changed, then pass // true as the second argument. If this is the user's first time diff --git a/caddy/https/user_test.go b/caddytls/user_test.go similarity index 74% rename from caddy/https/user_test.go rename to caddytls/user_test.go index c1d115e1f..67f730827 100644 --- a/caddy/https/user_test.go +++ b/caddytls/user_test.go @@ -1,4 +1,4 @@ -package https +package caddytls import ( "bytes" @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/mholt/caddy/server" "github.com/xenolf/lego/acme" ) @@ -54,8 +53,7 @@ func TestNewUser(t *testing.T) { } func TestSaveUser(t *testing.T) { - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) + defer os.RemoveAll(string(testStorage)) email := "me@foobar.com" user, err := newUser(email) @@ -63,25 +61,24 @@ func TestSaveUser(t *testing.T) { t.Fatalf("Error creating user: %v", err) } - err = saveUser(user) + err = saveUser(testStorage, user) if err != nil { t.Fatalf("Error saving user: %v", err) } - _, err = os.Stat(storage.UserRegFile(email)) + _, err = os.Stat(testStorage.UserRegFile(email)) if err != nil { t.Errorf("Cannot access user registration file, error: %v", err) } - _, err = os.Stat(storage.UserKeyFile(email)) + _, err = os.Stat(testStorage.UserKeyFile(email)) if err != nil { t.Errorf("Cannot access user private key file, error: %v", err) } } func TestGetUserDoesNotAlreadyExist(t *testing.T) { - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) + defer os.RemoveAll(string(testStorage)) - user, err := getUser("user_does_not_exist@foobar.com") + user, err := getUser(testStorage, "user_does_not_exist@foobar.com") if err != nil { t.Fatalf("Error getting user: %v", err) } @@ -92,8 +89,7 @@ func TestGetUserDoesNotAlreadyExist(t *testing.T) { } func TestGetUserAlreadyExists(t *testing.T) { - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) + defer os.RemoveAll(string(testStorage)) email := "me@foobar.com" @@ -102,13 +98,13 @@ func TestGetUserAlreadyExists(t *testing.T) { if err != nil { t.Fatalf("Error creating user: %v", err) } - err = saveUser(user) + err = saveUser(testStorage, user) if err != nil { t.Fatalf("Error saving user: %v", err) } // Expect to load user from disk - user2, err := getUser(email) + user2, err := getUser(testStorage, email) if err != nil { t.Fatalf("Error getting user: %v", err) } @@ -125,48 +121,38 @@ func TestGetUserAlreadyExists(t *testing.T) { } func TestGetEmail(t *testing.T) { + storageBasePath = string(testStorage) // to contain calls that create a new Storage... + // let's not clutter up the output origStdout := os.Stdout os.Stdout = nil defer func() { os.Stdout = origStdout }() - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) + defer os.RemoveAll(string(testStorage)) DefaultEmail = "test2@foo.com" - // Test1: Use email in config - config := server.Config{ - TLS: server.TLSConfig{ - LetsEncryptEmail: "test1@foo.com", - }, - } - actual := getEmail(config, true) - if actual != "test1@foo.com" { - t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual) - } - - // Test2: Use default email from flag (or user previously typing it) - actual = getEmail(server.Config{}, true) + // Test1: Use default email from flag (or user previously typing it) + actual := getEmail(testStorage, true) if actual != DefaultEmail { - t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual) + t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual) } - // Test3: Get input from user + // Test2: Get input from user DefaultEmail = "" stdin = new(bytes.Buffer) _, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n")) if err != nil { t.Fatalf("Could not simulate user input, error: %v", err) } - actual = getEmail(server.Config{}, true) + actual = getEmail(testStorage, true) if actual != "test3@foo.com" { t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual) } - // Test4: Get most recent email from before + // Test3: Get most recent email from before DefaultEmail = "" for i, eml := range []string{ - "test4-3@foo.com", + "TEST4-3@foo.com", // test case insensitivity "test4-2@foo.com", "test4-1@foo.com", } { @@ -174,23 +160,25 @@ func TestGetEmail(t *testing.T) { if err != nil { t.Fatalf("Error creating user %d: %v", i, err) } - err = saveUser(u) + err = saveUser(testStorage, u) if err != nil { t.Fatalf("Error saving user %d: %v", i, err) } // Change modified time so they're all different, so the test becomes deterministic - f, err := os.Stat(storage.User(eml)) + f, err := os.Stat(testStorage.User(eml)) if err != nil { t.Fatalf("Could not access user folder for '%s': %v", eml, err) } chTime := f.ModTime().Add(-(time.Duration(i) * time.Second)) - if err := os.Chtimes(storage.User(eml), chTime, chTime); err != nil { + if err := os.Chtimes(testStorage.User(eml), chTime, chTime); err != nil { t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) } } - actual = getEmail(server.Config{}, true) + actual = getEmail(testStorage, true) if actual != "test4-3@foo.com" { t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual) } } + +var testStorage = Storage("./testdata") diff --git a/middleware/commands.go b/commands.go similarity index 94% rename from middleware/commands.go rename to commands.go index 2aaeb6141..3e64c90b9 100644 --- a/middleware/commands.go +++ b/commands.go @@ -1,4 +1,4 @@ -package middleware +package caddy import ( "errors" @@ -10,8 +10,8 @@ import ( var runtimeGoos = runtime.GOOS -// SplitCommandAndArgs takes a command string and parses it -// shell-style into the command and its separate arguments. +// SplitCommandAndArgs takes a command string and parses it shell-style into the +// command and its separate arguments. func SplitCommandAndArgs(command string) (cmd string, args []string, err error) { var parts []string diff --git a/middleware/commands_test.go b/commands_test.go similarity index 99% rename from middleware/commands_test.go rename to commands_test.go index 3001e65a5..5de37c761 100644 --- a/middleware/commands_test.go +++ b/commands_test.go @@ -1,4 +1,4 @@ -package middleware +package caddy import ( "fmt" diff --git a/controller.go b/controller.go new file mode 100644 index 000000000..4be794821 --- /dev/null +++ b/controller.go @@ -0,0 +1,86 @@ +package caddy + +import ( + "strings" + + "github.com/mholt/caddy/caddyfile" +) + +// Controller is given to the setup function of directives which +// gives them access to be able to read tokens and do whatever +// they need to do. +type Controller struct { + caddyfile.Dispenser + + // The instance in which the setup is occurring + instance *Instance + + // Key is the key from the top of the server block, usually + // an address, hostname, or identifier of some sort. + Key string + + // OncePerServerBlock is a function that executes f + // exactly once per server block, no matter how many + // hosts are associated with it. If it is the first + // time, the function f is executed immediately + // (not deferred) and may return an error which is + // returned by OncePerServerBlock. + OncePerServerBlock func(f func() error) error + + // ServerBlockIndex is the 0-based index of the + // server block as it appeared in the input. + ServerBlockIndex int + + // ServerBlockKeyIndex is the 0-based index of this + // key as it appeared in the input at the head of the + // server block. + ServerBlockKeyIndex int + + // ServerBlockKeys is a list of keys that are + // associated with this server block. All these + // keys, consequently, share the same tokens. + ServerBlockKeys []string + + // ServerBlockStorage is used by a directive's + // setup function to persist state between all + // the keys on a server block. + ServerBlockStorage interface{} +} + +// ServerType gets the name of the server type that is being set up. +func (c *Controller) ServerType() string { + return c.instance.serverType +} + +// OnStartup adds fn to the list of callback functions to execute +// when the server is about to be started. +func (c *Controller) OnStartup(fn func() error) { + c.instance.onStartup = append(c.instance.onStartup, fn) +} + +// OnRestart adds fn to the list of callback functions to execute +// when the server is about to be restarted. +func (c *Controller) OnRestart(fn func() error) { + c.instance.onRestart = append(c.instance.onRestart, fn) +} + +// OnShutdown adds fn to the list of callback functions to execute +// when the server is about to be shut down.. +func (c *Controller) OnShutdown(fn func() error) { + c.instance.onShutdown = append(c.instance.onShutdown, fn) +} + +// NewTestController creates a new *Controller for +// the input specified, with a filename of "Testfile". +// The Config is bare, consisting only of a Root of cwd. +// +// Used primarily for testing but needs to be exported so +// add-ons can use this as a convenience. Does not initialize +// the server-block-related fields. +func NewTestController(input string) *Controller { + return &Controller{ + instance: &Instance{serverType: ""}, + Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)), + OncePerServerBlock: func(f func() error) error { return f() }, + } +} diff --git a/dist/README.txt b/dist/README.txt index 298f35f7a..19699f047 100644 --- a/dist/README.txt +++ b/dist/README.txt @@ -1,4 +1,4 @@ -CADDY 0.8.3 +CADDY 0.9 beta 1 Website https://caddyserver.com @@ -14,15 +14,21 @@ Source Code https://github.com/caddyserver -For instructions on using Caddy, please see the user guide on the website. -For a list of what's new in this version, see CHANGES.txt. +For instructions on using Caddy, please see the user guide on +the website. For a list of what's new in this version, see +CHANGES.txt. -Please consider donating to the project if you think it is helpful, -especially if your company is using Caddy. There are also sponsorship -opportunities available! +The Caddy project accepts pull requests! That means you can make +changes to the code and submit it for review, and if it's good, +we'll use it! You can help thousands of Caddy users and level +up your Go programming game by contributing to Caddy's source. -If you have a question, bug report, or would like to contribute, please open an -issue or submit a pull request on GitHub. Your contributions do not go unnoticed! +To report bugs or request features, open an issue on GitHub. + +Want to support the project financially? Consider donating, +especially if your company is using Caddy. Believe me, your +contributions do not go unnoticed! We also have sponsorship +opportunities available. For a good time, follow @mholt6 on Twitter. diff --git a/dist/automate.go b/dist/automate.go index 594233f56..9b65475a1 100644 --- a/dist/automate.go +++ b/dist/automate.go @@ -12,12 +12,12 @@ import ( "github.com/mholt/archiver" ) -var buildScript, pkgDir, distDir, buildDir, releaseDir string +var buildScript, repoDir, distDir, buildDir, releaseDir string func init() { - pkgDir = filepath.Join(os.Getenv("GOPATH"), "src", "github.com", "mholt", "caddy") - buildScript = filepath.Join(pkgDir, "build.bash") - distDir = filepath.Join(pkgDir, "dist") + repoDir = filepath.Join(os.Getenv("GOPATH"), "src", "github.com", "mholt", "caddy") + buildScript = filepath.Join(repoDir, "caddy", "build.bash") + distDir = filepath.Join(repoDir, "dist") buildDir = filepath.Join(distDir, "builds") releaseDir = filepath.Join(distDir, "release") } @@ -98,7 +98,7 @@ func main() { func build(p platform, out string) error { cmd := exec.Command(buildScript, out) - cmd.Dir = pkgDir + cmd.Dir = repoDir cmd.Env = os.Environ() cmd.Env = append(cmd.Env, "CGO_ENABLED=0") cmd.Env = append(cmd.Env, "GOOS="+p.os) @@ -132,8 +132,8 @@ func numProcs() int { // Not all supported platforms are listed since some are // problematic and we only build the most common ones. // These are just the pre-made, readily-available static -// builds, and we can add more upon request if there is -// enough demand. +// builds, and we can try to add more upon request if there +// is enough demand. var platforms = []platform{ {os: "darwin", arch: "amd64", archive: "zip"}, {os: "freebsd", arch: "386", archive: "tar.gz"}, diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 01722ed60..000000000 --- a/main_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package main - -import ( - "runtime" - "testing" -) - -func TestSetCPU(t *testing.T) { - currentCPU := runtime.GOMAXPROCS(-1) - maxCPU := runtime.NumCPU() - halfCPU := int(0.5 * float32(maxCPU)) - if halfCPU < 1 { - halfCPU = 1 - } - for i, test := range []struct { - input string - output int - shouldErr bool - }{ - {"1", 1, false}, - {"-1", currentCPU, true}, - {"0", currentCPU, true}, - {"100%", maxCPU, false}, - {"50%", halfCPU, false}, - {"110%", currentCPU, true}, - {"-10%", currentCPU, true}, - {"invalid input", currentCPU, true}, - {"invalid input%", currentCPU, true}, - {"9999", maxCPU, false}, // over available CPU - } { - err := setCPU(test.input) - if test.shouldErr && err == nil { - t.Errorf("Test %d: Expected error, but there wasn't any", i) - } - if !test.shouldErr && err != nil { - t.Errorf("Test %d: Expected no error, but there was one: %v", i, err) - } - if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected { - t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected) - } - // teardown - runtime.GOMAXPROCS(currentCPU) - } -} - -func TestSetVersion(t *testing.T) { - setVersion() - if !devBuild { - t.Error("Expected default to assume development build, but it didn't") - } - if got, want := appVersion, "(untracked dev build)"; got != want { - t.Errorf("Expected appVersion='%s', got: '%s'", want, got) - } - - gitTag = "v1.1" - setVersion() - if devBuild { - t.Error("Expected a stable build if gitTag is set with no changes") - } - if got, want := appVersion, "1.1"; got != want { - t.Errorf("Expected appVersion='%s', got: '%s'", want, got) - } - - gitTag = "" - gitNearestTag = "v1.0" - gitCommit = "deadbeef" - buildDate = "Fri Feb 26 06:53:17 UTC 2016" - setVersion() - if !devBuild { - t.Error("Expected inferring a dev build when gitTag is empty") - } - if got, want := appVersion, "1.0 (+deadbeef Fri Feb 26 06:53:17 UTC 2016)"; got != want { - t.Errorf("Expected appVersion='%s', got: '%s'", want, got) - } -} diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go deleted file mode 100644 index ebfd0a8e6..000000000 --- a/middleware/basicauth/basicauth.go +++ /dev/null @@ -1,148 +0,0 @@ -// Package basicauth implements HTTP Basic Authentication. -package basicauth - -import ( - "bufio" - "crypto/subtle" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "sync" - - "github.com/jimstudt/http-authentication/basic" - "github.com/mholt/caddy/middleware" -) - -// BasicAuth is middleware to protect resources with a username and password. -// Note that HTTP Basic Authentication is not secure by itself and should -// not be used to protect important assets without HTTPS. Even then, the -// security of HTTP Basic Auth is disputed. Use discretion when deciding -// what to protect with BasicAuth. -type BasicAuth struct { - Next middleware.Handler - SiteRoot string - Rules []Rule -} - -// ServeHTTP implements the middleware.Handler interface. -func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - - var hasAuth bool - var isAuthenticated bool - - for _, rule := range a.Rules { - for _, res := range rule.Resources { - if !middleware.Path(r.URL.Path).Matches(res) { - continue - } - - // Path matches; parse auth header - username, password, ok := r.BasicAuth() - hasAuth = true - - // Check credentials - if !ok || - username != rule.Username || - !rule.Password(password) { - //subtle.ConstantTimeCompare([]byte(password), []byte(rule.Password)) != 1 { - continue - } - - // Flag set only on successful authentication - isAuthenticated = true - } - } - - if hasAuth { - if !isAuthenticated { - w.Header().Set("WWW-Authenticate", "Basic") - return http.StatusUnauthorized, nil - } - // "It's an older code, sir, but it checks out. I was about to clear them." - return a.Next.ServeHTTP(w, r) - } - - // Pass-thru when no paths match - return a.Next.ServeHTTP(w, r) -} - -// Rule represents a BasicAuth rule. A username and password -// combination protect the associated resources, which are -// file or directory paths. -type Rule struct { - Username string - Password func(string) bool - Resources []string -} - -// PasswordMatcher determines whether a password matches a rule. -type PasswordMatcher func(pw string) bool - -var ( - htpasswords map[string]map[string]PasswordMatcher - htpasswordsMu sync.Mutex -) - -// GetHtpasswdMatcher matches password rules. -func GetHtpasswdMatcher(filename, username, siteRoot string) (PasswordMatcher, error) { - filename = filepath.Join(siteRoot, filename) - htpasswordsMu.Lock() - if htpasswords == nil { - htpasswords = make(map[string]map[string]PasswordMatcher) - } - pm := htpasswords[filename] - if pm == nil { - fh, err := os.Open(filename) - if err != nil { - return nil, fmt.Errorf("open %q: %v", filename, err) - } - defer fh.Close() - pm = make(map[string]PasswordMatcher) - if err = parseHtpasswd(pm, fh); err != nil { - return nil, fmt.Errorf("parsing htpasswd %q: %v", fh.Name(), err) - } - htpasswords[filename] = pm - } - htpasswordsMu.Unlock() - if pm[username] == nil { - return nil, fmt.Errorf("username %q not found in %q", username, filename) - } - return pm[username], nil -} - -func parseHtpasswd(pm map[string]PasswordMatcher, r io.Reader) error { - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.IndexByte(line, '#') == 0 { - continue - } - i := strings.IndexByte(line, ':') - if i <= 0 { - return fmt.Errorf("malformed line, no color: %q", line) - } - user, encoded := line[:i], line[i+1:] - for _, p := range basic.DefaultSystems { - matcher, err := p(encoded) - if err != nil { - return err - } - if matcher != nil { - pm[user] = matcher.MatchesPassword - break - } - } - } - return scanner.Err() -} - -// PlainMatcher returns a PasswordMatcher that does a constant-time -// byte-wise comparison. -func PlainMatcher(passw string) PasswordMatcher { - return func(pw string) bool { - return subtle.ConstantTimeCompare([]byte(pw), []byte(passw)) == 1 - } -} diff --git a/middleware/basicauth/basicauth_test.go b/middleware/basicauth/basicauth_test.go deleted file mode 100644 index 631aaaed9..000000000 --- a/middleware/basicauth/basicauth_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package basicauth - -import ( - "encoding/base64" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestBasicAuth(t *testing.T) { - - rw := BasicAuth{ - Next: middleware.HandlerFunc(contentHandler), - Rules: []Rule{ - {Username: "test", Password: PlainMatcher("ttest"), Resources: []string{"/testing"}}, - }, - } - - tests := []struct { - from string - result int - cred string - }{ - {"/testing", http.StatusUnauthorized, "ttest:test"}, - {"/testing", http.StatusOK, "test:ttest"}, - {"/testing", http.StatusUnauthorized, ""}, - } - - for i, test := range tests { - - req, err := http.NewRequest("GET", test.from, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request %v", i, err) - } - auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred)) - req.Header.Set("Authorization", auth) - - rec := httptest.NewRecorder() - result, err := rw.ServeHTTP(rec, req) - if err != nil { - t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) - } - if result != test.result { - t.Errorf("Test %d: Expected Header '%d' but was '%d'", - i, test.result, result) - } - if result == http.StatusUnauthorized { - headers := rec.Header() - if val, ok := headers["Www-Authenticate"]; ok { - if val[0] != "Basic" { - t.Errorf("Test %d, Www-Authenticate should be %s provided %s", i, "Basic", val[0]) - } - } else { - t.Errorf("Test %d, should provide a header Www-Authenticate", i) - } - } - - } - -} - -func TestMultipleOverlappingRules(t *testing.T) { - rw := BasicAuth{ - Next: middleware.HandlerFunc(contentHandler), - Rules: []Rule{ - {Username: "t", Password: PlainMatcher("p1"), Resources: []string{"/t"}}, - {Username: "t1", Password: PlainMatcher("p2"), Resources: []string{"/t/t"}}, - }, - } - - tests := []struct { - from string - result int - cred string - }{ - {"/t", http.StatusOK, "t:p1"}, - {"/t/t", http.StatusOK, "t:p1"}, - {"/t/t", http.StatusOK, "t1:p2"}, - {"/a", http.StatusOK, "t1:p2"}, - {"/t/t", http.StatusUnauthorized, "t1:p3"}, - {"/t", http.StatusUnauthorized, "t1:p2"}, - } - - for i, test := range tests { - - req, err := http.NewRequest("GET", test.from, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request %v", i, err) - } - auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred)) - req.Header.Set("Authorization", auth) - - rec := httptest.NewRecorder() - result, err := rw.ServeHTTP(rec, req) - if err != nil { - t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) - } - if result != test.result { - t.Errorf("Test %d: Expected Header '%d' but was '%d'", - i, test.result, result) - } - - } - -} - -func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) { - fmt.Fprintf(w, r.URL.String()) - return http.StatusOK, nil -} - -func TestHtpasswd(t *testing.T) { - htpasswdPasswd := "IedFOuGmTpT8" - htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww= -md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61` - - htfh, err := ioutil.TempFile("", "basicauth-") - if err != nil { - t.Skipf("Error creating temp file (%v), will skip htpassword test") - return - } - defer os.Remove(htfh.Name()) - if _, err = htfh.Write([]byte(htpasswdFile)); err != nil { - t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err) - } - htfh.Close() - - for i, username := range []string{"sha1", "md5"} { - rule := Rule{Username: username, Resources: []string{"/testing"}} - - siteRoot := filepath.Dir(htfh.Name()) - filename := filepath.Base(htfh.Name()) - if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil { - t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err) - } - t.Logf("%d. username=%q", i, rule.Username) - if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") { - t.Errorf("%d (%s) password does not match.", i, rule.Username) - } - } -} diff --git a/middleware/browse/browse.go b/middleware/browse/browse.go deleted file mode 100644 index 62e4b1684..000000000 --- a/middleware/browse/browse.go +++ /dev/null @@ -1,431 +0,0 @@ -// Package browse provides middleware for listing files in a directory -// when directory path is requested instead of a specific file. -package browse - -import ( - "bytes" - "encoding/json" - "net/http" - "net/url" - "os" - "path" - "sort" - "strconv" - "strings" - "text/template" - "time" - - "github.com/dustin/go-humanize" - "github.com/mholt/caddy/middleware" -) - -// Browse is an http.Handler that can show a file listing when -// directories in the given paths are specified. -type Browse struct { - Next middleware.Handler - Configs []Config - IgnoreIndexes bool -} - -// Config is a configuration for browsing in a particular path. -type Config struct { - PathScope string - Root http.FileSystem - Variables interface{} - Template *template.Template -} - -// A Listing is the context used to fill out a template. -type Listing struct { - // The name of the directory (the last element of the path) - Name string - - // The full path of the request - Path string - - // Whether the parent directory is browsable - CanGoUp bool - - // The items (files and folders) in the path - Items []FileInfo - - // The number of directories in the listing - NumDirs int - - // The number of files (items that aren't directories) in the listing - NumFiles int - - // Which sorting order is used - Sort string - - // And which order - Order string - - // If ≠0 then Items have been limited to that many elements - ItemsLimitedTo int - - // Optional custom variables for use in browse templates - User interface{} - - middleware.Context -} - -// BreadcrumbMap returns l.Path where every element is a map -// of URLs and path segment names. -func (l Listing) BreadcrumbMap() map[string]string { - result := map[string]string{} - - if len(l.Path) == 0 { - return result - } - - // skip trailing slash - lpath := l.Path - if lpath[len(lpath)-1] == '/' { - lpath = lpath[:len(lpath)-1] - } - - parts := strings.Split(lpath, "/") - for i, part := range parts { - if i == 0 && part == "" { - // Leading slash (root) - result["/"] = "/" - continue - } - result[strings.Join(parts[:i+1], "/")] = part - } - - return result -} - -// FileInfo is the info about a particular file or directory -type FileInfo struct { - IsDir bool - Name string - Size int64 - URL string - ModTime time.Time - Mode os.FileMode -} - -// HumanSize returns the size of the file as a human-readable string -// in IEC format (i.e. power of 2 or base 1024). -func (fi FileInfo) HumanSize() string { - return humanize.IBytes(uint64(fi.Size)) -} - -// HumanModTime returns the modified time of the file as a human-readable string. -func (fi FileInfo) HumanModTime(format string) string { - return fi.ModTime.Format(format) -} - -// Implement sorting for Listing -type byName Listing -type bySize Listing -type byTime Listing - -// By Name -func (l byName) Len() int { return len(l.Items) } -func (l byName) Swap(i, j int) { l.Items[i], l.Items[j] = l.Items[j], l.Items[i] } - -// Treat upper and lower case equally -func (l byName) Less(i, j int) bool { - return strings.ToLower(l.Items[i].Name) < strings.ToLower(l.Items[j].Name) -} - -// By Size -func (l bySize) Len() int { return len(l.Items) } -func (l bySize) Swap(i, j int) { l.Items[i], l.Items[j] = l.Items[j], l.Items[i] } - -const directoryOffset = -1 << 31 // = math.MinInt32 -func (l bySize) Less(i, j int) bool { - iSize, jSize := l.Items[i].Size, l.Items[j].Size - if l.Items[i].IsDir { - iSize = directoryOffset + iSize - } - if l.Items[j].IsDir { - jSize = directoryOffset + jSize - } - return iSize < jSize -} - -// By Time -func (l byTime) Len() int { return len(l.Items) } -func (l byTime) Swap(i, j int) { l.Items[i], l.Items[j] = l.Items[j], l.Items[i] } -func (l byTime) Less(i, j int) bool { return l.Items[i].ModTime.Before(l.Items[j].ModTime) } - -// Add sorting method to "Listing" -// it will apply what's in ".Sort" and ".Order" -func (l Listing) applySort() { - // Check '.Order' to know how to sort - if l.Order == "desc" { - switch l.Sort { - case "name": - sort.Sort(sort.Reverse(byName(l))) - case "size": - sort.Sort(sort.Reverse(bySize(l))) - case "time": - sort.Sort(sort.Reverse(byTime(l))) - default: - // If not one of the above, do nothing - return - } - } else { // If we had more Orderings we could add them here - switch l.Sort { - case "name": - sort.Sort(byName(l)) - case "size": - sort.Sort(bySize(l)) - case "time": - sort.Sort(byTime(l)) - default: - // If not one of the above, do nothing - return - } - } -} - -func directoryListing(files []os.FileInfo, canGoUp bool, urlPath string) (Listing, bool) { - var ( - fileinfos []FileInfo - dirCount, fileCount int - hasIndexFile bool - ) - - for _, f := range files { - name := f.Name() - - for _, indexName := range middleware.IndexPages { - if name == indexName { - hasIndexFile = true - break - } - } - - if f.IsDir() { - name += "/" - dirCount++ - } else { - fileCount++ - } - - url := url.URL{Path: "./" + name} // prepend with "./" to fix paths with ':' in the name - - fileinfos = append(fileinfos, FileInfo{ - IsDir: f.IsDir(), - Name: f.Name(), - Size: f.Size(), - URL: url.String(), - ModTime: f.ModTime().UTC(), - Mode: f.Mode(), - }) - } - - return Listing{ - Name: path.Base(urlPath), - Path: urlPath, - CanGoUp: canGoUp, - Items: fileinfos, - NumDirs: dirCount, - NumFiles: fileCount, - }, hasIndexFile -} - -// ServeHTTP determines if the request is for this plugin, and if all prerequisites are met. -// If so, control is handed over to ServeListing. -func (b Browse) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - var bc *Config - // See if there's a browse configuration to match the path - for i := range b.Configs { - if middleware.Path(r.URL.Path).Matches(b.Configs[i].PathScope) { - bc = &b.Configs[i] - goto inScope - } - } - return b.Next.ServeHTTP(w, r) -inScope: - - // Browse works on existing directories; delegate everything else - requestedFilepath, err := bc.Root.Open(r.URL.Path) - if err != nil { - switch { - case os.IsPermission(err): - return http.StatusForbidden, err - case os.IsExist(err): - return http.StatusNotFound, err - default: - return b.Next.ServeHTTP(w, r) - } - } - defer requestedFilepath.Close() - - info, err := requestedFilepath.Stat() - if err != nil { - switch { - case os.IsPermission(err): - return http.StatusForbidden, err - case os.IsExist(err): - return http.StatusGone, err - default: - return b.Next.ServeHTTP(w, r) - } - } - if !info.IsDir() { - return b.Next.ServeHTTP(w, r) - } - - // Do not reply to anything else because it might be nonsensical - switch r.Method { - case http.MethodGet, http.MethodHead: - // proceed, noop - case "PROPFIND", http.MethodOptions: - return http.StatusNotImplemented, nil - default: - return b.Next.ServeHTTP(w, r) - } - - // Browsing navigation gets messed up if browsing a directory - // that doesn't end in "/" (which it should, anyway) - if !strings.HasSuffix(r.URL.Path, "/") { - http.Redirect(w, r, r.URL.Path+"/", http.StatusTemporaryRedirect) - return 0, nil - } - - return b.ServeListing(w, r, requestedFilepath, bc) -} - -func (b Browse) loadDirectoryContents(requestedFilepath http.File, urlPath string) (*Listing, bool, error) { - files, err := requestedFilepath.Readdir(-1) - if err != nil { - return nil, false, err - } - - // Determine if user can browse up another folder - var canGoUp bool - curPathDir := path.Dir(strings.TrimSuffix(urlPath, "/")) - for _, other := range b.Configs { - if strings.HasPrefix(curPathDir, other.PathScope) { - canGoUp = true - break - } - } - - // Assemble listing of directory contents - listing, hasIndex := directoryListing(files, canGoUp, urlPath) - - return &listing, hasIndex, nil -} - -// handleSortOrder gets and stores for a Listing the 'sort' and 'order', -// and reads 'limit' if given. The latter is 0 if not given. -// -// This sets Cookies. -func (b Browse) handleSortOrder(w http.ResponseWriter, r *http.Request, scope string) (sort string, order string, limit int, err error) { - sort, order, limitQuery := r.URL.Query().Get("sort"), r.URL.Query().Get("order"), r.URL.Query().Get("limit") - - // If the query 'sort' or 'order' is empty, use defaults or any values previously saved in Cookies - switch sort { - case "": - sort = "name" - if sortCookie, sortErr := r.Cookie("sort"); sortErr == nil { - sort = sortCookie.Value - } - case "name", "size", "type": - http.SetCookie(w, &http.Cookie{Name: "sort", Value: sort, Path: scope, Secure: r.TLS != nil}) - } - - switch order { - case "": - order = "asc" - if orderCookie, orderErr := r.Cookie("order"); orderErr == nil { - order = orderCookie.Value - } - case "asc", "desc": - http.SetCookie(w, &http.Cookie{Name: "order", Value: order, Path: scope, Secure: r.TLS != nil}) - } - - if limitQuery != "" { - limit, err = strconv.Atoi(limitQuery) - if err != nil { // if the 'limit' query can't be interpreted as a number, return err - return - } - } - - return -} - -// ServeListing returns a formatted view of 'requestedFilepath' contents'. -func (b Browse) ServeListing(w http.ResponseWriter, r *http.Request, requestedFilepath http.File, bc *Config) (int, error) { - listing, containsIndex, err := b.loadDirectoryContents(requestedFilepath, r.URL.Path) - if err != nil { - switch { - case os.IsPermission(err): - return http.StatusForbidden, err - case os.IsExist(err): - return http.StatusGone, err - default: - return http.StatusInternalServerError, err - } - } - if containsIndex && !b.IgnoreIndexes { // directory isn't browsable - return b.Next.ServeHTTP(w, r) - } - listing.Context = middleware.Context{ - Root: bc.Root, - Req: r, - URL: r.URL, - } - listing.User = bc.Variables - - // Copy the query values into the Listing struct - var limit int - listing.Sort, listing.Order, limit, err = b.handleSortOrder(w, r, bc.PathScope) - if err != nil { - return http.StatusBadRequest, err - } - - listing.applySort() - - if limit > 0 && limit <= len(listing.Items) { - listing.Items = listing.Items[:limit] - listing.ItemsLimitedTo = limit - } - - var buf *bytes.Buffer - acceptHeader := strings.ToLower(strings.Join(r.Header["Accept"], ",")) - switch { - case strings.Contains(acceptHeader, "application/json"): - if buf, err = b.formatAsJSON(listing, bc); err != nil { - return http.StatusInternalServerError, err - } - w.Header().Set("Content-Type", "application/json; charset=utf-8") - - default: // There's no 'application/json' in the 'Accept' header; browse normally - if buf, err = b.formatAsHTML(listing, bc); err != nil { - return http.StatusInternalServerError, err - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - - } - - buf.WriteTo(w) - - return http.StatusOK, nil -} - -func (b Browse) formatAsJSON(listing *Listing, bc *Config) (*bytes.Buffer, error) { - marsh, err := json.Marshal(listing.Items) - if err != nil { - return nil, err - } - - buf := new(bytes.Buffer) - _, err = buf.Write(marsh) - return buf, err -} - -func (b Browse) formatAsHTML(listing *Listing, bc *Config) (*bytes.Buffer, error) { - buf := new(bytes.Buffer) - err := bc.Template.Execute(buf, listing) - return buf, err -} diff --git a/middleware/browse/browse_test.go b/middleware/browse/browse_test.go deleted file mode 100644 index 161498f43..000000000 --- a/middleware/browse/browse_test.go +++ /dev/null @@ -1,356 +0,0 @@ -package browse - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "net/url" - "os" - "path/filepath" - "sort" - "testing" - "text/template" - "time" - - "github.com/mholt/caddy/middleware" -) - -// "sort" package has "IsSorted" function, but no "IsReversed"; -func isReversed(data sort.Interface) bool { - n := data.Len() - for i := n - 1; i > 0; i-- { - if !data.Less(i, i-1) { - return false - } - } - return true -} - -func TestSort(t *testing.T) { - // making up []fileInfo with bogus values; - // to be used to make up our "listing" - fileInfos := []FileInfo{ - { - Name: "fizz", - Size: 4, - ModTime: time.Now().AddDate(-1, 1, 0), - }, - { - Name: "buzz", - Size: 2, - ModTime: time.Now().AddDate(0, -3, 3), - }, - { - Name: "bazz", - Size: 1, - ModTime: time.Now().AddDate(0, -2, -23), - }, - { - Name: "jazz", - Size: 3, - ModTime: time.Now(), - }, - } - listing := Listing{ - Name: "foobar", - Path: "/fizz/buzz", - CanGoUp: false, - Items: fileInfos, - } - - // sort by name - listing.Sort = "name" - listing.applySort() - if !sort.IsSorted(byName(listing)) { - t.Errorf("The listing isn't name sorted: %v", listing.Items) - } - - // sort by size - listing.Sort = "size" - listing.applySort() - if !sort.IsSorted(bySize(listing)) { - t.Errorf("The listing isn't size sorted: %v", listing.Items) - } - - // sort by Time - listing.Sort = "time" - listing.applySort() - if !sort.IsSorted(byTime(listing)) { - t.Errorf("The listing isn't time sorted: %v", listing.Items) - } - - // reverse by name - listing.Sort = "name" - listing.Order = "desc" - listing.applySort() - if !isReversed(byName(listing)) { - t.Errorf("The listing isn't reversed by name: %v", listing.Items) - } - - // reverse by size - listing.Sort = "size" - listing.Order = "desc" - listing.applySort() - if !isReversed(bySize(listing)) { - t.Errorf("The listing isn't reversed by size: %v", listing.Items) - } - - // reverse by time - listing.Sort = "time" - listing.Order = "desc" - listing.applySort() - if !isReversed(byTime(listing)) { - t.Errorf("The listing isn't reversed by time: %v", listing.Items) - } -} - -func TestBrowseHTTPMethods(t *testing.T) { - tmpl, err := template.ParseFiles("testdata/photos.tpl") - if err != nil { - t.Fatalf("An error occured while parsing the template: %v", err) - } - - b := Browse{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - return http.StatusTeapot, nil // not t.Fatalf, or we will not see what other methods yield - }), - Configs: []Config{ - { - PathScope: "/photos", - Root: http.Dir("./testdata"), - Template: tmpl, - }, - }, - } - - rec := httptest.NewRecorder() - for method, expected := range map[string]int{ - http.MethodGet: http.StatusOK, - http.MethodHead: http.StatusOK, - http.MethodOptions: http.StatusNotImplemented, - "PROPFIND": http.StatusNotImplemented, - } { - req, err := http.NewRequest(method, "/photos/", nil) - if err != nil { - t.Fatalf("Test: Could not create HTTP request: %v", err) - } - - code, _ := b.ServeHTTP(rec, req) - if code != expected { - t.Errorf("Wrong status with HTTP Method %s: expected %d, got %d", method, expected, code) - } - } -} - -func TestBrowseTemplate(t *testing.T) { - tmpl, err := template.ParseFiles("testdata/photos.tpl") - if err != nil { - t.Fatalf("An error occured while parsing the template: %v", err) - } - - b := Browse{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - t.Fatalf("Next shouldn't be called") - return 0, nil - }), - Configs: []Config{ - { - PathScope: "/photos", - Root: http.Dir("./testdata"), - Template: tmpl, - }, - }, - } - - req, err := http.NewRequest("GET", "/photos/", nil) - if err != nil { - t.Fatalf("Test: Could not create HTTP request: %v", err) - } - - rec := httptest.NewRecorder() - - code, _ := b.ServeHTTP(rec, req) - if code != http.StatusOK { - t.Fatalf("Wrong status, expected %d, got %d", http.StatusOK, code) - } - - respBody := rec.Body.String() - expectedBody := ` - - -Template - - -

Header

- -

/photos/

- -test.html
- -test2.html
- -test3.html
- - - -` - - if respBody != expectedBody { - t.Fatalf("Expected body: %v got: %v", expectedBody, respBody) - } - -} - -func TestBrowseJson(t *testing.T) { - - b := Browse{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - t.Fatalf("Next shouldn't be called") - return 0, nil - }), - Configs: []Config{ - { - PathScope: "/photos/", - Root: http.Dir("./testdata"), - }, - }, - } - - //Getting the listing from the ./testdata/photos, the listing returned will be used to validate test results - testDataPath := filepath.Join("./testdata", "photos") - file, err := os.Open(testDataPath) - if err != nil { - if os.IsPermission(err) { - t.Fatalf("Os Permission Error") - } - } - defer file.Close() - - files, err := file.Readdir(-1) - if err != nil { - t.Fatalf("Unable to Read Contents of the directory") - } - var fileinfos []FileInfo - - for i, f := range files { - name := f.Name() - - // Tests fail in CI environment because all file mod times are the same for - // some reason, making the sorting unpredictable. To hack around this, - // we ensure here that each file has a different mod time. - chTime := f.ModTime().UTC().Add(-(time.Duration(i) * time.Second)) - if err := os.Chtimes(filepath.Join(testDataPath, name), chTime, chTime); err != nil { - t.Fatal(err) - } - - if f.IsDir() { - name += "/" - } - - url := url.URL{Path: "./" + name} - - fileinfos = append(fileinfos, FileInfo{ - IsDir: f.IsDir(), - Name: f.Name(), - Size: f.Size(), - URL: url.String(), - ModTime: chTime, - Mode: f.Mode(), - }) - } - listing := Listing{Items: fileinfos} // this listing will be used for validation inside the tests - - tests := []struct { - QueryURL string - SortBy string - OrderBy string - Limit int - shouldErr bool - expectedResult []FileInfo - }{ - //test case 1: testing for default sort and order and without the limit parameter, default sort is by name and the default order is ascending - //without the limit query entire listing will be produced - {"/", "", "", -1, false, listing.Items}, - //test case 2: limit is set to 1, orderBy and sortBy is default - {"/?limit=1", "", "", 1, false, listing.Items[:1]}, - //test case 3 : if the listing request is bigger than total size of listing then it should return everything - {"/?limit=100000000", "", "", 100000000, false, listing.Items}, - //test case 4 : testing for negative limit - {"/?limit=-1", "", "", -1, false, listing.Items}, - //test case 5 : testing with limit set to -1 and order set to descending - {"/?limit=-1&order=desc", "", "desc", -1, false, listing.Items}, - //test case 6 : testing with limit set to 2 and order set to descending - {"/?limit=2&order=desc", "", "desc", 2, false, listing.Items}, - //test case 7 : testing with limit set to 3 and order set to descending - {"/?limit=3&order=desc", "", "desc", 3, false, listing.Items}, - //test case 8 : testing with limit set to 3 and order set to ascending - {"/?limit=3&order=asc", "", "asc", 3, false, listing.Items}, - //test case 9 : testing with limit set to 1111111 and order set to ascending - {"/?limit=1111111&order=asc", "", "asc", 1111111, false, listing.Items}, - //test case 10 : testing with limit set to default and order set to ascending and sorting by size - {"/?order=asc&sort=size", "size", "asc", -1, false, listing.Items}, - //test case 11 : testing with limit set to default and order set to ascending and sorting by last modified - {"/?order=asc&sort=time", "time", "asc", -1, false, listing.Items}, - //test case 12 : testing with limit set to 1 and order set to ascending and sorting by last modified - {"/?order=asc&sort=time&limit=1", "time", "asc", 1, false, listing.Items}, - //test case 13 : testing with limit set to -100 and order set to ascending and sorting by last modified - {"/?order=asc&sort=time&limit=-100", "time", "asc", -100, false, listing.Items}, - //test case 14 : testing with limit set to -100 and order set to ascending and sorting by size - {"/?order=asc&sort=size&limit=-100", "size", "asc", -100, false, listing.Items}, - } - - for i, test := range tests { - var marsh []byte - req, err := http.NewRequest("GET", "/photos"+test.QueryURL, nil) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } - - req.Header.Set("Accept", "application/json") - rec := httptest.NewRecorder() - - code, err := b.ServeHTTP(rec, req) - - if code != http.StatusOK { - t.Fatalf("In test %d: Wrong status, expected %d, got %d", i, http.StatusOK, code) - } - if rec.HeaderMap.Get("Content-Type") != "application/json; charset=utf-8" { - t.Fatalf("Expected Content type to be application/json; charset=utf-8, but got %s ", rec.HeaderMap.Get("Content-Type")) - } - - actualJSONResponse := rec.Body.String() - copyOflisting := listing - if test.SortBy == "" { - copyOflisting.Sort = "name" - } else { - copyOflisting.Sort = test.SortBy - } - if test.OrderBy == "" { - copyOflisting.Order = "asc" - } else { - copyOflisting.Order = test.OrderBy - } - - copyOflisting.applySort() - - limit := test.Limit - if limit <= len(copyOflisting.Items) && limit > 0 { - marsh, err = json.Marshal(copyOflisting.Items[:limit]) - } else { // if the 'limit' query is empty, or has the wrong value, list everything - marsh, err = json.Marshal(copyOflisting.Items) - } - - if err != nil { - t.Fatalf("Unable to Marshal the listing ") - } - expectedJSON := string(marsh) - - if actualJSONResponse != expectedJSON { - t.Errorf("JSON response doesn't match the expected for test number %d with sort=%s, order=%s\nExpected response %s\nActual response = %s\n", - i+1, test.SortBy, test.OrderBy, expectedJSON, actualJSONResponse) - } - } -} diff --git a/middleware/browse/testdata/header.html b/middleware/browse/testdata/header.html deleted file mode 100644 index 78e5a6a48..000000000 --- a/middleware/browse/testdata/header.html +++ /dev/null @@ -1 +0,0 @@ -

Header

diff --git a/middleware/browse/testdata/photos.tpl b/middleware/browse/testdata/photos.tpl deleted file mode 100644 index 5163ca008..000000000 --- a/middleware/browse/testdata/photos.tpl +++ /dev/null @@ -1,13 +0,0 @@ - - - -Template - - -{{.Include "header.html"}} -

{{.Path}}

-{{range .Items}} -{{.Name}}
-{{end}} - - diff --git a/middleware/browse/testdata/photos/test.html b/middleware/browse/testdata/photos/test.html deleted file mode 100644 index 40535a223..000000000 --- a/middleware/browse/testdata/photos/test.html +++ /dev/null @@ -1,8 +0,0 @@ - - - -Test - - - - diff --git a/middleware/browse/testdata/photos/test2.html b/middleware/browse/testdata/photos/test2.html deleted file mode 100644 index 8e10c5780..000000000 --- a/middleware/browse/testdata/photos/test2.html +++ /dev/null @@ -1,8 +0,0 @@ - - - -Test 2 - - - - diff --git a/middleware/browse/testdata/photos/test3.html b/middleware/browse/testdata/photos/test3.html deleted file mode 100644 index 6c70af2fa..000000000 --- a/middleware/browse/testdata/photos/test3.html +++ /dev/null @@ -1,3 +0,0 @@ - - - \ No newline at end of file diff --git a/middleware/context.go b/middleware/context.go deleted file mode 100644 index 7dbb8c877..000000000 --- a/middleware/context.go +++ /dev/null @@ -1,263 +0,0 @@ -package middleware - -import ( - "bytes" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/url" - "strings" - "text/template" - "time" - - "github.com/russross/blackfriday" -) - -// This file contains the context and functions available for -// use in the templates. - -// Context is the context with which Caddy templates are executed. -type Context struct { - Root http.FileSystem - Req *http.Request - URL *url.URL -} - -// Include returns the contents of filename relative to the site root. -func (c Context) Include(filename string) (string, error) { - return ContextInclude(filename, c, c.Root) -} - -// Now returns the current timestamp in the specified format. -func (c Context) Now(format string) string { - return time.Now().Format(format) -} - -// NowDate returns the current date/time that can be used -// in other time functions. -func (c Context) NowDate() time.Time { - return time.Now() -} - -// Cookie gets the value of a cookie with name name. -func (c Context) Cookie(name string) string { - cookies := c.Req.Cookies() - for _, cookie := range cookies { - if cookie.Name == name { - return cookie.Value - } - } - return "" -} - -// Header gets the value of a request header with field name. -func (c Context) Header(name string) string { - return c.Req.Header.Get(name) -} - -// IP gets the (remote) IP address of the client making the request. -func (c Context) IP() string { - ip, _, err := net.SplitHostPort(c.Req.RemoteAddr) - if err != nil { - return c.Req.RemoteAddr - } - return ip -} - -// URI returns the raw, unprocessed request URI (including query -// string and hash) obtained directly from the Request-Line of -// the HTTP request. -func (c Context) URI() string { - return c.Req.RequestURI -} - -// Host returns the hostname portion of the Host header -// from the HTTP request. -func (c Context) Host() (string, error) { - host, _, err := net.SplitHostPort(c.Req.Host) - if err != nil { - if !strings.Contains(c.Req.Host, ":") { - // common with sites served on the default port 80 - return c.Req.Host, nil - } - return "", err - } - return host, nil -} - -// Port returns the port portion of the Host header if specified. -func (c Context) Port() (string, error) { - _, port, err := net.SplitHostPort(c.Req.Host) - if err != nil { - if !strings.Contains(c.Req.Host, ":") { - // common with sites served on the default port 80 - return "80", nil - } - return "", err - } - return port, nil -} - -// Method returns the method (GET, POST, etc.) of the request. -func (c Context) Method() string { - return c.Req.Method -} - -// PathMatches returns true if the path portion of the request -// URL matches pattern. -func (c Context) PathMatches(pattern string) bool { - return Path(c.Req.URL.Path).Matches(pattern) -} - -// Truncate truncates the input string to the given length. -// If length is negative, it returns that many characters -// starting from the end of the string. If the absolute value -// of length is greater than len(input), the whole input is -// returned. -func (c Context) Truncate(input string, length int) string { - if length < 0 && len(input)+length > 0 { - return input[len(input)+length:] - } - if length >= 0 && len(input) > length { - return input[:length] - } - return input -} - -// StripHTML returns s without HTML tags. It is fairly naive -// but works with most valid HTML inputs. -func (c Context) StripHTML(s string) string { - var buf bytes.Buffer - var inTag, inQuotes bool - var tagStart int - for i, ch := range s { - if inTag { - if ch == '>' && !inQuotes { - inTag = false - } else if ch == '<' && !inQuotes { - // false start - buf.WriteString(s[tagStart:i]) - tagStart = i - } else if ch == '"' { - inQuotes = !inQuotes - } - continue - } - if ch == '<' { - inTag = true - tagStart = i - continue - } - buf.WriteRune(ch) - } - if inTag { - // false start - buf.WriteString(s[tagStart:]) - } - return buf.String() -} - -// StripExt returns the input string without the extension, -// which is the suffix starting with the final '.' character -// but not before the final path separator ('/') character. -// If there is no extension, the whole input is returned. -func (c Context) StripExt(path string) string { - for i := len(path) - 1; i >= 0 && path[i] != '/'; i-- { - if path[i] == '.' { - return path[:i] - } - } - return path -} - -// Replace replaces instances of find in input with replacement. -func (c Context) Replace(input, find, replacement string) string { - return strings.Replace(input, find, replacement, -1) -} - -// Markdown returns the HTML contents of the markdown contained in filename -// (relative to the site root). -func (c Context) Markdown(filename string) (string, error) { - body, err := c.Include(filename) - if err != nil { - return "", err - } - renderer := blackfriday.HtmlRenderer(0, "", "") - extns := 0 - extns |= blackfriday.EXTENSION_TABLES - extns |= blackfriday.EXTENSION_FENCED_CODE - extns |= blackfriday.EXTENSION_STRIKETHROUGH - extns |= blackfriday.EXTENSION_DEFINITION_LISTS - markdown := blackfriday.Markdown([]byte(body), renderer, extns) - - return string(markdown), nil -} - -// ContextInclude opens filename using fs and executes a template with the context ctx. -// This does the same thing that Context.Include() does, but with the ability to provide -// your own context so that the included files can have access to additional fields your -// type may provide. You can embed Context in your type, then override its Include method -// to call this function with ctx being the instance of your type, and fs being Context.Root. -func ContextInclude(filename string, ctx interface{}, fs http.FileSystem) (string, error) { - file, err := fs.Open(filename) - if err != nil { - return "", err - } - defer file.Close() - - body, err := ioutil.ReadAll(file) - if err != nil { - return "", err - } - - tpl, err := template.New(filename).Parse(string(body)) - if err != nil { - return "", err - } - - var buf bytes.Buffer - err = tpl.Execute(&buf, ctx) - if err != nil { - return "", err - } - - return buf.String(), nil -} - -// ToLower will convert the given string to lower case. -func (c Context) ToLower(s string) string { - return strings.ToLower(s) -} - -// ToUpper will convert the given string to upper case. -func (c Context) ToUpper(s string) string { - return strings.ToUpper(s) -} - -// Split is a passthrough to strings.Split. It will split the first argument at each instance of the separator and return a slice of strings. -func (c Context) Split(s string, sep string) []string { - return strings.Split(s, sep) -} - -// Slice will convert the given arguments into a slice. -func (c Context) Slice(elems ...interface{}) []interface{} { - return elems -} - -// Map will convert the arguments into a map. It expects alternating string keys and values. This is useful for building more complicated data structures -// if you are using subtemplates or things like that. -func (c Context) Map(values ...interface{}) (map[string]interface{}, error) { - if len(values)%2 != 0 { - return nil, fmt.Errorf("Map expects an even number of arguments") - } - dict := make(map[string]interface{}, len(values)/2) - for i := 0; i < len(values); i += 2 { - key, ok := values[i].(string) - if !ok { - return nil, fmt.Errorf("Map keys must be strings") - } - dict[key] = values[i+1] - } - return dict, nil -} diff --git a/middleware/context_test.go b/middleware/context_test.go deleted file mode 100644 index a61a3bf92..000000000 --- a/middleware/context_test.go +++ /dev/null @@ -1,650 +0,0 @@ -package middleware - -import ( - "bytes" - "fmt" - "io/ioutil" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "text/template" -) - -func TestInclude(t *testing.T) { - context := getContextOrFail(t) - - inputFilename := "test_file" - absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) - defer func() { - err := os.Remove(absInFilePath) - if err != nil && !os.IsNotExist(err) { - t.Fatalf("Failed to clean test file!") - } - }() - - tests := []struct { - fileContent string - expectedContent string - shouldErr bool - expectedErrorContent string - }{ - // Test 0 - all good - { - fileContent: `str1 {{ .Root }} str2`, - expectedContent: fmt.Sprintf("str1 %s str2", context.Root), - shouldErr: false, - expectedErrorContent: "", - }, - // Test 1 - failure on template.Parse - { - fileContent: `str1 {{ .Root } str2`, - expectedContent: "", - shouldErr: true, - expectedErrorContent: `unexpected "}" in operand`, - }, - // Test 3 - failure on template.Execute - { - fileContent: `str1 {{ .InvalidField }} str2`, - expectedContent: "", - shouldErr: true, - expectedErrorContent: `InvalidField`, - }, - { - fileContent: `str1 {{ .InvalidField }} str2`, - expectedContent: "", - shouldErr: true, - expectedErrorContent: `type middleware.Context`, - }, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - - // WriteFile truncates the contentt - err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) - if err != nil { - t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) - } - - content, err := context.Include(inputFilename) - if err != nil { - if !test.shouldErr { - t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error()) - } - if !strings.Contains(err.Error(), test.expectedErrorContent) { - t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error()) - } - } - - if err == nil && test.shouldErr { - t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename) - } - - if content != test.expectedContent { - t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) - } - } -} - -func TestIncludeNotExisting(t *testing.T) { - context := getContextOrFail(t) - - _, err := context.Include("not_existing") - if err == nil { - t.Errorf("Expected error but found nil!") - } -} - -func TestMarkdown(t *testing.T) { - context := getContextOrFail(t) - - inputFilename := "test_file" - absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) - defer func() { - err := os.Remove(absInFilePath) - if err != nil && !os.IsNotExist(err) { - t.Fatalf("Failed to clean test file!") - } - }() - - tests := []struct { - fileContent string - expectedContent string - }{ - // Test 0 - test parsing of markdown - { - fileContent: "* str1\n* str2\n", - expectedContent: "
    \n
  • str1
  • \n
  • str2
  • \n
\n", - }, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - - // WriteFile truncates the contentt - err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) - if err != nil { - t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) - } - - content, _ := context.Markdown(inputFilename) - if content != test.expectedContent { - t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) - } - } -} - -func TestCookie(t *testing.T) { - - tests := []struct { - cookie *http.Cookie - cookieName string - expectedValue string - }{ - // Test 0 - happy path - { - cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, - cookieName: "cookieName", - expectedValue: "cookieValue", - }, - // Test 1 - try to get a non-existing cookie - { - cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, - cookieName: "notExisting", - expectedValue: "", - }, - // Test 2 - partial name match - { - cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"}, - cookieName: "cook", - expectedValue: "", - }, - // Test 3 - cookie with optional fields - { - cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120}, - cookieName: "cookie", - expectedValue: "cookieValue", - }, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - - // reinitialize the context for each test - context := getContextOrFail(t) - - context.Req.AddCookie(test.cookie) - - actualCookieVal := context.Cookie(test.cookieName) - - if actualCookieVal != test.expectedValue { - t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName) - } - } -} - -func TestCookieMultipleCookies(t *testing.T) { - context := getContextOrFail(t) - - cookieNameBase, cookieValueBase := "cookieName", "cookieValue" - - // make sure that there's no state and multiple requests for different cookies return the correct result - for i := 0; i < 10; i++ { - context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)}) - } - - for i := 0; i < 10; i++ { - expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i) - actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i)) - if actualCookieVal != expectedCookieVal { - t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal) - } - } -} - -func TestHeader(t *testing.T) { - context := getContextOrFail(t) - - headerKey, headerVal := "Header1", "HeaderVal1" - context.Req.Header.Add(headerKey, headerVal) - - actualHeaderVal := context.Header(headerKey) - if actualHeaderVal != headerVal { - t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) - } - - missingHeaderVal := context.Header("not-existing") - if missingHeaderVal != "" { - t.Errorf("Expected empty header value, found %s", missingHeaderVal) - } -} - -func TestIP(t *testing.T) { - context := getContextOrFail(t) - - tests := []struct { - inputRemoteAddr string - expectedIP string - }{ - // Test 0 - ipv4 with port - {"1.1.1.1:1111", "1.1.1.1"}, - // Test 1 - ipv4 without port - {"1.1.1.1", "1.1.1.1"}, - // Test 2 - ipv6 with port - {"[::1]:11", "::1"}, - // Test 3 - ipv6 without port and brackets - {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, - // Test 4 - ipv6 with zone and port - {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - - context.Req.RemoteAddr = test.inputRemoteAddr - actualIP := context.IP() - - if actualIP != test.expectedIP { - t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) - } - } -} - -func TestURL(t *testing.T) { - context := getContextOrFail(t) - - inputURL := "http://localhost" - context.Req.RequestURI = inputURL - - if inputURL != context.URI() { - t.Errorf("Expected url %s, found %s", inputURL, context.URI()) - } -} - -func TestHost(t *testing.T) { - tests := []struct { - input string - expectedHost string - shouldErr bool - }{ - { - input: "localhost:123", - expectedHost: "localhost", - shouldErr: false, - }, - { - input: "localhost", - expectedHost: "localhost", - shouldErr: false, - }, - { - input: "[::]", - expectedHost: "", - shouldErr: true, - }, - } - - for _, test := range tests { - testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) - } -} - -func TestPort(t *testing.T) { - tests := []struct { - input string - expectedPort string - shouldErr bool - }{ - { - input: "localhost:123", - expectedPort: "123", - shouldErr: false, - }, - { - input: "localhost", - expectedPort: "80", // assuming 80 is the default port - shouldErr: false, - }, - { - input: ":8080", - expectedPort: "8080", - shouldErr: false, - }, - { - input: "[::]", - expectedPort: "", - shouldErr: true, - }, - } - - for _, test := range tests { - testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) - } -} - -func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { - context := getContextOrFail(t) - - context.Req.Host = input - var actualResult, testedObject string - var err error - - if isTestingHost { - actualResult, err = context.Host() - testedObject = "host" - } else { - actualResult, err = context.Port() - testedObject = "port" - } - - if shouldErr && err == nil { - t.Errorf("Expected error, found nil!") - return - } - - if !shouldErr && err != nil { - t.Errorf("Expected no error, found %s", err) - return - } - - if actualResult != expectedResult { - t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) - } -} - -func TestMethod(t *testing.T) { - context := getContextOrFail(t) - - method := "POST" - context.Req.Method = method - - if method != context.Method() { - t.Errorf("Expected method %s, found %s", method, context.Method()) - } - -} - -func TestPathMatches(t *testing.T) { - context := getContextOrFail(t) - - tests := []struct { - urlStr string - pattern string - shouldMatch bool - }{ - // Test 0 - { - urlStr: "http://localhost/", - pattern: "", - shouldMatch: true, - }, - // Test 1 - { - urlStr: "http://localhost", - pattern: "", - shouldMatch: true, - }, - // Test 1 - { - urlStr: "http://localhost/", - pattern: "/", - shouldMatch: true, - }, - // Test 3 - { - urlStr: "http://localhost/?param=val", - pattern: "/", - shouldMatch: true, - }, - // Test 4 - { - urlStr: "http://localhost/dir1/dir2", - pattern: "/dir2", - shouldMatch: false, - }, - // Test 5 - { - urlStr: "http://localhost/dir1/dir2", - pattern: "/dir1", - shouldMatch: true, - }, - // Test 6 - { - urlStr: "http://localhost:444/dir1/dir2", - pattern: "/dir1", - shouldMatch: true, - }, - // Test 7 - { - urlStr: "http://localhost/dir1/dir2", - pattern: "*/dir2", - shouldMatch: false, - }, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - var err error - context.Req.URL, err = url.Parse(test.urlStr) - if err != nil { - t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) - } - - matches := context.PathMatches(test.pattern) - if matches != test.shouldMatch { - t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) - } - } -} - -func TestTruncate(t *testing.T) { - context := getContextOrFail(t) - tests := []struct { - inputString string - inputLength int - expected string - }{ - // Test 0 - small length - { - inputString: "string", - inputLength: 1, - expected: "s", - }, - // Test 1 - exact length - { - inputString: "string", - inputLength: 6, - expected: "string", - }, - // Test 2 - bigger length - { - inputString: "string", - inputLength: 10, - expected: "string", - }, - // Test 3 - zero length - { - inputString: "string", - inputLength: 0, - expected: "", - }, - // Test 4 - negative, smaller length - { - inputString: "string", - inputLength: -5, - expected: "tring", - }, - // Test 5 - negative, exact length - { - inputString: "string", - inputLength: -6, - expected: "string", - }, - // Test 6 - negative, bigger length - { - inputString: "string", - inputLength: -7, - expected: "string", - }, - } - - for i, test := range tests { - actual := context.Truncate(test.inputString, test.inputLength) - if actual != test.expected { - t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength) - } - } -} - -func TestStripHTML(t *testing.T) { - context := getContextOrFail(t) - tests := []struct { - input string - expected string - }{ - // Test 0 - no tags - { - input: `h1`, - expected: `h1`, - }, - // Test 1 - happy path - { - input: `

h1

`, - expected: `h1`, - }, - // Test 2 - tag in quotes - { - input: `">h1`, - expected: `h1`, - }, - // Test 3 - multiple tags - { - input: `

h1

`, - expected: `h1`, - }, - // Test 4 - tags not closed - { - input: `hi`, - expected: `= 400 { - h.errorPage(w, r, status) - return 0, err - } - - return status, err -} - -// errorPage serves a static error page to w according to the status -// code. If there is an error serving the error page, a plaintext error -// message is written instead, and the extra error is logged. -func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int) { - defaultBody := fmt.Sprintf("%d %s", code, http.StatusText(code)) - - // See if an error page for this status code was specified - if pagePath, ok := h.ErrorPages[code]; ok { - - // Try to open it - errorPage, err := os.Open(pagePath) - if err != nil { - // An additional error handling an error... - h.Log.Printf("%s [NOTICE %d %s] could not load error page: %v", - time.Now().Format(timeFormat), code, r.URL.String(), err) - http.Error(w, defaultBody, code) - return - } - defer errorPage.Close() - - // Copy the page body into the response - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(code) - _, err = io.Copy(w, errorPage) - - if err != nil { - // Epic fail... sigh. - h.Log.Printf("%s [NOTICE %d %s] could not respond with %s: %v", - time.Now().Format(timeFormat), code, r.URL.String(), pagePath, err) - http.Error(w, defaultBody, code) - } - - return - } - - // Default error response - http.Error(w, defaultBody, code) -} - -func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) { - rec := recover() - if rec == nil { - return - } - - // Obtain source of panic - // From: https://gist.github.com/swdunlop/9629168 - var name, file string // function name, file name - var line int - var pc [16]uintptr - n := runtime.Callers(3, pc[:]) - for _, pc := range pc[:n] { - fn := runtime.FuncForPC(pc) - if fn == nil { - continue - } - file, line = fn.FileLine(pc) - name = fn.Name() - if !strings.HasPrefix(name, "runtime.") { - break - } - } - - // Trim file path - delim := "/caddy/" - pkgPathPos := strings.Index(file, delim) - if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) { - file = file[pkgPathPos+len(delim):] - } - - panicMsg := fmt.Sprintf("%s [PANIC %s] %s:%d - %v", time.Now().Format(timeFormat), r.URL.String(), file, line, rec) - if h.Debug { - // Write error and stack trace to the response rather than to a log - var stackBuf [4096]byte - stack := stackBuf[:runtime.Stack(stackBuf[:], false)] - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "%s\n\n%s", panicMsg, stack) - } else { - // Currently we don't use the function name, since file:line is more conventional - h.Log.Printf(panicMsg) - h.errorPage(w, r, http.StatusInternalServerError) - } -} - -const timeFormat = "02/Jan/2006:15:04:05 -0700" diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go deleted file mode 100644 index 49af3e4f4..000000000 --- a/middleware/errors/errors_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package errors - -import ( - "bytes" - "errors" - "fmt" - "log" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strconv" - "strings" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestErrors(t *testing.T) { - // create a temporary page - path := filepath.Join(os.TempDir(), "errors_test.html") - f, err := os.Create(path) - if err != nil { - t.Fatal(err) - } - defer os.Remove(path) - - const content = "This is a error page" - _, err = f.WriteString(content) - if err != nil { - t.Fatal(err) - } - f.Close() - - buf := bytes.Buffer{} - em := ErrorHandler{ - ErrorPages: map[int]string{ - http.StatusNotFound: path, - http.StatusForbidden: "not_exist_file", - }, - Log: log.New(&buf, "", 0), - } - _, notExistErr := os.Open("not_exist_file") - - testErr := errors.New("test error") - tests := []struct { - next middleware.Handler - expectedCode int - expectedBody string - expectedLog string - expectedErr error - }{ - { - next: genErrorHandler(http.StatusOK, nil, "normal"), - expectedCode: http.StatusOK, - expectedBody: "normal", - expectedLog: "", - expectedErr: nil, - }, - { - next: genErrorHandler(http.StatusMovedPermanently, testErr, ""), - expectedCode: http.StatusMovedPermanently, - expectedBody: "", - expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr), - expectedErr: testErr, - }, - { - next: genErrorHandler(http.StatusBadRequest, nil, ""), - expectedCode: 0, - expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest, - http.StatusText(http.StatusBadRequest)), - expectedLog: "", - expectedErr: nil, - }, - { - next: genErrorHandler(http.StatusNotFound, nil, ""), - expectedCode: 0, - expectedBody: content, - expectedLog: "", - expectedErr: nil, - }, - { - next: genErrorHandler(http.StatusForbidden, nil, ""), - expectedCode: 0, - expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden, - http.StatusText(http.StatusForbidden)), - expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n", - http.StatusForbidden, notExistErr), - expectedErr: nil, - }, - } - - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) - } - for i, test := range tests { - em.Next = test.next - buf.Reset() - rec := httptest.NewRecorder() - code, err := em.ServeHTTP(rec, req) - - if err != test.expectedErr { - t.Errorf("Test %d: Expected error %v, but got %v", - i, test.expectedErr, err) - } - if code != test.expectedCode { - t.Errorf("Test %d: Expected status code %d, but got %d", - i, test.expectedCode, code) - } - if body := rec.Body.String(); body != test.expectedBody { - t.Errorf("Test %d: Expected body %q, but got %q", - i, test.expectedBody, body) - } - if log := buf.String(); !strings.Contains(log, test.expectedLog) { - t.Errorf("Test %d: Expected log %q, but got %q", - i, test.expectedLog, log) - } - } -} - -func TestVisibleErrorWithPanic(t *testing.T) { - const panicMsg = "I'm a panic" - eh := ErrorHandler{ - ErrorPages: make(map[int]string), - Debug: true, - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - panic(panicMsg) - }), - } - - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) - } - rec := httptest.NewRecorder() - - code, err := eh.ServeHTTP(rec, req) - - if code != 0 { - t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code) - } - if err != nil { - t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err) - } - - body := rec.Body.String() - - if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") { - t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body) - } - if !strings.Contains(body, panicMsg) { - t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body) - } - if len(body) < 500 { - t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body)) - } -} - -func genErrorHandler(status int, err error, body string) middleware.Handler { - return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - if len(body) > 0 { - w.Header().Set("Content-Length", strconv.Itoa(len(body))) - fmt.Fprint(w, body) - } - return status, err - }) -} diff --git a/middleware/expvar/expvar.go b/middleware/expvar/expvar.go deleted file mode 100644 index 178243486..000000000 --- a/middleware/expvar/expvar.go +++ /dev/null @@ -1,46 +0,0 @@ -package expvar - -import ( - "expvar" - "fmt" - "net/http" - - "github.com/mholt/caddy/middleware" -) - -// ExpVar is a simple struct to hold expvar's configuration -type ExpVar struct { - Next middleware.Handler - Resource Resource -} - -// ServeHTTP handles requests to expvar's configured entry point with -// expvar, or passes all other requests up the chain. -func (e ExpVar) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - if middleware.Path(r.URL.Path).Matches(string(e.Resource)) { - expvarHandler(w, r) - return 0, nil - } - - return e.Next.ServeHTTP(w, r) -} - -// expvarHandler returns a JSON object will all the published variables. -// -// This is lifted straight from the expvar package. -func expvarHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - fmt.Fprintf(w, "{\n") - first := true - expvar.Do(func(kv expvar.KeyValue) { - if !first { - fmt.Fprintf(w, ",\n") - } - first = false - fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value) - }) - fmt.Fprintf(w, "\n}\n") -} - -// Resource contains the path to the expvar entry point -type Resource string diff --git a/middleware/expvar/expvar_test.go b/middleware/expvar/expvar_test.go deleted file mode 100644 index e702f9418..000000000 --- a/middleware/expvar/expvar_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package expvar - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestExpVar(t *testing.T) { - rw := ExpVar{ - Next: middleware.HandlerFunc(contentHandler), - Resource: "/d/v", - } - - tests := []struct { - from string - result int - }{ - {"/d/v", 0}, - {"/x/y", http.StatusOK}, - } - - for i, test := range tests { - req, err := http.NewRequest("GET", test.from, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request %v", i, err) - } - rec := httptest.NewRecorder() - result, err := rw.ServeHTTP(rec, req) - if err != nil { - t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) - } - if result != test.result { - t.Errorf("Test %d: Expected Header '%d' but was '%d'", - i, test.result, result) - } - } -} - -func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) { - fmt.Fprintf(w, r.URL.String()) - return http.StatusOK, nil -} diff --git a/middleware/extensions/ext.go b/middleware/extensions/ext.go deleted file mode 100644 index 6796325d8..000000000 --- a/middleware/extensions/ext.go +++ /dev/null @@ -1,52 +0,0 @@ -// Package extensions contains middleware for clean URLs. -// -// The root path of the site is passed in as well as possible extensions -// to try internally for paths requested that don't match an existing -// resource. The first path+ext combination that matches a valid file -// will be used. -package extensions - -import ( - "net/http" - "os" - "path" - "strings" - - "github.com/mholt/caddy/middleware" -) - -// Ext can assume an extension from clean URLs. -// It tries extensions in the order listed in Extensions. -type Ext struct { - // Next handler in the chain - Next middleware.Handler - - // Path to ther root of the site - Root string - - // List of extensions to try - Extensions []string -} - -// ServeHTTP implements the middleware.Handler interface. -func (e Ext) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - urlpath := strings.TrimSuffix(r.URL.Path, "/") - if path.Ext(urlpath) == "" && len(r.URL.Path) > 0 && r.URL.Path[len(r.URL.Path)-1] != '/' { - for _, ext := range e.Extensions { - if resourceExists(e.Root, urlpath+ext) { - r.URL.Path = urlpath + ext - break - } - } - } - return e.Next.ServeHTTP(w, r) -} - -// resourceExists returns true if the file specified at -// root + path exists; false otherwise. -func resourceExists(root, path string) bool { - _, err := os.Stat(root + path) - // technically we should use os.IsNotExist(err) - // but we don't handle any other kinds of errors anyway - return err == nil -} diff --git a/middleware/extensions/ext_test.go b/middleware/extensions/ext_test.go deleted file mode 100644 index f03eaa2f3..000000000 --- a/middleware/extensions/ext_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package extensions - -import ( - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestExtensions(t *testing.T) { - rootDir := os.TempDir() - - // create a temporary page - path := filepath.Join(rootDir, "extensions_test.html") - _, err := os.Create(path) - if err != nil { - t.Fatal(err) - } - defer os.Remove(path) - - for i, test := range []struct { - path string - extensions []string - expectedURL string - }{ - {"/extensions_test", []string{".html"}, "/extensions_test.html"}, - {"/extensions_test/", []string{".html"}, "/extensions_test/"}, - {"/extensions_test", []string{".json"}, "/extensions_test"}, - {"/another_test", []string{".html"}, "/another_test"}, - {"", []string{".html"}, ""}, - } { - ex := Ext{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - return 0, nil - }), - Root: rootDir, - Extensions: test.extensions, - } - - req, err := http.NewRequest("GET", test.path, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) - } - - rec := httptest.NewRecorder() - - ex.ServeHTTP(rec, req) - - if got := req.URL.String(); got != test.expectedURL { - t.Fatalf("Test %d: Got unexpected request URL: %q, wanted %q", i, got, test.expectedURL) - } - } -} diff --git a/middleware/fastcgi/fastcgi.go b/middleware/fastcgi/fastcgi.go deleted file mode 100755 index 9db71f594..000000000 --- a/middleware/fastcgi/fastcgi.go +++ /dev/null @@ -1,336 +0,0 @@ -// Package fastcgi has middleware that acts as a FastCGI client. Requests -// that get forwarded to FastCGI stop the middleware execution chain. -// The most common use for this package is to serve PHP websites via php-fpm. -package fastcgi - -import ( - "errors" - "io" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - - "github.com/mholt/caddy/middleware" -) - -// Handler is a middleware type that can handle requests as a FastCGI client. -type Handler struct { - Next middleware.Handler - Rules []Rule - Root string - AbsRoot string // same as root, but absolute path - FileSys http.FileSystem - - // These are sent to CGI scripts in env variables - SoftwareName string - SoftwareVersion string - ServerName string - ServerPort string -} - -// ServeHTTP satisfies the middleware.Handler interface. -func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - for _, rule := range h.Rules { - - // First requirement: Base path must match and the path must be allowed. - if !middleware.Path(r.URL.Path).Matches(rule.Path) || !rule.AllowedPath(r.URL.Path) { - continue - } - - // In addition to matching the path, a request must meet some - // other criteria before being proxied as FastCGI. For example, - // we probably want to exclude static assets (CSS, JS, images...) - // but we also want to be flexible for the script we proxy to. - - fpath := r.URL.Path - - if idx, ok := middleware.IndexFile(h.FileSys, fpath, rule.IndexFiles); ok { - fpath = idx - // Index file present. - // If request path cannot be split, return error. - if !rule.canSplit(fpath) { - return http.StatusInternalServerError, ErrIndexMissingSplit - } - } else { - // No index file present. - // If request path cannot be split, ignore request. - if !rule.canSplit(fpath) { - continue - } - } - - // These criteria work well in this order for PHP sites - if !h.exists(fpath) || fpath[len(fpath)-1] == '/' || strings.HasSuffix(fpath, rule.Ext) { - - // Create environment for CGI script - env, err := h.buildEnv(r, rule, fpath) - if err != nil { - return http.StatusInternalServerError, err - } - - // Connect to FastCGI gateway - network, address := rule.parseAddress() - fcgiBackend, err := Dial(network, address) - if err != nil { - return http.StatusBadGateway, err - } - - var resp *http.Response - contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length")) - switch r.Method { - case "HEAD": - resp, err = fcgiBackend.Head(env) - case "GET": - resp, err = fcgiBackend.Get(env) - case "OPTIONS": - resp, err = fcgiBackend.Options(env) - default: - resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) - } - - if resp.Body != nil { - defer resp.Body.Close() - } - - if err != nil && err != io.EOF { - return http.StatusBadGateway, err - } - - // Write response header - writeHeader(w, resp) - - // Write the response body - _, err = io.Copy(w, resp.Body) - if err != nil { - return http.StatusBadGateway, err - } - - // Log any stderr output from upstream - if fcgiBackend.stderr.Len() != 0 { - // Remove trailing newline, error logger already does this. - err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) - } - - // Normally we would return the status code if it is an error status (>= 400), - // however, upstream FastCGI apps don't know about our contract and have - // probably already written an error page. So we just return 0, indicating - // that the response body is already written. However, we do return any - // error value so it can be logged. - // Note that the proxy middleware works the same way, returning status=0. - return 0, err - } - } - - return h.Next.ServeHTTP(w, r) -} - -// parseAddress returns the network and address of r. -// The first string is the network, "tcp" or "unix", implied from the scheme and address. -// The second string is r.Address, with scheme prefixes removed. -// The two returned strings can be used as parameters to the Dial() function. -func (r Rule) parseAddress() (string, string) { - // check if address has tcp scheme explicitly set - if strings.HasPrefix(r.Address, "tcp://") { - return "tcp", r.Address[len("tcp://"):] - } - // check if address has fastcgi scheme explicitly set - if strings.HasPrefix(r.Address, "fastcgi://") { - return "tcp", r.Address[len("fastcgi://"):] - } - // check if unix socket - if trim := strings.HasPrefix(r.Address, "unix"); strings.HasPrefix(r.Address, "/") || trim { - if trim { - return "unix", r.Address[len("unix:"):] - } - return "unix", r.Address - } - // default case, a plain tcp address with no scheme - return "tcp", r.Address -} - -func writeHeader(w http.ResponseWriter, r *http.Response) { - for key, vals := range r.Header { - for _, val := range vals { - w.Header().Add(key, val) - } - } - w.WriteHeader(r.StatusCode) -} - -func (h Handler) exists(path string) bool { - if _, err := os.Stat(h.Root + path); err == nil { - return true - } - return false -} - -// buildEnv returns a set of CGI environment variables for the request. -func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]string, error) { - var env map[string]string - - // Get absolute path of requested resource - absPath := filepath.Join(h.AbsRoot, fpath) - - // Separate remote IP and port; more lenient than net.SplitHostPort - var ip, port string - if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 { - ip = r.RemoteAddr[:idx] - port = r.RemoteAddr[idx+1:] - } else { - ip = r.RemoteAddr - } - - // Remove [] from IPv6 addresses - ip = strings.Replace(ip, "[", "", 1) - ip = strings.Replace(ip, "]", "", 1) - - // Split path in preparation for env variables. - // Previous rule.canSplit checks ensure this can never be -1. - splitPos := rule.splitPos(fpath) - - // Request has the extension; path was split successfully - docURI := fpath[:splitPos+len(rule.SplitPath)] - pathInfo := fpath[splitPos+len(rule.SplitPath):] - scriptName := fpath - scriptFilename := absPath - - // Strip PATH_INFO from SCRIPT_NAME - scriptName = strings.TrimSuffix(scriptName, pathInfo) - - // Get the request URI. The request URI might be as it came in over the wire, - // or it might have been rewritten internally by the rewrite middleware (see issue #256). - // If it was rewritten, there will be a header indicating the original URL, - // which is needed to get the correct RequestURI value for PHP apps. - const internalRewriteFieldName = "Caddy-Rewrite-Original-URI" - reqURI := r.URL.RequestURI() - if origURI := r.Header.Get(internalRewriteFieldName); origURI != "" { - reqURI = origURI - r.Header.Del(internalRewriteFieldName) - } - - // Some variables are unused but cleared explicitly to prevent - // the parent environment from interfering. - env = map[string]string{ - - // Variables defined in CGI 1.1 spec - "AUTH_TYPE": "", // Not used - "CONTENT_LENGTH": r.Header.Get("Content-Length"), - "CONTENT_TYPE": r.Header.Get("Content-Type"), - "GATEWAY_INTERFACE": "CGI/1.1", - "PATH_INFO": pathInfo, - "QUERY_STRING": r.URL.RawQuery, - "REMOTE_ADDR": ip, - "REMOTE_HOST": ip, // For speed, remote host lookups disabled - "REMOTE_PORT": port, - "REMOTE_IDENT": "", // Not used - "REMOTE_USER": "", // Not used - "REQUEST_METHOD": r.Method, - "SERVER_NAME": h.ServerName, - "SERVER_PORT": h.ServerPort, - "SERVER_PROTOCOL": r.Proto, - "SERVER_SOFTWARE": h.SoftwareName + "/" + h.SoftwareVersion, - - // Other variables - "DOCUMENT_ROOT": h.AbsRoot, - "DOCUMENT_URI": docURI, - "HTTP_HOST": r.Host, // added here, since not always part of headers - "REQUEST_URI": reqURI, - "SCRIPT_FILENAME": scriptFilename, - "SCRIPT_NAME": scriptName, - } - - // compliance with the CGI specification that PATH_TRANSLATED - // should only exist if PATH_INFO is defined. - // Info: https://www.ietf.org/rfc/rfc3875 Page 14 - if env["PATH_INFO"] != "" { - env["PATH_TRANSLATED"] = filepath.Join(h.AbsRoot, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html - } - - // Some web apps rely on knowing HTTPS or not - if r.TLS != nil { - env["HTTPS"] = "on" - } - - // Add env variables from config - for _, envVar := range rule.EnvVars { - env[envVar[0]] = envVar[1] - } - - // Add all HTTP headers to env variables - for field, val := range r.Header { - header := strings.ToUpper(field) - header = headerNameReplacer.Replace(header) - env["HTTP_"+header] = strings.Join(val, ", ") - } - - return env, nil -} - -// Rule represents a FastCGI handling rule. -type Rule struct { - // The base path to match. Required. - Path string - - // The address of the FastCGI server. Required. - Address string - - // Always process files with this extension with fastcgi. - Ext string - - // The path in the URL will be split into two, with the first piece ending - // with the value of SplitPath. The first piece will be assumed as the - // actual resource (CGI script) name, and the second piece will be set to - // PATH_INFO for the CGI script to use. - SplitPath string - - // If the URL ends with '/' (which indicates a directory), these index - // files will be tried instead. - IndexFiles []string - - // Environment Variables - EnvVars [][2]string - - // Ignored paths - IgnoredSubPaths []string -} - -// canSplit checks if path can split into two based on rule.SplitPath. -func (r Rule) canSplit(path string) bool { - return r.splitPos(path) >= 0 -} - -// splitPos returns the index where path should be split -// based on rule.SplitPath. -func (r Rule) splitPos(path string) int { - if middleware.CaseSensitivePath { - return strings.Index(path, r.SplitPath) - } - return strings.Index(strings.ToLower(path), strings.ToLower(r.SplitPath)) -} - -// AllowedPath checks if requestPath is not an ignored path. -func (r Rule) AllowedPath(requestPath string) bool { - for _, ignoredSubPath := range r.IgnoredSubPaths { - if middleware.Path(path.Clean(requestPath)).Matches(path.Join(r.Path, ignoredSubPath)) { - return false - } - } - return true -} - -var ( - headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_") - // ErrIndexMissingSplit describes an index configuration error. - ErrIndexMissingSplit = errors.New("configured index file(s) must include split value") -) - -// LogError is a non fatal error that allows requests to go through. -type LogError string - -// Error satisfies error interface. -func (l LogError) Error() string { - return string(l) -} diff --git a/middleware/fastcgi/fastcgi_test.go b/middleware/fastcgi/fastcgi_test.go deleted file mode 100644 index e1e394919..000000000 --- a/middleware/fastcgi/fastcgi_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package fastcgi - -import ( - "net" - "net/http" - "net/http/fcgi" - "net/http/httptest" - "net/url" - "strconv" - "testing" -) - -func TestServeHTTP(t *testing.T) { - body := "This is some test body content" - - bodyLenStr := strconv.Itoa(len(body)) - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Unable to create listener for test: %v", err) - } - defer listener.Close() - go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", bodyLenStr) - w.Write([]byte(body)) - })) - - handler := Handler{ - Next: nil, - Rules: []Rule{{Path: "/", Address: listener.Addr().String()}}, - } - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("Unable to create request: %v", err) - } - w := httptest.NewRecorder() - - status, err := handler.ServeHTTP(w, r) - - if got, want := status, 0; got != want { - t.Errorf("Expected returned status code to be %d, got %d", want, got) - } - if err != nil { - t.Errorf("Expected nil error, got: %v", err) - } - if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want { - t.Errorf("Expected Content-Length to be '%s', got: '%s'", want, got) - } - if got, want := w.Body.String(), body; got != want { - t.Errorf("Expected response body to be '%s', got: '%s'", want, got) - } -} - -func TestRuleParseAddress(t *testing.T) { - getClientTestTable := []struct { - rule *Rule - expectednetwork string - expectedaddress string - }{ - {&Rule{Address: "tcp://172.17.0.1:9000"}, "tcp", "172.17.0.1:9000"}, - {&Rule{Address: "fastcgi://localhost:9000"}, "tcp", "localhost:9000"}, - {&Rule{Address: "172.17.0.15"}, "tcp", "172.17.0.15"}, - {&Rule{Address: "/my/unix/socket"}, "unix", "/my/unix/socket"}, - {&Rule{Address: "unix:/second/unix/socket"}, "unix", "/second/unix/socket"}, - } - - for _, entry := range getClientTestTable { - if actualnetwork, _ := entry.rule.parseAddress(); actualnetwork != entry.expectednetwork { - t.Errorf("Unexpected network for address string %v. Got %v, expected %v", entry.rule.Address, actualnetwork, entry.expectednetwork) - } - if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress { - t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress) - } - } -} - -func TestRuleIgnoredPath(t *testing.T) { - rule := &Rule{ - Path: "/fastcgi", - IgnoredSubPaths: []string{"/download", "/static"}, - } - tests := []struct { - url string - expected bool - }{ - {"/fastcgi", true}, - {"/fastcgi/dl", true}, - {"/fastcgi/download", false}, - {"/fastcgi/download/static", false}, - {"/fastcgi/static", false}, - {"/fastcgi/static/download", false}, - {"/fastcgi/something/download", true}, - {"/fastcgi/something/static", true}, - {"/fastcgi//static", false}, - {"/fastcgi//static//download", false}, - {"/fastcgi//download", false}, - } - - for i, test := range tests { - allowed := rule.AllowedPath(test.url) - if test.expected != allowed { - t.Errorf("Test %d: expected %v found %v", i, test.expected, allowed) - } - } -} - -func TestBuildEnv(t *testing.T) { - testBuildEnv := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string) { - var h Handler - env, err := h.buildEnv(r, rule, fpath) - if err != nil { - t.Error("Unexpected error:", err.Error()) - } - for k, v := range envExpected { - if env[k] != v { - t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v) - } - } - } - - rule := Rule{} - url, err := url.Parse("http://localhost:2015/fgci_test.php?test=blabla") - if err != nil { - t.Error("Unexpected error:", err.Error()) - } - - r := http.Request{ - Method: "GET", - URL: url, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Host: "localhost:2015", - RemoteAddr: "[2b02:1810:4f2d:9400:70ab:f822:be8a:9093]:51688", - RequestURI: "/fgci_test.php", - } - - fpath := "/fgci_test.php" - - var envExpected = map[string]string{ - "REMOTE_ADDR": "2b02:1810:4f2d:9400:70ab:f822:be8a:9093", - "REMOTE_PORT": "51688", - "SERVER_PROTOCOL": "HTTP/1.1", - "QUERY_STRING": "test=blabla", - "REQUEST_METHOD": "GET", - "HTTP_HOST": "localhost:2015", - } - - // 1. Test for full canonical IPv6 address - testBuildEnv(&r, rule, fpath, envExpected) - - // 2. Test for shorthand notation of IPv6 address - r.RemoteAddr = "[::1]:51688" - envExpected["REMOTE_ADDR"] = "::1" - testBuildEnv(&r, rule, fpath, envExpected) - - // 3. Test for IPv4 address - r.RemoteAddr = "192.168.0.10:51688" - envExpected["REMOTE_ADDR"] = "192.168.0.10" - testBuildEnv(&r, rule, fpath, envExpected) -} diff --git a/middleware/fastcgi/fcgi_test.php b/middleware/fastcgi/fcgi_test.php deleted file mode 100644 index 3f5e5f2db..000000000 --- a/middleware/fastcgi/fcgi_test.php +++ /dev/null @@ -1,79 +0,0 @@ - $val) { - $md5 = md5($val); - - if ($key != $md5) { - $stat = "FAILED"; - echo "server:err ".$md5." != ".$key."\n"; - } - - $length += strlen($key) + strlen($val); - - $ret .= $key."(".strlen($key).") "; - } - $ret .= "] ["; - foreach ($_FILES as $k0 => $val) { - - $error = $val["error"]; - if ($error == UPLOAD_ERR_OK) { - $tmp_name = $val["tmp_name"]; - $name = $val["name"]; - $datafile = "/tmp/test.go"; - move_uploaded_file($tmp_name, $datafile); - $md5 = md5_file($datafile); - - if ($k0 != $md5) { - $stat = "FAILED"; - echo "server:err ".$md5." != ".$key."\n"; - } - - $length += strlen($k0) + filesize($datafile); - - unlink($datafile); - $ret .= $k0."(".strlen($k0).") "; - } - else{ - $stat = "FAILED"; - echo "server:file err ".file_upload_error_message($error)."\n"; - } - } - $ret .= "]"; - echo "server:got data length " .$length."\n"; -} - - -echo "-{$stat}-POST(".count($_POST).") FILE(".count($_FILES).")\n"; - -function file_upload_error_message($error_code) { - switch ($error_code) { - case UPLOAD_ERR_INI_SIZE: - return 'The uploaded file exceeds the upload_max_filesize directive in php.ini'; - case UPLOAD_ERR_FORM_SIZE: - return 'The uploaded file exceeds the MAX_FILE_SIZE directive that was specified in the HTML form'; - case UPLOAD_ERR_PARTIAL: - return 'The uploaded file was only partially uploaded'; - case UPLOAD_ERR_NO_FILE: - return 'No file was uploaded'; - case UPLOAD_ERR_NO_TMP_DIR: - return 'Missing a temporary folder'; - case UPLOAD_ERR_CANT_WRITE: - return 'Failed to write file to disk'; - case UPLOAD_ERR_EXTENSION: - return 'File upload stopped by extension'; - default: - return 'Unknown upload error'; - } -} \ No newline at end of file diff --git a/middleware/fastcgi/fcgiclient.go b/middleware/fastcgi/fcgiclient.go deleted file mode 100644 index f443f63d9..000000000 --- a/middleware/fastcgi/fcgiclient.go +++ /dev/null @@ -1,560 +0,0 @@ -// Forked Jan. 2015 from http://bitbucket.org/PinIdea/fcgi_client -// (which is forked from https://code.google.com/p/go-fastcgi-client/) - -// This fork contains several fixes and improvements by Matt Holt and -// other contributors to this project. - -// Copyright 2012 Junqing Tan and The Go Authors -// Use of this source code is governed by a BSD-style -// Part of source code is from Go fcgi package - -package fastcgi - -import ( - "bufio" - "bytes" - "encoding/binary" - "errors" - "io" - "io/ioutil" - "mime/multipart" - "net" - "net/http" - "net/http/httputil" - "net/textproto" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "sync" -) - -// FCGIListenSockFileno describes listen socket file number. -const FCGIListenSockFileno uint8 = 0 - -// FCGIHeaderLen describes header length. -const FCGIHeaderLen uint8 = 8 - -// Version1 describes the version. -const Version1 uint8 = 1 - -// FCGINullRequestID describes the null request ID. -const FCGINullRequestID uint8 = 0 - -// FCGIKeepConn describes keep connection mode. -const FCGIKeepConn uint8 = 1 -const doubleCRLF = "\r\n\r\n" - -const ( - // BeginRequest is the begin request flag. - BeginRequest uint8 = iota + 1 - // AbortRequest is the abort request flag. - AbortRequest - // EndRequest is the end request flag. - EndRequest - // Params is the parameters flag. - Params - // Stdin is the standard input flag. - Stdin - // Stdout is the standard output flag. - Stdout - // Stderr is the standard error flag. - Stderr - // Data is the data flag. - Data - // GetValues is the get values flag. - GetValues - // GetValuesResult is the get values result flag. - GetValuesResult - // UnknownType is the unknown type flag. - UnknownType - // MaxType is the maximum type flag. - MaxType = UnknownType -) - -const ( - // Responder is the responder flag. - Responder uint8 = iota + 1 - // Authorizer is the authorizer flag. - Authorizer - // Filter is the filter flag. - Filter -) - -const ( - // RequestComplete is the completed request flag. - RequestComplete uint8 = iota - // CantMultiplexConns is the multiplexed connections flag. - CantMultiplexConns - // Overloaded is the overloaded flag. - Overloaded - // UnknownRole is the unknown role flag. - UnknownRole -) - -const ( - // MaxConns is the maximum connections flag. - MaxConns string = "MAX_CONNS" - // MaxRequests is the maximum requests flag. - MaxRequests string = "MAX_REQS" - // MultiplexConns is the multiplex connections flag. - MultiplexConns string = "MPXS_CONNS" -) - -const ( - maxWrite = 65500 // 65530 may work, but for compatibility - maxPad = 255 -) - -type header struct { - Version uint8 - Type uint8 - ID uint16 - ContentLength uint16 - PaddingLength uint8 - Reserved uint8 -} - -// for padding so we don't have to allocate all the time -// not synchronized because we don't care what the contents are -var pad [maxPad]byte - -func (h *header) init(recType uint8, reqID uint16, contentLength int) { - h.Version = 1 - h.Type = recType - h.ID = reqID - h.ContentLength = uint16(contentLength) - h.PaddingLength = uint8(-contentLength & 7) -} - -type record struct { - h header - rbuf []byte -} - -func (rec *record) read(r io.Reader) (buf []byte, err error) { - if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { - return - } - if rec.h.Version != 1 { - err = errors.New("fcgi: invalid header version") - return - } - if rec.h.Type == EndRequest { - err = io.EOF - return - } - n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) - if len(rec.rbuf) < n { - rec.rbuf = make([]byte, n) - } - if _, err = io.ReadFull(r, rec.rbuf[:n]); err != nil { - return - } - buf = rec.rbuf[:int(rec.h.ContentLength)] - - return -} - -// FCGIClient implements a FastCGI client, which is a standard for -// interfacing external applications with Web servers. -type FCGIClient struct { - mutex sync.Mutex - rwc io.ReadWriteCloser - h header - buf bytes.Buffer - stderr bytes.Buffer - keepAlive bool - reqID uint16 -} - -// DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer. -// See func net.Dial for a description of the network and address parameters. -func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) { - var conn net.Conn - conn, err = dialer.Dial(network, address) - if err != nil { - return - } - - fcgi = &FCGIClient{ - rwc: conn, - keepAlive: false, - reqID: 1, - } - - return -} - -// Dial connects to the fcgi responder at the specified network address, using default net.Dialer. -// See func net.Dial for a description of the network and address parameters. -func Dial(network, address string) (fcgi *FCGIClient, err error) { - return DialWithDialer(network, address, net.Dialer{}) -} - -// Close closes fcgi connnection -func (c *FCGIClient) Close() { - c.rwc.Close() -} - -func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) { - c.mutex.Lock() - defer c.mutex.Unlock() - c.buf.Reset() - c.h.init(recType, c.reqID, len(content)) - if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { - return err - } - if _, err := c.buf.Write(content); err != nil { - return err - } - if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { - return err - } - _, err = c.rwc.Write(c.buf.Bytes()) - return err -} - -func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error { - b := [8]byte{byte(role >> 8), byte(role), flags} - return c.writeRecord(BeginRequest, b[:]) -} - -func (c *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error { - b := make([]byte, 8) - binary.BigEndian.PutUint32(b, uint32(appStatus)) - b[4] = protocolStatus - return c.writeRecord(EndRequest, b) -} - -func (c *FCGIClient) writePairs(recType uint8, pairs map[string]string) error { - w := newWriter(c, recType) - b := make([]byte, 8) - nn := 0 - for k, v := range pairs { - m := 8 + len(k) + len(v) - if m > maxWrite { - // param data size exceed 65535 bytes" - vl := maxWrite - 8 - len(k) - v = v[:vl] - } - n := encodeSize(b, uint32(len(k))) - n += encodeSize(b[n:], uint32(len(v))) - m = n + len(k) + len(v) - if (nn + m) > maxWrite { - w.Flush() - nn = 0 - } - nn += m - if _, err := w.Write(b[:n]); err != nil { - return err - } - if _, err := w.WriteString(k); err != nil { - return err - } - if _, err := w.WriteString(v); err != nil { - return err - } - } - w.Close() - return nil -} - -func readSize(s []byte) (uint32, int) { - if len(s) == 0 { - return 0, 0 - } - size, n := uint32(s[0]), 1 - if size&(1<<7) != 0 { - if len(s) < 4 { - return 0, 0 - } - n = 4 - size = binary.BigEndian.Uint32(s) - size &^= 1 << 31 - } - return size, n -} - -func readString(s []byte, size uint32) string { - if size > uint32(len(s)) { - return "" - } - return string(s[:size]) -} - -func encodeSize(b []byte, size uint32) int { - if size > 127 { - size |= 1 << 31 - binary.BigEndian.PutUint32(b, size) - return 4 - } - b[0] = byte(size) - return 1 -} - -// bufWriter encapsulates bufio.Writer but also closes the underlying stream when -// Closed. -type bufWriter struct { - closer io.Closer - *bufio.Writer -} - -func (w *bufWriter) Close() error { - if err := w.Writer.Flush(); err != nil { - w.closer.Close() - return err - } - return w.closer.Close() -} - -func newWriter(c *FCGIClient, recType uint8) *bufWriter { - s := &streamWriter{c: c, recType: recType} - w := bufio.NewWriterSize(s, maxWrite) - return &bufWriter{s, w} -} - -// streamWriter abstracts out the separation of a stream into discrete records. -// It only writes maxWrite bytes at a time. -type streamWriter struct { - c *FCGIClient - recType uint8 -} - -func (w *streamWriter) Write(p []byte) (int, error) { - nn := 0 - for len(p) > 0 { - n := len(p) - if n > maxWrite { - n = maxWrite - } - if err := w.c.writeRecord(w.recType, p[:n]); err != nil { - return nn, err - } - nn += n - p = p[n:] - } - return nn, nil -} - -func (w *streamWriter) Close() error { - // send empty record to close the stream - return w.c.writeRecord(w.recType, nil) -} - -type streamReader struct { - c *FCGIClient - buf []byte -} - -func (w *streamReader) Read(p []byte) (n int, err error) { - - if len(p) > 0 { - if len(w.buf) == 0 { - - // filter outputs for error log - for { - rec := &record{} - var buf []byte - buf, err = rec.read(w.c.rwc) - if err != nil { - return - } - // standard error output - if rec.h.Type == Stderr { - w.c.stderr.Write(buf) - continue - } - w.buf = buf - break - } - } - - n = len(p) - if n > len(w.buf) { - n = len(w.buf) - } - copy(p, w.buf[:n]) - w.buf = w.buf[n:] - } - - return -} - -// Do made the request and returns a io.Reader that translates the data read -// from fcgi responder out of fcgi packet before returning it. -func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) { - err = c.writeBeginRequest(uint16(Responder), 0) - if err != nil { - return - } - - err = c.writePairs(Params, p) - if err != nil { - return - } - - body := newWriter(c, Stdin) - if req != nil { - io.Copy(body, req) - } - body.Close() - - r = &streamReader{c: c} - return -} - -// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer -// that closes FCGIClient connection. -type clientCloser struct { - *FCGIClient - io.Reader -} - -func (f clientCloser) Close() error { return f.rwc.Close() } - -// Request returns a HTTP Response with Header and Body -// from fcgi responder -func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) { - - r, err := c.Do(p, req) - if err != nil { - return - } - - rb := bufio.NewReader(r) - tp := textproto.NewReader(rb) - resp = new(http.Response) - - // Parse the response headers. - mimeHeader, err := tp.ReadMIMEHeader() - if err != nil && err != io.EOF { - return - } - resp.Header = http.Header(mimeHeader) - - if resp.Header.Get("Status") != "" { - statusParts := strings.SplitN(resp.Header.Get("Status"), " ", 2) - resp.StatusCode, err = strconv.Atoi(statusParts[0]) - if err != nil { - return - } - if len(statusParts) > 1 { - resp.Status = statusParts[1] - } - - } else { - resp.StatusCode = http.StatusOK - } - - // TODO: fixTransferEncoding ? - resp.TransferEncoding = resp.Header["Transfer-Encoding"] - resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - - if chunked(resp.TransferEncoding) { - resp.Body = clientCloser{c, httputil.NewChunkedReader(rb)} - } else { - resp.Body = clientCloser{c, ioutil.NopCloser(rb)} - } - return -} - -// Get issues a GET request to the fcgi responder. -func (c *FCGIClient) Get(p map[string]string) (resp *http.Response, err error) { - - p["REQUEST_METHOD"] = "GET" - p["CONTENT_LENGTH"] = "0" - - return c.Request(p, nil) -} - -// Head issues a HEAD request to the fcgi responder. -func (c *FCGIClient) Head(p map[string]string) (resp *http.Response, err error) { - - p["REQUEST_METHOD"] = "HEAD" - p["CONTENT_LENGTH"] = "0" - - return c.Request(p, nil) -} - -// Options issues an OPTIONS request to the fcgi responder. -func (c *FCGIClient) Options(p map[string]string) (resp *http.Response, err error) { - - p["REQUEST_METHOD"] = "OPTIONS" - p["CONTENT_LENGTH"] = "0" - - return c.Request(p, nil) -} - -// Post issues a POST request to the fcgi responder. with request body -// in the format that bodyType specified -func (c *FCGIClient) Post(p map[string]string, method string, bodyType string, body io.Reader, l int) (resp *http.Response, err error) { - if p == nil { - p = make(map[string]string) - } - - p["REQUEST_METHOD"] = strings.ToUpper(method) - - if len(p["REQUEST_METHOD"]) == 0 || p["REQUEST_METHOD"] == "GET" { - p["REQUEST_METHOD"] = "POST" - } - - p["CONTENT_LENGTH"] = strconv.Itoa(l) - if len(bodyType) > 0 { - p["CONTENT_TYPE"] = bodyType - } else { - p["CONTENT_TYPE"] = "application/x-www-form-urlencoded" - } - - return c.Request(p, body) -} - -// PostForm issues a POST to the fcgi responder, with form -// as a string key to a list values (url.Values) -func (c *FCGIClient) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) { - body := bytes.NewReader([]byte(data.Encode())) - return c.Post(p, "POST", "application/x-www-form-urlencoded", body, body.Len()) -} - -// PostFile issues a POST to the fcgi responder in multipart(RFC 2046) standard, -// with form as a string key to a list values (url.Values), -// and/or with file as a string key to a list file path. -func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[string]string) (resp *http.Response, err error) { - buf := &bytes.Buffer{} - writer := multipart.NewWriter(buf) - bodyType := writer.FormDataContentType() - - for key, val := range data { - for _, v0 := range val { - err = writer.WriteField(key, v0) - if err != nil { - return - } - } - } - - for key, val := range file { - fd, e := os.Open(val) - if e != nil { - return nil, e - } - defer fd.Close() - - part, e := writer.CreateFormFile(key, filepath.Base(val)) - if e != nil { - return nil, e - } - _, err = io.Copy(part, fd) - } - - err = writer.Close() - if err != nil { - return - } - - return c.Post(p, "POST", bodyType, buf, buf.Len()) -} - -// Checks whether chunked is part of the encodings stack -func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } diff --git a/middleware/fastcgi/fcgiclient_test.go b/middleware/fastcgi/fcgiclient_test.go deleted file mode 100644 index c4d997844..000000000 --- a/middleware/fastcgi/fcgiclient_test.go +++ /dev/null @@ -1,275 +0,0 @@ -// NOTE: These tests were adapted from the original -// repository from which this package was forked. -// The tests are slow (~10s) and in dire need of rewriting. -// As such, the tests have been disabled to speed up -// automated builds until they can be properly written. - -package fastcgi - -import ( - "bytes" - "crypto/md5" - "encoding/binary" - "fmt" - "io" - "io/ioutil" - "log" - "math/rand" - "net" - "net/http" - "net/http/fcgi" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "testing" - "time" -) - -// test fcgi protocol includes: -// Get, Post, Post in multipart/form-data, and Post with files -// each key should be the md5 of the value or the file uploaded -// sepicify remote fcgi responer ip:port to test with php -// test failed if the remote fcgi(script) failed md5 verification -// and output "FAILED" in response -const ( - scriptFile = "/tank/www/fcgic_test.php" - //ipPort = "remote-php-serv:59000" - ipPort = "127.0.0.1:59000" -) - -var globalt *testing.T - -type FastCGIServer struct{} - -func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { - - req.ParseMultipartForm(100000000) - - stat := "PASSED" - fmt.Fprintln(resp, "-") - fileNum := 0 - { - length := 0 - for k0, v0 := range req.Form { - h := md5.New() - io.WriteString(h, v0[0]) - md5 := fmt.Sprintf("%x", h.Sum(nil)) - - length += len(k0) - length += len(v0[0]) - - // echo error when key != md5(val) - if md5 != k0 { - fmt.Fprintln(resp, "server:err ", md5, k0) - stat = "FAILED" - } - } - if req.MultipartForm != nil { - fileNum = len(req.MultipartForm.File) - for kn, fns := range req.MultipartForm.File { - //fmt.Fprintln(resp, "server:filekey ", kn ) - length += len(kn) - for _, f := range fns { - fd, err := f.Open() - if err != nil { - log.Println("server:", err) - return - } - h := md5.New() - l0, err := io.Copy(h, fd) - if err != nil { - log.Println(err) - return - } - length += int(l0) - defer fd.Close() - md5 := fmt.Sprintf("%x", h.Sum(nil)) - //fmt.Fprintln(resp, "server:filemd5 ", md5 ) - - if kn != md5 { - fmt.Fprintln(resp, "server:err ", md5, kn) - stat = "FAILED" - } - //fmt.Fprintln(resp, "server:filename ", f.Filename ) - } - } - } - - fmt.Fprintln(resp, "server:got data length", length) - } - fmt.Fprintln(resp, "-"+stat+"-POST(", len(req.Form), ")-FILE(", fileNum, ")--") -} - -func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) { - fcgi, err := Dial("tcp", ipPort) - if err != nil { - log.Println("err:", err) - return - } - - length := 0 - - var resp *http.Response - switch reqType { - case 0: - if len(data) > 0 { - length = len(data) - rd := bytes.NewReader(data) - resp, err = fcgi.Post(fcgiParams, "", "", rd, rd.Len()) - } else if len(posts) > 0 { - values := url.Values{} - for k, v := range posts { - values.Set(k, v) - length += len(k) + 2 + len(v) - } - resp, err = fcgi.PostForm(fcgiParams, values) - } else { - resp, err = fcgi.Get(fcgiParams) - } - - default: - values := url.Values{} - for k, v := range posts { - values.Set(k, v) - length += len(k) + 2 + len(v) - } - - for k, v := range files { - fi, _ := os.Lstat(v) - length += len(k) + int(fi.Size()) - } - resp, err = fcgi.PostFile(fcgiParams, values, files) - } - - if err != nil { - log.Println("err:", err) - return - } - - defer resp.Body.Close() - content, _ = ioutil.ReadAll(resp.Body) - - log.Println("c: send data length ≈", length, string(content)) - fcgi.Close() - time.Sleep(1 * time.Second) - - if bytes.Index(content, []byte("FAILED")) >= 0 { - globalt.Error("Server return failed message") - } - - return -} - -func generateRandFile(size int) (p string, m string) { - - p = filepath.Join(os.TempDir(), "fcgict"+strconv.Itoa(rand.Int())) - - // open output file - fo, err := os.Create(p) - if err != nil { - panic(err) - } - // close fo on exit and check for its returned error - defer func() { - if err := fo.Close(); err != nil { - panic(err) - } - }() - - h := md5.New() - for i := 0; i < size/16; i++ { - buf := make([]byte, 16) - binary.PutVarint(buf, rand.Int63()) - fo.Write(buf) - h.Write(buf) - } - m = fmt.Sprintf("%x", h.Sum(nil)) - return -} - -func DisabledTest(t *testing.T) { - // TODO: test chunked reader - globalt = t - - rand.Seed(time.Now().UTC().UnixNano()) - - // server - go func() { - listener, err := net.Listen("tcp", ipPort) - if err != nil { - // handle error - log.Println("listener creation failed: ", err) - } - - srv := new(FastCGIServer) - fcgi.Serve(listener, srv) - }() - - time.Sleep(1 * time.Second) - - // init - fcgiParams := make(map[string]string) - fcgiParams["REQUEST_METHOD"] = "GET" - fcgiParams["SERVER_PROTOCOL"] = "HTTP/1.1" - //fcgi_params["GATEWAY_INTERFACE"] = "CGI/1.1" - fcgiParams["SCRIPT_FILENAME"] = scriptFile - - // simple GET - log.Println("test:", "get") - sendFcgi(0, fcgiParams, nil, nil, nil) - - // simple post data - log.Println("test:", "post") - sendFcgi(0, fcgiParams, []byte("c4ca4238a0b923820dcc509a6f75849b=1&7b8b965ad4bca0e41ab51de7b31363a1=n"), nil, nil) - - log.Println("test:", "post data (more than 60KB)") - data := "" - for i := 0x00; i < 0xff; i++ { - v0 := strings.Repeat(string(i), 256) - h := md5.New() - io.WriteString(h, v0) - k0 := fmt.Sprintf("%x", h.Sum(nil)) - data += k0 + "=" + url.QueryEscape(v0) + "&" - } - sendFcgi(0, fcgiParams, []byte(data), nil, nil) - - log.Println("test:", "post form (use url.Values)") - p0 := make(map[string]string, 1) - p0["c4ca4238a0b923820dcc509a6f75849b"] = "1" - p0["7b8b965ad4bca0e41ab51de7b31363a1"] = "n" - sendFcgi(1, fcgiParams, nil, p0, nil) - - log.Println("test:", "post forms (256 keys, more than 1MB)") - p1 := make(map[string]string, 1) - for i := 0x00; i < 0xff; i++ { - v0 := strings.Repeat(string(i), 4096) - h := md5.New() - io.WriteString(h, v0) - k0 := fmt.Sprintf("%x", h.Sum(nil)) - p1[k0] = v0 - } - sendFcgi(1, fcgiParams, nil, p1, nil) - - log.Println("test:", "post file (1 file, 500KB)) ") - f0 := make(map[string]string, 1) - path0, m0 := generateRandFile(500000) - f0[m0] = path0 - sendFcgi(1, fcgiParams, nil, p1, f0) - - log.Println("test:", "post multiple files (2 files, 5M each) and forms (256 keys, more than 1MB data") - path1, m1 := generateRandFile(5000000) - f0[m1] = path1 - sendFcgi(1, fcgiParams, nil, p1, f0) - - log.Println("test:", "post only files (2 files, 5M each)") - sendFcgi(1, fcgiParams, nil, nil, f0) - - log.Println("test:", "post only 1 file") - delete(f0, "m0") - sendFcgi(1, fcgiParams, nil, nil, f0) - - os.Remove(path0) - os.Remove(path1) -} diff --git a/middleware/headers/headers.go b/middleware/headers/headers.go deleted file mode 100644 index 831a1afb4..000000000 --- a/middleware/headers/headers.go +++ /dev/null @@ -1,51 +0,0 @@ -// Package headers provides middleware that appends headers to -// requests based on a set of configuration rules that define -// which routes receive which headers. -package headers - -import ( - "net/http" - "strings" - - "github.com/mholt/caddy/middleware" -) - -// Headers is middleware that adds headers to the responses -// for requests matching a certain path. -type Headers struct { - Next middleware.Handler - Rules []Rule -} - -// ServeHTTP implements the middleware.Handler interface and serves requests, -// setting headers on the response according to the configured rules. -func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - replacer := middleware.NewReplacer(r, nil, "") - for _, rule := range h.Rules { - if middleware.Path(r.URL.Path).Matches(rule.Path) { - for _, header := range rule.Headers { - if strings.HasPrefix(header.Name, "-") { - w.Header().Del(strings.TrimLeft(header.Name, "-")) - } else { - w.Header().Set(header.Name, replacer.Replace(header.Value)) - } - } - } - } - return h.Next.ServeHTTP(w, r) -} - -type ( - // Rule groups a slice of HTTP headers by a URL pattern. - // TODO: use http.Header type instead? - Rule struct { - Path string - Headers []Header - } - - // Header represents a single HTTP header, simply a name and value. - Header struct { - Name string - Value string - } -) diff --git a/middleware/headers/headers_test.go b/middleware/headers/headers_test.go deleted file mode 100644 index 0627902d1..000000000 --- a/middleware/headers/headers_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package headers - -import ( - "net/http" - "net/http/httptest" - "os" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestHeaders(t *testing.T) { - hostname, err := os.Hostname() - if err != nil { - t.Fatalf("Could not determine hostname: %v", err) - } - for i, test := range []struct { - from string - name string - value string - }{ - {"/a", "Foo", "Bar"}, - {"/a", "Bar", ""}, - {"/a", "Baz", ""}, - {"/a", "ServerName", hostname}, - {"/b", "Foo", ""}, - {"/b", "Bar", "Removed in /a"}, - } { - he := Headers{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - return 0, nil - }), - Rules: []Rule{ - {Path: "/a", Headers: []Header{ - {Name: "Foo", Value: "Bar"}, - {Name: "ServerName", Value: "{hostname}"}, - {Name: "-Bar"}, - }}, - }, - } - - req, err := http.NewRequest("GET", test.from, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) - } - - rec := httptest.NewRecorder() - rec.Header().Set("Bar", "Removed in /a") - - he.ServeHTTP(rec, req) - - if got := rec.Header().Get(test.name); got != test.value { - t.Errorf("Test %d: Expected %s header to be %q but was %q", - i, test.name, test.value, got) - } - } -} diff --git a/middleware/inner/internal.go b/middleware/inner/internal.go deleted file mode 100644 index d7f044f70..000000000 --- a/middleware/inner/internal.go +++ /dev/null @@ -1,90 +0,0 @@ -// Package inner provides a simple middleware that (a) prevents access -// to internal locations and (b) allows to return files from internal location -// by setting a special header, e.g. in a proxy response. -package inner - -import ( - "net/http" - - "github.com/mholt/caddy/middleware" -) - -// Internal middleware protects internal locations from external requests - -// but allows access from the inside by using a special HTTP header. -type Internal struct { - Next middleware.Handler - Paths []string -} - -const ( - redirectHeader string = "X-Accel-Redirect" - maxRedirectCount int = 10 -) - -func isInternalRedirect(w http.ResponseWriter) bool { - return w.Header().Get(redirectHeader) != "" -} - -// ServeHTTP implements the middlware.Handler interface. -func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - - // Internal location requested? -> Not found. - for _, prefix := range i.Paths { - if middleware.Path(r.URL.Path).Matches(prefix) { - return http.StatusNotFound, nil - } - } - - // Use internal response writer to ignore responses that will be - // redirected to internal locations - iw := internalResponseWriter{ResponseWriter: w} - status, err := i.Next.ServeHTTP(iw, r) - - for c := 0; c < maxRedirectCount && isInternalRedirect(iw); c++ { - // Redirect - adapt request URL path and send it again - // "down the chain" - r.URL.Path = iw.Header().Get(redirectHeader) - iw.ClearHeader() - - status, err = i.Next.ServeHTTP(iw, r) - } - - if isInternalRedirect(iw) { - // Too many redirect cycles - iw.ClearHeader() - return http.StatusInternalServerError, nil - } - - return status, err -} - -// internalResponseWriter wraps the underlying http.ResponseWriter and ignores -// calls to Write and WriteHeader if the response should be redirected to an -// internal location. -type internalResponseWriter struct { - http.ResponseWriter -} - -// ClearHeader removes all header fields that are already set. -func (w internalResponseWriter) ClearHeader() { - for k := range w.Header() { - w.Header().Del(k) - } -} - -// WriteHeader ignores the call if the response should be redirected to an -// internal location. -func (w internalResponseWriter) WriteHeader(code int) { - if !isInternalRedirect(w) { - w.ResponseWriter.WriteHeader(code) - } -} - -// Write ignores the call if the response should be redirected to an internal -// location. -func (w internalResponseWriter) Write(b []byte) (int, error) { - if isInternalRedirect(w) { - return 0, nil - } - return w.ResponseWriter.Write(b) -} diff --git a/middleware/inner/internal_test.go b/middleware/inner/internal_test.go deleted file mode 100644 index 97078febc..000000000 --- a/middleware/inner/internal_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package inner - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestInternal(t *testing.T) { - im := Internal{ - Next: middleware.HandlerFunc(internalTestHandlerFunc), - Paths: []string{"/internal"}, - } - - tests := []struct { - url string - expectedCode int - expectedBody string - }{ - {"/internal", http.StatusNotFound, ""}, - - {"/public", 0, "/public"}, - {"/public/internal", 0, "/public/internal"}, - - {"/redirect", 0, "/internal"}, - - {"/cycle", http.StatusInternalServerError, ""}, - } - - for i, test := range tests { - req, err := http.NewRequest("GET", test.url, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) - } - - rec := httptest.NewRecorder() - code, _ := im.ServeHTTP(rec, req) - - if code != test.expectedCode { - t.Errorf("Test %d: Expected status code %d for %s, but got %d", - i, test.expectedCode, test.url, code) - } - if rec.Body.String() != test.expectedBody { - t.Errorf("Test %d: Expected body '%s' for %s, but got '%s'", - i, test.expectedBody, test.url, rec.Body.String()) - } - } -} - -func internalTestHandlerFunc(w http.ResponseWriter, r *http.Request) (int, error) { - switch r.URL.Path { - case "/redirect": - w.Header().Set("X-Accel-Redirect", "/internal") - case "/cycle": - w.Header().Set("X-Accel-Redirect", "/cycle") - } - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, r.URL.String()) - - return 0, nil -} diff --git a/middleware/markdown/markdown.go b/middleware/markdown/markdown.go deleted file mode 100644 index ab53710e0..000000000 --- a/middleware/markdown/markdown.go +++ /dev/null @@ -1,171 +0,0 @@ -// Package markdown is middleware to render markdown files as HTML -// on-the-fly. -package markdown - -import ( - "net/http" - "os" - "path" - "strconv" - "strings" - "text/template" - "time" - - "github.com/mholt/caddy/middleware" - "github.com/russross/blackfriday" -) - -// Markdown implements a layer of middleware that serves -// markdown as HTML. -type Markdown struct { - // Server root - Root string - - // Jail the requests to site root with a mock file system - FileSys http.FileSystem - - // Next HTTP handler in the chain - Next middleware.Handler - - // The list of markdown configurations - Configs []*Config - - // The list of index files to try - IndexFiles []string -} - -// Config stores markdown middleware configurations. -type Config struct { - // Markdown renderer - Renderer blackfriday.Renderer - - // Base path to match - PathScope string - - // List of extensions to consider as markdown files - Extensions map[string]struct{} - - // List of style sheets to load for each markdown file - Styles []string - - // List of JavaScript files to load for each markdown file - Scripts []string - - // Template(s) to render with - Template *template.Template -} - -// ServeHTTP implements the http.Handler interface. -func (md Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - var cfg *Config - for _, c := range md.Configs { - if middleware.Path(r.URL.Path).Matches(c.PathScope) { // not negated - cfg = c - break // or goto - } - } - if cfg == nil { - return md.Next.ServeHTTP(w, r) // exit early - } - - // We only deal with HEAD/GET - switch r.Method { - case http.MethodGet, http.MethodHead: - default: - return http.StatusMethodNotAllowed, nil - } - - var dirents []os.FileInfo - var lastModTime time.Time - fpath := r.URL.Path - if idx, ok := middleware.IndexFile(md.FileSys, fpath, md.IndexFiles); ok { - // We're serving a directory index file, which may be a markdown - // file with a template. Let's grab a list of files this directory - // URL points to, and pass that in to any possible template invocations, - // so that templates can customize the look and feel of a directory. - fdp, err := md.FileSys.Open(fpath) - switch { - case err == nil: // nop - case os.IsPermission(err): - return http.StatusForbidden, err - case os.IsExist(err): - return http.StatusNotFound, nil - default: // did we run out of FD? - return http.StatusInternalServerError, err - } - defer fdp.Close() - - // Grab a possible set of directory entries. Note, we do not check - // for errors here (unreadable directory, for example). It may - // still be useful to have a directory template file, without the - // directory contents being present. Note, the directory's last - // modification is also present here (entry "."). - dirents, _ = fdp.Readdir(-1) - for _, d := range dirents { - lastModTime = latest(lastModTime, d.ModTime()) - } - - // Set path to found index file - fpath = idx - } - - // If not supported extension, pass on it - if _, ok := cfg.Extensions[path.Ext(fpath)]; !ok { - return md.Next.ServeHTTP(w, r) - } - - // At this point we have a supported extension/markdown - f, err := md.FileSys.Open(fpath) - switch { - case err == nil: // nop - case os.IsPermission(err): - return http.StatusForbidden, err - case os.IsExist(err): - return http.StatusNotFound, nil - default: // did we run out of FD? - return http.StatusInternalServerError, err - } - defer f.Close() - - if fs, err := f.Stat(); err != nil { - return http.StatusGone, nil - } else { - lastModTime = latest(lastModTime, fs.ModTime()) - } - - ctx := middleware.Context{ - Root: md.FileSys, - Req: r, - URL: r.URL, - } - html, err := cfg.Markdown(title(fpath), f, dirents, ctx) - if err != nil { - return http.StatusInternalServerError, err - } - - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Header().Set("Content-Length", strconv.FormatInt(int64(len(html)), 10)) - middleware.SetLastModifiedHeader(w, lastModTime) - if r.Method == http.MethodGet { - w.Write(html) - } - return http.StatusOK, nil -} - -// latest returns the latest time.Time -func latest(t ...time.Time) time.Time { - var last time.Time - - for _, tt := range t { - if tt.After(last) { - last = tt - } - } - - return last -} - -// title gives a backup generated title for a page -func title(p string) string { - return strings.TrimRight(path.Base(p), path.Ext(p)) -} diff --git a/middleware/markdown/markdown_test.go b/middleware/markdown/markdown_test.go deleted file mode 100644 index 382c8e120..000000000 --- a/middleware/markdown/markdown_test.go +++ /dev/null @@ -1,230 +0,0 @@ -package markdown - -import ( - "bufio" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strings" - "testing" - "text/template" - "time" - - "github.com/mholt/caddy/middleware" - "github.com/russross/blackfriday" -) - -func TestMarkdown(t *testing.T) { - rootDir := "./testdata" - - f := func(filename string) string { - return filepath.ToSlash(rootDir + string(filepath.Separator) + filename) - } - - md := Markdown{ - Root: rootDir, - FileSys: http.Dir(rootDir), - Configs: []*Config{ - { - Renderer: blackfriday.HtmlRenderer(0, "", ""), - PathScope: "/blog", - Extensions: map[string]struct{}{ - ".md": {}, - }, - Styles: []string{}, - Scripts: []string{}, - Template: setDefaultTemplate(f("markdown_tpl.html")), - }, - { - Renderer: blackfriday.HtmlRenderer(0, "", ""), - PathScope: "/docflags", - Extensions: map[string]struct{}{ - ".md": {}, - }, - Styles: []string{}, - Scripts: []string{}, - Template: setDefaultTemplate(f("docflags/template.txt")), - }, - { - Renderer: blackfriday.HtmlRenderer(0, "", ""), - PathScope: "/log", - Extensions: map[string]struct{}{ - ".md": {}, - }, - Styles: []string{"/resources/css/log.css", "/resources/css/default.css"}, - Scripts: []string{"/resources/js/log.js", "/resources/js/default.js"}, - Template: GetDefaultTemplate(), - }, - { - Renderer: blackfriday.HtmlRenderer(0, "", ""), - PathScope: "/og", - Extensions: map[string]struct{}{ - ".md": {}, - }, - Styles: []string{}, - Scripts: []string{}, - Template: setDefaultTemplate(f("markdown_tpl.html")), - }, - }, - IndexFiles: []string{"index.html"}, - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - t.Fatalf("Next shouldn't be called") - return 0, nil - }), - } - - req, err := http.NewRequest("GET", "/blog/test.md", nil) - if err != nil { - t.Fatalf("Could not create HTTP request: %v", err) - } - - rec := httptest.NewRecorder() - - md.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("Wrong status, expected: %d and got %d", http.StatusOK, rec.Code) - } - - respBody := rec.Body.String() - expectedBody := ` - - -Markdown test 1 - - -

Header for: Markdown test 1

- -Welcome to A Caddy website! -

Welcome on the blog

- -

Body

- -
func getTrue() bool {
-    return true
-}
-
- - - -` - if !equalStrings(respBody, expectedBody) { - t.Fatalf("Expected body: %v got: %v", expectedBody, respBody) - } - - req, err = http.NewRequest("GET", "/docflags/test.md", nil) - if err != nil { - t.Fatalf("Could not create HTTP request: %v", err) - } - rec = httptest.NewRecorder() - - md.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("Wrong status, expected: %d and got %d", http.StatusOK, rec.Code) - } - respBody = rec.Body.String() - expectedBody = `Doc.var_string hello -Doc.var_bool -DocFlags.var_string -DocFlags.var_bool true` - - if !equalStrings(respBody, expectedBody) { - t.Fatalf("Expected body: %v got: %v", expectedBody, respBody) - } - - req, err = http.NewRequest("GET", "/log/test.md", nil) - if err != nil { - t.Fatalf("Could not create HTTP request: %v", err) - } - rec = httptest.NewRecorder() - - md.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("Wrong status, expected: %d and got %d", http.StatusOK, rec.Code) - } - respBody = rec.Body.String() - expectedBody = ` - - - Markdown test 2 - - - - - - - -

Welcome on the blog

- -

Body

- -
func getTrue() bool {
-    return true
-}
-
- - -` - - if !equalStrings(respBody, expectedBody) { - t.Fatalf("Expected body: %v got: %v", expectedBody, respBody) - } - - req, err = http.NewRequest("GET", "/og/first.md", nil) - if err != nil { - t.Fatalf("Could not create HTTP request: %v", err) - } - rec = httptest.NewRecorder() - currenttime := time.Now().Local().Add(-time.Second) - _ = os.Chtimes("testdata/og/first.md", currenttime, currenttime) - currenttime = time.Now().Local() - _ = os.Chtimes("testdata/og_static/og/first.md/index.html", currenttime, currenttime) - time.Sleep(time.Millisecond * 200) - - md.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("Wrong status, expected: %d and got %d", http.StatusOK, rec.Code) - } - respBody = rec.Body.String() - expectedBody = ` - - -first_post - - -

Header for: first_post

- -Welcome to title! -

Test h1

- - -` - - if !equalStrings(respBody, expectedBody) { - t.Fatalf("Expected body: %v got: %v", expectedBody, respBody) - } -} - -func equalStrings(s1, s2 string) bool { - s1 = strings.TrimSpace(s1) - s2 = strings.TrimSpace(s2) - in := bufio.NewScanner(strings.NewReader(s1)) - for in.Scan() { - txt := strings.TrimSpace(in.Text()) - if !strings.HasPrefix(strings.TrimSpace(s2), txt) { - return false - } - s2 = strings.Replace(s2, txt, "", 1) - } - return true -} - -func setDefaultTemplate(filename string) *template.Template { - buf, err := ioutil.ReadFile(filename) - if err != nil { - return nil - } - - return template.Must(GetDefaultTemplate().Parse(string(buf))) -} diff --git a/middleware/markdown/metadata/metadata.go b/middleware/markdown/metadata/metadata.go deleted file mode 100644 index ade7fcc9d..000000000 --- a/middleware/markdown/metadata/metadata.go +++ /dev/null @@ -1,158 +0,0 @@ -package metadata - -import ( - "bufio" - "bytes" - "time" -) - -var ( - // Date format YYYY-MM-DD HH:MM:SS or YYYY-MM-DD - timeLayout = []string{ - `2006-01-02 15:04:05-0700`, - `2006-01-02 15:04:05`, - `2006-01-02`, - } -) - -// Metadata stores a page's metadata -type Metadata struct { - // Page title - Title string - - // Page template - Template string - - // Publish date - Date time.Time - - // Variables to be used with Template - Variables map[string]string - - // Flags to be used with Template - Flags map[string]bool -} - -// NewMetadata() returns a new Metadata struct, loaded with the given map -func NewMetadata(parsedMap map[string]interface{}) Metadata { - md := Metadata{ - Variables: make(map[string]string), - Flags: make(map[string]bool), - } - md.load(parsedMap) - - return md -} - -// load loads parsed values in parsedMap into Metadata -func (m *Metadata) load(parsedMap map[string]interface{}) { - - // Pull top level things out - if title, ok := parsedMap["title"]; ok { - m.Title, _ = title.(string) - } - if template, ok := parsedMap["template"]; ok { - m.Template, _ = template.(string) - } - if date, ok := parsedMap["date"].(string); ok { - for _, layout := range timeLayout { - if t, err := time.Parse(layout, date); err == nil { - m.Date = t - break - } - } - } - - // Store everything as a flag or variable - for key, val := range parsedMap { - switch v := val.(type) { - case bool: - m.Flags[key] = v - case string: - m.Variables[key] = v - } - } -} - -// MetadataParser is a an interface that must be satisfied by each parser -type MetadataParser interface { - // Initialize a parser - Init(b *bytes.Buffer) bool - - // Type of metadata - Type() string - - // Parsed metadata. - Metadata() Metadata - - // Raw markdown. - Markdown() []byte -} - -// GetParser returns a parser for the given data -func GetParser(buf []byte) MetadataParser { - for _, p := range parsers() { - b := bytes.NewBuffer(buf) - if p.Init(b) { - return p - } - } - - return nil -} - -// parsers returns all available parsers -func parsers() []MetadataParser { - return []MetadataParser{ - &TOMLMetadataParser{}, - &YAMLMetadataParser{}, - &JSONMetadataParser{}, - - // This one must be last - &NoneMetadataParser{}, - } -} - -// Split out prefixed/suffixed metadata with given delimiter -func splitBuffer(b *bytes.Buffer, delim string) (*bytes.Buffer, *bytes.Buffer) { - scanner := bufio.NewScanner(b) - - // Read and check first line - if !scanner.Scan() { - return nil, nil - } - if string(bytes.TrimSpace(scanner.Bytes())) != delim { - return nil, nil - } - - // Accumulate metadata, until delimiter - meta := bytes.NewBuffer(nil) - for scanner.Scan() { - if string(bytes.TrimSpace(scanner.Bytes())) == delim { - break - } - if _, err := meta.Write(scanner.Bytes()); err != nil { - return nil, nil - } - if _, err := meta.WriteRune('\n'); err != nil { - return nil, nil - } - } - // Make sure we saw closing delimiter - if string(bytes.TrimSpace(scanner.Bytes())) != delim { - return nil, nil - } - - // The rest is markdown - markdown := new(bytes.Buffer) - for scanner.Scan() { - if _, err := markdown.Write(scanner.Bytes()); err != nil { - return nil, nil - } - if _, err := markdown.WriteRune('\n'); err != nil { - return nil, nil - } - } - - return meta, markdown -} diff --git a/middleware/markdown/metadata/metadata_json.go b/middleware/markdown/metadata/metadata_json.go deleted file mode 100644 index d3b9991ff..000000000 --- a/middleware/markdown/metadata/metadata_json.go +++ /dev/null @@ -1,53 +0,0 @@ -package metadata - -import ( - "bytes" - "encoding/json" -) - -// JSONMetadataParser is the MetadataParser for JSON -type JSONMetadataParser struct { - metadata Metadata - markdown *bytes.Buffer -} - -func (j *JSONMetadataParser) Type() string { - return "JSON" -} - -// Parse metadata/markdown file -func (j *JSONMetadataParser) Init(b *bytes.Buffer) bool { - m := make(map[string]interface{}) - - err := json.Unmarshal(b.Bytes(), &m) - if err != nil { - var offset int - - if jerr, ok := err.(*json.SyntaxError); !ok { - return false - } else { - offset = int(jerr.Offset) - } - - m = make(map[string]interface{}) - err = json.Unmarshal(b.Next(offset-1), &m) - if err != nil { - return false - } - } - - j.metadata = NewMetadata(m) - j.markdown = bytes.NewBuffer(b.Bytes()) - - return true -} - -// Metadata returns parsed metadata. It should be called -// only after a call to Parse returns without error. -func (j *JSONMetadataParser) Metadata() Metadata { - return j.metadata -} - -func (j *JSONMetadataParser) Markdown() []byte { - return j.markdown.Bytes() -} diff --git a/middleware/markdown/metadata/metadata_none.go b/middleware/markdown/metadata/metadata_none.go deleted file mode 100644 index ed034f2fa..000000000 --- a/middleware/markdown/metadata/metadata_none.go +++ /dev/null @@ -1,39 +0,0 @@ -package metadata - -import ( - "bytes" -) - -// TOMLMetadataParser is the MetadataParser for TOML -type NoneMetadataParser struct { - metadata Metadata - markdown *bytes.Buffer -} - -func (n *NoneMetadataParser) Type() string { - return "None" -} - -// Parse metadata/markdown file -func (n *NoneMetadataParser) Init(b *bytes.Buffer) bool { - m := make(map[string]interface{}) - n.metadata = NewMetadata(m) - n.markdown = bytes.NewBuffer(b.Bytes()) - - return true -} - -// Parse the metadata -func (n *NoneMetadataParser) Parse(b []byte) ([]byte, error) { - return nil, nil -} - -// Metadata returns parsed metadata. It should be called -// only after a call to Parse returns without error. -func (n *NoneMetadataParser) Metadata() Metadata { - return n.metadata -} - -func (n *NoneMetadataParser) Markdown() []byte { - return n.markdown.Bytes() -} diff --git a/middleware/markdown/metadata/metadata_test.go b/middleware/markdown/metadata/metadata_test.go deleted file mode 100644 index 0c155d37c..000000000 --- a/middleware/markdown/metadata/metadata_test.go +++ /dev/null @@ -1,261 +0,0 @@ -package metadata - -import ( - "bytes" - "strings" - "testing" -) - -func check(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } -} - -var TOML = [5]string{` -title = "A title" -template = "default" -name = "value" -positive = true -negative = false -`, - `+++ -title = "A title" -template = "default" -name = "value" -positive = true -negative = false -+++ -Page content - `, - `+++ -title = "A title" -template = "default" -name = "value" -positive = true -negative = false - `, - `title = "A title" template = "default" [variables] name = "value"`, - `+++ -title = "A title" -template = "default" -name = "value" -positive = true -negative = false -+++ -`, -} - -var YAML = [5]string{` -title : A title -template : default -name : value -positive : true -negative : false -`, - `--- -title : A title -template : default -name : value -positive : true -negative : false ---- - Page content - `, - `--- -title : A title -template : default -name : value - `, - `title : A title template : default variables : name : value : positive : true : negative : false`, - `--- -title : A title -template : default -name : value -positive : true -negative : false ---- -`, -} - -var JSON = [5]string{` - "title" : "A title", - "template" : "default", - "name" : "value", - "positive" : true, - "negative" : false -`, - `{ - "title" : "A title", - "template" : "default", - "name" : "value", - "positive" : true, - "negative" : false -} -Page content - `, - ` -{ - "title" : "A title", - "template" : "default", - "name" : "value", - "positive" : true, - "negative" : false - `, - ` -{ - "title" :: "A title", - "template" : "default", - "name" : "value", - "positive" : true, - "negative" : false -} - `, - `{ - "title" : "A title", - "template" : "default", - "name" : "value", - "positive" : true, - "negative" : false -} -`, -} - -func TestParsers(t *testing.T) { - expected := Metadata{ - Title: "A title", - Template: "default", - Variables: map[string]string{ - "name": "value", - "title": "A title", - "template": "default", - }, - Flags: map[string]bool{ - "positive": true, - "negative": false, - }, - } - compare := func(m Metadata) bool { - if m.Title != expected.Title { - return false - } - if m.Template != expected.Template { - return false - } - for k, v := range m.Variables { - if v != expected.Variables[k] { - return false - } - } - for k, v := range m.Flags { - if v != expected.Flags[k] { - return false - } - } - varLenOK := len(m.Variables) == len(expected.Variables) - flagLenOK := len(m.Flags) == len(expected.Flags) - return varLenOK && flagLenOK - } - - data := []struct { - parser MetadataParser - testData [5]string - name string - }{ - {&JSONMetadataParser{}, JSON, "JSON"}, - {&YAMLMetadataParser{}, YAML, "YAML"}, - {&TOMLMetadataParser{}, TOML, "TOML"}, - } - - for _, v := range data { - // metadata without identifiers - if v.parser.Init(bytes.NewBufferString(v.testData[0])) { - t.Fatalf("Expected error for invalid metadata for %v", v.name) - } - - // metadata with identifiers - if !v.parser.Init(bytes.NewBufferString(v.testData[1])) { - t.Fatalf("Metadata failed to initialize, type %v", v.parser.Type()) - } - md := v.parser.Markdown() - if !compare(v.parser.Metadata()) { - t.Fatalf("Expected %v, found %v for %v", expected, v.parser.Metadata(), v.name) - } - if "Page content" != strings.TrimSpace(string(md)) { - t.Fatalf("Expected %v, found %v for %v", "Page content", string(md), v.name) - } - // Check that we find the correct metadata parser type - if p := GetParser([]byte(v.testData[1])); p.Type() != v.name { - t.Fatalf("Wrong parser found, expected %v, found %v", v.name, p.Type()) - } - - // metadata without closing identifier - if v.parser.Init(bytes.NewBufferString(v.testData[2])) { - t.Fatalf("Expected error for missing closing identifier for %v parser", v.name) - } - - // invalid metadata - if v.parser.Init(bytes.NewBufferString(v.testData[3])) { - t.Fatalf("Expected error for invalid metadata for %v", v.name) - } - - // front matter but no body - if !v.parser.Init(bytes.NewBufferString(v.testData[4])) { - t.Fatalf("Unexpected error for valid metadata but no body for %v", v.name) - } - } -} - -func TestLargeBody(t *testing.T) { - - var JSON = `{ -"template": "chapter" -} - -Mycket olika byggnader har man i de nordiska rikena: pyramidformiga, kilformiga, välvda, runda och fyrkantiga. De pyramidformiga består helt enkelt av träribbor, som upptill löper samman och nedtill bildar en vidare krets; de är avsedda att användas av hantverkarna under sommaren, för att de inte ska plågas av solen, på samma gång som de besväras av rök och eld. De kilformiga husen är i regel försedda med höga tak, för att de täta och tunga snömassorna fortare ska kunna blåsa av och inte tynga ned taken. Dessa är täckta av björknäver, tegel eller kluvet spån av furu - för kådans skull -, gran, ek eller bok; taken på de förmögnas hus däremot med plåtar av koppar eller bly, i likhet med kyrktaken. Valvbyggnaderna uppförs ganska konstnärligt till skydd mot våldsamma vindar och snöfall, görs av sten eller trä, och är avsedda för olika alldagliga viktiga ändamål. Liknande byggnader kan finnas i stormännens gårdar där de används som förvaringsrum för husgeråd och jordbruksredskap. De runda byggnaderna - som för övrigt är de högst sällsynta - används av konstnärer, som vid sitt arbete behöver ett jämnt fördelat ljus från taket. Vanligast är de fyrkantiga husen, vars grova bjälkar är synnerligen väl hopfogade i hörnen - ett sant mästerverk av byggnadskonst; även dessa har fönster högt uppe i taken, för att dagsljuset skall kunna strömma in och ge alla därinne full belysning. Stenhusen har dörröppningar i förhållande till byggnadens storlek, men smala fönstergluggar, som skydd mot den stränga kölden, frosten och snön. Vore de större och vidare, såsom fönstren i Italien, skulle husen i följd av den fint yrande snön, som röres upp av den starka blåsten, precis som dammet av virvelvinden, snart nog fyllas med massor av snö och inte kunna stå emot dess tryck, utan störta samman. - - ` - var TOML = `+++ -template = "chapter" -+++ - -Mycket olika byggnader har man i de nordiska rikena: pyramidformiga, kilformiga, välvda, runda och fyrkantiga. De pyramidformiga består helt enkelt av träribbor, som upptill löper samman och nedtill bildar en vidare krets; de är avsedda att användas av hantverkarna under sommaren, för att de inte ska plågas av solen, på samma gång som de besväras av rök och eld. De kilformiga husen är i regel försedda med höga tak, för att de täta och tunga snömassorna fortare ska kunna blåsa av och inte tynga ned taken. Dessa är täckta av björknäver, tegel eller kluvet spån av furu - för kådans skull -, gran, ek eller bok; taken på de förmögnas hus däremot med plåtar av koppar eller bly, i likhet med kyrktaken. Valvbyggnaderna uppförs ganska konstnärligt till skydd mot våldsamma vindar och snöfall, görs av sten eller trä, och är avsedda för olika alldagliga viktiga ändamål. Liknande byggnader kan finnas i stormännens gårdar där de används som förvaringsrum för husgeråd och jordbruksredskap. De runda byggnaderna - som för övrigt är de högst sällsynta - används av konstnärer, som vid sitt arbete behöver ett jämnt fördelat ljus från taket. Vanligast är de fyrkantiga husen, vars grova bjälkar är synnerligen väl hopfogade i hörnen - ett sant mästerverk av byggnadskonst; även dessa har fönster högt uppe i taken, för att dagsljuset skall kunna strömma in och ge alla därinne full belysning. Stenhusen har dörröppningar i förhållande till byggnadens storlek, men smala fönstergluggar, som skydd mot den stränga kölden, frosten och snön. Vore de större och vidare, såsom fönstren i Italien, skulle husen i följd av den fint yrande snön, som röres upp av den starka blåsten, precis som dammet av virvelvinden, snart nog fyllas med massor av snö och inte kunna stå emot dess tryck, utan störta samman. - - ` - var YAML = `--- -template : chapter ---- - -Mycket olika byggnader har man i de nordiska rikena: pyramidformiga, kilformiga, välvda, runda och fyrkantiga. De pyramidformiga består helt enkelt av träribbor, som upptill löper samman och nedtill bildar en vidare krets; de är avsedda att användas av hantverkarna under sommaren, för att de inte ska plågas av solen, på samma gång som de besväras av rök och eld. De kilformiga husen är i regel försedda med höga tak, för att de täta och tunga snömassorna fortare ska kunna blåsa av och inte tynga ned taken. Dessa är täckta av björknäver, tegel eller kluvet spån av furu - för kådans skull -, gran, ek eller bok; taken på de förmögnas hus däremot med plåtar av koppar eller bly, i likhet med kyrktaken. Valvbyggnaderna uppförs ganska konstnärligt till skydd mot våldsamma vindar och snöfall, görs av sten eller trä, och är avsedda för olika alldagliga viktiga ändamål. Liknande byggnader kan finnas i stormännens gårdar där de används som förvaringsrum för husgeråd och jordbruksredskap. De runda byggnaderna - som för övrigt är de högst sällsynta - används av konstnärer, som vid sitt arbete behöver ett jämnt fördelat ljus från taket. Vanligast är de fyrkantiga husen, vars grova bjälkar är synnerligen väl hopfogade i hörnen - ett sant mästerverk av byggnadskonst; även dessa har fönster högt uppe i taken, för att dagsljuset skall kunna strömma in och ge alla därinne full belysning. Stenhusen har dörröppningar i förhållande till byggnadens storlek, men smala fönstergluggar, som skydd mot den stränga kölden, frosten och snön. Vore de större och vidare, såsom fönstren i Italien, skulle husen i följd av den fint yrande snön, som röres upp av den starka blåsten, precis som dammet av virvelvinden, snart nog fyllas med massor av snö och inte kunna stå emot dess tryck, utan störta samman. - - ` - var NONE = ` - -Mycket olika byggnader har man i de nordiska rikena: pyramidformiga, kilformiga, välvda, runda och fyrkantiga. De pyramidformiga består helt enkelt av träribbor, som upptill löper samman och nedtill bildar en vidare krets; de är avsedda att användas av hantverkarna under sommaren, för att de inte ska plågas av solen, på samma gång som de besväras av rök och eld. De kilformiga husen är i regel försedda med höga tak, för att de täta och tunga snömassorna fortare ska kunna blåsa av och inte tynga ned taken. Dessa är täckta av björknäver, tegel eller kluvet spån av furu - för kådans skull -, gran, ek eller bok; taken på de förmögnas hus däremot med plåtar av koppar eller bly, i likhet med kyrktaken. Valvbyggnaderna uppförs ganska konstnärligt till skydd mot våldsamma vindar och snöfall, görs av sten eller trä, och är avsedda för olika alldagliga viktiga ändamål. Liknande byggnader kan finnas i stormännens gårdar där de används som förvaringsrum för husgeråd och jordbruksredskap. De runda byggnaderna - som för övrigt är de högst sällsynta - används av konstnärer, som vid sitt arbete behöver ett jämnt fördelat ljus från taket. Vanligast är de fyrkantiga husen, vars grova bjälkar är synnerligen väl hopfogade i hörnen - ett sant mästerverk av byggnadskonst; även dessa har fönster högt uppe i taken, för att dagsljuset skall kunna strömma in och ge alla därinne full belysning. Stenhusen har dörröppningar i förhållande till byggnadens storlek, men smala fönstergluggar, som skydd mot den stränga kölden, frosten och snön. Vore de större och vidare, såsom fönstren i Italien, skulle husen i följd av den fint yrande snön, som röres upp av den starka blåsten, precis som dammet av virvelvinden, snart nog fyllas med massor av snö och inte kunna stå emot dess tryck, utan störta samman. - - ` - var expectedBody = `Mycket olika byggnader har man i de nordiska rikena: pyramidformiga, kilformiga, välvda, runda och fyrkantiga. De pyramidformiga består helt enkelt av träribbor, som upptill löper samman och nedtill bildar en vidare krets; de är avsedda att användas av hantverkarna under sommaren, för att de inte ska plågas av solen, på samma gång som de besväras av rök och eld. De kilformiga husen är i regel försedda med höga tak, för att de täta och tunga snömassorna fortare ska kunna blåsa av och inte tynga ned taken. Dessa är täckta av björknäver, tegel eller kluvet spån av furu - för kådans skull -, gran, ek eller bok; taken på de förmögnas hus däremot med plåtar av koppar eller bly, i likhet med kyrktaken. Valvbyggnaderna uppförs ganska konstnärligt till skydd mot våldsamma vindar och snöfall, görs av sten eller trä, och är avsedda för olika alldagliga viktiga ändamål. Liknande byggnader kan finnas i stormännens gårdar där de används som förvaringsrum för husgeråd och jordbruksredskap. De runda byggnaderna - som för övrigt är de högst sällsynta - används av konstnärer, som vid sitt arbete behöver ett jämnt fördelat ljus från taket. Vanligast är de fyrkantiga husen, vars grova bjälkar är synnerligen väl hopfogade i hörnen - ett sant mästerverk av byggnadskonst; även dessa har fönster högt uppe i taken, för att dagsljuset skall kunna strömma in och ge alla därinne full belysning. Stenhusen har dörröppningar i förhållande till byggnadens storlek, men smala fönstergluggar, som skydd mot den stränga kölden, frosten och snön. Vore de större och vidare, såsom fönstren i Italien, skulle husen i följd av den fint yrande snön, som röres upp av den starka blåsten, precis som dammet av virvelvinden, snart nog fyllas med massor av snö och inte kunna stå emot dess tryck, utan störta samman. -` - - data := []struct { - pType string - testData string - }{ - {"JSON", JSON}, - {"TOML", TOML}, - {"YAML", YAML}, - {"None", NONE}, - } - for _, v := range data { - p := GetParser([]byte(v.testData)) - if v.pType != p.Type() { - t.Fatalf("Wrong parser type, expected %v, got %v", v.pType, p.Type()) - } - md := p.Markdown() - if strings.TrimSpace(string(md)) != strings.TrimSpace(expectedBody) { - t.Log("Provided:", v.testData) - t.Log("Returned:", p.Markdown()) - t.Fatalf("Error, mismatched body in expected type %v, matched type %v", v.pType, p.Type()) - } - } -} diff --git a/middleware/markdown/metadata/metadata_toml.go b/middleware/markdown/metadata/metadata_toml.go deleted file mode 100644 index 75c2067f0..000000000 --- a/middleware/markdown/metadata/metadata_toml.go +++ /dev/null @@ -1,44 +0,0 @@ -package metadata - -import ( - "bytes" - - "github.com/BurntSushi/toml" -) - -// TOMLMetadataParser is the MetadataParser for TOML -type TOMLMetadataParser struct { - metadata Metadata - markdown *bytes.Buffer -} - -func (t *TOMLMetadataParser) Type() string { - return "TOML" -} - -// Parse metadata/markdown file -func (t *TOMLMetadataParser) Init(b *bytes.Buffer) bool { - meta, data := splitBuffer(b, "+++") - if meta == nil || data == nil { - return false - } - t.markdown = data - - m := make(map[string]interface{}) - if err := toml.Unmarshal(meta.Bytes(), &m); err != nil { - return false - } - t.metadata = NewMetadata(m) - - return true -} - -// Metadata returns parsed metadata. It should be called -// only after a call to Parse returns without error. -func (t *TOMLMetadataParser) Metadata() Metadata { - return t.metadata -} - -func (t *TOMLMetadataParser) Markdown() []byte { - return t.markdown.Bytes() -} diff --git a/middleware/markdown/metadata/metadata_yaml.go b/middleware/markdown/metadata/metadata_yaml.go deleted file mode 100644 index f7ef5bb4f..000000000 --- a/middleware/markdown/metadata/metadata_yaml.go +++ /dev/null @@ -1,43 +0,0 @@ -package metadata - -import ( - "bytes" - - "gopkg.in/yaml.v2" -) - -// YAMLMetadataParser is the MetadataParser for YAML -type YAMLMetadataParser struct { - metadata Metadata - markdown *bytes.Buffer -} - -func (y *YAMLMetadataParser) Type() string { - return "YAML" -} - -func (y *YAMLMetadataParser) Init(b *bytes.Buffer) bool { - meta, data := splitBuffer(b, "---") - if meta == nil || data == nil { - return false - } - y.markdown = data - - m := make(map[string]interface{}) - if err := yaml.Unmarshal(meta.Bytes(), &m); err != nil { - return false - } - y.metadata = NewMetadata(m) - - return true -} - -// Metadata returns parsed metadata. It should be called -// only after a call to Parse returns without error. -func (y *YAMLMetadataParser) Metadata() Metadata { - return y.metadata -} - -func (y *YAMLMetadataParser) Markdown() []byte { - return y.markdown.Bytes() -} diff --git a/middleware/markdown/process.go b/middleware/markdown/process.go deleted file mode 100644 index dc1dc6d0b..000000000 --- a/middleware/markdown/process.go +++ /dev/null @@ -1,74 +0,0 @@ -package markdown - -import ( - "io" - "io/ioutil" - "os" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/markdown/metadata" - "github.com/mholt/caddy/middleware/markdown/summary" - "github.com/russross/blackfriday" -) - -type FileInfo struct { - os.FileInfo - ctx middleware.Context -} - -func (f FileInfo) Summarize(wordcount int) (string, error) { - fp, err := f.ctx.Root.Open(f.Name()) - if err != nil { - return "", err - } - defer fp.Close() - - buf, err := ioutil.ReadAll(fp) - if err != nil { - return "", err - } - - return string(summary.Markdown(buf, wordcount)), nil -} - -// Markdown processes the contents of a page in b. It parses the metadata -// (if any) and uses the template (if found). -func (c *Config) Markdown(title string, r io.Reader, dirents []os.FileInfo, ctx middleware.Context) ([]byte, error) { - body, err := ioutil.ReadAll(r) - if err != nil { - return nil, err - } - - parser := metadata.GetParser(body) - markdown := parser.Markdown() - mdata := parser.Metadata() - - // process markdown - extns := 0 - extns |= blackfriday.EXTENSION_TABLES - extns |= blackfriday.EXTENSION_FENCED_CODE - extns |= blackfriday.EXTENSION_STRIKETHROUGH - extns |= blackfriday.EXTENSION_DEFINITION_LISTS - html := blackfriday.Markdown(markdown, c.Renderer, extns) - - // set it as body for template - mdata.Variables["body"] = string(html) - - // fixup title - mdata.Variables["title"] = mdata.Title - if mdata.Variables["title"] == "" { - mdata.Variables["title"] = title - } - - // massage possible files - files := []FileInfo{} - for _, ent := range dirents { - file := FileInfo{ - FileInfo: ent, - ctx: ctx, - } - files = append(files, file) - } - - return execTemplate(c, mdata, files, ctx) -} diff --git a/middleware/markdown/summary/render.go b/middleware/markdown/summary/render.go deleted file mode 100644 index 0de9800e2..000000000 --- a/middleware/markdown/summary/render.go +++ /dev/null @@ -1,153 +0,0 @@ -package summary - -import ( - "bytes" - - "github.com/russross/blackfriday" -) - -// Ensure we implement the Blackfriday Markdown Renderer interface -var _ blackfriday.Renderer = (*renderer)(nil) - -// renderer renders Markdown to plain-text meant for listings and excerpts, -// and implements the blackfriday.Renderer interface. -// -// Many of the methods are stubs with no output to prevent output of HTML markup. -type renderer struct{} - -// Blocklevel callbacks - -// Stub BlockCode is the code tag callback. -func (r renderer) BlockCode(out *bytes.Buffer, text []byte, land string) {} - -// Stub BlockQuote is teh quote tag callback. -func (r renderer) BlockQuote(out *bytes.Buffer, text []byte) {} - -// Stub BlockHtml is the HTML tag callback. -func (r renderer) BlockHtml(out *bytes.Buffer, text []byte) {} - -// Stub Header is the header tag callback. -func (r renderer) Header(out *bytes.Buffer, text func() bool, level int, id string) {} - -// Stub HRule is the horizontal rule tag callback. -func (r renderer) HRule(out *bytes.Buffer) {} - -// List is the list tag callback. -func (r renderer) List(out *bytes.Buffer, text func() bool, flags int) { - // TODO: This is not desired (we'd rather not write lists as part of summary), - // but see this issue: https://github.com/russross/blackfriday/issues/189 - marker := out.Len() - if !text() { - out.Truncate(marker) - } - out.Write([]byte{' '}) -} - -// Stub ListItem is the list item tag callback. -func (r renderer) ListItem(out *bytes.Buffer, text []byte, flags int) {} - -// Paragraph is the paragraph tag callback. This renders simple paragraph text -// into plain text, such that summaries can be easily generated. -func (r renderer) Paragraph(out *bytes.Buffer, text func() bool) { - marker := out.Len() - if !text() { - out.Truncate(marker) - } - out.Write([]byte{' '}) -} - -// Stub Table is the table tag callback. -func (r renderer) Table(out *bytes.Buffer, header []byte, body []byte, columnData []int) {} - -// Stub TableRow is the table row tag callback. -func (r renderer) TableRow(out *bytes.Buffer, text []byte) {} - -// Stub TableHeaderCell is the table header cell tag callback. -func (r renderer) TableHeaderCell(out *bytes.Buffer, text []byte, flags int) {} - -// Stub TableCell is the table cell tag callback. -func (r renderer) TableCell(out *bytes.Buffer, text []byte, flags int) {} - -// Stub Footnotes is the foot notes tag callback. -func (r renderer) Footnotes(out *bytes.Buffer, text func() bool) {} - -// Stub FootnoteItem is the footnote item tag callback. -func (r renderer) FootnoteItem(out *bytes.Buffer, name, text []byte, flags int) {} - -// Stub TitleBlock is the title tag callback. -func (r renderer) TitleBlock(out *bytes.Buffer, text []byte) {} - -// Spanlevel callbacks - -// Stub AutoLink is the autolink tag callback. -func (r renderer) AutoLink(out *bytes.Buffer, link []byte, kind int) {} - -// CodeSpan is the code span tag callback. Outputs a simple Markdown version -// of the code span. -func (r renderer) CodeSpan(out *bytes.Buffer, text []byte) { - out.Write([]byte("`")) - out.Write(text) - out.Write([]byte("`")) -} - -// DoubleEmphasis is the double emphasis tag callback. Outputs a simple -// plain-text version of the input. -func (r renderer) DoubleEmphasis(out *bytes.Buffer, text []byte) { - out.Write(text) -} - -// Emphasis is the emphasis tag callback. Outputs a simple plain-text -// version of the input. -func (r renderer) Emphasis(out *bytes.Buffer, text []byte) { - out.Write(text) -} - -// Stub Image is the image tag callback. -func (r renderer) Image(out *bytes.Buffer, link []byte, title []byte, alt []byte) {} - -// Stub LineBreak is the line break tag callback. -func (r renderer) LineBreak(out *bytes.Buffer) {} - -// Link is the link tag callback. Outputs a sipmle plain-text version -// of the input. -func (r renderer) Link(out *bytes.Buffer, link []byte, title []byte, content []byte) { - out.Write(content) -} - -// Stub RawHtmlTag is the raw HTML tag callback. -func (r renderer) RawHtmlTag(out *bytes.Buffer, tag []byte) {} - -// TripleEmphasis is the triple emphasis tag callback. Outputs a simple plain-text -// version of the input. -func (r renderer) TripleEmphasis(out *bytes.Buffer, text []byte) { - out.Write(text) -} - -// Stub StrikeThrough is the strikethrough tag callback. -func (r renderer) StrikeThrough(out *bytes.Buffer, text []byte) {} - -// Stub FootnoteRef is the footnote ref tag callback. -func (r renderer) FootnoteRef(out *bytes.Buffer, ref []byte, id int) {} - -// Lowlevel callbacks - -// Entity callback. Outputs a simple plain-text version of the input. -func (r renderer) Entity(out *bytes.Buffer, entity []byte) { - out.Write(entity) -} - -// NormalText callback. Outputs a simple plain-text version of the input. -func (r renderer) NormalText(out *bytes.Buffer, text []byte) { - out.Write(text) -} - -// Header and footer - -// Stub DocumentHeader callback. -func (r renderer) DocumentHeader(out *bytes.Buffer) {} - -// Stub DocumentFooter callback. -func (r renderer) DocumentFooter(out *bytes.Buffer) {} - -// Stub GetFlags returns zero. -func (r renderer) GetFlags() int { return 0 } diff --git a/middleware/markdown/summary/summary.go b/middleware/markdown/summary/summary.go deleted file mode 100644 index e43a17187..000000000 --- a/middleware/markdown/summary/summary.go +++ /dev/null @@ -1,18 +0,0 @@ -package summary - -import ( - "bytes" - - "github.com/russross/blackfriday" -) - -// Markdown formats input using a plain-text renderer, and -// then returns up to the first `wordcount` words as a summary. -func Markdown(input []byte, wordcount int) []byte { - words := bytes.Fields(blackfriday.Markdown(input, renderer{}, 0)) - if wordcount > len(words) { - wordcount = len(words) - } - - return bytes.Join(words[0:wordcount], []byte{' '}) -} diff --git a/middleware/markdown/template.go b/middleware/markdown/template.go deleted file mode 100644 index 10ea31c58..000000000 --- a/middleware/markdown/template.go +++ /dev/null @@ -1,88 +0,0 @@ -package markdown - -import ( - "bytes" - "io/ioutil" - "text/template" - - "github.com/mholt/caddy/middleware" - "github.com/mholt/caddy/middleware/markdown/metadata" -) - -// Data represents a markdown document. -type Data struct { - middleware.Context - Doc map[string]string - DocFlags map[string]bool - Styles []string - Scripts []string - Files []FileInfo -} - -// Include "overrides" the embedded middleware.Context's Include() -// method so that included files have access to d's fields. -// Note: using {{template 'template-name' .}} instead might be better. -func (d Data) Include(filename string) (string, error) { - return middleware.ContextInclude(filename, d, d.Root) -} - -// execTemplate executes a template given a requestPath, template, and metadata -func execTemplate(c *Config, mdata metadata.Metadata, files []FileInfo, ctx middleware.Context) ([]byte, error) { - mdData := Data{ - Context: ctx, - Doc: mdata.Variables, - DocFlags: mdata.Flags, - Styles: c.Styles, - Scripts: c.Scripts, - Files: files, - } - - b := new(bytes.Buffer) - if err := c.Template.ExecuteTemplate(b, mdata.Template, mdData); err != nil { - return nil, err - } - - return b.Bytes(), nil -} - -func SetTemplate(t *template.Template, name, filename string) error { - - // Read template - buf, err := ioutil.ReadFile(filename) - if err != nil { - return err - } - - // Update if exists - if tt := t.Lookup(name); tt != nil { - _, err = tt.Parse(string(buf)) - return err - } - - // Allocate new name if not - _, err = t.New(name).Parse(string(buf)) - return err -} - -func GetDefaultTemplate() *template.Template { - return template.Must(template.New("").Parse(defaultTemplate)) -} - -const ( - defaultTemplate = ` - - - {{.Doc.title}} - - {{- range .Styles}} - - {{- end}} - {{- range .Scripts}} - - {{- end}} - - - {{.Doc.body}} - -` -) diff --git a/middleware/markdown/testdata/blog/test.md b/middleware/markdown/testdata/blog/test.md deleted file mode 100644 index 93f07a493..000000000 --- a/middleware/markdown/testdata/blog/test.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -title: Markdown test 1 -sitename: A Caddy website ---- - -## Welcome on the blog - -Body - -``` go -func getTrue() bool { - return true -} -``` diff --git a/middleware/markdown/testdata/docflags/template.txt b/middleware/markdown/testdata/docflags/template.txt deleted file mode 100644 index 2760d18d1..000000000 --- a/middleware/markdown/testdata/docflags/template.txt +++ /dev/null @@ -1,4 +0,0 @@ -Doc.var_string {{.Doc.var_string}} -Doc.var_bool {{.Doc.var_bool}} -DocFlags.var_string {{.DocFlags.var_string}} -DocFlags.var_bool {{.DocFlags.var_bool}} diff --git a/middleware/markdown/testdata/docflags/test.md b/middleware/markdown/testdata/docflags/test.md deleted file mode 100644 index 64ca7f78d..000000000 --- a/middleware/markdown/testdata/docflags/test.md +++ /dev/null @@ -1,4 +0,0 @@ ---- -var_string: hello -var_bool: true ---- diff --git a/middleware/markdown/testdata/header.html b/middleware/markdown/testdata/header.html deleted file mode 100644 index cfbdc75b5..000000000 --- a/middleware/markdown/testdata/header.html +++ /dev/null @@ -1 +0,0 @@ -

Header for: {{.Doc.title}}

\ No newline at end of file diff --git a/middleware/markdown/testdata/log/test.md b/middleware/markdown/testdata/log/test.md deleted file mode 100644 index 476ab3015..000000000 --- a/middleware/markdown/testdata/log/test.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -title: Markdown test 2 -sitename: A Caddy website ---- - -## Welcome on the blog - -Body - -``` go -func getTrue() bool { - return true -} -``` diff --git a/middleware/markdown/testdata/markdown_tpl.html b/middleware/markdown/testdata/markdown_tpl.html deleted file mode 100644 index 7c6978500..000000000 --- a/middleware/markdown/testdata/markdown_tpl.html +++ /dev/null @@ -1,11 +0,0 @@ - - - -{{.Doc.title}} - - -{{.Include "header.html"}} -Welcome to {{.Doc.sitename}}! -{{.Doc.body}} - - diff --git a/middleware/markdown/testdata/og/first.md b/middleware/markdown/testdata/og/first.md deleted file mode 100644 index 4d7a4251f..000000000 --- a/middleware/markdown/testdata/og/first.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -title: first_post -sitename: title ---- -# Test h1 diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go deleted file mode 100644 index 62fa4e250..000000000 --- a/middleware/middleware_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package middleware - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func TestIndexfile(t *testing.T) { - tests := []struct { - rootDir http.FileSystem - fpath string - indexFiles []string - shouldErr bool - expectedFilePath string //retun value - expectedBoolValue bool //return value - }{ - { - http.Dir("./templates/testdata"), - "/images/", - []string{"img.htm"}, - false, - "/images/img.htm", - true, - }, - } - for i, test := range tests { - actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles) - if actualBoolValue == true && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if actualBoolValue != true && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist") - } - if actualFilePath != test.expectedFilePath { - t.Fatalf("Test %d expected returned filepath to be %s, but got %s ", - i, test.expectedFilePath, actualFilePath) - - } - if actualBoolValue != test.expectedBoolValue { - t.Fatalf("Test %d expected returned bool value to be %v, but got %v ", - i, test.expectedBoolValue, actualBoolValue) - - } - } -} - -func TestSetLastModified(t *testing.T) { - nowTime := time.Now() - - // ovewrite the function to return reliable time - originalGetCurrentTimeFunc := currentTime - currentTime = func() time.Time { - return nowTime - } - defer func() { - currentTime = originalGetCurrentTimeFunc - }() - - pastTime := nowTime.Truncate(1 * time.Hour) - futureTime := nowTime.Add(1 * time.Hour) - - tests := []struct { - inputModTime time.Time - expectedIsHeaderSet bool - expectedLastModified string - }{ - { - inputModTime: pastTime, - expectedIsHeaderSet: true, - expectedLastModified: pastTime.UTC().Format(http.TimeFormat), - }, - { - inputModTime: nowTime, - expectedIsHeaderSet: true, - expectedLastModified: nowTime.UTC().Format(http.TimeFormat), - }, - { - inputModTime: futureTime, - expectedIsHeaderSet: true, - expectedLastModified: nowTime.UTC().Format(http.TimeFormat), - }, - { - inputModTime: time.Time{}, - expectedIsHeaderSet: false, - }, - } - - for i, test := range tests { - responseRecorder := httptest.NewRecorder() - errorPrefix := fmt.Sprintf("Test [%d]: ", i) - SetLastModifiedHeader(responseRecorder, test.inputModTime) - actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified") - - if test.expectedIsHeaderSet && actualLastModifiedHeader == "" { - t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing") - } - - if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" { - t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader) - } - - if test.expectedLastModified != actualLastModifiedHeader { - t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader) - } - } -} diff --git a/middleware/mime/mime.go b/middleware/mime/mime.go deleted file mode 100644 index 6990c596d..000000000 --- a/middleware/mime/mime.go +++ /dev/null @@ -1,32 +0,0 @@ -package mime - -import ( - "net/http" - "path" - - "github.com/mholt/caddy/middleware" -) - -// Config represent a mime config. Map from extension to mime-type. -// Note, this should be safe with concurrent read access, as this is -// not modified concurrently. -type Config map[string]string - -// Mime sets Content-Type header of requests based on configurations. -type Mime struct { - Next middleware.Handler - Configs Config -} - -// ServeHTTP implements the middleware.Handler interface. -func (e Mime) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - - // Get a clean /-path, grab the extension - ext := path.Ext(path.Clean(r.URL.Path)) - - if contentType, ok := e.Configs[ext]; ok { - w.Header().Set("Content-Type", contentType) - } - - return e.Next.ServeHTTP(w, r) -} diff --git a/middleware/mime/mime_test.go b/middleware/mime/mime_test.go deleted file mode 100644 index 4010b0aef..000000000 --- a/middleware/mime/mime_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package mime - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestMimeHandler(t *testing.T) { - - mimes := Config{ - ".html": "text/html", - ".txt": "text/plain", - ".swf": "application/x-shockwave-flash", - } - - m := Mime{Configs: mimes} - - w := httptest.NewRecorder() - exts := []string{ - ".html", ".txt", ".swf", - } - for _, e := range exts { - url := "/file" + e - r, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Error(err) - } - m.Next = nextFunc(true, mimes[e]) - _, err = m.ServeHTTP(w, r) - if err != nil { - t.Error(err) - } - } - - w = httptest.NewRecorder() - exts = []string{ - ".htm1", ".abc", ".mdx", - } - for _, e := range exts { - url := "/file" + e - r, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Error(err) - } - m.Next = nextFunc(false, "") - _, err = m.ServeHTTP(w, r) - if err != nil { - t.Error(err) - } - } -} - -func nextFunc(shouldMime bool, contentType string) middleware.Handler { - return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - if shouldMime { - if w.Header().Get("Content-Type") != contentType { - return 0, fmt.Errorf("expected Content-Type: %v, found %v", contentType, r.Header.Get("Content-Type")) - } - return 0, nil - } - if w.Header().Get("Content-Type") != "" { - return 0, fmt.Errorf("Content-Type header not expected") - } - return 0, nil - }) -} diff --git a/middleware/path.go b/middleware/path.go deleted file mode 100644 index 9c831e771..000000000 --- a/middleware/path.go +++ /dev/null @@ -1,44 +0,0 @@ -package middleware - -import ( - "os" - "strings" -) - -const caseSensitivePathEnv = "CASE_SENSITIVE_PATH" - -func init() { - initCaseSettings() -} - -// CaseSensitivePath determines if paths should be case sensitive. -// This is configurable via CASE_SENSITIVE_PATH environment variable. -// It defaults to false. -var CaseSensitivePath = true - -// initCaseSettings loads case sensitivity config from environment variable. -// -// This could have been in init, but init cannot be called from tests. -func initCaseSettings() { - switch os.Getenv(caseSensitivePathEnv) { - case "0", "false": - CaseSensitivePath = false - default: - CaseSensitivePath = true - } -} - -// Path represents a URI path, maybe with pattern characters. -type Path string - -// Matches checks to see if other matches p. -// -// Path matching will probably not always be a direct -// comparison; this method assures that paths can be -// easily and consistently matched. -func (p Path) Matches(other string) bool { - if CaseSensitivePath { - return strings.HasPrefix(string(p), other) - } - return strings.HasPrefix(strings.ToLower(string(p)), strings.ToLower(other)) -} diff --git a/middleware/pprof/pprof.go b/middleware/pprof/pprof.go deleted file mode 100644 index 8d8e9c788..000000000 --- a/middleware/pprof/pprof.go +++ /dev/null @@ -1,41 +0,0 @@ -package pprof - -import ( - "net/http" - pp "net/http/pprof" - - "github.com/mholt/caddy/middleware" -) - -// BasePath is the base path to match for all pprof requests. -const BasePath = "/debug/pprof" - -// Handler is a simple struct whose ServeHTTP will delegate pprof -// endpoints to their equivalent net/http/pprof handlers. -type Handler struct { - Next middleware.Handler - Mux *http.ServeMux -} - -// ServeHTTP handles requests to BasePath with pprof, or passes -// all other requests up the chain. -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - if middleware.Path(r.URL.Path).Matches(BasePath) { - h.Mux.ServeHTTP(w, r) - return 0, nil - } - return h.Next.ServeHTTP(w, r) -} - -// NewMux returns a new http.ServeMux that routes pprof requests. -// It pretty much copies what the std lib pprof does on init: -// https://golang.org/src/net/http/pprof/pprof.go#L67 -func NewMux() *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc(BasePath+"/", pp.Index) - mux.HandleFunc(BasePath+"/cmdline", pp.Cmdline) - mux.HandleFunc(BasePath+"/profile", pp.Profile) - mux.HandleFunc(BasePath+"/symbol", pp.Symbol) - mux.HandleFunc(BasePath+"/trace", pp.Trace) - return mux -} diff --git a/middleware/pprof/pprof_test.go b/middleware/pprof/pprof_test.go deleted file mode 100644 index a9aee20c9..000000000 --- a/middleware/pprof/pprof_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package pprof - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestServeHTTP(t *testing.T) { - h := Handler{ - Next: middleware.HandlerFunc(nextHandler), - Mux: NewMux(), - } - - w := httptest.NewRecorder() - r, err := http.NewRequest("GET", "/debug/pprof", nil) - if err != nil { - t.Fatal(err) - } - status, err := h.ServeHTTP(w, r) - - if status != 0 { - t.Errorf("Expected status %d but got %d", 0, status) - } - if err != nil { - t.Errorf("Expected nil error, but got: %v", err) - } - if w.Body.String() == "content" { - t.Errorf("Expected pprof to handle request, but it didn't") - } - - w = httptest.NewRecorder() - r, err = http.NewRequest("GET", "/foo", nil) - if err != nil { - t.Fatal(err) - } - status, err = h.ServeHTTP(w, r) - if status != http.StatusNotFound { - t.Errorf("Test two: Expected status %d but got %d", http.StatusNotFound, status) - } - if err != nil { - t.Errorf("Test two: Expected nil error, but got: %v", err) - } - if w.Body.String() != "content" { - t.Errorf("Expected pprof to pass the request thru, but it didn't; got: %s", w.Body.String()) - } -} - -func nextHandler(w http.ResponseWriter, r *http.Request) (int, error) { - fmt.Fprintf(w, "content") - return http.StatusNotFound, nil -} diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go deleted file mode 100644 index 96e382a5c..000000000 --- a/middleware/proxy/policy.go +++ /dev/null @@ -1,101 +0,0 @@ -package proxy - -import ( - "math/rand" - "sync/atomic" -) - -// HostPool is a collection of UpstreamHosts. -type HostPool []*UpstreamHost - -// Policy decides how a host will be selected from a pool. -type Policy interface { - Select(pool HostPool) *UpstreamHost -} - -func init() { - RegisterPolicy("random", func() Policy { return &Random{} }) - RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) - RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) -} - -// Random is a policy that selects up hosts from a pool at random. -type Random struct{} - -// Select selects an up host at random from the specified pool. -func (r *Random) Select(pool HostPool) *UpstreamHost { - // instead of just generating a random index - // this is done to prevent selecting a unavailable host - var randHost *UpstreamHost - count := 0 - for _, host := range pool { - if !host.Available() { - continue - } - count++ - if count == 1 { - randHost = host - } else { - r := rand.Int() % count - if r == (count - 1) { - randHost = host - } - } - } - return randHost -} - -// LeastConn is a policy that selects the host with the least connections. -type LeastConn struct{} - -// Select selects the up host with the least number of connections in the -// pool. If more than one host has the same least number of connections, -// one of the hosts is chosen at random. -func (r *LeastConn) Select(pool HostPool) *UpstreamHost { - var bestHost *UpstreamHost - count := 0 - leastConn := int64(1<<63 - 1) - for _, host := range pool { - if !host.Available() { - continue - } - hostConns := host.Conns - if hostConns < leastConn { - bestHost = host - leastConn = hostConns - count = 1 - } else if hostConns == leastConn { - // randomly select host among hosts with least connections - count++ - if count == 1 { - bestHost = host - } else { - r := rand.Int() % count - if r == (count - 1) { - bestHost = host - } - } - } - } - return bestHost -} - -// RoundRobin is a policy that selects hosts based on round robin ordering. -type RoundRobin struct { - Robin uint32 -} - -// Select selects an up host from the pool using a round robin ordering scheme. -func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { - poolLen := uint32(len(pool)) - selection := atomic.AddUint32(&r.Robin, 1) % poolLen - host := pool[selection] - // if the currently selected host is not available, just ffwd to up host - for i := uint32(1); !host.Available() && i < poolLen; i++ { - host = pool[(selection+i)%poolLen] - } - if !host.Available() { - return nil - } - return host -} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go deleted file mode 100644 index 4cc05f029..000000000 --- a/middleware/proxy/policy_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package proxy - -import ( - "net/http" - "net/http/httptest" - "os" - "testing" -) - -var workableServer *httptest.Server - -func TestMain(m *testing.M) { - workableServer = httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // do nothing - })) - r := m.Run() - workableServer.Close() - os.Exit(r) -} - -type customPolicy struct{} - -func (r *customPolicy) Select(pool HostPool) *UpstreamHost { - return pool[0] -} - -func testPool() HostPool { - pool := []*UpstreamHost{ - { - Name: workableServer.URL, // this should resolve (healthcheck test) - }, - { - Name: "http://shouldnot.resolve", // this shouldn't - }, - { - Name: "http://C", - }, - } - return HostPool(pool) -} - -func TestRoundRobinPolicy(t *testing.T) { - pool := testPool() - rrPolicy := &RoundRobin{} - h := rrPolicy.Select(pool) - // First selected host is 1, because counter starts at 0 - // and increments before host is selected - if h != pool[1] { - t.Error("Expected first round robin host to be second host in the pool.") - } - h = rrPolicy.Select(pool) - if h != pool[2] { - t.Error("Expected second round robin host to be third host in the pool.") - } - h = rrPolicy.Select(pool) - if h != pool[0] { - t.Error("Expected third round robin host to be first host in the pool.") - } - // mark host as down - pool[1].Unhealthy = true - h = rrPolicy.Select(pool) - if h != pool[2] { - t.Error("Expected to skip down host.") - } - // mark host as full - pool[2].Conns = 1 - pool[2].MaxConns = 1 - h = rrPolicy.Select(pool) - if h != pool[0] { - t.Error("Expected to skip full host.") - } -} - -func TestLeastConnPolicy(t *testing.T) { - pool := testPool() - lcPolicy := &LeastConn{} - pool[0].Conns = 10 - pool[1].Conns = 10 - h := lcPolicy.Select(pool) - if h != pool[2] { - t.Error("Expected least connection host to be third host.") - } - pool[2].Conns = 100 - h = lcPolicy.Select(pool) - if h != pool[0] && h != pool[1] { - t.Error("Expected least connection host to be first or second host.") - } -} - -func TestCustomPolicy(t *testing.T) { - pool := testPool() - customPolicy := &customPolicy{} - h := customPolicy.Select(pool) - if h != pool[0] { - t.Error("Expected custom policy host to be the first host.") - } -} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go deleted file mode 100644 index 264444d9a..000000000 --- a/middleware/proxy/proxy.go +++ /dev/null @@ -1,242 +0,0 @@ -// Package proxy is middleware that proxies requests. -package proxy - -import ( - "errors" - "net" - "net/http" - "net/url" - "strings" - "sync/atomic" - "time" - - "github.com/mholt/caddy/middleware" -) - -var errUnreachable = errors.New("unreachable backend") - -// Proxy represents a middleware instance that can proxy requests. -type Proxy struct { - Next middleware.Handler - Upstreams []Upstream -} - -// Upstream manages a pool of proxy upstream hosts. Select should return a -// suitable upstream host, or nil if no such hosts are available. -type Upstream interface { - // The path this upstream host should be routed on - From() string - // Selects an upstream host to be routed to. - Select() *UpstreamHost - // Checks if subpath is not an ignored path - AllowedPath(string) bool -} - -// UpstreamHostDownFunc can be used to customize how Down behaves. -type UpstreamHostDownFunc func(*UpstreamHost) bool - -// UpstreamHost represents a single proxy upstream -type UpstreamHost struct { - Conns int64 // must be first field to be 64-bit aligned on 32-bit systems - Name string // hostname of this upstream host - ReverseProxy *ReverseProxy - Fails int32 - FailTimeout time.Duration - Unhealthy bool - UpstreamHeaders http.Header - DownstreamHeaders http.Header - CheckDown UpstreamHostDownFunc - WithoutPathPrefix string - MaxConns int64 -} - -// Down checks whether the upstream host is down or not. -// Down will try to use uh.CheckDown first, and will fall -// back to some default criteria if necessary. -func (uh *UpstreamHost) Down() bool { - if uh.CheckDown == nil { - // Default settings - return uh.Unhealthy || uh.Fails > 0 - } - return uh.CheckDown(uh) -} - -// Full checks whether the upstream host has reached its maximum connections -func (uh *UpstreamHost) Full() bool { - return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns -} - -// Available checks whether the upstream host is available for proxying to -func (uh *UpstreamHost) Available() bool { - return !uh.Down() && !uh.Full() -} - -// tryDuration is how long to try upstream hosts; failures result in -// immediate retries until this duration ends or we get a nil host. -var tryDuration = 60 * time.Second - -// ServeHTTP satisfies the middleware.Handler interface. -func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - for _, upstream := range p.Upstreams { - if !middleware.Path(r.URL.Path).Matches(upstream.From()) || - !upstream.AllowedPath(r.URL.Path) { - continue - } - - var replacer middleware.Replacer - start := time.Now() - - outreq := createUpstreamRequest(r) - - // Since Select() should give us "up" hosts, keep retrying - // hosts until timeout (or until we get a nil host). - for time.Now().Sub(start) < tryDuration { - host := upstream.Select() - if host == nil { - return http.StatusBadGateway, errUnreachable - } - if rr, ok := w.(*middleware.ResponseRecorder); ok && rr.Replacer != nil { - rr.Replacer.Set("upstream", host.Name) - } - - outreq.Host = host.Name - if host.UpstreamHeaders != nil { - if replacer == nil { - rHost := r.Host - replacer = middleware.NewReplacer(r, nil, "") - outreq.Host = rHost - } - if v, ok := host.UpstreamHeaders["Host"]; ok { - outreq.Host = replacer.Replace(v[len(v)-1]) - } - // Modify headers for request that will be sent to the upstream host - upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer) - for k, v := range upHeaders { - outreq.Header[k] = v - } - } - - var downHeaderUpdateFn respUpdateFn - if host.DownstreamHeaders != nil { - if replacer == nil { - rHost := r.Host - replacer = middleware.NewReplacer(r, nil, "") - outreq.Host = rHost - } - //Creates a function that is used to update headers the response received by the reverse proxy - downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) - } - - proxy := host.ReverseProxy - if baseURL, err := url.Parse(host.Name); err == nil { - r.Host = baseURL.Host - if proxy == nil { - proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix) - } - } else if proxy == nil { - return http.StatusInternalServerError, err - } - - atomic.AddInt64(&host.Conns, 1) - backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) - atomic.AddInt64(&host.Conns, -1) - if backendErr == nil { - return 0, nil - } - timeout := host.FailTimeout - if timeout == 0 { - timeout = 10 * time.Second - } - atomic.AddInt32(&host.Fails, 1) - go func(host *UpstreamHost, timeout time.Duration) { - time.Sleep(timeout) - atomic.AddInt32(&host.Fails, -1) - }(host, timeout) - } - return http.StatusBadGateway, errUnreachable - } - - return p.Next.ServeHTTP(w, r) -} - -// createUpstremRequest shallow-copies r into a new request -// that can be sent upstream. -func createUpstreamRequest(r *http.Request) *http.Request { - outreq := new(http.Request) - *outreq = *r // includes shallow copies of maps, but okay - - // Restore URL Path if it has been modified - if outreq.URL.RawPath != "" { - outreq.URL.Opaque = outreq.URL.RawPath - } - - // Remove hop-by-hop headers to the backend. Especially - // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from r (shallow - // copied above) so we only copy it if necessary. - for _, h := range hopHeaders { - if outreq.Header.Get(h) != "" { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, r.Header) - outreq.Header.Del(h) - } - } - - if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { - // If we aren't the first proxy, retain prior - // X-Forwarded-For information as a comma+space - // separated list and fold multiple headers into one. - if prior, ok := outreq.Header["X-Forwarded-For"]; ok { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - outreq.Header.Set("X-Forwarded-For", clientIP) - } - - return outreq -} - -func createRespHeaderUpdateFn(rules http.Header, replacer middleware.Replacer) respUpdateFn { - return func(resp *http.Response) { - newHeaders := createHeadersByRules(rules, resp.Header, replacer) - for h, v := range newHeaders { - resp.Header[h] = v - } - } -} - -func createHeadersByRules(rules http.Header, base http.Header, repl middleware.Replacer) http.Header { - newHeaders := make(http.Header) - for header, values := range rules { - if strings.HasPrefix(header, "+") { - header = strings.TrimLeft(header, "+") - add(newHeaders, header, base[header]) - applyEach(values, repl.Replace) - add(newHeaders, header, values) - } else if strings.HasPrefix(header, "-") { - base.Del(strings.TrimLeft(header, "-")) - } else if _, ok := base[header]; ok { - applyEach(values, repl.Replace) - for _, v := range values { - newHeaders.Set(header, v) - } - } else { - applyEach(values, repl.Replace) - add(newHeaders, header, values) - add(newHeaders, header, base[header]) - } - } - return newHeaders -} - -func applyEach(values []string, mapFn func(string) string) { - for i, v := range values { - values[i] = mapFn(v) - } -} - -func add(base http.Header, header string, values []string) { - for _, v := range values { - base.Add(header, v) - } -} diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go deleted file mode 100644 index 592b782ea..000000000 --- a/middleware/proxy/proxy_test.go +++ /dev/null @@ -1,583 +0,0 @@ -package proxy - -import ( - "bufio" - "bytes" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "net/http" - "net/http/httptest" - "net/url" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - "time" - - "github.com/mholt/caddy/middleware" - - "golang.org/x/net/websocket" -) - -func init() { - tryDuration = 50 * time.Millisecond // prevent tests from hanging -} - -func TestReverseProxy(t *testing.T) { - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) - - var requestReceived bool - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestReceived = true - w.Write([]byte("Hello, client")) - })) - defer backend.Close() - - // set up proxy - p := &Proxy{ - Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, - } - - // create request and response recorder - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - w := httptest.NewRecorder() - - p.ServeHTTP(w, r) - - if !requestReceived { - t.Error("Expected backend to receive request, but it didn't") - } - - // Make sure {upstream} placeholder is set - rr := middleware.NewResponseRecorder(httptest.NewRecorder()) - rr.Replacer = middleware.NewReplacer(r, rr, "-") - - p.ServeHTTP(rr, r) - - if got, want := rr.Replacer.Replace("{upstream}"), backend.URL; got != want { - t.Errorf("Expected custom placeholder {upstream} to be set (%s), but it wasn't; got: %s", want, got) - } -} - -func TestReverseProxyInsecureSkipVerify(t *testing.T) { - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) - - var requestReceived bool - backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestReceived = true - w.Write([]byte("Hello, client")) - })) - defer backend.Close() - - // set up proxy - p := &Proxy{ - Upstreams: []Upstream{newFakeUpstream(backend.URL, true)}, - } - - // create request and response recorder - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - w := httptest.NewRecorder() - - p.ServeHTTP(w, r) - - if !requestReceived { - t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't") - } -} - -func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { - // No-op websocket backend simply allows the WS connection to be - // accepted then it will be immediately closed. Perfect for testing. - wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) - defer wsNop.Close() - - // Get proxy to use for the test - p := newWebSocketTestProxy(wsNop.URL) - - // Create client request - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - r.Header = http.Header{ - "Connection": {"Upgrade"}, - "Upgrade": {"websocket"}, - "Origin": {wsNop.URL}, - "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, - "Sec-WebSocket-Version": {"13"}, - } - - // Capture the request - w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)} - - // Booya! Do the test. - p.ServeHTTP(w, r) - - // Make sure the backend accepted the WS connection. - // Mostly interested in the Upgrade and Connection response headers - // and the 101 status code. - expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n") - actual := w.fakeConn.writeBuf.Bytes() - if !bytes.Equal(actual, expected) { - t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) - } -} - -func TestWebSocketReverseProxyFromWSClient(t *testing.T) { - // Echo server allows us to test that socket bytes are properly - // being proxied. - wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { - io.Copy(ws, ws) - })) - defer wsEcho.Close() - - // Get proxy to use for the test - p := newWebSocketTestProxy(wsEcho.URL) - - // This is a full end-end test, so the proxy handler - // has to be part of a server listening on a port. Our - // WS client will connect to this test server, not - // the echo client directly. - echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p.ServeHTTP(w, r) - })) - defer echoProxy.Close() - - // Set up WebSocket client - url := strings.Replace(echoProxy.URL, "http://", "ws://", 1) - ws, err := websocket.Dial(url, "", echoProxy.URL) - if err != nil { - t.Fatal(err) - } - defer ws.Close() - - // Send test message - trialMsg := "Is it working?" - websocket.Message.Send(ws, trialMsg) - - // It should be echoed back to us - var actualMsg string - websocket.Message.Receive(ws, &actualMsg) - if actualMsg != trialMsg { - t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) - } -} - -func TestUnixSocketProxy(t *testing.T) { - if runtime.GOOS == "windows" { - return - } - - trialMsg := "Is it working?" - - var proxySuccess bool - - // This is our fake "application" we want to proxy to - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Request was proxied when this is called - proxySuccess = true - - fmt.Fprint(w, trialMsg) - })) - - // Get absolute path for unix: socket - socketPath, err := filepath.Abs("./test_socket") - if err != nil { - t.Fatalf("Unable to get absolute path: %v", err) - } - - // Change httptest.Server listener to listen to unix: socket - ln, err := net.Listen("unix", socketPath) - if err != nil { - t.Fatalf("Unable to listen: %v", err) - } - ts.Listener = ln - - ts.Start() - defer ts.Close() - - url := strings.Replace(ts.URL, "http://", "unix:", 1) - p := newWebSocketTestProxy(url) - - echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p.ServeHTTP(w, r) - })) - defer echoProxy.Close() - - res, err := http.Get(echoProxy.URL) - if err != nil { - t.Fatalf("Unable to GET: %v", err) - } - - greeting, err := ioutil.ReadAll(res.Body) - res.Body.Close() - if err != nil { - t.Fatalf("Unable to GET: %v", err) - } - - actualMsg := fmt.Sprintf("%s", greeting) - - if !proxySuccess { - t.Errorf("Expected request to be proxied, but it wasn't") - } - - if actualMsg != trialMsg { - t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) - } -} - -func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, messageFormat, r.URL.String()) - })) - - return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts -} - -func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, error) { - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, messageFormat, r.URL.String()) - })) - - socketPath, err := filepath.Abs("./test_socket") - if err != nil { - return nil, nil, fmt.Errorf("Unable to get absolute path: %v", err) - } - - ln, err := net.Listen("unix", socketPath) - if err != nil { - return nil, nil, fmt.Errorf("Unable to listen: %v", err) - } - ts.Listener = ln - - ts.Start() - - tsURL := strings.Replace(ts.URL, "http://", "unix:", 1) - - return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, nil -} - -func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) { - echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p.ServeHTTP(w, r) - })) - - // *httptest.Server is passed so it can be `defer`red properly - defer ts.Close() - defer echoProxy.Close() - - res, err := http.Get(echoProxy.URL + path) - if err != nil { - return "", fmt.Errorf("Unable to GET: %v", err) - } - - greeting, err := ioutil.ReadAll(res.Body) - res.Body.Close() - if err != nil { - return "", fmt.Errorf("Unable to read body: %v", err) - } - - return fmt.Sprintf("%s", greeting), nil -} - -func TestUnixSocketProxyPaths(t *testing.T) { - greeting := "Hello route %s" - - tests := []struct { - url string - prefix string - expected string - }{ - {"", "", fmt.Sprintf(greeting, "/")}, - {"/hello", "", fmt.Sprintf(greeting, "/hello")}, - {"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")}, - {"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")}, - {"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")}, - {"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")}, - {"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")}, - {"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")}, - {"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")}, - {"/queues/%2F/fetchtasks", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks")}, - {"/queues/%2F/fetchtasks?foo=bar", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks?foo=bar")}, - } - - for _, test := range tests { - p, ts := GetHTTPProxy(greeting, test.prefix) - - actualMsg, err := GetTestServerMessage(p, ts, test.url) - - if err != nil { - t.Fatalf("Getting server message failed - %v", err) - } - - if actualMsg != test.expected { - t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg) - } - } - - if runtime.GOOS == "windows" { - return - } - - for _, test := range tests { - p, ts, err := GetSocketProxy(greeting, test.prefix) - - if err != nil { - t.Fatalf("Getting socket proxy failed - %v", err) - } - - actualMsg, err := GetTestServerMessage(p, ts, test.url) - - if err != nil { - t.Fatalf("Getting server message failed - %v", err) - } - - if actualMsg != test.expected { - t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg) - } - } -} - -func TestUpstreamHeadersUpdate(t *testing.T) { - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) - - var actualHeaders http.Header - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("Hello, client")) - actualHeaders = r.Header - })) - defer backend.Close() - - upstream := newFakeUpstream(backend.URL, false) - upstream.host.UpstreamHeaders = http.Header{ - "Connection": {"{>Connection}"}, - "Upgrade": {"{>Upgrade}"}, - "+Merge-Me": {"Merge-Value"}, - "+Add-Me": {"Add-Value"}, - "-Remove-Me": {""}, - "Replace-Me": {"{hostname}"}, - } - // set up proxy - p := &Proxy{ - Upstreams: []Upstream{upstream}, - } - - // create request and response recorder - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - w := httptest.NewRecorder() - - //add initial headers - r.Header.Add("Merge-Me", "Initial") - r.Header.Add("Remove-Me", "Remove-Value") - r.Header.Add("Replace-Me", "Replace-Value") - - p.ServeHTTP(w, r) - - replacer := middleware.NewReplacer(r, nil, "") - - headerKey := "Merge-Me" - values, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey) - } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { - t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values) - } - - headerKey = "Add-Me" - if _, ok := actualHeaders[headerKey]; !ok { - t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey) - } - - headerKey = "Remove-Me" - if _, ok := actualHeaders[headerKey]; ok { - t.Errorf("Request sent to upstream backend should not contain %v header", headerKey) - } - - headerKey = "Replace-Me" - headerValue := replacer.Replace("{hostname}") - value, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Request sent to upstream backend should not remove %v header", headerKey) - } else if len(value) > 0 && headerValue != value[0] { - t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value) - } - -} - -func TestDownstreamHeadersUpdate(t *testing.T) { - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Merge-Me", "Initial") - w.Header().Add("Remove-Me", "Remove-Value") - w.Header().Add("Replace-Me", "Replace-Value") - w.Write([]byte("Hello, client")) - })) - defer backend.Close() - - upstream := newFakeUpstream(backend.URL, false) - upstream.host.DownstreamHeaders = http.Header{ - "+Merge-Me": {"Merge-Value"}, - "+Add-Me": {"Add-Value"}, - "-Remove-Me": {""}, - "Replace-Me": {"{hostname}"}, - } - // set up proxy - p := &Proxy{ - Upstreams: []Upstream{upstream}, - } - - // create request and response recorder - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - w := httptest.NewRecorder() - - p.ServeHTTP(w, r) - - replacer := middleware.NewReplacer(r, nil, "") - actualHeaders := w.Header() - - headerKey := "Merge-Me" - values, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey) - } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { - t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values) - } - - headerKey = "Add-Me" - if _, ok := actualHeaders[headerKey]; !ok { - t.Errorf("Downstream response does not contain expected %v header", headerKey) - } - - headerKey = "Remove-Me" - if _, ok := actualHeaders[headerKey]; ok { - t.Errorf("Downstream response should not contain %v header received from upstream", headerKey) - } - - headerKey = "Replace-Me" - headerValue := replacer.Replace("{hostname}") - value, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Downstream response should contain %v header and not remove it", headerKey) - } else if len(value) > 0 && headerValue != value[0] { - t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value) - } - -} - -func newFakeUpstream(name string, insecure bool) *fakeUpstream { - uri, _ := url.Parse(name) - u := &fakeUpstream{ - name: name, - host: &UpstreamHost{ - Name: name, - ReverseProxy: NewSingleHostReverseProxy(uri, ""), - }, - } - if insecure { - u.host.ReverseProxy.Transport = InsecureTransport - } - return u -} - -type fakeUpstream struct { - name string - host *UpstreamHost -} - -func (u *fakeUpstream) From() string { - return "/" -} - -func (u *fakeUpstream) Select() *UpstreamHost { - return u.host -} - -func (u *fakeUpstream) AllowedPath(requestPath string) bool { - return true -} - -// newWebSocketTestProxy returns a test proxy that will -// redirect to the specified backendAddr. The function -// also sets up the rules/environment for testing WebSocket -// proxy. -func newWebSocketTestProxy(backendAddr string) *Proxy { - return &Proxy{ - Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}}, - } -} - -func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { - return &Proxy{ - Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}}, - } -} - -type fakeWsUpstream struct { - name string - without string -} - -func (u *fakeWsUpstream) From() string { - return "/" -} - -func (u *fakeWsUpstream) Select() *UpstreamHost { - uri, _ := url.Parse(u.name) - return &UpstreamHost{ - Name: u.name, - ReverseProxy: NewSingleHostReverseProxy(uri, u.without), - UpstreamHeaders: http.Header{ - "Connection": {"{>Connection}"}, - "Upgrade": {"{>Upgrade}"}}, - } -} - -func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { - return true -} - -// recorderHijacker is a ResponseRecorder that can -// be hijacked. -type recorderHijacker struct { - *httptest.ResponseRecorder - fakeConn *fakeConn -} - -func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return rh.fakeConn, nil, nil -} - -type fakeConn struct { - readBuf bytes.Buffer - writeBuf bytes.Buffer -} - -func (c *fakeConn) LocalAddr() net.Addr { return nil } -func (c *fakeConn) RemoteAddr() net.Addr { return nil } -func (c *fakeConn) SetDeadline(t time.Time) error { return nil } -func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } -func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } -func (c *fakeConn) Close() error { return nil } -func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } -func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go deleted file mode 100644 index 5a14aea79..000000000 --- a/middleware/proxy/reverseproxy.go +++ /dev/null @@ -1,269 +0,0 @@ -// This file is adapted from code in the net/http/httputil -// package of the Go standard library, which is by the -// Go Authors, and bears this copyright and license info: -// -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -// -// This file has been modified from the standard lib to -// meet the needs of the application. - -package proxy - -import ( - "crypto/tls" - "io" - "net" - "net/http" - "net/url" - "strings" - "sync" - "time" -) - -// onExitFlushLoop is a callback set by tests to detect the state of the -// flushLoop() goroutine. -var onExitFlushLoop func() - -// ReverseProxy is an HTTP Handler that takes an incoming request and -// sends it to another server, proxying the response back to the -// client. -type ReverseProxy struct { - // Director must be a function which modifies - // the request into a new request to be sent - // using Transport. Its response is then copied - // back to the original client unmodified. - Director func(*http.Request) - - // The transport used to perform proxy requests. - // If nil, http.DefaultTransport is used. - Transport http.RoundTripper - - // FlushInterval specifies the flush interval - // to flush to the client while copying the - // response body. - // If zero, no periodic flushing is done. - FlushInterval time.Duration -} - -func singleJoiningSlash(a, b string) string { - aslash := strings.HasSuffix(a, "/") - bslash := strings.HasPrefix(b, "/") - switch { - case aslash && bslash: - return a + b[1:] - case !aslash && !bslash: - return a + "/" + b - } - return a + b -} - -// Though the relevant directive prefix is just "unix:", url.Parse -// will - assuming the regular URL scheme - add additional slashes -// as if "unix" was a request protocol. -// What we need is just the path, so if "unix:/var/run/www.socket" -// was the proxy directive, the parsed hostName would be -// "unix:///var/run/www.socket", hence the ambiguous trimming. -func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) { - return func(network, addr string) (conn net.Conn, err error) { - return net.Dial("unix", hostName[len("unix://"):]) - } -} - -// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites -// URLs to the scheme, host, and base path provided in target. If the -// target's path is "/base" and the incoming request was for "/dir", -// the target request will be for /base/dir. -// Without logic: target's path is "/", incoming is "/api/messages", -// without is "/api", then the target request will be for /messages. -func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { - targetQuery := target.RawQuery - director := func(req *http.Request) { - if target.Scheme == "unix" { - // to make Dial work with unix URL, - // scheme and host have to be faked - req.URL.Scheme = "http" - req.URL.Host = "socket" - } else { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - } - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery - } - // Trims the path of the socket from the URL path. - // This is done because req.URL passed to your proxied service - // will have the full path of the socket file prefixed to it. - // Calling /test on a server that proxies requests to - // unix:/var/run/www.socket will thus set the requested path - // to /var/run/www.socket/test, rendering paths useless. - if target.Scheme == "unix" { - // See comment on socketDial for the trim - socketPrefix := target.String()[len("unix://"):] - req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix) - } - // We are then safe to remove the `without` prefix. - if without != "" { - req.URL.Path = strings.TrimPrefix(req.URL.Path, without) - } - } - rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events - if target.Scheme == "unix" { - rp.Transport = &http.Transport{ - Dial: socketDial(target.String()), - } - } - return rp -} - -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } - } -} - -// Hop-by-hop headers. These are removed when sent to the backend. -// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html -var hopHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailers", - "Transfer-Encoding", - "Upgrade", -} - -// InsecureTransport is used to facilitate HTTPS proxying -// when it is OK for upstream to be using a bad certificate, -// since this transport skips verification. -var InsecureTransport http.RoundTripper = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, -} - -type respUpdateFn func(resp *http.Response) - -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { - transport := p.Transport - if transport == nil { - transport = http.DefaultTransport - } - - p.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 - outreq.Close = false - - res, err := transport.RoundTrip(outreq) - if err != nil { - return err - } else if respUpdateFn != nil { - respUpdateFn(res) - } - - if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { - res.Body.Close() - hj, ok := rw.(http.Hijacker) - if !ok { - return nil - } - - conn, _, err := hj.Hijack() - if err != nil { - return err - } - defer conn.Close() - - backendConn, err := net.Dial("tcp", outreq.URL.Host) - if err != nil { - return err - } - defer backendConn.Close() - - outreq.Write(backendConn) - - go func() { - io.Copy(backendConn, conn) // write tcp stream to backend. - }() - io.Copy(conn, backendConn) // read tcp stream from backend. - } else { - defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) - } - copyHeader(rw.Header(), res.Header) - rw.WriteHeader(res.StatusCode) - p.copyResponse(rw, res.Body) - } - - return nil -} - -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { - if p.FlushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: p.FlushInterval, - done: make(chan bool), - } - go mlw.flushLoop() - defer mlw.stop() - dst = mlw - } - } - io.Copy(dst, src) -} - -type writeFlusher interface { - io.Writer - http.Flusher -} - -type maxLatencyWriter struct { - dst writeFlusher - latency time.Duration - - lk sync.Mutex // protects Write + Flush - done chan bool -} - -func (m *maxLatencyWriter) Write(p []byte) (int, error) { - m.lk.Lock() - defer m.lk.Unlock() - return m.dst.Write(p) -} - -func (m *maxLatencyWriter) flushLoop() { - t := time.NewTicker(m.latency) - defer t.Stop() - for { - select { - case <-m.done: - if onExitFlushLoop != nil { - onExitFlushLoop() - } - return - case <-t.C: - m.lk.Lock() - m.dst.Flush() - m.lk.Unlock() - } - } -} - -func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go deleted file mode 100644 index a1d9fcfce..000000000 --- a/middleware/proxy/upstream.go +++ /dev/null @@ -1,345 +0,0 @@ -package proxy - -import ( - "fmt" - "io" - "io/ioutil" - "net/http" - "net/url" - "path" - "strconv" - "strings" - "time" - - "github.com/mholt/caddy/caddy/parse" - "github.com/mholt/caddy/middleware" -) - -var ( - supportedPolicies = make(map[string]func() Policy) -) - -type staticUpstream struct { - from string - upstreamHeaders http.Header - downstreamHeaders http.Header - Hosts HostPool - Policy Policy - insecureSkipVerify bool - - FailTimeout time.Duration - MaxFails int32 - MaxConns int64 - HealthCheck struct { - Path string - Interval time.Duration - } - WithoutPathPrefix string - IgnoredSubPaths []string -} - -// NewStaticUpstreams parses the configuration input and sets up -// static upstreams for the proxy middleware. -func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { - var upstreams []Upstream - for c.Next() { - upstream := &staticUpstream{ - from: "", - upstreamHeaders: make(http.Header), - downstreamHeaders: make(http.Header), - Hosts: nil, - Policy: &Random{}, - FailTimeout: 10 * time.Second, - MaxFails: 1, - MaxConns: 0, - } - - if !c.Args(&upstream.from) { - return upstreams, c.ArgErr() - } - - var to []string - for _, t := range c.RemainingArgs() { - parsed, err := parseUpstream(t) - if err != nil { - return upstreams, err - } - to = append(to, parsed...) - } - - for c.NextBlock() { - switch c.Val() { - case "upstream": - if !c.NextArg() { - return upstreams, c.ArgErr() - } - parsed, err := parseUpstream(c.Val()) - if err != nil { - return upstreams, err - } - to = append(to, parsed...) - default: - if err := parseBlock(&c, upstream); err != nil { - return upstreams, err - } - } - } - - if len(to) == 0 { - return upstreams, c.ArgErr() - } - - upstream.Hosts = make([]*UpstreamHost, len(to)) - for i, host := range to { - uh, err := upstream.NewHost(host) - if err != nil { - return upstreams, err - } - upstream.Hosts[i] = uh - } - - if upstream.HealthCheck.Path != "" { - go upstream.HealthCheckWorker(nil) - } - upstreams = append(upstreams, upstream) - } - return upstreams, nil -} - -// RegisterPolicy adds a custom policy to the proxy. -func RegisterPolicy(name string, policy func() Policy) { - supportedPolicies[name] = policy -} - -func (u *staticUpstream) From() string { - return u.from -} - -func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { - if !strings.HasPrefix(host, "http") && - !strings.HasPrefix(host, "unix:") { - host = "http://" + host - } - uh := &UpstreamHost{ - Name: host, - Conns: 0, - Fails: 0, - FailTimeout: u.FailTimeout, - Unhealthy: false, - UpstreamHeaders: u.upstreamHeaders, - DownstreamHeaders: u.downstreamHeaders, - CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { - return func(uh *UpstreamHost) bool { - if uh.Unhealthy { - return true - } - if uh.Fails >= u.MaxFails && - u.MaxFails != 0 { - return true - } - return false - } - }(u), - WithoutPathPrefix: u.WithoutPathPrefix, - MaxConns: u.MaxConns, - } - - baseURL, err := url.Parse(uh.Name) - if err != nil { - return nil, err - } - - uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix) - if u.insecureSkipVerify { - uh.ReverseProxy.Transport = InsecureTransport - } - return uh, nil -} - -func parseUpstream(u string) ([]string, error) { - if !strings.HasPrefix(u, "unix:") { - colonIdx := strings.LastIndex(u, ":") - protoIdx := strings.Index(u, "://") - - if colonIdx != -1 && colonIdx != protoIdx { - us := u[:colonIdx] - ports := u[len(us)+1:] - if separators := strings.Count(ports, "-"); separators > 1 { - return nil, fmt.Errorf("port range [%s] is invalid", ports) - } else if separators == 1 { - portsStr := strings.Split(ports, "-") - pIni, err := strconv.Atoi(portsStr[0]) - if err != nil { - return nil, err - } - - pEnd, err := strconv.Atoi(portsStr[1]) - if err != nil { - return nil, err - } - - if pEnd <= pIni { - return nil, fmt.Errorf("port range [%s] is invalid", ports) - } - - hosts := []string{} - for p := pIni; p <= pEnd; p++ { - hosts = append(hosts, fmt.Sprintf("%s:%d", us, p)) - } - return hosts, nil - } - } - } - - return []string{u}, nil - -} - -func parseBlock(c *parse.Dispenser, u *staticUpstream) error { - switch c.Val() { - case "policy": - if !c.NextArg() { - return c.ArgErr() - } - policyCreateFunc, ok := supportedPolicies[c.Val()] - if !ok { - return c.ArgErr() - } - u.Policy = policyCreateFunc() - case "fail_timeout": - if !c.NextArg() { - return c.ArgErr() - } - dur, err := time.ParseDuration(c.Val()) - if err != nil { - return err - } - u.FailTimeout = dur - case "max_fails": - if !c.NextArg() { - return c.ArgErr() - } - n, err := strconv.Atoi(c.Val()) - if err != nil { - return err - } - u.MaxFails = int32(n) - case "max_conns": - if !c.NextArg() { - return c.ArgErr() - } - n, err := strconv.ParseInt(c.Val(), 10, 64) - if err != nil { - return err - } - u.MaxConns = n - case "health_check": - if !c.NextArg() { - return c.ArgErr() - } - u.HealthCheck.Path = c.Val() - u.HealthCheck.Interval = 30 * time.Second - if c.NextArg() { - dur, err := time.ParseDuration(c.Val()) - if err != nil { - return err - } - u.HealthCheck.Interval = dur - } - case "header_upstream": - fallthrough - case "proxy_header": - var header, value string - if !c.Args(&header, &value) { - return c.ArgErr() - } - u.upstreamHeaders.Add(header, value) - case "header_downstream": - var header, value string - if !c.Args(&header, &value) { - return c.ArgErr() - } - u.downstreamHeaders.Add(header, value) - case "websocket": - u.upstreamHeaders.Add("Connection", "{>Connection}") - u.upstreamHeaders.Add("Upgrade", "{>Upgrade}") - case "without": - if !c.NextArg() { - return c.ArgErr() - } - u.WithoutPathPrefix = c.Val() - case "except": - ignoredPaths := c.RemainingArgs() - if len(ignoredPaths) == 0 { - return c.ArgErr() - } - u.IgnoredSubPaths = ignoredPaths - case "insecure_skip_verify": - u.insecureSkipVerify = true - default: - return c.Errf("unknown property '%s'", c.Val()) - } - return nil -} - -func (u *staticUpstream) healthCheck() { - for _, host := range u.Hosts { - hostURL := host.Name + u.HealthCheck.Path - if r, err := http.Get(hostURL); err == nil { - io.Copy(ioutil.Discard, r.Body) - r.Body.Close() - host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 - } else { - host.Unhealthy = true - } - } -} - -func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { - ticker := time.NewTicker(u.HealthCheck.Interval) - u.healthCheck() - for { - select { - case <-ticker.C: - u.healthCheck() - case <-stop: - // TODO: the library should provide a stop channel and global - // waitgroup to allow goroutines started by plugins a chance - // to clean themselves up. - } - } -} - -func (u *staticUpstream) Select() *UpstreamHost { - pool := u.Hosts - if len(pool) == 1 { - if !pool[0].Available() { - return nil - } - return pool[0] - } - allUnavailable := true - for _, host := range pool { - if host.Available() { - allUnavailable = false - break - } - } - if allUnavailable { - return nil - } - - if u.Policy == nil { - return (&Random{}).Select(pool) - } - return u.Policy.Select(pool) -} - -func (u *staticUpstream) AllowedPath(requestPath string) bool { - for _, ignoredSubPath := range u.IgnoredSubPaths { - if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) { - return false - } - } - return true -} diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go deleted file mode 100644 index 9d38b785f..000000000 --- a/middleware/proxy/upstream_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package proxy - -import ( - "testing" - "time" -) - -func TestNewHost(t *testing.T) { - upstream := &staticUpstream{ - FailTimeout: 10 * time.Second, - MaxConns: 1, - MaxFails: 1, - } - - uh, err := upstream.NewHost("example.com") - if err != nil { - t.Error("Expected no error") - } - if uh.Name != "http://example.com" { - t.Error("Expected default schema to be added to Name.") - } - if uh.FailTimeout != upstream.FailTimeout { - t.Error("Expected default FailTimeout to be set.") - } - if uh.MaxConns != upstream.MaxConns { - t.Error("Expected default MaxConns to be set.") - } - if uh.CheckDown == nil { - t.Error("Expected default CheckDown to be set.") - } - if uh.CheckDown(uh) { - t.Error("Expected new host not to be down.") - } - // mark Unhealthy - uh.Unhealthy = true - if !uh.CheckDown(uh) { - t.Error("Expected unhealthy host to be down.") - } - // mark with Fails - uh.Unhealthy = false - uh.Fails = 1 - if !uh.CheckDown(uh) { - t.Error("Expected failed host to be down.") - } -} - -func TestHealthCheck(t *testing.T) { - upstream := &staticUpstream{ - from: "", - Hosts: testPool(), - Policy: &Random{}, - FailTimeout: 10 * time.Second, - MaxFails: 1, - } - upstream.healthCheck() - if upstream.Hosts[0].Down() { - t.Error("Expected first host in testpool to not fail healthcheck.") - } - if !upstream.Hosts[1].Down() { - t.Error("Expected second host in testpool to fail healthcheck.") - } -} - -func TestSelect(t *testing.T) { - upstream := &staticUpstream{ - from: "", - Hosts: testPool()[:3], - Policy: &Random{}, - FailTimeout: 10 * time.Second, - MaxFails: 1, - } - upstream.Hosts[0].Unhealthy = true - upstream.Hosts[1].Unhealthy = true - upstream.Hosts[2].Unhealthy = true - if h := upstream.Select(); h != nil { - t.Error("Expected select to return nil as all host are down") - } - upstream.Hosts[2].Unhealthy = false - if h := upstream.Select(); h == nil { - t.Error("Expected select to not return nil") - } - upstream.Hosts[0].Conns = 1 - upstream.Hosts[0].MaxConns = 1 - upstream.Hosts[1].Conns = 1 - upstream.Hosts[1].MaxConns = 1 - upstream.Hosts[2].Conns = 1 - upstream.Hosts[2].MaxConns = 1 - if h := upstream.Select(); h != nil { - t.Error("Expected select to return nil as all hosts are full") - } - upstream.Hosts[2].Conns = 0 - if h := upstream.Select(); h == nil { - t.Error("Expected select to not return nil") - } -} - -func TestRegisterPolicy(t *testing.T) { - name := "custom" - customPolicy := &customPolicy{} - RegisterPolicy(name, func() Policy { return customPolicy }) - if _, ok := supportedPolicies[name]; !ok { - t.Error("Expected supportedPolicies to have a custom policy.") - } - -} - -func TestAllowedPaths(t *testing.T) { - upstream := &staticUpstream{ - from: "/proxy", - IgnoredSubPaths: []string{"/download", "/static"}, - } - tests := []struct { - url string - expected bool - }{ - {"/proxy", true}, - {"/proxy/dl", true}, - {"/proxy/download", false}, - {"/proxy/download/static", false}, - {"/proxy/static", false}, - {"/proxy/static/download", false}, - {"/proxy/something/download", true}, - {"/proxy/something/static", true}, - {"/proxy//static", false}, - {"/proxy//static//download", false}, - {"/proxy//download", false}, - } - - for i, test := range tests { - allowed := upstream.AllowedPath(test.url) - if test.expected != allowed { - t.Errorf("Test %d: expected %v found %v", i+1, test.expected, allowed) - } - } -} diff --git a/middleware/redirect/redirect.go b/middleware/redirect/redirect.go deleted file mode 100644 index 04fb1c63a..000000000 --- a/middleware/redirect/redirect.go +++ /dev/null @@ -1,57 +0,0 @@ -// Package redirect is middleware for redirecting certain requests -// to other locations. -package redirect - -import ( - "fmt" - "html" - "net/http" - - "github.com/mholt/caddy/middleware" -) - -// Redirect is middleware to respond with HTTP redirects -type Redirect struct { - Next middleware.Handler - Rules []Rule -} - -// ServeHTTP implements the middleware.Handler interface. -func (rd Redirect) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - for _, rule := range rd.Rules { - if (rule.FromPath == "/" || r.URL.Path == rule.FromPath) && schemeMatches(rule, r) { - to := middleware.NewReplacer(r, nil, "").Replace(rule.To) - if rule.Meta { - safeTo := html.EscapeString(to) - fmt.Fprintf(w, metaRedir, safeTo, safeTo) - } else { - http.Redirect(w, r, to, rule.Code) - } - return 0, nil - } - } - return rd.Next.ServeHTTP(w, r) -} - -func schemeMatches(rule Rule, req *http.Request) bool { - return (rule.FromScheme == "https" && req.TLS != nil) || - (rule.FromScheme != "https" && req.TLS == nil) -} - -// Rule describes an HTTP redirect rule. -type Rule struct { - FromScheme, FromPath, To string - Code int - Meta bool -} - -// Script tag comes first since that will better imitate a redirect in the browser's -// history, but the meta tag is a fallback for most non-JS clients. -const metaRedir = ` - - - - - - Redirecting... -` diff --git a/middleware/redirect/redirect_test.go b/middleware/redirect/redirect_test.go deleted file mode 100644 index 3107921af..000000000 --- a/middleware/redirect/redirect_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package redirect - -import ( - "bytes" - "crypto/tls" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestRedirect(t *testing.T) { - for i, test := range []struct { - from string - expectedLocation string - expectedCode int - }{ - {"http://localhost/from", "/to", http.StatusMovedPermanently}, - {"http://localhost/a", "/b", http.StatusTemporaryRedirect}, - {"http://localhost/aa", "", http.StatusOK}, - {"http://localhost/", "", http.StatusOK}, - {"http://localhost/a?foo=bar", "/b", http.StatusTemporaryRedirect}, - {"http://localhost/asdf?foo=bar", "", http.StatusOK}, - {"http://localhost/foo#bar", "", http.StatusOK}, - {"http://localhost/a#foo", "/b", http.StatusTemporaryRedirect}, - - // The scheme checks that were added to this package don't actually - // help with redirects because of Caddy's design: a redirect middleware - // for http will always be different than the redirect middleware for - // https because they have to be on different listeners. These tests - // just go to show extra bulletproofing, I guess. - {"http://localhost/scheme", "https://localhost/scheme", http.StatusMovedPermanently}, - {"https://localhost/scheme", "", http.StatusOK}, - {"https://localhost/scheme2", "http://localhost/scheme2", http.StatusMovedPermanently}, - {"http://localhost/scheme2", "", http.StatusOK}, - {"http://localhost/scheme3", "https://localhost/scheme3", http.StatusMovedPermanently}, - {"https://localhost/scheme3", "", http.StatusOK}, - } { - var nextCalled bool - - re := Redirect{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - nextCalled = true - return 0, nil - }), - Rules: []Rule{ - {FromPath: "/from", To: "/to", Code: http.StatusMovedPermanently}, - {FromPath: "/a", To: "/b", Code: http.StatusTemporaryRedirect}, - - // These http and https schemes would never actually be mixed in the same - // redirect rule with Caddy because http and https schemes have different listeners, - // so they don't share a redirect rule. So although these tests prove something - // impossible with Caddy, it's extra bulletproofing at very little cost. - {FromScheme: "http", FromPath: "/scheme", To: "https://localhost/scheme", Code: http.StatusMovedPermanently}, - {FromScheme: "https", FromPath: "/scheme2", To: "http://localhost/scheme2", Code: http.StatusMovedPermanently}, - {FromScheme: "", FromPath: "/scheme3", To: "https://localhost/scheme3", Code: http.StatusMovedPermanently}, - }, - } - - req, err := http.NewRequest("GET", test.from, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) - } - if strings.HasPrefix(test.from, "https://") { - req.TLS = new(tls.ConnectionState) // faux HTTPS - } - - rec := httptest.NewRecorder() - re.ServeHTTP(rec, req) - - if rec.Header().Get("Location") != test.expectedLocation { - t.Errorf("Test %d: Expected Location header to be %q but was %q", - i, test.expectedLocation, rec.Header().Get("Location")) - } - - if rec.Code != test.expectedCode { - t.Errorf("Test %d: Expected status code to be %d but was %d", - i, test.expectedCode, rec.Code) - } - - if nextCalled && test.expectedLocation != "" { - t.Errorf("Test %d: Next handler was unexpectedly called", i) - } - } -} - -func TestParametersRedirect(t *testing.T) { - re := Redirect{ - Rules: []Rule{ - {FromPath: "/", Meta: false, To: "http://example.com{uri}"}, - }, - } - - req, err := http.NewRequest("GET", "/a?b=c", nil) - if err != nil { - t.Fatalf("Test: Could not create HTTP request: %v", err) - } - - rec := httptest.NewRecorder() - re.ServeHTTP(rec, req) - - if rec.Header().Get("Location") != "http://example.com/a?b=c" { - t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a?b=c", rec.Header().Get("Location")) - } - - re = Redirect{ - Rules: []Rule{ - {FromPath: "/", Meta: false, To: "http://example.com/a{path}?b=c&{query}"}, - }, - } - - req, err = http.NewRequest("GET", "/d?e=f", nil) - if err != nil { - t.Fatalf("Test: Could not create HTTP request: %v", err) - } - - re.ServeHTTP(rec, req) - - if "http://example.com/a/d?b=c&e=f" != rec.Header().Get("Location") { - t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a/d?b=c&e=f", rec.Header().Get("Location")) - } -} - -func TestMetaRedirect(t *testing.T) { - re := Redirect{ - Rules: []Rule{ - {FromPath: "/whatever", Meta: true, To: "/something"}, - {FromPath: "/", Meta: true, To: "https://example.com/"}, - }, - } - - for i, test := range re.Rules { - req, err := http.NewRequest("GET", test.FromPath, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) - } - - rec := httptest.NewRecorder() - re.ServeHTTP(rec, req) - - body, err := ioutil.ReadAll(rec.Body) - if err != nil { - t.Fatalf("Test %d: Could not read HTTP response body: %v", i, err) - } - expectedSnippet := `` - if !bytes.Contains(body, []byte(expectedSnippet)) { - t.Errorf("Test %d: Expected Response Body to contain %q but was %q", - i, expectedSnippet, body) - } - } -} diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go deleted file mode 100644 index 1431afc9c..000000000 --- a/middleware/rewrite/condition.go +++ /dev/null @@ -1,130 +0,0 @@ -package rewrite - -import ( - "fmt" - "net/http" - "regexp" - "strings" - - "github.com/mholt/caddy/middleware" -) - -// Operators -const ( - Is = "is" - Not = "not" - Has = "has" - NotHas = "not_has" - StartsWith = "starts_with" - EndsWith = "ends_with" - Match = "match" - NotMatch = "not_match" -) - -func operatorError(operator string) error { - return fmt.Errorf("Invalid operator %v", operator) -} - -func newReplacer(r *http.Request) middleware.Replacer { - return middleware.NewReplacer(r, nil, "") -} - -// condition is a rewrite condition. -type condition func(string, string) bool - -var conditions = map[string]condition{ - Is: isFunc, - Not: notFunc, - Has: hasFunc, - NotHas: notHasFunc, - StartsWith: startsWithFunc, - EndsWith: endsWithFunc, - Match: matchFunc, - NotMatch: notMatchFunc, -} - -// isFunc is condition for Is operator. -// It checks for equality. -func isFunc(a, b string) bool { - return a == b -} - -// notFunc is condition for Not operator. -// It checks for inequality. -func notFunc(a, b string) bool { - return a != b -} - -// hasFunc is condition for Has operator. -// It checks if b is a substring of a. -func hasFunc(a, b string) bool { - return strings.Contains(a, b) -} - -// notHasFunc is condition for NotHas operator. -// It checks if b is not a substring of a. -func notHasFunc(a, b string) bool { - return !strings.Contains(a, b) -} - -// startsWithFunc is condition for StartsWith operator. -// It checks if b is a prefix of a. -func startsWithFunc(a, b string) bool { - return strings.HasPrefix(a, b) -} - -// endsWithFunc is condition for EndsWith operator. -// It checks if b is a suffix of a. -func endsWithFunc(a, b string) bool { - return strings.HasSuffix(a, b) -} - -// matchFunc is condition for Match operator. -// It does regexp matching of a against pattern in b -// and returns if they match. -func matchFunc(a, b string) bool { - matched, _ := regexp.MatchString(b, a) - return matched -} - -// notMatchFunc is condition for NotMatch operator. -// It does regexp matching of a against pattern in b -// and returns if they do not match. -func notMatchFunc(a, b string) bool { - matched, _ := regexp.MatchString(b, a) - return !matched -} - -// If is statement for a rewrite condition. -type If struct { - A string - Operator string - B string -} - -// True returns true if the condition is true and false otherwise. -// If r is not nil, it replaces placeholders before comparison. -func (i If) True(r *http.Request) bool { - if c, ok := conditions[i.Operator]; ok { - a, b := i.A, i.B - if r != nil { - replacer := newReplacer(r) - a = replacer.Replace(i.A) - b = replacer.Replace(i.B) - } - return c(a, b) - } - return false -} - -// NewIf creates a new If condition. -func NewIf(a, operator, b string) (If, error) { - if _, ok := conditions[operator]; !ok { - return If{}, operatorError(operator) - } - return If{ - A: a, - Operator: operator, - B: b, - }, nil -} diff --git a/middleware/rewrite/condition_test.go b/middleware/rewrite/condition_test.go deleted file mode 100644 index 3c3b6053a..000000000 --- a/middleware/rewrite/condition_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package rewrite - -import ( - "net/http" - "strings" - "testing" -) - -func TestConditions(t *testing.T) { - tests := []struct { - condition string - isTrue bool - }{ - {"a is b", false}, - {"a is a", true}, - {"a not b", true}, - {"a not a", false}, - {"a has a", true}, - {"a has b", false}, - {"ba has b", true}, - {"bab has b", true}, - {"bab has bb", false}, - {"a not_has a", false}, - {"a not_has b", true}, - {"ba not_has b", false}, - {"bab not_has b", false}, - {"bab not_has bb", true}, - {"bab starts_with bb", false}, - {"bab starts_with ba", true}, - {"bab starts_with bab", true}, - {"bab ends_with bb", false}, - {"bab ends_with bab", true}, - {"bab ends_with ab", true}, - {"a match *", false}, - {"a match a", true}, - {"a match .*", true}, - {"a match a.*", true}, - {"a match b.*", false}, - {"ba match b.*", true}, - {"ba match b[a-z]", true}, - {"b0 match b[a-z]", false}, - {"b0a match b[a-z]", false}, - {"b0a match b[a-z]+", false}, - {"b0a match b[a-z0-9]+", true}, - {"a not_match *", true}, - {"a not_match a", false}, - {"a not_match .*", false}, - {"a not_match a.*", false}, - {"a not_match b.*", true}, - {"ba not_match b.*", false}, - {"ba not_match b[a-z]", false}, - {"b0 not_match b[a-z]", true}, - {"b0a not_match b[a-z]", true}, - {"b0a not_match b[a-z]+", true}, - {"b0a not_match b[a-z0-9]+", false}, - } - - for i, test := range tests { - str := strings.Fields(test.condition) - ifCond, err := NewIf(str[0], str[1], str[2]) - if err != nil { - t.Error(err) - } - isTrue := ifCond.True(nil) - if isTrue != test.isTrue { - t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) - } - } - - invalidOperators := []string{"ss", "and", "if"} - for _, op := range invalidOperators { - _, err := NewIf("a", op, "b") - if err == nil { - t.Errorf("Invalid operator %v used, expected error.", op) - } - } - - replaceTests := []struct { - url string - condition string - isTrue bool - }{ - {"/home", "{uri} match /home", true}, - {"/hom", "{uri} match /home", false}, - {"/hom", "{uri} starts_with /home", false}, - {"/hom", "{uri} starts_with /h", true}, - {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true}, - {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true}, - } - - for i, test := range replaceTests { - r, err := http.NewRequest("GET", test.url, nil) - if err != nil { - t.Error(err) - } - str := strings.Fields(test.condition) - ifCond, err := NewIf(str[0], str[1], str[2]) - if err != nil { - t.Error(err) - } - isTrue := ifCond.True(r) - if isTrue != test.isTrue { - t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) - } - } -} diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go deleted file mode 100644 index 1c2e26006..000000000 --- a/middleware/rewrite/rewrite.go +++ /dev/null @@ -1,236 +0,0 @@ -// Package rewrite is middleware for rewriting requests internally to -// a different path. -package rewrite - -import ( - "fmt" - "net/http" - "net/url" - "path" - "path/filepath" - "regexp" - "strings" - - "github.com/mholt/caddy/middleware" -) - -// Result is the result of a rewrite -type Result int - -const ( - // RewriteIgnored is returned when rewrite is not done on request. - RewriteIgnored Result = iota - // RewriteDone is returned when rewrite is done on request. - RewriteDone - // RewriteStatus is returned when rewrite is not needed and status code should be set - // for the request. - RewriteStatus -) - -// Rewrite is middleware to rewrite request locations internally before being handled. -type Rewrite struct { - Next middleware.Handler - FileSys http.FileSystem - Rules []Rule -} - -// ServeHTTP implements the middleware.Handler interface. -func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { -outer: - for _, rule := range rw.Rules { - switch result := rule.Rewrite(rw.FileSys, r); result { - case RewriteDone: - break outer - case RewriteIgnored: - break - case RewriteStatus: - // only valid for complex rules. - if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 { - return cRule.Status, nil - } - } - } - return rw.Next.ServeHTTP(w, r) -} - -// Rule describes an internal location rewrite rule. -type Rule interface { - // Rewrite rewrites the internal location of the current request. - Rewrite(http.FileSystem, *http.Request) Result -} - -// SimpleRule is a simple rewrite rule. -type SimpleRule struct { - From, To string -} - -// NewSimpleRule creates a new Simple Rule -func NewSimpleRule(from, to string) SimpleRule { - return SimpleRule{from, to} -} - -// Rewrite rewrites the internal location of the current request. -func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result { - if s.From == r.URL.Path { - // take note of this rewrite for internal use by fastcgi - // all we need is the URI, not full URL - r.Header.Set(headerFieldName, r.URL.RequestURI()) - - // attempt rewrite - return To(fs, r, s.To, newReplacer(r)) - } - return RewriteIgnored -} - -// ComplexRule is a rewrite rule based on a regular expression -type ComplexRule struct { - // Path base. Request to this path and subpaths will be rewritten - Base string - - // Path to rewrite to - To string - - // If set, neither performs rewrite nor proceeds - // with request. Only returns code. - Status int - - // Extensions to filter by - Exts []string - - // Rewrite conditions - Ifs []If - - *regexp.Regexp -} - -// NewComplexRule creates a new RegexpRule. It returns an error if regexp -// pattern (pattern) or extensions (ext) are invalid. -func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { - // validate regexp if present - var r *regexp.Regexp - if pattern != "" { - var err error - r, err = regexp.Compile(pattern) - if err != nil { - return nil, err - } - } - - // validate extensions if present - for _, v := range ext { - if len(v) < 2 || (len(v) < 3 && v[0] == '!') { - // check if no extension is specified - if v != "/" && v != "!/" { - return nil, fmt.Errorf("invalid extension %v", v) - } - } - } - - return &ComplexRule{ - Base: base, - To: to, - Status: status, - Exts: ext, - Ifs: ifs, - Regexp: r, - }, nil -} - -// Rewrite rewrites the internal location of the current request. -func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) { - rPath := req.URL.Path - replacer := newReplacer(req) - - // validate base - if !middleware.Path(rPath).Matches(r.Base) { - return - } - - // validate extensions - if !r.matchExt(rPath) { - return - } - - // validate regexp if present - if r.Regexp != nil { - // include trailing slash in regexp if present - start := len(r.Base) - if strings.HasSuffix(r.Base, "/") { - start-- - } - - matches := r.FindStringSubmatch(rPath[start:]) - switch len(matches) { - case 0: - // no match - return - default: - // set regexp match variables {1}, {2} ... - - // url escaped values of ? and #. - q, f := url.QueryEscape("?"), url.QueryEscape("#") - - for i := 1; i < len(matches); i++ { - // Special case of unescaped # and ? by stdlib regexp. - // Reverse the unescape. - if strings.ContainsAny(matches[i], "?#") { - matches[i] = strings.NewReplacer("?", q, "#", f).Replace(matches[i]) - } - - replacer.Set(fmt.Sprint(i), matches[i]) - } - } - } - - // validate rewrite conditions - for _, i := range r.Ifs { - if !i.True(req) { - return - } - } - - // if status is present, stop rewrite and return it. - if r.Status != 0 { - return RewriteStatus - } - - // attempt rewrite - return To(fs, req, r.To, replacer) -} - -// matchExt matches rPath against registered file extensions. -// Returns true if a match is found and false otherwise. -func (r *ComplexRule) matchExt(rPath string) bool { - f := filepath.Base(rPath) - ext := path.Ext(f) - if ext == "" { - ext = "/" - } - - mustUse := false - for _, v := range r.Exts { - use := true - if v[0] == '!' { - use = false - v = v[1:] - } - - if use { - mustUse = true - } - - if ext == v { - return use - } - } - - if mustUse { - return false - } - return true -} - -// When a rewrite is performed, this header is added to the request -// and is for internal use only, specifically the fastcgi middleware. -// It contains the original request URI before the rewrite. -const headerFieldName = "Caddy-Rewrite-Original-URI" diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go deleted file mode 100644 index 2baf91219..000000000 --- a/middleware/rewrite/rewrite_test.go +++ /dev/null @@ -1,163 +0,0 @@ -package rewrite - -import ( - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func TestRewrite(t *testing.T) { - rw := Rewrite{ - Next: middleware.HandlerFunc(urlPrinter), - Rules: []Rule{ - NewSimpleRule("/from", "/to"), - NewSimpleRule("/a", "/b"), - NewSimpleRule("/b", "/b{uri}"), - }, - FileSys: http.Dir("."), - } - - regexps := [][]string{ - {"/reg/", ".*", "/to", ""}, - {"/r/", "[a-z]+", "/toaz", "!.html|"}, - {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, - {"/ab/", "ab", "/ab?{query}", ".txt|"}, - {"/ab/", "ab", "/ab?type=html&{query}", ".html|"}, - {"/abc/", "ab", "/abc/{file}", ".html|"}, - {"/abcd/", "ab", "/a/{dir}/{file}", ".html|"}, - {"/abcde/", "ab", "/a#{fragment}", ".html|"}, - {"/ab/", `.*\.jpg`, "/ajpg", ""}, - {"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""}, - {"/reg2grp", `(.*)`, "/{1}", ""}, - {"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""}, - {"/hashtest", "(.*)", "/{1}", ""}, - } - - for _, regexpRule := range regexps { - var ext []string - if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { - ext = s[:len(s)-1] - } - rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil) - if err != nil { - t.Fatal(err) - } - rw.Rules = append(rw.Rules, rule) - } - - tests := []struct { - from string - expectedTo string - }{ - {"/from", "/to"}, - {"/a", "/b"}, - {"/b", "/b/b"}, - {"/aa", "/aa"}, - {"/", "/"}, - {"/a?foo=bar", "/b?foo=bar"}, - {"/asdf?foo=bar", "/asdf?foo=bar"}, - {"/foo#bar", "/foo#bar"}, - {"/a#foo", "/b#foo"}, - {"/reg/foo", "/to"}, - {"/re", "/re"}, - {"/r/", "/r/"}, - {"/r/123", "/r/123"}, - {"/r/a123", "/toaz"}, - {"/r/abcz", "/toaz"}, - {"/r/z", "/toaz"}, - {"/r/z.html", "/r/z.html"}, - {"/r/z.js", "/toaz"}, - {"/url/asAB", "/to/url/asAB"}, - {"/url/aBsAB", "/url/aBsAB"}, - {"/url/a00sAB", "/to/url/a00sAB"}, - {"/url/a0z0sAB", "/to/url/a0z0sAB"}, - {"/ab/aa", "/ab/aa"}, - {"/ab/ab", "/ab/ab"}, - {"/ab/ab.txt", "/ab"}, - {"/ab/ab.txt?name=name", "/ab?name=name"}, - {"/ab/ab.html?name=name", "/ab?type=html&name=name"}, - {"/abc/ab.html", "/abc/ab.html"}, - {"/abcd/abcd.html", "/a/abcd/abcd.html"}, - {"/abcde/abcde.html", "/a"}, - {"/abcde/abcde.html#1234", "/a#1234"}, - {"/ab/ab.jpg", "/ajpg"}, - {"/reggrp/ad/12", "/a12"}, - {"/reggrp/ad/124a", "/a124/a"}, - {"/reggrp/ad/124abc", "/a124/abc"}, - {"/reg2grp/ad/124abc", "/ad/124abc"}, - {"/reg3grp/ad/aa/66", "/adaa66"}, - {"/reg3grp/ad612/n1n/ab", "/ad612n1nab"}, - {"/hashtest/a%20%23%20test", "/a%20%23%20test"}, - {"/hashtest/a%20%3F%20test", "/a%20%3F%20test"}, - {"/hashtest/a%20%3F%23test", "/a%20%3F%23test"}, - } - - for i, test := range tests { - req, err := http.NewRequest("GET", test.from, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) - } - - rec := httptest.NewRecorder() - rw.ServeHTTP(rec, req) - - if rec.Body.String() != test.expectedTo { - t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", - i, test.expectedTo, rec.Body.String()) - } - } - - statusTests := []struct { - status int - base string - to string - regexp string - statusExpected bool - }{ - {400, "/status", "", "", true}, - {400, "/ignore", "", "", false}, - {400, "/", "", "^/ignore", false}, - {400, "/", "", "(.*)", true}, - {400, "/status", "", "", true}, - } - - for i, s := range statusTests { - urlPath := fmt.Sprintf("/status%d", i) - rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil) - if err != nil { - t.Fatalf("Test %d: No error expected for rule but found %v", i, err) - } - rw.Rules = []Rule{rule} - req, err := http.NewRequest("GET", urlPath, nil) - if err != nil { - t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) - } - - rec := httptest.NewRecorder() - code, err := rw.ServeHTTP(rec, req) - if err != nil { - t.Fatalf("Test %d: No error expected for handler but found %v", i, err) - } - if s.statusExpected { - if rec.Body.String() != "" { - t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String()) - } - if code != s.status { - t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code) - } - } else { - if code != 0 { - t.Errorf("Test %d: Expected no status code found %d", i, code) - } - } - } -} - -func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { - fmt.Fprint(w, r.URL.String()) - return 0, nil -} diff --git a/middleware/rewrite/testdata/testdir/empty b/middleware/rewrite/testdata/testdir/empty deleted file mode 100644 index e69de29bb..000000000 diff --git a/middleware/rewrite/testdata/testfile b/middleware/rewrite/testdata/testfile deleted file mode 100644 index 7b4d68d70..000000000 --- a/middleware/rewrite/testdata/testfile +++ /dev/null @@ -1 +0,0 @@ -empty \ No newline at end of file diff --git a/middleware/rewrite/to.go b/middleware/rewrite/to.go deleted file mode 100644 index 7a38349ff..000000000 --- a/middleware/rewrite/to.go +++ /dev/null @@ -1,87 +0,0 @@ -package rewrite - -import ( - "log" - "net/http" - "net/url" - "path" - "strings" - - "github.com/mholt/caddy/middleware" -) - -// To attempts rewrite. It attempts to rewrite to first valid path -// or the last path if none of the paths are valid. -// Returns true if rewrite is successful and false otherwise. -func To(fs http.FileSystem, r *http.Request, to string, replacer middleware.Replacer) Result { - tos := strings.Fields(to) - - // try each rewrite paths - t := "" - for _, v := range tos { - t = path.Clean(replacer.Replace(v)) - - // add trailing slash for directories, if present - if strings.HasSuffix(v, "/") && !strings.HasSuffix(t, "/") { - t += "/" - } - - // validate file - if isValidFile(fs, t) { - break - } - } - - // validate resulting path - u, err := url.Parse(t) - if err != nil { - // Let the user know we got here. Rewrite is expected but - // the resulting url is invalid. - log.Printf("[ERROR] rewrite: resulting path '%v' is invalid. error: %v", t, err) - return RewriteIgnored - } - - // take note of this rewrite for internal use by fastcgi - // all we need is the URI, not full URL - r.Header.Set(headerFieldName, r.URL.RequestURI()) - - // perform rewrite - r.URL.Path = u.Path - if u.RawQuery != "" { - // overwrite query string if present - r.URL.RawQuery = u.RawQuery - } - if u.Fragment != "" { - // overwrite fragment if present - r.URL.Fragment = u.Fragment - } - - return RewriteDone -} - -// isValidFile checks if file exists on the filesystem. -// if file ends with `/`, it is validated as a directory. -func isValidFile(fs http.FileSystem, file string) bool { - if fs == nil { - return false - } - - f, err := fs.Open(file) - if err != nil { - return false - } - defer f.Close() - - stat, err := f.Stat() - if err != nil { - return false - } - - // directory - if strings.HasSuffix(file, "/") { - return stat.IsDir() - } - - // file - return !stat.IsDir() -} diff --git a/middleware/rewrite/to_test.go b/middleware/rewrite/to_test.go deleted file mode 100644 index 6133c0b63..000000000 --- a/middleware/rewrite/to_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package rewrite - -import ( - "net/http" - "net/url" - "testing" -) - -func TestTo(t *testing.T) { - fs := http.Dir("testdata") - tests := []struct { - url string - to string - expected string - }{ - {"/", "/somefiles", "/somefiles"}, - {"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"}, - {"/somefiles", "/testfile /index.php{uri}", "/testfile"}, - {"/somefiles", "/testfile/ /index.php{uri}", "/index.php/somefiles"}, - {"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"}, - {"/?a=b", "/somefiles /index.php?{query}", "/index.php?a=b"}, - {"/?a=b", "/testfile /index.php?{query}", "/testfile?a=b"}, - {"/?a=b", "/testdir /index.php?{query}", "/index.php?a=b"}, - {"/?a=b", "/testdir/ /index.php?{query}", "/testdir/?a=b"}, - } - - uri := func(r *url.URL) string { - uri := r.Path - if r.RawQuery != "" { - uri += "?" + r.RawQuery - } - return uri - } - for i, test := range tests { - r, err := http.NewRequest("GET", test.url, nil) - if err != nil { - t.Error(err) - } - To(fs, r, test.to, newReplacer(r)) - if uri(r.URL) != test.expected { - t.Errorf("Test %v: expected %v found %v", i, test.expected, uri(r.URL)) - } - } -} diff --git a/middleware/roller.go b/middleware/roller.go deleted file mode 100644 index 995cabf91..000000000 --- a/middleware/roller.go +++ /dev/null @@ -1,27 +0,0 @@ -package middleware - -import ( - "io" - - "gopkg.in/natefinch/lumberjack.v2" -) - -// LogRoller implements a middleware that provides a rolling logger. -type LogRoller struct { - Filename string - MaxSize int - MaxAge int - MaxBackups int - LocalTime bool -} - -// GetLogWriter returns an io.Writer that writes to a rolling logger. -func (l LogRoller) GetLogWriter() io.Writer { - return &lumberjack.Logger{ - Filename: l.Filename, - MaxSize: l.MaxSize, - MaxAge: l.MaxAge, - MaxBackups: l.MaxBackups, - LocalTime: l.LocalTime, - } -} diff --git a/middleware/templates/templates.go b/middleware/templates/templates.go deleted file mode 100644 index c8c08922b..000000000 --- a/middleware/templates/templates.go +++ /dev/null @@ -1,95 +0,0 @@ -// Package templates implements template execution for files to be dynamically rendered for the client. -package templates - -import ( - "bytes" - "net/http" - "os" - "path" - "path/filepath" - "text/template" - - "github.com/mholt/caddy/middleware" -) - -// ServeHTTP implements the middleware.Handler interface. -func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - for _, rule := range t.Rules { - if !middleware.Path(r.URL.Path).Matches(rule.Path) { - continue - } - - // Check for index files - fpath := r.URL.Path - if idx, ok := middleware.IndexFile(t.FileSys, fpath, rule.IndexFiles); ok { - fpath = idx - } - - // Check the extension - reqExt := path.Ext(fpath) - - for _, ext := range rule.Extensions { - if reqExt == ext { - // Create execution context - ctx := middleware.Context{Root: t.FileSys, Req: r, URL: r.URL} - - // New template - templateName := filepath.Base(fpath) - tpl := template.New(templateName) - - // Set delims - if rule.Delims != [2]string{} { - tpl.Delims(rule.Delims[0], rule.Delims[1]) - } - - // Build the template - templatePath := filepath.Join(t.Root, fpath) - tpl, err := tpl.ParseFiles(templatePath) - if err != nil { - if os.IsNotExist(err) { - return http.StatusNotFound, nil - } else if os.IsPermission(err) { - return http.StatusForbidden, nil - } - return http.StatusInternalServerError, err - } - - // Execute it - var buf bytes.Buffer - err = tpl.Execute(&buf, ctx) - if err != nil { - return http.StatusInternalServerError, err - } - - templateInfo, err := os.Stat(templatePath) - if err == nil { - // add the Last-Modified header if we were able to read the stamp - middleware.SetLastModifiedHeader(w, templateInfo.ModTime()) - } - buf.WriteTo(w) - - return http.StatusOK, nil - } - } - } - - return t.Next.ServeHTTP(w, r) -} - -// Templates is middleware to render templated files as the HTTP response. -type Templates struct { - Next middleware.Handler - Rules []Rule - Root string - FileSys http.FileSystem -} - -// Rule represents a template rule. A template will only execute -// with this rule if the request path matches the Path specified -// and requests a resource with one of the extensions specified. -type Rule struct { - Path string - Extensions []string - IndexFiles []string - Delims [2]string -} diff --git a/middleware/templates/templates_test.go b/middleware/templates/templates_test.go deleted file mode 100644 index c5a5d24a8..000000000 --- a/middleware/templates/templates_test.go +++ /dev/null @@ -1,146 +0,0 @@ -package templates - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/mholt/caddy/middleware" -) - -func Test(t *testing.T) { - tmpl := Templates{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - return 0, nil - }), - Rules: []Rule{ - { - Extensions: []string{".html"}, - IndexFiles: []string{"index.html"}, - Path: "/photos", - }, - { - Extensions: []string{".html", ".htm"}, - IndexFiles: []string{"index.html", "index.htm"}, - Path: "/images", - Delims: [2]string{"{%", "%}"}, - }, - }, - Root: "./testdata", - FileSys: http.Dir("./testdata"), - } - - tmplroot := Templates{ - Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - return 0, nil - }), - Rules: []Rule{ - { - Extensions: []string{".html"}, - IndexFiles: []string{"index.html"}, - Path: "/", - }, - }, - Root: "./testdata", - FileSys: http.Dir("./testdata"), - } - - /* - * Test tmpl on /photos/test.html - */ - req, err := http.NewRequest("GET", "/photos/test.html", nil) - if err != nil { - t.Fatalf("Test: Could not create HTTP request: %v", err) - } - - rec := httptest.NewRecorder() - - tmpl.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("Test: Wrong response code: %d, should be %d", rec.Code, http.StatusOK) - } - - respBody := rec.Body.String() - expectedBody := `test page

Header title

- -` - - if respBody != expectedBody { - t.Fatalf("Test: the expected body %v is different from the response one: %v", expectedBody, respBody) - } - - /* - * Test tmpl on /images/img.htm - */ - req, err = http.NewRequest("GET", "/images/img.htm", nil) - if err != nil { - t.Fatalf("Could not create HTTP request: %v", err) - } - - rec = httptest.NewRecorder() - - tmpl.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("Test: Wrong response code: %d, should be %d", rec.Code, http.StatusOK) - } - - respBody = rec.Body.String() - expectedBody = `img

Header title

- -` - - if respBody != expectedBody { - t.Fatalf("Test: the expected body %v is different from the response one: %v", expectedBody, respBody) - } - - /* - * Test tmpl on /images/img2.htm - */ - req, err = http.NewRequest("GET", "/images/img2.htm", nil) - if err != nil { - t.Fatalf("Could not create HTTP request: %v", err) - } - - rec = httptest.NewRecorder() - - tmpl.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("Test: Wrong response code: %d, should be %d", rec.Code, http.StatusOK) - } - - respBody = rec.Body.String() - expectedBody = `img{{.Include "header.html"}} -` - - if respBody != expectedBody { - t.Fatalf("Test: the expected body %v is different from the response one: %v", expectedBody, respBody) - } - - /* - * Test tmplroot on /root.html - */ - req, err = http.NewRequest("GET", "/root.html", nil) - if err != nil { - t.Fatalf("Could not create HTTP request: %v", err) - } - - rec = httptest.NewRecorder() - - tmplroot.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("Test: Wrong response code: %d, should be %d", rec.Code, http.StatusOK) - } - - respBody = rec.Body.String() - expectedBody = `root

Header title

- -` - - if respBody != expectedBody { - t.Fatalf("Test: the expected body %v is different from the response one: %v", expectedBody, respBody) - } -} diff --git a/middleware/templates/testdata/header.html b/middleware/templates/testdata/header.html deleted file mode 100644 index 9c96e0e37..000000000 --- a/middleware/templates/testdata/header.html +++ /dev/null @@ -1 +0,0 @@ -

Header title

diff --git a/middleware/templates/testdata/images/header.html b/middleware/templates/testdata/images/header.html deleted file mode 100644 index 9c96e0e37..000000000 --- a/middleware/templates/testdata/images/header.html +++ /dev/null @@ -1 +0,0 @@ -

Header title

diff --git a/middleware/templates/testdata/images/img.htm b/middleware/templates/testdata/images/img.htm deleted file mode 100644 index c90602044..000000000 --- a/middleware/templates/testdata/images/img.htm +++ /dev/null @@ -1 +0,0 @@ -img{%.Include "header.html"%} diff --git a/middleware/templates/testdata/images/img2.htm b/middleware/templates/testdata/images/img2.htm deleted file mode 100644 index 865a73809..000000000 --- a/middleware/templates/testdata/images/img2.htm +++ /dev/null @@ -1 +0,0 @@ -img{{.Include "header.html"}} diff --git a/middleware/templates/testdata/photos/test.html b/middleware/templates/testdata/photos/test.html deleted file mode 100644 index e2e95e133..000000000 --- a/middleware/templates/testdata/photos/test.html +++ /dev/null @@ -1 +0,0 @@ -test page{{.Include "../header.html"}} diff --git a/middleware/templates/testdata/root.html b/middleware/templates/testdata/root.html deleted file mode 100644 index e1720e726..000000000 --- a/middleware/templates/testdata/root.html +++ /dev/null @@ -1 +0,0 @@ -root{{.Include "header.html"}} diff --git a/middleware/websocket/websocket.go b/middleware/websocket/websocket.go deleted file mode 100644 index 76781ba10..000000000 --- a/middleware/websocket/websocket.go +++ /dev/null @@ -1,259 +0,0 @@ -// Package websocket implements a WebSocket server by executing -// a command and piping its input and output through the WebSocket -// connection. -package websocket - -import ( - "bufio" - "bytes" - "io" - "net" - "net/http" - "os" - "os/exec" - "strings" - "time" - - "github.com/gorilla/websocket" - "github.com/mholt/caddy/middleware" -) - -const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 - - // Maximum message size allowed from peer. - maxMessageSize = 1024 * 1024 * 10 // 10 MB default. -) - -var ( - // GatewayInterface is the dialect of CGI being used by the server - // to communicate with the script. See CGI spec, 4.1.4 - GatewayInterface string - - // ServerSoftware is the name and version of the information server - // software making the CGI request. See CGI spec, 4.1.17 - ServerSoftware string -) - -type ( - // WebSocket is a type that holds configuration for the - // websocket middleware generally, like a list of all the - // websocket endpoints. - WebSocket struct { - // Next is the next HTTP handler in the chain for when the path doesn't match - Next middleware.Handler - - // Sockets holds all the web socket endpoint configurations - Sockets []Config - } - - // Config holds the configuration for a single websocket - // endpoint which may serve multiple websocket connections. - Config struct { - Path string - Command string - Arguments []string - Respawn bool // TODO: Not used, but parser supports it until we decide on it - } -) - -// ServeHTTP converts the HTTP request to a WebSocket connection and serves it up. -func (ws WebSocket) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - for _, sockconfig := range ws.Sockets { - if middleware.Path(r.URL.Path).Matches(sockconfig.Path) { - return serveWS(w, r, &sockconfig) - } - } - - // Didn't match a websocket path, so pass-thru - return ws.Next.ServeHTTP(w, r) -} - -// serveWS is used for setting and upgrading the HTTP connection to a websocket connection. -// It also spawns the child process that is associated with matched HTTP path/url. -func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error) { - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return http.StatusBadRequest, err - } - defer conn.Close() - - cmd := exec.Command(config.Command, config.Arguments...) - - stdout, err := cmd.StdoutPipe() - if err != nil { - return http.StatusBadGateway, err - } - defer stdout.Close() - - stdin, err := cmd.StdinPipe() - if err != nil { - return http.StatusBadGateway, err - } - defer stdin.Close() - - metavars, err := buildEnv(cmd.Path, r) - if err != nil { - return http.StatusBadGateway, err - } - - cmd.Env = metavars - - if err := cmd.Start(); err != nil { - return http.StatusBadGateway, err - } - - done := make(chan struct{}) - go pumpStdout(conn, stdout, done) - pumpStdin(conn, stdin) - - stdin.Close() // close stdin to end the process - - if err := cmd.Process.Signal(os.Interrupt); err != nil { // signal an interrupt to kill the process - return http.StatusInternalServerError, err - } - - select { - case <-done: - case <-time.After(time.Second): - // terminate with extreme prejudice. - if err := cmd.Process.Signal(os.Kill); err != nil { - return http.StatusInternalServerError, err - } - <-done - } - - // not sure what we want to do here. - // status for an "exited" process is greater - // than 0, but isn't really an error per se. - // just going to ignore it for now. - cmd.Wait() - - return 0, nil -} - -// buildEnv creates the meta-variables for the child process according -// to the CGI 1.1 specification: http://tools.ietf.org/html/rfc3875#section-4.1 -// cmdPath should be the path of the command being run. -// The returned string slice can be set to the command's Env property. -func buildEnv(cmdPath string, r *http.Request) (metavars []string, err error) { - if !strings.Contains(r.RemoteAddr, ":") { - r.RemoteAddr += ":" - } - remoteHost, remotePort, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return - } - - if !strings.Contains(r.Host, ":") { - r.Host += ":" - } - serverHost, serverPort, err := net.SplitHostPort(r.Host) - if err != nil { - return - } - - metavars = []string{ - `AUTH_TYPE=`, // Not used - `CONTENT_LENGTH=`, // Not used - `CONTENT_TYPE=`, // Not used - `GATEWAY_INTERFACE=` + GatewayInterface, - `PATH_INFO=`, // TODO - `PATH_TRANSLATED=`, // TODO - `QUERY_STRING=` + r.URL.RawQuery, - `REMOTE_ADDR=` + remoteHost, - `REMOTE_HOST=` + remoteHost, // Host lookups are slow - don't do them - `REMOTE_IDENT=`, // Not used - `REMOTE_PORT=` + remotePort, - `REMOTE_USER=`, // Not used, - `REQUEST_METHOD=` + r.Method, - `REQUEST_URI=` + r.RequestURI, - `SCRIPT_NAME=` + cmdPath, // path of the program being executed - `SERVER_NAME=` + serverHost, - `SERVER_PORT=` + serverPort, - `SERVER_PROTOCOL=` + r.Proto, - `SERVER_SOFTWARE=` + ServerSoftware, - } - - // Add each HTTP header to the environment as well - for header, values := range r.Header { - value := strings.Join(values, ", ") - header = strings.ToUpper(header) - header = strings.Replace(header, "-", "_", -1) - value = strings.Replace(value, "\n", " ", -1) - metavars = append(metavars, "HTTP_"+header+"="+value) - } - - return -} - -// pumpStdin handles reading data from the websocket connection and writing -// it to stdin of the process. -func pumpStdin(conn *websocket.Conn, stdin io.WriteCloser) { - // Setup our connection's websocket ping/pong handlers from our const values. - defer conn.Close() - conn.SetReadLimit(maxMessageSize) - conn.SetReadDeadline(time.Now().Add(pongWait)) - conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - for { - _, message, err := conn.ReadMessage() - if err != nil { - break - } - message = append(message, '\n') - if _, err := stdin.Write(message); err != nil { - break - } - } -} - -// pumpStdout handles reading data from stdout of the process and writing -// it to websocket connection. -func pumpStdout(conn *websocket.Conn, stdout io.Reader, done chan struct{}) { - go pinger(conn, done) - defer func() { - conn.Close() - close(done) // make sure to close the pinger when we are done. - }() - - s := bufio.NewScanner(stdout) - for s.Scan() { - conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := conn.WriteMessage(websocket.TextMessage, bytes.TrimSpace(s.Bytes())); err != nil { - break - } - } - if s.Err() != nil { - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, s.Err().Error()), time.Time{}) - } -} - -// pinger simulates the websocket to keep it alive with ping messages. -func pinger(conn *websocket.Conn, done chan struct{}) { - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - - for { // blocking loop with select to wait for stimulation. - select { - case <-ticker.C: - if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, err.Error()), time.Time{}) - return - } - case <-done: - return // clean up this routine. - } - } -} diff --git a/middleware/websocket/websocket_test.go b/middleware/websocket/websocket_test.go deleted file mode 100644 index 61d7d382b..000000000 --- a/middleware/websocket/websocket_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package websocket - -import ( - "net/http" - "testing" -) - -func TestBuildEnv(t *testing.T) { - req, err := http.NewRequest("GET", "http://localhost", nil) - if err != nil { - t.Fatal("Error setting up request:", err) - } - req.RemoteAddr = "localhost:50302" - - env, err := buildEnv("/bin/command", req) - if err != nil { - t.Fatal("Didn't expect an error:", err) - } - if len(env) == 0 { - t.Fatalf("Expected non-empty environment; got %#v", env) - } -} diff --git a/plugins.go b/plugins.go new file mode 100644 index 000000000..f66e24cda --- /dev/null +++ b/plugins.go @@ -0,0 +1,289 @@ +package caddy + +import ( + "fmt" + "net" + "sort" + + "github.com/mholt/caddy/caddyfile" +) + +// These are all the registered plugins. +var ( + // serverTypes is a map of registered server types. + serverTypes = make(map[string]ServerType) + + // plugins is a map of server type to map of plugin name to + // Plugin. These are the "general" plugins that may or may + // not be associated with a specific server type. If it's + // applicable to multiple server types or the server type is + // irrelevant, the key is empty string (""). But all plugins + // must have a name. + plugins = make(map[string]map[string]Plugin) + + // parsingCallbacks maps server type to map of directive + // to list of callback functions. These aren't really + // plugins on their own, but are often registered from + // plugins. + parsingCallbacks = make(map[string]map[string][]func() error) + + // caddyfileLoaders is the list of all Caddyfile loaders + // in registration order. + caddyfileLoaders []caddyfileLoader +) + +// DescribePlugins returns a string describing the registered plugins. +func DescribePlugins() string { + str := "Server types:\n" + for name := range serverTypes { + str += " " + name + "\n" + } + + // List the loaders in registration order + str += "\nCaddyfile loaders:\n" + for _, loader := range caddyfileLoaders { + str += " " + loader.name + "\n" + } + if defaultCaddyfileLoader.name != "" { + str += " " + defaultCaddyfileLoader.name + "\n" + } + + // Let's alphabetize the rest of these... + var others []string + for stype, stypePlugins := range plugins { + for name := range stypePlugins { + var s string + if stype != "" { + s = stype + "." + } + s += name + others = append(others, s) + } + } + sort.Strings(others) + str += "\nOther plugins:\n" + for _, name := range others { + str += " " + name + "\n" + } + + return str +} + +// ValidDirectives returns the list of all directives that are +// recognized for the server type serverType. However, not all +// directives may be installed. This makes it possible to give +// more helpful error messages, like "did you mean ..." or +// "maybe you need to plug in ...". +func ValidDirectives(serverType string) []string { + stype, err := getServerType(serverType) + if err != nil { + return nil + } + return stype.Directives +} + +// serverListener pairs a server to its listener. +type serverListener struct { + server Server + listener net.Listener +} + +// Context is a type that carries a server type through +// the load and setup phase; it maintains the state +// between loading the Caddyfile, then executing its +// directives, then making the servers for Caddy to +// manage. Typically, such state involves configuration +// structs, etc. +type Context interface { + InspectServerBlocks(string, []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) + MakeServers() ([]Server, error) +} + +// RegisterServerType registers a server type srv by its +// name, typeName. +func RegisterServerType(typeName string, srv ServerType) { + if _, ok := serverTypes[typeName]; ok { + panic("server type already registered") + } + serverTypes[typeName] = srv +} + +// ServerType contains information about a server type. +type ServerType struct { + // List of directives, in execution order, that are + // valid for this server type. Directives should be + // one word if possible and lower-cased. + Directives []string + + // InspectServerBlocks is an optional callback that is + // executed after loading the tokens for each server + // block but before executing the directives in them. + // This func may modify the server blocks and return + // new ones to be used. + InspectServerBlocks func(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) + + // MakeServers is a callback that makes the server + // instances. + MakeServers func() ([]Server, error) + + // DefaultInput returns a default config input if none + // is otherwise loaded. + DefaultInput func() Input + + NewContext func() Context +} + +// Plugin is a type which holds information about a plugin. +type Plugin struct { + // The plugin must have a name: lower case and one word. + // If this plugin has an action, it must be the name of + // the directive to attach to. A name is always required. + Name string + + // ServerType is the type of server this plugin is for. + // Can be empty if not applicable, or if the plugin + // can associate with any server type. + ServerType string + + // Action is the plugin's setup function, if associated + // with a directive in the Caddyfile. + Action SetupFunc +} + +// RegisterPlugin plugs in plugin. All plugins should register +// themselves, even if they do not perform an action associated +// with a directive. It is important for the process to know +// which plugins are available. +func RegisterPlugin(plugin Plugin) { + if plugin.Name == "" { + panic("plugin must have a name") + } + if _, ok := plugins[plugin.ServerType]; !ok { + plugins[plugin.ServerType] = make(map[string]Plugin) + } + if _, dup := plugins[plugin.ServerType][plugin.Name]; dup { + panic("plugin named " + plugin.Name + " already registered for server type " + plugin.ServerType) + } + plugins[plugin.ServerType][plugin.Name] = plugin +} + +// RegisterParsingCallback registers callback to be called after +// executing the directive afterDir for server type serverType. +func RegisterParsingCallback(serverType, afterDir string, callback func() error) { + if _, ok := parsingCallbacks[serverType]; !ok { + parsingCallbacks[serverType] = make(map[string][]func() error) + } + parsingCallbacks[serverType][afterDir] = append(parsingCallbacks[serverType][afterDir], callback) +} + +// SetupFunc is used to set up a plugin, or in other words, +// execute a directive. It will be called once per key for +// each server block it appears in. +type SetupFunc func(c *Controller) error + +// DirectiveAction gets the action for directive dir of +// server type serverType. +func DirectiveAction(serverType, dir string) (SetupFunc, error) { + if stypePlugins, ok := plugins[serverType]; ok { + if plugin, ok := stypePlugins[dir]; ok { + return plugin.Action, nil + } + } + if genericPlugins, ok := plugins[""]; ok { + if plugin, ok := genericPlugins[dir]; ok { + return plugin.Action, nil + } + } + return nil, fmt.Errorf("no action found for directive '%s' with server type '%s' (missing a plugin?)", + dir, serverType) +} + +// Loader is a type that can load a Caddyfile. +// It is passed the name of the server type. +// It returns an error only if something went +// wrong, not simply if there is no Caddyfile +// for this loader to load. +// +// A Loader should only load the Caddyfile if +// a certain condition or requirement is met, +// as returning a non-nil Input value along with +// another Loader will result in an error. +// In other words, loading the Caddyfile must +// be deliberate & deterministic, not haphazard. +// +// The exception is the default Caddyfile loader, +// which will be called only if no other Caddyfile +// loaders return a non-nil Input. The default +// loader may always return an Input value. +type Loader interface { + Load(string) (Input, error) +} + +// LoaderFunc is a convenience type similar to http.HandlerFunc +// that allows you to use a plain function as a Load() method. +type LoaderFunc func(string) (Input, error) + +// Load loads a Caddyfile. +func (lf LoaderFunc) Load(serverType string) (Input, error) { + return lf(serverType) +} + +// RegisterCaddyfileLoader registers loader named name. +func RegisterCaddyfileLoader(name string, loader Loader) { + caddyfileLoaders = append(caddyfileLoaders, caddyfileLoader{name: name, loader: loader}) +} + +// SetDefaultCaddyfileLoader registers loader by name +// as the default Caddyfile loader if no others produce +// a Caddyfile. If another Caddyfile loader has already +// been set as the default, this replaces it. +// +// Do not call RegisterCaddyfileLoader on the same +// loader; that would be redundant. +func SetDefaultCaddyfileLoader(name string, loader Loader) { + defaultCaddyfileLoader = caddyfileLoader{name: name, loader: loader} +} + +// loadCaddyfileInput iterates the registered Caddyfile loaders +// and, if needed, calls the default loader, to load a Caddyfile. +// It is an error if any of the loaders return an error or if +// more than one loader returns a Caddyfile. +func loadCaddyfileInput(serverType string) (Input, error) { + var loadedBy string + var caddyfileToUse Input + for _, l := range caddyfileLoaders { + if cdyfile, err := l.loader.Load(serverType); cdyfile != nil { + if caddyfileToUse != nil { + return nil, fmt.Errorf("Caddyfile loaded multiple times; first by %s, then by %s", loadedBy, l.name) + } + if err != nil { + return nil, err + } + loaderUsed = l + caddyfileToUse = cdyfile + loadedBy = l.name + } + } + if caddyfileToUse == nil && defaultCaddyfileLoader.loader != nil { + cdyfile, err := defaultCaddyfileLoader.loader.Load(serverType) + if err != nil { + return nil, err + } + if cdyfile != nil { + loaderUsed = defaultCaddyfileLoader + caddyfileToUse = cdyfile + } + } + return caddyfileToUse, nil +} + +// caddyfileLoader pairs the name of a loader to the loader. +type caddyfileLoader struct { + name string + loader Loader +} + +var ( + defaultCaddyfileLoader caddyfileLoader // the default loader if all else fail + loaderUsed caddyfileLoader // the loader that was used (relevant for reloads) +) diff --git a/server/config.go b/server/config.go deleted file mode 100644 index e66ec801c..000000000 --- a/server/config.go +++ /dev/null @@ -1,80 +0,0 @@ -package server - -import ( - "crypto/tls" - "net" - - "github.com/mholt/caddy/middleware" -) - -// Config configuration for a single server. -type Config struct { - // The hostname or IP on which to serve - Host string - - // The host address to bind on - defaults to (virtual) Host if empty - BindHost string - - // The port to listen on - Port string - - // The protocol (http/https) to serve with this config; only set if user explicitly specifies it - Scheme string - - // The directory from which to serve files - Root string - - // HTTPS configuration - TLS TLSConfig - - // Middleware stack - Middleware []middleware.Middleware - - // Startup is a list of functions (or methods) to execute at - // server startup and restart; these are executed before any - // parts of the server are configured, and the functions are - // blocking. These are good for setting up middlewares and - // starting goroutines. - Startup []func() error - - // FirstStartup is like Startup but these functions only execute - // during the initial startup, not on subsequent restarts. - // - // (Note: The server does not ever run these on its own; it is up - // to the calling application to do so, and do so only once, as the - // server itself has no notion whether it's a restart or not.) - FirstStartup []func() error - - // Functions (or methods) to execute when the server quits; - // these are executed in response to SIGINT and are blocking - Shutdown []func() error - - // The path to the configuration file from which this was loaded - ConfigFile string - - // The name of the application - AppName string - - // The application's version - AppVersion string -} - -// Address returns the host:port of c as a string. -func (c Config) Address() string { - return net.JoinHostPort(c.Host, c.Port) -} - -// TLSConfig describes how TLS should be configured and used. -type TLSConfig struct { - Enabled bool // will be set to true if TLS is enabled - LetsEncryptEmail string - Manual bool // will be set to true if user provides own certs and keys - Managed bool // will be set to true if config qualifies for implicit automatic/managed HTTPS - OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes) - Ciphers []uint16 - ProtocolMinVersion uint16 - ProtocolMaxVersion uint16 - PreferServerCipherSuites bool - ClientCerts []string - ClientAuth tls.ClientAuthType -} diff --git a/server/config_test.go b/server/config_test.go deleted file mode 100644 index 8787e467b..000000000 --- a/server/config_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package server - -import "testing" - -func TestConfigAddress(t *testing.T) { - cfg := Config{Host: "foobar", Port: "1234"} - if actual, expected := cfg.Address(), "foobar:1234"; expected != actual { - t.Errorf("Expected '%s' but got '%s'", expected, actual) - } - - cfg = Config{Host: "", Port: "1234"} - if actual, expected := cfg.Address(), ":1234"; expected != actual { - t.Errorf("Expected '%s' but got '%s'", expected, actual) - } - - cfg = Config{Host: "foobar", Port: ""} - if actual, expected := cfg.Address(), "foobar:"; expected != actual { - t.Errorf("Expected '%s' but got '%s'", expected, actual) - } - - cfg = Config{Host: "::1", Port: "443"} - if actual, expected := cfg.Address(), "[::1]:443"; expected != actual { - t.Errorf("Expected '%s' but got '%s'", expected, actual) - } -} diff --git a/server/server.go b/server/server.go deleted file mode 100644 index ea98f5e5f..000000000 --- a/server/server.go +++ /dev/null @@ -1,544 +0,0 @@ -// Package server implements a configurable, general-purpose web server. -// It relies on configurations obtained from the adjacent config package -// and can execute middleware as defined by the adjacent middleware package. -package server - -import ( - "crypto/rand" - "crypto/tls" - "crypto/x509" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "net/http" - "os" - "path" - "runtime" - "strings" - "sync" - "time" -) - -const ( - tlsNewTicketEvery = time.Hour * 10 // generate a new ticket for TLS PFS encryption every so often - tlsNumTickets = 4 // hold and consider that many tickets to decrypt TLS sessions -) - -// Server represents an instance of a server, which serves -// HTTP requests at a particular address (host and port). A -// server is capable of serving numerous virtual hosts on -// the same address and the listener may be stopped for -// graceful termination (POSIX only). -type Server struct { - *http.Server - HTTP2 bool // whether to enable HTTP/2 - tls bool // whether this server is serving all HTTPS hosts or not - OnDemandTLS bool // whether this server supports on-demand TLS (load certs at handshake-time) - tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine - vhosts map[string]virtualHost // virtual hosts keyed by their address - listener ListenerFile // the listener which is bound to the socket - listenerMu sync.Mutex // protects listener - httpWg sync.WaitGroup // used to wait on outstanding connections - startChan chan struct{} // used to block until server is finished starting - connTimeout time.Duration // the maximum duration of a graceful shutdown - ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request - SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) -} - -// ListenerFile represents a listener. -type ListenerFile interface { - net.Listener - File() (*os.File, error) -} - -// OptionalCallback is a function that may or may not handle a request. -// It returns whether or not it handled the request. If it handled the -// request, it is presumed that no further request handling should occur. -type OptionalCallback func(http.ResponseWriter, *http.Request) bool - -// New creates a new Server which will bind to addr and serve -// the sites/hosts configured in configs. Its listener will -// gracefully close when the server is stopped which will take -// no longer than gracefulTimeout. -// -// This function does not start serving. -// -// Do not re-use a server (start, stop, then start again). We -// could probably add more locking to make this possible, but -// as it stands, you should dispose of a server after stopping it. -// The behavior of serving with a spent server is undefined. -func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) { - var useTLS, useOnDemandTLS bool - if len(configs) > 0 { - useTLS = configs[0].TLS.Enabled - useOnDemandTLS = configs[0].TLS.OnDemand - } - - s := &Server{ - Server: &http.Server{ - Addr: addr, - TLSConfig: new(tls.Config), - // TODO: Make these values configurable? - // ReadTimeout: 2 * time.Minute, - // WriteTimeout: 2 * time.Minute, - // MaxHeaderBytes: 1 << 16, - }, - tls: useTLS, - OnDemandTLS: useOnDemandTLS, - vhosts: make(map[string]virtualHost), - startChan: make(chan struct{}), - connTimeout: gracefulTimeout, - } - s.Handler = s // this is weird, but whatever - - // We have to bound our wg with one increment - // to prevent a "race condition" that is hard-coded - // into sync.WaitGroup.Wait() - basically, an add - // with a positive delta must be guaranteed to - // occur before Wait() is called on the wg. - // In a way, this kind of acts as a safety barrier. - s.httpWg.Add(1) - - // Set up each virtualhost - for _, conf := range configs { - if _, exists := s.vhosts[conf.Host]; exists { - return nil, fmt.Errorf("cannot serve %s - host already defined for address %s", conf.Address(), s.Addr) - } - - vh := virtualHost{config: conf} - - // Build middleware stack - err := vh.buildStack() - if err != nil { - return nil, err - } - - s.vhosts[conf.Host] = vh - } - - return s, nil -} - -// Serve starts the server with an existing listener. It blocks until the -// server stops. -func (s *Server) Serve(ln ListenerFile) error { - err := s.setup() - if err != nil { - defer close(s.startChan) // MUST defer so error is properly reported, same with all cases in this file - return err - } - return s.serve(ln) -} - -// ListenAndServe starts the server with a new listener. It blocks until the server stops. -func (s *Server) ListenAndServe() error { - err := s.setup() - if err != nil { - defer close(s.startChan) - return err - } - - ln, err := net.Listen("tcp", s.Addr) - if err != nil { - var succeeded bool - if runtime.GOOS == "windows" { // TODO: Limit this to Windows only? (it keeps sockets open after closing listeners) - for i := 0; i < 20; i++ { - time.Sleep(100 * time.Millisecond) - ln, err = net.Listen("tcp", s.Addr) - if err == nil { - succeeded = true - break - } - } - } - if !succeeded { - defer close(s.startChan) - return err - } - } - - return s.serve(ln.(*net.TCPListener)) -} - -// serve prepares s to listen on ln by wrapping ln in a -// tcpKeepAliveListener (if ln is a *net.TCPListener) and -// then in a gracefulListener, so that keep-alive is supported -// as well as graceful shutdown/restart. It also configures -// TLS listener on top of that if applicable. -func (s *Server) serve(ln ListenerFile) error { - if tcpLn, ok := ln.(*net.TCPListener); ok { - ln = tcpKeepAliveListener{TCPListener: tcpLn} - } - - s.listenerMu.Lock() - s.listener = newGracefulListener(ln, &s.httpWg) - s.listenerMu.Unlock() - - if s.tls { - var tlsConfigs []TLSConfig - for _, vh := range s.vhosts { - tlsConfigs = append(tlsConfigs, vh.config.TLS) - } - return serveTLS(s, s.listener, tlsConfigs) - } - - close(s.startChan) // unblock anyone waiting for this to start listening - return s.Server.Serve(s.listener) -} - -// setup prepares the server s to begin listening; it should be -// called just before the listener announces itself on the network -// and should only be called when the server is just starting up. -func (s *Server) setup() error { - if !s.HTTP2 { - s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) - } - - // Execute startup functions now - for _, vh := range s.vhosts { - for _, startupFunc := range vh.config.Startup { - err := startupFunc() - if err != nil { - return err - } - } - } - - return nil -} - -// serveTLS serves TLS with SNI and client auth support if s has them enabled. It -// blocks until s quits. -func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { - // Customize our TLS configuration - s.TLSConfig.MinVersion = tlsConfigs[0].ProtocolMinVersion - s.TLSConfig.MaxVersion = tlsConfigs[0].ProtocolMaxVersion - s.TLSConfig.CipherSuites = tlsConfigs[0].Ciphers - s.TLSConfig.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites - - // TLS client authentication, if user enabled it - err := setupClientAuth(tlsConfigs, s.TLSConfig) - if err != nil { - defer close(s.startChan) - return err - } - - // Setup any goroutines governing over TLS settings - s.tlsGovChan = make(chan struct{}) - timer := time.NewTicker(tlsNewTicketEvery) - go runTLSTicketKeyRotation(s.TLSConfig, timer, s.tlsGovChan) - - // Create TLS listener - note that we do not replace s.listener - // with this TLS listener; tls.listener is unexported and does - // not implement the File() method we need for graceful restarts - // on POSIX systems. - ln = tls.NewListener(ln, s.TLSConfig) - - close(s.startChan) // unblock anyone waiting for this to start listening - return s.Server.Serve(ln) -} - -// Stop stops the server. It blocks until the server is -// totally stopped. On POSIX systems, it will wait for -// connections to close (up to a max timeout of a few -// seconds); on Windows it will close the listener -// immediately. -func (s *Server) Stop() (err error) { - s.Server.SetKeepAlivesEnabled(false) - - if runtime.GOOS != "windows" { - // force connections to close after timeout - done := make(chan struct{}) - go func() { - s.httpWg.Done() // decrement our initial increment used as a barrier - s.httpWg.Wait() - close(done) - }() - - // Wait for remaining connections to finish or - // force them all to close after timeout - select { - case <-time.After(s.connTimeout): - case <-done: - } - } - - // Close the listener now; this stops the server without delay - s.listenerMu.Lock() - if s.listener != nil { - err = s.listener.Close() - } - s.listenerMu.Unlock() - - // Closing this signals any TLS governor goroutines to exit - if s.tlsGovChan != nil { - close(s.tlsGovChan) - } - - return -} - -// WaitUntilStarted blocks until the server s is started, meaning -// that practically the next instruction is to start the server loop. -// It also unblocks if the server encounters an error during startup. -func (s *Server) WaitUntilStarted() { - <-s.startChan -} - -// ListenerFd gets a dup'ed file of the listener. If there -// is no underlying file, the return value will be nil. It -// is the caller's responsibility to close the file. -func (s *Server) ListenerFd() *os.File { - s.listenerMu.Lock() - defer s.listenerMu.Unlock() - if s.listener != nil { - file, _ := s.listener.File() - return file - } - return nil -} - -// ServeHTTP is the entry point for every request to the address that s -// is bound to. It acts as a multiplexer for the requests hostname as -// defined in the Host header so that the correct virtualhost -// (configuration and middleware stack) will handle the request. -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - defer func() { - // In case the user doesn't enable error middleware, we still - // need to make sure that we stay alive up here - if rec := recover(); rec != nil { - http.Error(w, http.StatusText(http.StatusInternalServerError), - http.StatusInternalServerError) - } - }() - - w.Header().Set("Server", "Caddy") - - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host // oh well - } - - // "The host subcomponent is case-insensitive." (RFC 3986) - host = strings.ToLower(host) - - // Try the host as given, or try falling back to 0.0.0.0 (wildcard) - if _, ok := s.vhosts[host]; !ok { - if _, ok2 := s.vhosts["0.0.0.0"]; ok2 { - host = "0.0.0.0" - } else if _, ok2 := s.vhosts[""]; ok2 { - host = "" - } - } - - // Use URL.RawPath If you need the original, "raw" URL.Path in your middleware. - // Collapse any ./ ../ /// madness here instead of doing that in every plugin. - if r.URL.Path != "/" { - cleanedPath := path.Clean(r.URL.Path) - if cleanedPath == "." { - r.URL.Path = "/" - } else { - if !strings.HasPrefix(cleanedPath, "/") { - cleanedPath = "/" + cleanedPath - } - if strings.HasSuffix(r.URL.Path, "/") && !strings.HasSuffix(cleanedPath, "/") { - cleanedPath = cleanedPath + "/" - } - r.URL.Path = cleanedPath - } - } - - // Execute the optional request callback if it exists and it's not disabled - if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) { - return - } - - if vh, ok := s.vhosts[host]; ok { - status, _ := vh.stack.ServeHTTP(w, r) - - // Fallback error response in case error handling wasn't chained in - if status >= 400 { - DefaultErrorFunc(w, r, status) - } - } else { - // Get the remote host - remoteHost, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - remoteHost = r.RemoteAddr - } - - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, "No such host at %s", s.Server.Addr) - log.Printf("[INFO] %s - No such host at %s (Remote: %s, Referer: %s)", - host, s.Server.Addr, remoteHost, r.Header.Get("Referer")) - } -} - -// DefaultErrorFunc responds to an HTTP request with a simple description -// of the specified HTTP status code. -func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) { - w.WriteHeader(status) - fmt.Fprintf(w, "%d %s", status, http.StatusText(status)) -} - -// setupClientAuth sets up TLS client authentication only if -// any of the TLS configs specified at least one cert file. -func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { - whatClientAuth := tls.NoClientCert - for _, cfg := range tlsConfigs { - if whatClientAuth < cfg.ClientAuth { // Use the most restrictive. - whatClientAuth = cfg.ClientAuth - } - } - - if whatClientAuth != tls.NoClientCert { - pool := x509.NewCertPool() - for _, cfg := range tlsConfigs { - if len(cfg.ClientCerts) == 0 { - continue - } - for _, caFile := range cfg.ClientCerts { - caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect - if err != nil { - return err - } - if !pool.AppendCertsFromPEM(caCrt) { - return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile) - } - } - } - config.ClientCAs = pool - config.ClientAuth = whatClientAuth - } - - return nil -} - -var runTLSTicketKeyRotation = standaloneTLSTicketKeyRotation - -var setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { - return keys -} - -// standaloneTLSTicketKeyRotation governs over the array of TLS ticket keys used to de/crypt TLS tickets. -// It periodically sets a new ticket key as the first one, used to encrypt (and decrypt), -// pushing any old ticket keys to the back, where they are considered for decryption only. -// -// Lack of entropy for the very first ticket key results in the feature being disabled (as does Go), -// later lack of entropy temporarily disables ticket key rotation. -// Old ticket keys are still phased out, though. -// -// Stops the timer when returning. -func standaloneTLSTicketKeyRotation(c *tls.Config, timer *time.Ticker, exitChan chan struct{}) { - defer timer.Stop() - // The entire page should be marked as sticky, but Go cannot do that - // without resorting to syscall#Mlock. And, we don't have madvise (for NODUMP), too. ☹ - keys := make([][32]byte, 1, tlsNumTickets) - - rng := c.Rand - if rng == nil { - rng = rand.Reader - } - if _, err := io.ReadFull(rng, keys[0][:]); err != nil { - c.SessionTicketsDisabled = true // bail if we don't have the entropy for the first one - return - } - c.SessionTicketKey = keys[0] // SetSessionTicketKeys doesn't set a 'tls.keysAlreadSet' - c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) - - for { - select { - case _, isOpen := <-exitChan: - if !isOpen { - return - } - case <-timer.C: - rng = c.Rand // could've changed since the start - if rng == nil { - rng = rand.Reader - } - var newTicketKey [32]byte - _, err := io.ReadFull(rng, newTicketKey[:]) - - if len(keys) < tlsNumTickets { - keys = append(keys, keys[0]) // manipulates the internal length - } - for idx := len(keys) - 1; idx >= 1; idx-- { - keys[idx] = keys[idx-1] // yes, this makes copies - } - - if err == nil { - keys[0] = newTicketKey - } - // pushes the last key out, doesn't matter that we don't have a new one - c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) - } - } -} - -// RunFirstStartupFuncs runs all of the server's FirstStartup -// callback functions unless one of them returns an error first. -// It is the caller's responsibility to call this only once and -// at the correct time. The functions here should not be executed -// at restarts or where the user does not explicitly start a new -// instance of the server. -func (s *Server) RunFirstStartupFuncs() error { - for _, vh := range s.vhosts { - for _, f := range vh.config.FirstStartup { - if err := f(); err != nil { - return err - } - } - } - return nil -} - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -// -// Borrowed from the Go standard library. -type tcpKeepAliveListener struct { - *net.TCPListener -} - -// Accept accepts the connection with a keep-alive enabled. -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -} - -// File implements ListenerFile; returns the underlying file of the listener. -func (ln tcpKeepAliveListener) File() (*os.File, error) { - return ln.TCPListener.File() -} - -// ShutdownCallbacks executes all the shutdown callbacks -// for all the virtualhosts in servers, and returns all the -// errors generated during their execution. In other words, -// an error executing one shutdown callback does not stop -// execution of others. Only one shutdown callback is executed -// at a time. You must protect the servers that are passed in -// if they are shared across threads. -func ShutdownCallbacks(servers []*Server) []error { - var errs []error - for _, s := range servers { - for _, vhost := range s.vhosts { - for _, shutdownFunc := range vhost.config.Shutdown { - err := shutdownFunc() - if err != nil { - errs = append(errs, err) - } - } - } - } - return errs -} diff --git a/server/server_test.go b/server/server_test.go deleted file mode 100644 index 08f1915bd..000000000 --- a/server/server_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package server - -import ( - "crypto/tls" - "testing" - "time" -) - -func TestStandaloneTLSTicketKeyRotation(t *testing.T) { - tlsGovChan := make(chan struct{}) - defer close(tlsGovChan) - callSync := make(chan bool, 1) - defer close(callSync) - - oldHook := setSessionTicketKeysTestHook - defer func() { - setSessionTicketKeysTestHook = oldHook - }() - var keysInUse [][32]byte - setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { - keysInUse = keys - callSync <- true - return keys - } - - c := new(tls.Config) - timer := time.NewTicker(time.Millisecond * 1) - - go standaloneTLSTicketKeyRotation(c, timer, tlsGovChan) - - rounds := 0 - var lastTicketKey [32]byte - for { - select { - case <-callSync: - if lastTicketKey == keysInUse[0] { - close(tlsGovChan) - t.Errorf("The same TLS ticket key has been used again (not rotated): %x.", lastTicketKey) - return - } - lastTicketKey = keysInUse[0] - rounds++ - if rounds <= tlsNumTickets && len(keysInUse) != rounds { - close(tlsGovChan) - t.Errorf("Expected TLS ticket keys in use: %d; Got instead: %d.", rounds, len(keysInUse)) - return - } - if c.SessionTicketsDisabled == true { - t.Error("Session tickets have been disabled unexpectedly.") - return - } - if rounds >= tlsNumTickets+1 { - return - } - case <-time.After(time.Second * 1): - t.Errorf("Timeout after %d rounds.", rounds) - return - } - } -} diff --git a/server/virtualhost.go b/server/virtualhost.go deleted file mode 100644 index 0f44cc68c..000000000 --- a/server/virtualhost.go +++ /dev/null @@ -1,35 +0,0 @@ -package server - -import ( - "net/http" - - "github.com/mholt/caddy/middleware" -) - -// virtualHost represents a virtual host/server. While a Server -// is what actually binds to the address, a user may want to serve -// multiple sites on a single address, and this is what a -// virtualHost allows us to do. -type virtualHost struct { - config Config - fileServer middleware.Handler - stack middleware.Handler -} - -// buildStack builds the server's middleware stack based -// on its config. This method should be called last before -// ListenAndServe begins. -func (vh *virtualHost) buildStack() error { - vh.fileServer = middleware.FileServer(http.Dir(vh.config.Root), []string{vh.config.ConfigFile}) - vh.compile(vh.config.Middleware) - return nil -} - -// compile is an elegant alternative to nesting middleware function -// calls like handler1(handler2(handler3(finalHandler))). -func (vh *virtualHost) compile(layers []middleware.Middleware) { - vh.stack = vh.fileServer // core app layer - for i := len(layers) - 1; i >= 0; i-- { - vh.stack = layers[i](vh.stack) - } -} diff --git a/caddy/sigtrap.go b/sigtrap.go similarity index 69% rename from caddy/sigtrap.go rename to sigtrap.go index 6fd00cac5..7acdbc491 100644 --- a/caddy/sigtrap.go +++ b/sigtrap.go @@ -5,15 +5,13 @@ import ( "os" "os/signal" "sync" - - "github.com/mholt/caddy/server" ) // TrapSignals create signal handlers for all applicable signals for this // system. If your Go program uses signals, this is a rather invasive // function; best to implement them yourself in that case. Signals are not // required for the caddy package to function properly, but this is a -// convenient way to allow the user to control this package of your program. +// convenient way to allow the user to control this part of your program. func TrapSignals() { trapSignalsCrossPlatform() trapSignalsPosix() @@ -54,10 +52,7 @@ func trapSignalsCrossPlatform() { // This function is idempotent; subsequent invocations always return 0. func executeShutdownCallbacks(signame string) (exitCode int) { shutdownCallbacksOnce.Do(func() { - serversMu.Lock() - errs := server.ShutdownCallbacks(servers) - serversMu.Unlock() - + errs := allShutdownCallbacks() if len(errs) > 0 { for _, err := range errs { log.Printf("[ERROR] %s shutdown: %v", signame, err) @@ -68,4 +63,21 @@ func executeShutdownCallbacks(signame string) (exitCode int) { return } +// allShutdownCallbacks executes all the shutdown callbacks +// for all the instances, and returns all the errors generated +// during their execution. An error executing one shutdown +// callback does not stop execution of others. Only one shutdown +// callback is executed at a time. +func allShutdownCallbacks() []error { + var errs []error + instancesMu.Lock() + for _, inst := range instances { + errs = append(errs, inst.shutdownCallbacks()...) + } + instancesMu.Unlock() + return errs +} + +// shutdownCallbacksOnce ensures that shutdown callbacks +// for all instances are only executed once. var shutdownCallbacksOnce sync.Once diff --git a/caddy/sigtrap_posix.go b/sigtrap_posix.go similarity index 62% rename from caddy/sigtrap_posix.go rename to sigtrap_posix.go index ac3000d76..9ee7bbba3 100644 --- a/caddy/sigtrap_posix.go +++ b/sigtrap_posix.go @@ -3,7 +3,6 @@ package caddy import ( - "io/ioutil" "log" "os" "os/signal" @@ -48,28 +47,34 @@ func trapSignalsPosix() { case syscall.SIGUSR1: log.Println("[INFO] SIGUSR1: Reloading") - var updatedCaddyfile Input - - caddyfileMu.Lock() - if caddyfile == nil { + // Start with the existing Caddyfile + instancesMu.Lock() + inst := instances[0] // we only support one instance at this time + instancesMu.Unlock() + updatedCaddyfile := inst.caddyfileInput + if updatedCaddyfile == nil { // Hmm, did spawing process forget to close stdin? Anyhow, this is unusual. log.Println("[ERROR] SIGUSR1: no Caddyfile to reload (was stdin left open?)") - caddyfileMu.Unlock() continue } - if caddyfile.IsFile() { - body, err := ioutil.ReadFile(caddyfile.Path()) - if err == nil { - updatedCaddyfile = CaddyfileInput{ - Filepath: caddyfile.Path(), - Contents: body, - RealFile: true, - } - } + if loaderUsed.loader == nil { + // This also should never happen + log.Println("[ERROR] SIGUSR1: no Caddyfile loader with which to reload Caddyfile") + continue } - caddyfileMu.Unlock() - err := Restart(updatedCaddyfile) + // Load the updated Caddyfile + newCaddyfile, err := loaderUsed.loader.Load(inst.serverType) + if err != nil { + log.Printf("[ERROR] SIGUSR1: loading updated Caddyfile: %v", err) + continue + } + if newCaddyfile != nil { + updatedCaddyfile = newCaddyfile + } + + // Kick off the restart; our work is done + inst, err = inst.Restart(updatedCaddyfile) if err != nil { log.Printf("[ERROR] SIGUSR1: %v", err) } diff --git a/caddy/sigtrap_windows.go b/sigtrap_windows.go similarity index 100% rename from caddy/sigtrap_windows.go rename to sigtrap_windows.go diff --git a/caddy/setup/startupshutdown.go b/startupshutdown/startupshutdown.go similarity index 51% rename from caddy/setup/startupshutdown.go rename to startupshutdown/startupshutdown.go index 7a21ef47a..911e1b015 100644 --- a/caddy/setup/startupshutdown.go +++ b/startupshutdown/startupshutdown.go @@ -1,27 +1,38 @@ -package setup +package startupshutdown import ( "os" "os/exec" "strings" - "github.com/mholt/caddy/middleware" + "github.com/mholt/caddy" ) -// Startup registers a startup callback to execute during server start. -func Startup(c *Controller) (middleware.Middleware, error) { - return nil, registerCallback(c, &c.FirstStartup) +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "startup", + Action: Startup, + }) + caddy.RegisterPlugin(caddy.Plugin{ + Name: "shutdown", + Action: Shutdown, + }) } -// Shutdown registers a shutdown callback to execute during process exit. -func Shutdown(c *Controller) (middleware.Middleware, error) { - return nil, registerCallback(c, &c.Shutdown) +// Startup registers a startup callback to execute during server start. +func Startup(c *caddy.Controller) error { + return registerCallback(c, c.OnStartup) +} + +// Shutdown registers a shutdown callback to execute during server stop. +func Shutdown(c *caddy.Controller) error { + return registerCallback(c, c.OnShutdown) } // registerCallback registers a callback function to execute by -// using c to parse the line. It appends the callback function -// to the list of callback functions passed in by reference. -func registerCallback(c *Controller, list *[]func() error) error { +// using c to parse the directive. It registers the callback +// to be executed using registerFunc. +func registerCallback(c *caddy.Controller, registerFunc func(func() error)) error { var funcs []func() error for c.Next() { @@ -37,7 +48,7 @@ func registerCallback(c *Controller, list *[]func() error) error { args = args[:len(args)-1] } - command, args, err := middleware.SplitCommandAndArgs(strings.Join(args, " ")) + command, args, err := caddy.SplitCommandAndArgs(strings.Join(args, " ")) if err != nil { return c.Err(err.Error()) } @@ -47,7 +58,6 @@ func registerCallback(c *Controller, list *[]func() error) error { cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - if nonblock { return cmd.Start() } @@ -58,7 +68,9 @@ func registerCallback(c *Controller, list *[]func() error) error { } return c.OncePerServerBlock(func() error { - *list = append(*list, funcs...) + for _, fn := range funcs { + registerFunc(fn) + } return nil }) } diff --git a/caddy/setup/startupshutdown_test.go b/startupshutdown/startupshutdown_test.go similarity index 78% rename from caddy/setup/startupshutdown_test.go rename to startupshutdown/startupshutdown_test.go index 871a64214..8bc98f9ab 100644 --- a/caddy/setup/startupshutdown_test.go +++ b/startupshutdown/startupshutdown_test.go @@ -1,4 +1,4 @@ -package setup +package startupshutdown import ( "os" @@ -6,16 +6,15 @@ import ( "strconv" "testing" "time" + + "github.com/mholt/caddy" ) // The Startup function's tests are symmetrical to Shutdown tests, // because the Startup and Shutdown functions share virtually the // same functionality func TestStartup(t *testing.T) { - tempDirPath, err := getTempDirPath() - if err != nil { - t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) - } + tempDirPath := os.TempDir() testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown") defer func() { @@ -26,6 +25,11 @@ func TestStartup(t *testing.T) { osSenitiveTestDir := filepath.FromSlash(testDir) os.RemoveAll(osSenitiveTestDir) // start with a clean slate + var registeredFunction func() error + fakeRegister := func(fn func() error) { + registeredFunction = fn + } + tests := []struct { input string shouldExecutionErr bool @@ -42,12 +46,15 @@ func TestStartup(t *testing.T) { } for i, test := range tests { - c := NewTestController(test.input) - _, err = Startup(c) + c := caddy.NewTestController(test.input) + err := registerCallback(c, fakeRegister) if err != nil { t.Errorf("Expected no errors, got: %v", err) } - err = c.FirstStartup[0]() + if registeredFunction == nil { + t.Fatalf("Expected function to be registered, but it wasn't") + } + err = registeredFunction() if err != nil && !test.shouldExecutionErr { t.Errorf("Test %d recieved an error of:\n%v", i, err) }