mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-26 13:43:47 +03:00
Keep type information with placeholders until replacements happen
This commit is contained in:
parent
deba26d225
commit
105acfa086
7 changed files with 184 additions and 109 deletions
|
@ -86,7 +86,7 @@ func (m *MatchExpression) Provision(_ caddy.Context) error {
|
||||||
decls.NewFunction(placeholderFuncName,
|
decls.NewFunction(placeholderFuncName,
|
||||||
decls.NewOverload(placeholderFuncName+"_httpRequest_string",
|
decls.NewOverload(placeholderFuncName+"_httpRequest_string",
|
||||||
[]*exprpb.Type{httpRequestObjectType, decls.String},
|
[]*exprpb.Type{httpRequestObjectType, decls.String},
|
||||||
decls.String)),
|
decls.Any)),
|
||||||
),
|
),
|
||||||
cel.CustomTypeAdapter(celHTTPRequestTypeAdapter{}),
|
cel.CustomTypeAdapter(celHTTPRequestTypeAdapter{}),
|
||||||
ext.Strings(),
|
ext.Strings(),
|
||||||
|
@ -210,7 +210,35 @@ func caddyPlaceholderFunc(lhs, rhs ref.Val) ref.Val {
|
||||||
repl := celReq.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
repl := celReq.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||||
val, _ := repl.Get(string(phStr))
|
val, _ := repl.Get(string(phStr))
|
||||||
|
|
||||||
return types.String(val)
|
// TODO: this is... kinda awful and underwhelming, how can we expand CEL's type system more easily?
|
||||||
|
switch v := val.(type) {
|
||||||
|
case string:
|
||||||
|
return types.String(v)
|
||||||
|
case fmt.Stringer:
|
||||||
|
return types.String(v.String())
|
||||||
|
case error:
|
||||||
|
return types.NewErr(v.Error())
|
||||||
|
case int:
|
||||||
|
return types.Int(v)
|
||||||
|
case int32:
|
||||||
|
return types.Int(v)
|
||||||
|
case int64:
|
||||||
|
return types.Int(v)
|
||||||
|
case uint:
|
||||||
|
return types.Int(v)
|
||||||
|
case uint32:
|
||||||
|
return types.Int(v)
|
||||||
|
case uint64:
|
||||||
|
return types.Int(v)
|
||||||
|
case float32:
|
||||||
|
return types.Double(v)
|
||||||
|
case float64:
|
||||||
|
return types.Double(v)
|
||||||
|
case bool:
|
||||||
|
return types.Bool(v)
|
||||||
|
default:
|
||||||
|
return types.String(fmt.Sprintf("%+v", v))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Interface guards
|
// Interface guards
|
||||||
|
|
|
@ -31,7 +31,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) {
|
func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) {
|
||||||
httpVars := func(key string) (string, bool) {
|
httpVars := func(key string) (interface{}, bool) {
|
||||||
if req != nil {
|
if req != nil {
|
||||||
// query string parameters
|
// query string parameters
|
||||||
if strings.HasPrefix(key, reqURIQueryReplPrefix) {
|
if strings.HasPrefix(key, reqURIQueryReplPrefix) {
|
||||||
|
@ -62,7 +62,7 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// http.request.tls.
|
// http.request.tls.*
|
||||||
if strings.HasPrefix(key, reqTLSReplPrefix) {
|
if strings.HasPrefix(key, reqTLSReplPrefix) {
|
||||||
return getReqTLSReplacement(req, key)
|
return getReqTLSReplacement(req, key)
|
||||||
}
|
}
|
||||||
|
@ -182,21 +182,10 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
|
||||||
if strings.HasPrefix(key, varsReplPrefix) {
|
if strings.HasPrefix(key, varsReplPrefix) {
|
||||||
varName := key[len(varsReplPrefix):]
|
varName := key[len(varsReplPrefix):]
|
||||||
tbl := req.Context().Value(VarsCtxKey).(map[string]interface{})
|
tbl := req.Context().Value(VarsCtxKey).(map[string]interface{})
|
||||||
raw, ok := tbl[varName]
|
raw, _ := tbl[varName]
|
||||||
if !ok {
|
// variables can be dynamic, so always return true
|
||||||
// variables can be dynamic, so always return true
|
// even when it may not be set; treat as empty then
|
||||||
// even when it may not be set; treat as empty
|
return raw, true
|
||||||
return "", true
|
|
||||||
}
|
|
||||||
// do our best to convert it to a string efficiently
|
|
||||||
switch val := raw.(type) {
|
|
||||||
case string:
|
|
||||||
return val, true
|
|
||||||
case fmt.Stringer:
|
|
||||||
return val.String(), true
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%s", val), true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,19 +200,19 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
repl.Map(httpVars)
|
repl.Map(httpVars)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
|
func getReqTLSReplacement(req *http.Request, key string) (interface{}, bool) {
|
||||||
if req == nil || req.TLS == nil {
|
if req == nil || req.TLS == nil {
|
||||||
return "", false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(key) < len(reqTLSReplPrefix) {
|
if len(key) < len(reqTLSReplPrefix) {
|
||||||
return "", false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
field := strings.ToLower(key[len(reqTLSReplPrefix):])
|
field := strings.ToLower(key[len(reqTLSReplPrefix):])
|
||||||
|
@ -231,20 +220,20 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
|
||||||
if strings.HasPrefix(field, "client.") {
|
if strings.HasPrefix(field, "client.") {
|
||||||
cert := getTLSPeerCert(req.TLS)
|
cert := getTLSPeerCert(req.TLS)
|
||||||
if cert == nil {
|
if cert == nil {
|
||||||
return "", false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch field {
|
switch field {
|
||||||
case "client.fingerprint":
|
case "client.fingerprint":
|
||||||
return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true
|
return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true
|
||||||
case "client.issuer":
|
case "client.issuer":
|
||||||
return cert.Issuer.String(), true
|
return cert.Issuer, true
|
||||||
case "client.serial":
|
case "client.serial":
|
||||||
return fmt.Sprintf("%x", cert.SerialNumber), true
|
return cert.SerialNumber, true
|
||||||
case "client.subject":
|
case "client.subject":
|
||||||
return cert.Subject.String(), true
|
return cert.Subject, true
|
||||||
default:
|
default:
|
||||||
return "", false
|
return nil, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -254,22 +243,15 @@ func getReqTLSReplacement(req *http.Request, key string) (string, bool) {
|
||||||
case "cipher_suite":
|
case "cipher_suite":
|
||||||
return tls.CipherSuiteName(req.TLS.CipherSuite), true
|
return tls.CipherSuiteName(req.TLS.CipherSuite), true
|
||||||
case "resumed":
|
case "resumed":
|
||||||
if req.TLS.DidResume {
|
return req.TLS.DidResume, true
|
||||||
return "true", true
|
|
||||||
}
|
|
||||||
return "false", true
|
|
||||||
case "proto":
|
case "proto":
|
||||||
return req.TLS.NegotiatedProtocol, true
|
return req.TLS.NegotiatedProtocol, true
|
||||||
case "proto_mutual":
|
case "proto_mutual":
|
||||||
if req.TLS.NegotiatedProtocolIsMutual {
|
return req.TLS.NegotiatedProtocolIsMutual, true
|
||||||
return "true", true
|
|
||||||
}
|
|
||||||
return "false", true
|
|
||||||
case "server_name":
|
case "server_name":
|
||||||
return req.TLS.ServerName, true
|
return req.TLS.ServerName, true
|
||||||
default:
|
|
||||||
return "", false
|
|
||||||
}
|
}
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTLSPeerCert retrieves the first peer certificate from a TLS session.
|
// getTLSPeerCert retrieves the first peer certificate from a TLS session.
|
||||||
|
|
|
@ -24,7 +24,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -328,9 +327,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
|
||||||
repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
|
repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
|
||||||
repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
|
repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
|
||||||
repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
|
repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
|
||||||
repl.Set("http.reverse_proxy.upstream.requests", strconv.Itoa(upstream.Host.NumRequests()))
|
repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests())
|
||||||
repl.Set("http.reverse_proxy.upstream.max_requests", strconv.Itoa(upstream.MaxRequests))
|
repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests)
|
||||||
repl.Set("http.reverse_proxy.upstream.fails", strconv.Itoa(upstream.Host.Fails()))
|
repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
|
||||||
|
|
||||||
// mutate request headers according to this upstream;
|
// mutate request headers according to this upstream;
|
||||||
// because we're in a retry loop, we have to copy
|
// because we're in a retry loop, we have to copy
|
||||||
|
|
|
@ -15,8 +15,10 @@
|
||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/caddyserver/caddy/v2"
|
"github.com/caddyserver/caddy/v2"
|
||||||
|
@ -208,11 +210,22 @@ func buildQueryString(qs string, repl *caddy.Replacer) string {
|
||||||
|
|
||||||
// consume the component and write the result
|
// consume the component and write the result
|
||||||
comp := qs[:end]
|
comp := qs[:end]
|
||||||
comp, _ = repl.ReplaceFunc(comp, func(name, val string) (string, error) {
|
comp, _ = repl.ReplaceFunc(comp, func(name string, val interface{}) (interface{}, error) {
|
||||||
if name == "http.request.uri.query" && wroteVal {
|
if name == "http.request.uri.query" && wroteVal {
|
||||||
return val, nil // already escaped
|
return val, nil // already escaped
|
||||||
}
|
}
|
||||||
return url.QueryEscape(val), nil
|
var valStr string
|
||||||
|
switch v := val.(type) {
|
||||||
|
case string:
|
||||||
|
valStr = v
|
||||||
|
case fmt.Stringer:
|
||||||
|
valStr = v.String()
|
||||||
|
case int:
|
||||||
|
valStr = strconv.Itoa(v)
|
||||||
|
default:
|
||||||
|
valStr = fmt.Sprintf("%+v", v)
|
||||||
|
}
|
||||||
|
return url.QueryEscape(valStr), nil
|
||||||
})
|
})
|
||||||
if end < len(qs) {
|
if end < len(qs) {
|
||||||
end++ // consume delimiter
|
end++ // consume delimiter
|
||||||
|
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -166,9 +165,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
defer func() {
|
defer func() {
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
|
|
||||||
repl.Set("http.response.status", strconv.Itoa(wrec.Status()))
|
repl.Set("http.response.status", wrec.Status())
|
||||||
repl.Set("http.response.size", strconv.Itoa(wrec.Size()))
|
repl.Set("http.response.size", wrec.Size())
|
||||||
repl.Set("http.response.latency", latency.String())
|
repl.Set("http.response.latency", latency)
|
||||||
|
|
||||||
logger := accLog
|
logger := accLog
|
||||||
if s.Logs != nil && s.Logs.LoggerNames != nil {
|
if s.Logs != nil && s.Logs.LoggerNames != nil {
|
||||||
|
@ -360,9 +359,9 @@ func (*HTTPErrorConfig) WithError(r *http.Request, err error) *http.Request {
|
||||||
|
|
||||||
// add error values to the replacer
|
// add error values to the replacer
|
||||||
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||||
repl.Set("http.error", err.Error())
|
repl.Set("http.error", err)
|
||||||
if handlerErr, ok := err.(HandlerError); ok {
|
if handlerErr, ok := err.(HandlerError); ok {
|
||||||
repl.Set("http.error.status_code", strconv.Itoa(handlerErr.StatusCode))
|
repl.Set("http.error.status_code", handlerErr.StatusCode)
|
||||||
repl.Set("http.error.status_text", http.StatusText(handlerErr.StatusCode))
|
repl.Set("http.error.status_text", http.StatusText(handlerErr.StatusCode))
|
||||||
repl.Set("http.error.trace", handlerErr.Trace)
|
repl.Set("http.error.trace", handlerErr.Trace)
|
||||||
repl.Set("http.error.id", handlerErr.ID)
|
repl.Set("http.error.id", handlerErr.ID)
|
||||||
|
|
85
replacer.go
85
replacer.go
|
@ -27,7 +27,7 @@ import (
|
||||||
// NewReplacer returns a new Replacer.
|
// NewReplacer returns a new Replacer.
|
||||||
func NewReplacer() *Replacer {
|
func NewReplacer() *Replacer {
|
||||||
rep := &Replacer{
|
rep := &Replacer{
|
||||||
static: make(map[string]string),
|
static: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
rep.providers = []ReplacerFunc{
|
rep.providers = []ReplacerFunc{
|
||||||
globalDefaultReplacements,
|
globalDefaultReplacements,
|
||||||
|
@ -41,7 +41,7 @@ func NewReplacer() *Replacer {
|
||||||
// use NewReplacer to make one.
|
// use NewReplacer to make one.
|
||||||
type Replacer struct {
|
type Replacer struct {
|
||||||
providers []ReplacerFunc
|
providers []ReplacerFunc
|
||||||
static map[string]string
|
static map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map adds mapFunc to the list of value providers.
|
// Map adds mapFunc to the list of value providers.
|
||||||
|
@ -51,19 +51,19 @@ func (r *Replacer) Map(mapFunc ReplacerFunc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set sets a custom variable to a static value.
|
// Set sets a custom variable to a static value.
|
||||||
func (r *Replacer) Set(variable, value string) {
|
func (r *Replacer) Set(variable string, value interface{}) {
|
||||||
r.static[variable] = value
|
r.static[variable] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets a value from the replacer. It returns
|
// Get gets a value from the replacer. It returns
|
||||||
// the value and whether the variable was known.
|
// the value and whether the variable was known.
|
||||||
func (r *Replacer) Get(variable string) (string, bool) {
|
func (r *Replacer) Get(variable string) (interface{}, bool) {
|
||||||
for _, mapFunc := range r.providers {
|
for _, mapFunc := range r.providers {
|
||||||
if val, ok := mapFunc(variable); ok {
|
if val, ok := mapFunc(variable); ok {
|
||||||
return val, true
|
return val, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return "", false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a variable with a static value
|
// Delete removes a variable with a static value
|
||||||
|
@ -73,9 +73,9 @@ func (r *Replacer) Delete(variable string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// fromStatic provides values from r.static.
|
// fromStatic provides values from r.static.
|
||||||
func (r *Replacer) fromStatic(key string) (val string, ok bool) {
|
func (r *Replacer) fromStatic(key string) (interface{}, bool) {
|
||||||
val, ok = r.static[key]
|
val, ok := r.static[key]
|
||||||
return
|
return val, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceOrErr is like ReplaceAll, but any placeholders
|
// ReplaceOrErr is like ReplaceAll, but any placeholders
|
||||||
|
@ -102,10 +102,9 @@ func (r *Replacer) ReplaceAll(input, empty string) string {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceFunc calls ReplaceAll efficiently replaces placeholders in input with
|
// ReplaceFunc is the same as ReplaceAll, but calls f for every
|
||||||
// their values. All placeholders are replaced in the output
|
// replacement to be made, in case f wants to change or inspect
|
||||||
// whether they are recognized or not. Values that are empty
|
// the replacement.
|
||||||
// string will be substituted with empty.
|
|
||||||
func (r *Replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) {
|
func (r *Replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) {
|
||||||
return r.replace(input, "", true, false, false, f)
|
return r.replace(input, "", true, false, false, f)
|
||||||
}
|
}
|
||||||
|
@ -125,7 +124,7 @@ func (r *Replacer) replace(input, empty string,
|
||||||
|
|
||||||
// iterate the input to find each placeholder
|
// iterate the input to find each placeholder
|
||||||
var lastWriteCursor int
|
var lastWriteCursor int
|
||||||
|
|
||||||
scan:
|
scan:
|
||||||
for i := 0; i < len(input); i++ {
|
for i := 0; i < len(input); i++ {
|
||||||
|
|
||||||
|
@ -169,9 +168,8 @@ scan:
|
||||||
return "", fmt.Errorf("unrecognized placeholder %s%s%s",
|
return "", fmt.Errorf("unrecognized placeholder %s%s%s",
|
||||||
string(phOpen), key, string(phClose))
|
string(phOpen), key, string(phClose))
|
||||||
} else if !treatUnknownAsEmpty {
|
} else if !treatUnknownAsEmpty {
|
||||||
// if treatUnknownAsEmpty is true, we'll
|
// if treatUnknownAsEmpty is true, we'll handle an empty
|
||||||
// handle an empty val later; so only
|
// val later; so only continue otherwise
|
||||||
// continue otherwise
|
|
||||||
lastWriteCursor = i
|
lastWriteCursor = i
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -186,9 +184,12 @@ scan:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convert val to a string as efficiently as possible
|
||||||
|
valStr := toString(val)
|
||||||
|
|
||||||
// write the value; if it's empty, either return
|
// write the value; if it's empty, either return
|
||||||
// an error or write a default value
|
// an error or write a default value
|
||||||
if val == "" {
|
if valStr == "" {
|
||||||
if errOnEmpty {
|
if errOnEmpty {
|
||||||
return "", fmt.Errorf("evaluated placeholder %s%s%s is empty",
|
return "", fmt.Errorf("evaluated placeholder %s%s%s is empty",
|
||||||
string(phOpen), key, string(phClose))
|
string(phOpen), key, string(phClose))
|
||||||
|
@ -196,7 +197,7 @@ scan:
|
||||||
sb.WriteString(empty)
|
sb.WriteString(empty)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString(val)
|
sb.WriteString(valStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// advance cursor to end of placeholder
|
// advance cursor to end of placeholder
|
||||||
|
@ -210,14 +211,54 @@ scan:
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toString(val interface{}) string {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return ""
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
case fmt.Stringer:
|
||||||
|
return v.String()
|
||||||
|
case byte:
|
||||||
|
return string(v)
|
||||||
|
case []byte:
|
||||||
|
return string(v)
|
||||||
|
case []rune:
|
||||||
|
return string(v)
|
||||||
|
case int:
|
||||||
|
return strconv.Itoa(v)
|
||||||
|
case int32:
|
||||||
|
return strconv.Itoa(int(v))
|
||||||
|
case int64:
|
||||||
|
return strconv.Itoa(int(v))
|
||||||
|
case uint:
|
||||||
|
return strconv.Itoa(int(v))
|
||||||
|
case uint32:
|
||||||
|
return strconv.Itoa(int(v))
|
||||||
|
case uint64:
|
||||||
|
return strconv.Itoa(int(v))
|
||||||
|
case float32:
|
||||||
|
return strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||||
|
case float64:
|
||||||
|
return strconv.FormatFloat(v, 'f', -1, 64)
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
return "true"
|
||||||
|
}
|
||||||
|
return "false"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%+v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ReplacerFunc is a function that returns a replacement
|
// ReplacerFunc is a function that returns a replacement
|
||||||
// for the given key along with true if the function is able
|
// for the given key along with true if the function is able
|
||||||
// to service that key (even if the value is blank). If the
|
// to service that key (even if the value is blank). If the
|
||||||
// function does not recognize the key, false should be
|
// function does not recognize the key, false should be
|
||||||
// returned.
|
// returned.
|
||||||
type ReplacerFunc func(key string) (val string, ok bool)
|
type ReplacerFunc func(key string) (interface{}, bool)
|
||||||
|
|
||||||
func globalDefaultReplacements(key string) (string, bool) {
|
func globalDefaultReplacements(key string) (interface{}, bool) {
|
||||||
// check environment variable
|
// check environment variable
|
||||||
const envPrefix = "env."
|
const envPrefix = "env."
|
||||||
if strings.HasPrefix(key, envPrefix) {
|
if strings.HasPrefix(key, envPrefix) {
|
||||||
|
@ -241,7 +282,7 @@ func globalDefaultReplacements(key string) (string, bool) {
|
||||||
return strconv.Itoa(nowFunc().Year()), true
|
return strconv.Itoa(nowFunc().Year()), true
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplacementFunc is a function that is called when a
|
// ReplacementFunc is a function that is called when a
|
||||||
|
@ -250,7 +291,7 @@ func globalDefaultReplacements(key string) (string, bool) {
|
||||||
// will be the replacement, and returns the value that
|
// will be the replacement, and returns the value that
|
||||||
// will actually be the replacement, or an error. Note
|
// will actually be the replacement, or an error. Note
|
||||||
// that errors are sometimes ignored by replacers.
|
// that errors are sometimes ignored by replacers.
|
||||||
type ReplacementFunc func(variable, val string) (string, error)
|
type ReplacementFunc func(variable string, val interface{}) (interface{}, error)
|
||||||
|
|
||||||
// nowFunc is a variable so tests can change it
|
// nowFunc is a variable so tests can change it
|
||||||
// in order to obtain a deterministic time.
|
// in order to obtain a deterministic time.
|
||||||
|
|
|
@ -173,41 +173,12 @@ func TestReplacer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkReplacer(b *testing.B) {
|
|
||||||
type testCase struct {
|
|
||||||
name, input, empty string
|
|
||||||
}
|
|
||||||
|
|
||||||
rep := testReplacer()
|
|
||||||
|
|
||||||
for _, bm := range []testCase{
|
|
||||||
{
|
|
||||||
name: "no placeholder",
|
|
||||||
input: `simple string`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "placeholder",
|
|
||||||
input: `{"json": "object"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "escaped placeholder",
|
|
||||||
input: `\{"json": \{"nested": "{bar}"\}\}`,
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
b.Run(bm.name, func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
rep.ReplaceAll(bm.input, bm.empty)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReplacerSet(t *testing.T) {
|
func TestReplacerSet(t *testing.T) {
|
||||||
rep := testReplacer()
|
rep := testReplacer()
|
||||||
|
|
||||||
for _, tc := range []struct {
|
for _, tc := range []struct {
|
||||||
variable string
|
variable string
|
||||||
value string
|
value interface{}
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
variable: "test1",
|
variable: "test1",
|
||||||
|
@ -217,6 +188,10 @@ func TestReplacerSet(t *testing.T) {
|
||||||
variable: "asdf",
|
variable: "asdf",
|
||||||
value: "123",
|
value: "123",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
variable: "numbers",
|
||||||
|
value: 123.456,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
variable: "äöü",
|
variable: "äöü",
|
||||||
value: "öö_äü",
|
value: "öö_äü",
|
||||||
|
@ -252,7 +227,7 @@ func TestReplacerSet(t *testing.T) {
|
||||||
|
|
||||||
// test if all keys are still there (by length)
|
// test if all keys are still there (by length)
|
||||||
length := len(rep.static)
|
length := len(rep.static)
|
||||||
if len(rep.static) != 7 {
|
if len(rep.static) != 8 {
|
||||||
t.Errorf("Expected length '%v' got '%v'", 7, length)
|
t.Errorf("Expected length '%v' got '%v'", 7, length)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -261,7 +236,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
|
||||||
rep := Replacer{
|
rep := Replacer{
|
||||||
providers: []ReplacerFunc{
|
providers: []ReplacerFunc{
|
||||||
// split our possible vars to two functions (to test if both functions are called)
|
// split our possible vars to two functions (to test if both functions are called)
|
||||||
func(key string) (val string, ok bool) {
|
func(key string) (val interface{}, ok bool) {
|
||||||
switch key {
|
switch key {
|
||||||
case "test1":
|
case "test1":
|
||||||
return "val1", true
|
return "val1", true
|
||||||
|
@ -275,7 +250,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
|
||||||
return "NOOO", false
|
return "NOOO", false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
func(key string) (val string, ok bool) {
|
func(key string) (val interface{}, ok bool) {
|
||||||
switch key {
|
switch key {
|
||||||
case "1":
|
case "1":
|
||||||
return "test-123", true
|
return "test-123", true
|
||||||
|
@ -331,7 +306,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
|
||||||
|
|
||||||
func TestReplacerDelete(t *testing.T) {
|
func TestReplacerDelete(t *testing.T) {
|
||||||
rep := Replacer{
|
rep := Replacer{
|
||||||
static: map[string]string{
|
static: map[string]interface{}{
|
||||||
"key1": "val1",
|
"key1": "val1",
|
||||||
"key2": "val2",
|
"key2": "val2",
|
||||||
"key3": "val3",
|
"key3": "val3",
|
||||||
|
@ -366,10 +341,10 @@ func TestReplacerMap(t *testing.T) {
|
||||||
rep := testReplacer()
|
rep := testReplacer()
|
||||||
|
|
||||||
for i, tc := range []ReplacerFunc{
|
for i, tc := range []ReplacerFunc{
|
||||||
func(key string) (val string, ok bool) {
|
func(key string) (val interface{}, ok bool) {
|
||||||
return "", false
|
return "", false
|
||||||
},
|
},
|
||||||
func(key string) (val string, ok bool) {
|
func(key string) (val interface{}, ok bool) {
|
||||||
return "", false
|
return "", false
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
|
@ -434,12 +409,50 @@ func TestReplacerNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReplacer(b *testing.B) {
|
||||||
|
type testCase struct {
|
||||||
|
name, input, empty string
|
||||||
|
}
|
||||||
|
|
||||||
|
rep := testReplacer()
|
||||||
|
rep.Set("str", "a string")
|
||||||
|
rep.Set("int", 123.456)
|
||||||
|
|
||||||
|
for _, bm := range []testCase{
|
||||||
|
{
|
||||||
|
name: "no placeholder",
|
||||||
|
input: `simple string`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string replacement",
|
||||||
|
input: `str={str}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int replacement",
|
||||||
|
input: `int={int}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "placeholder",
|
||||||
|
input: `{"json": "object"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped placeholder",
|
||||||
|
input: `\{"json": \{"nested": "{bar}"\}\}`,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rep.ReplaceAll(bm.input, bm.empty)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testReplacer() Replacer {
|
func testReplacer() Replacer {
|
||||||
return Replacer{
|
return Replacer{
|
||||||
providers: make([]ReplacerFunc, 0),
|
providers: make([]ReplacerFunc, 0),
|
||||||
static: make(map[string]string),
|
static: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue