Keep type information with placeholders until replacements happen

This commit is contained in:
Matthew Holt 2020-03-30 11:49:53 -06:00
parent deba26d225
commit 105acfa086
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
7 changed files with 184 additions and 109 deletions

View file

@ -86,7 +86,7 @@ func (m *MatchExpression) Provision(_ caddy.Context) error {
decls.NewFunction(placeholderFuncName, decls.NewFunction(placeholderFuncName,
decls.NewOverload(placeholderFuncName+"_httpRequest_string", decls.NewOverload(placeholderFuncName+"_httpRequest_string",
[]*exprpb.Type{httpRequestObjectType, decls.String}, []*exprpb.Type{httpRequestObjectType, decls.String},
decls.String)), decls.Any)),
), ),
cel.CustomTypeAdapter(celHTTPRequestTypeAdapter{}), cel.CustomTypeAdapter(celHTTPRequestTypeAdapter{}),
ext.Strings(), ext.Strings(),
@ -210,7 +210,35 @@ func caddyPlaceholderFunc(lhs, rhs ref.Val) ref.Val {
repl := celReq.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) repl := celReq.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
val, _ := repl.Get(string(phStr)) val, _ := repl.Get(string(phStr))
return types.String(val) // TODO: this is... kinda awful and underwhelming, how can we expand CEL's type system more easily?
switch v := val.(type) {
case string:
return types.String(v)
case fmt.Stringer:
return types.String(v.String())
case error:
return types.NewErr(v.Error())
case int:
return types.Int(v)
case int32:
return types.Int(v)
case int64:
return types.Int(v)
case uint:
return types.Int(v)
case uint32:
return types.Int(v)
case uint64:
return types.Int(v)
case float32:
return types.Double(v)
case float64:
return types.Double(v)
case bool:
return types.Bool(v)
default:
return types.String(fmt.Sprintf("%+v", v))
}
} }
// Interface guards // Interface guards

View file

@ -31,7 +31,7 @@ import (
) )
func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) { func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) {
httpVars := func(key string) (string, bool) { httpVars := func(key string) (interface{}, bool) {
if req != nil { if req != nil {
// query string parameters // query string parameters
if strings.HasPrefix(key, reqURIQueryReplPrefix) { if strings.HasPrefix(key, reqURIQueryReplPrefix) {
@ -62,7 +62,7 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
} }
} }
// http.request.tls. // http.request.tls.*
if strings.HasPrefix(key, reqTLSReplPrefix) { if strings.HasPrefix(key, reqTLSReplPrefix) {
return getReqTLSReplacement(req, key) return getReqTLSReplacement(req, key)
} }
@ -182,21 +182,10 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
if strings.HasPrefix(key, varsReplPrefix) { if strings.HasPrefix(key, varsReplPrefix) {
varName := key[len(varsReplPrefix):] varName := key[len(varsReplPrefix):]
tbl := req.Context().Value(VarsCtxKey).(map[string]interface{}) tbl := req.Context().Value(VarsCtxKey).(map[string]interface{})
raw, ok := tbl[varName] raw, _ := tbl[varName]
if !ok { // variables can be dynamic, so always return true
// variables can be dynamic, so always return true // even when it may not be set; treat as empty then
// even when it may not be set; treat as empty return raw, true
return "", true
}
// do our best to convert it to a string efficiently
switch val := raw.(type) {
case string:
return val, true
case fmt.Stringer:
return val.String(), true
default:
return fmt.Sprintf("%s", val), true
}
} }
} }
@ -211,19 +200,19 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
} }
} }
return "", false return nil, false
} }
repl.Map(httpVars) repl.Map(httpVars)
} }
func getReqTLSReplacement(req *http.Request, key string) (string, bool) { func getReqTLSReplacement(req *http.Request, key string) (interface{}, bool) {
if req == nil || req.TLS == nil { if req == nil || req.TLS == nil {
return "", false return nil, false
} }
if len(key) < len(reqTLSReplPrefix) { if len(key) < len(reqTLSReplPrefix) {
return "", false return nil, false
} }
field := strings.ToLower(key[len(reqTLSReplPrefix):]) field := strings.ToLower(key[len(reqTLSReplPrefix):])
@ -231,20 +220,20 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
if strings.HasPrefix(field, "client.") { if strings.HasPrefix(field, "client.") {
cert := getTLSPeerCert(req.TLS) cert := getTLSPeerCert(req.TLS)
if cert == nil { if cert == nil {
return "", false return nil, false
} }
switch field { switch field {
case "client.fingerprint": case "client.fingerprint":
return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true
case "client.issuer": case "client.issuer":
return cert.Issuer.String(), true return cert.Issuer, true
case "client.serial": case "client.serial":
return fmt.Sprintf("%x", cert.SerialNumber), true return cert.SerialNumber, true
case "client.subject": case "client.subject":
return cert.Subject.String(), true return cert.Subject, true
default: default:
return "", false return nil, false
} }
} }
@ -254,22 +243,15 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
case "cipher_suite": case "cipher_suite":
return tls.CipherSuiteName(req.TLS.CipherSuite), true return tls.CipherSuiteName(req.TLS.CipherSuite), true
case "resumed": case "resumed":
if req.TLS.DidResume { return req.TLS.DidResume, true
return "true", true
}
return "false", true
case "proto": case "proto":
return req.TLS.NegotiatedProtocol, true return req.TLS.NegotiatedProtocol, true
case "proto_mutual": case "proto_mutual":
if req.TLS.NegotiatedProtocolIsMutual { return req.TLS.NegotiatedProtocolIsMutual, true
return "true", true
}
return "false", true
case "server_name": case "server_name":
return req.TLS.ServerName, true return req.TLS.ServerName, true
default:
return "", false
} }
return nil, false
} }
// getTLSPeerCert retrieves the first peer certificate from a TLS session. // getTLSPeerCert retrieves the first peer certificate from a TLS session.

View file

@ -24,7 +24,6 @@ import (
"net" "net"
"net/http" "net/http"
"regexp" "regexp"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -328,9 +327,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address) repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host) repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port) repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
repl.Set("http.reverse_proxy.upstream.requests", strconv.Itoa(upstream.Host.NumRequests())) repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests())
repl.Set("http.reverse_proxy.upstream.max_requests", strconv.Itoa(upstream.MaxRequests)) repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests)
repl.Set("http.reverse_proxy.upstream.fails", strconv.Itoa(upstream.Host.Fails())) repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
// mutate request headers according to this upstream; // mutate request headers according to this upstream;
// because we're in a retry loop, we have to copy // because we're in a retry loop, we have to copy

View file

@ -15,8 +15,10 @@
package rewrite package rewrite
import ( import (
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
@ -208,11 +210,22 @@ func buildQueryString(qs string, repl *caddy.Replacer) string {
// consume the component and write the result // consume the component and write the result
comp := qs[:end] comp := qs[:end]
comp, _ = repl.ReplaceFunc(comp, func(name, val string) (string, error) { comp, _ = repl.ReplaceFunc(comp, func(name string, val interface{}) (interface{}, error) {
if name == "http.request.uri.query" && wroteVal { if name == "http.request.uri.query" && wroteVal {
return val, nil // already escaped return val, nil // already escaped
} }
return url.QueryEscape(val), nil var valStr string
switch v := val.(type) {
case string:
valStr = v
case fmt.Stringer:
valStr = v.String()
case int:
valStr = strconv.Itoa(v)
default:
valStr = fmt.Sprintf("%+v", v)
}
return url.QueryEscape(valStr), nil
}) })
if end < len(qs) { if end < len(qs) {
end++ // consume delimiter end++ // consume delimiter

View file

@ -21,7 +21,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
@ -166,9 +165,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() { defer func() {
latency := time.Since(start) latency := time.Since(start)
repl.Set("http.response.status", strconv.Itoa(wrec.Status())) repl.Set("http.response.status", wrec.Status())
repl.Set("http.response.size", strconv.Itoa(wrec.Size())) repl.Set("http.response.size", wrec.Size())
repl.Set("http.response.latency", latency.String()) repl.Set("http.response.latency", latency)
logger := accLog logger := accLog
if s.Logs != nil && s.Logs.LoggerNames != nil { if s.Logs != nil && s.Logs.LoggerNames != nil {
@ -360,9 +359,9 @@ func (*HTTPErrorConfig) WithError(r *http.Request, err error) *http.Request {
// add error values to the replacer // add error values to the replacer
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
repl.Set("http.error", err.Error()) repl.Set("http.error", err)
if handlerErr, ok := err.(HandlerError); ok { if handlerErr, ok := err.(HandlerError); ok {
repl.Set("http.error.status_code", strconv.Itoa(handlerErr.StatusCode)) repl.Set("http.error.status_code", handlerErr.StatusCode)
repl.Set("http.error.status_text", http.StatusText(handlerErr.StatusCode)) repl.Set("http.error.status_text", http.StatusText(handlerErr.StatusCode))
repl.Set("http.error.trace", handlerErr.Trace) repl.Set("http.error.trace", handlerErr.Trace)
repl.Set("http.error.id", handlerErr.ID) repl.Set("http.error.id", handlerErr.ID)

View file

@ -27,7 +27,7 @@ import (
// NewReplacer returns a new Replacer. // NewReplacer returns a new Replacer.
func NewReplacer() *Replacer { func NewReplacer() *Replacer {
rep := &Replacer{ rep := &Replacer{
static: make(map[string]string), static: make(map[string]interface{}),
} }
rep.providers = []ReplacerFunc{ rep.providers = []ReplacerFunc{
globalDefaultReplacements, globalDefaultReplacements,
@ -41,7 +41,7 @@ func NewReplacer() *Replacer {
// use NewReplacer to make one. // use NewReplacer to make one.
type Replacer struct { type Replacer struct {
providers []ReplacerFunc providers []ReplacerFunc
static map[string]string static map[string]interface{}
} }
// Map adds mapFunc to the list of value providers. // Map adds mapFunc to the list of value providers.
@ -51,19 +51,19 @@ func (r *Replacer) Map(mapFunc ReplacerFunc) {
} }
// Set sets a custom variable to a static value. // Set sets a custom variable to a static value.
func (r *Replacer) Set(variable, value string) { func (r *Replacer) Set(variable string, value interface{}) {
r.static[variable] = value r.static[variable] = value
} }
// Get gets a value from the replacer. It returns // Get gets a value from the replacer. It returns
// the value and whether the variable was known. // the value and whether the variable was known.
func (r *Replacer) Get(variable string) (string, bool) { func (r *Replacer) Get(variable string) (interface{}, bool) {
for _, mapFunc := range r.providers { for _, mapFunc := range r.providers {
if val, ok := mapFunc(variable); ok { if val, ok := mapFunc(variable); ok {
return val, true return val, true
} }
} }
return "", false return nil, false
} }
// Delete removes a variable with a static value // Delete removes a variable with a static value
@ -73,9 +73,9 @@ func (r *Replacer) Delete(variable string) {
} }
// fromStatic provides values from r.static. // fromStatic provides values from r.static.
func (r *Replacer) fromStatic(key string) (val string, ok bool) { func (r *Replacer) fromStatic(key string) (interface{}, bool) {
val, ok = r.static[key] val, ok := r.static[key]
return return val, ok
} }
// ReplaceOrErr is like ReplaceAll, but any placeholders // ReplaceOrErr is like ReplaceAll, but any placeholders
@ -102,10 +102,9 @@ func (r *Replacer) ReplaceAll(input, empty string) string {
return out return out
} }
// ReplaceFunc calls ReplaceAll efficiently replaces placeholders in input with // ReplaceFunc is the same as ReplaceAll, but calls f for every
// their values. All placeholders are replaced in the output // replacement to be made, in case f wants to change or inspect
// whether they are recognized or not. Values that are empty // the replacement.
// string will be substituted with empty.
func (r *Replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) { func (r *Replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) {
return r.replace(input, "", true, false, false, f) return r.replace(input, "", true, false, false, f)
} }
@ -169,9 +168,8 @@ scan:
return "", fmt.Errorf("unrecognized placeholder %s%s%s", return "", fmt.Errorf("unrecognized placeholder %s%s%s",
string(phOpen), key, string(phClose)) string(phOpen), key, string(phClose))
} else if !treatUnknownAsEmpty { } else if !treatUnknownAsEmpty {
// if treatUnknownAsEmpty is true, we'll // if treatUnknownAsEmpty is true, we'll handle an empty
// handle an empty val later; so only // val later; so only continue otherwise
// continue otherwise
lastWriteCursor = i lastWriteCursor = i
continue continue
} }
@ -186,9 +184,12 @@ scan:
} }
} }
// convert val to a string as efficiently as possible
valStr := toString(val)
// write the value; if it's empty, either return // write the value; if it's empty, either return
// an error or write a default value // an error or write a default value
if val == "" { if valStr == "" {
if errOnEmpty { if errOnEmpty {
return "", fmt.Errorf("evaluated placeholder %s%s%s is empty", return "", fmt.Errorf("evaluated placeholder %s%s%s is empty",
string(phOpen), key, string(phClose)) string(phOpen), key, string(phClose))
@ -196,7 +197,7 @@ scan:
sb.WriteString(empty) sb.WriteString(empty)
} }
} else { } else {
sb.WriteString(val) sb.WriteString(valStr)
} }
// advance cursor to end of placeholder // advance cursor to end of placeholder
@ -210,14 +211,54 @@ scan:
return sb.String(), nil return sb.String(), nil
} }
func toString(val interface{}) string {
switch v := val.(type) {
case nil:
return ""
case string:
return v
case fmt.Stringer:
return v.String()
case byte:
return string(v)
case []byte:
return string(v)
case []rune:
return string(v)
case int:
return strconv.Itoa(v)
case int32:
return strconv.Itoa(int(v))
case int64:
return strconv.Itoa(int(v))
case uint:
return strconv.Itoa(int(v))
case uint32:
return strconv.Itoa(int(v))
case uint64:
return strconv.Itoa(int(v))
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case bool:
if v {
return "true"
}
return "false"
default:
return fmt.Sprintf("%+v", v)
}
}
// ReplacerFunc is a function that returns a replacement // ReplacerFunc is a function that returns a replacement
// for the given key along with true if the function is able // for the given key along with true if the function is able
// to service that key (even if the value is blank). If the // to service that key (even if the value is blank). If the
// function does not recognize the key, false should be // function does not recognize the key, false should be
// returned. // returned.
type ReplacerFunc func(key string) (val string, ok bool) type ReplacerFunc func(key string) (interface{}, bool)
func globalDefaultReplacements(key string) (string, bool) { func globalDefaultReplacements(key string) (interface{}, bool) {
// check environment variable // check environment variable
const envPrefix = "env." const envPrefix = "env."
if strings.HasPrefix(key, envPrefix) { if strings.HasPrefix(key, envPrefix) {
@ -241,7 +282,7 @@ func globalDefaultReplacements(key string) (string, bool) {
return strconv.Itoa(nowFunc().Year()), true return strconv.Itoa(nowFunc().Year()), true
} }
return "", false return nil, false
} }
// ReplacementFunc is a function that is called when a // ReplacementFunc is a function that is called when a
@ -250,7 +291,7 @@ func globalDefaultReplacements(key string) (string, bool) {
// will be the replacement, and returns the value that // will be the replacement, and returns the value that
// will actually be the replacement, or an error. Note // will actually be the replacement, or an error. Note
// that errors are sometimes ignored by replacers. // that errors are sometimes ignored by replacers.
type ReplacementFunc func(variable, val string) (string, error) type ReplacementFunc func(variable string, val interface{}) (interface{}, error)
// nowFunc is a variable so tests can change it // nowFunc is a variable so tests can change it
// in order to obtain a deterministic time. // in order to obtain a deterministic time.

View file

@ -173,41 +173,12 @@ func TestReplacer(t *testing.T) {
} }
} }
func BenchmarkReplacer(b *testing.B) {
type testCase struct {
name, input, empty string
}
rep := testReplacer()
for _, bm := range []testCase{
{
name: "no placeholder",
input: `simple string`,
},
{
name: "placeholder",
input: `{"json": "object"}`,
},
{
name: "escaped placeholder",
input: `\{"json": \{"nested": "{bar}"\}\}`,
},
} {
b.Run(bm.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
rep.ReplaceAll(bm.input, bm.empty)
}
})
}
}
func TestReplacerSet(t *testing.T) { func TestReplacerSet(t *testing.T) {
rep := testReplacer() rep := testReplacer()
for _, tc := range []struct { for _, tc := range []struct {
variable string variable string
value string value interface{}
}{ }{
{ {
variable: "test1", variable: "test1",
@ -217,6 +188,10 @@ func TestReplacerSet(t *testing.T) {
variable: "asdf", variable: "asdf",
value: "123", value: "123",
}, },
{
variable: "numbers",
value: 123.456,
},
{ {
variable: "äöü", variable: "äöü",
value: "öö_äü", value: "öö_äü",
@ -252,7 +227,7 @@ func TestReplacerSet(t *testing.T) {
// test if all keys are still there (by length) // test if all keys are still there (by length)
length := len(rep.static) length := len(rep.static)
if len(rep.static) != 7 { if len(rep.static) != 8 {
t.Errorf("Expected length '%v' got '%v'", 7, length) t.Errorf("Expected length '%v' got '%v'", 7, length)
} }
} }
@ -261,7 +236,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
rep := Replacer{ rep := Replacer{
providers: []ReplacerFunc{ providers: []ReplacerFunc{
// split our possible vars to two functions (to test if both functions are called) // split our possible vars to two functions (to test if both functions are called)
func(key string) (val string, ok bool) { func(key string) (val interface{}, ok bool) {
switch key { switch key {
case "test1": case "test1":
return "val1", true return "val1", true
@ -275,7 +250,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
return "NOOO", false return "NOOO", false
} }
}, },
func(key string) (val string, ok bool) { func(key string) (val interface{}, ok bool) {
switch key { switch key {
case "1": case "1":
return "test-123", true return "test-123", true
@ -331,7 +306,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
func TestReplacerDelete(t *testing.T) { func TestReplacerDelete(t *testing.T) {
rep := Replacer{ rep := Replacer{
static: map[string]string{ static: map[string]interface{}{
"key1": "val1", "key1": "val1",
"key2": "val2", "key2": "val2",
"key3": "val3", "key3": "val3",
@ -366,10 +341,10 @@ func TestReplacerMap(t *testing.T) {
rep := testReplacer() rep := testReplacer()
for i, tc := range []ReplacerFunc{ for i, tc := range []ReplacerFunc{
func(key string) (val string, ok bool) { func(key string) (val interface{}, ok bool) {
return "", false return "", false
}, },
func(key string) (val string, ok bool) { func(key string) (val interface{}, ok bool) {
return "", false return "", false
}, },
} { } {
@ -434,12 +409,50 @@ func TestReplacerNew(t *testing.T) {
} }
} }
} }
}
func BenchmarkReplacer(b *testing.B) {
type testCase struct {
name, input, empty string
}
rep := testReplacer()
rep.Set("str", "a string")
rep.Set("int", 123.456)
for _, bm := range []testCase{
{
name: "no placeholder",
input: `simple string`,
},
{
name: "string replacement",
input: `str={str}`,
},
{
name: "int replacement",
input: `int={int}`,
},
{
name: "placeholder",
input: `{"json": "object"}`,
},
{
name: "escaped placeholder",
input: `\{"json": \{"nested": "{bar}"\}\}`,
},
} {
b.Run(bm.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
rep.ReplaceAll(bm.input, bm.empty)
}
})
}
} }
func testReplacer() Replacer { func testReplacer() Replacer {
return Replacer{ return Replacer{
providers: make([]ReplacerFunc, 0), providers: make([]ReplacerFunc, 0),
static: make(map[string]string), static: make(map[string]interface{}),
} }
} }