refactor and tests

This commit is contained in:
Mohammed Al Sahaf 2023-09-18 00:00:00 +03:00
parent 37c6f1c5b6
commit a9933aace1
4 changed files with 458 additions and 31 deletions

View file

@ -177,35 +177,6 @@ func (a *adminAPI) handleCACerts(w http.ResponseWriter, r *http.Request) error {
return nil
}
type csrRequest struct {
// Custom name assigned to the CSR key. If empty, UUID is generated and assigned.
ID string `json:"id,omitempty"`
// Customization knobs of the generated/loaded key, if desired.
// If empty, sane defaults will be managed internally without exposing their details
// to the user. At the moment, the default parameters are:
// {
// "type": "EC",
// "curve": "P-256"
// }
Key *struct {
// The key type to be used for signing the CSR. The possible types are:
// EC, RSA, and OKP.
Type string `json:"type"`
// The curve to use with key types EC and OKP.
// If the Type is OKP, then acceptable curves are: Ed25519, or X25519
// If the Type is EC, then acceptable curves are: P-256, P-384, or P-521
Curve string `json:"curve,omitempty"`
// Only used with RSA keys and accepts minimum of 2048.
Size int `json:"size,omitempty"`
} `json:"key,omitempty"`
// SANs is a list of subject alternative names for the certificate.
SANs []string `json:"sans"`
}
func (a *adminAPI) handleCSRGeneration(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost {
return caddy.APIError{
@ -234,7 +205,12 @@ func (a *adminAPI) handleCSRGeneration(w http.ResponseWriter, r *http.Request) e
if len(csrReq.ID) == 0 {
csrReq.ID = uuid.New().String()
}
if err := csrReq.validate(); err != nil {
return caddy.APIError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("invalid CSR request: %v", err),
}
}
// Generate the CSR
csr, err := ca.generateCSR(csrReq)
if err != nil {

View file

@ -440,7 +440,7 @@ func (ca CA) generateCSR(csrReq csrRequest) (csr *x509.CertificateRequest, err e
return nil, err
}
} else {
signer, err = keyutil.GenerateSigner(csrReq.Key.Type, csrReq.Key.Curve, csrReq.Key.Size)
signer, err = keyutil.GenerateSigner(csrReq.Key.Type.String(), csrReq.Key.Curve.String(), csrReq.Key.Size)
if err != nil {
return nil, err
}

159
modules/caddypki/csr.go Normal file
View file

@ -0,0 +1,159 @@
package caddypki
import (
"encoding/json"
"fmt"
)
// The key type to be used for signing the CSR. The possible types are:
// EC, RSA, and OKP.
type keyType string
const (
keyTypeEC keyType = "EC"
keyTypeRSA keyType = "RSA"
keyTypeOKP keyType = "OKP"
)
var stringToKey = map[string]keyType{
"EC": keyTypeEC,
"RSA": keyTypeRSA,
"OKP": keyTypeOKP,
}
func (kt *keyType) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
switch s {
case string(keyTypeEC), string(keyTypeRSA), string(keyTypeOKP):
*kt = stringToKey[s]
default:
return fmt.Errorf("unknown key type: %s", s)
}
return nil
}
func (kt keyType) String() string {
return string(kt)
}
// The curve to use with key types EC and OKP.
// If the Type is OKP, then acceptable curves are: Ed25519, or X25519
// If the Type is EC, then acceptable curves are: P-256, P-384, or P-521
type curve string
const (
curveEd25519 curve = "Ed25519"
curveX25519 curve = "X25519"
curveP256 curve = "P-256"
curveP384 curve = "P-384"
curveP521 curve = "P-521"
)
var stringToCurve = map[string]curve{
"Ed25519": curveEd25519,
"X25519": curveX25519,
"P-256": curveP256,
"P-384": curveP384,
"P-521": curveP521,
}
func (c *curve) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
switch s {
case string(curveEd25519), string(curveX25519), string(curveP256), string(curveP384), string(curveP521):
*c = stringToCurve[s]
default:
return fmt.Errorf("unknown curve: %s", s)
}
return nil
}
func (c curve) String() string {
return string(c)
}
type keyParameters struct {
// The key type to be used for signing the CSR. The possible types are:
// EC, RSA, and OKP.
// The value of this field is case-sensitive.
Type keyType `json:"type"`
// The curve to use with key types EC and OKP.
// If the Type is OKP, then acceptable curves are: Ed25519, or X25519
// If the Type is EC, then acceptable curves are: P-256, P-384, or P-521
// The value of this field is case-sensitive.
Curve curve `json:"curve,omitempty"`
// Only used with RSA keys and accepts minimum of 2048.
Size int `json:"size,omitempty"`
}
func (kp *keyParameters) validate() error {
if kp == nil {
return nil
}
if kp.Type == keyTypeRSA {
if kp.Size < 2048 {
return fmt.Errorf("minimum RSA key size is 2048 bits: %v", kp.Size)
}
}
if kp.Type == keyTypeEC {
switch kp.Curve {
case curveP256, curveP384, curveP521:
return nil
default:
return fmt.Errorf("unrecognized EC curve: %v", kp.Curve)
}
}
if kp.Type == keyTypeOKP {
switch kp.Curve {
case curveEd25519, curveX25519:
return nil
default:
return fmt.Errorf("unrecognized OKP curve: %v", kp.Curve)
}
}
return nil
}
type csrRequest struct {
// Custom name assigned to the CSR key. If empty, UUID is generated and assigned.
ID string `json:"id,omitempty"`
// Customization knobs of the generated/loaded key, if desired. The format is:
// {
// // Valid values for type are: EC, RSA, and OKP.
// "type": "",
//
// // The curve to use with key types EC and OKP.
// // If the Type is OKP, then acceptable curves are: Ed25519, or X25519
// // If the Type is EC, then acceptable curves are: P-256, P-384, or P-521
// "curve": "",
//
// // Only used with RSA keys and accepts minimum of 2048.
// "size": 0
// }
//
// If empty, sane defaults will be managed internally without exposing their details
// to the user. At the moment, the default parameters are:
// {
// "type": "EC",
// "curve": "P-256"
// }
// The values are case-sensitive.
Key *keyParameters `json:"key,omitempty"`
// SANs is a list of subject alternative names for the certificate.
SANs []string `json:"sans"`
}
func (c csrRequest) validate() error {
return c.Key.validate()
}

View file

@ -0,0 +1,292 @@
package caddypki
import (
"encoding/json"
"reflect"
"testing"
)
func TestParseKeyType(t *testing.T) {
tests := []struct {
name string
input string
expected keyType
err string
}{
{
name: "uppercase EC is recognized",
input: `"EC"`,
expected: keyTypeEC,
},
{
name: "lowercase EC is recognized",
input: `"ec"`,
err: "unknown key type: ec",
},
{
name: "mixed case EC is recognized",
input: `"eC"`,
err: "unknown key type: eC",
},
{
name: "uppercase RSA is recognized",
input: `"RSA"`,
expected: keyTypeRSA,
},
{
name: "lowercase rsa is not accepted",
input: `"rsa"`,
err: "unknown key type: rsa",
},
{
name: "mixed case RSA is not accepted",
input: `"RsA"`,
err: "unknown key type: RsA",
},
{
name: "uppercase OKP is recognized",
input: `"OKP"`,
expected: keyTypeOKP,
},
{
name: "lowercase OKP is not accepted",
input: `"okp"`,
err: "unknown key type: okp",
},
{
name: "mixed case OKP is not accepted",
input: `"OkP"`,
err: "unknown key type: OkP",
},
{
name: "unknown key type is an error",
input: `"foo"`,
err: "unknown key type: foo",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var kt keyType
err := json.Unmarshal([]byte(test.input), &kt)
if test.err != "" {
if err == nil {
t.Errorf("expected error %q, but got nil", test.err)
}
if err.Error() != test.err {
t.Errorf("expected error %q, but got %q", test.err, err.Error())
}
return
}
if err != nil {
t.Errorf("expected no error, but got %q", err.Error())
return
}
if kt != test.expected {
t.Errorf("expected %v, but got %v", test.expected, kt)
}
})
}
}
func TestCSRRequestValidate(t *testing.T) {
tests := []struct {
name string
key *keyParameters
wantErr bool
}{
{
name: "empty request is valid",
key: nil,
wantErr: false,
},
{
name: "RSA with size 2048 is valid",
key: &keyParameters{
Type: keyTypeRSA,
Size: 2048,
},
wantErr: false,
},
{
name: "RSA with size less than 2048 is invalid",
key: &keyParameters{
Type: keyTypeRSA,
Size: 1024,
},
wantErr: true,
},
{
name: "EC key with curve P-256 is valid",
key: &keyParameters{
Type: keyTypeEC,
Curve: "P-256",
},
wantErr: false,
},
{
name: "EC key with curve P-256 is valid",
key: &keyParameters{
Type: keyTypeEC,
Curve: "P-256",
},
wantErr: false,
},
{
name: "EC key with curve P-384 is valid",
key: &keyParameters{
Type: keyTypeEC,
Curve: "P-384",
},
wantErr: false,
},
{
name: "EC key with curve P-521 is valid",
key: &keyParameters{
Type: keyTypeEC,
Curve: "P-521",
},
wantErr: false,
},
{
name: "EC key with unknown curve is invalid",
key: &keyParameters{
Type: keyTypeEC,
Curve: "foo",
},
wantErr: true,
},
{
name: "EC key with Ed25519 curve is invalid",
key: &keyParameters{
Type: keyTypeEC,
Curve: "Ed25519",
},
wantErr: true,
},
{
name: "EC key with X25519 curve is invalid",
key: &keyParameters{
Type: keyTypeEC,
Curve: "X25519",
},
wantErr: true,
},
{
name: "OKP key with curve Ed25519 is valid",
key: &keyParameters{
Type: keyTypeOKP,
Curve: "Ed25519",
},
wantErr: false,
},
{
name: "OKP key with curve X25519 is valid",
key: &keyParameters{
Type: keyTypeOKP,
Curve: "X25519",
},
wantErr: false,
},
{
name: "OKP with unknown curve is invalid",
key: &keyParameters{
Type: keyTypeOKP,
Curve: "foo",
},
wantErr: true,
},
{
name: "OKP key with curve P-256 is invalid",
key: &keyParameters{
Type: keyTypeOKP,
Curve: "P-256",
},
wantErr: true,
},
{
name: "OKP key with curve P-384 is invalid",
key: &keyParameters{
Type: keyTypeOKP,
Curve: "P-384",
},
wantErr: true,
},
{
name: "OKP key with curve P-521 is invalid",
key: &keyParameters{
Type: keyTypeOKP,
Curve: "P-521",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := csrRequest{
Key: tt.key,
}
if err := c.validate(); (err != nil) != tt.wantErr {
t.Errorf("csrRequest.validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestCSRRequestUnmarshalJSON(t *testing.T) {
tests := []struct {
name string
request string
want csrRequest
err string
}{
{
name: "empty request is valid",
request: "{}",
want: csrRequest{
Key: nil,
},
},
{
name: "RSA with size 2048 is valid",
request: `{"key":{"type":"RSA","size":2048}}`,
want: csrRequest{
Key: &keyParameters{
Type: keyTypeRSA,
Size: 2048,
},
},
},
{
name: "EC key with curve P-256 is valid",
request: `{"key":{"type":"EC","curve":"P-256"}}`,
want: csrRequest{
Key: &keyParameters{
Type: keyTypeEC,
Curve: "P-256",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var c csrRequest
err := json.Unmarshal([]byte(tt.request), &c)
if tt.err != "" {
if err == nil {
t.Errorf("expected error %q, but got nil", tt.err)
}
if err.Error() != tt.err {
t.Errorf("expected error %q, but got %q", tt.err, err.Error())
}
}
if err != nil {
t.Errorf("expected no error, but got %q", err.Error())
}
if !reflect.DeepEqual(c, tt.want) {
t.Errorf("csrRequest.unmarshalJSON() = %v, want %v", c, tt.want)
}
})
}
}