mox/scram/scram_test.go
Mechiel Lukkien e7478ed6ac
implement the plus variants of scram, to bind the authentication exchange to the tls connection
to get the security benefits (detecting mitm attempts), explicitly configure
clients to use a scram plus variant, e.g. scram-sha-256-plus. unfortunately,
not many clients support it yet.

imapserver scram plus support seems to work with the latest imtest (imap test
client) from cyrus-sasl. no success yet with mutt (with gsasl) though.
2023-12-23 23:19:36 +01:00

300 lines
11 KiB
Go

package scram
import (
"crypto/ed25519"
cryptorand "crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"hash"
"math/big"
"net"
"testing"
"time"
)
func base64Decode(s string) []byte {
buf, err := base64.StdEncoding.DecodeString(s)
if err != nil {
panic("bad base64")
}
return buf
}
func tcheck(t *testing.T, err error, msg string) {
t.Helper()
if err != nil {
t.Fatalf("%s: %s", msg, err)
}
}
func TestSCRAMSHA1Server(t *testing.T) {
// Test vector from ../rfc/5802:496
salt := base64Decode("QSXCR+Q6sek8bf92")
saltedPassword := SaltPassword(sha1.New, "pencil", salt, 4096)
server, err := NewServer(sha1.New, []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL"), nil, false)
server.serverNonceOverride = "3rfcNHYJY1ZVvWVs7j"
tcheck(t, err, "newserver")
resp, err := server.ServerFirst(4096, salt)
tcheck(t, err, "server first")
if resp != "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096" {
t.Fatalf("bad server first")
}
serverFinal, err := server.Finish([]byte("c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts="), saltedPassword)
tcheck(t, err, "finish")
if serverFinal != "v=rmF9pqV8S7suAoZWja4dJRkFsKQ=" {
t.Fatalf("bad server final")
}
}
func TestSCRAMSHA256Server(t *testing.T) {
// Test vector from ../rfc/7677:122
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
saltedPassword := SaltPassword(sha256.New, "pencil", salt, 4096)
server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
tcheck(t, err, "newserver")
resp, err := server.ServerFirst(4096, salt)
tcheck(t, err, "server first")
if resp != "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" {
t.Fatalf("bad server first")
}
serverFinal, err := server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
tcheck(t, err, "finish")
if serverFinal != "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=" {
t.Fatalf("bad server final")
}
}
// Bad attempt with wrong password.
func TestScramServerBadPassword(t *testing.T) {
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
saltedPassword := SaltPassword(sha256.New, "marker", salt, 4096)
server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
tcheck(t, err, "newserver")
_, err = server.ServerFirst(4096, salt)
tcheck(t, err, "server first")
_, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
if !errors.Is(err, ErrInvalidProof) {
t.Fatalf("got %v, expected ErrInvalidProof", err)
}
}
// Bad attempt with different number of rounds.
func TestScramServerBadIterations(t *testing.T) {
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
saltedPassword := SaltPassword(sha256.New, "pencil", salt, 2048)
server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
tcheck(t, err, "newserver")
_, err = server.ServerFirst(4096, salt)
tcheck(t, err, "server first")
_, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
if !errors.Is(err, ErrInvalidProof) {
t.Fatalf("got %v, expected ErrInvalidProof", err)
}
}
// Another attempt but with a randomly different nonce.
func TestScramServerBad(t *testing.T) {
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
saltedPassword := SaltPassword(sha256.New, "pencil", salt, 4096)
server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
tcheck(t, err, "newserver")
_, err = server.ServerFirst(4096, salt)
tcheck(t, err, "server first")
_, err = server.Finish([]byte("c=biws,r="+server.nonce+",p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
if !errors.Is(err, ErrInvalidProof) {
t.Fatalf("got %v, expected ErrInvalidProof", err)
}
}
func TestScramClient(t *testing.T) {
c := NewClient(sha256.New, "user", "", false, nil)
c.clientNonce = "rOprNGfwEbeRWgbNEkqO"
clientFirst, err := c.ClientFirst()
tcheck(t, err, "ClientFirst")
if clientFirst != "n,,n=user,r=rOprNGfwEbeRWgbNEkqO" {
t.Fatalf("bad clientFirst")
}
clientFinal, err := c.ServerFirst([]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"), "pencil")
tcheck(t, err, "ServerFirst")
if clientFinal != "c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=" {
t.Fatalf("bad clientFinal")
}
err = c.ServerFinal([]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="))
tcheck(t, err, "ServerFinal")
}
func TestScram(t *testing.T) {
runHash := func(h func() hash.Hash, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string, noServerPlus bool, clientcs, servercs *tls.ConnectionState) {
t.Helper()
defer func() {
x := recover()
if x == nil || x == "" {
return
}
panic(x)
}()
// check err is either nil or the expected error. if the expected error, panic to abort the authentication session.
xerr := func(err error, msg string) {
t.Helper()
if err != nil && !errors.Is(err, expErr) {
t.Fatalf("%s: got %v, expected %v", msg, err, expErr)
}
if err != nil {
panic("") // Abort test.
}
}
salt := MakeRandom()
saltedPassword := SaltPassword(h, password, salt, iterations)
client := NewClient(h, username, "", noServerPlus, clientcs)
client.clientNonce = clientNonce
clientFirst, err := client.ClientFirst()
xerr(err, "client.ClientFirst")
server, err := NewServer(h, []byte(clientFirst), servercs, servercs != nil)
xerr(err, "NewServer")
server.serverNonceOverride = serverNonce
serverFirst, err := server.ServerFirst(iterations, salt)
xerr(err, "server.ServerFirst")
clientFinal, err := client.ServerFirst([]byte(serverFirst), password)
xerr(err, "client.ServerFirst")
serverFinal, err := server.Finish([]byte(clientFinal), saltedPassword)
xerr(err, "server.Finish")
err = client.ServerFinal([]byte(serverFinal))
xerr(err, "client.ServerFinal")
if expErr != nil {
t.Fatalf("got no error, expected %v", expErr)
}
}
makeState := func(maxTLSVersion uint16) (tls.ConnectionState, tls.ConnectionState) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
tlsClient := tls.Client(client, &tls.Config{
InsecureSkipVerify: true,
MaxVersion: maxTLSVersion,
})
tlsServer := tls.Server(server, &tls.Config{
Certificates: []tls.Certificate{fakeCert(t, "mox.example", false)},
MaxVersion: maxTLSVersion,
})
errc := make(chan error, 1)
go func() {
errc <- tlsServer.Handshake()
}()
err := tlsClient.Handshake()
tcheck(t, err, "tls handshake")
err = <-errc
tcheck(t, err, "server tls handshake")
clientcs := tlsClient.ConnectionState()
servercs := tlsServer.ConnectionState()
return clientcs, servercs
}
runPlus := func(maxTLSVersion uint16, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
t.Helper()
// PLUS variants.
clientcs, servercs := makeState(maxTLSVersion)
runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
}
run := func(expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
t.Helper()
// Bare variants
runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
// Check with both TLS 1.2 for "tls-unique", and latest TLS for "tls-exporter".
runPlus(tls.VersionTLS12, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
runPlus(0, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
}
run(nil, "user", "", "pencil", 4096, "", "")
run(nil, "mjl@mox.example", "", "testtest", 4096, "", "")
run(nil, "mjl@mox.example", "", "short", 4096, "", "")
run(nil, "mjl@mox.example", "", "short", 2048, "", "")
run(nil, "mjl@mox.example", "mjl@mox.example", "testtest", 4096, "", "")
run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
run(ErrUnsafe, "user", "", "pencil", 1, "", "") // Few iterations.
run(ErrUnsafe, "user", "", "pencil", 2048, "short", "") // Short client nonce.
run(ErrUnsafe, "user", "", "pencil", 2048, "test1234", "test") // Server added too few random data.
// Test mechanism downgrade attacks are detected.
runHash(sha1.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
runHash(sha256.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
// Test channel binding, detecting MitM attacks.
runChannelBind := func(maxTLSVersion uint16) {
t.Helper()
clientcs0, _ := makeState(maxTLSVersion)
_, servercs1 := makeState(maxTLSVersion)
runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
// Client thinks it is on a TLS connection and server is not.
runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
}
runChannelBind(0)
runChannelBind(tls.VersionTLS12)
}
// Just a cert that appears valid.
func fakeCert(t *testing.T, name string, expired bool) tls.Certificate {
notAfter := time.Now()
if expired {
notAfter = notAfter.Add(-time.Hour)
} else {
notAfter = notAfter.Add(time.Hour)
}
privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
template := &x509.Certificate{
SerialNumber: big.NewInt(1), // Required field...
DNSNames: []string{name},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: notAfter,
}
localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
if err != nil {
t.Fatalf("making certificate: %s", err)
}
cert, err := x509.ParseCertificate(localCertBuf)
if err != nil {
t.Fatalf("parsing generated certificate: %s", err)
}
c := tls.Certificate{
Certificate: [][]byte{localCertBuf},
PrivateKey: privKey,
Leaf: cert,
}
return c
}