Merge branch 'master' into telemetry

# Conflicts:
#	caddy/caddymain/run.go
#	caddyhttp/httpserver/plugin.go
#	caddytls/client.go
This commit is contained in:
Matthew Holt 2018-04-20 00:03:57 -06:00
commit b019501b8b
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
228 changed files with 10317 additions and 5964 deletions

View file

@ -103,7 +103,7 @@ While we really do value your requests and implement many of them, not all featu
### Improving documentation
Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, feel free to contribute at the [caddyserver/website](https://github.com/caddyserver/website) repository!
Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, please submit an issue here describing the change to make.
Note that plugin documentation is not hosted by the Caddy website, other than basic usage examples. They are managed by the individual plugin authors, and you will have to contact them to change their documentation.

View file

@ -1,5 +1,5 @@
<p align="center">
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36137292-bebc223a-1051-11e8-9a81-4ea9054c96ac.png" alt="Caddy" width="400"></a>
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36338535-05fb646a-136f-11e8-987b-e6901e717d5a.png" alt="Caddy" width="450"></a>
</p>
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
@ -21,7 +21,7 @@
---
Caddy is fast, easy to use, and makes you more productive.
Caddy is a **production-ready** open-source web server that is fast, easy to use, and makes you more productive.
Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.com/mholt/caddy/wiki/Running-Caddy-on-Android).
@ -41,31 +41,35 @@ Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.co
- **Automatic HTTPS** on by default (via [Let's Encrypt](https://letsencrypt.org))
- **HTTP/2** by default
- **Virtual hosting** so multiple sites just work
- Experimental **QUIC support** for those that like speed
- Experimental **QUIC support** for cutting-edge transmissions
- TLS session ticket **key rotation** for more secure connections
- **Extensible with plugins** because a convenient web server is a helpful one
- **Runs anywhere** with **no external dependencies** (not even libc)
There's way more, too! [See all features built into Caddy.](https://caddyserver.com/features) On top of all those, Caddy does even more with plugins: choose which plugins you want at [download](https://caddyserver.com/download).
[See a more complete list of features built into Caddy.](https://caddyserver.com/features) On top of all those, Caddy does even more with plugins: choose which plugins you want at [download](https://caddyserver.com/download).
Altogether, Caddy can do things other web servers simply cannot do. Its features and plugins save you time and mistakes, and will cheer you up. Your Caddy instance takes care of the details for you!
## Install
Caddy binaries have no dependencies and are available for every platform. Get Caddy any one of these ways:
Caddy binaries have no dependencies and are available for every platform. Get Caddy either of these ways:
- **[Download page](https://caddyserver.com/download)** (RECOMMENDED) allows you to customize your build in the browser
- **[Latest release](https://github.com/mholt/caddy/releases/latest)** for pre-built, vanilla binaries
- **[Download page](https://caddyserver.com/download)** allows you to
customize your build in the browser
- **[Latest release](https://github.com/mholt/caddy/releases/latest)** for
pre-built, vanilla binaries
## Build
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
- Get the source with `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
- Now `cd $GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
Then make sure the `caddy` binary is in your PATH.
To build for other platforms, use build.go with the `--goos` and `--goarch` flags.
## Quick Start
@ -85,7 +89,7 @@ If the `caddy` binary has permission to bind to low ports and your domain name's
caddy -host example.com
```
This command serves static files from the current directory over HTTPS. Certificates are automatically obtained and renewed for you!
This command serves static files from the current directory over HTTPS. Certificates are automatically obtained and renewed for you! Caddy is also automatically configuring ports 80 and 443 for you, and redirecting HTTP to HTTPS. Cool, huh?
### Customizing your site
@ -115,7 +119,7 @@ To host multiple sites and do more with the Caddyfile, please see the [Caddyfile
Sites with qualifying hostnames are served over [HTTPS by default](https://caddyserver.com/docs/automatic-https).
Caddy has a command line interface. Run `caddy -h` to view basic help or see the [CLI documentation](https://caddyserver.com/docs/cli) for details.
Caddy has a nice little command line interface. Run `caddy -h` to view basic help or see the [CLI documentation](https://caddyserver.com/docs/cli) for details.
## Running in Production
@ -139,7 +143,7 @@ Please see our [contributing guidelines](https://github.com/mholt/caddy/blob/mas
We use GitHub issues and pull requests only for discussing bug reports and the development of specific changes. We welcome all other topics on the [forum](https://caddy.community)!
If you want to contribute to the documentation, please submit pull requests to [caddyserver/website](https://github.com/caddyserver/website).
If you want to contribute to the documentation, please [submit an issue](https://github.com/mholt/caddy/issues/new) describing the change that should be made.
Thanks for making Caddy -- and the Web -- better!
@ -158,6 +162,6 @@ We thank them for their services. **If you want to help keep Caddy free, please
Caddy was born out of the need for a "batteries-included" web server that runs anywhere and doesn't have to take its configuration with it. Caddy took inspiration from [spark](https://github.com/rif/spark), [nginx](https://github.com/nginx/nginx), lighttpd,
[Websocketd](https://github.com/joewalnes/websocketd) and [Vagrant](https://www.vagrantup.com/), which provides a pleasant mixture of features from each of them.
**The name "Caddy":** The name of the software is "Caddy", not "Caddy Server" or "CaddyServer". Please call it "Caddy" or, if you wish to clarify, "the Caddy web server". See [brand guidelines](https://caddyserver.com/brand).
**The name "Caddy" is trademarked:** The name of the software is "Caddy", not "Caddy Server" or "CaddyServer". Please call it "Caddy" or, if you wish to clarify, "the Caddy web server". See [brand guidelines](https://caddyserver.com/brand). Caddy is a registered trademark of Light Code Labs, LLC.
*Author on Twitter: [@mholt6](https://twitter.com/mholt6)*

View file

@ -802,7 +802,7 @@ func startServers(serverList []Server, inst *Instance, restartFds map[string]res
continue
}
if strings.Contains(err.Error(), "use of closed network connection") {
// this error is normal when closing the listener
// this error is normal when closing the listener; see https://github.com/golang/go/issues/4373
continue
}
log.Println(err)

View file

@ -31,7 +31,7 @@ import (
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls"
"github.com/mholt/caddy/telemetry"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
"gopkg.in/natefinch/lumberjack.v2"
_ "github.com/mholt/caddy/caddyhttp" // plug in the HTTP server type
@ -43,7 +43,7 @@ func init() {
setVersion()
flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement")
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory")
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v02.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory")
flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge")
flag.BoolVar(&caddytls.DisableTLSSNIChallenge, "disable-tls-sni-challenge", caddytls.DisableTLSSNIChallenge, "Disable the ACME TLS-SNI challenge")
flag.StringVar(&disabledMetrics, "disabled-metrics", "", "Comma-separated list of telemetry metrics to disable")

View file

@ -265,14 +265,19 @@ func (p *parser) doImport() error {
} else {
globPattern = importPattern
}
if strings.Count(globPattern, "*") > 1 || strings.Count(globPattern, "?") > 1 ||
(strings.Contains(globPattern, "[") && strings.Contains(globPattern, "]")) {
// See issue #2096 - a pattern with many glob expansions can hang for too long
return p.Errf("Glob pattern may only contain one wildcard (*), but has others: %s", globPattern)
}
matches, err = filepath.Glob(globPattern)
if err != nil {
return p.Errf("Failed to use import pattern %s: %v", importPattern, err)
}
if len(matches) == 0 {
if strings.Contains(globPattern, "*") {
log.Printf("[WARNING] No files matching import pattern: %s", importPattern)
if strings.ContainsAny(globPattern, "*?[]") {
log.Printf("[WARNING] No files matching import glob pattern: %s", importPattern)
} else {
return p.Errf("File to import not found: %s", importPattern)
}
@ -443,7 +448,7 @@ func replaceEnvReferences(s, refStart, refEnd string) string {
index := strings.Index(s, refStart)
for index != -1 {
endIndex := strings.Index(s, refEnd)
if endIndex != -1 {
if endIndex > index+len(refStart) {
ref := s[index : endIndex+len(refEnd)]
s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1)
} else {

View file

@ -228,6 +228,17 @@ func TestParseOneAndImport(t *testing.T) {
{`""`, false, []string{}, map[string]int{}},
{``, false, []string{}, map[string]int{}},
// test cases found by fuzzing!
{`import }{$"`, true, []string{}, map[string]int{}},
{`import /*/*.txt`, true, []string{}, map[string]int{}},
{`import /???/?*?o`, true, []string{}, map[string]int{}},
{`import /??`, true, []string{}, map[string]int{}},
{`import /[a-z]`, true, []string{}, map[string]int{}},
{`import {$}`, true, []string{}, map[string]int{}},
{`import {%}`, true, []string{}, map[string]int{}},
{`import {$$}`, true, []string{}, map[string]int{}},
{`import {%%}`, true, []string{}, map[string]int{}},
} {
result, err := testParseOne(test.input)

View file

@ -46,5 +46,4 @@ import (
_ "github.com/mholt/caddy/caddyhttp/timeouts"
_ "github.com/mholt/caddy/caddyhttp/websocket"
_ "github.com/mholt/caddy/onevent"
_ "github.com/mholt/caddy/startupshutdown"
)

View file

@ -25,7 +25,7 @@ import (
// ensure that the standard plugins are in fact plugged in
// and registered properly; this is a quick/naive way to do it.
func TestStandardPlugins(t *testing.T) {
numStandardPlugins := 33 // importing caddyhttp plugs in this many plugins
numStandardPlugins := 31 // importing caddyhttp plugs in this many plugins
s := caddy.DescribePlugins()
if got, want := strings.Count(s, "\n"), numStandardPlugins+5; got != want {
t.Errorf("Expected all standard plugins to be plugged in, got:\n%s", s)

View file

@ -33,8 +33,11 @@ import (
"sync/atomic"
"time"
"crypto/tls"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
"github.com/mholt/caddy/caddytls"
)
// Handler is a middleware type that can handle requests as a FastCGI client.
@ -323,6 +326,19 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
// Some web apps rely on knowing HTTPS or not
if r.TLS != nil {
env["HTTPS"] = "on"
// and pass the protocol details in a manner compatible with apache's mod_ssl
// (which is why they have a SSL_ prefix and not TLS_).
v, ok := tlsProtocolStringToMap[r.TLS.Version]
if ok {
env["SSL_PROTOCOL"] = v
}
// and pass the cipher suite in a manner compatible with apache's mod_ssl
for k, v := range caddytls.SupportedCiphersMap {
if v == r.TLS.CipherSuite {
env["SSL_CIPHER"] = k
break
}
}
}
// Add env variables from config (with support for placeholders in values)
@ -465,3 +481,11 @@ type LogError string
func (l LogError) Error() string {
return string(l)
}
// Map of supported protocols to Apache ssl_mod format
// Note that these are slightly different from SupportedProtocols in caddytls/config.go's
var tlsProtocolStringToMap = map[uint16]string{
tls.VersionTLS10: "TLSv1",
tls.VersionTLS11: "TLSv1.1",
tls.VersionTLS12: "TLSv1.2",
}

View file

@ -100,8 +100,8 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
}
cfg.TLS.Enabled = true
cfg.Addr.Scheme = "https"
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host)
if loadCertificates && caddytls.HostQualifies(cfg.TLS.Hostname) {
_, err := cfg.TLS.CacheManagedCertificate(cfg.TLS.Hostname)
if err != nil {
return err
}

View file

@ -44,6 +44,7 @@ type Logger struct {
V4ipMask net.IPMask
V6ipMask net.IPMask
IPMaskExists bool
Exceptions []string
}
// NewTestLogger creates logger suitable for testing purposes
@ -84,6 +85,17 @@ func (l Logger) MaskIP(ip string) string {
}
// ShouldLog returns true if the path is not exempted from
// being logged (i.e. it is not found in l.Exceptions).
func (l Logger) ShouldLog(path string) bool {
for _, exc := range l.Exceptions {
if Path(path).Matches(exc) {
return false
}
}
return true
}
// Attach binds logger Start and Close functions to
// controller's OnStartup and OnShutdown hooks.
func (l *Logger) Attach(controller *caddy.Controller) {

View file

@ -15,6 +15,7 @@
package httpserver
import (
"crypto/tls"
"flag"
"fmt"
"log"
@ -123,15 +124,17 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
// For each address in each server block, make a new config
for _, sb := range serverBlocks {
for _, key := range sb.Keys {
key = strings.ToLower(key)
if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
}
addr, err := standardizeAddress(key)
if err != nil {
return serverBlocks, err
}
addr = addr.Normalize()
key = addr.Key()
if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
}
// Fill in address components from command line so that middleware
// have access to the correct information during setup
if addr.Host == "" && Host != DefaultHost {
@ -146,7 +149,7 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
if addrCopy.Port == "" && Port == DefaultPort {
addrCopy.Port = Port
}
addrStr := strings.ToLower(addrCopy.String())
addrStr := addrCopy.String()
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
err := fmt.Errorf("duplicate site address: %s", addrStr)
if (addrCopy.Host == Host && Host != DefaultHost) ||
@ -218,6 +221,13 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
}
}
// Iterate each site configuration and make sure that:
// 1) TLS is disabled for explicitly-HTTP sites (necessary
// when an HTTP address shares a block containing tls)
// 2) if QUIC is enabled, TLS ClientAuth is not, because
// currently, QUIC does not support ClientAuth (TODO:
// revisit this when our QUIC implementation supports it)
// 3) if TLS ClientAuth is used, StrictHostMatching is on
var atLeastOneSiteLooksLikeProduction bool
for _, cfg := range h.siteConfigs {
// see if all the addresses (both sites and
@ -254,6 +264,17 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
// instead of 443 because it doesn't know about TLS.
cfg.Addr.Port = HTTPSPort
}
if cfg.TLS.ClientAuth != tls.NoClientCert {
if QUIC {
return nil, fmt.Errorf("cannot enable TLS client authentication with QUIC, because QUIC does not yet support it")
}
// this must be enabled so that a client cannot connect
// using SNI for another site on this listener that
// does NOT require ClientAuth, and then send HTTP
// requests with the Host header of this site which DOES
// require client auth, thus bypassing it...
cfg.StrictHostMatching = true
}
}
// we must map (group) each config to a bind address
@ -287,12 +308,22 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
return servers, nil
}
// normalizedKey returns "normalized" key representation:
// scheme and host names are lowered, everything else stays the same
func normalizedKey(key string) string {
addr, err := standardizeAddress(key)
if err != nil {
return key
}
return addr.Normalize().Key()
}
// GetConfig gets the SiteConfig that corresponds to c.
// If none exist (should only happen in tests), then a
// new, empty one will be created.
func GetConfig(c *caddy.Controller) *SiteConfig {
ctx := c.Context().(*httpContext)
key := strings.ToLower(c.Key)
key := normalizedKey(c.Key)
if cfg, ok := ctx.keysToSiteConfigs[key]; ok {
return cfg
}
@ -396,6 +427,43 @@ func (a Address) VHost() string {
return a.Original
}
// Normalize normalizes URL: turn scheme and host names into lower case
func (a Address) Normalize() Address {
path := a.Path
if !CaseSensitivePath {
path = strings.ToLower(path)
}
return Address{
Original: a.Original,
Scheme: strings.ToLower(a.Scheme),
Host: strings.ToLower(a.Host),
Port: a.Port,
Path: path,
}
}
// Key is similar to String, just replaces scheme and host values with modified values.
// Unlike String it doesn't add anything default (scheme, port, etc)
func (a Address) Key() string {
res := ""
if a.Scheme != "" {
res += a.Scheme + "://"
}
if a.Host != "" {
res += a.Host
}
if a.Port != "" {
if strings.HasPrefix(a.Original[len(res):], ":"+a.Port) {
// insert port only if the original has its own explicit port
res += ":" + a.Port
}
}
if a.Path != "" {
res += a.Path
}
return res
}
// standardizeAddress parses an address string into a structured format with separate
// scheme, host, port, and path portions, as well as the original input string.
func standardizeAddress(str string) (Address, error) {
@ -523,6 +591,7 @@ var directives = []string{
"startup", // TODO: Deprecate this directive
"shutdown", // TODO: Deprecate this directive
"on",
"supervisor", // github.com/lucaslorentz/caddy-supervisor
"request_id",
"realip", // github.com/captncraig/caddy-realip
"git", // github.com/abiosoft/caddy-git
@ -538,13 +607,13 @@ var directives = []string{
"ext",
"gzip",
"header",
"geoip", // github.com/kodnaplakal/caddy-geoip
"errors",
"authz", // github.com/casbin/caddy-authz
"filter", // github.com/echocat/caddy-filter
"minify", // github.com/hacdias/caddy-minify
"ipfilter", // github.com/pyed/ipfilter
"ratelimit", // github.com/xuqingfeng/caddy-rate-limit
"search", // github.com/pedronasser/caddy-search
"expires", // github.com/epicagency/caddy-expires
"forwardproxy", // github.com/caddyserver/forwardproxy
"basicauth",

View file

@ -18,6 +18,10 @@ import (
"strings"
"testing"
"sort"
"fmt"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile"
)
@ -147,7 +151,20 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
if err != nil {
t.Fatalf("Didn't expect an error, but got: %v", err)
}
addr := ctx.keysToSiteConfigs["localhost"].Addr
localhostKey := "localhost"
item, ok := ctx.keysToSiteConfigs[localhostKey]
if !ok {
availableKeys := make(sort.StringSlice, len(ctx.keysToSiteConfigs))
i := 0
for key := range ctx.keysToSiteConfigs {
availableKeys[i] = fmt.Sprintf("'%s'", key)
i++
}
availableKeys.Sort()
t.Errorf("`%s` not found within registered keys, only these are available: %s", localhostKey, strings.Join(availableKeys, ", "))
return
}
addr := item.Addr
if addr.Port != Port {
t.Errorf("Expected the port on the address to be set, but got: %#v", addr)
}
@ -184,6 +201,64 @@ func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
}
}
func TestKeyNormalization(t *testing.T) {
originalCaseSensitivePath := CaseSensitivePath
defer func() {
CaseSensitivePath = originalCaseSensitivePath
}()
CaseSensitivePath = true
caseSensitiveData := []struct {
orig string
res string
}{
{
orig: "HTTP://A/ABCDEF",
res: "http://a/ABCDEF",
},
{
orig: "A/ABCDEF",
res: "a/ABCDEF",
},
{
orig: "A:2015/Port",
res: "a:2015/Port",
},
}
for _, item := range caseSensitiveData {
v := normalizedKey(item.orig)
if v != item.res {
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to true must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
}
}
CaseSensitivePath = false
caseInsensitiveData := []struct {
orig string
res string
}{
{
orig: "HTTP://A/ABCDEF",
res: "http://a/abcdef",
},
{
orig: "A/ABCDEF",
res: "a/abcdef",
},
{
orig: "A:2015/Port",
res: "a:2015/port",
},
}
for _, item := range caseInsensitiveData {
v := normalizedKey(item.orig)
if v != item.res {
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to false must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
}
}
}
func TestGetConfig(t *testing.T) {
// case insensitivity for key
con := caddy.NewTestController("http", "")
@ -201,6 +276,14 @@ func TestGetConfig(t *testing.T) {
if cfg == cfg3 {
t.Errorf("Expected different configs using when key is different; got %p and %p", cfg, cfg3)
}
con.Key = "foo/foobar"
cfg4 := GetConfig(con)
con.Key = "foo/Foobar"
cfg5 := GetConfig(con)
if cfg4 == cfg5 {
t.Errorf("Expected different cases in path to differentiate keys in general")
}
}
func TestDirectivesList(t *testing.T) {

View file

@ -29,6 +29,7 @@ import (
"time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls"
)
// requestReplacer is a strings.Replacer which is used to
@ -140,6 +141,14 @@ func canLogRequest(r *http.Request) bool {
return false
}
// unescapeBraces finds escaped braces in s and returns
// a string with those braces unescaped.
func unescapeBraces(s string) string {
s = strings.Replace(s, "\\{", "{", -1)
s = strings.Replace(s, "\\}", "}", -1)
return s
}
// Replace performs a replacement of values on s and returns
// the string with the replaced values.
func (r *replacer) Replace(s string) string {
@ -149,32 +158,59 @@ func (r *replacer) Replace(s string) string {
}
result := ""
Placeholders: // process each placeholder in sequence
for {
idxStart := strings.Index(s, "{")
if idxStart == -1 {
// no placeholder anymore
break
}
idxEnd := strings.Index(s[idxStart:], "}")
if idxEnd == -1 {
// unpaired placeholder
break
}
idxEnd += idxStart
var idxStart, idxEnd int
// get a replacement
placeholder := s[idxStart : idxEnd+1]
idxOffset := 0
for { // find first unescaped opening brace
searchSpace := s[idxOffset:]
idxStart = strings.Index(searchSpace, "{")
if idxStart == -1 {
// no more placeholders
break Placeholders
}
if idxStart == 0 || searchSpace[idxStart-1] != '\\' {
// preceding character is not an escape
idxStart += idxOffset
break
}
// the brace we found was escaped
// search the rest of the string next
idxOffset += idxStart + 1
}
idxOffset = 0
for { // find first unescaped closing brace
searchSpace := s[idxStart+idxOffset:]
idxEnd = strings.Index(searchSpace, "}")
if idxEnd == -1 {
// unpaired placeholder
break Placeholders
}
if idxEnd == 0 || searchSpace[idxEnd-1] != '\\' {
// preceding character is not an escape
idxEnd += idxOffset + idxStart
break
}
// the brace we found was escaped
// search the rest of the string next
idxOffset += idxEnd + 1
}
// get a replacement for the unescaped placeholder
placeholder := unescapeBraces(s[idxStart : idxEnd+1])
replacement := r.getSubstitution(placeholder)
// append prefix + replacement
result += s[:idxStart] + replacement
// append unescaped prefix + replacement
result += strings.TrimPrefix(unescapeBraces(s[:idxStart]), "\\") + replacement
// strip out scanned parts
s = s[idxEnd+1:]
}
// append unscanned parts
return result + s
return result + unescapeBraces(s)
}
func roundDuration(d time.Duration) time.Duration {
@ -224,6 +260,16 @@ func (r *replacer) getSubstitution(key string) string {
}
}
}
// search response headers then
if r.responseRecorder != nil && key[1] == '<' {
want := key[2 : len(key)-1]
for key, values := range r.responseRecorder.Header() {
// Header placeholders (case-insensitive)
if strings.EqualFold(key, want) {
return strings.Join(values, ",")
}
}
}
// next check for cookies
if key[1] == '~' {
name := key[2 : len(key)-1]
@ -365,12 +411,46 @@ func (r *replacer) getSubstitution(key string) string {
}
elapsedDuration := time.Since(r.responseRecorder.start)
return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10)
case "{tls_protocol}":
if r.request.TLS != nil {
for k, v := range caddytls.SupportedProtocols {
if v == r.request.TLS.Version {
return k
}
}
return "tls" // this should never happen, but guard in case
}
return r.emptyValue // because not using a secure channel
case "{tls_cipher}":
if r.request.TLS != nil {
for k, v := range caddytls.SupportedCiphersMap {
if v == r.request.TLS.CipherSuite {
return k
}
}
return "UNKNOWN" // this should never happen, but guard in case
}
return r.emptyValue
default:
// {labelN}
if strings.HasPrefix(key, "{label") {
nStr := key[6 : len(key)-1] // get the integer N in "{labelN}"
n, err := strconv.Atoi(nStr)
if err != nil || n < 1 {
return r.emptyValue
}
labels := strings.Split(r.request.Host, ".")
if n > len(labels) {
return r.emptyValue
}
return labels[n-1]
}
}
return r.emptyValue
}
//convertToMilliseconds returns the number of milliseconds in the given duration
// convertToMilliseconds returns the number of milliseconds in the given duration
func convertToMilliseconds(d time.Duration) int64 {
return d.Nanoseconds() / 1e6
}

View file

@ -53,7 +53,7 @@ func TestReplace(t *testing.T) {
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
request, err := http.NewRequest("POST", "http://localhost.local/?foo=bar", reader)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
@ -67,6 +67,9 @@ func TestReplace(t *testing.T) {
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
hostname, err := os.Hostname()
if err != nil {
t.Fatalf("Failed to determine hostname: %v", err)
@ -84,7 +87,7 @@ func TestReplace(t *testing.T) {
expect string
}{
{"This hostname is {hostname}", "This hostname is " + hostname},
{"This host is {host}.", "This host is localhost."},
{"This host is {host}.", "This host is localhost.local."},
{"This request method is {method}.", "This request method is POST."},
{"The response status is {status}.", "The response status is 200."},
{"{when}", "02/Jan/2006:15:04:05 +0000"},
@ -92,10 +95,13 @@ func TestReplace(t *testing.T) {
{"{when_unix}", "1136214252"},
{"The Custom header is {>Custom}.", "The Custom header is foobarbaz."},
{"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."},
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost\\r\\n" +
{"The Custom response header is {<Custom}.", "The Custom response header is CustomResponseHeader."},
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost.local\\r\\n" +
"Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" +
"Shorterval: 1\\r\\n\\r\\n."},
{"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."},
{"The cUsToM response header is {<CuSTom}.", "The cUsToM response header is CustomResponseHeader."},
{"The Non-Existent header is {>Non-Existent}.", "The Non-Existent header is -."},
{"Bad {host placeholder...", "Bad {host placeholder..."},
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
@ -106,6 +112,9 @@ func TestReplace(t *testing.T) {
{"Query string is {query}", "Query string is foo=bar"},
{"Query string value for foo is {?foo}", "Query string value for foo is bar"},
{"Missing query string argument is {?missing}", "Missing query string argument is "},
{"{label1} {label2} {label3} {label4}", "localhost local - -"},
{"Label with missing number is {label} or {labelQQ}", "Label with missing number is - or -"},
{"\\{ 'hostname': '{hostname}' \\}", "{ 'hostname': '" + hostname + "' }"},
}
for _, c := range testCases {
@ -138,6 +147,107 @@ func TestReplace(t *testing.T) {
}
}
func BenchmarkReplace(b *testing.B) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
b.Fatalf("Failed to make request: %v", err)
}
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
request = request.WithContext(ctx)
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
repl.Replace("This hostname is {hostname}")
}
}
func BenchmarkReplaceEscaped(b *testing.B) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
b.Fatalf("Failed to make request: %v", err)
}
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
request = request.WithContext(ctx)
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
repl.Replace("\\{ 'hostname': '{hostname}' \\}")
}
}
func TestResponseRecorderNil(t *testing.T) {
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
request.Header.Set("Custom", "foobarbaz")
repl := NewReplacer(request, nil, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
old := now
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
defer func() {
now = old
}()
testCases := []struct {
template string
expect string
}{
{"The Custom response header is {<Custom}.", "The Custom response header is -."},
}
for _, c := range testCases {
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
}
}
}
func TestSet(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)

View file

@ -320,6 +320,9 @@ func (s *Server) Serve(ln net.Listener) error {
}
err := s.Server.Serve(ln)
if err == http.ErrServerClosed {
err = nil // not an error worth reporting since closing a server is intentional
}
if s.quicServer != nil {
s.quicServer.Close()
}
@ -421,19 +424,39 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
r.URL = trimPathPrefix(r.URL, pathPrefix)
}
// enforce strict host matching, which ensures that the SNI
// value (if any), matches the Host header; essential for
// sites that rely on TLS ClientAuth sharing a port with
// sites that do not - if mismatched, close the connection
if vhost.StrictHostMatching && r.TLS != nil &&
strings.ToLower(r.TLS.ServerName) != strings.ToLower(hostname) {
r.Close = true
log.Printf("[ERROR] %s - strict host matching: SNI (%s) and HTTP Host (%s) values differ",
vhost.Addr, r.TLS.ServerName, hostname)
return http.StatusForbidden, nil
}
return vhost.middlewareChain.ServeHTTP(w, r)
}
func trimPathPrefix(u *url.URL, prefix string) *url.URL {
// We need to use URL.EscapedPath() when trimming the pathPrefix as
// URL.Path is ambiguous about / or %2f - see docs. See #1927
trimmed := strings.TrimPrefix(u.EscapedPath(), prefix)
if !strings.HasPrefix(trimmed, "/") {
trimmed = "/" + trimmed
trimmedPath := strings.TrimPrefix(u.EscapedPath(), prefix)
if !strings.HasPrefix(trimmedPath, "/") {
trimmedPath = "/" + trimmedPath
}
trimmedURL, err := url.Parse(trimmed)
// After trimming path reconstruct uri string with Query before parsing
trimmedURI := trimmedPath
if u.RawQuery != "" || u.ForceQuery == true {
trimmedURI = trimmedPath + "?" + u.RawQuery
}
if u.Fragment != "" {
trimmedURI = trimmedURI + "#" + u.Fragment
}
trimmedURL, err := url.Parse(trimmedURI)
if err != nil {
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err)
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmedURI, err)
return u
}
return trimmedURL

View file

@ -129,88 +129,108 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
func TestTrimPathPrefix(t *testing.T) {
for i, pt := range []struct {
path string
url string
prefix string
expected string
shouldFail bool
}{
{
path: "/my/path",
url: "/my/path",
prefix: "/my",
expected: "/path",
shouldFail: false,
},
{
path: "/my/%2f/path",
url: "/my/%2f/path",
prefix: "/my",
expected: "/%2f/path",
shouldFail: false,
},
{
path: "/my/path",
url: "/my/path",
prefix: "/my/",
expected: "/path",
shouldFail: false,
},
{
path: "/my///path",
url: "/my///path",
prefix: "/my",
expected: "/path",
shouldFail: true,
},
{
path: "/my///path",
url: "/my///path",
prefix: "/my",
expected: "///path",
shouldFail: false,
},
{
path: "/my/path///slash",
url: "/my/path///slash",
prefix: "/my",
expected: "/path///slash",
shouldFail: false,
},
{
path: "/my/%2f/path/%2f",
url: "/my/%2f/path/%2f",
prefix: "/my",
expected: "/%2f/path/%2f",
shouldFail: false,
}, {
path: "/my/%20/path",
url: "/my/%20/path",
prefix: "/my",
expected: "/%20/path",
shouldFail: false,
}, {
path: "/path",
url: "/path",
prefix: "",
expected: "/path",
shouldFail: false,
}, {
path: "/path/my/",
url: "/path/my/",
prefix: "/my",
expected: "/path/my/",
shouldFail: false,
}, {
path: "",
url: "",
prefix: "/my",
expected: "/",
shouldFail: false,
}, {
path: "/apath",
url: "/apath",
prefix: "",
expected: "/apath",
shouldFail: false,
}, {
url: "/my/path/page.php?akey=value",
prefix: "/my",
expected: "/path/page.php?akey=value",
shouldFail: false,
}, {
url: "/my/path/page?key=value#fragment",
prefix: "/my",
expected: "/path/page?key=value#fragment",
shouldFail: false,
}, {
url: "/my/path/page#fragment",
prefix: "/my",
expected: "/path/page#fragment",
shouldFail: false,
}, {
url: "/my/apath?",
prefix: "/my",
expected: "/apath?",
shouldFail: false,
},
} {
u, _ := url.Parse(pt.path)
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want {
u, _ := url.Parse(pt.url)
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.String() != want {
if !pt.shouldFail {
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath())
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.String())
}
} else if pt.shouldFail {
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath())
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.String())
}
}
}

View file

@ -36,6 +36,16 @@ type SiteConfig struct {
// TLS configuration
TLS *caddytls.Config
// If true, the Host header in the HTTP request must
// match the SNI value in the TLS handshake (if any).
// This should be enabled whenever a site relies on
// TLS client authentication, for example; or any time
// you want to enforce that THIS site's TLS config
// is used and not the TLS config of any other site
// on the same listener. TODO: Check how relevant this
// is with TLS 1.3.
StrictHostMatching bool
// Uncompiled middleware stack
middleware []Middleware

View file

@ -277,7 +277,7 @@ func TestHostname(t *testing.T) {
// // Test 3 - ipv6 without port and brackets
// {"2001:4860:4860::8888", "google-public-dns-a.google.com."},
// Test 4 - no hostname available
{"1.1.1.1", "1.1.1.1"},
{"0.0.0.0", "0.0.0.0"},
}
for i, test := range tests {

View file

@ -67,6 +67,10 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// Write log entries
for _, e := range rule.Entries {
// Check if there is an exception to prevent log being written
if !e.Log.ShouldLog(r.URL.Path) {
continue
}
// Mask IP Address
if e.Log.IPMaskExists {
@ -78,6 +82,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
}
}
e.Log.Println(rep.Replace(e.Format))
}
return status, err

View file

@ -177,3 +177,85 @@ func TestMultiEntries(t *testing.T) {
t.Errorf("Expected %q, but got %q", expect, got)
}
}
func TestLogExcept(t *testing.T) {
tests := []struct {
LogRules []Rule
logPath string
shouldLog bool
}{
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/soup"},
},
Format: DefaultLogFormat,
}},
}}, `/soup`, false},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/tart"},
},
Format: DefaultLogFormat,
}},
}}, `/soup`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/soup"},
},
Format: DefaultLogFormat,
}},
}}, `/tomatosoup`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/pie/"},
},
Format: DefaultLogFormat,
}},
// Check exception with a trailing slash does not match without
}}, `/pie`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/pie.php"},
},
Format: DefaultLogFormat,
}},
}}, `/pie`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/pie"},
},
Format: DefaultLogFormat,
}},
// Check that a word without trailing slash will match a filename
}}, `/pie.php`, false},
}
for i, test := range tests {
for _, LogRule := range test.LogRules {
for _, e := range LogRule.Entries {
shouldLog := e.Log.ShouldLog(test.logPath)
if shouldLog != test.shouldLog {
t.Fatalf("Test %d expected shouldLog=%t but got shouldLog=%t,", i, test.shouldLog, shouldLog)
}
}
}
}
}

View file

@ -44,7 +44,7 @@ func setup(c *caddy.Controller) error {
func logParse(c *caddy.Controller) ([]*Rule, error) {
var rules []*Rule
var logExceptions []string
for c.Next() {
args := c.RemainingArgs()
@ -91,6 +91,12 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
}
} else if what == "except" {
for i := 0; i < len(where); i++ {
logExceptions = append(logExceptions, where[i])
}
} else if httpserver.IsLogRollerSubdirective(what) {
if err := httpserver.ParseRoller(logRoller, what, where...); err != nil {
@ -133,6 +139,7 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
V4ipMask: ip4Mask,
V6ipMask: ip6Mask,
IPMaskExists: ipMaskExists,
Exceptions: logExceptions,
},
Format: format,
})

View file

@ -58,6 +58,10 @@ type Upstream interface {
// Gets the number of upstream hosts.
GetHostCount() int
// Gets how long to wait before timing out
// the request
GetTimeout() time.Duration
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error
}
@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if nameURL, err := url.Parse(host.Name); err == nil {
outreq.Host = nameURL.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost)
proxy = NewSingleHostReverseProxy(nameURL,
host.WithoutPathPrefix,
http.DefaultMaxIdleConnsPerHost,
upstream.GetTimeout(),
)
}
// use upstream credentials by default

View file

@ -122,7 +122,7 @@ func TestReverseProxy(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
}
// Create the fake request body.
@ -202,7 +202,7 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, true, 30*time.Second)},
}
// create request and response recorder
@ -287,6 +287,31 @@ func TestReverseProxyMaxConnLimit(t *testing.T) {
jobs.Wait()
}
func TestReverseProxyTimeout(t *testing.T) {
timeout := 2 * time.Second
errorMargin := 100 * time.Millisecond
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream("https://8.8.8.8", true, timeout)},
}
// create request and response recorder
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
start := time.Now()
p.ServeHTTP(w, r)
took := time.Since(start)
if took > timeout+errorMargin {
t.Errorf("Expected timeout ~ %v but got %v", timeout, took)
}
}
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
// Capture the expected panic
defer func() {
@ -301,7 +326,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
@ -331,7 +356,7 @@ func TestWebSocketReverseProxyBackendShutDown(t *testing.T) {
}()
// Get proxy to use for the test
p := newWebSocketTestProxy(backend.URL, false)
p := newWebSocketTestProxy(backend.URL, false, 30*time.Second)
backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
@ -360,7 +385,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
@ -407,7 +432,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
defer wsEcho.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL, false)
p := newWebSocketTestProxy(wsEcho.URL, false, 30*time.Second)
// This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our
@ -452,7 +477,7 @@ func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
}))
defer wsEcho.Close()
p := newWebSocketTestProxy(wsEcho.URL, true)
p := newWebSocketTestProxy(wsEcho.URL, true, 30*time.Second)
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
@ -528,7 +553,7 @@ func TestUnixSocketProxy(t *testing.T) {
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url, false)
p := newWebSocketTestProxy(url, false, 30*time.Second)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
@ -686,7 +711,7 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
@ -753,7 +778,7 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
@ -893,7 +918,7 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
}
r := httptest.NewRequest("GET", "/", nil)
@ -913,6 +938,67 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
}
}
func TestReverseProxyTransparentHeaders(t *testing.T) {
testCases := []struct {
name string
remoteAddr string
forwardedForHeader string
expected []string
}{
{"No header", "192.168.0.1:80", "", []string{"192.168.0.1"}},
{"Existing", "192.168.0.1:80", "1.1.1.1, 2.2.2.2", []string{"1.1.1.1, 2.2.2.2, 192.168.0.1"}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testReverseProxyTransparentHeaders(t, tc.remoteAddr, tc.forwardedForHeader, tc.expected)
})
}
}
func testReverseProxyTransparentHeaders(t *testing.T, remoteAddr, forwardedForHeader string, expected []string) {
// Arrange
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var actualHeaders http.Header
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualHeaders = r.Header
}))
defer backend.Close()
config := "proxy / " + backend.URL + " {\n transparent \n}"
// make proxy
upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(config)), "")
if err != nil {
t.Errorf("Expected no error. Got: %s", err.Error())
}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: upstreams,
}
// create request and response recorder
r := httptest.NewRequest("GET", backend.URL, nil)
r.RemoteAddr = remoteAddr
if forwardedForHeader != "" {
r.Header.Set("X-Forwarded-For", forwardedForHeader)
}
w := httptest.NewRecorder()
// Act
p.ServeHTTP(w, r)
// Assert
if got := actualHeaders["X-Forwarded-For"]; !reflect.DeepEqual(got, expected) {
t.Errorf("Transparent proxy response does not contain expected %v header: expect %v, but got %v",
"X-Forwarded-For", expected, got)
}
}
func TestHostHeaderReplacedUsingForward(t *testing.T) {
var requestHost string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -921,7 +1007,7 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
proxyHostHeader := "test2.com"
upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}}
// set up proxy
@ -943,11 +1029,22 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
}
func TestBasicAuth(t *testing.T) {
basicAuthTestcase(t, nil, nil)
basicAuthTestcase(t, nil, url.UserPassword("username", "password"))
basicAuthTestcase(t, url.UserPassword("usename", "password"), nil)
basicAuthTestcase(t, url.UserPassword("unused", "unused"),
url.UserPassword("username", "password"))
testCases := []struct {
name string
upstreamUser *url.Userinfo
clientUser *url.Userinfo
}{
{"Nil Both", nil, nil},
{"Nil Upstream User", nil, url.UserPassword("username", "password")},
{"Nil Client User", url.UserPassword("usename", "password"), nil},
{"Both Provided", url.UserPassword("unused", "unused"),
url.UserPassword("username", "password")},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
basicAuthTestcase(t, tc.upstreamUser, tc.clientUser)
})
}
}
func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
@ -972,7 +1069,7 @@ func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
p := &Proxy{
Next: httpserver.EmptyNext,
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false)},
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false, 30*time.Second)},
}
r, err := http.NewRequest("GET", "/foo", nil)
if err != nil {
@ -1107,7 +1204,7 @@ func TestProxyDirectorURL(t *testing.T) {
continue
}
NewSingleHostReverseProxy(targetURL, c.without, 0).Director(req)
NewSingleHostReverseProxy(targetURL, c.without, 0, 30*time.Second).Director(req)
if expect, got := c.expectURL, req.URL.String(); expect != got {
t.Errorf("case %d url not equal: expect %q, but got %q",
i, expect, got)
@ -1254,7 +1351,7 @@ func TestCancelRequest(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
}
// setup request with cancel ctx
@ -1303,14 +1400,15 @@ func (r *noopReader) Read(b []byte) (int, error) {
return n, nil
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
func newFakeUpstream(name string, insecure bool, timeout time.Duration) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{
name: name,
from: "/",
name: name,
from: "/",
timeout: timeout,
host: &UpstreamHost{
Name: name,
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost),
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost, timeout),
},
}
if insecure {
@ -1324,6 +1422,7 @@ type fakeUpstream struct {
host *UpstreamHost
from string
without string
timeout time.Duration
}
func (u *fakeUpstream) From() string {
@ -1338,7 +1437,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
}
u.host = &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
}
}
return u.host
@ -1347,6 +1446,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeUpstream) GetTimeout() time.Duration { return u.timeout }
func (u *fakeUpstream) GetHostCount() int { return 1 }
func (u *fakeUpstream) Stop() error { return nil }
@ -1354,13 +1454,14 @@ func (u *fakeUpstream) Stop() error { return nil }
// redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket
// proxy.
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
func newWebSocketTestProxy(backendAddr string, insecure bool, timeout time.Duration) *Proxy {
return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{
name: backendAddr,
without: "",
insecure: insecure,
timeout: timeout,
}},
}
}
@ -1368,7 +1469,7 @@ func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix, timeout: 30 * time.Second}},
}
}
@ -1376,6 +1477,7 @@ type fakeWsUpstream struct {
name string
without string
insecure bool
timeout time.Duration
}
func (u *fakeWsUpstream) From() string {
@ -1386,7 +1488,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name)
host := &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}},
@ -1400,6 +1502,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeWsUpstream) GetTimeout() time.Duration { return u.timeout }
func (u *fakeWsUpstream) GetHostCount() int { return 1 }
func (u *fakeWsUpstream) Stop() error { return nil }
@ -1445,7 +1548,7 @@ func BenchmarkProxy(b *testing.B) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.UpstreamHeaders = http.Header{
"Hostname": {"{hostname}"},
"Host": {"{host}"},
@ -1488,7 +1591,7 @@ func TestChunkedWebSocketReverseProxy(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request
r := httptest.NewRequest("GET", "/", nil)

View file

@ -94,6 +94,10 @@ type ReverseProxy struct {
// If zero, no periodic flushing is done.
FlushInterval time.Duration
// dialer is used when values from the
// defaultDialer need to be overridden per Proxy
dialer *net.Dialer
srvResolver srvResolver
}
@ -103,13 +107,13 @@ type ReverseProxy struct {
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
}
}
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
service := locator
if strings.HasPrefix(locator, "srv://") {
service = locator[6:]
@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
if err != nil {
return nil, err
}
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
}
}
@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
// the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
if target.Scheme == "unix" {
@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}
}
dialer := *defaultDialer
if timeout != defaultDialer.Timeout {
dialer.Timeout = timeout
}
rp := &ReverseProxy{
Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver,
dialer: &dialer,
}
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
Dial: socketDial(target.String(), timeout),
}
} else if target.Scheme == "quic" {
rp.Transport = &h2quic.RoundTripper{
@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
},
}
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial
dialFunc := rp.dialer.Dial
if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String())
dialFunc = rp.srvDialerFunc(target.String(), timeout)
}
transport := &http.Transport{
@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial,
Dial: rp.dialer.Dial,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport)
} else if transport == nil {
transport = http.DefaultTransport
transport = &http.Transport{
Dial: rp.dialer.Dial,
}
}
rp.Director(outreq)
@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
}
bufferPool.Put(hj.Replay)
} else {
backendConn, err = net.Dial("tcp", outreq.URL.Host)
backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
if err != nil {
return err
}

View file

@ -21,6 +21,7 @@ import (
"net/url"
"strconv"
"testing"
"time"
)
const (
@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
}
port := uint16(pp)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second)
rp.srvResolver = testResolver{
result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},

View file

@ -49,6 +49,7 @@ type staticUpstream struct {
Hosts HostPool
Policy Policy
KeepAlive int
Timeout time.Duration
FailTimeout time.Duration
TryDuration time.Duration
TryInterval time.Duration
@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
TryInterval: 250 * time.Millisecond,
MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost,
Timeout: 30 * time.Second,
resolver: net.DefaultResolver,
}
@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
return nil, err
}
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout)
if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport()
}
@ -431,9 +433,10 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
}
u.downstreamHeaders.Add(header, value)
case "transparent":
// Note: X-Forwarded-For header is always being appended for proxy connections
// See implementation of createUpstreamRequest in proxy.go
u.upstreamHeaders.Add("Host", "{host}")
u.upstreamHeaders.Add("X-Real-IP", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
case "websocket":
u.upstreamHeaders.Add("Connection", "{>Connection}")
@ -463,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
return c.ArgErr()
}
u.KeepAlive = n
case "timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return c.Errf("unable to parse timeout duration '%s'", c.Val())
}
u.Timeout = dur
default:
return c.Errf("unknown property '%s'", c.Val())
}
@ -618,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval
}
// GetTimeout returns u.Timeout.
func (u *staticUpstream) GetTimeout() time.Duration {
return u.Timeout
}
func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts)
}

View file

@ -282,7 +282,8 @@ func TestStop(t *testing.T) {
}
}
func TestParseBlock(t *testing.T) {
func TestParseBlockTransparent(t *testing.T) {
// tests for transparent proxy presets
r, _ := http.NewRequest("GET", "/", nil)
tests := []struct {
config string
@ -316,6 +317,10 @@ func TestParseBlock(t *testing.T) {
if _, ok := headers["X-Forwarded-Proto"]; !ok {
t.Errorf("Test %d: Could not find the X-Forwarded-Proto header", i+1)
}
if _, ok := headers["X-Forwarded-For"]; ok {
t.Errorf("Test %d: Found unexpected X-Forwarded-For header", i+1)
}
}
}
}

View file

@ -63,22 +63,38 @@ type Rule interface {
// SimpleRule is a simple rewrite rule.
type SimpleRule struct {
From, To string
Regexp *regexp.Regexp
To string
Negate bool
}
// NewSimpleRule creates a new Simple Rule
func NewSimpleRule(from, to string) SimpleRule {
return SimpleRule{from, to}
func NewSimpleRule(from, to string, negate bool) (*SimpleRule, error) {
r, err := regexp.Compile(from)
if err != nil {
return nil, err
}
return &SimpleRule{
Regexp: r,
To: to,
Negate: negate,
}, nil
}
// BasePath satisfies httpserver.Config
func (s SimpleRule) BasePath() string { return s.From }
func (s SimpleRule) BasePath() string { return "/" }
// Match satisfies httpserver.Config
func (s SimpleRule) Match(r *http.Request) bool { return s.From == r.URL.Path }
func (s *SimpleRule) Match(r *http.Request) bool {
matches := regexpMatches(s.Regexp, "/", r.URL.Path)
if s.Negate {
return len(matches) == 0
}
return len(matches) > 0
}
// Rewrite rewrites the internal location of the current request.
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
func (s *SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
// attempt rewrite
return To(fs, r, s.To, newReplacer(r))
@ -165,7 +181,7 @@ func (r ComplexRule) Match(req *http.Request) bool {
return true
}
// otherwise validate regex
return r.regexpMatches(req.URL.Path) != nil
return regexpMatches(r.Regexp, r.Base, req.URL.Path) != nil
}
// Rewrite rewrites the internal location of the current request.
@ -174,7 +190,7 @@ func (r ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result)
// validate regexp if present
if r.Regexp != nil {
matches := r.regexpMatches(req.URL.Path)
matches := regexpMatches(r.Regexp, r.Base, req.URL.Path)
switch len(matches) {
case 0:
// no match
@ -230,14 +246,14 @@ func (r ComplexRule) matchExt(rPath string) bool {
return !mustUse
}
func (r ComplexRule) regexpMatches(rPath string) []string {
if r.Regexp != nil {
func regexpMatches(regexp *regexp.Regexp, base, rPath string) []string {
if regexp != nil {
// include trailing slash in regexp if present
start := len(r.Base)
if strings.HasSuffix(r.Base, "/") {
start := len(base)
if strings.HasSuffix(base, "/") {
start--
}
return r.Regexp.FindStringSubmatch(rPath[start:])
return regexp.FindStringSubmatch(rPath[start:])
}
return nil
}

View file

@ -29,9 +29,9 @@ func TestRewrite(t *testing.T) {
rw := Rewrite{
Next: httpserver.HandlerFunc(urlPrinter),
Rules: []httpserver.HandlerConfig{
NewSimpleRule("/from", "/to"),
NewSimpleRule("/a", "/b"),
NewSimpleRule("/b", "/b{uri}"),
newSimpleRule(t, "^/from$", "/to"),
newSimpleRule(t, "^/a$", "/b"),
newSimpleRule(t, "^/b$", "/b{uri}"),
},
FileSys: http.Dir("."),
}
@ -131,6 +131,45 @@ func TestRewrite(t *testing.T) {
}
}
// TestWordpress is a test for wordpress usecase.
func TestWordpress(t *testing.T) {
rw := Rewrite{
Next: httpserver.HandlerFunc(urlPrinter),
Rules: []httpserver.HandlerConfig{
// both rules are same, thanks to Go regexp (confusion).
newSimpleRule(t, "^/wp-admin", "{path} {path}/ /index.php?{query}", true),
newSimpleRule(t, "^\\/wp-admin", "{path} {path}/ /index.php?{query}", true),
},
FileSys: http.Dir("."),
}
tests := []struct {
from string
expectedTo string
}{
{"/wp-admin", "/wp-admin"},
{"/wp-admin/login.php", "/wp-admin/login.php"},
{"/not-wp-admin/login.php?not=admin", "/index.php?not=admin"},
{"/loophole", "/index.php"},
{"/user?name=john", "/index.php?name=john"},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
ctx := context.WithValue(req.Context(), httpserver.OriginalURLCtxKey, *req.URL)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
rw.ServeHTTP(rec, req)
if got, want := rec.Body.String(), test.expectedTo; got != want {
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", i, want, got)
}
}
}
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprint(w, r.URL.String())
return 0, nil

View file

@ -58,6 +58,7 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
var base = "/"
var pattern, to string
var ext []string
var negate bool
args := c.RemainingArgs()
@ -111,7 +112,14 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
// the only unhandled case is 2 and above
default:
rule = NewSimpleRule(args[0], strings.Join(args[1:], " "))
if args[0] == "not" {
negate = true
args = args[1:]
}
rule, err = NewSimpleRule(args[0], strings.Join(args[1:], " "), negate)
if err != nil {
return nil, err
}
rules = append(rules, rule)
}

View file

@ -50,6 +50,19 @@ func TestSetup(t *testing.T) {
}
}
// newSimpleRule is convenience test function for SimpleRule.
func newSimpleRule(t *testing.T, from, to string, negate ...bool) Rule {
var n bool
if len(negate) > 0 {
n = negate[0]
}
rule, err := NewSimpleRule(from, to, n)
if err != nil {
t.Fatal(err)
}
return rule
}
func TestRewriteParse(t *testing.T) {
simpleTests := []struct {
input string
@ -57,17 +70,20 @@ func TestRewriteParse(t *testing.T) {
expected []Rule
}{
{`rewrite /from /to`, false, []Rule{
SimpleRule{From: "/from", To: "/to"},
newSimpleRule(t, "/from", "/to"),
}},
{`rewrite /from /to
rewrite a b`, false, []Rule{
SimpleRule{From: "/from", To: "/to"},
SimpleRule{From: "a", To: "b"},
newSimpleRule(t, "/from", "/to"),
newSimpleRule(t, "a", "b"),
}},
{`rewrite a`, true, []Rule{}},
{`rewrite`, true, []Rule{}},
{`rewrite a b c`, false, []Rule{
SimpleRule{From: "a", To: "b c"},
newSimpleRule(t, "a", "b c"),
}},
{`rewrite not a b c`, false, []Rule{
newSimpleRule(t, "a", "b c", true),
}},
}
@ -88,17 +104,22 @@ func TestRewriteParse(t *testing.T) {
}
for j, e := range test.expected {
actualRule := actual[j].(SimpleRule)
expectedRule := e.(SimpleRule)
actualRule := actual[j].(*SimpleRule)
expectedRule := e.(*SimpleRule)
if actualRule.From != expectedRule.From {
if actualRule.Regexp.String() != expectedRule.Regexp.String() {
t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
i, j, expectedRule.From, actualRule.From)
i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
}
if actualRule.To != expectedRule.To {
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
i, j, expectedRule.To, actualRule.To)
i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
}
if actualRule.Negate != expectedRule.Negate {
t.Errorf("Test %d, rule %d: Expected Negate=%v, got %v",
i, j, expectedRule.Negate, actualRule.Negate)
}
}
}

View file

@ -265,21 +265,21 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
return err
}
if leaf.Subject.CommonName != "" {
if leaf.Subject.CommonName != "" { // TODO: CommonName is deprecated
cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
}
for _, name := range leaf.DNSNames {
if name != leaf.Subject.CommonName {
if name != leaf.Subject.CommonName { // TODO: CommonName is deprecated
cert.Names = append(cert.Names, strings.ToLower(name))
}
}
for _, ip := range leaf.IPAddresses {
if ipStr := ip.String(); ipStr != leaf.Subject.CommonName {
if ipStr := ip.String(); ipStr != leaf.Subject.CommonName { // TODO: CommonName is deprecated
cert.Names = append(cert.Names, strings.ToLower(ipStr))
}
}
for _, email := range leaf.EmailAddresses {
if email != leaf.Subject.CommonName {
if email != leaf.Subject.CommonName { // TODO: CommonName is deprecated
cert.Names = append(cert.Names, strings.ToLower(email))
}
}

View file

@ -43,10 +43,11 @@ func TestUnexportedGetCertificate(t *testing.T) {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
}
// TODO: Re-implement this behavior when I'm not in the middle of upgrading for ACMEv2 support. :) (it was reverted in #2037)
// // When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
// if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
// t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
// }
// When no certificate matches and SNI is NOT provided, a random is returned
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {

View file

@ -27,7 +27,7 @@ import (
"github.com/mholt/caddy"
"github.com/mholt/caddy/telemetry"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
// acmeMu ensures that only one ACME challenge occurs at a time.
@ -90,27 +90,22 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// If not registered, the user must register an account with the CA
// and agree to terms
if leUser.Registration == nil {
reg, err := client.Register()
if allowPrompts { // can't prompt a user who isn't there
termsURL := client.GetToSURL()
if !Agreed && termsURL != "" {
Agreed = askUserAgreement(client.GetToSURL())
}
if !Agreed && termsURL != "" {
return nil, errors.New("user must agree to CA terms (use -agree flag)")
}
}
reg, err := client.Register(Agreed)
if err != nil {
return nil, errors.New("registration error: " + err.Error())
}
leUser.Registration = reg
if allowPrompts { // can't prompt a user who isn't there
if !Agreed && reg.TosURL == "" {
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
}
if !Agreed && reg.TosURL == "" {
return nil, errors.New("user must agree to terms")
}
}
err = client.AgreeToTOS()
if err != nil {
saveUser(storage, leUser) // Might as well try, right?
return nil, errors.New("error agreeing to terms: " + err.Error())
}
// save user to the file system
err = saveUser(storage, leUser)
if err != nil {
@ -137,38 +132,57 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
useHTTPPort = DefaultHTTPAlternatePort
}
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// See which port TLS-SNI challenges will be accomplished on
useTLSSNIPort := TLSSNIChallengePort
if config.AltTLSSNIPort != "" {
useTLSSNIPort = config.AltTLSSNIPort
}
// Always respect user's bind preferences by using config.ListenHost.
// NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress()
// must be called before SetChallengeProvider(), since they reset the
// challenge provider back to the default one!
err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort))
if err != nil {
return nil, err
}
err = c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
if err != nil {
return nil, err
// useTLSSNIPort := TLSSNIChallengePort
// if config.AltTLSSNIPort != "" {
// useTLSSNIPort = config.AltTLSSNIPort
// }
// err := c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
// if err != nil {
// return nil, err
// }
// if using file storage, we can distribute the HTTP challenge across
// all instances sharing the acme folder; either way, we must still set
// the address for the default HTTP provider server
var useDistributedHTTPSolver bool
if storage, err := c.config.StorageFor(c.config.CAUrl); err == nil {
if _, ok := storage.(*FileStorage); ok {
useDistributedHTTPSolver = true
}
}
if useDistributedHTTPSolver {
c.acmeClient.SetChallengeProvider(acme.HTTP01, distributedHTTPSolver{
// being careful to respect user's listener bind preferences
httpProviderServer: acme.NewHTTPProviderServer(config.ListenHost, useHTTPPort),
})
} else {
// Always respect user's bind preferences by using config.ListenHost.
// NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress()
// must be called before SetChallengeProvider() (see above), since they reset
// the challenge provider back to the default one! (still true in March 2018)
err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort))
if err != nil {
return nil, err
}
}
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// See if TLS challenge needs to be handled by our own facilities
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
}
// if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
// c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
// }
// Disable any challenges that should not be used
var disabledChallenges []acme.Challenge
if DisableHTTPChallenge {
disabledChallenges = append(disabledChallenges, acme.HTTP01)
}
if DisableTLSSNIChallenge {
disabledChallenges = append(disabledChallenges, acme.TLSSNI01)
}
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// if DisableTLSSNIChallenge {
// disabledChallenges = append(disabledChallenges, acme.TLSSNI01)
// }
if len(disabledChallenges) > 0 {
c.acmeClient.ExcludeChallenges(disabledChallenges)
}
@ -189,7 +203,9 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
}
// Use the DNS challenge exclusively
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01})
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01})
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01})
c.acmeClient.SetChallengeProvider(acme.DNS01, prov)
}
@ -222,41 +238,31 @@ func (c *ACMEClient) Obtain(name string) error {
}
}()
Attempts:
for attempts := 0; attempts < 2; attempts++ {
namesObtaining.Add([]string{name})
acmeMu.Lock()
certificate, failures := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple)
certificate, err := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple)
acmeMu.Unlock()
namesObtaining.Remove([]string{name})
if len(failures) > 0 {
// Error - try to fix it or report it to the user and abort
var errMsg string // we'll combine all the failures into a single error message
var promptedForAgreement bool // only prompt user for agreement at most once
for errDomain, obtainErr := range failures {
if obtainErr == nil {
continue
}
if tosErr, ok := obtainErr.(acme.TOSError); ok {
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && c.AllowPrompts {
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || !c.AllowPrompts {
err := c.acmeClient.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
continue Attempts
if err != nil {
// for a certain kind of error, we can enumerate the error per-domain
if failures, ok := err.(acme.ObtainError); ok && len(failures) > 0 {
var errMsg string // combine all the failures into a single error message
for errDomain, obtainErr := range failures {
if obtainErr == nil {
continue
}
errMsg += fmt.Sprintf("[%s] failed to get certificate: %v\n", errDomain, obtainErr)
}
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
return errors.New(errMsg)
}
return errors.New(errMsg)
return fmt.Errorf("[%s] failed to obtain certificate: %v", name, err)
}
// double-check that we actually got a certificate, in case there's a bug upstream (see issue #2121)
if certificate.Domain == "" || certificate.Certificate == nil {
return errors.New("returned certificate was empty; probably an unchecked error obtaining it")
}
// Success - immediately save the certificate resource
@ -315,23 +321,20 @@ func (c *ACMEClient) Renew(name string) error {
acmeMu.Unlock()
namesObtaining.Remove([]string{name})
if err == nil {
success = true
break
}
// If the legal terms were updated and need to be
// agreed to again, we can handle that.
if _, ok := err.(acme.TOSError); ok {
err := c.acmeClient.AgreeToTOS()
if err != nil {
return err
// double-check that we actually got a certificate; check a couple fields
// TODO: This is a temporary workaround for what I think is a bug in the acmev2 package (March 2018)
// but it might not hurt to keep this extra check in place
if newCertMeta.Domain == "" || newCertMeta.Certificate == nil {
err = errors.New("returned certificate was empty; probably an unchecked error renewing it")
} else {
success = true
break
}
continue
}
// For any other kind of error, wait 10s and try again.
// wait a little bit and try again
wait := 10 * time.Second
log.Printf("[ERROR] Renewing: %v; trying again in %s", err, wait)
log.Printf("[ERROR] Renewing [%v]: %v; trying again in %s", name, err, wait)
time.Sleep(wait)
}

View file

@ -25,7 +25,7 @@ import (
"github.com/klauspost/cpuid"
"github.com/mholt/caddy"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
// Config describes how TLS should be configured and used.
@ -190,10 +190,15 @@ func NewConfig(inst *caddy.Instance) *Config {
// it does not load them into memory. If allowPrompts is true,
// the user may be shown a prompt.
func (c *Config) ObtainCert(name string, allowPrompts bool) error {
if !c.Managed || !HostQualifies(name) {
skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
if err != nil {
return err
}
if skip {
return nil
}
// we expect this to be a new (non-existent) site
storage, err := c.StorageFor(c.CAUrl)
if err != nil {
return err
@ -205,9 +210,6 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
if siteExists {
return nil
}
if c.ACMEEmail == "" {
c.ACMEEmail = getEmail(storage, allowPrompts)
}
client, err := newACMEClient(c, allowPrompts)
if err != nil {
@ -219,6 +221,14 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
// RenewCert renews the certificate for name using c. It stows the
// renewed certificate and its assets in storage if successful.
func (c *Config) RenewCert(name string, allowPrompts bool) error {
skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
if err != nil {
return err
}
if skip {
return nil
}
client, err := newACMEClient(c, allowPrompts)
if err != nil {
return err
@ -226,6 +236,33 @@ func (c *Config) RenewCert(name string, allowPrompts bool) error {
return client.Renew(name)
}
// preObtainOrRenewChecks perform a few simple checks before
// obtaining or renewing a certificate with ACME, and returns
// whether this name should be skipped (like if it's not
// managed TLS) as well as any error. It ensures that the
// config is Managed, that the name qualifies for a certificate,
// and that an email address is available.
func (c *Config) preObtainOrRenewChecks(name string, allowPrompts bool) (bool, error) {
if !c.Managed || !HostQualifies(name) {
return true, nil
}
// wildcard certificates require DNS challenge (as of March 2018)
if strings.Contains(name, "*") && c.DNSProvider == "" {
return false, fmt.Errorf("wildcard domain name (%s) requires DNS challenge; use dns subdirective to configure it", name)
}
if c.ACMEEmail == "" {
var err error
c.ACMEEmail, err = getEmail(c, allowPrompts)
if err != nil {
return false, err
}
}
return false, nil
}
// StorageFor obtains a TLS Storage instance for the given CA URL which should
// be unique for every different ACME CA. If a StorageCreator is set on this
// Config, it will be used. Otherwise the default file storage implementation
@ -476,6 +513,14 @@ func assertConfigsCompatible(cfg1, cfg2 *Config) error {
if c1.ClientAuth != c2.ClientAuth {
return fmt.Errorf("client authentication policy mismatch")
}
if c1.ClientAuth != tls.NoClientCert && c2.ClientAuth != tls.NoClientCert && c1.ClientCAs != c2.ClientCAs {
// Two hosts defined on the same listener are not compatible if they
// have ClientAuth enabled, because there's no guarantee beyond the
// hostname which config will be used (because SNI only has server name).
// To prevent clients from bypassing authentication, require that
// ClientAuth be configured in an unambiguous manner.
return fmt.Errorf("multiple hosts requiring client authentication ambiguously configured")
}
return nil
}
@ -511,7 +556,7 @@ func SetDefaultTLSParams(config *Config) {
// Set default protocol min and max versions - must balance compatibility and security
if config.ProtocolMinVersion == 0 {
config.ProtocolMinVersion = tls.VersionTLS11
config.ProtocolMinVersion = tls.VersionTLS12
}
if config.ProtocolMaxVersion == 0 {
config.ProtocolMaxVersion = tls.VersionTLS12
@ -532,7 +577,8 @@ var supportedKeyTypes = map[string]acme.KeyType{
// Map of supported protocols.
// HTTP/2 only supports TLS 1.2 and higher.
var supportedProtocols = map[string]uint16{
// If updating this map, also update tlsProtocolStringToMap in caddyhttp/fastcgi/fastcgi.go
var SupportedProtocols = map[string]uint16{
"tls1.0": tls.VersionTLS10,
"tls1.1": tls.VersionTLS11,
"tls1.2": tls.VersionTLS12,
@ -548,7 +594,7 @@ var supportedProtocols = map[string]uint16{
// it is always added (even though it is not technically a cipher suite).
//
// This map, like any map, is NOT ORDERED. Do not range over this map.
var supportedCiphersMap = map[string]uint16{
var SupportedCiphersMap = map[string]uint16{
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,

View file

@ -35,13 +35,14 @@ import (
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/crypto/ocsp"
"github.com/mholt/caddy"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
@ -106,7 +107,8 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error {
// TODO: Use Storage interface instead of disk directly
var ocspFileNamePrefix string
if len(cert.Names) > 0 {
ocspFileNamePrefix = cert.Names[0] + "-"
firstName := strings.Replace(cert.Names[0], "*", "wildcard_", -1)
ocspFileNamePrefix = firstName + "-"
}
ocspFileName := ocspFileNamePrefix + fastHash(pemBundle)
ocspCachePath := filepath.Join(ocspFolder, ocspFileName)
@ -216,10 +218,13 @@ func makeSelfSignedCert(config *Config) error {
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
var names []string
if ip := net.ParseIP(config.Hostname); ip != nil {
names = append(names, strings.ToLower(ip.String()))
cert.IPAddresses = append(cert.IPAddresses, ip)
} else {
cert.DNSNames = append(cert.DNSNames, config.Hostname)
names = append(names, strings.ToLower(config.Hostname))
cert.DNSNames = append(cert.DNSNames, strings.ToLower(config.Hostname))
}
publicKey := func(privKey interface{}) interface{} {
@ -245,7 +250,7 @@ func makeSelfSignedCert(config *Config) error {
PrivateKey: privKey,
Leaf: cert,
},
Names: cert.DNSNames,
Names: names,
NotAfter: cert.NotAfter,
Hash: hashCertificateChain(chain),
})

View file

@ -30,14 +30,14 @@ func init() {
RegisterStorageProvider("file", NewFileStorage)
}
// storageBasePath is the root path in which all TLS/ACME assets are
// stored. Do not change this value during the lifetime of the program.
var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
// NewFileStorage is a StorageConstructor function that creates a new
// Storage instance backed by the local disk. The resulting Storage
// instance is guaranteed to be non-nil if there is no error.
func NewFileStorage(caURL *url.URL) (Storage, error) {
// storageBasePath is the root path in which all TLS/ACME assets are
// stored. Do not change this value during the lifetime of the program.
storageBasePath := filepath.Join(caddy.AssetsPath(), "acme")
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
return storage, nil
@ -58,25 +58,25 @@ func (s *FileStorage) sites() string {
// site returns the path to the folder containing assets for domain.
func (s *FileStorage) site(domain string) string {
domain = strings.ToLower(domain)
domain = fileSafe(domain)
return filepath.Join(s.sites(), domain)
}
// siteCertFile returns the path to the certificate file for domain.
func (s *FileStorage) siteCertFile(domain string) string {
domain = strings.ToLower(domain)
domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".crt")
}
// siteKeyFile returns the path to domain's private key file.
func (s *FileStorage) siteKeyFile(domain string) string {
domain = strings.ToLower(domain)
domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".key")
}
// siteMetaFile returns the path to the domain's asset metadata file.
func (s *FileStorage) siteMetaFile(domain string) string {
domain = strings.ToLower(domain)
domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".json")
}
@ -90,7 +90,7 @@ func (s *FileStorage) user(email string) string {
if email == "" {
email = emptyEmail
}
email = strings.ToLower(email)
email = fileSafe(email)
return filepath.Join(s.users(), email)
}
@ -117,6 +117,7 @@ func (s *FileStorage) userRegFile(email string) string {
if fileName == "" {
fileName = "registration"
}
fileName = fileSafe(fileName)
return filepath.Join(s.user(email), fileName+".json")
}
@ -131,6 +132,7 @@ func (s *FileStorage) userKeyFile(email string) string {
if fileName == "" {
fileName = "private"
}
fileName = fileSafe(fileName)
return filepath.Join(s.user(email), fileName+".key")
}
@ -274,3 +276,29 @@ func (s *FileStorage) MostRecentUserEmail() string {
}
return ""
}
// fileSafe standardizes and sanitizes str for use in a file path.
func fileSafe(str string) string {
str = strings.ToLower(str)
str = strings.TrimSpace(str)
repl := strings.NewReplacer("..", "",
"/", "",
"\\", "",
// TODO: Consider also replacing "@" with "_at_" (but migrate existing accounts...)
"+", "_plus_",
"%", "",
"$", "",
"`", "",
"~", "",
":", "",
";", "",
"=", "",
"!", "",
"#", "",
"&", "",
"|", "",
"\"", "",
"'", "",
"*", "wildcard_")
return repl.Replace(str)
}

View file

@ -14,7 +14,71 @@
package caddytls
import (
"path/filepath"
"testing"
)
// *********************************** NOTE ********************************
// Due to circular package dependencies with the storagetest sub package and
// the fact that we want to use that harness to test file storage, the tests
// for file storage are done in the storagetest package.
// the fact that we want to use that harness to test file storage, most of
// the tests for file storage are done in the storagetest package.
func TestPathBuilders(t *testing.T) {
fs := FileStorage{Path: "test"}
for i, testcase := range []struct {
in, folder, certFile, keyFile, metaFile string
}{
{
in: "example.com",
folder: filepath.Join("test", "sites", "example.com"),
certFile: filepath.Join("test", "sites", "example.com", "example.com.crt"),
keyFile: filepath.Join("test", "sites", "example.com", "example.com.key"),
metaFile: filepath.Join("test", "sites", "example.com", "example.com.json"),
},
{
in: "*.example.com",
folder: filepath.Join("test", "sites", "wildcard_.example.com"),
certFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.crt"),
keyFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.key"),
metaFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.json"),
},
{
// prevent directory traversal! very important, esp. with on-demand TLS
// see issue #2092
in: "a/../../../foo",
folder: filepath.Join("test", "sites", "afoo"),
certFile: filepath.Join("test", "sites", "afoo", "afoo.crt"),
keyFile: filepath.Join("test", "sites", "afoo", "afoo.key"),
metaFile: filepath.Join("test", "sites", "afoo", "afoo.json"),
},
{
in: "b\\..\\..\\..\\foo",
folder: filepath.Join("test", "sites", "bfoo"),
certFile: filepath.Join("test", "sites", "bfoo", "bfoo.crt"),
keyFile: filepath.Join("test", "sites", "bfoo", "bfoo.key"),
metaFile: filepath.Join("test", "sites", "bfoo", "bfoo.json"),
},
{
in: "c/foo",
folder: filepath.Join("test", "sites", "cfoo"),
certFile: filepath.Join("test", "sites", "cfoo", "cfoo.crt"),
keyFile: filepath.Join("test", "sites", "cfoo", "cfoo.key"),
metaFile: filepath.Join("test", "sites", "cfoo", "cfoo.json"),
},
} {
if actual := fs.site(testcase.in); actual != testcase.folder {
t.Errorf("Test %d: site folder: Expected '%s' but got '%s'", i, testcase.folder, actual)
}
if actual := fs.siteCertFile(testcase.in); actual != testcase.certFile {
t.Errorf("Test %d: site cert file: Expected '%s' but got '%s'", i, testcase.certFile, actual)
}
if actual := fs.siteKeyFile(testcase.in); actual != testcase.keyFile {
t.Errorf("Test %d: site key file: Expected '%s' but got '%s'", i, testcase.keyFile, actual)
}
if actual := fs.siteMetaFile(testcase.in); actual != testcase.metaFile {
t.Errorf("Test %d: site meta file: Expected '%s' but got '%s'", i, testcase.metaFile, actual)
}
}
}

View file

@ -91,7 +91,20 @@ func (s *fileStorageLock) Unlock(name string) error {
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
// remove lock file
os.Remove(fw.filename)
// if parent folder is now empty, remove it too to keep it tidy
lockParentFolder := s.storage.site(name)
dir, err := os.Open(lockParentFolder)
if err == nil {
items, _ := dir.Readdirnames(3) // OK to ignore error here
if len(items) == 0 {
os.Remove(lockParentFolder)
}
dir.Close()
}
fw.wg.Done()
delete(fileStorageNameLocks, s.caURL+name)
return nil

View file

@ -61,10 +61,9 @@ func (cg configGroup) getConfig(name string) *Config {
}
}
// try a config that serves all names (this
// is basically the same as a config defined
// for "*" -- I think -- but the above loop
// doesn't try an empty string)
// try a config that serves all names (the above
// loop doesn't try empty string; for hosts defined
// with only a port, for instance, like ":443")
if config, ok := cg[""]; ok {
return config
}
@ -190,17 +189,19 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau
return
}
// if nothing matches and SNI was not provided, use a random
// certificate; at least there's a chance this older client
// can connect, and in the future we won't need this provision
// (if SNI is present, it's probably best to just raise a TLS
// alert by not serving a certificate)
if name == "" {
for _, certKey := range cfg.Certificates {
defaulted = true
cert = cfg.certCache.cache[certKey]
return
}
// if nothing matches, use a random certificate
// TODO: This is not my favorite behavior; I would rather serve
// no certificate if SNI is provided and cause a TLS alert, than
// serve the wrong certificate (but sometimes the 'wrong' cert
// is what is wanted, but in those cases I would prefer that the
// site owner explicitly configure a "default" certificate).
// (See issue 2035; any change to this behavior must account for
// hosts defined like ":443" or "0.0.0.0:443" where the hostname
// is empty or a catch-all IP or something.)
for _, certKey := range cfg.Certificates {
cert = cfg.certCache.cache[certKey]
defaulted = true
return
}
return

View file

@ -27,7 +27,7 @@ func TestGetCertificate(t *testing.T) {
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
helloNoSNI := &tls.ClientHelloInfo{}
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
// helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} // TODO (see below)
// When cache is empty
if cert, err := cfg.GetCertificate(hello); err == nil {
@ -69,8 +69,9 @@ func TestGetCertificate(t *testing.T) {
t.Errorf("Expected random cert with no matches, got: %v", cert)
}
// TODO: Re-implement this behavior (it was reverted in #2037)
// When no certificate matches, raise an alert
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
}
// if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
// t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
// }
}

View file

@ -16,12 +16,16 @@ package caddytls
import (
"crypto/tls"
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"github.com/xenolf/lego/acmev2"
)
const challengeBasePath = "/.well-known/acme-challenge"
@ -38,6 +42,13 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
if DisableHTTPChallenge {
return false
}
// see if another instance started the HTTP challenge for this name
if tryDistributedChallengeSolver(w, r) {
return true
}
// otherwise, if we aren't getting the name, then ignore this challenge
if !namesObtaining.Has(r.Host) {
return false
}
@ -70,3 +81,40 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
return true
}
// tryDistributedChallengeSolver checks to see if this challenge
// request was initiated by another instance that shares file
// storage, and attempts to complete the challenge for it. It
// returns true if the challenge was handled; false otherwise.
func tryDistributedChallengeSolver(w http.ResponseWriter, r *http.Request) bool {
filePath := distributedHTTPSolver{}.challengeTokensPath(r.Host)
f, err := os.Open(filePath)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("[ERROR][%s] Opening distributed challenge token file: %v", r.Host, err)
}
return false
}
defer f.Close()
var chalInfo challengeInfo
err = json.NewDecoder(f).Decode(&chalInfo)
if err != nil {
log.Printf("[ERROR][%s] Decoding challenge token file %s (corrupted?): %v", r.Host, filePath, err)
return false
}
// this part borrowed from xenolf/lego's built-in HTTP-01 challenge solver (March 2018)
challengeReqPath := acme.HTTP01ChallengePath(chalInfo.Token)
if r.URL.Path == challengeReqPath &&
strings.HasPrefix(r.Host, chalInfo.Domain) &&
r.Method == "GET" {
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte(chalInfo.KeyAuth))
r.Close = true
log.Printf("[INFO][%s] Served key authentication", chalInfo.Domain)
return true
}
return false
}

View file

@ -334,6 +334,7 @@ func DeleteOldStapleFiles() {
if err != nil {
log.Printf("[ERROR] Purging corrupt staple file %s: %v", stapleFile, err)
}
continue
}
if time.Now().After(resp.NextUpdate) {
// response has expired; delete it

View file

@ -107,19 +107,19 @@ func setupTLS(c *caddy.Controller) error {
case "protocols":
args := c.RemainingArgs()
if len(args) == 1 {
value, ok := supportedProtocols[strings.ToLower(args[0])]
value, ok := SupportedProtocols[strings.ToLower(args[0])]
if !ok {
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
}
config.ProtocolMinVersion, config.ProtocolMaxVersion = value, value
} else {
value, ok := supportedProtocols[strings.ToLower(args[0])]
value, ok := SupportedProtocols[strings.ToLower(args[0])]
if !ok {
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
}
config.ProtocolMinVersion = value
value, ok = supportedProtocols[strings.ToLower(args[1])]
value, ok = SupportedProtocols[strings.ToLower(args[1])]
if !ok {
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[1])
}
@ -130,7 +130,7 @@ func setupTLS(c *caddy.Controller) error {
}
case "ciphers":
for c.NextArg() {
value, ok := supportedCiphersMap[strings.ToUpper(c.Val())]
value, ok := SupportedCiphersMap[strings.ToUpper(c.Val())]
if !ok {
return c.Errf("Wrong cipher name or cipher not supported: '%s'", c.Val())
}
@ -210,8 +210,21 @@ func setupTLS(c *caddy.Controller) error {
}
case "must_staple":
config.MustStaple = true
case "wildcard":
if !HostQualifies(config.Hostname) {
return c.Errf("Hostname '%s' does not qualify for managed TLS, so cannot manage wildcard certificate for it", config.Hostname)
}
if strings.Contains(config.Hostname, "*") {
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: already has a wildcard label", config.Hostname)
}
parts := strings.Split(config.Hostname, ".")
if len(parts) < 3 {
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: too few labels", config.Hostname)
}
parts[0] = "*"
config.Hostname = strings.Join(parts, ".")
default:
return c.Errf("Unknown keyword '%s'", c.Val())
return c.Errf("Unknown subdirective '%s'", c.Val())
}
}

View file

@ -22,7 +22,7 @@ import (
"testing"
"github.com/mholt/caddy"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
func TestMain(m *testing.M) {
@ -67,8 +67,8 @@ func TestSetupParseBasic(t *testing.T) {
}
// Security defaults
if cfg.ProtocolMinVersion != tls.VersionTLS11 {
t.Errorf("Expected 'tls1.1 (0x0302)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion)
if cfg.ProtocolMinVersion != tls.VersionTLS12 {
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion)
}
if cfg.ProtocolMaxVersion != tls.VersionTLS12 {
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", cfg.ProtocolMaxVersion)

View file

@ -58,7 +58,8 @@ type Locker interface {
// successfully obtained the lock (no Waiter value was returned)
// should call this method, and it should be called only after
// the obtain/renew and store are finished, even if there was
// an error (or a timeout).
// an error (or a timeout). Unlock should also clean up any
// unused resources allocated during TryLock.
Unlock(name string) error
}

View file

@ -30,26 +30,35 @@ package caddytls
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net"
"os"
"path/filepath"
"strings"
"github.com/mholt/caddy"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
// HostQualifies returns true if the hostname alone
// appears eligible for automatic HTTPS. For example,
// appears eligible for automatic HTTPS. For example:
// localhost, empty hostname, and IP addresses are
// not eligible because we cannot obtain certificates
// for those names.
// for those names. Wildcard names are allowed, as long
// as they conform to CABF requirements (only one wildcard
// label, and it must be the left-most label).
func HostQualifies(hostname string) bool {
return hostname != "localhost" && // localhost is ineligible
// hostname must not be empty
strings.TrimSpace(hostname) != "" &&
// must not contain wildcard (*) characters (until CA supports it)
!strings.Contains(hostname, "*") &&
// only one wildcard label allowed, and it must be left-most
(!strings.Contains(hostname, "*") ||
(strings.Count(hostname, "*") == 1 &&
strings.HasPrefix(hostname, "*."))) &&
// must not start or end with a dot
!strings.HasPrefix(hostname, ".") &&
@ -88,39 +97,125 @@ func Revoke(host string) error {
return client.Revoke(host)
}
// tlsSNISolver is a type that can solve TLS-SNI challenges using
// an existing listener and our custom, in-memory certificate cache.
type tlsSNISolver struct {
certCache *certificateCache
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// // tlsSNISolver is a type that can solve TLS-SNI challenges using
// // an existing listener and our custom, in-memory certificate cache.
// type tlsSNISolver struct {
// certCache *certificateCache
// }
// // Present adds the challenge certificate to the cache.
// func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
// cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
// if err != nil {
// return err
// }
// certHash := hashCertificateChain(cert.Certificate)
// s.certCache.Lock()
// s.certCache.cache[acmeDomain] = Certificate{
// Certificate: cert,
// Names: []string{acmeDomain},
// Hash: certHash, // perhaps not necesssary
// }
// s.certCache.Unlock()
// return nil
// }
// // CleanUp removes the challenge certificate from the cache.
// func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
// _, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
// if err != nil {
// return err
// }
// s.certCache.Lock()
// delete(s.certCache.cache, acmeDomain)
// s.certCache.Unlock()
// return nil
// }
// distributedHTTPSolver allows the HTTP-01 challenge to be solved by
// an instance other than the one which initiated it. This is useful
// behind load balancers or in other cluster/fleet configurations.
// The only requirement is that this (the initiating) instance share
// the $CADDYPATH/acme folder with the instance that will complete
// the challenge. Mounting the folder locally should be sufficient.
//
// Obviously, the instance which completes the challenge must be
// serving on the HTTPChallengePort to receive and handle the request.
// The HTTP server which receives it must check if a file exists, e.g.:
// $CADDYPATH/acme/challenge_tokens/example.com.json, and if so,
// decode it and use it to serve up the correct response. Caddy's HTTP
// server does this by default.
//
// So as long as the folder is shared, this will just work. There are
// no other requirements. The instances may be on other machines or
// even other networks, as long as they share the folder as part of
// the local file system.
//
// This solver works by persisting the token and keyauth information
// to disk in the shared folder when the authorization is presented,
// and then deletes it when it is cleaned up.
type distributedHTTPSolver struct {
// The distributed HTTPS solver only works if an instance (either
// this one or another one) is already listening and serving on the
// HTTPChallengePort. If not -- for example: if this is the only
// instance, and it is just starting up and hasn't started serving
// yet -- then we still need a listener open with an HTTP server
// to handle the challenge request. Set this field to have the
// standard HTTPProviderServer open its listener for the duration
// of the challenge. Make sure to configure its listen address
// correctly.
httpProviderServer *acme.HTTPProviderServer
}
type challengeInfo struct {
Domain, Token, KeyAuth string
}
// Present adds the challenge certificate to the cache.
func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
func (dhs distributedHTTPSolver) Present(domain, token, keyAuth string) error {
if dhs.httpProviderServer != nil {
err := dhs.httpProviderServer.Present(domain, token, keyAuth)
if err != nil {
return fmt.Errorf("presenting with standard HTTP provider server: %v", err)
}
}
err := os.MkdirAll(dhs.challengeTokensBasePath(), 0755)
if err != nil {
return err
}
certHash := hashCertificateChain(cert.Certificate)
s.certCache.Lock()
s.certCache.cache[acmeDomain] = Certificate{
Certificate: cert,
Names: []string{acmeDomain},
Hash: certHash, // perhaps not necesssary
infoBytes, err := json.Marshal(challengeInfo{
Domain: domain,
Token: token,
KeyAuth: keyAuth,
})
if err != nil {
return err
}
s.certCache.Unlock()
return nil
return ioutil.WriteFile(dhs.challengeTokensPath(domain), infoBytes, 0644)
}
// CleanUp removes the challenge certificate from the cache.
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
if err != nil {
return err
func (dhs distributedHTTPSolver) CleanUp(domain, token, keyAuth string) error {
if dhs.httpProviderServer != nil {
err := dhs.httpProviderServer.CleanUp(domain, token, keyAuth)
if err != nil {
log.Printf("[ERROR] Cleaning up standard HTTP provider server: %v", err)
}
}
s.certCache.Lock()
delete(s.certCache.cache, acmeDomain)
s.certCache.Unlock()
return nil
return os.Remove(dhs.challengeTokensPath(domain))
}
func (dhs distributedHTTPSolver) challengeTokensPath(domain string) string {
domainFile := strings.Replace(strings.ToLower(domain), "*", "wildcard_", -1)
return filepath.Join(dhs.challengeTokensBasePath(), domainFile+".json")
}
func (dhs distributedHTTPSolver) challengeTokensBasePath() string {
return filepath.Join(caddy.AssetsPath(), "acme", "challenge_tokens")
}
// ConfigHolder is any type that has a Config; it presumably is

View file

@ -18,7 +18,7 @@ import (
"os"
"testing"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
func TestHostQualifies(t *testing.T) {
@ -37,7 +37,10 @@ func TestHostQualifies(t *testing.T) {
{"0.0.0.0", false},
{"", false},
{" ", false},
{"*.example.com", false},
{"*.example.com", true},
{"*.*.example.com", false},
{"sub.*.example.com", false},
{"*sub.example.com", false},
{".com", false},
{"example.com.", false},
{"localhost", false},
@ -77,7 +80,10 @@ func TestQualifiesForManagedTLS(t *testing.T) {
{holder{host: "localhost", cfg: new(Config)}, false},
{holder{host: "123.44.3.21", cfg: new(Config)}, false},
{holder{host: "example.com", cfg: new(Config)}, true},
{holder{host: "*.example.com", cfg: new(Config)}, false},
{holder{host: "*.example.com", cfg: new(Config)}, true},
{holder{host: "*.*.example.com", cfg: new(Config)}, false},
{holder{host: "*sub.example.com", cfg: new(Config)}, false},
{holder{host: "sub.*.example.com", cfg: new(Config)}, false},
{holder{host: "example.com", cfg: &Config{Manual: true}}, false},
{holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false},
{holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true},

View file

@ -27,7 +27,7 @@ import (
"os"
"strings"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
// User represents a Let's Encrypt user account.
@ -67,43 +67,82 @@ func newUser(email string) (User, error) {
return user, nil
}
// getEmail does everything it can to obtain an email
// address from the user within the scope of storage
// to use for ACME TLS. If it cannot get an email
// address, it returns empty string. (It will warn the
// user of the consequences of an empty email.) This
// function MAY prompt the user for input. If userPresent
// is false, the operator will NOT be prompted and an
// empty email may be returned.
func getEmail(storage Storage, userPresent bool) string {
// getEmail does everything it can to obtain an email address
// from the user within the scope of memory and storage to use
// for ACME TLS. If it cannot get an email address, it returns
// empty string. (If user is present, it will warn the user of
// the consequences of an empty email.) This function MAY prompt
// the user for input. If userPresent is false, the operator
// will NOT be prompted and an empty email may be returned.
// If the user is prompted, a new User will be created and
// stored in storage according to the email address they
// provided (which might be blank).
func getEmail(cfg *Config, userPresent bool) (string, error) {
storage, err := cfg.StorageFor(cfg.CAUrl)
if err != nil {
return "", err
}
// First try memory (command line flag or typed by user previously)
leEmail := DefaultEmail
// Then try to get most recent user email from storage
if leEmail == "" {
// Then try to get most recent user email
leEmail = storage.MostRecentUserEmail()
// Save for next time
DefaultEmail = leEmail
DefaultEmail = leEmail // save for next time
}
// Looks like there is no email address readily available,
// so we will have to ask the user if we can.
if leEmail == "" && userPresent {
// Alas, we must bother the user and ask for an email address;
// if they proceed they also agree to the SA.
reader := bufio.NewReader(stdin)
fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.")
fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:")
fmt.Println(" " + saURL) // TODO: Show current SA link
fmt.Println("Please enter your email address so you can recover your account if needed.")
fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
fmt.Print("Email address: ")
var err error
leEmail, err = reader.ReadString('\n')
// evidently, no User data was present in storage;
// thus we must make a new User so that we can get
// the Terms of Service URL via our ACME client, phew!
user, err := newUser("")
if err != nil {
return ""
return "", err
}
// get the agreement URL
agreementURL := agreementTestURL
if agreementURL == "" {
// we call acme.NewClient directly because newACMEClient
// would require that we already know the user's email
caURL := DefaultCAUrl
if cfg.CAUrl != "" {
caURL = cfg.CAUrl
}
tempClient, err := acme.NewClient(caURL, user, "")
if err != nil {
return "", fmt.Errorf("making ACME client to get ToS URL: %v", err)
}
agreementURL = tempClient.GetToSURL()
}
// prompt the user for an email address and terms agreement
reader := bufio.NewReader(stdin)
promptUserAgreement(agreementURL)
fmt.Println("Please enter your email address to signify agreement and to be notified")
fmt.Println("in case of issues. You can leave it blank, but we don't recommend it.")
fmt.Print(" Email address: ")
leEmail, err = reader.ReadString('\n')
if err != nil && err != io.EOF {
return "", fmt.Errorf("reading email address: %v", err)
}
leEmail = strings.TrimSpace(leEmail)
DefaultEmail = leEmail
Agreed = true
// save the new user to preserve this for next time
user.Email = leEmail
err = saveUser(storage, user)
if err != nil {
return "", err
}
}
return strings.ToLower(leEmail)
// lower-casing the email is important for consistency
return strings.ToLower(leEmail), nil
}
// getUser loads the user with the given email from disk
@ -154,18 +193,21 @@ func saveUser(storage Storage, user User) error {
return err
}
// promptUserAgreement prompts the user to agree to the agreement
// at agreementURL via stdin. If the agreement has changed, then pass
// true as the second argument. If this is the user's first time
// agreeing, pass false. It returns whether the user agreed or not.
func promptUserAgreement(agreementURL string, changed bool) bool {
if changed {
fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL)
fmt.Print("Do you agree to the new terms? (y/n): ")
} else {
fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL)
fmt.Print("Do you agree to the terms? (y/n): ")
}
// promptUserAgreement simply outputs the standard user
// agreement prompt with the given agreement URL.
// It outputs a newline after the message.
func promptUserAgreement(agreementURL string) {
const userAgreementPrompt = `Your sites will be served over HTTPS automatically using Let's Encrypt.
By continuing, you agree to the Let's Encrypt Subscriber Agreement at:`
fmt.Printf("\n\n%s\n %s\n", userAgreementPrompt, agreementURL)
}
// askUserAgreement prompts the user to agree to the agreement
// at the given agreement URL via stdin. It returns whether the
// user agreed or not.
func askUserAgreement(agreementURL string) bool {
promptUserAgreement(agreementURL)
fmt.Print("Do you agree to the terms? (y/n): ")
reader := bufio.NewReader(stdin)
answer, err := reader.ReadString('\n')
@ -177,14 +219,15 @@ func promptUserAgreement(agreementURL string, changed bool) bool {
return answer == "y" || answer == "yes"
}
// agreementTestURL is set during tests to skip requiring
// setting up an entire ACME CA endpoint.
var agreementTestURL string
// stdin is used to read the user's input if prompted;
// this is changed by tests during tests.
var stdin = io.ReadWriter(os.Stdin)
// The name of the folder for accounts where the email
// address was not provided; default 'username' if you will.
// address was not provided; default 'username' if you will,
// but only for local/storage use, not with the CA.
const emptyEmail = "default"
// TODO: After Boulder implements the 'meta' field of the directory,
// we can get this link dynamically.
const saURL = "https://acme-v01.api.letsencrypt.org/terms"

View file

@ -20,13 +20,14 @@ import (
"crypto/elliptic"
"crypto/rand"
"io"
"path/filepath"
"strings"
"testing"
"time"
"os"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acmev2"
)
func TestUser(t *testing.T) {
@ -135,7 +136,13 @@ func TestGetUserAlreadyExists(t *testing.T) {
}
func TestGetEmail(t *testing.T) {
storageBasePath = testStorage.Path // to contain calls that create a new Storage...
// ensure storage (via StorageFor) uses the local testdata folder that we delete later
origCaddypath := os.Getenv("CADDYPATH")
os.Setenv("CADDYPATH", "./testdata")
defer os.Setenv("CADDYPATH", origCaddypath)
agreementTestURL = "(none - testing)"
defer func() { agreementTestURL = "" }()
// let's not clutter up the output
origStdout := os.Stdout
@ -146,7 +153,10 @@ func TestGetEmail(t *testing.T) {
DefaultEmail = "test2@foo.com"
// Test1: Use default email from flag (or user previously typing it)
actual := getEmail(testStorage, true)
actual, err := getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (1) error: %v", err)
}
if actual != DefaultEmail {
t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual)
}
@ -154,16 +164,19 @@ func TestGetEmail(t *testing.T) {
// Test2: Get input from user
DefaultEmail = ""
stdin = new(bytes.Buffer)
_, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
_, err = io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
if err != nil {
t.Fatalf("Could not simulate user input, error: %v", err)
}
actual = getEmail(testStorage, true)
actual, err = getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (2) error: %v", err)
}
if actual != "test3@foo.com" {
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
}
// Test3: Get most recent email from before
// Test3: Get most recent email from before (in storage)
DefaultEmail = ""
for i, eml := range []string{
"TEST4-3@foo.com", // test case insensitivity
@ -189,14 +202,20 @@ func TestGetEmail(t *testing.T) {
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
}
}
actual = getEmail(testStorage, true)
actual, err = getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (3) error: %v", err)
}
if actual != "test4-3@foo.com" {
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
}
}
var testStorage = &FileStorage{Path: "./testdata"}
var (
testStorageBase = "./testdata" // ephemeral folder that gets deleted after tests finish
testCAHost = "localhost"
testConfig = &Config{CAUrl: "http://" + testCAHost + "/directory", StorageProvider: "file"}
testStorage = &FileStorage{Path: filepath.Join(testStorageBase, "acme", testCAHost)}
)
func (s *FileStorage) clean() error {
return os.RemoveAll(s.Path)
}
func (s *FileStorage) clean() error { return os.RemoveAll(testStorageBase) }

59
dist/CHANGES.txt vendored
View file

@ -1,5 +1,64 @@
CHANGES
0.10.14 (April 19, 2018)
- tls: Fix error handling bug when obtaining certificates
0.10.13 (April 18, 2018)
- New third-party plugin: supervisor
- Updated QUIC
- proxy: Fix transparent pass-thru of X-Forwarded-For
- proxy: Configurable timeout to upstream
- rewrite: Now supports regular expressions on single-line
- tls: StrictHostMatching mode to prevent client auth bypass
- tls: Disable client auth when using QUIC
- tls: Require same client auth cert pools per hostname
- tls: Prevent On-Demand TLS directory traversal
- tls: Fix empty files when using ACME fails to obtain cert
- Fixed test broken by 1.1.1.1 resolving
- Improved Caddyfile parser robustness by fuzzing
0.10.12 (March 27, 2018)
- Switch to Let's Encrypt ACMEv2 production endpoint
- Support for automated wildcard certificates
- Support distributed solving of HTTP-01 challenge
- New {labelN}, {tls_cipher}, and {tls_version} placeholders
- Curly braces can now be escaped when not used as placeholders
- New third-party plugin: geoip
- Updated QUIC
- fastcgi: Add SSL_CIPHER and SSL_PROTOCOL environment variables
- log: New 'except' subdirective to exempt paths from logging
- startup/shutdown: Removed in favor of 'on'
- tls: Default minimum version is TLS 1.2
- tls: Revert to fallback cert if no cert matches SNI
- tls: New 'wildcard' subdirective to force automated wildcard cert
- Several significant bug fixes and improvements!
0.10.11 (February 20, 2018)
- Built with Go 1.10
- Reusable snippets for the Caddyfile
- Updated QUIC
- Auto-HTTPS certificates may be shared by multiple instances
- Expand globbed values in -conf flag
- Swap behavior of SIGTERM and SIGQUIT; ignore SIGHUP
- 9 new DNS provider plugins for the ACME DNS challenge
- New placeholder for {<Response-Header} values
- basicauth: Username put in {user} placeholder
- fastcgi: GET requests can now send a body
- proxy: Service discovery with DNS SRV load balancing
- request_id: Allow reusing request ID from header field
- tls: Improved efficiency of many certificates and reloads
- tls: Raise error if conflicting TLS configurations collide
- tls: Raise TLS alert if SNI used and no cert matched
- tls: Reject OCSP responses that expire after the certificate
- tls: Clients can use SNI to request a specific certificate
- tls: Add option for backend to approve on-demand certificate
- tls: Synchronize maintenance of shared, managed certificates
- Numerous fabulous bug fixes
0.10.10 (October 9, 2017)
- Built with Go 1.9.1
- Removed Caddy-Sponsors header

6
dist/README.txt vendored
View file

@ -1,4 +1,4 @@
CADDY 0.10.10
CADDY 0.10.14
Website
https://caddyserver.com
@ -32,9 +32,9 @@ the project wiki: https://github.com/mholt/caddy/wiki
And thanks - you're awesome!
If you think Caddy is awesome too, consider sponsoring it:
https://caddyserver.com/pricing - and help keep Caddy free
https://caddyserver.com/sponsor - and help keep Caddy free
for personal use.
---
(c) 2015-2017 Light Code Labs, LLC
(c) 2015-2018 Light Code Labs, LLC

View file

@ -39,7 +39,7 @@ var (
// eventHooks is a map of hook name to Hook. All hooks plugins
// must have a name.
eventHooks = sync.Map{}
eventHooks = &sync.Map{}
// parsingCallbacks maps server type to map of directive
// to list of callback functions. These aren't really
@ -296,6 +296,36 @@ func EmitEvent(event EventName, info interface{}) {
})
}
// cloneEventHooks return a clone of the event hooks *sync.Map
func cloneEventHooks() *sync.Map {
c := &sync.Map{}
eventHooks.Range(func(k, v interface{}) bool {
c.Store(k, v)
return true
})
return c
}
// purgeEventHooks purges all event hooks from the map
func purgeEventHooks() {
eventHooks.Range(func(k, _ interface{}) bool {
eventHooks.Delete(k)
return true
})
}
// restoreEventHooks restores eventHooks with a provided *sync.Map
func restoreEventHooks(m *sync.Map) {
// Purge old event hooks
purgeEventHooks()
// Restore event hooks
m.Range(func(k, v interface{}) bool {
eventHooks.Store(k, v)
return true
})
}
// ParsingCallback is a function that is called after
// a directive's setup functions have been executed
// for all the server blocks.

View file

@ -83,9 +83,17 @@ func trapSignalsPosix() {
caddyfileToUse = newCaddyfile
}
// Backup old event hooks
oldEventHooks := cloneEventHooks()
// Purge the old event hooks
purgeEventHooks()
// Kick off the restart; our work is done
_, err = inst.Restart(caddyfileToUse)
if err != nil {
restoreEventHooks(oldEventHooks)
log.Printf("[ERROR] SIGUSR1: %v", err)
}

View file

@ -1,98 +0,0 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package startupshutdown
import (
"fmt"
"strings"
"github.com/google/uuid"
"github.com/mholt/caddy"
"github.com/mholt/caddy/onevent/hook"
)
func init() {
caddy.RegisterPlugin("startup", caddy.Plugin{Action: Startup})
caddy.RegisterPlugin("shutdown", caddy.Plugin{Action: Shutdown})
}
// Startup (an alias for 'on startup') registers a startup callback to execute during server start.
func Startup(c *caddy.Controller) error {
config, err := onParse(c, caddy.InstanceStartupEvent)
if err != nil {
return c.ArgErr()
}
// Register Event Hooks.
c.OncePerServerBlock(func() error {
for _, cfg := range config {
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
}
return nil
})
fmt.Println("NOTICE: Startup directive will be removed in a later version. Please migrate to 'on startup'")
return nil
}
// Shutdown (an alias for 'on shutdown') registers a shutdown callback to execute during server start.
func Shutdown(c *caddy.Controller) error {
config, err := onParse(c, caddy.ShutdownEvent)
if err != nil {
return c.ArgErr()
}
// Register Event Hooks.
for _, cfg := range config {
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
}
fmt.Println("NOTICE: Shutdown directive will be removed in a later version. Please migrate to 'on shutdown'")
return nil
}
func onParse(c *caddy.Controller, event caddy.EventName) ([]*hook.Config, error) {
var config []*hook.Config
for c.Next() {
cfg := new(hook.Config)
args := c.RemainingArgs()
if len(args) == 0 {
return config, c.ArgErr()
}
// Configure Event.
cfg.Event = event
// Assign an unique ID.
cfg.ID = uuid.New().String()
// Extract command and arguments.
command, args, err := caddy.SplitCommandAndArgs(strings.Join(args, " "))
if err != nil {
return config, c.Err(err.Error())
}
cfg.Command = command
cfg.Args = args
config = append(config, cfg)
}
return config, nil
}

View file

@ -1,69 +0,0 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package startupshutdown
import (
"testing"
"github.com/mholt/caddy"
)
func TestStartup(t *testing.T) {
tests := []struct {
name string
input string
shouldErr bool
}{
{name: "noInput", input: "startup", shouldErr: true},
{name: "startup", input: "startup cmd arg", shouldErr: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
c := caddy.NewTestController("", test.input)
err := Startup(c)
if err == nil && test.shouldErr {
t.Error("Test didn't error, but it should have")
} else if err != nil && !test.shouldErr {
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
}
})
}
}
func TestShutdown(t *testing.T) {
tests := []struct {
name string
input string
shouldErr bool
}{
{name: "noInput", input: "shutdown", shouldErr: true},
{name: "shutdown", input: "shutdown cmd arg", shouldErr: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
c := caddy.NewTestController("", test.input)
err := Shutdown(c)
if err == nil && test.shouldErr {
t.Error("Test didn't error, but it should have")
} else if err != nil && !test.shouldErr {
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
}
})
}
}

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64,!s390x
// +build !amd64
package aes12

View file

@ -8,19 +8,20 @@ import (
var bufferPool sync.Pool
func getPacketBuffer() []byte {
return bufferPool.Get().([]byte)
func getPacketBuffer() *[]byte {
return bufferPool.Get().(*[]byte)
}
func putPacketBuffer(buf []byte) {
if cap(buf) != int(protocol.MaxReceivePacketSize) {
func putPacketBuffer(buf *[]byte) {
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!")
}
bufferPool.Put(buf[:0])
bufferPool.Put(buf)
}
func init() {
bufferPool.New = func() interface{} {
return make([]byte, 0, protocol.MaxReceivePacketSize)
b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
}
}

View file

@ -38,6 +38,8 @@ type client struct {
version protocol.VersionNumber
session packetHandler
logger utils.Logger
}
var (
@ -85,6 +87,14 @@ func Dial(
}
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
@ -94,9 +104,10 @@ func Dial(
config: clientConfig,
version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}),
logger: utils.DefaultLogger,
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.dial(); err != nil {
return nil, err
@ -132,6 +143,18 @@ func populateClientConfig(config *Config) *Config {
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
return &Config{
Versions: versions,
@ -140,7 +163,9 @@ func populateClientConfig(config *Config) *Config {
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
KeepAlive: config.KeepAlive,
}
}
@ -171,12 +196,11 @@ func (c *client) dialTLS() error {
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
// TODO(#523): make these values configurable
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
}
csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
@ -193,7 +217,7 @@ func (c *client) dialTLS() error {
if err != handshake.ErrCloseSessionForRetry {
return err
}
utils.Infof("Received a Retry packet. Recreating session.")
c.logger.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
@ -216,7 +240,7 @@ func (c *client) establishSecureConnection() error {
go func() {
runErr = c.session.run() // returns as soon as the session is closed
close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.logger.Infof("Connection %x closed.", c.connectionID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close()
}
@ -245,7 +269,7 @@ func (c *client) listen() {
for {
var n int
var addr net.Addr
data := getPacketBuffer()
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
@ -270,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
r := bytes.NewReader(packet)
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
if err != nil {
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the header
return
}
@ -293,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
c.logger.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := wire.ParsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return
}
@ -347,6 +371,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
}
}
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
@ -362,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
if err != nil {
return err
}
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
return nil
}
@ -379,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) {
c.config,
c.initialVersion,
c.negotiatedVersions,
c.logger,
)
return err
}
@ -398,6 +425,7 @@ func (c *client) createNewTLSSession(
c.tls,
paramsChan,
1,
c.logger,
)
return err
}

View file

@ -19,12 +19,14 @@ func main() {
flag.Parse()
urls := flag.Args()
logger := utils.DefaultLogger
if *verbose {
utils.SetLogLevel(utils.LogLevelDebug)
logger.SetLogLevel(utils.LogLevelDebug)
} else {
utils.SetLogLevel(utils.LogLevelInfo)
logger.SetLogLevel(utils.LogLevelInfo)
}
utils.SetLogTimeFormat("")
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
@ -42,21 +44,21 @@ func main() {
var wg sync.WaitGroup
wg.Add(len(urls))
for _, addr := range urls {
utils.Infof("GET %s", addr)
logger.Infof("GET %s", addr)
go func(addr string) {
rsp, err := hclient.Get(addr)
if err != nil {
panic(err)
}
utils.Infof("Got response for %s: %#v", addr, rsp)
logger.Infof("Got response for %s: %#v", addr, rsp)
body := &bytes.Buffer{}
_, err = io.Copy(body, rsp.Body)
if err != nil {
panic(err)
}
utils.Infof("Request Body:")
utils.Infof("%s", body.Bytes())
logger.Infof("Request Body:")
logger.Infof("%s", body.Bytes())
wg.Done()
}(addr)
}

View file

@ -91,7 +91,7 @@ func init() {
}
}
if err != nil {
utils.Infof("Error receiving upload: %#v", err)
utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
}
}
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
@ -126,12 +126,14 @@ func main() {
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
logger := utils.DefaultLogger
if *verbose {
utils.SetLogLevel(utils.LogLevelDebug)
logger.SetLogLevel(utils.LogLevelDebug)
} else {
utils.SetLogLevel(utils.LogLevelInfo)
logger.SetLogLevel(utils.LogLevelInfo)
}
utils.SetLogTimeFormat("")
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {

View file

@ -46,6 +46,8 @@ type client struct {
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
logger utils.Logger
}
var _ http.RoundTripper = &client{}
@ -75,6 +77,7 @@ func newClient(
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
logger: utils.DefaultLogger,
}
}
@ -95,7 +98,7 @@ func (c *client) dial() error {
if err != nil {
return err
}
c.requestWriter = newRequestWriter(c.headerStream)
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
go c.handleHeaderStream()
return nil
}
@ -108,7 +111,9 @@ func (c *client) handleHeaderStream() {
for err == nil {
err = c.readResponse(h2framer, decoder)
}
utils.Debugf("Error handling header stream: %s", err)
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
c.logger.Debugf("Error handling header stream: %s", err)
}
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
@ -202,6 +207,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
bodySent = true
}
ctx := req.Context()
for !(bodySent && receivedResponse) {
select {
case res = <-responseChan:
@ -214,8 +220,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil {
return nil, err
}
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored:
// an error occured on the header stream
// an error occurred on the header stream
_ = c.CloseWithError(c.headerErr)
return nil, c.headerErr
}

View file

@ -23,13 +23,16 @@ type requestWriter struct {
henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this
logger utils.Logger
}
const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream quic.Stream) *requestWriter {
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
rw := &requestWriter{
headerStream: headerStream,
logger: logger,
}
rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw
@ -76,9 +79,8 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return nil, fmt.Errorf("invalid request :path %q", orig)
}
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
@ -157,7 +159,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
}
func (w *requestWriter) writeHeader(name, value string) {
utils.Debugf("http2: Transport encoding header %q = %q", name, value)
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}

View file

@ -3,7 +3,6 @@ package h2quic
import (
"bytes"
"errors"
"io"
"io/ioutil"
"net/http"
"net/textproto"
@ -16,7 +15,7 @@ import (
// copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil))
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
@ -33,16 +32,7 @@ func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
return nil, errors.New("malformed non-numeric status pseudo header")
}
if statusCode == 100 {
// TODO: handle this
// traceGot100Continue(cs.trace)
// if cs.on100 != nil {
// cs.on100() // forces any write delay timer to fire
// }
// cs.pastHeaders = false // do it all again
// return nil, nil
}
// TODO: handle statusCode == 100
header := make(http.Header)
res := &http.Response{
@ -78,13 +68,7 @@ func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
} else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
} else if len(clens) > 1 {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
}
return res

View file

@ -24,15 +24,24 @@ type responseWriter struct {
header http.Header
status int // status code passed to WriteHeader
headerWritten bool
logger utils.Logger
}
func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter {
func newResponseWriter(
headerStream quic.Stream,
headerStreamMutex *sync.Mutex,
dataStream quic.Stream,
dataStreamID protocol.StreamID,
logger utils.Logger,
) *responseWriter {
return &responseWriter{
header: http.Header{},
headerStream: headerStream,
headerStreamMutex: headerStreamMutex,
dataStream: dataStream,
dataStreamID: dataStreamID,
logger: logger,
}
}
@ -57,7 +66,7 @@ func (w *responseWriter) WriteHeader(status int) {
}
}
utils.Infof("Responding with %d", status)
w.logger.Infof("Responding with %d", status)
w.headerStreamMutex.Lock()
defer w.headerStreamMutex.Unlock()
h2framer := http2.NewFramer(w.headerStream, nil)
@ -67,7 +76,7 @@ func (w *responseWriter) WriteHeader(status int) {
BlockFragment: headers.Bytes(),
})
if err != nil {
utils.Errorf("could not write h2 header: %s", err.Error())
w.logger.Errorf("could not write h2 header: %s", err.Error())
}
}

View file

@ -53,6 +53,8 @@ type Server struct {
closed bool
supportedVersionsAsString string
logger utils.Logger // will be set by Server.serveImpl()
}
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
@ -88,6 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
s.logger = utils.DefaultLogger
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
@ -138,7 +141,7 @@ func (s *Server) handleHeaderStream(session streamCreator) {
// In this case, the session has already logged the error, so we don't
// need to log it again.
if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error())
s.logger.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(err)
return
@ -160,7 +163,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
}
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
if err != nil {
utils.Errorf("invalid http2 headers encoding: %s", err.Error())
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
return err
}
@ -169,10 +172,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
return err
}
if utils.Debug() {
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
if s.logger.Debug() {
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else {
utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
}
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
@ -201,7 +204,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
handler := s.Handler
if handler == nil {
@ -215,7 +218,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
utils.Errorf("http: panic serving: %v\n%s", p, buf)
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
panicked = true
}
}()

View file

@ -7,6 +7,7 @@ import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Connection is a UDP connection
@ -43,6 +44,8 @@ func (d Direction) String() string {
}
}
// Is says if one direction matches another direction.
// For example, incoming matches both incoming and both, but not outgoing.
func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth {
return true
@ -92,6 +95,8 @@ type QuicProxy struct {
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
logger utils.Logger
}
// NewQuicProxy creates a new UDP proxy
@ -129,14 +134,23 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu
dropPacket: packetDropper,
delayPacket: packetDelayer,
version: version,
logger: utils.DefaultLogger,
}
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
go p.runProxy()
return &p, nil
}
// Close stops the UDP Proxy
func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
return p.conn.Close()
}
@ -189,19 +203,27 @@ func (p *QuicProxy) runProxy() error {
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
if p.dropPacket(DirectionIncoming, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
}
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = conn.ServerConn.Write(raw)
})
} else {
_, err := conn.ServerConn.Write(raw)
if err != nil {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err
}
}
@ -221,18 +243,26 @@ func (p *QuicProxy) runConnection(conn *connection) error {
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
if p.dropPacket(DirectionOutgoing, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
})
} else {
_, err := p.conn.WriteToUDP(raw, conn.ClientAddr)
if err != nil {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return err
}
}

View file

@ -30,7 +30,7 @@ var _ = BeforeEach(func() {
logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.SetLogLevel(utils.LogLevelDebug)
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
}
})

View file

@ -22,7 +22,9 @@ const (
)
var (
PRData = GeneratePRData(dataLen)
// PRData contains dataLen bytes of pseudo-random data.
PRData = GeneratePRData(dataLen)
// PRDataLong contains dataLenLong bytes of pseudo-random data.
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
@ -105,11 +107,13 @@ func StartQuicServer(versions []protocol.VersionNumber) {
}()
}
// StopQuicServer stops the h2quic.Server.
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed())
}
// Port returns the UDP port of the QUIC server.
func Port() string {
return port
}

View file

@ -16,6 +16,9 @@ type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
// VersionGQUIC39 is gQUIC version 39.
const VersionGQUIC39 = protocol.Version39
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
@ -113,15 +116,25 @@ type StreamError interface {
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached.
// New streams always have the smallest possible stream ID.
// TODO: Enable testing for the special error
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
AcceptUniStream() (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error)
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened.
// It always picks the smallest possible stream ID.
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenStreamSync() (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenUniStreamSync() (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
@ -166,6 +179,17 @@ type Config struct {
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow uint64
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
}

View file

@ -0,0 +1,3 @@
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet

View file

@ -10,15 +10,13 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
SentPacket(packet *Packet)
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
// SendingAllowed says if a packet can be sent.
// Sending packets might not be possible because:
// * we're congestion limited
// * we're tracking the maximum number of sent packets
SendingAllowed() bool
// The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
@ -32,10 +30,10 @@ type SentPacketHandler interface {
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
GetAlarmTimeout() time.Time
OnAlarm()
OnAlarm() error
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets

View file

@ -8,28 +8,22 @@ import (
)
// A Packet is a packet
// +gen linkedlist
type Packet struct {
PacketNumber protocol.PacketNumber
PacketType protocol.PacketType
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
sendTime time.Time
}
// GetFramesForRetransmission gets all the frames for retransmission
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
var fs []wire.Frame
for _, frame := range p.Frames {
switch frame.(type) {
case *wire.AckFrame:
continue
case *wire.StopWaitingFrame:
continue
}
fs = append(fs, frame)
}
return fs
// There are two reasons why a packet cannot be retransmitted:
// * it was already retransmitted
// * this packet is a retransmission, and we already received an ACK for the original packet
canBeRetransmitted bool
includedInBytesInFlight bool
retransmittedAs []protocol.PacketNumber
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
retransmissionOf protocol.PacketNumber
}

View file

@ -1,13 +1,10 @@
// Generated by: main
// TypeWriter: linkedlist
// Directive: +gen on Packet
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package ackhandler
// List is a modification of http://golang.org/pkg/container/list/
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Linked list implementation from the Go standard library.
// PacketElement is an element of a linked list.
type PacketElement struct {
@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
return nil
}
// PacketList represents a doubly linked list.
// The zero value for PacketList is an empty list ready to use.
// PacketList is a linked list of Packets.
type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
// The complexity is O(1).
func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil.
// Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement {
if l.len == 0 {
return nil
@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
return l.root.next
}
// Back returns the last element of list l or nil.
// Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement {
if l.len == 0 {
return nil
@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
return l.root.prev
}
// lazyInit lazily initializes a zero PacketList value.
// lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() {
if l.root.next == nil {
l.Init()
@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
return e
}
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at).
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at)
}
@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero PacketElement) and l.remove will crash
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
}
// MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {

View file

@ -3,7 +3,9 @@ package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
@ -15,6 +17,7 @@ type receivedPacketHandler struct {
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
@ -25,29 +28,54 @@ type receivedPacketHandler struct {
version protocol.VersionNumber
}
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
// initial maximum number of retransmittable packets received before sending an ack.
initialRetransmittablePacketsBeforeAck = 2
// number of retransmittable that an ACK is sent for
retransmittablePacketsBeforeAck = 10
// 1/5 RTT delay when doing ack decimation
ackDecimationDelay = 1.0 / 4
// 1/8 RTT delay when doing ack decimation
shortAckDecimationDelay = 1.0 / 8
// Minimum number of packets received before ack decimation is enabled.
// This intends to avoid the beginning of slow start, when CWNDs may be
// rapidly increasing.
minReceivedBeforeAckDecimation = 100
// Maximum number of packets to ack immediately after a missing packet for
// fast retransmission to kick in at the sender. This limit is created to
// reduce the number of acks sent that have no benefit for fast retransmission.
// Set to the number of nacks needed for fast retransmit plus one for protection
// against an ack loss
maxPacketsAfterNewMissing = 4
)
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: protocol.AckSendDelay,
ackSendDelay: ackSendDelay,
rttStats: rttStats,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if packetNumber < h.ignoreBelow {
return nil
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil
}
@ -58,35 +86,68 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
h.packetHistory.DeleteBelow(p)
}
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil {
return false
}
return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
// in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.ackQueued = true
return
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
h.ackQueued = true
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
h.ackQueued = true
h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
}
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = rcvTime.Add(h.ackSendDelay)
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)
ackTime := rcvTime.Add(time.Duration(ackDelay))
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
}
}
}
@ -118,7 +179,6 @@ func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}

View file

@ -0,0 +1,36 @@
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendAny packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendRTO:
return "rto"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}

View file

@ -1,7 +1,6 @@
package ackhandler
import (
"errors"
"fmt"
"math"
"time"
@ -30,13 +29,13 @@ const (
maxRTOTimeout = 60 * time.Second
)
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
var ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
nextPacketSendTime time.Time
skippedPackets []protocol.PacketNumber
lastSentPacketNumber protocol.PacketNumber
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
nextPacketSendTime time.Time
skippedPackets []protocol.PacketNumber
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
@ -44,8 +43,9 @@ type sentPacketHandler struct {
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
largestSentBeforeRTO protocol.PacketNumber
packetHistory *PacketList
packetHistory *sentPacketHistory
stopWaitingManager stopWaitingManager
retransmissionQueue []*Packet
@ -61,16 +61,20 @@ type sentPacketHandler struct {
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The number of RTO probe packets that should be sent.
numRTOs int
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
logger utils.Logger
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
@ -80,16 +84,17 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
)
return &sentPacketHandler{
packetHistory: NewPacketList(),
packetHistory: newSentPacketHistory(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
logger: logger,
}
}
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
if f := h.packetHistory.Front(); f != nil {
return f.Value.PacketNumber
if p := h.packetHistory.FirstOutstanding(); p != nil {
return p.PacketNumber
}
return h.largestAcked + 1
}
@ -101,30 +106,51 @@ func (h *sentPacketHandler) SetHandshakeComplete() {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.packetHistory.Remove(p.PacketNumber)
}
h.retransmissionQueue = queue
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
return errors.New("Too many outstanding non-acked and non-retransmitted packets")
func (h *sentPacketHandler) SentPacket(packet *Packet) {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
h.packetHistory.SentPacket(packet)
h.updateLossDetectionAlarm()
}
}
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
var p []*Packet
for _, packet := range packets {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
p = append(p, packet)
}
}
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p)
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
h.skippedPackets = h.skippedPackets[1:]
}
}
now := time.Now()
h.lastSentPacketNumber = packet.PacketNumber
var largestAcked protocol.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
largestAcked = ackFrame.LargestAcked
packet.largestAcked = ackFrame.LargestAcked
}
}
@ -132,24 +158,21 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.sendTime = now
packet.largestAcked = largestAcked
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
h.lastSentHandshakePacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet)
packet.canBeRetransmitted = true
if h.numRTOs > 0 {
h.numRTOs--
}
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
h.congestion.OnPacketSent(
now,
h.bytesInFlight,
packet.PacketNumber,
packet.Length,
isRetransmittable,
)
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
h.updateLossDetectionAlarm(now)
return nil
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
return isRetransmittable
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
@ -157,26 +180,19 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
}
// duplicate or out-of-order ACK
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
if withPacketNumber <= h.largestReceivedPacketWithAck {
return ErrDuplicateOrOutOfOrderAck
}
h.largestReceivedPacketWithAck = withPacketNumber
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
if ackFrame.LargestAcked < h.lowestUnacked() {
// duplicate or out of order ACK
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
return nil
}
h.largestAcked = ackFrame.LargestAcked
h.largestReceivedPacketWithAck = withPacketNumber
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, ackFrame.LargestAcked)
if h.skippedPacketsAcked(ackFrame) {
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
if rttUpdated {
if rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
h.congestion.MaybeExitSlowStart()
}
@ -185,24 +201,29 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
return err
}
if len(ackedPackets) > 0 {
for _, p := range ackedPackets {
if encLevel < p.Value.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
}
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.Value.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
}
h.onPacketAcked(p)
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
priorInFlight := h.bytesInFlight
for _, p := range ackedPackets {
if encLevel < p.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
}
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
}
if err := h.onPacketAcked(p); err != nil {
return err
}
if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight)
}
}
h.detectLostPackets(rcvTime)
h.updateLossDetectionAlarm(rcvTime)
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
return err
}
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
@ -214,59 +235,50 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
var ackedPackets []*PacketElement
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
var ackedPackets []*Packet
ackRangeIndex := 0
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
packetNumber := packet.PacketNumber
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the LowestAcked
if packetNumber < ackFrame.LowestAcked {
continue
if p.PacketNumber < ackFrame.LowestAcked {
return true, nil
}
// Break after LargestAcked is reached
if packetNumber > ackFrame.LargestAcked {
break
if p.PacketNumber > ackFrame.LargestAcked {
return false, nil
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
for p.PacketNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if packetNumber >= ackRange.First { // packet i contained in ACK range
if packetNumber > ackRange.Last {
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
if p.PacketNumber >= ackRange.First { // packet i contained in ACK range
if p.PacketNumber > ackRange.Last {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.First, ackRange.Last)
}
ackedPackets = append(ackedPackets, el)
ackedPackets = append(ackedPackets, p)
}
} else {
ackedPackets = append(ackedPackets, el)
ackedPackets = append(ackedPackets, p)
}
}
return ackedPackets, nil
return true, nil
})
return ackedPackets, err
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber == largestAcked {
h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
return true
}
// Packets are sorted by number, so we can stop searching
if packet.PacketNumber > largestAcked {
break
}
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
return true
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
h.alarm = time.Time{}
@ -275,76 +287,152 @@ func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
// TODO(#497): TLP
if !h.handshakeComplete {
h.alarm = now.Add(h.computeHandshakeTimeout())
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO
h.alarm = now.Add(h.computeRTOTimeout())
h.alarm = h.lastSentRetransmittablePacketTime.Add(h.computeRTOTimeout())
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time) {
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
h.lossTime = time.Time{}
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
var lostPackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
var lostPackets []*Packet
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
if packet.PacketNumber > h.largestAcked {
break
return false, nil
}
timeSinceSent := now.Sub(packet.sendTime)
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, el)
lostPackets = append(lostPackets, packet)
} else if h.lossTime.IsZero() {
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
}
return true, nil
})
if len(lostPackets) > 0 {
for _, p := range lostPackets {
h.queuePacketForRetransmission(p)
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
for _, p := range lostPackets {
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
if p.canBeRetransmitted {
// queue the packet for retransmission, and report the loss to the congestion controller
h.logger.Debugf("\tQueueing packet %#x because it was detected lost", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
h.packetHistory.Remove(p.PacketNumber)
}
return nil
}
func (h *sentPacketHandler) OnAlarm() {
func (h *sentPacketHandler) OnAlarm() error {
now := time.Now()
// TODO(#497): TLP
var err error
if !h.handshakeComplete {
h.queueHandshakePacketsForRetransmission()
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
} else if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
h.detectLostPackets(now)
err = h.detectLostPackets(now, h.bytesInFlight)
} else {
// RTO
h.retransmitOldestTwoPackets()
h.rtoCount++
h.numRTOs += 2
err = h.queueRTOs()
}
h.updateLossDetectionAlarm(now)
if err != nil {
return err
}
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
h.bytesInFlight -= packetElement.Value.Length
func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
// This happens if a packet and its retransmissions is acked in the same ACK.
// As soon as we process the first one, this will remove all the retransmissions,
// so we won't find the retransmitted packet number later.
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
return nil
}
// only report the acking of this packet to the congestion controller if:
// * it is a retransmittable packet
// * this packet wasn't retransmitted yet
if p.isRetransmission {
// that the parent doesn't exist is expected to happen every time the original packet was already acked
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
if len(parent.retransmittedAs) == 1 {
parent.retransmittedAs = nil
} else {
// remove this packet from the slice of retransmission
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
for _, pn := range parent.retransmittedAs {
if pn != p.PacketNumber {
retransmittedAs = append(retransmittedAs, pn)
}
}
parent.retransmittedAs = retransmittedAs
}
}
}
// this also applies to packets that have been retransmitted as probe packets
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if h.rtoCount > 0 {
h.verifyRTO(p.PacketNumber)
}
if err := h.stopRetransmissionsFor(p); err != nil {
return err
}
h.rtoCount = 0
h.handshakeCount = 0
// TODO(#497): h.tlpCount = 0
h.packetHistory.Remove(packetElement)
return h.packetHistory.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
for _, r := range p.retransmittedAs {
packet := h.packetHistory.GetPacket(r)
if packet == nil {
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
}
h.stopRetransmissionsFor(packet)
}
return nil
}
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
if pn <= h.largestSentBeforeRTO {
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
// Replace SRTT with latest_rtt and increase the variance to prevent
// a spurious RTO from happening again.
h.rttStats.ExpireSmoothedMetrics()
return
}
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
@ -359,26 +447,42 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
return packet
}
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
return h.lowestUnacked()
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendingAllowed() bool {
cwnd := h.congestion.GetCongestionWindow()
congestionLimited := h.bytesInFlight > cwnd
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
if congestionLimited {
utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
return SendNone
}
// Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
// to RTOs, but we currently don't have a nice way of distinguishing them.
haveRetransmissions := len(h.retransmissionQueue) > 0
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
if h.numRTOs > 0 {
return SendRTO
}
// Only send ACKs if we're congestion limited.
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
return SendAck
}
// Send retransmissions first, if there are any.
if len(h.retransmissionQueue) > 0 {
return SendRetransmission
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
return SendAck
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
@ -386,6 +490,10 @@ func (h *sentPacketHandler) TimeUntilSend() time.Time {
}
func (h *sentPacketHandler) ShouldSendNumPackets() int {
if h.numRTOs > 0 {
// RTO probes should not be paced, but must be sent immediately.
return h.numRTOs
}
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
if delay == 0 || delay > protocol.MinPacingDelay {
return 1
@ -393,45 +501,50 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)
}
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)
}
}
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
packet := &el.Value
utils.Debugf(
"\tQueueing packet 0x%x for retransmission (RTO), %d outstanding",
packet.PacketNumber,
h.packetHistory.Len(),
)
h.queuePacketForRetransmission(el)
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
var handshakePackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, el)
// retransmit the oldest two packets
func (h *sentPacketHandler) queueRTOs() error {
h.largestSentBeforeRTO = h.lastSentPacketNumber
// Queue the first two outstanding packets for retransmission.
// This does NOT declare this packets as lost:
// They are still tracked in the packet history and count towards the bytes in flight.
for i := 0; i < 2; i++ {
if p := h.packetHistory.FirstOutstanding(); p != nil {
h.logger.Debugf("\tQueueing packet %#x for retransmission (RTO)", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
}
for _, el := range handshakePackets {
h.queuePacketForRetransmission(el)
}
return nil
}
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
packet := &packetElement.Value
h.bytesInFlight -= packet.Length
h.retransmissionQueue = append(h.retransmissionQueue, packet)
h.packetHistory.Remove(packetElement)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("\tQueueing packet %#x as a handshake retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
return nil
}
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
if !p.canBeRetransmitted {
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
}
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
return nil
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
@ -446,9 +559,12 @@ func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
rto := h.congestion.RetransmissionDelay()
if rto == 0 {
var rto time.Duration
rtt := h.rttStats.SmoothedRTT()
if rtt == 0 {
rto = defaultRTOTimeout
} else {
rto = rtt + 4*h.rttStats.MeanDeviation()
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff

View file

@ -0,0 +1,127 @@
package ackhandler
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
}
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
return el
}
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
retransmission, ok := h.packetMap[retransmissionOf]
// The retransmitted packet is not present anymore.
// This can happen if it was acked in between dequeueing of the retransmission and sending.
// Just treat the retransmissions as normal packets.
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
if !ok {
for _, packet := range packets {
h.sentPacketImpl(packet)
}
return
}
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
for i, packet := range packets {
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
el := h.sentPacketImpl(packet)
el.Value.isRetransmission = true
el.Value.retransmissionOf = retransmissionOf
}
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
if el, ok := h.packetMap[p]; ok {
return &el.Value
}
return nil
}
// Iterate iterates through all packets.
// The callback must not modify the history.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
var err error
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn]
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && !el.Value.canBeRetransmitted {
el = el.Next()
}
h.firstOutstanding = el
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}

View file

@ -292,11 +292,3 @@ func (c *cubicSender) OnConnectionMigration() {
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
c.slowStartLargeReduction = enabled
}
// RetransmissionDelay gives the time to retransmission
func (c *cubicSender) RetransmissionDelay() time.Duration {
if c.rttStats.SmoothedRTT() == 0 {
return 0
}
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
}

View file

@ -17,7 +17,6 @@ type SendAlgorithm interface {
SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration()
RetransmissionDelay() time.Duration
// Experiments
SetSlowStartLargeReduction(enabled bool)

View file

@ -84,7 +84,6 @@ func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 {
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
return
}

View file

@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
return cert.Certificate[0], nil
}
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
c := cc.config
c, err := maybeGetConfigForClient(c, sni)
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
conf := c.config
conf, err := maybeGetConfigForClient(conf, sni)
if err != nil {
return nil, err
}
// The rest of this function is mostly copied from crypto/tls.getCertificate
if c.GetCertificate != nil {
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if conf.GetCertificate != nil {
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil {
return cert, err
}
}
if len(c.Certificates) == 0 {
if len(conf.Certificates) == 0 {
return nil, errNoMatchingCertificate
}
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
return &c.Certificates[0], nil
return &conf.Certificates[0], nil
}
name := strings.ToLower(sni)
@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
name = name[:len(name)-1]
}
if cert, ok := c.NameToCertificate[name]; ok {
if cert, ok := conf.NameToCertificate[name]; ok {
return cert, nil
}
@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[candidate]; ok {
if cert, ok := conf.NameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &c.Certificates[0], nil
return &conf.Certificates[0], nil
}
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {

View file

@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) {
if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key")
}
// See https://cr.yp.to/ecdh.html
c.secret[0] &= 248
c.secret[31] &= 127
c.secret[31] |= 64
curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil
}

View file

@ -1,13 +1,16 @@
package crypto
import (
"crypto"
"encoding/binary"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret"
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
)
// A TLSExporter gets the negotiated ciphersuite and computes exporter
@ -16,6 +19,14 @@ type TLSExporter interface {
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
}
func qhkdfExpand(secret []byte, label string, length int) []byte {
qlabel := make([]byte, 2+1+5+len(label))
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(5 + len(label))
copy(qlabel[3:], []byte("QUIC "+label))
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
}
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
var myLabel, otherLabel string
@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error)
if err != nil {
return nil, nil, err
}
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen)
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen)
key = qhkdfExpand(secret, "key", cs.KeyLen)
iv = qhkdfExpand(secret, "iv", cs.IvLen)
return key, iv, nil
}

View file

@ -8,7 +8,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39}
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID)
@ -31,14 +31,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
connID := make([]byte, 8)
binary.BigEndian.PutUint64(connID, uint64(connectionID))
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID)
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
return
}
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16)
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12)
key = qhkdfExpand(secret, "key", 16)
iv = qhkdfExpand(secret, "iv", 12)
return
}

View file

@ -1,10 +1,11 @@
package crypto
import (
"encoding/binary"
"bytes"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/fnv128a"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
}
hash := fnv128a.New()
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src[12:])
if n.perspective == protocol.PerspectiveServer {
@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
} else {
hash.Write([]byte("Server"))
}
testHigh, testLow := hash.Sum128()
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
low := binary.LittleEndian.Uint64(src)
high := binary.LittleEndian.Uint32(src[8:])
if uint32(testHigh&0xffffffff) != high || testLow != low {
return nil, errors.New("NullAEAD: failed to authenticate received data")
if !bytes.Equal(sum[:12], src[:12]) {
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
}
return src[12:], nil
}
@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
dst = dst[:12+len(src)]
}
hash := fnv128a.New()
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src)
@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
} else {
hash.Write([]byte("Client"))
}
high, low := hash.Sum128()
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
copy(dst[12:], src)
binary.LittleEndian.PutUint64(dst, low)
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
copy(dst, sum[:12])
return dst
}
func (n *nullAEADFNV128a) Overhead() int {
return 12
}
func reverse(a []byte) {
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
a[left], a[right] = a[right], a[left]
}
}

View file

@ -25,6 +25,8 @@ type baseFlowController struct {
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats
logger utils.Logger
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {

View file

@ -22,6 +22,7 @@ func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
rttStats *congestion.RTTStats,
logger utils.Logger,
) ConnectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
@ -29,6 +30,7 @@ func NewConnectionFlowController(
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
logger: logger,
},
}
}
@ -65,7 +67,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()
return offset

View file

@ -31,6 +31,7 @@ func NewStreamFlowController(
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
rttStats *congestion.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
@ -42,6 +43,7 @@ func NewStreamFlowController(
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow,
logger: logger,
},
}
}
@ -137,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}

View file

@ -7,15 +7,20 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A CookieHandler generates and validates cookies.
// The cookie is sent in the TLS Retry.
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
type CookieHandler struct {
callback func(net.Addr, *Cookie) bool
callback func(net.Addr, *Cookie) bool
cookieGenerator *CookieGenerator
logger utils.Logger
}
var _ mint.CookieHandler = &CookieHandler{}
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
// NewCookieHandler creates a new CookieHandler.
func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
@ -23,9 +28,11 @@ func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, er
return &CookieHandler{
callback: callback,
cookieGenerator: cookieGenerator,
logger: logger,
}, nil
}
// Generate a new cookie for a mint connection.
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
if h.callback(conn.RemoteAddr(), nil) {
return nil, nil
@ -33,10 +40,11 @@ func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
return h.cookieGenerator.NewToken(conn.RemoteAddr())
}
// Validate a cookie.
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token)
if err != nil {
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
return false
}
return h.callback(conn.RemoteAddr(), data)

View file

@ -38,13 +38,12 @@ type cryptoSetupClient struct {
lastSentCHLO []byte
certManager crypto.CertManager
divNonceChan chan []byte
divNonceChan <-chan []byte
diversificationNonce []byte
clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
receivedSecurePacket bool
nullAEAD crypto.AEAD
@ -55,6 +54,8 @@ type cryptoSetupClient struct {
handshakeEvent chan<- struct{}
params *TransportParameters
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupClient{}
@ -77,12 +78,14 @@ func NewCryptoSetupClient(
handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) {
logger utils.Logger,
) (CryptoSetup, chan<- []byte, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
return nil, nil, err
}
return &cryptoSetupClient{
divNonceChan := make(chan []byte)
cs := &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: hostname,
connID: connID,
@ -90,14 +93,15 @@ func NewCryptoSetupClient(
certManager: crypto.NewCertManager(tlsConfig),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte),
}, nil
divNonceChan: divNonceChan,
logger: logger,
}
return cs, divNonceChan, nil
}
func (h *cryptoSetupClient) HandleCryptoStream() error {
@ -146,7 +150,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
return err
}
utils.Debugf("Got %s", message)
h.logger.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
if err := h.handleREJMessage(message.Data); err != nil {
@ -211,7 +215,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
err = h.certManager.Verify(h.hostname)
if err != nil {
utils.Infof("Certificate validation failed: %s", err.Error())
h.logger.Infof("Certificate validation failed: %s", err.Error())
return qerr.ProofInvalid
}
}
@ -219,7 +223,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
if !validProof {
utils.Infof("Server proof verification failed")
h.logger.Infof("Server proof verification failed")
return qerr.ProofInvalid
}
@ -373,14 +377,6 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
return nil, errors.New("CryptoSetupClient: no encryption level specified")
}
func (h *cryptoSetupClient) DiversificationNonce() []byte {
panic("not needed for cryptoSetupClient")
}
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -408,7 +404,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
Data: tags,
}
utils.Debugf("Sending %s", message)
h.logger.Debugf("Sending %s", message)
message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes())

View file

@ -19,7 +19,7 @@ import (
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() crypto.KeyExchange
type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct {
@ -54,6 +54,8 @@ type cryptoSetupServer struct {
params *TransportParameters
sni string // need to fill out the ConnectionState
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupServer{}
@ -73,32 +75,36 @@ func NewCryptoSetup(
connID protocol.ConnectionID,
remoteAddr net.Addr,
version protocol.VersionNumber,
divNonce []byte,
scfg *ServerConfig,
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupServer{
cryptoStream: cryptoStream,
connID: connID,
remoteAddr: remoteAddr,
version: version,
supportedVersions: supportedVersions,
scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
params: params,
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
cryptoStream: cryptoStream,
connID: connID,
remoteAddr: remoteAddr,
version: version,
supportedVersions: supportedVersions,
diversificationNonce: divNonce,
scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
params: params,
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
logger: logger,
}, nil
}
@ -114,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
return qerr.InvalidCryptoMessageType
}
utils.Debugf("Got %s", message)
h.logger.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data)
if err != nil {
return err
@ -297,7 +303,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
if err != nil {
utils.Debugf("STK invalid: %s", err.Error())
h.logger.Debugf("STK invalid: %s", err.Error())
return false
}
return h.acceptSTKCallback(h.remoteAddr, stk)
@ -340,7 +346,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
var serverReply bytes.Buffer
message.Write(&serverReply)
utils.Debugf("Sending %s", message)
h.logger.Debugf("Sending %s", message)
return serverReply.Bytes(), nil
}
@ -364,11 +370,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
return nil, err
}
h.diversificationNonce = make([]byte, 32)
if _, err = rand.Read(h.diversificationNonce); err != nil {
return nil, err
}
clientNonce := cryptoData[TagNONC]
err = h.validateClientNonce(clientNonce)
if err != nil {
@ -405,7 +406,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
var fsNonce bytes.Buffer
fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce)
ephermalKex := h.keyExchange()
ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
@ -429,7 +433,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
replyMap := h.params.getHelloMap()
// add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range protocol.GetGreasedVersions(h.supportedVersions) {
for _, v := range h.supportedVersions {
utils.BigEndian.WriteUint32(verTag, uint32(v))
}
replyMap[TagPUBS] = ephermalKex.PublicKey()
@ -443,19 +447,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
}
var reply bytes.Buffer
message.Write(&reply)
utils.Debugf("Sending %s", message)
h.logger.Debugf("Sending %s", message)
return reply.Bytes(), nil
}
// DiversificationNonce returns the diversification nonce
func (h *cryptoSetupServer) DiversificationNonce() []byte {
return h.diversificationNonce
}
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()

View file

@ -31,6 +31,8 @@ type cryptoSetupTLS struct {
handshakeEvent chan<- struct{}
}
var _ CryptoSetupTLS = &cryptoSetupTLS{}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
tls MintTLS,
@ -38,7 +40,7 @@ func NewCryptoSetupTLSServer(
nullAEAD crypto.AEAD,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) CryptoSetup {
) CryptoSetupTLS {
return &cryptoSetupTLS{
tls: tls,
cryptoStream: cryptoStream,
@ -57,7 +59,7 @@ func NewCryptoSetupTLSClient(
handshakeEvent chan<- struct{},
tls MintTLS,
version protocol.VersionNumber,
) (CryptoSetup, error) {
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
@ -107,22 +109,18 @@ handshakeLoop:
return nil
}
func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.aead != nil {
data, err := h.aead.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return data, protocol.EncryptionForwardSecure, nil
if h.aead == nil {
return nil, errors.New("no 1-RTT sealer")
}
data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return data, protocol.EncryptionUnencrypted, nil
return h.aead.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
@ -157,14 +155,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
panic("diversification nonce not needed for TLS")
}
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
panic("diversification nonce not needed for TLS")
}
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()

View file

@ -6,7 +6,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
var (
@ -24,13 +23,13 @@ var (
// used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a
// small time span.
func getEphermalKEX() (res crypto.KeyExchange) {
func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock()
res = kexCurrent
res := kexCurrent
t := kexCurrentTime
kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime {
return res
return res, nil
}
kexMutex.Lock()
@ -39,12 +38,11 @@ func getEphermalKEX() (res crypto.KeyExchange) {
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
kex, err := crypto.NewCurve25519KEX()
if err != nil {
utils.Errorf("could not set KEX: %s", err.Error())
return kexCurrent
return nil, err
}
kexCurrent = kex
kexCurrentTime = time.Now()
return kexCurrent
return kexCurrent, nil
}
return kexCurrent
return kexCurrent, nil
}

View file

@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
offset := uint32(0)
for i, t := range h.getTagsSorted() {
v := data[Tag(t)]
v := data[t]
b.Write(v)
offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag {
func (h HandshakeMessage) String() string {
var pad string
res := tagToString(h.Tag) + ":\n"
for _, t := range h.getTagsSorted() {
tag := Tag(t)
for _, tag := range h.getTagsSorted() {
if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else {

View file

@ -35,13 +35,8 @@ type MintTLS interface {
SetCryptoStream(io.ReadWriter)
}
// CryptoSetup is a crypto setup
type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
type baseCryptoSetup interface {
HandleCryptoStream() error
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
ConnectionState() ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer)
@ -49,6 +44,21 @@ type CryptoSetup interface {
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// CryptoSetup is the crypto setup used by gQUIC
type CryptoSetup interface {
baseCryptoSetup
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
}
// CryptoSetupTLS is the crypto setup used by IETF QUIC
type CryptoSetupTLS interface {
baseCryptoSetup
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ConnectionState records basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
type ConnectionState struct {

View file

@ -9,10 +9,10 @@ import (
// ServerConfig is a server config
type ServerConfig struct {
kex crypto.KeyExchange
certChain crypto.CertChain
ID []byte
obit []byte
kex crypto.KeyExchange
certChain crypto.CertChain
ID []byte
obit []byte
cookieGenerator *CookieGenerator
}
@ -36,10 +36,10 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve
}
return &ServerConfig{
kex: kex,
certChain: certChain,
ID: id,
obit: obit,
kex: kex,
certChain: certChain,
ID: id,
obit: obit,
cookieGenerator: cookieGenerator,
}, nil
}

View file

@ -102,32 +102,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
var pubs_kexs []struct{Length uint32; Value []byte}
var last_len uint32
for i := 0; i < len(pubs)-3; i += int(last_len)+3 {
var pubsKexs []struct {
Length uint32
Value []byte
}
var lastLen uint32
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
// the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &last_len);
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
}
if last_len == 0 {
if lastLen == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
if i+3+int(last_len) > len(pubs) {
if i+3+int(lastLen) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
pubs_kexs = append(pubs_kexs, struct{Length uint32; Value []byte}{last_len, pubs[i+3:i+3+int(last_len)]})
pubsKexs = append(pubsKexs, struct {
Length uint32
Value []byte
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
}
if c255Foundat >= len(pubs_kexs) {
if c255Foundat >= len(pubsKexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
}
if pubs_kexs[c255Foundat].Length != 32 {
if pubsKexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
@ -137,8 +142,7 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
return err
}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs_kexs[c255Foundat].Value)
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
if err != nil {
return err
}

View file

@ -9,14 +9,14 @@ type transportParameterID uint16
const quicTLSExtensionType = 26
const (
initialMaxStreamDataParameterID transportParameterID = 0x0
initialMaxDataParameterID transportParameterID = 0x1
initialMaxStreamIDBiDiParameterID transportParameterID = 0x2
idleTimeoutParameterID transportParameterID = 0x3
omitConnectionIDParameterID transportParameterID = 0x4
maxPacketSizeParameterID transportParameterID = 0x5
statelessResetTokenParameterID transportParameterID = 0x6
initialMaxStreamIDUniParameterID transportParameterID = 0x8
initialMaxStreamDataParameterID transportParameterID = 0x0
initialMaxDataParameterID transportParameterID = 0x1
initialMaxStreamsBiDiParameterID transportParameterID = 0x2
idleTimeoutParameterID transportParameterID = 0x3
omitConnectionIDParameterID transportParameterID = 0x4
maxPacketSizeParameterID transportParameterID = 0x5
statelessResetTokenParameterID transportParameterID = 0x6
initialMaxStreamsUniParameterID transportParameterID = 0x8
)
type transportParameter struct {

Some files were not shown because too many files have changed in this diff Show more