mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-27 22:23:48 +03:00
Merge branch 'master' into telemetry
# Conflicts: # caddy/caddymain/run.go # caddyhttp/httpserver/plugin.go # caddytls/client.go
This commit is contained in:
commit
b019501b8b
228 changed files with 10317 additions and 5964 deletions
2
.github/CONTRIBUTING.md
vendored
2
.github/CONTRIBUTING.md
vendored
|
@ -103,7 +103,7 @@ While we really do value your requests and implement many of them, not all featu
|
|||
|
||||
### Improving documentation
|
||||
|
||||
Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, feel free to contribute at the [caddyserver/website](https://github.com/caddyserver/website) repository!
|
||||
Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, please submit an issue here describing the change to make.
|
||||
|
||||
Note that plugin documentation is not hosted by the Caddy website, other than basic usage examples. They are managed by the individual plugin authors, and you will have to contact them to change their documentation.
|
||||
|
||||
|
|
34
README.md
34
README.md
|
@ -1,5 +1,5 @@
|
|||
<p align="center">
|
||||
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36137292-bebc223a-1051-11e8-9a81-4ea9054c96ac.png" alt="Caddy" width="400"></a>
|
||||
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36338535-05fb646a-136f-11e8-987b-e6901e717d5a.png" alt="Caddy" width="450"></a>
|
||||
</p>
|
||||
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
|
||||
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
|
||||
|
@ -21,7 +21,7 @@
|
|||
|
||||
---
|
||||
|
||||
Caddy is fast, easy to use, and makes you more productive.
|
||||
Caddy is a **production-ready** open-source web server that is fast, easy to use, and makes you more productive.
|
||||
|
||||
Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.com/mholt/caddy/wiki/Running-Caddy-on-Android).
|
||||
|
||||
|
@ -41,31 +41,35 @@ Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.co
|
|||
- **Automatic HTTPS** on by default (via [Let's Encrypt](https://letsencrypt.org))
|
||||
- **HTTP/2** by default
|
||||
- **Virtual hosting** so multiple sites just work
|
||||
- Experimental **QUIC support** for those that like speed
|
||||
- Experimental **QUIC support** for cutting-edge transmissions
|
||||
- TLS session ticket **key rotation** for more secure connections
|
||||
- **Extensible with plugins** because a convenient web server is a helpful one
|
||||
- **Runs anywhere** with **no external dependencies** (not even libc)
|
||||
|
||||
There's way more, too! [See all features built into Caddy.](https://caddyserver.com/features) On top of all those, Caddy does even more with plugins: choose which plugins you want at [download](https://caddyserver.com/download).
|
||||
[See a more complete list of features built into Caddy.](https://caddyserver.com/features) On top of all those, Caddy does even more with plugins: choose which plugins you want at [download](https://caddyserver.com/download).
|
||||
|
||||
Altogether, Caddy can do things other web servers simply cannot do. Its features and plugins save you time and mistakes, and will cheer you up. Your Caddy instance takes care of the details for you!
|
||||
|
||||
|
||||
## Install
|
||||
|
||||
Caddy binaries have no dependencies and are available for every platform. Get Caddy any one of these ways:
|
||||
Caddy binaries have no dependencies and are available for every platform. Get Caddy either of these ways:
|
||||
|
||||
- **[Download page](https://caddyserver.com/download)** (RECOMMENDED) allows you to customize your build in the browser
|
||||
- **[Latest release](https://github.com/mholt/caddy/releases/latest)** for pre-built, vanilla binaries
|
||||
|
||||
- **[Download page](https://caddyserver.com/download)** allows you to
|
||||
customize your build in the browser
|
||||
- **[Latest release](https://github.com/mholt/caddy/releases/latest)** for
|
||||
pre-built, vanilla binaries
|
||||
|
||||
## Build
|
||||
|
||||
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
|
||||
|
||||
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
|
||||
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
|
||||
- Get the source with `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
|
||||
- Now `cd $GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
|
||||
|
||||
Then make sure the `caddy` binary is in your PATH.
|
||||
|
||||
To build for other platforms, use build.go with the `--goos` and `--goarch` flags.
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
@ -85,7 +89,7 @@ If the `caddy` binary has permission to bind to low ports and your domain name's
|
|||
caddy -host example.com
|
||||
```
|
||||
|
||||
This command serves static files from the current directory over HTTPS. Certificates are automatically obtained and renewed for you!
|
||||
This command serves static files from the current directory over HTTPS. Certificates are automatically obtained and renewed for you! Caddy is also automatically configuring ports 80 and 443 for you, and redirecting HTTP to HTTPS. Cool, huh?
|
||||
|
||||
### Customizing your site
|
||||
|
||||
|
@ -115,7 +119,7 @@ To host multiple sites and do more with the Caddyfile, please see the [Caddyfile
|
|||
|
||||
Sites with qualifying hostnames are served over [HTTPS by default](https://caddyserver.com/docs/automatic-https).
|
||||
|
||||
Caddy has a command line interface. Run `caddy -h` to view basic help or see the [CLI documentation](https://caddyserver.com/docs/cli) for details.
|
||||
Caddy has a nice little command line interface. Run `caddy -h` to view basic help or see the [CLI documentation](https://caddyserver.com/docs/cli) for details.
|
||||
|
||||
|
||||
## Running in Production
|
||||
|
@ -139,7 +143,7 @@ Please see our [contributing guidelines](https://github.com/mholt/caddy/blob/mas
|
|||
|
||||
We use GitHub issues and pull requests only for discussing bug reports and the development of specific changes. We welcome all other topics on the [forum](https://caddy.community)!
|
||||
|
||||
If you want to contribute to the documentation, please submit pull requests to [caddyserver/website](https://github.com/caddyserver/website).
|
||||
If you want to contribute to the documentation, please [submit an issue](https://github.com/mholt/caddy/issues/new) describing the change that should be made.
|
||||
|
||||
Thanks for making Caddy -- and the Web -- better!
|
||||
|
||||
|
@ -158,6 +162,6 @@ We thank them for their services. **If you want to help keep Caddy free, please
|
|||
Caddy was born out of the need for a "batteries-included" web server that runs anywhere and doesn't have to take its configuration with it. Caddy took inspiration from [spark](https://github.com/rif/spark), [nginx](https://github.com/nginx/nginx), lighttpd,
|
||||
[Websocketd](https://github.com/joewalnes/websocketd) and [Vagrant](https://www.vagrantup.com/), which provides a pleasant mixture of features from each of them.
|
||||
|
||||
**The name "Caddy":** The name of the software is "Caddy", not "Caddy Server" or "CaddyServer". Please call it "Caddy" or, if you wish to clarify, "the Caddy web server". See [brand guidelines](https://caddyserver.com/brand).
|
||||
**The name "Caddy" is trademarked:** The name of the software is "Caddy", not "Caddy Server" or "CaddyServer". Please call it "Caddy" or, if you wish to clarify, "the Caddy web server". See [brand guidelines](https://caddyserver.com/brand). Caddy is a registered trademark of Light Code Labs, LLC.
|
||||
|
||||
*Author on Twitter: [@mholt6](https://twitter.com/mholt6)*
|
||||
|
|
2
caddy.go
2
caddy.go
|
@ -802,7 +802,7 @@ func startServers(serverList []Server, inst *Instance, restartFds map[string]res
|
|||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
// this error is normal when closing the listener
|
||||
// this error is normal when closing the listener; see https://github.com/golang/go/issues/4373
|
||||
continue
|
||||
}
|
||||
log.Println(err)
|
||||
|
|
|
@ -31,7 +31,7 @@ import (
|
|||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
_ "github.com/mholt/caddy/caddyhttp" // plug in the HTTP server type
|
||||
|
@ -43,7 +43,7 @@ func init() {
|
|||
setVersion()
|
||||
|
||||
flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement")
|
||||
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory")
|
||||
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v02.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory")
|
||||
flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge")
|
||||
flag.BoolVar(&caddytls.DisableTLSSNIChallenge, "disable-tls-sni-challenge", caddytls.DisableTLSSNIChallenge, "Disable the ACME TLS-SNI challenge")
|
||||
flag.StringVar(&disabledMetrics, "disabled-metrics", "", "Comma-separated list of telemetry metrics to disable")
|
||||
|
|
|
@ -265,14 +265,19 @@ func (p *parser) doImport() error {
|
|||
} else {
|
||||
globPattern = importPattern
|
||||
}
|
||||
if strings.Count(globPattern, "*") > 1 || strings.Count(globPattern, "?") > 1 ||
|
||||
(strings.Contains(globPattern, "[") && strings.Contains(globPattern, "]")) {
|
||||
// See issue #2096 - a pattern with many glob expansions can hang for too long
|
||||
return p.Errf("Glob pattern may only contain one wildcard (*), but has others: %s", globPattern)
|
||||
}
|
||||
matches, err = filepath.Glob(globPattern)
|
||||
|
||||
if err != nil {
|
||||
return p.Errf("Failed to use import pattern %s: %v", importPattern, err)
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
if strings.Contains(globPattern, "*") {
|
||||
log.Printf("[WARNING] No files matching import pattern: %s", importPattern)
|
||||
if strings.ContainsAny(globPattern, "*?[]") {
|
||||
log.Printf("[WARNING] No files matching import glob pattern: %s", importPattern)
|
||||
} else {
|
||||
return p.Errf("File to import not found: %s", importPattern)
|
||||
}
|
||||
|
@ -443,7 +448,7 @@ func replaceEnvReferences(s, refStart, refEnd string) string {
|
|||
index := strings.Index(s, refStart)
|
||||
for index != -1 {
|
||||
endIndex := strings.Index(s, refEnd)
|
||||
if endIndex != -1 {
|
||||
if endIndex > index+len(refStart) {
|
||||
ref := s[index : endIndex+len(refEnd)]
|
||||
s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1)
|
||||
} else {
|
||||
|
|
|
@ -228,6 +228,17 @@ func TestParseOneAndImport(t *testing.T) {
|
|||
{`""`, false, []string{}, map[string]int{}},
|
||||
|
||||
{``, false, []string{}, map[string]int{}},
|
||||
|
||||
// test cases found by fuzzing!
|
||||
{`import }{$"`, true, []string{}, map[string]int{}},
|
||||
{`import /*/*.txt`, true, []string{}, map[string]int{}},
|
||||
{`import /???/?*?o`, true, []string{}, map[string]int{}},
|
||||
{`import /??`, true, []string{}, map[string]int{}},
|
||||
{`import /[a-z]`, true, []string{}, map[string]int{}},
|
||||
{`import {$}`, true, []string{}, map[string]int{}},
|
||||
{`import {%}`, true, []string{}, map[string]int{}},
|
||||
{`import {$$}`, true, []string{}, map[string]int{}},
|
||||
{`import {%%}`, true, []string{}, map[string]int{}},
|
||||
} {
|
||||
result, err := testParseOne(test.input)
|
||||
|
||||
|
|
|
@ -46,5 +46,4 @@ import (
|
|||
_ "github.com/mholt/caddy/caddyhttp/timeouts"
|
||||
_ "github.com/mholt/caddy/caddyhttp/websocket"
|
||||
_ "github.com/mholt/caddy/onevent"
|
||||
_ "github.com/mholt/caddy/startupshutdown"
|
||||
)
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
// ensure that the standard plugins are in fact plugged in
|
||||
// and registered properly; this is a quick/naive way to do it.
|
||||
func TestStandardPlugins(t *testing.T) {
|
||||
numStandardPlugins := 33 // importing caddyhttp plugs in this many plugins
|
||||
numStandardPlugins := 31 // importing caddyhttp plugs in this many plugins
|
||||
s := caddy.DescribePlugins()
|
||||
if got, want := strings.Count(s, "\n"), numStandardPlugins+5; got != want {
|
||||
t.Errorf("Expected all standard plugins to be plugged in, got:\n%s", s)
|
||||
|
|
|
@ -33,8 +33,11 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
// Handler is a middleware type that can handle requests as a FastCGI client.
|
||||
|
@ -323,6 +326,19 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
|
|||
// Some web apps rely on knowing HTTPS or not
|
||||
if r.TLS != nil {
|
||||
env["HTTPS"] = "on"
|
||||
// and pass the protocol details in a manner compatible with apache's mod_ssl
|
||||
// (which is why they have a SSL_ prefix and not TLS_).
|
||||
v, ok := tlsProtocolStringToMap[r.TLS.Version]
|
||||
if ok {
|
||||
env["SSL_PROTOCOL"] = v
|
||||
}
|
||||
// and pass the cipher suite in a manner compatible with apache's mod_ssl
|
||||
for k, v := range caddytls.SupportedCiphersMap {
|
||||
if v == r.TLS.CipherSuite {
|
||||
env["SSL_CIPHER"] = k
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add env variables from config (with support for placeholders in values)
|
||||
|
@ -465,3 +481,11 @@ type LogError string
|
|||
func (l LogError) Error() string {
|
||||
return string(l)
|
||||
}
|
||||
|
||||
// Map of supported protocols to Apache ssl_mod format
|
||||
// Note that these are slightly different from SupportedProtocols in caddytls/config.go's
|
||||
var tlsProtocolStringToMap = map[uint16]string{
|
||||
tls.VersionTLS10: "TLSv1",
|
||||
tls.VersionTLS11: "TLSv1.1",
|
||||
tls.VersionTLS12: "TLSv1.2",
|
||||
}
|
||||
|
|
|
@ -100,8 +100,8 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
|
|||
}
|
||||
cfg.TLS.Enabled = true
|
||||
cfg.Addr.Scheme = "https"
|
||||
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
|
||||
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host)
|
||||
if loadCertificates && caddytls.HostQualifies(cfg.TLS.Hostname) {
|
||||
_, err := cfg.TLS.CacheManagedCertificate(cfg.TLS.Hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ type Logger struct {
|
|||
V4ipMask net.IPMask
|
||||
V6ipMask net.IPMask
|
||||
IPMaskExists bool
|
||||
Exceptions []string
|
||||
}
|
||||
|
||||
// NewTestLogger creates logger suitable for testing purposes
|
||||
|
@ -84,6 +85,17 @@ func (l Logger) MaskIP(ip string) string {
|
|||
|
||||
}
|
||||
|
||||
// ShouldLog returns true if the path is not exempted from
|
||||
// being logged (i.e. it is not found in l.Exceptions).
|
||||
func (l Logger) ShouldLog(path string) bool {
|
||||
for _, exc := range l.Exceptions {
|
||||
if Path(path).Matches(exc) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Attach binds logger Start and Close functions to
|
||||
// controller's OnStartup and OnShutdown hooks.
|
||||
func (l *Logger) Attach(controller *caddy.Controller) {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package httpserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -123,15 +124,17 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
|||
// For each address in each server block, make a new config
|
||||
for _, sb := range serverBlocks {
|
||||
for _, key := range sb.Keys {
|
||||
key = strings.ToLower(key)
|
||||
if _, dup := h.keysToSiteConfigs[key]; dup {
|
||||
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
|
||||
}
|
||||
addr, err := standardizeAddress(key)
|
||||
if err != nil {
|
||||
return serverBlocks, err
|
||||
}
|
||||
|
||||
addr = addr.Normalize()
|
||||
key = addr.Key()
|
||||
if _, dup := h.keysToSiteConfigs[key]; dup {
|
||||
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
|
||||
}
|
||||
|
||||
// Fill in address components from command line so that middleware
|
||||
// have access to the correct information during setup
|
||||
if addr.Host == "" && Host != DefaultHost {
|
||||
|
@ -146,7 +149,7 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
|||
if addrCopy.Port == "" && Port == DefaultPort {
|
||||
addrCopy.Port = Port
|
||||
}
|
||||
addrStr := strings.ToLower(addrCopy.String())
|
||||
addrStr := addrCopy.String()
|
||||
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
|
||||
err := fmt.Errorf("duplicate site address: %s", addrStr)
|
||||
if (addrCopy.Host == Host && Host != DefaultHost) ||
|
||||
|
@ -218,6 +221,13 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Iterate each site configuration and make sure that:
|
||||
// 1) TLS is disabled for explicitly-HTTP sites (necessary
|
||||
// when an HTTP address shares a block containing tls)
|
||||
// 2) if QUIC is enabled, TLS ClientAuth is not, because
|
||||
// currently, QUIC does not support ClientAuth (TODO:
|
||||
// revisit this when our QUIC implementation supports it)
|
||||
// 3) if TLS ClientAuth is used, StrictHostMatching is on
|
||||
var atLeastOneSiteLooksLikeProduction bool
|
||||
for _, cfg := range h.siteConfigs {
|
||||
// see if all the addresses (both sites and
|
||||
|
@ -254,6 +264,17 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
|
|||
// instead of 443 because it doesn't know about TLS.
|
||||
cfg.Addr.Port = HTTPSPort
|
||||
}
|
||||
if cfg.TLS.ClientAuth != tls.NoClientCert {
|
||||
if QUIC {
|
||||
return nil, fmt.Errorf("cannot enable TLS client authentication with QUIC, because QUIC does not yet support it")
|
||||
}
|
||||
// this must be enabled so that a client cannot connect
|
||||
// using SNI for another site on this listener that
|
||||
// does NOT require ClientAuth, and then send HTTP
|
||||
// requests with the Host header of this site which DOES
|
||||
// require client auth, thus bypassing it...
|
||||
cfg.StrictHostMatching = true
|
||||
}
|
||||
}
|
||||
|
||||
// we must map (group) each config to a bind address
|
||||
|
@ -287,12 +308,22 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
|
|||
return servers, nil
|
||||
}
|
||||
|
||||
// normalizedKey returns "normalized" key representation:
|
||||
// scheme and host names are lowered, everything else stays the same
|
||||
func normalizedKey(key string) string {
|
||||
addr, err := standardizeAddress(key)
|
||||
if err != nil {
|
||||
return key
|
||||
}
|
||||
return addr.Normalize().Key()
|
||||
}
|
||||
|
||||
// GetConfig gets the SiteConfig that corresponds to c.
|
||||
// If none exist (should only happen in tests), then a
|
||||
// new, empty one will be created.
|
||||
func GetConfig(c *caddy.Controller) *SiteConfig {
|
||||
ctx := c.Context().(*httpContext)
|
||||
key := strings.ToLower(c.Key)
|
||||
key := normalizedKey(c.Key)
|
||||
if cfg, ok := ctx.keysToSiteConfigs[key]; ok {
|
||||
return cfg
|
||||
}
|
||||
|
@ -396,6 +427,43 @@ func (a Address) VHost() string {
|
|||
return a.Original
|
||||
}
|
||||
|
||||
// Normalize normalizes URL: turn scheme and host names into lower case
|
||||
func (a Address) Normalize() Address {
|
||||
path := a.Path
|
||||
if !CaseSensitivePath {
|
||||
path = strings.ToLower(path)
|
||||
}
|
||||
return Address{
|
||||
Original: a.Original,
|
||||
Scheme: strings.ToLower(a.Scheme),
|
||||
Host: strings.ToLower(a.Host),
|
||||
Port: a.Port,
|
||||
Path: path,
|
||||
}
|
||||
}
|
||||
|
||||
// Key is similar to String, just replaces scheme and host values with modified values.
|
||||
// Unlike String it doesn't add anything default (scheme, port, etc)
|
||||
func (a Address) Key() string {
|
||||
res := ""
|
||||
if a.Scheme != "" {
|
||||
res += a.Scheme + "://"
|
||||
}
|
||||
if a.Host != "" {
|
||||
res += a.Host
|
||||
}
|
||||
if a.Port != "" {
|
||||
if strings.HasPrefix(a.Original[len(res):], ":"+a.Port) {
|
||||
// insert port only if the original has its own explicit port
|
||||
res += ":" + a.Port
|
||||
}
|
||||
}
|
||||
if a.Path != "" {
|
||||
res += a.Path
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// standardizeAddress parses an address string into a structured format with separate
|
||||
// scheme, host, port, and path portions, as well as the original input string.
|
||||
func standardizeAddress(str string) (Address, error) {
|
||||
|
@ -523,6 +591,7 @@ var directives = []string{
|
|||
"startup", // TODO: Deprecate this directive
|
||||
"shutdown", // TODO: Deprecate this directive
|
||||
"on",
|
||||
"supervisor", // github.com/lucaslorentz/caddy-supervisor
|
||||
"request_id",
|
||||
"realip", // github.com/captncraig/caddy-realip
|
||||
"git", // github.com/abiosoft/caddy-git
|
||||
|
@ -538,13 +607,13 @@ var directives = []string{
|
|||
"ext",
|
||||
"gzip",
|
||||
"header",
|
||||
"geoip", // github.com/kodnaplakal/caddy-geoip
|
||||
"errors",
|
||||
"authz", // github.com/casbin/caddy-authz
|
||||
"filter", // github.com/echocat/caddy-filter
|
||||
"minify", // github.com/hacdias/caddy-minify
|
||||
"ipfilter", // github.com/pyed/ipfilter
|
||||
"ratelimit", // github.com/xuqingfeng/caddy-rate-limit
|
||||
"search", // github.com/pedronasser/caddy-search
|
||||
"expires", // github.com/epicagency/caddy-expires
|
||||
"forwardproxy", // github.com/caddyserver/forwardproxy
|
||||
"basicauth",
|
||||
|
|
|
@ -18,6 +18,10 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"sort"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
)
|
||||
|
@ -147,7 +151,20 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Didn't expect an error, but got: %v", err)
|
||||
}
|
||||
addr := ctx.keysToSiteConfigs["localhost"].Addr
|
||||
localhostKey := "localhost"
|
||||
item, ok := ctx.keysToSiteConfigs[localhostKey]
|
||||
if !ok {
|
||||
availableKeys := make(sort.StringSlice, len(ctx.keysToSiteConfigs))
|
||||
i := 0
|
||||
for key := range ctx.keysToSiteConfigs {
|
||||
availableKeys[i] = fmt.Sprintf("'%s'", key)
|
||||
i++
|
||||
}
|
||||
availableKeys.Sort()
|
||||
t.Errorf("`%s` not found within registered keys, only these are available: %s", localhostKey, strings.Join(availableKeys, ", "))
|
||||
return
|
||||
}
|
||||
addr := item.Addr
|
||||
if addr.Port != Port {
|
||||
t.Errorf("Expected the port on the address to be set, but got: %#v", addr)
|
||||
}
|
||||
|
@ -184,6 +201,64 @@ func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestKeyNormalization(t *testing.T) {
|
||||
originalCaseSensitivePath := CaseSensitivePath
|
||||
defer func() {
|
||||
CaseSensitivePath = originalCaseSensitivePath
|
||||
}()
|
||||
CaseSensitivePath = true
|
||||
|
||||
caseSensitiveData := []struct {
|
||||
orig string
|
||||
res string
|
||||
}{
|
||||
{
|
||||
orig: "HTTP://A/ABCDEF",
|
||||
res: "http://a/ABCDEF",
|
||||
},
|
||||
{
|
||||
orig: "A/ABCDEF",
|
||||
res: "a/ABCDEF",
|
||||
},
|
||||
{
|
||||
orig: "A:2015/Port",
|
||||
res: "a:2015/Port",
|
||||
},
|
||||
}
|
||||
for _, item := range caseSensitiveData {
|
||||
v := normalizedKey(item.orig)
|
||||
if v != item.res {
|
||||
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to true must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
|
||||
}
|
||||
}
|
||||
|
||||
CaseSensitivePath = false
|
||||
caseInsensitiveData := []struct {
|
||||
orig string
|
||||
res string
|
||||
}{
|
||||
{
|
||||
orig: "HTTP://A/ABCDEF",
|
||||
res: "http://a/abcdef",
|
||||
},
|
||||
{
|
||||
orig: "A/ABCDEF",
|
||||
res: "a/abcdef",
|
||||
},
|
||||
{
|
||||
orig: "A:2015/Port",
|
||||
res: "a:2015/port",
|
||||
},
|
||||
}
|
||||
for _, item := range caseInsensitiveData {
|
||||
v := normalizedKey(item.orig)
|
||||
if v != item.res {
|
||||
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to false must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
// case insensitivity for key
|
||||
con := caddy.NewTestController("http", "")
|
||||
|
@ -201,6 +276,14 @@ func TestGetConfig(t *testing.T) {
|
|||
if cfg == cfg3 {
|
||||
t.Errorf("Expected different configs using when key is different; got %p and %p", cfg, cfg3)
|
||||
}
|
||||
|
||||
con.Key = "foo/foobar"
|
||||
cfg4 := GetConfig(con)
|
||||
con.Key = "foo/Foobar"
|
||||
cfg5 := GetConfig(con)
|
||||
if cfg4 == cfg5 {
|
||||
t.Errorf("Expected different cases in path to differentiate keys in general")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectivesList(t *testing.T) {
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
// requestReplacer is a strings.Replacer which is used to
|
||||
|
@ -140,6 +141,14 @@ func canLogRequest(r *http.Request) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// unescapeBraces finds escaped braces in s and returns
|
||||
// a string with those braces unescaped.
|
||||
func unescapeBraces(s string) string {
|
||||
s = strings.Replace(s, "\\{", "{", -1)
|
||||
s = strings.Replace(s, "\\}", "}", -1)
|
||||
return s
|
||||
}
|
||||
|
||||
// Replace performs a replacement of values on s and returns
|
||||
// the string with the replaced values.
|
||||
func (r *replacer) Replace(s string) string {
|
||||
|
@ -149,32 +158,59 @@ func (r *replacer) Replace(s string) string {
|
|||
}
|
||||
|
||||
result := ""
|
||||
Placeholders: // process each placeholder in sequence
|
||||
for {
|
||||
idxStart := strings.Index(s, "{")
|
||||
if idxStart == -1 {
|
||||
// no placeholder anymore
|
||||
break
|
||||
}
|
||||
idxEnd := strings.Index(s[idxStart:], "}")
|
||||
if idxEnd == -1 {
|
||||
// unpaired placeholder
|
||||
break
|
||||
}
|
||||
idxEnd += idxStart
|
||||
var idxStart, idxEnd int
|
||||
|
||||
// get a replacement
|
||||
placeholder := s[idxStart : idxEnd+1]
|
||||
idxOffset := 0
|
||||
for { // find first unescaped opening brace
|
||||
searchSpace := s[idxOffset:]
|
||||
idxStart = strings.Index(searchSpace, "{")
|
||||
if idxStart == -1 {
|
||||
// no more placeholders
|
||||
break Placeholders
|
||||
}
|
||||
if idxStart == 0 || searchSpace[idxStart-1] != '\\' {
|
||||
// preceding character is not an escape
|
||||
idxStart += idxOffset
|
||||
break
|
||||
}
|
||||
// the brace we found was escaped
|
||||
// search the rest of the string next
|
||||
idxOffset += idxStart + 1
|
||||
}
|
||||
|
||||
idxOffset = 0
|
||||
for { // find first unescaped closing brace
|
||||
searchSpace := s[idxStart+idxOffset:]
|
||||
idxEnd = strings.Index(searchSpace, "}")
|
||||
if idxEnd == -1 {
|
||||
// unpaired placeholder
|
||||
break Placeholders
|
||||
}
|
||||
if idxEnd == 0 || searchSpace[idxEnd-1] != '\\' {
|
||||
// preceding character is not an escape
|
||||
idxEnd += idxOffset + idxStart
|
||||
break
|
||||
}
|
||||
// the brace we found was escaped
|
||||
// search the rest of the string next
|
||||
idxOffset += idxEnd + 1
|
||||
}
|
||||
|
||||
// get a replacement for the unescaped placeholder
|
||||
placeholder := unescapeBraces(s[idxStart : idxEnd+1])
|
||||
replacement := r.getSubstitution(placeholder)
|
||||
|
||||
// append prefix + replacement
|
||||
result += s[:idxStart] + replacement
|
||||
// append unescaped prefix + replacement
|
||||
result += strings.TrimPrefix(unescapeBraces(s[:idxStart]), "\\") + replacement
|
||||
|
||||
// strip out scanned parts
|
||||
s = s[idxEnd+1:]
|
||||
}
|
||||
|
||||
// append unscanned parts
|
||||
return result + s
|
||||
return result + unescapeBraces(s)
|
||||
}
|
||||
|
||||
func roundDuration(d time.Duration) time.Duration {
|
||||
|
@ -224,6 +260,16 @@ func (r *replacer) getSubstitution(key string) string {
|
|||
}
|
||||
}
|
||||
}
|
||||
// search response headers then
|
||||
if r.responseRecorder != nil && key[1] == '<' {
|
||||
want := key[2 : len(key)-1]
|
||||
for key, values := range r.responseRecorder.Header() {
|
||||
// Header placeholders (case-insensitive)
|
||||
if strings.EqualFold(key, want) {
|
||||
return strings.Join(values, ",")
|
||||
}
|
||||
}
|
||||
}
|
||||
// next check for cookies
|
||||
if key[1] == '~' {
|
||||
name := key[2 : len(key)-1]
|
||||
|
@ -365,12 +411,46 @@ func (r *replacer) getSubstitution(key string) string {
|
|||
}
|
||||
elapsedDuration := time.Since(r.responseRecorder.start)
|
||||
return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10)
|
||||
case "{tls_protocol}":
|
||||
if r.request.TLS != nil {
|
||||
for k, v := range caddytls.SupportedProtocols {
|
||||
if v == r.request.TLS.Version {
|
||||
return k
|
||||
}
|
||||
}
|
||||
return "tls" // this should never happen, but guard in case
|
||||
}
|
||||
return r.emptyValue // because not using a secure channel
|
||||
case "{tls_cipher}":
|
||||
if r.request.TLS != nil {
|
||||
for k, v := range caddytls.SupportedCiphersMap {
|
||||
if v == r.request.TLS.CipherSuite {
|
||||
return k
|
||||
}
|
||||
}
|
||||
return "UNKNOWN" // this should never happen, but guard in case
|
||||
}
|
||||
return r.emptyValue
|
||||
default:
|
||||
// {labelN}
|
||||
if strings.HasPrefix(key, "{label") {
|
||||
nStr := key[6 : len(key)-1] // get the integer N in "{labelN}"
|
||||
n, err := strconv.Atoi(nStr)
|
||||
if err != nil || n < 1 {
|
||||
return r.emptyValue
|
||||
}
|
||||
labels := strings.Split(r.request.Host, ".")
|
||||
if n > len(labels) {
|
||||
return r.emptyValue
|
||||
}
|
||||
return labels[n-1]
|
||||
}
|
||||
}
|
||||
|
||||
return r.emptyValue
|
||||
}
|
||||
|
||||
//convertToMilliseconds returns the number of milliseconds in the given duration
|
||||
// convertToMilliseconds returns the number of milliseconds in the given duration
|
||||
func convertToMilliseconds(d time.Duration) int64 {
|
||||
return d.Nanoseconds() / 1e6
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ func TestReplace(t *testing.T) {
|
|||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
|
||||
request, err := http.NewRequest("POST", "http://localhost.local/?foo=bar", reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
|
@ -67,6 +67,9 @@ func TestReplace(t *testing.T) {
|
|||
request.Header.Set("CustomAdd", "caddy")
|
||||
request.Header.Set("Cookie", "foo=bar; taste=delicious")
|
||||
|
||||
// add some respons headers
|
||||
recordRequest.Header().Set("Custom", "CustomResponseHeader")
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to determine hostname: %v", err)
|
||||
|
@ -84,7 +87,7 @@ func TestReplace(t *testing.T) {
|
|||
expect string
|
||||
}{
|
||||
{"This hostname is {hostname}", "This hostname is " + hostname},
|
||||
{"This host is {host}.", "This host is localhost."},
|
||||
{"This host is {host}.", "This host is localhost.local."},
|
||||
{"This request method is {method}.", "This request method is POST."},
|
||||
{"The response status is {status}.", "The response status is 200."},
|
||||
{"{when}", "02/Jan/2006:15:04:05 +0000"},
|
||||
|
@ -92,10 +95,13 @@ func TestReplace(t *testing.T) {
|
|||
{"{when_unix}", "1136214252"},
|
||||
{"The Custom header is {>Custom}.", "The Custom header is foobarbaz."},
|
||||
{"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."},
|
||||
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost\\r\\n" +
|
||||
{"The Custom response header is {<Custom}.", "The Custom response header is CustomResponseHeader."},
|
||||
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
|
||||
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost.local\\r\\n" +
|
||||
"Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" +
|
||||
"Shorterval: 1\\r\\n\\r\\n."},
|
||||
{"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."},
|
||||
{"The cUsToM response header is {<CuSTom}.", "The cUsToM response header is CustomResponseHeader."},
|
||||
{"The Non-Existent header is {>Non-Existent}.", "The Non-Existent header is -."},
|
||||
{"Bad {host placeholder...", "Bad {host placeholder..."},
|
||||
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
|
||||
|
@ -106,6 +112,9 @@ func TestReplace(t *testing.T) {
|
|||
{"Query string is {query}", "Query string is foo=bar"},
|
||||
{"Query string value for foo is {?foo}", "Query string value for foo is bar"},
|
||||
{"Missing query string argument is {?missing}", "Missing query string argument is "},
|
||||
{"{label1} {label2} {label3} {label4}", "localhost local - -"},
|
||||
{"Label with missing number is {label} or {labelQQ}", "Label with missing number is - or -"},
|
||||
{"\\{ 'hostname': '{hostname}' \\}", "{ 'hostname': '" + hostname + "' }"},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
|
@ -138,6 +147,107 @@ func TestReplace(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplace(b *testing.B) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
|
||||
request = request.WithContext(ctx)
|
||||
|
||||
request.Header.Set("Custom", "foobarbaz")
|
||||
request.Header.Set("ShorterVal", "1")
|
||||
repl := NewReplacer(request, recordRequest, "-")
|
||||
// add some headers after creating replacer
|
||||
request.Header.Set("CustomAdd", "caddy")
|
||||
request.Header.Set("Cookie", "foo=bar; taste=delicious")
|
||||
|
||||
// add some respons headers
|
||||
recordRequest.Header().Set("Custom", "CustomResponseHeader")
|
||||
|
||||
now = func() time.Time {
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
repl.Replace("This hostname is {hostname}")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplaceEscaped(b *testing.B) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
|
||||
request = request.WithContext(ctx)
|
||||
|
||||
request.Header.Set("Custom", "foobarbaz")
|
||||
request.Header.Set("ShorterVal", "1")
|
||||
repl := NewReplacer(request, recordRequest, "-")
|
||||
// add some headers after creating replacer
|
||||
request.Header.Set("CustomAdd", "caddy")
|
||||
request.Header.Set("Cookie", "foo=bar; taste=delicious")
|
||||
|
||||
// add some respons headers
|
||||
recordRequest.Header().Set("Custom", "CustomResponseHeader")
|
||||
|
||||
now = func() time.Time {
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
repl.Replace("\\{ 'hostname': '{hostname}' \\}")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRecorderNil(t *testing.T) {
|
||||
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Custom", "foobarbaz")
|
||||
repl := NewReplacer(request, nil, "-")
|
||||
// add some headers after creating replacer
|
||||
request.Header.Set("CustomAdd", "caddy")
|
||||
request.Header.Set("Cookie", "foo=bar; taste=delicious")
|
||||
|
||||
old := now
|
||||
now = func() time.Time {
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
|
||||
}
|
||||
defer func() {
|
||||
now = old
|
||||
}()
|
||||
testCases := []struct {
|
||||
template string
|
||||
expect string
|
||||
}{
|
||||
{"The Custom response header is {<Custom}.", "The Custom response header is -."},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
|
||||
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
|
|
|
@ -320,6 +320,9 @@ func (s *Server) Serve(ln net.Listener) error {
|
|||
}
|
||||
|
||||
err := s.Server.Serve(ln)
|
||||
if err == http.ErrServerClosed {
|
||||
err = nil // not an error worth reporting since closing a server is intentional
|
||||
}
|
||||
if s.quicServer != nil {
|
||||
s.quicServer.Close()
|
||||
}
|
||||
|
@ -421,19 +424,39 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||
r.URL = trimPathPrefix(r.URL, pathPrefix)
|
||||
}
|
||||
|
||||
// enforce strict host matching, which ensures that the SNI
|
||||
// value (if any), matches the Host header; essential for
|
||||
// sites that rely on TLS ClientAuth sharing a port with
|
||||
// sites that do not - if mismatched, close the connection
|
||||
if vhost.StrictHostMatching && r.TLS != nil &&
|
||||
strings.ToLower(r.TLS.ServerName) != strings.ToLower(hostname) {
|
||||
r.Close = true
|
||||
log.Printf("[ERROR] %s - strict host matching: SNI (%s) and HTTP Host (%s) values differ",
|
||||
vhost.Addr, r.TLS.ServerName, hostname)
|
||||
return http.StatusForbidden, nil
|
||||
}
|
||||
|
||||
return vhost.middlewareChain.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func trimPathPrefix(u *url.URL, prefix string) *url.URL {
|
||||
// We need to use URL.EscapedPath() when trimming the pathPrefix as
|
||||
// URL.Path is ambiguous about / or %2f - see docs. See #1927
|
||||
trimmed := strings.TrimPrefix(u.EscapedPath(), prefix)
|
||||
if !strings.HasPrefix(trimmed, "/") {
|
||||
trimmed = "/" + trimmed
|
||||
trimmedPath := strings.TrimPrefix(u.EscapedPath(), prefix)
|
||||
if !strings.HasPrefix(trimmedPath, "/") {
|
||||
trimmedPath = "/" + trimmedPath
|
||||
}
|
||||
trimmedURL, err := url.Parse(trimmed)
|
||||
// After trimming path reconstruct uri string with Query before parsing
|
||||
trimmedURI := trimmedPath
|
||||
if u.RawQuery != "" || u.ForceQuery == true {
|
||||
trimmedURI = trimmedPath + "?" + u.RawQuery
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
trimmedURI = trimmedURI + "#" + u.Fragment
|
||||
}
|
||||
trimmedURL, err := url.Parse(trimmedURI)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err)
|
||||
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmedURI, err)
|
||||
return u
|
||||
}
|
||||
return trimmedURL
|
||||
|
|
|
@ -129,88 +129,108 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
|
|||
|
||||
func TestTrimPathPrefix(t *testing.T) {
|
||||
for i, pt := range []struct {
|
||||
path string
|
||||
url string
|
||||
prefix string
|
||||
expected string
|
||||
shouldFail bool
|
||||
}{
|
||||
{
|
||||
path: "/my/path",
|
||||
url: "/my/path",
|
||||
prefix: "/my",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/%2f/path",
|
||||
url: "/my/%2f/path",
|
||||
prefix: "/my",
|
||||
expected: "/%2f/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/path",
|
||||
url: "/my/path",
|
||||
prefix: "/my/",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my///path",
|
||||
url: "/my///path",
|
||||
prefix: "/my",
|
||||
expected: "/path",
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
path: "/my///path",
|
||||
url: "/my///path",
|
||||
prefix: "/my",
|
||||
expected: "///path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/path///slash",
|
||||
url: "/my/path///slash",
|
||||
prefix: "/my",
|
||||
expected: "/path///slash",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/%2f/path/%2f",
|
||||
url: "/my/%2f/path/%2f",
|
||||
prefix: "/my",
|
||||
expected: "/%2f/path/%2f",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/my/%20/path",
|
||||
url: "/my/%20/path",
|
||||
prefix: "/my",
|
||||
expected: "/%20/path",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/path",
|
||||
url: "/path",
|
||||
prefix: "",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/path/my/",
|
||||
url: "/path/my/",
|
||||
prefix: "/my",
|
||||
expected: "/path/my/",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "",
|
||||
url: "",
|
||||
prefix: "/my",
|
||||
expected: "/",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/apath",
|
||||
url: "/apath",
|
||||
prefix: "",
|
||||
expected: "/apath",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/path/page.php?akey=value",
|
||||
prefix: "/my",
|
||||
expected: "/path/page.php?akey=value",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/path/page?key=value#fragment",
|
||||
prefix: "/my",
|
||||
expected: "/path/page?key=value#fragment",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/path/page#fragment",
|
||||
prefix: "/my",
|
||||
expected: "/path/page#fragment",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/apath?",
|
||||
prefix: "/my",
|
||||
expected: "/apath?",
|
||||
shouldFail: false,
|
||||
},
|
||||
} {
|
||||
|
||||
u, _ := url.Parse(pt.path)
|
||||
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want {
|
||||
u, _ := url.Parse(pt.url)
|
||||
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.String() != want {
|
||||
if !pt.shouldFail {
|
||||
|
||||
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath())
|
||||
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.String())
|
||||
}
|
||||
} else if pt.shouldFail {
|
||||
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath())
|
||||
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,16 @@ type SiteConfig struct {
|
|||
// TLS configuration
|
||||
TLS *caddytls.Config
|
||||
|
||||
// If true, the Host header in the HTTP request must
|
||||
// match the SNI value in the TLS handshake (if any).
|
||||
// This should be enabled whenever a site relies on
|
||||
// TLS client authentication, for example; or any time
|
||||
// you want to enforce that THIS site's TLS config
|
||||
// is used and not the TLS config of any other site
|
||||
// on the same listener. TODO: Check how relevant this
|
||||
// is with TLS 1.3.
|
||||
StrictHostMatching bool
|
||||
|
||||
// Uncompiled middleware stack
|
||||
middleware []Middleware
|
||||
|
||||
|
|
|
@ -277,7 +277,7 @@ func TestHostname(t *testing.T) {
|
|||
// // Test 3 - ipv6 without port and brackets
|
||||
// {"2001:4860:4860::8888", "google-public-dns-a.google.com."},
|
||||
// Test 4 - no hostname available
|
||||
{"1.1.1.1", "1.1.1.1"},
|
||||
{"0.0.0.0", "0.0.0.0"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
|
|
@ -67,6 +67,10 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
|
||||
// Write log entries
|
||||
for _, e := range rule.Entries {
|
||||
// Check if there is an exception to prevent log being written
|
||||
if !e.Log.ShouldLog(r.URL.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Mask IP Address
|
||||
if e.Log.IPMaskExists {
|
||||
|
@ -78,6 +82,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
}
|
||||
}
|
||||
e.Log.Println(rep.Replace(e.Format))
|
||||
|
||||
}
|
||||
|
||||
return status, err
|
||||
|
|
|
@ -177,3 +177,85 @@ func TestMultiEntries(t *testing.T) {
|
|||
t.Errorf("Expected %q, but got %q", expect, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogExcept(t *testing.T) {
|
||||
tests := []struct {
|
||||
LogRules []Rule
|
||||
logPath string
|
||||
shouldLog bool
|
||||
}{
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/soup"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/soup`, false},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/tart"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/soup`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/soup"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/tomatosoup`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/pie/"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
// Check exception with a trailing slash does not match without
|
||||
}}, `/pie`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/pie.php"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/pie`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/pie"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
// Check that a word without trailing slash will match a filename
|
||||
}}, `/pie.php`, false},
|
||||
}
|
||||
for i, test := range tests {
|
||||
for _, LogRule := range test.LogRules {
|
||||
for _, e := range LogRule.Entries {
|
||||
shouldLog := e.Log.ShouldLog(test.logPath)
|
||||
if shouldLog != test.shouldLog {
|
||||
t.Fatalf("Test %d expected shouldLog=%t but got shouldLog=%t,", i, test.shouldLog, shouldLog)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ func setup(c *caddy.Controller) error {
|
|||
|
||||
func logParse(c *caddy.Controller) ([]*Rule, error) {
|
||||
var rules []*Rule
|
||||
|
||||
var logExceptions []string
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
|
||||
|
@ -91,6 +91,12 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
|
|||
|
||||
}
|
||||
|
||||
} else if what == "except" {
|
||||
|
||||
for i := 0; i < len(where); i++ {
|
||||
logExceptions = append(logExceptions, where[i])
|
||||
}
|
||||
|
||||
} else if httpserver.IsLogRollerSubdirective(what) {
|
||||
|
||||
if err := httpserver.ParseRoller(logRoller, what, where...); err != nil {
|
||||
|
@ -133,6 +139,7 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
|
|||
V4ipMask: ip4Mask,
|
||||
V6ipMask: ip6Mask,
|
||||
IPMaskExists: ipMaskExists,
|
||||
Exceptions: logExceptions,
|
||||
},
|
||||
Format: format,
|
||||
})
|
||||
|
|
|
@ -58,6 +58,10 @@ type Upstream interface {
|
|||
// Gets the number of upstream hosts.
|
||||
GetHostCount() int
|
||||
|
||||
// Gets how long to wait before timing out
|
||||
// the request
|
||||
GetTimeout() time.Duration
|
||||
|
||||
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
|
||||
Stop() error
|
||||
}
|
||||
|
@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
if nameURL, err := url.Parse(host.Name); err == nil {
|
||||
outreq.Host = nameURL.Host
|
||||
if proxy == nil {
|
||||
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost)
|
||||
proxy = NewSingleHostReverseProxy(nameURL,
|
||||
host.WithoutPathPrefix,
|
||||
http.DefaultMaxIdleConnsPerHost,
|
||||
upstream.GetTimeout(),
|
||||
)
|
||||
}
|
||||
|
||||
// use upstream credentials by default
|
||||
|
|
|
@ -122,7 +122,7 @@ func TestReverseProxy(t *testing.T) {
|
|||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
|
||||
}
|
||||
|
||||
// Create the fake request body.
|
||||
|
@ -202,7 +202,7 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
|
|||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, true, 30*time.Second)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
|
@ -287,6 +287,31 @@ func TestReverseProxyMaxConnLimit(t *testing.T) {
|
|||
jobs.Wait()
|
||||
}
|
||||
|
||||
func TestReverseProxyTimeout(t *testing.T) {
|
||||
timeout := 2 * time.Second
|
||||
errorMargin := 100 * time.Millisecond
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream("https://8.8.8.8", true, timeout)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
p.ServeHTTP(w, r)
|
||||
took := time.Since(start)
|
||||
|
||||
if took > timeout+errorMargin {
|
||||
t.Errorf("Expected timeout ~ %v but got %v", timeout, took)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
||||
// Capture the expected panic
|
||||
defer func() {
|
||||
|
@ -301,7 +326,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
|||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -331,7 +356,7 @@ func TestWebSocketReverseProxyBackendShutDown(t *testing.T) {
|
|||
}()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(backend.URL, false)
|
||||
p := newWebSocketTestProxy(backend.URL, false, 30*time.Second)
|
||||
backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
|
@ -360,7 +385,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
|||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -407,7 +432,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
|||
defer wsEcho.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsEcho.URL, false)
|
||||
p := newWebSocketTestProxy(wsEcho.URL, false, 30*time.Second)
|
||||
|
||||
// This is a full end-end test, so the proxy handler
|
||||
// has to be part of a server listening on a port. Our
|
||||
|
@ -452,7 +477,7 @@ func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
|
|||
}))
|
||||
defer wsEcho.Close()
|
||||
|
||||
p := newWebSocketTestProxy(wsEcho.URL, true)
|
||||
p := newWebSocketTestProxy(wsEcho.URL, true, 30*time.Second)
|
||||
|
||||
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
|
@ -528,7 +553,7 @@ func TestUnixSocketProxy(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||
p := newWebSocketTestProxy(url, false)
|
||||
p := newWebSocketTestProxy(url, false, 30*time.Second)
|
||||
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
|
@ -686,7 +711,7 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
|
|||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
|
||||
upstream.host.UpstreamHeaders = http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"},
|
||||
|
@ -753,7 +778,7 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
|
|||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
|
||||
upstream.host.DownstreamHeaders = http.Header{
|
||||
"+Merge-Me": {"Merge-Value"},
|
||||
"+Add-Me": {"Add-Value"},
|
||||
|
@ -893,7 +918,7 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
|
|||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -913,6 +938,67 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTransparentHeaders(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
forwardedForHeader string
|
||||
expected []string
|
||||
}{
|
||||
{"No header", "192.168.0.1:80", "", []string{"192.168.0.1"}},
|
||||
{"Existing", "192.168.0.1:80", "1.1.1.1, 2.2.2.2", []string{"1.1.1.1, 2.2.2.2, 192.168.0.1"}},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testReverseProxyTransparentHeaders(t, tc.remoteAddr, tc.forwardedForHeader, tc.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testReverseProxyTransparentHeaders(t *testing.T, remoteAddr, forwardedForHeader string, expected []string) {
|
||||
// Arrange
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var actualHeaders http.Header
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualHeaders = r.Header
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
config := "proxy / " + backend.URL + " {\n transparent \n}"
|
||||
|
||||
// make proxy
|
||||
upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(config)), "")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error. Got: %s", err.Error())
|
||||
}
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: upstreams,
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r := httptest.NewRequest("GET", backend.URL, nil)
|
||||
r.RemoteAddr = remoteAddr
|
||||
if forwardedForHeader != "" {
|
||||
r.Header.Set("X-Forwarded-For", forwardedForHeader)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Act
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
// Assert
|
||||
if got := actualHeaders["X-Forwarded-For"]; !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("Transparent proxy response does not contain expected %v header: expect %v, but got %v",
|
||||
"X-Forwarded-For", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostHeaderReplacedUsingForward(t *testing.T) {
|
||||
var requestHost string
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -921,7 +1007,7 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
|
|||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
|
||||
proxyHostHeader := "test2.com"
|
||||
upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}}
|
||||
// set up proxy
|
||||
|
@ -943,11 +1029,22 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
basicAuthTestcase(t, nil, nil)
|
||||
basicAuthTestcase(t, nil, url.UserPassword("username", "password"))
|
||||
basicAuthTestcase(t, url.UserPassword("usename", "password"), nil)
|
||||
basicAuthTestcase(t, url.UserPassword("unused", "unused"),
|
||||
url.UserPassword("username", "password"))
|
||||
testCases := []struct {
|
||||
name string
|
||||
upstreamUser *url.Userinfo
|
||||
clientUser *url.Userinfo
|
||||
}{
|
||||
{"Nil Both", nil, nil},
|
||||
{"Nil Upstream User", nil, url.UserPassword("username", "password")},
|
||||
{"Nil Client User", url.UserPassword("usename", "password"), nil},
|
||||
{"Both Provided", url.UserPassword("unused", "unused"),
|
||||
url.UserPassword("username", "password")},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
basicAuthTestcase(t, tc.upstreamUser, tc.clientUser)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
|
||||
|
@ -972,7 +1069,7 @@ func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
|
|||
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext,
|
||||
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false, 30*time.Second)},
|
||||
}
|
||||
r, err := http.NewRequest("GET", "/foo", nil)
|
||||
if err != nil {
|
||||
|
@ -1107,7 +1204,7 @@ func TestProxyDirectorURL(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
|
||||
NewSingleHostReverseProxy(targetURL, c.without, 0).Director(req)
|
||||
NewSingleHostReverseProxy(targetURL, c.without, 0, 30*time.Second).Director(req)
|
||||
if expect, got := c.expectURL, req.URL.String(); expect != got {
|
||||
t.Errorf("case %d url not equal: expect %q, but got %q",
|
||||
i, expect, got)
|
||||
|
@ -1254,7 +1351,7 @@ func TestCancelRequest(t *testing.T) {
|
|||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
|
||||
}
|
||||
|
||||
// setup request with cancel ctx
|
||||
|
@ -1303,14 +1400,15 @@ func (r *noopReader) Read(b []byte) (int, error) {
|
|||
return n, nil
|
||||
}
|
||||
|
||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||
func newFakeUpstream(name string, insecure bool, timeout time.Duration) *fakeUpstream {
|
||||
uri, _ := url.Parse(name)
|
||||
u := &fakeUpstream{
|
||||
name: name,
|
||||
from: "/",
|
||||
name: name,
|
||||
from: "/",
|
||||
timeout: timeout,
|
||||
host: &UpstreamHost{
|
||||
Name: name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost),
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost, timeout),
|
||||
},
|
||||
}
|
||||
if insecure {
|
||||
|
@ -1324,6 +1422,7 @@ type fakeUpstream struct {
|
|||
host *UpstreamHost
|
||||
from string
|
||||
without string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) From() string {
|
||||
|
@ -1338,7 +1437,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
|
|||
}
|
||||
u.host = &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
|
||||
}
|
||||
}
|
||||
return u.host
|
||||
|
@ -1347,6 +1446,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
|
|||
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
|
||||
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
||||
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
||||
func (u *fakeUpstream) GetTimeout() time.Duration { return u.timeout }
|
||||
func (u *fakeUpstream) GetHostCount() int { return 1 }
|
||||
func (u *fakeUpstream) Stop() error { return nil }
|
||||
|
||||
|
@ -1354,13 +1454,14 @@ func (u *fakeUpstream) Stop() error { return nil }
|
|||
// redirect to the specified backendAddr. The function
|
||||
// also sets up the rules/environment for testing WebSocket
|
||||
// proxy.
|
||||
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
||||
func newWebSocketTestProxy(backendAddr string, insecure bool, timeout time.Duration) *Proxy {
|
||||
return &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{
|
||||
name: backendAddr,
|
||||
without: "",
|
||||
insecure: insecure,
|
||||
timeout: timeout,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
@ -1368,7 +1469,7 @@ func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
|||
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
|
||||
return &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix, timeout: 30 * time.Second}},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1376,6 +1477,7 @@ type fakeWsUpstream struct {
|
|||
name string
|
||||
without string
|
||||
insecure bool
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) From() string {
|
||||
|
@ -1386,7 +1488,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
|||
uri, _ := url.Parse(u.name)
|
||||
host := &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
|
||||
UpstreamHeaders: http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"}},
|
||||
|
@ -1400,6 +1502,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
|||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
||||
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
||||
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
||||
func (u *fakeWsUpstream) GetTimeout() time.Duration { return u.timeout }
|
||||
func (u *fakeWsUpstream) GetHostCount() int { return 1 }
|
||||
func (u *fakeWsUpstream) Stop() error { return nil }
|
||||
|
||||
|
@ -1445,7 +1548,7 @@ func BenchmarkProxy(b *testing.B) {
|
|||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
|
||||
upstream.host.UpstreamHeaders = http.Header{
|
||||
"Hostname": {"{hostname}"},
|
||||
"Host": {"{host}"},
|
||||
|
@ -1488,7 +1591,7 @@ func TestChunkedWebSocketReverseProxy(t *testing.T) {
|
|||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
|
|
@ -94,6 +94,10 @@ type ReverseProxy struct {
|
|||
// If zero, no periodic flushing is done.
|
||||
FlushInterval time.Duration
|
||||
|
||||
// dialer is used when values from the
|
||||
// defaultDialer need to be overridden per Proxy
|
||||
dialer *net.Dialer
|
||||
|
||||
srvResolver srvResolver
|
||||
}
|
||||
|
||||
|
@ -103,13 +107,13 @@ type ReverseProxy struct {
|
|||
// What we need is just the path, so if "unix:/var/run/www.socket"
|
||||
// was the proxy directive, the parsed hostName would be
|
||||
// "unix:///var/run/www.socket", hence the ambiguous trimming.
|
||||
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
|
||||
func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
|
||||
return func(network, addr string) (conn net.Conn, err error) {
|
||||
return net.Dial("unix", hostName[len("unix://"):])
|
||||
return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
|
||||
func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
|
||||
service := locator
|
||||
if strings.HasPrefix(locator, "srv://") {
|
||||
service = locator[6:]
|
||||
|
@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
|
||||
return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
|
|||
// the target request will be for /base/dir.
|
||||
// Without logic: target's path is "/", incoming is "/api/messages",
|
||||
// without is "/api", then the target request will be for /messages.
|
||||
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
|
||||
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
if target.Scheme == "unix" {
|
||||
|
@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
}
|
||||
}
|
||||
|
||||
dialer := *defaultDialer
|
||||
if timeout != defaultDialer.Timeout {
|
||||
dialer.Timeout = timeout
|
||||
}
|
||||
|
||||
rp := &ReverseProxy{
|
||||
Director: director,
|
||||
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
|
||||
srvResolver: net.DefaultResolver,
|
||||
dialer: &dialer,
|
||||
}
|
||||
|
||||
if target.Scheme == "unix" {
|
||||
rp.Transport = &http.Transport{
|
||||
Dial: socketDial(target.String()),
|
||||
Dial: socketDial(target.String(), timeout),
|
||||
}
|
||||
} else if target.Scheme == "quic" {
|
||||
rp.Transport = &h2quic.RoundTripper{
|
||||
|
@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
},
|
||||
}
|
||||
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
|
||||
dialFunc := defaultDialer.Dial
|
||||
dialFunc := rp.dialer.Dial
|
||||
if strings.HasPrefix(target.Scheme, "srv") {
|
||||
dialFunc = rp.srvDialerFunc(target.String())
|
||||
dialFunc = rp.srvDialerFunc(target.String(), timeout)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
|
@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
|
|||
if rp.Transport == nil {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: defaultDialer.Dial,
|
||||
Dial: rp.dialer.Dial,
|
||||
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
|||
if requestIsWebsocket(outreq) {
|
||||
transport = newConnHijackerTransport(transport)
|
||||
} else if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
transport = &http.Transport{
|
||||
Dial: rp.dialer.Dial,
|
||||
}
|
||||
}
|
||||
|
||||
rp.Director(outreq)
|
||||
|
@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
|||
}
|
||||
bufferPool.Put(hj.Replay)
|
||||
} else {
|
||||
backendConn, err = net.Dial("tcp", outreq.URL.Host)
|
||||
backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
|
|||
}
|
||||
port := uint16(pp)
|
||||
|
||||
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
|
||||
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second)
|
||||
rp.srvResolver = testResolver{
|
||||
result: []*net.SRV{
|
||||
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
|
||||
|
|
|
@ -49,6 +49,7 @@ type staticUpstream struct {
|
|||
Hosts HostPool
|
||||
Policy Policy
|
||||
KeepAlive int
|
||||
Timeout time.Duration
|
||||
FailTimeout time.Duration
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
|
@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
|
|||
TryInterval: 250 * time.Millisecond,
|
||||
MaxConns: 0,
|
||||
KeepAlive: http.DefaultMaxIdleConnsPerHost,
|
||||
Timeout: 30 * time.Second,
|
||||
resolver: net.DefaultResolver,
|
||||
}
|
||||
|
||||
|
@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
|
||||
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout)
|
||||
if u.insecureSkipVerify {
|
||||
uh.ReverseProxy.UseInsecureTransport()
|
||||
}
|
||||
|
@ -431,9 +433,10 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
|
|||
}
|
||||
u.downstreamHeaders.Add(header, value)
|
||||
case "transparent":
|
||||
// Note: X-Forwarded-For header is always being appended for proxy connections
|
||||
// See implementation of createUpstreamRequest in proxy.go
|
||||
u.upstreamHeaders.Add("Host", "{host}")
|
||||
u.upstreamHeaders.Add("X-Real-IP", "{remote}")
|
||||
u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
|
||||
u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
|
||||
case "websocket":
|
||||
u.upstreamHeaders.Add("Connection", "{>Connection}")
|
||||
|
@ -463,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
|
|||
return c.ArgErr()
|
||||
}
|
||||
u.KeepAlive = n
|
||||
case "timeout":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return c.Errf("unable to parse timeout duration '%s'", c.Val())
|
||||
}
|
||||
u.Timeout = dur
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
|
@ -618,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
|
|||
return u.TryInterval
|
||||
}
|
||||
|
||||
// GetTimeout returns u.Timeout.
|
||||
func (u *staticUpstream) GetTimeout() time.Duration {
|
||||
return u.Timeout
|
||||
}
|
||||
|
||||
func (u *staticUpstream) GetHostCount() int {
|
||||
return len(u.Hosts)
|
||||
}
|
||||
|
|
|
@ -282,7 +282,8 @@ func TestStop(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestParseBlock(t *testing.T) {
|
||||
func TestParseBlockTransparent(t *testing.T) {
|
||||
// tests for transparent proxy presets
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
tests := []struct {
|
||||
config string
|
||||
|
@ -316,6 +317,10 @@ func TestParseBlock(t *testing.T) {
|
|||
if _, ok := headers["X-Forwarded-Proto"]; !ok {
|
||||
t.Errorf("Test %d: Could not find the X-Forwarded-Proto header", i+1)
|
||||
}
|
||||
|
||||
if _, ok := headers["X-Forwarded-For"]; ok {
|
||||
t.Errorf("Test %d: Found unexpected X-Forwarded-For header", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,22 +63,38 @@ type Rule interface {
|
|||
|
||||
// SimpleRule is a simple rewrite rule.
|
||||
type SimpleRule struct {
|
||||
From, To string
|
||||
Regexp *regexp.Regexp
|
||||
To string
|
||||
Negate bool
|
||||
}
|
||||
|
||||
// NewSimpleRule creates a new Simple Rule
|
||||
func NewSimpleRule(from, to string) SimpleRule {
|
||||
return SimpleRule{from, to}
|
||||
func NewSimpleRule(from, to string, negate bool) (*SimpleRule, error) {
|
||||
r, err := regexp.Compile(from)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SimpleRule{
|
||||
Regexp: r,
|
||||
To: to,
|
||||
Negate: negate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BasePath satisfies httpserver.Config
|
||||
func (s SimpleRule) BasePath() string { return s.From }
|
||||
func (s SimpleRule) BasePath() string { return "/" }
|
||||
|
||||
// Match satisfies httpserver.Config
|
||||
func (s SimpleRule) Match(r *http.Request) bool { return s.From == r.URL.Path }
|
||||
func (s *SimpleRule) Match(r *http.Request) bool {
|
||||
matches := regexpMatches(s.Regexp, "/", r.URL.Path)
|
||||
if s.Negate {
|
||||
return len(matches) == 0
|
||||
}
|
||||
return len(matches) > 0
|
||||
}
|
||||
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
|
||||
func (s *SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
|
||||
|
||||
// attempt rewrite
|
||||
return To(fs, r, s.To, newReplacer(r))
|
||||
|
@ -165,7 +181,7 @@ func (r ComplexRule) Match(req *http.Request) bool {
|
|||
return true
|
||||
}
|
||||
// otherwise validate regex
|
||||
return r.regexpMatches(req.URL.Path) != nil
|
||||
return regexpMatches(r.Regexp, r.Base, req.URL.Path) != nil
|
||||
}
|
||||
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
|
@ -174,7 +190,7 @@ func (r ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result)
|
|||
|
||||
// validate regexp if present
|
||||
if r.Regexp != nil {
|
||||
matches := r.regexpMatches(req.URL.Path)
|
||||
matches := regexpMatches(r.Regexp, r.Base, req.URL.Path)
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
// no match
|
||||
|
@ -230,14 +246,14 @@ func (r ComplexRule) matchExt(rPath string) bool {
|
|||
return !mustUse
|
||||
}
|
||||
|
||||
func (r ComplexRule) regexpMatches(rPath string) []string {
|
||||
if r.Regexp != nil {
|
||||
func regexpMatches(regexp *regexp.Regexp, base, rPath string) []string {
|
||||
if regexp != nil {
|
||||
// include trailing slash in regexp if present
|
||||
start := len(r.Base)
|
||||
if strings.HasSuffix(r.Base, "/") {
|
||||
start := len(base)
|
||||
if strings.HasSuffix(base, "/") {
|
||||
start--
|
||||
}
|
||||
return r.Regexp.FindStringSubmatch(rPath[start:])
|
||||
return regexp.FindStringSubmatch(rPath[start:])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -29,9 +29,9 @@ func TestRewrite(t *testing.T) {
|
|||
rw := Rewrite{
|
||||
Next: httpserver.HandlerFunc(urlPrinter),
|
||||
Rules: []httpserver.HandlerConfig{
|
||||
NewSimpleRule("/from", "/to"),
|
||||
NewSimpleRule("/a", "/b"),
|
||||
NewSimpleRule("/b", "/b{uri}"),
|
||||
newSimpleRule(t, "^/from$", "/to"),
|
||||
newSimpleRule(t, "^/a$", "/b"),
|
||||
newSimpleRule(t, "^/b$", "/b{uri}"),
|
||||
},
|
||||
FileSys: http.Dir("."),
|
||||
}
|
||||
|
@ -131,6 +131,45 @@ func TestRewrite(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestWordpress is a test for wordpress usecase.
|
||||
func TestWordpress(t *testing.T) {
|
||||
rw := Rewrite{
|
||||
Next: httpserver.HandlerFunc(urlPrinter),
|
||||
Rules: []httpserver.HandlerConfig{
|
||||
// both rules are same, thanks to Go regexp (confusion).
|
||||
newSimpleRule(t, "^/wp-admin", "{path} {path}/ /index.php?{query}", true),
|
||||
newSimpleRule(t, "^\\/wp-admin", "{path} {path}/ /index.php?{query}", true),
|
||||
},
|
||||
FileSys: http.Dir("."),
|
||||
}
|
||||
tests := []struct {
|
||||
from string
|
||||
expectedTo string
|
||||
}{
|
||||
{"/wp-admin", "/wp-admin"},
|
||||
{"/wp-admin/login.php", "/wp-admin/login.php"},
|
||||
{"/not-wp-admin/login.php?not=admin", "/index.php?not=admin"},
|
||||
{"/loophole", "/index.php"},
|
||||
{"/user?name=john", "/index.php?name=john"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), httpserver.OriginalURLCtxKey, *req.URL)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
rw.ServeHTTP(rec, req)
|
||||
|
||||
if got, want := rec.Body.String(), test.expectedTo; got != want {
|
||||
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", i, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
fmt.Fprint(w, r.URL.String())
|
||||
return 0, nil
|
||||
|
|
|
@ -58,6 +58,7 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
|
|||
var base = "/"
|
||||
var pattern, to string
|
||||
var ext []string
|
||||
var negate bool
|
||||
|
||||
args := c.RemainingArgs()
|
||||
|
||||
|
@ -111,7 +112,14 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
|
|||
|
||||
// the only unhandled case is 2 and above
|
||||
default:
|
||||
rule = NewSimpleRule(args[0], strings.Join(args[1:], " "))
|
||||
if args[0] == "not" {
|
||||
negate = true
|
||||
args = args[1:]
|
||||
}
|
||||
rule, err = NewSimpleRule(args[0], strings.Join(args[1:], " "), negate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
|
|
|
@ -50,6 +50,19 @@ func TestSetup(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// newSimpleRule is convenience test function for SimpleRule.
|
||||
func newSimpleRule(t *testing.T, from, to string, negate ...bool) Rule {
|
||||
var n bool
|
||||
if len(negate) > 0 {
|
||||
n = negate[0]
|
||||
}
|
||||
rule, err := NewSimpleRule(from, to, n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func TestRewriteParse(t *testing.T) {
|
||||
simpleTests := []struct {
|
||||
input string
|
||||
|
@ -57,17 +70,20 @@ func TestRewriteParse(t *testing.T) {
|
|||
expected []Rule
|
||||
}{
|
||||
{`rewrite /from /to`, false, []Rule{
|
||||
SimpleRule{From: "/from", To: "/to"},
|
||||
newSimpleRule(t, "/from", "/to"),
|
||||
}},
|
||||
{`rewrite /from /to
|
||||
rewrite a b`, false, []Rule{
|
||||
SimpleRule{From: "/from", To: "/to"},
|
||||
SimpleRule{From: "a", To: "b"},
|
||||
newSimpleRule(t, "/from", "/to"),
|
||||
newSimpleRule(t, "a", "b"),
|
||||
}},
|
||||
{`rewrite a`, true, []Rule{}},
|
||||
{`rewrite`, true, []Rule{}},
|
||||
{`rewrite a b c`, false, []Rule{
|
||||
SimpleRule{From: "a", To: "b c"},
|
||||
newSimpleRule(t, "a", "b c"),
|
||||
}},
|
||||
{`rewrite not a b c`, false, []Rule{
|
||||
newSimpleRule(t, "a", "b c", true),
|
||||
}},
|
||||
}
|
||||
|
||||
|
@ -88,17 +104,22 @@ func TestRewriteParse(t *testing.T) {
|
|||
}
|
||||
|
||||
for j, e := range test.expected {
|
||||
actualRule := actual[j].(SimpleRule)
|
||||
expectedRule := e.(SimpleRule)
|
||||
actualRule := actual[j].(*SimpleRule)
|
||||
expectedRule := e.(*SimpleRule)
|
||||
|
||||
if actualRule.From != expectedRule.From {
|
||||
if actualRule.Regexp.String() != expectedRule.Regexp.String() {
|
||||
t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
|
||||
i, j, expectedRule.From, actualRule.From)
|
||||
i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
|
||||
}
|
||||
|
||||
if actualRule.To != expectedRule.To {
|
||||
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
|
||||
i, j, expectedRule.To, actualRule.To)
|
||||
i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
|
||||
}
|
||||
|
||||
if actualRule.Negate != expectedRule.Negate {
|
||||
t.Errorf("Test %d, rule %d: Expected Negate=%v, got %v",
|
||||
i, j, expectedRule.Negate, actualRule.Negate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -265,21 +265,21 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if leaf.Subject.CommonName != "" {
|
||||
if leaf.Subject.CommonName != "" { // TODO: CommonName is deprecated
|
||||
cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
|
||||
}
|
||||
for _, name := range leaf.DNSNames {
|
||||
if name != leaf.Subject.CommonName {
|
||||
if name != leaf.Subject.CommonName { // TODO: CommonName is deprecated
|
||||
cert.Names = append(cert.Names, strings.ToLower(name))
|
||||
}
|
||||
}
|
||||
for _, ip := range leaf.IPAddresses {
|
||||
if ipStr := ip.String(); ipStr != leaf.Subject.CommonName {
|
||||
if ipStr := ip.String(); ipStr != leaf.Subject.CommonName { // TODO: CommonName is deprecated
|
||||
cert.Names = append(cert.Names, strings.ToLower(ipStr))
|
||||
}
|
||||
}
|
||||
for _, email := range leaf.EmailAddresses {
|
||||
if email != leaf.Subject.CommonName {
|
||||
if email != leaf.Subject.CommonName { // TODO: CommonName is deprecated
|
||||
cert.Names = append(cert.Names, strings.ToLower(email))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,10 +43,11 @@ func TestUnexportedGetCertificate(t *testing.T) {
|
|||
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
|
||||
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
|
||||
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
|
||||
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||
}
|
||||
// TODO: Re-implement this behavior when I'm not in the middle of upgrading for ACMEv2 support. :) (it was reverted in #2037)
|
||||
// // When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
|
||||
// if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
|
||||
// t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||
// }
|
||||
|
||||
// When no certificate matches and SNI is NOT provided, a random is returned
|
||||
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
// acmeMu ensures that only one ACME challenge occurs at a time.
|
||||
|
@ -90,27 +90,22 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||
// If not registered, the user must register an account with the CA
|
||||
// and agree to terms
|
||||
if leUser.Registration == nil {
|
||||
reg, err := client.Register()
|
||||
if allowPrompts { // can't prompt a user who isn't there
|
||||
termsURL := client.GetToSURL()
|
||||
if !Agreed && termsURL != "" {
|
||||
Agreed = askUserAgreement(client.GetToSURL())
|
||||
}
|
||||
if !Agreed && termsURL != "" {
|
||||
return nil, errors.New("user must agree to CA terms (use -agree flag)")
|
||||
}
|
||||
}
|
||||
|
||||
reg, err := client.Register(Agreed)
|
||||
if err != nil {
|
||||
return nil, errors.New("registration error: " + err.Error())
|
||||
}
|
||||
leUser.Registration = reg
|
||||
|
||||
if allowPrompts { // can't prompt a user who isn't there
|
||||
if !Agreed && reg.TosURL == "" {
|
||||
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
|
||||
}
|
||||
if !Agreed && reg.TosURL == "" {
|
||||
return nil, errors.New("user must agree to terms")
|
||||
}
|
||||
}
|
||||
|
||||
err = client.AgreeToTOS()
|
||||
if err != nil {
|
||||
saveUser(storage, leUser) // Might as well try, right?
|
||||
return nil, errors.New("error agreeing to terms: " + err.Error())
|
||||
}
|
||||
|
||||
// save user to the file system
|
||||
err = saveUser(storage, leUser)
|
||||
if err != nil {
|
||||
|
@ -137,38 +132,57 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||
useHTTPPort = DefaultHTTPAlternatePort
|
||||
}
|
||||
|
||||
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
|
||||
// See which port TLS-SNI challenges will be accomplished on
|
||||
useTLSSNIPort := TLSSNIChallengePort
|
||||
if config.AltTLSSNIPort != "" {
|
||||
useTLSSNIPort = config.AltTLSSNIPort
|
||||
}
|
||||
|
||||
// Always respect user's bind preferences by using config.ListenHost.
|
||||
// NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress()
|
||||
// must be called before SetChallengeProvider(), since they reset the
|
||||
// challenge provider back to the default one!
|
||||
err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// useTLSSNIPort := TLSSNIChallengePort
|
||||
// if config.AltTLSSNIPort != "" {
|
||||
// useTLSSNIPort = config.AltTLSSNIPort
|
||||
// }
|
||||
// err := c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// if using file storage, we can distribute the HTTP challenge across
|
||||
// all instances sharing the acme folder; either way, we must still set
|
||||
// the address for the default HTTP provider server
|
||||
var useDistributedHTTPSolver bool
|
||||
if storage, err := c.config.StorageFor(c.config.CAUrl); err == nil {
|
||||
if _, ok := storage.(*FileStorage); ok {
|
||||
useDistributedHTTPSolver = true
|
||||
}
|
||||
}
|
||||
if useDistributedHTTPSolver {
|
||||
c.acmeClient.SetChallengeProvider(acme.HTTP01, distributedHTTPSolver{
|
||||
// being careful to respect user's listener bind preferences
|
||||
httpProviderServer: acme.NewHTTPProviderServer(config.ListenHost, useHTTPPort),
|
||||
})
|
||||
} else {
|
||||
// Always respect user's bind preferences by using config.ListenHost.
|
||||
// NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress()
|
||||
// must be called before SetChallengeProvider() (see above), since they reset
|
||||
// the challenge provider back to the default one! (still true in March 2018)
|
||||
err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
|
||||
// See if TLS challenge needs to be handled by our own facilities
|
||||
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
|
||||
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
|
||||
}
|
||||
// if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
|
||||
// c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
|
||||
// }
|
||||
|
||||
// Disable any challenges that should not be used
|
||||
var disabledChallenges []acme.Challenge
|
||||
if DisableHTTPChallenge {
|
||||
disabledChallenges = append(disabledChallenges, acme.HTTP01)
|
||||
}
|
||||
if DisableTLSSNIChallenge {
|
||||
disabledChallenges = append(disabledChallenges, acme.TLSSNI01)
|
||||
}
|
||||
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
|
||||
// if DisableTLSSNIChallenge {
|
||||
// disabledChallenges = append(disabledChallenges, acme.TLSSNI01)
|
||||
// }
|
||||
if len(disabledChallenges) > 0 {
|
||||
c.acmeClient.ExcludeChallenges(disabledChallenges)
|
||||
}
|
||||
|
@ -189,7 +203,9 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||
}
|
||||
|
||||
// Use the DNS challenge exclusively
|
||||
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01})
|
||||
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
|
||||
// c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01})
|
||||
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01})
|
||||
c.acmeClient.SetChallengeProvider(acme.DNS01, prov)
|
||||
}
|
||||
|
||||
|
@ -222,41 +238,31 @@ func (c *ACMEClient) Obtain(name string) error {
|
|||
}
|
||||
}()
|
||||
|
||||
Attempts:
|
||||
for attempts := 0; attempts < 2; attempts++ {
|
||||
namesObtaining.Add([]string{name})
|
||||
acmeMu.Lock()
|
||||
certificate, failures := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple)
|
||||
certificate, err := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple)
|
||||
acmeMu.Unlock()
|
||||
namesObtaining.Remove([]string{name})
|
||||
if len(failures) > 0 {
|
||||
// Error - try to fix it or report it to the user and abort
|
||||
var errMsg string // we'll combine all the failures into a single error message
|
||||
var promptedForAgreement bool // only prompt user for agreement at most once
|
||||
|
||||
for errDomain, obtainErr := range failures {
|
||||
if obtainErr == nil {
|
||||
continue
|
||||
}
|
||||
if tosErr, ok := obtainErr.(acme.TOSError); ok {
|
||||
// Terms of Service agreement error; we can probably deal with this
|
||||
if !Agreed && !promptedForAgreement && c.AllowPrompts {
|
||||
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
|
||||
promptedForAgreement = true
|
||||
}
|
||||
if Agreed || !c.AllowPrompts {
|
||||
err := c.acmeClient.AgreeToTOS()
|
||||
if err != nil {
|
||||
return errors.New("error agreeing to updated terms: " + err.Error())
|
||||
}
|
||||
continue Attempts
|
||||
if err != nil {
|
||||
// for a certain kind of error, we can enumerate the error per-domain
|
||||
if failures, ok := err.(acme.ObtainError); ok && len(failures) > 0 {
|
||||
var errMsg string // combine all the failures into a single error message
|
||||
for errDomain, obtainErr := range failures {
|
||||
if obtainErr == nil {
|
||||
continue
|
||||
}
|
||||
errMsg += fmt.Sprintf("[%s] failed to get certificate: %v\n", errDomain, obtainErr)
|
||||
}
|
||||
|
||||
// If user did not agree or it was any other kind of error, just append to the list of errors
|
||||
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
return errors.New(errMsg)
|
||||
|
||||
return fmt.Errorf("[%s] failed to obtain certificate: %v", name, err)
|
||||
}
|
||||
|
||||
// double-check that we actually got a certificate, in case there's a bug upstream (see issue #2121)
|
||||
if certificate.Domain == "" || certificate.Certificate == nil {
|
||||
return errors.New("returned certificate was empty; probably an unchecked error obtaining it")
|
||||
}
|
||||
|
||||
// Success - immediately save the certificate resource
|
||||
|
@ -315,23 +321,20 @@ func (c *ACMEClient) Renew(name string) error {
|
|||
acmeMu.Unlock()
|
||||
namesObtaining.Remove([]string{name})
|
||||
if err == nil {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
|
||||
// If the legal terms were updated and need to be
|
||||
// agreed to again, we can handle that.
|
||||
if _, ok := err.(acme.TOSError); ok {
|
||||
err := c.acmeClient.AgreeToTOS()
|
||||
if err != nil {
|
||||
return err
|
||||
// double-check that we actually got a certificate; check a couple fields
|
||||
// TODO: This is a temporary workaround for what I think is a bug in the acmev2 package (March 2018)
|
||||
// but it might not hurt to keep this extra check in place
|
||||
if newCertMeta.Domain == "" || newCertMeta.Certificate == nil {
|
||||
err = errors.New("returned certificate was empty; probably an unchecked error renewing it")
|
||||
} else {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For any other kind of error, wait 10s and try again.
|
||||
// wait a little bit and try again
|
||||
wait := 10 * time.Second
|
||||
log.Printf("[ERROR] Renewing: %v; trying again in %s", err, wait)
|
||||
log.Printf("[ERROR] Renewing [%v]: %v; trying again in %s", name, err, wait)
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
|
||||
"github.com/klauspost/cpuid"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
// Config describes how TLS should be configured and used.
|
||||
|
@ -190,10 +190,15 @@ func NewConfig(inst *caddy.Instance) *Config {
|
|||
// it does not load them into memory. If allowPrompts is true,
|
||||
// the user may be shown a prompt.
|
||||
func (c *Config) ObtainCert(name string, allowPrompts bool) error {
|
||||
if !c.Managed || !HostQualifies(name) {
|
||||
skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if skip {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we expect this to be a new (non-existent) site
|
||||
storage, err := c.StorageFor(c.CAUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -205,9 +210,6 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
|
|||
if siteExists {
|
||||
return nil
|
||||
}
|
||||
if c.ACMEEmail == "" {
|
||||
c.ACMEEmail = getEmail(storage, allowPrompts)
|
||||
}
|
||||
|
||||
client, err := newACMEClient(c, allowPrompts)
|
||||
if err != nil {
|
||||
|
@ -219,6 +221,14 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
|
|||
// RenewCert renews the certificate for name using c. It stows the
|
||||
// renewed certificate and its assets in storage if successful.
|
||||
func (c *Config) RenewCert(name string, allowPrompts bool) error {
|
||||
skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if skip {
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := newACMEClient(c, allowPrompts)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -226,6 +236,33 @@ func (c *Config) RenewCert(name string, allowPrompts bool) error {
|
|||
return client.Renew(name)
|
||||
}
|
||||
|
||||
// preObtainOrRenewChecks perform a few simple checks before
|
||||
// obtaining or renewing a certificate with ACME, and returns
|
||||
// whether this name should be skipped (like if it's not
|
||||
// managed TLS) as well as any error. It ensures that the
|
||||
// config is Managed, that the name qualifies for a certificate,
|
||||
// and that an email address is available.
|
||||
func (c *Config) preObtainOrRenewChecks(name string, allowPrompts bool) (bool, error) {
|
||||
if !c.Managed || !HostQualifies(name) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// wildcard certificates require DNS challenge (as of March 2018)
|
||||
if strings.Contains(name, "*") && c.DNSProvider == "" {
|
||||
return false, fmt.Errorf("wildcard domain name (%s) requires DNS challenge; use dns subdirective to configure it", name)
|
||||
}
|
||||
|
||||
if c.ACMEEmail == "" {
|
||||
var err error
|
||||
c.ACMEEmail, err = getEmail(c, allowPrompts)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// StorageFor obtains a TLS Storage instance for the given CA URL which should
|
||||
// be unique for every different ACME CA. If a StorageCreator is set on this
|
||||
// Config, it will be used. Otherwise the default file storage implementation
|
||||
|
@ -476,6 +513,14 @@ func assertConfigsCompatible(cfg1, cfg2 *Config) error {
|
|||
if c1.ClientAuth != c2.ClientAuth {
|
||||
return fmt.Errorf("client authentication policy mismatch")
|
||||
}
|
||||
if c1.ClientAuth != tls.NoClientCert && c2.ClientAuth != tls.NoClientCert && c1.ClientCAs != c2.ClientCAs {
|
||||
// Two hosts defined on the same listener are not compatible if they
|
||||
// have ClientAuth enabled, because there's no guarantee beyond the
|
||||
// hostname which config will be used (because SNI only has server name).
|
||||
// To prevent clients from bypassing authentication, require that
|
||||
// ClientAuth be configured in an unambiguous manner.
|
||||
return fmt.Errorf("multiple hosts requiring client authentication ambiguously configured")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -511,7 +556,7 @@ func SetDefaultTLSParams(config *Config) {
|
|||
|
||||
// Set default protocol min and max versions - must balance compatibility and security
|
||||
if config.ProtocolMinVersion == 0 {
|
||||
config.ProtocolMinVersion = tls.VersionTLS11
|
||||
config.ProtocolMinVersion = tls.VersionTLS12
|
||||
}
|
||||
if config.ProtocolMaxVersion == 0 {
|
||||
config.ProtocolMaxVersion = tls.VersionTLS12
|
||||
|
@ -532,7 +577,8 @@ var supportedKeyTypes = map[string]acme.KeyType{
|
|||
|
||||
// Map of supported protocols.
|
||||
// HTTP/2 only supports TLS 1.2 and higher.
|
||||
var supportedProtocols = map[string]uint16{
|
||||
// If updating this map, also update tlsProtocolStringToMap in caddyhttp/fastcgi/fastcgi.go
|
||||
var SupportedProtocols = map[string]uint16{
|
||||
"tls1.0": tls.VersionTLS10,
|
||||
"tls1.1": tls.VersionTLS11,
|
||||
"tls1.2": tls.VersionTLS12,
|
||||
|
@ -548,7 +594,7 @@ var supportedProtocols = map[string]uint16{
|
|||
// it is always added (even though it is not technically a cipher suite).
|
||||
//
|
||||
// This map, like any map, is NOT ORDERED. Do not range over this map.
|
||||
var supportedCiphersMap = map[string]uint16{
|
||||
var SupportedCiphersMap = map[string]uint16{
|
||||
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
|
|
|
@ -35,13 +35,14 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
|
||||
|
@ -106,7 +107,8 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error {
|
|||
// TODO: Use Storage interface instead of disk directly
|
||||
var ocspFileNamePrefix string
|
||||
if len(cert.Names) > 0 {
|
||||
ocspFileNamePrefix = cert.Names[0] + "-"
|
||||
firstName := strings.Replace(cert.Names[0], "*", "wildcard_", -1)
|
||||
ocspFileNamePrefix = firstName + "-"
|
||||
}
|
||||
ocspFileName := ocspFileNamePrefix + fastHash(pemBundle)
|
||||
ocspCachePath := filepath.Join(ocspFolder, ocspFileName)
|
||||
|
@ -216,10 +218,13 @@ func makeSelfSignedCert(config *Config) error {
|
|||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
var names []string
|
||||
if ip := net.ParseIP(config.Hostname); ip != nil {
|
||||
names = append(names, strings.ToLower(ip.String()))
|
||||
cert.IPAddresses = append(cert.IPAddresses, ip)
|
||||
} else {
|
||||
cert.DNSNames = append(cert.DNSNames, config.Hostname)
|
||||
names = append(names, strings.ToLower(config.Hostname))
|
||||
cert.DNSNames = append(cert.DNSNames, strings.ToLower(config.Hostname))
|
||||
}
|
||||
|
||||
publicKey := func(privKey interface{}) interface{} {
|
||||
|
@ -245,7 +250,7 @@ func makeSelfSignedCert(config *Config) error {
|
|||
PrivateKey: privKey,
|
||||
Leaf: cert,
|
||||
},
|
||||
Names: cert.DNSNames,
|
||||
Names: names,
|
||||
NotAfter: cert.NotAfter,
|
||||
Hash: hashCertificateChain(chain),
|
||||
})
|
||||
|
|
|
@ -30,14 +30,14 @@ func init() {
|
|||
RegisterStorageProvider("file", NewFileStorage)
|
||||
}
|
||||
|
||||
// storageBasePath is the root path in which all TLS/ACME assets are
|
||||
// stored. Do not change this value during the lifetime of the program.
|
||||
var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
|
||||
|
||||
// NewFileStorage is a StorageConstructor function that creates a new
|
||||
// Storage instance backed by the local disk. The resulting Storage
|
||||
// instance is guaranteed to be non-nil if there is no error.
|
||||
func NewFileStorage(caURL *url.URL) (Storage, error) {
|
||||
// storageBasePath is the root path in which all TLS/ACME assets are
|
||||
// stored. Do not change this value during the lifetime of the program.
|
||||
storageBasePath := filepath.Join(caddy.AssetsPath(), "acme")
|
||||
|
||||
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
|
||||
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
|
||||
return storage, nil
|
||||
|
@ -58,25 +58,25 @@ func (s *FileStorage) sites() string {
|
|||
|
||||
// site returns the path to the folder containing assets for domain.
|
||||
func (s *FileStorage) site(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
domain = fileSafe(domain)
|
||||
return filepath.Join(s.sites(), domain)
|
||||
}
|
||||
|
||||
// siteCertFile returns the path to the certificate file for domain.
|
||||
func (s *FileStorage) siteCertFile(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
domain = fileSafe(domain)
|
||||
return filepath.Join(s.site(domain), domain+".crt")
|
||||
}
|
||||
|
||||
// siteKeyFile returns the path to domain's private key file.
|
||||
func (s *FileStorage) siteKeyFile(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
domain = fileSafe(domain)
|
||||
return filepath.Join(s.site(domain), domain+".key")
|
||||
}
|
||||
|
||||
// siteMetaFile returns the path to the domain's asset metadata file.
|
||||
func (s *FileStorage) siteMetaFile(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
domain = fileSafe(domain)
|
||||
return filepath.Join(s.site(domain), domain+".json")
|
||||
}
|
||||
|
||||
|
@ -90,7 +90,7 @@ func (s *FileStorage) user(email string) string {
|
|||
if email == "" {
|
||||
email = emptyEmail
|
||||
}
|
||||
email = strings.ToLower(email)
|
||||
email = fileSafe(email)
|
||||
return filepath.Join(s.users(), email)
|
||||
}
|
||||
|
||||
|
@ -117,6 +117,7 @@ func (s *FileStorage) userRegFile(email string) string {
|
|||
if fileName == "" {
|
||||
fileName = "registration"
|
||||
}
|
||||
fileName = fileSafe(fileName)
|
||||
return filepath.Join(s.user(email), fileName+".json")
|
||||
}
|
||||
|
||||
|
@ -131,6 +132,7 @@ func (s *FileStorage) userKeyFile(email string) string {
|
|||
if fileName == "" {
|
||||
fileName = "private"
|
||||
}
|
||||
fileName = fileSafe(fileName)
|
||||
return filepath.Join(s.user(email), fileName+".key")
|
||||
}
|
||||
|
||||
|
@ -274,3 +276,29 @@ func (s *FileStorage) MostRecentUserEmail() string {
|
|||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// fileSafe standardizes and sanitizes str for use in a file path.
|
||||
func fileSafe(str string) string {
|
||||
str = strings.ToLower(str)
|
||||
str = strings.TrimSpace(str)
|
||||
repl := strings.NewReplacer("..", "",
|
||||
"/", "",
|
||||
"\\", "",
|
||||
// TODO: Consider also replacing "@" with "_at_" (but migrate existing accounts...)
|
||||
"+", "_plus_",
|
||||
"%", "",
|
||||
"$", "",
|
||||
"`", "",
|
||||
"~", "",
|
||||
":", "",
|
||||
";", "",
|
||||
"=", "",
|
||||
"!", "",
|
||||
"#", "",
|
||||
"&", "",
|
||||
"|", "",
|
||||
"\"", "",
|
||||
"'", "",
|
||||
"*", "wildcard_")
|
||||
return repl.Replace(str)
|
||||
}
|
||||
|
|
|
@ -14,7 +14,71 @@
|
|||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// *********************************** NOTE ********************************
|
||||
// Due to circular package dependencies with the storagetest sub package and
|
||||
// the fact that we want to use that harness to test file storage, the tests
|
||||
// for file storage are done in the storagetest package.
|
||||
// the fact that we want to use that harness to test file storage, most of
|
||||
// the tests for file storage are done in the storagetest package.
|
||||
|
||||
func TestPathBuilders(t *testing.T) {
|
||||
fs := FileStorage{Path: "test"}
|
||||
|
||||
for i, testcase := range []struct {
|
||||
in, folder, certFile, keyFile, metaFile string
|
||||
}{
|
||||
{
|
||||
in: "example.com",
|
||||
folder: filepath.Join("test", "sites", "example.com"),
|
||||
certFile: filepath.Join("test", "sites", "example.com", "example.com.crt"),
|
||||
keyFile: filepath.Join("test", "sites", "example.com", "example.com.key"),
|
||||
metaFile: filepath.Join("test", "sites", "example.com", "example.com.json"),
|
||||
},
|
||||
{
|
||||
in: "*.example.com",
|
||||
folder: filepath.Join("test", "sites", "wildcard_.example.com"),
|
||||
certFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.crt"),
|
||||
keyFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.key"),
|
||||
metaFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.json"),
|
||||
},
|
||||
{
|
||||
// prevent directory traversal! very important, esp. with on-demand TLS
|
||||
// see issue #2092
|
||||
in: "a/../../../foo",
|
||||
folder: filepath.Join("test", "sites", "afoo"),
|
||||
certFile: filepath.Join("test", "sites", "afoo", "afoo.crt"),
|
||||
keyFile: filepath.Join("test", "sites", "afoo", "afoo.key"),
|
||||
metaFile: filepath.Join("test", "sites", "afoo", "afoo.json"),
|
||||
},
|
||||
{
|
||||
in: "b\\..\\..\\..\\foo",
|
||||
folder: filepath.Join("test", "sites", "bfoo"),
|
||||
certFile: filepath.Join("test", "sites", "bfoo", "bfoo.crt"),
|
||||
keyFile: filepath.Join("test", "sites", "bfoo", "bfoo.key"),
|
||||
metaFile: filepath.Join("test", "sites", "bfoo", "bfoo.json"),
|
||||
},
|
||||
{
|
||||
in: "c/foo",
|
||||
folder: filepath.Join("test", "sites", "cfoo"),
|
||||
certFile: filepath.Join("test", "sites", "cfoo", "cfoo.crt"),
|
||||
keyFile: filepath.Join("test", "sites", "cfoo", "cfoo.key"),
|
||||
metaFile: filepath.Join("test", "sites", "cfoo", "cfoo.json"),
|
||||
},
|
||||
} {
|
||||
if actual := fs.site(testcase.in); actual != testcase.folder {
|
||||
t.Errorf("Test %d: site folder: Expected '%s' but got '%s'", i, testcase.folder, actual)
|
||||
}
|
||||
if actual := fs.siteCertFile(testcase.in); actual != testcase.certFile {
|
||||
t.Errorf("Test %d: site cert file: Expected '%s' but got '%s'", i, testcase.certFile, actual)
|
||||
}
|
||||
if actual := fs.siteKeyFile(testcase.in); actual != testcase.keyFile {
|
||||
t.Errorf("Test %d: site key file: Expected '%s' but got '%s'", i, testcase.keyFile, actual)
|
||||
}
|
||||
if actual := fs.siteMetaFile(testcase.in); actual != testcase.metaFile {
|
||||
t.Errorf("Test %d: site meta file: Expected '%s' but got '%s'", i, testcase.metaFile, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,7 +91,20 @@ func (s *fileStorageLock) Unlock(name string) error {
|
|||
if !ok {
|
||||
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
||||
}
|
||||
// remove lock file
|
||||
os.Remove(fw.filename)
|
||||
|
||||
// if parent folder is now empty, remove it too to keep it tidy
|
||||
lockParentFolder := s.storage.site(name)
|
||||
dir, err := os.Open(lockParentFolder)
|
||||
if err == nil {
|
||||
items, _ := dir.Readdirnames(3) // OK to ignore error here
|
||||
if len(items) == 0 {
|
||||
os.Remove(lockParentFolder)
|
||||
}
|
||||
dir.Close()
|
||||
}
|
||||
|
||||
fw.wg.Done()
|
||||
delete(fileStorageNameLocks, s.caURL+name)
|
||||
return nil
|
||||
|
|
|
@ -61,10 +61,9 @@ func (cg configGroup) getConfig(name string) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
// try a config that serves all names (this
|
||||
// is basically the same as a config defined
|
||||
// for "*" -- I think -- but the above loop
|
||||
// doesn't try an empty string)
|
||||
// try a config that serves all names (the above
|
||||
// loop doesn't try empty string; for hosts defined
|
||||
// with only a port, for instance, like ":443")
|
||||
if config, ok := cg[""]; ok {
|
||||
return config
|
||||
}
|
||||
|
@ -190,17 +189,19 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau
|
|||
return
|
||||
}
|
||||
|
||||
// if nothing matches and SNI was not provided, use a random
|
||||
// certificate; at least there's a chance this older client
|
||||
// can connect, and in the future we won't need this provision
|
||||
// (if SNI is present, it's probably best to just raise a TLS
|
||||
// alert by not serving a certificate)
|
||||
if name == "" {
|
||||
for _, certKey := range cfg.Certificates {
|
||||
defaulted = true
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
return
|
||||
}
|
||||
// if nothing matches, use a random certificate
|
||||
// TODO: This is not my favorite behavior; I would rather serve
|
||||
// no certificate if SNI is provided and cause a TLS alert, than
|
||||
// serve the wrong certificate (but sometimes the 'wrong' cert
|
||||
// is what is wanted, but in those cases I would prefer that the
|
||||
// site owner explicitly configure a "default" certificate).
|
||||
// (See issue 2035; any change to this behavior must account for
|
||||
// hosts defined like ":443" or "0.0.0.0:443" where the hostname
|
||||
// is empty or a catch-all IP or something.)
|
||||
for _, certKey := range cfg.Certificates {
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
defaulted = true
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
|
|
|
@ -27,7 +27,7 @@ func TestGetCertificate(t *testing.T) {
|
|||
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
||||
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
||||
helloNoSNI := &tls.ClientHelloInfo{}
|
||||
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
|
||||
// helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} // TODO (see below)
|
||||
|
||||
// When cache is empty
|
||||
if cert, err := cfg.GetCertificate(hello); err == nil {
|
||||
|
@ -69,8 +69,9 @@ func TestGetCertificate(t *testing.T) {
|
|||
t.Errorf("Expected random cert with no matches, got: %v", cert)
|
||||
}
|
||||
|
||||
// TODO: Re-implement this behavior (it was reverted in #2037)
|
||||
// When no certificate matches, raise an alert
|
||||
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
|
||||
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
|
||||
}
|
||||
// if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
|
||||
// t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
|
||||
// }
|
||||
}
|
||||
|
|
|
@ -16,12 +16,16 @@ package caddytls
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
const challengeBasePath = "/.well-known/acme-challenge"
|
||||
|
@ -38,6 +42,13 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
|
|||
if DisableHTTPChallenge {
|
||||
return false
|
||||
}
|
||||
|
||||
// see if another instance started the HTTP challenge for this name
|
||||
if tryDistributedChallengeSolver(w, r) {
|
||||
return true
|
||||
}
|
||||
|
||||
// otherwise, if we aren't getting the name, then ignore this challenge
|
||||
if !namesObtaining.Has(r.Host) {
|
||||
return false
|
||||
}
|
||||
|
@ -70,3 +81,40 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
|
|||
|
||||
return true
|
||||
}
|
||||
|
||||
// tryDistributedChallengeSolver checks to see if this challenge
|
||||
// request was initiated by another instance that shares file
|
||||
// storage, and attempts to complete the challenge for it. It
|
||||
// returns true if the challenge was handled; false otherwise.
|
||||
func tryDistributedChallengeSolver(w http.ResponseWriter, r *http.Request) bool {
|
||||
filePath := distributedHTTPSolver{}.challengeTokensPath(r.Host)
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.Printf("[ERROR][%s] Opening distributed challenge token file: %v", r.Host, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var chalInfo challengeInfo
|
||||
err = json.NewDecoder(f).Decode(&chalInfo)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR][%s] Decoding challenge token file %s (corrupted?): %v", r.Host, filePath, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// this part borrowed from xenolf/lego's built-in HTTP-01 challenge solver (March 2018)
|
||||
challengeReqPath := acme.HTTP01ChallengePath(chalInfo.Token)
|
||||
if r.URL.Path == challengeReqPath &&
|
||||
strings.HasPrefix(r.Host, chalInfo.Domain) &&
|
||||
r.Method == "GET" {
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte(chalInfo.KeyAuth))
|
||||
r.Close = true
|
||||
log.Printf("[INFO][%s] Served key authentication", chalInfo.Domain)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -334,6 +334,7 @@ func DeleteOldStapleFiles() {
|
|||
if err != nil {
|
||||
log.Printf("[ERROR] Purging corrupt staple file %s: %v", stapleFile, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if time.Now().After(resp.NextUpdate) {
|
||||
// response has expired; delete it
|
||||
|
|
|
@ -107,19 +107,19 @@ func setupTLS(c *caddy.Controller) error {
|
|||
case "protocols":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 1 {
|
||||
value, ok := supportedProtocols[strings.ToLower(args[0])]
|
||||
value, ok := SupportedProtocols[strings.ToLower(args[0])]
|
||||
if !ok {
|
||||
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
|
||||
}
|
||||
|
||||
config.ProtocolMinVersion, config.ProtocolMaxVersion = value, value
|
||||
} else {
|
||||
value, ok := supportedProtocols[strings.ToLower(args[0])]
|
||||
value, ok := SupportedProtocols[strings.ToLower(args[0])]
|
||||
if !ok {
|
||||
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
|
||||
}
|
||||
config.ProtocolMinVersion = value
|
||||
value, ok = supportedProtocols[strings.ToLower(args[1])]
|
||||
value, ok = SupportedProtocols[strings.ToLower(args[1])]
|
||||
if !ok {
|
||||
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[1])
|
||||
}
|
||||
|
@ -130,7 +130,7 @@ func setupTLS(c *caddy.Controller) error {
|
|||
}
|
||||
case "ciphers":
|
||||
for c.NextArg() {
|
||||
value, ok := supportedCiphersMap[strings.ToUpper(c.Val())]
|
||||
value, ok := SupportedCiphersMap[strings.ToUpper(c.Val())]
|
||||
if !ok {
|
||||
return c.Errf("Wrong cipher name or cipher not supported: '%s'", c.Val())
|
||||
}
|
||||
|
@ -210,8 +210,21 @@ func setupTLS(c *caddy.Controller) error {
|
|||
}
|
||||
case "must_staple":
|
||||
config.MustStaple = true
|
||||
case "wildcard":
|
||||
if !HostQualifies(config.Hostname) {
|
||||
return c.Errf("Hostname '%s' does not qualify for managed TLS, so cannot manage wildcard certificate for it", config.Hostname)
|
||||
}
|
||||
if strings.Contains(config.Hostname, "*") {
|
||||
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: already has a wildcard label", config.Hostname)
|
||||
}
|
||||
parts := strings.Split(config.Hostname, ".")
|
||||
if len(parts) < 3 {
|
||||
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: too few labels", config.Hostname)
|
||||
}
|
||||
parts[0] = "*"
|
||||
config.Hostname = strings.Join(parts, ".")
|
||||
default:
|
||||
return c.Errf("Unknown keyword '%s'", c.Val())
|
||||
return c.Errf("Unknown subdirective '%s'", c.Val())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -67,8 +67,8 @@ func TestSetupParseBasic(t *testing.T) {
|
|||
}
|
||||
|
||||
// Security defaults
|
||||
if cfg.ProtocolMinVersion != tls.VersionTLS11 {
|
||||
t.Errorf("Expected 'tls1.1 (0x0302)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion)
|
||||
if cfg.ProtocolMinVersion != tls.VersionTLS12 {
|
||||
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion)
|
||||
}
|
||||
if cfg.ProtocolMaxVersion != tls.VersionTLS12 {
|
||||
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", cfg.ProtocolMaxVersion)
|
||||
|
|
|
@ -58,7 +58,8 @@ type Locker interface {
|
|||
// successfully obtained the lock (no Waiter value was returned)
|
||||
// should call this method, and it should be called only after
|
||||
// the obtain/renew and store are finished, even if there was
|
||||
// an error (or a timeout).
|
||||
// an error (or a timeout). Unlock should also clean up any
|
||||
// unused resources allocated during TryLock.
|
||||
Unlock(name string) error
|
||||
}
|
||||
|
||||
|
|
149
caddytls/tls.go
149
caddytls/tls.go
|
@ -30,26 +30,35 @@ package caddytls
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
// HostQualifies returns true if the hostname alone
|
||||
// appears eligible for automatic HTTPS. For example,
|
||||
// appears eligible for automatic HTTPS. For example:
|
||||
// localhost, empty hostname, and IP addresses are
|
||||
// not eligible because we cannot obtain certificates
|
||||
// for those names.
|
||||
// for those names. Wildcard names are allowed, as long
|
||||
// as they conform to CABF requirements (only one wildcard
|
||||
// label, and it must be the left-most label).
|
||||
func HostQualifies(hostname string) bool {
|
||||
return hostname != "localhost" && // localhost is ineligible
|
||||
|
||||
// hostname must not be empty
|
||||
strings.TrimSpace(hostname) != "" &&
|
||||
|
||||
// must not contain wildcard (*) characters (until CA supports it)
|
||||
!strings.Contains(hostname, "*") &&
|
||||
// only one wildcard label allowed, and it must be left-most
|
||||
(!strings.Contains(hostname, "*") ||
|
||||
(strings.Count(hostname, "*") == 1 &&
|
||||
strings.HasPrefix(hostname, "*."))) &&
|
||||
|
||||
// must not start or end with a dot
|
||||
!strings.HasPrefix(hostname, ".") &&
|
||||
|
@ -88,39 +97,125 @@ func Revoke(host string) error {
|
|||
return client.Revoke(host)
|
||||
}
|
||||
|
||||
// tlsSNISolver is a type that can solve TLS-SNI challenges using
|
||||
// an existing listener and our custom, in-memory certificate cache.
|
||||
type tlsSNISolver struct {
|
||||
certCache *certificateCache
|
||||
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
|
||||
// // tlsSNISolver is a type that can solve TLS-SNI challenges using
|
||||
// // an existing listener and our custom, in-memory certificate cache.
|
||||
// type tlsSNISolver struct {
|
||||
// certCache *certificateCache
|
||||
// }
|
||||
|
||||
// // Present adds the challenge certificate to the cache.
|
||||
// func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
|
||||
// cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// certHash := hashCertificateChain(cert.Certificate)
|
||||
// s.certCache.Lock()
|
||||
// s.certCache.cache[acmeDomain] = Certificate{
|
||||
// Certificate: cert,
|
||||
// Names: []string{acmeDomain},
|
||||
// Hash: certHash, // perhaps not necesssary
|
||||
// }
|
||||
// s.certCache.Unlock()
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // CleanUp removes the challenge certificate from the cache.
|
||||
// func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
|
||||
// _, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// s.certCache.Lock()
|
||||
// delete(s.certCache.cache, acmeDomain)
|
||||
// s.certCache.Unlock()
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// distributedHTTPSolver allows the HTTP-01 challenge to be solved by
|
||||
// an instance other than the one which initiated it. This is useful
|
||||
// behind load balancers or in other cluster/fleet configurations.
|
||||
// The only requirement is that this (the initiating) instance share
|
||||
// the $CADDYPATH/acme folder with the instance that will complete
|
||||
// the challenge. Mounting the folder locally should be sufficient.
|
||||
//
|
||||
// Obviously, the instance which completes the challenge must be
|
||||
// serving on the HTTPChallengePort to receive and handle the request.
|
||||
// The HTTP server which receives it must check if a file exists, e.g.:
|
||||
// $CADDYPATH/acme/challenge_tokens/example.com.json, and if so,
|
||||
// decode it and use it to serve up the correct response. Caddy's HTTP
|
||||
// server does this by default.
|
||||
//
|
||||
// So as long as the folder is shared, this will just work. There are
|
||||
// no other requirements. The instances may be on other machines or
|
||||
// even other networks, as long as they share the folder as part of
|
||||
// the local file system.
|
||||
//
|
||||
// This solver works by persisting the token and keyauth information
|
||||
// to disk in the shared folder when the authorization is presented,
|
||||
// and then deletes it when it is cleaned up.
|
||||
type distributedHTTPSolver struct {
|
||||
// The distributed HTTPS solver only works if an instance (either
|
||||
// this one or another one) is already listening and serving on the
|
||||
// HTTPChallengePort. If not -- for example: if this is the only
|
||||
// instance, and it is just starting up and hasn't started serving
|
||||
// yet -- then we still need a listener open with an HTTP server
|
||||
// to handle the challenge request. Set this field to have the
|
||||
// standard HTTPProviderServer open its listener for the duration
|
||||
// of the challenge. Make sure to configure its listen address
|
||||
// correctly.
|
||||
httpProviderServer *acme.HTTPProviderServer
|
||||
}
|
||||
|
||||
type challengeInfo struct {
|
||||
Domain, Token, KeyAuth string
|
||||
}
|
||||
|
||||
// Present adds the challenge certificate to the cache.
|
||||
func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
|
||||
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
func (dhs distributedHTTPSolver) Present(domain, token, keyAuth string) error {
|
||||
if dhs.httpProviderServer != nil {
|
||||
err := dhs.httpProviderServer.Present(domain, token, keyAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("presenting with standard HTTP provider server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := os.MkdirAll(dhs.challengeTokensBasePath(), 0755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certHash := hashCertificateChain(cert.Certificate)
|
||||
s.certCache.Lock()
|
||||
s.certCache.cache[acmeDomain] = Certificate{
|
||||
Certificate: cert,
|
||||
Names: []string{acmeDomain},
|
||||
Hash: certHash, // perhaps not necesssary
|
||||
|
||||
infoBytes, err := json.Marshal(challengeInfo{
|
||||
Domain: domain,
|
||||
Token: token,
|
||||
KeyAuth: keyAuth,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.certCache.Unlock()
|
||||
return nil
|
||||
|
||||
return ioutil.WriteFile(dhs.challengeTokensPath(domain), infoBytes, 0644)
|
||||
}
|
||||
|
||||
// CleanUp removes the challenge certificate from the cache.
|
||||
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
|
||||
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
func (dhs distributedHTTPSolver) CleanUp(domain, token, keyAuth string) error {
|
||||
if dhs.httpProviderServer != nil {
|
||||
err := dhs.httpProviderServer.CleanUp(domain, token, keyAuth)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Cleaning up standard HTTP provider server: %v", err)
|
||||
}
|
||||
}
|
||||
s.certCache.Lock()
|
||||
delete(s.certCache.cache, acmeDomain)
|
||||
s.certCache.Unlock()
|
||||
return nil
|
||||
return os.Remove(dhs.challengeTokensPath(domain))
|
||||
}
|
||||
|
||||
func (dhs distributedHTTPSolver) challengeTokensPath(domain string) string {
|
||||
domainFile := strings.Replace(strings.ToLower(domain), "*", "wildcard_", -1)
|
||||
return filepath.Join(dhs.challengeTokensBasePath(), domainFile+".json")
|
||||
}
|
||||
|
||||
func (dhs distributedHTTPSolver) challengeTokensBasePath() string {
|
||||
return filepath.Join(caddy.AssetsPath(), "acme", "challenge_tokens")
|
||||
}
|
||||
|
||||
// ConfigHolder is any type that has a Config; it presumably is
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
func TestHostQualifies(t *testing.T) {
|
||||
|
@ -37,7 +37,10 @@ func TestHostQualifies(t *testing.T) {
|
|||
{"0.0.0.0", false},
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{"*.example.com", false},
|
||||
{"*.example.com", true},
|
||||
{"*.*.example.com", false},
|
||||
{"sub.*.example.com", false},
|
||||
{"*sub.example.com", false},
|
||||
{".com", false},
|
||||
{"example.com.", false},
|
||||
{"localhost", false},
|
||||
|
@ -77,7 +80,10 @@ func TestQualifiesForManagedTLS(t *testing.T) {
|
|||
{holder{host: "localhost", cfg: new(Config)}, false},
|
||||
{holder{host: "123.44.3.21", cfg: new(Config)}, false},
|
||||
{holder{host: "example.com", cfg: new(Config)}, true},
|
||||
{holder{host: "*.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "*.example.com", cfg: new(Config)}, true},
|
||||
{holder{host: "*.*.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "*sub.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "sub.*.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "example.com", cfg: &Config{Manual: true}}, false},
|
||||
{holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false},
|
||||
{holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true},
|
||||
|
|
129
caddytls/user.go
129
caddytls/user.go
|
@ -27,7 +27,7 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
// User represents a Let's Encrypt user account.
|
||||
|
@ -67,43 +67,82 @@ func newUser(email string) (User, error) {
|
|||
return user, nil
|
||||
}
|
||||
|
||||
// getEmail does everything it can to obtain an email
|
||||
// address from the user within the scope of storage
|
||||
// to use for ACME TLS. If it cannot get an email
|
||||
// address, it returns empty string. (It will warn the
|
||||
// user of the consequences of an empty email.) This
|
||||
// function MAY prompt the user for input. If userPresent
|
||||
// is false, the operator will NOT be prompted and an
|
||||
// empty email may be returned.
|
||||
func getEmail(storage Storage, userPresent bool) string {
|
||||
// getEmail does everything it can to obtain an email address
|
||||
// from the user within the scope of memory and storage to use
|
||||
// for ACME TLS. If it cannot get an email address, it returns
|
||||
// empty string. (If user is present, it will warn the user of
|
||||
// the consequences of an empty email.) This function MAY prompt
|
||||
// the user for input. If userPresent is false, the operator
|
||||
// will NOT be prompted and an empty email may be returned.
|
||||
// If the user is prompted, a new User will be created and
|
||||
// stored in storage according to the email address they
|
||||
// provided (which might be blank).
|
||||
func getEmail(cfg *Config, userPresent bool) (string, error) {
|
||||
storage, err := cfg.StorageFor(cfg.CAUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// First try memory (command line flag or typed by user previously)
|
||||
leEmail := DefaultEmail
|
||||
|
||||
// Then try to get most recent user email from storage
|
||||
if leEmail == "" {
|
||||
// Then try to get most recent user email
|
||||
leEmail = storage.MostRecentUserEmail()
|
||||
// Save for next time
|
||||
DefaultEmail = leEmail
|
||||
DefaultEmail = leEmail // save for next time
|
||||
}
|
||||
|
||||
// Looks like there is no email address readily available,
|
||||
// so we will have to ask the user if we can.
|
||||
if leEmail == "" && userPresent {
|
||||
// Alas, we must bother the user and ask for an email address;
|
||||
// if they proceed they also agree to the SA.
|
||||
reader := bufio.NewReader(stdin)
|
||||
fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.")
|
||||
fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:")
|
||||
fmt.Println(" " + saURL) // TODO: Show current SA link
|
||||
fmt.Println("Please enter your email address so you can recover your account if needed.")
|
||||
fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
|
||||
fmt.Print("Email address: ")
|
||||
var err error
|
||||
leEmail, err = reader.ReadString('\n')
|
||||
// evidently, no User data was present in storage;
|
||||
// thus we must make a new User so that we can get
|
||||
// the Terms of Service URL via our ACME client, phew!
|
||||
user, err := newUser("")
|
||||
if err != nil {
|
||||
return ""
|
||||
return "", err
|
||||
}
|
||||
|
||||
// get the agreement URL
|
||||
agreementURL := agreementTestURL
|
||||
if agreementURL == "" {
|
||||
// we call acme.NewClient directly because newACMEClient
|
||||
// would require that we already know the user's email
|
||||
caURL := DefaultCAUrl
|
||||
if cfg.CAUrl != "" {
|
||||
caURL = cfg.CAUrl
|
||||
}
|
||||
tempClient, err := acme.NewClient(caURL, user, "")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("making ACME client to get ToS URL: %v", err)
|
||||
}
|
||||
agreementURL = tempClient.GetToSURL()
|
||||
}
|
||||
|
||||
// prompt the user for an email address and terms agreement
|
||||
reader := bufio.NewReader(stdin)
|
||||
promptUserAgreement(agreementURL)
|
||||
fmt.Println("Please enter your email address to signify agreement and to be notified")
|
||||
fmt.Println("in case of issues. You can leave it blank, but we don't recommend it.")
|
||||
fmt.Print(" Email address: ")
|
||||
leEmail, err = reader.ReadString('\n')
|
||||
if err != nil && err != io.EOF {
|
||||
return "", fmt.Errorf("reading email address: %v", err)
|
||||
}
|
||||
leEmail = strings.TrimSpace(leEmail)
|
||||
DefaultEmail = leEmail
|
||||
Agreed = true
|
||||
|
||||
// save the new user to preserve this for next time
|
||||
user.Email = leEmail
|
||||
err = saveUser(storage, user)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return strings.ToLower(leEmail)
|
||||
|
||||
// lower-casing the email is important for consistency
|
||||
return strings.ToLower(leEmail), nil
|
||||
}
|
||||
|
||||
// getUser loads the user with the given email from disk
|
||||
|
@ -154,18 +193,21 @@ func saveUser(storage Storage, user User) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// promptUserAgreement prompts the user to agree to the agreement
|
||||
// at agreementURL via stdin. If the agreement has changed, then pass
|
||||
// true as the second argument. If this is the user's first time
|
||||
// agreeing, pass false. It returns whether the user agreed or not.
|
||||
func promptUserAgreement(agreementURL string, changed bool) bool {
|
||||
if changed {
|
||||
fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL)
|
||||
fmt.Print("Do you agree to the new terms? (y/n): ")
|
||||
} else {
|
||||
fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL)
|
||||
fmt.Print("Do you agree to the terms? (y/n): ")
|
||||
}
|
||||
// promptUserAgreement simply outputs the standard user
|
||||
// agreement prompt with the given agreement URL.
|
||||
// It outputs a newline after the message.
|
||||
func promptUserAgreement(agreementURL string) {
|
||||
const userAgreementPrompt = `Your sites will be served over HTTPS automatically using Let's Encrypt.
|
||||
By continuing, you agree to the Let's Encrypt Subscriber Agreement at:`
|
||||
fmt.Printf("\n\n%s\n %s\n", userAgreementPrompt, agreementURL)
|
||||
}
|
||||
|
||||
// askUserAgreement prompts the user to agree to the agreement
|
||||
// at the given agreement URL via stdin. It returns whether the
|
||||
// user agreed or not.
|
||||
func askUserAgreement(agreementURL string) bool {
|
||||
promptUserAgreement(agreementURL)
|
||||
fmt.Print("Do you agree to the terms? (y/n): ")
|
||||
|
||||
reader := bufio.NewReader(stdin)
|
||||
answer, err := reader.ReadString('\n')
|
||||
|
@ -177,14 +219,15 @@ func promptUserAgreement(agreementURL string, changed bool) bool {
|
|||
return answer == "y" || answer == "yes"
|
||||
}
|
||||
|
||||
// agreementTestURL is set during tests to skip requiring
|
||||
// setting up an entire ACME CA endpoint.
|
||||
var agreementTestURL string
|
||||
|
||||
// stdin is used to read the user's input if prompted;
|
||||
// this is changed by tests during tests.
|
||||
var stdin = io.ReadWriter(os.Stdin)
|
||||
|
||||
// The name of the folder for accounts where the email
|
||||
// address was not provided; default 'username' if you will.
|
||||
// address was not provided; default 'username' if you will,
|
||||
// but only for local/storage use, not with the CA.
|
||||
const emptyEmail = "default"
|
||||
|
||||
// TODO: After Boulder implements the 'meta' field of the directory,
|
||||
// we can get this link dynamically.
|
||||
const saURL = "https://acme-v01.api.letsencrypt.org/terms"
|
||||
|
|
|
@ -20,13 +20,14 @@ import (
|
|||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"os"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/xenolf/lego/acmev2"
|
||||
)
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
|
@ -135,7 +136,13 @@ func TestGetUserAlreadyExists(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetEmail(t *testing.T) {
|
||||
storageBasePath = testStorage.Path // to contain calls that create a new Storage...
|
||||
// ensure storage (via StorageFor) uses the local testdata folder that we delete later
|
||||
origCaddypath := os.Getenv("CADDYPATH")
|
||||
os.Setenv("CADDYPATH", "./testdata")
|
||||
defer os.Setenv("CADDYPATH", origCaddypath)
|
||||
|
||||
agreementTestURL = "(none - testing)"
|
||||
defer func() { agreementTestURL = "" }()
|
||||
|
||||
// let's not clutter up the output
|
||||
origStdout := os.Stdout
|
||||
|
@ -146,7 +153,10 @@ func TestGetEmail(t *testing.T) {
|
|||
DefaultEmail = "test2@foo.com"
|
||||
|
||||
// Test1: Use default email from flag (or user previously typing it)
|
||||
actual := getEmail(testStorage, true)
|
||||
actual, err := getEmail(testConfig, true)
|
||||
if err != nil {
|
||||
t.Fatalf("getEmail (1) error: %v", err)
|
||||
}
|
||||
if actual != DefaultEmail {
|
||||
t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual)
|
||||
}
|
||||
|
@ -154,16 +164,19 @@ func TestGetEmail(t *testing.T) {
|
|||
// Test2: Get input from user
|
||||
DefaultEmail = ""
|
||||
stdin = new(bytes.Buffer)
|
||||
_, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
|
||||
_, err = io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not simulate user input, error: %v", err)
|
||||
}
|
||||
actual = getEmail(testStorage, true)
|
||||
actual, err = getEmail(testConfig, true)
|
||||
if err != nil {
|
||||
t.Fatalf("getEmail (2) error: %v", err)
|
||||
}
|
||||
if actual != "test3@foo.com" {
|
||||
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
|
||||
}
|
||||
|
||||
// Test3: Get most recent email from before
|
||||
// Test3: Get most recent email from before (in storage)
|
||||
DefaultEmail = ""
|
||||
for i, eml := range []string{
|
||||
"TEST4-3@foo.com", // test case insensitivity
|
||||
|
@ -189,14 +202,20 @@ func TestGetEmail(t *testing.T) {
|
|||
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
|
||||
}
|
||||
}
|
||||
actual = getEmail(testStorage, true)
|
||||
actual, err = getEmail(testConfig, true)
|
||||
if err != nil {
|
||||
t.Fatalf("getEmail (3) error: %v", err)
|
||||
}
|
||||
if actual != "test4-3@foo.com" {
|
||||
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
|
||||
}
|
||||
}
|
||||
|
||||
var testStorage = &FileStorage{Path: "./testdata"}
|
||||
var (
|
||||
testStorageBase = "./testdata" // ephemeral folder that gets deleted after tests finish
|
||||
testCAHost = "localhost"
|
||||
testConfig = &Config{CAUrl: "http://" + testCAHost + "/directory", StorageProvider: "file"}
|
||||
testStorage = &FileStorage{Path: filepath.Join(testStorageBase, "acme", testCAHost)}
|
||||
)
|
||||
|
||||
func (s *FileStorage) clean() error {
|
||||
return os.RemoveAll(s.Path)
|
||||
}
|
||||
func (s *FileStorage) clean() error { return os.RemoveAll(testStorageBase) }
|
||||
|
|
59
dist/CHANGES.txt
vendored
59
dist/CHANGES.txt
vendored
|
@ -1,5 +1,64 @@
|
|||
CHANGES
|
||||
|
||||
0.10.14 (April 19, 2018)
|
||||
- tls: Fix error handling bug when obtaining certificates
|
||||
|
||||
|
||||
0.10.13 (April 18, 2018)
|
||||
- New third-party plugin: supervisor
|
||||
- Updated QUIC
|
||||
- proxy: Fix transparent pass-thru of X-Forwarded-For
|
||||
- proxy: Configurable timeout to upstream
|
||||
- rewrite: Now supports regular expressions on single-line
|
||||
- tls: StrictHostMatching mode to prevent client auth bypass
|
||||
- tls: Disable client auth when using QUIC
|
||||
- tls: Require same client auth cert pools per hostname
|
||||
- tls: Prevent On-Demand TLS directory traversal
|
||||
- tls: Fix empty files when using ACME fails to obtain cert
|
||||
- Fixed test broken by 1.1.1.1 resolving
|
||||
- Improved Caddyfile parser robustness by fuzzing
|
||||
|
||||
|
||||
0.10.12 (March 27, 2018)
|
||||
- Switch to Let's Encrypt ACMEv2 production endpoint
|
||||
- Support for automated wildcard certificates
|
||||
- Support distributed solving of HTTP-01 challenge
|
||||
- New {labelN}, {tls_cipher}, and {tls_version} placeholders
|
||||
- Curly braces can now be escaped when not used as placeholders
|
||||
- New third-party plugin: geoip
|
||||
- Updated QUIC
|
||||
- fastcgi: Add SSL_CIPHER and SSL_PROTOCOL environment variables
|
||||
- log: New 'except' subdirective to exempt paths from logging
|
||||
- startup/shutdown: Removed in favor of 'on'
|
||||
- tls: Default minimum version is TLS 1.2
|
||||
- tls: Revert to fallback cert if no cert matches SNI
|
||||
- tls: New 'wildcard' subdirective to force automated wildcard cert
|
||||
- Several significant bug fixes and improvements!
|
||||
|
||||
|
||||
0.10.11 (February 20, 2018)
|
||||
- Built with Go 1.10
|
||||
- Reusable snippets for the Caddyfile
|
||||
- Updated QUIC
|
||||
- Auto-HTTPS certificates may be shared by multiple instances
|
||||
- Expand globbed values in -conf flag
|
||||
- Swap behavior of SIGTERM and SIGQUIT; ignore SIGHUP
|
||||
- 9 new DNS provider plugins for the ACME DNS challenge
|
||||
- New placeholder for {<Response-Header} values
|
||||
- basicauth: Username put in {user} placeholder
|
||||
- fastcgi: GET requests can now send a body
|
||||
- proxy: Service discovery with DNS SRV load balancing
|
||||
- request_id: Allow reusing request ID from header field
|
||||
- tls: Improved efficiency of many certificates and reloads
|
||||
- tls: Raise error if conflicting TLS configurations collide
|
||||
- tls: Raise TLS alert if SNI used and no cert matched
|
||||
- tls: Reject OCSP responses that expire after the certificate
|
||||
- tls: Clients can use SNI to request a specific certificate
|
||||
- tls: Add option for backend to approve on-demand certificate
|
||||
- tls: Synchronize maintenance of shared, managed certificates
|
||||
- Numerous fabulous bug fixes
|
||||
|
||||
|
||||
0.10.10 (October 9, 2017)
|
||||
- Built with Go 1.9.1
|
||||
- Removed Caddy-Sponsors header
|
||||
|
|
6
dist/README.txt
vendored
6
dist/README.txt
vendored
|
@ -1,4 +1,4 @@
|
|||
CADDY 0.10.10
|
||||
CADDY 0.10.14
|
||||
|
||||
Website
|
||||
https://caddyserver.com
|
||||
|
@ -32,9 +32,9 @@ the project wiki: https://github.com/mholt/caddy/wiki
|
|||
And thanks - you're awesome!
|
||||
|
||||
If you think Caddy is awesome too, consider sponsoring it:
|
||||
https://caddyserver.com/pricing - and help keep Caddy free
|
||||
https://caddyserver.com/sponsor - and help keep Caddy free
|
||||
for personal use.
|
||||
|
||||
|
||||
---
|
||||
(c) 2015-2017 Light Code Labs, LLC
|
||||
(c) 2015-2018 Light Code Labs, LLC
|
||||
|
|
32
plugins.go
32
plugins.go
|
@ -39,7 +39,7 @@ var (
|
|||
|
||||
// eventHooks is a map of hook name to Hook. All hooks plugins
|
||||
// must have a name.
|
||||
eventHooks = sync.Map{}
|
||||
eventHooks = &sync.Map{}
|
||||
|
||||
// parsingCallbacks maps server type to map of directive
|
||||
// to list of callback functions. These aren't really
|
||||
|
@ -296,6 +296,36 @@ func EmitEvent(event EventName, info interface{}) {
|
|||
})
|
||||
}
|
||||
|
||||
// cloneEventHooks return a clone of the event hooks *sync.Map
|
||||
func cloneEventHooks() *sync.Map {
|
||||
c := &sync.Map{}
|
||||
eventHooks.Range(func(k, v interface{}) bool {
|
||||
c.Store(k, v)
|
||||
return true
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// purgeEventHooks purges all event hooks from the map
|
||||
func purgeEventHooks() {
|
||||
eventHooks.Range(func(k, _ interface{}) bool {
|
||||
eventHooks.Delete(k)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// restoreEventHooks restores eventHooks with a provided *sync.Map
|
||||
func restoreEventHooks(m *sync.Map) {
|
||||
// Purge old event hooks
|
||||
purgeEventHooks()
|
||||
|
||||
// Restore event hooks
|
||||
m.Range(func(k, v interface{}) bool {
|
||||
eventHooks.Store(k, v)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// ParsingCallback is a function that is called after
|
||||
// a directive's setup functions have been executed
|
||||
// for all the server blocks.
|
||||
|
|
|
@ -83,9 +83,17 @@ func trapSignalsPosix() {
|
|||
caddyfileToUse = newCaddyfile
|
||||
}
|
||||
|
||||
// Backup old event hooks
|
||||
oldEventHooks := cloneEventHooks()
|
||||
|
||||
// Purge the old event hooks
|
||||
purgeEventHooks()
|
||||
|
||||
// Kick off the restart; our work is done
|
||||
_, err = inst.Restart(caddyfileToUse)
|
||||
if err != nil {
|
||||
restoreEventHooks(oldEventHooks)
|
||||
|
||||
log.Printf("[ERROR] SIGUSR1: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,98 +0,0 @@
|
|||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package startupshutdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/onevent/hook"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("startup", caddy.Plugin{Action: Startup})
|
||||
caddy.RegisterPlugin("shutdown", caddy.Plugin{Action: Shutdown})
|
||||
}
|
||||
|
||||
// Startup (an alias for 'on startup') registers a startup callback to execute during server start.
|
||||
func Startup(c *caddy.Controller) error {
|
||||
config, err := onParse(c, caddy.InstanceStartupEvent)
|
||||
if err != nil {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// Register Event Hooks.
|
||||
c.OncePerServerBlock(func() error {
|
||||
for _, cfg := range config {
|
||||
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
fmt.Println("NOTICE: Startup directive will be removed in a later version. Please migrate to 'on startup'")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown (an alias for 'on shutdown') registers a shutdown callback to execute during server start.
|
||||
func Shutdown(c *caddy.Controller) error {
|
||||
config, err := onParse(c, caddy.ShutdownEvent)
|
||||
if err != nil {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// Register Event Hooks.
|
||||
for _, cfg := range config {
|
||||
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
|
||||
}
|
||||
|
||||
fmt.Println("NOTICE: Shutdown directive will be removed in a later version. Please migrate to 'on shutdown'")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func onParse(c *caddy.Controller, event caddy.EventName) ([]*hook.Config, error) {
|
||||
var config []*hook.Config
|
||||
|
||||
for c.Next() {
|
||||
cfg := new(hook.Config)
|
||||
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
return config, c.ArgErr()
|
||||
}
|
||||
|
||||
// Configure Event.
|
||||
cfg.Event = event
|
||||
|
||||
// Assign an unique ID.
|
||||
cfg.ID = uuid.New().String()
|
||||
|
||||
// Extract command and arguments.
|
||||
command, args, err := caddy.SplitCommandAndArgs(strings.Join(args, " "))
|
||||
if err != nil {
|
||||
return config, c.Err(err.Error())
|
||||
}
|
||||
|
||||
cfg.Command = command
|
||||
cfg.Args = args
|
||||
|
||||
config = append(config, cfg)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package startupshutdown
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestStartup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{name: "noInput", input: "startup", shouldErr: true},
|
||||
{name: "startup", input: "startup cmd arg", shouldErr: false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
c := caddy.NewTestController("", test.input)
|
||||
|
||||
err := Startup(c)
|
||||
if err == nil && test.shouldErr {
|
||||
t.Error("Test didn't error, but it should have")
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{name: "noInput", input: "shutdown", shouldErr: true},
|
||||
{name: "shutdown", input: "shutdown cmd arg", shouldErr: false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
c := caddy.NewTestController("", test.input)
|
||||
|
||||
err := Shutdown(c)
|
||||
if err == nil && test.shouldErr {
|
||||
t.Error("Test didn't error, but it should have")
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
2
vendor/github.com/lucas-clemente/aes12/cipher_generic.go
generated
vendored
2
vendor/github.com/lucas-clemente/aes12/cipher_generic.go
generated
vendored
|
@ -2,7 +2,7 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !amd64,!s390x
|
||||
// +build !amd64
|
||||
|
||||
package aes12
|
||||
|
||||
|
|
13
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
|
@ -8,19 +8,20 @@ import (
|
|||
|
||||
var bufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() []byte {
|
||||
return bufferPool.Get().([]byte)
|
||||
func getPacketBuffer() *[]byte {
|
||||
return bufferPool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
func putPacketBuffer(buf []byte) {
|
||||
if cap(buf) != int(protocol.MaxReceivePacketSize) {
|
||||
func putPacketBuffer(buf *[]byte) {
|
||||
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
bufferPool.Put(buf[:0])
|
||||
bufferPool.Put(buf)
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() interface{} {
|
||||
return make([]byte, 0, protocol.MaxReceivePacketSize)
|
||||
b := make([]byte, 0, protocol.MaxReceivePacketSize)
|
||||
return &b
|
||||
}
|
||||
}
|
||||
|
|
56
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
56
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
|
@ -38,6 +38,8 @@ type client struct {
|
|||
version protocol.VersionNumber
|
||||
|
||||
session packetHandler
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -85,6 +87,14 @@ func Dial(
|
|||
}
|
||||
}
|
||||
|
||||
// check that all versions are actually supported
|
||||
if config != nil {
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
|
@ -94,9 +104,10 @@ func Dial(
|
|||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
|
@ -132,6 +143,18 @@ func populateClientConfig(config *Config) *Config {
|
|||
if maxReceiveConnectionFlowControlWindow == 0 {
|
||||
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
|
||||
}
|
||||
maxIncomingStreams := config.MaxIncomingStreams
|
||||
if maxIncomingStreams == 0 {
|
||||
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
||||
} else if maxIncomingStreams < 0 {
|
||||
maxIncomingStreams = 0
|
||||
}
|
||||
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
||||
if maxIncomingUniStreams == 0 {
|
||||
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
|
@ -140,7 +163,9 @@ func populateClientConfig(config *Config) *Config {
|
|||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
KeepAlive: config.KeepAlive,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
KeepAlive: config.KeepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,12 +196,11 @@ func (c *client) dialTLS() error {
|
|||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
// TODO(#523): make these values configurable
|
||||
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
||||
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
||||
}
|
||||
csc := handshake.NewCryptoStreamConn(nil)
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -193,7 +217,7 @@ func (c *client) dialTLS() error {
|
|||
if err != handshake.ErrCloseSessionForRetry {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Received a Retry packet. Recreating session.")
|
||||
c.logger.Infof("Received a Retry packet. Recreating session.")
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -216,7 +240,7 @@ func (c *client) establishSecureConnection() error {
|
|||
go func() {
|
||||
runErr = c.session.run() // returns as soon as the session is closed
|
||||
close(errorChan)
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.logger.Infof("Connection %x closed.", c.connectionID)
|
||||
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
@ -245,7 +269,7 @@ func (c *client) listen() {
|
|||
for {
|
||||
var n int
|
||||
var addr net.Addr
|
||||
data := getPacketBuffer()
|
||||
data := *getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
|
@ -270,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
r := bytes.NewReader(packet)
|
||||
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||
if err != nil {
|
||||
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||
c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||
// drop this packet if we can't parse the header
|
||||
return
|
||||
}
|
||||
|
@ -293,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
|
||||
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||
c.logger.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||
return
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(r)
|
||||
if err != nil {
|
||||
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
return
|
||||
}
|
||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
return
|
||||
}
|
||||
|
@ -347,6 +371,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
}
|
||||
}
|
||||
|
||||
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
|
@ -362,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
return nil
|
||||
}
|
||||
|
@ -379,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) {
|
|||
c.config,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
c.logger,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
@ -398,6 +425,7 @@ func (c *client) createNewTLSSession(
|
|||
c.tls,
|
||||
paramsChan,
|
||||
1,
|
||||
c.logger,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
16
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
|
@ -19,12 +19,14 @@ func main() {
|
|||
flag.Parse()
|
||||
urls := flag.Args()
|
||||
|
||||
logger := utils.DefaultLogger
|
||||
|
||||
if *verbose {
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
logger.SetLogLevel(utils.LogLevelDebug)
|
||||
} else {
|
||||
utils.SetLogLevel(utils.LogLevelInfo)
|
||||
logger.SetLogLevel(utils.LogLevelInfo)
|
||||
}
|
||||
utils.SetLogTimeFormat("")
|
||||
logger.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
|
@ -42,21 +44,21 @@ func main() {
|
|||
var wg sync.WaitGroup
|
||||
wg.Add(len(urls))
|
||||
for _, addr := range urls {
|
||||
utils.Infof("GET %s", addr)
|
||||
logger.Infof("GET %s", addr)
|
||||
go func(addr string) {
|
||||
rsp, err := hclient.Get(addr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
utils.Infof("Got response for %s: %#v", addr, rsp)
|
||||
logger.Infof("Got response for %s: %#v", addr, rsp)
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
_, err = io.Copy(body, rsp.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
utils.Infof("Request Body:")
|
||||
utils.Infof("%s", body.Bytes())
|
||||
logger.Infof("Request Body:")
|
||||
logger.Infof("%s", body.Bytes())
|
||||
wg.Done()
|
||||
}(addr)
|
||||
}
|
||||
|
|
10
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
10
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
|
@ -91,7 +91,7 @@ func init() {
|
|||
}
|
||||
}
|
||||
if err != nil {
|
||||
utils.Infof("Error receiving upload: %#v", err)
|
||||
utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
|
||||
}
|
||||
}
|
||||
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
|
||||
|
@ -126,12 +126,14 @@ func main() {
|
|||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
|
||||
logger := utils.DefaultLogger
|
||||
|
||||
if *verbose {
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
logger.SetLogLevel(utils.LogLevelDebug)
|
||||
} else {
|
||||
utils.SetLogLevel(utils.LogLevelInfo)
|
||||
logger.SetLogLevel(utils.LogLevelInfo)
|
||||
}
|
||||
utils.SetLogTimeFormat("")
|
||||
logger.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
|
|
20
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
20
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
|
@ -46,6 +46,8 @@ type client struct {
|
|||
requestWriter *requestWriter
|
||||
|
||||
responses map[protocol.StreamID]chan *http.Response
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = &client{}
|
||||
|
@ -75,6 +77,7 @@ func newClient(
|
|||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,7 +98,7 @@ func (c *client) dial() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream)
|
||||
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
}
|
||||
|
@ -108,7 +111,9 @@ func (c *client) handleHeaderStream() {
|
|||
for err == nil {
|
||||
err = c.readResponse(h2framer, decoder)
|
||||
}
|
||||
utils.Debugf("Error handling header stream: %s", err)
|
||||
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
|
||||
c.logger.Debugf("Error handling header stream: %s", err)
|
||||
}
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||
// stop all running request
|
||||
close(c.headerErrored)
|
||||
|
@ -202,6 +207,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
bodySent = true
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
for !(bodySent && receivedResponse) {
|
||||
select {
|
||||
case res = <-responseChan:
|
||||
|
@ -214,8 +220,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// error code 6 signals that stream was canceled
|
||||
dataStream.CancelRead(6)
|
||||
dataStream.CancelWrite(6)
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStream.StreamID())
|
||||
c.mutex.Unlock()
|
||||
return nil, ctx.Err()
|
||||
case <-c.headerErrored:
|
||||
// an error occured on the header stream
|
||||
// an error occurred on the header stream
|
||||
_ = c.CloseWithError(c.headerErr)
|
||||
return nil, c.headerErr
|
||||
}
|
||||
|
|
10
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
10
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
|
@ -23,13 +23,16 @@ type requestWriter struct {
|
|||
|
||||
henc *hpack.Encoder
|
||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
const defaultUserAgent = "quic-go"
|
||||
|
||||
func newRequestWriter(headerStream quic.Stream) *requestWriter {
|
||||
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
|
||||
rw := &requestWriter{
|
||||
headerStream: headerStream,
|
||||
logger: logger,
|
||||
}
|
||||
rw.henc = hpack.NewEncoder(&rw.hbuf)
|
||||
return rw
|
||||
|
@ -76,9 +79,8 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
|
|||
if !validPseudoPath(path) {
|
||||
if req.URL.Opaque != "" {
|
||||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -157,7 +159,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
|
|||
}
|
||||
|
||||
func (w *requestWriter) writeHeader(name, value string) {
|
||||
utils.Debugf("http2: Transport encoding header %q = %q", name, value)
|
||||
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
|
||||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
|
||||
|
|
20
vendor/github.com/lucas-clemente/quic-go/h2quic/response.go
generated
vendored
20
vendor/github.com/lucas-clemente/quic-go/h2quic/response.go
generated
vendored
|
@ -3,7 +3,6 @@ package h2quic
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
|
@ -16,7 +15,7 @@ import (
|
|||
// copied from net/http2/transport.go
|
||||
|
||||
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
|
||||
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil))
|
||||
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
|
||||
|
||||
// from the handleResponse function
|
||||
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
|
||||
|
@ -33,16 +32,7 @@ func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
|
|||
return nil, errors.New("malformed non-numeric status pseudo header")
|
||||
}
|
||||
|
||||
if statusCode == 100 {
|
||||
// TODO: handle this
|
||||
|
||||
// traceGot100Continue(cs.trace)
|
||||
// if cs.on100 != nil {
|
||||
// cs.on100() // forces any write delay timer to fire
|
||||
// }
|
||||
// cs.pastHeaders = false // do it all again
|
||||
// return nil, nil
|
||||
}
|
||||
// TODO: handle statusCode == 100
|
||||
|
||||
header := make(http.Header)
|
||||
res := &http.Response{
|
||||
|
@ -78,13 +68,7 @@ func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
|
|||
if clens := res.Header["Content-Length"]; len(clens) == 1 {
|
||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
||||
res.ContentLength = clen64
|
||||
} else {
|
||||
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
||||
// more safe smuggling-wise to ignore.
|
||||
}
|
||||
} else if len(clens) > 1 {
|
||||
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
||||
// more safe smuggling-wise to ignore.
|
||||
}
|
||||
}
|
||||
return res
|
||||
|
|
15
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
15
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
|
@ -24,15 +24,24 @@ type responseWriter struct {
|
|||
header http.Header
|
||||
status int // status code passed to WriteHeader
|
||||
headerWritten bool
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter {
|
||||
func newResponseWriter(
|
||||
headerStream quic.Stream,
|
||||
headerStreamMutex *sync.Mutex,
|
||||
dataStream quic.Stream,
|
||||
dataStreamID protocol.StreamID,
|
||||
logger utils.Logger,
|
||||
) *responseWriter {
|
||||
return &responseWriter{
|
||||
header: http.Header{},
|
||||
headerStream: headerStream,
|
||||
headerStreamMutex: headerStreamMutex,
|
||||
dataStream: dataStream,
|
||||
dataStreamID: dataStreamID,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +66,7 @@ func (w *responseWriter) WriteHeader(status int) {
|
|||
}
|
||||
}
|
||||
|
||||
utils.Infof("Responding with %d", status)
|
||||
w.logger.Infof("Responding with %d", status)
|
||||
w.headerStreamMutex.Lock()
|
||||
defer w.headerStreamMutex.Unlock()
|
||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
||||
|
@ -67,7 +76,7 @@ func (w *responseWriter) WriteHeader(status int) {
|
|||
BlockFragment: headers.Bytes(),
|
||||
})
|
||||
if err != nil {
|
||||
utils.Errorf("could not write h2 header: %s", err.Error())
|
||||
w.logger.Errorf("could not write h2 header: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
17
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
17
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
|
@ -53,6 +53,8 @@ type Server struct {
|
|||
closed bool
|
||||
|
||||
supportedVersionsAsString string
|
||||
|
||||
logger utils.Logger // will be set by Server.serveImpl()
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||
|
@ -88,6 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.logger = utils.DefaultLogger
|
||||
s.listenerMutex.Lock()
|
||||
if s.closed {
|
||||
s.listenerMutex.Unlock()
|
||||
|
@ -138,7 +141,7 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
if _, ok := err.(*qerr.QuicError); !ok {
|
||||
utils.Errorf("error handling h2 request: %s", err.Error())
|
||||
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.Close(err)
|
||||
return
|
||||
|
@ -160,7 +163,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
}
|
||||
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
utils.Errorf("invalid http2 headers encoding: %s", err.Error())
|
||||
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -169,10 +172,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
return err
|
||||
}
|
||||
|
||||
if utils.Debug() {
|
||||
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||
if s.logger.Debug() {
|
||||
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||
} else {
|
||||
utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||
}
|
||||
|
||||
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
@ -201,7 +204,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
|
||||
req.RemoteAddr = session.RemoteAddr().String()
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
|
||||
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
|
@ -215,7 +218,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
utils.Errorf("http: panic serving: %v\n%s", p, buf)
|
||||
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
|
|
38
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
38
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Connection is a UDP connection
|
||||
|
@ -43,6 +44,8 @@ func (d Direction) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
// Is says if one direction matches another direction.
|
||||
// For example, incoming matches both incoming and both, but not outgoing.
|
||||
func (d Direction) Is(dir Direction) bool {
|
||||
if d == DirectionBoth || dir == DirectionBoth {
|
||||
return true
|
||||
|
@ -92,6 +95,8 @@ type QuicProxy struct {
|
|||
|
||||
// Mapping from client addresses (as host:port) to connection
|
||||
clientDict map[string]*connection
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// NewQuicProxy creates a new UDP proxy
|
||||
|
@ -129,14 +134,23 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu
|
|||
dropPacket: packetDropper,
|
||||
delayPacket: packetDelayer,
|
||||
version: version,
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
|
||||
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
|
||||
go p.runProxy()
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Close stops the UDP Proxy
|
||||
func (p *QuicProxy) Close() error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
for _, c := range p.clientDict {
|
||||
if err := c.ServerConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
|
@ -189,19 +203,27 @@ func (p *QuicProxy) runProxy() error {
|
|||
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||
|
||||
if p.dropPacket(DirectionIncoming, packetCount) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet to the server
|
||||
delay := p.delayPacket(DirectionIncoming, packetCount)
|
||||
if delay != 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
|
||||
}
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
_, _ = conn.ServerConn.Write(raw)
|
||||
})
|
||||
} else {
|
||||
_, err := conn.ServerConn.Write(raw)
|
||||
if err != nil {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
|
||||
}
|
||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -221,18 +243,26 @@ func (p *QuicProxy) runConnection(conn *connection) error {
|
|||
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||
|
||||
if p.dropPacket(DirectionOutgoing, packetCount) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
||||
if delay != 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
|
||||
}
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
|
||||
})
|
||||
} else {
|
||||
_, err := p.conn.WriteToUDP(raw, conn.ClientAddr)
|
||||
if err != nil {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
|
||||
}
|
||||
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
|
@ -30,7 +30,7 @@ var _ = BeforeEach(func() {
|
|||
logFile, err = os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
log.SetOutput(logFile)
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
6
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
6
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
|
@ -22,7 +22,9 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
PRData = GeneratePRData(dataLen)
|
||||
// PRData contains dataLen bytes of pseudo-random data.
|
||||
PRData = GeneratePRData(dataLen)
|
||||
// PRDataLong contains dataLenLong bytes of pseudo-random data.
|
||||
PRDataLong = GeneratePRData(dataLenLong)
|
||||
|
||||
server *h2quic.Server
|
||||
|
@ -105,11 +107,13 @@ func StartQuicServer(versions []protocol.VersionNumber) {
|
|||
}()
|
||||
}
|
||||
|
||||
// StopQuicServer stops the h2quic.Server.
|
||||
func StopQuicServer() {
|
||||
Expect(server.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing).Should(BeClosed())
|
||||
}
|
||||
|
||||
// Port returns the UDP port of the QUIC server.
|
||||
func Port() string {
|
||||
return port
|
||||
}
|
||||
|
|
36
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
36
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
|
@ -16,6 +16,9 @@ type StreamID = protocol.StreamID
|
|||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
// VersionGQUIC39 is gQUIC version 39.
|
||||
const VersionGQUIC39 = protocol.Version39
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
||||
|
@ -113,15 +116,25 @@ type StreamError interface {
|
|||
// A Session is a QUIC connection between two peers.
|
||||
type Session interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
|
||||
AcceptStream() (Stream, error)
|
||||
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached.
|
||||
// New streams always have the smallest possible stream ID.
|
||||
// TODO: Enable testing for the special error
|
||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||
AcceptUniStream() (ReceiveStream, error)
|
||||
// OpenStream opens a new bidirectional QUIC stream.
|
||||
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream.
|
||||
// TODO(#1152): Enable testing for the special error
|
||||
OpenStream() (Stream, error)
|
||||
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
// It always picks the smallest possible stream ID.
|
||||
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
OpenStreamSync() (Stream, error)
|
||||
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
|
||||
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||
// TODO(#1152): Enable testing for the special error
|
||||
OpenUniStream() (SendStream, error)
|
||||
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
|
||||
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
OpenUniStreamSync() (SendStream, error)
|
||||
// LocalAddr returns the local address.
|
||||
LocalAddr() net.Addr
|
||||
// RemoteAddr returns the address of the peer.
|
||||
|
@ -166,6 +179,17 @@ type Config struct {
|
|||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||
MaxReceiveConnectionFlowControlWindow uint64
|
||||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingStreams int
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// This value doesn't have any effect in Google QUIC.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingUniStreams int
|
||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||
KeepAlive bool
|
||||
}
|
||||
|
|
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go
generated
vendored
Normal file
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go
generated
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
package ackhandler
|
||||
|
||||
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet
|
14
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
|
@ -10,15 +10,13 @@ import (
|
|||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet) error
|
||||
SentPacket(packet *Packet)
|
||||
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
|
||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||
SetHandshakeComplete()
|
||||
|
||||
// SendingAllowed says if a packet can be sent.
|
||||
// Sending packets might not be possible because:
|
||||
// * we're congestion limited
|
||||
// * we're tracking the maximum number of sent packets
|
||||
SendingAllowed() bool
|
||||
// The SendMode determines if and what kind of packets can be sent.
|
||||
SendMode() SendMode
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() time.Time
|
||||
|
@ -32,10 +30,10 @@ type SentPacketHandler interface {
|
|||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetLeastUnacked() protocol.PacketNumber
|
||||
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm()
|
||||
OnAlarm() error
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
|
|
26
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go
generated
vendored
|
@ -8,28 +8,22 @@ import (
|
|||
)
|
||||
|
||||
// A Packet is a packet
|
||||
// +gen linkedlist
|
||||
type Packet struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
PacketType protocol.PacketType
|
||||
Frames []wire.Frame
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
SendTime time.Time
|
||||
|
||||
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
|
||||
sendTime time.Time
|
||||
}
|
||||
|
||||
// GetFramesForRetransmission gets all the frames for retransmission
|
||||
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
|
||||
var fs []wire.Frame
|
||||
for _, frame := range p.Frames {
|
||||
switch frame.(type) {
|
||||
case *wire.AckFrame:
|
||||
continue
|
||||
case *wire.StopWaitingFrame:
|
||||
continue
|
||||
}
|
||||
fs = append(fs, frame)
|
||||
}
|
||||
return fs
|
||||
// There are two reasons why a packet cannot be retransmitted:
|
||||
// * it was already retransmitted
|
||||
// * this packet is a retransmission, and we already received an ACK for the original packet
|
||||
canBeRetransmitted bool
|
||||
includedInBytesInFlight bool
|
||||
retransmittedAs []protocol.PacketNumber
|
||||
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
|
||||
retransmissionOf protocol.PacketNumber
|
||||
}
|
||||
|
|
45
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go
generated
vendored
45
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go
generated
vendored
|
@ -1,13 +1,10 @@
|
|||
// Generated by: main
|
||||
// TypeWriter: linkedlist
|
||||
// Directive: +gen on Packet
|
||||
// This file was automatically generated by genny.
|
||||
// Any changes will be lost if this file is regenerated.
|
||||
// see https://github.com/cheekybits/genny
|
||||
|
||||
package ackhandler
|
||||
|
||||
// List is a modification of http://golang.org/pkg/container/list/
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// Linked list implementation from the Go standard library.
|
||||
|
||||
// PacketElement is an element of a linked list.
|
||||
type PacketElement struct {
|
||||
|
@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
|
|||
return nil
|
||||
}
|
||||
|
||||
// PacketList represents a doubly linked list.
|
||||
// The zero value for PacketList is an empty list ready to use.
|
||||
// PacketList is a linked list of Packets.
|
||||
type PacketList struct {
|
||||
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list length excluding (this) sentinel element
|
||||
|
@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
|
|||
// The complexity is O(1).
|
||||
func (l *PacketList) Len() int { return l.len }
|
||||
|
||||
// Front returns the first element of list l or nil.
|
||||
// Front returns the first element of list l or nil if the list is empty.
|
||||
func (l *PacketList) Front() *PacketElement {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
|
@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
|
|||
return l.root.next
|
||||
}
|
||||
|
||||
// Back returns the last element of list l or nil.
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *PacketList) Back() *PacketElement {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
|
@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
|
|||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero PacketList value.
|
||||
// lazyInit lazily initializes a zero List value.
|
||||
func (l *PacketList) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
|
@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
|
|||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at).
|
||||
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
|
||||
return l.insert(&PacketElement{Value: v}, at)
|
||||
}
|
||||
|
@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
|
|||
|
||||
// Remove removes e from l if e is an element of list l.
|
||||
// It returns the element value e.Value.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) Remove(e *PacketElement) Packet {
|
||||
if e.list == l {
|
||||
// if e.list == l, l must have been initialized when e was inserted
|
||||
// in l or l == nil (e is a zero PacketElement) and l.remove will crash
|
||||
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||
l.remove(e)
|
||||
}
|
||||
return e.Value
|
||||
|
@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
|
|||
|
||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) MoveToFront(e *PacketElement) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), &l.root)
|
||||
}
|
||||
|
||||
// MoveToBack moves element e to the back of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) MoveToBack(e *PacketElement) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), l.root.prev)
|
||||
}
|
||||
|
||||
// MoveBefore moves element e to its new position before mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
|
@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
|||
}
|
||||
|
||||
// MoveAfter moves element e to its new position after mark.
|
||||
// If e is not an element of l, or e == mark, the list is not modified.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
|
@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
|||
}
|
||||
|
||||
// PushBackList inserts a copy of an other list at the back of list l.
|
||||
// The lists l and other may be the same.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *PacketList) PushBackList(other *PacketList) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||
|
@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
|
|||
}
|
||||
|
||||
// PushFrontList inserts a copy of an other list at the front of list l.
|
||||
// The lists l and other may be the same.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *PacketList) PushFrontList(other *PacketList) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
||||
|
|
110
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
110
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
|
@ -3,7 +3,9 @@ package ackhandler
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
|
@ -15,6 +17,7 @@ type receivedPacketHandler struct {
|
|||
packetHistory *receivedPacketHistory
|
||||
|
||||
ackSendDelay time.Duration
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
packetsReceivedSinceLastAck int
|
||||
retransmittablePacketsReceivedSinceLastAck int
|
||||
|
@ -25,29 +28,54 @@ type receivedPacketHandler struct {
|
|||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
const (
|
||||
// maximum delay that can be applied to an ACK for a retransmittable packet
|
||||
ackSendDelay = 25 * time.Millisecond
|
||||
// initial maximum number of retransmittable packets received before sending an ack.
|
||||
initialRetransmittablePacketsBeforeAck = 2
|
||||
// number of retransmittable that an ACK is sent for
|
||||
retransmittablePacketsBeforeAck = 10
|
||||
// 1/5 RTT delay when doing ack decimation
|
||||
ackDecimationDelay = 1.0 / 4
|
||||
// 1/8 RTT delay when doing ack decimation
|
||||
shortAckDecimationDelay = 1.0 / 8
|
||||
// Minimum number of packets received before ack decimation is enabled.
|
||||
// This intends to avoid the beginning of slow start, when CWNDs may be
|
||||
// rapidly increasing.
|
||||
minReceivedBeforeAckDecimation = 100
|
||||
// Maximum number of packets to ack immediately after a missing packet for
|
||||
// fast retransmission to kick in at the sender. This limit is created to
|
||||
// reduce the number of acks sent that have no benefit for fast retransmission.
|
||||
// Set to the number of nacks needed for fast retransmit plus one for protection
|
||||
// against an ack loss
|
||||
maxPacketsAfterNewMissing = 4
|
||||
)
|
||||
|
||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
|
||||
func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
ackSendDelay: protocol.AckSendDelay,
|
||||
ackSendDelay: ackSendDelay,
|
||||
rttStats: rttStats,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||
if packetNumber < h.ignoreBelow {
|
||||
return nil
|
||||
}
|
||||
|
||||
isMissing := h.isMissing(packetNumber)
|
||||
if packetNumber > h.largestObserved {
|
||||
h.largestObserved = packetNumber
|
||||
h.largestObservedReceivedTime = rcvTime
|
||||
}
|
||||
|
||||
if packetNumber < h.ignoreBelow {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
|
||||
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -58,35 +86,68 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
|||
h.packetHistory.DeleteBelow(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
|
||||
h.packetsReceivedSinceLastAck++
|
||||
|
||||
if shouldInstigateAck {
|
||||
h.retransmittablePacketsReceivedSinceLastAck++
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
highestRange := h.packetHistory.GetHighestAckRange()
|
||||
return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing
|
||||
}
|
||||
|
||||
// maybeQueueAck queues an ACK, if necessary.
|
||||
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
|
||||
// in ACK_DECIMATION_WITH_REORDERING mode.
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
|
||||
h.packetsReceivedSinceLastAck++
|
||||
|
||||
// always ack the first packet
|
||||
if h.lastAck == nil {
|
||||
h.ackQueued = true
|
||||
return
|
||||
}
|
||||
|
||||
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
|
||||
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
|
||||
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
// check if a new missing range above the previously was created
|
||||
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
|
||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||
if wasMissing {
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if !h.ackQueued && shouldInstigateAck {
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
|
||||
h.ackQueued = true
|
||||
h.retransmittablePacketsReceivedSinceLastAck++
|
||||
|
||||
if packetNumber > minReceivedBeforeAckDecimation {
|
||||
// ack up to 10 packets at once
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
|
||||
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
|
||||
h.ackAlarm = rcvTime.Add(ackDelay)
|
||||
}
|
||||
} else {
|
||||
if h.ackAlarm.IsZero() {
|
||||
h.ackAlarm = rcvTime.Add(h.ackSendDelay)
|
||||
// send an ACK every 2 retransmittable packets
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
h.ackAlarm = rcvTime.Add(ackSendDelay)
|
||||
}
|
||||
}
|
||||
// If there are new missing packets to report, set a short timer to send an ACK.
|
||||
if h.hasNewMissingPackets() {
|
||||
// wait the minimum of 1/8 min RTT and the existing ack time
|
||||
ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)
|
||||
ackTime := rcvTime.Add(time.Duration(ackDelay))
|
||||
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
|
||||
h.ackAlarm = ackTime
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -118,7 +179,6 @@ func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
|||
h.ackQueued = false
|
||||
h.packetsReceivedSinceLastAck = 0
|
||||
h.retransmittablePacketsReceivedSinceLastAck = 0
|
||||
|
||||
return ack
|
||||
}
|
||||
|
||||
|
|
36
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
36
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
package ackhandler
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The SendMode says what kind of packets can be sent.
|
||||
type SendMode uint8
|
||||
|
||||
const (
|
||||
// SendNone means that no packets should be sent
|
||||
SendNone SendMode = iota
|
||||
// SendAck means an ACK-only packet should be sent
|
||||
SendAck
|
||||
// SendRetransmission means that retransmissions should be sent
|
||||
SendRetransmission
|
||||
// SendRTO means that an RTO probe packet should be sent
|
||||
SendRTO
|
||||
// SendAny packet should be sent
|
||||
SendAny
|
||||
)
|
||||
|
||||
func (s SendMode) String() string {
|
||||
switch s {
|
||||
case SendNone:
|
||||
return "none"
|
||||
case SendAck:
|
||||
return "ack"
|
||||
case SendRetransmission:
|
||||
return "retransmission"
|
||||
case SendRTO:
|
||||
return "rto"
|
||||
case SendAny:
|
||||
return "any"
|
||||
default:
|
||||
return fmt.Sprintf("invalid send mode: %d", s)
|
||||
}
|
||||
}
|
446
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
446
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
|
@ -1,7 +1,6 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
@ -30,13 +29,13 @@ const (
|
|||
maxRTOTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
|
||||
var ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
|
||||
|
||||
type sentPacketHandler struct {
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
nextPacketSendTime time.Time
|
||||
skippedPackets []protocol.PacketNumber
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
lastSentRetransmittablePacketTime time.Time
|
||||
lastSentHandshakePacketTime time.Time
|
||||
|
||||
nextPacketSendTime time.Time
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
|
@ -44,8 +43,9 @@ type sentPacketHandler struct {
|
|||
// example: we send an ACK for packets 90-100 with packet number 20
|
||||
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
|
||||
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||
largestSentBeforeRTO protocol.PacketNumber
|
||||
|
||||
packetHistory *PacketList
|
||||
packetHistory *sentPacketHistory
|
||||
stopWaitingManager stopWaitingManager
|
||||
|
||||
retransmissionQueue []*Packet
|
||||
|
@ -61,16 +61,20 @@ type sentPacketHandler struct {
|
|||
|
||||
// The number of times an RTO has been sent without receiving an ack.
|
||||
rtoCount uint32
|
||||
// The number of RTO probe packets that should be sent.
|
||||
numRTOs int
|
||||
|
||||
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
|
||||
lossTime time.Time
|
||||
|
||||
// The alarm timeout
|
||||
alarm time.Time
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
|
@ -80,16 +84,17 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
|
|||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: NewPacketList(),
|
||||
packetHistory: newSentPacketHistory(),
|
||||
stopWaitingManager: stopWaitingManager{},
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
|
||||
if f := h.packetHistory.Front(); f != nil {
|
||||
return f.Value.PacketNumber
|
||||
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||
return p.PacketNumber
|
||||
}
|
||||
return h.largestAcked + 1
|
||||
}
|
||||
|
@ -101,30 +106,51 @@ func (h *sentPacketHandler) SetHandshakeComplete() {
|
|||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range handshakePackets {
|
||||
h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
h.retransmissionQueue = queue
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
||||
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
|
||||
return errors.New("Too many outstanding non-acked and non-retransmitted packets")
|
||||
func (h *sentPacketHandler) SentPacket(packet *Packet) {
|
||||
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||
h.packetHistory.SentPacket(packet)
|
||||
h.updateLossDetectionAlarm()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||
var p []*Packet
|
||||
for _, packet := range packets {
|
||||
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||
p = append(p, packet)
|
||||
}
|
||||
}
|
||||
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
|
||||
h.updateLossDetectionAlarm()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
h.skippedPackets = append(h.skippedPackets, p)
|
||||
|
||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
||||
h.skippedPackets = h.skippedPackets[1:]
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
|
||||
var largestAcked protocol.PacketNumber
|
||||
if len(packet.Frames) > 0 {
|
||||
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||
largestAcked = ackFrame.LargestAcked
|
||||
packet.largestAcked = ackFrame.LargestAcked
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,24 +158,21 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
packet.sendTime = now
|
||||
packet.largestAcked = largestAcked
|
||||
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.lastSentHandshakePacketTime = packet.SendTime
|
||||
}
|
||||
h.lastSentRetransmittablePacketTime = packet.SendTime
|
||||
packet.includedInBytesInFlight = true
|
||||
h.bytesInFlight += packet.Length
|
||||
h.packetHistory.PushBack(*packet)
|
||||
packet.canBeRetransmitted = true
|
||||
if h.numRTOs > 0 {
|
||||
h.numRTOs--
|
||||
}
|
||||
}
|
||||
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
|
||||
|
||||
h.congestion.OnPacketSent(
|
||||
now,
|
||||
h.bytesInFlight,
|
||||
packet.PacketNumber,
|
||||
packet.Length,
|
||||
isRetransmittable,
|
||||
)
|
||||
|
||||
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
|
||||
|
||||
h.updateLossDetectionAlarm(now)
|
||||
return nil
|
||||
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
|
||||
return isRetransmittable
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||
|
@ -157,26 +180,19 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||
}
|
||||
|
||||
// duplicate or out-of-order ACK
|
||||
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
|
||||
if withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||
return ErrDuplicateOrOutOfOrderAck
|
||||
}
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
|
||||
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
|
||||
if ackFrame.LargestAcked < h.lowestUnacked() {
|
||||
// duplicate or out of order ACK
|
||||
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
|
||||
return nil
|
||||
}
|
||||
h.largestAcked = ackFrame.LargestAcked
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, ackFrame.LargestAcked)
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
}
|
||||
|
||||
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
|
||||
|
||||
if rttUpdated {
|
||||
if rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
|
||||
|
@ -185,24 +201,29 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
return err
|
||||
}
|
||||
|
||||
if len(ackedPackets) > 0 {
|
||||
for _, p := range ackedPackets {
|
||||
if encLevel < p.Value.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
|
||||
}
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
if p.Value.largestAcked != 0 {
|
||||
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
|
||||
}
|
||||
h.onPacketAcked(p)
|
||||
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
priorInFlight := h.bytesInFlight
|
||||
for _, p := range ackedPackets {
|
||||
if encLevel < p.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||
}
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
if p.largestAcked != 0 {
|
||||
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
|
||||
}
|
||||
if err := h.onPacketAcked(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.includedInBytesInFlight {
|
||||
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
h.detectLostPackets(rcvTime)
|
||||
h.updateLossDetectionAlarm(rcvTime)
|
||||
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
|
||||
return err
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
@ -214,59 +235,50 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
|
|||
return h.lowestPacketNotConfirmedAcked
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
|
||||
var ackedPackets []*PacketElement
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
|
||||
var ackedPackets []*Packet
|
||||
ackRangeIndex := 0
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
packetNumber := packet.PacketNumber
|
||||
|
||||
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
// Ignore packets below the LowestAcked
|
||||
if packetNumber < ackFrame.LowestAcked {
|
||||
continue
|
||||
if p.PacketNumber < ackFrame.LowestAcked {
|
||||
return true, nil
|
||||
}
|
||||
// Break after LargestAcked is reached
|
||||
if packetNumber > ackFrame.LargestAcked {
|
||||
break
|
||||
if p.PacketNumber > ackFrame.LargestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ackFrame.HasMissingRanges() {
|
||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
for p.PacketNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if packetNumber >= ackRange.First { // packet i contained in ACK range
|
||||
if packetNumber > ackRange.Last {
|
||||
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
|
||||
if p.PacketNumber >= ackRange.First { // packet i contained in ACK range
|
||||
if p.PacketNumber > ackRange.Last {
|
||||
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.First, ackRange.Last)
|
||||
}
|
||||
ackedPackets = append(ackedPackets, el)
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
}
|
||||
} else {
|
||||
ackedPackets = append(ackedPackets, el)
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
}
|
||||
}
|
||||
return ackedPackets, nil
|
||||
return true, nil
|
||||
})
|
||||
return ackedPackets, err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
if packet.PacketNumber == largestAcked {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
|
||||
return true
|
||||
}
|
||||
// Packets are sorted by number, so we can stop searching
|
||||
if packet.PacketNumber > largestAcked {
|
||||
break
|
||||
}
|
||||
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if h.packetHistory.Len() == 0 {
|
||||
h.alarm = time.Time{}
|
||||
|
@ -275,76 +287,152 @@ func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
|
|||
|
||||
// TODO(#497): TLP
|
||||
if !h.handshakeComplete {
|
||||
h.alarm = now.Add(h.computeHandshakeTimeout())
|
||||
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = h.lossTime
|
||||
} else {
|
||||
// RTO
|
||||
h.alarm = now.Add(h.computeRTOTimeout())
|
||||
h.alarm = h.lastSentRetransmittablePacketTime.Add(h.computeRTOTimeout())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time) {
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
|
||||
h.lossTime = time.Time{}
|
||||
|
||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||
|
||||
var lostPackets []*PacketElement
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
|
||||
var lostPackets []*Packet
|
||||
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
|
||||
if packet.PacketNumber > h.largestAcked {
|
||||
break
|
||||
return false, nil
|
||||
}
|
||||
|
||||
timeSinceSent := now.Sub(packet.sendTime)
|
||||
timeSinceSent := now.Sub(packet.SendTime)
|
||||
if timeSinceSent > delayUntilLost {
|
||||
lostPackets = append(lostPackets, el)
|
||||
lostPackets = append(lostPackets, packet)
|
||||
} else if h.lossTime.IsZero() {
|
||||
// Note: This conditional is only entered once per call
|
||||
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if len(lostPackets) > 0 {
|
||||
for _, p := range lostPackets {
|
||||
h.queuePacketForRetransmission(p)
|
||||
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
for _, p := range lostPackets {
|
||||
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
|
||||
}
|
||||
if p.canBeRetransmitted {
|
||||
// queue the packet for retransmission, and report the loss to the congestion controller
|
||||
h.logger.Debugf("\tQueueing packet %#x because it was detected lost", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) OnAlarm() {
|
||||
func (h *sentPacketHandler) OnAlarm() error {
|
||||
now := time.Now()
|
||||
|
||||
// TODO(#497): TLP
|
||||
var err error
|
||||
if !h.handshakeComplete {
|
||||
h.queueHandshakePacketsForRetransmission()
|
||||
h.handshakeCount++
|
||||
err = h.queueHandshakePacketsForRetransmission()
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit or time loss detection
|
||||
h.detectLostPackets(now)
|
||||
err = h.detectLostPackets(now, h.bytesInFlight)
|
||||
} else {
|
||||
// RTO
|
||||
h.retransmitOldestTwoPackets()
|
||||
h.rtoCount++
|
||||
h.numRTOs += 2
|
||||
err = h.queueRTOs()
|
||||
}
|
||||
|
||||
h.updateLossDetectionAlarm(now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
return h.alarm
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
|
||||
h.bytesInFlight -= packetElement.Value.Length
|
||||
func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
|
||||
// This happens if a packet and its retransmissions is acked in the same ACK.
|
||||
// As soon as we process the first one, this will remove all the retransmissions,
|
||||
// so we won't find the retransmitted packet number later.
|
||||
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only report the acking of this packet to the congestion controller if:
|
||||
// * it is a retransmittable packet
|
||||
// * this packet wasn't retransmitted yet
|
||||
if p.isRetransmission {
|
||||
// that the parent doesn't exist is expected to happen every time the original packet was already acked
|
||||
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
|
||||
if len(parent.retransmittedAs) == 1 {
|
||||
parent.retransmittedAs = nil
|
||||
} else {
|
||||
// remove this packet from the slice of retransmission
|
||||
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
|
||||
for _, pn := range parent.retransmittedAs {
|
||||
if pn != p.PacketNumber {
|
||||
retransmittedAs = append(retransmittedAs, pn)
|
||||
}
|
||||
}
|
||||
parent.retransmittedAs = retransmittedAs
|
||||
}
|
||||
}
|
||||
}
|
||||
// this also applies to packets that have been retransmitted as probe packets
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
}
|
||||
if h.rtoCount > 0 {
|
||||
h.verifyRTO(p.PacketNumber)
|
||||
}
|
||||
if err := h.stopRetransmissionsFor(p); err != nil {
|
||||
return err
|
||||
}
|
||||
h.rtoCount = 0
|
||||
h.handshakeCount = 0
|
||||
// TODO(#497): h.tlpCount = 0
|
||||
h.packetHistory.Remove(packetElement)
|
||||
return h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
|
||||
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range p.retransmittedAs {
|
||||
packet := h.packetHistory.GetPacket(r)
|
||||
if packet == nil {
|
||||
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
|
||||
}
|
||||
h.stopRetransmissionsFor(packet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
|
||||
if pn <= h.largestSentBeforeRTO {
|
||||
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
|
||||
// Replace SRTT with latest_rtt and increase the variance to prevent
|
||||
// a spurious RTO from happening again.
|
||||
h.rttStats.ExpireSmoothedMetrics()
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
|
||||
h.congestion.OnRetransmissionTimeout(true)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
||||
|
@ -359,26 +447,42 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
|||
return packet
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
|
||||
return h.lowestUnacked()
|
||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendingAllowed() bool {
|
||||
cwnd := h.congestion.GetCongestionWindow()
|
||||
congestionLimited := h.bytesInFlight > cwnd
|
||||
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
|
||||
if congestionLimited {
|
||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
func (h *sentPacketHandler) SendMode() SendMode {
|
||||
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
|
||||
|
||||
// Don't send any packets if we're keeping track of the maximum number of packets.
|
||||
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
|
||||
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
|
||||
// but still allow sending of retransmissions and ACKs.
|
||||
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
|
||||
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||
return SendNone
|
||||
}
|
||||
// Workaround for #555:
|
||||
// Always allow sending of retransmissions. This should probably be limited
|
||||
// to RTOs, but we currently don't have a nice way of distinguishing them.
|
||||
haveRetransmissions := len(h.retransmissionQueue) > 0
|
||||
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
|
||||
if h.numRTOs > 0 {
|
||||
return SendRTO
|
||||
}
|
||||
// Only send ACKs if we're congestion limited.
|
||||
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
|
||||
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
return SendAck
|
||||
}
|
||||
// Send retransmissions first, if there are any.
|
||||
if len(h.retransmissionQueue) > 0 {
|
||||
return SendRetransmission
|
||||
}
|
||||
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
|
||||
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||
return SendAck
|
||||
}
|
||||
return SendAny
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||
|
@ -386,6 +490,10 @@ func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
||||
if h.numRTOs > 0 {
|
||||
// RTO probes should not be paced, but must be sent immediately.
|
||||
return h.numRTOs
|
||||
}
|
||||
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||
if delay == 0 || delay > protocol.MinPacingDelay {
|
||||
return 1
|
||||
|
@ -393,45 +501,50 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
|||
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
||||
if p := h.packetHistory.Front(); p != nil {
|
||||
h.queueRTO(p)
|
||||
}
|
||||
if p := h.packetHistory.Front(); p != nil {
|
||||
h.queueRTO(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
|
||||
packet := &el.Value
|
||||
utils.Debugf(
|
||||
"\tQueueing packet 0x%x for retransmission (RTO), %d outstanding",
|
||||
packet.PacketNumber,
|
||||
h.packetHistory.Len(),
|
||||
)
|
||||
h.queuePacketForRetransmission(el)
|
||||
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
|
||||
h.congestion.OnRetransmissionTimeout(true)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
|
||||
var handshakePackets []*PacketElement
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, el)
|
||||
// retransmit the oldest two packets
|
||||
func (h *sentPacketHandler) queueRTOs() error {
|
||||
h.largestSentBeforeRTO = h.lastSentPacketNumber
|
||||
// Queue the first two outstanding packets for retransmission.
|
||||
// This does NOT declare this packets as lost:
|
||||
// They are still tracked in the packet history and count towards the bytes in flight.
|
||||
for i := 0; i < 2; i++ {
|
||||
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||
h.logger.Debugf("\tQueueing packet %#x for retransmission (RTO)", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, el := range handshakePackets {
|
||||
h.queuePacketForRetransmission(el)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
|
||||
packet := &packetElement.Value
|
||||
h.bytesInFlight -= packet.Length
|
||||
h.retransmissionQueue = append(h.retransmissionQueue, packet)
|
||||
h.packetHistory.Remove(packetElement)
|
||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
|
||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range handshakePackets {
|
||||
h.logger.Debugf("\tQueueing packet %#x as a handshake retransmission", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
||||
if !p.canBeRetransmitted {
|
||||
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
|
||||
}
|
||||
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||
|
@ -446,9 +559,12 @@ func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||
rto := h.congestion.RetransmissionDelay()
|
||||
if rto == 0 {
|
||||
var rto time.Duration
|
||||
rtt := h.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
rto = defaultRTOTimeout
|
||||
} else {
|
||||
rto = rtt + 4*h.rttStats.MeanDeviation()
|
||||
}
|
||||
rto = utils.MaxDuration(rto, minRTOTimeout)
|
||||
// Exponential backoff
|
||||
|
|
127
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
127
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
|
@ -0,0 +1,127 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type sentPacketHistory struct {
|
||||
packetList *PacketList
|
||||
packetMap map[protocol.PacketNumber]*PacketElement
|
||||
|
||||
firstOutstanding *PacketElement
|
||||
}
|
||||
|
||||
func newSentPacketHistory() *sentPacketHistory {
|
||||
return &sentPacketHistory{
|
||||
packetList: NewPacketList(),
|
||||
packetMap: make(map[protocol.PacketNumber]*PacketElement),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPacket(p *Packet) {
|
||||
h.sentPacketImpl(p)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
||||
el := h.packetList.PushBack(*p)
|
||||
h.packetMap[p.PacketNumber] = el
|
||||
if h.firstOutstanding == nil {
|
||||
h.firstOutstanding = el
|
||||
}
|
||||
return el
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||
retransmission, ok := h.packetMap[retransmissionOf]
|
||||
// The retransmitted packet is not present anymore.
|
||||
// This can happen if it was acked in between dequeueing of the retransmission and sending.
|
||||
// Just treat the retransmissions as normal packets.
|
||||
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
|
||||
if !ok {
|
||||
for _, packet := range packets {
|
||||
h.sentPacketImpl(packet)
|
||||
}
|
||||
return
|
||||
}
|
||||
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
|
||||
for i, packet := range packets {
|
||||
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
|
||||
el := h.sentPacketImpl(packet)
|
||||
el.Value.isRetransmission = true
|
||||
el.Value.retransmissionOf = retransmissionOf
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
|
||||
if el, ok := h.packetMap[p]; ok {
|
||||
return &el.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterate iterates through all packets.
|
||||
// The callback must not modify the history.
|
||||
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
|
||||
cont := true
|
||||
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
|
||||
var err error
|
||||
cont, err = cb(&el.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FirstOutStanding returns the first outstanding packet.
|
||||
// It must not be modified (e.g. retransmitted).
|
||||
// Use DequeueFirstPacketForRetransmission() to retransmit it.
|
||||
func (h *sentPacketHistory) FirstOutstanding() *Packet {
|
||||
if h.firstOutstanding == nil {
|
||||
return nil
|
||||
}
|
||||
return &h.firstOutstanding.Value
|
||||
}
|
||||
|
||||
// QueuePacketForRetransmission marks a packet for retransmission.
|
||||
// A packet can only be queued once.
|
||||
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
|
||||
el, ok := h.packetMap[pn]
|
||||
if !ok {
|
||||
return fmt.Errorf("sent packet history: packet %d not found", pn)
|
||||
}
|
||||
el.Value.canBeRetransmitted = false
|
||||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
|
||||
// This is necessary every time the first outstanding packet is deleted or retransmitted.
|
||||
func (h *sentPacketHistory) readjustFirstOutstanding() {
|
||||
el := h.firstOutstanding.Next()
|
||||
for el != nil && !el.Value.canBeRetransmitted {
|
||||
el = el.Next()
|
||||
}
|
||||
h.firstOutstanding = el
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Len() int {
|
||||
return len(h.packetMap)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
||||
el, ok := h.packetMap[p]
|
||||
if !ok {
|
||||
return fmt.Errorf("packet %d not found in sent packet history", p)
|
||||
}
|
||||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
}
|
||||
h.packetList.Remove(el)
|
||||
delete(h.packetMap, p)
|
||||
return nil
|
||||
}
|
8
vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go
generated
vendored
|
@ -292,11 +292,3 @@ func (c *cubicSender) OnConnectionMigration() {
|
|||
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
|
||||
c.slowStartLargeReduction = enabled
|
||||
}
|
||||
|
||||
// RetransmissionDelay gives the time to retransmission
|
||||
func (c *cubicSender) RetransmissionDelay() time.Duration {
|
||||
if c.rttStats.SmoothedRTT() == 0 {
|
||||
return 0
|
||||
}
|
||||
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
|
||||
}
|
||||
|
|
1
vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go
generated
vendored
|
@ -17,7 +17,6 @@ type SendAlgorithm interface {
|
|||
SetNumEmulatedConnections(n int)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
OnConnectionMigration()
|
||||
RetransmissionDelay() time.Duration
|
||||
|
||||
// Experiments
|
||||
SetSlowStartLargeReduction(enabled bool)
|
||||
|
|
1
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
|
@ -84,7 +84,6 @@ func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
|
|||
// UpdateRTT updates the RTT based on a new sample.
|
||||
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
||||
if sendDelta == utils.InfDuration || sendDelta <= 0 {
|
||||
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
22
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go
generated
vendored
22
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go
generated
vendored
|
@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
|
|||
return cert.Certificate[0], nil
|
||||
}
|
||||
|
||||
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
c := cc.config
|
||||
c, err := maybeGetConfigForClient(c, sni)
|
||||
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
conf := c.config
|
||||
conf, err := maybeGetConfigForClient(conf, sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The rest of this function is mostly copied from crypto/tls.getCertificate
|
||||
|
||||
if c.GetCertificate != nil {
|
||||
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
||||
if conf.GetCertificate != nil {
|
||||
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
||||
if cert != nil || err != nil {
|
||||
return cert, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.Certificates) == 0 {
|
||||
if len(conf.Certificates) == 0 {
|
||||
return nil, errNoMatchingCertificate
|
||||
}
|
||||
|
||||
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
|
||||
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
|
||||
// There's only one choice, so no point doing any work.
|
||||
return &c.Certificates[0], nil
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
name := strings.ToLower(sni)
|
||||
|
@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
|||
name = name[:len(name)-1]
|
||||
}
|
||||
|
||||
if cert, ok := c.NameToCertificate[name]; ok {
|
||||
if cert, ok := conf.NameToCertificate[name]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
|
@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
|||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok := c.NameToCertificate[candidate]; ok {
|
||||
if cert, ok := conf.NameToCertificate[candidate]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing matches, return the first certificate.
|
||||
return &c.Certificates[0], nil
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
|
||||
|
|
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go
generated
vendored
|
@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) {
|
|||
if _, err := rand.Read(c.secret[:]); err != nil {
|
||||
return nil, errors.New("Curve25519: could not create private key")
|
||||
}
|
||||
// See https://cr.yp.to/ecdh.html
|
||||
c.secret[0] &= 248
|
||||
c.secret[31] &= 127
|
||||
c.secret[31] |= 64
|
||||
curve25519.ScalarBaseMult(&c.public, &c.secret)
|
||||
return c, nil
|
||||
}
|
||||
|
|
19
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
19
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
|
@ -1,13 +1,16 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret"
|
||||
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
|
||||
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
|
||||
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
|
||||
)
|
||||
|
||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||
|
@ -16,6 +19,14 @@ type TLSExporter interface {
|
|||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
}
|
||||
|
||||
func qhkdfExpand(secret []byte, label string, length int) []byte {
|
||||
qlabel := make([]byte, 2+1+5+len(label))
|
||||
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
|
||||
qlabel[2] = uint8(5 + len(label))
|
||||
copy(qlabel[3:], []byte("QUIC "+label))
|
||||
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
|
||||
}
|
||||
|
||||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
||||
var myLabel, otherLabel string
|
||||
|
@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error)
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen)
|
||||
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen)
|
||||
key = qhkdfExpand(secret, "key", cs.KeyLen)
|
||||
iv = qhkdfExpand(secret, "iv", cs.IvLen)
|
||||
return key, iv, nil
|
||||
}
|
||||
|
|
12
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
12
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39}
|
||||
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
|
||||
|
||||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||
|
@ -31,14 +31,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
|
|||
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
connID := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(connID, uint64(connectionID))
|
||||
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID)
|
||||
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
|
||||
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
|
||||
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
||||
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16)
|
||||
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12)
|
||||
key = qhkdfExpand(secret, "key", 16)
|
||||
iv = qhkdfExpand(secret, "iv", 12)
|
||||
return
|
||||
}
|
||||
|
|
36
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go
generated
vendored
36
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go
generated
vendored
|
@ -1,10 +1,11 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/fnv128a"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
|
@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
|
|||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
||||
}
|
||||
|
||||
hash := fnv128a.New()
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src[12:])
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
|
@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
|
|||
} else {
|
||||
hash.Write([]byte("Server"))
|
||||
}
|
||||
testHigh, testLow := hash.Sum128()
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
low := binary.LittleEndian.Uint64(src)
|
||||
high := binary.LittleEndian.Uint32(src[8:])
|
||||
|
||||
if uint32(testHigh&0xffffffff) != high || testLow != low {
|
||||
return nil, errors.New("NullAEAD: failed to authenticate received data")
|
||||
if !bytes.Equal(sum[:12], src[:12]) {
|
||||
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
|
||||
}
|
||||
return src[12:], nil
|
||||
}
|
||||
|
@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
|
|||
dst = dst[:12+len(src)]
|
||||
}
|
||||
|
||||
hash := fnv128a.New()
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src)
|
||||
|
||||
|
@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
|
|||
} else {
|
||||
hash.Write([]byte("Client"))
|
||||
}
|
||||
|
||||
high, low := hash.Sum128()
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
copy(dst[12:], src)
|
||||
binary.LittleEndian.PutUint64(dst, low)
|
||||
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
|
||||
copy(dst, sum[:12])
|
||||
return dst
|
||||
}
|
||||
|
||||
func (n *nullAEADFNV128a) Overhead() int {
|
||||
return 12
|
||||
}
|
||||
|
||||
func reverse(a []byte) {
|
||||
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
|
||||
a[left], a[right] = a[right], a[left]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,8 @@ type baseFlowController struct {
|
|||
epochStartTime time.Time
|
||||
epochStartOffset protocol.ByteCount
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
|
|
|
@ -22,6 +22,7 @@ func NewConnectionFlowController(
|
|||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ConnectionFlowController {
|
||||
return &connectionFlowController{
|
||||
baseFlowController: baseFlowController{
|
||||
|
@ -29,6 +30,7 @@ func NewConnectionFlowController(
|
|||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -65,7 +67,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
|||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if oldWindowSize < c.receiveWindowSize {
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
|
|
|
@ -31,6 +31,7 @@ func NewStreamFlowController(
|
|||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
|
@ -42,6 +43,7 @@ func NewStreamFlowController(
|
|||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -137,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
|||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
if c.contributesToConnection {
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
|
|
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
|
@ -7,15 +7,20 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A CookieHandler generates and validates cookies.
|
||||
// The cookie is sent in the TLS Retry.
|
||||
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
|
||||
type CookieHandler struct {
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
cookieGenerator *CookieGenerator
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ mint.CookieHandler = &CookieHandler{}
|
||||
|
||||
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
|
||||
// NewCookieHandler creates a new CookieHandler.
|
||||
func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -23,9 +28,11 @@ func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, er
|
|||
return &CookieHandler{
|
||||
callback: callback,
|
||||
cookieGenerator: cookieGenerator,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate a new cookie for a mint connection.
|
||||
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
if h.callback(conn.RemoteAddr(), nil) {
|
||||
return nil, nil
|
||||
|
@ -33,10 +40,11 @@ func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
|||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||
}
|
||||
|
||||
// Validate a cookie.
|
||||
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
data, err := h.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
return false
|
||||
}
|
||||
return h.callback(conn.RemoteAddr(), data)
|
||||
|
|
36
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
36
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
|
@ -38,13 +38,12 @@ type cryptoSetupClient struct {
|
|||
lastSentCHLO []byte
|
||||
certManager crypto.CertManager
|
||||
|
||||
divNonceChan chan []byte
|
||||
divNonceChan <-chan []byte
|
||||
diversificationNonce []byte
|
||||
|
||||
clientHelloCounter int
|
||||
serverVerified bool // has the certificate chain and the proof already been verified
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
|
||||
receivedSecurePacket bool
|
||||
nullAEAD crypto.AEAD
|
||||
|
@ -55,6 +54,8 @@ type cryptoSetupClient struct {
|
|||
handshakeEvent chan<- struct{}
|
||||
|
||||
params *TransportParameters
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupClient{}
|
||||
|
@ -77,12 +78,14 @@ func NewCryptoSetupClient(
|
|||
handshakeEvent chan<- struct{},
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, chan<- []byte, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
return &cryptoSetupClient{
|
||||
divNonceChan := make(chan []byte)
|
||||
cs := &cryptoSetupClient{
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
|
@ -90,14 +93,15 @@ func NewCryptoSetupClient(
|
|||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
params: params,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
negotiatedVersions: negotiatedVersions,
|
||||
divNonceChan: make(chan []byte),
|
||||
}, nil
|
||||
divNonceChan: divNonceChan,
|
||||
logger: logger,
|
||||
}
|
||||
return cs, divNonceChan, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
|
@ -146,7 +150,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||
return err
|
||||
}
|
||||
|
||||
utils.Debugf("Got %s", message)
|
||||
h.logger.Debugf("Got %s", message)
|
||||
switch message.Tag {
|
||||
case TagREJ:
|
||||
if err := h.handleREJMessage(message.Data); err != nil {
|
||||
|
@ -211,7 +215,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
|||
|
||||
err = h.certManager.Verify(h.hostname)
|
||||
if err != nil {
|
||||
utils.Infof("Certificate validation failed: %s", err.Error())
|
||||
h.logger.Infof("Certificate validation failed: %s", err.Error())
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
}
|
||||
|
@ -219,7 +223,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
|||
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
|
||||
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
|
||||
if !validProof {
|
||||
utils.Infof("Server proof verification failed")
|
||||
h.logger.Infof("Server proof verification failed")
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
|
||||
|
@ -373,14 +377,6 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
|
|||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) DiversificationNonce() []byte {
|
||||
panic("not needed for cryptoSetupClient")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
||||
h.divNonceChan <- data
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
@ -408,7 +404,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
|
|||
Data: tags,
|
||||
}
|
||||
|
||||
utils.Debugf("Sending %s", message)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
message.Write(b)
|
||||
|
||||
_, err = h.cryptoStream.Write(b.Bytes())
|
||||
|
|
65
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
65
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
|
@ -19,7 +19,7 @@ import (
|
|||
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
// KeyExchangeFunction is used to make a new KEX
|
||||
type KeyExchangeFunction func() crypto.KeyExchange
|
||||
type KeyExchangeFunction func() (crypto.KeyExchange, error)
|
||||
|
||||
// The CryptoSetupServer handles all things crypto for the Session
|
||||
type cryptoSetupServer struct {
|
||||
|
@ -54,6 +54,8 @@ type cryptoSetupServer struct {
|
|||
params *TransportParameters
|
||||
|
||||
sni string // need to fill out the ConnectionState
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupServer{}
|
||||
|
@ -73,32 +75,36 @@ func NewCryptoSetup(
|
|||
connID protocol.ConnectionID,
|
||||
remoteAddr net.Addr,
|
||||
version protocol.VersionNumber,
|
||||
divNonce []byte,
|
||||
scfg *ServerConfig,
|
||||
params *TransportParameters,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
acceptSTK func(net.Addr, *Cookie) bool,
|
||||
paramsChan chan<- TransportParameters,
|
||||
handshakeEvent chan<- struct{},
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cryptoSetupServer{
|
||||
cryptoStream: cryptoStream,
|
||||
connID: connID,
|
||||
remoteAddr: remoteAddr,
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
scfg: scfg,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
params: params,
|
||||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
cryptoStream: cryptoStream,
|
||||
connID: connID,
|
||||
remoteAddr: remoteAddr,
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
diversificationNonce: divNonce,
|
||||
scfg: scfg,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
params: params,
|
||||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -114,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
|
|||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
|
||||
utils.Debugf("Got %s", message)
|
||||
h.logger.Debugf("Got %s", message)
|
||||
done, err := h.handleMessage(chloData.Bytes(), message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -297,7 +303,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
|||
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
||||
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("STK invalid: %s", err.Error())
|
||||
h.logger.Debugf("STK invalid: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
return h.acceptSTKCallback(h.remoteAddr, stk)
|
||||
|
@ -340,7 +346,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
|||
|
||||
var serverReply bytes.Buffer
|
||||
message.Write(&serverReply)
|
||||
utils.Debugf("Sending %s", message)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return serverReply.Bytes(), nil
|
||||
}
|
||||
|
||||
|
@ -364,11 +370,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
return nil, err
|
||||
}
|
||||
|
||||
h.diversificationNonce = make([]byte, 32)
|
||||
if _, err = rand.Read(h.diversificationNonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientNonce := cryptoData[TagNONC]
|
||||
err = h.validateClientNonce(clientNonce)
|
||||
if err != nil {
|
||||
|
@ -405,7 +406,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
var fsNonce bytes.Buffer
|
||||
fsNonce.Write(clientNonce)
|
||||
fsNonce.Write(serverNonce)
|
||||
ephermalKex := h.keyExchange()
|
||||
ephermalKex, err := h.keyExchange()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -429,7 +433,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
replyMap := h.params.getHelloMap()
|
||||
// add crypto parameters
|
||||
verTag := &bytes.Buffer{}
|
||||
for _, v := range protocol.GetGreasedVersions(h.supportedVersions) {
|
||||
for _, v := range h.supportedVersions {
|
||||
utils.BigEndian.WriteUint32(verTag, uint32(v))
|
||||
}
|
||||
replyMap[TagPUBS] = ephermalKex.PublicKey()
|
||||
|
@ -443,19 +447,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
}
|
||||
var reply bytes.Buffer
|
||||
message.Write(&reply)
|
||||
utils.Debugf("Sending %s", message)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return reply.Bytes(), nil
|
||||
}
|
||||
|
||||
// DiversificationNonce returns the diversification nonce
|
||||
func (h *cryptoSetupServer) DiversificationNonce() []byte {
|
||||
return h.diversificationNonce
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
|
34
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
34
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
|
@ -31,6 +31,8 @@ type cryptoSetupTLS struct {
|
|||
handshakeEvent chan<- struct{}
|
||||
}
|
||||
|
||||
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
func NewCryptoSetupTLSServer(
|
||||
tls MintTLS,
|
||||
|
@ -38,7 +40,7 @@ func NewCryptoSetupTLSServer(
|
|||
nullAEAD crypto.AEAD,
|
||||
handshakeEvent chan<- struct{},
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetup {
|
||||
) CryptoSetupTLS {
|
||||
return &cryptoSetupTLS{
|
||||
tls: tls,
|
||||
cryptoStream: cryptoStream,
|
||||
|
@ -57,7 +59,7 @@ func NewCryptoSetupTLSClient(
|
|||
handshakeEvent chan<- struct{},
|
||||
tls MintTLS,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
) (CryptoSetupTLS, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -107,22 +109,18 @@ handshakeLoop:
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.aead != nil {
|
||||
data, err := h.aead.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return data, protocol.EncryptionForwardSecure, nil
|
||||
if h.aead == nil {
|
||||
return nil, errors.New("no 1-RTT sealer")
|
||||
}
|
||||
data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return data, protocol.EncryptionUnencrypted, nil
|
||||
return h.aead.Open(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
|
@ -157,14 +155,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
|
|||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
||||
panic("diversification nonce not needed for TLS")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
|
||||
panic("diversification nonce not needed for TLS")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
|
14
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
|
@ -6,7 +6,6 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -24,13 +23,13 @@ var (
|
|||
// used for all connections for 60 seconds is negligible. Thus we can amortise
|
||||
// the Diffie-Hellman key generation at the server over all the connections in a
|
||||
// small time span.
|
||||
func getEphermalKEX() (res crypto.KeyExchange) {
|
||||
func getEphermalKEX() (crypto.KeyExchange, error) {
|
||||
kexMutex.RLock()
|
||||
res = kexCurrent
|
||||
res := kexCurrent
|
||||
t := kexCurrentTime
|
||||
kexMutex.RUnlock()
|
||||
if res != nil && time.Since(t) < kexLifetime {
|
||||
return res
|
||||
return res, nil
|
||||
}
|
||||
|
||||
kexMutex.Lock()
|
||||
|
@ -39,12 +38,11 @@ func getEphermalKEX() (res crypto.KeyExchange) {
|
|||
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
utils.Errorf("could not set KEX: %s", err.Error())
|
||||
return kexCurrent
|
||||
return nil, err
|
||||
}
|
||||
kexCurrent = kex
|
||||
kexCurrentTime = time.Now()
|
||||
return kexCurrent
|
||||
return kexCurrent, nil
|
||||
}
|
||||
return kexCurrent
|
||||
return kexCurrent, nil
|
||||
}
|
||||
|
|
5
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
5
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
|
@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
|
|||
|
||||
offset := uint32(0)
|
||||
for i, t := range h.getTagsSorted() {
|
||||
v := data[Tag(t)]
|
||||
v := data[t]
|
||||
b.Write(v)
|
||||
offset += uint32(len(v))
|
||||
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
|
||||
|
@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag {
|
|||
func (h HandshakeMessage) String() string {
|
||||
var pad string
|
||||
res := tagToString(h.Tag) + ":\n"
|
||||
for _, t := range h.getTagsSorted() {
|
||||
tag := Tag(t)
|
||||
for _, tag := range h.getTagsSorted() {
|
||||
if tag == TagPAD {
|
||||
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
|
||||
} else {
|
||||
|
|
22
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
22
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
|
@ -35,13 +35,8 @@ type MintTLS interface {
|
|||
SetCryptoStream(io.ReadWriter)
|
||||
}
|
||||
|
||||
// CryptoSetup is a crypto setup
|
||||
type CryptoSetup interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
type baseCryptoSetup interface {
|
||||
HandleCryptoStream() error
|
||||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
|
@ -49,6 +44,21 @@ type CryptoSetup interface {
|
|||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||
}
|
||||
|
||||
// CryptoSetup is the crypto setup used by gQUIC
|
||||
type CryptoSetup interface {
|
||||
baseCryptoSetup
|
||||
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
}
|
||||
|
||||
// CryptoSetupTLS is the crypto setup used by IETF QUIC
|
||||
type CryptoSetupTLS interface {
|
||||
baseCryptoSetup
|
||||
|
||||
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
type ConnectionState struct {
|
||||
|
|
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go
generated
vendored
|
@ -9,10 +9,10 @@ import (
|
|||
|
||||
// ServerConfig is a server config
|
||||
type ServerConfig struct {
|
||||
kex crypto.KeyExchange
|
||||
certChain crypto.CertChain
|
||||
ID []byte
|
||||
obit []byte
|
||||
kex crypto.KeyExchange
|
||||
certChain crypto.CertChain
|
||||
ID []byte
|
||||
obit []byte
|
||||
cookieGenerator *CookieGenerator
|
||||
}
|
||||
|
||||
|
@ -36,10 +36,10 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve
|
|||
}
|
||||
|
||||
return &ServerConfig{
|
||||
kex: kex,
|
||||
certChain: certChain,
|
||||
ID: id,
|
||||
obit: obit,
|
||||
kex: kex,
|
||||
certChain: certChain,
|
||||
ID: id,
|
||||
obit: obit,
|
||||
cookieGenerator: cookieGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
|
28
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
28
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
|
@ -102,32 +102,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
|
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||
}
|
||||
|
||||
var pubs_kexs []struct{Length uint32; Value []byte}
|
||||
var last_len uint32
|
||||
|
||||
for i := 0; i < len(pubs)-3; i += int(last_len)+3 {
|
||||
var pubsKexs []struct {
|
||||
Length uint32
|
||||
Value []byte
|
||||
}
|
||||
var lastLen uint32
|
||||
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
|
||||
// the PUBS value is always prepended by 3 byte little endian length field
|
||||
|
||||
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &last_len);
|
||||
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
|
||||
}
|
||||
if last_len == 0 {
|
||||
if lastLen == 0 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
if i+3+int(last_len) > len(pubs) {
|
||||
if i+3+int(lastLen) > len(pubs) {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
pubs_kexs = append(pubs_kexs, struct{Length uint32; Value []byte}{last_len, pubs[i+3:i+3+int(last_len)]})
|
||||
pubsKexs = append(pubsKexs, struct {
|
||||
Length uint32
|
||||
Value []byte
|
||||
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
|
||||
}
|
||||
|
||||
if c255Foundat >= len(pubs_kexs) {
|
||||
if c255Foundat >= len(pubsKexs) {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
|
||||
}
|
||||
|
||||
if pubs_kexs[c255Foundat].Length != 32 {
|
||||
if pubsKexs[c255Foundat].Length != 32 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
|
@ -137,8 +142,7 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
|
|||
return err
|
||||
}
|
||||
|
||||
|
||||
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs_kexs[c255Foundat].Value)
|
||||
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go
generated
vendored
|
@ -9,14 +9,14 @@ type transportParameterID uint16
|
|||
const quicTLSExtensionType = 26
|
||||
|
||||
const (
|
||||
initialMaxStreamDataParameterID transportParameterID = 0x0
|
||||
initialMaxDataParameterID transportParameterID = 0x1
|
||||
initialMaxStreamIDBiDiParameterID transportParameterID = 0x2
|
||||
idleTimeoutParameterID transportParameterID = 0x3
|
||||
omitConnectionIDParameterID transportParameterID = 0x4
|
||||
maxPacketSizeParameterID transportParameterID = 0x5
|
||||
statelessResetTokenParameterID transportParameterID = 0x6
|
||||
initialMaxStreamIDUniParameterID transportParameterID = 0x8
|
||||
initialMaxStreamDataParameterID transportParameterID = 0x0
|
||||
initialMaxDataParameterID transportParameterID = 0x1
|
||||
initialMaxStreamsBiDiParameterID transportParameterID = 0x2
|
||||
idleTimeoutParameterID transportParameterID = 0x3
|
||||
omitConnectionIDParameterID transportParameterID = 0x4
|
||||
maxPacketSizeParameterID transportParameterID = 0x5
|
||||
statelessResetTokenParameterID transportParameterID = 0x6
|
||||
initialMaxStreamsUniParameterID transportParameterID = 0x8
|
||||
)
|
||||
|
||||
type transportParameter struct {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue