diff --git a/modules/caddyhttp/celmatcher.go b/modules/caddyhttp/celmatcher.go index 83e01cfb..bab0a075 100644 --- a/modules/caddyhttp/celmatcher.go +++ b/modules/caddyhttp/celmatcher.go @@ -15,6 +15,7 @@ package caddyhttp import ( + "crypto/x509/pkix" "encoding/json" "fmt" "net/http" @@ -199,6 +200,27 @@ func (cr celHTTPRequest) Equal(other ref.Val) ref.Val { func (celHTTPRequest) Type() ref.Type { return httpRequestCELType } func (cr celHTTPRequest) Value() interface{} { return cr } +var pkixNameCELType = types.NewTypeValue("pkix.Name", traits.ReceiverType) + +// celPkixName wraps an pkix.Name with +// methods to satisfy the ref.Val interface. +type celPkixName struct{ *pkix.Name } + +func (pn celPkixName) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + return pn.Name, nil +} +func (celPkixName) ConvertToType(typeVal ref.Type) ref.Val { + panic("not implemented") +} +func (pn celPkixName) Equal(other ref.Val) ref.Val { + if o, ok := other.Value().(string); ok { + return types.Bool(pn.Name.String() == o) + } + return types.ValOrErr(other, "%v is not comparable type", other) +} +func (celPkixName) Type() ref.Type { return pkixNameCELType } +func (pn celPkixName) Value() interface{} { return pn } + // celTypeAdapter can adapt our custom types to a CEL value. type celTypeAdapter struct{} @@ -206,6 +228,8 @@ func (celTypeAdapter) NativeToValue(value interface{}) ref.Val { switch v := value.(type) { case celHTTPRequest: return v + case pkix.Name: + return celPkixName{&v} case time.Time: // TODO: eliminate direct protobuf dependency, sigh -- just wrap stdlib time.Time instead... return types.Timestamp{Timestamp: ×tamp.Timestamp{Seconds: v.Unix(), Nanos: int32(v.Nanosecond())}} diff --git a/modules/caddyhttp/celmatcher_test.go b/modules/caddyhttp/celmatcher_test.go index 0e3b335d..a78eb5a4 100644 --- a/modules/caddyhttp/celmatcher_test.go +++ b/modules/caddyhttp/celmatcher_test.go @@ -15,6 +15,11 @@ package caddyhttp import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "net/http/httptest" "testing" "github.com/caddyserver/caddy/v2" @@ -27,7 +32,7 @@ func TestMatchExpressionProvision(t *testing.T) { wantErr bool }{ { - name: "boolean mtaches succeed", + name: "boolean matches succeed", expression: &MatchExpression{ Expr: "{http.request.uri.query} != ''", }, @@ -49,3 +54,71 @@ func TestMatchExpressionProvision(t *testing.T) { }) } } + +func TestMatchExpressionMatch(t *testing.T) { + + clientCert := []byte(`-----BEGIN CERTIFICATE----- +MIIB9jCCAV+gAwIBAgIBAjANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDDA1DYWRk +eSBUZXN0IENBMB4XDTE4MDcyNDIxMzUwNVoXDTI4MDcyMTIxMzUwNVowHTEbMBkG +A1UEAwwSY2xpZW50LmxvY2FsZG9tYWluMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB +iQKBgQDFDEpzF0ew68teT3xDzcUxVFaTII+jXH1ftHXxxP4BEYBU4q90qzeKFneF +z83I0nC0WAQ45ZwHfhLMYHFzHPdxr6+jkvKPASf0J2v2HDJuTM1bHBbik5Ls5eq+ +fVZDP8o/VHKSBKxNs8Goc2NTsr5b07QTIpkRStQK+RJALk4x9QIDAQABo0swSTAJ +BgNVHRMEAjAAMAsGA1UdDwQEAwIHgDAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8A +AAEwEwYDVR0lBAwwCgYIKwYBBQUHAwIwDQYJKoZIhvcNAQELBQADgYEANSjz2Sk+ +eqp31wM9il1n+guTNyxJd+FzVAH+hCZE5K+tCgVDdVFUlDEHHbS/wqb2PSIoouLV +3Q9fgDkiUod+uIK0IynzIKvw+Cjg+3nx6NQ0IM0zo8c7v398RzB4apbXKZyeeqUH +9fNwfEi+OoXR6s+upSKobCmLGLGi9Na5s5g= +-----END CERTIFICATE-----`) + + tests := []struct { + name string + expression *MatchExpression + wantErr bool + wantResult bool + clientCertificate []byte + }{ + { + name: "boolean matches succeed for placeholder http.request.tls.client.subject", + expression: &MatchExpression{ + Expr: "{http.request.tls.client.subject} == 'CN=client.localdomain'", + }, + clientCertificate: clientCert, + wantResult: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.expression.Provision(caddy.Context{}); (err != nil) != tt.wantErr { + t.Errorf("MatchExpression.Provision() error = %v, wantErr %v", err, tt.wantErr) + } + + req := httptest.NewRequest("GET", "https://example.com/foo", nil) + repl := caddy.NewReplacer() + ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl) + req = req.WithContext(ctx) + addHTTPVarsToReplacer(repl, req, httptest.NewRecorder()) + + if tt.clientCertificate != nil { + block, _ := pem.Decode(clientCert) + if block == nil { + t.Fatalf("failed to decode PEM certificate") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to decode PEM certificate: %v", err) + } + + req.TLS = &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{cert}, + } + } + + if tt.expression.Match(req) != tt.wantResult { + t.Errorf("MatchExpression.Match() expected to return '%t', for expression : '%s'", tt.wantResult, tt.expression) + } + + }) + } +}