From a9933aace1f8f96d9f32a68436cc4d18f683a7ef Mon Sep 17 00:00:00 2001 From: Mohammed Al Sahaf Date: Mon, 18 Sep 2023 00:00:00 +0300 Subject: [PATCH] refactor and tests --- modules/caddypki/adminapi.go | 36 +---- modules/caddypki/ca.go | 2 +- modules/caddypki/csr.go | 159 +++++++++++++++++++ modules/caddypki/csr_test.go | 292 +++++++++++++++++++++++++++++++++++ 4 files changed, 458 insertions(+), 31 deletions(-) create mode 100644 modules/caddypki/csr.go create mode 100644 modules/caddypki/csr_test.go diff --git a/modules/caddypki/adminapi.go b/modules/caddypki/adminapi.go index 448e59b9..22633114 100644 --- a/modules/caddypki/adminapi.go +++ b/modules/caddypki/adminapi.go @@ -177,35 +177,6 @@ func (a *adminAPI) handleCACerts(w http.ResponseWriter, r *http.Request) error { return nil } -type csrRequest struct { - // Custom name assigned to the CSR key. If empty, UUID is generated and assigned. - ID string `json:"id,omitempty"` - - // Customization knobs of the generated/loaded key, if desired. - // If empty, sane defaults will be managed internally without exposing their details - // to the user. At the moment, the default parameters are: - // { - // "type": "EC", - // "curve": "P-256" - // } - Key *struct { - // The key type to be used for signing the CSR. The possible types are: - // EC, RSA, and OKP. - Type string `json:"type"` - - // The curve to use with key types EC and OKP. - // If the Type is OKP, then acceptable curves are: Ed25519, or X25519 - // If the Type is EC, then acceptable curves are: P-256, P-384, or P-521 - Curve string `json:"curve,omitempty"` - - // Only used with RSA keys and accepts minimum of 2048. - Size int `json:"size,omitempty"` - } `json:"key,omitempty"` - - // SANs is a list of subject alternative names for the certificate. - SANs []string `json:"sans"` -} - func (a *adminAPI) handleCSRGeneration(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodPost { return caddy.APIError{ @@ -234,7 +205,12 @@ func (a *adminAPI) handleCSRGeneration(w http.ResponseWriter, r *http.Request) e if len(csrReq.ID) == 0 { csrReq.ID = uuid.New().String() } - + if err := csrReq.validate(); err != nil { + return caddy.APIError{ + HTTPStatus: http.StatusBadRequest, + Err: fmt.Errorf("invalid CSR request: %v", err), + } + } // Generate the CSR csr, err := ca.generateCSR(csrReq) if err != nil { diff --git a/modules/caddypki/ca.go b/modules/caddypki/ca.go index 326f1711..6d25b8f7 100644 --- a/modules/caddypki/ca.go +++ b/modules/caddypki/ca.go @@ -440,7 +440,7 @@ func (ca CA) generateCSR(csrReq csrRequest) (csr *x509.CertificateRequest, err e return nil, err } } else { - signer, err = keyutil.GenerateSigner(csrReq.Key.Type, csrReq.Key.Curve, csrReq.Key.Size) + signer, err = keyutil.GenerateSigner(csrReq.Key.Type.String(), csrReq.Key.Curve.String(), csrReq.Key.Size) if err != nil { return nil, err } diff --git a/modules/caddypki/csr.go b/modules/caddypki/csr.go new file mode 100644 index 00000000..2f379ef9 --- /dev/null +++ b/modules/caddypki/csr.go @@ -0,0 +1,159 @@ +package caddypki + +import ( + "encoding/json" + "fmt" +) + +// The key type to be used for signing the CSR. The possible types are: +// EC, RSA, and OKP. +type keyType string + +const ( + keyTypeEC keyType = "EC" + keyTypeRSA keyType = "RSA" + keyTypeOKP keyType = "OKP" +) + +var stringToKey = map[string]keyType{ + "EC": keyTypeEC, + "RSA": keyTypeRSA, + "OKP": keyTypeOKP, +} + +func (kt *keyType) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + switch s { + case string(keyTypeEC), string(keyTypeRSA), string(keyTypeOKP): + *kt = stringToKey[s] + default: + return fmt.Errorf("unknown key type: %s", s) + } + return nil +} + +func (kt keyType) String() string { + return string(kt) +} + +// The curve to use with key types EC and OKP. +// If the Type is OKP, then acceptable curves are: Ed25519, or X25519 +// If the Type is EC, then acceptable curves are: P-256, P-384, or P-521 +type curve string + +const ( + curveEd25519 curve = "Ed25519" + curveX25519 curve = "X25519" + curveP256 curve = "P-256" + curveP384 curve = "P-384" + curveP521 curve = "P-521" +) + +var stringToCurve = map[string]curve{ + "Ed25519": curveEd25519, + "X25519": curveX25519, + "P-256": curveP256, + "P-384": curveP384, + "P-521": curveP521, +} + +func (c *curve) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + switch s { + case string(curveEd25519), string(curveX25519), string(curveP256), string(curveP384), string(curveP521): + *c = stringToCurve[s] + default: + return fmt.Errorf("unknown curve: %s", s) + } + return nil +} + +func (c curve) String() string { + return string(c) +} + +type keyParameters struct { + // The key type to be used for signing the CSR. The possible types are: + // EC, RSA, and OKP. + // The value of this field is case-sensitive. + Type keyType `json:"type"` + + // The curve to use with key types EC and OKP. + // If the Type is OKP, then acceptable curves are: Ed25519, or X25519 + // If the Type is EC, then acceptable curves are: P-256, P-384, or P-521 + // The value of this field is case-sensitive. + Curve curve `json:"curve,omitempty"` + + // Only used with RSA keys and accepts minimum of 2048. + Size int `json:"size,omitempty"` +} + +func (kp *keyParameters) validate() error { + if kp == nil { + return nil + } + + if kp.Type == keyTypeRSA { + if kp.Size < 2048 { + return fmt.Errorf("minimum RSA key size is 2048 bits: %v", kp.Size) + } + } + if kp.Type == keyTypeEC { + switch kp.Curve { + case curveP256, curveP384, curveP521: + return nil + default: + return fmt.Errorf("unrecognized EC curve: %v", kp.Curve) + } + } + if kp.Type == keyTypeOKP { + switch kp.Curve { + case curveEd25519, curveX25519: + return nil + default: + return fmt.Errorf("unrecognized OKP curve: %v", kp.Curve) + } + } + return nil +} + +type csrRequest struct { + // Custom name assigned to the CSR key. If empty, UUID is generated and assigned. + ID string `json:"id,omitempty"` + + // Customization knobs of the generated/loaded key, if desired. The format is: + // { + // // Valid values for type are: EC, RSA, and OKP. + // "type": "", + // + // // The curve to use with key types EC and OKP. + // // If the Type is OKP, then acceptable curves are: Ed25519, or X25519 + // // If the Type is EC, then acceptable curves are: P-256, P-384, or P-521 + // "curve": "", + // + // // Only used with RSA keys and accepts minimum of 2048. + // "size": 0 + // } + // + // If empty, sane defaults will be managed internally without exposing their details + // to the user. At the moment, the default parameters are: + // { + // "type": "EC", + // "curve": "P-256" + // } + // The values are case-sensitive. + Key *keyParameters `json:"key,omitempty"` + + // SANs is a list of subject alternative names for the certificate. + SANs []string `json:"sans"` +} + +func (c csrRequest) validate() error { + return c.Key.validate() +} diff --git a/modules/caddypki/csr_test.go b/modules/caddypki/csr_test.go new file mode 100644 index 00000000..19a56d65 --- /dev/null +++ b/modules/caddypki/csr_test.go @@ -0,0 +1,292 @@ +package caddypki + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestParseKeyType(t *testing.T) { + tests := []struct { + name string + input string + expected keyType + err string + }{ + { + name: "uppercase EC is recognized", + input: `"EC"`, + expected: keyTypeEC, + }, + { + name: "lowercase EC is recognized", + input: `"ec"`, + err: "unknown key type: ec", + }, + { + name: "mixed case EC is recognized", + input: `"eC"`, + err: "unknown key type: eC", + }, + { + name: "uppercase RSA is recognized", + input: `"RSA"`, + expected: keyTypeRSA, + }, + { + name: "lowercase rsa is not accepted", + input: `"rsa"`, + err: "unknown key type: rsa", + }, + { + name: "mixed case RSA is not accepted", + input: `"RsA"`, + err: "unknown key type: RsA", + }, + { + name: "uppercase OKP is recognized", + input: `"OKP"`, + expected: keyTypeOKP, + }, + { + name: "lowercase OKP is not accepted", + input: `"okp"`, + err: "unknown key type: okp", + }, + { + name: "mixed case OKP is not accepted", + input: `"OkP"`, + err: "unknown key type: OkP", + }, + { + name: "unknown key type is an error", + input: `"foo"`, + err: "unknown key type: foo", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var kt keyType + + err := json.Unmarshal([]byte(test.input), &kt) + if test.err != "" { + if err == nil { + t.Errorf("expected error %q, but got nil", test.err) + } + if err.Error() != test.err { + t.Errorf("expected error %q, but got %q", test.err, err.Error()) + } + return + } + if err != nil { + t.Errorf("expected no error, but got %q", err.Error()) + return + } + if kt != test.expected { + t.Errorf("expected %v, but got %v", test.expected, kt) + } + }) + } +} + +func TestCSRRequestValidate(t *testing.T) { + tests := []struct { + name string + key *keyParameters + wantErr bool + }{ + { + name: "empty request is valid", + key: nil, + wantErr: false, + }, + { + name: "RSA with size 2048 is valid", + key: &keyParameters{ + Type: keyTypeRSA, + Size: 2048, + }, + wantErr: false, + }, + { + name: "RSA with size less than 2048 is invalid", + key: &keyParameters{ + Type: keyTypeRSA, + Size: 1024, + }, + wantErr: true, + }, + { + name: "EC key with curve P-256 is valid", + key: &keyParameters{ + Type: keyTypeEC, + Curve: "P-256", + }, + wantErr: false, + }, + { + name: "EC key with curve P-256 is valid", + key: &keyParameters{ + Type: keyTypeEC, + Curve: "P-256", + }, + wantErr: false, + }, + { + name: "EC key with curve P-384 is valid", + key: &keyParameters{ + Type: keyTypeEC, + Curve: "P-384", + }, + wantErr: false, + }, + { + name: "EC key with curve P-521 is valid", + key: &keyParameters{ + Type: keyTypeEC, + Curve: "P-521", + }, + wantErr: false, + }, + { + name: "EC key with unknown curve is invalid", + key: &keyParameters{ + Type: keyTypeEC, + Curve: "foo", + }, + wantErr: true, + }, + { + name: "EC key with Ed25519 curve is invalid", + key: &keyParameters{ + Type: keyTypeEC, + Curve: "Ed25519", + }, + wantErr: true, + }, + { + name: "EC key with X25519 curve is invalid", + key: &keyParameters{ + Type: keyTypeEC, + Curve: "X25519", + }, + wantErr: true, + }, + { + name: "OKP key with curve Ed25519 is valid", + key: &keyParameters{ + Type: keyTypeOKP, + Curve: "Ed25519", + }, + wantErr: false, + }, + { + name: "OKP key with curve X25519 is valid", + key: &keyParameters{ + Type: keyTypeOKP, + Curve: "X25519", + }, + wantErr: false, + }, + { + name: "OKP with unknown curve is invalid", + key: &keyParameters{ + Type: keyTypeOKP, + Curve: "foo", + }, + wantErr: true, + }, + { + name: "OKP key with curve P-256 is invalid", + key: &keyParameters{ + Type: keyTypeOKP, + Curve: "P-256", + }, + wantErr: true, + }, + { + name: "OKP key with curve P-384 is invalid", + key: &keyParameters{ + Type: keyTypeOKP, + Curve: "P-384", + }, + wantErr: true, + }, + { + name: "OKP key with curve P-521 is invalid", + key: &keyParameters{ + Type: keyTypeOKP, + Curve: "P-521", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := csrRequest{ + Key: tt.key, + } + if err := c.validate(); (err != nil) != tt.wantErr { + t.Errorf("csrRequest.validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCSRRequestUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + request string + want csrRequest + err string + }{ + { + name: "empty request is valid", + request: "{}", + want: csrRequest{ + Key: nil, + }, + }, + { + name: "RSA with size 2048 is valid", + request: `{"key":{"type":"RSA","size":2048}}`, + want: csrRequest{ + Key: &keyParameters{ + Type: keyTypeRSA, + Size: 2048, + }, + }, + }, + { + name: "EC key with curve P-256 is valid", + request: `{"key":{"type":"EC","curve":"P-256"}}`, + want: csrRequest{ + Key: &keyParameters{ + Type: keyTypeEC, + Curve: "P-256", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var c csrRequest + err := json.Unmarshal([]byte(tt.request), &c) + if tt.err != "" { + if err == nil { + t.Errorf("expected error %q, but got nil", tt.err) + } + if err.Error() != tt.err { + t.Errorf("expected error %q, but got %q", tt.err, err.Error()) + } + } + if err != nil { + t.Errorf("expected no error, but got %q", err.Error()) + } + if !reflect.DeepEqual(c, tt.want) { + t.Errorf("csrRequest.unmarshalJSON() = %v, want %v", c, tt.want) + } + }) + } +}