Export Replacer and use concrete type instead of interface

The interface was only making things difficult; a concrete pointer is
probably best.
This commit is contained in:
Matthew Holt 2019-12-29 13:12:52 -07:00
parent 2b33d9a5e5
commit 95d944613b
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
19 changed files with 128 additions and 134 deletions

View file

@ -76,7 +76,7 @@ func (a Authentication) ServeHTTP(w http.ResponseWriter, r *http.Request, next c
return caddyhttp.Error(http.StatusUnauthorized, fmt.Errorf("not authenticated")) return caddyhttp.Error(http.StatusUnauthorized, fmt.Errorf("not authenticated"))
} }
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
repl.Set("http.authentication.user.id", user.ID) repl.Set("http.authentication.user.id", user.ID)
return next.ServeHTTP(w, r) return next.ServeHTTP(w, r)

View file

@ -52,7 +52,7 @@ func (fsrv *FileServer) serveBrowse(dirPath string, w http.ResponseWriter, r *ht
} }
defer dir.Close() defer dir.Close()
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
// calling path.Clean here prevents weird breadcrumbs when URL paths are sketchy like /%2e%2e%2f // calling path.Clean here prevents weird breadcrumbs when URL paths are sketchy like /%2e%2e%2f
listing, err := fsrv.loadDirectoryContents(dir, path.Clean(r.URL.Path), repl) listing, err := fsrv.loadDirectoryContents(dir, path.Clean(r.URL.Path), repl)
@ -87,7 +87,7 @@ func (fsrv *FileServer) serveBrowse(dirPath string, w http.ResponseWriter, r *ht
return nil return nil
} }
func (fsrv *FileServer) loadDirectoryContents(dir *os.File, urlPath string, repl caddy.Replacer) (browseListing, error) { func (fsrv *FileServer) loadDirectoryContents(dir *os.File, urlPath string, repl *caddy.Replacer) (browseListing, error) {
files, err := dir.Readdir(-1) files, err := dir.Readdir(-1)
if err != nil { if err != nil {
return browseListing{}, err return browseListing{}, err

View file

@ -27,7 +27,7 @@ import (
"github.com/dustin/go-humanize" "github.com/dustin/go-humanize"
) )
func (fsrv *FileServer) directoryListing(files []os.FileInfo, canGoUp bool, urlPath string, repl caddy.Replacer) browseListing { func (fsrv *FileServer) directoryListing(files []os.FileInfo, canGoUp bool, urlPath string, repl *caddy.Replacer) browseListing {
filesToHide := fsrv.transformHidePaths(repl) filesToHide := fsrv.transformHidePaths(repl)
var ( var (

View file

@ -126,7 +126,7 @@ func (m MatchFile) Validate() error {
// - http.matchers.file.relative // - http.matchers.file.relative
// - http.matchers.file.absolute // - http.matchers.file.absolute
func (m MatchFile) Match(r *http.Request) bool { func (m MatchFile) Match(r *http.Request) bool {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
rel, abs, matched := m.selectFile(r) rel, abs, matched := m.selectFile(r)
if matched { if matched {
repl.Set("http.matchers.file.relative", rel) repl.Set("http.matchers.file.relative", rel)
@ -140,7 +140,7 @@ func (m MatchFile) Match(r *http.Request) bool {
// It returns the root-relative path to the matched file, the full // It returns the root-relative path to the matched file, the full
// or absolute path, and whether a match was made. // or absolute path, and whether a match was made.
func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
root := repl.ReplaceAll(m.Root, ".") root := repl.ReplaceAll(m.Root, ".")

View file

@ -106,7 +106,7 @@ func (fsrv *FileServer) Provision(ctx caddy.Context) error {
} }
func (fsrv *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { func (fsrv *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
filesToHide := fsrv.transformHidePaths(repl) filesToHide := fsrv.transformHidePaths(repl)
@ -293,7 +293,7 @@ func mapDirOpenError(originalErr error, name string) error {
// transformHidePaths performs replacements for all the elements of // transformHidePaths performs replacements for all the elements of
// fsrv.Hide and returns a new list of the transformed values. // fsrv.Hide and returns a new list of the transformed values.
func (fsrv *FileServer) transformHidePaths(repl caddy.Replacer) []string { func (fsrv *FileServer) transformHidePaths(repl *caddy.Replacer) []string {
hide := make([]string, len(fsrv.Hide)) hide := make([]string, len(fsrv.Hide))
for i := range fsrv.Hide { for i := range fsrv.Hide {
hide[i] = repl.ReplaceAll(fsrv.Hide[i], "") hide[i] = repl.ReplaceAll(fsrv.Hide[i], "")

View file

@ -88,7 +88,7 @@ func (h Handler) Validate() error {
} }
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
if h.Request != nil { if h.Request != nil {
h.Request.ApplyToRequest(r) h.Request.ApplyToRequest(r)
@ -182,7 +182,7 @@ type RespHeaderOps struct {
} }
// ApplyTo applies ops to hdr using repl. // ApplyTo applies ops to hdr using repl.
func (ops HeaderOps) ApplyTo(hdr http.Header, repl caddy.Replacer) { func (ops HeaderOps) ApplyTo(hdr http.Header, repl *caddy.Replacer) {
// add // add
for fieldName, vals := range ops.Add { for fieldName, vals := range ops.Add {
fieldName = repl.ReplaceAll(fieldName, "") fieldName = repl.ReplaceAll(fieldName, "")
@ -249,7 +249,7 @@ func (ops HeaderOps) ApplyTo(hdr http.Header, repl caddy.Replacer) {
// header which the standard library does not include with the // header which the standard library does not include with the
// header map with all the others. This method mutates r.Host. // header map with all the others. This method mutates r.Host.
func (ops HeaderOps) ApplyToRequest(r *http.Request) { func (ops HeaderOps) ApplyToRequest(r *http.Request) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
// capture the current Host header so we can // capture the current Host header so we can
// reset to it when we're done // reset to it when we're done
@ -285,7 +285,7 @@ func (ops HeaderOps) ApplyToRequest(r *http.Request) {
// operations until WriteHeader is called. // operations until WriteHeader is called.
type responseWriterWrapper struct { type responseWriterWrapper struct {
*caddyhttp.ResponseWriterWrapper *caddyhttp.ResponseWriterWrapper
replacer caddy.Replacer replacer *caddy.Replacer
require *caddyhttp.ResponseMatcher require *caddyhttp.ResponseMatcher
headerOps *HeaderOps headerOps *HeaderOps
wroteHeader bool wroteHeader bool

View file

@ -118,7 +118,7 @@ func (m MatchHost) Match(r *http.Request) bool {
reqHost = strings.TrimSuffix(reqHost, "]") reqHost = strings.TrimSuffix(reqHost, "]")
} }
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
outer: outer:
for _, host := range m { for _, host := range m {
@ -223,7 +223,7 @@ func (MatchPathRE) CaddyModule() caddy.ModuleInfo {
// Match returns true if r matches m. // Match returns true if r matches m.
func (m MatchPathRE) Match(r *http.Request) bool { func (m MatchPathRE) Match(r *http.Request) bool {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
return m.MatchRegexp.Match(r.URL.Path, repl) return m.MatchRegexp.Match(r.URL.Path, repl)
} }
@ -380,7 +380,7 @@ func (m *MatchHeaderRE) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// Match returns true if r matches m. // Match returns true if r matches m.
func (m MatchHeaderRE) Match(r *http.Request) bool { func (m MatchHeaderRE) Match(r *http.Request) bool {
for field, rm := range m { for field, rm := range m {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
match := rm.Match(r.Header.Get(field), repl) match := rm.Match(r.Header.Get(field), repl)
if !match { if !match {
return false return false
@ -652,10 +652,8 @@ func (mre *MatchRegexp) Validate() error {
// Match returns true if input matches the compiled regular // Match returns true if input matches the compiled regular
// expression in mre. It sets values on the replacer repl // expression in mre. It sets values on the replacer repl
// associated with capture groups, using the given scope // associated with capture groups, using the given scope
// (namespace). Capture groups stored to repl will take on // (namespace).
// the name "http.matchers.<scope>.<mre.Name>.<N>" where func (mre *MatchRegexp) Match(input string, repl *caddy.Replacer) bool {
// <N> is the name or number of the capture group.
func (mre *MatchRegexp) Match(input string, repl caddy.Replacer) bool {
matches := mre.compiled.FindStringSubmatch(input) matches := mre.compiled.FindStringSubmatch(input)
if matches == nil { if matches == nil {
return false return false

View file

@ -26,7 +26,7 @@ import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
) )
func addHTTPVarsToReplacer(repl caddy.Replacer, req *http.Request, w http.ResponseWriter) { func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) {
httpVars := func(key string) (string, bool) { httpVars := func(key string) (string, bool) {
if req != nil { if req != nil {
// query string parameters // query string parameters

View file

@ -148,7 +148,7 @@ func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
// buildEnv returns a set of CGI environment variables for the request. // buildEnv returns a set of CGI environment variables for the request.
func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { func (t Transport) buildEnv(r *http.Request) (map[string]string, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
var env map[string]string var env map[string]string

View file

@ -204,7 +204,7 @@ func (di DialInfo) String() string {
// fillDialInfo returns a filled DialInfo for the given upstream, using // fillDialInfo returns a filled DialInfo for the given upstream, using
// the given Replacer. Note that the returned value is not a pointer. // the given Replacer. Note that the returned value is not a pointer.
func fillDialInfo(upstream *Upstream, repl caddy.Replacer) (DialInfo, error) { func fillDialInfo(upstream *Upstream, repl *caddy.Replacer) (DialInfo, error) {
dial := repl.ReplaceAll(upstream.Dial, "") dial := repl.ReplaceAll(upstream.Dial, "")
addr, err := caddy.ParseNetworkAddress(dial) addr, err := caddy.ParseNetworkAddress(dial)
if err != nil { if err != nil {

View file

@ -263,7 +263,7 @@ func (h *Handler) Cleanup() error {
} }
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
// if enabled, buffer client request; // if enabled, buffer client request;
// this should only be enabled if the // this should only be enabled if the
@ -507,7 +507,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, di Dia
if h.Headers != nil && h.Headers.Response != nil { if h.Headers != nil && h.Headers.Response != nil {
if h.Headers.Response.Require == nil || if h.Headers.Response.Require == nil ||
h.Headers.Response.Require.Match(res.StatusCode, rw.Header()) { h.Headers.Response.Require.Match(res.StatusCode, rw.Header()) {
repl := req.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
h.Headers.Response.ApplyTo(rw.Header(), repl) h.Headers.Response.ApplyTo(rw.Header(), repl)
} }
} }

View file

@ -91,7 +91,7 @@ func (rewr Rewrite) Validate() error {
} }
func (rewr Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { func (rewr Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
logger := rewr.logger.With( logger := rewr.logger.With(
zap.Object("request", caddyhttp.LoggableHTTPRequest{Request: r}), zap.Object("request", caddyhttp.LoggableHTTPRequest{Request: r}),
@ -124,7 +124,7 @@ func (rewr Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy
// rewrite performs the rewrites on r using repl, which // rewrite performs the rewrites on r using repl, which
// should have been obtained from r, but is passed in for // should have been obtained from r, but is passed in for
// efficiency. It returns true if any changes were made to r. // efficiency. It returns true if any changes were made to r.
func (rewr Rewrite) rewrite(r *http.Request, repl caddy.Replacer, logger *zap.Logger) bool { func (rewr Rewrite) rewrite(r *http.Request, repl *caddy.Replacer, logger *zap.Logger) bool {
oldMethod := r.Method oldMethod := r.Method
oldURI := r.RequestURI oldURI := r.RequestURI
@ -209,7 +209,7 @@ type replacer struct {
} }
// do performs the replacement on r and returns true if any changes were made. // do performs the replacement on r and returns true if any changes were made.
func (rep replacer) do(r *http.Request, repl caddy.Replacer) bool { func (rep replacer) do(r *http.Request, repl *caddy.Replacer) bool {
if rep.Find == "" || rep.Replace == "" { if rep.Find == "" || rep.Replace == "" {
return false return false
} }

View file

@ -451,7 +451,7 @@ func (*HTTPErrorConfig) WithError(r *http.Request, err error) *http.Request {
r = r.WithContext(c) r = r.WithContext(c)
// add error values to the replacer // add error values to the replacer
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
repl.Set("http.error", err.Error()) repl.Set("http.error", err.Error())
if handlerErr, ok := err.(HandlerError); ok { if handlerErr, ok := err.(HandlerError); ok {
repl.Set("http.error.status_code", strconv.Itoa(handlerErr.StatusCode)) repl.Set("http.error.status_code", strconv.Itoa(handlerErr.StatusCode))

View file

@ -51,7 +51,7 @@ func (StaticError) CaddyModule() caddy.ModuleInfo {
} }
func (e StaticError) ServeHTTP(w http.ResponseWriter, r *http.Request, _ Handler) error { func (e StaticError) ServeHTTP(w http.ResponseWriter, r *http.Request, _ Handler) error {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
statusCode := http.StatusInternalServerError statusCode := http.StatusInternalServerError
if codeStr := e.StatusCode.String(); codeStr != "" { if codeStr := e.StatusCode.String(); codeStr != "" {

View file

@ -86,7 +86,7 @@ func (s *StaticResponse) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
} }
func (s StaticResponse) ServeHTTP(w http.ResponseWriter, r *http.Request, _ Handler) error { func (s StaticResponse) ServeHTTP(w http.ResponseWriter, r *http.Request, _ Handler) error {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
// close the connection after responding // close the connection after responding
r.Close = s.Close r.Close = s.Close

View file

@ -122,7 +122,7 @@ func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy
func (t *Templates) executeTemplate(rr caddyhttp.ResponseRecorder, r *http.Request) error { func (t *Templates) executeTemplate(rr caddyhttp.ResponseRecorder, r *http.Request) error {
var fs http.FileSystem var fs http.FileSystem
if t.FileRoot != "" { if t.FileRoot != "" {
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
fs = http.Dir(repl.ReplaceAll(t.FileRoot, ".")) fs = http.Dir(repl.ReplaceAll(t.FileRoot, "."))
} }

View file

@ -40,7 +40,7 @@ func (VarsMiddleware) CaddyModule() caddy.ModuleInfo {
func (t VarsMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next Handler) error { func (t VarsMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next Handler) error {
vars := r.Context().Value(VarsCtxKey).(map[string]interface{}) vars := r.Context().Value(VarsCtxKey).(map[string]interface{})
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
for k, v := range t { for k, v := range t {
keyExpanded := repl.ReplaceAll(k, "") keyExpanded := repl.ReplaceAll(k, "")
valExpanded := repl.ReplaceAll(v, "") valExpanded := repl.ReplaceAll(v, "")
@ -64,7 +64,7 @@ func (VarsMatcher) CaddyModule() caddy.ModuleInfo {
// Match matches a request based on variables in the context. // Match matches a request based on variables in the context.
func (m VarsMatcher) Match(r *http.Request) bool { func (m VarsMatcher) Match(r *http.Request) bool {
vars := r.Context().Value(VarsCtxKey).(map[string]string) vars := r.Context().Value(VarsCtxKey).(map[string]string)
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
for k, v := range m { for k, v := range m {
keyExpanded := repl.ReplaceAll(k, "") keyExpanded := repl.ReplaceAll(k, "")
valExpanded := repl.ReplaceAll(v, "") valExpanded := repl.ReplaceAll(v, "")

View file

@ -23,20 +23,9 @@ import (
"time" "time"
) )
// Replacer can replace values in strings.
type Replacer interface {
Set(variable, value string)
Delete(variable string)
Map(ReplacerFunc)
ReplaceAll(input, empty string) string
ReplaceKnown(input, empty string) string
ReplaceOrErr(input string, errOnEmpty, errOnUnknown bool) (string, error)
ReplaceFunc(input string, f ReplacementFunc) (string, error)
}
// NewReplacer returns a new Replacer. // NewReplacer returns a new Replacer.
func NewReplacer() Replacer { func NewReplacer() *Replacer {
rep := &replacer{ rep := &Replacer{
static: make(map[string]string), static: make(map[string]string),
} }
rep.providers = []ReplacerFunc{ rep.providers = []ReplacerFunc{
@ -46,30 +35,44 @@ func NewReplacer() Replacer {
return rep return rep
} }
type replacer struct { // Replacer can replace values in strings.
// A default/empty Replacer is not valid;
// use NewReplacer to make one.
type Replacer struct {
providers []ReplacerFunc providers []ReplacerFunc
static map[string]string static map[string]string
} }
// Map adds mapFunc to the list of value providers. // Map adds mapFunc to the list of value providers.
// mapFunc will be executed only at replace-time. // mapFunc will be executed only at replace-time.
func (r *replacer) Map(mapFunc ReplacerFunc) { func (r *Replacer) Map(mapFunc ReplacerFunc) {
r.providers = append(r.providers, mapFunc) r.providers = append(r.providers, mapFunc)
} }
// Set sets a custom variable to a static value. // Set sets a custom variable to a static value.
func (r *replacer) Set(variable, value string) { func (r *Replacer) Set(variable, value string) {
r.static[variable] = value r.static[variable] = value
} }
// Get gets a value from the replacer. It returns
// the value and whether the variable was known.
func (r *Replacer) Get(variable string) (string, bool) {
for _, mapFunc := range r.providers {
if val, ok := mapFunc(variable); ok {
return val, true
}
}
return "", false
}
// Delete removes a variable with a static value // Delete removes a variable with a static value
// that was created using Set. // that was created using Set.
func (r *replacer) Delete(variable string) { func (r *Replacer) Delete(variable string) {
delete(r.static, variable) delete(r.static, variable)
} }
// fromStatic provides values from r.static. // fromStatic provides values from r.static.
func (r *replacer) fromStatic(key string) (val string, ok bool) { func (r *Replacer) fromStatic(key string) (val string, ok bool) {
val, ok = r.static[key] val, ok = r.static[key]
return return
} }
@ -77,14 +80,14 @@ func (r *replacer) fromStatic(key string) (val string, ok bool) {
// ReplaceOrErr is like ReplaceAll, but any placeholders // ReplaceOrErr is like ReplaceAll, but any placeholders
// that are empty or not recognized will cause an error to // that are empty or not recognized will cause an error to
// be returned. // be returned.
func (r *replacer) ReplaceOrErr(input string, errOnEmpty, errOnUnknown bool) (string, error) { func (r *Replacer) ReplaceOrErr(input string, errOnEmpty, errOnUnknown bool) (string, error) {
return r.replace(input, "", false, errOnEmpty, errOnUnknown, nil) return r.replace(input, "", false, errOnEmpty, errOnUnknown, nil)
} }
// ReplaceKnown is like ReplaceAll but only replaces // ReplaceKnown is like ReplaceAll but only replaces
// placeholders that are known (recognized). Unrecognized // placeholders that are known (recognized). Unrecognized
// placeholders will remain in the output. // placeholders will remain in the output.
func (r *replacer) ReplaceKnown(input, empty string) string { func (r *Replacer) ReplaceKnown(input, empty string) string {
out, _ := r.replace(input, empty, false, false, false, nil) out, _ := r.replace(input, empty, false, false, false, nil)
return out return out
} }
@ -93,7 +96,7 @@ func (r *replacer) ReplaceKnown(input, empty string) string {
// their values. All placeholders are replaced in the output // their values. All placeholders are replaced in the output
// whether they are recognized or not. Values that are empty // whether they are recognized or not. Values that are empty
// string will be substituted with empty. // string will be substituted with empty.
func (r *replacer) ReplaceAll(input, empty string) string { func (r *Replacer) ReplaceAll(input, empty string) string {
out, _ := r.replace(input, empty, true, false, false, nil) out, _ := r.replace(input, empty, true, false, false, nil)
return out return out
} }
@ -102,11 +105,11 @@ func (r *replacer) ReplaceAll(input, empty string) string {
// their values. All placeholders are replaced in the output // their values. All placeholders are replaced in the output
// whether they are recognized or not. Values that are empty // whether they are recognized or not. Values that are empty
// string will be substituted with empty. // string will be substituted with empty.
func (r *replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) { func (r *Replacer) ReplaceFunc(input string, f ReplacementFunc) (string, error) {
return r.replace(input, "", true, false, false, f) return r.replace(input, "", true, false, false, f)
} }
func (r *replacer) replace(input, empty string, func (r *Replacer) replace(input, empty string,
treatUnknownAsEmpty, errOnEmpty, errOnUnknown bool, treatUnknownAsEmpty, errOnEmpty, errOnUnknown bool,
f ReplacementFunc) (string, error) { f ReplacementFunc) (string, error) {
if !strings.Contains(input, string(phOpen)) { if !strings.Contains(input, string(phOpen)) {
@ -138,12 +141,24 @@ func (r *replacer) replace(input, empty string,
// trim opening bracket // trim opening bracket
key := input[i+1 : end] key := input[i+1 : end]
// try to get a value for this key, // try to get a value for this key, handle empty values accordingly
// handle empty values accordingly val, found := r.Get(key)
var found bool if !found {
for _, mapFunc := range r.providers { // placeholder is unknown (unrecognized); handle accordingly
if val, ok := mapFunc(key); ok { if errOnUnknown {
found = true return "", fmt.Errorf("unrecognized placeholder %s%s%s",
string(phOpen), key, string(phClose))
} else if treatUnknownAsEmpty {
if empty != "" {
sb.WriteString(empty)
}
} else {
lastWriteCursor = i
continue
}
}
// apply any transformations
if f != nil { if f != nil {
var err error var err error
val, err = f(key, val) val, err = f(key, val)
@ -151,6 +166,9 @@ func (r *replacer) replace(input, empty string,
return "", err return "", err
} }
} }
// write the value; if it's empty, either return
// an error or write a default value
if val == "" { if val == "" {
if errOnEmpty { if errOnEmpty {
return "", fmt.Errorf("evaluated placeholder %s%s%s is empty", return "", fmt.Errorf("evaluated placeholder %s%s%s is empty",
@ -161,24 +179,6 @@ func (r *replacer) replace(input, empty string,
} else { } else {
sb.WriteString(val) sb.WriteString(val)
} }
break
}
}
if !found {
// placeholder is unknown (unrecognized), handle accordingly
switch {
case errOnUnknown:
return "", fmt.Errorf("unrecognized placeholder %s%s%s",
string(phOpen), key, string(phClose))
case treatUnknownAsEmpty:
if empty != "" {
sb.WriteString(empty)
}
default:
lastWriteCursor = i
continue
}
}
// advance cursor to end of placeholder // advance cursor to end of placeholder
i = end i = end

View file

@ -132,7 +132,7 @@ func TestReplacerSet(t *testing.T) {
} }
func TestReplacerReplaceKnown(t *testing.T) { func TestReplacerReplaceKnown(t *testing.T) {
rep := replacer{ rep := Replacer{
providers: []ReplacerFunc{ providers: []ReplacerFunc{
// split our possible vars to two functions (to test if both functions are called) // split our possible vars to two functions (to test if both functions are called)
func(key string) (val string, ok bool) { func(key string) (val string, ok bool) {
@ -204,7 +204,7 @@ func TestReplacerReplaceKnown(t *testing.T) {
} }
func TestReplacerDelete(t *testing.T) { func TestReplacerDelete(t *testing.T) {
rep := replacer{ rep := Replacer{
static: map[string]string{ static: map[string]string{
"key1": "val1", "key1": "val1",
"key2": "val2", "key2": "val2",
@ -264,10 +264,8 @@ func TestReplacerMap(t *testing.T) {
} }
func TestReplacerNew(t *testing.T) { func TestReplacerNew(t *testing.T) {
var tc = NewReplacer() rep := NewReplacer()
rep, ok := tc.(*replacer)
if ok {
if len(rep.providers) != 2 { if len(rep.providers) != 2 {
t.Errorf("Expected providers length '%v' got length '%v'", 2, len(rep.providers)) t.Errorf("Expected providers length '%v' got length '%v'", 2, len(rep.providers))
} else { } else {
@ -310,13 +308,11 @@ func TestReplacerNew(t *testing.T) {
} }
} }
} }
} else {
t.Errorf("Expected type of replacer %T got %T ", &replacer{}, tc)
}
} }
func testReplacer() replacer { func testReplacer() Replacer {
return replacer{ return Replacer{
providers: make([]ReplacerFunc, 0), providers: make([]ReplacerFunc, 0),
static: make(map[string]string), static: make(map[string]string),
} }