update to quic-go v0.10.0 (#2288)

quic-go now vendors all of its dependencies, so we don't need to vendor
them here.

Created by running:
gvt delete github.com/lucas-clemente/quic-go
gvt delete github.com/bifurcation/mint
gvt delete github.com/lucas-clemente/aes12
gvt delete github.com/lucas-clemente/fnv128a
gvt delete github.com/lucas-clemente/quic-go-certificates
gvt delete github.com/aead/chacha20
gvt delete github.com/hashicorp/golang-lru
gvt fetch -tag v0.10.0-no-integrationtests github.com/lucas-clemente/quic-go
This commit is contained in:
Marten Seemann 2018-09-03 04:18:54 +07:00 committed by Matt Holt
parent 9edc16e4d6
commit dfbc2e81e3
185 changed files with 7807 additions and 12112 deletions

View file

@ -9,6 +9,7 @@ package chacha // import "github.com/aead/chacha20/chacha"
import (
"encoding/binary"
"errors"
"math"
)
const (
@ -28,6 +29,7 @@ const (
var (
useSSE2 bool
useSSSE3 bool
useAVX bool
useAVX2 bool
)
@ -55,7 +57,7 @@ func setup(state *[64]byte, nonce, key []byte) (err error) {
copy(hNonce[:], nonce[:16])
copy(tmpKey[:], key)
hChaCha20(&tmpKey, &hNonce, &tmpKey)
HChaCha20(&tmpKey, &hNonce, &tmpKey)
copy(Nonce[8:], nonce[16:])
initialize(state, tmpKey[:], &Nonce)
@ -161,6 +163,21 @@ func (c *Cipher) XORKeyStream(dst, src []byte) {
c.off = 0
}
// check for counter overflow
blocksToXOR := len(src) / 64
if len(src)%64 != 0 {
blocksToXOR++
}
var overflow bool
if c.noncesize == INonceSize {
overflow = binary.LittleEndian.Uint32(c.state[48:]) > math.MaxUint32-uint32(blocksToXOR)
} else {
overflow = binary.LittleEndian.Uint64(c.state[48:]) > math.MaxUint64-uint64(blocksToXOR)
}
if overflow {
panic("chacha20/chacha: counter overflow")
}
c.off += xorKeyStream(dst, src, &(c.block), &(c.state), c.rounds)
}
@ -174,3 +191,7 @@ func (c *Cipher) SetCounter(ctr uint64) {
}
c.off = 0
}
// HChaCha20 generates 32 pseudo-random bytes from a 128 bit nonce and a 256 bit secret key.
// It can be used as a key-derivation-function (KDF).
func HChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { hChaCha20(out, nonce, key) }

View file

@ -2,111 +2,10 @@
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build go1.7,amd64,!gccgo,!appengine,!nacl
// +build amd64,!gccgo,!appengine,!nacl
#include "textflag.h"
DATA ·sigma_AVX<>+0x00(SB)/4, $0x61707865
DATA ·sigma_AVX<>+0x04(SB)/4, $0x3320646e
DATA ·sigma_AVX<>+0x08(SB)/4, $0x79622d32
DATA ·sigma_AVX<>+0x0C(SB)/4, $0x6b206574
GLOBL ·sigma_AVX<>(SB), (NOPTR+RODATA), $16
DATA ·one_AVX<>+0x00(SB)/8, $1
DATA ·one_AVX<>+0x08(SB)/8, $0
GLOBL ·one_AVX<>(SB), (NOPTR+RODATA), $16
DATA ·one_AVX2<>+0x00(SB)/8, $0
DATA ·one_AVX2<>+0x08(SB)/8, $0
DATA ·one_AVX2<>+0x10(SB)/8, $1
DATA ·one_AVX2<>+0x18(SB)/8, $0
GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32
DATA ·two_AVX2<>+0x00(SB)/8, $2
DATA ·two_AVX2<>+0x08(SB)/8, $0
DATA ·two_AVX2<>+0x10(SB)/8, $2
DATA ·two_AVX2<>+0x18(SB)/8, $0
GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32
DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32
DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32
#define ROTL(n, t, v) \
VPSLLD $n, v, t; \
VPSRLD $(32-n), v, v; \
VPXOR v, t, v
#define CHACHA_QROUND(v0, v1, v2, v3, t, c16, c8) \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB c16, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL(12, t, v1); \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB c8, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL(7, t, v1)
#define CHACHA_SHUFFLE(v1, v2, v3) \
VPSHUFD $0x39, v1, v1; \
VPSHUFD $0x4E, v2, v2; \
VPSHUFD $-109, v3, v3
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
VMOVDQU (64+off)(src), t0; \
VPERM2I128 $49, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (64+off)(dst); \
VMOVDQU (96+off)(src), t0; \
VPERM2I128 $49, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (96+off)(dst)
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \
VPERM2I128 $49, v1, v0, t0; \
VMOVDQU t0, 0(dst); \
VPERM2I128 $49, v3, v2, t0; \
VMOVDQU t0, 32(dst)
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t0) \
VPXOR 0+off(src), v0, t0; \
VMOVDQU t0, 0+off(dst); \
VPXOR 16+off(src), v1, t0; \
VMOVDQU t0, 16+off(dst); \
VPXOR 32+off(src), v2, t0; \
VMOVDQU t0, 32+off(dst); \
VPXOR 48+off(src), v3, t0; \
VMOVDQU t0, 48+off(dst)
#include "const.s"
#include "macro.s"
#define TWO 0(SP)
#define C16 32(SP)
@ -122,10 +21,10 @@ GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32
TEXT ·xorKeyStreamAVX2(SB), 4, $320-80
MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), CX
MOVQ block+48(FP), BX
MOVQ state+56(FP), AX
MOVQ rounds+64(FP), DX
MOVQ src_len+32(FP), CX
MOVQ SP, R8
ADDQ $32, SP
@ -185,28 +84,28 @@ at_least_512:
chacha_loop_512:
VMOVDQA Y8, TMP_0
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
VMOVDQA TMP_0, Y8
VMOVDQA Y0, TMP_0
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8)
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_SHUFFLE(Y1, Y2, Y3)
CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_SHUFFLE(Y9, Y10, Y11)
CHACHA_SHUFFLE(Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_SHUFFLE_AVX(Y13, Y14, Y15)
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8)
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
VMOVDQA TMP_0, Y0
VMOVDQA Y8, TMP_0
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8)
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
VMOVDQA TMP_0, Y8
CHACHA_SHUFFLE(Y3, Y2, Y1)
CHACHA_SHUFFLE(Y7, Y6, Y5)
CHACHA_SHUFFLE(Y11, Y10, Y9)
CHACHA_SHUFFLE(Y15, Y14, Y13)
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
CHACHA_SHUFFLE_AVX(Y15, Y14, Y13)
SUBQ $2, R9
JA chacha_loop_512
@ -289,18 +188,18 @@ between_320_and_448:
MOVQ DX, R9
chacha_loop_384:
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y1, Y2, Y3)
CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_SHUFFLE(Y9, Y10, Y11)
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y3, Y2, Y1)
CHACHA_SHUFFLE(Y7, Y6, Y5)
CHACHA_SHUFFLE(Y11, Y10, Y9)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
SUBQ $2, R9
JA chacha_loop_384
@ -361,14 +260,14 @@ between_192_and_320:
MOVQ DX, R9
chacha_loop_256:
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_SHUFFLE(Y9, Y10, Y11)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y7, Y6, Y5)
CHACHA_SHUFFLE(Y11, Y10, Y9)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
SUBQ $2, R9
JA chacha_loop_256
@ -413,10 +312,10 @@ between_64_and_192:
MOVQ DX, R9
chacha_loop_128:
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y7, Y6, Y5)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
SUBQ $2, R9
JA chacha_loop_128
@ -455,10 +354,10 @@ between_0_and_64:
MOVQ DX, R9
chacha_loop_64:
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE(X5, X6, X7)
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE(X7, X6, X5)
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE_AVX(X5, X6, X7)
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE_AVX(X7, X6, X5)
SUBQ $2, R9
JA chacha_loop_64
@ -466,7 +365,7 @@ chacha_loop_64:
VPADDD X1, X5, X5
VPADDD X2, X6, X6
VPADDD X3, X7, X7
VMOVDQU ·one_AVX<>(SB), X0
VMOVDQU ·one<>(SB), X0
VPADDQ X0, X3, X3
CMPQ CX, $64
@ -505,38 +404,3 @@ done:
MOVQ CX, ret+72(FP)
RET
// func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20AVX(SB), 4, $0-24
MOVQ out+0(FP), DI
MOVQ nonce+8(FP), AX
MOVQ key+16(FP), BX
VMOVDQU ·sigma_AVX<>(SB), X0
VMOVDQU 0(BX), X1
VMOVDQU 16(BX), X2
VMOVDQU 0(AX), X3
VMOVDQU ·rol16_AVX2<>(SB), X5
VMOVDQU ·rol8_AVX2<>(SB), X6
MOVQ $20, CX
chacha_loop:
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X1, X2, X3)
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X3, X2, X1)
SUBQ $2, CX
JNZ chacha_loop
VMOVDQU X0, 0(DI)
VMOVDQU X3, 16(DI)
VZEROUPPER
RET
// func supportsAVX2() bool
TEXT ·supportsAVX2(SB), 4, $0-1
MOVQ runtime·support_avx(SB), AX
MOVQ runtime·support_avx2(SB), BX
ANDQ AX, BX
MOVB BX, ret+0(FP)
RET

View file

@ -6,11 +6,16 @@
package chacha
import "encoding/binary"
import (
"encoding/binary"
"golang.org/x/sys/cpu"
)
func init() {
useSSE2 = supportsSSE2()
useSSSE3 = supportsSSSE3()
useSSE2 = cpu.X86.HasSSE2
useSSSE3 = cpu.X86.HasSSSE3
useAVX = false
useAVX2 = false
}
@ -23,14 +28,6 @@ func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
copy(state[48:], nonce[:])
}
// This function is implemented in chacha_386.s
//go:noescape
func supportsSSE2() bool
// This function is implemented in chacha_386.s
//go:noescape
func supportsSSSE3() bool
// This function is implemented in chacha_386.s
//go:noescape
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
@ -43,25 +40,21 @@ func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
//go:noescape
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chacha_386.s
//go:noescape
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
if useSSSE3 {
switch {
case useSSSE3:
hChaCha20SSSE3(out, nonce, key)
} else if useSSE2 {
case useSSE2:
hChaCha20SSE2(out, nonce, key)
} else {
default:
hChaCha20Generic(out, nonce, key)
}
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
if useSSSE3 {
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
} else if useSSE2 {
if useSSE2 {
return xorKeyStreamSSE2(dst, src, block, state, rounds)
} else {
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}

View file

@ -4,126 +4,125 @@
// +build 386,!gccgo,!appengine,!nacl
#include "textflag.h"
DATA ·sigma<>+0x00(SB)/4, $0x61707865
DATA ·sigma<>+0x04(SB)/4, $0x3320646e
DATA ·sigma<>+0x08(SB)/4, $0x79622d32
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16
DATA ·one<>+0x00(SB)/8, $1
DATA ·one<>+0x08(SB)/8, $0
GLOBL ·one<>(SB), (NOPTR+RODATA), $16
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
#define ROTL_SSE2(n, t, v) \
MOVO v, t; \
PSLLL $n, t; \
PSRLL $(32-n), v; \
PXOR t, v
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE2(16, t0, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE2(12, t0, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE2(8, t0, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE2(7, t0, v1)
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r16, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE2(12, t0, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r8, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE2(7, t0, v1)
#define CHACHA_SHUFFLE(v1, v2, v3) \
PSHUFL $0x39, v1, v1; \
PSHUFL $0x4E, v2, v2; \
PSHUFL $0x93, v3, v3
#define XOR(dst, src, off, v0, v1, v2, v3, t0) \
MOVOU 0+off(src), t0; \
PXOR v0, t0; \
MOVOU t0, 0+off(dst); \
MOVOU 16+off(src), t0; \
PXOR v1, t0; \
MOVOU t0, 16+off(dst); \
MOVOU 32+off(src), t0; \
PXOR v2, t0; \
MOVOU t0, 32+off(dst); \
MOVOU 48+off(src), t0; \
PXOR v3, t0; \
MOVOU t0, 48+off(dst)
#include "const.s"
#include "macro.s"
// FINALIZE xors len bytes from src and block using
// the temp. registers t0 and t1 and writes the result
// to dst.
#define FINALIZE(dst, src, block, len, t0, t1) \
XORL t0, t0; \
XORL t1, t1; \
finalize: \
MOVB 0(src), t0; \
MOVB 0(block), t1; \
XORL t0, t1; \
MOVB t1, 0(dst); \
INCL src; \
INCL block; \
INCL dst; \
DECL len; \
JA finalize \
XORL t0, t0; \
XORL t1, t1; \
FINALIZE_LOOP:; \
MOVB 0(src), t0; \
MOVB 0(block), t1; \
XORL t0, t1; \
MOVB t1, 0(dst); \
INCL src; \
INCL block; \
INCL dst; \
DECL len; \
JG FINALIZE_LOOP \
#define Dst DI
#define Nonce AX
#define Key BX
#define Rounds DX
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSE2(SB), 4, $0-12
MOVL out+0(FP), Dst
MOVL nonce+4(FP), Nonce
MOVL key+8(FP), Key
MOVOU ·sigma<>(SB), X0
MOVOU 0*16(Key), X1
MOVOU 1*16(Key), X2
MOVOU 0*16(Nonce), X3
MOVL $20, Rounds
chacha_loop:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE_SSE(X1, X2, X3)
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE_SSE(X3, X2, X1)
SUBL $2, Rounds
JNZ chacha_loop
MOVOU X0, 0*16(Dst)
MOVOU X3, 1*16(Dst)
RET
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSSE3(SB), 4, $0-12
MOVL out+0(FP), Dst
MOVL nonce+4(FP), Nonce
MOVL key+8(FP), Key
MOVOU ·sigma<>(SB), X0
MOVOU 0*16(Key), X1
MOVOU 1*16(Key), X2
MOVOU 0*16(Nonce), X3
MOVL $20, Rounds
MOVOU ·rol16<>(SB), X5
MOVOU ·rol8<>(SB), X6
chacha_loop:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE_SSE(X1, X2, X3)
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE_SSE(X3, X2, X1)
SUBL $2, Rounds
JNZ chacha_loop
MOVOU X0, 0*16(Dst)
MOVOU X3, 1*16(Dst)
RET
#undef Dst
#undef Nonce
#undef Key
#undef Rounds
#define State AX
#define Dst DI
#define Src SI
#define Len DX
#define Tmp0 BX
#define Tmp1 BP
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSE2(SB), 4, $0-40
MOVL dst_base+0(FP), DI
MOVL src_base+12(FP), SI
MOVL src_len+16(FP), CX
MOVL state+28(FP), AX
MOVL rounds+32(FP), DX
MOVL dst_base+0(FP), Dst
MOVL src_base+12(FP), Src
MOVL state+28(FP), State
MOVL src_len+16(FP), Len
MOVL $0, ret+36(FP) // Number of bytes written to the keystream buffer - 0 iff len mod 64 == 0
MOVOU 0(AX), X0
MOVOU 16(AX), X1
MOVOU 32(AX), X2
MOVOU 48(AX), X3
MOVOU 0*16(State), X0
MOVOU 1*16(State), X1
MOVOU 2*16(State), X2
MOVOU 3*16(State), X3
TESTL Len, Len
JZ DONE
TESTL CX, CX
JZ done
at_least_64:
GENERATE_KEYSTREAM:
MOVO X0, X4
MOVO X1, X5
MOVO X2, X6
MOVO X3, X7
MOVL rounds+32(FP), Tmp0
MOVL DX, BX
chacha_loop:
CHACHA_LOOP:
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE(X5, X6, X7)
CHACHA_SHUFFLE_SSE(X5, X6, X7)
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE(X7, X6, X5)
SUBL $2, BX
JA chacha_loop
CHACHA_SHUFFLE_SSE(X7, X6, X5)
SUBL $2, Tmp0
JA CHACHA_LOOP
MOVOU 0(AX), X0
MOVOU 0*16(State), X0 // Restore X0 from state
PADDL X0, X4
PADDL X1, X5
PADDL X2, X6
@ -131,181 +130,34 @@ chacha_loop:
MOVOU ·one<>(SB), X0
PADDQ X0, X3
CMPL CX, $64
JB less_than_64
CMPL Len, $64
JL BUFFER_KEYSTREAM
XOR(DI, SI, 0, X4, X5, X6, X7, X0)
MOVOU 0(AX), X0
ADDL $64, SI
ADDL $64, DI
SUBL $64, CX
JNZ at_least_64
XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X0)
MOVOU 0*16(State), X0 // Restore X0 from state
ADDL $64, Src
ADDL $64, Dst
SUBL $64, Len
JZ DONE
JMP GENERATE_KEYSTREAM // There is at least one more plaintext byte
less_than_64:
MOVL CX, BP
TESTL BP, BP
JZ done
BUFFER_KEYSTREAM:
MOVL block+24(FP), State
MOVOU X4, 0(State)
MOVOU X5, 16(State)
MOVOU X6, 32(State)
MOVOU X7, 48(State)
MOVL Len, ret+36(FP) // Number of bytes written to the keystream buffer - 0 < Len < 64
FINALIZE(Dst, Src, State, Len, Tmp0, Tmp1)
MOVL block+24(FP), BX
MOVOU X4, 0(BX)
MOVOU X5, 16(BX)
MOVOU X6, 32(BX)
MOVOU X7, 48(BX)
FINALIZE(DI, SI, BX, BP, AX, DX)
done:
MOVL state+28(FP), AX
MOVOU X3, 48(AX)
MOVL CX, ret+36(FP)
DONE:
MOVL state+28(FP), State
MOVOU X3, 3*16(State)
RET
// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSSE3(SB), 4, $64-40
MOVL dst_base+0(FP), DI
MOVL src_base+12(FP), SI
MOVL src_len+16(FP), CX
MOVL state+28(FP), AX
MOVL rounds+32(FP), DX
MOVOU 48(AX), X3
TESTL CX, CX
JZ done
MOVL SP, BP
ADDL $16, SP
ANDL $-16, SP
MOVOU ·one<>(SB), X0
MOVOU 16(AX), X1
MOVOU 32(AX), X2
MOVO X0, 0(SP)
MOVO X1, 16(SP)
MOVO X2, 32(SP)
MOVOU 0(AX), X0
MOVOU ·rol16<>(SB), X1
MOVOU ·rol8<>(SB), X2
at_least_64:
MOVO X0, X4
MOVO 16(SP), X5
MOVO 32(SP), X6
MOVO X3, X7
MOVL DX, BX
chacha_loop:
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2)
CHACHA_SHUFFLE(X5, X6, X7)
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2)
CHACHA_SHUFFLE(X7, X6, X5)
SUBL $2, BX
JA chacha_loop
MOVOU 0(AX), X0
PADDL X0, X4
PADDL 16(SP), X5
PADDL 32(SP), X6
PADDL X3, X7
PADDQ 0(SP), X3
CMPL CX, $64
JB less_than_64
XOR(DI, SI, 0, X4, X5, X6, X7, X0)
MOVOU 0(AX), X0
ADDL $64, SI
ADDL $64, DI
SUBL $64, CX
JNZ at_least_64
less_than_64:
MOVL BP, SP
MOVL CX, BP
TESTL BP, BP
JE done
MOVL block+24(FP), BX
MOVOU X4, 0(BX)
MOVOU X5, 16(BX)
MOVOU X6, 32(BX)
MOVOU X7, 48(BX)
FINALIZE(DI, SI, BX, BP, AX, DX)
done:
MOVL state+28(FP), AX
MOVOU X3, 48(AX)
MOVL CX, ret+36(FP)
RET
// func supportsSSE2() bool
TEXT ·supportsSSE2(SB), NOSPLIT, $0-1
XORL AX, AX
INCL AX
CPUID
SHRL $26, DX
ANDL $1, DX
MOVB DX, ret+0(FP)
RET
// func supportsSSSE3() bool
TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1
XORL AX, AX
INCL AX
CPUID
SHRL $9, CX
ANDL $1, CX
MOVB CX, ret+0(FP)
RET
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSE2(SB), 4, $0-12
MOVL out+0(FP), DI
MOVL nonce+4(FP), AX
MOVL key+8(FP), BX
MOVOU ·sigma<>(SB), X0
MOVOU 0(BX), X1
MOVOU 16(BX), X2
MOVOU 0(AX), X3
MOVL $20, CX
chacha_loop:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE(X1, X2, X3)
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE(X3, X2, X1)
SUBL $2, CX
JNZ chacha_loop
MOVOU X0, 0(DI)
MOVOU X3, 16(DI)
RET
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSSE3(SB), 4, $0-12
MOVL out+0(FP), DI
MOVL nonce+4(FP), AX
MOVL key+8(FP), BX
MOVOU ·sigma<>(SB), X0
MOVOU 0(BX), X1
MOVOU 16(BX), X2
MOVOU 0(AX), X3
MOVOU ·rol16<>(SB), X5
MOVOU ·rol8<>(SB), X6
MOVL $20, CX
chacha_loop:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X1, X2, X3)
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X3, X2, X1)
SUBL $2, CX
JNZ chacha_loop
MOVOU X0, 0(DI)
MOVOU X3, 16(DI)
RET
#undef State
#undef Dst
#undef Src
#undef Len
#undef Tmp0
#undef Tmp1

View file

@ -6,24 +6,19 @@
package chacha
import "golang.org/x/sys/cpu"
func init() {
useSSE2 = true
useSSSE3 = supportsSSSE3()
useAVX2 = supportsAVX2()
useSSE2 = cpu.X86.HasSSE2
useSSSE3 = cpu.X86.HasSSSE3
useAVX = cpu.X86.HasAVX
useAVX2 = cpu.X86.HasAVX2
}
// This function is implemented in chacha_amd64.s
//go:noescape
func initialize(state *[64]byte, key []byte, nonce *[16]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func supportsSSSE3() bool
// This function is implemented in chachaAVX2_amd64.s
//go:noescape
func supportsAVX2() bool
// This function is implemented in chacha_amd64.s
//go:noescape
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
@ -44,29 +39,38 @@ func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
//go:noescape
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chachaAVX2_amd64.s
//go:noescape
func xorKeyStreamAVX2(dst, src []byte, block, state *[64]byte, rounds int) int
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
if useAVX2 {
switch {
case useAVX:
hChaCha20AVX(out, nonce, key)
} else if useSSSE3 {
case useSSSE3:
hChaCha20SSSE3(out, nonce, key)
} else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64
case useSSE2:
hChaCha20SSE2(out, nonce, key)
} else {
default:
hChaCha20Generic(out, nonce, key)
}
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
if useAVX2 {
switch {
case useAVX2:
return xorKeyStreamAVX2(dst, src, block, state, rounds)
} else if useSSSE3 {
case useAVX:
return xorKeyStreamAVX(dst, src, block, state, rounds)
case useSSSE3:
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
} else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64
case useSSE2:
return xorKeyStreamSSE2(dst, src, block, state, rounds)
default:
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}

File diff suppressed because it is too large Load diff

View file

@ -1,56 +0,0 @@
// Copyright (c) 2017 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build amd64,!gccgo,!appengine,!nacl,!go1.7
package chacha
func init() {
useSSE2 = true
useSSSE3 = supportsSSSE3()
useAVX2 = false
}
// This function is implemented in chacha_amd64.s
//go:noescape
func initialize(state *[64]byte, key []byte, nonce *[16]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func supportsSSSE3() bool
// This function is implemented in chacha_amd64.s
//go:noescape
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
if useSSSE3 {
hChaCha20SSSE3(out, nonce, key)
} else if useSSE2 { // on amd64 this is always true - used to test generic on amd64
hChaCha20SSE2(out, nonce, key)
} else {
hChaCha20Generic(out, nonce, key)
}
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
if useSSSE3 {
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
} else if useSSE2 { // on amd64 this is always true - used to test generic on amd64
return xorKeyStreamSSE2(dst, src, block, state, rounds)
}
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}

View file

@ -8,6 +8,13 @@ package chacha
import "encoding/binary"
func init() {
useSSE2 = false
useSSSE3 = false
useAVX = false
useAVX2 = false
}
func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
binary.LittleEndian.PutUint32(state[0:], sigma[0])
binary.LittleEndian.PutUint32(state[4:], sigma[1])

53
vendor/github.com/aead/chacha20/chacha/const.s generated vendored Normal file
View file

@ -0,0 +1,53 @@
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
#include "textflag.h"
DATA ·sigma<>+0x00(SB)/4, $0x61707865
DATA ·sigma<>+0x04(SB)/4, $0x3320646e
DATA ·sigma<>+0x08(SB)/4, $0x79622d32
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 // The 4 ChaCha initialization constants
// SSE2/SSE3/AVX constants
DATA ·one<>+0x00(SB)/8, $1
DATA ·one<>+0x08(SB)/8, $0
GLOBL ·one<>(SB), (NOPTR+RODATA), $16 // The constant 1 as 128 bit value
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 16 bit left rotate constant
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 8 bit left rotate constant
// AVX2 constants
DATA ·one_AVX2<>+0x00(SB)/8, $0
DATA ·one_AVX2<>+0x08(SB)/8, $0
DATA ·one_AVX2<>+0x10(SB)/8, $1
DATA ·one_AVX2<>+0x18(SB)/8, $0
GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 // The constant 1 as 256 bit value
DATA ·two_AVX2<>+0x00(SB)/8, $2
DATA ·two_AVX2<>+0x08(SB)/8, $0
DATA ·two_AVX2<>+0x10(SB)/8, $2
DATA ·two_AVX2<>+0x18(SB)/8, $0
GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32
DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 16 bit left rotate constant
DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 8 bit left rotate constant

163
vendor/github.com/aead/chacha20/chacha/macro.s generated vendored Normal file
View file

@ -0,0 +1,163 @@
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
// ROTL_SSE rotates all 4 32 bit values of the XMM register v
// left by n bits using SSE2 instructions (0 <= n <= 32).
// The XMM register t is used as a temp. register.
#define ROTL_SSE(n, t, v) \
MOVO v, t; \
PSLLL $n, t; \
PSRLL $(32-n), v; \
PXOR t, v
// ROTL_AVX rotates all 4/8 32 bit values of the AVX/AVX2 register v
// left by n bits using AVX/AVX2 instructions (0 <= n <= 32).
// The AVX/AVX2 register t is used as a temp. register.
#define ROTL_AVX(n, t, v) \
VPSLLD $n, v, t; \
VPSRLD $(32-n), v, v; \
VPXOR v, t, v
// CHACHA_QROUND_SSE2 performs a ChaCha quarter-round using the
// 4 XMM registers v0, v1, v2 and v3. It uses only ROTL_SSE2 for
// rotations. The XMM register t is used as a temp. register.
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t) \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE(16, t, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(12, t, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE(8, t, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(7, t, v1)
// CHACHA_QROUND_SSSE3 performs a ChaCha quarter-round using the
// 4 XMM registers v0, v1, v2 and v3. It uses PSHUFB for 8/16 bit
// rotations. The XMM register t is used as a temp. register.
//
// r16 holds the PSHUFB constant for a 16 bit left rotate.
// r8 holds the PSHUFB constant for a 8 bit left rotate.
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t, r16, r8) \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r16, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(12, t, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r8, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(7, t, v1)
// CHACHA_QROUND_AVX performs a ChaCha quarter-round using the
// 4 AVX/AVX2 registers v0, v1, v2 and v3. It uses VPSHUFB for 8/16 bit
// rotations. The AVX/AVX2 register t is used as a temp. register.
//
// r16 holds the VPSHUFB constant for a 16 bit left rotate.
// r8 holds the VPSHUFB constant for a 8 bit left rotate.
#define CHACHA_QROUND_AVX(v0, v1, v2, v3, t, r16, r8) \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB r16, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL_AVX(12, t, v1); \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB r8, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL_AVX(7, t, v1)
// CHACHA_SHUFFLE_SSE performs a ChaCha shuffle using the
// 3 XMM registers v1, v2 and v3. The inverse shuffle is
// performed by switching v1 and v3: CHACHA_SHUFFLE_SSE(v3, v2, v1).
#define CHACHA_SHUFFLE_SSE(v1, v2, v3) \
PSHUFL $0x39, v1, v1; \
PSHUFL $0x4E, v2, v2; \
PSHUFL $0x93, v3, v3
// CHACHA_SHUFFLE_AVX performs a ChaCha shuffle using the
// 3 AVX/AVX2 registers v1, v2 and v3. The inverse shuffle is
// performed by switching v1 and v3: CHACHA_SHUFFLE_AVX(v3, v2, v1).
#define CHACHA_SHUFFLE_AVX(v1, v2, v3) \
VPSHUFD $0x39, v1, v1; \
VPSHUFD $0x4E, v2, v2; \
VPSHUFD $0x93, v3, v3
// XOR_SSE extracts 4x16 byte vectors from src at
// off, xors all vectors with the corresponding XMM
// register (v0 - v3) and writes the result to dst
// at off.
// The XMM register t is used as a temp. register.
#define XOR_SSE(dst, src, off, v0, v1, v2, v3, t) \
MOVOU 0+off(src), t; \
PXOR v0, t; \
MOVOU t, 0+off(dst); \
MOVOU 16+off(src), t; \
PXOR v1, t; \
MOVOU t, 16+off(dst); \
MOVOU 32+off(src), t; \
PXOR v2, t; \
MOVOU t, 32+off(dst); \
MOVOU 48+off(src), t; \
PXOR v3, t; \
MOVOU t, 48+off(dst)
// XOR_AVX extracts 4x16 byte vectors from src at
// off, xors all vectors with the corresponding AVX
// register (v0 - v3) and writes the result to dst
// at off.
// The XMM register t is used as a temp. register.
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t) \
VPXOR 0+off(src), v0, t; \
VMOVDQU t, 0+off(dst); \
VPXOR 16+off(src), v1, t; \
VMOVDQU t, 16+off(dst); \
VPXOR 32+off(src), v2, t; \
VMOVDQU t, 32+off(dst); \
VPXOR 48+off(src), v3, t; \
VMOVDQU t, 48+off(dst)
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
VMOVDQU (64+off)(src), t0; \
VPERM2I128 $49, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (64+off)(dst); \
VMOVDQU (96+off)(src), t0; \
VPERM2I128 $49, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (96+off)(dst)
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \
VPERM2I128 $49, v1, v0, t0; \
VMOVDQU t0, 0(dst); \
VPERM2I128 $49, v3, v2, t0; \
VMOVDQU t0, 32(dst)

View file

@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2016 Richard Barnes
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View file

@ -1,99 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mint
import "strconv"
type Alert uint8
const (
// alert level
AlertLevelWarning = 1
AlertLevelError = 2
)
const (
AlertCloseNotify Alert = 0
AlertUnexpectedMessage Alert = 10
AlertBadRecordMAC Alert = 20
AlertDecryptionFailed Alert = 21
AlertRecordOverflow Alert = 22
AlertDecompressionFailure Alert = 30
AlertHandshakeFailure Alert = 40
AlertBadCertificate Alert = 42
AlertUnsupportedCertificate Alert = 43
AlertCertificateRevoked Alert = 44
AlertCertificateExpired Alert = 45
AlertCertificateUnknown Alert = 46
AlertIllegalParameter Alert = 47
AlertUnknownCA Alert = 48
AlertAccessDenied Alert = 49
AlertDecodeError Alert = 50
AlertDecryptError Alert = 51
AlertProtocolVersion Alert = 70
AlertInsufficientSecurity Alert = 71
AlertInternalError Alert = 80
AlertInappropriateFallback Alert = 86
AlertUserCanceled Alert = 90
AlertNoRenegotiation Alert = 100
AlertMissingExtension Alert = 109
AlertUnsupportedExtension Alert = 110
AlertCertificateUnobtainable Alert = 111
AlertUnrecognizedName Alert = 112
AlertBadCertificateStatsResponse Alert = 113
AlertBadCertificateHashValue Alert = 114
AlertUnknownPSKIdentity Alert = 115
AlertNoApplicationProtocol Alert = 120
AlertWouldBlock Alert = 254
AlertNoAlert Alert = 255
)
var alertText = map[Alert]string{
AlertCloseNotify: "close notify",
AlertUnexpectedMessage: "unexpected message",
AlertBadRecordMAC: "bad record MAC",
AlertDecryptionFailed: "decryption failed",
AlertRecordOverflow: "record overflow",
AlertDecompressionFailure: "decompression failure",
AlertHandshakeFailure: "handshake failure",
AlertBadCertificate: "bad certificate",
AlertUnsupportedCertificate: "unsupported certificate",
AlertCertificateRevoked: "revoked certificate",
AlertCertificateExpired: "expired certificate",
AlertCertificateUnknown: "unknown certificate",
AlertIllegalParameter: "illegal parameter",
AlertUnknownCA: "unknown certificate authority",
AlertAccessDenied: "access denied",
AlertDecodeError: "error decoding message",
AlertDecryptError: "error decrypting message",
AlertProtocolVersion: "protocol version not supported",
AlertInsufficientSecurity: "insufficient security level",
AlertInternalError: "internal error",
AlertInappropriateFallback: "inappropriate fallback",
AlertUserCanceled: "user canceled",
AlertMissingExtension: "missing extension",
AlertUnsupportedExtension: "unsupported extension",
AlertCertificateUnobtainable: "certificate unobtainable",
AlertUnrecognizedName: "unrecognized name",
AlertBadCertificateStatsResponse: "bad certificate status response",
AlertBadCertificateHashValue: "bad certificate hash value",
AlertUnknownPSKIdentity: "unknown PSK identity",
AlertNoApplicationProtocol: "no application protocol",
AlertNoRenegotiation: "no renegotiation",
AlertWouldBlock: "would have blocked",
AlertNoAlert: "no alert",
}
func (e Alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
}
func (e Alert) Error() string {
return e.String()
}

View file

@ -1,42 +0,0 @@
package main
import (
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"github.com/bifurcation/mint"
)
var url string
func main() {
url := flag.String("url", "https://localhost:4430", "URL to send request")
flag.Parse()
mintdial := func(network, addr string) (net.Conn, error) {
return mint.Dial(network, addr, nil)
}
tr := &http.Transport{
DialTLS: mintdial,
DisableCompression: true,
}
client := &http.Client{Transport: tr}
response, err := client.Get(*url)
if err != nil {
fmt.Println("err:", err)
return
}
defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
fmt.Printf("%s", err)
os.Exit(1)
}
fmt.Printf("%s\n", string(contents))
}

View file

@ -1,37 +0,0 @@
package main
import (
"flag"
"fmt"
"github.com/bifurcation/mint"
)
var addr string
func main() {
flag.StringVar(&addr, "addr", "localhost:4430", "port")
flag.Parse()
conn, err := mint.Dial("tcp", addr, nil)
if err != nil {
fmt.Println("TLS handshake failed:", err)
return
}
request := "GET / HTTP/1.0\r\n\r\n"
conn.Write([]byte(request))
response := ""
buffer := make([]byte, 1024)
var read int
for err == nil {
read, err = conn.Read(buffer)
fmt.Println(" ~~ read: ", read)
response += string(buffer)
}
fmt.Println("err:", err)
fmt.Println("Received from server:")
fmt.Println(response)
}

View file

@ -1,226 +0,0 @@
package main
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/bifurcation/mint"
"golang.org/x/net/http2"
)
var (
port string
serverName string
certFile string
keyFile string
responseFile string
h2 bool
sendTickets bool
)
type responder []byte
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write(rsp)
}
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
// PEM-encoded private key.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
keyDER, _ := pem.Decode(keyPEM)
if keyDER == nil {
return nil, err
}
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
if err != nil {
// We don't include the actual error into
// the final error. The reason might be
// we don't want to leak any info about
// the private key.
return nil, fmt.Errorf("No successful private key decoder")
}
}
}
switch generalKey.(type) {
case *rsa.PrivateKey:
return generalKey.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return generalKey.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, fmt.Errorf("Should be unreachable")
}
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
// either a raw x509 certificate or a PKCS #7 structure possibly containing
// multiple certificates, from the top of certsPEM, which itself may
// contain multiple PEM encoded certificate objects.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
block, rest := pem.Decode(certsPEM)
if block == nil {
return nil, rest, nil
}
cert, err := x509.ParseCertificate(block.Bytes)
var certs = []*x509.Certificate{cert}
return certs, rest, err
}
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
// can handle PEM encoded PKCS #7 structures.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
var err error
certsPEM = bytes.TrimSpace(certsPEM)
for len(certsPEM) > 0 {
var cert []*x509.Certificate
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
if err != nil {
return nil, err
} else if cert == nil {
break
}
certs = append(certs, cert...)
}
if len(certsPEM) > 0 {
return nil, fmt.Errorf("Trailing PEM data")
}
return certs, nil
}
func main() {
flag.StringVar(&port, "port", "4430", "port")
flag.StringVar(&serverName, "host", "example.com", "hostname")
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
flag.StringVar(&responseFile, "response", "", "file to serve")
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
flag.Parse()
var certChain []*x509.Certificate
var priv crypto.Signer
var response []byte
var err error
// Load the key and certificate chain
if certFile != "" {
certs, err := ioutil.ReadFile(certFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
certChain, err = ParseCertificatesPEM(certs)
if err != nil {
certChain, err = x509.ParseCertificates(certs)
if err != nil {
log.Fatalf("Error parsing certificates: %v", err)
}
}
}
}
if keyFile != "" {
keyPEM, err := ioutil.ReadFile(keyFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
priv, err = ParsePrivateKeyPEM(keyPEM)
if priv == nil || err != nil {
log.Fatalf("Error parsing private key: %v", err)
}
}
}
if err != nil {
log.Fatalf("Error: %v", err)
}
// Load response file
if responseFile != "" {
log.Printf("Loading response file: %v", responseFile)
response, err = ioutil.ReadFile(responseFile)
if err != nil {
log.Fatalf("Error: %v", err)
}
} else {
response = []byte("Welcome to the TLS 1.3 zone!")
}
handler := responder(response)
config := mint.Config{
SendSessionTickets: true,
ServerName: serverName,
NextProtos: []string{"http/1.1"},
}
if h2 {
config.NextProtos = []string{"h2"}
}
config.SendSessionTickets = sendTickets
if certChain != nil && priv != nil {
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
config.Certificates = []*mint.Certificate{
{
Chain: certChain,
PrivateKey: priv,
},
}
}
config.Init(false)
service := "0.0.0.0:" + port
srv := &http.Server{Handler: handler}
log.Printf("Listening on port %v", port)
// Need the inner loop here because the h1 server errors on a dropped connection
// Need the outer loop here because the h2 server is per-connection
for {
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Printf("Listen Error: %v", err)
continue
}
if !h2 {
alert := srv.Serve(listener)
if alert != mint.AlertNoAlert {
log.Printf("Serve Error: %v", err)
}
} else {
srv2 := new(http2.Server)
opts := &http2.ServeConnOpts{
Handler: handler,
BaseConfig: srv,
}
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go srv2.ServeConn(conn, opts)
}
}
}
}

View file

@ -1,65 +0,0 @@
package main
import (
"flag"
"log"
"net"
"github.com/bifurcation/mint"
)
var port string
func main() {
var config mint.Config
config.SendSessionTickets = true
config.ServerName = "localhost"
config.Init(false)
flag.StringVar(&port, "port", "4430", "port")
flag.Parse()
service := "0.0.0.0:" + port
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Fatalf("server: listen: %s", err)
}
log.Print("server: listening")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("server: accept: %s", err)
break
}
defer conn.Close()
log.Printf("server: accepted from %s", conn.RemoteAddr())
go handleClient(conn)
}
}
func handleClient(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 10)
for {
log.Print("server: conn: waiting")
n, err := conn.Read(buf)
if err != nil {
if err != nil {
log.Printf("server: conn: read: %s", err)
}
break
}
n, err = conn.Write([]byte("hello world"))
log.Printf("server: conn: wrote %d bytes", n)
if err != nil {
log.Printf("server: write: %s", err)
break
}
break
}
log.Println("server: conn: closed")
}

View file

@ -1,942 +0,0 @@
package mint
import (
"bytes"
"crypto"
"hash"
"time"
)
// Client State Machine
//
// START <----+
// Send ClientHello | | Recv HelloRetryRequest
// / v |
// | WAIT_SH ---+
// Can | | Recv ServerHello
// send | V
// early | WAIT_EE
// data | | Recv EncryptedExtensions
// | +--------+--------+
// | Using | | Using certificate
// | PSK | v
// | | WAIT_CERT_CR
// | | Recv | | Recv CertificateRequest
// | | Certificate | v
// | | | WAIT_CERT
// | | | | Recv Certificate
// | | v v
// | | WAIT_CV
// | | | Recv CertificateVerify
// | +> WAIT_FINISHED <+
// | | Recv Finished
// \ |
// | [Send EndOfEarlyData]
// | [Send Certificate [+ CertificateVerify]]
// | Send Finished
// Can send v
// app data --> CONNECTED
// after
// here
//
// State Instructions
// START Send(CH); [RekeyOut; SendEarlyData]
// WAIT_SH Send(CH) || RekeyIn
// WAIT_EE {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ClientStateStart struct {
Caps Capabilities
Opts ConnectionOptions
Params ConnectionParameters
cookie []byte
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
}
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
return nil, nil, AlertUnexpectedMessage
}
// key_shares
offeredDH := map[NamedGroup][]byte{}
ks := KeyShareExtension{
HandshakeType: HandshakeTypeClientHello,
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
}
for i, group := range state.Caps.Groups {
pub, priv, err := newKeyShare(group)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
return nil, nil, AlertInternalError
}
ks.Shares[i].Group = group
ks.Shares[i].KeyExchange = pub
offeredDH[group] = priv
}
logf(logTypeHandshake, "opts: %+v", state.Opts)
// supported_versions, supported_groups, signature_algorithms, server_name
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
sni := ServerNameExtension(state.Opts.ServerName)
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
state.Params.ServerName = state.Opts.ServerName
// Application Layer Protocol Negotiation
var alpn *ALPNExtension
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
}
// Construct base ClientHello
ch := &ClientHelloBody{
CipherSuites: state.Caps.CipherSuites,
}
_, err := prng.Read(ch.Random[:])
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
return nil, nil, AlertInternalError
}
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
err := ch.Extensions.Add(ext)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
return nil, nil, AlertInternalError
}
}
// XXX: These optional extensions can't be folded into the above because Go
// interface-typed values are never reported as nil
if alpn != nil {
err := ch.Extensions.Add(alpn)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.cookie != nil {
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
return nil, nil, AlertInternalError
}
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
// Handle PSK and EarlyData just before transmitting, so that we can
// calculate the PSK binder value
var psk *PreSharedKeyExtension
var ed *EarlyDataExtension
var offeredPSK PreSharedKey
var earlyHash crypto.Hash
var earlySecret []byte
var clientEarlyTrafficKeys keySet
var clientHello *HandshakeMessage
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
offeredPSK = key
// Narrow ciphersuites to ones that match PSK hash
params, ok := cipherSuiteMap[key.CipherSuite]
if !ok {
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
return nil, nil, AlertInternalError
}
compatibleSuites := []CipherSuite{}
for _, suite := range ch.CipherSuites {
if cipherSuiteMap[suite].Hash == params.Hash {
compatibleSuites = append(compatibleSuites, suite)
}
}
ch.CipherSuites = compatibleSuites
// Signal early data if we're going to do it
if len(state.Opts.EarlyData) > 0 {
state.Params.ClientSendingEarlyData = true
ed = &EarlyDataExtension{}
err = ch.Extensions.Add(ed)
if err != nil {
logf(logTypeHandshake, "Error adding early data extension: %v", err)
return nil, nil, AlertInternalError
}
}
// Signal supported PSK key exchange modes
if len(state.Caps.PSKModes) == 0 {
logf(logTypeHandshake, "PSK selected, but no PSKModes")
return nil, nil, AlertInternalError
}
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
err = ch.Extensions.Add(kem)
if err != nil {
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
return nil, nil, AlertInternalError
}
// Add the shim PSK extension to the ClientHello
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
psk = &PreSharedKeyExtension{
HandshakeType: HandshakeTypeClientHello,
Identities: []PSKIdentity{
{
Identity: key.Identity,
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
},
},
Binders: []PSKBinderEntry{
// Note: Stub to get the length fields right
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
},
}
ch.Extensions.Add(psk)
// Compute the binder key
h0 := params.Hash.New().Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlyHash = params.Hash
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
binderLabel := labelExternalBinder
if key.IsResumption {
binderLabel = labelResumptionBinder
}
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
// Compute the binder value
trunc, err := ch.Truncated()
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
return nil, nil, AlertInternalError
}
truncHash := params.Hash.New()
truncHash.Write(trunc)
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
// Replace the PSK extension
psk.Binders[0].Binder = binder
ch.Extensions.Add(psk)
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
// this one should too.
clientHello, _ = HandshakeMessageFromBody(ch)
// Compute early traffic keys
h := params.Hash.New()
h.Write(clientHello.Marshal())
chHash := h.Sum(nil)
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
} else if len(state.Opts.EarlyData) > 0 {
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
return nil, nil, AlertInternalError
} else {
clientHello, err = HandshakeMessageFromBody(ch)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
return nil, nil, AlertInternalError
}
}
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
nextState := ClientStateWaitSH{
Caps: state.Caps,
Opts: state.Opts,
Params: state.Params,
OfferedDH: offeredDH,
OfferedPSK: offeredPSK,
earlySecret: earlySecret,
earlyHash: earlyHash,
firstClientHello: state.firstClientHello,
helloRetryRequest: state.helloRetryRequest,
clientHello: clientHello,
}
toSend := []HandshakeAction{
SendHandshakeMessage{clientHello},
}
if state.Params.ClientSendingEarlyData {
toSend = append(toSend, []HandshakeAction{
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
SendEarlyData{},
}...)
}
return nextState, toSend, AlertNoAlert
}
type ClientStateWaitSH struct {
Caps Capabilities
Opts ConnectionOptions
Params ConnectionParameters
OfferedDH map[NamedGroup][]byte
OfferedPSK PreSharedKey
PSK []byte
earlySecret []byte
earlyHash crypto.Hash
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
clientHello *HandshakeMessage
}
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
switch body := bodyGeneric.(type) {
case *HelloRetryRequestBody:
hrr := body
if state.helloRetryRequest != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
return nil, nil, AlertUnexpectedMessage
}
// Check that the version sent by the server is the one we support
if hrr.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Narrow the supported ciphersuites to the server-provided one
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// The only thing we know how to respond to in an HRR is the Cookie
// extension, so if there is either no Cookie extension or anything other
// than a Cookie extension, we have to fail.
serverCookie := new(CookieExtension)
foundCookie := hrr.Extensions.Find(serverCookie)
if !foundCookie || len(hrr.Extensions) != 1 {
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
return nil, nil, AlertIllegalParameter
}
// Hash the body into a pseudo-message
// XXX: Ignoring some errors here
params := cipherSuiteMap[hrr.CipherSuite]
h := params.Hash.New()
h.Write(state.clientHello.Marshal())
firstClientHello := &HandshakeMessage{
msgType: HandshakeTypeMessageHash,
body: h.Sum(nil),
}
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
return ClientStateStart{
Caps: state.Caps,
Opts: state.Opts,
cookie: serverCookie.Cookie,
firstClientHello: firstClientHello,
helloRetryRequest: hm,
}.Next(nil)
case *ServerHelloBody:
sh := body
// Check that the version sent by the server is the one we support
if sh.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// Do PSK or key agreement depending on extensions
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
foundPSK := sh.Extensions.Find(&serverPSK)
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
if foundPSK && (serverPSK.SelectedIdentity == 0) {
state.Params.UsingPSK = true
}
var dhSecret []byte
if foundKeyShare {
sks := serverKeyShare.Shares[0]
priv, ok := state.OfferedDH[sks.Group]
if !ok {
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
return nil, nil, AlertIllegalParameter
}
state.Params.UsingDH = true
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
}
suite := sh.CipherSuite
state.Params.CipherSuite = suite
params, ok := cipherSuiteMap[suite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(hm.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
if params.Hash != state.earlyHash {
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
state.earlyHash, suite, params.Hash)
}
earlySecret = state.earlySecret
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if dhSecret == nil {
dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := ClientStateWaitEE{
Caps: state.Caps,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
certificates: state.Caps.Certificates,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
}
toSend := []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
}
return nextState, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
return nil, nil, AlertUnexpectedMessage
}
type ClientStateWaitEE struct {
Caps Capabilities
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
ee := EncryptedExtensionsBody{}
_, err := ee.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
serverALPN := ALPNExtension{}
serverEarlyData := EarlyDataExtension{}
gotALPN := ee.Extensions.Find(&serverALPN)
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
if gotALPN && len(serverALPN.Protocols) > 0 {
state.Params.NextProto = serverALPN.Protocols[0]
}
state.handshakeHash.Write(hm.Marshal())
if state.Params.UsingPSK {
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
nextState := ClientStateWaitCertCR{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitCertCR struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
switch body := bodyGeneric.(type) {
case *CertificateBody:
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
case *CertificateRequestBody:
// A certificate request in the handshake should have a zero-length context
if len(body.CertificateRequestContext) > 0 {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
return nil, nil, AlertIllegalParameter
}
state.Params.UsingClientAuth = true
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
nextState := ClientStateWaitCert{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
return nil, nil, AlertUnexpectedMessage
}
type ClientStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificate {
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
cert := &CertificateBody{}
_, err := cert.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: cert,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificate *CertificateBody
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
certVerify := CertificateVerifyBody{}
_, err := certVerify.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.serverCertificate.CertificateList)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
}
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitFinished struct {
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeFinished {
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
// Verify server's Finished
h3 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
_, err := fin.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
fin.VerifyData, serverFinishedData)
return nil, nil, AlertHandshakeFailure
}
// Update the handshake hash with the Finished
state.handshakeHash.Write(hm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
h4 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
// Compute traffic secrets and keys
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
// Assemble client's second flight
toSend := []HandshakeAction{}
if state.Params.UsingEarlyData {
// Note: We only send EOED if the server is actually going to use the early
// data. Otherwise, it will never see it, and the transcripts will
// mismatch.
// EOED marshal is infallible
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
toSend = append(toSend, SendHandshakeMessage{eoedm})
state.handshakeHash.Write(eoedm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
}
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
if state.Params.UsingClientAuth {
// Extract constraints from certicateRequest
schemes := SignatureAlgorithmsExtension{}
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
if !gotSchemes {
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
return nil, nil, AlertIllegalParameter
}
// Select a certificate
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
if err != nil {
// XXX: Signal this to the application layer?
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
certificate := &CertificateBody{}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal())
} else {
// Create and send Certificate, CertificateVerify
certificate := &CertificateBody{
CertificateList: make([]CertificateEntry, len(cert.Chain)),
}
for i, entry := range cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal())
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
err = certificateVerify.Sign(cert.PrivateKey, hcv)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
certvm, err := HandshakeMessageFromBody(certificateVerify)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certvm})
state.handshakeHash.Write(certvm.Marshal())
}
}
// Compute the client's Finished message
h5 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
fin = &FinishedBody{
VerifyDataLen: len(clientFinishedData),
VerifyData: clientFinishedData,
}
finm, err := HandshakeMessageFromBody(fin)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
return nil, nil, AlertInternalError
}
// Compute the resumption secret
state.handshakeHash.Write(finm.Marshal())
h6 := state.handshakeHash.Sum(nil)
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
toSend = append(toSend, []HandshakeAction{
SendHandshakeMessage{finm},
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
}...)
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
nextState := StateConnected{
Params: state.Params,
isClient: true,
cryptoParams: state.cryptoParams,
resumptionSecret: resumptionSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
return nextState, toSend, AlertNoAlert
}

View file

@ -1,152 +0,0 @@
package mint
import (
"fmt"
"strconv"
)
var (
supportedVersion uint16 = 0x7f15 // draft-21
// Flags for some minor compat issues
allowWrongVersionNumber = true
allowPKCS1 = true
)
// enum {...} ContentType;
type RecordType byte
const (
RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23
)
// enum {...} HandshakeType;
type HandshakeType byte
const (
// Omitted: *_RESERVED
HandshakeTypeClientHello HandshakeType = 1
HandshakeTypeServerHello HandshakeType = 2
HandshakeTypeNewSessionTicket HandshakeType = 4
HandshakeTypeEndOfEarlyData HandshakeType = 5
HandshakeTypeHelloRetryRequest HandshakeType = 6
HandshakeTypeEncryptedExtensions HandshakeType = 8
HandshakeTypeCertificate HandshakeType = 11
HandshakeTypeCertificateRequest HandshakeType = 13
HandshakeTypeCertificateVerify HandshakeType = 15
HandshakeTypeServerConfiguration HandshakeType = 17
HandshakeTypeFinished HandshakeType = 20
HandshakeTypeKeyUpdate HandshakeType = 24
HandshakeTypeMessageHash HandshakeType = 254
)
// uint8 CipherSuite[2];
type CipherSuite uint16
const (
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
// value for this type so that we can detect when a field is set.
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
)
func (c CipherSuite) String() string {
switch c {
case CIPHER_SUITE_UNKNOWN:
return "unknown"
case TLS_AES_128_GCM_SHA256:
return "TLS_AES_128_GCM_SHA256"
case TLS_AES_256_GCM_SHA384:
return "TLS_AES_256_GCM_SHA384"
case TLS_CHACHA20_POLY1305_SHA256:
return "TLS_CHACHA20_POLY1305_SHA256"
case TLS_AES_128_CCM_SHA256:
return "TLS_AES_128_CCM_SHA256"
case TLS_AES_256_CCM_8_SHA256:
return "TLS_AES_256_CCM_8_SHA256"
}
// cannot use %x here, since it calls String(), leading to infinite recursion
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
}
// enum {...} SignatureScheme
type SignatureScheme uint16
const (
// RSASSA-PKCS1-v1_5 algorithms
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
// ECDSA algorithms
ECDSA_P256_SHA256 SignatureScheme = 0x0403
ECDSA_P384_SHA384 SignatureScheme = 0x0503
ECDSA_P521_SHA512 SignatureScheme = 0x0603
// RSASSA-PSS algorithms
RSA_PSS_SHA256 SignatureScheme = 0x0804
RSA_PSS_SHA384 SignatureScheme = 0x0805
RSA_PSS_SHA512 SignatureScheme = 0x0806
// EdDSA algorithms
Ed25519 SignatureScheme = 0x0807
Ed448 SignatureScheme = 0x0808
)
// enum {...} ExtensionType
type ExtensionType uint16
const (
ExtensionTypeServerName ExtensionType = 0
ExtensionTypeSupportedGroups ExtensionType = 10
ExtensionTypeSignatureAlgorithms ExtensionType = 13
ExtensionTypeALPN ExtensionType = 16
ExtensionTypeKeyShare ExtensionType = 40
ExtensionTypePreSharedKey ExtensionType = 41
ExtensionTypeEarlyData ExtensionType = 42
ExtensionTypeSupportedVersions ExtensionType = 43
ExtensionTypeCookie ExtensionType = 44
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
)
// enum {...} NamedGroup
type NamedGroup uint16
const (
// Elliptic Curve Groups.
P256 NamedGroup = 23
P384 NamedGroup = 24
P521 NamedGroup = 25
// ECDH functions.
X25519 NamedGroup = 29
X448 NamedGroup = 30
// Finite field groups.
FFDHE2048 NamedGroup = 256
FFDHE3072 NamedGroup = 257
FFDHE4096 NamedGroup = 258
FFDHE6144 NamedGroup = 259
FFDHE8192 NamedGroup = 260
)
// enum {...} PskKeyExchangeMode;
type PSKKeyExchangeMode uint8
const (
PSKModeKE PSKKeyExchangeMode = 0
PSKModeDHEKE PSKKeyExchangeMode = 1
)
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
type KeyUpdateRequest uint8
const (
KeyUpdateNotRequested KeyUpdateRequest = 0
KeyUpdateRequested KeyUpdateRequest = 1
)

View file

@ -1,819 +0,0 @@
package mint
import (
"crypto"
"crypto/x509"
"encoding/hex"
"fmt"
"io"
"net"
"reflect"
"sync"
"time"
)
var WouldBlock = fmt.Errorf("Would have blocked")
type Certificate struct {
Chain []*x509.Certificate
PrivateKey crypto.Signer
}
type PreSharedKey struct {
CipherSuite CipherSuite
IsResumption bool
Identity []byte
Key []byte
NextProto string
ReceivedAt time.Time
ExpiresAt time.Time
TicketAgeAdd uint32
}
type PreSharedKeyCache interface {
Get(string) (PreSharedKey, bool)
Put(string, PreSharedKey)
Size() int
}
type PSKMapCache map[string]PreSharedKey
// A CookieHandler does two things:
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
// - validates this byte string echoed by the client in the ClientHello
type CookieHandler interface {
Generate(*Conn) ([]byte, error)
Validate(*Conn, []byte) bool
}
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
psk, ok = cache[key]
return
}
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
(*cache)[key] = psk
}
func (cache PSKMapCache) Size() int {
return len(cache)
}
// Config is the struct used to pass configuration settings to a TLS client or
// server instance. The settings for client and server are pretty different,
// but we just throw them all in here.
type Config struct {
// Client fields
ServerName string
// Server fields
SendSessionTickets bool
TicketLifetime uint32
TicketLen int
EarlyDataLifetime uint32
AllowEarlyData bool
// Require the client to echo a cookie.
RequireCookie bool
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
// The default cookie handler uses 32 random bytes as a cookie.
CookieHandler CookieHandler
RequireClientAuth bool
// Shared fields
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
NextProtos []string
PSKs PreSharedKeyCache
PSKModes []PSKKeyExchangeMode
NonBlocking bool
// The same config object can be shared among different connections, so it
// needs its own mutex
mutex sync.RWMutex
}
// Clone returns a shallow clone of c. It is safe to clone a Config that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
c.mutex.Lock()
defer c.mutex.Unlock()
return &Config{
ServerName: c.ServerName,
SendSessionTickets: c.SendSessionTickets,
TicketLifetime: c.TicketLifetime,
TicketLen: c.TicketLen,
EarlyDataLifetime: c.EarlyDataLifetime,
AllowEarlyData: c.AllowEarlyData,
RequireCookie: c.RequireCookie,
RequireClientAuth: c.RequireClientAuth,
Certificates: c.Certificates,
AuthCertificate: c.AuthCertificate,
CipherSuites: c.CipherSuites,
Groups: c.Groups,
SignatureSchemes: c.SignatureSchemes,
NextProtos: c.NextProtos,
PSKs: c.PSKs,
PSKModes: c.PSKModes,
NonBlocking: c.NonBlocking,
}
}
func (c *Config) Init(isClient bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Set defaults
if len(c.CipherSuites) == 0 {
c.CipherSuites = defaultSupportedCipherSuites
}
if len(c.Groups) == 0 {
c.Groups = defaultSupportedGroups
}
if len(c.SignatureSchemes) == 0 {
c.SignatureSchemes = defaultSignatureSchemes
}
if c.TicketLen == 0 {
c.TicketLen = defaultTicketLen
}
if !reflect.ValueOf(c.PSKs).IsValid() {
c.PSKs = &PSKMapCache{}
}
if len(c.PSKModes) == 0 {
c.PSKModes = defaultPSKModes
}
// If there is no certificate, generate one
if !isClient && len(c.Certificates) == 0 {
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
priv, err := newSigningKey(RSA_PSS_SHA256)
if err != nil {
return err
}
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
if err != nil {
return err
}
c.Certificates = []*Certificate{
{
Chain: []*x509.Certificate{cert},
PrivateKey: priv,
},
}
}
return nil
}
func (c *Config) ValidForServer() bool {
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
(len(c.Certificates) > 0 &&
len(c.Certificates[0].Chain) > 0 &&
c.Certificates[0].PrivateKey != nil)
}
func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0
}
var (
defaultSupportedCipherSuites = []CipherSuite{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
defaultSupportedGroups = []NamedGroup{
P256,
P384,
FFDHE2048,
X25519,
}
defaultSignatureSchemes = []SignatureScheme{
RSA_PSS_SHA256,
RSA_PSS_SHA384,
RSA_PSS_SHA512,
ECDSA_P256_SHA256,
ECDSA_P384_SHA384,
ECDSA_P521_SHA512,
}
defaultTicketLen = 16
defaultPSKModes = []PSKKeyExchangeMode{
PSKModeKE,
PSKModeDHEKE,
}
)
type ConnectionState struct {
HandshakeState string // string representation of the handshake state.
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
NextProto string // Selected ALPN proto
}
// Conn implements the net.Conn interface, as with "crypto/tls"
// * Read, Write, and Close are provided locally
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
type Conn struct {
config *Config
conn net.Conn
isClient bool
EarlyData []byte
state StateConnected
hState HandshakeState
handshakeMutex sync.Mutex
handshakeAlert Alert
handshakeComplete bool
readBuffer []byte
in, out *RecordLayer
hIn, hOut *HandshakeLayer
extHandler AppExtensionHandler
}
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient}
c.in = NewRecordLayer(c.conn)
c.out = NewRecordLayer(c.conn)
c.hIn = NewHandshakeLayer(c.in)
c.hIn.nonblocking = c.config.NonBlocking
c.hOut = NewHandshakeLayer(c.out)
return c
}
// Read up
func (c *Conn) consumeRecord() error {
pt, err := c.in.ReadRecord()
if pt == nil {
logf(logTypeIO, "extendBuffer returns error %v", err)
return err
}
switch pt.contentType {
case RecordTypeHandshake:
logf(logTypeHandshake, "Received post-handshake message")
// We do not support fragmentation of post-handshake handshake messages.
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
start := 0
for start < len(pt.fragment) {
if len(pt.fragment[start:]) < handshakeHeaderLen {
return fmt.Errorf("Post-handshake handshake message too short for header")
}
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(pt.fragment[start])
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
return fmt.Errorf("Post-handshake handshake message too short for body")
}
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
// Advance state machine
state, actions, alert := c.state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return io.EOF
}
}
// XXX: If we want to support more advanced cases, e.g., post-handshake
// authentication, we'll need to allow transitions other than
// Connected -> Connected
var connected bool
c.state, connected = state.(StateConnected)
if !connected {
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
start += handshakeHeaderLen + hmLen
}
case RecordTypeAlert:
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
if len(pt.fragment) != 2 {
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
if Alert(pt.fragment[1]) == AlertCloseNotify {
return io.EOF
}
switch pt.fragment[0] {
case AlertLevelWarning:
// drop on the floor
case AlertLevelError:
return Alert(pt.fragment[1])
default:
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
case RecordTypeApplicationData:
c.readBuffer = append(c.readBuffer, pt.fragment...)
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
}
return err
}
// Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) {
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
if alert := c.Handshake(); alert != AlertNoAlert {
return 0, alert
}
if len(buffer) == 0 {
return 0, nil
}
// Lock the input channel
c.in.Lock()
defer c.in.Unlock()
for len(c.readBuffer) == 0 {
err := c.consumeRecord()
// err can be nil if consumeRecord processed a non app-data
// record.
if err != nil {
if c.config.NonBlocking || err != WouldBlock {
logf(logTypeIO, "conn.Read returns err=%v", err)
return 0, err
}
}
}
var read int
n := len(buffer)
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
if len(c.readBuffer) <= n {
buffer = buffer[:len(c.readBuffer)]
copy(buffer, c.readBuffer)
read = len(c.readBuffer)
c.readBuffer = c.readBuffer[:0]
} else {
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
copy(buffer[:n], c.readBuffer[:n])
c.readBuffer = c.readBuffer[n:]
read = n
}
logf(logTypeVerbose, "Returning %v", string(buffer))
return read, nil
}
// Write application data
func (c *Conn) Write(buffer []byte) (int, error) {
// Lock the output channel
c.out.Lock()
defer c.out.Unlock()
// Send full-size fragments
var start int
sent := 0
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return sent, err
}
sent += maxFragmentLen
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start:],
})
if err != nil {
return sent, err
}
sent += len(buffer[start:])
}
return sent, nil
}
// sendAlert sends a TLS alert message.
// c.out.Mutex <= L.
func (c *Conn) sendAlert(err Alert) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
var level int
switch err {
case AlertNoRenegotiation, AlertCloseNotify:
level = AlertLevelWarning
default:
level = AlertLevelError
}
buf := []byte{byte(err), byte(level)}
c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: buf,
})
// close_notify and end_of_early_data are not actually errors
if level == AlertLevelWarning {
return &net.OpError{Op: "local error", Err: err}
}
return c.Close()
}
// Close closes the connection.
func (c *Conn) Close() error {
// XXX crypto/tls has an interlock with Write here. Do we need that?
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
switch action := actionGeneric.(type) {
case SendHandshakeMessage:
err := c.hOut.WriteMessage(action.Message)
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
case RekeyIn:
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
return AlertInternalError
}
case RekeyOut:
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
return AlertInternalError
}
case SendEarlyData:
logf(logTypeHandshake, "%s Sending early data...", label)
_, err := c.Write(c.EarlyData)
if err != nil {
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
return AlertInternalError
}
case ReadPastEarlyData:
logf(logTypeHandshake, "%s Reading past early data...", label)
// Scan past all records that fail to decrypt
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok := err.(DecryptError)
for ok {
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok = err.(DecryptError)
}
case ReadEarlyData:
logf(logTypeHandshake, "%s Reading early data...", label)
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
for t == RecordTypeApplicationData {
// Read a record into the buffer. Note that this is safe
// in blocking mode because we read the record in in
// PeekRecordType.
pt, err := c.in.ReadRecord()
if err != nil {
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
c.EarlyData = append(c.EarlyData, pt.fragment...)
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
}
logf(logTypeHandshake, "%s Done reading early data", label)
case StorePSK:
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
if c.isClient {
// Clients look up PSKs based on server name
c.config.PSKs.Put(c.config.ServerName, action.PSK)
} else {
// Servers look them up based on the identity in the extension
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
}
default:
logf(logTypeHandshake, "%s Unknown actionuction type", label)
return AlertInternalError
}
return AlertNoAlert
}
func (c *Conn) HandshakeSetup() Alert {
var state HandshakeState
var actions []HandshakeAction
var alert Alert
if err := c.config.Init(c.isClient); err != nil {
logf(logTypeHandshake, "Error initializing config: %v", err)
return AlertInternalError
}
// Set things up
caps := Capabilities{
CipherSuites: c.config.CipherSuites,
Groups: c.config.Groups,
SignatureSchemes: c.config.SignatureSchemes,
PSKs: c.config.PSKs,
PSKModes: c.config.PSKModes,
AllowEarlyData: c.config.AllowEarlyData,
RequireCookie: c.config.RequireCookie,
CookieHandler: c.config.CookieHandler,
RequireClientAuth: c.config.RequireClientAuth,
NextProtos: c.config.NextProtos,
Certificates: c.config.Certificates,
ExtensionHandler: c.extHandler,
}
opts := ConnectionOptions{
ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos,
EarlyData: c.EarlyData,
}
if caps.RequireCookie && caps.CookieHandler == nil {
caps.CookieHandler = &defaultCookieHandler{}
}
if c.isClient {
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
return alert
}
}
} else {
state = ServerStateStart{Caps: caps, conn: c}
}
c.hState = state
return AlertNoAlert
}
// Handshake causes a TLS handshake on the connection. The `isClient` member
// determines whether a client or server handshake is performed. If a
// handshake has already been performed, then its result will be returned.
func (c *Conn) Handshake() Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
// TODO Lock handshakeMutex
// TODO Remove CloseNotify hack
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
return c.handshakeAlert
}
if c.handshakeComplete {
return AlertNoAlert
}
var alert Alert
if c.hState == nil {
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
alert = c.HandshakeSetup()
if alert != AlertNoAlert {
return alert
}
} else {
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
}
state := c.hState
_, connected := state.(StateConnected)
var actions []HandshakeAction
for !connected {
// Read a handshake message
hm, err := c.hIn.ReadMessage()
if err == WouldBlock {
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
return AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
c.sendAlert(AlertCloseNotify)
return AlertCloseNotify
}
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
// Advance the state machine
state, actions, alert = state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
return alert
}
for index, action := range actions {
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
c.hState = state
logf(logTypeHandshake, "state is now %s", c.GetHsState())
_, connected = state.(StateConnected)
}
c.state = state.(StateConnected)
// Send NewSessionTicket if acting as server
if !c.isClient && c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket(
c.config.TicketLen,
c.config.TicketLifetime,
c.config.EarlyDataLifetime)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
}
c.handshakeComplete = true
return AlertNoAlert
}
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
if !c.handshakeComplete {
return fmt.Errorf("Cannot update keys until after handshake")
}
request := KeyUpdateNotRequested
if requestUpdate {
request = KeyUpdateRequested
}
// Create the key update and update state
actions, alert := c.state.KeyUpdate(request)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert while generating key update: %v", alert)
}
// Take actions (send key update and rekey)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert during key update actions: %v", alert)
}
}
return nil
}
func (c *Conn) GetHsState() string {
return reflect.TypeOf(c.hState).Name()
}
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
_, connected := c.hState.(StateConnected)
if !connected {
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
}
if c.state.exporterSecret == nil {
return nil, fmt.Errorf("Internal error: no exporter secret")
}
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
hc := c.state.cryptoParams.Hash.New().Sum(context)
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
}
func (c *Conn) State() ConnectionState {
state := ConnectionState{
HandshakeState: c.GetHsState(),
}
if c.handshakeComplete {
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
state.NextProto = c.state.Params.NextProto
}
return state
}
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
if c.hState != nil {
return fmt.Errorf("Can't set extension handler after setup")
}
c.extHandler = h
return nil
}

View file

@ -1,654 +0,0 @@
package mint
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
"time"
"golang.org/x/crypto/curve25519"
// Blank includes to ensure hash support
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
)
var prng = rand.Reader
type aeadFactory func(key []byte) (cipher.AEAD, error)
type CipherSuiteParams struct {
Suite CipherSuite
Cipher aeadFactory // Cipher factory
Hash crypto.Hash // Hash function
KeyLen int // Key length in octets
IvLen int // IV length in octets
}
type signatureAlgorithm uint8
const (
signatureAlgorithmUnknown = iota
signatureAlgorithmRSA_PKCS1
signatureAlgorithmRSA_PSS
signatureAlgorithmECDSA
)
var (
hashMap = map[SignatureScheme]crypto.Hash{
RSA_PKCS1_SHA1: crypto.SHA1,
RSA_PKCS1_SHA256: crypto.SHA256,
RSA_PKCS1_SHA384: crypto.SHA384,
RSA_PKCS1_SHA512: crypto.SHA512,
ECDSA_P256_SHA256: crypto.SHA256,
ECDSA_P384_SHA384: crypto.SHA384,
ECDSA_P521_SHA512: crypto.SHA512,
RSA_PSS_SHA256: crypto.SHA256,
RSA_PSS_SHA384: crypto.SHA384,
RSA_PSS_SHA512: crypto.SHA512,
}
sigMap = map[SignatureScheme]signatureAlgorithm{
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
}
curveMap = map[SignatureScheme]NamedGroup{
ECDSA_P256_SHA256: P256,
ECDSA_P384_SHA384: P384,
ECDSA_P521_SHA512: P521,
}
newAESGCM = func(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// TLS always uses 12-byte nonces
return cipher.NewGCMWithNonceSize(block, 12)
}
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
TLS_AES_128_GCM_SHA256: {
Suite: TLS_AES_128_GCM_SHA256,
Cipher: newAESGCM,
Hash: crypto.SHA256,
KeyLen: 16,
IvLen: 12,
},
TLS_AES_256_GCM_SHA384: {
Suite: TLS_AES_256_GCM_SHA384,
Cipher: newAESGCM,
Hash: crypto.SHA384,
KeyLen: 32,
IvLen: 12,
},
}
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
}
defaultRSAKeySize = 2048
)
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
switch group {
case P256:
crv = elliptic.P256()
case P384:
crv = elliptic.P384()
case P521:
crv = elliptic.P521()
}
return
}
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
switch key.Curve.Params().Name {
case elliptic.P256().Params().Name:
g = P256
case elliptic.P384().Params().Name:
g = P384
case elliptic.P521().Params().Name:
g = P521
}
return
}
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
size = 0
switch group {
case X25519:
size = 32
case P256:
size = 65
case P384:
size = 97
case P521:
size = 133
case FFDHE2048:
size = 256
case FFDHE3072:
size = 384
case FFDHE4096:
size = 512
case FFDHE6144:
size = 768
case FFDHE8192:
size = 1024
}
return
}
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
switch group {
case FFDHE2048:
p = finiteFieldPrime2048
case FFDHE3072:
p = finiteFieldPrime3072
case FFDHE4096:
p = finiteFieldPrime4096
case FFDHE6144:
p = finiteFieldPrime6144
case FFDHE8192:
p = finiteFieldPrime8192
}
return
}
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
sigType := sigMap[alg]
switch key.(type) {
case *rsa.PrivateKey:
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
case *ecdsa.PrivateKey:
return sigType == signatureAlgorithmECDSA
default:
return false
}
}
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
primeLen := len(p.Bytes())
for {
// g = 2 for all ffdhe groups
priv, err = rand.Int(prng, p)
if err != nil {
return
}
pub = big.NewInt(0)
pub.Exp(big.NewInt(2), priv, p)
if len(pub.Bytes()) == primeLen {
return
}
}
}
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
switch group {
case P256, P384, P521:
var x, y *big.Int
crv := curveFromNamedGroup(group)
priv, x, y, err = elliptic.GenerateKey(crv, prng)
if err != nil {
return
}
pub = elliptic.Marshal(crv, x, y)
return
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
p := primeFromNamedGroup(group)
x, X, err2 := ffdheKeyShareFromPrime(p)
if err2 != nil {
err = err2
return
}
priv = x.Bytes()
pubBytes := X.Bytes()
numBytes := keyExchangeSizeFromNamedGroup(group)
pub = make([]byte, numBytes)
copy(pub[numBytes-len(pubBytes):], pubBytes)
return
case X25519:
var private, public [32]byte
_, err = prng.Read(private[:])
if err != nil {
return
}
curve25519.ScalarBaseMult(&public, &private)
priv = private[:]
pub = public[:]
return
default:
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
}
}
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
switch group {
case P256, P384, P521:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
crv := curveFromNamedGroup(group)
pubX, pubY := elliptic.Unmarshal(crv, pub)
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
xBytes := x.Bytes()
numBytes := len(crv.Params().P.Bytes())
ret := make([]byte, numBytes)
copy(ret[numBytes-len(xBytes):], xBytes)
return ret, nil
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
numBytes := keyExchangeSizeFromNamedGroup(group)
if len(pub) != numBytes {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
p := primeFromNamedGroup(group)
x := big.NewInt(0).SetBytes(priv)
Y := big.NewInt(0).SetBytes(pub)
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
ret := make([]byte, numBytes)
copy(ret[numBytes-len(ZBytes):], ZBytes)
return ret, nil
case X25519:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
var private, public, ret [32]byte
copy(private[:], priv)
copy(public[:], pub)
curve25519.ScalarMult(&ret, &private, &public)
return ret[:], nil
default:
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
}
}
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
switch sig {
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
RSA_PSS_SHA256, RSA_PSS_SHA384,
RSA_PSS_SHA512:
return rsa.GenerateKey(prng, defaultRSAKeySize)
case ECDSA_P256_SHA256:
return ecdsa.GenerateKey(elliptic.P256(), prng)
case ECDSA_P384_SHA384:
return ecdsa.GenerateKey(elliptic.P384(), prng)
case ECDSA_P521_SHA512:
return ecdsa.GenerateKey(elliptic.P521(), prng)
default:
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
}
}
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}
// XXX(rlb): Copied from crypto/x509
type ecdsaSignature struct {
R, S *big.Int
}
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
var opts crypto.SignerOpts
hash := hashMap[alg]
if hash == crypto.SHA1 {
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
var realInput []byte
switch key := privateKey.(type) {
case *rsa.PrivateKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
opts = hash
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
case *ecdsa.PrivateKey:
if sigType != signatureAlgorithmECDSA {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
}
algGroup := curveMap[alg]
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
if algGroup != keyGroup {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
}
sig, err := privateKey.Sign(prng, realInput, opts)
logf(logTypeCrypto, "signature: %x", sig)
return sig, err
}
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
hash := hashMap[alg]
if hash == crypto.SHA1 {
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
switch pub := publicKey.(type) {
case *rsa.PublicKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
default:
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
}
case *ecdsa.PublicKey:
if sigType != signatureAlgorithmECDSA {
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
}
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
}
ecdsaSig := new(ecdsaSignature)
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
return err
} else if len(rest) != 0 {
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
return fmt.Errorf("tls.verify: ECDSA verification failure")
}
return nil
default:
return fmt.Errorf("tls.verify: Unsupported key type")
}
}
// 0
// |
// v
// PSK -> HKDF-Extract = Early Secret
// |
// +-----> Derive-Secret(.,
// | "ext binder" |
// | "res binder",
// | "")
// | = binder_key
// |
// +-----> Derive-Secret(., "c e traffic",
// | ClientHello)
// | = client_early_traffic_secret
// |
// +-----> Derive-Secret(., "e exp master",
// | ClientHello)
// | = early_exporter_master_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// (EC)DHE -> HKDF-Extract = Handshake Secret
// |
// +-----> Derive-Secret(., "c hs traffic",
// | ClientHello...ServerHello)
// | = client_handshake_traffic_secret
// |
// +-----> Derive-Secret(., "s hs traffic",
// | ClientHello...ServerHello)
// | = server_handshake_traffic_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// 0 -> HKDF-Extract = Master Secret
// |
// +-----> Derive-Secret(., "c ap traffic",
// | ClientHello...server Finished)
// | = client_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "s ap traffic",
// | ClientHello...server Finished)
// | = server_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "exp master",
// | ClientHello...server Finished)
// | = exporter_master_secret
// |
// +-----> Derive-Secret(., "res master",
// ClientHello...client Finished)
// = resumption_master_secret
// From RFC 5869
// PRK = HMAC-Hash(salt, IKM)
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
salt := saltIn
// if [salt is] not provided, it is set to a string of HashLen zeros
if salt == nil {
salt = bytes.Repeat([]byte{0}, hash.Size())
}
h := hmac.New(hash.New, salt)
h.Write(input)
out := h.Sum(nil)
logf(logTypeCrypto, "HKDF Extract:\n")
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
return out
}
const (
labelExternalBinder = "ext binder"
labelResumptionBinder = "res binder"
labelEarlyTrafficSecret = "c e traffic"
labelEarlyExporterSecret = "e exp master"
labelClientHandshakeTrafficSecret = "c hs traffic"
labelServerHandshakeTrafficSecret = "s hs traffic"
labelClientApplicationTrafficSecret = "c ap traffic"
labelServerApplicationTrafficSecret = "s ap traffic"
labelExporterSecret = "exp master"
labelResumptionSecret = "res master"
labelDerived = "derived"
labelFinished = "finished"
labelResumption = "resumption"
)
// struct HkdfLabel {
// uint16 length;
// opaque label<9..255>;
// opaque hash_value<0..255>;
// };
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
label := "tls13 " + labelIn
labelLen := len(label)
hashLen := len(hashValue)
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
hkdfLabel[0] = byte(outLen >> 8)
hkdfLabel[1] = byte(outLen)
hkdfLabel[2] = byte(labelLen)
copy(hkdfLabel[3:3+labelLen], []byte(label))
hkdfLabel[3+labelLen] = byte(hashLen)
copy(hkdfLabel[3+labelLen+1:], hashValue)
return hkdfLabel
}
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
out := []byte{}
T := []byte{}
i := byte(1)
for len(out) < outLen {
block := append(T, info...)
block = append(block, i)
h := hmac.New(hash.New, prk)
h.Write(block)
T = h.Sum(nil)
out = append(out, T...)
i++
}
return out[:outLen]
}
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
info := hkdfEncodeLabel(label, hashValue, outLen)
derived := HkdfExpand(hash, secret, info, outLen)
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
return derived
}
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
}
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
mac := hmac.New(params.Hash.New, macKey)
mac.Write(input)
return mac.Sum(nil)
}
type keySet struct {
cipher aeadFactory
key []byte
iv []byte
}
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
return keySet{
cipher: params.Cipher,
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
}
}

View file

@ -1,586 +0,0 @@
package mint
import (
"bytes"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type ExtensionBody interface {
Type() ExtensionType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ExtensionType extension_type;
// opaque extension_data<0..2^16-1>;
// } Extension;
type Extension struct {
ExtensionType ExtensionType
ExtensionData []byte `tls:"head=2"`
}
func (ext Extension) Marshal() ([]byte, error) {
return syntax.Marshal(ext)
}
func (ext *Extension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ext)
}
type ExtensionList []Extension
type extensionListInner struct {
List []Extension `tls:"head=2"`
}
func (el ExtensionList) Marshal() ([]byte, error) {
return syntax.Marshal(extensionListInner{el})
}
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
var list extensionListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
*el = list.List
return read, nil
}
func (el *ExtensionList) Add(src ExtensionBody) error {
data, err := src.Marshal()
if err != nil {
return err
}
if el == nil {
el = new(ExtensionList)
}
// If one already exists with this type, replace it
for i := range *el {
if (*el)[i].ExtensionType == src.Type() {
(*el)[i].ExtensionData = data
return nil
}
}
// Otherwise append
*el = append(*el, Extension{
ExtensionType: src.Type(),
ExtensionData: data,
})
return nil
}
func (el ExtensionList) Find(dst ExtensionBody) bool {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
_, err := dst.Unmarshal(ext.ExtensionData)
return err == nil
}
}
return false
}
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
//
// enum {
// host_name(0), (255)
// } NameType;
//
// opaque HostName<1..2^16-1>;
//
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
//
// But we only care about the case where there's a single DNS hostname. We
// will never create anything else, and throw if we receive something else
//
// 2 1 2
// | listLen | NameType | nameLen | name |
type ServerNameExtension string
type serverNameInner struct {
NameType uint8
HostName []byte `tls:"head=2,min=1"`
}
type serverNameListInner struct {
ServerNameList []serverNameInner `tls:"head=2,min=1"`
}
func (sni ServerNameExtension) Type() ExtensionType {
return ExtensionTypeServerName
}
func (sni ServerNameExtension) Marshal() ([]byte, error) {
list := serverNameListInner{
ServerNameList: []serverNameInner{{
NameType: 0x00, // host_name
HostName: []byte(sni),
}},
}
return syntax.Marshal(list)
}
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
var list serverNameListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
// Syntax requires at least one entry
// Entries beyond the first are ignored
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
}
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
return read, nil
}
// struct {
// NamedGroup group;
// opaque key_exchange<1..2^16-1>;
// } KeyShareEntry;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// KeyShareEntry client_shares<0..2^16-1>;
//
// case hello_retry_request:
// NamedGroup selected_group;
//
// case server_hello:
// KeyShareEntry server_share;
// };
// } KeyShare;
type KeyShareEntry struct {
Group NamedGroup
KeyExchange []byte `tls:"head=2,min=1"`
}
func (kse KeyShareEntry) SizeValid() bool {
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
}
type KeyShareExtension struct {
HandshakeType HandshakeType
SelectedGroup NamedGroup
Shares []KeyShareEntry
}
type KeyShareClientHelloInner struct {
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
}
type KeyShareHelloRetryInner struct {
SelectedGroup NamedGroup
}
type KeyShareServerHelloInner struct {
ServerShare KeyShareEntry
}
func (ks KeyShareExtension) Type() ExtensionType {
return ExtensionTypeKeyShare
}
func (ks KeyShareExtension) Marshal() ([]byte, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
for _, share := range ks.Shares {
if !share.SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
case HandshakeTypeHelloRetryRequest:
if len(ks.Shares) > 0 {
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
}
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
case HandshakeTypeServerHello:
if len(ks.Shares) != 1 {
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
}
if !ks.Shares[0].SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
default:
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
var inner KeyShareClientHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
for _, share := range inner.ClientShares {
if !share.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
ks.Shares = inner.ClientShares
return read, nil
case HandshakeTypeHelloRetryRequest:
var inner KeyShareHelloRetryInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
ks.SelectedGroup = inner.SelectedGroup
return read, nil
case HandshakeTypeServerHello:
var inner KeyShareServerHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if !inner.ServerShare.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
ks.Shares = []KeyShareEntry{inner.ServerShare}
return read, nil
default:
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
// struct {
// NamedGroup named_group_list<2..2^16-1>;
// } NamedGroupList;
type SupportedGroupsExtension struct {
Groups []NamedGroup `tls:"head=2,min=2"`
}
func (sg SupportedGroupsExtension) Type() ExtensionType {
return ExtensionTypeSupportedGroups
}
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sg)
}
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sg)
}
// struct {
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
// } SignatureSchemeList
type SignatureAlgorithmsExtension struct {
Algorithms []SignatureScheme `tls:"head=2,min=2"`
}
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
return ExtensionTypeSignatureAlgorithms
}
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sa)
}
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sa)
}
// struct {
// opaque identity<1..2^16-1>;
// uint32 obfuscated_ticket_age;
// } PskIdentity;
//
// opaque PskBinderEntry<32..255>;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// PskIdentity identities<7..2^16-1>;
// PskBinderEntry binders<33..2^16-1>;
//
// case server_hello:
// uint16 selected_identity;
// };
//
// } PreSharedKeyExtension;
type PSKIdentity struct {
Identity []byte `tls:"head=2,min=1"`
ObfuscatedTicketAge uint32
}
type PSKBinderEntry struct {
Binder []byte `tls:"head=1,min=32"`
}
type PreSharedKeyExtension struct {
HandshakeType HandshakeType
Identities []PSKIdentity
Binders []PSKBinderEntry
SelectedIdentity uint16
}
type preSharedKeyClientInner struct {
Identities []PSKIdentity `tls:"head=2,min=7"`
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}
type preSharedKeyServerInner struct {
SelectedIdentity uint16
}
func (psk PreSharedKeyExtension) Type() ExtensionType {
return ExtensionTypePreSharedKey
}
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
return syntax.Marshal(preSharedKeyClientInner{
Identities: psk.Identities,
Binders: psk.Binders,
})
case HandshakeTypeServerHello:
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
}
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
default:
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
var inner preSharedKeyClientInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if len(inner.Identities) != len(inner.Binders) {
return 0, fmt.Errorf("Lengths of identities and binders not equal")
}
psk.Identities = inner.Identities
psk.Binders = inner.Binders
return read, nil
case HandshakeTypeServerHello:
var inner preSharedKeyServerInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
psk.SelectedIdentity = inner.SelectedIdentity
return read, nil
default:
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
for i, localID := range psk.Identities {
if bytes.Equal(localID.Identity, id) {
return psk.Binders[i].Binder, true
}
}
return nil, false
}
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
//
// struct {
// PskKeyExchangeMode ke_modes<1..255>;
// } PskKeyExchangeModes;
type PSKKeyExchangeModesExtension struct {
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
}
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
return ExtensionTypePSKKeyExchangeModes
}
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
return syntax.Marshal(pkem)
}
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, pkem)
}
// struct {
// } EarlyDataIndication;
type EarlyDataExtension struct{}
func (ed EarlyDataExtension) Type() ExtensionType {
return ExtensionTypeEarlyData
}
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
return 0, nil
}
// struct {
// uint32 max_early_data_size;
// } TicketEarlyDataInfo;
type TicketEarlyDataInfoExtension struct {
MaxEarlyDataSize uint32
}
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
return ExtensionTypeTicketEarlyDataInfo
}
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
return syntax.Marshal(tedi)
}
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tedi)
}
// opaque ProtocolName<1..2^8-1>;
//
// struct {
// ProtocolName protocol_name_list<2..2^16-1>
// } ProtocolNameList;
type ALPNExtension struct {
Protocols []string
}
type protocolNameInner struct {
Name []byte `tls:"head=1,min=1"`
}
type alpnExtensionInner struct {
Protocols []protocolNameInner `tls:"head=2,min=2"`
}
func (alpn ALPNExtension) Type() ExtensionType {
return ExtensionTypeALPN
}
func (alpn ALPNExtension) Marshal() ([]byte, error) {
protocols := make([]protocolNameInner, len(alpn.Protocols))
for i, protocol := range alpn.Protocols {
protocols[i] = protocolNameInner{[]byte(protocol)}
}
return syntax.Marshal(alpnExtensionInner{protocols})
}
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
var inner alpnExtensionInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
alpn.Protocols = make([]string, len(inner.Protocols))
for i, protocol := range inner.Protocols {
alpn.Protocols[i] = string(protocol.Name)
}
return read, nil
}
// struct {
// ProtocolVersion versions<2..254>;
// } SupportedVersions;
type SupportedVersionsExtension struct {
Versions []uint16 `tls:"head=1,min=2,max=254"`
}
func (sv SupportedVersionsExtension) Type() ExtensionType {
return ExtensionTypeSupportedVersions
}
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sv)
}
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sv)
}
// struct {
// opaque cookie<1..2^16-1>;
// } Cookie;
type CookieExtension struct {
Cookie []byte `tls:"head=2,min=1"`
}
func (c CookieExtension) Type() ExtensionType {
return ExtensionTypeCookie
}
func (c CookieExtension) Marshal() ([]byte, error) {
return syntax.Marshal(c)
}
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, c)
}
// defaultCookieLength is the default length of a cookie
const defaultCookieLength = 32
type defaultCookieHandler struct {
data []byte
}
var _ CookieHandler = &defaultCookieHandler{}
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
h.data = make([]byte, defaultCookieLength)
if _, err := prng.Read(h.data); err != nil {
return nil, err
}
return h.data, nil
}
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
return bytes.Equal(h.data, data)
}

View file

@ -1,147 +0,0 @@
package mint
import (
"encoding/hex"
"math/big"
)
var (
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B423861285C97FFFFFFFFFFFFFFFF"
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
"FFFFFFFFFFFFFFFF"
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
)

View file

@ -1,98 +0,0 @@
// Read a generic "framed" packet consisting of a header and a
// This is used for both TLS Records and TLS Handshake Messages
package mint
type framing interface {
headerLen() int
defaultReadLen() int
frameLen(hdr []byte) (int, error)
}
const (
kFrameReaderHdr = 0
kFrameReaderBody = 1
)
type frameNextAction func(f *frameReader) error
type frameReader struct {
details framing
state uint8
header []byte
body []byte
working []byte
writeOffset int
remainder []byte
}
func newFrameReader(d framing) *frameReader {
hdr := make([]byte, d.headerLen())
return &frameReader{
d,
kFrameReaderHdr,
hdr,
nil,
hdr,
0,
nil,
}
}
func dup(a []byte) []byte {
r := make([]byte, len(a))
copy(r, a)
return r
}
func (f *frameReader) needed() int {
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
if tmp < 0 {
return 0
}
return tmp
}
func (f *frameReader) addChunk(in []byte) {
// Append to the buffer.
logf(logTypeFrameReader, "Appending %v", len(in))
f.remainder = append(f.remainder, in...)
}
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
for f.needed() == 0 {
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
// Fill out our working block
copied := copy(f.working[f.writeOffset:], f.remainder)
f.remainder = f.remainder[copied:]
f.writeOffset += copied
if f.writeOffset < len(f.working) {
logf(logTypeFrameReader, "Read would have blocked 1")
return nil, nil, WouldBlock
}
// Reset the write offset, because we are now full.
f.writeOffset = 0
// We have read a full frame
if f.state == kFrameReaderBody {
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
f.state = kFrameReaderHdr
f.working = f.header
return dup(f.header), dup(f.body), nil
}
// We have read the header
bodyLen, err := f.details.frameLen(f.header)
if err != nil {
return nil, nil, err
}
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
f.body = make([]byte, bodyLen)
f.working = f.body
f.writeOffset = 0
f.state = kFrameReaderBody
}
logf(logTypeFrameReader, "Read would have blocked 2")
return nil, nil, WouldBlock
}

View file

@ -1,253 +0,0 @@
package mint
import (
"fmt"
"io"
"net"
)
const (
handshakeHeaderLen = 4 // handshake message header length
maxHandshakeMessageLen = 1 << 24 // max handshake message length
)
// struct {
// HandshakeType msg_type; /* handshake type */
// uint24 length; /* bytes in message */
// select (HandshakeType) {
// ...
// } body;
// } Handshake;
//
// We do the select{...} part in a different layer, so we treat the
// actual message body as opaque:
//
// struct {
// HandshakeType msg_type;
// opaque msg<0..2^24-1>
// } Handshake;
//
// TODO: File a spec bug
type HandshakeMessage struct {
// Omitted: length
msgType HandshakeType
body []byte
}
// Note: This could be done with the `syntax` module, using the simplified
// syntax as discussed above. However, since this is so simple, there's not
// much benefit to doing so.
func (hm *HandshakeMessage) Marshal() []byte {
if hm == nil {
return []byte{}
}
msgLen := len(hm.body)
data := make([]byte, 4+len(hm.body))
data[0] = byte(hm.msgType)
data[1] = byte(msgLen >> 16)
data[2] = byte(msgLen >> 8)
data[3] = byte(msgLen)
copy(data[4:], hm.body)
return data
}
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
var body HandshakeMessageBody
switch hm.msgType {
case HandshakeTypeClientHello:
body = new(ClientHelloBody)
case HandshakeTypeServerHello:
body = new(ServerHelloBody)
case HandshakeTypeHelloRetryRequest:
body = new(HelloRetryRequestBody)
case HandshakeTypeEncryptedExtensions:
body = new(EncryptedExtensionsBody)
case HandshakeTypeCertificate:
body = new(CertificateBody)
case HandshakeTypeCertificateRequest:
body = new(CertificateRequestBody)
case HandshakeTypeCertificateVerify:
body = new(CertificateVerifyBody)
case HandshakeTypeFinished:
body = &FinishedBody{VerifyDataLen: len(hm.body)}
case HandshakeTypeNewSessionTicket:
body = new(NewSessionTicketBody)
case HandshakeTypeKeyUpdate:
body = new(KeyUpdateBody)
case HandshakeTypeEndOfEarlyData:
body = new(EndOfEarlyDataBody)
default:
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
}
_, err := body.Unmarshal(hm.body)
return body, err
}
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
data, err := body.Marshal()
if err != nil {
return nil, err
}
return &HandshakeMessage{
msgType: body.Type(),
body: data,
}, nil
}
type HandshakeLayer struct {
nonblocking bool // Should we operate in nonblocking mode
conn *RecordLayer // Used for reading/writing records
frame *frameReader // The buffered frame reader
}
type handshakeLayerFrameDetails struct{}
func (d handshakeLayerFrameDetails) headerLen() int {
return handshakeHeaderLen
}
func (d handshakeLayerFrameDetails) defaultReadLen() int {
return handshakeHeaderLen + maxFragmentLen
}
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
logf(logTypeIO, "Header=%x", hdr)
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
}
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.conn = r
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
return &h
}
func (h *HandshakeLayer) readRecord() error {
logf(logTypeIO, "Trying to read record")
pt, err := h.conn.ReadRecord()
if err != nil {
return err
}
if pt.contentType != RecordTypeHandshake &&
pt.contentType != RecordTypeAlert {
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
}
if pt.contentType == RecordTypeAlert {
logf(logTypeIO, "read alert %v", pt.fragment[1])
if len(pt.fragment) < 2 {
h.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
return Alert(pt.fragment[1])
}
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
h.frame.addChunk(pt.fragment)
return nil
}
// sendAlert sends a TLS alert message.
func (h *HandshakeLayer) sendAlert(err Alert) error {
tmp := make([]byte, 2)
tmp[0] = AlertLevelError
tmp[1] = byte(err)
h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: tmp},
)
// closeNotify is a special case in that it isn't an error:
if err != AlertCloseNotify {
return &net.OpError{Op: "local error", Err: err}
}
return nil
}
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
var hdr, body []byte
var err error
for {
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
if h.frame.needed() > 0 {
logf(logTypeHandshake, "Trying to read a new record")
err = h.readRecord()
}
if err != nil && (h.nonblocking || err != WouldBlock) {
return nil, err
}
hdr, body, err = h.frame.process()
if err == nil {
break
}
if err != nil && (h.nonblocking || err != WouldBlock) {
return nil, err
}
}
logf(logTypeHandshake, "read handshake message")
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(hdr[0])
hm.body = make([]byte, len(body))
copy(hm.body, body)
return hm, nil
}
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
return h.WriteMessages([]*HandshakeMessage{hm})
}
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
for _, hm := range hms {
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
}
// Write out headers and bodies
buffer := []byte{}
for _, msg := range hms {
msgLen := len(msg.body)
if msgLen > maxHandshakeMessageLen {
return fmt.Errorf("tls.handshakelayer: Message too large to send")
}
buffer = append(buffer, msg.Marshal()...)
}
// Send full-size fragments
var start int
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return err
}
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buffer[start:],
})
if err != nil {
return err
}
}
return nil
}

View file

@ -1,450 +0,0 @@
package mint
import (
"bytes"
"crypto"
"crypto/x509"
"encoding/binary"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type HandshakeMessageBody interface {
Type() HandshakeType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// Random random;
// opaque legacy_session_id<0..32>;
// CipherSuite cipher_suites<2..2^16-2>;
// opaque legacy_compression_methods<1..2^8-1>;
// Extension extensions<0..2^16-1>;
// } ClientHello;
type ClientHelloBody struct {
// Omitted: clientVersion
// Omitted: legacySessionID
// Omitted: legacyCompressionMethods
Random [32]byte
CipherSuites []CipherSuite
Extensions ExtensionList
}
type clientHelloBodyInner struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
CipherSuites []CipherSuite `tls:"head=2,min=2"`
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
Extensions []Extension `tls:"head=2"`
}
func (ch ClientHelloBody) Type() HandshakeType {
return HandshakeTypeClientHello
}
func (ch ClientHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(clientHelloBodyInner{
LegacyVersion: 0x0303,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
}
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
var inner clientHelloBodyInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
// We are strict about these things because we only support 1.3
if inner.LegacyVersion != 0x0303 {
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.Random = inner.Random
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
return read, nil
}
// TODO: File a spec bug to clarify this
func (ch ClientHelloBody) Truncated() ([]byte, error) {
if len(ch.Extensions) == 0 {
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
}
pskExt := ch.Extensions[len(ch.Extensions)-1]
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
}
chm, err := HandshakeMessageFromBody(&ch)
if err != nil {
return nil, err
}
chData := chm.Marshal()
psk := PreSharedKeyExtension{
HandshakeType: HandshakeTypeClientHello,
}
_, err = psk.Unmarshal(pskExt.ExtensionData)
if err != nil {
return nil, err
}
// Marshal just the binders so that we know how much to truncate
binders := struct {
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}{Binders: psk.Binders}
binderData, _ := syntax.Marshal(binders)
binderLen := len(binderData)
chLen := len(chData)
return chData[:chLen-binderLen], nil
}
// struct {
// ProtocolVersion server_version;
// CipherSuite cipher_suite;
// Extension extensions<2..2^16-1>;
// } HelloRetryRequest;
type HelloRetryRequestBody struct {
Version uint16
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2,min=2"`
}
func (hrr HelloRetryRequestBody) Type() HandshakeType {
return HandshakeTypeHelloRetryRequest
}
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(hrr)
}
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, hrr)
}
// struct {
// ProtocolVersion version;
// Random random;
// CipherSuite cipher_suite;
// Extension extensions<0..2^16-1>;
// } ServerHello;
type ServerHelloBody struct {
Version uint16
Random [32]byte
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2"`
}
func (sh ServerHelloBody) Type() HandshakeType {
return HandshakeTypeServerHello
}
func (sh ServerHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(sh)
}
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sh)
}
// struct {
// opaque verify_data[verify_data_length];
// } Finished;
//
// verifyDataLen is not a field in the TLS struct, but we add it here so
// that calling code can tell us how much data to expect when we marshal /
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
// try to keep the signature consistent for now.)
//
// For similar reasons, we don't use the `syntax` module here, because this
// struct doesn't map well to standard TLS presentation language concepts.
//
// TODO: File a spec bug
type FinishedBody struct {
VerifyDataLen int
VerifyData []byte
}
func (fin FinishedBody) Type() HandshakeType {
return HandshakeTypeFinished
}
func (fin FinishedBody) Marshal() ([]byte, error) {
if len(fin.VerifyData) != fin.VerifyDataLen {
return nil, fmt.Errorf("tls.finished: data length mismatch")
}
body := make([]byte, len(fin.VerifyData))
copy(body, fin.VerifyData)
return body, nil
}
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
if len(data) < fin.VerifyDataLen {
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
}
fin.VerifyData = make([]byte, fin.VerifyDataLen)
copy(fin.VerifyData, data[:fin.VerifyDataLen])
return fin.VerifyDataLen, nil
}
// struct {
// Extension extensions<0..2^16-1>;
// } EncryptedExtensions;
//
// Marshal() and Unmarshal() are handled by ExtensionList
type EncryptedExtensionsBody struct {
Extensions ExtensionList `tls:"head=2"`
}
func (ee EncryptedExtensionsBody) Type() HandshakeType {
return HandshakeTypeEncryptedExtensions
}
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
return syntax.Marshal(ee)
}
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ee)
}
// opaque ASN1Cert<1..2^24-1>;
//
// struct {
// ASN1Cert cert_data;
// Extension extensions<0..2^16-1>
// } CertificateEntry;
//
// struct {
// opaque certificate_request_context<0..2^8-1>;
// CertificateEntry certificate_list<0..2^24-1>;
// } Certificate;
type CertificateEntry struct {
CertData *x509.Certificate
Extensions ExtensionList
}
type CertificateBody struct {
CertificateRequestContext []byte
CertificateList []CertificateEntry
}
type certificateEntryInner struct {
CertData []byte `tls:"head=3,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
type certificateBodyInner struct {
CertificateRequestContext []byte `tls:"head=1"`
CertificateList []certificateEntryInner `tls:"head=3"`
}
func (c CertificateBody) Type() HandshakeType {
return HandshakeTypeCertificate
}
func (c CertificateBody) Marshal() ([]byte, error) {
inner := certificateBodyInner{
CertificateRequestContext: c.CertificateRequestContext,
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
}
for i, entry := range c.CertificateList {
inner.CertificateList[i] = certificateEntryInner{
CertData: entry.CertData.Raw,
Extensions: entry.Extensions,
}
}
return syntax.Marshal(inner)
}
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
inner := certificateBodyInner{}
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return read, err
}
c.CertificateRequestContext = inner.CertificateRequestContext
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
for i, entry := range inner.CertificateList {
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
if err != nil {
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
}
c.CertificateList[i].Extensions = entry.Extensions
}
return read, nil
}
// struct {
// SignatureScheme algorithm;
// opaque signature<0..2^16-1>;
// } CertificateVerify;
type CertificateVerifyBody struct {
Algorithm SignatureScheme
Signature []byte `tls:"head=2"`
}
func (cv CertificateVerifyBody) Type() HandshakeType {
return HandshakeTypeCertificateVerify
}
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
return syntax.Marshal(cv)
}
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cv)
}
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
// TODO: Change context for client auth
// TODO: Put this in a const
const context = "TLS 1.3, server CertificateVerify"
sigInput := bytes.Repeat([]byte{0x20}, 64)
sigInput = append(sigInput, []byte(context)...)
sigInput = append(sigInput, []byte{0}...)
sigInput = append(sigInput, data...)
return sigInput
}
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
sigInput := cv.EncodeSignatureInput(handshakeHash)
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return
}
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
sigInput := cv.EncodeSignatureInput(handshakeHash)
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
}
// struct {
// opaque certificate_request_context<0..2^8-1>;
// Extension extensions<2..2^16-1>;
// } CertificateRequest;
type CertificateRequestBody struct {
CertificateRequestContext []byte `tls:"head=1"`
Extensions ExtensionList `tls:"head=2"`
}
func (cr CertificateRequestBody) Type() HandshakeType {
return HandshakeTypeCertificateRequest
}
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(cr)
}
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cr)
}
// struct {
// uint32 ticket_lifetime;
// uint32 ticket_age_add;
// opaque ticket_nonce<1..255>;
// opaque ticket<1..2^16-1>;
// Extension extensions<0..2^16-2>;
// } NewSessionTicket;
type NewSessionTicketBody struct {
TicketLifetime uint32
TicketAgeAdd uint32
TicketNonce []byte `tls:"head=1,min=1"`
Ticket []byte `tls:"head=2,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
const ticketNonceLen = 16
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
buf := make([]byte, 4+ticketNonceLen+ticketLen)
_, err := prng.Read(buf)
if err != nil {
return nil, err
}
tkt := &NewSessionTicketBody{
TicketLifetime: ticketLifetime,
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
TicketNonce: buf[4 : 4+ticketNonceLen],
Ticket: buf[4+ticketNonceLen:],
}
return tkt, err
}
func (tkt NewSessionTicketBody) Type() HandshakeType {
return HandshakeTypeNewSessionTicket
}
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
return syntax.Marshal(tkt)
}
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tkt)
}
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
//
// struct {
// KeyUpdateRequest request_update;
// } KeyUpdate;
type KeyUpdateBody struct {
KeyUpdateRequest KeyUpdateRequest
}
func (ku KeyUpdateBody) Type() HandshakeType {
return HandshakeTypeKeyUpdate
}
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
return syntax.Marshal(ku)
}
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ku)
}
// struct {} EndOfEarlyData;
type EndOfEarlyDataBody struct{}
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
return HandshakeTypeEndOfEarlyData
}
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
return 0, nil
}

View file

@ -1,55 +0,0 @@
package mint
import (
"fmt"
"log"
"os"
"strings"
)
// We use this environment variable to control logging. It should be a
// comma-separated list of log tags (see below) or "*" to enable all logging.
const logConfigVar = "MINT_LOG"
// Pre-defined log types
const (
logTypeCrypto = "crypto"
logTypeHandshake = "handshake"
logTypeNegotiation = "negotiation"
logTypeIO = "io"
logTypeFrameReader = "frame"
logTypeVerbose = "verbose"
)
var (
logFunction = log.Printf
logAll = false
logSettings = map[string]bool{}
)
func init() {
parseLogEnv(os.Environ())
}
func parseLogEnv(env []string) {
for _, stmt := range env {
if strings.HasPrefix(stmt, logConfigVar+"=") {
val := stmt[len(logConfigVar)+1:]
if val == "*" {
logAll = true
} else {
for _, t := range strings.Split(val, ",") {
logSettings[t] = true
}
}
}
}
}
func logf(tag string, format string, args ...interface{}) {
if logAll || logSettings[tag] {
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
logFunction(fullFormat, args...)
}
}

View file

@ -1,217 +0,0 @@
package mint
import (
"bytes"
"encoding/hex"
"fmt"
"time"
)
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
for _, offeredVersion := range offered {
for _, supportedVersion := range supported {
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
if offeredVersion == supportedVersion {
// XXX: Should probably be highest supported version, but for now, we
// only support one version, so it doesn't really matter.
return true, offeredVersion
}
}
}
return false, 0
}
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
for _, share := range keyShares {
for _, group := range groups {
if group != share.Group {
continue
}
pub, priv, err := newKeyShare(share.Group)
if err != nil {
// If we encounter an error, just keep looking
continue
}
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
if err != nil {
// If we encounter an error, just keep looking
continue
}
return true, group, pub, dhSecret
}
}
return false, 0, nil, nil
}
const (
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
)
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
for i, id := range identities {
identityHex := hex.EncodeToString(id.Identity)
psk, ok := psks.Get(identityHex)
if !ok {
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
continue
}
// For resumption, make sure the ticket age is correct
if psk.IsResumption {
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
ticketAgeDelta := knownTicketAge - extTicketAge
if knownTicketAge < extTicketAge {
ticketAgeDelta = extTicketAge - knownTicketAge
}
if ticketAgeDelta > ticketAgeTolerance {
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
}
}
params, ok := cipherSuiteMap[psk.CipherSuite]
if !ok {
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
return false, 0, nil, CipherSuiteParams{}, err
}
// Compute binder
binderLabel := labelExternalBinder
if psk.IsResumption {
binderLabel = labelResumptionBinder
}
h0 := params.Hash.New().Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
// context = ClientHello[truncated]
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
ctxHash := params.Hash.New()
ctxHash.Write(context)
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
if !bytes.Equal(binder, binders[i].Binder) {
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
}
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
return true, i, &psk, params, nil
}
logf(logTypeNegotiation, "Failed to find a usable PSK")
return false, 0, nil, CipherSuiteParams{}, nil
}
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
dhAllowed := false
dhRequired := true
for _, mode := range modes {
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
dhRequired = dhRequired && (mode == PSKModeDHEKE)
}
// Use PSK if we can meet DH requirement and modes were provided
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
// Use DH if allowed
usingDH := canDoDH && (dhAllowed || !usingPSK)
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
return usingDH, usingPSK
}
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
// Select for server name if provided
candidates := certs
if serverName != nil {
candidatesByName := []*Certificate{}
for _, cert := range certs {
for _, name := range cert.Chain[0].DNSNames {
if len(*serverName) > 0 && name == *serverName {
candidatesByName = append(candidatesByName, cert)
}
}
}
if len(candidatesByName) == 0 {
return nil, 0, fmt.Errorf("No certificates available for server name")
}
candidates = candidatesByName
}
// Select for signature scheme
for _, cert := range candidates {
for _, scheme := range signatureSchemes {
if !schemeValidForKey(scheme, cert.PrivateKey) {
continue
}
return cert, scheme, nil
}
}
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
}
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
return usingEarlyData
}
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
for _, s1 := range offered {
if psk != nil {
if s1 == psk.CipherSuite {
return s1, nil
}
continue
}
for _, s2 := range supported {
if s1 == s2 {
return s1, nil
}
}
}
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
}
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
for _, p1 := range offered {
if psk != nil {
if p1 != psk.NextProto {
continue
}
}
for _, p2 := range supported {
if p1 == p2 {
return p1, nil
}
}
}
// If the client offers ALPN on resumption, it must match the earlier one
var err error
if psk != nil && psk.IsResumption && (len(offered) > 0) {
err = fmt.Errorf("ALPN for PSK not provided")
}
return "", err
}

View file

@ -1,296 +0,0 @@
package mint
import (
"bytes"
"crypto/cipher"
"fmt"
"io"
"sync"
)
const (
sequenceNumberLen = 8 // sequence number length
recordHeaderLen = 5 // record header length
maxFragmentLen = 1 << 14 // max number of bytes in a record
)
type DecryptError string
func (err DecryptError) Error() string {
return string(err)
}
// struct {
// ContentType type;
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
// uint16 length;
// opaque fragment[TLSPlaintext.length];
// } TLSPlaintext;
type TLSPlaintext struct {
// Omitted: record_version (static)
// Omitted: length (computed from fragment)
contentType RecordType
fragment []byte
}
type RecordLayer struct {
sync.Mutex
conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader
nextData []byte // The next record to send
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
cachedError error // Error on the last record read
ivLength int // Length of the seq and nonce fields
seq []byte // Zero-padded sequence number
nonce []byte // Buffer for per-record nonces
cipher cipher.AEAD // AEAD cipher
}
type recordLayerFrameDetails struct{}
func (d recordLayerFrameDetails) headerLen() int {
return recordHeaderLen
}
func (d recordLayerFrameDetails) defaultReadLen() int {
return recordHeaderLen + maxFragmentLen
}
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return (int(hdr[3]) << 8) | int(hdr[4]), nil
}
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
r := RecordLayer{}
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{})
r.ivLength = 0
return &r
}
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
var err error
r.cipher, err = cipher(key)
if err != nil {
return err
}
r.ivLength = len(iv)
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
r.nonce = make([]byte, r.ivLength)
copy(r.nonce, iv)
return nil
}
func (r *RecordLayer) incrementSequenceNumber() {
if r.ivLength == 0 {
return
}
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
r.seq[i]++
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
if r.seq[i] != 0 {
return
}
}
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
panic("TLS: sequence number wraparound")
}
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
// Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen
ciphertextLen := plaintextLen + r.cipher.Overhead()
// Assemble the revised plaintext
out := &TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: make([]byte, ciphertextLen),
}
copy(out.fragment, pt.fragment)
out.fragment[originalLen] = byte(pt.contentType)
for i := 1; i <= padLen; i++ {
out.fragment[originalLen+i] = 0
}
// Encrypt the fragment
payload := out.fragment[:plaintextLen]
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
return out
}
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
if len(pt.fragment) < r.cipher.Overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
return nil, 0, DecryptError(msg)
}
decryptLen := len(pt.fragment) - r.cipher.Overhead()
out := &TLSPlaintext{
contentType: pt.contentType,
fragment: make([]byte, decryptLen),
}
// Decrypt
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
if err != nil {
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
}
// Find the padding boundary
padLen := 0
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
}
// Transfer the content type
newLen := decryptLen - padLen - 1
out.contentType = RecordType(out.fragment[newLen])
// Truncate the message to remove contentType, padding, overhead
out.fragment = out.fragment[:newLen]
return out, padLen, nil
}
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
var pt *TLSPlaintext
var err error
for {
pt, err = r.nextRecord()
if err == nil {
break
}
if !block || err != WouldBlock {
return 0, err
}
}
return pt.contentType, nil
}
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
pt, err := r.nextRecord()
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
if r.cachedRecord != nil {
logf(logTypeIO, "Returning cached record")
return r.cachedRecord, r.cachedError
}
// Loop until one of three things happens:
//
// 1. We get a frame
// 2. We try to read off the socket and get nothing, in which case
// return WouldBlock
// 3. We get an error.
err := WouldBlock
var header, body []byte
for err != nil {
if r.frame.needed() > 0 {
buf := make([]byte, recordHeaderLen+maxFragmentLen)
n, err := r.conn.Read(buf)
if err != nil {
logf(logTypeIO, "Error reading, %v", err)
return nil, err
}
if n == 0 {
return nil, WouldBlock
}
logf(logTypeIO, "Read %v bytes", n)
buf = buf[:n]
r.frame.addChunk(buf)
}
header, body, err = r.frame.process()
// Loop around on WouldBlock to see if some
// data is now available.
if err != nil && err != WouldBlock {
return nil, err
}
}
pt := &TLSPlaintext{}
// Validate content type
switch RecordType(header[0]) {
default:
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
pt.contentType = RecordType(header[0])
}
// Validate version
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
}
// Validate size < max
size := (int(header[3]) << 8) + int(header[4])
if size > maxFragmentLen+256 {
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
}
pt.fragment = make([]byte, size)
copy(pt.fragment, body)
// Attempt to decrypt fragment
if r.cipher != nil {
pt, _, err = r.decrypt(pt)
if err != nil {
return nil, err
}
}
// Check that plaintext length is not too long
if len(pt.fragment) > maxFragmentLen {
return nil, fmt.Errorf("tls.record: Plaintext size too big")
}
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
r.cachedRecord = pt
r.incrementSequenceNumber()
return pt, nil
}
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
return r.WriteRecordWithPadding(pt, 0)
}
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
if r.cipher != nil {
pt = r.encrypt(pt, padLen)
} else if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
}
if len(pt.fragment) > maxFragmentLen {
return fmt.Errorf("tls.record: Record size too big")
}
length := len(pt.fragment)
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
record := append(header, pt.fragment...)
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
r.incrementSequenceNumber()
_, err := r.conn.Write(record)
return err
}

View file

@ -1,898 +0,0 @@
package mint
import (
"bytes"
"hash"
"reflect"
)
// Server State Machine
//
// START <-----+
// Recv ClientHello | | Send HelloRetryRequest
// v |
// RECVD_CH ----+
// | Select parameters
// | Send ServerHello
// v
// NEGOTIATED
// | Send EncryptedExtensions
// | [Send CertificateRequest]
// Can send | [Send Certificate + CertificateVerify]
// app data --> | Send Finished
// after +--------+--------+
// here No 0-RTT | | 0-RTT
// | v
// | WAIT_EOED <---+
// | Recv | | | Recv
// | EndOfEarlyData | | | early data
// | | +-----+
// +> WAIT_FLIGHT2 <-+
// |
// +--------+--------+
// No auth | | Client auth
// | |
// | v
// | WAIT_CERT
// | Recv | | Recv Certificate
// | empty | v
// | Certificate | WAIT_CV
// | | | Recv
// | v | CertificateVerify
// +-> WAIT_FINISHED <---+
// | Recv Finished
// v
// CONNECTED
//
// NB: Not using state RECVD_CH
//
// State Instructions
// START {}
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
// WAIT_EOED RekeyIn;
// WAIT_FLIGHT2 {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ServerStateStart struct {
Caps Capabilities
conn *Conn
cookieSent bool
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
}
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeClientHello {
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
return nil, nil, AlertUnexpectedMessage
}
ch := &ClientHelloBody{}
_, err := ch.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
clientHello := hm
connParams := ConnectionParameters{}
supportedVersions := new(SupportedVersionsExtension)
serverName := new(ServerNameExtension)
supportedGroups := new(SupportedGroupsExtension)
signatureAlgorithms := new(SignatureAlgorithmsExtension)
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
clientEarlyData := &EarlyDataExtension{}
clientALPN := new(ALPNExtension)
clientPSKModes := new(PSKKeyExchangeModesExtension)
clientCookie := new(CookieExtension)
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
gotServerName := ch.Extensions.Find(serverName)
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
gotEarlyData := ch.Extensions.Find(clientEarlyData)
ch.Extensions.Find(clientKeyShares)
ch.Extensions.Find(clientPSK)
ch.Extensions.Find(clientALPN)
ch.Extensions.Find(clientPSKModes)
ch.Extensions.Find(clientCookie)
if gotServerName {
connParams.ServerName = string(*serverName)
}
// If the client didn't send supportedVersions or doesn't support 1.3,
// then we're done here.
if !gotSupportedVersions {
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
return nil, nil, AlertProtocolVersion
}
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
if !versionOK {
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
return nil, nil, AlertProtocolVersion
}
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
return nil, nil, AlertAccessDenied
}
// Figure out if we can do DH
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
// Figure out if we can do PSK
canDoPSK := false
var selectedPSK int
var psk *PreSharedKey
var params CipherSuiteParams
if len(clientPSK.Identities) > 0 {
contextBase := []byte{}
if state.helloRetryRequest != nil {
chBytes := state.firstClientHello.Marshal()
hrrBytes := state.helloRetryRequest.Marshal()
contextBase = append(chBytes, hrrBytes...)
}
chTrunc, err := ch.Truncated()
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
return nil, nil, AlertDecodeError
}
context := append(contextBase, chTrunc...)
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
return nil, nil, AlertInternalError
}
}
// Figure out if we actually should do DH / PSK
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
// Select a ciphersuite
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
return nil, nil, AlertHandshakeFailure
}
// Send a cookie if required
// NB: Need to do this here because it's after ciphersuite selection, which
// has to be after PSK selection.
// XXX: Doing this statefully for now, could be stateless
var cookieData []byte
if state.Caps.RequireCookie && !state.cookieSent {
var err error
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
return nil, nil, AlertInternalError
}
}
if cookieData != nil {
// Ignoring errors because everything here is newly constructed, so there
// shouldn't be marshal errors
hrr := &HelloRetryRequestBody{
Version: supportedVersion,
CipherSuite: connParams.CipherSuite,
}
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
return nil, nil, AlertInternalError
}
params := cipherSuiteMap[connParams.CipherSuite]
h := params.Hash.New()
h.Write(clientHello.Marshal())
firstClientHello := &HandshakeMessage{
msgType: HandshakeTypeMessageHash,
body: h.Sum(nil),
}
nextState := ServerStateStart{
Caps: state.Caps,
conn: state.conn,
cookieSent: true,
firstClientHello: firstClientHello,
helloRetryRequest: helloRetryRequest,
}
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
return nextState, toSend, AlertNoAlert
}
// If we've got no entropy to make keys from, fail
if !connParams.UsingDH && !connParams.UsingPSK {
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
return nil, nil, AlertHandshakeFailure
}
var pskSecret []byte
var cert *Certificate
var certScheme SignatureScheme
if connParams.UsingPSK {
pskSecret = psk.Key
} else {
psk = nil
// If we're not using a PSK mode, then we need to have certain extensions
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
return nil, nil, AlertMissingExtension
}
// Select a certificate
name := string(*serverName)
var err error
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
return nil, nil, AlertAccessDenied
}
}
if !connParams.UsingDH {
dhSecret = nil
}
// Figure out if we're going to do early data
var clientEarlyTrafficSecret []byte
connParams.ClientSendingEarlyData = gotEarlyData
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
if connParams.UsingEarlyData {
h := params.Hash.New()
h.Write(clientHello.Marshal())
chHash := h.Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
}
// Select a next protocol
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
return nil, nil, AlertNoApplicationProtocol
}
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
return ServerStateNegotiated{
Caps: state.Caps,
Params: connParams,
dhGroup: dhGroup,
dhPublic: dhPublic,
dhSecret: dhSecret,
pskSecret: pskSecret,
selectedPSK: selectedPSK,
cert: cert,
certScheme: certScheme,
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
firstClientHello: state.firstClientHello,
helloRetryRequest: state.helloRetryRequest,
clientHello: clientHello,
}.Next(nil)
}
type ServerStateNegotiated struct {
Caps Capabilities
Params ConnectionParameters
dhGroup NamedGroup
dhPublic []byte
dhSecret []byte
pskSecret []byte
clientEarlyTrafficSecret []byte
selectedPSK int
cert *Certificate
certScheme SignatureScheme
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
clientHello *HandshakeMessage
}
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
// Create the ServerHello
sh := &ServerHelloBody{
Version: supportedVersion,
CipherSuite: state.Params.CipherSuite,
}
_, err := prng.Read(sh.Random[:])
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
return nil, nil, AlertInternalError
}
if state.Params.UsingDH {
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
err = sh.Extensions.Add(&KeyShareExtension{
HandshakeType: HandshakeTypeServerHello,
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.Params.UsingPSK {
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
err = sh.Extensions.Add(&PreSharedKeyExtension{
HandshakeType: HandshakeTypeServerHello,
SelectedIdentity: uint16(state.selectedPSK),
})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
return nil, nil, AlertInternalError
}
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
serverHello, err := HandshakeMessageFromBody(sh)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
return nil, nil, AlertInternalError
}
// Look up crypto params
params, ok := cipherSuiteMap[sh.CipherSuite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(serverHello.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if state.dhSecret == nil {
state.dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
// Send an EncryptedExtensions message (even if it's empty)
eeList := ExtensionList{}
if state.Params.NextProto != "" {
logf(logTypeHandshake, "[server] sending ALPN extension")
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.Params.UsingEarlyData {
logf(logTypeHandshake, "[server] sending EDI extension")
err = eeList.Add(&EarlyDataExtension{})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
}
ee := &EncryptedExtensionsBody{eeList}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
eem, err := HandshakeMessageFromBody(ee)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
handshakeHash.Write(eem.Marshal())
toSend := []HandshakeAction{
SendHandshakeMessage{serverHello},
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
SendHandshakeMessage{eem},
}
// Authenticate with a certificate if required
if !state.Params.UsingPSK {
// Send a CertificateRequest message if we want client auth
if state.Caps.RequireClientAuth {
state.Params.UsingClientAuth = true
// XXX: We don't support sending any constraints besides a list of
// supported signature algorithms
cr := &CertificateRequestBody{}
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
err := cr.Extensions.Add(schemes)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
return nil, nil, AlertInternalError
}
crm, err := HandshakeMessageFromBody(cr)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
return nil, nil, AlertInternalError
}
//TODO state.state.serverCertificateRequest = cr
toSend = append(toSend, SendHandshakeMessage{crm})
handshakeHash.Write(crm.Marshal())
}
// Create and send Certificate, CertificateVerify
certificate := &CertificateBody{
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
}
for i, entry := range state.cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
handshakeHash.Write(certm.Marshal())
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
hcv := handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
certvm, err := HandshakeMessageFromBody(certificateVerify)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certvm})
handshakeHash.Write(certvm.Marshal())
}
// Compute secrets resulting from the server's first flight
h3 := handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
// Assemble the Finished message
fin := &FinishedBody{
VerifyDataLen: len(serverFinishedData),
VerifyData: serverFinishedData,
}
finm, _ := HandshakeMessageFromBody(fin)
toSend = append(toSend, SendHandshakeMessage{finm})
handshakeHash.Write(finm.Marshal())
// Compute traffic secrets
h4 := handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
if state.Params.UsingEarlyData {
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
nextState := ServerStateWaitEOED{
AuthCertificate: state.Caps.AuthCertificate,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
toSend = append(toSend, []HandshakeAction{
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
ReadEarlyData{},
}...)
return nextState, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
toSend = append(toSend, []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
ReadPastEarlyData{},
}...)
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.Caps.AuthCertificate,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
nextState, moreToSend, alert := waitFlight2.Next(nil)
toSend = append(toSend, moreToSend...)
return nextState, toSend, alert
}
type ServerStateWaitEOED struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
if len(hm.body) > 0 {
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
toSend := []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
}
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
nextState, moreToSend, alert := waitFlight2.Next(nil)
toSend = append(toSend, moreToSend...)
return nextState, toSend, alert
}
type ServerStateWaitFlight2 struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
if state.Params.UsingClientAuth {
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
nextState := ServerStateWaitCert{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificate {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
cert := &CertificateBody{}
_, err := cert.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
if len(cert.CertificateList) == 0 {
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
nextState := ServerStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
clientCertificate: cert,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
clientCertificate *CertificateBody
}
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
return nil, nil, AlertUnexpectedMessage
}
certVerify := &CertificateVerifyBody{}
_, err := certVerify.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
// Verify client signature over handshake hash
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.clientCertificate.CertificateList)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
}
// If it passes, record the certificateVerify in the transcript hash
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitFinished struct {
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeFinished {
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
_, err := fin.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
// Verify client Finished data
h5 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
return nil, nil, AlertHandshakeFailure
}
// Compute the resumption secret
state.handshakeHash.Write(hm.Marshal())
h6 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
// Compute client traffic keys
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
nextState := StateConnected{
Params: state.Params,
isClient: false,
cryptoParams: state.cryptoParams,
resumptionSecret: resumptionSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
toSend := []HandshakeAction{
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
}
return nextState, toSend, AlertNoAlert
}

View file

@ -1,230 +0,0 @@
package mint
import (
"time"
)
// Marker interface for actions that an implementation should take based on
// state transitions.
type HandshakeAction interface{}
type SendHandshakeMessage struct {
Message *HandshakeMessage
}
type SendEarlyData struct{}
type ReadEarlyData struct{}
type ReadPastEarlyData struct{}
type RekeyIn struct {
Label string
KeySet keySet
}
type RekeyOut struct {
Label string
KeySet keySet
}
type StorePSK struct {
PSK PreSharedKey
}
type HandshakeState interface {
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
}
type AppExtensionHandler interface {
Send(hs HandshakeType, el *ExtensionList) error
Receive(hs HandshakeType, el *ExtensionList) error
}
// Capabilities objects represent the capabilities of a TLS client or server,
// as an input to TLS negotiation
type Capabilities struct {
// For both client and server
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
PSKs PreSharedKeyCache
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
ExtensionHandler AppExtensionHandler
// For client
PSKModes []PSKKeyExchangeMode
// For server
NextProtos []string
AllowEarlyData bool
RequireCookie bool
CookieHandler CookieHandler
RequireClientAuth bool
}
// ConnectionOptions objects represent per-connection settings for a client
// initiating a connection
type ConnectionOptions struct {
ServerName string
NextProtos []string
EarlyData []byte
}
// ConnectionParameters objects represent the parameters negotiated for a
// connection.
type ConnectionParameters struct {
UsingPSK bool
UsingDH bool
ClientSendingEarlyData bool
UsingEarlyData bool
UsingClientAuth bool
CipherSuite CipherSuite
ServerName string
NextProto string
}
// StateConnected is symmetric between client and server
type StateConnected struct {
Params ConnectionParameters
isClient bool
cryptoParams CipherSuiteParams
resumptionSecret []byte
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
var trafficKeys keySet
if state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
SendHandshakeMessage{kum},
RekeyOut{Label: "update", KeySet: trafficKeys},
}
return toSend, AlertNoAlert
}
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
tkt, err := NewSessionTicket(length, lifetime)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
return nil, AlertInternalError
}
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
return nil, AlertInternalError
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
newPSK := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: tkt.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
TicketAgeAdd: tkt.TicketAgeAdd,
}
tktm, err := HandshakeMessageFromBody(tkt)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
StorePSK{newPSK},
SendHandshakeMessage{tktm},
}
return toSend, AlertNoAlert
}
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[StateConnected] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
switch body := bodyGeneric.(type) {
case *KeyUpdateBody:
var trafficKeys keySet
if !state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
// If requested, roll outbound keys and send a KeyUpdate
if body.KeyUpdateRequest == KeyUpdateRequested {
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
if alert != AlertNoAlert {
return nil, nil, alert
}
toSend = append(toSend, moreToSend...)
}
return state, toSend, AlertNoAlert
case *NewSessionTicketBody:
// XXX: Allow NewSessionTicket in both directions?
if !state.isClient {
return nil, nil, AlertUnexpectedMessage
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
psk := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: body.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
TicketAgeAdd: body.TicketAgeAdd,
}
toSend := []HandshakeAction{StorePSK{psk}}
return state, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
return nil, nil, AlertUnexpectedMessage
}

View file

@ -1,243 +0,0 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Unmarshal(data []byte, v interface{}) (int, error) {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
d := decodeState{}
d.Write(data)
return d.unmarshal(v)
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type decOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
}
type decodeState struct {
bytes.Buffer
}
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
}
read = d.value(rv)
return read, nil
}
func (e *decodeState) value(v reflect.Value) int {
return valueDecoder(v)(e, v, decOpts{})
}
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
func valueDecoder(v reflect.Value) decoderFunc {
return typeDecoder(v.Type().Elem())
}
func typeDecoder(t reflect.Type) decoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeDecoder(t)
}
func newTypeDecoder(t reflect.Type) decoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintDecoder
case reflect.Array:
return newArrayDecoder(t)
case reflect.Slice:
return newSliceDecoder(t)
case reflect.Struct:
return newStructDecoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific decoders below
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
var uintLen int
switch v.Elem().Kind() {
case reflect.Uint8:
uintLen = 1
case reflect.Uint16:
uintLen = 2
case reflect.Uint32:
uintLen = 4
case reflect.Uint64:
uintLen = 8
}
buf := make([]byte, uintLen)
n, err := d.Read(buf)
if err != nil {
panic(err)
}
if n != uintLen {
panic(fmt.Errorf("Insufficient data to read uint"))
}
val := uint64(0)
for _, b := range buf {
val = (val << 8) + uint64(b)
}
v.Elem().SetUint(val)
return uintLen
}
//////////
type arrayDecoder struct {
elemDec decoderFunc
}
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
n := v.Elem().Type().Len()
read := 0
for i := 0; i < n; i += 1 {
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
}
return read
}
func newArrayDecoder(t reflect.Type) decoderFunc {
dec := &arrayDecoder{typeDecoder(t.Elem())}
return dec.decode
}
//////////
type sliceDecoder struct {
elementType reflect.Type
elementDec decoderFunc
}
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
if opts.head == 0 {
panic(fmt.Errorf("Cannot decode a slice without a header length"))
}
lengthBytes := make([]byte, opts.head)
n, err := d.Read(lengthBytes)
if err != nil {
panic(err)
}
if uint(n) != opts.head {
panic(fmt.Errorf("Not enough data to read header"))
}
length := uint(0)
for _, b := range lengthBytes {
length = (length << 8) + uint(b)
}
if opts.max > 0 && length > opts.max {
panic(fmt.Errorf("Length of vector exceeds declared max"))
}
if length < opts.min {
panic(fmt.Errorf("Length of vector below declared min"))
}
data := make([]byte, length)
n, err = d.Read(data)
if err != nil {
panic(err)
}
if uint(n) != length {
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
}
elemBuf := &decodeState{}
elemBuf.Write(data)
elems := []reflect.Value{}
read := int(opts.head)
for elemBuf.Len() > 0 {
elem := reflect.New(sd.elementType)
read += sd.elementDec(elemBuf, elem, opts)
elems = append(elems, elem)
}
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
for i := 0; i < len(elems); i += 1 {
v.Elem().Index(i).Set(elems[i].Elem())
}
return read
}
func newSliceDecoder(t reflect.Type) decoderFunc {
dec := &sliceDecoder{
elementType: t.Elem(),
elementDec: typeDecoder(t.Elem()),
}
return dec.decode
}
//////////
type structDecoder struct {
fieldOpts []decOpts
fieldDecs []decoderFunc
}
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
read := 0
for i := range sd.fieldDecs {
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
}
return read
}
func newStructDecoder(t reflect.Type) decoderFunc {
n := t.NumField()
sd := structDecoder{
fieldOpts: make([]decOpts, n),
fieldDecs: make([]decoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
sd.fieldOpts[i] = decOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
}
sd.fieldDecs[i] = typeDecoder(f.Type)
}
return sd.decode
}

View file

@ -1,187 +0,0 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Marshal(v interface{}) ([]byte, error) {
e := &encodeState{}
err := e.marshal(v, encOpts{})
if err != nil {
return nil, err
}
return e.Bytes(), nil
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type encOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
}
type encodeState struct {
bytes.Buffer
}
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
e.reflectValue(reflect.ValueOf(v), opts)
return nil
}
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
valueEncoder(v)(e, v, opts)
}
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
func valueEncoder(v reflect.Value) encoderFunc {
if !v.IsValid() {
panic(fmt.Errorf("Cannot encode an invalid value"))
}
return typeEncoder(v.Type())
}
func typeEncoder(t reflect.Type) encoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeEncoder(t)
}
func newTypeEncoder(t reflect.Type) encoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintEncoder
case reflect.Array:
return newArrayEncoder(t)
case reflect.Slice:
return newSliceEncoder(t)
case reflect.Struct:
return newStructEncoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific encoders below
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
u := v.Uint()
switch v.Type().Kind() {
case reflect.Uint8:
e.WriteByte(byte(u))
case reflect.Uint16:
e.Write([]byte{byte(u >> 8), byte(u)})
case reflect.Uint32:
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
case reflect.Uint64:
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
}
}
//////////
type arrayEncoder struct {
elemEnc encoderFunc
}
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
n := v.Len()
for i := 0; i < n; i += 1 {
ae.elemEnc(e, v.Index(i), opts)
}
}
func newArrayEncoder(t reflect.Type) encoderFunc {
enc := &arrayEncoder{typeEncoder(t.Elem())}
return enc.encode
}
//////////
type sliceEncoder struct {
ae *arrayEncoder
}
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if opts.head == 0 {
panic(fmt.Errorf("Cannot encode a slice without a header length"))
}
arrayState := &encodeState{}
se.ae.encode(arrayState, v, opts)
n := uint(arrayState.Len())
if opts.max > 0 && n > opts.max {
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
}
if n>>(8*opts.head) > 0 {
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
}
if n < opts.min {
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
}
for i := int(opts.head - 1); i >= 0; i -= 1 {
e.WriteByte(byte(n >> (8 * uint(i))))
}
e.Write(arrayState.Bytes())
}
func newSliceEncoder(t reflect.Type) encoderFunc {
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
return enc.encode
}
//////////
type structEncoder struct {
fieldOpts []encOpts
fieldEncs []encoderFunc
}
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
for i := range se.fieldEncs {
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
}
}
func newStructEncoder(t reflect.Type) encoderFunc {
n := t.NumField()
se := structEncoder{
fieldOpts: make([]encOpts, n),
fieldEncs: make([]encoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
se.fieldOpts[i] = encOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
}
se.fieldEncs[i] = typeEncoder(f.Type)
}
return se.encode
}

View file

@ -1,30 +0,0 @@
package syntax
import (
"strconv"
"strings"
)
// `tls:"head=2,min=2,max=255"`
type tagOptions map[string]uint
// parseTag parses a struct field's "tls" tag as a comma-separated list of
// name=value pairs, where the values MUST be unsigned integers
func parseTag(tag string) tagOptions {
opts := tagOptions{}
for _, token := range strings.Split(tag, ",") {
if strings.Index(token, "=") == -1 {
continue
}
parts := strings.Split(token, "=")
if len(parts[0]) == 0 {
continue
}
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
opts[parts[0]] = uint(val)
}
}
return opts
}

View file

@ -1,168 +0,0 @@
package mint
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
import (
"errors"
"net"
"strings"
"time"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, false)
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, true)
}
// A listener implements a network listener (net.Listener) for TLS connections.
type Listener struct {
net.Listener
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection c is a *tls.Conn.
func (l *Listener) Accept() (c net.Conn, err error) {
c, err = l.Listener.Accept()
if err != nil {
return
}
server := Server(c, l.config)
err = server.Handshake()
if err == AlertNoAlert {
err = nil
}
c = server
return
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) net.Listener {
l := new(Listener)
l.Listener = inner
l.config = config
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || !config.ValidForServer() {
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config), nil
}
type TimeoutError struct{}
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (TimeoutError) Timeout() bool { return true }
func (TimeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := dialer.Deadline.Sub(time.Now())
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- TimeoutError{}
})
}
rawConn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = &Config{}
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config)
if timeout == 0 {
err = conn.Handshake()
if err == AlertNoAlert {
err = nil
}
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
if err == AlertNoAlert {
err = nil
}
}
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}

202
vendor/github.com/golang/mock/gomock/LICENSE generated vendored Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

428
vendor/github.com/golang/mock/gomock/call.go generated vendored Normal file
View file

@ -0,0 +1,428 @@
// Copyright 2010 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gomock
import (
"fmt"
"reflect"
"strconv"
"strings"
)
// Call represents an expected call to a mock.
type Call struct {
t TestReporter // for triggering test failures on invalid call setup
receiver interface{} // the receiver of the method call
method string // the name of the method
methodType reflect.Type // the type of the method
args []Matcher // the args
origin string // file and line number of call setup
preReqs []*Call // prerequisite calls
// Expectations
minCalls, maxCalls int
numCalls int // actual number made
// actions are called when this Call is called. Each action gets the args and
// can set the return values by returning a non-nil slice. Actions run in the
// order they are created.
actions []func([]interface{}) []interface{}
}
// newCall creates a *Call. It requires the method type in order to support
// unexported methods.
func newCall(t TestReporter, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
if h, ok := t.(testHelper); ok {
h.Helper()
}
// TODO: check arity, types.
margs := make([]Matcher, len(args))
for i, arg := range args {
if m, ok := arg.(Matcher); ok {
margs[i] = m
} else if arg == nil {
// Handle nil specially so that passing a nil interface value
// will match the typed nils of concrete args.
margs[i] = Nil()
} else {
margs[i] = Eq(arg)
}
}
origin := callerInfo(3)
actions := []func([]interface{}) []interface{}{func([]interface{}) []interface{} {
// Synthesize the zero value for each of the return args' types.
rets := make([]interface{}, methodType.NumOut())
for i := 0; i < methodType.NumOut(); i++ {
rets[i] = reflect.Zero(methodType.Out(i)).Interface()
}
return rets
}}
return &Call{t: t, receiver: receiver, method: method, methodType: methodType,
args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
}
// AnyTimes allows the expectation to be called 0 or more times
func (c *Call) AnyTimes() *Call {
c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity
return c
}
// MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called, MinTimes also
// sets the maximum number of calls to infinity.
func (c *Call) MinTimes(n int) *Call {
c.minCalls = n
if c.maxCalls == 1 {
c.maxCalls = 1e8
}
return c
}
// MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called, MaxTimes also
// sets the minimum number of calls to 0.
func (c *Call) MaxTimes(n int) *Call {
c.maxCalls = n
if c.minCalls == 1 {
c.minCalls = 0
}
return c
}
// DoAndReturn declares the action to run when the call is matched.
// The return values from this function are returned by the mocked function.
// It takes an interface{} argument to support n-arity functions.
func (c *Call) DoAndReturn(f interface{}) *Call {
// TODO: Check arity and types here, rather than dying badly elsewhere.
v := reflect.ValueOf(f)
c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
}
}
vrets := v.Call(vargs)
rets := make([]interface{}, len(vrets))
for i, ret := range vrets {
rets[i] = ret.Interface()
}
return rets
})
return c
}
// Do declares the action to run when the call is matched. The function's
// return values are ignored to retain backward compatibility. To use the
// return values call DoAndReturn.
// It takes an interface{} argument to support n-arity functions.
func (c *Call) Do(f interface{}) *Call {
// TODO: Check arity and types here, rather than dying badly elsewhere.
v := reflect.ValueOf(f)
c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
}
}
v.Call(vargs)
return nil
})
return c
}
// Return declares the values to be returned by the mocked function call.
func (c *Call) Return(rets ...interface{}) *Call {
if h, ok := c.t.(testHelper); ok {
h.Helper()
}
mt := c.methodType
if len(rets) != mt.NumOut() {
c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]",
c.receiver, c.method, len(rets), mt.NumOut(), c.origin)
}
for i, ret := range rets {
if got, want := reflect.TypeOf(ret), mt.Out(i); got == want {
// Identical types; nothing to do.
} else if got == nil {
// Nil needs special handling.
switch want.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
// ok
default:
c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]",
i, c.receiver, c.method, want, c.origin)
}
} else if got.AssignableTo(want) {
// Assignable type relation. Make the assignment now so that the generated code
// can return the values with a type assertion.
v := reflect.New(want).Elem()
v.Set(reflect.ValueOf(ret))
rets[i] = v.Interface()
} else {
c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]",
i, c.receiver, c.method, got, want, c.origin)
}
}
c.addAction(func([]interface{}) []interface{} {
return rets
})
return c
}
// Times declares the exact number of times a function call is expected to be executed.
func (c *Call) Times(n int) *Call {
c.minCalls, c.maxCalls = n, n
return c
}
// SetArg declares an action that will set the nth argument's value,
// indirected through a pointer. Or, in the case of a slice, SetArg
// will copy value's elements into the nth argument.
func (c *Call) SetArg(n int, value interface{}) *Call {
if h, ok := c.t.(testHelper); ok {
h.Helper()
}
mt := c.methodType
// TODO: This will break on variadic methods.
// We will need to check those at invocation time.
if n < 0 || n >= mt.NumIn() {
c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]",
n, mt.NumIn(), c.origin)
}
// Permit setting argument through an interface.
// In the interface case, we don't (nay, can't) check the type here.
at := mt.In(n)
switch at.Kind() {
case reflect.Ptr:
dt := at.Elem()
if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) {
c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]",
n, vt, dt, c.origin)
}
case reflect.Interface:
// nothing to do
case reflect.Slice:
// nothing to do
default:
c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]",
n, at, c.origin)
}
c.addAction(func(args []interface{}) []interface{} {
v := reflect.ValueOf(value)
switch reflect.TypeOf(args[n]).Kind() {
case reflect.Slice:
setSlice(args[n], v)
default:
reflect.ValueOf(args[n]).Elem().Set(v)
}
return nil
})
return c
}
// isPreReq returns true if other is a direct or indirect prerequisite to c.
func (c *Call) isPreReq(other *Call) bool {
for _, preReq := range c.preReqs {
if other == preReq || preReq.isPreReq(other) {
return true
}
}
return false
}
// After declares that the call may only match after preReq has been exhausted.
func (c *Call) After(preReq *Call) *Call {
if h, ok := c.t.(testHelper); ok {
h.Helper()
}
if c == preReq {
c.t.Fatalf("A call isn't allowed to be its own prerequisite")
}
if preReq.isPreReq(c) {
c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq)
}
c.preReqs = append(c.preReqs, preReq)
return c
}
// Returns true if the minimum number of calls have been made.
func (c *Call) satisfied() bool {
return c.numCalls >= c.minCalls
}
// Returns true iff the maximum number of calls have been made.
func (c *Call) exhausted() bool {
return c.numCalls >= c.maxCalls
}
func (c *Call) String() string {
args := make([]string, len(c.args))
for i, arg := range c.args {
args[i] = arg.String()
}
arguments := strings.Join(args, ", ")
return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin)
}
// Tests if the given call matches the expected call.
// If yes, returns nil. If no, returns error with message explaining why it does not match.
func (c *Call) matches(args []interface{}) error {
if !c.methodType.IsVariadic() {
if len(args) != len(c.args) {
return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d",
c.origin, len(args), len(c.args))
}
for i, m := range c.args {
if !m.Matches(args[i]) {
return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
c.origin, strconv.Itoa(i), args[i], m)
}
}
} else {
if len(c.args) < c.methodType.NumIn()-1 {
return fmt.Errorf("Expected call at %s has the wrong number of matchers. Got: %d, want: %d",
c.origin, len(c.args), c.methodType.NumIn()-1)
}
if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) {
return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d",
c.origin, len(args), len(c.args))
}
if len(args) < len(c.args)-1 {
return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d",
c.origin, len(args), len(c.args)-1)
}
for i, m := range c.args {
if i < c.methodType.NumIn()-1 {
// Non-variadic args
if !m.Matches(args[i]) {
return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
c.origin, strconv.Itoa(i), args[i], m)
}
continue
}
// The last arg has a possibility of a variadic argument, so let it branch
// sample: Foo(a int, b int, c ...int)
if i < len(c.args) && i < len(args) {
if m.Matches(args[i]) {
// Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher)
// Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC)
// Got Foo(a, b) want Foo(matcherA, matcherB)
// Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD)
continue
}
}
// The number of actual args don't match the number of matchers,
// or the last matcher is a slice and the last arg is not.
// If this function still matches it is because the last matcher
// matches all the remaining arguments or the lack of any.
// Convert the remaining arguments, if any, into a slice of the
// expected type.
vargsType := c.methodType.In(c.methodType.NumIn() - 1)
vargs := reflect.MakeSlice(vargsType, 0, len(args)-i)
for _, arg := range args[i:] {
vargs = reflect.Append(vargs, reflect.ValueOf(arg))
}
if m.Matches(vargs.Interface()) {
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher)
// Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher)
break
}
// Wrong number of matchers or not match. Fail.
// Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD)
// Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD)
// Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE)
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD)
// Got Foo(a, b, c) want Foo(matcherA, matcherB)
return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
c.origin, strconv.Itoa(i), args[i:], c.args[i])
}
}
// Check that all prerequisite calls have been satisfied.
for _, preReqCall := range c.preReqs {
if !preReqCall.satisfied() {
return fmt.Errorf("Expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v",
c.origin, preReqCall, c)
}
}
// Check that the call is not exhausted.
if c.exhausted() {
return fmt.Errorf("Expected call at %s has already been called the max number of times.", c.origin)
}
return nil
}
// dropPrereqs tells the expected Call to not re-check prerequisite calls any
// longer, and to return its current set.
func (c *Call) dropPrereqs() (preReqs []*Call) {
preReqs = c.preReqs
c.preReqs = nil
return
}
func (c *Call) call(args []interface{}) []func([]interface{}) []interface{} {
c.numCalls++
return c.actions
}
// InOrder declares that the given calls should occur in order.
func InOrder(calls ...*Call) {
for i := 1; i < len(calls); i++ {
calls[i].After(calls[i-1])
}
}
func setSlice(arg interface{}, v reflect.Value) {
va := reflect.ValueOf(arg)
for i := 0; i < v.Len(); i++ {
va.Index(i).Set(v.Index(i))
}
}
func (c *Call) addAction(action func([]interface{}) []interface{}) {
c.actions = append(c.actions, action)
}

108
vendor/github.com/golang/mock/gomock/callset.go generated vendored Normal file
View file

@ -0,0 +1,108 @@
// Copyright 2011 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gomock
import (
"bytes"
"fmt"
)
// callSet represents a set of expected calls, indexed by receiver and method
// name.
type callSet struct {
// Calls that are still expected.
expected map[callSetKey][]*Call
// Calls that have been exhausted.
exhausted map[callSetKey][]*Call
}
// callSetKey is the key in the maps in callSet
type callSetKey struct {
receiver interface{}
fname string
}
func newCallSet() *callSet {
return &callSet{make(map[callSetKey][]*Call), make(map[callSetKey][]*Call)}
}
// Add adds a new expected call.
func (cs callSet) Add(call *Call) {
key := callSetKey{call.receiver, call.method}
m := cs.expected
if call.exhausted() {
m = cs.exhausted
}
m[key] = append(m[key], call)
}
// Remove removes an expected call.
func (cs callSet) Remove(call *Call) {
key := callSetKey{call.receiver, call.method}
calls := cs.expected[key]
for i, c := range calls {
if c == call {
// maintain order for remaining calls
cs.expected[key] = append(calls[:i], calls[i+1:]...)
cs.exhausted[key] = append(cs.exhausted[key], call)
break
}
}
}
// FindMatch searches for a matching call. Returns error with explanation message if no call matched.
func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {
key := callSetKey{receiver, method}
// Search through the expected calls.
expected := cs.expected[key]
var callsErrors bytes.Buffer
for _, call := range expected {
err := call.matches(args)
if err != nil {
fmt.Fprintf(&callsErrors, "\n%v", err)
} else {
return call, nil
}
}
// If we haven't found a match then search through the exhausted calls so we
// get useful error messages.
exhausted := cs.exhausted[key]
for _, call := range exhausted {
if err := call.matches(args); err != nil {
fmt.Fprintf(&callsErrors, "\n%v", err)
}
}
if len(expected)+len(exhausted) == 0 {
fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)
}
return nil, fmt.Errorf(callsErrors.String())
}
// Failures returns the calls that are not satisfied.
func (cs callSet) Failures() []*Call {
failures := make([]*Call, 0, len(cs.expected))
for _, calls := range cs.expected {
for _, call := range calls {
if !call.satisfied() {
failures = append(failures, call)
}
}
}
return failures
}

217
vendor/github.com/golang/mock/gomock/controller.go generated vendored Normal file
View file

@ -0,0 +1,217 @@
// Copyright 2010 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// GoMock - a mock framework for Go.
//
// Standard usage:
// (1) Define an interface that you wish to mock.
// type MyInterface interface {
// SomeMethod(x int64, y string)
// }
// (2) Use mockgen to generate a mock from the interface.
// (3) Use the mock in a test:
// func TestMyThing(t *testing.T) {
// mockCtrl := gomock.NewController(t)
// defer mockCtrl.Finish()
//
// mockObj := something.NewMockMyInterface(mockCtrl)
// mockObj.EXPECT().SomeMethod(4, "blah")
// // pass mockObj to a real object and play with it.
// }
//
// By default, expected calls are not enforced to run in any particular order.
// Call order dependency can be enforced by use of InOrder and/or Call.After.
// Call.After can create more varied call order dependencies, but InOrder is
// often more convenient.
//
// The following examples create equivalent call order dependencies.
//
// Example of using Call.After to chain expected call order:
//
// firstCall := mockObj.EXPECT().SomeMethod(1, "first")
// secondCall := mockObj.EXPECT().SomeMethod(2, "second").After(firstCall)
// mockObj.EXPECT().SomeMethod(3, "third").After(secondCall)
//
// Example of using InOrder to declare expected call order:
//
// gomock.InOrder(
// mockObj.EXPECT().SomeMethod(1, "first"),
// mockObj.EXPECT().SomeMethod(2, "second"),
// mockObj.EXPECT().SomeMethod(3, "third"),
// )
//
// TODO:
// - Handle different argument/return types (e.g. ..., chan, map, interface).
package gomock
import (
"fmt"
"golang.org/x/net/context"
"reflect"
"runtime"
"sync"
)
// A TestReporter is something that can be used to report test failures.
// It is satisfied by the standard library's *testing.T.
type TestReporter interface {
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
}
// A Controller represents the top-level control of a mock ecosystem.
// It defines the scope and lifetime of mock objects, as well as their expectations.
// It is safe to call Controller's methods from multiple goroutines.
type Controller struct {
mu sync.Mutex
t TestReporter
expectedCalls *callSet
finished bool
}
func NewController(t TestReporter) *Controller {
return &Controller{
t: t,
expectedCalls: newCallSet(),
}
}
type cancelReporter struct {
t TestReporter
cancel func()
}
func (r *cancelReporter) Errorf(format string, args ...interface{}) { r.t.Errorf(format, args...) }
func (r *cancelReporter) Fatalf(format string, args ...interface{}) {
defer r.cancel()
r.t.Fatalf(format, args...)
}
// WithContext returns a new Controller and a Context, which is cancelled on any
// fatal failure.
func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return NewController(&cancelReporter{t, cancel}), ctx
}
func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call {
if h, ok := ctrl.t.(testHelper); ok {
h.Helper()
}
recv := reflect.ValueOf(receiver)
for i := 0; i < recv.Type().NumMethod(); i++ {
if recv.Type().Method(i).Name == method {
return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...)
}
}
ctrl.t.Fatalf("gomock: failed finding method %s on %T", method, receiver)
panic("unreachable")
}
func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
if h, ok := ctrl.t.(testHelper); ok {
h.Helper()
}
call := newCall(ctrl.t, receiver, method, methodType, args...)
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
ctrl.expectedCalls.Add(call)
return call
}
func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
if h, ok := ctrl.t.(testHelper); ok {
h.Helper()
}
// Nest this code so we can use defer to make sure the lock is released.
actions := func() []func([]interface{}) []interface{} {
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
if err != nil {
origin := callerInfo(2)
ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
}
// Two things happen here:
// * the matching call no longer needs to check prerequite calls,
// * and the prerequite calls are no longer expected, so remove them.
preReqCalls := expected.dropPrereqs()
for _, preReqCall := range preReqCalls {
ctrl.expectedCalls.Remove(preReqCall)
}
actions := expected.call(args)
if expected.exhausted() {
ctrl.expectedCalls.Remove(expected)
}
return actions
}()
var rets []interface{}
for _, action := range actions {
if r := action(args); r != nil {
rets = r
}
}
return rets
}
func (ctrl *Controller) Finish() {
if h, ok := ctrl.t.(testHelper); ok {
h.Helper()
}
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
if ctrl.finished {
ctrl.t.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.")
}
ctrl.finished = true
// If we're currently panicking, probably because this is a deferred call,
// pass through the panic.
if err := recover(); err != nil {
panic(err)
}
// Check that all remaining expected calls are satisfied.
failures := ctrl.expectedCalls.Failures()
for _, call := range failures {
ctrl.t.Errorf("missing call(s) to %v", call)
}
if len(failures) != 0 {
ctrl.t.Fatalf("aborting test due to missing call(s)")
}
}
func callerInfo(skip int) string {
if _, file, line, ok := runtime.Caller(skip + 1); ok {
return fmt.Sprintf("%s:%d", file, line)
}
return "unknown file"
}
type testHelper interface {
TestReporter
Helper()
}

124
vendor/github.com/golang/mock/gomock/matchers.go generated vendored Normal file
View file

@ -0,0 +1,124 @@
//go:generate mockgen -destination mock_matcher/mock_matcher.go github.com/golang/mock/gomock Matcher
// Copyright 2010 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gomock
import (
"fmt"
"reflect"
)
// A Matcher is a representation of a class of values.
// It is used to represent the valid or expected arguments to a mocked method.
type Matcher interface {
// Matches returns whether x is a match.
Matches(x interface{}) bool
// String describes what the matcher matches.
String() string
}
type anyMatcher struct{}
func (anyMatcher) Matches(x interface{}) bool {
return true
}
func (anyMatcher) String() string {
return "is anything"
}
type eqMatcher struct {
x interface{}
}
func (e eqMatcher) Matches(x interface{}) bool {
return reflect.DeepEqual(e.x, x)
}
func (e eqMatcher) String() string {
return fmt.Sprintf("is equal to %v", e.x)
}
type nilMatcher struct{}
func (nilMatcher) Matches(x interface{}) bool {
if x == nil {
return true
}
v := reflect.ValueOf(x)
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.Ptr, reflect.Slice:
return v.IsNil()
}
return false
}
func (nilMatcher) String() string {
return "is nil"
}
type notMatcher struct {
m Matcher
}
func (n notMatcher) Matches(x interface{}) bool {
return !n.m.Matches(x)
}
func (n notMatcher) String() string {
// TODO: Improve this if we add a NotString method to the Matcher interface.
return "not(" + n.m.String() + ")"
}
type assignableToTypeOfMatcher struct {
targetType reflect.Type
}
func (m assignableToTypeOfMatcher) Matches(x interface{}) bool {
return reflect.TypeOf(x).AssignableTo(m.targetType)
}
func (m assignableToTypeOfMatcher) String() string {
return "is assignable to " + m.targetType.Name()
}
// Constructors
func Any() Matcher { return anyMatcher{} }
func Eq(x interface{}) Matcher { return eqMatcher{x} }
func Nil() Matcher { return nilMatcher{} }
func Not(x interface{}) Matcher {
if m, ok := x.(Matcher); ok {
return notMatcher{m}
}
return notMatcher{Eq(x)}
}
// AssignableToTypeOf is a Matcher that matches if the parameter to the mock
// function is assignable to the type of the parameter to this function.
//
// Example usage:
//
// dbMock.EXPECT().
// Insert(gomock.AssignableToTypeOf(&EmployeeRecord{})).
// Return(errors.New("DB error"))
//
func AssignableToTypeOf(x interface{}) Matcher {
return assignableToTypeOfMatcher{reflect.TypeOf(x)}
}

View file

@ -0,0 +1,57 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/golang/mock/gomock (interfaces: Matcher)
// Package mock_gomock is a generated GoMock package.
package mock_gomock
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockMatcher is a mock of Matcher interface
type MockMatcher struct {
ctrl *gomock.Controller
recorder *MockMatcherMockRecorder
}
// MockMatcherMockRecorder is the mock recorder for MockMatcher
type MockMatcherMockRecorder struct {
mock *MockMatcher
}
// NewMockMatcher creates a new mock instance
func NewMockMatcher(ctrl *gomock.Controller) *MockMatcher {
mock := &MockMatcher{ctrl: ctrl}
mock.recorder = &MockMatcherMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockMatcher) EXPECT() *MockMatcherMockRecorder {
return m.recorder
}
// Matches mocks base method
func (m *MockMatcher) Matches(arg0 interface{}) bool {
ret := m.ctrl.Call(m, "Matches", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Matches indicates an expected call of Matches
func (mr *MockMatcherMockRecorder) Matches(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Matches", reflect.TypeOf((*MockMatcher)(nil).Matches), arg0)
}
// String mocks base method
func (m *MockMatcher) String() string {
ret := m.ctrl.Call(m, "String")
ret0, _ := ret[0].(string)
return ret0
}
// String indicates an expected call of String
func (mr *MockMatcherMockRecorder) String() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockMatcher)(nil).String))
}

View file

@ -1,362 +0,0 @@
Mozilla Public License, version 2.0
1. Definitions
1.1. "Contributor"
means each individual or legal entity that creates, contributes to the
creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used by a
Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached the
notice in Exhibit A, the Executable Form of such Source Code Form, and
Modifications of such Source Code Form, in each case including portions
thereof.
1.5. "Incompatible With Secondary Licenses"
means
a. that the initial Contributor has attached the notice described in
Exhibit B to the Covered Software; or
b. that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the terms of
a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in a
separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible, whether
at the time of the initial grant or subsequently, any and all of the
rights conveyed by this License.
1.10. "Modifications"
means any of the following:
a. any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered Software; or
b. any new file in Source Code Form that contains any Covered Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the License,
by the making, using, selling, offering for sale, having made, import,
or transfer of either its Contributions or its Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU Lesser
General Public License, Version 2.1, the GNU Affero General Public
License, Version 3.0, or any later versions of those licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that controls, is
controlled by, or is under common control with You. For purposes of this
definition, "control" means (a) the power, direct or indirect, to cause
the direction or management of such entity, whether by contract or
otherwise, or (b) ownership of more than fifty percent (50%) of the
outstanding shares or beneficial ownership of such entity.
2. License Grants and Conditions
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
a. under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
b. under Patent Claims of such Contributor to make, use, sell, offer for
sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
a. for any code that a Contributor has removed from Covered Software; or
b. for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
c. under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights to
grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
Section 2.1.
3. Responsibilities
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
a. such Covered Software must also be made available in Source Code Form,
as described in Section 3.1, and You must inform recipients of the
Executable Form how they can obtain a copy of such Source Code Form by
reasonable means in a timely manner, at a charge no more than the cost
of distribution to the recipient; and
b. You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter the
recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty, or
limitations of liability) contained within the Source Code Form of the
Covered Software, except that You may alter any license notices to the
extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
If it is impossible for You to comply with any of the terms of this License
with respect to some or all of the Covered Software due to statute,
judicial order, or regulation then You must: (a) comply with the terms of
this License to the maximum extent possible; and (b) describe the
limitations and the code they affect. Such description must be placed in a
text file included with all distributions of the Covered Software under
this License. Except to the extent prohibited by statute or regulation,
such description must be sufficiently detailed for a recipient of ordinary
skill to be able to understand it.
5. Termination
5.1. The rights granted under this License will terminate automatically if You
fail to comply with any of its terms. However, if You become compliant,
then the rights granted under this License from a particular Contributor
are reinstated (a) provisionally, unless and until such Contributor
explicitly and finally terminates Your grants, and (b) on an ongoing
basis, if such Contributor fails to notify You of the non-compliance by
some reasonable means prior to 60 days after You have come back into
compliance. Moreover, Your grants from a particular Contributor are
reinstated on an ongoing basis if such Contributor notifies You of the
non-compliance by some reasonable means, this is the first time You have
received notice of non-compliance with this License from such
Contributor, and You become compliant prior to 30 days after Your receipt
of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
license agreements (excluding distributors and resellers) which have been
validly granted by You or Your distributors under this License prior to
termination shall survive termination.
6. Disclaimer of Warranty
Covered Software is provided under this License on an "as is" basis,
without warranty of any kind, either expressed, implied, or statutory,
including, without limitation, warranties that the Covered Software is free
of defects, merchantable, fit for a particular purpose or non-infringing.
The entire risk as to the quality and performance of the Covered Software
is with You. Should any Covered Software prove defective in any respect,
You (not any Contributor) assume the cost of any necessary servicing,
repair, or correction. This disclaimer of warranty constitutes an essential
part of this License. No use of any Covered Software is authorized under
this License except under this disclaimer.
7. Limitation of Liability
Under no circumstances and under no legal theory, whether tort (including
negligence), contract, or otherwise, shall any Contributor, or anyone who
distributes Covered Software as permitted above, be liable to You for any
direct, indirect, special, incidental, or consequential damages of any
character including, without limitation, damages for lost profits, loss of
goodwill, work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses, even if such party shall have been
informed of the possibility of such damages. This limitation of liability
shall not apply to liability for death or personal injury resulting from
such party's negligence to the extent applicable law prohibits such
limitation. Some jurisdictions do not allow the exclusion or limitation of
incidental or consequential damages, so this exclusion and limitation may
not apply to You.
8. Litigation
Any litigation relating to this License may be brought only in the courts
of a jurisdiction where the defendant maintains its principal place of
business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions. Nothing
in this Section shall prevent a party's ability to bring cross-claims or
counter-claims.
9. Miscellaneous
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides that
the language of a contract shall be construed against the drafter shall not
be used to construe this License against a Contributor.
10. Versions of the License
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses If You choose to distribute Source Code Form that is
Incompatible With Secondary Licenses under the terms of this version of
the License, the notice described in Exhibit B of this License must be
attached.
Exhibit A - Source Code Form License Notice
This Source Code Form is subject to the
terms of the Mozilla Public License, v.
2.0. If a copy of the MPL was not
distributed with this file, You can
obtain one at
http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular file,
then You may include the notice in a location (such as a LICENSE file in a
relevant directory) where a recipient would be likely to look for such a
notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
This Source Code Form is "Incompatible
With Secondary Licenses", as defined by
the Mozilla Public License, v. 2.0.

View file

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2016 Lucas Clemente
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2016 Lucas Clemente
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,87 +0,0 @@
// Package fnv128a implements FNV-1 and FNV-1a, non-cryptographic hash functions
// created by Glenn Fowler, Landon Curt Noll, and Phong Vo.
// See https://en.wikipedia.org/wiki/Fowler-Noll-Vo_hash_function.
//
// Write() algorithm taken and modified from github.com/romain-jacotin/quic
package fnv128a
import "hash"
// Hash128 is the common interface implemented by all 128-bit hash functions.
type Hash128 interface {
hash.Hash
Sum128() (uint64, uint64)
}
type sum128a struct {
v0, v1, v2, v3 uint64
}
var _ Hash128 = &sum128a{}
// New1 returns a new 128-bit FNV-1a hash.Hash.
func New() Hash128 {
s := &sum128a{}
s.Reset()
return s
}
func (s *sum128a) Reset() {
s.v0 = 0x6295C58D
s.v1 = 0x62B82175
s.v2 = 0x07BB0142
s.v3 = 0x6C62272E
}
func (s *sum128a) Sum128() (uint64, uint64) {
return s.v3<<32 | s.v2, s.v1<<32 | s.v0
}
func (s *sum128a) Write(data []byte) (int, error) {
var t0, t1, t2, t3 uint64
const fnv128PrimeLow = 0x0000013B
const fnv128PrimeShift = 24
for _, v := range data {
// xor the bottom with the current octet
s.v0 ^= uint64(v)
// multiply by the 128 bit FNV magic prime mod 2^128
// fnv_prime = 309485009821345068724781371 (decimal)
// = 0x0000000001000000000000000000013B (hexadecimal)
// = 0x00000000 0x01000000 0x00000000 0x0000013B (in 4*32 words)
// = 0x0 1<<fnv128PrimeShift 0x0 fnv128PrimeLow
//
// fnv128PrimeLow = 0x0000013B
// fnv128PrimeShift = 24
// multiply by the lowest order digit base 2^32 and by the other non-zero digit
t0 = s.v0 * fnv128PrimeLow
t1 = s.v1 * fnv128PrimeLow
t2 = s.v2*fnv128PrimeLow + s.v0<<fnv128PrimeShift
t3 = s.v3*fnv128PrimeLow + s.v1<<fnv128PrimeShift
// propagate carries
t1 += (t0 >> 32)
t2 += (t1 >> 32)
t3 += (t2 >> 32)
s.v0 = t0 & 0xffffffff
s.v1 = t1 & 0xffffffff
s.v2 = t2 & 0xffffffff
s.v3 = t3 // & 0xffffffff
// Doing a s.v3 &= 0xffffffff is not really needed since it simply
// removes multiples of 2^128. We can discard these excess bits
// outside of the loop when writing the hash in Little Endian.
}
return len(data), nil
}
func (s *sum128a) Size() int { return 16 }
func (s *sum128a) BlockSize() int { return 1 }
func (s *sum128a) Sum(in []byte) []byte {
panic("FNV: not supported")
}

View file

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2016 Lucas Clemente
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -2,14 +2,14 @@ package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -20,37 +20,68 @@ import (
type client struct {
mutex sync.Mutex
conn connection
conn connection
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
hostname string
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has the server accepted our version
packetHandlers packetHandlerManager
token []byte
numRetries int
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
config *Config
tls handshake.MintTLS // only used when using TLS
tlsConf *tls.Config
mintConf *mint.Config
config *Config
connectionID protocol.ConnectionID
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialVersion protocol.VersionNumber
version protocol.VersionNumber
session packetHandler
handshakeChan chan struct{}
closeCallback func(protocol.ConnectionID)
session quicSession
logger utils.Logger
}
var _ packetHandler = &client{}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = utils.GenerateConnectionID
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
)
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
func DialAddr(
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return DialAddrContext(context.Background(), addr, tlsConf, config)
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address.
func DialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -59,7 +90,7 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, addr, tlsConf, config)
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
@ -71,16 +102,69 @@ func Dial(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
connID, err := generateConnectionID()
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
}
func dialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
if !createdPacketConn {
for _, v := range config.Versions {
if v == protocol.Version44 {
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
}
}
}
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.session, nil
}
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
config *Config,
tlsConf *tls.Config,
host string,
closeCallback func(protocol.ConnectionID),
createdPacketConn bool,
) (*client, error) {
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
var err error
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
@ -95,29 +179,27 @@ func Dial(
}
}
}
clientConfig := populateClientConfig(config)
onClose := func(protocol.ConnectionID) {}
if closeCallback != nil {
onClose = closeCallback
}
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}),
logger: utils.DefaultLogger,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
hostname: hostname,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}
c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.dial(); err != nil {
return nil, err
}
return c.session, nil
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config {
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if config == nil {
config = &Config{}
}
@ -155,12 +237,22 @@ func populateClientConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
for _, v := range versions {
if v == protocol.Version44 {
connIDLen = 0
}
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
ConnectionIDLength: connIDLen,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
@ -169,28 +261,54 @@ func populateClientConfig(config *Config) *Config {
}
}
func (c *client) dial() error {
func (c *client) generateConnectionIDs() error {
connIDLen := protocol.ConnectionIDLenGQUIC
if c.version.UsesTLS() {
connIDLen = c.config.ConnectionIDLength
}
srcConnID, err := generateConnectionID(connIDLen)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = generateConnectionIDForInitial()
if err != nil {
return err
}
}
c.srcConnID = srcConnID
c.destConnID = destConnID
if c.version == protocol.Version44 {
c.srcConnID = nil
}
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
var err error
if c.version.UsesTLS() {
err = c.dialTLS()
err = c.dialTLS(ctx)
} else {
err = c.dialGQUIC()
}
if err == errCloseSessionForNewVersion {
return c.dial()
err = c.dialGQUIC(ctx)
}
return err
}
func (c *client) dialGQUIC() error {
func (c *client) dialGQUIC(ctx context.Context) error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
go c.listen()
return c.establishSecureConnection()
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
func (c *client) dialTLS() error {
func (c *client) dialTLS(ctx context.Context) error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
@ -198,8 +316,8 @@ func (c *client) dialTLS() error {
OmitConnectionID: c.config.RequestConnectionIDOmission,
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
DisableMigration: true,
}
csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
@ -207,25 +325,16 @@ func (c *client) dialTLS() error {
}
mintConf.ExtensionHandler = extHandler
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
c.mintConf = mintConf
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
go c.listen()
if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry {
return err
}
c.logger.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
return err
}
err = c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return nil
return err
}
// establishSecureConnection runs the session, and tries to establish a secure connection
@ -234,134 +343,114 @@ func (c *client) dialTLS() error {
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error {
var runErr error
errorChan := make(chan struct{})
func (c *client) establishSecureConnection(ctx context.Context) error {
errorChan := make(chan error, 1)
go func() {
runErr = c.session.run() // returns as soon as the session is closed
close(errorChan)
c.logger.Infof("Connection %x closed.", c.connectionID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
}()
// wait until the server accepts the QUIC version (or an error occurs)
select {
case <-errorChan:
return runErr
case <-c.versionNegotiationChan:
}
select {
case <-errorChan:
return runErr
case err := <-c.session.handshakeStatus():
case <-ctx.Done():
// The session will send a PeerGoingAway error to the server.
c.session.Close()
return ctx.Err()
case err := <-errorChan:
return err
case <-c.handshakeChan:
// handshake successfully completed
return nil
}
}
// Listen listens on the underlying connection and passes packets on for handling.
// It returns when the connection is closed.
func (c *client) listen() {
var err error
for {
var n int
var addr net.Addr
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err = c.conn.Read(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
c.mutex.Lock()
if c.session != nil {
c.session.Close(err)
}
c.mutex.Unlock()
}
break
}
c.handlePacket(addr, data[:n])
func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil {
c.logger.Errorf("error handling packet: %s", err)
}
}
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now()
r := bytes.NewReader(packet)
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
if err != nil {
c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the header
return
}
// reject packets with truncated connection id if we didn't request truncation
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return
}
hdr.Raw = packet[:len(packet)-r.Len()]
func (c *client) handlePacketImpl(p *receivedPacket) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
c.logger.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := wire.ParsePublicReset(r)
if err != nil {
c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return
}
c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return
}
// handle Version Negotiation Packets
if hdr.IsVersionNegotiation {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
return
if p.header.IsVersionNegotiation {
err := c.handleVersionNegotiationPacket(p.header)
if err != nil {
c.session.destroy(err)
}
// version negotiation packets have no payload
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
c.session.Close(err)
return err
}
if !c.version.UsesIETFHeaderFormat() {
connID := p.header.DestConnectionID
// reject packets with truncated connection id if we didn't request truncation
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
// reject packets with the wrong connection ID
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
}
if p.header.ResetFlag {
return c.handlePublicReset(p)
}
} else {
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
}
if p.header.IsLongHeader {
switch p.header.Type {
case protocol.PacketTypeRetry:
c.handleRetryPacket(p.header)
return nil
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
default:
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
}
return
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
close(c.versionNegotiationChan)
}
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
c.session.handlePacket(p)
return nil
}
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
})
func (c *client) handlePublicReset(p *receivedPacket) error {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
if err != nil {
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
}
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
return nil
}
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
@ -372,7 +461,6 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
}
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
@ -383,49 +471,125 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
var err error
c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
if err := c.generateConnectionIDs(); err != nil {
return err
}
c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
return nil
}
func (c *client) createNewGQUICSession() (err error) {
func (c *client) handleRetryPacket(hdr *wire.Header) {
c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger)
// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
// Only a server that won't send additional Retries can use shorter connection IDs.
if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
return
}
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return
}
c.numRetries++
if c.numRetries > protocol.MaxRetries {
c.session.destroy(qerr.CryptoTooManyRejects)
return
}
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
}
func (c *client) createNewGQUICSession() error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newClientSession(
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newClientSession(
c.conn,
runner,
c.hostname,
c.version,
c.connectionID,
c.destConnID,
c.srcConnID,
c.tlsConf,
c.config,
c.initialVersion,
c.negotiatedVersions,
c.logger,
)
return err
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
if c.config.RequestConnectionIDOmission {
c.packetHandlers.Add(protocol.ConnectionID{}, c)
}
return nil
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) (err error) {
) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newTLSClientSession(
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newTLSClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
runner,
c.token,
c.destConnID,
c.srcConnID,
c.config,
c.tls,
c.mintConf,
paramsChan,
1,
c.logger,
c.version,
)
return err
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
return nil
}
func (c *client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return nil
}
return c.session.Close()
}
func (c *client) destroy(e error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return
}
c.session.destroy(e)
}
func (c *client) GetVersion() protocol.VersionNumber {
c.mutex.Lock()
v := c.version
c.mutex.Unlock()
return v
}
func (c *client) GetPerspective() protocol.Perspective {
return protocol.PerspectiveClient
}

View file

@ -8,7 +8,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStreamI interface {
type cryptoStream interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
@ -21,21 +21,21 @@ type cryptoStreamI interface {
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream struct {
type cryptoStreamImpl struct {
*stream
}
var _ cryptoStreamI = &cryptoStream{}
var _ cryptoStream = &cryptoStreamImpl{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStream{str}
return &cryptoStreamImpl{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset
s.receiveStream.frameQueue.readPos = offset
}

View file

@ -5,51 +5,55 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type streamFrameSorter struct {
queuedFrames map[protocol.ByteCount]*wire.StreamFrame
readPosition protocol.ByteCount
gaps *utils.ByteIntervalList
type frameSorter struct {
queue map[protocol.ByteCount][]byte
readPos protocol.ByteCount
finalOffset protocol.ByteCount
gaps *utils.ByteIntervalList
}
var (
errTooManyGapsInReceivedStreamData = errors.New("Too many gaps in received StreamFrame data")
errDuplicateStreamData = errors.New("Duplicate Stream Data")
errEmptyStreamData = errors.New("Stream Data empty")
)
var errDuplicateStreamData = errors.New("Duplicate Stream Data")
func newStreamFrameSorter() *streamFrameSorter {
s := streamFrameSorter{
gaps: utils.NewByteIntervalList(),
queuedFrames: make(map[protocol.ByteCount]*wire.StreamFrame),
func newFrameSorter() *frameSorter {
s := frameSorter{
gaps: utils.NewByteIntervalList(),
queue: make(map[protocol.ByteCount][]byte),
finalOffset: protocol.MaxByteCount,
}
s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount})
return &s
}
func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
if frame.DataLen() == 0 {
if frame.FinBit {
s.queuedFrames[frame.Offset] = frame
return nil
}
return errEmptyStreamData
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, fin bool) error {
err := s.push(data, offset, fin)
if err == errDuplicateStreamData {
return nil
}
return err
}
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, fin bool) error {
if fin {
s.finalOffset = offset + protocol.ByteCount(len(data))
}
if len(data) == 0 {
return nil
}
var wasCut bool
if oldFrame, ok := s.queuedFrames[frame.Offset]; ok {
if frame.DataLen() <= oldFrame.DataLen() {
if oldData, ok := s.queue[offset]; ok {
if len(data) <= len(oldData) {
return errDuplicateStreamData
}
frame.Data = frame.Data[oldFrame.DataLen():]
frame.Offset += oldFrame.DataLen()
data = data[len(oldData):]
offset += protocol.ByteCount(len(oldData))
wasCut = true
}
start := frame.Offset
end := frame.Offset + frame.DataLen()
start := offset
end := offset + protocol.ByteCount(len(data))
// skip all gaps that are before this stream frame
var gap *utils.ByteIntervalElement
@ -69,9 +73,9 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
if start < gap.Value.Start {
add := gap.Value.Start - start
frame.Offset += add
offset += add
start += add
frame.Data = frame.Data[add:]
data = data[add:]
wasCut = true
}
@ -89,15 +93,15 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
break
}
// delete queued frames completely covered by the current frame
delete(s.queuedFrames, endGap.Value.End)
delete(s.queue, endGap.Value.End)
endGap = nextEndGap
}
if end > endGap.Value.End {
cutLen := end - endGap.Value.End
len := frame.DataLen() - cutLen
len := protocol.ByteCount(len(data)) - cutLen
end -= cutLen
frame.Data = frame.Data[:len]
data = data[:len]
wasCut = true
}
@ -130,32 +134,25 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
}
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
return errTooManyGapsInReceivedStreamData
return errors.New("Too many gaps in received data")
}
if wasCut {
data := make([]byte, frame.DataLen())
copy(data, frame.Data)
frame.Data = data
newData := make([]byte, len(data))
copy(newData, data)
data = newData
}
s.queuedFrames[frame.Offset] = frame
s.queue[offset] = data
return nil
}
func (s *streamFrameSorter) Pop() *wire.StreamFrame {
frame := s.Head()
if frame != nil {
s.readPosition += frame.DataLen()
delete(s.queuedFrames, frame.Offset)
func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
data, ok := s.queue[s.readPos]
if !ok {
return nil, s.readPos >= s.finalOffset
}
return frame
}
func (s *streamFrameSorter) Head() *wire.StreamFrame {
frame, ok := s.queuedFrames[s.readPosition]
if ok {
return frame
}
return nil
delete(s.queue, s.readPos)
s.readPos += protocol.ByteCount(len(data))
return data, s.readPos >= s.finalOffset
}

View file

@ -77,7 +77,7 @@ func newClient(
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
logger: utils.DefaultLogger,
logger: utils.DefaultLogger.WithPrefix("client"),
}
}
@ -172,7 +172,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync()
if err != nil {
_ = c.CloseWithError(err)
_ = c.closeWithError(err)
return nil, err
}
c.mutex.Lock()
@ -187,7 +187,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil {
_ = c.CloseWithError(err)
_ = c.closeWithError(err)
return nil, err
}
@ -230,7 +230,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, ctx.Err()
case <-c.headerErrored:
// an error occurred on the header stream
_ = c.CloseWithError(c.headerErr)
_ = c.closeWithError(c.headerErr)
return nil, c.headerErr
}
}
@ -275,16 +275,19 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
return dataStream.Close()
}
// Close closes the client
func (c *client) CloseWithError(e error) error {
func (c *client) closeWithError(e error) error {
if c.session == nil {
return nil
}
return c.session.Close(e)
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
}
// Close closes the client
func (c *client) Close() error {
return c.CloseWithError(nil)
if c.session == nil {
return nil
}
return c.session.Close()
}
// copied from net/transport.go

View file

@ -70,9 +70,6 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
}
func hostnameFromRequest(req *http.Request) string {
if len(req.Host) > 0 {
return req.Host
}
if req.URL != nil {
return req.URL.Host
}

View file

@ -8,9 +8,9 @@ import (
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/lex/httplex"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -65,7 +65,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
if host == "" {
host = req.URL.Host
}
host, err := httplex.PunycodeHostPort(host)
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return nil, err
}
@ -89,11 +89,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httplex.ValidHeaderFieldName(k) {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httplex.ValidHeaderFieldValue(v) {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}

View file

@ -98,9 +98,6 @@ func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
// test that we implement http.Flusher
var _ http.Flusher = &responseWriter{}
// test that we implement http.CloseNotifier
var _ http.CloseNotifier = &responseWriter{}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.

View file

@ -0,0 +1,9 @@
package h2quic
import "net/http"
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
// By defining it in a separate file, we can exclude this file from staticcheck.
// test that we implement http.CloseNotifier
var _ http.CloseNotifier = &responseWriter{}

View file

@ -11,7 +11,7 @@ import (
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/lex/httplex"
"golang.org/x/net/http/httpguts"
)
type roundTripCloser interface {
@ -80,11 +80,11 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
if req.URL.Scheme == "https" {
for k, vv := range req.Header {
if !httplex.ValidHeaderFieldName(k) {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
}
for _, v := range vv {
if !httplex.ValidHeaderFieldValue(v) {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
}
}
@ -175,5 +175,5 @@ func validMethod(method string) bool {
// copied from net/http/http.go
func isNotToken(r rune) bool {
return !httplex.IsTokenRune(r)
return !httpguts.IsTokenRune(r)
}

View file

@ -90,7 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
s.logger = utils.DefaultLogger
s.logger = utils.DefaultLogger.WithPrefix("server")
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
func (s *Server) handleHeaderStream(session streamCreator) {
stream, err := session.AcceptStream()
if err != nil {
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
return
}
@ -140,10 +140,12 @@ func (s *Server) handleHeaderStream(session streamCreator) {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
if _, ok := err.(*qerr.QuicError); !ok {
errorCode := qerr.InternalError
if qerr, ok := err.(*qerr.QuicError); !ok {
errorCode = qerr.ErrorCode
s.logger.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(err)
session.CloseWithError(quic.ErrorCode(errorCode), err)
return
}
}
@ -154,10 +156,18 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
if err != nil {
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
}
h2headersFrame, ok := h2frame.(*http2.HeadersFrame)
if !ok {
var h2headersFrame *http2.HeadersFrame
switch f := h2frame.(type) {
case *http2.PriorityFrame:
// ignore PRIORITY frames
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
return nil
case *http2.HeadersFrame:
h2headersFrame = f
default:
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
}
if !h2headersFrame.HeadersEnded() {
return errors.New("http2 header continuation not implemented")
}
@ -238,7 +248,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
}
if s.CloseAfterFirstRequest {
time.Sleep(100 * time.Millisecond)
session.Close(nil)
session.Close()
}
}()

View file

@ -1,270 +0,0 @@
package quicproxy
import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Connection is a UDP connection
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerConn *net.UDPConn // UDP connection to server
incomingPacketCounter uint64
outgoingPacketCounter uint64
}
// Direction is the direction a packet is sent.
type Direction int
const (
// DirectionIncoming is the direction from the client to the server.
DirectionIncoming Direction = iota
// DirectionOutgoing is the direction from the server to the client.
DirectionOutgoing
// DirectionBoth is both incoming and outgoing
DirectionBoth
)
func (d Direction) String() string {
switch d {
case DirectionIncoming:
return "incoming"
case DirectionOutgoing:
return "outgoing"
case DirectionBoth:
return "both"
default:
panic("unknown direction")
}
}
// Is says if one direction matches another direction.
// For example, incoming matches both incoming and both, but not outgoing.
func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth {
return true
}
return d == dir
}
// DropCallback is a callback that determines which packet gets dropped.
type DropCallback func(dir Direction, packetCount uint64) bool
// NoDropper doesn't drop packets.
var NoDropper DropCallback = func(Direction, uint64) bool {
return false
}
// DelayCallback is a callback that determines how much delay to apply to a packet.
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
// NoDelay doesn't apply a delay.
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
return 0
}
// Opts are proxy options.
type Opts struct {
// The address this proxy proxies packets to.
RemoteAddr string
// DropPacket determines whether a packet gets dropped.
DropPacket DropCallback
// DelayPacket determines how long a packet gets delayed. This allows
// simulating a connection with non-zero RTTs.
// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
DelayPacket DelayCallback
}
// QuicProxy is a QUIC proxy that can drop and delay packets.
type QuicProxy struct {
mutex sync.Mutex
version protocol.VersionNumber
conn *net.UDPConn
serverAddr *net.UDPAddr
dropPacket DropCallback
delayPacket DelayCallback
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
logger utils.Logger
}
// NewQuicProxy creates a new UDP proxy
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
if opts == nil {
opts = &Opts{}
}
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, err
}
raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
if err != nil {
return nil, err
}
packetDropper := NoDropper
if opts.DropPacket != nil {
packetDropper = opts.DropPacket
}
packetDelayer := NoDelay
if opts.DelayPacket != nil {
packetDelayer = opts.DelayPacket
}
p := QuicProxy{
clientDict: make(map[string]*connection),
conn: conn,
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
version: version,
logger: utils.DefaultLogger,
}
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
go p.runProxy()
return &p, nil
}
// Close stops the UDP Proxy
func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
return p.conn.Close()
}
// LocalAddr is the address the proxy is listening on.
func (p *QuicProxy) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// LocalPort is the UDP port number the proxy is listening on.
func (p *QuicProxy) LocalPort() int {
return p.conn.LocalAddr().(*net.UDPAddr).Port
}
func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
srvudp, err := net.DialUDP("udp", nil, p.serverAddr)
if err != nil {
return nil, err
}
return &connection{
ClientAddr: cliAddr,
ServerConn: srvudp,
}, nil
}
// runProxy listens on the proxy address and handles incoming packets.
func (p *QuicProxy) runProxy() error {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
saddr := cliaddr.String()
p.mutex.Lock()
conn, ok := p.clientDict[saddr]
if !ok {
conn, err = p.newConnection(cliaddr)
if err != nil {
p.mutex.Unlock()
return err
}
p.clientDict[saddr] = conn
go p.runConnection(conn)
}
p.mutex.Unlock()
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
if p.dropPacket(DirectionIncoming, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
}
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = conn.ServerConn.Write(raw)
})
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err
}
}
}
}
// runConnection handles packets from server to a single client
func (p *QuicProxy) runConnection(conn *connection) error {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
if p.dropPacket(DirectionOutgoing, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
})
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return err
}
}
}
}

View file

@ -1,41 +0,0 @@
package testlog
import (
"flag"
"log"
"os"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var (
logFileName string // the log file set in the ginkgo flags
logFile *os.File
)
// read the logfile command line flag
// to set call ginkgo -- -logfile=log.txt
func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
}
var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
if len(logFileName) > 0 {
var err error
logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
}
})
var _ = AfterEach(func() {
if len(logFileName) > 0 {
_ = logFile.Close()
}
})

View file

@ -1,119 +0,0 @@
package testserver
import (
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLenLong = 50 * 1024 * 1024 // 50 MB
)
var (
// PRData contains dataLen bytes of pseudo-random data.
PRData = GeneratePRData(dataLen)
// PRDataLong contains dataLenLong bytes of pseudo-random data.
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
stoppedServing chan struct{}
port string
)
func init() {
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
sl := r.URL.Query().Get("len")
if sl != "" {
var err error
l, err := strconv.Atoi(sl)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(GeneratePRData(l))
Expect(err).NotTo(HaveOccurred())
} else {
_, err := w.Write(PRData)
Expect(err).NotTo(HaveOccurred())
}
})
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := w.Write(PRDataLong)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := io.WriteString(w, "Hello, World!\n")
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
body, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(body)
Expect(err).NotTo(HaveOccurred())
})
}
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
func GeneratePRData(l int) []byte {
res := make([]byte, l)
seed := uint64(1)
for i := 0; i < l; i++ {
seed = seed * 48271 % 2147483647
res[i] = byte(seed)
}
return res
}
// StartQuicServer starts a h2quic.Server.
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
func StartQuicServer(versions []protocol.VersionNumber) {
server = &h2quic.Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
QuicConfig: &quic.Config{
Versions: versions,
},
}
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
Expect(err).NotTo(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
stoppedServing = make(chan struct{})
go func() {
defer GinkgoRecover()
server.Serve(conn)
close(stoppedServing)
}()
}
// StopQuicServer stops the h2quic.Server.
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed())
}
// Port returns the UDP port of the QUIC server.
func Port() string {
return port
}

View file

@ -16,8 +16,16 @@ type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
// VersionGQUIC39 is gQUIC version 39.
const VersionGQUIC39 = protocol.Version39
const (
// VersionGQUIC39 is gQUIC version 39.
VersionGQUIC39 = protocol.Version39
// VersionGQUIC43 is gQUIC version 43.
VersionGQUIC43 = protocol.Version43
// VersionGQUIC44 is gQUIC version 44.
VersionGQUIC44 = protocol.Version44
// VersionMilestone0_10_0 uses TLS
VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0
)
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
@ -139,8 +147,11 @@ type Session interface {
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error
// Close the connection.
io.Closer
// Close the connection with an error.
// The error must not be nil.
CloseWithError(ErrorCode, error) error
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
@ -159,6 +170,13 @@ type Config struct {
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDOmission bool
// The length of the connection ID in bytes. Only valid for IETF QUIC.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.

View file

@ -29,7 +29,8 @@ type SentPacketHandler interface {
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() (packet *Packet)
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
GetAlarmTimeout() time.Time

View file

@ -25,6 +25,8 @@ type receivedPacketHandler struct {
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
@ -52,11 +54,16 @@ const (
)
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler {
func NewReceivedPacketHandler(
rttStats *congestion.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: ackSendDelay,
rttStats: rttStats,
logger: logger,
version: version,
}
}
@ -82,16 +89,22 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p)
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
@ -99,7 +112,7 @@ func (h *receivedPacketHandler) hasNewMissingPackets() bool {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
@ -110,6 +123,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
// always ack the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.ackQueued = true
return
}
@ -118,6 +132,9 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
}
h.ackQueued = true
}
@ -128,26 +145,41 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
}
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
} else {
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
}
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)
ackTime := rcvTime.Add(time.Duration(ackDelay))
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
ackTime := rcvTime.Add(ackDelay)
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
}
}
@ -159,19 +191,17 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
now := time.Now()
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
ackRanges := h.packetHistory.GetAckRanges()
ack := &wire.AckFrame{
LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].First,
PacketReceivedTime: h.largestObservedReceivedTime,
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
if len(ackRanges) > 1 {
ack.AckRanges = ackRanges
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
DelayTime: now.Sub(h.largestObservedReceivedTime),
}
h.lastAck = ack

View file

@ -104,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
ackRanges := make([]wire.AckRange, h.ranges.Len())
i := 0
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
i++
}
return ackRanges
@ -114,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.First = r.Start
ackRange.Last = r.End
ackRange.Smallest = r.Start
ackRange.Largest = r.End
}
return ackRange
}

View file

@ -14,7 +14,9 @@ const (
SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendAny packet should be sent
// SendTLP means that a TLP probe packet should be sent
SendTLP
// SendAny means that any packet should be sent
SendAny
)
@ -28,6 +30,8 @@ func (s SendMode) String() string {
return "retransmission"
case SendRTO:
return "rto"
case SendTLP:
return "tlp"
case SendAny:
return "any"
default:

View file

@ -1,6 +1,7 @@
package ackhandler
import (
"errors"
"fmt"
"math"
"time"
@ -16,13 +17,12 @@ const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// The default RTT used before an RTT sample is taken.
// Note: This constant is also defined in the congestion package.
defaultInitialRTT = 100 * time.Millisecond
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Maximum number of tail loss probes before an RTO fires.
maxTLPs = 2
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
@ -59,6 +59,10 @@ type sentPacketHandler struct {
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times a TLP has been sent without receiving an ack.
tlpCount uint32
allowTLP bool
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The number of RTO probe packets that should be sent.
@ -71,10 +75,12 @@ type sentPacketHandler struct {
alarm time.Time
logger utils.Logger
version protocol.VersionNumber
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
@ -89,6 +95,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
@ -100,6 +107,7 @@ func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
@ -150,7 +158,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
packet.largestAcked = ackFrame.LargestAcked
packet.largestAcked = ackFrame.LargestAcked()
}
}
@ -168,6 +176,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
if h.numRTOs > 0 {
h.numRTOs--
}
h.allowTLP = false
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
@ -176,7 +185,8 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
if ackFrame.LargestAcked > h.lastSentPacketNumber {
largestAcked := ackFrame.LargestAcked()
if largestAcked > h.lastSentPacketNumber {
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
}
@ -186,13 +196,13 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
return nil
}
h.largestReceivedPacketWithAck = withPacketNumber
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, ackFrame.LargestAcked)
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
if h.skippedPacketsAcked(ackFrame) {
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
if rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
h.congestion.MaybeExitSlowStart()
}
@ -212,11 +222,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
if p.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
}
if err := h.onPacketAcked(p); err != nil {
if err := h.onPacketAcked(p, rcvTime); err != nil {
return err
}
if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight)
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
}
@ -238,27 +248,29 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
var ackedPackets []*Packet
ackRangeIndex := 0
lowestAcked := ackFrame.LowestAcked()
largestAcked := ackFrame.LargestAcked()
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the LowestAcked
if p.PacketNumber < ackFrame.LowestAcked {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
}
// Break after LargestAcked is reached
if p.PacketNumber > ackFrame.LargestAcked {
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber >= ackRange.First { // packet i contained in ACK range
if p.PacketNumber > ackRange.Last {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.First, ackRange.Last)
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
ackedPackets = append(ackedPackets, p)
}
@ -267,12 +279,22 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame)
}
return true, nil
})
if h.logger.Debug() && len(ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(ackedPackets))
for i, p := range ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
}
return ackedPackets, err
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
return true
}
return false
@ -280,20 +302,25 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
if !h.packetHistory.HasOutstandingPackets() {
h.alarm = time.Time{}
return
}
// TODO(#497): TLP
if !h.handshakeComplete {
if h.packetHistory.HasOutstandingHandshakePackets() {
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO
h.alarm = h.lastSentRetransmittablePacketTime.Add(h.computeRTOTimeout())
// RTO or TLP alarm
alarmDuration := h.computeRTOTimeout()
if h.tlpCount < maxTLPs {
tlpAlarm := h.computeTLPTimeout()
// if the RTO duration is shorter than the TLP duration, use the RTO duration
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
}
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
}
}
@ -313,11 +340,21 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, packet)
} else if h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
}
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
return true, nil
})
if h.logger.Debug() && len(lostPackets) > 0 {
pns := make([]protocol.PacketNumber, len(lostPackets))
for i, p := range lostPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
}
for _, p := range lostPackets {
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
@ -327,7 +364,6 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto
}
if p.canBeRetransmitted {
// queue the packet for retransmission, and report the loss to the congestion controller
h.logger.Debugf("\tQueueing packet %#x because it was detected lost", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
@ -338,34 +374,57 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto
}
func (h *sentPacketHandler) OnAlarm() error {
now := time.Now()
// TODO(#497): TLP
var err error
if !h.handshakeComplete {
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
} else if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
err = h.detectLostPackets(now, h.bytesInFlight)
} else {
// RTO
h.rtoCount++
h.numRTOs += 2
err = h.queueRTOs()
}
if err != nil {
return err
// When all outstanding are acknowledged, the alarm is canceled in
// updateLossDetectionAlarm. This doesn't reset the timer in the session though.
// When OnAlarm is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.packetHistory.HasOutstandingPackets() {
if err := h.onVerifiedAlarm(); err != nil {
return err
}
}
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) onVerifiedAlarm() error {
var err error
if h.packetHistory.HasOutstandingHandshakePackets() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
}
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
} else if !h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
}
// Early retransmit or time loss detection
err = h.detectLostPackets(time.Now(), h.bytesInFlight)
} else if h.tlpCount < maxTLPs { // TLP
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount)
}
h.allowTLP = true
h.tlpCount++
} else { // RTO
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount)
}
if h.rtoCount == 0 {
h.largestSentBeforeRTO = h.lastSentPacketNumber
}
h.rtoCount++
h.numRTOs += 2
}
return err
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
// This happens if a packet and its retransmissions is acked in the same ACK.
// As soon as we process the first one, this will remove all the retransmissions,
// so we won't find the retransmitted packet number later.
@ -404,8 +463,8 @@ func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
return err
}
h.rtoCount = 0
h.tlpCount = 0
h.handshakeCount = 0
// TODO(#497): h.tlpCount = 0
return h.packetHistory.Remove(p.PacketNumber)
}
@ -447,8 +506,21 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
return packet
}
func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
if len(h.retransmissionQueue) == 0 {
p := h.packetHistory.FirstOutstanding()
if p == nil {
return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
}
if err := h.queuePacketForRetransmission(p); err != nil {
return nil, err
}
}
return h.DequeuePacketForRetransmission(), nil
}
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
@ -463,15 +535,22 @@ func (h *sentPacketHandler) SendMode() SendMode {
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.allowTLP {
return SendTLP
}
if h.numRTOs > 0 {
return SendRTO
}
// Only send ACKs if we're congestion limited.
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
}
return SendAck
}
// Send retransmissions first, if there are any.
@ -479,7 +558,9 @@ func (h *sentPacketHandler) SendMode() SendMode {
return SendRetransmission
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
return SendAny
@ -501,23 +582,6 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
// retransmit the oldest two packets
func (h *sentPacketHandler) queueRTOs() error {
h.largestSentBeforeRTO = h.lastSentPacketNumber
// Queue the first two outstanding packets for retransmission.
// This does NOT declare this packets as lost:
// They are still tracked in the packet history and count towards the bytes in flight.
for i := 0; i < 2; i++ {
if p := h.packetHistory.FirstOutstanding(); p != nil {
h.logger.Debugf("\tQueueing packet %#x for retransmission (RTO)", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
}
return nil
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
@ -527,7 +591,7 @@ func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("\tQueueing packet %#x as a handshake retransmission", p.PacketNumber)
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
@ -548,16 +612,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := 2 * h.rttStats.SmoothedRTT()
if duration == 0 {
duration = 2 * defaultInitialRTT
}
duration = utils.MaxDuration(duration, minTPLTimeout)
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
// TODO(#1236): include the max_ack_delay
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
var rto time.Duration
rtt := h.rttStats.SmoothedRTT()

View file

@ -10,6 +10,9 @@ type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
numOutstandingPackets int
numOutstandingHandshakePackets int
firstOutstanding *PacketElement
}
@ -30,6 +33,12 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets++
}
}
return el
}
@ -92,6 +101,18 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
@ -121,7 +142,27 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstandingPackets > 0
}
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
return h.numOutstandingHandshakePackets > 0
}

View file

@ -30,8 +30,9 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFr
}
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
if ack.LargestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = ack.LargestAcked + 1
largestAcked := ack.LargestAcked()
if largestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = largestAcked + 1
}
}

View file

@ -16,11 +16,10 @@ import (
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling
// round trip time.
// where 0.100 is 100 ms which is the scaling round trip time.
const cubeScale = 40
const cubeCongestionWindowScale = 410
const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
const defaultNumConnections = 2
@ -32,39 +31,35 @@ const beta float32 = 0.7
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// If true, Cubic's epoch is shifted when the sender is application-limited.
const shiftQuicCubicEpochWhenAppLimited = true
const maxCubicTimeInterval = 30 * time.Millisecond
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Time when sender went into application-limited period. Zero if not in
// application-limited period.
appLimitedStartTime time.Time
// Time when we updated last_congestion_window.
lastUpdateTime time.Time
// Last congestion window (in packets) used.
lastCongestionWindow protocol.PacketNumber
// Max congestion window (in packets) used just before last loss event.
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.PacketNumber
// Number of acked packets since the cycle started (epoch).
ackedPacketsCount protocol.PacketNumber
lastMaxCongestionWindow protocol.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.PacketNumber
estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function.
originPointCongestionWindow protocol.PacketNumber
originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.PacketNumber
lastTargetCongestionWindow protocol.ByteCount
}
// NewCubic returns a new Cubic instance
@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic {
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.appLimitedStartTime = time.Time{}
c.lastUpdateTime = time.Time{}
c.lastCongestionWindow = 0
c.lastMaxCongestionWindow = 0
c.ackedPacketsCount = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 {
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
if shiftQuicCubicEpochWhenAppLimited {
// When sender is not using the available congestion window, Cubic's epoch
// should not continue growing. Record the time when sender goes into an
// app-limited period here, to compensate later when cwnd growth happens.
if c.appLimitedStartTime.IsZero() {
c.appLimitedStartTime = c.clock.Now()
}
} else {
// When sender is not using the available congestion window, Cubic's epoch
// should not continue growing. Reset the epoch when in such a period.
c.epoch = time.Time{}
}
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber {
if currentCongestionWindow < c.lastMaxCongestionWindow {
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow))
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta())
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber {
c.ackedPacketsCount++ // Packets acked.
currentTime := c.clock.Now()
// Cubic is "independent" of RTT, the update is limited by the time elapsed.
if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) {
return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow)
}
c.lastCongestionWindow = currentCongestionWindow
c.lastUpdateTime = currentTime
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes protocol.ByteCount,
currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) protocol.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = currentTime // Start of epoch.
c.ackedPacketsCount = 1 // Reset count.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
} else {
// If sender was app-limited, then freeze congestion window growth during
// app-limited period. Continue growth now by shifting the epoch-start
// through the app-limited period.
if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() {
shift := currentTime.Sub(c.appLimitedStartTime)
c.epoch = c.epoch.Add(shift)
c.appLimitedStartTime = time.Time{}
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
// Right-shifts of negative, signed numbers have
// implementation-dependent behavior. Force the offset to be
// positive, similar to the kernel implementation.
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale)
var targetCongestionWindow protocol.PacketNumber
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// With dynamic beta/alpha based on number of active streams, it is possible
// for the required_ack_count to become much lower than acked_packets_count_
// suddenly, leading to more than one iteration through the following loop.
for {
// Update estimated TCP congestion_window.
requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha())
if c.ackedPacketsCount < requiredAckCount {
break
}
c.ackedPacketsCount -= requiredAckCount
c.estimatedTCPcongestionWindow++
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}

View file

@ -8,9 +8,9 @@ import (
)
const (
maxBurstBytes = 3 * protocol.DefaultTCPMSS
defaultMinimumCongestionWindow protocol.PacketNumber = 2
renoBeta float32 = 0.7 // Reno backoff factor.
maxBurstBytes = 3 * protocol.DefaultTCPMSS
renoBeta float32 = 0.7 // Reno backoff factor.
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
)
type cubicSender struct {
@ -31,12 +31,6 @@ type cubicSender struct {
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber
// Congestion window in packets.
congestionWindow protocol.PacketNumber
// Slow start congestion window in packets, aka ssthresh.
slowstartThreshold protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
@ -44,24 +38,35 @@ type cubicSender struct {
// When true, exit slow start with large cutback of congestion window.
slowStartLargeReduction bool
// Minimum congestion window in packets.
minCongestionWindow protocol.PacketNumber
// Congestion window in packets.
congestionWindow protocol.ByteCount
// Maximum number of outstanding packets for tcp.
maxTCPCongestionWindow protocol.PacketNumber
// Minimum congestion window in packets.
minCongestionWindow protocol.ByteCount
// Maximum congestion window.
maxCongestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowstartThreshold protocol.ByteCount
// Number of connections to simulate.
numConnections int
// ACK counter for the Reno implementation.
congestionWindowCount protocol.ByteCount
numAckedPackets uint64
initialCongestionWindow protocol.PacketNumber
initialMaxCongestionWindow protocol.PacketNumber
initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.ByteCount
minSlowStartExitWindow protocol.ByteCount
}
var _ SendAlgorithm = &cubicSender{}
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
// NewCubicSender makes a new cubic sender
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo {
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
return &cubicSender{
rttStats: rttStats,
initialCongestionWindow: initialCongestionWindow,
@ -69,7 +74,7 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
congestionWindow: initialCongestionWindow,
minCongestionWindow: defaultMinimumCongestionWindow,
slowstartThreshold: initialMaxCongestionWindow,
maxTCPCongestionWindow: initialMaxCongestionWindow,
maxCongestionWindow: initialMaxCongestionWindow,
numConnections: defaultNumConnections,
cubic: NewCubic(clock),
reno: reno,
@ -80,21 +85,26 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
if c.InRecovery() {
// PRR is used when in recovery.
if c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) == 0 {
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
return 0
}
}
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()/protocol.DefaultTCPMSS)
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
delay = delay * 8 / 5
}
return delay
}
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
// Only update bytesInFlight for data packets.
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
bytesInFlight protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
if !isRetransmittable {
return false
return
}
if c.InRecovery() {
// PRR is used when in recovery.
@ -102,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
return true
}
func (c *cubicSender) InRecovery() bool {
@ -114,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool {
}
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS
return c.congestionWindow
}
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS
return c.slowstartThreshold
}
func (c *cubicSender) ExitSlowstart() {
c.slowstartThreshold = c.congestionWindow
}
func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
return c.slowstartThreshold
}
@ -135,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() {
}
}
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
// PRR is used when in recovery.
c.prr.OnPacketAcked(ackedBytes)
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight)
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) OnPacketLost(
packetNumber protocol.PacketNumber,
lostBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
@ -156,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.stats.slowstartPacketsLost++
c.stats.slowstartBytesLost += lostBytes
if c.slowStartLargeReduction {
if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS {
// Reduce congestion window by 1 for every mss of bytes lost.
c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow)
}
// Reduce congestion window by lost_bytes for every loss.
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
c.slowstartThreshold = c.congestionWindow
}
}
@ -170,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.stats.slowstartPacketsLost++
}
c.prr.OnPacketLost(bytesInFlight)
c.prr.OnPacketLost(priorInFlight)
// TODO(chromium): Separate out all of slow start into a separate class.
if c.slowStartLargeReduction && c.InSlowStart() {
c.congestionWindow = c.congestionWindow - 1
if c.congestionWindow >= 2*c.initialCongestionWindow {
c.minSlowStartExitWindow = c.congestionWindow / 2
}
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
} else if c.reno {
c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta())
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
// Enforce a minimum congestion window.
if c.congestionWindow < c.minCongestionWindow {
c.congestionWindow = c.minCongestionWindow
}
@ -188,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.congestionWindowCount = 0
c.numAckedPackets = 0
}
func (c *cubicSender) RenoBeta() float32 {
@ -201,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 {
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) maybeIncreaseCwnd(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(bytesInFlight) {
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
return
}
if c.congestionWindow >= c.maxTCPCongestionWindow {
if c.congestionWindow >= c.maxCongestionWindow {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow++
c.congestionWindow += protocol.DefaultTCPMSS
return
}
// Congestion avoidance
if c.reno {
// Classic Reno congestion avoidance.
c.congestionWindowCount++
c.numAckedPackets++
// Divide by num_connections to smoothly increase the CWND at a faster
// rate than conventional Reno.
if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow {
c.congestionWindow++
c.congestionWindowCount = 0
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
c.congestionWindow += protocol.DefaultTCPMSS
c.numAckedPackets = 0
}
} else {
c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT()))
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
@ -282,10 +306,10 @@ func (c *cubicSender) OnConnectionMigration() {
c.largestSentAtLastCutback = 0
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.congestionWindowCount = 0
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowstartThreshold = c.initialMaxCongestionWindow
c.maxTCPCongestionWindow = c.initialMaxCongestionWindow
c.maxCongestionWindow = c.initialMaxCongestionWindow
}
// SetSlowStartLargeReduction allows enabling the SSLR experiment

View file

@ -9,11 +9,11 @@ import (
// A SendAlgorithm performs congestion control and calculates the congestion window
type SendAlgorithm interface {
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
GetCongestionWindow() protocol.ByteCount
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration()
@ -30,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface {
// Stuff only used in testing
HybridSlowStart() *HybridSlowStart
SlowstartThreshold() protocol.PacketNumber
SlowstartThreshold() protocol.ByteCount
RenoBeta() float32
InRecovery() bool
}

View file

@ -1,10 +1,7 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
// OnPacketLost should be called on the first loss that triggers a recovery
// period and all other methods in this class should only be called when in
// recovery.
func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) {
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
p.bytesSentSinceLoss = 0
p.bytesInFlightBeforeLoss = bytesInFlight
p.bytesInFlightBeforeLoss = priorInFlight
p.bytesDeliveredSinceLoss = 0
p.ackCountSinceLoss = 0
}
@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
p.ackCountSinceLoss++
}
// TimeUntilSend calculates the time until a packet can be sent
func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration {
// CanSend returns if packets can be sent
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
// Return QuicTime::Zero In order to ensure limited transmit always works.
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
return 0
return true
}
if congestionWindow > bytesInFlight {
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
// of sending the entire available window. This prevents burst retransmits
// when more packets are lost than the CWND reduction.
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss {
return utils.InfDuration
}
return 0
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
}
// Implement Proportional Rate Reduction (RFC6937).
// Checks a simplified version of the PRR formula that doesn't use division:
// AvailableSendWindow =
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss {
return 0
}
return utils.InfDuration
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
}

View file

@ -7,50 +7,27 @@ import (
)
const (
// Note: This constant is also defined in the ackhandler package.
initialRTTus = 100 * 1000
rttAlpha float32 = 0.125
oneMinusAlpha float32 = (1 - rttAlpha)
rttBeta float32 = 0.25
oneMinusBeta float32 = (1 - rttBeta)
halfWindow float32 = 0.5
quarterWindow float32 = 0.25
// The default RTT used before an RTT sample is taken.
defaultInitialRTT = 100 * time.Millisecond
)
type rttSample struct {
rtt time.Duration
time time.Time
}
// RTTStats provides round-trip statistics
type RTTStats struct {
initialRTTus int64
recentMinRTTwindow time.Duration
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
numMinRTTsamplesRemaining uint32
newMinRTT rttSample
recentMinRTT rttSample
halfWindowRTT rttSample
quarterWindowRTT rttSample
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{
initialRTTus: initialRTTus,
recentMinRTTwindow: utils.InfDuration,
}
return &RTTStats{}
}
// InitialRTTus is the initial RTT in us
func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus }
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
@ -59,28 +36,22 @@ func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the
// minRTT for the entire connection if SampleNewMinRtt was never called.
func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt }
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// GetQuarterWindowRTT gets the quarter window RTT
func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt }
// GetHalfWindowRTT gets the half window RTT
func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt }
// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection.
// If no valid updates have occurred, it returns the initial RTT.
func (r *RTTStats) SmoothedOrInitialRTT() time.Duration {
if r.smoothedRTT != 0 {
return r.smoothedRTT
}
return defaultInitialRTT
}
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// SetRecentMinRTTwindow sets how old a recent min rtt sample can be.
func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
r.recentMinRTTwindow = recentMinRTTwindow
}
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 {
@ -94,7 +65,6 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
r.updateRecentMinRTT(sendDelta, now)
// Correct for ackDelay if information received from the peer results in a
// an RTT sample at least as large as minRTT. Otherwise, only use the
@ -114,63 +84,12 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
}
}
func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update.
if r.numMinRTTsamplesRemaining > 0 {
r.numMinRTTsamplesRemaining--
if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt {
r.newMinRTT = rttSample{rtt: sample, time: now}
}
if r.numMinRTTsamplesRemaining == 0 {
r.recentMinRTT = r.newMinRTT
r.halfWindowRTT = r.newMinRTT
r.quarterWindowRTT = r.newMinRTT
}
}
// Update the three recent rtt samples.
if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt {
r.recentMinRTT = rttSample{rtt: sample, time: now}
r.halfWindowRTT = r.recentMinRTT
r.quarterWindowRTT = r.recentMinRTT
} else if sample <= r.halfWindowRTT.rtt {
r.halfWindowRTT = rttSample{rtt: sample, time: now}
r.quarterWindowRTT = r.halfWindowRTT
} else if sample <= r.quarterWindowRTT.rtt {
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
}
// Expire old min rtt samples.
if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) {
r.recentMinRTT = r.halfWindowRTT
r.halfWindowRTT = r.quarterWindowRTT
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
} else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) {
r.halfWindowRTT = r.quarterWindowRTT
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
} else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) {
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
}
}
// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next
// |numSamples| UpdateRTT calls.
func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) {
r.numMinRTTsamplesRemaining = numSamples
r.newMinRTT = rttSample{}
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
r.initialRTTus = initialRTTus
r.numMinRTTsamplesRemaining = 0
r.recentMinRTTwindow = utils.InfDuration
r.recentMinRTT = rttSample{}
r.halfWindowRTT = rttSample{}
r.quarterWindowRTT = rttSample{}
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt

View file

@ -15,7 +15,7 @@ const (
// A TLSExporter gets the negotiated ciphersuite and computes exporter
type TLSExporter interface {
GetCipherSuite() mint.CipherSuiteParams
ConnectionState() mint.ConnectionState
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
}
@ -49,7 +49,7 @@ func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
}
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
cs := tls.GetCipherSuite()
cs := tls.ConnectionState().CipherSuite
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
if err != nil {
return nil, nil, err

View file

@ -6,7 +6,6 @@ import (
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"golang.org/x/crypto/hkdf"
)
@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
} else {
info.Write([]byte("QUIC key expansion\x00"))
}
utils.BigEndian.WriteUint64(&info, uint64(connID))
info.Write(connID)
info.Write(chlo)
info.Write(scfg)
info.Write(cert)

View file

@ -2,7 +2,6 @@ package crypto
import (
"crypto"
"encoding/binary"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -28,9 +27,7 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
}
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
connID := make([]byte, 8)
binary.BigEndian.PutUint64(connID, uint64(connectionID))
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())

View file

@ -11,8 +11,9 @@ import (
type baseFlowController struct {
// for sending data
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastBlockedAt protocol.ByteCount
// for receiving data
mutex sync.RWMutex
@ -29,6 +30,17 @@ type baseFlowController struct {
logger utils.Logger
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}

View file

@ -10,8 +10,9 @@ import (
)
type connectionFlowController struct {
lastBlockedAt protocol.ByteCount
baseFlowController
queueWindowUpdate func()
}
var _ ConnectionFlowController = &connectionFlowController{}
@ -21,6 +22,7 @@ var _ ConnectionFlowController = &connectionFlowController{}
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
rttStats *congestion.RTTStats,
logger utils.Logger,
) ConnectionFlowController {
@ -32,6 +34,7 @@ func NewConnectionFlowController(
maxReceiveWindowSize: maxReceiveWindow,
logger: logger,
},
queueWindowUpdate: queueWindowUpdate,
}
}
@ -39,17 +42,6 @@ func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
return c.baseFlowController.sendWindowSize()
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock()
@ -62,6 +54,15 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
return nil
}
func (c *connectionFlowController) MaybeQueueWindowUpdate() {
c.mutex.Lock()
hasWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if hasWindowUpdate {
c.queueWindowUpdate()
}
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
@ -78,6 +79,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock()
if inc > c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
c.startNewAutoTuningEpoch()
}

View file

@ -10,26 +10,22 @@ type flowController interface {
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
MaybeQueueWindowUpdate() // queues a window update, if necessary
IsNewlyBlocked() (bool, protocol.ByteCount)
}
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
// for sending
IsBlocked() (bool, protocol.ByteCount)
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// HasWindowUpdate says if it is necessary to update the window
HasWindowUpdate() bool
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
// for sending
IsNewlyBlocked() (bool, protocol.ByteCount)
}
type connectionFlowControllerI interface {

View file

@ -14,6 +14,8 @@ type streamFlowController struct {
streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI
contributesToConnection bool // does the stream contribute to connection level flow control
@ -30,6 +32,7 @@ func NewStreamFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *congestion.RTTStats,
logger utils.Logger,
) StreamFlowController {
@ -37,6 +40,7 @@ func NewStreamFlowController(
streamID: streamID,
contributesToConnection: contributesToConnection,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
@ -111,20 +115,16 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return window
}
// IsBlocked says if it is blocked by stream-level flow control.
// If it is blocked, the offset is returned.
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 {
return false, 0
}
return true, c.sendWindow
}
func (c *streamFlowController) HasWindowUpdate() bool {
func (c *streamFlowController) MaybeQueueWindowUpdate() {
c.mutex.Lock()
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
c.mutex.Unlock()
return hasWindowUpdate
if hasWindowUpdate {
c.queueWindowUpdate()
}
if c.contributesToConnection {
c.connection.MaybeQueueWindowUpdate()
}
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
@ -139,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}

View file

@ -5,8 +5,6 @@ import (
"fmt"
"net"
"time"
"github.com/bifurcation/mint"
)
const (
@ -29,12 +27,12 @@ type token struct {
// A CookieGenerator generates Cookies
type CookieGenerator struct {
cookieProtector mint.CookieProtector
cookieProtector cookieProtector
}
// NewCookieGenerator initializes a new CookieGenerator
func NewCookieGenerator() (*CookieGenerator, error) {
cookieProtector, err := mint.NewDefaultCookieProtector()
cookieProtector, err := newCookieProtector()
if err != nil {
return nil, err
}

View file

@ -1,51 +0,0 @@
package handshake
import (
"net"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A CookieHandler generates and validates cookies.
// The cookie is sent in the TLS Retry.
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
type CookieHandler struct {
callback func(net.Addr, *Cookie) bool
cookieGenerator *CookieGenerator
logger utils.Logger
}
var _ mint.CookieHandler = &CookieHandler{}
// NewCookieHandler creates a new CookieHandler.
func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}
return &CookieHandler{
callback: callback,
cookieGenerator: cookieGenerator,
logger: logger,
}, nil
}
// Generate a new cookie for a mint connection.
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
if h.callback(conn.RemoteAddr(), nil) {
return nil, nil
}
return h.cookieGenerator.NewToken(conn.RemoteAddr())
}
// Validate a cookie.
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token)
if err != nil {
h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
return false
}
return h.callback(conn.RemoteAddr(), data)
}

View file

@ -0,0 +1,86 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// CookieProtector is used to create and verify a cookie
type cookieProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const (
cookieSecretSize = 32
cookieNonceSize = 32
)
// cookieProtector is used to create and verify a cookie
type cookieProtectorImpl struct {
secret []byte
}
// newCookieProtector creates a source for source address tokens
func newCookieProtector() (cookieProtector, error) {
secret := make([]byte, cookieSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &cookieProtectorImpl{secret: secret}, nil
}
// NewToken encodes data into a new token.
func (s *cookieProtectorImpl) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, cookieNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
}
// DecodeToken decodes a token.
func (s *cookieProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
if len(p) < cookieNonceSize {
return nil, fmt.Errorf("Token too short: %d", len(p))
}
nonce := p[:cookieNonceSize]
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
}
func (s *cookieProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go cookie source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err
}
aeadNonce := make([]byte, 12)
if _, err := io.ReadFull(h, aeadNonce); err != nil {
return nil, nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, nil, err
}
return aead, aeadNonce, nil
}

View file

@ -38,7 +38,7 @@ type cryptoSetupClient struct {
lastSentCHLO []byte
certManager crypto.CertManager
divNonceChan <-chan []byte
divNonceChan chan struct{}
diversificationNonce []byte
clientHelloCounter int
@ -79,29 +79,31 @@ func NewCryptoSetupClient(
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
logger utils.Logger,
) (CryptoSetup, chan<- []byte, error) {
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, nil, err
return nil, err
}
divNonceChan := make(chan []byte)
divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: hostname,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConfig),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions,
cryptoStream: cryptoStream,
hostname: hostname,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConfig),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
// The server might have sent greased versions in the Version Negotiation packet.
// We need strip those from the list, since they won't be included in the handshake tag.
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
divNonceChan: divNonceChan,
logger: logger,
}
return cs, divNonceChan, nil
return cs, nil
}
func (h *cryptoSetupClient) HandleCryptoStream() error {
@ -120,33 +122,26 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
}()
for {
err := h.maybeUpgradeCrypto()
if err != nil {
if err := h.maybeUpgradeCrypto(); err != nil {
return err
}
h.mutex.RLock()
sendCHLO := h.secureAEAD == nil
h.mutex.RUnlock()
if sendCHLO {
err = h.sendCHLO()
if err != nil {
if err := h.sendCHLO(); err != nil {
return err
}
}
var message HandshakeMessage
select {
case divNonce := <-h.divNonceChan:
if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
h.diversificationNonce = divNonce
case <-h.divNonceChan:
// there's no message to process, but we should try upgrading the crypto again
continue
case message = <-messageChan:
case err = <-errorChan:
case err := <-errorChan:
return err
}
@ -281,6 +276,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*Trans
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
params, err := readHelloMap(cryptoData)
if err != nil {
@ -326,6 +322,7 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
if h.secureAEAD != nil {
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return data, protocol.EncryptionSecure, nil
}
@ -386,6 +383,21 @@ func (h *cryptoSetupClient) ConnectionState() ConnectionState {
}
}
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
h.mutex.Lock()
if len(h.diversificationNonce) > 0 {
defer h.mutex.Unlock()
if !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
return nil
}
h.diversificationNonce = divNonce
h.mutex.Unlock()
h.divNonceChan <- struct{}{}
return nil
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {
@ -501,6 +513,7 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if err != nil {
return err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
}
return nil

View file

@ -60,11 +60,6 @@ type cryptoSetupServer struct {
var _ CryptoSetup = &cryptoSetupServer{}
// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 36, which we don't support.
// TODO: remove this when dropping support for QUIC 36
var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported")
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
@ -132,9 +127,6 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
}
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment {
return false, ErrHOLExperiment
}
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
return false, ErrNSTPExperiment
}
@ -214,6 +206,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
h.receivedForwardSecurePacket = true
// wait for the send on the handshakeEvent chan
<-h.sentSHLO
@ -228,6 +221,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
if h.secureAEAD != nil {
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return res, protocol.EncryptionSecure, nil
}
@ -400,6 +394,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
// Generate a new curve instance to derive the forward secure key
@ -429,6 +424,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
replyMap := h.params.getHelloMap()
// add crypto parameters

View file

@ -11,9 +11,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
// KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
@ -26,8 +23,8 @@ type cryptoSetupTLS struct {
nullAEAD crypto.AEAD
aead crypto.AEAD
tls MintTLS
cryptoStream *CryptoStreamConn
tls mintTLS
conn *cryptoStreamConn
handshakeEvent chan<- struct{}
}
@ -35,39 +32,46 @@ var _ CryptoSetupTLS = &cryptoSetupTLS{}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
tls MintTLS,
cryptoStream *CryptoStreamConn,
nullAEAD crypto.AEAD,
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) CryptoSetupTLS {
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Server(conn, config)
return &cryptoSetupTLS{
tls: tls,
cryptoStream: cryptoStream,
conn: conn,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}
}, nil
}
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
hostname string,
config *mint.Config,
handshakeEvent chan<- struct{},
tls MintTLS,
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Client(conn, config)
return &cryptoSetupTLS{
perspective: protocol.PerspectiveClient,
tls: tls,
conn: conn,
perspective: protocol.PerspectiveClient,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
@ -75,24 +79,16 @@ func NewCryptoSetupTLSClient(
}
func (h *cryptoSetupTLS) HandleCryptoStream() error {
if h.perspective == protocol.PerspectiveServer {
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
// send out that data now
if _, err := h.cryptoStream.Flush(); err != nil {
return err
}
}
handshakeLoop:
for {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
switch h.tls.State() {
case mint.StateClientStart: // this happens if a stateless retry is performed
return ErrCloseSessionForRetry
case mint.StateClientConnected, mint.StateServerConnected:
break handshakeLoop
state := h.tls.ConnectionState().HandshakeState
if err := h.conn.Flush(); err != nil {
return err
}
if state == mint.StateClientConnected || state == mint.StateServerConnected {
break
}
}

View file

@ -7,95 +7,63 @@ import (
"time"
)
// The CryptoStreamConn is used as the net.Conn passed to mint.
// It has two operating modes:
// 1. It can read and write to bytes.Buffers.
// 2. It can use a quic.Stream for reading and writing.
// The buffer-mode is only used by the server, in order to statelessly handle retries.
type CryptoStreamConn struct {
remoteAddr net.Addr
// the buffers are used before the session is initialized
readBuf bytes.Buffer
writeBuf bytes.Buffer
// stream will be set once the session is initialized
type cryptoStreamConn struct {
buffer *bytes.Buffer
stream io.ReadWriter
}
var _ net.Conn = &CryptoStreamConn{}
var _ net.Conn = &cryptoStreamConn{}
// NewCryptoStreamConn creates a new CryptoStreamConn
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
return &CryptoStreamConn{remoteAddr: remoteAddr}
}
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
if c.stream != nil {
return c.stream.Read(b)
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
return &cryptoStreamConn{
stream: stream,
buffer: &bytes.Buffer{},
}
return c.readBuf.Read(b)
}
// AddDataForReading adds data to the read buffer.
// This data will ONLY be read when the stream has not been set.
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
c.readBuf.Write(data)
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
return c.stream.Read(b)
}
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
if c.stream != nil {
return c.stream.Write(p)
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
return c.buffer.Write(p)
}
func (c *cryptoStreamConn) Flush() error {
if c.buffer.Len() == 0 {
return nil
}
return c.writeBuf.Write(p)
}
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
func (c *CryptoStreamConn) GetDataForWriting() []byte {
defer c.writeBuf.Reset()
data := make([]byte, c.writeBuf.Len())
copy(data, c.writeBuf.Bytes())
return data
}
// SetStream sets the stream.
// After setting the stream, the read and write buffer won't be used any more.
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
c.stream = stream
}
// Flush copies the contents of the write buffer to the stream
func (c *CryptoStreamConn) Flush() (int, error) {
n, err := io.Copy(c.stream, &c.writeBuf)
return int(n), err
_, err := c.stream.Write(c.buffer.Bytes())
c.buffer.Reset()
return err
}
// Close is not implemented
func (c *CryptoStreamConn) Close() error {
func (c *cryptoStreamConn) Close() error {
return nil
}
// LocalAddr is not implemented
func (c *CryptoStreamConn) LocalAddr() net.Addr {
func (c *cryptoStreamConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr returns the remote address
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
return c.remoteAddr
// RemoteAddr is not implemented
func (c *cryptoStreamConn) RemoteAddr() net.Addr {
return nil
}
// SetReadDeadline is not implemented
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is not implemented
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// SetDeadline is not implemented
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
func (c *cryptoStreamConn) SetDeadline(time.Time) error {
return nil
}

View file

@ -35,7 +35,7 @@ func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.Lock()
defer kexMutex.Unlock()
// Check if still unfulfilled
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err

View file

@ -2,7 +2,6 @@ package handshake
import (
"crypto/x509"
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
@ -15,6 +14,12 @@ type Sealer interface {
Overhead() int
}
// mintTLS combines some methods needed to interact with mint.
type mintTLS interface {
crypto.TLSExporter
Handshake() mint.Alert
}
// A TLSExtensionHandler sends and received the QUIC TLS extension.
// It provides the parameters sent by the peer on a channel.
type TLSExtensionHandler interface {
@ -23,18 +28,6 @@ type TLSExtensionHandler interface {
GetPeerParams() <-chan TransportParameters
}
// MintTLS combines some methods needed to interact with mint.
type MintTLS interface {
crypto.TLSExporter
// additional methods
Handshake() mint.Alert
State() mint.State
ConnectionState() mint.ConnectionState
SetCryptoStream(io.ReadWriter)
}
type baseCryptoSetup interface {
HandleCryptoStream() error
ConnectionState() ConnectionState

View file

@ -0,0 +1,3 @@
package handshake
//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"

View file

@ -50,10 +50,6 @@ const (
// TagSFCW is the initial stream flow control receive window.
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagFHL2 forces head of line blocking.
// Chrome experiment (see https://codereview.chromium.org/2115033002)
// unsupported by quic-go
TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24
// TagNSTP is the no STOP_WAITING experiment
// currently unsupported by quic-go
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24

View file

@ -1,38 +1,106 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type transportParameterID uint16
const quicTLSExtensionType = 26
const quicTLSExtensionType = 0xff5
const (
initialMaxStreamDataParameterID transportParameterID = 0x0
initialMaxDataParameterID transportParameterID = 0x1
initialMaxStreamsBiDiParameterID transportParameterID = 0x2
initialMaxBidiStreamsParameterID transportParameterID = 0x2
idleTimeoutParameterID transportParameterID = 0x3
omitConnectionIDParameterID transportParameterID = 0x4
maxPacketSizeParameterID transportParameterID = 0x5
statelessResetTokenParameterID transportParameterID = 0x6
initialMaxStreamsUniParameterID transportParameterID = 0x8
initialMaxUniStreamsParameterID transportParameterID = 0x8
disableMigrationParameterID transportParameterID = 0x9
)
type transportParameter struct {
Parameter transportParameterID
Value []byte `tls:"head=2"`
type clientHelloTransportParameters struct {
InitialVersion protocol.VersionNumber
Parameters TransportParameters
}
type clientHelloTransportParameters struct {
InitialVersion uint32 // actually a protocol.VersionNumber
Parameters []transportParameter `tls:"head=2"`
func (p *clientHelloTransportParameters) Marshal() []byte {
const lenOffset = 4
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(p.InitialVersion))
b.Write([]byte{0, 0}) // length. Will be replaced later
p.Parameters.marshal(b)
data := b.Bytes()
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
return data
}
func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
if len(data) < 6 {
return errors.New("transport parameter data too short")
}
p.InitialVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
paramsLen := int(binary.BigEndian.Uint16(data[4:6]))
data = data[6:]
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
}
type encryptedExtensionsTransportParameters struct {
NegotiatedVersion uint32 // actually a protocol.VersionNumber
SupportedVersions []uint32 `tls:"head=1"` // actually a protocol.VersionNumber
Parameters []transportParameter `tls:"head=2"`
NegotiatedVersion protocol.VersionNumber
SupportedVersions []protocol.VersionNumber
Parameters TransportParameters
}
func (p *encryptedExtensionsTransportParameters) Marshal() []byte {
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(p.NegotiatedVersion))
b.WriteByte(uint8(4 * len(p.SupportedVersions)))
for _, v := range p.SupportedVersions {
utils.BigEndian.WriteUint32(b, uint32(v))
}
lenOffset := b.Len()
b.Write([]byte{0, 0}) // length. Will be replaced later
p.Parameters.marshal(b)
data := b.Bytes()
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
return data
}
func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
if len(data) < 5 {
return errors.New("transport parameter data too short")
}
p.NegotiatedVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
numVersions := int(data[4])
if numVersions%4 != 0 {
return fmt.Errorf("invalid length for version list: %d", numVersions)
}
numVersions /= 4
data = data[5:]
if len(data) < 4*numVersions+2 /*length field for the parameter list */ {
return errors.New("transport parameter data too short")
}
p.SupportedVersions = make([]protocol.VersionNumber, numVersions)
for i := 0; i < numVersions; i++ {
p.SupportedVersions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
data = data[4:]
}
paramsLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
}
type tlsExtensionBody struct {

View file

@ -7,7 +7,6 @@ import (
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/bifurcation/mint/syntax"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -52,16 +51,12 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
if hType != mint.HandshakeTypeClientHello {
return nil
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
data, err := syntax.Marshal(clientHelloTransportParameters{
InitialVersion: uint32(h.initialVersion),
Parameters: h.ourParams.getTransportParameters(),
})
if err != nil {
return err
chtp := &clientHelloTransportParameters{
InitialVersion: h.initialVersion,
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data})
return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
}
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
@ -84,50 +79,31 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
}
eetp := &encryptedExtensionsTransportParameters{}
if _, err := syntax.Unmarshal(ext.data, eetp); err != nil {
if err := eetp.Unmarshal(ext.data); err != nil {
return err
}
serverSupportedVersions := make([]protocol.VersionNumber, len(eetp.SupportedVersions))
for i, v := range eetp.SupportedVersions {
serverSupportedVersions[i] = protocol.VersionNumber(v)
}
// check that the negotiated_version is the current version
if protocol.VersionNumber(eetp.NegotiatedVersion) != h.version {
if eetp.NegotiatedVersion != h.version {
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
}
// check that the current version is included in the supported versions
if !protocol.IsSupportedVersion(serverSupportedVersions, h.version) {
if !protocol.IsSupportedVersion(eetp.SupportedVersions, h.version) {
return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions")
}
// if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server
if h.version != h.initialVersion {
negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, serverSupportedVersions)
negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, eetp.SupportedVersions)
if !ok || h.version != negotiatedVersion {
return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version")
}
}
// check that the server sent the stateless reset token
var foundStatelessResetToken bool
for _, p := range eetp.Parameters {
if p.Parameter == statelessResetTokenParameterID {
if len(p.Value) != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", len(p.Value))
}
foundStatelessResetToken = true
// TODO: handle this value
}
}
if !foundStatelessResetToken {
// TODO: return the right error here
// check that the server sent a stateless reset token
if len(eetp.Parameters.StatelessResetToken) == 0 {
return errors.New("server didn't sent stateless_reset_token")
}
params, err := readTransportParameters(eetp.Parameters)
if err != nil {
return err
}
h.logger.Debugf("Received Transport Parameters: %s", params)
h.paramsChan <- *params
h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
h.paramsChan <- eetp.Parameters
return nil
}

View file

@ -1,14 +1,12 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/bifurcation/mint/syntax"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -49,27 +47,13 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
if hType != mint.HandshakeTypeEncryptedExtensions {
return nil
}
transportParams := append(
h.ourParams.getTransportParameters(),
// TODO(#855): generate a real token
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
)
supportedVersions := protocol.GetGreasedVersions(h.supportedVersions)
versions := make([]uint32, len(supportedVersions))
for i, v := range supportedVersions {
versions[i] = uint32(v)
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
data, err := syntax.Marshal(encryptedExtensionsTransportParameters{
NegotiatedVersion: uint32(h.version),
SupportedVersions: versions,
Parameters: transportParams,
})
if err != nil {
return err
eetp := &encryptedExtensionsTransportParameters{
NegotiatedVersion: h.version,
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data})
return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
}
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
@ -90,30 +74,24 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
return errors.New("ClientHello didn't contain a QUIC extension")
}
chtp := &clientHelloTransportParameters{}
if _, err := syntax.Unmarshal(ext.data, chtp); err != nil {
if err := chtp.Unmarshal(ext.data); err != nil {
return err
}
initialVersion := protocol.VersionNumber(chtp.InitialVersion)
// perform the stateless version negotiation validation:
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
// this is the case if and only if the initial version is not contained in the supported versions
if initialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) {
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
}
for _, p := range chtp.Parameters {
if p.Parameter == statelessResetTokenParameterID {
// TODO: return the correct error type
return errors.New("client sent a stateless reset token")
}
// check that the client didn't send a stateless reset token
if len(chtp.Parameters.StatelessResetToken) != 0 {
// TODO: return the correct error type
return errors.New("client sent a stateless reset token")
}
params, err := readTransportParameters(chtp.Parameters)
if err != nil {
return err
}
h.logger.Debugf("Received Transport Parameters: %s", params)
h.paramsChan <- *params
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
h.paramsChan <- chtp.Parameters
return nil
}

View file

@ -26,8 +26,10 @@ type TransportParameters struct {
MaxBidiStreams uint16 // only used for IETF QUIC
MaxStreams uint32 // only used for gQUIC
OmitConnectionID bool
IdleTimeout time.Duration
OmitConnectionID bool // only used for gQUIC
IdleTimeout time.Duration
DisableMigration bool // only used for IETF QUIC
StatelessResetToken []byte // only used for IETF QUIC
}
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
@ -94,98 +96,114 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte {
return tags
}
// readTransportParameters reads the transport parameters sent in the QUIC TLS extension
func readTransportParameters(paramsList []transportParameter) (*TransportParameters, error) {
params := &TransportParameters{}
var foundInitialMaxStreamData bool
var foundInitialMaxData bool
func (p *TransportParameters) unmarshal(data []byte) error {
var foundIdleTimeout bool
for _, p := range paramsList {
switch p.Parameter {
for len(data) >= 4 {
paramID := binary.BigEndian.Uint16(data[:2])
paramLen := int(binary.BigEndian.Uint16(data[2:4]))
data = data[4:]
if len(data) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
}
switch transportParameterID(paramID) {
case initialMaxStreamDataParameterID:
foundInitialMaxStreamData = true
if len(p.Value) != 4 {
return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value))
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
}
params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxDataParameterID:
foundInitialMaxData = true
if len(p.Value) != 4 {
return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
}
params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
case initialMaxStreamsBiDiParameterID:
if len(p.Value) != 2 {
return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", len(p.Value))
p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxBidiStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
}
params.MaxBidiStreams = binary.BigEndian.Uint16(p.Value)
case initialMaxStreamsUniParameterID:
if len(p.Value) != 2 {
return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", len(p.Value))
p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
case initialMaxUniStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
}
params.MaxUniStreams = binary.BigEndian.Uint16(p.Value)
p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
case idleTimeoutParameterID:
foundIdleTimeout = true
if len(p.Value) != 2 {
return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value))
if paramLen != 2 {
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
}
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second)
case omitConnectionIDParameterID:
if len(p.Value) != 0 {
return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value))
}
params.OmitConnectionID = true
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
case maxPacketSizeParameterID:
if len(p.Value) != 2 {
return nil, fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", len(p.Value))
if paramLen != 2 {
return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
}
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(p.Value))
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
if maxPacketSize < 1200 {
return nil, fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
}
params.MaxPacketSize = maxPacketSize
p.MaxPacketSize = maxPacketSize
case disableMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
}
p.DisableMigration = true
case statelessResetTokenParameterID:
if paramLen != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
p.StatelessResetToken = data[:16]
}
data = data[paramLen:]
}
if !(foundInitialMaxStreamData && foundInitialMaxData && foundIdleTimeout) {
return nil, errors.New("missing parameter")
if len(data) != 0 {
return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
}
return params, nil
if !foundIdleTimeout {
return errors.New("missing parameter")
}
return nil
}
// GetTransportParameters gets the parameters needed for the TLS handshake.
// It doesn't send the initial_max_stream_id_uni parameter, so the peer isn't allowed to open any unidirectional streams.
func (p *TransportParameters) getTransportParameters() []transportParameter {
initialMaxStreamData := make([]byte, 4)
binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow))
initialMaxData := make([]byte, 4)
binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow))
initialMaxBidiStreamID := make([]byte, 2)
binary.BigEndian.PutUint16(initialMaxBidiStreamID, p.MaxBidiStreams)
initialMaxUniStreamID := make([]byte, 2)
binary.BigEndian.PutUint16(initialMaxUniStreamID, p.MaxUniStreams)
idleTimeout := make([]byte, 2)
binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second))
maxPacketSize := make([]byte, 2)
binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize))
params := []transportParameter{
{initialMaxStreamDataParameterID, initialMaxStreamData},
{initialMaxDataParameterID, initialMaxData},
{initialMaxStreamsBiDiParameterID, initialMaxBidiStreamID},
{initialMaxStreamsUniParameterID, initialMaxUniStreamID},
{idleTimeoutParameterID, idleTimeout},
{maxPacketSizeParameterID, maxPacketSize},
func (p *TransportParameters) marshal(b *bytes.Buffer) {
// initial_max_stream_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
// initial_max_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
// initial_max_bidi_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
// initial_max_uni_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
// idle_timeout
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
// max_packet_size
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
// disable_migration
if p.DisableMigration {
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
utils.BigEndian.WriteUint16(b, 0)
}
if p.OmitConnectionID {
params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}})
if len(p.StatelessResetToken) > 0 {
utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID))
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
b.Write(p.StatelessResetToken)
}
return params
}
// String returns a string representation, intended for logging.
// It should only used for IETF QUIC.
func (p *TransportParameters) String() string {
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, OmitConnectionID: %t, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.OmitConnectionID, p.IdleTimeout)
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
}

View file

@ -49,6 +49,19 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission))
}
// DequeueProbePacket mocks base method
func (m *MockSentPacketHandler) DequeueProbePacket() (*ackhandler.Packet, error) {
ret := m.ctrl.Call(m, "DequeueProbePacket")
ret0, _ := ret[0].(*ackhandler.Packet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DequeueProbePacket indicates an expected call of DequeueProbePacket
func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket))
}
// GetAlarmTimeout mocks base method
func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time {
ret := m.ctrl.Call(m, "GetAlarmTimeout")

View file

@ -68,13 +68,13 @@ func (mr *MockSendAlgorithmMockRecorder) OnConnectionMigration() *gomock.Call {
}
// OnPacketAcked mocks base method
func (m *MockSendAlgorithm) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) {
m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2)
func (m *MockSendAlgorithm) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) {
m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3)
}
// OnPacketAcked indicates an expected call of OnPacketAcked
func (mr *MockSendAlgorithmMockRecorder) OnPacketAcked(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketAcked), arg0, arg1, arg2)
func (mr *MockSendAlgorithmMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketAcked), arg0, arg1, arg2, arg3)
}
// OnPacketLost mocks base method
@ -88,10 +88,8 @@ func (mr *MockSendAlgorithmMockRecorder) OnPacketLost(arg0, arg1, arg2 interface
}
// OnPacketSent mocks base method
func (m *MockSendAlgorithm) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) bool {
ret := m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(bool)
return ret0
func (m *MockSendAlgorithm) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) {
m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4)
}
// OnPacketSent indicates an expected call of OnPacketSent

View file

@ -79,6 +79,16 @@ func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked))
}
// MaybeQueueWindowUpdate mocks base method
func (m *MockConnectionFlowController) MaybeQueueWindowUpdate() {
m.ctrl.Call(m, "MaybeQueueWindowUpdate")
}
// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate
func (mr *MockConnectionFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).MaybeQueueWindowUpdate))
}
// SendWindowSize mocks base method
func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount {
ret := m.ctrl.Call(m, "SendWindowSize")

View file

@ -1,11 +0,0 @@
package mocks
//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS"
//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
//go:generate sh -c "./mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm"
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
//go:generate sh -c "goimports -w ."

View file

@ -1,107 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
// Package mockhandshake is a generated GoMock package.
package mockhandshake
import (
io "io"
reflect "reflect"
mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock"
)
// MockMintTLS is a mock of MintTLS interface
type MockMintTLS struct {
ctrl *gomock.Controller
recorder *MockMintTLSMockRecorder
}
// MockMintTLSMockRecorder is the mock recorder for MockMintTLS
type MockMintTLSMockRecorder struct {
mock *MockMintTLS
}
// NewMockMintTLS creates a new mock instance
func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
mock := &MockMintTLS{ctrl: ctrl}
mock.recorder = &MockMintTLSMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
return m.recorder
}
// ComputeExporter mocks base method
func (m *MockMintTLS) ComputeExporter(arg0 string, arg1 []byte, arg2 int) ([]byte, error) {
ret := m.ctrl.Call(m, "ComputeExporter", arg0, arg1, arg2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ComputeExporter indicates an expected call of ComputeExporter
func (mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
}
// ConnectionState mocks base method
func (m *MockMintTLS) ConnectionState() mint.ConnectionState {
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(mint.ConnectionState)
return ret0
}
// ConnectionState indicates an expected call of ConnectionState
func (mr *MockMintTLSMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockMintTLS)(nil).ConnectionState))
}
// GetCipherSuite mocks base method
func (m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams {
ret := m.ctrl.Call(m, "GetCipherSuite")
ret0, _ := ret[0].(mint.CipherSuiteParams)
return ret0
}
// GetCipherSuite indicates an expected call of GetCipherSuite
func (mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite))
}
// Handshake mocks base method
func (m *MockMintTLS) Handshake() mint.Alert {
ret := m.ctrl.Call(m, "Handshake")
ret0, _ := ret[0].(mint.Alert)
return ret0
}
// Handshake indicates an expected call of Handshake
func (mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
}
// SetCryptoStream mocks base method
func (m *MockMintTLS) SetCryptoStream(arg0 io.ReadWriter) {
m.ctrl.Call(m, "SetCryptoStream", arg0)
}
// SetCryptoStream indicates an expected call of SetCryptoStream
func (mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0)
}
// State mocks base method
func (m *MockMintTLS) State() mint.State {
ret := m.ctrl.Call(m, "State")
ret0, _ := ret[0].(mint.State)
return ret0
}
// State indicates an expected call of State
func (mr *MockMintTLSMockRecorder) State() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State))
}

View file

@ -0,0 +1,9 @@
package mocks
//go:generate sh -c "../mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm"
//go:generate sh -c "../mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "../mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"

View file

@ -66,29 +66,27 @@ func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate))
}
// HasWindowUpdate mocks base method
func (m *MockStreamFlowController) HasWindowUpdate() bool {
ret := m.ctrl.Call(m, "HasWindowUpdate")
ret0, _ := ret[0].(bool)
return ret0
}
// HasWindowUpdate indicates an expected call of HasWindowUpdate
func (mr *MockStreamFlowControllerMockRecorder) HasWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).HasWindowUpdate))
}
// IsBlocked mocks base method
func (m *MockStreamFlowController) IsBlocked() (bool, protocol.ByteCount) {
ret := m.ctrl.Call(m, "IsBlocked")
// IsNewlyBlocked mocks base method
func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
ret := m.ctrl.Call(m, "IsNewlyBlocked")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(protocol.ByteCount)
return ret0, ret1
}
// IsBlocked indicates an expected call of IsBlocked
func (mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked))
// IsNewlyBlocked indicates an expected call of IsNewlyBlocked
func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked))
}
// MaybeQueueWindowUpdate mocks base method
func (m *MockStreamFlowController) MaybeQueueWindowUpdate() {
m.ctrl.Call(m, "MaybeQueueWindowUpdate")
}
// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate
func (mr *MockStreamFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).MaybeQueueWindowUpdate))
}
// SendWindowSize mocks base method

Some files were not shown because too many files have changed in this diff Show more