From e1ea58b7c4dcd9aea5f51383d0029830842a2994 Mon Sep 17 00:00:00 2001
From: elcore <eldinhadzic@protonmail.com>
Date: Mon, 3 Oct 2016 18:52:45 +0200
Subject: [PATCH] Customize curve preferences, closes #1117 (#1159)

* Feature Request: #1117

* The order of the curves matter
---
 caddytls/config.go     | 20 ++++++++++++++++++++
 caddytls/setup.go      |  8 ++++++++
 caddytls/setup_test.go | 37 +++++++++++++++++++++++++++++++++++++
 3 files changed, 65 insertions(+)

diff --git a/caddytls/config.go b/caddytls/config.go
index 80d301426..92e5729a5 100644
--- a/caddytls/config.go
+++ b/caddytls/config.go
@@ -35,6 +35,9 @@ type Config struct {
 	// Whether to prefer server cipher suites
 	PreferServerCipherSuites bool
 
+	// The list of preferred curves
+	CurvePreferences []tls.CurveID
+
 	// Client authentication policy
 	ClientAuth tls.ClientAuthType
 
@@ -220,6 +223,7 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
 
 	config := new(tls.Config)
 	ciphersAdded := make(map[uint16]struct{})
+	curvesAdded := make(map[tls.CurveID]struct{})
 	configMap := make(configGroup)
 
 	for i, cfg := range configs {
@@ -264,6 +268,14 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
 		}
 		config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
 
+		// Union curves
+		for _, curv := range cfg.CurvePreferences {
+			if _, ok := curvesAdded[curv]; !ok {
+				curvesAdded[curv] = struct{}{}
+				config.CurvePreferences = append(config.CurvePreferences, curv)
+			}
+		}
+
 		// Go with the widest range of protocol versions
 		if config.MinVersion == 0 || cfg.ProtocolMinVersion < config.MinVersion {
 			config.MinVersion = cfg.ProtocolMinVersion
@@ -441,6 +453,14 @@ var defaultCiphers = []uint16{
 	tls.TLS_RSA_WITH_AES_128_CBC_SHA,
 }
 
+// Map of supported curves
+// https://golang.org/pkg/crypto/tls/#CurveID
+var supportedCurvesMap = map[string]tls.CurveID{
+	"P256": tls.CurveP256,
+	"P384": tls.CurveP384,
+	"P521": tls.CurveP521,
+}
+
 const (
 	// HTTPChallengePort is the officially designated port for
 	// the HTTP challenge.
diff --git a/caddytls/setup.go b/caddytls/setup.go
index e782e04ea..8e822015b 100644
--- a/caddytls/setup.go
+++ b/caddytls/setup.go
@@ -105,6 +105,14 @@ func setupTLS(c *caddy.Controller) error {
 					}
 					config.Ciphers = append(config.Ciphers, value)
 				}
+			case "curves":
+				for c.NextArg() {
+					value, ok := supportedCurvesMap[strings.ToUpper(c.Val())]
+					if !ok {
+						return c.Errf("Wrong curve name or curve not supported: '%s'", c.Val())
+					}
+					config.CurvePreferences = append(config.CurvePreferences, value)
+				}
 			case "clients":
 				clientCertList := c.RemainingArgs()
 				if len(clientCertList) == 0 {
diff --git a/caddytls/setup_test.go b/caddytls/setup_test.go
index 2c18f1d1a..8811b2c63 100644
--- a/caddytls/setup_test.go
+++ b/caddytls/setup_test.go
@@ -179,6 +179,18 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
 	if err == nil {
 		t.Errorf("Expected errors, but no error returned")
 	}
+
+	// Test curves wrong params
+	params = `tls {
+			curves ab123, cd456, ef789
+		}`
+	cfg = new(Config)
+	RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
+	c = caddy.NewTestController("", params)
+	err = setupTLS(c)
+	if err == nil {
+		t.Errorf("Expected errors, but no error returned")
+	}
 }
 
 func TestSetupParseWithClientAuth(t *testing.T) {
@@ -269,6 +281,31 @@ func TestSetupParseWithKeyType(t *testing.T) {
 	}
 }
 
+func TestSetupParseWithCurves(t *testing.T) {
+	params := `tls {
+            curves p256 p384 p521
+        }`
+	cfg := new(Config)
+	RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
+	c := caddy.NewTestController("", params)
+
+	err := setupTLS(c)
+	if err != nil {
+		t.Errorf("Expected no errors, got: %v", err)
+	}
+
+	if len(cfg.CurvePreferences) != 3 {
+		t.Errorf("Expected 3 curves, got %v", len(cfg.CurvePreferences))
+	}
+
+	expectedCurveOrder := []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521}
+	for i := range cfg.CurvePreferences {
+		if cfg.CurvePreferences[i] != expectedCurveOrder[i] {
+			t.Errorf("Expected %v as curve, got %v", expectedCurveOrder[i], cfg.CurvePreferences[i])
+		}
+	}
+}
+
 func TestSetupParseWithOneTLSProtocol(t *testing.T) {
 	params := `tls {
             protocols tls1.2