diff --git a/vendor/github.com/aead/chacha20/chacha/chacha.go b/vendor/github.com/aead/chacha20/chacha/chacha.go index 8c387a97..c2b39da4 100644 --- a/vendor/github.com/aead/chacha20/chacha/chacha.go +++ b/vendor/github.com/aead/chacha20/chacha/chacha.go @@ -9,6 +9,7 @@ package chacha // import "github.com/aead/chacha20/chacha" import ( "encoding/binary" "errors" + "math" ) const ( @@ -28,6 +29,7 @@ const ( var ( useSSE2 bool useSSSE3 bool + useAVX bool useAVX2 bool ) @@ -55,7 +57,7 @@ func setup(state *[64]byte, nonce, key []byte) (err error) { copy(hNonce[:], nonce[:16]) copy(tmpKey[:], key) - hChaCha20(&tmpKey, &hNonce, &tmpKey) + HChaCha20(&tmpKey, &hNonce, &tmpKey) copy(Nonce[8:], nonce[16:]) initialize(state, tmpKey[:], &Nonce) @@ -161,6 +163,21 @@ func (c *Cipher) XORKeyStream(dst, src []byte) { c.off = 0 } + // check for counter overflow + blocksToXOR := len(src) / 64 + if len(src)%64 != 0 { + blocksToXOR++ + } + var overflow bool + if c.noncesize == INonceSize { + overflow = binary.LittleEndian.Uint32(c.state[48:]) > math.MaxUint32-uint32(blocksToXOR) + } else { + overflow = binary.LittleEndian.Uint64(c.state[48:]) > math.MaxUint64-uint64(blocksToXOR) + } + if overflow { + panic("chacha20/chacha: counter overflow") + } + c.off += xorKeyStream(dst, src, &(c.block), &(c.state), c.rounds) } @@ -174,3 +191,7 @@ func (c *Cipher) SetCounter(ctr uint64) { } c.off = 0 } + +// HChaCha20 generates 32 pseudo-random bytes from a 128 bit nonce and a 256 bit secret key. +// It can be used as a key-derivation-function (KDF). +func HChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { hChaCha20(out, nonce, key) } diff --git a/vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s b/vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s index 8d022332..c2b5f52f 100644 --- a/vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s +++ b/vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s @@ -2,111 +2,10 @@ // Use of this source code is governed by a license that can be // found in the LICENSE file. -// +build go1.7,amd64,!gccgo,!appengine,!nacl +// +build amd64,!gccgo,!appengine,!nacl -#include "textflag.h" - -DATA ·sigma_AVX<>+0x00(SB)/4, $0x61707865 -DATA ·sigma_AVX<>+0x04(SB)/4, $0x3320646e -DATA ·sigma_AVX<>+0x08(SB)/4, $0x79622d32 -DATA ·sigma_AVX<>+0x0C(SB)/4, $0x6b206574 -GLOBL ·sigma_AVX<>(SB), (NOPTR+RODATA), $16 - -DATA ·one_AVX<>+0x00(SB)/8, $1 -DATA ·one_AVX<>+0x08(SB)/8, $0 -GLOBL ·one_AVX<>(SB), (NOPTR+RODATA), $16 - -DATA ·one_AVX2<>+0x00(SB)/8, $0 -DATA ·one_AVX2<>+0x08(SB)/8, $0 -DATA ·one_AVX2<>+0x10(SB)/8, $1 -DATA ·one_AVX2<>+0x18(SB)/8, $0 -GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 - -DATA ·two_AVX2<>+0x00(SB)/8, $2 -DATA ·two_AVX2<>+0x08(SB)/8, $0 -DATA ·two_AVX2<>+0x10(SB)/8, $2 -DATA ·two_AVX2<>+0x18(SB)/8, $0 -GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32 - -DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302 -DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A -DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302 -DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A -GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 - -DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003 -DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B -DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003 -DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B -GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 - -#define ROTL(n, t, v) \ - VPSLLD $n, v, t; \ - VPSRLD $(32-n), v, v; \ - VPXOR v, t, v - -#define CHACHA_QROUND(v0, v1, v2, v3, t, c16, c8) \ - VPADDD v0, v1, v0; \ - VPXOR v3, v0, v3; \ - VPSHUFB c16, v3, v3; \ - VPADDD v2, v3, v2; \ - VPXOR v1, v2, v1; \ - ROTL(12, t, v1); \ - VPADDD v0, v1, v0; \ - VPXOR v3, v0, v3; \ - VPSHUFB c8, v3, v3; \ - VPADDD v2, v3, v2; \ - VPXOR v1, v2, v1; \ - ROTL(7, t, v1) - -#define CHACHA_SHUFFLE(v1, v2, v3) \ - VPSHUFD $0x39, v1, v1; \ - VPSHUFD $0x4E, v2, v2; \ - VPSHUFD $-109, v3, v3 - -#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ - VMOVDQU (0+off)(src), t0; \ - VPERM2I128 $32, v1, v0, t1; \ - VPXOR t0, t1, t0; \ - VMOVDQU t0, (0+off)(dst); \ - VMOVDQU (32+off)(src), t0; \ - VPERM2I128 $32, v3, v2, t1; \ - VPXOR t0, t1, t0; \ - VMOVDQU t0, (32+off)(dst); \ - VMOVDQU (64+off)(src), t0; \ - VPERM2I128 $49, v1, v0, t1; \ - VPXOR t0, t1, t0; \ - VMOVDQU t0, (64+off)(dst); \ - VMOVDQU (96+off)(src), t0; \ - VPERM2I128 $49, v3, v2, t1; \ - VPXOR t0, t1, t0; \ - VMOVDQU t0, (96+off)(dst) - -#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ - VMOVDQU (0+off)(src), t0; \ - VPERM2I128 $32, v1, v0, t1; \ - VPXOR t0, t1, t0; \ - VMOVDQU t0, (0+off)(dst); \ - VMOVDQU (32+off)(src), t0; \ - VPERM2I128 $32, v3, v2, t1; \ - VPXOR t0, t1, t0; \ - VMOVDQU t0, (32+off)(dst); \ - -#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \ - VPERM2I128 $49, v1, v0, t0; \ - VMOVDQU t0, 0(dst); \ - VPERM2I128 $49, v3, v2, t0; \ - VMOVDQU t0, 32(dst) - -#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t0) \ - VPXOR 0+off(src), v0, t0; \ - VMOVDQU t0, 0+off(dst); \ - VPXOR 16+off(src), v1, t0; \ - VMOVDQU t0, 16+off(dst); \ - VPXOR 32+off(src), v2, t0; \ - VMOVDQU t0, 32+off(dst); \ - VPXOR 48+off(src), v3, t0; \ - VMOVDQU t0, 48+off(dst) +#include "const.s" +#include "macro.s" #define TWO 0(SP) #define C16 32(SP) @@ -122,10 +21,10 @@ GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 TEXT ·xorKeyStreamAVX2(SB), 4, $320-80 MOVQ dst_base+0(FP), DI MOVQ src_base+24(FP), SI - MOVQ src_len+32(FP), CX MOVQ block+48(FP), BX MOVQ state+56(FP), AX MOVQ rounds+64(FP), DX + MOVQ src_len+32(FP), CX MOVQ SP, R8 ADDQ $32, SP @@ -185,28 +84,28 @@ at_least_512: chacha_loop_512: VMOVDQA Y8, TMP_0 - CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8) + CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8) VMOVDQA TMP_0, Y8 VMOVDQA Y0, TMP_0 - CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8) - CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8) - CHACHA_SHUFFLE(Y1, Y2, Y3) - CHACHA_SHUFFLE(Y5, Y6, Y7) - CHACHA_SHUFFLE(Y9, Y10, Y11) - CHACHA_SHUFFLE(Y13, Y14, Y15) + CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8) + CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8) + CHACHA_SHUFFLE_AVX(Y1, Y2, Y3) + CHACHA_SHUFFLE_AVX(Y5, Y6, Y7) + CHACHA_SHUFFLE_AVX(Y9, Y10, Y11) + CHACHA_SHUFFLE_AVX(Y13, Y14, Y15) - CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8) - CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8) + CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8) + CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8) VMOVDQA TMP_0, Y0 VMOVDQA Y8, TMP_0 - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8) - CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8) + CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8) VMOVDQA TMP_0, Y8 - CHACHA_SHUFFLE(Y3, Y2, Y1) - CHACHA_SHUFFLE(Y7, Y6, Y5) - CHACHA_SHUFFLE(Y11, Y10, Y9) - CHACHA_SHUFFLE(Y15, Y14, Y13) + CHACHA_SHUFFLE_AVX(Y3, Y2, Y1) + CHACHA_SHUFFLE_AVX(Y7, Y6, Y5) + CHACHA_SHUFFLE_AVX(Y11, Y10, Y9) + CHACHA_SHUFFLE_AVX(Y15, Y14, Y13) SUBQ $2, R9 JA chacha_loop_512 @@ -289,18 +188,18 @@ between_320_and_448: MOVQ DX, R9 chacha_loop_384: - CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) - CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) - CHACHA_SHUFFLE(Y1, Y2, Y3) - CHACHA_SHUFFLE(Y5, Y6, Y7) - CHACHA_SHUFFLE(Y9, Y10, Y11) - CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) - CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) - CHACHA_SHUFFLE(Y3, Y2, Y1) - CHACHA_SHUFFLE(Y7, Y6, Y5) - CHACHA_SHUFFLE(Y11, Y10, Y9) + CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE_AVX(Y1, Y2, Y3) + CHACHA_SHUFFLE_AVX(Y5, Y6, Y7) + CHACHA_SHUFFLE_AVX(Y9, Y10, Y11) + CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE_AVX(Y3, Y2, Y1) + CHACHA_SHUFFLE_AVX(Y7, Y6, Y5) + CHACHA_SHUFFLE_AVX(Y11, Y10, Y9) SUBQ $2, R9 JA chacha_loop_384 @@ -361,14 +260,14 @@ between_192_and_320: MOVQ DX, R9 chacha_loop_256: - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) - CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) - CHACHA_SHUFFLE(Y5, Y6, Y7) - CHACHA_SHUFFLE(Y9, Y10, Y11) - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) - CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) - CHACHA_SHUFFLE(Y7, Y6, Y5) - CHACHA_SHUFFLE(Y11, Y10, Y9) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE_AVX(Y5, Y6, Y7) + CHACHA_SHUFFLE_AVX(Y9, Y10, Y11) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE_AVX(Y7, Y6, Y5) + CHACHA_SHUFFLE_AVX(Y11, Y10, Y9) SUBQ $2, R9 JA chacha_loop_256 @@ -413,10 +312,10 @@ between_64_and_192: MOVQ DX, R9 chacha_loop_128: - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) - CHACHA_SHUFFLE(Y5, Y6, Y7) - CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) - CHACHA_SHUFFLE(Y7, Y6, Y5) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_SHUFFLE_AVX(Y5, Y6, Y7) + CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_SHUFFLE_AVX(Y7, Y6, Y5) SUBQ $2, R9 JA chacha_loop_128 @@ -455,10 +354,10 @@ between_0_and_64: MOVQ DX, R9 chacha_loop_64: - CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) - CHACHA_SHUFFLE(X7, X6, X5) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15) + CHACHA_SHUFFLE_AVX(X5, X6, X7) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15) + CHACHA_SHUFFLE_AVX(X7, X6, X5) SUBQ $2, R9 JA chacha_loop_64 @@ -466,7 +365,7 @@ chacha_loop_64: VPADDD X1, X5, X5 VPADDD X2, X6, X6 VPADDD X3, X7, X7 - VMOVDQU ·one_AVX<>(SB), X0 + VMOVDQU ·one<>(SB), X0 VPADDQ X0, X3, X3 CMPQ CX, $64 @@ -505,38 +404,3 @@ done: MOVQ CX, ret+72(FP) RET -// func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte) -TEXT ·hChaCha20AVX(SB), 4, $0-24 - MOVQ out+0(FP), DI - MOVQ nonce+8(FP), AX - MOVQ key+16(FP), BX - - VMOVDQU ·sigma_AVX<>(SB), X0 - VMOVDQU 0(BX), X1 - VMOVDQU 16(BX), X2 - VMOVDQU 0(AX), X3 - VMOVDQU ·rol16_AVX2<>(SB), X5 - VMOVDQU ·rol8_AVX2<>(SB), X6 - - MOVQ $20, CX - -chacha_loop: - CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) - CHACHA_SHUFFLE(X1, X2, X3) - CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) - CHACHA_SHUFFLE(X3, X2, X1) - SUBQ $2, CX - JNZ chacha_loop - - VMOVDQU X0, 0(DI) - VMOVDQU X3, 16(DI) - VZEROUPPER - RET - -// func supportsAVX2() bool -TEXT ·supportsAVX2(SB), 4, $0-1 - MOVQ runtime·support_avx(SB), AX - MOVQ runtime·support_avx2(SB), BX - ANDQ AX, BX - MOVB BX, ret+0(FP) - RET diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_386.go b/vendor/github.com/aead/chacha20/chacha/chacha_386.go index e3135efb..97e533d3 100644 --- a/vendor/github.com/aead/chacha20/chacha/chacha_386.go +++ b/vendor/github.com/aead/chacha20/chacha/chacha_386.go @@ -6,11 +6,16 @@ package chacha -import "encoding/binary" +import ( + "encoding/binary" + + "golang.org/x/sys/cpu" +) func init() { - useSSE2 = supportsSSE2() - useSSSE3 = supportsSSSE3() + useSSE2 = cpu.X86.HasSSE2 + useSSSE3 = cpu.X86.HasSSSE3 + useAVX = false useAVX2 = false } @@ -23,14 +28,6 @@ func initialize(state *[64]byte, key []byte, nonce *[16]byte) { copy(state[48:], nonce[:]) } -// This function is implemented in chacha_386.s -//go:noescape -func supportsSSE2() bool - -// This function is implemented in chacha_386.s -//go:noescape -func supportsSSSE3() bool - // This function is implemented in chacha_386.s //go:noescape func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) @@ -43,25 +40,21 @@ func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) //go:noescape func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int -// This function is implemented in chacha_386.s -//go:noescape -func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int - func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { - if useSSSE3 { + switch { + case useSSSE3: hChaCha20SSSE3(out, nonce, key) - } else if useSSE2 { + case useSSE2: hChaCha20SSE2(out, nonce, key) - } else { + default: hChaCha20Generic(out, nonce, key) } } func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { - if useSSSE3 { - return xorKeyStreamSSSE3(dst, src, block, state, rounds) - } else if useSSE2 { + if useSSE2 { return xorKeyStreamSSE2(dst, src, block, state, rounds) + } else { + return xorKeyStreamGeneric(dst, src, block, state, rounds) } - return xorKeyStreamGeneric(dst, src, block, state, rounds) } diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_386.s b/vendor/github.com/aead/chacha20/chacha/chacha_386.s index d7bba759..262fc869 100644 --- a/vendor/github.com/aead/chacha20/chacha/chacha_386.s +++ b/vendor/github.com/aead/chacha20/chacha/chacha_386.s @@ -4,126 +4,125 @@ // +build 386,!gccgo,!appengine,!nacl -#include "textflag.h" - -DATA ·sigma<>+0x00(SB)/4, $0x61707865 -DATA ·sigma<>+0x04(SB)/4, $0x3320646e -DATA ·sigma<>+0x08(SB)/4, $0x79622d32 -DATA ·sigma<>+0x0C(SB)/4, $0x6b206574 -GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 - -DATA ·one<>+0x00(SB)/8, $1 -DATA ·one<>+0x08(SB)/8, $0 -GLOBL ·one<>(SB), (NOPTR+RODATA), $16 - -DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 -DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A -GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 - -DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 -DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B -GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 - -#define ROTL_SSE2(n, t, v) \ - MOVO v, t; \ - PSLLL $n, t; \ - PSRLL $(32-n), v; \ - PXOR t, v - -#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \ - PADDL v1, v0; \ - PXOR v0, v3; \ - ROTL_SSE2(16, t0, v3); \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(12, t0, v1); \ - PADDL v1, v0; \ - PXOR v0, v3; \ - ROTL_SSE2(8, t0, v3); \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(7, t0, v1) - -#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \ - PADDL v1, v0; \ - PXOR v0, v3; \ - PSHUFB r16, v3; \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(12, t0, v1); \ - PADDL v1, v0; \ - PXOR v0, v3; \ - PSHUFB r8, v3; \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(7, t0, v1) - -#define CHACHA_SHUFFLE(v1, v2, v3) \ - PSHUFL $0x39, v1, v1; \ - PSHUFL $0x4E, v2, v2; \ - PSHUFL $0x93, v3, v3 - -#define XOR(dst, src, off, v0, v1, v2, v3, t0) \ - MOVOU 0+off(src), t0; \ - PXOR v0, t0; \ - MOVOU t0, 0+off(dst); \ - MOVOU 16+off(src), t0; \ - PXOR v1, t0; \ - MOVOU t0, 16+off(dst); \ - MOVOU 32+off(src), t0; \ - PXOR v2, t0; \ - MOVOU t0, 32+off(dst); \ - MOVOU 48+off(src), t0; \ - PXOR v3, t0; \ - MOVOU t0, 48+off(dst) +#include "const.s" +#include "macro.s" +// FINALIZE xors len bytes from src and block using +// the temp. registers t0 and t1 and writes the result +// to dst. #define FINALIZE(dst, src, block, len, t0, t1) \ - XORL t0, t0; \ - XORL t1, t1; \ - finalize: \ - MOVB 0(src), t0; \ - MOVB 0(block), t1; \ - XORL t0, t1; \ - MOVB t1, 0(dst); \ - INCL src; \ - INCL block; \ - INCL dst; \ - DECL len; \ - JA finalize \ + XORL t0, t0; \ + XORL t1, t1; \ + FINALIZE_LOOP:; \ + MOVB 0(src), t0; \ + MOVB 0(block), t1; \ + XORL t0, t1; \ + MOVB t1, 0(dst); \ + INCL src; \ + INCL block; \ + INCL dst; \ + DECL len; \ + JG FINALIZE_LOOP \ + +#define Dst DI +#define Nonce AX +#define Key BX +#define Rounds DX + +// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20SSE2(SB), 4, $0-12 + MOVL out+0(FP), Dst + MOVL nonce+4(FP), Nonce + MOVL key+8(FP), Key + + MOVOU ·sigma<>(SB), X0 + MOVOU 0*16(Key), X1 + MOVOU 1*16(Key), X2 + MOVOU 0*16(Nonce), X3 + MOVL $20, Rounds + +chacha_loop: + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_SHUFFLE_SSE(X1, X2, X3) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_SHUFFLE_SSE(X3, X2, X1) + SUBL $2, Rounds + JNZ chacha_loop + + MOVOU X0, 0*16(Dst) + MOVOU X3, 1*16(Dst) + RET + +// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20SSSE3(SB), 4, $0-12 + MOVL out+0(FP), Dst + MOVL nonce+4(FP), Nonce + MOVL key+8(FP), Key + + MOVOU ·sigma<>(SB), X0 + MOVOU 0*16(Key), X1 + MOVOU 1*16(Key), X2 + MOVOU 0*16(Nonce), X3 + MOVL $20, Rounds + + MOVOU ·rol16<>(SB), X5 + MOVOU ·rol8<>(SB), X6 + +chacha_loop: + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE_SSE(X1, X2, X3) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE_SSE(X3, X2, X1) + SUBL $2, Rounds + JNZ chacha_loop + + MOVOU X0, 0*16(Dst) + MOVOU X3, 1*16(Dst) + RET + +#undef Dst +#undef Nonce +#undef Key +#undef Rounds + +#define State AX +#define Dst DI +#define Src SI +#define Len DX +#define Tmp0 BX +#define Tmp1 BP // func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int TEXT ·xorKeyStreamSSE2(SB), 4, $0-40 - MOVL dst_base+0(FP), DI - MOVL src_base+12(FP), SI - MOVL src_len+16(FP), CX - MOVL state+28(FP), AX - MOVL rounds+32(FP), DX + MOVL dst_base+0(FP), Dst + MOVL src_base+12(FP), Src + MOVL state+28(FP), State + MOVL src_len+16(FP), Len + MOVL $0, ret+36(FP) // Number of bytes written to the keystream buffer - 0 iff len mod 64 == 0 - MOVOU 0(AX), X0 - MOVOU 16(AX), X1 - MOVOU 32(AX), X2 - MOVOU 48(AX), X3 + MOVOU 0*16(State), X0 + MOVOU 1*16(State), X1 + MOVOU 2*16(State), X2 + MOVOU 3*16(State), X3 + TESTL Len, Len + JZ DONE - TESTL CX, CX - JZ done - -at_least_64: +GENERATE_KEYSTREAM: MOVO X0, X4 MOVO X1, X5 MOVO X2, X6 MOVO X3, X7 + MOVL rounds+32(FP), Tmp0 - MOVL DX, BX - -chacha_loop: +CHACHA_LOOP: CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) - CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_SHUFFLE_SSE(X5, X6, X7) CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) - CHACHA_SHUFFLE(X7, X6, X5) - SUBL $2, BX - JA chacha_loop + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBL $2, Tmp0 + JA CHACHA_LOOP - MOVOU 0(AX), X0 + MOVOU 0*16(State), X0 // Restore X0 from state PADDL X0, X4 PADDL X1, X5 PADDL X2, X6 @@ -131,181 +130,34 @@ chacha_loop: MOVOU ·one<>(SB), X0 PADDQ X0, X3 - CMPL CX, $64 - JB less_than_64 + CMPL Len, $64 + JL BUFFER_KEYSTREAM - XOR(DI, SI, 0, X4, X5, X6, X7, X0) - MOVOU 0(AX), X0 - ADDL $64, SI - ADDL $64, DI - SUBL $64, CX - JNZ at_least_64 + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X0) + MOVOU 0*16(State), X0 // Restore X0 from state + ADDL $64, Src + ADDL $64, Dst + SUBL $64, Len + JZ DONE + JMP GENERATE_KEYSTREAM // There is at least one more plaintext byte -less_than_64: - MOVL CX, BP - TESTL BP, BP - JZ done +BUFFER_KEYSTREAM: + MOVL block+24(FP), State + MOVOU X4, 0(State) + MOVOU X5, 16(State) + MOVOU X6, 32(State) + MOVOU X7, 48(State) + MOVL Len, ret+36(FP) // Number of bytes written to the keystream buffer - 0 < Len < 64 + FINALIZE(Dst, Src, State, Len, Tmp0, Tmp1) - MOVL block+24(FP), BX - MOVOU X4, 0(BX) - MOVOU X5, 16(BX) - MOVOU X6, 32(BX) - MOVOU X7, 48(BX) - FINALIZE(DI, SI, BX, BP, AX, DX) - -done: - MOVL state+28(FP), AX - MOVOU X3, 48(AX) - MOVL CX, ret+36(FP) +DONE: + MOVL state+28(FP), State + MOVOU X3, 3*16(State) RET -// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int -TEXT ·xorKeyStreamSSSE3(SB), 4, $64-40 - MOVL dst_base+0(FP), DI - MOVL src_base+12(FP), SI - MOVL src_len+16(FP), CX - MOVL state+28(FP), AX - MOVL rounds+32(FP), DX - - MOVOU 48(AX), X3 - TESTL CX, CX - JZ done - - MOVL SP, BP - ADDL $16, SP - ANDL $-16, SP - - MOVOU ·one<>(SB), X0 - MOVOU 16(AX), X1 - MOVOU 32(AX), X2 - MOVO X0, 0(SP) - MOVO X1, 16(SP) - MOVO X2, 32(SP) - - MOVOU 0(AX), X0 - MOVOU ·rol16<>(SB), X1 - MOVOU ·rol8<>(SB), X2 - -at_least_64: - MOVO X0, X4 - MOVO 16(SP), X5 - MOVO 32(SP), X6 - MOVO X3, X7 - - MOVL DX, BX - -chacha_loop: - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) - CHACHA_SHUFFLE(X7, X6, X5) - SUBL $2, BX - JA chacha_loop - - MOVOU 0(AX), X0 - PADDL X0, X4 - PADDL 16(SP), X5 - PADDL 32(SP), X6 - PADDL X3, X7 - PADDQ 0(SP), X3 - - CMPL CX, $64 - JB less_than_64 - - XOR(DI, SI, 0, X4, X5, X6, X7, X0) - MOVOU 0(AX), X0 - ADDL $64, SI - ADDL $64, DI - SUBL $64, CX - JNZ at_least_64 - -less_than_64: - MOVL BP, SP - MOVL CX, BP - TESTL BP, BP - JE done - - MOVL block+24(FP), BX - MOVOU X4, 0(BX) - MOVOU X5, 16(BX) - MOVOU X6, 32(BX) - MOVOU X7, 48(BX) - FINALIZE(DI, SI, BX, BP, AX, DX) - -done: - MOVL state+28(FP), AX - MOVOU X3, 48(AX) - MOVL CX, ret+36(FP) - RET - -// func supportsSSE2() bool -TEXT ·supportsSSE2(SB), NOSPLIT, $0-1 - XORL AX, AX - INCL AX - CPUID - SHRL $26, DX - ANDL $1, DX - MOVB DX, ret+0(FP) - RET - -// func supportsSSSE3() bool -TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1 - XORL AX, AX - INCL AX - CPUID - SHRL $9, CX - ANDL $1, CX - MOVB CX, ret+0(FP) - RET - -// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) -TEXT ·hChaCha20SSE2(SB), 4, $0-12 - MOVL out+0(FP), DI - MOVL nonce+4(FP), AX - MOVL key+8(FP), BX - - MOVOU ·sigma<>(SB), X0 - MOVOU 0(BX), X1 - MOVOU 16(BX), X2 - MOVOU 0(AX), X3 - - MOVL $20, CX - -chacha_loop: - CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) - CHACHA_SHUFFLE(X1, X2, X3) - CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) - CHACHA_SHUFFLE(X3, X2, X1) - SUBL $2, CX - JNZ chacha_loop - - MOVOU X0, 0(DI) - MOVOU X3, 16(DI) - RET - -// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) -TEXT ·hChaCha20SSSE3(SB), 4, $0-12 - MOVL out+0(FP), DI - MOVL nonce+4(FP), AX - MOVL key+8(FP), BX - - MOVOU ·sigma<>(SB), X0 - MOVOU 0(BX), X1 - MOVOU 16(BX), X2 - MOVOU 0(AX), X3 - MOVOU ·rol16<>(SB), X5 - MOVOU ·rol8<>(SB), X6 - - MOVL $20, CX - -chacha_loop: - CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) - CHACHA_SHUFFLE(X1, X2, X3) - CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) - CHACHA_SHUFFLE(X3, X2, X1) - SUBL $2, CX - JNZ chacha_loop - - MOVOU X0, 0(DI) - MOVOU X3, 16(DI) - RET +#undef State +#undef Dst +#undef Src +#undef Len +#undef Tmp0 +#undef Tmp1 diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_go17_amd64.go b/vendor/github.com/aead/chacha20/chacha/chacha_amd64.go similarity index 76% rename from vendor/github.com/aead/chacha20/chacha/chacha_go17_amd64.go rename to vendor/github.com/aead/chacha20/chacha/chacha_amd64.go index 9ff41cf2..635f7de8 100644 --- a/vendor/github.com/aead/chacha20/chacha/chacha_go17_amd64.go +++ b/vendor/github.com/aead/chacha20/chacha/chacha_amd64.go @@ -6,24 +6,19 @@ package chacha +import "golang.org/x/sys/cpu" + func init() { - useSSE2 = true - useSSSE3 = supportsSSSE3() - useAVX2 = supportsAVX2() + useSSE2 = cpu.X86.HasSSE2 + useSSSE3 = cpu.X86.HasSSSE3 + useAVX = cpu.X86.HasAVX + useAVX2 = cpu.X86.HasAVX2 } // This function is implemented in chacha_amd64.s //go:noescape func initialize(state *[64]byte, key []byte, nonce *[16]byte) -// This function is implemented in chacha_amd64.s -//go:noescape -func supportsSSSE3() bool - -// This function is implemented in chachaAVX2_amd64.s -//go:noescape -func supportsAVX2() bool - // This function is implemented in chacha_amd64.s //go:noescape func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) @@ -44,29 +39,38 @@ func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int //go:noescape func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int +// This function is implemented in chacha_amd64.s +//go:noescape +func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int + // This function is implemented in chachaAVX2_amd64.s //go:noescape func xorKeyStreamAVX2(dst, src []byte, block, state *[64]byte, rounds int) int func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { - if useAVX2 { + switch { + case useAVX: hChaCha20AVX(out, nonce, key) - } else if useSSSE3 { + case useSSSE3: hChaCha20SSSE3(out, nonce, key) - } else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64 + case useSSE2: hChaCha20SSE2(out, nonce, key) - } else { + default: hChaCha20Generic(out, nonce, key) } } func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { - if useAVX2 { + switch { + case useAVX2: return xorKeyStreamAVX2(dst, src, block, state, rounds) - } else if useSSSE3 { + case useAVX: + return xorKeyStreamAVX(dst, src, block, state, rounds) + case useSSSE3: return xorKeyStreamSSSE3(dst, src, block, state, rounds) - } else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64 + case useSSE2: return xorKeyStreamSSE2(dst, src, block, state, rounds) + default: + return xorKeyStreamGeneric(dst, src, block, state, rounds) } - return xorKeyStreamGeneric(dst, src, block, state, rounds) } diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_amd64.s b/vendor/github.com/aead/chacha20/chacha/chacha_amd64.s index 5bc41ef7..26a23835 100644 --- a/vendor/github.com/aead/chacha20/chacha/chacha_amd64.s +++ b/vendor/github.com/aead/chacha20/chacha/chacha_amd64.s @@ -4,785 +4,1069 @@ // +build amd64,!gccgo,!appengine,!nacl -#include "textflag.h" - -DATA ·sigma<>+0x00(SB)/4, $0x61707865 -DATA ·sigma<>+0x04(SB)/4, $0x3320646e -DATA ·sigma<>+0x08(SB)/4, $0x79622d32 -DATA ·sigma<>+0x0C(SB)/4, $0x6b206574 -GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 - -DATA ·one<>+0x00(SB)/8, $1 -DATA ·one<>+0x08(SB)/8, $0 -GLOBL ·one<>(SB), (NOPTR+RODATA), $16 - -DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 -DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A -GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 - -DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 -DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B -GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 - -#define ROTL_SSE2(n, t, v) \ - MOVO v, t; \ - PSLLL $n, t; \ - PSRLL $(32-n), v; \ - PXOR t, v - -#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \ - PADDL v1, v0; \ - PXOR v0, v3; \ - ROTL_SSE2(16, t0, v3); \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(12, t0, v1); \ - PADDL v1, v0; \ - PXOR v0, v3; \ - ROTL_SSE2(8, t0, v3); \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(7, t0, v1) - -#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \ - PADDL v1, v0; \ - PXOR v0, v3; \ - PSHUFB r16, v3; \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(12, t0, v1); \ - PADDL v1, v0; \ - PXOR v0, v3; \ - PSHUFB r8, v3; \ - PADDL v3, v2; \ - PXOR v2, v1; \ - ROTL_SSE2(7, t0, v1) - -#define CHACHA_SHUFFLE(v1, v2, v3) \ - PSHUFL $0x39, v1, v1; \ - PSHUFL $0x4E, v2, v2; \ - PSHUFL $0x93, v3, v3 - -#define XOR(dst, src, off, v0, v1, v2, v3, t0) \ - MOVOU 0+off(src), t0; \ - PXOR v0, t0; \ - MOVOU t0, 0+off(dst); \ - MOVOU 16+off(src), t0; \ - PXOR v1, t0; \ - MOVOU t0, 16+off(dst); \ - MOVOU 32+off(src), t0; \ - PXOR v2, t0; \ - MOVOU t0, 32+off(dst); \ - MOVOU 48+off(src), t0; \ - PXOR v3, t0; \ - MOVOU t0, 48+off(dst) - -// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int -TEXT ·xorKeyStreamSSE2(SB), 4, $112-80 - MOVQ dst_base+0(FP), DI - MOVQ src_base+24(FP), SI - MOVQ src_len+32(FP), CX - MOVQ block+48(FP), BX - MOVQ state+56(FP), AX - MOVQ rounds+64(FP), DX - - MOVQ SP, R9 - ADDQ $16, SP - ANDQ $-16, SP - - MOVOU 0(AX), X0 - MOVOU 16(AX), X1 - MOVOU 32(AX), X2 - MOVOU 48(AX), X3 - MOVOU ·one<>(SB), X15 - - TESTQ CX, CX - JZ done - - CMPQ CX, $64 - JBE between_0_and_64 - - CMPQ CX, $128 - JBE between_64_and_128 - - MOVO X0, 0(SP) - MOVO X1, 16(SP) - MOVO X2, 32(SP) - MOVO X3, 48(SP) - MOVO X15, 64(SP) - - CMPQ CX, $192 - JBE between_128_and_192 - - MOVQ $192, R14 - -at_least_256: - MOVO X0, X4 - MOVO X1, X5 - MOVO X2, X6 - MOVO X3, X7 - PADDQ 64(SP), X7 - MOVO X0, X12 - MOVO X1, X13 - MOVO X2, X14 - MOVO X7, X15 - PADDQ 64(SP), X15 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X15, X11 - PADDQ 64(SP), X11 - - MOVQ DX, R8 - -chacha_loop_256: - MOVO X8, 80(SP) - CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8) - CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) - MOVO 80(SP), X8 - - MOVO X0, 80(SP) - CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) - MOVO 80(SP), X0 - - CHACHA_SHUFFLE(X1, X2, X3) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_SHUFFLE(X13, X14, X15) - CHACHA_SHUFFLE(X9, X10, X11) - - MOVO X8, 80(SP) - CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8) - CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) - MOVO 80(SP), X8 - - MOVO X0, 80(SP) - CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) - MOVO 80(SP), X0 - - CHACHA_SHUFFLE(X3, X2, X1) - CHACHA_SHUFFLE(X7, X6, X5) - CHACHA_SHUFFLE(X15, X14, X13) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_256 - - MOVO X8, 80(SP) - - PADDL 0(SP), X0 - PADDL 16(SP), X1 - PADDL 32(SP), X2 - PADDL 48(SP), X3 - XOR(DI, SI, 0, X0, X1, X2, X3, X8) - - MOVO 0(SP), X0 - MOVO 16(SP), X1 - MOVO 32(SP), X2 - MOVO 48(SP), X3 - PADDQ 64(SP), X3 - - PADDL X0, X4 - PADDL X1, X5 - PADDL X2, X6 - PADDL X3, X7 - PADDQ 64(SP), X3 - XOR(DI, SI, 64, X4, X5, X6, X7, X8) - - MOVO 64(SP), X5 - MOVO 80(SP), X8 - - PADDL X0, X12 - PADDL X1, X13 - PADDL X2, X14 - PADDL X3, X15 - PADDQ X5, X3 - XOR(DI, SI, 128, X12, X13, X14, X15, X4) - - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X5, X3 - - CMPQ CX, $256 - JB less_than_64 - - XOR(DI, SI, 192, X8, X9, X10, X11, X4) - MOVO X3, 48(SP) - ADDQ $256, SI - ADDQ $256, DI - SUBQ $256, CX - CMPQ CX, $192 - JA at_least_256 - - TESTQ CX, CX - JZ done - MOVO 64(SP), X15 - CMPQ CX, $64 - JBE between_0_and_64 - CMPQ CX, $128 - JBE between_64_and_128 - -between_128_and_192: - MOVQ $128, R14 - MOVO X0, X4 - MOVO X1, X5 - MOVO X2, X6 - MOVO X3, X7 - PADDQ X15, X7 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X7, X11 - PADDQ X15, X11 - - MOVQ DX, R8 - -chacha_loop_192: - CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12) - CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) - CHACHA_SHUFFLE(X1, X2, X3) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_SHUFFLE(X9, X10, X11) - CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12) - CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) - CHACHA_SHUFFLE(X3, X2, X1) - CHACHA_SHUFFLE(X7, X6, X5) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_192 - - PADDL 0(SP), X0 - PADDL 16(SP), X1 - PADDL 32(SP), X2 - PADDL 48(SP), X3 - XOR(DI, SI, 0, X0, X1, X2, X3, X12) - - MOVO 0(SP), X0 - MOVO 16(SP), X1 - MOVO 32(SP), X2 - MOVO 48(SP), X3 - PADDQ X15, X3 - - PADDL X0, X4 - PADDL X1, X5 - PADDL X2, X6 - PADDL X3, X7 - PADDQ X15, X3 - XOR(DI, SI, 64, X4, X5, X6, X7, X12) - - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X15, X3 - - CMPQ CX, $192 - JB less_than_64 - - XOR(DI, SI, 128, X8, X9, X10, X11, X12) - SUBQ $192, CX - JMP done - -between_64_and_128: - MOVQ $64, R14 - MOVO X0, X4 - MOVO X1, X5 - MOVO X2, X6 - MOVO X3, X7 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X3, X11 - PADDQ X15, X11 - - MOVQ DX, R8 - -chacha_loop_128: - CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_SHUFFLE(X9, X10, X11) - CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) - CHACHA_SHUFFLE(X7, X6, X5) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_128 - - PADDL X0, X4 - PADDL X1, X5 - PADDL X2, X6 - PADDL X3, X7 - PADDQ X15, X3 - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X15, X3 - XOR(DI, SI, 0, X4, X5, X6, X7, X12) - - CMPQ CX, $128 - JB less_than_64 - - XOR(DI, SI, 64, X8, X9, X10, X11, X12) - SUBQ $128, CX - JMP done - -between_0_and_64: - MOVQ $0, R14 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X3, X11 - MOVQ DX, R8 - -chacha_loop_64: - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) - CHACHA_SHUFFLE(X9, X10, X11) - CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_64 - - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X15, X3 - CMPQ CX, $64 - JB less_than_64 - - XOR(DI, SI, 0, X8, X9, X10, X11, X12) - SUBQ $64, CX - JMP done - -less_than_64: - // R14 contains the num of bytes already xor'd - ADDQ R14, SI - ADDQ R14, DI - SUBQ R14, CX - MOVOU X8, 0(BX) - MOVOU X9, 16(BX) - MOVOU X10, 32(BX) - MOVOU X11, 48(BX) - XORQ R11, R11 - XORQ R12, R12 - MOVQ CX, BP - -xor_loop: - MOVB 0(SI), R11 - MOVB 0(BX), R12 - XORQ R11, R12 - MOVB R12, 0(DI) - INCQ SI - INCQ BX - INCQ DI - DECQ BP - JA xor_loop - -done: - MOVOU X3, 48(AX) - MOVQ R9, SP - MOVQ CX, ret+72(FP) - RET - -// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int -TEXT ·xorKeyStreamSSSE3(SB), 4, $144-80 - MOVQ dst_base+0(FP), DI - MOVQ src_base+24(FP), SI - MOVQ src_len+32(FP), CX - MOVQ block+48(FP), BX - MOVQ state+56(FP), AX - MOVQ rounds+64(FP), DX - - MOVQ SP, R9 - ADDQ $16, SP - ANDQ $-16, SP - - MOVOU 0(AX), X0 - MOVOU 16(AX), X1 - MOVOU 32(AX), X2 - MOVOU 48(AX), X3 - MOVOU ·rol16<>(SB), X13 - MOVOU ·rol8<>(SB), X14 - MOVOU ·one<>(SB), X15 - - TESTQ CX, CX - JZ done - - CMPQ CX, $64 - JBE between_0_and_64 - - CMPQ CX, $128 - JBE between_64_and_128 - - MOVO X0, 0(SP) - MOVO X1, 16(SP) - MOVO X2, 32(SP) - MOVO X3, 48(SP) - MOVO X15, 64(SP) - - CMPQ CX, $192 - JBE between_128_and_192 - - MOVO X13, 96(SP) - MOVO X14, 112(SP) - MOVQ $192, R14 - -at_least_256: - MOVO X0, X4 - MOVO X1, X5 - MOVO X2, X6 - MOVO X3, X7 - PADDQ 64(SP), X7 - MOVO X0, X12 - MOVO X1, X13 - MOVO X2, X14 - MOVO X7, X15 - PADDQ 64(SP), X15 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X15, X11 - PADDQ 64(SP), X11 - - MOVQ DX, R8 - -chacha_loop_256: - MOVO X8, 80(SP) - CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP)) - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, 96(SP), 112(SP)) - MOVO 80(SP), X8 - - MOVO X0, 80(SP) - CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, 96(SP), 112(SP)) - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 96(SP), 112(SP)) - MOVO 80(SP), X0 - - CHACHA_SHUFFLE(X1, X2, X3) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_SHUFFLE(X13, X14, X15) - CHACHA_SHUFFLE(X9, X10, X11) - - MOVO X8, 80(SP) - CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP)) - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, 96(SP), 112(SP)) - MOVO 80(SP), X8 - - MOVO X0, 80(SP) - CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, 96(SP), 112(SP)) - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 96(SP), 112(SP)) - MOVO 80(SP), X0 - - CHACHA_SHUFFLE(X3, X2, X1) - CHACHA_SHUFFLE(X7, X6, X5) - CHACHA_SHUFFLE(X15, X14, X13) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_256 - - MOVO X8, 80(SP) - - PADDL 0(SP), X0 - PADDL 16(SP), X1 - PADDL 32(SP), X2 - PADDL 48(SP), X3 - XOR(DI, SI, 0, X0, X1, X2, X3, X8) - MOVO 0(SP), X0 - MOVO 16(SP), X1 - MOVO 32(SP), X2 - MOVO 48(SP), X3 - PADDQ 64(SP), X3 - - PADDL X0, X4 - PADDL X1, X5 - PADDL X2, X6 - PADDL X3, X7 - PADDQ 64(SP), X3 - XOR(DI, SI, 64, X4, X5, X6, X7, X8) - - MOVO 64(SP), X5 - MOVO 80(SP), X8 - - PADDL X0, X12 - PADDL X1, X13 - PADDL X2, X14 - PADDL X3, X15 - PADDQ X5, X3 - XOR(DI, SI, 128, X12, X13, X14, X15, X4) - - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X5, X3 - - CMPQ CX, $256 - JB less_than_64 - - XOR(DI, SI, 192, X8, X9, X10, X11, X4) - MOVO X3, 48(SP) - ADDQ $256, SI - ADDQ $256, DI - SUBQ $256, CX - CMPQ CX, $192 - JA at_least_256 - - TESTQ CX, CX - JZ done - MOVOU ·rol16<>(SB), X13 - MOVOU ·rol8<>(SB), X14 - MOVO 64(SP), X15 - CMPQ CX, $64 - JBE between_0_and_64 - CMPQ CX, $128 - JBE between_64_and_128 - -between_128_and_192: - MOVQ $128, R14 - MOVO X0, X4 - MOVO X1, X5 - MOVO X2, X6 - MOVO X3, X7 - PADDQ X15, X7 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X7, X11 - PADDQ X15, X11 - - MOVQ DX, R8 - -chacha_loop_192: - CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14) - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) - CHACHA_SHUFFLE(X1, X2, X3) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_SHUFFLE(X9, X10, X11) - CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14) - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) - CHACHA_SHUFFLE(X3, X2, X1) - CHACHA_SHUFFLE(X7, X6, X5) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_192 - - PADDL 0(SP), X0 - PADDL 16(SP), X1 - PADDL 32(SP), X2 - PADDL 48(SP), X3 - XOR(DI, SI, 0, X0, X1, X2, X3, X12) - - MOVO 0(SP), X0 - MOVO 16(SP), X1 - MOVO 32(SP), X2 - MOVO 48(SP), X3 - PADDQ X15, X3 - - PADDL X0, X4 - PADDL X1, X5 - PADDL X2, X6 - PADDL X3, X7 - PADDQ X15, X3 - XOR(DI, SI, 64, X4, X5, X6, X7, X12) - - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X15, X3 - - CMPQ CX, $192 - JB less_than_64 - - XOR(DI, SI, 128, X8, X9, X10, X11, X12) - SUBQ $192, CX - JMP done - -between_64_and_128: - MOVQ $64, R14 - MOVO X0, X4 - MOVO X1, X5 - MOVO X2, X6 - MOVO X3, X7 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X3, X11 - PADDQ X15, X11 - - MOVQ DX, R8 - -chacha_loop_128: - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) - CHACHA_SHUFFLE(X5, X6, X7) - CHACHA_SHUFFLE(X9, X10, X11) - CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) - CHACHA_SHUFFLE(X7, X6, X5) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_128 - - PADDL X0, X4 - PADDL X1, X5 - PADDL X2, X6 - PADDL X3, X7 - PADDQ X15, X3 - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X15, X3 - XOR(DI, SI, 0, X4, X5, X6, X7, X12) - - CMPQ CX, $128 - JB less_than_64 - - XOR(DI, SI, 64, X8, X9, X10, X11, X12) - SUBQ $128, CX - JMP done - -between_0_and_64: - MOVQ $0, R14 - MOVO X0, X8 - MOVO X1, X9 - MOVO X2, X10 - MOVO X3, X11 - MOVQ DX, R8 - -chacha_loop_64: - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) - CHACHA_SHUFFLE(X9, X10, X11) - CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) - CHACHA_SHUFFLE(X11, X10, X9) - SUBQ $2, R8 - JA chacha_loop_64 - - PADDL X0, X8 - PADDL X1, X9 - PADDL X2, X10 - PADDL X3, X11 - PADDQ X15, X3 - CMPQ CX, $64 - JB less_than_64 - - XOR(DI, SI, 0, X8, X9, X10, X11, X12) - SUBQ $64, CX - JMP done - -less_than_64: - // R14 contains the num of bytes already xor'd - ADDQ R14, SI - ADDQ R14, DI - SUBQ R14, CX - MOVOU X8, 0(BX) - MOVOU X9, 16(BX) - MOVOU X10, 32(BX) - MOVOU X11, 48(BX) - XORQ R11, R11 - XORQ R12, R12 - MOVQ CX, BP - -xor_loop: - MOVB 0(SI), R11 - MOVB 0(BX), R12 - XORQ R11, R12 - MOVB R12, 0(DI) - INCQ SI - INCQ BX - INCQ DI - DECQ BP - JA xor_loop - -done: - MOVQ R9, SP - MOVOU X3, 48(AX) - MOVQ CX, ret+72(FP) - RET - -// func supportsSSSE3() bool -TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1 - XORQ AX, AX - INCQ AX - CPUID - SHRQ $9, CX - ANDQ $1, CX - MOVB CX, ret+0(FP) - RET +#include "const.s" +#include "macro.s" + +// FINALIZE xors len bytes from src and block using +// the temp. registers t0 and t1 and writes the result +// to dst. +#define FINALIZE(dst, src, block, len, t0, t1) \ + XORQ t0, t0; \ + XORQ t1, t1; \ + FINALIZE_LOOP:; \ + MOVB 0(src), t0; \ + MOVB 0(block), t1; \ + XORQ t0, t1; \ + MOVB t1, 0(dst); \ + INCQ src; \ + INCQ block; \ + INCQ dst; \ + DECQ len; \ + JG FINALIZE_LOOP \ + +#define Dst DI +#define Nonce AX +#define Key BX +#define Rounds DX // func initialize(state *[64]byte, key []byte, nonce *[16]byte) TEXT ·initialize(SB), 4, $0-40 - MOVQ state+0(FP), DI - MOVQ key+8(FP), AX - MOVQ nonce+32(FP), BX + MOVQ state+0(FP), Dst + MOVQ key+8(FP), Key + MOVQ nonce+32(FP), Nonce MOVOU ·sigma<>(SB), X0 - MOVOU 0(AX), X1 - MOVOU 16(AX), X2 - MOVOU 0(BX), X3 + MOVOU 0*16(Key), X1 + MOVOU 1*16(Key), X2 + MOVOU 0*16(Nonce), X3 - MOVOU X0, 0(DI) - MOVOU X1, 16(DI) - MOVOU X2, 32(DI) - MOVOU X3, 48(DI) + MOVOU X0, 0*16(Dst) + MOVOU X1, 1*16(Dst) + MOVOU X2, 2*16(Dst) + MOVOU X3, 3*16(Dst) + RET + +// func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20AVX(SB), 4, $0-24 + MOVQ out+0(FP), Dst + MOVQ nonce+8(FP), Nonce + MOVQ key+16(FP), Key + + VMOVDQU ·sigma<>(SB), X0 + VMOVDQU 0*16(Key), X1 + VMOVDQU 1*16(Key), X2 + VMOVDQU 0*16(Nonce), X3 + VMOVDQU ·rol16_AVX2<>(SB), X5 + VMOVDQU ·rol8_AVX2<>(SB), X6 + MOVQ $20, Rounds + +CHACHA_LOOP: + CHACHA_QROUND_AVX(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE_AVX(X1, X2, X3) + CHACHA_QROUND_AVX(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE_AVX(X3, X2, X1) + SUBQ $2, Rounds + JNZ CHACHA_LOOP + + VMOVDQU X0, 0*16(Dst) + VMOVDQU X3, 1*16(Dst) + VZEROUPPER RET // func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) TEXT ·hChaCha20SSE2(SB), 4, $0-24 - MOVQ out+0(FP), DI - MOVQ nonce+8(FP), AX - MOVQ key+16(FP), BX + MOVQ out+0(FP), Dst + MOVQ nonce+8(FP), Nonce + MOVQ key+16(FP), Key MOVOU ·sigma<>(SB), X0 - MOVOU 0(BX), X1 - MOVOU 16(BX), X2 - MOVOU 0(AX), X3 + MOVOU 0*16(Key), X1 + MOVOU 1*16(Key), X2 + MOVOU 0*16(Nonce), X3 + MOVQ $20, Rounds - MOVQ $20, CX - -chacha_loop: +CHACHA_LOOP: CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) - CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_SHUFFLE_SSE(X1, X2, X3) CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) - CHACHA_SHUFFLE(X3, X2, X1) - SUBQ $2, CX - JNZ chacha_loop + CHACHA_SHUFFLE_SSE(X3, X2, X1) + SUBQ $2, Rounds + JNZ CHACHA_LOOP - MOVOU X0, 0(DI) - MOVOU X3, 16(DI) + MOVOU X0, 0*16(Dst) + MOVOU X3, 1*16(Dst) RET // func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) TEXT ·hChaCha20SSSE3(SB), 4, $0-24 - MOVQ out+0(FP), DI - MOVQ nonce+8(FP), AX - MOVQ key+16(FP), BX + MOVQ out+0(FP), Dst + MOVQ nonce+8(FP), Nonce + MOVQ key+16(FP), Key MOVOU ·sigma<>(SB), X0 - MOVOU 0(BX), X1 - MOVOU 16(BX), X2 - MOVOU 0(AX), X3 + MOVOU 0*16(Key), X1 + MOVOU 1*16(Key), X2 + MOVOU 0*16(Nonce), X3 MOVOU ·rol16<>(SB), X5 MOVOU ·rol8<>(SB), X6 - - MOVQ $20, CX + MOVQ $20, Rounds chacha_loop: CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) - CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_SHUFFLE_SSE(X1, X2, X3) CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) - CHACHA_SHUFFLE(X3, X2, X1) - SUBQ $2, CX + CHACHA_SHUFFLE_SSE(X3, X2, X1) + SUBQ $2, Rounds JNZ chacha_loop - MOVOU X0, 0(DI) - MOVOU X3, 16(DI) + MOVOU X0, 0*16(Dst) + MOVOU X3, 1*16(Dst) RET + +#undef Dst +#undef Nonce +#undef Key +#undef Rounds + +#define Dst DI +#define Src SI +#define Len R12 +#define Rounds DX +#define Buffer BX +#define State AX +#define Stack SP +#define SavedSP R8 +#define Tmp0 R9 +#define Tmp1 R10 +#define Tmp2 R11 + +// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamSSE2(SB), 4, $112-80 + MOVQ dst_base+0(FP), Dst + MOVQ src_base+24(FP), Src + MOVQ block+48(FP), Buffer + MOVQ state+56(FP), State + MOVQ rounds+64(FP), Rounds + MOVQ src_len+32(FP), Len + + MOVOU 0*16(State), X0 + MOVOU 1*16(State), X1 + MOVOU 2*16(State), X2 + MOVOU 3*16(State), X3 + + MOVQ Stack, SavedSP + ADDQ $16, Stack + ANDQ $-16, Stack + + TESTQ Len, Len + JZ DONE + + MOVOU ·one<>(SB), X4 + MOVO X0, 0*16(Stack) + MOVO X1, 1*16(Stack) + MOVO X2, 2*16(Stack) + MOVO X3, 3*16(Stack) + MOVO X4, 4*16(Stack) + + CMPQ Len, $64 + JLE GENERATE_KEYSTREAM_64 + CMPQ Len, $128 + JLE GENERATE_KEYSTREAM_128 + CMPQ Len, $192 + JLE GENERATE_KEYSTREAM_192 + +GENERATE_KEYSTREAM_256: + MOVO X0, X12 + MOVO X1, X13 + MOVO X2, X14 + MOVO X3, X15 + PADDQ 4*16(Stack), X15 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X15, X11 + PADDQ 4*16(Stack), X11 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X11, X7 + PADDQ 4*16(Stack), X7 + MOVQ Rounds, Tmp0 + + MOVO X3, 3*16(Stack) // Save X3 + +CHACHA_LOOP_256: + MOVO X4, 5*16(Stack) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_QROUND_SSE2(X12, X13, X14, X15, X4) + MOVO 5*16(Stack), X4 + MOVO X0, 5*16(Stack) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) + MOVO 5*16(Stack), X0 + CHACHA_SHUFFLE_SSE(X1, X2, X3) + CHACHA_SHUFFLE_SSE(X13, X14, X15) + CHACHA_SHUFFLE_SSE(X9, X10, X11) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + MOVO X4, 5*16(Stack) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_QROUND_SSE2(X12, X13, X14, X15, X4) + MOVO 5*16(Stack), X4 + MOVO X0, 5*16(Stack) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) + MOVO 5*16(Stack), X0 + CHACHA_SHUFFLE_SSE(X3, X2, X1) + CHACHA_SHUFFLE_SSE(X15, X14, X13) + CHACHA_SHUFFLE_SSE(X11, X10, X9) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_256 + + PADDL 0*16(Stack), X0 + PADDL 1*16(Stack), X1 + PADDL 2*16(Stack), X2 + PADDL 3*16(Stack), X3 + MOVO X4, 5*16(Stack) // Save X4 + XOR_SSE(Dst, Src, 0, X0, X1, X2, X3, X4) + MOVO 5*16(Stack), X4 // Restore X4 + + MOVO 0*16(Stack), X0 + MOVO 1*16(Stack), X1 + MOVO 2*16(Stack), X2 + MOVO 3*16(Stack), X3 + PADDQ 4*16(Stack), X3 + + PADDL X0, X12 + PADDL X1, X13 + PADDL X2, X14 + PADDL X3, X15 + PADDQ 4*16(Stack), X3 + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ 4*16(Stack), X3 + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + XOR_SSE(Dst, Src, 64, X12, X13, X14, X15, X0) + XOR_SSE(Dst, Src, 128, X8, X9, X10, X11, X0) + MOVO 0*16(Stack), X0 // Restore X0 + ADDQ $192, Dst + ADDQ $192, Src + SUBQ $192, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE + CMPQ Len, $64 // If Len <= 64 -> gen. only 64 byte keystream. + JLE GENERATE_KEYSTREAM_64 + CMPQ Len, $128 // If 64 < Len <= 128 -> gen. only 128 byte keystream. + JLE GENERATE_KEYSTREAM_128 + CMPQ Len, $192 // If Len > 192 -> repeat, otherwise Len > 128 && Len <= 192 -> gen. 192 byte keystream + JG GENERATE_KEYSTREAM_256 + +GENERATE_KEYSTREAM_192: + MOVO X0, X12 + MOVO X1, X13 + MOVO X2, X14 + MOVO X3, X15 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + PADDQ 4*16(Stack), X11 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X11, X7 + PADDQ 4*16(Stack), X7 + MOVQ Rounds, Tmp0 + +CHACHA_LOOP_192: + CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) + CHACHA_SHUFFLE_SSE(X13, X14, X15) + CHACHA_SHUFFLE_SSE(X9, X10, X11) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) + CHACHA_SHUFFLE_SSE(X15, X14, X13) + CHACHA_SHUFFLE_SSE(X11, X10, X9) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_192 + + MOVO 0*16(Stack), X0 // Restore X0 + PADDL X0, X12 + PADDL X1, X13 + PADDL X2, X14 + PADDL X3, X15 + PADDQ 4*16(Stack), X3 + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ 4*16(Stack), X3 + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + XOR_SSE(Dst, Src, 0, X12, X13, X14, X15, X0) + XOR_SSE(Dst, Src, 64, X8, X9, X10, X11, X0) + MOVO 0*16(Stack), X0 // Restore X0 + ADDQ $128, Dst + ADDQ $128, Src + SUBQ $128, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE + CMPQ Len, $64 // If Len <= 64 -> gen. only 64 byte keystream. + JLE GENERATE_KEYSTREAM_64 + +GENERATE_KEYSTREAM_128: + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + PADDQ 4*16(Stack), X7 + MOVQ Rounds, Tmp0 + +CHACHA_LOOP_128: + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) + CHACHA_SHUFFLE_SSE(X9, X10, X11) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) + CHACHA_SHUFFLE_SSE(X11, X10, X9) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_128 + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ 4*16(Stack), X3 + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + XOR_SSE(Dst, Src, 0, X8, X9, X10, X11, X12) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE // If Len == 0 -> DONE, otherwise Len <= 64 -> gen 64 byte keystream + +GENERATE_KEYSTREAM_64: + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + MOVQ Rounds, Tmp0 + +CHACHA_LOOP_64: + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_64 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Src + ADDQ $64, Dst + SUBQ $64, Len + JMP DONE // jump directly to DONE - there is no keystream to buffer, Len == 0 always true. + +BUFFER_KEYSTREAM: + MOVOU X4, 0*16(Buffer) + MOVOU X5, 1*16(Buffer) + MOVOU X6, 2*16(Buffer) + MOVOU X7, 3*16(Buffer) + MOVQ Len, Tmp0 + FINALIZE(Dst, Src, Buffer, Tmp0, Tmp1, Tmp2) + +DONE: + MOVQ SavedSP, Stack // Restore stack pointer + MOVOU X3, 3*16(State) + MOVQ Len, ret+72(FP) + RET + +// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamSSSE3(SB), 4, $144-80 + MOVQ dst_base+0(FP), Dst + MOVQ src_base+24(FP), Src + MOVQ block+48(FP), Buffer + MOVQ state+56(FP), State + MOVQ rounds+64(FP), Rounds + MOVQ src_len+32(FP), Len + + MOVOU 0*16(State), X0 + MOVOU 1*16(State), X1 + MOVOU 2*16(State), X2 + MOVOU 3*16(State), X3 + + MOVQ Stack, SavedSP + ADDQ $16, Stack + ANDQ $-16, Stack + + TESTQ Len, Len + JZ DONE + + MOVOU ·one<>(SB), X4 + MOVOU ·rol16<>(SB), X5 + MOVOU ·rol8<>(SB), X6 + MOVO X0, 0*16(Stack) + MOVO X1, 1*16(Stack) + MOVO X2, 2*16(Stack) + MOVO X3, 3*16(Stack) + MOVO X4, 4*16(Stack) + MOVO X5, 6*16(Stack) + MOVO X6, 7*16(Stack) + + CMPQ Len, $64 + JLE GENERATE_KEYSTREAM_64 + CMPQ Len, $128 + JLE GENERATE_KEYSTREAM_128 + CMPQ Len, $192 + JLE GENERATE_KEYSTREAM_192 + +GENERATE_KEYSTREAM_256: + MOVO X0, X12 + MOVO X1, X13 + MOVO X2, X14 + MOVO X3, X15 + PADDQ 4*16(Stack), X15 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X15, X11 + PADDQ 4*16(Stack), X11 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X11, X7 + PADDQ 4*16(Stack), X7 + MOVQ Rounds, Tmp0 + + MOVO X3, 3*16(Stack) // Save X3 + +CHACHA_LOOP_256: + MOVO X4, 5*16(Stack) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X4, 6*16(Stack), 7*16(Stack)) + MOVO 5*16(Stack), X4 + MOVO X0, 5*16(Stack) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, 6*16(Stack), 7*16(Stack)) + MOVO 5*16(Stack), X0 + CHACHA_SHUFFLE_SSE(X1, X2, X3) + CHACHA_SHUFFLE_SSE(X13, X14, X15) + CHACHA_SHUFFLE_SSE(X9, X10, X11) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + MOVO X4, 5*16(Stack) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X4, 6*16(Stack), 7*16(Stack)) + MOVO 5*16(Stack), X4 + MOVO X0, 5*16(Stack) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, 6*16(Stack), 7*16(Stack)) + MOVO 5*16(Stack), X0 + CHACHA_SHUFFLE_SSE(X3, X2, X1) + CHACHA_SHUFFLE_SSE(X15, X14, X13) + CHACHA_SHUFFLE_SSE(X11, X10, X9) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_256 + + PADDL 0*16(Stack), X0 + PADDL 1*16(Stack), X1 + PADDL 2*16(Stack), X2 + PADDL 3*16(Stack), X3 + MOVO X4, 5*16(Stack) // Save X4 + XOR_SSE(Dst, Src, 0, X0, X1, X2, X3, X4) + MOVO 5*16(Stack), X4 // Restore X4 + + MOVO 0*16(Stack), X0 + MOVO 1*16(Stack), X1 + MOVO 2*16(Stack), X2 + MOVO 3*16(Stack), X3 + PADDQ 4*16(Stack), X3 + + PADDL X0, X12 + PADDL X1, X13 + PADDL X2, X14 + PADDL X3, X15 + PADDQ 4*16(Stack), X3 + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ 4*16(Stack), X3 + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + XOR_SSE(Dst, Src, 64, X12, X13, X14, X15, X0) + XOR_SSE(Dst, Src, 128, X8, X9, X10, X11, X0) + MOVO 0*16(Stack), X0 // Restore X0 + ADDQ $192, Dst + ADDQ $192, Src + SUBQ $192, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE + CMPQ Len, $64 // If Len <= 64 -> gen. only 64 byte keystream. + JLE GENERATE_KEYSTREAM_64 + CMPQ Len, $128 // If 64 < Len <= 128 -> gen. only 128 byte keystream. + JLE GENERATE_KEYSTREAM_128 + CMPQ Len, $192 // If Len > 192 -> repeat, otherwise Len > 128 && Len <= 192 -> gen. 192 byte keystream + JG GENERATE_KEYSTREAM_256 + +GENERATE_KEYSTREAM_192: + MOVO X0, X12 + MOVO X1, X13 + MOVO X2, X14 + MOVO X3, X15 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + PADDQ 4*16(Stack), X11 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X11, X7 + PADDQ 4*16(Stack), X7 + MOVQ Rounds, Tmp0 + + MOVO 6*16(Stack), X1 // Load 16 bit rotate-left constant + MOVO 7*16(Stack), X2 // Load 8 bit rotate-left constant + +CHACHA_LOOP_192: + CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, X1, X2) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, X1, X2) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) + CHACHA_SHUFFLE_SSE(X13, X14, X15) + CHACHA_SHUFFLE_SSE(X9, X10, X11) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, X1, X2) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, X1, X2) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) + CHACHA_SHUFFLE_SSE(X15, X14, X13) + CHACHA_SHUFFLE_SSE(X11, X10, X9) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_192 + + MOVO 0*16(Stack), X0 // Restore X0 + MOVO 1*16(Stack), X1 // Restore X1 + MOVO 2*16(Stack), X2 // Restore X2 + PADDL X0, X12 + PADDL X1, X13 + PADDL X2, X14 + PADDL X3, X15 + PADDQ 4*16(Stack), X3 + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ 4*16(Stack), X3 + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + XOR_SSE(Dst, Src, 0, X12, X13, X14, X15, X0) + XOR_SSE(Dst, Src, 64, X8, X9, X10, X11, X0) + MOVO 0*16(Stack), X0 // Restore X0 + ADDQ $128, Dst + ADDQ $128, Src + SUBQ $128, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE + CMPQ Len, $64 // If Len <= 64 -> gen. only 64 byte keystream. + JLE GENERATE_KEYSTREAM_64 + +GENERATE_KEYSTREAM_128: + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + PADDQ 4*16(Stack), X7 + MOVQ Rounds, Tmp0 + + MOVO 6*16(Stack), X13 // Load 16 bit rotate-left constant + MOVO 7*16(Stack), X14 // Load 8 bit rotate-left constant + +CHACHA_LOOP_128: + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) + CHACHA_SHUFFLE_SSE(X9, X10, X11) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) + CHACHA_SHUFFLE_SSE(X11, X10, X9) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_128 + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ 4*16(Stack), X3 + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + XOR_SSE(Dst, Src, 0, X8, X9, X10, X11, X12) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE // If Len == 0 -> DONE, otherwise Len <= 64 -> gen 64 byte keystream + +GENERATE_KEYSTREAM_64: + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + MOVQ Rounds, Tmp0 + + MOVO 6*16(Stack), X9 // Load 16 bit rotate-left constant + MOVO 7*16(Stack), X10 // Load 8 bit rotate-left constant + +CHACHA_LOOP_64: + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, X9, X10) + CHACHA_SHUFFLE_SSE(X5, X6, X7) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, X9, X10) + CHACHA_SHUFFLE_SSE(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_64 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 4*16(Stack), X3 + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Src + ADDQ $64, Dst + SUBQ $64, Len + JMP DONE // jump directly to DONE - there is no keystream to buffer, Len == 0 always true. + +BUFFER_KEYSTREAM: + MOVOU X4, 0*16(Buffer) + MOVOU X5, 1*16(Buffer) + MOVOU X6, 2*16(Buffer) + MOVOU X7, 3*16(Buffer) + MOVQ Len, Tmp0 + FINALIZE(Dst, Src, Buffer, Tmp0, Tmp1, Tmp2) + +DONE: + MOVQ SavedSP, Stack // Restore stack pointer + MOVOU X3, 3*16(State) + MOVQ Len, ret+72(FP) + RET + +// func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamAVX(SB), 4, $144-80 + MOVQ dst_base+0(FP), Dst + MOVQ src_base+24(FP), Src + MOVQ block+48(FP), Buffer + MOVQ state+56(FP), State + MOVQ rounds+64(FP), Rounds + MOVQ src_len+32(FP), Len + + VMOVDQU 0*16(State), X0 + VMOVDQU 1*16(State), X1 + VMOVDQU 2*16(State), X2 + VMOVDQU 3*16(State), X3 + + MOVQ Stack, SavedSP + ADDQ $16, Stack + ANDQ $-16, Stack + + TESTQ Len, Len + JZ DONE + + VMOVDQU ·one<>(SB), X4 + VMOVDQU ·rol16<>(SB), X5 + VMOVDQU ·rol8<>(SB), X6 + VMOVDQA X0, 0*16(Stack) + VMOVDQA X1, 1*16(Stack) + VMOVDQA X2, 2*16(Stack) + VMOVDQA X3, 3*16(Stack) + VMOVDQA X4, 4*16(Stack) + VMOVDQA X5, 6*16(Stack) + VMOVDQA X6, 7*16(Stack) + + CMPQ Len, $64 + JLE GENERATE_KEYSTREAM_64 + CMPQ Len, $128 + JLE GENERATE_KEYSTREAM_128 + CMPQ Len, $192 + JLE GENERATE_KEYSTREAM_192 + +GENERATE_KEYSTREAM_256: + VMOVDQA X0, X12 + VMOVDQA X1, X13 + VMOVDQA X2, X14 + VMOVDQA X3, X15 + VPADDQ 4*16(Stack), X15, X15 + VMOVDQA X0, X8 + VMOVDQA X1, X9 + VMOVDQA X2, X10 + VMOVDQA X15, X11 + VPADDQ 4*16(Stack), X11, X11 + VMOVDQA X0, X4 + VMOVDQA X1, X5 + VMOVDQA X2, X6 + VMOVDQA X11, X7 + VPADDQ 4*16(Stack), X7, X7 + MOVQ Rounds, Tmp0 + + VMOVDQA X3, 3*16(Stack) // Save X3 + +CHACHA_LOOP_256: + VMOVDQA X4, 5*16(Stack) + CHACHA_QROUND_AVX(X0, X1, X2, X3, X4, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_AVX(X12, X13, X14, X15, X4, 6*16(Stack), 7*16(Stack)) + VMOVDQA 5*16(Stack), X4 + VMOVDQA X0, 5*16(Stack) + CHACHA_QROUND_AVX(X8, X9, X10, X11, X0, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X0, 6*16(Stack), 7*16(Stack)) + VMOVDQA 5*16(Stack), X0 + CHACHA_SHUFFLE_AVX(X1, X2, X3) + CHACHA_SHUFFLE_AVX(X13, X14, X15) + CHACHA_SHUFFLE_AVX(X9, X10, X11) + CHACHA_SHUFFLE_AVX(X5, X6, X7) + VMOVDQA X4, 5*16(Stack) + CHACHA_QROUND_AVX(X0, X1, X2, X3, X4, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_AVX(X12, X13, X14, X15, X4, 6*16(Stack), 7*16(Stack)) + VMOVDQA 5*16(Stack), X4 + VMOVDQA X0, 5*16(Stack) + CHACHA_QROUND_AVX(X8, X9, X10, X11, X0, 6*16(Stack), 7*16(Stack)) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X0, 6*16(Stack), 7*16(Stack)) + VMOVDQA 5*16(Stack), X0 + CHACHA_SHUFFLE_AVX(X3, X2, X1) + CHACHA_SHUFFLE_AVX(X15, X14, X13) + CHACHA_SHUFFLE_AVX(X11, X10, X9) + CHACHA_SHUFFLE_AVX(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_256 + + VPADDD 0*16(Stack), X0, X0 + VPADDD 1*16(Stack), X1, X1 + VPADDD 2*16(Stack), X2, X2 + VPADDD 3*16(Stack), X3, X3 + VMOVDQA X4, 5*16(Stack) // Save X4 + XOR_AVX(Dst, Src, 0, X0, X1, X2, X3, X4) + VMOVDQA 5*16(Stack), X4 // Restore X4 + + VMOVDQA 0*16(Stack), X0 + VMOVDQA 1*16(Stack), X1 + VMOVDQA 2*16(Stack), X2 + VMOVDQA 3*16(Stack), X3 + VPADDQ 4*16(Stack), X3, X3 + + VPADDD X0, X12, X12 + VPADDD X1, X13, X13 + VPADDD X2, X14, X14 + VPADDD X3, X15, X15 + VPADDQ 4*16(Stack), X3, X3 + VPADDD X0, X8, X8 + VPADDD X1, X9, X9 + VPADDD X2, X10, X10 + VPADDD X3, X11, X11 + VPADDQ 4*16(Stack), X3, X3 + VPADDD X0, X4, X4 + VPADDD X1, X5, X5 + VPADDD X2, X6, X6 + VPADDD X3, X7, X7 + VPADDQ 4*16(Stack), X3, X3 + + XOR_AVX(Dst, Src, 64, X12, X13, X14, X15, X0) + XOR_AVX(Dst, Src, 128, X8, X9, X10, X11, X0) + VMOVDQA 0*16(Stack), X0 // Restore X0 + ADDQ $192, Dst + ADDQ $192, Src + SUBQ $192, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_AVX(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE + CMPQ Len, $64 // If Len <= 64 -> gen. only 64 byte keystream. + JLE GENERATE_KEYSTREAM_64 + CMPQ Len, $128 // If 64 < Len <= 128 -> gen. only 128 byte keystream. + JLE GENERATE_KEYSTREAM_128 + CMPQ Len, $192 // If Len > 192 -> repeat, otherwise Len > 128 && Len <= 192 -> gen. 192 byte keystream + JG GENERATE_KEYSTREAM_256 + +GENERATE_KEYSTREAM_192: + VMOVDQA X0, X12 + VMOVDQA X1, X13 + VMOVDQA X2, X14 + VMOVDQA X3, X15 + VMOVDQA X0, X8 + VMOVDQA X1, X9 + VMOVDQA X2, X10 + VMOVDQA X3, X11 + VPADDQ 4*16(Stack), X11, X11 + VMOVDQA X0, X4 + VMOVDQA X1, X5 + VMOVDQA X2, X6 + VMOVDQA X11, X7 + VPADDQ 4*16(Stack), X7, X7 + MOVQ Rounds, Tmp0 + + VMOVDQA 6*16(Stack), X1 // Load 16 bit rotate-left constant + VMOVDQA 7*16(Stack), X2 // Load 8 bit rotate-left constant + +CHACHA_LOOP_192: + CHACHA_QROUND_AVX(X12, X13, X14, X15, X0, X1, X2) + CHACHA_QROUND_AVX(X8, X9, X10, X11, X0, X1, X2) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X0, X1, X2) + CHACHA_SHUFFLE_AVX(X13, X14, X15) + CHACHA_SHUFFLE_AVX(X9, X10, X11) + CHACHA_SHUFFLE_AVX(X5, X6, X7) + CHACHA_QROUND_AVX(X12, X13, X14, X15, X0, X1, X2) + CHACHA_QROUND_AVX(X8, X9, X10, X11, X0, X1, X2) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X0, X1, X2) + CHACHA_SHUFFLE_AVX(X15, X14, X13) + CHACHA_SHUFFLE_AVX(X11, X10, X9) + CHACHA_SHUFFLE_AVX(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_192 + + VMOVDQA 0*16(Stack), X0 // Restore X0 + VMOVDQA 1*16(Stack), X1 // Restore X1 + VMOVDQA 2*16(Stack), X2 // Restore X2 + VPADDD X0, X12, X12 + VPADDD X1, X13, X13 + VPADDD X2, X14, X14 + VPADDD X3, X15, X15 + VPADDQ 4*16(Stack), X3, X3 + VPADDD X0, X8, X8 + VPADDD X1, X9, X9 + VPADDD X2, X10, X10 + VPADDD X3, X11, X11 + VPADDQ 4*16(Stack), X3, X3 + VPADDD X0, X4, X4 + VPADDD X1, X5, X5 + VPADDD X2, X6, X6 + VPADDD X3, X7, X7 + VPADDQ 4*16(Stack), X3, X3 + + XOR_AVX(Dst, Src, 0, X12, X13, X14, X15, X0) + XOR_AVX(Dst, Src, 64, X8, X9, X10, X11, X0) + VMOVDQA 0*16(Stack), X0 // Restore X0 + ADDQ $128, Dst + ADDQ $128, Src + SUBQ $128, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_AVX(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE + CMPQ Len, $64 // If Len <= 64 -> gen. only 64 byte keystream. + JLE GENERATE_KEYSTREAM_64 + +GENERATE_KEYSTREAM_128: + VMOVDQA X0, X8 + VMOVDQA X1, X9 + VMOVDQA X2, X10 + VMOVDQA X3, X11 + VMOVDQA X0, X4 + VMOVDQA X1, X5 + VMOVDQA X2, X6 + VMOVDQA X3, X7 + VPADDQ 4*16(Stack), X7, X7 + MOVQ Rounds, Tmp0 + + VMOVDQA 6*16(Stack), X13 // Load 16 bit rotate-left constant + VMOVDQA 7*16(Stack), X14 // Load 8 bit rotate-left constant + +CHACHA_LOOP_128: + CHACHA_QROUND_AVX(X8, X9, X10, X11, X12, X13, X14) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X12, X13, X14) + CHACHA_SHUFFLE_AVX(X9, X10, X11) + CHACHA_SHUFFLE_AVX(X5, X6, X7) + CHACHA_QROUND_AVX(X8, X9, X10, X11, X12, X13, X14) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X12, X13, X14) + CHACHA_SHUFFLE_AVX(X11, X10, X9) + CHACHA_SHUFFLE_AVX(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_128 + + VPADDD X0, X8, X8 + VPADDD X1, X9, X9 + VPADDD X2, X10, X10 + VPADDD X3, X11, X11 + VPADDQ 4*16(Stack), X3, X3 + VPADDD X0, X4, X4 + VPADDD X1, X5, X5 + VPADDD X2, X6, X6 + VPADDD X3, X7, X7 + VPADDQ 4*16(Stack), X3, X3 + + XOR_AVX(Dst, Src, 0, X8, X9, X10, X11, X12) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_AVX(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Dst + ADDQ $64, Src + SUBQ $64, Len + JZ DONE // If Len == 0 -> DONE, otherwise Len <= 64 -> gen 64 byte keystream + +GENERATE_KEYSTREAM_64: + VMOVDQA X0, X4 + VMOVDQA X1, X5 + VMOVDQA X2, X6 + VMOVDQA X3, X7 + MOVQ Rounds, Tmp0 + + VMOVDQA 6*16(Stack), X9 // Load 16 bit rotate-left constant + VMOVDQA 7*16(Stack), X10 // Load 8 bit rotate-left constant + +CHACHA_LOOP_64: + CHACHA_QROUND_AVX(X4, X5, X6, X7, X8, X9, X10) + CHACHA_SHUFFLE_AVX(X5, X6, X7) + CHACHA_QROUND_AVX(X4, X5, X6, X7, X8, X9, X10) + CHACHA_SHUFFLE_AVX(X7, X6, X5) + SUBQ $2, Tmp0 + JNZ CHACHA_LOOP_64 + + VPADDD X0, X4, X4 + VPADDD X1, X5, X5 + VPADDD X2, X6, X6 + VPADDD X3, X7, X7 + VPADDQ 4*16(Stack), X3, X3 + + CMPQ Len, $64 + JL BUFFER_KEYSTREAM + + XOR_AVX(Dst, Src, 0, X4, X5, X6, X7, X8) + ADDQ $64, Src + ADDQ $64, Dst + SUBQ $64, Len + JMP DONE // jump directly to DONE - there is no keystream to buffer, Len == 0 always true. + +BUFFER_KEYSTREAM: + VMOVDQU X4, 0*16(Buffer) + VMOVDQU X5, 1*16(Buffer) + VMOVDQU X6, 2*16(Buffer) + VMOVDQU X7, 3*16(Buffer) + MOVQ Len, Tmp0 + FINALIZE(Dst, Src, Buffer, Tmp0, Tmp1, Tmp2) + +DONE: + MOVQ SavedSP, Stack // Restore stack pointer + VMOVDQU X3, 3*16(State) + VZEROUPPER + MOVQ Len, ret+72(FP) + RET + +#undef Dst +#undef Src +#undef Len +#undef Rounds +#undef Buffer +#undef State +#undef Stack +#undef SavedSP +#undef Tmp0 +#undef Tmp1 +#undef Tmp2 diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go b/vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go deleted file mode 100644 index 0dcb3027..00000000 --- a/vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2017 Andreas Auernhammer. All rights reserved. -// Use of this source code is governed by a license that can be -// found in the LICENSE file. - -// +build amd64,!gccgo,!appengine,!nacl,!go1.7 - -package chacha - -func init() { - useSSE2 = true - useSSSE3 = supportsSSSE3() - useAVX2 = false -} - -// This function is implemented in chacha_amd64.s -//go:noescape -func initialize(state *[64]byte, key []byte, nonce *[16]byte) - -// This function is implemented in chacha_amd64.s -//go:noescape -func supportsSSSE3() bool - -// This function is implemented in chacha_amd64.s -//go:noescape -func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) - -// This function is implemented in chacha_amd64.s -//go:noescape -func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) - -// This function is implemented in chacha_amd64.s -//go:noescape -func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int - -// This function is implemented in chacha_amd64.s -//go:noescape -func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int - -func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { - if useSSSE3 { - hChaCha20SSSE3(out, nonce, key) - } else if useSSE2 { // on amd64 this is always true - used to test generic on amd64 - hChaCha20SSE2(out, nonce, key) - } else { - hChaCha20Generic(out, nonce, key) - } -} - -func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { - if useSSSE3 { - return xorKeyStreamSSSE3(dst, src, block, state, rounds) - } else if useSSE2 { // on amd64 this is always true - used to test generic on amd64 - return xorKeyStreamSSE2(dst, src, block, state, rounds) - } - return xorKeyStreamGeneric(dst, src, block, state, rounds) -} diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_ref.go b/vendor/github.com/aead/chacha20/chacha/chacha_ref.go index 2c95a0c8..526877c2 100644 --- a/vendor/github.com/aead/chacha20/chacha/chacha_ref.go +++ b/vendor/github.com/aead/chacha20/chacha/chacha_ref.go @@ -8,6 +8,13 @@ package chacha import "encoding/binary" +func init() { + useSSE2 = false + useSSSE3 = false + useAVX = false + useAVX2 = false +} + func initialize(state *[64]byte, key []byte, nonce *[16]byte) { binary.LittleEndian.PutUint32(state[0:], sigma[0]) binary.LittleEndian.PutUint32(state[4:], sigma[1]) diff --git a/vendor/github.com/aead/chacha20/chacha/const.s b/vendor/github.com/aead/chacha20/chacha/const.s new file mode 100644 index 00000000..c7a94a47 --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/const.s @@ -0,0 +1,53 @@ +// Copyright (c) 2018 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl + +#include "textflag.h" + +DATA ·sigma<>+0x00(SB)/4, $0x61707865 +DATA ·sigma<>+0x04(SB)/4, $0x3320646e +DATA ·sigma<>+0x08(SB)/4, $0x79622d32 +DATA ·sigma<>+0x0C(SB)/4, $0x6b206574 +GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 // The 4 ChaCha initialization constants + +// SSE2/SSE3/AVX constants + +DATA ·one<>+0x00(SB)/8, $1 +DATA ·one<>+0x08(SB)/8, $0 +GLOBL ·one<>(SB), (NOPTR+RODATA), $16 // The constant 1 as 128 bit value + +DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 +DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A +GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 16 bit left rotate constant + +DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 +DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B +GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 8 bit left rotate constant + +// AVX2 constants + +DATA ·one_AVX2<>+0x00(SB)/8, $0 +DATA ·one_AVX2<>+0x08(SB)/8, $0 +DATA ·one_AVX2<>+0x10(SB)/8, $1 +DATA ·one_AVX2<>+0x18(SB)/8, $0 +GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 // The constant 1 as 256 bit value + +DATA ·two_AVX2<>+0x00(SB)/8, $2 +DATA ·two_AVX2<>+0x08(SB)/8, $0 +DATA ·two_AVX2<>+0x10(SB)/8, $2 +DATA ·two_AVX2<>+0x18(SB)/8, $0 +GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32 + +DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302 +DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A +DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302 +DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A +GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 16 bit left rotate constant + +DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003 +DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B +DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003 +DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B +GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 8 bit left rotate constant diff --git a/vendor/github.com/aead/chacha20/chacha/macro.s b/vendor/github.com/aead/chacha20/chacha/macro.s new file mode 100644 index 00000000..780108f8 --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/macro.s @@ -0,0 +1,163 @@ +// Copyright (c) 2018 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl + +// ROTL_SSE rotates all 4 32 bit values of the XMM register v +// left by n bits using SSE2 instructions (0 <= n <= 32). +// The XMM register t is used as a temp. register. +#define ROTL_SSE(n, t, v) \ + MOVO v, t; \ + PSLLL $n, t; \ + PSRLL $(32-n), v; \ + PXOR t, v + +// ROTL_AVX rotates all 4/8 32 bit values of the AVX/AVX2 register v +// left by n bits using AVX/AVX2 instructions (0 <= n <= 32). +// The AVX/AVX2 register t is used as a temp. register. +#define ROTL_AVX(n, t, v) \ + VPSLLD $n, v, t; \ + VPSRLD $(32-n), v, v; \ + VPXOR v, t, v + +// CHACHA_QROUND_SSE2 performs a ChaCha quarter-round using the +// 4 XMM registers v0, v1, v2 and v3. It uses only ROTL_SSE2 for +// rotations. The XMM register t is used as a temp. register. +#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t) \ + PADDL v1, v0; \ + PXOR v0, v3; \ + ROTL_SSE(16, t, v3); \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE(12, t, v1); \ + PADDL v1, v0; \ + PXOR v0, v3; \ + ROTL_SSE(8, t, v3); \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE(7, t, v1) + +// CHACHA_QROUND_SSSE3 performs a ChaCha quarter-round using the +// 4 XMM registers v0, v1, v2 and v3. It uses PSHUFB for 8/16 bit +// rotations. The XMM register t is used as a temp. register. +// +// r16 holds the PSHUFB constant for a 16 bit left rotate. +// r8 holds the PSHUFB constant for a 8 bit left rotate. +#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t, r16, r8) \ + PADDL v1, v0; \ + PXOR v0, v3; \ + PSHUFB r16, v3; \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE(12, t, v1); \ + PADDL v1, v0; \ + PXOR v0, v3; \ + PSHUFB r8, v3; \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE(7, t, v1) + +// CHACHA_QROUND_AVX performs a ChaCha quarter-round using the +// 4 AVX/AVX2 registers v0, v1, v2 and v3. It uses VPSHUFB for 8/16 bit +// rotations. The AVX/AVX2 register t is used as a temp. register. +// +// r16 holds the VPSHUFB constant for a 16 bit left rotate. +// r8 holds the VPSHUFB constant for a 8 bit left rotate. +#define CHACHA_QROUND_AVX(v0, v1, v2, v3, t, r16, r8) \ + VPADDD v0, v1, v0; \ + VPXOR v3, v0, v3; \ + VPSHUFB r16, v3, v3; \ + VPADDD v2, v3, v2; \ + VPXOR v1, v2, v1; \ + ROTL_AVX(12, t, v1); \ + VPADDD v0, v1, v0; \ + VPXOR v3, v0, v3; \ + VPSHUFB r8, v3, v3; \ + VPADDD v2, v3, v2; \ + VPXOR v1, v2, v1; \ + ROTL_AVX(7, t, v1) + +// CHACHA_SHUFFLE_SSE performs a ChaCha shuffle using the +// 3 XMM registers v1, v2 and v3. The inverse shuffle is +// performed by switching v1 and v3: CHACHA_SHUFFLE_SSE(v3, v2, v1). +#define CHACHA_SHUFFLE_SSE(v1, v2, v3) \ + PSHUFL $0x39, v1, v1; \ + PSHUFL $0x4E, v2, v2; \ + PSHUFL $0x93, v3, v3 + +// CHACHA_SHUFFLE_AVX performs a ChaCha shuffle using the +// 3 AVX/AVX2 registers v1, v2 and v3. The inverse shuffle is +// performed by switching v1 and v3: CHACHA_SHUFFLE_AVX(v3, v2, v1). +#define CHACHA_SHUFFLE_AVX(v1, v2, v3) \ + VPSHUFD $0x39, v1, v1; \ + VPSHUFD $0x4E, v2, v2; \ + VPSHUFD $0x93, v3, v3 + +// XOR_SSE extracts 4x16 byte vectors from src at +// off, xors all vectors with the corresponding XMM +// register (v0 - v3) and writes the result to dst +// at off. +// The XMM register t is used as a temp. register. +#define XOR_SSE(dst, src, off, v0, v1, v2, v3, t) \ + MOVOU 0+off(src), t; \ + PXOR v0, t; \ + MOVOU t, 0+off(dst); \ + MOVOU 16+off(src), t; \ + PXOR v1, t; \ + MOVOU t, 16+off(dst); \ + MOVOU 32+off(src), t; \ + PXOR v2, t; \ + MOVOU t, 32+off(dst); \ + MOVOU 48+off(src), t; \ + PXOR v3, t; \ + MOVOU t, 48+off(dst) + +// XOR_AVX extracts 4x16 byte vectors from src at +// off, xors all vectors with the corresponding AVX +// register (v0 - v3) and writes the result to dst +// at off. +// The XMM register t is used as a temp. register. +#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t) \ + VPXOR 0+off(src), v0, t; \ + VMOVDQU t, 0+off(dst); \ + VPXOR 16+off(src), v1, t; \ + VMOVDQU t, 16+off(dst); \ + VPXOR 32+off(src), v2, t; \ + VMOVDQU t, 32+off(dst); \ + VPXOR 48+off(src), v3, t; \ + VMOVDQU t, 48+off(dst) + +#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ + VMOVDQU (0+off)(src), t0; \ + VPERM2I128 $32, v1, v0, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (0+off)(dst); \ + VMOVDQU (32+off)(src), t0; \ + VPERM2I128 $32, v3, v2, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (32+off)(dst); \ + VMOVDQU (64+off)(src), t0; \ + VPERM2I128 $49, v1, v0, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (64+off)(dst); \ + VMOVDQU (96+off)(src), t0; \ + VPERM2I128 $49, v3, v2, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (96+off)(dst) + +#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ + VMOVDQU (0+off)(src), t0; \ + VPERM2I128 $32, v1, v0, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (0+off)(dst); \ + VMOVDQU (32+off)(src), t0; \ + VPERM2I128 $32, v3, v2, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (32+off)(dst); \ + +#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \ + VPERM2I128 $49, v1, v0, t0; \ + VMOVDQU t0, 0(dst); \ + VPERM2I128 $49, v3, v2, t0; \ + VMOVDQU t0, 32(dst) diff --git a/vendor/github.com/bifurcation/mint/LICENSE.md b/vendor/github.com/bifurcation/mint/LICENSE.md deleted file mode 100644 index 63858124..00000000 --- a/vendor/github.com/bifurcation/mint/LICENSE.md +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2016 Richard Barnes - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/vendor/github.com/bifurcation/mint/alert.go b/vendor/github.com/bifurcation/mint/alert.go deleted file mode 100644 index 5e31035a..00000000 --- a/vendor/github.com/bifurcation/mint/alert.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package mint - -import "strconv" - -type Alert uint8 - -const ( - // alert level - AlertLevelWarning = 1 - AlertLevelError = 2 -) - -const ( - AlertCloseNotify Alert = 0 - AlertUnexpectedMessage Alert = 10 - AlertBadRecordMAC Alert = 20 - AlertDecryptionFailed Alert = 21 - AlertRecordOverflow Alert = 22 - AlertDecompressionFailure Alert = 30 - AlertHandshakeFailure Alert = 40 - AlertBadCertificate Alert = 42 - AlertUnsupportedCertificate Alert = 43 - AlertCertificateRevoked Alert = 44 - AlertCertificateExpired Alert = 45 - AlertCertificateUnknown Alert = 46 - AlertIllegalParameter Alert = 47 - AlertUnknownCA Alert = 48 - AlertAccessDenied Alert = 49 - AlertDecodeError Alert = 50 - AlertDecryptError Alert = 51 - AlertProtocolVersion Alert = 70 - AlertInsufficientSecurity Alert = 71 - AlertInternalError Alert = 80 - AlertInappropriateFallback Alert = 86 - AlertUserCanceled Alert = 90 - AlertNoRenegotiation Alert = 100 - AlertMissingExtension Alert = 109 - AlertUnsupportedExtension Alert = 110 - AlertCertificateUnobtainable Alert = 111 - AlertUnrecognizedName Alert = 112 - AlertBadCertificateStatsResponse Alert = 113 - AlertBadCertificateHashValue Alert = 114 - AlertUnknownPSKIdentity Alert = 115 - AlertNoApplicationProtocol Alert = 120 - AlertWouldBlock Alert = 254 - AlertNoAlert Alert = 255 -) - -var alertText = map[Alert]string{ - AlertCloseNotify: "close notify", - AlertUnexpectedMessage: "unexpected message", - AlertBadRecordMAC: "bad record MAC", - AlertDecryptionFailed: "decryption failed", - AlertRecordOverflow: "record overflow", - AlertDecompressionFailure: "decompression failure", - AlertHandshakeFailure: "handshake failure", - AlertBadCertificate: "bad certificate", - AlertUnsupportedCertificate: "unsupported certificate", - AlertCertificateRevoked: "revoked certificate", - AlertCertificateExpired: "expired certificate", - AlertCertificateUnknown: "unknown certificate", - AlertIllegalParameter: "illegal parameter", - AlertUnknownCA: "unknown certificate authority", - AlertAccessDenied: "access denied", - AlertDecodeError: "error decoding message", - AlertDecryptError: "error decrypting message", - AlertProtocolVersion: "protocol version not supported", - AlertInsufficientSecurity: "insufficient security level", - AlertInternalError: "internal error", - AlertInappropriateFallback: "inappropriate fallback", - AlertUserCanceled: "user canceled", - AlertMissingExtension: "missing extension", - AlertUnsupportedExtension: "unsupported extension", - AlertCertificateUnobtainable: "certificate unobtainable", - AlertUnrecognizedName: "unrecognized name", - AlertBadCertificateStatsResponse: "bad certificate status response", - AlertBadCertificateHashValue: "bad certificate hash value", - AlertUnknownPSKIdentity: "unknown PSK identity", - AlertNoApplicationProtocol: "no application protocol", - AlertNoRenegotiation: "no renegotiation", - AlertWouldBlock: "would have blocked", - AlertNoAlert: "no alert", -} - -func (e Alert) String() string { - s, ok := alertText[e] - if ok { - return s - } - return "alert(" + strconv.Itoa(int(e)) + ")" -} - -func (e Alert) Error() string { - return e.String() -} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go b/vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go deleted file mode 100644 index 4efe2f55..00000000 --- a/vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go +++ /dev/null @@ -1,42 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "io/ioutil" - "net" - "net/http" - "os" - - "github.com/bifurcation/mint" -) - -var url string - -func main() { - url := flag.String("url", "https://localhost:4430", "URL to send request") - flag.Parse() - mintdial := func(network, addr string) (net.Conn, error) { - return mint.Dial(network, addr, nil) - } - - tr := &http.Transport{ - DialTLS: mintdial, - DisableCompression: true, - } - client := &http.Client{Transport: tr} - - response, err := client.Get(*url) - if err != nil { - fmt.Println("err:", err) - return - } - defer response.Body.Close() - - contents, err := ioutil.ReadAll(response.Body) - if err != nil { - fmt.Printf("%s", err) - os.Exit(1) - } - fmt.Printf("%s\n", string(contents)) -} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-client/main.go b/vendor/github.com/bifurcation/mint/bin/mint-client/main.go deleted file mode 100644 index 27b0f253..00000000 --- a/vendor/github.com/bifurcation/mint/bin/mint-client/main.go +++ /dev/null @@ -1,37 +0,0 @@ -package main - -import ( - "flag" - "fmt" - - "github.com/bifurcation/mint" -) - -var addr string - -func main() { - flag.StringVar(&addr, "addr", "localhost:4430", "port") - flag.Parse() - - conn, err := mint.Dial("tcp", addr, nil) - - if err != nil { - fmt.Println("TLS handshake failed:", err) - return - } - - request := "GET / HTTP/1.0\r\n\r\n" - conn.Write([]byte(request)) - - response := "" - buffer := make([]byte, 1024) - var read int - for err == nil { - read, err = conn.Read(buffer) - fmt.Println(" ~~ read: ", read) - response += string(buffer) - } - fmt.Println("err:", err) - fmt.Println("Received from server:") - fmt.Println(response) -} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go b/vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go deleted file mode 100644 index 7ac0e60e..00000000 --- a/vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go +++ /dev/null @@ -1,226 +0,0 @@ -package main - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "flag" - "fmt" - "io/ioutil" - "log" - "net/http" - - "github.com/bifurcation/mint" - "golang.org/x/net/http2" -) - -var ( - port string - serverName string - certFile string - keyFile string - responseFile string - h2 bool - sendTickets bool -) - -type responder []byte - -func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Write(rsp) -} - -// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve -// PEM-encoded private key. -// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module -func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) { - keyDER, _ := pem.Decode(keyPEM) - if keyDER == nil { - return nil, err - } - - generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes) - if err != nil { - generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes) - if err != nil { - generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes) - if err != nil { - // We don't include the actual error into - // the final error. The reason might be - // we don't want to leak any info about - // the private key. - return nil, fmt.Errorf("No successful private key decoder") - } - } - } - - switch generalKey.(type) { - case *rsa.PrivateKey: - return generalKey.(*rsa.PrivateKey), nil - case *ecdsa.PrivateKey: - return generalKey.(*ecdsa.PrivateKey), nil - } - - // should never reach here - return nil, fmt.Errorf("Should be unreachable") -} - -// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object, -// either a raw x509 certificate or a PKCS #7 structure possibly containing -// multiple certificates, from the top of certsPEM, which itself may -// contain multiple PEM encoded certificate objects. -// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module -func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) { - block, rest := pem.Decode(certsPEM) - if block == nil { - return nil, rest, nil - } - - cert, err := x509.ParseCertificate(block.Bytes) - var certs = []*x509.Certificate{cert} - return certs, rest, err -} - -// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them, -// can handle PEM encoded PKCS #7 structures. -// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module -func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) { - var certs []*x509.Certificate - var err error - certsPEM = bytes.TrimSpace(certsPEM) - for len(certsPEM) > 0 { - var cert []*x509.Certificate - cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM) - if err != nil { - return nil, err - } else if cert == nil { - break - } - - certs = append(certs, cert...) - } - if len(certsPEM) > 0 { - return nil, fmt.Errorf("Trailing PEM data") - } - return certs, nil -} - -func main() { - flag.StringVar(&port, "port", "4430", "port") - flag.StringVar(&serverName, "host", "example.com", "hostname") - flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER") - flag.StringVar(&keyFile, "key", "", "private key in PEM format") - flag.StringVar(&responseFile, "response", "", "file to serve") - flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)") - flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets") - flag.Parse() - - var certChain []*x509.Certificate - var priv crypto.Signer - var response []byte - var err error - - // Load the key and certificate chain - if certFile != "" { - certs, err := ioutil.ReadFile(certFile) - if err != nil { - log.Fatalf("Error: %v", err) - } else { - certChain, err = ParseCertificatesPEM(certs) - if err != nil { - certChain, err = x509.ParseCertificates(certs) - if err != nil { - log.Fatalf("Error parsing certificates: %v", err) - } - } - } - } - if keyFile != "" { - keyPEM, err := ioutil.ReadFile(keyFile) - if err != nil { - log.Fatalf("Error: %v", err) - } else { - priv, err = ParsePrivateKeyPEM(keyPEM) - if priv == nil || err != nil { - log.Fatalf("Error parsing private key: %v", err) - } - } - } - if err != nil { - log.Fatalf("Error: %v", err) - } - - // Load response file - if responseFile != "" { - log.Printf("Loading response file: %v", responseFile) - response, err = ioutil.ReadFile(responseFile) - if err != nil { - log.Fatalf("Error: %v", err) - } - } else { - response = []byte("Welcome to the TLS 1.3 zone!") - } - handler := responder(response) - - config := mint.Config{ - SendSessionTickets: true, - ServerName: serverName, - NextProtos: []string{"http/1.1"}, - } - - if h2 { - config.NextProtos = []string{"h2"} - } - - config.SendSessionTickets = sendTickets - - if certChain != nil && priv != nil { - log.Printf("Loading cert: %v key: %v", certFile, keyFile) - config.Certificates = []*mint.Certificate{ - { - Chain: certChain, - PrivateKey: priv, - }, - } - } - config.Init(false) - - service := "0.0.0.0:" + port - srv := &http.Server{Handler: handler} - - log.Printf("Listening on port %v", port) - // Need the inner loop here because the h1 server errors on a dropped connection - // Need the outer loop here because the h2 server is per-connection - for { - listener, err := mint.Listen("tcp", service, &config) - if err != nil { - log.Printf("Listen Error: %v", err) - continue - } - - if !h2 { - alert := srv.Serve(listener) - if alert != mint.AlertNoAlert { - log.Printf("Serve Error: %v", err) - } - } else { - srv2 := new(http2.Server) - opts := &http2.ServeConnOpts{ - Handler: handler, - BaseConfig: srv, - } - - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("Accept error: %v", err) - continue - } - go srv2.ServeConn(conn, opts) - } - } - } -} diff --git a/vendor/github.com/bifurcation/mint/bin/mint-server/main.go b/vendor/github.com/bifurcation/mint/bin/mint-server/main.go deleted file mode 100644 index 216f8acb..00000000 --- a/vendor/github.com/bifurcation/mint/bin/mint-server/main.go +++ /dev/null @@ -1,65 +0,0 @@ -package main - -import ( - "flag" - "log" - "net" - - "github.com/bifurcation/mint" -) - -var port string - -func main() { - var config mint.Config - config.SendSessionTickets = true - config.ServerName = "localhost" - config.Init(false) - - flag.StringVar(&port, "port", "4430", "port") - flag.Parse() - - service := "0.0.0.0:" + port - listener, err := mint.Listen("tcp", service, &config) - - if err != nil { - log.Fatalf("server: listen: %s", err) - } - log.Print("server: listening") - - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("server: accept: %s", err) - break - } - defer conn.Close() - log.Printf("server: accepted from %s", conn.RemoteAddr()) - go handleClient(conn) - } -} - -func handleClient(conn net.Conn) { - defer conn.Close() - buf := make([]byte, 10) - for { - log.Print("server: conn: waiting") - n, err := conn.Read(buf) - if err != nil { - if err != nil { - log.Printf("server: conn: read: %s", err) - } - break - } - - n, err = conn.Write([]byte("hello world")) - log.Printf("server: conn: wrote %d bytes", n) - - if err != nil { - log.Printf("server: write: %s", err) - break - } - break - } - log.Println("server: conn: closed") -} diff --git a/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/bifurcation/mint/client-state-machine.go deleted file mode 100644 index 290a9303..00000000 --- a/vendor/github.com/bifurcation/mint/client-state-machine.go +++ /dev/null @@ -1,942 +0,0 @@ -package mint - -import ( - "bytes" - "crypto" - "hash" - "time" -) - -// Client State Machine -// -// START <----+ -// Send ClientHello | | Recv HelloRetryRequest -// / v | -// | WAIT_SH ---+ -// Can | | Recv ServerHello -// send | V -// early | WAIT_EE -// data | | Recv EncryptedExtensions -// | +--------+--------+ -// | Using | | Using certificate -// | PSK | v -// | | WAIT_CERT_CR -// | | Recv | | Recv CertificateRequest -// | | Certificate | v -// | | | WAIT_CERT -// | | | | Recv Certificate -// | | v v -// | | WAIT_CV -// | | | Recv CertificateVerify -// | +> WAIT_FINISHED <+ -// | | Recv Finished -// \ | -// | [Send EndOfEarlyData] -// | [Send Certificate [+ CertificateVerify]] -// | Send Finished -// Can send v -// app data --> CONNECTED -// after -// here -// -// State Instructions -// START Send(CH); [RekeyOut; SendEarlyData] -// WAIT_SH Send(CH) || RekeyIn -// WAIT_EE {} -// WAIT_CERT_CR {} -// WAIT_CERT {} -// WAIT_CV {} -// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut; -// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) - -type ClientStateStart struct { - Caps Capabilities - Opts ConnectionOptions - Params ConnectionParameters - - cookie []byte - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage -} - -func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm != nil { - logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message") - return nil, nil, AlertUnexpectedMessage - } - - // key_shares - offeredDH := map[NamedGroup][]byte{} - ks := KeyShareExtension{ - HandshakeType: HandshakeTypeClientHello, - Shares: make([]KeyShareEntry, len(state.Caps.Groups)), - } - for i, group := range state.Caps.Groups { - pub, priv, err := newKeyShare(group) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) - return nil, nil, AlertInternalError - } - - ks.Shares[i].Group = group - ks.Shares[i].KeyExchange = pub - offeredDH[group] = priv - } - - logf(logTypeHandshake, "opts: %+v", state.Opts) - - // supported_versions, supported_groups, signature_algorithms, server_name - sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} - sni := ServerNameExtension(state.Opts.ServerName) - sg := SupportedGroupsExtension{Groups: state.Caps.Groups} - sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} - - state.Params.ServerName = state.Opts.ServerName - - // Application Layer Protocol Negotiation - var alpn *ALPNExtension - if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { - alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} - } - - // Construct base ClientHello - ch := &ClientHelloBody{ - CipherSuites: state.Caps.CipherSuites, - } - _, err := prng.Read(ch.Random[:]) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err) - return nil, nil, AlertInternalError - } - for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} { - err := ch.Extensions.Add(ext) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err) - return nil, nil, AlertInternalError - } - } - // XXX: These optional extensions can't be folded into the above because Go - // interface-typed values are never reported as nil - if alpn != nil { - err := ch.Extensions.Add(alpn) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) - return nil, nil, AlertInternalError - } - } - if state.cookie != nil { - err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Handle PSK and EarlyData just before transmitting, so that we can - // calculate the PSK binder value - var psk *PreSharedKeyExtension - var ed *EarlyDataExtension - var offeredPSK PreSharedKey - var earlyHash crypto.Hash - var earlySecret []byte - var clientEarlyTrafficKeys keySet - var clientHello *HandshakeMessage - if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok { - offeredPSK = key - - // Narrow ciphersuites to ones that match PSK hash - params, ok := cipherSuiteMap[key.CipherSuite] - if !ok { - logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite") - return nil, nil, AlertInternalError - } - - compatibleSuites := []CipherSuite{} - for _, suite := range ch.CipherSuites { - if cipherSuiteMap[suite].Hash == params.Hash { - compatibleSuites = append(compatibleSuites, suite) - } - } - ch.CipherSuites = compatibleSuites - - // Signal early data if we're going to do it - if len(state.Opts.EarlyData) > 0 { - state.Params.ClientSendingEarlyData = true - ed = &EarlyDataExtension{} - err = ch.Extensions.Add(ed) - if err != nil { - logf(logTypeHandshake, "Error adding early data extension: %v", err) - return nil, nil, AlertInternalError - } - } - - // Signal supported PSK key exchange modes - if len(state.Caps.PSKModes) == 0 { - logf(logTypeHandshake, "PSK selected, but no PSKModes") - return nil, nil, AlertInternalError - } - kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes} - err = ch.Extensions.Add(kem) - if err != nil { - logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) - return nil, nil, AlertInternalError - } - - // Add the shim PSK extension to the ClientHello - logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity) - psk = &PreSharedKeyExtension{ - HandshakeType: HandshakeTypeClientHello, - Identities: []PSKIdentity{ - { - Identity: key.Identity, - ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd, - }, - }, - Binders: []PSKBinderEntry{ - // Note: Stub to get the length fields right - {Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())}, - }, - } - ch.Extensions.Add(psk) - - // Compute the binder key - h0 := params.Hash.New().Sum(nil) - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - earlyHash = params.Hash - earlySecret = HkdfExtract(params.Hash, zero, key.Key) - logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) - - binderLabel := labelExternalBinder - if key.IsResumption { - binderLabel = labelResumptionBinder - } - binderKey := deriveSecret(params, earlySecret, binderLabel, h0) - logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey) - - // Compute the binder value - trunc, err := ch.Truncated() - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err) - return nil, nil, AlertInternalError - } - - truncHash := params.Hash.New() - truncHash.Write(trunc) - - binder := computeFinishedData(params, binderKey, truncHash.Sum(nil)) - - // Replace the PSK extension - psk.Binders[0].Binder = binder - ch.Extensions.Add(psk) - - // If we got here, the earlier marshal succeeded (in ch.Truncated()), so - // this one should too. - clientHello, _ = HandshakeMessageFromBody(ch) - - // Compute early traffic keys - h := params.Hash.New() - h.Write(clientHello.Marshal()) - chHash := h.Sum(nil) - - earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) - logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) - clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) - } else if len(state.Opts.EarlyData) > 0 { - logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") - return nil, nil, AlertInternalError - } else { - clientHello, err = HandshakeMessageFromBody(ch) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) - return nil, nil, AlertInternalError - } - } - - logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") - nextState := ClientStateWaitSH{ - Caps: state.Caps, - Opts: state.Opts, - Params: state.Params, - OfferedDH: offeredDH, - OfferedPSK: offeredPSK, - - earlySecret: earlySecret, - earlyHash: earlyHash, - - firstClientHello: state.firstClientHello, - helloRetryRequest: state.helloRetryRequest, - clientHello: clientHello, - } - - toSend := []HandshakeAction{ - SendHandshakeMessage{clientHello}, - } - if state.Params.ClientSendingEarlyData { - toSend = append(toSend, []HandshakeAction{ - RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys}, - SendEarlyData{}, - }...) - } - - return nextState, toSend, AlertNoAlert -} - -type ClientStateWaitSH struct { - Caps Capabilities - Opts ConnectionOptions - Params ConnectionParameters - OfferedDH map[NamedGroup][]byte - OfferedPSK PreSharedKey - PSK []byte - - earlySecret []byte - earlyHash crypto.Hash - - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage - clientHello *HandshakeMessage -} - -func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") - return nil, nil, AlertUnexpectedMessage - } - - bodyGeneric, err := hm.ToBody() - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - switch body := bodyGeneric.(type) { - case *HelloRetryRequestBody: - hrr := body - - if state.helloRetryRequest != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest") - return nil, nil, AlertUnexpectedMessage - } - - // Check that the version sent by the server is the one we support - if hrr.Version != supportedVersion { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version) - return nil, nil, AlertProtocolVersion - } - - // Check that the server provided a supported ciphersuite - supportedCipherSuite := false - for _, suite := range state.Caps.CipherSuites { - supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite) - } - if !supportedCipherSuite { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite) - return nil, nil, AlertHandshakeFailure - } - - // Narrow the supported ciphersuites to the server-provided one - state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} - - // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - // The only thing we know how to respond to in an HRR is the Cookie - // extension, so if there is either no Cookie extension or anything other - // than a Cookie extension, we have to fail. - serverCookie := new(CookieExtension) - foundCookie := hrr.Extensions.Find(serverCookie) - if !foundCookie || len(hrr.Extensions) != 1 { - logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) - return nil, nil, AlertIllegalParameter - } - - // Hash the body into a pseudo-message - // XXX: Ignoring some errors here - params := cipherSuiteMap[hrr.CipherSuite] - h := params.Hash.New() - h.Write(state.clientHello.Marshal()) - firstClientHello := &HandshakeMessage{ - msgType: HandshakeTypeMessageHash, - body: h.Sum(nil), - } - - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") - return ClientStateStart{ - Caps: state.Caps, - Opts: state.Opts, - cookie: serverCookie.Cookie, - firstClientHello: firstClientHello, - helloRetryRequest: hm, - }.Next(nil) - - case *ServerHelloBody: - sh := body - - // Check that the version sent by the server is the one we support - if sh.Version != supportedVersion { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version) - return nil, nil, AlertProtocolVersion - } - - // Check that the server provided a supported ciphersuite - supportedCipherSuite := false - for _, suite := range state.Caps.CipherSuites { - supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) - } - if !supportedCipherSuite { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) - return nil, nil, AlertHandshakeFailure - } - - // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Do PSK or key agreement depending on extensions - serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} - serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} - - foundPSK := sh.Extensions.Find(&serverPSK) - foundKeyShare := sh.Extensions.Find(&serverKeyShare) - - if foundPSK && (serverPSK.SelectedIdentity == 0) { - state.Params.UsingPSK = true - } - - var dhSecret []byte - if foundKeyShare { - sks := serverKeyShare.Shares[0] - priv, ok := state.OfferedDH[sks.Group] - if !ok { - logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") - return nil, nil, AlertIllegalParameter - } - - state.Params.UsingDH = true - dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) - } - - suite := sh.CipherSuite - state.Params.CipherSuite = suite - - params, ok := cipherSuiteMap[suite] - if !ok { - logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) - return nil, nil, AlertHandshakeFailure - } - - // Start up the handshake hash - handshakeHash := params.Hash.New() - handshakeHash.Write(state.firstClientHello.Marshal()) - handshakeHash.Write(state.helloRetryRequest.Marshal()) - handshakeHash.Write(state.clientHello.Marshal()) - handshakeHash.Write(hm.Marshal()) - - // Compute handshake secrets - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - var earlySecret []byte - if state.Params.UsingPSK { - if params.Hash != state.earlyHash { - logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", - state.earlyHash, suite, params.Hash) - } - - earlySecret = state.earlySecret - } else { - earlySecret = HkdfExtract(params.Hash, zero, zero) - } - - if dhSecret == nil { - dhSecret = zero - } - - h0 := params.Hash.New().Sum(nil) - h2 := handshakeHash.Sum(nil) - preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) - handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) - clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) - serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) - preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) - masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) - - logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) - logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) - logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) - logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) - logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) - - serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") - nextState := ClientStateWaitEE{ - Caps: state.Caps, - Params: state.Params, - cryptoParams: params, - handshakeHash: handshakeHash, - certificates: state.Caps.Certificates, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, - } - toSend := []HandshakeAction{ - RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys}, - } - return nextState, toSend, AlertNoAlert - } - - logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) - return nil, nil, AlertUnexpectedMessage -} - -type ClientStateWaitEE struct { - Caps Capabilities - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - certificates []*Certificate - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { - logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - ee := EncryptedExtensionsBody{} - _, err := ee.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - serverALPN := ALPNExtension{} - serverEarlyData := EarlyDataExtension{} - - gotALPN := ee.Extensions.Find(&serverALPN) - state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) - - if gotALPN && len(serverALPN.Protocols) > 0 { - state.Params.NextProto = serverALPN.Protocols[0] - } - - state.handshakeHash.Write(hm.Marshal()) - - if state.Params.UsingPSK { - logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") - nextState := ClientStateWaitFinished{ - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.certificates, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert - } - - logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") - nextState := ClientStateWaitCertCR{ - AuthCertificate: state.AuthCertificate, - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.certificates, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert -} - -type ClientStateWaitCertCR struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - certificates []*Certificate - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil { - logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - bodyGeneric, err := hm.ToBody() - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - switch body := bodyGeneric.(type) { - case *CertificateBody: - logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") - nextState := ClientStateWaitCV{ - AuthCertificate: state.AuthCertificate, - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.certificates, - serverCertificate: body, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert - - case *CertificateRequestBody: - // A certificate request in the handshake should have a zero-length context - if len(body.CertificateRequestContext) > 0 { - logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err) - return nil, nil, AlertIllegalParameter - } - - state.Params.UsingClientAuth = true - - logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") - nextState := ClientStateWaitCert{ - AuthCertificate: state.AuthCertificate, - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.certificates, - serverCertificateRequest: body, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert - } - - return nil, nil, AlertUnexpectedMessage -} - -type ClientStateWaitCert struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - - certificates []*Certificate - serverCertificateRequest *CertificateRequestBody - - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeCertificate { - logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - cert := &CertificateBody{} - _, err := cert.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") - nextState := ClientStateWaitCV{ - AuthCertificate: state.AuthCertificate, - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.certificates, - serverCertificate: cert, - serverCertificateRequest: state.serverCertificateRequest, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert -} - -type ClientStateWaitCV struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - - certificates []*Certificate - serverCertificate *CertificateBody - serverCertificateRequest *CertificateRequestBody - - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { - logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - certVerify := CertificateVerifyBody{} - _, err := certVerify.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - hcv := state.handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey - if err := certVerify.Verify(serverPublicKey, hcv); err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") - return nil, nil, AlertHandshakeFailure - } - - if state.AuthCertificate != nil { - err := state.AuthCertificate(state.serverCertificate.CertificateList) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate") - return nil, nil, AlertBadCertificate - } - } else { - logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate") - } - - state.handshakeHash.Write(hm.Marshal()) - - logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") - nextState := ClientStateWaitFinished{ - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.certificates, - serverCertificateRequest: state.serverCertificateRequest, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert -} - -type ClientStateWaitFinished struct { - Params ConnectionParameters - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - - certificates []*Certificate - serverCertificateRequest *CertificateRequestBody - - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeFinished { - logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - // Verify server's Finished - h3 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) - logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) - - serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3) - logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) - - fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} - _, err := fin.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - if !bytes.Equal(fin.VerifyData, serverFinishedData) { - logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]", - fin.VerifyData, serverFinishedData) - return nil, nil, AlertHandshakeFailure - } - - // Update the handshake hash with the Finished - state.handshakeHash.Write(hm.Marshal()) - logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal()) - h4 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4) - - // Compute traffic secrets and keys - clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4) - serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4) - logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) - logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) - - clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret) - serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret) - - exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4) - logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret) - - // Assemble client's second flight - toSend := []HandshakeAction{} - - if state.Params.UsingEarlyData { - // Note: We only send EOED if the server is actually going to use the early - // data. Otherwise, it will never see it, and the transcripts will - // mismatch. - // EOED marshal is infallible - eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{}) - toSend = append(toSend, SendHandshakeMessage{eoedm}) - state.handshakeHash.Write(eoedm.Marshal()) - logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) - } - - clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) - toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys}) - - if state.Params.UsingClientAuth { - // Extract constraints from certicateRequest - schemes := SignatureAlgorithmsExtension{} - gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) - if !gotSchemes { - logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) - return nil, nil, AlertIllegalParameter - } - - // Select a certificate - cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates) - if err != nil { - // XXX: Signal this to the application layer? - logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) - - certificate := &CertificateBody{} - certm, err := HandshakeMessageFromBody(certificate) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, SendHandshakeMessage{certm}) - state.handshakeHash.Write(certm.Marshal()) - } else { - // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(cert.Chain)), - } - for i, entry := range cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} - } - certm, err := HandshakeMessageFromBody(certificate) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, SendHandshakeMessage{certm}) - state.handshakeHash.Write(certm.Marshal()) - - hcv := state.handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - certificateVerify := &CertificateVerifyBody{Algorithm: certScheme} - logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash) - - err = certificateVerify.Sign(cert.PrivateKey, hcv) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - certvm, err := HandshakeMessageFromBody(certificateVerify) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, SendHandshakeMessage{certvm}) - state.handshakeHash.Write(certvm.Marshal()) - } - } - - // Compute the client's Finished message - h5 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) - - clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) - logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) - - fin = &FinishedBody{ - VerifyDataLen: len(clientFinishedData), - VerifyData: clientFinishedData, - } - finm, err := HandshakeMessageFromBody(fin) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) - return nil, nil, AlertInternalError - } - - // Compute the resumption secret - state.handshakeHash.Write(finm.Marshal()) - h6 := state.handshakeHash.Sum(nil) - - resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) - logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) - - toSend = append(toSend, []HandshakeAction{ - SendHandshakeMessage{finm}, - RekeyIn{Label: "application", KeySet: serverTrafficKeys}, - RekeyOut{Label: "application", KeySet: clientTrafficKeys}, - }...) - - logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") - nextState := StateConnected{ - Params: state.Params, - isClient: true, - cryptoParams: state.cryptoParams, - resumptionSecret: resumptionSecret, - clientTrafficSecret: clientTrafficSecret, - serverTrafficSecret: serverTrafficSecret, - exporterSecret: exporterSecret, - } - return nextState, toSend, AlertNoAlert -} diff --git a/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/bifurcation/mint/common.go deleted file mode 100644 index dfda7c3e..00000000 --- a/vendor/github.com/bifurcation/mint/common.go +++ /dev/null @@ -1,152 +0,0 @@ -package mint - -import ( - "fmt" - "strconv" -) - -var ( - supportedVersion uint16 = 0x7f15 // draft-21 - - // Flags for some minor compat issues - allowWrongVersionNumber = true - allowPKCS1 = true -) - -// enum {...} ContentType; -type RecordType byte - -const ( - RecordTypeAlert RecordType = 21 - RecordTypeHandshake RecordType = 22 - RecordTypeApplicationData RecordType = 23 -) - -// enum {...} HandshakeType; -type HandshakeType byte - -const ( - // Omitted: *_RESERVED - HandshakeTypeClientHello HandshakeType = 1 - HandshakeTypeServerHello HandshakeType = 2 - HandshakeTypeNewSessionTicket HandshakeType = 4 - HandshakeTypeEndOfEarlyData HandshakeType = 5 - HandshakeTypeHelloRetryRequest HandshakeType = 6 - HandshakeTypeEncryptedExtensions HandshakeType = 8 - HandshakeTypeCertificate HandshakeType = 11 - HandshakeTypeCertificateRequest HandshakeType = 13 - HandshakeTypeCertificateVerify HandshakeType = 15 - HandshakeTypeServerConfiguration HandshakeType = 17 - HandshakeTypeFinished HandshakeType = 20 - HandshakeTypeKeyUpdate HandshakeType = 24 - HandshakeTypeMessageHash HandshakeType = 254 -) - -// uint8 CipherSuite[2]; -type CipherSuite uint16 - -const ( - // XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero - // value for this type so that we can detect when a field is set. - CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000 - TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301 - TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302 - TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303 - TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304 - TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305 -) - -func (c CipherSuite) String() string { - switch c { - case CIPHER_SUITE_UNKNOWN: - return "unknown" - case TLS_AES_128_GCM_SHA256: - return "TLS_AES_128_GCM_SHA256" - case TLS_AES_256_GCM_SHA384: - return "TLS_AES_256_GCM_SHA384" - case TLS_CHACHA20_POLY1305_SHA256: - return "TLS_CHACHA20_POLY1305_SHA256" - case TLS_AES_128_CCM_SHA256: - return "TLS_AES_128_CCM_SHA256" - case TLS_AES_256_CCM_8_SHA256: - return "TLS_AES_256_CCM_8_SHA256" - } - // cannot use %x here, since it calls String(), leading to infinite recursion - return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16)) -} - -// enum {...} SignatureScheme -type SignatureScheme uint16 - -const ( - // RSASSA-PKCS1-v1_5 algorithms - RSA_PKCS1_SHA1 SignatureScheme = 0x0201 - RSA_PKCS1_SHA256 SignatureScheme = 0x0401 - RSA_PKCS1_SHA384 SignatureScheme = 0x0501 - RSA_PKCS1_SHA512 SignatureScheme = 0x0601 - // ECDSA algorithms - ECDSA_P256_SHA256 SignatureScheme = 0x0403 - ECDSA_P384_SHA384 SignatureScheme = 0x0503 - ECDSA_P521_SHA512 SignatureScheme = 0x0603 - // RSASSA-PSS algorithms - RSA_PSS_SHA256 SignatureScheme = 0x0804 - RSA_PSS_SHA384 SignatureScheme = 0x0805 - RSA_PSS_SHA512 SignatureScheme = 0x0806 - // EdDSA algorithms - Ed25519 SignatureScheme = 0x0807 - Ed448 SignatureScheme = 0x0808 -) - -// enum {...} ExtensionType -type ExtensionType uint16 - -const ( - ExtensionTypeServerName ExtensionType = 0 - ExtensionTypeSupportedGroups ExtensionType = 10 - ExtensionTypeSignatureAlgorithms ExtensionType = 13 - ExtensionTypeALPN ExtensionType = 16 - ExtensionTypeKeyShare ExtensionType = 40 - ExtensionTypePreSharedKey ExtensionType = 41 - ExtensionTypeEarlyData ExtensionType = 42 - ExtensionTypeSupportedVersions ExtensionType = 43 - ExtensionTypeCookie ExtensionType = 44 - ExtensionTypePSKKeyExchangeModes ExtensionType = 45 - ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 -) - -// enum {...} NamedGroup -type NamedGroup uint16 - -const ( - // Elliptic Curve Groups. - P256 NamedGroup = 23 - P384 NamedGroup = 24 - P521 NamedGroup = 25 - // ECDH functions. - X25519 NamedGroup = 29 - X448 NamedGroup = 30 - // Finite field groups. - FFDHE2048 NamedGroup = 256 - FFDHE3072 NamedGroup = 257 - FFDHE4096 NamedGroup = 258 - FFDHE6144 NamedGroup = 259 - FFDHE8192 NamedGroup = 260 -) - -// enum {...} PskKeyExchangeMode; -type PSKKeyExchangeMode uint8 - -const ( - PSKModeKE PSKKeyExchangeMode = 0 - PSKModeDHEKE PSKKeyExchangeMode = 1 -) - -// enum { -// update_not_requested(0), update_requested(1), (255) -// } KeyUpdateRequest; -type KeyUpdateRequest uint8 - -const ( - KeyUpdateNotRequested KeyUpdateRequest = 0 - KeyUpdateRequested KeyUpdateRequest = 1 -) diff --git a/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/bifurcation/mint/conn.go deleted file mode 100644 index 08eb58df..00000000 --- a/vendor/github.com/bifurcation/mint/conn.go +++ /dev/null @@ -1,819 +0,0 @@ -package mint - -import ( - "crypto" - "crypto/x509" - "encoding/hex" - "fmt" - "io" - "net" - "reflect" - "sync" - "time" -) - -var WouldBlock = fmt.Errorf("Would have blocked") - -type Certificate struct { - Chain []*x509.Certificate - PrivateKey crypto.Signer -} - -type PreSharedKey struct { - CipherSuite CipherSuite - IsResumption bool - Identity []byte - Key []byte - NextProto string - ReceivedAt time.Time - ExpiresAt time.Time - TicketAgeAdd uint32 -} - -type PreSharedKeyCache interface { - Get(string) (PreSharedKey, bool) - Put(string, PreSharedKey) - Size() int -} - -type PSKMapCache map[string]PreSharedKey - -// A CookieHandler does two things: -// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest -// - validates this byte string echoed by the client in the ClientHello -type CookieHandler interface { - Generate(*Conn) ([]byte, error) - Validate(*Conn, []byte) bool -} - -func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { - psk, ok = cache[key] - return -} - -func (cache *PSKMapCache) Put(key string, psk PreSharedKey) { - (*cache)[key] = psk -} - -func (cache PSKMapCache) Size() int { - return len(cache) -} - -// Config is the struct used to pass configuration settings to a TLS client or -// server instance. The settings for client and server are pretty different, -// but we just throw them all in here. -type Config struct { - // Client fields - ServerName string - - // Server fields - SendSessionTickets bool - TicketLifetime uint32 - TicketLen int - EarlyDataLifetime uint32 - AllowEarlyData bool - // Require the client to echo a cookie. - RequireCookie bool - // If cookies are required and no CookieHandler is set, a default cookie handler is used. - // The default cookie handler uses 32 random bytes as a cookie. - CookieHandler CookieHandler - RequireClientAuth bool - - // Shared fields - Certificates []*Certificate - AuthCertificate func(chain []CertificateEntry) error - CipherSuites []CipherSuite - Groups []NamedGroup - SignatureSchemes []SignatureScheme - NextProtos []string - PSKs PreSharedKeyCache - PSKModes []PSKKeyExchangeMode - NonBlocking bool - - // The same config object can be shared among different connections, so it - // needs its own mutex - mutex sync.RWMutex -} - -// Clone returns a shallow clone of c. It is safe to clone a Config that is -// being used concurrently by a TLS client or server. -func (c *Config) Clone() *Config { - c.mutex.Lock() - defer c.mutex.Unlock() - - return &Config{ - ServerName: c.ServerName, - - SendSessionTickets: c.SendSessionTickets, - TicketLifetime: c.TicketLifetime, - TicketLen: c.TicketLen, - EarlyDataLifetime: c.EarlyDataLifetime, - AllowEarlyData: c.AllowEarlyData, - RequireCookie: c.RequireCookie, - RequireClientAuth: c.RequireClientAuth, - - Certificates: c.Certificates, - AuthCertificate: c.AuthCertificate, - CipherSuites: c.CipherSuites, - Groups: c.Groups, - SignatureSchemes: c.SignatureSchemes, - NextProtos: c.NextProtos, - PSKs: c.PSKs, - PSKModes: c.PSKModes, - NonBlocking: c.NonBlocking, - } -} - -func (c *Config) Init(isClient bool) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - // Set defaults - if len(c.CipherSuites) == 0 { - c.CipherSuites = defaultSupportedCipherSuites - } - if len(c.Groups) == 0 { - c.Groups = defaultSupportedGroups - } - if len(c.SignatureSchemes) == 0 { - c.SignatureSchemes = defaultSignatureSchemes - } - if c.TicketLen == 0 { - c.TicketLen = defaultTicketLen - } - if !reflect.ValueOf(c.PSKs).IsValid() { - c.PSKs = &PSKMapCache{} - } - if len(c.PSKModes) == 0 { - c.PSKModes = defaultPSKModes - } - - // If there is no certificate, generate one - if !isClient && len(c.Certificates) == 0 { - logf(logTypeHandshake, "Generating key name=%v", c.ServerName) - priv, err := newSigningKey(RSA_PSS_SHA256) - if err != nil { - return err - } - - cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv) - if err != nil { - return err - } - - c.Certificates = []*Certificate{ - { - Chain: []*x509.Certificate{cert}, - PrivateKey: priv, - }, - } - } - - return nil -} - -func (c *Config) ValidForServer() bool { - return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) || - (len(c.Certificates) > 0 && - len(c.Certificates[0].Chain) > 0 && - c.Certificates[0].PrivateKey != nil) -} - -func (c *Config) ValidForClient() bool { - return len(c.ServerName) > 0 -} - -var ( - defaultSupportedCipherSuites = []CipherSuite{ - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - } - - defaultSupportedGroups = []NamedGroup{ - P256, - P384, - FFDHE2048, - X25519, - } - - defaultSignatureSchemes = []SignatureScheme{ - RSA_PSS_SHA256, - RSA_PSS_SHA384, - RSA_PSS_SHA512, - ECDSA_P256_SHA256, - ECDSA_P384_SHA384, - ECDSA_P521_SHA512, - } - - defaultTicketLen = 16 - - defaultPSKModes = []PSKKeyExchangeMode{ - PSKModeKE, - PSKModeDHEKE, - } -) - -type ConnectionState struct { - HandshakeState string // string representation of the handshake state. - CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) - PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement - NextProto string // Selected ALPN proto -} - -// Conn implements the net.Conn interface, as with "crypto/tls" -// * Read, Write, and Close are provided locally -// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn -type Conn struct { - config *Config - conn net.Conn - isClient bool - - EarlyData []byte - - state StateConnected - hState HandshakeState - handshakeMutex sync.Mutex - handshakeAlert Alert - handshakeComplete bool - - readBuffer []byte - in, out *RecordLayer - hIn, hOut *HandshakeLayer - - extHandler AppExtensionHandler -} - -func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { - c := &Conn{conn: conn, config: config, isClient: isClient} - c.in = NewRecordLayer(c.conn) - c.out = NewRecordLayer(c.conn) - c.hIn = NewHandshakeLayer(c.in) - c.hIn.nonblocking = c.config.NonBlocking - c.hOut = NewHandshakeLayer(c.out) - return c -} - -// Read up -func (c *Conn) consumeRecord() error { - pt, err := c.in.ReadRecord() - if pt == nil { - logf(logTypeIO, "extendBuffer returns error %v", err) - return err - } - - switch pt.contentType { - case RecordTypeHandshake: - logf(logTypeHandshake, "Received post-handshake message") - // We do not support fragmentation of post-handshake handshake messages. - // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage() - start := 0 - for start < len(pt.fragment) { - if len(pt.fragment[start:]) < handshakeHeaderLen { - return fmt.Errorf("Post-handshake handshake message too short for header") - } - - hm := &HandshakeMessage{} - hm.msgType = HandshakeType(pt.fragment[start]) - hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) - - if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen { - return fmt.Errorf("Post-handshake handshake message too short for body") - } - hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen] - - // Advance state machine - state, actions, alert := c.state.Next(hm) - - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error in state transition: %v", alert) - c.sendAlert(alert) - return io.EOF - } - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return io.EOF - } - } - - // XXX: If we want to support more advanced cases, e.g., post-handshake - // authentication, we'll need to allow transitions other than - // Connected -> Connected - var connected bool - c.state, connected = state.(StateConnected) - if !connected { - logf(logTypeHandshake, "Disconnected after state transition: %v", alert) - c.sendAlert(alert) - return io.EOF - } - - start += handshakeHeaderLen + hmLen - } - case RecordTypeAlert: - logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) - if len(pt.fragment) != 2 { - c.sendAlert(AlertUnexpectedMessage) - return io.EOF - } - if Alert(pt.fragment[1]) == AlertCloseNotify { - return io.EOF - } - - switch pt.fragment[0] { - case AlertLevelWarning: - // drop on the floor - case AlertLevelError: - return Alert(pt.fragment[1]) - default: - c.sendAlert(AlertUnexpectedMessage) - return io.EOF - } - - case RecordTypeApplicationData: - c.readBuffer = append(c.readBuffer, pt.fragment...) - logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) - } - - return err -} - -// Read application data up to the size of buffer. Handshake and alert records -// are consumed by the Conn object directly. -func (c *Conn) Read(buffer []byte) (int, error) { - logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) - if alert := c.Handshake(); alert != AlertNoAlert { - return 0, alert - } - - if len(buffer) == 0 { - return 0, nil - } - - // Lock the input channel - c.in.Lock() - defer c.in.Unlock() - for len(c.readBuffer) == 0 { - err := c.consumeRecord() - - // err can be nil if consumeRecord processed a non app-data - // record. - if err != nil { - if c.config.NonBlocking || err != WouldBlock { - logf(logTypeIO, "conn.Read returns err=%v", err) - return 0, err - } - } - } - - var read int - n := len(buffer) - logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) - if len(c.readBuffer) <= n { - buffer = buffer[:len(c.readBuffer)] - copy(buffer, c.readBuffer) - read = len(c.readBuffer) - c.readBuffer = c.readBuffer[:0] - } else { - logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) - copy(buffer[:n], c.readBuffer[:n]) - c.readBuffer = c.readBuffer[n:] - read = n - } - - logf(logTypeVerbose, "Returning %v", string(buffer)) - return read, nil -} - -// Write application data -func (c *Conn) Write(buffer []byte) (int, error) { - // Lock the output channel - c.out.Lock() - defer c.out.Unlock() - - // Send full-size fragments - var start int - sent := 0 - for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { - err := c.out.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeApplicationData, - fragment: buffer[start : start+maxFragmentLen], - }) - - if err != nil { - return sent, err - } - sent += maxFragmentLen - } - - // Send a final partial fragment if necessary - if start < len(buffer) { - err := c.out.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeApplicationData, - fragment: buffer[start:], - }) - - if err != nil { - return sent, err - } - sent += len(buffer[start:]) - } - return sent, nil -} - -// sendAlert sends a TLS alert message. -// c.out.Mutex <= L. -func (c *Conn) sendAlert(err Alert) error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - var level int - switch err { - case AlertNoRenegotiation, AlertCloseNotify: - level = AlertLevelWarning - default: - level = AlertLevelError - } - - buf := []byte{byte(err), byte(level)} - c.out.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeAlert, - fragment: buf, - }) - - // close_notify and end_of_early_data are not actually errors - if level == AlertLevelWarning { - return &net.OpError{Op: "local error", Err: err} - } - - return c.Close() -} - -// Close closes the connection. -func (c *Conn) Close() error { - // XXX crypto/tls has an interlock with Write here. Do we need that? - - return c.conn.Close() -} - -// LocalAddr returns the local network address. -func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -// RemoteAddr returns the remote network address. -func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline sets the read and write deadlines associated with the connection. -// A zero value for t means Read and Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying connection. -// A zero value for t means Read will not time out. -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying connection. -// A zero value for t means Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { - label := "[server]" - if c.isClient { - label = "[client]" - } - - switch action := actionGeneric.(type) { - case SendHandshakeMessage: - err := c.hOut.WriteMessage(action.Message) - if err != nil { - logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) - return AlertInternalError - } - - case RekeyIn: - logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet) - err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) - if err != nil { - logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) - return AlertInternalError - } - - case RekeyOut: - logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet) - err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) - if err != nil { - logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) - return AlertInternalError - } - - case SendEarlyData: - logf(logTypeHandshake, "%s Sending early data...", label) - _, err := c.Write(c.EarlyData) - if err != nil { - logf(logTypeHandshake, "%s Error writing early data: %v", label, err) - return AlertInternalError - } - - case ReadPastEarlyData: - logf(logTypeHandshake, "%s Reading past early data...", label) - // Scan past all records that fail to decrypt - _, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok := err.(DecryptError) - - for ok { - _, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok = err.(DecryptError) - } - - case ReadEarlyData: - logf(logTypeHandshake, "%s Reading early data...", label) - t, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type(1): %v", label, t) - - for t == RecordTypeApplicationData { - // Read a record into the buffer. Note that this is safe - // in blocking mode because we read the record in in - // PeekRecordType. - pt, err := c.in.ReadRecord() - if err != nil { - logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) - return AlertInternalError - } - - logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) - c.EarlyData = append(c.EarlyData, pt.fragment...) - - t, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type (2): %v", label, t) - } - logf(logTypeHandshake, "%s Done reading early data", label) - - case StorePSK: - logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) - if c.isClient { - // Clients look up PSKs based on server name - c.config.PSKs.Put(c.config.ServerName, action.PSK) - } else { - // Servers look them up based on the identity in the extension - c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK) - } - - default: - logf(logTypeHandshake, "%s Unknown actionuction type", label) - return AlertInternalError - } - - return AlertNoAlert -} - -func (c *Conn) HandshakeSetup() Alert { - var state HandshakeState - var actions []HandshakeAction - var alert Alert - - if err := c.config.Init(c.isClient); err != nil { - logf(logTypeHandshake, "Error initializing config: %v", err) - return AlertInternalError - } - - // Set things up - caps := Capabilities{ - CipherSuites: c.config.CipherSuites, - Groups: c.config.Groups, - SignatureSchemes: c.config.SignatureSchemes, - PSKs: c.config.PSKs, - PSKModes: c.config.PSKModes, - AllowEarlyData: c.config.AllowEarlyData, - RequireCookie: c.config.RequireCookie, - CookieHandler: c.config.CookieHandler, - RequireClientAuth: c.config.RequireClientAuth, - NextProtos: c.config.NextProtos, - Certificates: c.config.Certificates, - ExtensionHandler: c.extHandler, - } - opts := ConnectionOptions{ - ServerName: c.config.ServerName, - NextProtos: c.config.NextProtos, - EarlyData: c.EarlyData, - } - - if caps.RequireCookie && caps.CookieHandler == nil { - caps.CookieHandler = &defaultCookieHandler{} - } - - if c.isClient { - state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error initializing client state: %v", alert) - return alert - } - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - return alert - } - } - } else { - state = ServerStateStart{Caps: caps, conn: c} - } - - c.hState = state - - return AlertNoAlert -} - -// Handshake causes a TLS handshake on the connection. The `isClient` member -// determines whether a client or server handshake is performed. If a -// handshake has already been performed, then its result will be returned. -func (c *Conn) Handshake() Alert { - label := "[server]" - if c.isClient { - label = "[client]" - } - - // TODO Lock handshakeMutex - // TODO Remove CloseNotify hack - if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify { - logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert) - return c.handshakeAlert - } - if c.handshakeComplete { - return AlertNoAlert - } - - var alert Alert - if c.hState == nil { - logf(logTypeHandshake, "%s First time through handshake, setting up", label) - alert = c.HandshakeSetup() - if alert != AlertNoAlert { - return alert - } - } else { - logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState) - } - - state := c.hState - _, connected := state.(StateConnected) - - var actions []HandshakeAction - - for !connected { - // Read a handshake message - hm, err := c.hIn.ReadMessage() - if err == WouldBlock { - logf(logTypeHandshake, "%s Would block reading message: %v", label, err) - return AlertWouldBlock - } - if err != nil { - logf(logTypeHandshake, "%s Error reading message: %v", label, err) - c.sendAlert(AlertCloseNotify) - return AlertCloseNotify - } - logf(logTypeHandshake, "Read message with type: %v", hm.msgType) - - // Advance the state machine - state, actions, alert = state.Next(hm) - - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error in state transition: %v", alert) - return alert - } - - for index, action := range actions { - logf(logTypeHandshake, "%s taking next action (%d)", label, index) - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert - } - } - - c.hState = state - logf(logTypeHandshake, "state is now %s", c.GetHsState()) - - _, connected = state.(StateConnected) - } - - c.state = state.(StateConnected) - - // Send NewSessionTicket if acting as server - if !c.isClient && c.config.SendSessionTickets { - actions, alert := c.state.NewSessionTicket( - c.config.TicketLen, - c.config.TicketLifetime, - c.config.EarlyDataLifetime) - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert - } - } - } - - c.handshakeComplete = true - return AlertNoAlert -} - -func (c *Conn) SendKeyUpdate(requestUpdate bool) error { - if !c.handshakeComplete { - return fmt.Errorf("Cannot update keys until after handshake") - } - - request := KeyUpdateNotRequested - if requestUpdate { - request = KeyUpdateRequested - } - - // Create the key update and update state - actions, alert := c.state.KeyUpdate(request) - if alert != AlertNoAlert { - c.sendAlert(alert) - return fmt.Errorf("Alert while generating key update: %v", alert) - } - - // Take actions (send key update and rekey) - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - c.sendAlert(alert) - return fmt.Errorf("Alert during key update actions: %v", alert) - } - } - - return nil -} - -func (c *Conn) GetHsState() string { - return reflect.TypeOf(c.hState).Name() -} - -func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - _, connected := c.hState.(StateConnected) - if !connected { - return nil, fmt.Errorf("Cannot compute exporter when state is not connected") - } - - if c.state.exporterSecret == nil { - return nil, fmt.Errorf("Internal error: no exporter secret") - } - - h0 := c.state.cryptoParams.Hash.New().Sum(nil) - tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0) - - hc := c.state.cryptoParams.Hash.New().Sum(context) - return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil -} - -func (c *Conn) State() ConnectionState { - state := ConnectionState{ - HandshakeState: c.GetHsState(), - } - - if c.handshakeComplete { - state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] - state.NextProto = c.state.Params.NextProto - } - - return state -} - -func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error { - if c.hState != nil { - return fmt.Errorf("Can't set extension handler after setup") - } - - c.extHandler = h - return nil -} diff --git a/vendor/github.com/bifurcation/mint/crypto.go b/vendor/github.com/bifurcation/mint/crypto.go deleted file mode 100644 index 60d34377..00000000 --- a/vendor/github.com/bifurcation/mint/crypto.go +++ /dev/null @@ -1,654 +0,0 @@ -package mint - -import ( - "bytes" - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/hmac" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "fmt" - "math/big" - "time" - - "golang.org/x/crypto/curve25519" - - // Blank includes to ensure hash support - _ "crypto/sha1" - _ "crypto/sha256" - _ "crypto/sha512" -) - -var prng = rand.Reader - -type aeadFactory func(key []byte) (cipher.AEAD, error) - -type CipherSuiteParams struct { - Suite CipherSuite - Cipher aeadFactory // Cipher factory - Hash crypto.Hash // Hash function - KeyLen int // Key length in octets - IvLen int // IV length in octets -} - -type signatureAlgorithm uint8 - -const ( - signatureAlgorithmUnknown = iota - signatureAlgorithmRSA_PKCS1 - signatureAlgorithmRSA_PSS - signatureAlgorithmECDSA -) - -var ( - hashMap = map[SignatureScheme]crypto.Hash{ - RSA_PKCS1_SHA1: crypto.SHA1, - RSA_PKCS1_SHA256: crypto.SHA256, - RSA_PKCS1_SHA384: crypto.SHA384, - RSA_PKCS1_SHA512: crypto.SHA512, - ECDSA_P256_SHA256: crypto.SHA256, - ECDSA_P384_SHA384: crypto.SHA384, - ECDSA_P521_SHA512: crypto.SHA512, - RSA_PSS_SHA256: crypto.SHA256, - RSA_PSS_SHA384: crypto.SHA384, - RSA_PSS_SHA512: crypto.SHA512, - } - - sigMap = map[SignatureScheme]signatureAlgorithm{ - RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1, - RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1, - RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1, - RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1, - ECDSA_P256_SHA256: signatureAlgorithmECDSA, - ECDSA_P384_SHA384: signatureAlgorithmECDSA, - ECDSA_P521_SHA512: signatureAlgorithmECDSA, - RSA_PSS_SHA256: signatureAlgorithmRSA_PSS, - RSA_PSS_SHA384: signatureAlgorithmRSA_PSS, - RSA_PSS_SHA512: signatureAlgorithmRSA_PSS, - } - - curveMap = map[SignatureScheme]NamedGroup{ - ECDSA_P256_SHA256: P256, - ECDSA_P384_SHA384: P384, - ECDSA_P521_SHA512: P521, - } - - newAESGCM = func(key []byte) (cipher.AEAD, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - // TLS always uses 12-byte nonces - return cipher.NewGCMWithNonceSize(block, 12) - } - - cipherSuiteMap = map[CipherSuite]CipherSuiteParams{ - TLS_AES_128_GCM_SHA256: { - Suite: TLS_AES_128_GCM_SHA256, - Cipher: newAESGCM, - Hash: crypto.SHA256, - KeyLen: 16, - IvLen: 12, - }, - TLS_AES_256_GCM_SHA384: { - Suite: TLS_AES_256_GCM_SHA384, - Cipher: newAESGCM, - Hash: crypto.SHA384, - KeyLen: 32, - IvLen: 12, - }, - } - - x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{ - RSA_PKCS1_SHA1: x509.SHA1WithRSA, - RSA_PKCS1_SHA256: x509.SHA256WithRSA, - RSA_PKCS1_SHA384: x509.SHA384WithRSA, - RSA_PKCS1_SHA512: x509.SHA512WithRSA, - ECDSA_P256_SHA256: x509.ECDSAWithSHA256, - ECDSA_P384_SHA384: x509.ECDSAWithSHA384, - ECDSA_P521_SHA512: x509.ECDSAWithSHA512, - } - - defaultRSAKeySize = 2048 -) - -func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) { - switch group { - case P256: - crv = elliptic.P256() - case P384: - crv = elliptic.P384() - case P521: - crv = elliptic.P521() - } - return -} - -func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) { - switch key.Curve.Params().Name { - case elliptic.P256().Params().Name: - g = P256 - case elliptic.P384().Params().Name: - g = P384 - case elliptic.P521().Params().Name: - g = P521 - } - return -} - -func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) { - size = 0 - switch group { - case X25519: - size = 32 - case P256: - size = 65 - case P384: - size = 97 - case P521: - size = 133 - case FFDHE2048: - size = 256 - case FFDHE3072: - size = 384 - case FFDHE4096: - size = 512 - case FFDHE6144: - size = 768 - case FFDHE8192: - size = 1024 - } - return -} - -func primeFromNamedGroup(group NamedGroup) (p *big.Int) { - switch group { - case FFDHE2048: - p = finiteFieldPrime2048 - case FFDHE3072: - p = finiteFieldPrime3072 - case FFDHE4096: - p = finiteFieldPrime4096 - case FFDHE6144: - p = finiteFieldPrime6144 - case FFDHE8192: - p = finiteFieldPrime8192 - } - return -} - -func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool { - sigType := sigMap[alg] - switch key.(type) { - case *rsa.PrivateKey: - return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS - case *ecdsa.PrivateKey: - return sigType == signatureAlgorithmECDSA - default: - return false - } -} - -func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) { - primeLen := len(p.Bytes()) - for { - // g = 2 for all ffdhe groups - priv, err = rand.Int(prng, p) - if err != nil { - return - } - - pub = big.NewInt(0) - pub.Exp(big.NewInt(2), priv, p) - - if len(pub.Bytes()) == primeLen { - return - } - } -} - -func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) { - switch group { - case P256, P384, P521: - var x, y *big.Int - crv := curveFromNamedGroup(group) - priv, x, y, err = elliptic.GenerateKey(crv, prng) - if err != nil { - return - } - - pub = elliptic.Marshal(crv, x, y) - return - - case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: - p := primeFromNamedGroup(group) - x, X, err2 := ffdheKeyShareFromPrime(p) - if err2 != nil { - err = err2 - return - } - - priv = x.Bytes() - pubBytes := X.Bytes() - - numBytes := keyExchangeSizeFromNamedGroup(group) - - pub = make([]byte, numBytes) - copy(pub[numBytes-len(pubBytes):], pubBytes) - - return - - case X25519: - var private, public [32]byte - _, err = prng.Read(private[:]) - if err != nil { - return - } - - curve25519.ScalarBaseMult(&public, &private) - priv = private[:] - pub = public[:] - return - - default: - return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group) - } -} - -func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) { - switch group { - case P256, P384, P521: - if len(pub) != keyExchangeSizeFromNamedGroup(group) { - return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") - } - - crv := curveFromNamedGroup(group) - pubX, pubY := elliptic.Unmarshal(crv, pub) - x, _ := crv.Params().ScalarMult(pubX, pubY, priv) - xBytes := x.Bytes() - - numBytes := len(crv.Params().P.Bytes()) - - ret := make([]byte, numBytes) - copy(ret[numBytes-len(xBytes):], xBytes) - - return ret, nil - - case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: - numBytes := keyExchangeSizeFromNamedGroup(group) - if len(pub) != numBytes { - return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") - } - p := primeFromNamedGroup(group) - x := big.NewInt(0).SetBytes(priv) - Y := big.NewInt(0).SetBytes(pub) - ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes() - - ret := make([]byte, numBytes) - copy(ret[numBytes-len(ZBytes):], ZBytes) - - return ret, nil - - case X25519: - if len(pub) != keyExchangeSizeFromNamedGroup(group) { - return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") - } - - var private, public, ret [32]byte - copy(private[:], priv) - copy(public[:], pub) - curve25519.ScalarMult(&ret, &private, &public) - - return ret[:], nil - - default: - return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group) - } -} - -func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { - switch sig { - case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256, - RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, - RSA_PSS_SHA256, RSA_PSS_SHA384, - RSA_PSS_SHA512: - return rsa.GenerateKey(prng, defaultRSAKeySize) - case ECDSA_P256_SHA256: - return ecdsa.GenerateKey(elliptic.P256(), prng) - case ECDSA_P384_SHA384: - return ecdsa.GenerateKey(elliptic.P384(), prng) - case ECDSA_P521_SHA512: - return ecdsa.GenerateKey(elliptic.P521(), prng) - default: - return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig) - } -} - -func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) { - sigAlg, ok := x509AlgMap[alg] - if !ok { - return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg) - } - if len(name) == 0 { - return nil, fmt.Errorf("tls.selfsigned: No name provided") - } - - serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0)) - if err != nil { - return nil, err - } - - template := &x509.Certificate{ - SerialNumber: serial, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - SignatureAlgorithm: sigAlg, - Subject: pkix.Name{CommonName: name}, - DNSNames: []string{name}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - } - der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv) - if err != nil { - return nil, err - } - - // It is safe to ignore the error here because we're parsing known-good data - cert, _ := x509.ParseCertificate(der) - return cert, nil -} - -// XXX(rlb): Copied from crypto/x509 -type ecdsaSignature struct { - R, S *big.Int -} - -func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) { - var opts crypto.SignerOpts - - hash := hashMap[alg] - if hash == crypto.SHA1 { - return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") - } - - sigType := sigMap[alg] - var realInput []byte - switch key := privateKey.(type) { - case *rsa.PrivateKey: - switch { - case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size()) - opts = hash - case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - fallthrough - case sigType == signatureAlgorithmRSA_PSS: - logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size()) - opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} - default: - return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key") - } - - h := hash.New() - h.Write(sigInput) - realInput = h.Sum(nil) - case *ecdsa.PrivateKey: - if sigType != signatureAlgorithmECDSA { - return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key") - } - - algGroup := curveMap[alg] - keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey)) - if algGroup != keyGroup { - return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination") - } - - h := hash.New() - h.Write(sigInput) - realInput = h.Sum(nil) - default: - return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type") - } - - sig, err := privateKey.Sign(prng, realInput, opts) - logf(logTypeCrypto, "signature: %x", sig) - return sig, err -} - -func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error { - hash := hashMap[alg] - - if hash == crypto.SHA1 { - return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") - } - - sigType := sigMap[alg] - switch pub := publicKey.(type) { - case *rsa.PublicKey: - switch { - case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size()) - - h := hash.New() - h.Write(sigInput) - realInput := h.Sum(nil) - return rsa.VerifyPKCS1v15(pub, hash, realInput, sig) - case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - fallthrough - case sigType == signatureAlgorithmRSA_PSS: - logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size()) - opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} - - h := hash.New() - h.Write(sigInput) - realInput := h.Sum(nil) - return rsa.VerifyPSS(pub, hash, realInput, sig, opts) - default: - return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key") - } - - case *ecdsa.PublicKey: - if sigType != signatureAlgorithmECDSA { - return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key") - } - - if curveMap[alg] != namedGroupFromECDSAKey(pub) { - return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key") - } - - ecdsaSig := new(ecdsaSignature) - if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { - return err - } else if len(rest) != 0 { - return fmt.Errorf("tls.verify: trailing data after ECDSA signature") - } - if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { - return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values") - } - - h := hash.New() - h.Write(sigInput) - realInput := h.Sum(nil) - if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) { - return fmt.Errorf("tls.verify: ECDSA verification failure") - } - return nil - default: - return fmt.Errorf("tls.verify: Unsupported key type") - } -} - -// 0 -// | -// v -// PSK -> HKDF-Extract = Early Secret -// | -// +-----> Derive-Secret(., -// | "ext binder" | -// | "res binder", -// | "") -// | = binder_key -// | -// +-----> Derive-Secret(., "c e traffic", -// | ClientHello) -// | = client_early_traffic_secret -// | -// +-----> Derive-Secret(., "e exp master", -// | ClientHello) -// | = early_exporter_master_secret -// v -// Derive-Secret(., "derived", "") -// | -// v -// (EC)DHE -> HKDF-Extract = Handshake Secret -// | -// +-----> Derive-Secret(., "c hs traffic", -// | ClientHello...ServerHello) -// | = client_handshake_traffic_secret -// | -// +-----> Derive-Secret(., "s hs traffic", -// | ClientHello...ServerHello) -// | = server_handshake_traffic_secret -// v -// Derive-Secret(., "derived", "") -// | -// v -// 0 -> HKDF-Extract = Master Secret -// | -// +-----> Derive-Secret(., "c ap traffic", -// | ClientHello...server Finished) -// | = client_application_traffic_secret_0 -// | -// +-----> Derive-Secret(., "s ap traffic", -// | ClientHello...server Finished) -// | = server_application_traffic_secret_0 -// | -// +-----> Derive-Secret(., "exp master", -// | ClientHello...server Finished) -// | = exporter_master_secret -// | -// +-----> Derive-Secret(., "res master", -// ClientHello...client Finished) -// = resumption_master_secret - -// From RFC 5869 -// PRK = HMAC-Hash(salt, IKM) -func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte { - salt := saltIn - - // if [salt is] not provided, it is set to a string of HashLen zeros - if salt == nil { - salt = bytes.Repeat([]byte{0}, hash.Size()) - } - - h := hmac.New(hash.New, salt) - h.Write(input) - out := h.Sum(nil) - - logf(logTypeCrypto, "HKDF Extract:\n") - logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt) - logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input) - logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out) - - return out -} - -const ( - labelExternalBinder = "ext binder" - labelResumptionBinder = "res binder" - labelEarlyTrafficSecret = "c e traffic" - labelEarlyExporterSecret = "e exp master" - labelClientHandshakeTrafficSecret = "c hs traffic" - labelServerHandshakeTrafficSecret = "s hs traffic" - labelClientApplicationTrafficSecret = "c ap traffic" - labelServerApplicationTrafficSecret = "s ap traffic" - labelExporterSecret = "exp master" - labelResumptionSecret = "res master" - labelDerived = "derived" - labelFinished = "finished" - labelResumption = "resumption" -) - -// struct HkdfLabel { -// uint16 length; -// opaque label<9..255>; -// opaque hash_value<0..255>; -// }; -func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte { - label := "tls13 " + labelIn - - labelLen := len(label) - hashLen := len(hashValue) - hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen) - hkdfLabel[0] = byte(outLen >> 8) - hkdfLabel[1] = byte(outLen) - hkdfLabel[2] = byte(labelLen) - copy(hkdfLabel[3:3+labelLen], []byte(label)) - hkdfLabel[3+labelLen] = byte(hashLen) - copy(hkdfLabel[3+labelLen+1:], hashValue) - - return hkdfLabel -} - -func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte { - out := []byte{} - T := []byte{} - i := byte(1) - for len(out) < outLen { - block := append(T, info...) - block = append(block, i) - - h := hmac.New(hash.New, prk) - h.Write(block) - - T = h.Sum(nil) - out = append(out, T...) - i++ - } - return out[:outLen] -} - -func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte { - info := hkdfEncodeLabel(label, hashValue, outLen) - derived := HkdfExpand(hash, secret, info, outLen) - - logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen) - logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret) - logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue) - logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info) - logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived) - - return derived -} - -func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte { - return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size()) -} - -func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte { - macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size()) - mac := hmac.New(params.Hash.New, macKey) - mac.Write(input) - return mac.Sum(nil) -} - -type keySet struct { - cipher aeadFactory - key []byte - iv []byte -} - -func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { - logf(logTypeCrypto, "making traffic keys: secret=%x", secret) - return keySet{ - cipher: params.Cipher, - key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen), - iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), - } -} diff --git a/vendor/github.com/bifurcation/mint/extensions.go b/vendor/github.com/bifurcation/mint/extensions.go deleted file mode 100644 index 1dbe7bd2..00000000 --- a/vendor/github.com/bifurcation/mint/extensions.go +++ /dev/null @@ -1,586 +0,0 @@ -package mint - -import ( - "bytes" - "fmt" - - "github.com/bifurcation/mint/syntax" -) - -type ExtensionBody interface { - Type() ExtensionType - Marshal() ([]byte, error) - Unmarshal(data []byte) (int, error) -} - -// struct { -// ExtensionType extension_type; -// opaque extension_data<0..2^16-1>; -// } Extension; -type Extension struct { - ExtensionType ExtensionType - ExtensionData []byte `tls:"head=2"` -} - -func (ext Extension) Marshal() ([]byte, error) { - return syntax.Marshal(ext) -} - -func (ext *Extension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, ext) -} - -type ExtensionList []Extension - -type extensionListInner struct { - List []Extension `tls:"head=2"` -} - -func (el ExtensionList) Marshal() ([]byte, error) { - return syntax.Marshal(extensionListInner{el}) -} - -func (el *ExtensionList) Unmarshal(data []byte) (int, error) { - var list extensionListInner - read, err := syntax.Unmarshal(data, &list) - if err != nil { - return 0, err - } - - *el = list.List - return read, nil -} - -func (el *ExtensionList) Add(src ExtensionBody) error { - data, err := src.Marshal() - if err != nil { - return err - } - - if el == nil { - el = new(ExtensionList) - } - - // If one already exists with this type, replace it - for i := range *el { - if (*el)[i].ExtensionType == src.Type() { - (*el)[i].ExtensionData = data - return nil - } - } - - // Otherwise append - *el = append(*el, Extension{ - ExtensionType: src.Type(), - ExtensionData: data, - }) - return nil -} - -func (el ExtensionList) Find(dst ExtensionBody) bool { - for _, ext := range el { - if ext.ExtensionType == dst.Type() { - _, err := dst.Unmarshal(ext.ExtensionData) - return err == nil - } - } - return false -} - -// struct { -// NameType name_type; -// select (name_type) { -// case host_name: HostName; -// } name; -// } ServerName; -// -// enum { -// host_name(0), (255) -// } NameType; -// -// opaque HostName<1..2^16-1>; -// -// struct { -// ServerName server_name_list<1..2^16-1> -// } ServerNameList; -// -// But we only care about the case where there's a single DNS hostname. We -// will never create anything else, and throw if we receive something else -// -// 2 1 2 -// | listLen | NameType | nameLen | name | -type ServerNameExtension string - -type serverNameInner struct { - NameType uint8 - HostName []byte `tls:"head=2,min=1"` -} - -type serverNameListInner struct { - ServerNameList []serverNameInner `tls:"head=2,min=1"` -} - -func (sni ServerNameExtension) Type() ExtensionType { - return ExtensionTypeServerName -} - -func (sni ServerNameExtension) Marshal() ([]byte, error) { - list := serverNameListInner{ - ServerNameList: []serverNameInner{{ - NameType: 0x00, // host_name - HostName: []byte(sni), - }}, - } - - return syntax.Marshal(list) -} - -func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) { - var list serverNameListInner - read, err := syntax.Unmarshal(data, &list) - if err != nil { - return 0, err - } - - // Syntax requires at least one entry - // Entries beyond the first are ignored - if nameType := list.ServerNameList[0].NameType; nameType != 0x00 { - return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType) - } - - *sni = ServerNameExtension(list.ServerNameList[0].HostName) - return read, nil -} - -// struct { -// NamedGroup group; -// opaque key_exchange<1..2^16-1>; -// } KeyShareEntry; -// -// struct { -// select (Handshake.msg_type) { -// case client_hello: -// KeyShareEntry client_shares<0..2^16-1>; -// -// case hello_retry_request: -// NamedGroup selected_group; -// -// case server_hello: -// KeyShareEntry server_share; -// }; -// } KeyShare; -type KeyShareEntry struct { - Group NamedGroup - KeyExchange []byte `tls:"head=2,min=1"` -} - -func (kse KeyShareEntry) SizeValid() bool { - return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group) -} - -type KeyShareExtension struct { - HandshakeType HandshakeType - SelectedGroup NamedGroup - Shares []KeyShareEntry -} - -type KeyShareClientHelloInner struct { - ClientShares []KeyShareEntry `tls:"head=2,min=0"` -} -type KeyShareHelloRetryInner struct { - SelectedGroup NamedGroup -} -type KeyShareServerHelloInner struct { - ServerShare KeyShareEntry -} - -func (ks KeyShareExtension) Type() ExtensionType { - return ExtensionTypeKeyShare -} - -func (ks KeyShareExtension) Marshal() ([]byte, error) { - switch ks.HandshakeType { - case HandshakeTypeClientHello: - for _, share := range ks.Shares { - if !share.SizeValid() { - return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - } - return syntax.Marshal(KeyShareClientHelloInner{ks.Shares}) - - case HandshakeTypeHelloRetryRequest: - if len(ks.Shares) > 0 { - return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest") - } - - return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup}) - - case HandshakeTypeServerHello: - if len(ks.Shares) != 1 { - return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share") - } - - if !ks.Shares[0].SizeValid() { - return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - - return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]}) - - default: - return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed") - } -} - -func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) { - switch ks.HandshakeType { - case HandshakeTypeClientHello: - var inner KeyShareClientHelloInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - for _, share := range inner.ClientShares { - if !share.SizeValid() { - return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - } - - ks.Shares = inner.ClientShares - return read, nil - - case HandshakeTypeHelloRetryRequest: - var inner KeyShareHelloRetryInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - ks.SelectedGroup = inner.SelectedGroup - return read, nil - - case HandshakeTypeServerHello: - var inner KeyShareServerHelloInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - if !inner.ServerShare.SizeValid() { - return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - - ks.Shares = []KeyShareEntry{inner.ServerShare} - return read, nil - - default: - return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed") - } -} - -// struct { -// NamedGroup named_group_list<2..2^16-1>; -// } NamedGroupList; -type SupportedGroupsExtension struct { - Groups []NamedGroup `tls:"head=2,min=2"` -} - -func (sg SupportedGroupsExtension) Type() ExtensionType { - return ExtensionTypeSupportedGroups -} - -func (sg SupportedGroupsExtension) Marshal() ([]byte, error) { - return syntax.Marshal(sg) -} - -func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sg) -} - -// struct { -// SignatureScheme supported_signature_algorithms<2..2^16-2>; -// } SignatureSchemeList -type SignatureAlgorithmsExtension struct { - Algorithms []SignatureScheme `tls:"head=2,min=2"` -} - -func (sa SignatureAlgorithmsExtension) Type() ExtensionType { - return ExtensionTypeSignatureAlgorithms -} - -func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) { - return syntax.Marshal(sa) -} - -func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sa) -} - -// struct { -// opaque identity<1..2^16-1>; -// uint32 obfuscated_ticket_age; -// } PskIdentity; -// -// opaque PskBinderEntry<32..255>; -// -// struct { -// select (Handshake.msg_type) { -// case client_hello: -// PskIdentity identities<7..2^16-1>; -// PskBinderEntry binders<33..2^16-1>; -// -// case server_hello: -// uint16 selected_identity; -// }; -// -// } PreSharedKeyExtension; -type PSKIdentity struct { - Identity []byte `tls:"head=2,min=1"` - ObfuscatedTicketAge uint32 -} - -type PSKBinderEntry struct { - Binder []byte `tls:"head=1,min=32"` -} - -type PreSharedKeyExtension struct { - HandshakeType HandshakeType - Identities []PSKIdentity - Binders []PSKBinderEntry - SelectedIdentity uint16 -} - -type preSharedKeyClientInner struct { - Identities []PSKIdentity `tls:"head=2,min=7"` - Binders []PSKBinderEntry `tls:"head=2,min=33"` -} - -type preSharedKeyServerInner struct { - SelectedIdentity uint16 -} - -func (psk PreSharedKeyExtension) Type() ExtensionType { - return ExtensionTypePreSharedKey -} - -func (psk PreSharedKeyExtension) Marshal() ([]byte, error) { - switch psk.HandshakeType { - case HandshakeTypeClientHello: - return syntax.Marshal(preSharedKeyClientInner{ - Identities: psk.Identities, - Binders: psk.Binders, - }) - - case HandshakeTypeServerHello: - if len(psk.Identities) > 0 || len(psk.Binders) > 0 { - return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index") - } - return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity}) - - default: - return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported") - } -} - -func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) { - switch psk.HandshakeType { - case HandshakeTypeClientHello: - var inner preSharedKeyClientInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - if len(inner.Identities) != len(inner.Binders) { - return 0, fmt.Errorf("Lengths of identities and binders not equal") - } - - psk.Identities = inner.Identities - psk.Binders = inner.Binders - return read, nil - - case HandshakeTypeServerHello: - var inner preSharedKeyServerInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - psk.SelectedIdentity = inner.SelectedIdentity - return read, nil - - default: - return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported") - } -} - -func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) { - for i, localID := range psk.Identities { - if bytes.Equal(localID.Identity, id) { - return psk.Binders[i].Binder, true - } - } - return nil, false -} - -// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode; -// -// struct { -// PskKeyExchangeMode ke_modes<1..255>; -// } PskKeyExchangeModes; -type PSKKeyExchangeModesExtension struct { - KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"` -} - -func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType { - return ExtensionTypePSKKeyExchangeModes -} - -func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) { - return syntax.Marshal(pkem) -} - -func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, pkem) -} - -// struct { -// } EarlyDataIndication; - -type EarlyDataExtension struct{} - -func (ed EarlyDataExtension) Type() ExtensionType { - return ExtensionTypeEarlyData -} - -func (ed EarlyDataExtension) Marshal() ([]byte, error) { - return []byte{}, nil -} - -func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) { - return 0, nil -} - -// struct { -// uint32 max_early_data_size; -// } TicketEarlyDataInfo; - -type TicketEarlyDataInfoExtension struct { - MaxEarlyDataSize uint32 -} - -func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType { - return ExtensionTypeTicketEarlyDataInfo -} - -func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) { - return syntax.Marshal(tedi) -} - -func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, tedi) -} - -// opaque ProtocolName<1..2^8-1>; -// -// struct { -// ProtocolName protocol_name_list<2..2^16-1> -// } ProtocolNameList; -type ALPNExtension struct { - Protocols []string -} - -type protocolNameInner struct { - Name []byte `tls:"head=1,min=1"` -} - -type alpnExtensionInner struct { - Protocols []protocolNameInner `tls:"head=2,min=2"` -} - -func (alpn ALPNExtension) Type() ExtensionType { - return ExtensionTypeALPN -} - -func (alpn ALPNExtension) Marshal() ([]byte, error) { - protocols := make([]protocolNameInner, len(alpn.Protocols)) - for i, protocol := range alpn.Protocols { - protocols[i] = protocolNameInner{[]byte(protocol)} - } - return syntax.Marshal(alpnExtensionInner{protocols}) -} - -func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { - var inner alpnExtensionInner - read, err := syntax.Unmarshal(data, &inner) - - if err != nil { - return 0, err - } - - alpn.Protocols = make([]string, len(inner.Protocols)) - for i, protocol := range inner.Protocols { - alpn.Protocols[i] = string(protocol.Name) - } - return read, nil -} - -// struct { -// ProtocolVersion versions<2..254>; -// } SupportedVersions; -type SupportedVersionsExtension struct { - Versions []uint16 `tls:"head=1,min=2,max=254"` -} - -func (sv SupportedVersionsExtension) Type() ExtensionType { - return ExtensionTypeSupportedVersions -} - -func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { - return syntax.Marshal(sv) -} - -func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sv) -} - -// struct { -// opaque cookie<1..2^16-1>; -// } Cookie; -type CookieExtension struct { - Cookie []byte `tls:"head=2,min=1"` -} - -func (c CookieExtension) Type() ExtensionType { - return ExtensionTypeCookie -} - -func (c CookieExtension) Marshal() ([]byte, error) { - return syntax.Marshal(c) -} - -func (c *CookieExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, c) -} - -// defaultCookieLength is the default length of a cookie -const defaultCookieLength = 32 - -type defaultCookieHandler struct { - data []byte -} - -var _ CookieHandler = &defaultCookieHandler{} - -// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data -func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) { - h.data = make([]byte, defaultCookieLength) - if _, err := prng.Read(h.data); err != nil { - return nil, err - } - return h.data, nil -} - -func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool { - return bytes.Equal(h.data, data) -} diff --git a/vendor/github.com/bifurcation/mint/ffdhe.go b/vendor/github.com/bifurcation/mint/ffdhe.go deleted file mode 100644 index 59d1f7f9..00000000 --- a/vendor/github.com/bifurcation/mint/ffdhe.go +++ /dev/null @@ -1,147 +0,0 @@ -package mint - -import ( - "encoding/hex" - "math/big" -) - -var ( - finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B423861285C97FFFFFFFFFFFFFFFF" - finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex) - finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes) - - finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF" - finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex) - finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes) - - finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + - "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + - "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + - "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + - "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + - "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" + - "FFFFFFFFFFFFFFFF" - finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex) - finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes) - - finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + - "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + - "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + - "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + - "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + - "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + - "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + - "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + - "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + - "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + - "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + - "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + - "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + - "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + - "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + - "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + - "A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF" - finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex) - finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes) - - finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + - "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + - "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + - "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + - "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + - "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + - "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + - "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + - "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + - "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + - "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + - "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + - "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + - "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + - "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + - "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + - "A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" + - "1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" + - "0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" + - "CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" + - "2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" + - "BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" + - "51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" + - "D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" + - "1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" + - "FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" + - "97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" + - "D68C8BB7C5C6424CFFFFFFFFFFFFFFFF" - finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex) - finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes) -) diff --git a/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/bifurcation/mint/frame-reader.go deleted file mode 100644 index 99ea470d..00000000 --- a/vendor/github.com/bifurcation/mint/frame-reader.go +++ /dev/null @@ -1,98 +0,0 @@ -// Read a generic "framed" packet consisting of a header and a -// This is used for both TLS Records and TLS Handshake Messages -package mint - -type framing interface { - headerLen() int - defaultReadLen() int - frameLen(hdr []byte) (int, error) -} - -const ( - kFrameReaderHdr = 0 - kFrameReaderBody = 1 -) - -type frameNextAction func(f *frameReader) error - -type frameReader struct { - details framing - state uint8 - header []byte - body []byte - working []byte - writeOffset int - remainder []byte -} - -func newFrameReader(d framing) *frameReader { - hdr := make([]byte, d.headerLen()) - return &frameReader{ - d, - kFrameReaderHdr, - hdr, - nil, - hdr, - 0, - nil, - } -} - -func dup(a []byte) []byte { - r := make([]byte, len(a)) - copy(r, a) - return r -} - -func (f *frameReader) needed() int { - tmp := (len(f.working) - f.writeOffset) - len(f.remainder) - if tmp < 0 { - return 0 - } - return tmp -} - -func (f *frameReader) addChunk(in []byte) { - // Append to the buffer. - logf(logTypeFrameReader, "Appending %v", len(in)) - f.remainder = append(f.remainder, in...) -} - -func (f *frameReader) process() (hdr []byte, body []byte, err error) { - for f.needed() == 0 { - logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) - // Fill out our working block - copied := copy(f.working[f.writeOffset:], f.remainder) - f.remainder = f.remainder[copied:] - f.writeOffset += copied - if f.writeOffset < len(f.working) { - logf(logTypeFrameReader, "Read would have blocked 1") - return nil, nil, WouldBlock - } - // Reset the write offset, because we are now full. - f.writeOffset = 0 - - // We have read a full frame - if f.state == kFrameReaderBody { - logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) - f.state = kFrameReaderHdr - f.working = f.header - return dup(f.header), dup(f.body), nil - } - - // We have read the header - bodyLen, err := f.details.frameLen(f.header) - if err != nil { - return nil, nil, err - } - logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) - - f.body = make([]byte, bodyLen) - f.working = f.body - f.writeOffset = 0 - f.state = kFrameReaderBody - } - - logf(logTypeFrameReader, "Read would have blocked 2") - return nil, nil, WouldBlock -} diff --git a/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/bifurcation/mint/handshake-layer.go deleted file mode 100644 index 2b04ac5c..00000000 --- a/vendor/github.com/bifurcation/mint/handshake-layer.go +++ /dev/null @@ -1,253 +0,0 @@ -package mint - -import ( - "fmt" - "io" - "net" -) - -const ( - handshakeHeaderLen = 4 // handshake message header length - maxHandshakeMessageLen = 1 << 24 // max handshake message length -) - -// struct { -// HandshakeType msg_type; /* handshake type */ -// uint24 length; /* bytes in message */ -// select (HandshakeType) { -// ... -// } body; -// } Handshake; -// -// We do the select{...} part in a different layer, so we treat the -// actual message body as opaque: -// -// struct { -// HandshakeType msg_type; -// opaque msg<0..2^24-1> -// } Handshake; -// -// TODO: File a spec bug -type HandshakeMessage struct { - // Omitted: length - msgType HandshakeType - body []byte -} - -// Note: This could be done with the `syntax` module, using the simplified -// syntax as discussed above. However, since this is so simple, there's not -// much benefit to doing so. -func (hm *HandshakeMessage) Marshal() []byte { - if hm == nil { - return []byte{} - } - - msgLen := len(hm.body) - data := make([]byte, 4+len(hm.body)) - data[0] = byte(hm.msgType) - data[1] = byte(msgLen >> 16) - data[2] = byte(msgLen >> 8) - data[3] = byte(msgLen) - copy(data[4:], hm.body) - return data -} - -func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { - logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body) - - var body HandshakeMessageBody - switch hm.msgType { - case HandshakeTypeClientHello: - body = new(ClientHelloBody) - case HandshakeTypeServerHello: - body = new(ServerHelloBody) - case HandshakeTypeHelloRetryRequest: - body = new(HelloRetryRequestBody) - case HandshakeTypeEncryptedExtensions: - body = new(EncryptedExtensionsBody) - case HandshakeTypeCertificate: - body = new(CertificateBody) - case HandshakeTypeCertificateRequest: - body = new(CertificateRequestBody) - case HandshakeTypeCertificateVerify: - body = new(CertificateVerifyBody) - case HandshakeTypeFinished: - body = &FinishedBody{VerifyDataLen: len(hm.body)} - case HandshakeTypeNewSessionTicket: - body = new(NewSessionTicketBody) - case HandshakeTypeKeyUpdate: - body = new(KeyUpdateBody) - case HandshakeTypeEndOfEarlyData: - body = new(EndOfEarlyDataBody) - default: - return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") - } - - _, err := body.Unmarshal(hm.body) - return body, err -} - -func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { - data, err := body.Marshal() - if err != nil { - return nil, err - } - - return &HandshakeMessage{ - msgType: body.Type(), - body: data, - }, nil -} - -type HandshakeLayer struct { - nonblocking bool // Should we operate in nonblocking mode - conn *RecordLayer // Used for reading/writing records - frame *frameReader // The buffered frame reader -} - -type handshakeLayerFrameDetails struct{} - -func (d handshakeLayerFrameDetails) headerLen() int { - return handshakeHeaderLen -} - -func (d handshakeLayerFrameDetails) defaultReadLen() int { - return handshakeHeaderLen + maxFragmentLen -} - -func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { - logf(logTypeIO, "Header=%x", hdr) - return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil -} - -func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer { - h := HandshakeLayer{} - h.conn = r - h.frame = newFrameReader(&handshakeLayerFrameDetails{}) - return &h -} - -func (h *HandshakeLayer) readRecord() error { - logf(logTypeIO, "Trying to read record") - pt, err := h.conn.ReadRecord() - if err != nil { - return err - } - - if pt.contentType != RecordTypeHandshake && - pt.contentType != RecordTypeAlert { - return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) - } - - if pt.contentType == RecordTypeAlert { - logf(logTypeIO, "read alert %v", pt.fragment[1]) - if len(pt.fragment) < 2 { - h.sendAlert(AlertUnexpectedMessage) - return io.EOF - } - return Alert(pt.fragment[1]) - } - - logf(logTypeIO, "read handshake record of len %v", len(pt.fragment)) - h.frame.addChunk(pt.fragment) - - return nil -} - -// sendAlert sends a TLS alert message. -func (h *HandshakeLayer) sendAlert(err Alert) error { - tmp := make([]byte, 2) - tmp[0] = AlertLevelError - tmp[1] = byte(err) - h.conn.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeAlert, - fragment: tmp}, - ) - - // closeNotify is a special case in that it isn't an error: - if err != AlertCloseNotify { - return &net.OpError{Op: "local error", Err: err} - } - return nil -} - -func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { - var hdr, body []byte - var err error - - for { - logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder)) - if h.frame.needed() > 0 { - logf(logTypeHandshake, "Trying to read a new record") - err = h.readRecord() - } - if err != nil && (h.nonblocking || err != WouldBlock) { - return nil, err - } - - hdr, body, err = h.frame.process() - if err == nil { - break - } - if err != nil && (h.nonblocking || err != WouldBlock) { - return nil, err - } - } - - logf(logTypeHandshake, "read handshake message") - - hm := &HandshakeMessage{} - hm.msgType = HandshakeType(hdr[0]) - - hm.body = make([]byte, len(body)) - copy(hm.body, body) - - return hm, nil -} - -func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { - return h.WriteMessages([]*HandshakeMessage{hm}) -} - -func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { - for _, hm := range hms { - logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) - } - - // Write out headers and bodies - buffer := []byte{} - for _, msg := range hms { - msgLen := len(msg.body) - if msgLen > maxHandshakeMessageLen { - return fmt.Errorf("tls.handshakelayer: Message too large to send") - } - - buffer = append(buffer, msg.Marshal()...) - } - - // Send full-size fragments - var start int - for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { - err := h.conn.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeHandshake, - fragment: buffer[start : start+maxFragmentLen], - }) - - if err != nil { - return err - } - } - - // Send a final partial fragment if necessary - if start < len(buffer) { - err := h.conn.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeHandshake, - fragment: buffer[start:], - }) - - if err != nil { - return err - } - } - return nil -} diff --git a/vendor/github.com/bifurcation/mint/handshake-messages.go b/vendor/github.com/bifurcation/mint/handshake-messages.go deleted file mode 100644 index 339bbcd0..00000000 --- a/vendor/github.com/bifurcation/mint/handshake-messages.go +++ /dev/null @@ -1,450 +0,0 @@ -package mint - -import ( - "bytes" - "crypto" - "crypto/x509" - "encoding/binary" - "fmt" - - "github.com/bifurcation/mint/syntax" -) - -type HandshakeMessageBody interface { - Type() HandshakeType - Marshal() ([]byte, error) - Unmarshal(data []byte) (int, error) -} - -// struct { -// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ -// Random random; -// opaque legacy_session_id<0..32>; -// CipherSuite cipher_suites<2..2^16-2>; -// opaque legacy_compression_methods<1..2^8-1>; -// Extension extensions<0..2^16-1>; -// } ClientHello; -type ClientHelloBody struct { - // Omitted: clientVersion - // Omitted: legacySessionID - // Omitted: legacyCompressionMethods - Random [32]byte - CipherSuites []CipherSuite - Extensions ExtensionList -} - -type clientHelloBodyInner struct { - LegacyVersion uint16 - Random [32]byte - LegacySessionID []byte `tls:"head=1,max=32"` - CipherSuites []CipherSuite `tls:"head=2,min=2"` - LegacyCompressionMethods []byte `tls:"head=1,min=1"` - Extensions []Extension `tls:"head=2"` -} - -func (ch ClientHelloBody) Type() HandshakeType { - return HandshakeTypeClientHello -} - -func (ch ClientHelloBody) Marshal() ([]byte, error) { - return syntax.Marshal(clientHelloBodyInner{ - LegacyVersion: 0x0303, - Random: ch.Random, - LegacySessionID: []byte{}, - CipherSuites: ch.CipherSuites, - LegacyCompressionMethods: []byte{0}, - Extensions: ch.Extensions, - }) -} - -func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { - var inner clientHelloBodyInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - // We are strict about these things because we only support 1.3 - if inner.LegacyVersion != 0x0303 { - return 0, fmt.Errorf("tls.clienthello: Incorrect version number") - } - - if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { - return 0, fmt.Errorf("tls.clienthello: Invalid compression method") - } - - ch.Random = inner.Random - ch.CipherSuites = inner.CipherSuites - ch.Extensions = inner.Extensions - return read, nil -} - -// TODO: File a spec bug to clarify this -func (ch ClientHelloBody) Truncated() ([]byte, error) { - if len(ch.Extensions) == 0 { - return nil, fmt.Errorf("tls.clienthello.truncate: No extensions") - } - - pskExt := ch.Extensions[len(ch.Extensions)-1] - if pskExt.ExtensionType != ExtensionTypePreSharedKey { - return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") - } - - chm, err := HandshakeMessageFromBody(&ch) - if err != nil { - return nil, err - } - chData := chm.Marshal() - - psk := PreSharedKeyExtension{ - HandshakeType: HandshakeTypeClientHello, - } - _, err = psk.Unmarshal(pskExt.ExtensionData) - if err != nil { - return nil, err - } - - // Marshal just the binders so that we know how much to truncate - binders := struct { - Binders []PSKBinderEntry `tls:"head=2,min=33"` - }{Binders: psk.Binders} - binderData, _ := syntax.Marshal(binders) - binderLen := len(binderData) - - chLen := len(chData) - return chData[:chLen-binderLen], nil -} - -// struct { -// ProtocolVersion server_version; -// CipherSuite cipher_suite; -// Extension extensions<2..2^16-1>; -// } HelloRetryRequest; -type HelloRetryRequestBody struct { - Version uint16 - CipherSuite CipherSuite - Extensions ExtensionList `tls:"head=2,min=2"` -} - -func (hrr HelloRetryRequestBody) Type() HandshakeType { - return HandshakeTypeHelloRetryRequest -} - -func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) { - return syntax.Marshal(hrr) -} - -func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, hrr) -} - -// struct { -// ProtocolVersion version; -// Random random; -// CipherSuite cipher_suite; -// Extension extensions<0..2^16-1>; -// } ServerHello; -type ServerHelloBody struct { - Version uint16 - Random [32]byte - CipherSuite CipherSuite - Extensions ExtensionList `tls:"head=2"` -} - -func (sh ServerHelloBody) Type() HandshakeType { - return HandshakeTypeServerHello -} - -func (sh ServerHelloBody) Marshal() ([]byte, error) { - return syntax.Marshal(sh) -} - -func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sh) -} - -// struct { -// opaque verify_data[verify_data_length]; -// } Finished; -// -// verifyDataLen is not a field in the TLS struct, but we add it here so -// that calling code can tell us how much data to expect when we marshal / -// unmarshal. (We could add this to the marshal/unmarshal methods, but let's -// try to keep the signature consistent for now.) -// -// For similar reasons, we don't use the `syntax` module here, because this -// struct doesn't map well to standard TLS presentation language concepts. -// -// TODO: File a spec bug -type FinishedBody struct { - VerifyDataLen int - VerifyData []byte -} - -func (fin FinishedBody) Type() HandshakeType { - return HandshakeTypeFinished -} - -func (fin FinishedBody) Marshal() ([]byte, error) { - if len(fin.VerifyData) != fin.VerifyDataLen { - return nil, fmt.Errorf("tls.finished: data length mismatch") - } - - body := make([]byte, len(fin.VerifyData)) - copy(body, fin.VerifyData) - return body, nil -} - -func (fin *FinishedBody) Unmarshal(data []byte) (int, error) { - if len(data) < fin.VerifyDataLen { - return 0, fmt.Errorf("tls.finished: Malformed finished; too short") - } - - fin.VerifyData = make([]byte, fin.VerifyDataLen) - copy(fin.VerifyData, data[:fin.VerifyDataLen]) - return fin.VerifyDataLen, nil -} - -// struct { -// Extension extensions<0..2^16-1>; -// } EncryptedExtensions; -// -// Marshal() and Unmarshal() are handled by ExtensionList -type EncryptedExtensionsBody struct { - Extensions ExtensionList `tls:"head=2"` -} - -func (ee EncryptedExtensionsBody) Type() HandshakeType { - return HandshakeTypeEncryptedExtensions -} - -func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) { - return syntax.Marshal(ee) -} - -func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, ee) -} - -// opaque ASN1Cert<1..2^24-1>; -// -// struct { -// ASN1Cert cert_data; -// Extension extensions<0..2^16-1> -// } CertificateEntry; -// -// struct { -// opaque certificate_request_context<0..2^8-1>; -// CertificateEntry certificate_list<0..2^24-1>; -// } Certificate; -type CertificateEntry struct { - CertData *x509.Certificate - Extensions ExtensionList -} - -type CertificateBody struct { - CertificateRequestContext []byte - CertificateList []CertificateEntry -} - -type certificateEntryInner struct { - CertData []byte `tls:"head=3,min=1"` - Extensions ExtensionList `tls:"head=2"` -} - -type certificateBodyInner struct { - CertificateRequestContext []byte `tls:"head=1"` - CertificateList []certificateEntryInner `tls:"head=3"` -} - -func (c CertificateBody) Type() HandshakeType { - return HandshakeTypeCertificate -} - -func (c CertificateBody) Marshal() ([]byte, error) { - inner := certificateBodyInner{ - CertificateRequestContext: c.CertificateRequestContext, - CertificateList: make([]certificateEntryInner, len(c.CertificateList)), - } - - for i, entry := range c.CertificateList { - inner.CertificateList[i] = certificateEntryInner{ - CertData: entry.CertData.Raw, - Extensions: entry.Extensions, - } - } - - return syntax.Marshal(inner) -} - -func (c *CertificateBody) Unmarshal(data []byte) (int, error) { - inner := certificateBodyInner{} - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return read, err - } - - c.CertificateRequestContext = inner.CertificateRequestContext - c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) - - for i, entry := range inner.CertificateList { - c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) - if err != nil { - return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) - } - - c.CertificateList[i].Extensions = entry.Extensions - } - - return read, nil -} - -// struct { -// SignatureScheme algorithm; -// opaque signature<0..2^16-1>; -// } CertificateVerify; -type CertificateVerifyBody struct { - Algorithm SignatureScheme - Signature []byte `tls:"head=2"` -} - -func (cv CertificateVerifyBody) Type() HandshakeType { - return HandshakeTypeCertificateVerify -} - -func (cv CertificateVerifyBody) Marshal() ([]byte, error) { - return syntax.Marshal(cv) -} - -func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, cv) -} - -func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte { - // TODO: Change context for client auth - // TODO: Put this in a const - const context = "TLS 1.3, server CertificateVerify" - sigInput := bytes.Repeat([]byte{0x20}, 64) - sigInput = append(sigInput, []byte(context)...) - sigInput = append(sigInput, []byte{0}...) - sigInput = append(sigInput, data...) - return sigInput -} - -func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) { - sigInput := cv.EncodeSignatureInput(handshakeHash) - cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput) - logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) - return -} - -func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error { - sigInput := cv.EncodeSignatureInput(handshakeHash) - logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) - return verify(cv.Algorithm, publicKey, sigInput, cv.Signature) -} - -// struct { -// opaque certificate_request_context<0..2^8-1>; -// Extension extensions<2..2^16-1>; -// } CertificateRequest; -type CertificateRequestBody struct { - CertificateRequestContext []byte `tls:"head=1"` - Extensions ExtensionList `tls:"head=2"` -} - -func (cr CertificateRequestBody) Type() HandshakeType { - return HandshakeTypeCertificateRequest -} - -func (cr CertificateRequestBody) Marshal() ([]byte, error) { - return syntax.Marshal(cr) -} - -func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, cr) -} - -// struct { -// uint32 ticket_lifetime; -// uint32 ticket_age_add; -// opaque ticket_nonce<1..255>; -// opaque ticket<1..2^16-1>; -// Extension extensions<0..2^16-2>; -// } NewSessionTicket; -type NewSessionTicketBody struct { - TicketLifetime uint32 - TicketAgeAdd uint32 - TicketNonce []byte `tls:"head=1,min=1"` - Ticket []byte `tls:"head=2,min=1"` - Extensions ExtensionList `tls:"head=2"` -} - -const ticketNonceLen = 16 - -func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) { - buf := make([]byte, 4+ticketNonceLen+ticketLen) - _, err := prng.Read(buf) - if err != nil { - return nil, err - } - - tkt := &NewSessionTicketBody{ - TicketLifetime: ticketLifetime, - TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]), - TicketNonce: buf[4 : 4+ticketNonceLen], - Ticket: buf[4+ticketNonceLen:], - } - - return tkt, err -} - -func (tkt NewSessionTicketBody) Type() HandshakeType { - return HandshakeTypeNewSessionTicket -} - -func (tkt NewSessionTicketBody) Marshal() ([]byte, error) { - return syntax.Marshal(tkt) -} - -func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, tkt) -} - -// enum { -// update_not_requested(0), update_requested(1), (255) -// } KeyUpdateRequest; -// -// struct { -// KeyUpdateRequest request_update; -// } KeyUpdate; -type KeyUpdateBody struct { - KeyUpdateRequest KeyUpdateRequest -} - -func (ku KeyUpdateBody) Type() HandshakeType { - return HandshakeTypeKeyUpdate -} - -func (ku KeyUpdateBody) Marshal() ([]byte, error) { - return syntax.Marshal(ku) -} - -func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, ku) -} - -// struct {} EndOfEarlyData; -type EndOfEarlyDataBody struct{} - -func (eoed EndOfEarlyDataBody) Type() HandshakeType { - return HandshakeTypeEndOfEarlyData -} - -func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) { - return []byte{}, nil -} - -func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) { - return 0, nil -} diff --git a/vendor/github.com/bifurcation/mint/log.go b/vendor/github.com/bifurcation/mint/log.go deleted file mode 100644 index 2fba90de..00000000 --- a/vendor/github.com/bifurcation/mint/log.go +++ /dev/null @@ -1,55 +0,0 @@ -package mint - -import ( - "fmt" - "log" - "os" - "strings" -) - -// We use this environment variable to control logging. It should be a -// comma-separated list of log tags (see below) or "*" to enable all logging. -const logConfigVar = "MINT_LOG" - -// Pre-defined log types -const ( - logTypeCrypto = "crypto" - logTypeHandshake = "handshake" - logTypeNegotiation = "negotiation" - logTypeIO = "io" - logTypeFrameReader = "frame" - logTypeVerbose = "verbose" -) - -var ( - logFunction = log.Printf - logAll = false - logSettings = map[string]bool{} -) - -func init() { - parseLogEnv(os.Environ()) -} - -func parseLogEnv(env []string) { - for _, stmt := range env { - if strings.HasPrefix(stmt, logConfigVar+"=") { - val := stmt[len(logConfigVar)+1:] - - if val == "*" { - logAll = true - } else { - for _, t := range strings.Split(val, ",") { - logSettings[t] = true - } - } - } - } -} - -func logf(tag string, format string, args ...interface{}) { - if logAll || logSettings[tag] { - fullFormat := fmt.Sprintf("[%s] %s", tag, format) - logFunction(fullFormat, args...) - } -} diff --git a/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/bifurcation/mint/negotiation.go deleted file mode 100644 index f4ead72e..00000000 --- a/vendor/github.com/bifurcation/mint/negotiation.go +++ /dev/null @@ -1,217 +0,0 @@ -package mint - -import ( - "bytes" - "encoding/hex" - "fmt" - "time" -) - -func VersionNegotiation(offered, supported []uint16) (bool, uint16) { - for _, offeredVersion := range offered { - for _, supportedVersion := range supported { - logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion) - if offeredVersion == supportedVersion { - // XXX: Should probably be highest supported version, but for now, we - // only support one version, so it doesn't really matter. - return true, offeredVersion - } - } - } - - return false, 0 -} - -func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) { - for _, share := range keyShares { - for _, group := range groups { - if group != share.Group { - continue - } - - pub, priv, err := newKeyShare(share.Group) - if err != nil { - // If we encounter an error, just keep looking - continue - } - - dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv) - if err != nil { - // If we encounter an error, just keep looking - continue - } - - return true, group, pub, dhSecret - } - } - - return false, 0, nil, nil -} - -const ( - ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds -) - -func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) { - logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size()) - for i, id := range identities { - identityHex := hex.EncodeToString(id.Identity) - - psk, ok := psks.Get(identityHex) - if !ok { - logf(logTypeNegotiation, "No PSK for identity %x", identityHex) - continue - } - - // For resumption, make sure the ticket age is correct - if psk.IsResumption { - extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd - knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond) - ticketAgeDelta := knownTicketAge - extTicketAge - if knownTicketAge < extTicketAge { - ticketAgeDelta = extTicketAge - knownTicketAge - } - if ticketAgeDelta > ticketAgeTolerance { - logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity) - logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]", - extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance) - return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity) - } - } - - params, ok := cipherSuiteMap[psk.CipherSuite] - if !ok { - err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite) - return false, 0, nil, CipherSuiteParams{}, err - } - - // Compute binder - binderLabel := labelExternalBinder - if psk.IsResumption { - binderLabel = labelResumptionBinder - } - - h0 := params.Hash.New().Sum(nil) - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - earlySecret := HkdfExtract(params.Hash, zero, psk.Key) - binderKey := deriveSecret(params, earlySecret, binderLabel, h0) - - // context = ClientHello[truncated] - // context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated] - ctxHash := params.Hash.New() - ctxHash.Write(context) - - binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil)) - if !bytes.Equal(binder, binders[i].Binder) { - logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder) - return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity) - } - - logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity) - return true, i, &psk, params, nil - } - - logf(logTypeNegotiation, "Failed to find a usable PSK") - return false, 0, nil, CipherSuiteParams{}, nil -} - -func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) { - logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes) - dhAllowed := false - dhRequired := true - for _, mode := range modes { - dhAllowed = dhAllowed || (mode == PSKModeDHEKE) - dhRequired = dhRequired && (mode == PSKModeDHEKE) - } - - // Use PSK if we can meet DH requirement and modes were provided - usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0) - - // Use DH if allowed - usingDH := canDoDH && (dhAllowed || !usingPSK) - - logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK) - return usingDH, usingPSK -} - -func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) { - // Select for server name if provided - candidates := certs - if serverName != nil { - candidatesByName := []*Certificate{} - for _, cert := range certs { - for _, name := range cert.Chain[0].DNSNames { - if len(*serverName) > 0 && name == *serverName { - candidatesByName = append(candidatesByName, cert) - } - } - } - - if len(candidatesByName) == 0 { - return nil, 0, fmt.Errorf("No certificates available for server name") - } - - candidates = candidatesByName - } - - // Select for signature scheme - for _, cert := range candidates { - for _, scheme := range signatureSchemes { - if !schemeValidForKey(scheme, cert.PrivateKey) { - continue - } - - return cert, scheme, nil - } - } - - return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") -} - -func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { - usingEarlyData := gotEarlyData && usingPSK && allowEarlyData - logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) - return usingEarlyData -} - -func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { - for _, s1 := range offered { - if psk != nil { - if s1 == psk.CipherSuite { - return s1, nil - } - continue - } - - for _, s2 := range supported { - if s1 == s2 { - return s1, nil - } - } - } - - return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil) -} - -func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) { - for _, p1 := range offered { - if psk != nil { - if p1 != psk.NextProto { - continue - } - } - - for _, p2 := range supported { - if p1 == p2 { - return p1, nil - } - } - } - - // If the client offers ALPN on resumption, it must match the earlier one - var err error - if psk != nil && psk.IsResumption && (len(offered) > 0) { - err = fmt.Errorf("ALPN for PSK not provided") - } - return "", err -} diff --git a/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/bifurcation/mint/record-layer.go deleted file mode 100644 index bcef6136..00000000 --- a/vendor/github.com/bifurcation/mint/record-layer.go +++ /dev/null @@ -1,296 +0,0 @@ -package mint - -import ( - "bytes" - "crypto/cipher" - "fmt" - "io" - "sync" -) - -const ( - sequenceNumberLen = 8 // sequence number length - recordHeaderLen = 5 // record header length - maxFragmentLen = 1 << 14 // max number of bytes in a record -) - -type DecryptError string - -func (err DecryptError) Error() string { - return string(err) -} - -// struct { -// ContentType type; -// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */ -// uint16 length; -// opaque fragment[TLSPlaintext.length]; -// } TLSPlaintext; -type TLSPlaintext struct { - // Omitted: record_version (static) - // Omitted: length (computed from fragment) - contentType RecordType - fragment []byte -} - -type RecordLayer struct { - sync.Mutex - - conn io.ReadWriter // The underlying connection - frame *frameReader // The buffered frame reader - nextData []byte // The next record to send - cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" - cachedError error // Error on the last record read - - ivLength int // Length of the seq and nonce fields - seq []byte // Zero-padded sequence number - nonce []byte // Buffer for per-record nonces - cipher cipher.AEAD // AEAD cipher -} - -type recordLayerFrameDetails struct{} - -func (d recordLayerFrameDetails) headerLen() int { - return recordHeaderLen -} - -func (d recordLayerFrameDetails) defaultReadLen() int { - return recordHeaderLen + maxFragmentLen -} - -func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { - return (int(hdr[3]) << 8) | int(hdr[4]), nil -} - -func NewRecordLayer(conn io.ReadWriter) *RecordLayer { - r := RecordLayer{} - r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{}) - r.ivLength = 0 - return &r -} - -func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error { - var err error - r.cipher, err = cipher(key) - if err != nil { - return err - } - - r.ivLength = len(iv) - r.seq = bytes.Repeat([]byte{0}, r.ivLength) - r.nonce = make([]byte, r.ivLength) - copy(r.nonce, iv) - return nil -} - -func (r *RecordLayer) incrementSequenceNumber() { - if r.ivLength == 0 { - return - } - - for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- { - r.seq[i]++ - r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i] - if r.seq[i] != 0 { - return - } - } - - // Not allowed to let sequence number wrap. - // Instead, must renegotiate before it does. - // Not likely enough to bother. - panic("TLS: sequence number wraparound") -} - -func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext { - // Expand the fragment to hold contentType, padding, and overhead - originalLen := len(pt.fragment) - plaintextLen := originalLen + 1 + padLen - ciphertextLen := plaintextLen + r.cipher.Overhead() - - // Assemble the revised plaintext - out := &TLSPlaintext{ - contentType: RecordTypeApplicationData, - fragment: make([]byte, ciphertextLen), - } - copy(out.fragment, pt.fragment) - out.fragment[originalLen] = byte(pt.contentType) - for i := 1; i <= padLen; i++ { - out.fragment[originalLen+i] = 0 - } - - // Encrypt the fragment - payload := out.fragment[:plaintextLen] - r.cipher.Seal(payload[:0], r.nonce, payload, nil) - return out -} - -func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) { - if len(pt.fragment) < r.cipher.Overhead() { - msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead()) - return nil, 0, DecryptError(msg) - } - - decryptLen := len(pt.fragment) - r.cipher.Overhead() - out := &TLSPlaintext{ - contentType: pt.contentType, - fragment: make([]byte, decryptLen), - } - - // Decrypt - _, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil) - if err != nil { - return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") - } - - // Find the padding boundary - padLen := 0 - for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ { - } - - // Transfer the content type - newLen := decryptLen - padLen - 1 - out.contentType = RecordType(out.fragment[newLen]) - - // Truncate the message to remove contentType, padding, overhead - out.fragment = out.fragment[:newLen] - return out, padLen, nil -} - -func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { - var pt *TLSPlaintext - var err error - - for { - pt, err = r.nextRecord() - if err == nil { - break - } - if !block || err != WouldBlock { - return 0, err - } - } - return pt.contentType, nil -} - -func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { - pt, err := r.nextRecord() - - // Consume the cached record if there was one - r.cachedRecord = nil - r.cachedError = nil - - return pt, err -} - -func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { - if r.cachedRecord != nil { - logf(logTypeIO, "Returning cached record") - return r.cachedRecord, r.cachedError - } - - // Loop until one of three things happens: - // - // 1. We get a frame - // 2. We try to read off the socket and get nothing, in which case - // return WouldBlock - // 3. We get an error. - err := WouldBlock - var header, body []byte - - for err != nil { - if r.frame.needed() > 0 { - buf := make([]byte, recordHeaderLen+maxFragmentLen) - n, err := r.conn.Read(buf) - if err != nil { - logf(logTypeIO, "Error reading, %v", err) - return nil, err - } - - if n == 0 { - return nil, WouldBlock - } - - logf(logTypeIO, "Read %v bytes", n) - - buf = buf[:n] - r.frame.addChunk(buf) - } - - header, body, err = r.frame.process() - // Loop around on WouldBlock to see if some - // data is now available. - if err != nil && err != WouldBlock { - return nil, err - } - } - - pt := &TLSPlaintext{} - // Validate content type - switch RecordType(header[0]) { - default: - return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) - case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: - pt.contentType = RecordType(header[0]) - } - - // Validate version - if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) { - return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2]) - } - - // Validate size < max - size := (int(header[3]) << 8) + int(header[4]) - if size > maxFragmentLen+256 { - return nil, fmt.Errorf("tls.record: Ciphertext size too big") - } - - pt.fragment = make([]byte, size) - copy(pt.fragment, body) - - // Attempt to decrypt fragment - if r.cipher != nil { - pt, _, err = r.decrypt(pt) - if err != nil { - return nil, err - } - } - - // Check that plaintext length is not too long - if len(pt.fragment) > maxFragmentLen { - return nil, fmt.Errorf("tls.record: Plaintext size too big") - } - - logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) - - r.cachedRecord = pt - r.incrementSequenceNumber() - return pt, nil -} - -func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { - return r.WriteRecordWithPadding(pt, 0) -} - -func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { - if r.cipher != nil { - pt = r.encrypt(pt, padLen) - } else if padLen > 0 { - return fmt.Errorf("tls.record: Padding can only be done on encrypted records") - } - - if len(pt.fragment) > maxFragmentLen { - return fmt.Errorf("tls.record: Record size too big") - } - - length := len(pt.fragment) - header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} - record := append(header, pt.fragment...) - - logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment) - - r.incrementSequenceNumber() - _, err := r.conn.Write(record) - return err -} diff --git a/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/bifurcation/mint/server-state-machine.go deleted file mode 100644 index 60df9b64..00000000 --- a/vendor/github.com/bifurcation/mint/server-state-machine.go +++ /dev/null @@ -1,898 +0,0 @@ -package mint - -import ( - "bytes" - "hash" - "reflect" -) - -// Server State Machine -// -// START <-----+ -// Recv ClientHello | | Send HelloRetryRequest -// v | -// RECVD_CH ----+ -// | Select parameters -// | Send ServerHello -// v -// NEGOTIATED -// | Send EncryptedExtensions -// | [Send CertificateRequest] -// Can send | [Send Certificate + CertificateVerify] -// app data --> | Send Finished -// after +--------+--------+ -// here No 0-RTT | | 0-RTT -// | v -// | WAIT_EOED <---+ -// | Recv | | | Recv -// | EndOfEarlyData | | | early data -// | | +-----+ -// +> WAIT_FLIGHT2 <-+ -// | -// +--------+--------+ -// No auth | | Client auth -// | | -// | v -// | WAIT_CERT -// | Recv | | Recv Certificate -// | empty | v -// | Certificate | WAIT_CV -// | | | Recv -// | v | CertificateVerify -// +-> WAIT_FINISHED <---+ -// | Recv Finished -// v -// CONNECTED -// -// NB: Not using state RECVD_CH -// -// State Instructions -// START {} -// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] -// WAIT_EOED RekeyIn; -// WAIT_FLIGHT2 {} -// WAIT_CERT_CR {} -// WAIT_CERT {} -// WAIT_CV {} -// WAIT_FINISHED RekeyIn; RekeyOut; -// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) - -type ServerStateStart struct { - Caps Capabilities - conn *Conn - - cookieSent bool - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage -} - -func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeClientHello { - logf(logTypeHandshake, "[ServerStateStart] unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - ch := &ClientHelloBody{} - _, err := ch.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - clientHello := hm - connParams := ConnectionParameters{} - - supportedVersions := new(SupportedVersionsExtension) - serverName := new(ServerNameExtension) - supportedGroups := new(SupportedGroupsExtension) - signatureAlgorithms := new(SignatureAlgorithmsExtension) - clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello} - clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello} - clientEarlyData := &EarlyDataExtension{} - clientALPN := new(ALPNExtension) - clientPSKModes := new(PSKKeyExchangeModesExtension) - clientCookie := new(CookieExtension) - - // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - gotSupportedVersions := ch.Extensions.Find(supportedVersions) - gotServerName := ch.Extensions.Find(serverName) - gotSupportedGroups := ch.Extensions.Find(supportedGroups) - gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms) - gotEarlyData := ch.Extensions.Find(clientEarlyData) - ch.Extensions.Find(clientKeyShares) - ch.Extensions.Find(clientPSK) - ch.Extensions.Find(clientALPN) - ch.Extensions.Find(clientPSKModes) - ch.Extensions.Find(clientCookie) - - if gotServerName { - connParams.ServerName = string(*serverName) - } - - // If the client didn't send supportedVersions or doesn't support 1.3, - // then we're done here. - if !gotSupportedVersions { - logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") - return nil, nil, AlertProtocolVersion - } - versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion}) - if !versionOK { - logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version") - return nil, nil, AlertProtocolVersion - } - - if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) { - logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") - return nil, nil, AlertAccessDenied - } - - // Figure out if we can do DH - canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups) - - // Figure out if we can do PSK - canDoPSK := false - var selectedPSK int - var psk *PreSharedKey - var params CipherSuiteParams - if len(clientPSK.Identities) > 0 { - contextBase := []byte{} - if state.helloRetryRequest != nil { - chBytes := state.firstClientHello.Marshal() - hrrBytes := state.helloRetryRequest.Marshal() - contextBase = append(chBytes, hrrBytes...) - } - - chTrunc, err := ch.Truncated() - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err) - return nil, nil, AlertDecodeError - } - - context := append(contextBase, chTrunc...) - - canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Figure out if we actually should do DH / PSK - connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) - - // Select a ciphersuite - connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) - return nil, nil, AlertHandshakeFailure - } - - // Send a cookie if required - // NB: Need to do this here because it's after ciphersuite selection, which - // has to be after PSK selection. - // XXX: Doing this statefully for now, could be stateless - var cookieData []byte - if state.Caps.RequireCookie && !state.cookieSent { - var err error - cookieData, err = state.Caps.CookieHandler.Generate(state.conn) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) - return nil, nil, AlertInternalError - } - } - if cookieData != nil { - // Ignoring errors because everything here is newly constructed, so there - // shouldn't be marshal errors - hrr := &HelloRetryRequestBody{ - Version: supportedVersion, - CipherSuite: connParams.CipherSuite, - } - hrr.Extensions.Add(&CookieExtension{Cookie: cookieData}) - - // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) - return nil, nil, AlertInternalError - } - } - - helloRetryRequest, err := HandshakeMessageFromBody(hrr) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) - return nil, nil, AlertInternalError - } - - params := cipherSuiteMap[connParams.CipherSuite] - h := params.Hash.New() - h.Write(clientHello.Marshal()) - firstClientHello := &HandshakeMessage{ - msgType: HandshakeTypeMessageHash, - body: h.Sum(nil), - } - - nextState := ServerStateStart{ - Caps: state.Caps, - conn: state.conn, - cookieSent: true, - firstClientHello: firstClientHello, - helloRetryRequest: helloRetryRequest, - } - toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}} - logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") - return nextState, toSend, AlertNoAlert - } - - // If we've got no entropy to make keys from, fail - if !connParams.UsingDH && !connParams.UsingPSK { - logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated") - return nil, nil, AlertHandshakeFailure - } - - var pskSecret []byte - var cert *Certificate - var certScheme SignatureScheme - if connParams.UsingPSK { - pskSecret = psk.Key - } else { - psk = nil - - // If we're not using a PSK mode, then we need to have certain extensions - if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms { - logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)", - gotServerName, gotSupportedGroups, gotSignatureAlgorithms) - return nil, nil, AlertMissingExtension - } - - // Select a certificate - name := string(*serverName) - var err error - cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err) - return nil, nil, AlertAccessDenied - } - } - - if !connParams.UsingDH { - dhSecret = nil - } - - // Figure out if we're going to do early data - var clientEarlyTrafficSecret []byte - connParams.ClientSendingEarlyData = gotEarlyData - connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData) - if connParams.UsingEarlyData { - - h := params.Hash.New() - h.Write(clientHello.Marshal()) - chHash := h.Sum(nil) - - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - earlySecret := HkdfExtract(params.Hash, zero, pskSecret) - clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) - } - - // Select a next protocol - connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err) - return nil, nil, AlertNoApplicationProtocol - } - - logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") - return ServerStateNegotiated{ - Caps: state.Caps, - Params: connParams, - - dhGroup: dhGroup, - dhPublic: dhPublic, - dhSecret: dhSecret, - pskSecret: pskSecret, - selectedPSK: selectedPSK, - cert: cert, - certScheme: certScheme, - clientEarlyTrafficSecret: clientEarlyTrafficSecret, - - firstClientHello: state.firstClientHello, - helloRetryRequest: state.helloRetryRequest, - clientHello: clientHello, - }.Next(nil) -} - -type ServerStateNegotiated struct { - Caps Capabilities - Params ConnectionParameters - - dhGroup NamedGroup - dhPublic []byte - dhSecret []byte - pskSecret []byte - clientEarlyTrafficSecret []byte - selectedPSK int - cert *Certificate - certScheme SignatureScheme - - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage - clientHello *HandshakeMessage -} - -func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - // Create the ServerHello - sh := &ServerHelloBody{ - Version: supportedVersion, - CipherSuite: state.Params.CipherSuite, - } - _, err := prng.Read(sh.Random[:]) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) - return nil, nil, AlertInternalError - } - if state.Params.UsingDH { - logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") - err = sh.Extensions.Add(&KeyShareExtension{ - HandshakeType: HandshakeTypeServerHello, - Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}}, - }) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err) - return nil, nil, AlertInternalError - } - } - if state.Params.UsingPSK { - logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension") - err = sh.Extensions.Add(&PreSharedKeyExtension{ - HandshakeType: HandshakeTypeServerHello, - SelectedIdentity: uint16(state.selectedPSK), - }) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) - return nil, nil, AlertInternalError - } - } - - serverHello, err := HandshakeMessageFromBody(sh) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err) - return nil, nil, AlertInternalError - } - - // Look up crypto params - params, ok := cipherSuiteMap[sh.CipherSuite] - if !ok { - logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite) - return nil, nil, AlertHandshakeFailure - } - - // Start up the handshake hash - handshakeHash := params.Hash.New() - handshakeHash.Write(state.firstClientHello.Marshal()) - handshakeHash.Write(state.helloRetryRequest.Marshal()) - handshakeHash.Write(state.clientHello.Marshal()) - handshakeHash.Write(serverHello.Marshal()) - - // Compute handshake secrets - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - var earlySecret []byte - if state.Params.UsingPSK { - earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret) - } else { - earlySecret = HkdfExtract(params.Hash, zero, zero) - } - - if state.dhSecret == nil { - state.dhSecret = zero - } - - h0 := params.Hash.New().Sum(nil) - h2 := handshakeHash.Sum(nil) - preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) - handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret) - clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) - serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) - preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) - masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) - - logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret) - logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) - logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) - logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) - logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) - - clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret) - serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - - // Send an EncryptedExtensions message (even if it's empty) - eeList := ExtensionList{} - if state.Params.NextProto != "" { - logf(logTypeHandshake, "[server] sending ALPN extension") - err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}}) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err) - return nil, nil, AlertInternalError - } - } - if state.Params.UsingEarlyData { - logf(logTypeHandshake, "[server] sending EDI extension") - err = eeList.Add(&EarlyDataExtension{}) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err) - return nil, nil, AlertInternalError - } - } - ee := &EncryptedExtensionsBody{eeList} - - // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) - return nil, nil, AlertInternalError - } - } - - eem, err := HandshakeMessageFromBody(ee) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err) - return nil, nil, AlertInternalError - } - - handshakeHash.Write(eem.Marshal()) - - toSend := []HandshakeAction{ - SendHandshakeMessage{serverHello}, - RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys}, - SendHandshakeMessage{eem}, - } - - // Authenticate with a certificate if required - if !state.Params.UsingPSK { - // Send a CertificateRequest message if we want client auth - if state.Caps.RequireClientAuth { - state.Params.UsingClientAuth = true - - // XXX: We don't support sending any constraints besides a list of - // supported signature algorithms - cr := &CertificateRequestBody{} - schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} - err := cr.Extensions.Add(schemes) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err) - return nil, nil, AlertInternalError - } - - crm, err := HandshakeMessageFromBody(cr) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err) - return nil, nil, AlertInternalError - } - //TODO state.state.serverCertificateRequest = cr - - toSend = append(toSend, SendHandshakeMessage{crm}) - handshakeHash.Write(crm.Marshal()) - } - - // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(state.cert.Chain)), - } - for i, entry := range state.cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} - } - certm, err := HandshakeMessageFromBody(certificate) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, SendHandshakeMessage{certm}) - handshakeHash.Write(certm.Marshal()) - - certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} - logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) - - hcv := handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - err = certificateVerify.Sign(state.cert.PrivateKey, hcv) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - certvm, err := HandshakeMessageFromBody(certificateVerify) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, SendHandshakeMessage{certvm}) - handshakeHash.Write(certvm.Marshal()) - } - - // Compute secrets resulting from the server's first flight - h3 := handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) - logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) - - serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3) - logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) - - // Assemble the Finished message - fin := &FinishedBody{ - VerifyDataLen: len(serverFinishedData), - VerifyData: serverFinishedData, - } - finm, _ := HandshakeMessageFromBody(fin) - - toSend = append(toSend, SendHandshakeMessage{finm}) - handshakeHash.Write(finm.Marshal()) - - // Compute traffic secrets - h4 := handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4) - logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4) - - clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4) - serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4) - logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) - logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) - - serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret) - toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys}) - - exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4) - logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret) - - if state.Params.UsingEarlyData { - clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret) - - logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]") - nextState := ServerStateWaitEOED{ - AuthCertificate: state.Caps.AuthCertificate, - Params: state.Params, - cryptoParams: params, - handshakeHash: handshakeHash, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - clientTrafficSecret: clientTrafficSecret, - serverTrafficSecret: serverTrafficSecret, - exporterSecret: exporterSecret, - } - toSend = append(toSend, []HandshakeAction{ - RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys}, - ReadEarlyData{}, - }...) - return nextState, toSend, AlertNoAlert - } - - logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") - toSend = append(toSend, []HandshakeAction{ - RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, - ReadPastEarlyData{}, - }...) - waitFlight2 := ServerStateWaitFlight2{ - AuthCertificate: state.Caps.AuthCertificate, - Params: state.Params, - cryptoParams: params, - handshakeHash: handshakeHash, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - clientTrafficSecret: clientTrafficSecret, - serverTrafficSecret: serverTrafficSecret, - exporterSecret: exporterSecret, - } - nextState, moreToSend, alert := waitFlight2.Next(nil) - toSend = append(toSend, moreToSend...) - return nextState, toSend, alert -} - -type ServerStateWaitEOED struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - masterSecret []byte - clientHandshakeTrafficSecret []byte - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData { - logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - if len(hm.body) > 0 { - logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]") - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) - - logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]") - toSend := []HandshakeAction{ - RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, - } - waitFlight2 := ServerStateWaitFlight2{ - AuthCertificate: state.AuthCertificate, - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - nextState, moreToSend, alert := waitFlight2.Next(nil) - toSend = append(toSend, moreToSend...) - return nextState, toSend, alert -} - -type ServerStateWaitFlight2 struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - masterSecret []byte - clientHandshakeTrafficSecret []byte - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm != nil { - logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - if state.Params.UsingClientAuth { - logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]") - nextState := ServerStateWaitCert{ - AuthCertificate: state.AuthCertificate, - Params: state.Params, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert - } - - logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]") - nextState := ServerStateWaitFinished{ - Params: state.Params, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert -} - -type ServerStateWaitCert struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - masterSecret []byte - clientHandshakeTrafficSecret []byte - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeCertificate { - logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - cert := &CertificateBody{} - _, err := cert.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - if len(cert.CertificateList) == 0 { - logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate") - - logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]") - nextState := ServerStateWaitFinished{ - Params: state.Params, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert - } - - logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]") - nextState := ServerStateWaitCV{ - AuthCertificate: state.AuthCertificate, - Params: state.Params, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - clientCertificate: cert, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert -} - -type ServerStateWaitCV struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - - masterSecret []byte - clientHandshakeTrafficSecret []byte - - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte - - clientCertificate *CertificateBody -} - -func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { - logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm)) - return nil, nil, AlertUnexpectedMessage - } - - certVerify := &CertificateVerifyBody{} - _, err := certVerify.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) - return nil, nil, AlertDecodeError - } - - // Verify client signature over handshake hash - hcv := state.handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey - if err := certVerify.Verify(clientPublicKey, hcv); err != nil { - logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) - return nil, nil, AlertHandshakeFailure - } - - if state.AuthCertificate != nil { - err := state.AuthCertificate(state.clientCertificate.CertificateList) - if err != nil { - logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate") - return nil, nil, AlertBadCertificate - } - } else { - logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate") - } - - // If it passes, record the certificateVerify in the transcript hash - state.handshakeHash.Write(hm.Marshal()) - - logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]") - nextState := ServerStateWaitFinished{ - Params: state.Params, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert -} - -type ServerStateWaitFinished struct { - Params ConnectionParameters - cryptoParams CipherSuiteParams - - masterSecret []byte - clientHandshakeTrafficSecret []byte - - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil || hm.msgType != HandshakeTypeFinished { - logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} - _, err := fin.Unmarshal(hm.body) - if err != nil { - logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) - return nil, nil, AlertDecodeError - } - - // Verify client Finished data - h5 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) - - clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) - logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) - - if !bytes.Equal(fin.VerifyData, clientFinishedData) { - logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify") - return nil, nil, AlertHandshakeFailure - } - - // Compute the resumption secret - state.handshakeHash.Write(hm.Marshal()) - h6 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6) - - resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) - logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) - - // Compute client traffic keys - clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) - - logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") - nextState := StateConnected{ - Params: state.Params, - isClient: false, - cryptoParams: state.cryptoParams, - resumptionSecret: resumptionSecret, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - toSend := []HandshakeAction{ - RekeyIn{Label: "application", KeySet: clientTrafficKeys}, - } - return nextState, toSend, AlertNoAlert -} diff --git a/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/bifurcation/mint/state-machine.go deleted file mode 100644 index 4eb468c6..00000000 --- a/vendor/github.com/bifurcation/mint/state-machine.go +++ /dev/null @@ -1,230 +0,0 @@ -package mint - -import ( - "time" -) - -// Marker interface for actions that an implementation should take based on -// state transitions. -type HandshakeAction interface{} - -type SendHandshakeMessage struct { - Message *HandshakeMessage -} - -type SendEarlyData struct{} - -type ReadEarlyData struct{} - -type ReadPastEarlyData struct{} - -type RekeyIn struct { - Label string - KeySet keySet -} - -type RekeyOut struct { - Label string - KeySet keySet -} - -type StorePSK struct { - PSK PreSharedKey -} - -type HandshakeState interface { - Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) -} - -type AppExtensionHandler interface { - Send(hs HandshakeType, el *ExtensionList) error - Receive(hs HandshakeType, el *ExtensionList) error -} - -// Capabilities objects represent the capabilities of a TLS client or server, -// as an input to TLS negotiation -type Capabilities struct { - // For both client and server - CipherSuites []CipherSuite - Groups []NamedGroup - SignatureSchemes []SignatureScheme - PSKs PreSharedKeyCache - Certificates []*Certificate - AuthCertificate func(chain []CertificateEntry) error - ExtensionHandler AppExtensionHandler - - // For client - PSKModes []PSKKeyExchangeMode - - // For server - NextProtos []string - AllowEarlyData bool - RequireCookie bool - CookieHandler CookieHandler - RequireClientAuth bool -} - -// ConnectionOptions objects represent per-connection settings for a client -// initiating a connection -type ConnectionOptions struct { - ServerName string - NextProtos []string - EarlyData []byte -} - -// ConnectionParameters objects represent the parameters negotiated for a -// connection. -type ConnectionParameters struct { - UsingPSK bool - UsingDH bool - ClientSendingEarlyData bool - UsingEarlyData bool - UsingClientAuth bool - - CipherSuite CipherSuite - ServerName string - NextProto string -} - -// StateConnected is symmetric between client and server -type StateConnected struct { - Params ConnectionParameters - isClient bool - cryptoParams CipherSuiteParams - resumptionSecret []byte - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { - var trafficKeys keySet - if state.isClient { - state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, - labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) - } else { - state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, - labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) - } - - kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) - return nil, AlertInternalError - } - - toSend := []HandshakeAction{ - SendHandshakeMessage{kum}, - RekeyOut{Label: "update", KeySet: trafficKeys}, - } - return toSend, AlertNoAlert -} - -func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { - tkt, err := NewSessionTicket(length, lifetime) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) - return nil, AlertInternalError - } - - err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime}) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err) - return nil, AlertInternalError - } - - resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, - labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size()) - - newPSK := PreSharedKey{ - CipherSuite: state.cryptoParams.Suite, - IsResumption: true, - Identity: tkt.Ticket, - Key: resumptionKey, - NextProto: state.Params.NextProto, - ReceivedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second), - TicketAgeAdd: tkt.TicketAgeAdd, - } - - tktm, err := HandshakeMessageFromBody(tkt) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) - return nil, AlertInternalError - } - - toSend := []HandshakeAction{ - StorePSK{newPSK}, - SendHandshakeMessage{tktm}, - } - return toSend, AlertNoAlert -} - -func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil { - logf(logTypeHandshake, "[StateConnected] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - bodyGeneric, err := hm.ToBody() - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - switch body := bodyGeneric.(type) { - case *KeyUpdateBody: - var trafficKeys keySet - if !state.isClient { - state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, - labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) - } else { - state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, - labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) - } - - toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}} - - // If requested, roll outbound keys and send a KeyUpdate - if body.KeyUpdateRequest == KeyUpdateRequested { - moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) - if alert != AlertNoAlert { - return nil, nil, alert - } - - toSend = append(toSend, moreToSend...) - } - - return state, toSend, AlertNoAlert - - case *NewSessionTicketBody: - // XXX: Allow NewSessionTicket in both directions? - if !state.isClient { - return nil, nil, AlertUnexpectedMessage - } - - resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, - labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) - - psk := PreSharedKey{ - CipherSuite: state.cryptoParams.Suite, - IsResumption: true, - Identity: body.Ticket, - Key: resumptionKey, - NextProto: state.Params.NextProto, - ReceivedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second), - TicketAgeAdd: body.TicketAgeAdd, - } - - toSend := []HandshakeAction{StorePSK{psk}} - return state, toSend, AlertNoAlert - } - - logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType) - return nil, nil, AlertUnexpectedMessage -} diff --git a/vendor/github.com/bifurcation/mint/syntax/decode.go b/vendor/github.com/bifurcation/mint/syntax/decode.go deleted file mode 100644 index cd5aadaf..00000000 --- a/vendor/github.com/bifurcation/mint/syntax/decode.go +++ /dev/null @@ -1,243 +0,0 @@ -package syntax - -import ( - "bytes" - "fmt" - "reflect" - "runtime" -) - -func Unmarshal(data []byte, v interface{}) (int, error) { - // Check for well-formedness. - // Avoids filling out half a data structure - // before discovering a JSON syntax error. - d := decodeState{} - d.Write(data) - return d.unmarshal(v) -} - -// These are the options that can be specified in the struct tag. Right now, -// all of them apply to variable-length vectors and nothing else -type decOpts struct { - head uint // length of length in bytes - min uint // minimum size in bytes - max uint // maximum size in bytes -} - -type decodeState struct { - bytes.Buffer -} - -func (d *decodeState) unmarshal(v interface{}) (read int, err error) { - defer func() { - if r := recover(); r != nil { - if _, ok := r.(runtime.Error); ok { - panic(r) - } - if s, ok := r.(string); ok { - panic(s) - } - err = r.(error) - } - }() - - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)") - } - - read = d.value(rv) - return read, nil -} - -func (e *decodeState) value(v reflect.Value) int { - return valueDecoder(v)(e, v, decOpts{}) -} - -type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int - -func valueDecoder(v reflect.Value) decoderFunc { - return typeDecoder(v.Type().Elem()) -} - -func typeDecoder(t reflect.Type) decoderFunc { - // Note: Omits the caching / wait-group things that encoding/json uses - return newTypeDecoder(t) -} - -func newTypeDecoder(t reflect.Type) decoderFunc { - // Note: Does not support Marshaler, so don't need the allowAddr argument - - switch t.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return uintDecoder - case reflect.Array: - return newArrayDecoder(t) - case reflect.Slice: - return newSliceDecoder(t) - case reflect.Struct: - return newStructDecoder(t) - default: - panic(fmt.Errorf("Unsupported type (%s)", t)) - } -} - -///// Specific decoders below - -func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { - var uintLen int - switch v.Elem().Kind() { - case reflect.Uint8: - uintLen = 1 - case reflect.Uint16: - uintLen = 2 - case reflect.Uint32: - uintLen = 4 - case reflect.Uint64: - uintLen = 8 - } - - buf := make([]byte, uintLen) - n, err := d.Read(buf) - if err != nil { - panic(err) - } - if n != uintLen { - panic(fmt.Errorf("Insufficient data to read uint")) - } - - val := uint64(0) - for _, b := range buf { - val = (val << 8) + uint64(b) - } - - v.Elem().SetUint(val) - return uintLen -} - -////////// - -type arrayDecoder struct { - elemDec decoderFunc -} - -func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { - n := v.Elem().Type().Len() - read := 0 - for i := 0; i < n; i += 1 { - read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts) - } - return read -} - -func newArrayDecoder(t reflect.Type) decoderFunc { - dec := &arrayDecoder{typeDecoder(t.Elem())} - return dec.decode -} - -////////// - -type sliceDecoder struct { - elementType reflect.Type - elementDec decoderFunc -} - -func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { - if opts.head == 0 { - panic(fmt.Errorf("Cannot decode a slice without a header length")) - } - - lengthBytes := make([]byte, opts.head) - n, err := d.Read(lengthBytes) - if err != nil { - panic(err) - } - if uint(n) != opts.head { - panic(fmt.Errorf("Not enough data to read header")) - } - - length := uint(0) - for _, b := range lengthBytes { - length = (length << 8) + uint(b) - } - - if opts.max > 0 && length > opts.max { - panic(fmt.Errorf("Length of vector exceeds declared max")) - } - if length < opts.min { - panic(fmt.Errorf("Length of vector below declared min")) - } - - data := make([]byte, length) - n, err = d.Read(data) - if err != nil { - panic(err) - } - if uint(n) != length { - panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length)) - } - - elemBuf := &decodeState{} - elemBuf.Write(data) - elems := []reflect.Value{} - read := int(opts.head) - for elemBuf.Len() > 0 { - elem := reflect.New(sd.elementType) - read += sd.elementDec(elemBuf, elem, opts) - elems = append(elems, elem) - } - - v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems))) - for i := 0; i < len(elems); i += 1 { - v.Elem().Index(i).Set(elems[i].Elem()) - } - return read -} - -func newSliceDecoder(t reflect.Type) decoderFunc { - dec := &sliceDecoder{ - elementType: t.Elem(), - elementDec: typeDecoder(t.Elem()), - } - return dec.decode -} - -////////// - -type structDecoder struct { - fieldOpts []decOpts - fieldDecs []decoderFunc -} - -func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { - read := 0 - for i := range sd.fieldDecs { - read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i]) - } - return read -} - -func newStructDecoder(t reflect.Type) decoderFunc { - n := t.NumField() - sd := structDecoder{ - fieldOpts: make([]decOpts, n), - fieldDecs: make([]decoderFunc, n), - } - - for i := 0; i < n; i += 1 { - f := t.Field(i) - - tag := f.Tag.Get("tls") - tagOpts := parseTag(tag) - - sd.fieldOpts[i] = decOpts{ - head: tagOpts["head"], - max: tagOpts["max"], - min: tagOpts["min"], - } - - sd.fieldDecs[i] = typeDecoder(f.Type) - } - - return sd.decode -} diff --git a/vendor/github.com/bifurcation/mint/syntax/encode.go b/vendor/github.com/bifurcation/mint/syntax/encode.go deleted file mode 100644 index 2874f404..00000000 --- a/vendor/github.com/bifurcation/mint/syntax/encode.go +++ /dev/null @@ -1,187 +0,0 @@ -package syntax - -import ( - "bytes" - "fmt" - "reflect" - "runtime" -) - -func Marshal(v interface{}) ([]byte, error) { - e := &encodeState{} - err := e.marshal(v, encOpts{}) - if err != nil { - return nil, err - } - return e.Bytes(), nil -} - -// These are the options that can be specified in the struct tag. Right now, -// all of them apply to variable-length vectors and nothing else -type encOpts struct { - head uint // length of length in bytes - min uint // minimum size in bytes - max uint // maximum size in bytes -} - -type encodeState struct { - bytes.Buffer -} - -func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) { - defer func() { - if r := recover(); r != nil { - if _, ok := r.(runtime.Error); ok { - panic(r) - } - if s, ok := r.(string); ok { - panic(s) - } - err = r.(error) - } - }() - e.reflectValue(reflect.ValueOf(v), opts) - return nil -} - -func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) { - valueEncoder(v)(e, v, opts) -} - -type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts) - -func valueEncoder(v reflect.Value) encoderFunc { - if !v.IsValid() { - panic(fmt.Errorf("Cannot encode an invalid value")) - } - return typeEncoder(v.Type()) -} - -func typeEncoder(t reflect.Type) encoderFunc { - // Note: Omits the caching / wait-group things that encoding/json uses - return newTypeEncoder(t) -} - -func newTypeEncoder(t reflect.Type) encoderFunc { - // Note: Does not support Marshaler, so don't need the allowAddr argument - - switch t.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return uintEncoder - case reflect.Array: - return newArrayEncoder(t) - case reflect.Slice: - return newSliceEncoder(t) - case reflect.Struct: - return newStructEncoder(t) - default: - panic(fmt.Errorf("Unsupported type (%s)", t)) - } -} - -///// Specific encoders below - -func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { - u := v.Uint() - switch v.Type().Kind() { - case reflect.Uint8: - e.WriteByte(byte(u)) - case reflect.Uint16: - e.Write([]byte{byte(u >> 8), byte(u)}) - case reflect.Uint32: - e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) - case reflect.Uint64: - e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32), - byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) - } -} - -////////// - -type arrayEncoder struct { - elemEnc encoderFunc -} - -func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { - n := v.Len() - for i := 0; i < n; i += 1 { - ae.elemEnc(e, v.Index(i), opts) - } -} - -func newArrayEncoder(t reflect.Type) encoderFunc { - enc := &arrayEncoder{typeEncoder(t.Elem())} - return enc.encode -} - -////////// - -type sliceEncoder struct { - ae *arrayEncoder -} - -func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { - if opts.head == 0 { - panic(fmt.Errorf("Cannot encode a slice without a header length")) - } - - arrayState := &encodeState{} - se.ae.encode(arrayState, v, opts) - - n := uint(arrayState.Len()) - if opts.max > 0 && n > opts.max { - panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max)) - } - if n>>(8*opts.head) > 0 { - panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head)) - } - if n < opts.min { - panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min)) - } - - for i := int(opts.head - 1); i >= 0; i -= 1 { - e.WriteByte(byte(n >> (8 * uint(i)))) - } - e.Write(arrayState.Bytes()) -} - -func newSliceEncoder(t reflect.Type) encoderFunc { - enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}} - return enc.encode -} - -////////// - -type structEncoder struct { - fieldOpts []encOpts - fieldEncs []encoderFunc -} - -func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { - for i := range se.fieldEncs { - se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i]) - } -} - -func newStructEncoder(t reflect.Type) encoderFunc { - n := t.NumField() - se := structEncoder{ - fieldOpts: make([]encOpts, n), - fieldEncs: make([]encoderFunc, n), - } - - for i := 0; i < n; i += 1 { - f := t.Field(i) - tag := f.Tag.Get("tls") - tagOpts := parseTag(tag) - - se.fieldOpts[i] = encOpts{ - head: tagOpts["head"], - max: tagOpts["max"], - min: tagOpts["min"], - } - se.fieldEncs[i] = typeEncoder(f.Type) - } - - return se.encode -} diff --git a/vendor/github.com/bifurcation/mint/syntax/tags.go b/vendor/github.com/bifurcation/mint/syntax/tags.go deleted file mode 100644 index a6c9c88d..00000000 --- a/vendor/github.com/bifurcation/mint/syntax/tags.go +++ /dev/null @@ -1,30 +0,0 @@ -package syntax - -import ( - "strconv" - "strings" -) - -// `tls:"head=2,min=2,max=255"` - -type tagOptions map[string]uint - -// parseTag parses a struct field's "tls" tag as a comma-separated list of -// name=value pairs, where the values MUST be unsigned integers -func parseTag(tag string) tagOptions { - opts := tagOptions{} - for _, token := range strings.Split(tag, ",") { - if strings.Index(token, "=") == -1 { - continue - } - - parts := strings.Split(token, "=") - if len(parts[0]) == 0 { - continue - } - if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 { - opts[parts[0]] = uint(val) - } - } - return opts -} diff --git a/vendor/github.com/bifurcation/mint/tls.go b/vendor/github.com/bifurcation/mint/tls.go deleted file mode 100644 index 0c57aba5..00000000 --- a/vendor/github.com/bifurcation/mint/tls.go +++ /dev/null @@ -1,168 +0,0 @@ -package mint - -// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls - -import ( - "errors" - "net" - "strings" - "time" -) - -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config) *Conn { - return NewConn(conn, config, false) -} - -// Client returns a new TLS client side connection -// using conn as the underlying transport. -// The config cannot be nil: users must set either ServerName or -// InsecureSkipVerify in the config. -func Client(conn net.Conn, config *Config) *Conn { - return NewConn(conn, config, true) -} - -// A listener implements a network listener (net.Listener) for TLS connections. -type Listener struct { - net.Listener - config *Config -} - -// Accept waits for and returns the next incoming TLS connection. -// The returned connection c is a *tls.Conn. -func (l *Listener) Accept() (c net.Conn, err error) { - c, err = l.Listener.Accept() - if err != nil { - return - } - server := Server(c, l.config) - err = server.Handshake() - if err == AlertNoAlert { - err = nil - } - c = server - return -} - -// NewListener creates a Listener which accepts connections from an inner -// Listener and wraps each connection with Server. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func NewListener(inner net.Listener, config *Config) net.Listener { - l := new(Listener) - l.Listener = inner - l.config = config - return l -} - -// Listen creates a TLS listener accepting connections on the -// given network address using net.Listen. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Listen(network, laddr string, config *Config) (net.Listener, error) { - if config == nil || !config.ValidForServer() { - return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") - } - l, err := net.Listen(network, laddr) - if err != nil { - return nil, err - } - return NewListener(l, config), nil -} - -type TimeoutError struct{} - -func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" } -func (TimeoutError) Timeout() bool { return true } -func (TimeoutError) Temporary() bool { return true } - -// DialWithDialer connects to the given network address using dialer.Dial and -// then initiates a TLS handshake, returning the resulting TLS connection. Any -// timeout or deadline given in the dialer apply to connection and TLS -// handshake as a whole. -// -// DialWithDialer interprets a nil configuration as equivalent to the zero -// configuration; see the documentation of Config for the defaults. -func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - // We want the Timeout and Deadline values from dialer to cover the - // whole process: TCP connection and TLS handshake. This means that we - // also need to start our own timers now. - timeout := dialer.Timeout - - if !dialer.Deadline.IsZero() { - deadlineTimeout := dialer.Deadline.Sub(time.Now()) - if timeout == 0 || deadlineTimeout < timeout { - timeout = deadlineTimeout - } - } - - var errChannel chan error - - if timeout != 0 { - errChannel = make(chan error, 2) - time.AfterFunc(timeout, func() { - errChannel <- TimeoutError{} - }) - } - - rawConn, err := dialer.Dial(network, addr) - if err != nil { - return nil, err - } - - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - - if config == nil { - config = &Config{} - } - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - if config.ServerName == "" { - // Make a copy to avoid polluting argument or default. - c := config.Clone() - c.ServerName = hostname - config = c - } - - conn := Client(rawConn, config) - - if timeout == 0 { - err = conn.Handshake() - if err == AlertNoAlert { - err = nil - } - } else { - go func() { - errChannel <- conn.Handshake() - }() - - err = <-errChannel - if err == AlertNoAlert { - err = nil - } - } - - if err != nil { - rawConn.Close() - return nil, err - } - - return conn, nil -} - -// Dial connects to the given network address using net.Dial -// and then initiates a TLS handshake, returning the resulting -// TLS connection. -// Dial interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. -func Dial(network, addr string, config *Config) (*Conn, error) { - return DialWithDialer(new(net.Dialer), network, addr, config) -} diff --git a/vendor/github.com/golang/mock/gomock/LICENSE b/vendor/github.com/golang/mock/gomock/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/golang/mock/gomock/call.go b/vendor/github.com/golang/mock/gomock/call.go new file mode 100644 index 00000000..a3fa1ae4 --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/call.go @@ -0,0 +1,428 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// Call represents an expected call to a mock. +type Call struct { + t TestReporter // for triggering test failures on invalid call setup + + receiver interface{} // the receiver of the method call + method string // the name of the method + methodType reflect.Type // the type of the method + args []Matcher // the args + origin string // file and line number of call setup + + preReqs []*Call // prerequisite calls + + // Expectations + minCalls, maxCalls int + + numCalls int // actual number made + + // actions are called when this Call is called. Each action gets the args and + // can set the return values by returning a non-nil slice. Actions run in the + // order they are created. + actions []func([]interface{}) []interface{} +} + +// newCall creates a *Call. It requires the method type in order to support +// unexported methods. +func newCall(t TestReporter, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call { + if h, ok := t.(testHelper); ok { + h.Helper() + } + + // TODO: check arity, types. + margs := make([]Matcher, len(args)) + for i, arg := range args { + if m, ok := arg.(Matcher); ok { + margs[i] = m + } else if arg == nil { + // Handle nil specially so that passing a nil interface value + // will match the typed nils of concrete args. + margs[i] = Nil() + } else { + margs[i] = Eq(arg) + } + } + + origin := callerInfo(3) + actions := []func([]interface{}) []interface{}{func([]interface{}) []interface{} { + // Synthesize the zero value for each of the return args' types. + rets := make([]interface{}, methodType.NumOut()) + for i := 0; i < methodType.NumOut(); i++ { + rets[i] = reflect.Zero(methodType.Out(i)).Interface() + } + return rets + }} + return &Call{t: t, receiver: receiver, method: method, methodType: methodType, + args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} +} + +// AnyTimes allows the expectation to be called 0 or more times +func (c *Call) AnyTimes() *Call { + c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity + return c +} + +// MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called, MinTimes also +// sets the maximum number of calls to infinity. +func (c *Call) MinTimes(n int) *Call { + c.minCalls = n + if c.maxCalls == 1 { + c.maxCalls = 1e8 + } + return c +} + +// MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called, MaxTimes also +// sets the minimum number of calls to 0. +func (c *Call) MaxTimes(n int) *Call { + c.maxCalls = n + if c.minCalls == 1 { + c.minCalls = 0 + } + return c +} + +// DoAndReturn declares the action to run when the call is matched. +// The return values from this function are returned by the mocked function. +// It takes an interface{} argument to support n-arity functions. +func (c *Call) DoAndReturn(f interface{}) *Call { + // TODO: Check arity and types here, rather than dying badly elsewhere. + v := reflect.ValueOf(f) + + c.addAction(func(args []interface{}) []interface{} { + vargs := make([]reflect.Value, len(args)) + ft := v.Type() + for i := 0; i < len(args); i++ { + if args[i] != nil { + vargs[i] = reflect.ValueOf(args[i]) + } else { + // Use the zero value for the arg. + vargs[i] = reflect.Zero(ft.In(i)) + } + } + vrets := v.Call(vargs) + rets := make([]interface{}, len(vrets)) + for i, ret := range vrets { + rets[i] = ret.Interface() + } + return rets + }) + return c +} + +// Do declares the action to run when the call is matched. The function's +// return values are ignored to retain backward compatibility. To use the +// return values call DoAndReturn. +// It takes an interface{} argument to support n-arity functions. +func (c *Call) Do(f interface{}) *Call { + // TODO: Check arity and types here, rather than dying badly elsewhere. + v := reflect.ValueOf(f) + + c.addAction(func(args []interface{}) []interface{} { + vargs := make([]reflect.Value, len(args)) + ft := v.Type() + for i := 0; i < len(args); i++ { + if args[i] != nil { + vargs[i] = reflect.ValueOf(args[i]) + } else { + // Use the zero value for the arg. + vargs[i] = reflect.Zero(ft.In(i)) + } + } + v.Call(vargs) + return nil + }) + return c +} + +// Return declares the values to be returned by the mocked function call. +func (c *Call) Return(rets ...interface{}) *Call { + if h, ok := c.t.(testHelper); ok { + h.Helper() + } + + mt := c.methodType + if len(rets) != mt.NumOut() { + c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]", + c.receiver, c.method, len(rets), mt.NumOut(), c.origin) + } + for i, ret := range rets { + if got, want := reflect.TypeOf(ret), mt.Out(i); got == want { + // Identical types; nothing to do. + } else if got == nil { + // Nil needs special handling. + switch want.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + // ok + default: + c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]", + i, c.receiver, c.method, want, c.origin) + } + } else if got.AssignableTo(want) { + // Assignable type relation. Make the assignment now so that the generated code + // can return the values with a type assertion. + v := reflect.New(want).Elem() + v.Set(reflect.ValueOf(ret)) + rets[i] = v.Interface() + } else { + c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]", + i, c.receiver, c.method, got, want, c.origin) + } + } + + c.addAction(func([]interface{}) []interface{} { + return rets + }) + + return c +} + +// Times declares the exact number of times a function call is expected to be executed. +func (c *Call) Times(n int) *Call { + c.minCalls, c.maxCalls = n, n + return c +} + +// SetArg declares an action that will set the nth argument's value, +// indirected through a pointer. Or, in the case of a slice, SetArg +// will copy value's elements into the nth argument. +func (c *Call) SetArg(n int, value interface{}) *Call { + if h, ok := c.t.(testHelper); ok { + h.Helper() + } + + mt := c.methodType + // TODO: This will break on variadic methods. + // We will need to check those at invocation time. + if n < 0 || n >= mt.NumIn() { + c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]", + n, mt.NumIn(), c.origin) + } + // Permit setting argument through an interface. + // In the interface case, we don't (nay, can't) check the type here. + at := mt.In(n) + switch at.Kind() { + case reflect.Ptr: + dt := at.Elem() + if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) { + c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]", + n, vt, dt, c.origin) + } + case reflect.Interface: + // nothing to do + case reflect.Slice: + // nothing to do + default: + c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]", + n, at, c.origin) + } + + c.addAction(func(args []interface{}) []interface{} { + v := reflect.ValueOf(value) + switch reflect.TypeOf(args[n]).Kind() { + case reflect.Slice: + setSlice(args[n], v) + default: + reflect.ValueOf(args[n]).Elem().Set(v) + } + return nil + }) + return c +} + +// isPreReq returns true if other is a direct or indirect prerequisite to c. +func (c *Call) isPreReq(other *Call) bool { + for _, preReq := range c.preReqs { + if other == preReq || preReq.isPreReq(other) { + return true + } + } + return false +} + +// After declares that the call may only match after preReq has been exhausted. +func (c *Call) After(preReq *Call) *Call { + if h, ok := c.t.(testHelper); ok { + h.Helper() + } + + if c == preReq { + c.t.Fatalf("A call isn't allowed to be its own prerequisite") + } + if preReq.isPreReq(c) { + c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq) + } + + c.preReqs = append(c.preReqs, preReq) + return c +} + +// Returns true if the minimum number of calls have been made. +func (c *Call) satisfied() bool { + return c.numCalls >= c.minCalls +} + +// Returns true iff the maximum number of calls have been made. +func (c *Call) exhausted() bool { + return c.numCalls >= c.maxCalls +} + +func (c *Call) String() string { + args := make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = arg.String() + } + arguments := strings.Join(args, ", ") + return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin) +} + +// Tests if the given call matches the expected call. +// If yes, returns nil. If no, returns error with message explaining why it does not match. +func (c *Call) matches(args []interface{}) error { + if !c.methodType.IsVariadic() { + if len(args) != len(c.args) { + return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + + for i, m := range c.args { + if !m.Matches(args[i]) { + return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i], m) + } + } + } else { + if len(c.args) < c.methodType.NumIn()-1 { + return fmt.Errorf("Expected call at %s has the wrong number of matchers. Got: %d, want: %d", + c.origin, len(c.args), c.methodType.NumIn()-1) + } + if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) { + return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + if len(args) < len(c.args)-1 { + return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d", + c.origin, len(args), len(c.args)-1) + } + + for i, m := range c.args { + if i < c.methodType.NumIn()-1 { + // Non-variadic args + if !m.Matches(args[i]) { + return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i], m) + } + continue + } + // The last arg has a possibility of a variadic argument, so let it branch + + // sample: Foo(a int, b int, c ...int) + if i < len(c.args) && i < len(args) { + if m.Matches(args[i]) { + // Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC) + // Got Foo(a, b) want Foo(matcherA, matcherB) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD) + continue + } + } + + // The number of actual args don't match the number of matchers, + // or the last matcher is a slice and the last arg is not. + // If this function still matches it is because the last matcher + // matches all the remaining arguments or the lack of any. + // Convert the remaining arguments, if any, into a slice of the + // expected type. + vargsType := c.methodType.In(c.methodType.NumIn() - 1) + vargs := reflect.MakeSlice(vargsType, 0, len(args)-i) + for _, arg := range args[i:] { + vargs = reflect.Append(vargs, reflect.ValueOf(arg)) + } + if m.Matches(vargs.Interface()) { + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher) + break + } + // Wrong number of matchers or not match. Fail. + // Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB) + return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i:], c.args[i]) + + } + } + + // Check that all prerequisite calls have been satisfied. + for _, preReqCall := range c.preReqs { + if !preReqCall.satisfied() { + return fmt.Errorf("Expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v", + c.origin, preReqCall, c) + } + } + + // Check that the call is not exhausted. + if c.exhausted() { + return fmt.Errorf("Expected call at %s has already been called the max number of times.", c.origin) + } + + return nil +} + +// dropPrereqs tells the expected Call to not re-check prerequisite calls any +// longer, and to return its current set. +func (c *Call) dropPrereqs() (preReqs []*Call) { + preReqs = c.preReqs + c.preReqs = nil + return +} + +func (c *Call) call(args []interface{}) []func([]interface{}) []interface{} { + c.numCalls++ + return c.actions +} + +// InOrder declares that the given calls should occur in order. +func InOrder(calls ...*Call) { + for i := 1; i < len(calls); i++ { + calls[i].After(calls[i-1]) + } +} + +func setSlice(arg interface{}, v reflect.Value) { + va := reflect.ValueOf(arg) + for i := 0; i < v.Len(); i++ { + va.Index(i).Set(v.Index(i)) + } +} + +func (c *Call) addAction(action func([]interface{}) []interface{}) { + c.actions = append(c.actions, action) +} diff --git a/vendor/github.com/golang/mock/gomock/callset.go b/vendor/github.com/golang/mock/gomock/callset.go new file mode 100644 index 00000000..c44a8a58 --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/callset.go @@ -0,0 +1,108 @@ +// Copyright 2011 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "bytes" + "fmt" +) + +// callSet represents a set of expected calls, indexed by receiver and method +// name. +type callSet struct { + // Calls that are still expected. + expected map[callSetKey][]*Call + // Calls that have been exhausted. + exhausted map[callSetKey][]*Call +} + +// callSetKey is the key in the maps in callSet +type callSetKey struct { + receiver interface{} + fname string +} + +func newCallSet() *callSet { + return &callSet{make(map[callSetKey][]*Call), make(map[callSetKey][]*Call)} +} + +// Add adds a new expected call. +func (cs callSet) Add(call *Call) { + key := callSetKey{call.receiver, call.method} + m := cs.expected + if call.exhausted() { + m = cs.exhausted + } + m[key] = append(m[key], call) +} + +// Remove removes an expected call. +func (cs callSet) Remove(call *Call) { + key := callSetKey{call.receiver, call.method} + calls := cs.expected[key] + for i, c := range calls { + if c == call { + // maintain order for remaining calls + cs.expected[key] = append(calls[:i], calls[i+1:]...) + cs.exhausted[key] = append(cs.exhausted[key], call) + break + } + } +} + +// FindMatch searches for a matching call. Returns error with explanation message if no call matched. +func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) { + key := callSetKey{receiver, method} + + // Search through the expected calls. + expected := cs.expected[key] + var callsErrors bytes.Buffer + for _, call := range expected { + err := call.matches(args) + if err != nil { + fmt.Fprintf(&callsErrors, "\n%v", err) + } else { + return call, nil + } + } + + // If we haven't found a match then search through the exhausted calls so we + // get useful error messages. + exhausted := cs.exhausted[key] + for _, call := range exhausted { + if err := call.matches(args); err != nil { + fmt.Fprintf(&callsErrors, "\n%v", err) + } + } + + if len(expected)+len(exhausted) == 0 { + fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method) + } + + return nil, fmt.Errorf(callsErrors.String()) +} + +// Failures returns the calls that are not satisfied. +func (cs callSet) Failures() []*Call { + failures := make([]*Call, 0, len(cs.expected)) + for _, calls := range cs.expected { + for _, call := range calls { + if !call.satisfied() { + failures = append(failures, call) + } + } + } + return failures +} diff --git a/vendor/github.com/golang/mock/gomock/controller.go b/vendor/github.com/golang/mock/gomock/controller.go new file mode 100644 index 00000000..a7b79188 --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/controller.go @@ -0,0 +1,217 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// GoMock - a mock framework for Go. +// +// Standard usage: +// (1) Define an interface that you wish to mock. +// type MyInterface interface { +// SomeMethod(x int64, y string) +// } +// (2) Use mockgen to generate a mock from the interface. +// (3) Use the mock in a test: +// func TestMyThing(t *testing.T) { +// mockCtrl := gomock.NewController(t) +// defer mockCtrl.Finish() +// +// mockObj := something.NewMockMyInterface(mockCtrl) +// mockObj.EXPECT().SomeMethod(4, "blah") +// // pass mockObj to a real object and play with it. +// } +// +// By default, expected calls are not enforced to run in any particular order. +// Call order dependency can be enforced by use of InOrder and/or Call.After. +// Call.After can create more varied call order dependencies, but InOrder is +// often more convenient. +// +// The following examples create equivalent call order dependencies. +// +// Example of using Call.After to chain expected call order: +// +// firstCall := mockObj.EXPECT().SomeMethod(1, "first") +// secondCall := mockObj.EXPECT().SomeMethod(2, "second").After(firstCall) +// mockObj.EXPECT().SomeMethod(3, "third").After(secondCall) +// +// Example of using InOrder to declare expected call order: +// +// gomock.InOrder( +// mockObj.EXPECT().SomeMethod(1, "first"), +// mockObj.EXPECT().SomeMethod(2, "second"), +// mockObj.EXPECT().SomeMethod(3, "third"), +// ) +// +// TODO: +// - Handle different argument/return types (e.g. ..., chan, map, interface). +package gomock + +import ( + "fmt" + "golang.org/x/net/context" + "reflect" + "runtime" + "sync" +) + +// A TestReporter is something that can be used to report test failures. +// It is satisfied by the standard library's *testing.T. +type TestReporter interface { + Errorf(format string, args ...interface{}) + Fatalf(format string, args ...interface{}) +} + +// A Controller represents the top-level control of a mock ecosystem. +// It defines the scope and lifetime of mock objects, as well as their expectations. +// It is safe to call Controller's methods from multiple goroutines. +type Controller struct { + mu sync.Mutex + t TestReporter + expectedCalls *callSet + finished bool +} + +func NewController(t TestReporter) *Controller { + return &Controller{ + t: t, + expectedCalls: newCallSet(), + } +} + +type cancelReporter struct { + t TestReporter + cancel func() +} + +func (r *cancelReporter) Errorf(format string, args ...interface{}) { r.t.Errorf(format, args...) } +func (r *cancelReporter) Fatalf(format string, args ...interface{}) { + defer r.cancel() + r.t.Fatalf(format, args...) +} + +// WithContext returns a new Controller and a Context, which is cancelled on any +// fatal failure. +func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return NewController(&cancelReporter{t, cancel}), ctx +} + +func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call { + if h, ok := ctrl.t.(testHelper); ok { + h.Helper() + } + + recv := reflect.ValueOf(receiver) + for i := 0; i < recv.Type().NumMethod(); i++ { + if recv.Type().Method(i).Name == method { + return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...) + } + } + ctrl.t.Fatalf("gomock: failed finding method %s on %T", method, receiver) + panic("unreachable") +} + +func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call { + if h, ok := ctrl.t.(testHelper); ok { + h.Helper() + } + + call := newCall(ctrl.t, receiver, method, methodType, args...) + + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + ctrl.expectedCalls.Add(call) + + return call +} + +func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} { + if h, ok := ctrl.t.(testHelper); ok { + h.Helper() + } + + // Nest this code so we can use defer to make sure the lock is released. + actions := func() []func([]interface{}) []interface{} { + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + + expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args) + if err != nil { + origin := callerInfo(2) + ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err) + } + + // Two things happen here: + // * the matching call no longer needs to check prerequite calls, + // * and the prerequite calls are no longer expected, so remove them. + preReqCalls := expected.dropPrereqs() + for _, preReqCall := range preReqCalls { + ctrl.expectedCalls.Remove(preReqCall) + } + + actions := expected.call(args) + if expected.exhausted() { + ctrl.expectedCalls.Remove(expected) + } + return actions + }() + + var rets []interface{} + for _, action := range actions { + if r := action(args); r != nil { + rets = r + } + } + + return rets +} + +func (ctrl *Controller) Finish() { + if h, ok := ctrl.t.(testHelper); ok { + h.Helper() + } + + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + + if ctrl.finished { + ctrl.t.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.") + } + ctrl.finished = true + + // If we're currently panicking, probably because this is a deferred call, + // pass through the panic. + if err := recover(); err != nil { + panic(err) + } + + // Check that all remaining expected calls are satisfied. + failures := ctrl.expectedCalls.Failures() + for _, call := range failures { + ctrl.t.Errorf("missing call(s) to %v", call) + } + if len(failures) != 0 { + ctrl.t.Fatalf("aborting test due to missing call(s)") + } +} + +func callerInfo(skip int) string { + if _, file, line, ok := runtime.Caller(skip + 1); ok { + return fmt.Sprintf("%s:%d", file, line) + } + return "unknown file" +} + +type testHelper interface { + TestReporter + Helper() +} diff --git a/vendor/github.com/golang/mock/gomock/matchers.go b/vendor/github.com/golang/mock/gomock/matchers.go new file mode 100644 index 00000000..65ad8bab --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/matchers.go @@ -0,0 +1,124 @@ +//go:generate mockgen -destination mock_matcher/mock_matcher.go github.com/golang/mock/gomock Matcher + +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "fmt" + "reflect" +) + +// A Matcher is a representation of a class of values. +// It is used to represent the valid or expected arguments to a mocked method. +type Matcher interface { + // Matches returns whether x is a match. + Matches(x interface{}) bool + + // String describes what the matcher matches. + String() string +} + +type anyMatcher struct{} + +func (anyMatcher) Matches(x interface{}) bool { + return true +} + +func (anyMatcher) String() string { + return "is anything" +} + +type eqMatcher struct { + x interface{} +} + +func (e eqMatcher) Matches(x interface{}) bool { + return reflect.DeepEqual(e.x, x) +} + +func (e eqMatcher) String() string { + return fmt.Sprintf("is equal to %v", e.x) +} + +type nilMatcher struct{} + +func (nilMatcher) Matches(x interface{}) bool { + if x == nil { + return true + } + + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.Ptr, reflect.Slice: + return v.IsNil() + } + + return false +} + +func (nilMatcher) String() string { + return "is nil" +} + +type notMatcher struct { + m Matcher +} + +func (n notMatcher) Matches(x interface{}) bool { + return !n.m.Matches(x) +} + +func (n notMatcher) String() string { + // TODO: Improve this if we add a NotString method to the Matcher interface. + return "not(" + n.m.String() + ")" +} + +type assignableToTypeOfMatcher struct { + targetType reflect.Type +} + +func (m assignableToTypeOfMatcher) Matches(x interface{}) bool { + return reflect.TypeOf(x).AssignableTo(m.targetType) +} + +func (m assignableToTypeOfMatcher) String() string { + return "is assignable to " + m.targetType.Name() +} + +// Constructors +func Any() Matcher { return anyMatcher{} } +func Eq(x interface{}) Matcher { return eqMatcher{x} } +func Nil() Matcher { return nilMatcher{} } +func Not(x interface{}) Matcher { + if m, ok := x.(Matcher); ok { + return notMatcher{m} + } + return notMatcher{Eq(x)} +} + +// AssignableToTypeOf is a Matcher that matches if the parameter to the mock +// function is assignable to the type of the parameter to this function. +// +// Example usage: +// +// dbMock.EXPECT(). +// Insert(gomock.AssignableToTypeOf(&EmployeeRecord{})). +// Return(errors.New("DB error")) +// +func AssignableToTypeOf(x interface{}) Matcher { + return assignableToTypeOfMatcher{reflect.TypeOf(x)} +} diff --git a/vendor/github.com/golang/mock/gomock/mock_matcher/mock_matcher.go b/vendor/github.com/golang/mock/gomock/mock_matcher/mock_matcher.go new file mode 100644 index 00000000..7e4b4c8a --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/mock_matcher/mock_matcher.go @@ -0,0 +1,57 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/golang/mock/gomock (interfaces: Matcher) + +// Package mock_gomock is a generated GoMock package. +package mock_gomock + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockMatcher is a mock of Matcher interface +type MockMatcher struct { + ctrl *gomock.Controller + recorder *MockMatcherMockRecorder +} + +// MockMatcherMockRecorder is the mock recorder for MockMatcher +type MockMatcherMockRecorder struct { + mock *MockMatcher +} + +// NewMockMatcher creates a new mock instance +func NewMockMatcher(ctrl *gomock.Controller) *MockMatcher { + mock := &MockMatcher{ctrl: ctrl} + mock.recorder = &MockMatcherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockMatcher) EXPECT() *MockMatcherMockRecorder { + return m.recorder +} + +// Matches mocks base method +func (m *MockMatcher) Matches(arg0 interface{}) bool { + ret := m.ctrl.Call(m, "Matches", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Matches indicates an expected call of Matches +func (mr *MockMatcherMockRecorder) Matches(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Matches", reflect.TypeOf((*MockMatcher)(nil).Matches), arg0) +} + +// String mocks base method +func (m *MockMatcher) String() string { + ret := m.ctrl.Call(m, "String") + ret0, _ := ret[0].(string) + return ret0 +} + +// String indicates an expected call of String +func (mr *MockMatcherMockRecorder) String() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockMatcher)(nil).String)) +} diff --git a/vendor/github.com/hashicorp/golang-lru/LICENSE b/vendor/github.com/hashicorp/golang-lru/LICENSE deleted file mode 100644 index be2cc4df..00000000 --- a/vendor/github.com/hashicorp/golang-lru/LICENSE +++ /dev/null @@ -1,362 +0,0 @@ -Mozilla Public License, version 2.0 - -1. Definitions - -1.1. "Contributor" - - means each individual or legal entity that creates, contributes to the - creation of, or owns Covered Software. - -1.2. "Contributor Version" - - means the combination of the Contributions of others (if any) used by a - Contributor and that particular Contributor's Contribution. - -1.3. "Contribution" - - means Covered Software of a particular Contributor. - -1.4. "Covered Software" - - means Source Code Form to which the initial Contributor has attached the - notice in Exhibit A, the Executable Form of such Source Code Form, and - Modifications of such Source Code Form, in each case including portions - thereof. - -1.5. "Incompatible With Secondary Licenses" - means - - a. that the initial Contributor has attached the notice described in - Exhibit B to the Covered Software; or - - b. that the Covered Software was made available under the terms of - version 1.1 or earlier of the License, but not also under the terms of - a Secondary License. - -1.6. "Executable Form" - - means any form of the work other than Source Code Form. - -1.7. "Larger Work" - - means a work that combines Covered Software with other material, in a - separate file or files, that is not Covered Software. - -1.8. "License" - - means this document. - -1.9. "Licensable" - - means having the right to grant, to the maximum extent possible, whether - at the time of the initial grant or subsequently, any and all of the - rights conveyed by this License. - -1.10. "Modifications" - - means any of the following: - - a. any file in Source Code Form that results from an addition to, - deletion from, or modification of the contents of Covered Software; or - - b. any new file in Source Code Form that contains any Covered Software. - -1.11. "Patent Claims" of a Contributor - - means any patent claim(s), including without limitation, method, - process, and apparatus claims, in any patent Licensable by such - Contributor that would be infringed, but for the grant of the License, - by the making, using, selling, offering for sale, having made, import, - or transfer of either its Contributions or its Contributor Version. - -1.12. "Secondary License" - - means either the GNU General Public License, Version 2.0, the GNU Lesser - General Public License, Version 2.1, the GNU Affero General Public - License, Version 3.0, or any later versions of those licenses. - -1.13. "Source Code Form" - - means the form of the work preferred for making modifications. - -1.14. "You" (or "Your") - - means an individual or a legal entity exercising rights under this - License. For legal entities, "You" includes any entity that controls, is - controlled by, or is under common control with You. For purposes of this - definition, "control" means (a) the power, direct or indirect, to cause - the direction or management of such entity, whether by contract or - otherwise, or (b) ownership of more than fifty percent (50%) of the - outstanding shares or beneficial ownership of such entity. - - -2. License Grants and Conditions - -2.1. Grants - - Each Contributor hereby grants You a world-wide, royalty-free, - non-exclusive license: - - a. under intellectual property rights (other than patent or trademark) - Licensable by such Contributor to use, reproduce, make available, - modify, display, perform, distribute, and otherwise exploit its - Contributions, either on an unmodified basis, with Modifications, or - as part of a Larger Work; and - - b. under Patent Claims of such Contributor to make, use, sell, offer for - sale, have made, import, and otherwise transfer either its - Contributions or its Contributor Version. - -2.2. Effective Date - - The licenses granted in Section 2.1 with respect to any Contribution - become effective for each Contribution on the date the Contributor first - distributes such Contribution. - -2.3. Limitations on Grant Scope - - The licenses granted in this Section 2 are the only rights granted under - this License. No additional rights or licenses will be implied from the - distribution or licensing of Covered Software under this License. - Notwithstanding Section 2.1(b) above, no patent license is granted by a - Contributor: - - a. for any code that a Contributor has removed from Covered Software; or - - b. for infringements caused by: (i) Your and any other third party's - modifications of Covered Software, or (ii) the combination of its - Contributions with other software (except as part of its Contributor - Version); or - - c. under Patent Claims infringed by Covered Software in the absence of - its Contributions. - - This License does not grant any rights in the trademarks, service marks, - or logos of any Contributor (except as may be necessary to comply with - the notice requirements in Section 3.4). - -2.4. Subsequent Licenses - - No Contributor makes additional grants as a result of Your choice to - distribute the Covered Software under a subsequent version of this - License (see Section 10.2) or under the terms of a Secondary License (if - permitted under the terms of Section 3.3). - -2.5. Representation - - Each Contributor represents that the Contributor believes its - Contributions are its original creation(s) or it has sufficient rights to - grant the rights to its Contributions conveyed by this License. - -2.6. Fair Use - - This License is not intended to limit any rights You have under - applicable copyright doctrines of fair use, fair dealing, or other - equivalents. - -2.7. Conditions - - Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in - Section 2.1. - - -3. Responsibilities - -3.1. Distribution of Source Form - - All distribution of Covered Software in Source Code Form, including any - Modifications that You create or to which You contribute, must be under - the terms of this License. You must inform recipients that the Source - Code Form of the Covered Software is governed by the terms of this - License, and how they can obtain a copy of this License. You may not - attempt to alter or restrict the recipients' rights in the Source Code - Form. - -3.2. Distribution of Executable Form - - If You distribute Covered Software in Executable Form then: - - a. such Covered Software must also be made available in Source Code Form, - as described in Section 3.1, and You must inform recipients of the - Executable Form how they can obtain a copy of such Source Code Form by - reasonable means in a timely manner, at a charge no more than the cost - of distribution to the recipient; and - - b. You may distribute such Executable Form under the terms of this - License, or sublicense it under different terms, provided that the - license for the Executable Form does not attempt to limit or alter the - recipients' rights in the Source Code Form under this License. - -3.3. Distribution of a Larger Work - - You may create and distribute a Larger Work under terms of Your choice, - provided that You also comply with the requirements of this License for - the Covered Software. If the Larger Work is a combination of Covered - Software with a work governed by one or more Secondary Licenses, and the - Covered Software is not Incompatible With Secondary Licenses, this - License permits You to additionally distribute such Covered Software - under the terms of such Secondary License(s), so that the recipient of - the Larger Work may, at their option, further distribute the Covered - Software under the terms of either this License or such Secondary - License(s). - -3.4. Notices - - You may not remove or alter the substance of any license notices - (including copyright notices, patent notices, disclaimers of warranty, or - limitations of liability) contained within the Source Code Form of the - Covered Software, except that You may alter any license notices to the - extent required to remedy known factual inaccuracies. - -3.5. Application of Additional Terms - - You may choose to offer, and to charge a fee for, warranty, support, - indemnity or liability obligations to one or more recipients of Covered - Software. However, You may do so only on Your own behalf, and not on - behalf of any Contributor. You must make it absolutely clear that any - such warranty, support, indemnity, or liability obligation is offered by - You alone, and You hereby agree to indemnify every Contributor for any - liability incurred by such Contributor as a result of warranty, support, - indemnity or liability terms You offer. You may include additional - disclaimers of warranty and limitations of liability specific to any - jurisdiction. - -4. Inability to Comply Due to Statute or Regulation - - If it is impossible for You to comply with any of the terms of this License - with respect to some or all of the Covered Software due to statute, - judicial order, or regulation then You must: (a) comply with the terms of - this License to the maximum extent possible; and (b) describe the - limitations and the code they affect. Such description must be placed in a - text file included with all distributions of the Covered Software under - this License. Except to the extent prohibited by statute or regulation, - such description must be sufficiently detailed for a recipient of ordinary - skill to be able to understand it. - -5. Termination - -5.1. The rights granted under this License will terminate automatically if You - fail to comply with any of its terms. However, if You become compliant, - then the rights granted under this License from a particular Contributor - are reinstated (a) provisionally, unless and until such Contributor - explicitly and finally terminates Your grants, and (b) on an ongoing - basis, if such Contributor fails to notify You of the non-compliance by - some reasonable means prior to 60 days after You have come back into - compliance. Moreover, Your grants from a particular Contributor are - reinstated on an ongoing basis if such Contributor notifies You of the - non-compliance by some reasonable means, this is the first time You have - received notice of non-compliance with this License from such - Contributor, and You become compliant prior to 30 days after Your receipt - of the notice. - -5.2. If You initiate litigation against any entity by asserting a patent - infringement claim (excluding declaratory judgment actions, - counter-claims, and cross-claims) alleging that a Contributor Version - directly or indirectly infringes any patent, then the rights granted to - You by any and all Contributors for the Covered Software under Section - 2.1 of this License shall terminate. - -5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user - license agreements (excluding distributors and resellers) which have been - validly granted by You or Your distributors under this License prior to - termination shall survive termination. - -6. Disclaimer of Warranty - - Covered Software is provided under this License on an "as is" basis, - without warranty of any kind, either expressed, implied, or statutory, - including, without limitation, warranties that the Covered Software is free - of defects, merchantable, fit for a particular purpose or non-infringing. - The entire risk as to the quality and performance of the Covered Software - is with You. Should any Covered Software prove defective in any respect, - You (not any Contributor) assume the cost of any necessary servicing, - repair, or correction. This disclaimer of warranty constitutes an essential - part of this License. No use of any Covered Software is authorized under - this License except under this disclaimer. - -7. Limitation of Liability - - Under no circumstances and under no legal theory, whether tort (including - negligence), contract, or otherwise, shall any Contributor, or anyone who - distributes Covered Software as permitted above, be liable to You for any - direct, indirect, special, incidental, or consequential damages of any - character including, without limitation, damages for lost profits, loss of - goodwill, work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses, even if such party shall have been - informed of the possibility of such damages. This limitation of liability - shall not apply to liability for death or personal injury resulting from - such party's negligence to the extent applicable law prohibits such - limitation. Some jurisdictions do not allow the exclusion or limitation of - incidental or consequential damages, so this exclusion and limitation may - not apply to You. - -8. Litigation - - Any litigation relating to this License may be brought only in the courts - of a jurisdiction where the defendant maintains its principal place of - business and such litigation shall be governed by laws of that - jurisdiction, without reference to its conflict-of-law provisions. Nothing - in this Section shall prevent a party's ability to bring cross-claims or - counter-claims. - -9. Miscellaneous - - This License represents the complete agreement concerning the subject - matter hereof. If any provision of this License is held to be - unenforceable, such provision shall be reformed only to the extent - necessary to make it enforceable. Any law or regulation which provides that - the language of a contract shall be construed against the drafter shall not - be used to construe this License against a Contributor. - - -10. Versions of the License - -10.1. New Versions - - Mozilla Foundation is the license steward. Except as provided in Section - 10.3, no one other than the license steward has the right to modify or - publish new versions of this License. Each version will be given a - distinguishing version number. - -10.2. Effect of New Versions - - You may distribute the Covered Software under the terms of the version - of the License under which You originally received the Covered Software, - or under the terms of any subsequent version published by the license - steward. - -10.3. Modified Versions - - If you create software not governed by this License, and you want to - create a new license for such software, you may create and use a - modified version of this License if you rename the license and remove - any references to the name of the license steward (except to note that - such modified license differs from this License). - -10.4. Distributing Source Code Form that is Incompatible With Secondary - Licenses If You choose to distribute Source Code Form that is - Incompatible With Secondary Licenses under the terms of this version of - the License, the notice described in Exhibit B of this License must be - attached. - -Exhibit A - Source Code Form License Notice - - This Source Code Form is subject to the - terms of the Mozilla Public License, v. - 2.0. If a copy of the MPL was not - distributed with this file, You can - obtain one at - http://mozilla.org/MPL/2.0/. - -If it is not possible or desirable to put the notice in a particular file, -then You may include the notice in a location (such as a LICENSE file in a -relevant directory) where a recipient would be likely to look for such a -notice. - -You may add additional accurate notices of copyright ownership. - -Exhibit B - "Incompatible With Secondary Licenses" Notice - - This Source Code Form is "Incompatible - With Secondary Licenses", as defined by - the Mozilla Public License, v. 2.0. diff --git a/vendor/github.com/lucas-clemente/aes12/LICENSE b/vendor/github.com/lucas-clemente/aes12/LICENSE deleted file mode 100644 index 2c08ae24..00000000 --- a/vendor/github.com/lucas-clemente/aes12/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2016 Lucas Clemente - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/lucas-clemente/fnv128a/LICENSE b/vendor/github.com/lucas-clemente/fnv128a/LICENSE deleted file mode 100644 index 06dc795d..00000000 --- a/vendor/github.com/lucas-clemente/fnv128a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2016 Lucas Clemente - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/lucas-clemente/fnv128a/fnv128a.go b/vendor/github.com/lucas-clemente/fnv128a/fnv128a.go deleted file mode 100644 index 59212396..00000000 --- a/vendor/github.com/lucas-clemente/fnv128a/fnv128a.go +++ /dev/null @@ -1,87 +0,0 @@ -// Package fnv128a implements FNV-1 and FNV-1a, non-cryptographic hash functions -// created by Glenn Fowler, Landon Curt Noll, and Phong Vo. -// See https://en.wikipedia.org/wiki/Fowler-Noll-Vo_hash_function. -// -// Write() algorithm taken and modified from github.com/romain-jacotin/quic -package fnv128a - -import "hash" - -// Hash128 is the common interface implemented by all 128-bit hash functions. -type Hash128 interface { - hash.Hash - Sum128() (uint64, uint64) -} - -type sum128a struct { - v0, v1, v2, v3 uint64 -} - -var _ Hash128 = &sum128a{} - -// New1 returns a new 128-bit FNV-1a hash.Hash. -func New() Hash128 { - s := &sum128a{} - s.Reset() - return s -} - -func (s *sum128a) Reset() { - s.v0 = 0x6295C58D - s.v1 = 0x62B82175 - s.v2 = 0x07BB0142 - s.v3 = 0x6C62272E -} - -func (s *sum128a) Sum128() (uint64, uint64) { - return s.v3<<32 | s.v2, s.v1<<32 | s.v0 -} - -func (s *sum128a) Write(data []byte) (int, error) { - var t0, t1, t2, t3 uint64 - const fnv128PrimeLow = 0x0000013B - const fnv128PrimeShift = 24 - - for _, v := range data { - // xor the bottom with the current octet - s.v0 ^= uint64(v) - - // multiply by the 128 bit FNV magic prime mod 2^128 - // fnv_prime = 309485009821345068724781371 (decimal) - // = 0x0000000001000000000000000000013B (hexadecimal) - // = 0x00000000 0x01000000 0x00000000 0x0000013B (in 4*32 words) - // = 0x0 1<> 32) - t2 += (t1 >> 32) - t3 += (t2 >> 32) - - s.v0 = t0 & 0xffffffff - s.v1 = t1 & 0xffffffff - s.v2 = t2 & 0xffffffff - s.v3 = t3 // & 0xffffffff - // Doing a s.v3 &= 0xffffffff is not really needed since it simply - // removes multiples of 2^128. We can discard these excess bits - // outside of the loop when writing the hash in Little Endian. - } - - return len(data), nil -} - -func (s *sum128a) Size() int { return 16 } - -func (s *sum128a) BlockSize() int { return 1 } - -func (s *sum128a) Sum(in []byte) []byte { - panic("FNV: not supported") -} diff --git a/vendor/github.com/lucas-clemente/quic-go-certificates/LICENSE b/vendor/github.com/lucas-clemente/quic-go-certificates/LICENSE deleted file mode 100644 index 2c08ae24..00000000 --- a/vendor/github.com/lucas-clemente/quic-go-certificates/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2016 Lucas Clemente - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go index 1906abdf..71e72dce 100644 --- a/vendor/github.com/lucas-clemente/quic-go/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/client.go @@ -2,14 +2,14 @@ package quic import ( "bytes" + "context" "crypto/tls" "errors" "fmt" "net" - "strings" "sync" - "time" + "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -20,37 +20,68 @@ import ( type client struct { mutex sync.Mutex - conn connection + conn connection + // If the client is created with DialAddr, we create a packet conn. + // If it is started with Dial, we take a packet conn as a parameter. + createdPacketConn bool + hostname string - versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version - versionNegotiated bool // has the server accepted our version + packetHandlers packetHandlerManager + + token []byte + numRetries int + + versionNegotiated bool // has the server accepted our version receivedVersionNegotiationPacket bool negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet - tlsConf *tls.Config - config *Config - tls handshake.MintTLS // only used when using TLS + tlsConf *tls.Config + mintConf *mint.Config + config *Config - connectionID protocol.ConnectionID + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID initialVersion protocol.VersionNumber version protocol.VersionNumber - session packetHandler + handshakeChan chan struct{} + closeCallback func(protocol.ConnectionID) + + session quicSession logger utils.Logger } +var _ packetHandler = &client{} + var ( // make it possible to mock connection ID generation in the tests - generateConnectionID = utils.GenerateConnectionID - errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") + generateConnectionID = protocol.GenerateConnectionID + generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial + errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") + errCloseSessionForRetry = errors.New("closing session in response to a stateless retry") ) // DialAddr establishes a new QUIC connection to a server. // The hostname for SNI is taken from the given address. -func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) { +func DialAddr( + addr string, + tlsConf *tls.Config, + config *Config, +) (Session, error) { + return DialAddrContext(context.Background(), addr, tlsConf, config) +} + +// DialAddrContext establishes a new QUIC connection to a server using the provided context. +// The hostname for SNI is taken from the given address. +func DialAddrContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, +) (Session, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -59,7 +90,7 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) if err != nil { return nil, err } - return Dial(udpConn, udpAddr, addr, tlsConf, config) + return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true) } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -71,16 +102,69 @@ func Dial( tlsConf *tls.Config, config *Config, ) (Session, error) { - connID, err := generateConnectionID() + return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) +} + +// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. +// The host parameter is used for SNI. +func DialContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (Session, error) { + return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false) +} + +func dialContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, + createdPacketConn bool, +) (Session, error) { + config = populateClientConfig(config, createdPacketConn) + if !createdPacketConn { + for _, v := range config.Versions { + if v == protocol.Version44 { + return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr") + } + } + } + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength) if err != nil { return nil, err } + c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn) + if err != nil { + return nil, err + } + c.packetHandlers = packetHandlers + if err := c.dial(ctx); err != nil { + return nil, err + } + return c.session, nil +} +func newClient( + pconn net.PacketConn, + remoteAddr net.Addr, + config *Config, + tlsConf *tls.Config, + host string, + closeCallback func(protocol.ConnectionID), + createdPacketConn bool, +) (*client, error) { var hostname string if tlsConf != nil { hostname = tlsConf.ServerName } if hostname == "" { + var err error hostname, _, err = net.SplitHostPort(host) if err != nil { return nil, err @@ -95,29 +179,27 @@ func Dial( } } } - clientConfig := populateClientConfig(config) + onClose := func(protocol.ConnectionID) {} + if closeCallback != nil { + onClose = closeCallback + } c := &client{ - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - connectionID: connID, - hostname: hostname, - tlsConf: tlsConf, - config: clientConfig, - version: clientConfig.Versions[0], - versionNegotiationChan: make(chan struct{}), - logger: utils.DefaultLogger, + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + createdPacketConn: createdPacketConn, + hostname: hostname, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + closeCallback: onClose, + logger: utils.DefaultLogger.WithPrefix("client"), } - - c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) - - if err := c.dial(); err != nil { - return nil, err - } - return c.session, nil + return c, c.generateConnectionIDs() } // populateClientConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil -func populateClientConfig(config *Config) *Config { +func populateClientConfig(config *Config, createdPacketConn bool) *Config { if config == nil { config = &Config{} } @@ -155,12 +237,22 @@ func populateClientConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDLen := config.ConnectionIDLength + if connIDLen == 0 && !createdPacketConn { + connIDLen = protocol.DefaultConnectionIDLength + } + for _, v := range versions { + if v == protocol.Version44 { + connIDLen = 0 + } + } return &Config{ Versions: versions, HandshakeTimeout: handshakeTimeout, IdleTimeout: idleTimeout, RequestConnectionIDOmission: config.RequestConnectionIDOmission, + ConnectionIDLength: connIDLen, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxIncomingStreams: maxIncomingStreams, @@ -169,28 +261,54 @@ func populateClientConfig(config *Config) *Config { } } -func (c *client) dial() error { +func (c *client) generateConnectionIDs() error { + connIDLen := protocol.ConnectionIDLenGQUIC + if c.version.UsesTLS() { + connIDLen = c.config.ConnectionIDLength + } + srcConnID, err := generateConnectionID(connIDLen) + if err != nil { + return err + } + destConnID := srcConnID + if c.version.UsesTLS() { + destConnID, err = generateConnectionIDForInitial() + if err != nil { + return err + } + } + c.srcConnID = srcConnID + c.destConnID = destConnID + if c.version == protocol.Version44 { + c.srcConnID = nil + } + return nil +} + +func (c *client) dial(ctx context.Context) error { + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + var err error if c.version.UsesTLS() { - err = c.dialTLS() + err = c.dialTLS(ctx) } else { - err = c.dialGQUIC() - } - if err == errCloseSessionForNewVersion { - return c.dial() + err = c.dialGQUIC(ctx) } return err } -func (c *client) dialGQUIC() error { +func (c *client) dialGQUIC(ctx context.Context) error { if err := c.createNewGQUICSession(); err != nil { return err } - go c.listen() - return c.establishSecureConnection() + err := c.establishSecureConnection(ctx) + if err == errCloseSessionForNewVersion { + return c.dial(ctx) + } + return err } -func (c *client) dialTLS() error { +func (c *client) dialTLS(ctx context.Context) error { params := &handshake.TransportParameters{ StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, @@ -198,8 +316,8 @@ func (c *client) dialTLS() error { OmitConnectionID: c.config.RequestConnectionIDOmission, MaxBidiStreams: uint16(c.config.MaxIncomingStreams), MaxUniStreams: uint16(c.config.MaxIncomingUniStreams), + DisableMigration: true, } - csc := handshake.NewCryptoStreamConn(nil) extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger) mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) if err != nil { @@ -207,25 +325,16 @@ func (c *client) dialTLS() error { } mintConf.ExtensionHandler = extHandler mintConf.ServerName = c.hostname - c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient) + c.mintConf = mintConf if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { return err } - go c.listen() - if err := c.establishSecureConnection(); err != nil { - if err != handshake.ErrCloseSessionForRetry { - return err - } - c.logger.Infof("Received a Retry packet. Recreating session.") - if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { - return err - } - if err := c.establishSecureConnection(); err != nil { - return err - } + err = c.establishSecureConnection(ctx) + if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion { + return c.dial(ctx) } - return nil + return err } // establishSecureConnection runs the session, and tries to establish a secure connection @@ -234,134 +343,114 @@ func (c *client) dialTLS() error { // - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC) // - any other error that might occur // - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) -func (c *client) establishSecureConnection() error { - var runErr error - errorChan := make(chan struct{}) +func (c *client) establishSecureConnection(ctx context.Context) error { + errorChan := make(chan error, 1) + go func() { - runErr = c.session.run() // returns as soon as the session is closed - close(errorChan) - c.logger.Infof("Connection %x closed.", c.connectionID) - if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { + err := c.session.run() // returns as soon as the session is closed + if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn { c.conn.Close() } + errorChan <- err }() - // wait until the server accepts the QUIC version (or an error occurs) select { - case <-errorChan: - return runErr - case <-c.versionNegotiationChan: - } - - select { - case <-errorChan: - return runErr - case err := <-c.session.handshakeStatus(): + case <-ctx.Done(): + // The session will send a PeerGoingAway error to the server. + c.session.Close() + return ctx.Err() + case err := <-errorChan: return err + case <-c.handshakeChan: + // handshake successfully completed + return nil } } -// Listen listens on the underlying connection and passes packets on for handling. -// It returns when the connection is closed. -func (c *client) listen() { - var err error - - for { - var n int - var addr net.Addr - data := *getPacketBuffer() - data = data[:protocol.MaxReceivePacketSize] - // The packet size should not exceed protocol.MaxReceivePacketSize bytes - // If it does, we only read a truncated packet, which will then end up undecryptable - n, addr, err = c.conn.Read(data) - if err != nil { - if !strings.HasSuffix(err.Error(), "use of closed network connection") { - c.mutex.Lock() - if c.session != nil { - c.session.Close(err) - } - c.mutex.Unlock() - } - break - } - c.handlePacket(addr, data[:n]) +func (c *client) handlePacket(p *receivedPacket) { + if err := c.handlePacketImpl(p); err != nil { + c.logger.Errorf("error handling packet: %s", err) } } -func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { - rcvTime := time.Now() - - r := bytes.NewReader(packet) - hdr, err := wire.ParseHeaderSentByServer(r, c.version) - if err != nil { - c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) - // drop this packet if we can't parse the header - return - } - // reject packets with truncated connection id if we didn't request truncation - if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { - return - } - hdr.Raw = packet[:len(packet)-r.Len()] - +func (c *client) handlePacketImpl(p *receivedPacket) error { c.mutex.Lock() defer c.mutex.Unlock() - // reject packets with the wrong connection ID - if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID { - return - } - - if hdr.ResetFlag { - cr := c.conn.RemoteAddr() - // check if the remote address and the connection ID match - // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection - if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { - c.logger.Infof("Received a spoofed Public Reset. Ignoring.") - return - } - pr, err := wire.ParsePublicReset(r) - if err != nil { - c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) - return - } - c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) - c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) - return - } - // handle Version Negotiation Packets - if hdr.IsVersionNegotiation { - // ignore delayed / duplicated version negotiation packets - if c.receivedVersionNegotiationPacket || c.versionNegotiated { - return + if p.header.IsVersionNegotiation { + err := c.handleVersionNegotiationPacket(p.header) + if err != nil { + c.session.destroy(err) } - // version negotiation packets have no payload - if err := c.handleVersionNegotiationPacket(hdr); err != nil { - c.session.Close(err) + return err + } + + if !c.version.UsesIETFHeaderFormat() { + connID := p.header.DestConnectionID + // reject packets with truncated connection id if we didn't request truncation + if !c.config.RequestConnectionIDOmission && connID.Len() == 0 { + return errors.New("received packet with truncated connection ID, but didn't request truncation") + } + // reject packets with the wrong connection ID + if connID.Len() > 0 && !connID.Equal(c.srcConnID) { + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID) + } + if p.header.ResetFlag { + return c.handlePublicReset(p) + } + } else { + // reject packets with the wrong connection ID + if !p.header.DestConnectionID.Equal(c.srcConnID) { + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID) + } + } + + if p.header.IsLongHeader { + switch p.header.Type { + case protocol.PacketTypeRetry: + c.handleRetryPacket(p.header) + return nil + case protocol.PacketTypeHandshake, protocol.PacketType0RTT: + default: + return fmt.Errorf("Received unsupported packet type: %s", p.header.Type) } - return } // this is the first packet we are receiving // since it is not a Version Negotiation Packet, this means the server supports the suggested version if !c.versionNegotiated { c.versionNegotiated = true - close(c.versionNegotiationChan) } - // TODO: validate packet number and connection ID on Retry packets (for IETF QUIC) + c.session.handlePacket(p) + return nil +} - c.session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - header: hdr, - data: packet[len(packet)-r.Len():], - rcvTime: rcvTime, - }) +func (c *client) handlePublicReset(p *receivedPacket) error { + cr := c.conn.RemoteAddr() + // check if the remote address and the connection ID match + // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection + if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) { + return errors.New("Received a spoofed Public Reset") + } + pr, err := wire.ParsePublicReset(bytes.NewReader(p.data)) + if err != nil { + return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err) + } + c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) + c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber) + return nil } func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { + // ignore delayed / duplicated version negotiation packets + if c.receivedVersionNegotiationPacket || c.versionNegotiated { + c.logger.Debugf("Received a delayed Version Negotiation Packet.") + return nil + } + for _, v := range hdr.SupportedVersions { if v == c.version { // the version negotiation packet contains the version that we offered @@ -372,7 +461,6 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { } c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) - newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { return qerr.InvalidVersion @@ -383,49 +471,125 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { // switch to negotiated version c.initialVersion = c.version c.version = newVersion - var err error - c.connectionID, err = utils.GenerateConnectionID() - if err != nil { + if err := c.generateConnectionIDs(); err != nil { return err } - c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) - c.session.Close(errCloseSessionForNewVersion) + + c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) + c.session.destroy(errCloseSessionForNewVersion) return nil } -func (c *client) createNewGQUICSession() (err error) { +func (c *client) handleRetryPacket(hdr *wire.Header) { + c.logger.Debugf("<- Received Retry") + hdr.Log(c.logger) + // A server that performs multiple retries must use a source connection ID of at least 8 bytes. + // Only a server that won't send additional Retries can use shorter connection IDs. + if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial { + c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial) + return + } + if !hdr.OrigDestConnectionID.Equal(c.destConnID) { + c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID) + return + } + c.numRetries++ + if c.numRetries > protocol.MaxRetries { + c.session.destroy(qerr.CryptoTooManyRejects) + return + } + c.destConnID = hdr.SrcConnectionID + c.token = hdr.Token + c.session.destroy(errCloseSessionForRetry) +} + +func (c *client) createNewGQUICSession() error { c.mutex.Lock() defer c.mutex.Unlock() - c.session, err = newClientSession( + runner := &runner{ + onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, + removeConnectionIDImpl: c.closeCallback, + } + sess, err := newClientSession( c.conn, + runner, c.hostname, c.version, - c.connectionID, + c.destConnID, + c.srcConnID, c.tlsConf, c.config, c.initialVersion, c.negotiatedVersions, c.logger, ) - return err + if err != nil { + return err + } + c.session = sess + c.packetHandlers.Add(c.srcConnID, c) + if c.config.RequestConnectionIDOmission { + c.packetHandlers.Add(protocol.ConnectionID{}, c) + } + return nil } func (c *client) createNewTLSSession( paramsChan <-chan handshake.TransportParameters, version protocol.VersionNumber, -) (err error) { +) error { c.mutex.Lock() defer c.mutex.Unlock() - c.session, err = newTLSClientSession( + runner := &runner{ + onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, + removeConnectionIDImpl: c.closeCallback, + } + sess, err := newTLSClientSession( c.conn, - c.hostname, - c.version, - c.connectionID, + runner, + c.token, + c.destConnID, + c.srcConnID, c.config, - c.tls, + c.mintConf, paramsChan, 1, c.logger, + c.version, ) - return err + if err != nil { + return err + } + c.session = sess + c.packetHandlers.Add(c.srcConnID, c) + return nil +} + +func (c *client) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.session == nil { + return nil + } + return c.session.Close() +} + +func (c *client) destroy(e error) { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.session == nil { + return + } + c.session.destroy(e) +} + +func (c *client) GetVersion() protocol.VersionNumber { + c.mutex.Lock() + v := c.version + c.mutex.Unlock() + return v +} + +func (c *client) GetPerspective() protocol.Perspective { + return protocol.PerspectiveClient } diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go index 8e96ec10..a5ec4ecf 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go @@ -8,7 +8,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type cryptoStreamI interface { +type cryptoStream interface { StreamID() protocol.StreamID io.Reader io.Writer @@ -21,21 +21,21 @@ type cryptoStreamI interface { handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) } -type cryptoStream struct { +type cryptoStreamImpl struct { *stream } -var _ cryptoStreamI = &cryptoStream{} +var _ cryptoStream = &cryptoStreamImpl{} -func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { +func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream { str := newStream(version.CryptoStreamID(), sender, flowController, version) - return &cryptoStream{str} + return &cryptoStreamImpl{str} } // SetReadOffset sets the read offset. // It is only needed for the crypto stream. // It must not be called concurrently with any other stream methods, especially Read and Write. -func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) { +func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) { s.receiveStream.readOffset = offset - s.receiveStream.frameQueue.readPosition = offset + s.receiveStream.frameQueue.readPos = offset } diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go b/vendor/github.com/lucas-clemente/quic-go/frame_sorter.go similarity index 56% rename from vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go rename to vendor/github.com/lucas-clemente/quic-go/frame_sorter.go index e3a3a807..47062c06 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go +++ b/vendor/github.com/lucas-clemente/quic-go/frame_sorter.go @@ -5,51 +5,55 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" ) -type streamFrameSorter struct { - queuedFrames map[protocol.ByteCount]*wire.StreamFrame - readPosition protocol.ByteCount - gaps *utils.ByteIntervalList +type frameSorter struct { + queue map[protocol.ByteCount][]byte + readPos protocol.ByteCount + finalOffset protocol.ByteCount + gaps *utils.ByteIntervalList } -var ( - errTooManyGapsInReceivedStreamData = errors.New("Too many gaps in received StreamFrame data") - errDuplicateStreamData = errors.New("Duplicate Stream Data") - errEmptyStreamData = errors.New("Stream Data empty") -) +var errDuplicateStreamData = errors.New("Duplicate Stream Data") -func newStreamFrameSorter() *streamFrameSorter { - s := streamFrameSorter{ - gaps: utils.NewByteIntervalList(), - queuedFrames: make(map[protocol.ByteCount]*wire.StreamFrame), +func newFrameSorter() *frameSorter { + s := frameSorter{ + gaps: utils.NewByteIntervalList(), + queue: make(map[protocol.ByteCount][]byte), + finalOffset: protocol.MaxByteCount, } s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) return &s } -func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error { - if frame.DataLen() == 0 { - if frame.FinBit { - s.queuedFrames[frame.Offset] = frame - return nil - } - return errEmptyStreamData +func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, fin bool) error { + err := s.push(data, offset, fin) + if err == errDuplicateStreamData { + return nil + } + return err +} + +func (s *frameSorter) push(data []byte, offset protocol.ByteCount, fin bool) error { + if fin { + s.finalOffset = offset + protocol.ByteCount(len(data)) + } + if len(data) == 0 { + return nil } var wasCut bool - if oldFrame, ok := s.queuedFrames[frame.Offset]; ok { - if frame.DataLen() <= oldFrame.DataLen() { + if oldData, ok := s.queue[offset]; ok { + if len(data) <= len(oldData) { return errDuplicateStreamData } - frame.Data = frame.Data[oldFrame.DataLen():] - frame.Offset += oldFrame.DataLen() + data = data[len(oldData):] + offset += protocol.ByteCount(len(oldData)) wasCut = true } - start := frame.Offset - end := frame.Offset + frame.DataLen() + start := offset + end := offset + protocol.ByteCount(len(data)) // skip all gaps that are before this stream frame var gap *utils.ByteIntervalElement @@ -69,9 +73,9 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error { if start < gap.Value.Start { add := gap.Value.Start - start - frame.Offset += add + offset += add start += add - frame.Data = frame.Data[add:] + data = data[add:] wasCut = true } @@ -89,15 +93,15 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error { break } // delete queued frames completely covered by the current frame - delete(s.queuedFrames, endGap.Value.End) + delete(s.queue, endGap.Value.End) endGap = nextEndGap } if end > endGap.Value.End { cutLen := end - endGap.Value.End - len := frame.DataLen() - cutLen + len := protocol.ByteCount(len(data)) - cutLen end -= cutLen - frame.Data = frame.Data[:len] + data = data[:len] wasCut = true } @@ -130,32 +134,25 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error { } if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps { - return errTooManyGapsInReceivedStreamData + return errors.New("Too many gaps in received data") } if wasCut { - data := make([]byte, frame.DataLen()) - copy(data, frame.Data) - frame.Data = data + newData := make([]byte, len(data)) + copy(newData, data) + data = newData } - s.queuedFrames[frame.Offset] = frame + s.queue[offset] = data return nil } -func (s *streamFrameSorter) Pop() *wire.StreamFrame { - frame := s.Head() - if frame != nil { - s.readPosition += frame.DataLen() - delete(s.queuedFrames, frame.Offset) +func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) { + data, ok := s.queue[s.readPos] + if !ok { + return nil, s.readPos >= s.finalOffset } - return frame -} - -func (s *streamFrameSorter) Head() *wire.StreamFrame { - frame, ok := s.queuedFrames[s.readPosition] - if ok { - return frame - } - return nil + delete(s.queue, s.readPos) + s.readPos += protocol.ByteCount(len(data)) + return data, s.readPos >= s.finalOffset } diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go index 40980882..ac28a7f0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go @@ -77,7 +77,7 @@ func newClient( opts: opts, headerErrored: make(chan struct{}), dialer: dialer, - logger: utils.DefaultLogger, + logger: utils.DefaultLogger.WithPrefix("client"), } } @@ -172,7 +172,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { responseChan := make(chan *http.Response) dataStream, err := c.session.OpenStreamSync() if err != nil { - _ = c.CloseWithError(err) + _ = c.closeWithError(err) return nil, err } c.mutex.Lock() @@ -187,7 +187,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { endStream := !hasBody err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) if err != nil { - _ = c.CloseWithError(err) + _ = c.closeWithError(err) return nil, err } @@ -230,7 +230,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, ctx.Err() case <-c.headerErrored: // an error occurred on the header stream - _ = c.CloseWithError(c.headerErr) + _ = c.closeWithError(c.headerErr) return nil, c.headerErr } } @@ -275,16 +275,19 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e return dataStream.Close() } -// Close closes the client -func (c *client) CloseWithError(e error) error { +func (c *client) closeWithError(e error) error { if c.session == nil { return nil } - return c.session.Close(e) + return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e) } +// Close closes the client func (c *client) Close() error { - return c.CloseWithError(nil) + if c.session == nil { + return nil + } + return c.session.Close() } // copied from net/transport.go diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go index 911485ef..b27e37e2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go @@ -70,9 +70,6 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { } func hostnameFromRequest(req *http.Request) string { - if len(req.Host) > 0 { - return req.Host - } if req.URL != nil { return req.URL.Host } diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go index ddaaa741..b5ee9751 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go @@ -8,9 +8,9 @@ import ( "strings" "sync" + "golang.org/x/net/http/httpguts" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" - "golang.org/x/net/lex/httplex" quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -65,7 +65,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra if host == "" { host = req.URL.Host } - host, err := httplex.PunycodeHostPort(host) + host, err := httpguts.PunycodeHostPort(host) if err != nil { return nil, err } @@ -89,11 +89,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) for k, vv := range req.Header { - if !httplex.ValidHeaderFieldName(k) { + if !httpguts.ValidHeaderFieldName(k) { return nil, fmt.Errorf("invalid HTTP header name %q", k) } for _, v := range vv { - if !httplex.ValidHeaderFieldValue(v) { + if !httpguts.ValidHeaderFieldValue(v) { return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) } } diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go index 25b77a54..02841227 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go @@ -98,9 +98,6 @@ func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) } // test that we implement http.Flusher var _ http.Flusher = &responseWriter{} -// test that we implement http.CloseNotifier -var _ http.CloseNotifier = &responseWriter{} - // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 2616, section 4.4. diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go new file mode 100644 index 00000000..b26f91c1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go @@ -0,0 +1,9 @@ +package h2quic + +import "net/http" + +// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11. +// By defining it in a separate file, we can exclude this file from staticcheck. + +// test that we implement http.CloseNotifier +var _ http.CloseNotifier = &responseWriter{} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go index f6c170b0..27732b5e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go @@ -11,7 +11,7 @@ import ( quic "github.com/lucas-clemente/quic-go" - "golang.org/x/net/lex/httplex" + "golang.org/x/net/http/httpguts" ) type roundTripCloser interface { @@ -80,11 +80,11 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if req.URL.Scheme == "https" { for k, vv := range req.Header { - if !httplex.ValidHeaderFieldName(k) { + if !httpguts.ValidHeaderFieldName(k) { return nil, fmt.Errorf("quic: invalid http header field name %q", k) } for _, v := range vv { - if !httplex.ValidHeaderFieldValue(v) { + if !httpguts.ValidHeaderFieldValue(v) { return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) } } @@ -175,5 +175,5 @@ func validMethod(method string) bool { // copied from net/http/http.go func isNotToken(r rune) bool { - return !httplex.IsTokenRune(r) + return !httpguts.IsTokenRune(r) } diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go index a2412bd1..0d787bdd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go @@ -90,7 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { if s.Server == nil { return errors.New("use of h2quic.Server without http.Server") } - s.logger = utils.DefaultLogger + s.logger = utils.DefaultLogger.WithPrefix("server") s.listenerMutex.Lock() if s.closed { s.listenerMutex.Unlock() @@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { func (s *Server) handleHeaderStream(session streamCreator) { stream, err := session.AcceptStream() if err != nil { - session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) + session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err) return } @@ -140,10 +140,12 @@ func (s *Server) handleHeaderStream(session streamCreator) { // QuicErrors must originate from stream.Read() returning an error. // In this case, the session has already logged the error, so we don't // need to log it again. - if _, ok := err.(*qerr.QuicError); !ok { + errorCode := qerr.InternalError + if qerr, ok := err.(*qerr.QuicError); !ok { + errorCode = qerr.ErrorCode s.logger.Errorf("error handling h2 request: %s", err.Error()) } - session.Close(err) + session.CloseWithError(quic.ErrorCode(errorCode), err) return } } @@ -154,10 +156,18 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, if err != nil { return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") } - h2headersFrame, ok := h2frame.(*http2.HeadersFrame) - if !ok { + var h2headersFrame *http2.HeadersFrame + switch f := h2frame.(type) { + case *http2.PriorityFrame: + // ignore PRIORITY frames + s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f) + return nil + case *http2.HeadersFrame: + h2headersFrame = f + default: return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame") } + if !h2headersFrame.HeadersEnded() { return errors.New("http2 header continuation not implemented") } @@ -238,7 +248,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, } if s.CloseAfterFirstRequest { time.Sleep(100 * time.Millisecond) - session.Close(nil) + session.Close() } }() diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go deleted file mode 100644 index d12a3bae..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go +++ /dev/null @@ -1,270 +0,0 @@ -package quicproxy - -import ( - "net" - "sync" - "sync/atomic" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -// Connection is a UDP connection -type connection struct { - ClientAddr *net.UDPAddr // Address of the client - ServerConn *net.UDPConn // UDP connection to server - - incomingPacketCounter uint64 - outgoingPacketCounter uint64 -} - -// Direction is the direction a packet is sent. -type Direction int - -const ( - // DirectionIncoming is the direction from the client to the server. - DirectionIncoming Direction = iota - // DirectionOutgoing is the direction from the server to the client. - DirectionOutgoing - // DirectionBoth is both incoming and outgoing - DirectionBoth -) - -func (d Direction) String() string { - switch d { - case DirectionIncoming: - return "incoming" - case DirectionOutgoing: - return "outgoing" - case DirectionBoth: - return "both" - default: - panic("unknown direction") - } -} - -// Is says if one direction matches another direction. -// For example, incoming matches both incoming and both, but not outgoing. -func (d Direction) Is(dir Direction) bool { - if d == DirectionBoth || dir == DirectionBoth { - return true - } - return d == dir -} - -// DropCallback is a callback that determines which packet gets dropped. -type DropCallback func(dir Direction, packetCount uint64) bool - -// NoDropper doesn't drop packets. -var NoDropper DropCallback = func(Direction, uint64) bool { - return false -} - -// DelayCallback is a callback that determines how much delay to apply to a packet. -type DelayCallback func(dir Direction, packetCount uint64) time.Duration - -// NoDelay doesn't apply a delay. -var NoDelay DelayCallback = func(Direction, uint64) time.Duration { - return 0 -} - -// Opts are proxy options. -type Opts struct { - // The address this proxy proxies packets to. - RemoteAddr string - // DropPacket determines whether a packet gets dropped. - DropPacket DropCallback - // DelayPacket determines how long a packet gets delayed. This allows - // simulating a connection with non-zero RTTs. - // Note that the RTT is the sum of the delay for the incoming and the outgoing packet. - DelayPacket DelayCallback -} - -// QuicProxy is a QUIC proxy that can drop and delay packets. -type QuicProxy struct { - mutex sync.Mutex - - version protocol.VersionNumber - - conn *net.UDPConn - serverAddr *net.UDPAddr - - dropPacket DropCallback - delayPacket DelayCallback - - // Mapping from client addresses (as host:port) to connection - clientDict map[string]*connection - - logger utils.Logger -} - -// NewQuicProxy creates a new UDP proxy -func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) { - if opts == nil { - opts = &Opts{} - } - laddr, err := net.ResolveUDPAddr("udp", local) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", laddr) - if err != nil { - return nil, err - } - raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr) - if err != nil { - return nil, err - } - - packetDropper := NoDropper - if opts.DropPacket != nil { - packetDropper = opts.DropPacket - } - - packetDelayer := NoDelay - if opts.DelayPacket != nil { - packetDelayer = opts.DelayPacket - } - - p := QuicProxy{ - clientDict: make(map[string]*connection), - conn: conn, - serverAddr: raddr, - dropPacket: packetDropper, - delayPacket: packetDelayer, - version: version, - logger: utils.DefaultLogger, - } - - p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr) - go p.runProxy() - return &p, nil -} - -// Close stops the UDP Proxy -func (p *QuicProxy) Close() error { - p.mutex.Lock() - defer p.mutex.Unlock() - for _, c := range p.clientDict { - if err := c.ServerConn.Close(); err != nil { - return err - } - } - return p.conn.Close() -} - -// LocalAddr is the address the proxy is listening on. -func (p *QuicProxy) LocalAddr() net.Addr { - return p.conn.LocalAddr() -} - -// LocalPort is the UDP port number the proxy is listening on. -func (p *QuicProxy) LocalPort() int { - return p.conn.LocalAddr().(*net.UDPAddr).Port -} - -func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { - srvudp, err := net.DialUDP("udp", nil, p.serverAddr) - if err != nil { - return nil, err - } - return &connection{ - ClientAddr: cliAddr, - ServerConn: srvudp, - }, nil -} - -// runProxy listens on the proxy address and handles incoming packets. -func (p *QuicProxy) runProxy() error { - for { - buffer := make([]byte, protocol.MaxReceivePacketSize) - n, cliaddr, err := p.conn.ReadFromUDP(buffer) - if err != nil { - return err - } - raw := buffer[0:n] - - saddr := cliaddr.String() - p.mutex.Lock() - conn, ok := p.clientDict[saddr] - - if !ok { - conn, err = p.newConnection(cliaddr) - if err != nil { - p.mutex.Unlock() - return err - } - p.clientDict[saddr] = conn - go p.runConnection(conn) - } - p.mutex.Unlock() - - packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1) - - if p.dropPacket(DirectionIncoming, packetCount) { - if p.logger.Debug() { - p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n) - } - continue - } - - // Send the packet to the server - delay := p.delayPacket(DirectionIncoming, packetCount) - if delay != 0 { - if p.logger.Debug() { - p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay) - } - time.AfterFunc(delay, func() { - // TODO: handle error - _, _ = conn.ServerConn.Write(raw) - }) - } else { - if p.logger.Debug() { - p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr()) - } - if _, err := conn.ServerConn.Write(raw); err != nil { - return err - } - } - } -} - -// runConnection handles packets from server to a single client -func (p *QuicProxy) runConnection(conn *connection) error { - for { - buffer := make([]byte, protocol.MaxReceivePacketSize) - n, err := conn.ServerConn.Read(buffer) - if err != nil { - return err - } - raw := buffer[0:n] - - packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1) - - if p.dropPacket(DirectionOutgoing, packetCount) { - if p.logger.Debug() { - p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n) - } - continue - } - - delay := p.delayPacket(DirectionOutgoing, packetCount) - if delay != 0 { - if p.logger.Debug() { - p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay) - } - time.AfterFunc(delay, func() { - // TODO: handle error - _, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) - }) - } else { - if p.logger.Debug() { - p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr) - } - if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { - return err - } - } - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go deleted file mode 100644 index c987ddb7..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go +++ /dev/null @@ -1,41 +0,0 @@ -package testlog - -import ( - "flag" - "log" - "os" - - "github.com/lucas-clemente/quic-go/internal/utils" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var ( - logFileName string // the log file set in the ginkgo flags - logFile *os.File -) - -// read the logfile command line flag -// to set call ginkgo -- -logfile=log.txt -func init() { - flag.StringVar(&logFileName, "logfile", "", "log file") -} - -var _ = BeforeEach(func() { - log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds) - - if len(logFileName) > 0 { - var err error - logFile, err = os.Create(logFileName) - Expect(err).ToNot(HaveOccurred()) - log.SetOutput(logFile) - utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug) - } -}) - -var _ = AfterEach(func() { - if len(logFileName) > 0 { - _ = logFile.Close() - } -}) diff --git a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go b/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go deleted file mode 100644 index 70ba2dda..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go +++ /dev/null @@ -1,119 +0,0 @@ -package testserver - -import ( - "io" - "io/ioutil" - "net" - "net/http" - "strconv" - - quic "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/h2quic" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/testdata" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -const ( - dataLen = 500 * 1024 // 500 KB - dataLenLong = 50 * 1024 * 1024 // 50 MB -) - -var ( - // PRData contains dataLen bytes of pseudo-random data. - PRData = GeneratePRData(dataLen) - // PRDataLong contains dataLenLong bytes of pseudo-random data. - PRDataLong = GeneratePRData(dataLenLong) - - server *h2quic.Server - stoppedServing chan struct{} - port string -) - -func init() { - http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - sl := r.URL.Query().Get("len") - if sl != "" { - var err error - l, err := strconv.Atoi(sl) - Expect(err).NotTo(HaveOccurred()) - _, err = w.Write(GeneratePRData(l)) - Expect(err).NotTo(HaveOccurred()) - } else { - _, err := w.Write(PRData) - Expect(err).NotTo(HaveOccurred()) - } - }) - - http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - _, err := w.Write(PRDataLong) - Expect(err).NotTo(HaveOccurred()) - }) - - http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - _, err := io.WriteString(w, "Hello, World!\n") - Expect(err).NotTo(HaveOccurred()) - }) - - http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - body, err := ioutil.ReadAll(r.Body) - Expect(err).NotTo(HaveOccurred()) - _, err = w.Write(body) - Expect(err).NotTo(HaveOccurred()) - }) -} - -// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator -func GeneratePRData(l int) []byte { - res := make([]byte, l) - seed := uint64(1) - for i := 0; i < l; i++ { - seed = seed * 48271 % 2147483647 - res[i] = byte(seed) - } - return res -} - -// StartQuicServer starts a h2quic.Server. -// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used. -func StartQuicServer(versions []protocol.VersionNumber) { - server = &h2quic.Server{ - Server: &http.Server{ - TLSConfig: testdata.GetTLSConfig(), - }, - QuicConfig: &quic.Config{ - Versions: versions, - }, - } - - addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0") - Expect(err).NotTo(HaveOccurred()) - conn, err := net.ListenUDP("udp", addr) - Expect(err).NotTo(HaveOccurred()) - port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port) - - stoppedServing = make(chan struct{}) - - go func() { - defer GinkgoRecover() - server.Serve(conn) - close(stoppedServing) - }() -} - -// StopQuicServer stops the h2quic.Server. -func StopQuicServer() { - Expect(server.Close()).NotTo(HaveOccurred()) - Eventually(stoppedServing).Should(BeClosed()) -} - -// Port returns the UDP port of the QUIC server. -func Port() string { - return port -} diff --git a/vendor/github.com/lucas-clemente/quic-go/interface.go b/vendor/github.com/lucas-clemente/quic-go/interface.go index 3ab64afd..d7048097 100644 --- a/vendor/github.com/lucas-clemente/quic-go/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/interface.go @@ -16,8 +16,16 @@ type StreamID = protocol.StreamID // A VersionNumber is a QUIC version number. type VersionNumber = protocol.VersionNumber -// VersionGQUIC39 is gQUIC version 39. -const VersionGQUIC39 = protocol.Version39 +const ( + // VersionGQUIC39 is gQUIC version 39. + VersionGQUIC39 = protocol.Version39 + // VersionGQUIC43 is gQUIC version 43. + VersionGQUIC43 = protocol.Version43 + // VersionGQUIC44 is gQUIC version 44. + VersionGQUIC44 = protocol.Version44 + // VersionMilestone0_10_0 uses TLS + VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0 +) // A Cookie can be used to verify the ownership of the client address. type Cookie = handshake.Cookie @@ -139,8 +147,11 @@ type Session interface { LocalAddr() net.Addr // RemoteAddr returns the address of the peer. RemoteAddr() net.Addr - // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. - Close(error) error + // Close the connection. + io.Closer + // Close the connection with an error. + // The error must not be nil. + CloseWithError(ErrorCode, error) error // The context is cancelled when the session is closed. // Warning: This API should not be considered stable and might change soon. Context() context.Context @@ -159,6 +170,13 @@ type Config struct { // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // Currently only valid for the client. RequestConnectionIDOmission bool + // The length of the connection ID in bytes. Only valid for IETF QUIC. + // It can be 0, or any value between 4 and 18. + // If not set, the interpretation depends on where the Config is used: + // If used for dialing an address, a 0 byte connection ID will be used. + // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. + // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. + ConnectionIDLength int // HandshakeTimeout is the maximum duration that the cryptographic handshake may take. // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 10 seconds. diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go index 43027dcf..1924cdc9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go @@ -29,7 +29,8 @@ type SentPacketHandler interface { GetStopWaitingFrame(force bool) *wire.StopWaitingFrame GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - DequeuePacketForRetransmission() (packet *Packet) + DequeuePacketForRetransmission() *Packet + DequeueProbePacket() (*Packet, error) GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen GetAlarmTimeout() time.Time diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go index 10200f4c..8af21324 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go @@ -25,6 +25,8 @@ type receivedPacketHandler struct { ackAlarm time.Time lastAck *wire.AckFrame + logger utils.Logger + version protocol.VersionNumber } @@ -52,11 +54,16 @@ const ( ) // NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler { +func NewReceivedPacketHandler( + rttStats *congestion.RTTStats, + logger utils.Logger, + version protocol.VersionNumber, +) ReceivedPacketHandler { return &receivedPacketHandler{ packetHistory: newReceivedPacketHistory(), ackSendDelay: ackSendDelay, rttStats: rttStats, + logger: logger, version: version, } } @@ -82,16 +89,22 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe // IgnoreBelow sets a lower limit for acking packets. // Packets with packet numbers smaller than p will not be acked. func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) { + if p <= h.ignoreBelow { + return + } h.ignoreBelow = p h.packetHistory.DeleteBelow(p) + if h.logger.Debug() { + h.logger.Debugf("\tIgnoring all packets below %#x.", p) + } } // isMissing says if a packet was reported missing in the last ACK. func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool { - if h.lastAck == nil { + if h.lastAck == nil || p < h.ignoreBelow { return false } - return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p) + return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) } func (h *receivedPacketHandler) hasNewMissingPackets() bool { @@ -99,7 +112,7 @@ func (h *receivedPacketHandler) hasNewMissingPackets() bool { return false } highestRange := h.packetHistory.GetHighestAckRange() - return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing + return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing } // maybeQueueAck queues an ACK, if necessary. @@ -110,6 +123,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber // always ack the first packet if h.lastAck == nil { + h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") h.ackQueued = true return } @@ -118,6 +132,9 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber // Ack decimation with reordering relies on the timer to send an ACK, but if // missing packets we reported in the previous ack, send an ACK immediately. if wasMissing { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber) + } h.ackQueued = true } @@ -128,26 +145,41 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber // ack up to 10 packets at once if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck { h.ackQueued = true + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck) + } } else if h.ackAlarm.IsZero() { // wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay))) h.ackAlarm = rcvTime.Add(ackDelay) + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm)) + } } } else { // send an ACK every 2 retransmittable packets if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck) + } h.ackQueued = true } else if h.ackAlarm.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay) + } h.ackAlarm = rcvTime.Add(ackSendDelay) } } // If there are new missing packets to report, set a short timer to send an ACK. if h.hasNewMissingPackets() { // wait the minimum of 1/8 min RTT and the existing ack time - ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay) - ackTime := rcvTime.Add(time.Duration(ackDelay)) + ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)) + ackTime := rcvTime.Add(ackDelay) if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) { h.ackAlarm = ackTime + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm)) + } } } } @@ -159,19 +191,17 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber } func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { - if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) { + now := time.Now() + if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { return nil } - - ackRanges := h.packetHistory.GetAckRanges() - ack := &wire.AckFrame{ - LargestAcked: h.largestObserved, - LowestAcked: ackRanges[len(ackRanges)-1].First, - PacketReceivedTime: h.largestObservedReceivedTime, + if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + h.logger.Debugf("Sending ACK because the ACK timer expired.") } - if len(ackRanges) > 1 { - ack.AckRanges = ackRanges + ack := &wire.AckFrame{ + AckRanges: h.packetHistory.GetAckRanges(), + DelayTime: now.Sub(h.largestObservedReceivedTime), } h.lastAck = ack diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go index ba119544..758286df 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go @@ -104,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { ackRanges := make([]wire.AckRange, h.ranges.Len()) i := 0 for el := h.ranges.Back(); el != nil; el = el.Prev() { - ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End} + ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End} i++ } return ackRanges @@ -114,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange { ackRange := wire.AckRange{} if h.ranges.Len() > 0 { r := h.ranges.Back().Value - ackRange.First = r.Start - ackRange.Last = r.End + ackRange.Smallest = r.Start + ackRange.Largest = r.End } return ackRange } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go index 61573a47..76c833c4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go @@ -14,7 +14,9 @@ const ( SendRetransmission // SendRTO means that an RTO probe packet should be sent SendRTO - // SendAny packet should be sent + // SendTLP means that a TLP probe packet should be sent + SendTLP + // SendAny means that any packet should be sent SendAny ) @@ -28,6 +30,8 @@ func (s SendMode) String() string { return "retransmission" case SendRTO: return "rto" + case SendTLP: + return "tlp" case SendAny: return "any" default: diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go index 72ab4df8..4fdb8c36 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go @@ -1,6 +1,7 @@ package ackhandler import ( + "errors" "fmt" "math" "time" @@ -16,13 +17,12 @@ const ( // Maximum reordering in time space before time based loss detection considers a packet lost. // In fraction of an RTT. timeReorderingFraction = 1.0 / 8 - // The default RTT used before an RTT sample is taken. - // Note: This constant is also defined in the congestion package. - defaultInitialRTT = 100 * time.Millisecond // defaultRTOTimeout is the RTO time on new connections defaultRTOTimeout = 500 * time.Millisecond // Minimum time in the future a tail loss probe alarm may be set for. minTPLTimeout = 10 * time.Millisecond + // Maximum number of tail loss probes before an RTO fires. + maxTLPs = 2 // Minimum time in the future an RTO alarm may be set for. minRTOTimeout = 200 * time.Millisecond // maxRTOTimeout is the maximum RTO time @@ -59,6 +59,10 @@ type sentPacketHandler struct { // The number of times the handshake packets have been retransmitted without receiving an ack. handshakeCount uint32 + // The number of times a TLP has been sent without receiving an ack. + tlpCount uint32 + allowTLP bool + // The number of times an RTO has been sent without receiving an ack. rtoCount uint32 // The number of RTO probe packets that should be sent. @@ -71,10 +75,12 @@ type sentPacketHandler struct { alarm time.Time logger utils.Logger + + version protocol.VersionNumber } // NewSentPacketHandler creates a new sentPacketHandler -func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler { +func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler { congestion := congestion.NewCubicSender( congestion.DefaultClock{}, rttStats, @@ -89,6 +95,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se rttStats: rttStats, congestion: congestion, logger: logger, + version: version, } } @@ -100,6 +107,7 @@ func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber { } func (h *sentPacketHandler) SetHandshakeComplete() { + h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.") var queue []*Packet for _, packet := range h.retransmissionQueue { if packet.EncryptionLevel == protocol.EncryptionForwardSecure { @@ -150,7 +158,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt if len(packet.Frames) > 0 { if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok { - packet.largestAcked = ackFrame.LargestAcked + packet.largestAcked = ackFrame.LargestAcked() } } @@ -168,6 +176,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt if h.numRTOs > 0 { h.numRTOs-- } + h.allowTLP = false } h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable) @@ -176,7 +185,8 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt } func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { - if ackFrame.LargestAcked > h.lastSentPacketNumber { + largestAcked := ackFrame.LargestAcked() + if largestAcked > h.lastSentPacketNumber { return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package") } @@ -186,13 +196,13 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe return nil } h.largestReceivedPacketWithAck = withPacketNumber - h.largestAcked = utils.MaxPacketNumber(h.largestAcked, ackFrame.LargestAcked) + h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked) if h.skippedPacketsAcked(ackFrame) { return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number") } - if rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime); rttUpdated { + if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated { h.congestion.MaybeExitSlowStart() } @@ -212,11 +222,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe if p.largestAcked != 0 { h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1) } - if err := h.onPacketAcked(p); err != nil { + if err := h.onPacketAcked(p, rcvTime); err != nil { return err } if p.includedInBytesInFlight { - h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight) + h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) } } @@ -238,27 +248,29 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) { var ackedPackets []*Packet ackRangeIndex := 0 + lowestAcked := ackFrame.LowestAcked() + largestAcked := ackFrame.LargestAcked() err := h.packetHistory.Iterate(func(p *Packet) (bool, error) { - // Ignore packets below the LowestAcked - if p.PacketNumber < ackFrame.LowestAcked { + // Ignore packets below the lowest acked + if p.PacketNumber < lowestAcked { return true, nil } - // Break after LargestAcked is reached - if p.PacketNumber > ackFrame.LargestAcked { + // Break after largest acked is reached + if p.PacketNumber > largestAcked { return false, nil } if ackFrame.HasMissingRanges() { ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] - for p.PacketNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 { + for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 { ackRangeIndex++ ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] } - if p.PacketNumber >= ackRange.First { // packet i contained in ACK range - if p.PacketNumber > ackRange.Last { - return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.First, ackRange.Last) + if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range + if p.PacketNumber > ackRange.Largest { + return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest) } ackedPackets = append(ackedPackets, p) } @@ -267,12 +279,22 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) } return true, nil }) + if h.logger.Debug() && len(ackedPackets) > 0 { + pns := make([]protocol.PacketNumber, len(ackedPackets)) + for i, p := range ackedPackets { + pns[i] = p.PacketNumber + } + h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns) + } return ackedPackets, err } func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool { if p := h.packetHistory.GetPacket(largestAcked); p != nil { h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) + if h.logger.Debug() { + h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) + } return true } return false @@ -280,20 +302,25 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a func (h *sentPacketHandler) updateLossDetectionAlarm() { // Cancel the alarm if no packets are outstanding - if h.packetHistory.Len() == 0 { + if !h.packetHistory.HasOutstandingPackets() { h.alarm = time.Time{} return } - // TODO(#497): TLP - if !h.handshakeComplete { + if h.packetHistory.HasOutstandingHandshakePackets() { h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout()) } else if !h.lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = h.lossTime } else { - // RTO - h.alarm = h.lastSentRetransmittablePacketTime.Add(h.computeRTOTimeout()) + // RTO or TLP alarm + alarmDuration := h.computeRTOTimeout() + if h.tlpCount < maxTLPs { + tlpAlarm := h.computeTLPTimeout() + // if the RTO duration is shorter than the TLP duration, use the RTO duration + alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm) + } + h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration) } } @@ -313,11 +340,21 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto if timeSinceSent > delayUntilLost { lostPackets = append(lostPackets, packet) } else if h.lossTime.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent) + } // Note: This conditional is only entered once per call h.lossTime = now.Add(delayUntilLost - timeSinceSent) } return true, nil }) + if h.logger.Debug() && len(lostPackets) > 0 { + pns := make([]protocol.PacketNumber, len(lostPackets)) + for i, p := range lostPackets { + pns[i] = p.PacketNumber + } + h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns) + } for _, p := range lostPackets { // the bytes in flight need to be reduced no matter if this packet will be retransmitted @@ -327,7 +364,6 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto } if p.canBeRetransmitted { // queue the packet for retransmission, and report the loss to the congestion controller - h.logger.Debugf("\tQueueing packet %#x because it was detected lost", p.PacketNumber) if err := h.queuePacketForRetransmission(p); err != nil { return err } @@ -338,34 +374,57 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto } func (h *sentPacketHandler) OnAlarm() error { - now := time.Now() - - // TODO(#497): TLP - var err error - if !h.handshakeComplete { - h.handshakeCount++ - err = h.queueHandshakePacketsForRetransmission() - } else if !h.lossTime.IsZero() { - // Early retransmit or time loss detection - err = h.detectLostPackets(now, h.bytesInFlight) - } else { - // RTO - h.rtoCount++ - h.numRTOs += 2 - err = h.queueRTOs() - } - if err != nil { - return err + // When all outstanding are acknowledged, the alarm is canceled in + // updateLossDetectionAlarm. This doesn't reset the timer in the session though. + // When OnAlarm is called, we therefore need to make sure that there are + // actually packets outstanding. + if h.packetHistory.HasOutstandingPackets() { + if err := h.onVerifiedAlarm(); err != nil { + return err + } } h.updateLossDetectionAlarm() return nil } +func (h *sentPacketHandler) onVerifiedAlarm() error { + var err error + if h.packetHistory.HasOutstandingHandshakePackets() { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount) + } + h.handshakeCount++ + err = h.queueHandshakePacketsForRetransmission() + } else if !h.lossTime.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime) + } + // Early retransmit or time loss detection + err = h.detectLostPackets(time.Now(), h.bytesInFlight) + } else if h.tlpCount < maxTLPs { // TLP + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount) + } + h.allowTLP = true + h.tlpCount++ + } else { // RTO + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount) + } + if h.rtoCount == 0 { + h.largestSentBeforeRTO = h.lastSentPacketNumber + } + h.rtoCount++ + h.numRTOs += 2 + } + return err +} + func (h *sentPacketHandler) GetAlarmTimeout() time.Time { return h.alarm } -func (h *sentPacketHandler) onPacketAcked(p *Packet) error { +func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error { // This happens if a packet and its retransmissions is acked in the same ACK. // As soon as we process the first one, this will remove all the retransmissions, // so we won't find the retransmitted packet number later. @@ -404,8 +463,8 @@ func (h *sentPacketHandler) onPacketAcked(p *Packet) error { return err } h.rtoCount = 0 + h.tlpCount = 0 h.handshakeCount = 0 - // TODO(#497): h.tlpCount = 0 return h.packetHistory.Remove(p.PacketNumber) } @@ -447,8 +506,21 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { return packet } +func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) { + if len(h.retransmissionQueue) == 0 { + p := h.packetHistory.FirstOutstanding() + if p == nil { + return nil, errors.New("cannot dequeue a probe packet. No outstanding packets") + } + if err := h.queuePacketForRetransmission(p); err != nil { + return nil, err + } + } + return h.DequeuePacketForRetransmission(), nil +} + func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen { - return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked()) + return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version) } func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { @@ -463,15 +535,22 @@ func (h *sentPacketHandler) SendMode() SendMode { // we will stop sending out new data when reaching MaxOutstandingSentPackets, // but still allow sending of retransmissions and ACKs. if numTrackedPackets >= protocol.MaxTrackedSentPackets { - h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) + if h.logger.Debug() { + h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) + } return SendNone } + if h.allowTLP { + return SendTLP + } if h.numRTOs > 0 { return SendRTO } // Only send ACKs if we're congestion limited. if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd { - h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd) + if h.logger.Debug() { + h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd) + } return SendAck } // Send retransmissions first, if there are any. @@ -479,7 +558,9 @@ func (h *sentPacketHandler) SendMode() SendMode { return SendRetransmission } if numTrackedPackets >= protocol.MaxOutstandingSentPackets { - h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) + if h.logger.Debug() { + h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) + } return SendAck } return SendAny @@ -501,23 +582,6 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) } -// retransmit the oldest two packets -func (h *sentPacketHandler) queueRTOs() error { - h.largestSentBeforeRTO = h.lastSentPacketNumber - // Queue the first two outstanding packets for retransmission. - // This does NOT declare this packets as lost: - // They are still tracked in the packet history and count towards the bytes in flight. - for i := 0; i < 2; i++ { - if p := h.packetHistory.FirstOutstanding(); p != nil { - h.logger.Debugf("\tQueueing packet %#x for retransmission (RTO)", p.PacketNumber) - if err := h.queuePacketForRetransmission(p); err != nil { - return err - } - } - } - return nil -} - func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error { var handshakePackets []*Packet h.packetHistory.Iterate(func(p *Packet) (bool, error) { @@ -527,7 +591,7 @@ func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error { return true, nil }) for _, p := range handshakePackets { - h.logger.Debugf("\tQueueing packet %#x as a handshake retransmission", p.PacketNumber) + h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber) if err := h.queuePacketForRetransmission(p); err != nil { return err } @@ -548,16 +612,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error { } func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { - duration := 2 * h.rttStats.SmoothedRTT() - if duration == 0 { - duration = 2 * defaultInitialRTT - } - duration = utils.MaxDuration(duration, minTPLTimeout) + duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout) // exponential backoff // There's an implicit limit to this set by the handshake timeout. return duration << h.handshakeCount } +func (h *sentPacketHandler) computeTLPTimeout() time.Duration { + // TODO(#1236): include the max_ack_delay + return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout) +} + func (h *sentPacketHandler) computeRTOTimeout() time.Duration { var rto time.Duration rtt := h.rttStats.SmoothedRTT() diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go index 38a2a0e5..91aa2697 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go @@ -10,6 +10,9 @@ type sentPacketHistory struct { packetList *PacketList packetMap map[protocol.PacketNumber]*PacketElement + numOutstandingPackets int + numOutstandingHandshakePackets int + firstOutstanding *PacketElement } @@ -30,6 +33,12 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement { if h.firstOutstanding == nil { h.firstOutstanding = el } + if p.canBeRetransmitted { + h.numOutstandingPackets++ + if p.EncryptionLevel < protocol.EncryptionForwardSecure { + h.numOutstandingHandshakePackets++ + } + } return el } @@ -92,6 +101,18 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) if !ok { return fmt.Errorf("sent packet history: packet %d not found", pn) } + if el.Value.canBeRetransmitted { + h.numOutstandingPackets-- + if h.numOutstandingPackets < 0 { + panic("numOutstandingHandshakePackets negative") + } + if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { + h.numOutstandingHandshakePackets-- + if h.numOutstandingHandshakePackets < 0 { + panic("numOutstandingHandshakePackets negative") + } + } + } el.Value.canBeRetransmitted = false if el == h.firstOutstanding { h.readjustFirstOutstanding() @@ -121,7 +142,27 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { if el == h.firstOutstanding { h.readjustFirstOutstanding() } + if el.Value.canBeRetransmitted { + h.numOutstandingPackets-- + if h.numOutstandingPackets < 0 { + panic("numOutstandingHandshakePackets negative") + } + if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { + h.numOutstandingHandshakePackets-- + if h.numOutstandingHandshakePackets < 0 { + panic("numOutstandingHandshakePackets negative") + } + } + } h.packetList.Remove(el) delete(h.packetMap, p) return nil } + +func (h *sentPacketHistory) HasOutstandingPackets() bool { + return h.numOutstandingPackets > 0 +} + +func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool { + return h.numOutstandingHandshakePackets > 0 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go index 04cb61f9..40ad88cd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go @@ -30,8 +30,9 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFr } func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) { - if ack.LargestAcked >= s.nextLeastUnacked { - s.nextLeastUnacked = ack.LargestAcked + 1 + largestAcked := ack.LargestAcked() + if largestAcked >= s.nextLeastUnacked { + s.nextLeastUnacked = largestAcked + 1 } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go index 3922f476..dcf91fc6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go @@ -16,11 +16,10 @@ import ( // allow a 10 shift right to divide. // 1024*1024^3 (first 1024 is from 0.100^3) -// where 0.100 is 100 ms which is the scaling -// round trip time. +// where 0.100 is 100 ms which is the scaling round trip time. const cubeScale = 40 const cubeCongestionWindowScale = 410 -const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale +const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS const defaultNumConnections = 2 @@ -32,39 +31,35 @@ const beta float32 = 0.7 // new concurrent flows and speed up convergence. const betaLastMax float32 = 0.85 -// If true, Cubic's epoch is shifted when the sender is application-limited. -const shiftQuicCubicEpochWhenAppLimited = true - -const maxCubicTimeInterval = 30 * time.Millisecond - // Cubic implements the cubic algorithm from TCP type Cubic struct { clock Clock + // Number of connections to simulate. numConnections int + // Time when this cycle started, after last loss event. epoch time.Time - // Time when sender went into application-limited period. Zero if not in - // application-limited period. - appLimitedStartTime time.Time - // Time when we updated last_congestion_window. - lastUpdateTime time.Time - // Last congestion window (in packets) used. - lastCongestionWindow protocol.PacketNumber - // Max congestion window (in packets) used just before last loss event. + + // Max congestion window used just before last loss event. // Note: to improve fairness to other streams an additional back off is // applied to this value if the new value is below our latest value. - lastMaxCongestionWindow protocol.PacketNumber - // Number of acked packets since the cycle started (epoch). - ackedPacketsCount protocol.PacketNumber + lastMaxCongestionWindow protocol.ByteCount + + // Number of acked bytes since the cycle started (epoch). + ackedBytesCount protocol.ByteCount + // TCP Reno equivalent congestion window in packets. - estimatedTCPcongestionWindow protocol.PacketNumber + estimatedTCPcongestionWindow protocol.ByteCount + // Origin point of cubic function. - originPointCongestionWindow protocol.PacketNumber + originPointCongestionWindow protocol.ByteCount + // Time to origin point of cubic function in 2^10 fractions of a second. timeToOriginPoint uint32 + // Last congestion window in packets computed by cubic function. - lastTargetCongestionWindow protocol.PacketNumber + lastTargetCongestionWindow protocol.ByteCount } // NewCubic returns a new Cubic instance @@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic { // Reset is called after a timeout to reset the cubic state func (c *Cubic) Reset() { c.epoch = time.Time{} - c.appLimitedStartTime = time.Time{} - c.lastUpdateTime = time.Time{} - c.lastCongestionWindow = 0 c.lastMaxCongestionWindow = 0 - c.ackedPacketsCount = 0 + c.ackedBytesCount = 0 c.estimatedTCPcongestionWindow = 0 c.originPointCongestionWindow = 0 c.timeToOriginPoint = 0 @@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 { return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) } +func (c *Cubic) betaLastMax() float32 { + // betaLastMax is the additional backoff factor after loss for our + // N-connection emulation, which emulates the additional backoff of + // an ensemble of N TCP-Reno connections on a single loss event. The + // effective multiplier is computed as: + return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) +} + // OnApplicationLimited is called on ack arrival when sender is unable to use // the available congestion window. Resets Cubic state during quiescence. func (c *Cubic) OnApplicationLimited() { - if shiftQuicCubicEpochWhenAppLimited { - // When sender is not using the available congestion window, Cubic's epoch - // should not continue growing. Record the time when sender goes into an - // app-limited period here, to compensate later when cwnd growth happens. - if c.appLimitedStartTime.IsZero() { - c.appLimitedStartTime = c.clock.Now() - } - } else { - // When sender is not using the available congestion window, Cubic's epoch - // should not continue growing. Reset the epoch when in such a period. - c.epoch = time.Time{} - } + // When sender is not using the available congestion window, the window does + // not grow. But to be RTT-independent, Cubic assumes that the sender has been + // using the entire window during the time since the beginning of the current + // "epoch" (the end of the last loss recovery period). Since + // application-limited periods break this assumption, we reset the epoch when + // in such a period. This reset effectively freezes congestion window growth + // through application-limited periods and allows Cubic growth to continue + // when the entire window is being used. + c.epoch = time.Time{} } // CongestionWindowAfterPacketLoss computes a new congestion window to use after // a loss event. Returns the new congestion window in packets. The new // congestion window is a multiplicative decrease of our current window. -func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber { - if currentCongestionWindow < c.lastMaxCongestionWindow { +func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount { + if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow { // We never reached the old max, so assume we are competing with another // flow. Use our extra back off factor to allow the other flow to go up. - c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow)) + c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) } else { c.lastMaxCongestionWindow = currentCongestionWindow } c.epoch = time.Time{} // Reset time. - return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta()) + return protocol.ByteCount(float32(currentCongestionWindow) * c.beta()) } // CongestionWindowAfterAck computes a new congestion window to use after a received ACK. // Returns the new congestion window in packets. The new congestion window // follows a cubic function that depends on the time passed since last // packet loss. -func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber { - c.ackedPacketsCount++ // Packets acked. - currentTime := c.clock.Now() - - // Cubic is "independent" of RTT, the update is limited by the time elapsed. - if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) { - return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow) - } - c.lastCongestionWindow = currentCongestionWindow - c.lastUpdateTime = currentTime +func (c *Cubic) CongestionWindowAfterAck( + ackedBytes protocol.ByteCount, + currentCongestionWindow protocol.ByteCount, + delayMin time.Duration, + eventTime time.Time, +) protocol.ByteCount { + c.ackedBytesCount += ackedBytes if c.epoch.IsZero() { // First ACK after a loss event. - c.epoch = currentTime // Start of epoch. - c.ackedPacketsCount = 1 // Reset count. + c.epoch = eventTime // Start of epoch. + c.ackedBytesCount = ackedBytes // Reset count. // Reset estimated_tcp_congestion_window_ to be in sync with cubic. c.estimatedTCPcongestionWindow = currentCongestionWindow if c.lastMaxCongestionWindow <= currentCongestionWindow { @@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) c.originPointCongestionWindow = c.lastMaxCongestionWindow } - } else { - // If sender was app-limited, then freeze congestion window growth during - // app-limited period. Continue growth now by shifting the epoch-start - // through the app-limited period. - if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() { - shift := currentTime.Sub(c.appLimitedStartTime) - c.epoch = c.epoch.Add(shift) - c.appLimitedStartTime = time.Time{} - } } // Change the time unit from microseconds to 2^10 fractions per second. Take // the round trip time in account. This is done to allow us to use shift as a // divide operator. - elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000 + elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) + // Right-shifts of negative, signed numbers have implementation-dependent + // behavior, so force the offset to be positive, as is done in the kernel. offset := int64(c.timeToOriginPoint) - elapsedTime - // Right-shifts of negative, signed numbers have - // implementation-dependent behavior. Force the offset to be - // positive, similar to the kernel implementation. if offset < 0 { offset = -offset } - deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale) - var targetCongestionWindow protocol.PacketNumber + + deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale + var targetCongestionWindow protocol.ByteCount if elapsedTime > int64(c.timeToOriginPoint) { targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow } else { targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow } - // With dynamic beta/alpha based on number of active streams, it is possible - // for the required_ack_count to become much lower than acked_packets_count_ - // suddenly, leading to more than one iteration through the following loop. - for { - // Update estimated TCP congestion_window. - requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha()) - if c.ackedPacketsCount < requiredAckCount { - break - } - c.ackedPacketsCount -= requiredAckCount - c.estimatedTCPcongestionWindow++ - } + // Limit the CWND increase to half the acked bytes. + targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + + // Increase the window by approximately Alpha * 1 MSS of bytes every + // time we ack an estimated tcp window of bytes. For small + // congestion windows (less than 25), the formula below will + // increase slightly slower than linearly per estimated tcp window + // of bytes. + c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow)) + c.ackedBytesCount = 0 // We have a new cubic congestion window. c.lastTargetCongestionWindow = targetCongestionWindow @@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet if targetCongestionWindow < c.estimatedTCPcongestionWindow { targetCongestionWindow = c.estimatedTCPcongestionWindow } - return targetCongestionWindow } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go index 21f01942..b9f67e6c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go @@ -8,9 +8,9 @@ import ( ) const ( - maxBurstBytes = 3 * protocol.DefaultTCPMSS - defaultMinimumCongestionWindow protocol.PacketNumber = 2 - renoBeta float32 = 0.7 // Reno backoff factor. + maxBurstBytes = 3 * protocol.DefaultTCPMSS + renoBeta float32 = 0.7 // Reno backoff factor. + defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS ) type cubicSender struct { @@ -31,12 +31,6 @@ type cubicSender struct { // Track the largest packet number outstanding when a CWND cutback occurs. largestSentAtLastCutback protocol.PacketNumber - // Congestion window in packets. - congestionWindow protocol.PacketNumber - - // Slow start congestion window in packets, aka ssthresh. - slowstartThreshold protocol.PacketNumber - // Whether the last loss event caused us to exit slowstart. // Used for stats collection of slowstartPacketsLost lastCutbackExitedSlowstart bool @@ -44,24 +38,35 @@ type cubicSender struct { // When true, exit slow start with large cutback of congestion window. slowStartLargeReduction bool - // Minimum congestion window in packets. - minCongestionWindow protocol.PacketNumber + // Congestion window in packets. + congestionWindow protocol.ByteCount - // Maximum number of outstanding packets for tcp. - maxTCPCongestionWindow protocol.PacketNumber + // Minimum congestion window in packets. + minCongestionWindow protocol.ByteCount + + // Maximum congestion window. + maxCongestionWindow protocol.ByteCount + + // Slow start congestion window in bytes, aka ssthresh. + slowstartThreshold protocol.ByteCount // Number of connections to simulate. numConnections int // ACK counter for the Reno implementation. - congestionWindowCount protocol.ByteCount + numAckedPackets uint64 - initialCongestionWindow protocol.PacketNumber - initialMaxCongestionWindow protocol.PacketNumber + initialCongestionWindow protocol.ByteCount + initialMaxCongestionWindow protocol.ByteCount + + minSlowStartExitWindow protocol.ByteCount } +var _ SendAlgorithm = &cubicSender{} +var _ SendAlgorithmWithDebugInfo = &cubicSender{} + // NewCubicSender makes a new cubic sender -func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo { +func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo { return &cubicSender{ rttStats: rttStats, initialCongestionWindow: initialCongestionWindow, @@ -69,7 +74,7 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio congestionWindow: initialCongestionWindow, minCongestionWindow: defaultMinimumCongestionWindow, slowstartThreshold: initialMaxCongestionWindow, - maxTCPCongestionWindow: initialMaxCongestionWindow, + maxCongestionWindow: initialMaxCongestionWindow, numConnections: defaultNumConnections, cubic: NewCubic(clock), reno: reno, @@ -80,21 +85,26 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration { if c.InRecovery() { // PRR is used when in recovery. - if c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) == 0 { + if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) { return 0 } } - delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()/protocol.DefaultTCPMSS) + delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()) if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt delay = delay * 8 / 5 } return delay } -func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { - // Only update bytesInFlight for data packets. +func (c *cubicSender) OnPacketSent( + sentTime time.Time, + bytesInFlight protocol.ByteCount, + packetNumber protocol.PacketNumber, + bytes protocol.ByteCount, + isRetransmittable bool, +) { if !isRetransmittable { - return false + return } if c.InRecovery() { // PRR is used when in recovery. @@ -102,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By } c.largestSentPacketNumber = packetNumber c.hybridSlowStart.OnPacketSent(packetNumber) - return true } func (c *cubicSender) InRecovery() bool { @@ -114,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool { } func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { - return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS + return c.congestionWindow } func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount { - return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS + return c.slowstartThreshold } func (c *cubicSender) ExitSlowstart() { c.slowstartThreshold = c.congestionWindow } -func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber { +func (c *cubicSender) SlowstartThreshold() protocol.ByteCount { return c.slowstartThreshold } @@ -135,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() { } } -func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) OnPacketAcked( + ackedPacketNumber protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) if c.InRecovery() { // PRR is used when in recovery. c.prr.OnPacketAcked(ackedBytes) return } - c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight) + c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) if c.InSlowStart() { c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) } } -func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) OnPacketLost( + packetNumber protocol.PacketNumber, + lostBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, +) { // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { @@ -156,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes c.stats.slowstartPacketsLost++ c.stats.slowstartBytesLost += lostBytes if c.slowStartLargeReduction { - if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS { - // Reduce congestion window by 1 for every mss of bytes lost. - c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow) - } + // Reduce congestion window by lost_bytes for every loss. + c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow) c.slowstartThreshold = c.congestionWindow } } @@ -170,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes c.stats.slowstartPacketsLost++ } - c.prr.OnPacketLost(bytesInFlight) + c.prr.OnPacketLost(priorInFlight) // TODO(chromium): Separate out all of slow start into a separate class. if c.slowStartLargeReduction && c.InSlowStart() { - c.congestionWindow = c.congestionWindow - 1 + if c.congestionWindow >= 2*c.initialCongestionWindow { + c.minSlowStartExitWindow = c.congestionWindow / 2 + } + c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS } else if c.reno { - c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta()) + c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta()) } else { c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) } - // Enforce a minimum congestion window. if c.congestionWindow < c.minCongestionWindow { c.congestionWindow = c.minCongestionWindow } @@ -188,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes c.largestSentAtLastCutback = c.largestSentPacketNumber // reset packet count from congestion avoidance mode. We start // counting again when we're out of recovery. - c.congestionWindowCount = 0 + c.numAckedPackets = 0 } func (c *cubicSender) RenoBeta() float32 { @@ -201,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 { // Called when we receive an ack. Normal TCP tracks how many packets one ack // represents, but quic has a separate ack for each packet. -func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) maybeIncreaseCwnd( + ackedPacketNumber protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { // Do not increase the congestion window unless the sender is close to using // the current window. - if !c.isCwndLimited(bytesInFlight) { + if !c.isCwndLimited(priorInFlight) { c.cubic.OnApplicationLimited() return } - if c.congestionWindow >= c.maxTCPCongestionWindow { + if c.congestionWindow >= c.maxCongestionWindow { return } if c.InSlowStart() { // TCP slow start, exponential growth, increase by one for each ACK. - c.congestionWindow++ + c.congestionWindow += protocol.DefaultTCPMSS return } + // Congestion avoidance if c.reno { // Classic Reno congestion avoidance. - c.congestionWindowCount++ + c.numAckedPackets++ // Divide by num_connections to smoothly increase the CWND at a faster // rate than conventional Reno. - if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow { - c.congestionWindow++ - c.congestionWindowCount = 0 + if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) { + c.congestionWindow += protocol.DefaultTCPMSS + c.numAckedPackets = 0 } } else { - c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT())) + c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) } } @@ -282,10 +306,10 @@ func (c *cubicSender) OnConnectionMigration() { c.largestSentAtLastCutback = 0 c.lastCutbackExitedSlowstart = false c.cubic.Reset() - c.congestionWindowCount = 0 + c.numAckedPackets = 0 c.congestionWindow = c.initialCongestionWindow c.slowstartThreshold = c.initialMaxCongestionWindow - c.maxTCPCongestionWindow = c.initialMaxCongestionWindow + c.maxCongestionWindow = c.initialMaxCongestionWindow } // SetSlowStartLargeReduction allows enabling the SSLR experiment diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go index 28950dd0..7c27da64 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go @@ -9,11 +9,11 @@ import ( // A SendAlgorithm performs congestion control and calculates the congestion window type SendAlgorithm interface { TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration - OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool + OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) GetCongestionWindow() protocol.ByteCount MaybeExitSlowStart() - OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) - OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) + OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) + OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) SetNumEmulatedConnections(n int) OnRetransmissionTimeout(packetsRetransmitted bool) OnConnectionMigration() @@ -30,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface { // Stuff only used in testing HybridSlowStart() *HybridSlowStart - SlowstartThreshold() protocol.PacketNumber + SlowstartThreshold() protocol.ByteCount RenoBeta() float32 InRecovery() bool } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go index 18a3736a..5c807d19 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go @@ -1,10 +1,7 @@ package congestion import ( - "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" ) // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 @@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) { // OnPacketLost should be called on the first loss that triggers a recovery // period and all other methods in this class should only be called when in // recovery. -func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) { +func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) { p.bytesSentSinceLoss = 0 - p.bytesInFlightBeforeLoss = bytesInFlight + p.bytesInFlightBeforeLoss = priorInFlight p.bytesDeliveredSinceLoss = 0 p.ackCountSinceLoss = 0 } @@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) { p.ackCountSinceLoss++ } -// TimeUntilSend calculates the time until a packet can be sent -func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration { +// CanSend returns if packets can be sent +func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool { // Return QuicTime::Zero In order to ensure limited transmit always works. if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS { - return 0 + return true } if congestionWindow > bytesInFlight { // During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead // of sending the entire available window. This prevents burst retransmits // when more packets are lost than the CWND reduction. // limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS - if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss { - return utils.InfDuration - } - return 0 + return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss } // Implement Proportional Rate Reduction (RFC6937). // Checks a simplified version of the PRR formula that doesn't use division: // AvailableSendWindow = // CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent - if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss { - return 0 - } - return utils.InfDuration + return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go index 599e350f..f0ebbb23 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go @@ -7,50 +7,27 @@ import ( ) const ( - // Note: This constant is also defined in the ackhandler package. - initialRTTus = 100 * 1000 rttAlpha float32 = 0.125 oneMinusAlpha float32 = (1 - rttAlpha) rttBeta float32 = 0.25 oneMinusBeta float32 = (1 - rttBeta) - halfWindow float32 = 0.5 - quarterWindow float32 = 0.25 + // The default RTT used before an RTT sample is taken. + defaultInitialRTT = 100 * time.Millisecond ) -type rttSample struct { - rtt time.Duration - time time.Time -} - // RTTStats provides round-trip statistics type RTTStats struct { - initialRTTus int64 - - recentMinRTTwindow time.Duration - minRTT time.Duration - latestRTT time.Duration - smoothedRTT time.Duration - meanDeviation time.Duration - - numMinRTTsamplesRemaining uint32 - - newMinRTT rttSample - recentMinRTT rttSample - halfWindowRTT rttSample - quarterWindowRTT rttSample + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + meanDeviation time.Duration } // NewRTTStats makes a properly initialized RTTStats object func NewRTTStats() *RTTStats { - return &RTTStats{ - initialRTTus: initialRTTus, - recentMinRTTwindow: utils.InfDuration, - } + return &RTTStats{} } -// InitialRTTus is the initial RTT in us -func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus } - // MinRTT Returns the minRTT for the entire connection. // May return Zero if no valid updates have occurred. func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } @@ -59,28 +36,22 @@ func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } // May return Zero if no valid updates have occurred. func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } -// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the -// minRTT for the entire connection if SampleNewMinRtt was never called. -func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt } - // SmoothedRTT returns the EWMA smoothed RTT for the connection. // May return Zero if no valid updates have occurred. func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } -// GetQuarterWindowRTT gets the quarter window RTT -func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt } - -// GetHalfWindowRTT gets the half window RTT -func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt } +// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection. +// If no valid updates have occurred, it returns the initial RTT. +func (r *RTTStats) SmoothedOrInitialRTT() time.Duration { + if r.smoothedRTT != 0 { + return r.smoothedRTT + } + return defaultInitialRTT +} // MeanDeviation gets the mean deviation func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } -// SetRecentMinRTTwindow sets how old a recent min rtt sample can be. -func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) { - r.recentMinRTTwindow = recentMinRTTwindow -} - // UpdateRTT updates the RTT based on a new sample. func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { if sendDelta == utils.InfDuration || sendDelta <= 0 { @@ -94,7 +65,6 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { if r.minRTT == 0 || r.minRTT > sendDelta { r.minRTT = sendDelta } - r.updateRecentMinRTT(sendDelta, now) // Correct for ackDelay if information received from the peer results in a // an RTT sample at least as large as minRTT. Otherwise, only use the @@ -114,63 +84,12 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { } } -func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update. - if r.numMinRTTsamplesRemaining > 0 { - r.numMinRTTsamplesRemaining-- - if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt { - r.newMinRTT = rttSample{rtt: sample, time: now} - } - if r.numMinRTTsamplesRemaining == 0 { - r.recentMinRTT = r.newMinRTT - r.halfWindowRTT = r.newMinRTT - r.quarterWindowRTT = r.newMinRTT - } - } - - // Update the three recent rtt samples. - if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt { - r.recentMinRTT = rttSample{rtt: sample, time: now} - r.halfWindowRTT = r.recentMinRTT - r.quarterWindowRTT = r.recentMinRTT - } else if sample <= r.halfWindowRTT.rtt { - r.halfWindowRTT = rttSample{rtt: sample, time: now} - r.quarterWindowRTT = r.halfWindowRTT - } else if sample <= r.quarterWindowRTT.rtt { - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } - - // Expire old min rtt samples. - if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) { - r.recentMinRTT = r.halfWindowRTT - r.halfWindowRTT = r.quarterWindowRTT - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) { - r.halfWindowRTT = r.quarterWindowRTT - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) { - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } -} - -// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next -// |numSamples| UpdateRTT calls. -func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) { - r.numMinRTTsamplesRemaining = numSamples - r.newMinRTT = rttSample{} -} - // OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. func (r *RTTStats) OnConnectionMigration() { r.latestRTT = 0 r.minRTT = 0 r.smoothedRTT = 0 r.meanDeviation = 0 - r.initialRTTus = initialRTTus - r.numMinRTTsamplesRemaining = 0 - r.recentMinRTTwindow = utils.InfDuration - r.recentMinRTT = rttSample{} - r.halfWindowRTT = rttSample{} - r.quarterWindowRTT = rttSample{} } // ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go index 89b9e1f1..8aa187ff 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go @@ -15,7 +15,7 @@ const ( // A TLSExporter gets the negotiated ciphersuite and computes exporter type TLSExporter interface { - GetCipherSuite() mint.CipherSuiteParams + ConnectionState() mint.ConnectionState ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) } @@ -49,7 +49,7 @@ func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { } func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) { - cs := tls.GetCipherSuite() + cs := tls.ConnectionState().CipherSuite secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size()) if err != nil { return nil, nil, err diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go index 28f6c2cc..6c294178 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go @@ -6,7 +6,6 @@ import ( "io" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "golang.org/x/crypto/hkdf" ) @@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol } else { info.Write([]byte("QUIC key expansion\x00")) } - utils.BigEndian.WriteUint64(&info, uint64(connID)) + info.Write(connID) info.Write(chlo) info.Write(scfg) info.Write(cert) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go index 92c6362f..4abc6229 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go @@ -2,7 +2,6 @@ package crypto import ( "crypto" - "encoding/binary" "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -28,9 +27,7 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) } -func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { - connID := make([]byte, 8) - binary.BigEndian.PutUint64(connID, uint64(connectionID)) +func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID) clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size()) serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size()) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go index fb92f084..6a0aa3c5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go @@ -11,8 +11,9 @@ import ( type baseFlowController struct { // for sending data - bytesSent protocol.ByteCount - sendWindow protocol.ByteCount + bytesSent protocol.ByteCount + sendWindow protocol.ByteCount + lastBlockedAt protocol.ByteCount // for receiving data mutex sync.RWMutex @@ -29,6 +30,17 @@ type baseFlowController struct { logger utils.Logger } +// IsNewlyBlocked says if it is newly blocked by flow control. +// For every offset, it only returns true once. +// If it is blocked, the offset is returned. +func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { + return false, 0 + } + c.lastBlockedAt = c.sendWindow + return true, c.sendWindow +} + func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { c.bytesSent += n } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go index c4f6e125..d18eaf55 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -10,8 +10,9 @@ import ( ) type connectionFlowController struct { - lastBlockedAt protocol.ByteCount baseFlowController + + queueWindowUpdate func() } var _ ConnectionFlowController = &connectionFlowController{} @@ -21,6 +22,7 @@ var _ ConnectionFlowController = &connectionFlowController{} func NewConnectionFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, + queueWindowUpdate func(), rttStats *congestion.RTTStats, logger utils.Logger, ) ConnectionFlowController { @@ -32,6 +34,7 @@ func NewConnectionFlowController( maxReceiveWindowSize: maxReceiveWindow, logger: logger, }, + queueWindowUpdate: queueWindowUpdate, } } @@ -39,17 +42,6 @@ func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { return c.baseFlowController.sendWindowSize() } -// IsNewlyBlocked says if it is newly blocked by flow control. -// For every offset, it only returns true once. -// If it is blocked, the offset is returned. -func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { - if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { - return false, 0 - } - c.lastBlockedAt = c.sendWindow - return true, c.sendWindow -} - // IncrementHighestReceived adds an increment to the highestReceived value func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { c.mutex.Lock() @@ -62,6 +54,15 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B return nil } +func (c *connectionFlowController) MaybeQueueWindowUpdate() { + c.mutex.Lock() + hasWindowUpdate := c.hasWindowUpdate() + c.mutex.Unlock() + if hasWindowUpdate { + c.queueWindowUpdate() + } +} + func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { c.mutex.Lock() oldWindowSize := c.receiveWindowSize @@ -78,6 +79,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { c.mutex.Lock() if inc > c.receiveWindowSize { + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize) c.startNewAutoTuningEpoch() } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go index 61d57e31..5ee53e52 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go @@ -10,26 +10,22 @@ type flowController interface { // for receiving AddBytesRead(protocol.ByteCount) GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary + MaybeQueueWindowUpdate() // queues a window update, if necessary + IsNewlyBlocked() (bool, protocol.ByteCount) } // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController - // for sending - IsBlocked() (bool, protocol.ByteCount) // for receiving // UpdateHighestReceived should be called when a new highest offset is received // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool) error - // HasWindowUpdate says if it is necessary to update the window - HasWindowUpdate() bool } // The ConnectionFlowController is the flow controller for the connection. type ConnectionFlowController interface { flowController - // for sending - IsNewlyBlocked() (bool, protocol.ByteCount) } type connectionFlowControllerI interface { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go index 16bef261..d9878606 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -14,6 +14,8 @@ type streamFlowController struct { streamID protocol.StreamID + queueWindowUpdate func() + connection connectionFlowControllerI contributesToConnection bool // does the stream contribute to connection level flow control @@ -30,6 +32,7 @@ func NewStreamFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, initialSendWindow protocol.ByteCount, + queueWindowUpdate func(protocol.StreamID), rttStats *congestion.RTTStats, logger utils.Logger, ) StreamFlowController { @@ -37,6 +40,7 @@ func NewStreamFlowController( streamID: streamID, contributesToConnection: contributesToConnection, connection: cfc.(connectionFlowControllerI), + queueWindowUpdate: func() { queueWindowUpdate(streamID) }, baseFlowController: baseFlowController{ rttStats: rttStats, receiveWindow: receiveWindow, @@ -111,20 +115,16 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount { return window } -// IsBlocked says if it is blocked by stream-level flow control. -// If it is blocked, the offset is returned. -func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) { - if c.sendWindowSize() != 0 { - return false, 0 - } - return true, c.sendWindow -} - -func (c *streamFlowController) HasWindowUpdate() bool { +func (c *streamFlowController) MaybeQueueWindowUpdate() { c.mutex.Lock() hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate() c.mutex.Unlock() - return hasWindowUpdate + if hasWindowUpdate { + c.queueWindowUpdate() + } + if c.contributesToConnection { + c.connection.MaybeQueueWindowUpdate() + } } func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { @@ -139,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { oldWindowSize := c.receiveWindowSize offset := c.baseFlowController.getWindowUpdate() if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size - c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) + c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10)) if c.contributesToConnection { c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go index 97accb73..00f6e7ef 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go @@ -5,8 +5,6 @@ import ( "fmt" "net" "time" - - "github.com/bifurcation/mint" ) const ( @@ -29,12 +27,12 @@ type token struct { // A CookieGenerator generates Cookies type CookieGenerator struct { - cookieProtector mint.CookieProtector + cookieProtector cookieProtector } // NewCookieGenerator initializes a new CookieGenerator func NewCookieGenerator() (*CookieGenerator, error) { - cookieProtector, err := mint.NewDefaultCookieProtector() + cookieProtector, err := newCookieProtector() if err != nil { return nil, err } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go deleted file mode 100644 index bc2bd8e5..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go +++ /dev/null @@ -1,51 +0,0 @@ -package handshake - -import ( - "net" - - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -// A CookieHandler generates and validates cookies. -// The cookie is sent in the TLS Retry. -// By including the cookie in its ClientHello, a client can proof ownership of its source address. -type CookieHandler struct { - callback func(net.Addr, *Cookie) bool - cookieGenerator *CookieGenerator - - logger utils.Logger -} - -var _ mint.CookieHandler = &CookieHandler{} - -// NewCookieHandler creates a new CookieHandler. -func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) { - cookieGenerator, err := NewCookieGenerator() - if err != nil { - return nil, err - } - return &CookieHandler{ - callback: callback, - cookieGenerator: cookieGenerator, - logger: logger, - }, nil -} - -// Generate a new cookie for a mint connection. -func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { - if h.callback(conn.RemoteAddr(), nil) { - return nil, nil - } - return h.cookieGenerator.NewToken(conn.RemoteAddr()) -} - -// Validate a cookie. -func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool { - data, err := h.cookieGenerator.DecodeToken(token) - if err != nil { - h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) - return false - } - return h.callback(conn.RemoteAddr(), data) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_protector.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_protector.go new file mode 100644 index 00000000..7ebdfa18 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_protector.go @@ -0,0 +1,86 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +// CookieProtector is used to create and verify a cookie +type cookieProtector interface { + // NewToken creates a new token + NewToken([]byte) ([]byte, error) + // DecodeToken decodes a token + DecodeToken([]byte) ([]byte, error) +} + +const ( + cookieSecretSize = 32 + cookieNonceSize = 32 +) + +// cookieProtector is used to create and verify a cookie +type cookieProtectorImpl struct { + secret []byte +} + +// newCookieProtector creates a source for source address tokens +func newCookieProtector() (cookieProtector, error) { + secret := make([]byte, cookieSecretSize) + if _, err := rand.Read(secret); err != nil { + return nil, err + } + return &cookieProtectorImpl{secret: secret}, nil +} + +// NewToken encodes data into a new token. +func (s *cookieProtectorImpl) NewToken(data []byte) ([]byte, error) { + nonce := make([]byte, cookieNonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil +} + +// DecodeToken decodes a token. +func (s *cookieProtectorImpl) DecodeToken(p []byte) ([]byte, error) { + if len(p) < cookieNonceSize { + return nil, fmt.Errorf("Token too short: %d", len(p)) + } + nonce := p[:cookieNonceSize] + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil) +} + +func (s *cookieProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { + h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go cookie source")) + key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 + if _, err := io.ReadFull(h, key); err != nil { + return nil, nil, err + } + aeadNonce := make([]byte, 12) + if _, err := io.ReadFull(h, aeadNonce); err != nil { + return nil, nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, nil, err + } + return aead, aeadNonce, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go index 0700399a..6687b834 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go @@ -38,7 +38,7 @@ type cryptoSetupClient struct { lastSentCHLO []byte certManager crypto.CertManager - divNonceChan <-chan []byte + divNonceChan chan struct{} diversificationNonce []byte clientHelloCounter int @@ -79,29 +79,31 @@ func NewCryptoSetupClient( initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger, -) (CryptoSetup, chan<- []byte, error) { +) (CryptoSetup, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) if err != nil { - return nil, nil, err + return nil, err } - divNonceChan := make(chan []byte) + divNonceChan := make(chan struct{}) cs := &cryptoSetupClient{ - cryptoStream: cryptoStream, - hostname: hostname, - connID: connID, - version: version, - certManager: crypto.NewCertManager(tlsConfig), - params: params, - keyDerivation: crypto.DeriveQuicCryptoAESKeys, - nullAEAD: nullAEAD, - paramsChan: paramsChan, - handshakeEvent: handshakeEvent, - initialVersion: initialVersion, - negotiatedVersions: negotiatedVersions, + cryptoStream: cryptoStream, + hostname: hostname, + connID: connID, + version: version, + certManager: crypto.NewCertManager(tlsConfig), + params: params, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + nullAEAD: nullAEAD, + paramsChan: paramsChan, + handshakeEvent: handshakeEvent, + initialVersion: initialVersion, + // The server might have sent greased versions in the Version Negotiation packet. + // We need strip those from the list, since they won't be included in the handshake tag. + negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions), divNonceChan: divNonceChan, logger: logger, } - return cs, divNonceChan, nil + return cs, nil } func (h *cryptoSetupClient) HandleCryptoStream() error { @@ -120,33 +122,26 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { }() for { - err := h.maybeUpgradeCrypto() - if err != nil { + if err := h.maybeUpgradeCrypto(); err != nil { return err } h.mutex.RLock() sendCHLO := h.secureAEAD == nil h.mutex.RUnlock() - if sendCHLO { - err = h.sendCHLO() - if err != nil { + if err := h.sendCHLO(); err != nil { return err } } var message HandshakeMessage select { - case divNonce := <-h.divNonceChan: - if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) { - return errConflictingDiversificationNonces - } - h.diversificationNonce = divNonce + case <-h.divNonceChan: // there's no message to process, but we should try upgrading the crypto again continue case message = <-messageChan: - case err = <-errorChan: + case err := <-errorChan: return err } @@ -281,6 +276,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*Trans if err != nil { return nil, err } + h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.") params, err := readHelloMap(cryptoData) if err != nil { @@ -326,6 +322,7 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu if h.secureAEAD != nil { data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { + h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.") h.receivedSecurePacket = true return data, protocol.EncryptionSecure, nil } @@ -386,6 +383,21 @@ func (h *cryptoSetupClient) ConnectionState() ConnectionState { } } +func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error { + h.mutex.Lock() + if len(h.diversificationNonce) > 0 { + defer h.mutex.Unlock() + if !bytes.Equal(h.diversificationNonce, divNonce) { + return errConflictingDiversificationNonces + } + return nil + } + h.diversificationNonce = divNonce + h.mutex.Unlock() + h.divNonceChan <- struct{}{} + return nil +} + func (h *cryptoSetupClient) sendCHLO() error { h.clientHelloCounter++ if h.clientHelloCounter > protocol.MaxClientHellos { @@ -501,6 +513,7 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { if err != nil { return err } + h.logger.Debugf("Creating AEAD for secure encryption.") h.handshakeEvent <- struct{}{} } return nil diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go index d977f655..552e8297 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go @@ -60,11 +60,6 @@ type cryptoSetupServer struct { var _ CryptoSetup = &cryptoSetupServer{} -// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO. -// This is an experiment implemented by Chrome in QUIC 36, which we don't support. -// TODO: remove this when dropping support for QUIC 36 -var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported") - // ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO. // This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point. var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported") @@ -132,9 +127,6 @@ func (h *cryptoSetupServer) HandleCryptoStream() error { } func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) { - if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment { - return false, ErrHOLExperiment - } if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment { return false, ErrNSTPExperiment } @@ -214,6 +206,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client + h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.") h.receivedForwardSecurePacket = true // wait for the send on the handshakeEvent chan <-h.sentSHLO @@ -228,6 +221,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if h.secureAEAD != nil { res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { + h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.") h.receivedSecurePacket = true return res, protocol.EncryptionSecure, nil } @@ -400,6 +394,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T if err != nil { return nil, err } + h.logger.Debugf("Creating AEAD for secure encryption.") h.handshakeEvent <- struct{}{} // Generate a new curve instance to derive the forward secure key @@ -429,6 +424,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T if err != nil { return nil, err } + h.logger.Debugf("Creating AEAD for forward-secure encryption.") replyMap := h.params.getHelloMap() // add crypto parameters diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go index 43f85406..2e5dd025 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go @@ -11,9 +11,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) -// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry -var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry") - // KeyDerivationFunction is used for key derivation type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) @@ -26,8 +23,8 @@ type cryptoSetupTLS struct { nullAEAD crypto.AEAD aead crypto.AEAD - tls MintTLS - cryptoStream *CryptoStreamConn + tls mintTLS + conn *cryptoStreamConn handshakeEvent chan<- struct{} } @@ -35,39 +32,46 @@ var _ CryptoSetupTLS = &cryptoSetupTLS{} // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server func NewCryptoSetupTLSServer( - tls MintTLS, - cryptoStream *CryptoStreamConn, - nullAEAD crypto.AEAD, + cryptoStream io.ReadWriter, + connID protocol.ConnectionID, + config *mint.Config, handshakeEvent chan<- struct{}, version protocol.VersionNumber, -) CryptoSetupTLS { +) (CryptoSetupTLS, error) { + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) + if err != nil { + return nil, err + } + conn := newCryptoStreamConn(cryptoStream) + tls := mint.Server(conn, config) return &cryptoSetupTLS{ tls: tls, - cryptoStream: cryptoStream, + conn: conn, nullAEAD: nullAEAD, perspective: protocol.PerspectiveServer, keyDerivation: crypto.DeriveAESKeys, handshakeEvent: handshakeEvent, - } + }, nil } // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client func NewCryptoSetupTLSClient( cryptoStream io.ReadWriter, connID protocol.ConnectionID, - hostname string, + config *mint.Config, handshakeEvent chan<- struct{}, - tls MintTLS, version protocol.VersionNumber, ) (CryptoSetupTLS, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) if err != nil { return nil, err } - + conn := newCryptoStreamConn(cryptoStream) + tls := mint.Client(conn, config) return &cryptoSetupTLS{ - perspective: protocol.PerspectiveClient, tls: tls, + conn: conn, + perspective: protocol.PerspectiveClient, nullAEAD: nullAEAD, keyDerivation: crypto.DeriveAESKeys, handshakeEvent: handshakeEvent, @@ -75,24 +79,16 @@ func NewCryptoSetupTLSClient( } func (h *cryptoSetupTLS) HandleCryptoStream() error { - if h.perspective == protocol.PerspectiveServer { - // mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer - // send out that data now - if _, err := h.cryptoStream.Flush(); err != nil { - return err - } - } - -handshakeLoop: for { if alert := h.tls.Handshake(); alert != mint.AlertNoAlert { return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) } - switch h.tls.State() { - case mint.StateClientStart: // this happens if a stateless retry is performed - return ErrCloseSessionForRetry - case mint.StateClientConnected, mint.StateServerConnected: - break handshakeLoop + state := h.tls.ConnectionState().HandshakeState + if err := h.conn.Flush(); err != nil { + return err + } + if state == mint.StateClientConnected || state == mint.StateServerConnected { + break } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go index 03825c41..a031f90c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go @@ -7,95 +7,63 @@ import ( "time" ) -// The CryptoStreamConn is used as the net.Conn passed to mint. -// It has two operating modes: -// 1. It can read and write to bytes.Buffers. -// 2. It can use a quic.Stream for reading and writing. -// The buffer-mode is only used by the server, in order to statelessly handle retries. -type CryptoStreamConn struct { - remoteAddr net.Addr - - // the buffers are used before the session is initialized - readBuf bytes.Buffer - writeBuf bytes.Buffer - - // stream will be set once the session is initialized +type cryptoStreamConn struct { + buffer *bytes.Buffer stream io.ReadWriter } -var _ net.Conn = &CryptoStreamConn{} +var _ net.Conn = &cryptoStreamConn{} -// NewCryptoStreamConn creates a new CryptoStreamConn -func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn { - return &CryptoStreamConn{remoteAddr: remoteAddr} -} - -func (c *CryptoStreamConn) Read(b []byte) (int, error) { - if c.stream != nil { - return c.stream.Read(b) +func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn { + return &cryptoStreamConn{ + stream: stream, + buffer: &bytes.Buffer{}, } - return c.readBuf.Read(b) } -// AddDataForReading adds data to the read buffer. -// This data will ONLY be read when the stream has not been set. -func (c *CryptoStreamConn) AddDataForReading(data []byte) { - c.readBuf.Write(data) +func (c *cryptoStreamConn) Read(b []byte) (int, error) { + return c.stream.Read(b) } -func (c *CryptoStreamConn) Write(p []byte) (int, error) { - if c.stream != nil { - return c.stream.Write(p) +func (c *cryptoStreamConn) Write(p []byte) (int, error) { + return c.buffer.Write(p) +} + +func (c *cryptoStreamConn) Flush() error { + if c.buffer.Len() == 0 { + return nil } - return c.writeBuf.Write(p) -} - -// GetDataForWriting returns all data currently in the write buffer, and resets this buffer. -func (c *CryptoStreamConn) GetDataForWriting() []byte { - defer c.writeBuf.Reset() - data := make([]byte, c.writeBuf.Len()) - copy(data, c.writeBuf.Bytes()) - return data -} - -// SetStream sets the stream. -// After setting the stream, the read and write buffer won't be used any more. -func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) { - c.stream = stream -} - -// Flush copies the contents of the write buffer to the stream -func (c *CryptoStreamConn) Flush() (int, error) { - n, err := io.Copy(c.stream, &c.writeBuf) - return int(n), err + _, err := c.stream.Write(c.buffer.Bytes()) + c.buffer.Reset() + return err } // Close is not implemented -func (c *CryptoStreamConn) Close() error { +func (c *cryptoStreamConn) Close() error { return nil } // LocalAddr is not implemented -func (c *CryptoStreamConn) LocalAddr() net.Addr { +func (c *cryptoStreamConn) LocalAddr() net.Addr { return nil } -// RemoteAddr returns the remote address -func (c *CryptoStreamConn) RemoteAddr() net.Addr { - return c.remoteAddr +// RemoteAddr is not implemented +func (c *cryptoStreamConn) RemoteAddr() net.Addr { + return nil } // SetReadDeadline is not implemented -func (c *CryptoStreamConn) SetReadDeadline(time.Time) error { +func (c *cryptoStreamConn) SetReadDeadline(time.Time) error { return nil } // SetWriteDeadline is not implemented -func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error { +func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error { return nil } // SetDeadline is not implemented -func (c *CryptoStreamConn) SetDeadline(time.Time) error { +func (c *cryptoStreamConn) SetDeadline(time.Time) error { return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go index 72871583..eb1824d9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go @@ -35,7 +35,7 @@ func getEphermalKEX() (crypto.KeyExchange, error) { kexMutex.Lock() defer kexMutex.Unlock() // Check if still unfulfilled - if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime { + if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime { kex, err := crypto.NewCurve25519KEX() if err != nil { return nil, err diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go index 8d8fd545..5813fdd9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go @@ -2,7 +2,6 @@ package handshake import ( "crypto/x509" - "io" "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/crypto" @@ -15,6 +14,12 @@ type Sealer interface { Overhead() int } +// mintTLS combines some methods needed to interact with mint. +type mintTLS interface { + crypto.TLSExporter + Handshake() mint.Alert +} + // A TLSExtensionHandler sends and received the QUIC TLS extension. // It provides the parameters sent by the peer on a channel. type TLSExtensionHandler interface { @@ -23,18 +28,6 @@ type TLSExtensionHandler interface { GetPeerParams() <-chan TransportParameters } -// MintTLS combines some methods needed to interact with mint. -type MintTLS interface { - crypto.TLSExporter - - // additional methods - Handshake() mint.Alert - State() mint.State - ConnectionState() mint.ConnectionState - - SetCryptoStream(io.ReadWriter) -} - type baseCryptoSetup interface { HandleCryptoStream() error ConnectionState() ConnectionState diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go new file mode 100644 index 00000000..86232720 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go @@ -0,0 +1,3 @@ +package handshake + +//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS" diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go index 19ec78d3..cf2a7562 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go @@ -50,10 +50,6 @@ const ( // TagSFCW is the initial stream flow control receive window. TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24 - // TagFHL2 forces head of line blocking. - // Chrome experiment (see https://codereview.chromium.org/2115033002) - // unsupported by quic-go - TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24 // TagNSTP is the no STOP_WAITING experiment // currently unsupported by quic-go TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go index 98ad3a57..b3665dfe 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go @@ -1,38 +1,106 @@ package handshake import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) type transportParameterID uint16 -const quicTLSExtensionType = 26 +const quicTLSExtensionType = 0xff5 const ( initialMaxStreamDataParameterID transportParameterID = 0x0 initialMaxDataParameterID transportParameterID = 0x1 - initialMaxStreamsBiDiParameterID transportParameterID = 0x2 + initialMaxBidiStreamsParameterID transportParameterID = 0x2 idleTimeoutParameterID transportParameterID = 0x3 - omitConnectionIDParameterID transportParameterID = 0x4 maxPacketSizeParameterID transportParameterID = 0x5 statelessResetTokenParameterID transportParameterID = 0x6 - initialMaxStreamsUniParameterID transportParameterID = 0x8 + initialMaxUniStreamsParameterID transportParameterID = 0x8 + disableMigrationParameterID transportParameterID = 0x9 ) -type transportParameter struct { - Parameter transportParameterID - Value []byte `tls:"head=2"` +type clientHelloTransportParameters struct { + InitialVersion protocol.VersionNumber + Parameters TransportParameters } -type clientHelloTransportParameters struct { - InitialVersion uint32 // actually a protocol.VersionNumber - Parameters []transportParameter `tls:"head=2"` +func (p *clientHelloTransportParameters) Marshal() []byte { + const lenOffset = 4 + b := &bytes.Buffer{} + utils.BigEndian.WriteUint32(b, uint32(p.InitialVersion)) + b.Write([]byte{0, 0}) // length. Will be replaced later + p.Parameters.marshal(b) + data := b.Bytes() + binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2)) + return data +} + +func (p *clientHelloTransportParameters) Unmarshal(data []byte) error { + if len(data) < 6 { + return errors.New("transport parameter data too short") + } + p.InitialVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4])) + paramsLen := int(binary.BigEndian.Uint16(data[4:6])) + data = data[6:] + if len(data) != paramsLen { + return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data)) + } + return p.Parameters.unmarshal(data) } type encryptedExtensionsTransportParameters struct { - NegotiatedVersion uint32 // actually a protocol.VersionNumber - SupportedVersions []uint32 `tls:"head=1"` // actually a protocol.VersionNumber - Parameters []transportParameter `tls:"head=2"` + NegotiatedVersion protocol.VersionNumber + SupportedVersions []protocol.VersionNumber + Parameters TransportParameters +} + +func (p *encryptedExtensionsTransportParameters) Marshal() []byte { + b := &bytes.Buffer{} + utils.BigEndian.WriteUint32(b, uint32(p.NegotiatedVersion)) + b.WriteByte(uint8(4 * len(p.SupportedVersions))) + for _, v := range p.SupportedVersions { + utils.BigEndian.WriteUint32(b, uint32(v)) + } + lenOffset := b.Len() + b.Write([]byte{0, 0}) // length. Will be replaced later + p.Parameters.marshal(b) + data := b.Bytes() + binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2)) + return data +} + +func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error { + if len(data) < 5 { + return errors.New("transport parameter data too short") + } + p.NegotiatedVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4])) + numVersions := int(data[4]) + if numVersions%4 != 0 { + return fmt.Errorf("invalid length for version list: %d", numVersions) + } + numVersions /= 4 + data = data[5:] + if len(data) < 4*numVersions+2 /*length field for the parameter list */ { + return errors.New("transport parameter data too short") + } + p.SupportedVersions = make([]protocol.VersionNumber, numVersions) + for i := 0; i < numVersions; i++ { + p.SupportedVersions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4])) + data = data[4:] + } + paramsLen := int(binary.BigEndian.Uint16(data[:2])) + data = data[2:] + if len(data) != paramsLen { + return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data)) + } + return p.Parameters.unmarshal(data) } type tlsExtensionBody struct { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go index 8e711be5..d03021ae 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go @@ -7,7 +7,6 @@ import ( "github.com/lucas-clemente/quic-go/qerr" "github.com/bifurcation/mint" - "github.com/bifurcation/mint/syntax" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -52,16 +51,12 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi if hType != mint.HandshakeTypeClientHello { return nil } - h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) - data, err := syntax.Marshal(clientHelloTransportParameters{ - InitialVersion: uint32(h.initialVersion), - Parameters: h.ourParams.getTransportParameters(), - }) - if err != nil { - return err + chtp := &clientHelloTransportParameters{ + InitialVersion: h.initialVersion, + Parameters: *h.ourParams, } - return el.Add(&tlsExtensionBody{data}) + return el.Add(&tlsExtensionBody{data: chtp.Marshal()}) } func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { @@ -84,50 +79,31 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte } eetp := &encryptedExtensionsTransportParameters{} - if _, err := syntax.Unmarshal(ext.data, eetp); err != nil { + if err := eetp.Unmarshal(ext.data); err != nil { return err } - serverSupportedVersions := make([]protocol.VersionNumber, len(eetp.SupportedVersions)) - for i, v := range eetp.SupportedVersions { - serverSupportedVersions[i] = protocol.VersionNumber(v) - } // check that the negotiated_version is the current version - if protocol.VersionNumber(eetp.NegotiatedVersion) != h.version { + if eetp.NegotiatedVersion != h.version { return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version") } // check that the current version is included in the supported versions - if !protocol.IsSupportedVersion(serverSupportedVersions, h.version) { + if !protocol.IsSupportedVersion(eetp.SupportedVersions, h.version) { return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions") } // if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server if h.version != h.initialVersion { - negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, serverSupportedVersions) + negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, eetp.SupportedVersions) if !ok || h.version != negotiatedVersion { return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version") } } - // check that the server sent the stateless reset token - var foundStatelessResetToken bool - for _, p := range eetp.Parameters { - if p.Parameter == statelessResetTokenParameterID { - if len(p.Value) != 16 { - return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", len(p.Value)) - } - foundStatelessResetToken = true - // TODO: handle this value - } - } - if !foundStatelessResetToken { - // TODO: return the right error here + // check that the server sent a stateless reset token + if len(eetp.Parameters.StatelessResetToken) == 0 { return errors.New("server didn't sent stateless_reset_token") } - params, err := readTransportParameters(eetp.Parameters) - if err != nil { - return err - } - h.logger.Debugf("Received Transport Parameters: %s", params) - h.paramsChan <- *params + h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters) + h.paramsChan <- eetp.Parameters return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go index 138fc21b..2d75d693 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go @@ -1,14 +1,12 @@ package handshake import ( - "bytes" "errors" "fmt" "github.com/lucas-clemente/quic-go/qerr" "github.com/bifurcation/mint" - "github.com/bifurcation/mint/syntax" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -49,27 +47,13 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi if hType != mint.HandshakeTypeEncryptedExtensions { return nil } - - transportParams := append( - h.ourParams.getTransportParameters(), - // TODO(#855): generate a real token - transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, - ) - supportedVersions := protocol.GetGreasedVersions(h.supportedVersions) - versions := make([]uint32, len(supportedVersions)) - for i, v := range supportedVersions { - versions[i] = uint32(v) - } h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) - data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ - NegotiatedVersion: uint32(h.version), - SupportedVersions: versions, - Parameters: transportParams, - }) - if err != nil { - return err + eetp := &encryptedExtensionsTransportParameters{ + NegotiatedVersion: h.version, + SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions), + Parameters: *h.ourParams, } - return el.Add(&tlsExtensionBody{data}) + return el.Add(&tlsExtensionBody{data: eetp.Marshal()}) } func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { @@ -90,30 +74,24 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte return errors.New("ClientHello didn't contain a QUIC extension") } chtp := &clientHelloTransportParameters{} - if _, err := syntax.Unmarshal(ext.data, chtp); err != nil { + if err := chtp.Unmarshal(ext.data); err != nil { return err } - initialVersion := protocol.VersionNumber(chtp.InitialVersion) // perform the stateless version negotiation validation: // make sure that we would have sent a Version Negotiation Packet if the client offered the initial version // this is the case if and only if the initial version is not contained in the supported versions - if initialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) { + if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) { return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version") } - for _, p := range chtp.Parameters { - if p.Parameter == statelessResetTokenParameterID { - // TODO: return the correct error type - return errors.New("client sent a stateless reset token") - } + // check that the client didn't send a stateless reset token + if len(chtp.Parameters.StatelessResetToken) != 0 { + // TODO: return the correct error type + return errors.New("client sent a stateless reset token") } - params, err := readTransportParameters(chtp.Parameters) - if err != nil { - return err - } - h.logger.Debugf("Received Transport Parameters: %s", params) - h.paramsChan <- *params + h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters) + h.paramsChan <- chtp.Parameters return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go index fce1e3f2..9ed0aeb5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go @@ -26,8 +26,10 @@ type TransportParameters struct { MaxBidiStreams uint16 // only used for IETF QUIC MaxStreams uint32 // only used for gQUIC - OmitConnectionID bool - IdleTimeout time.Duration + OmitConnectionID bool // only used for gQUIC + IdleTimeout time.Duration + DisableMigration bool // only used for IETF QUIC + StatelessResetToken []byte // only used for IETF QUIC } // readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message @@ -94,98 +96,114 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte { return tags } -// readTransportParameters reads the transport parameters sent in the QUIC TLS extension -func readTransportParameters(paramsList []transportParameter) (*TransportParameters, error) { - params := &TransportParameters{} - - var foundInitialMaxStreamData bool - var foundInitialMaxData bool +func (p *TransportParameters) unmarshal(data []byte) error { var foundIdleTimeout bool - for _, p := range paramsList { - switch p.Parameter { + for len(data) >= 4 { + paramID := binary.BigEndian.Uint16(data[:2]) + paramLen := int(binary.BigEndian.Uint16(data[2:4])) + data = data[4:] + if len(data) < paramLen { + return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen) + } + switch transportParameterID(paramID) { case initialMaxStreamDataParameterID: - foundInitialMaxStreamData = true - if len(p.Value) != 4 { - return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value)) + if paramLen != 4 { + return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen) } - params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) + p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4])) case initialMaxDataParameterID: - foundInitialMaxData = true - if len(p.Value) != 4 { - return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value)) + if paramLen != 4 { + return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen) } - params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) - case initialMaxStreamsBiDiParameterID: - if len(p.Value) != 2 { - return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", len(p.Value)) + p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4])) + case initialMaxBidiStreamsParameterID: + if paramLen != 2 { + return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen) } - params.MaxBidiStreams = binary.BigEndian.Uint16(p.Value) - case initialMaxStreamsUniParameterID: - if len(p.Value) != 2 { - return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", len(p.Value)) + p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2]) + case initialMaxUniStreamsParameterID: + if paramLen != 2 { + return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen) } - params.MaxUniStreams = binary.BigEndian.Uint16(p.Value) + p.MaxUniStreams = binary.BigEndian.Uint16(data[:2]) case idleTimeoutParameterID: foundIdleTimeout = true - if len(p.Value) != 2 { - return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) + if paramLen != 2 { + return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen) } - params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second) - case omitConnectionIDParameterID: - if len(p.Value) != 0 { - return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) - } - params.OmitConnectionID = true + p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second) case maxPacketSizeParameterID: - if len(p.Value) != 2 { - return nil, fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", len(p.Value)) + if paramLen != 2 { + return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen) } - maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(p.Value)) + maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2])) if maxPacketSize < 1200 { - return nil, fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize) + return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize) } - params.MaxPacketSize = maxPacketSize + p.MaxPacketSize = maxPacketSize + case disableMigrationParameterID: + if paramLen != 0 { + return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen) + } + p.DisableMigration = true + case statelessResetTokenParameterID: + if paramLen != 16 { + return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) + } + p.StatelessResetToken = data[:16] } + data = data[paramLen:] } - if !(foundInitialMaxStreamData && foundInitialMaxData && foundIdleTimeout) { - return nil, errors.New("missing parameter") + if len(data) != 0 { + return fmt.Errorf("should have read all data. Still have %d bytes", len(data)) } - return params, nil + if !foundIdleTimeout { + return errors.New("missing parameter") + } + return nil } -// GetTransportParameters gets the parameters needed for the TLS handshake. -// It doesn't send the initial_max_stream_id_uni parameter, so the peer isn't allowed to open any unidirectional streams. -func (p *TransportParameters) getTransportParameters() []transportParameter { - initialMaxStreamData := make([]byte, 4) - binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow)) - initialMaxData := make([]byte, 4) - binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow)) - initialMaxBidiStreamID := make([]byte, 2) - binary.BigEndian.PutUint16(initialMaxBidiStreamID, p.MaxBidiStreams) - initialMaxUniStreamID := make([]byte, 2) - binary.BigEndian.PutUint16(initialMaxUniStreamID, p.MaxUniStreams) - idleTimeout := make([]byte, 2) - binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second)) - maxPacketSize := make([]byte, 2) - binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize)) - params := []transportParameter{ - {initialMaxStreamDataParameterID, initialMaxStreamData}, - {initialMaxDataParameterID, initialMaxData}, - {initialMaxStreamsBiDiParameterID, initialMaxBidiStreamID}, - {initialMaxStreamsUniParameterID, initialMaxUniStreamID}, - {idleTimeoutParameterID, idleTimeout}, - {maxPacketSizeParameterID, maxPacketSize}, +func (p *TransportParameters) marshal(b *bytes.Buffer) { + // initial_max_stream_data + utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID)) + utils.BigEndian.WriteUint16(b, 4) + utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow)) + // initial_max_data + utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID)) + utils.BigEndian.WriteUint16(b, 4) + utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow)) + // initial_max_bidi_streams + utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID)) + utils.BigEndian.WriteUint16(b, 2) + utils.BigEndian.WriteUint16(b, p.MaxBidiStreams) + // initial_max_uni_streams + utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID)) + utils.BigEndian.WriteUint16(b, 2) + utils.BigEndian.WriteUint16(b, p.MaxUniStreams) + // idle_timeout + utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID)) + utils.BigEndian.WriteUint16(b, 2) + utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second)) + // max_packet_size + utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID)) + utils.BigEndian.WriteUint16(b, 2) + utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize)) + // disable_migration + if p.DisableMigration { + utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID)) + utils.BigEndian.WriteUint16(b, 0) } - if p.OmitConnectionID { - params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}}) + if len(p.StatelessResetToken) > 0 { + utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID)) + utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes + b.Write(p.StatelessResetToken) } - return params } // String returns a string representation, intended for logging. // It should only used for IETF QUIC. func (p *TransportParameters) String() string { - return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, OmitConnectionID: %t, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.OmitConnectionID, p.IdleTimeout) + return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go index b11d999a..aff5a1e1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/sent_packet_handler.go @@ -49,6 +49,19 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission)) } +// DequeueProbePacket mocks base method +func (m *MockSentPacketHandler) DequeueProbePacket() (*ackhandler.Packet, error) { + ret := m.ctrl.Call(m, "DequeueProbePacket") + ret0, _ := ret[0].(*ackhandler.Packet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DequeueProbePacket indicates an expected call of DequeueProbePacket +func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket)) +} + // GetAlarmTimeout mocks base method func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time { ret := m.ctrl.Call(m, "GetAlarmTimeout") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go index b86ccfea..d0749252 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go @@ -68,13 +68,13 @@ func (mr *MockSendAlgorithmMockRecorder) OnConnectionMigration() *gomock.Call { } // OnPacketAcked mocks base method -func (m *MockSendAlgorithm) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) { - m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2) +func (m *MockSendAlgorithm) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) { + m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3) } // OnPacketAcked indicates an expected call of OnPacketAcked -func (mr *MockSendAlgorithmMockRecorder) OnPacketAcked(arg0, arg1, arg2 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketAcked), arg0, arg1, arg2) +func (mr *MockSendAlgorithmMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) } // OnPacketLost mocks base method @@ -88,10 +88,8 @@ func (mr *MockSendAlgorithmMockRecorder) OnPacketLost(arg0, arg1, arg2 interface } // OnPacketSent mocks base method -func (m *MockSendAlgorithm) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) bool { - ret := m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(bool) - return ret0 +func (m *MockSendAlgorithm) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) { + m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4) } // OnPacketSent indicates an expected call of OnPacketSent diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go index ae10e785..1a47362b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go @@ -79,6 +79,16 @@ func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) } +// MaybeQueueWindowUpdate mocks base method +func (m *MockConnectionFlowController) MaybeQueueWindowUpdate() { + m.ctrl.Call(m, "MaybeQueueWindowUpdate") +} + +// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate +func (mr *MockConnectionFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).MaybeQueueWindowUpdate)) +} + // SendWindowSize mocks base method func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { ret := m.ctrl.Call(m, "SendWindowSize") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go deleted file mode 100644 index bd33e7d0..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go +++ /dev/null @@ -1,11 +0,0 @@ -package mocks - -//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS" -//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler" -//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" -//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler" -//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler" -//go:generate sh -c "./mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm" -//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD" -//go:generate sh -c "goimports -w ." diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go deleted file mode 100644 index 0a0714db..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go +++ /dev/null @@ -1,107 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS) - -// Package mockhandshake is a generated GoMock package. -package mockhandshake - -import ( - io "io" - reflect "reflect" - - mint "github.com/bifurcation/mint" - gomock "github.com/golang/mock/gomock" -) - -// MockMintTLS is a mock of MintTLS interface -type MockMintTLS struct { - ctrl *gomock.Controller - recorder *MockMintTLSMockRecorder -} - -// MockMintTLSMockRecorder is the mock recorder for MockMintTLS -type MockMintTLSMockRecorder struct { - mock *MockMintTLS -} - -// NewMockMintTLS creates a new mock instance -func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS { - mock := &MockMintTLS{ctrl: ctrl} - mock.recorder = &MockMintTLSMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder { - return m.recorder -} - -// ComputeExporter mocks base method -func (m *MockMintTLS) ComputeExporter(arg0 string, arg1 []byte, arg2 int) ([]byte, error) { - ret := m.ctrl.Call(m, "ComputeExporter", arg0, arg1, arg2) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ComputeExporter indicates an expected call of ComputeExporter -func (mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2) -} - -// ConnectionState mocks base method -func (m *MockMintTLS) ConnectionState() mint.ConnectionState { - ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(mint.ConnectionState) - return ret0 -} - -// ConnectionState indicates an expected call of ConnectionState -func (mr *MockMintTLSMockRecorder) ConnectionState() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockMintTLS)(nil).ConnectionState)) -} - -// GetCipherSuite mocks base method -func (m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams { - ret := m.ctrl.Call(m, "GetCipherSuite") - ret0, _ := ret[0].(mint.CipherSuiteParams) - return ret0 -} - -// GetCipherSuite indicates an expected call of GetCipherSuite -func (mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite)) -} - -// Handshake mocks base method -func (m *MockMintTLS) Handshake() mint.Alert { - ret := m.ctrl.Call(m, "Handshake") - ret0, _ := ret[0].(mint.Alert) - return ret0 -} - -// Handshake indicates an expected call of Handshake -func (mr *MockMintTLSMockRecorder) Handshake() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake)) -} - -// SetCryptoStream mocks base method -func (m *MockMintTLS) SetCryptoStream(arg0 io.ReadWriter) { - m.ctrl.Call(m, "SetCryptoStream", arg0) -} - -// SetCryptoStream indicates an expected call of SetCryptoStream -func (mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0) -} - -// State mocks base method -func (m *MockMintTLS) State() mint.State { - ret := m.ctrl.Call(m, "State") - ret0, _ := ret[0].(mint.State) - return ret0 -} - -// State indicates an expected call of State -func (mr *MockMintTLSMockRecorder) State() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State)) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/mockgen.go new file mode 100644 index 00000000..5bfcdc22 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/mockgen.go @@ -0,0 +1,9 @@ +package mocks + +//go:generate sh -c "../mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler" +//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" +//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler" +//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "../mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm" +//go:generate sh -c "../mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "../mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD" diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go index a69e73f1..955f5509 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go @@ -66,29 +66,27 @@ func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) } -// HasWindowUpdate mocks base method -func (m *MockStreamFlowController) HasWindowUpdate() bool { - ret := m.ctrl.Call(m, "HasWindowUpdate") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasWindowUpdate indicates an expected call of HasWindowUpdate -func (mr *MockStreamFlowControllerMockRecorder) HasWindowUpdate() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).HasWindowUpdate)) -} - -// IsBlocked mocks base method -func (m *MockStreamFlowController) IsBlocked() (bool, protocol.ByteCount) { - ret := m.ctrl.Call(m, "IsBlocked") +// IsNewlyBlocked mocks base method +func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + ret := m.ctrl.Call(m, "IsNewlyBlocked") ret0, _ := ret[0].(bool) ret1, _ := ret[1].(protocol.ByteCount) return ret0, ret1 } -// IsBlocked indicates an expected call of IsBlocked -func (mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked)) +// IsNewlyBlocked indicates an expected call of IsNewlyBlocked +func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) +} + +// MaybeQueueWindowUpdate mocks base method +func (m *MockStreamFlowController) MaybeQueueWindowUpdate() { + m.ctrl.Call(m, "MaybeQueueWindowUpdate") +} + +// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate +func (mr *MockStreamFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).MaybeQueueWindowUpdate)) } // SendWindowSize mocks base method diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go new file mode 100644 index 00000000..f99461b2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go @@ -0,0 +1,69 @@ +package protocol + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" +) + +// A ConnectionID in QUIC +type ConnectionID []byte + +const maxConnectionIDLen = 18 + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID(len int) (ConnectionID, error) { + b := make([]byte, len) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return ConnectionID(b), nil +} + +// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. +// It uses a length randomly chosen between 8 and 18 bytes. +func GenerateConnectionIDForInitial() (ConnectionID, error) { + r := make([]byte, 1) + if _, err := rand.Read(r); err != nil { + return nil, err + } + len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(len) +} + +// ReadConnectionID reads a connection ID of length len from the given io.Reader. +// It returns io.EOF if there are not enough bytes to read. +func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { + if len == 0 { + return nil, nil + } + c := make(ConnectionID, len) + _, err := io.ReadFull(r, c) + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return c, err +} + +// Equal says if two connection IDs are equal +func (c ConnectionID) Equal(other ConnectionID) bool { + return bytes.Equal(c, other) +} + +// Len returns the length of the connection ID in bytes +func (c ConnectionID) Len() int { + return len(c) +} + +// Bytes returns the byte representation +func (c ConnectionID) Bytes() []byte { + return []byte(c) +} + +func (c ConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%#x", c.Bytes()) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go index 4bc8bfc9..41f002f5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go @@ -1,8 +1,25 @@ package protocol // InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number -func InferPacketNumber(packetNumberLength PacketNumberLen, lastPacketNumber PacketNumber, wirePacketNumber PacketNumber) PacketNumber { - epochDelta := PacketNumber(1) << (uint8(packetNumberLength) * 8) +func InferPacketNumber( + packetNumberLength PacketNumberLen, + lastPacketNumber PacketNumber, + wirePacketNumber PacketNumber, + version VersionNumber, +) PacketNumber { + var epochDelta PacketNumber + if version.UsesVarintPacketNumbers() { + switch packetNumberLength { + case PacketNumberLen1: + epochDelta = PacketNumber(1) << 7 + case PacketNumberLen2: + epochDelta = PacketNumber(1) << 14 + case PacketNumberLen4: + epochDelta = PacketNumber(1) << 30 + } + } else { + epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8) + } epoch := lastPacketNumber & ^(epochDelta - 1) prevEpochBegin := epoch - epochDelta nextEpochBegin := epoch + epochDelta @@ -29,9 +46,10 @@ func delta(a, b PacketNumber) PacketNumber { // GetPacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen { +func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen { diff := uint64(packetNumber - leastUnacked) - if diff < (1 << (uint8(PacketNumberLen2)*8 - 1)) { + if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) || + !version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) { return PacketNumberLen2 } return PacketNumberLen4 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go index 948e371a..43358fec 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go @@ -9,6 +9,11 @@ const ( PerspectiveClient Perspective = 2 ) +// Opposite returns the perspective of the peer +func (p Perspective) Opposite() Perspective { + return 3 - p +} + func (p Perspective) String() string { switch p { case PerspectiveServer: diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go index 2821d2cd..2a52f895 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go @@ -52,9 +52,6 @@ func (t PacketType) String() string { } } -// A ConnectionID in QUIC -type ConnectionID uint64 - // A ByteCount in QUIC type ByteCount uint64 @@ -85,3 +82,9 @@ const MinInitialPacketSize = 1200 // * one failure due to an incorrect or missing source-address token // * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token const MaxClientHellos = 3 + +// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets. +const ConnectionIDLenGQUIC = 8 + +// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. +const MinConnectionIDLenInitial = 8 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go index 96ebbdcd..aa61b3d7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go @@ -12,11 +12,13 @@ const MaxPacketSizeIPv6 = 1232 // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames const NonForwardSecurePacketSizeReduction = 50 +const defaultMaxCongestionWindowPackets = 1000 + // DefaultMaxCongestionWindow is the default for the max congestion window -const DefaultMaxCongestionWindow = 1000 +const DefaultMaxCongestionWindow ByteCount = defaultMaxCongestionWindowPackets * DefaultTCPMSS // InitialCongestionWindow is the initial congestion window in QUIC packets -const InitialCongestionWindow = 32 +const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS // MaxUndecryptablePackets limits the number of undecryptable packets that a // session queues for later until it sends a public reset. @@ -70,7 +72,7 @@ const MaxStreamsMultiplier = 1.1 const MaxStreamsMinimumIncrement = 10 // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. -const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow +const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets // SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack const SkipPacketAveragePeriodLength PacketNumber = 500 @@ -84,7 +86,7 @@ const CookieExpiryTime = 24 * time.Hour // MaxOutstandingSentPackets is maximum number of packets saved for retransmission. // When reached, it imposes a soft limit on sending new packets: // Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. -const MaxOutstandingSentPackets = 2 * DefaultMaxCongestionWindow +const MaxOutstandingSentPackets = 2 * defaultMaxCongestionWindowPackets // MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. // When reached, no more packets will be sent. @@ -92,7 +94,7 @@ const MaxOutstandingSentPackets = 2 * DefaultMaxCongestionWindow const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked -const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow +const MaxTrackedReceivedAckRanges = defaultMaxCongestionWindowPackets // MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row const MaxNonRetransmittableAcks = 19 @@ -133,7 +135,20 @@ const NumCachedCertificates = 128 // 2. it reduces the head-of-line blocking, when a packet is lost const MinStreamFrameSize ByteCount = 128 +// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write +// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. +// The MaxAckFrameSize should be large enough to encode many ACK range, +// but must ensure that a maximum size ACK frame fits into one packet. +const MaxAckFrameSize ByteCount = 1000 + // MinPacingDelay is the minimum duration that is used for packet pacing // If the packet packing frequency is higher, multiple packets might be sent at once. // Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth. const MinPacingDelay time.Duration = 100 * time.Microsecond + +// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections +// if no other value is configured. +const DefaultConnectionIDLength = 4 + +// MaxRetries is the maximum number of Retries a client will do before failing the connection. +const MaxRetries = 3 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go index d5f2f37b..9e1963dc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go @@ -18,26 +18,32 @@ const ( // The version numbers, making grepping easier const ( - Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9 + iota + Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9 + Version43 VersionNumber = gquicVersion0 + 4*0x100 + 0x3 + Version44 VersionNumber = gquicVersion0 + 4*0x100 + 0x4 VersionTLS VersionNumber = 101 VersionWhatever VersionNumber = 0 // for when the version doesn't matter VersionUnknown VersionNumber = math.MaxUint32 + + VersionMilestone0_10_0 VersionNumber = 0x51474f02 ) // SupportedVersions lists the versions that the server supports // must be in sorted descending order var SupportedVersions = []VersionNumber{ + Version44, + Version43, Version39, } // IsValidVersion says if the version is known to quic-go func IsValidVersion(v VersionNumber) bool { - return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) + return v == VersionTLS || v == VersionMilestone0_10_0 || IsSupportedVersion(SupportedVersions, v) } // UsesTLS says if this QUIC version uses TLS 1.3 for the handshake func (vn VersionNumber) UsesTLS() bool { - return vn == VersionTLS + return !vn.isGQUIC() } func (vn VersionNumber) String() string { @@ -46,6 +52,8 @@ func (vn VersionNumber) String() string { return "whatever" case VersionUnknown: return "unknown" + case VersionMilestone0_10_0: + return "quic-go Milestone 0.10.0" case VersionTLS: return "TLS dev version (WIP)" default: @@ -74,12 +82,32 @@ func (vn VersionNumber) CryptoStreamID() StreamID { // UsesIETFFrameFormat tells if this version uses the IETF frame format func (vn VersionNumber) UsesIETFFrameFormat() bool { - return vn != Version39 + return !vn.isGQUIC() +} + +// UsesIETFHeaderFormat tells if this version uses the IETF header format +func (vn VersionNumber) UsesIETFHeaderFormat() bool { + return !vn.isGQUIC() || vn >= Version44 +} + +// UsesLengthInHeader tells if this version uses the Length field in the IETF header +func (vn VersionNumber) UsesLengthInHeader() bool { + return !vn.isGQUIC() +} + +// UsesTokenInHeader tells if this version uses the Token field in the IETF header +func (vn VersionNumber) UsesTokenInHeader() bool { + return !vn.isGQUIC() } // UsesStopWaitingFrames tells if this version uses STOP_WAITING frames func (vn VersionNumber) UsesStopWaitingFrames() bool { - return vn == Version39 + return vn.isGQUIC() && vn <= Version43 +} + +// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers +func (vn VersionNumber) UsesVarintPacketNumbers() bool { + return !vn.isGQUIC() } // StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control @@ -144,3 +172,14 @@ func GetGreasedVersions(supported []VersionNumber) []VersionNumber { copy(greased[randPos+1:], supported[randPos:]) return greased } + +// StripGreasedVersions strips all greased versions from a slice of versions +func StripGreasedVersions(versions []VersionNumber) []VersionNumber { + realVersions := make([]VersionNumber, 0, len(versions)) + for _, v := range versions { + if v&0x0f0f0f0f != 0x0a0a0a0a { + realVersions = append(realVersions, v) + } + } + return realVersions +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go deleted file mode 100644 index b4af4e78..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go +++ /dev/null @@ -1,18 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/binary" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID() (protocol.ConnectionID, error) { - b := make([]byte, 8) - _, err := rand.Read(b) - if err != nil { - return 0, err - } - return protocol.ConnectionID(binary.LittleEndian.Uint64(b)), nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go index 62a3d075..e27f01b4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go @@ -28,6 +28,7 @@ const logEnv = "QUIC_GO_LOG_LEVEL" type Logger interface { SetLogLevel(LogLevel) SetLogTimeFormat(format string) + WithPrefix(prefix string) Logger Debug() bool Errorf(format string, args ...interface{}) @@ -39,6 +40,8 @@ type Logger interface { var DefaultLogger Logger type defaultLogger struct { + prefix string + logLevel LogLevel timeFormat string } @@ -79,10 +82,25 @@ func (l *defaultLogger) Errorf(format string, args ...interface{}) { } func (l *defaultLogger) logMessage(format string, args ...interface{}) { + var pre string + if len(l.timeFormat) > 0 { - log.Printf(time.Now().Format(l.timeFormat)+" "+format, args...) - } else { - log.Printf(format, args...) + pre = time.Now().Format(l.timeFormat) + " " + } + if len(l.prefix) > 0 { + pre += l.prefix + " " + } + log.Printf(pre+format, args...) +} + +func (l *defaultLogger) WithPrefix(prefix string) Logger { + if len(l.prefix) > 0 { + prefix = l.prefix + " " + prefix + } + return &defaultLogger{ + logLevel: l.logLevel, + timeFormat: l.timeFormat, + prefix: prefix, } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go index ef71c7fa..4394ab04 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go @@ -82,6 +82,14 @@ func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { return b } +// MaxByteCount returns the maximum of two ByteCounts +func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return b + } + return a +} + // MaxDuration returns the max duration func MaxDuration(a, b time.Duration) time.Duration { if a > b { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint_packetnumber.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint_packetnumber.go new file mode 100644 index 00000000..b05afd4b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint_packetnumber.go @@ -0,0 +1,50 @@ +package utils + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// ReadVarIntPacketNumber reads a number in the QUIC varint packet number format +func ReadVarIntPacketNumber(b *bytes.Reader) (protocol.PacketNumber, protocol.PacketNumberLen, error) { + b1, err := b.ReadByte() + if err != nil { + return 0, 0, err + } + if b1&0x80 == 0 { + return protocol.PacketNumber(b1), protocol.PacketNumberLen1, nil + } + b2, err := b.ReadByte() + if err != nil { + return 0, 0, err + } + if b1&0x40 == 0 { + return protocol.PacketNumber(uint64(b1&0x3f)<<8 + uint64(b2)), protocol.PacketNumberLen2, nil + } + b3, err := b.ReadByte() + if err != nil { + return 0, 0, err + } + b4, err := b.ReadByte() + if err != nil { + return 0, 0, err + } + return protocol.PacketNumber(uint64(b1&0x3f)<<24 + uint64(b2)<<16 + uint64(b3)<<8 + uint64(b4)), protocol.PacketNumberLen4, nil +} + +// WriteVarIntPacketNumber writes a packet number in the QUIC varint packet number format +func WriteVarIntPacketNumber(b *bytes.Buffer, i protocol.PacketNumber, len protocol.PacketNumberLen) error { + switch len { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(i & 0x7f)) + case protocol.PacketNumberLen2: + b.Write([]byte{(uint8(i>>8) & 0x3f) | 0x80, uint8(i)}) + case protocol.PacketNumberLen4: + b.Write([]byte{(uint8(i>>24) & 0x3f) | 0xc0, uint8(i >> 16), uint8(i >> 8), uint8(i)}) + default: + return fmt.Errorf("invalid packet number length: %d", len) + } + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go index 1a95bb9f..00759db4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "errors" + "sort" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -14,18 +15,20 @@ const ackDelayExponent = 3 // An AckFrame is an ACK frame type AckFrame struct { - LargestAcked protocol.PacketNumber - LowestAcked protocol.PacketNumber - AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last + AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last + DelayTime time.Duration +} - // time when the LargestAcked was receiveid - // this field will not be set for received ACKs frames - PacketReceivedTime time.Time - DelayTime time.Duration +func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { + return parseAckOrAckEcnFrame(r, false, version) +} + +func parseAckEcnFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { + return parseAckOrAckEcnFrame(r, true, version) } // parseAckFrame reads an ACK frame -func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { +func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNumber) (*AckFrame, error) { if !version.UsesIETFFrameFormat() { return parseAckFrameLegacy(r, version) } @@ -36,16 +39,25 @@ func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, frame := &AckFrame{} - largestAcked, err := utils.ReadVarInt(r) + la, err := utils.ReadVarInt(r) if err != nil { return nil, err } - frame.LargestAcked = protocol.PacketNumber(largestAcked) + largestAcked := protocol.PacketNumber(la) delay, err := utils.ReadVarInt(r) if err != nil { return nil, err } frame.DelayTime = time.Duration(delay*1< frame.LargestAcked { + if ackBlock > largestAcked { return nil, errors.New("invalid first ACK range") } - smallest := frame.LargestAcked - ackBlock + smallest := largestAcked - ackBlock // read all the other ACK ranges - if numBlocks > 0 { - frame.AckRanges = append(frame.AckRanges, AckRange{First: smallest, Last: frame.LargestAcked}) - } + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) for i := uint64(0); i < numBlocks; i++ { g, err := utils.ReadVarInt(r) if err != nil { @@ -87,14 +97,12 @@ func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, return nil, errInvalidAckRanges } smallest = largest - ackBlock - frame.AckRanges = append(frame.AckRanges, AckRange{First: smallest, Last: largest}) + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) } - frame.LowestAcked = smallest if !frame.validateAckRanges() { return nil, errInvalidAckRanges } - return frame, nil } @@ -104,36 +112,22 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error return f.writeLegacy(b, version) } - b.WriteByte(0xe) - utils.WriteVarInt(b, uint64(f.LargestAcked)) + b.WriteByte(0x0d) + utils.WriteVarInt(b, uint64(f.LargestAcked())) utils.WriteVarInt(b, encodeAckDelay(f.DelayTime)) - // TODO: limit the number of ACK ranges, such that the frame doesn't grow larger than an upper bound - var lowestInFirstRange protocol.PacketNumber - if f.HasMissingRanges() { - utils.WriteVarInt(b, uint64(len(f.AckRanges)-1)) - lowestInFirstRange = f.AckRanges[0].First - } else { - utils.WriteVarInt(b, 0) - lowestInFirstRange = f.LowestAcked - } + numRanges := f.numEncodableAckRanges() + utils.WriteVarInt(b, uint64(numRanges-1)) // write the first range - utils.WriteVarInt(b, uint64(f.LargestAcked-lowestInFirstRange)) + _, firstRange := f.encodeAckRange(0) + utils.WriteVarInt(b, firstRange) // write all the other range - if !f.HasMissingRanges() { - return nil - } - var lowest protocol.PacketNumber - for i, ackRange := range f.AckRanges { - if i == 0 { - lowest = lowestInFirstRange - continue - } - utils.WriteVarInt(b, uint64(lowest-ackRange.Last-2)) - utils.WriteVarInt(b, uint64(ackRange.Last-ackRange.First)) - lowest = ackRange.First + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + utils.WriteVarInt(b, gap) + utils.WriteVarInt(b, len) } return nil } @@ -144,56 +138,62 @@ func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { return f.lengthLegacy(version) } - length := 1 + utils.VarIntLen(uint64(f.LargestAcked)) + utils.VarIntLen(encodeAckDelay(f.DelayTime)) + largestAcked := f.AckRanges[0].Largest + numRanges := f.numEncodableAckRanges() - var lowestInFirstRange protocol.PacketNumber - if f.HasMissingRanges() { - length += utils.VarIntLen(uint64(len(f.AckRanges) - 1)) - lowestInFirstRange = f.AckRanges[0].First - } else { - length += utils.VarIntLen(0) - lowestInFirstRange = f.LowestAcked - } - length += utils.VarIntLen(uint64(f.LargestAcked - lowestInFirstRange)) + length := 1 + utils.VarIntLen(uint64(largestAcked)) + utils.VarIntLen(encodeAckDelay(f.DelayTime)) - if !f.HasMissingRanges() { - return length - } - var lowest protocol.PacketNumber - for i, ackRange := range f.AckRanges { - if i == 0 { - lowest = ackRange.First - continue - } - length += utils.VarIntLen(uint64(lowest - ackRange.Last - 2)) - length += utils.VarIntLen(uint64(ackRange.Last - ackRange.First)) - lowest = ackRange.First + length += utils.VarIntLen(uint64(numRanges - 1)) + lowestInFirstRange := f.AckRanges[0].Smallest + length += utils.VarIntLen(uint64(largestAcked - lowestInFirstRange)) + + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + length += utils.VarIntLen(gap) + length += utils.VarIntLen(len) } return length } +// gets the number of ACK ranges that can be encoded +// such that the resulting frame is smaller than the maximum ACK frame size +func (f *AckFrame) numEncodableAckRanges() int { + length := 1 + utils.VarIntLen(uint64(f.LargestAcked())) + utils.VarIntLen(encodeAckDelay(f.DelayTime)) + length += 2 // assume that the number of ranges will consume 2 bytes + for i := 1; i < len(f.AckRanges); i++ { + gap, len := f.encodeAckRange(i) + rangeLen := utils.VarIntLen(gap) + utils.VarIntLen(len) + if length+rangeLen > protocol.MaxAckFrameSize { + // Writing range i would exceed the MaxAckFrameSize. + // So encode one range less than that. + return i - 1 + } + length += rangeLen + } + return len(f.AckRanges) +} + +func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { + if i == 0 { + return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) + } + return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), + uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) +} + // HasMissingRanges returns if this frame reports any missing packets func (f *AckFrame) HasMissingRanges() bool { - return len(f.AckRanges) > 0 + return len(f.AckRanges) > 1 } func (f *AckFrame) validateAckRanges() bool { if len(f.AckRanges) == 0 { - return true - } - - // if there are missing packets, there will always be at least 2 ACK ranges - if len(f.AckRanges) == 1 { - return false - } - - if f.AckRanges[0].Last != f.LargestAcked { return false } // check the validity of every single ACK range for _, ackRange := range f.AckRanges { - if ackRange.First > ackRange.Last { + if ackRange.Smallest > ackRange.Largest { return false } } @@ -204,10 +204,10 @@ func (f *AckFrame) validateAckRanges() bool { continue } lastAckRange := f.AckRanges[i-1] - if lastAckRange.First <= ackRange.First { + if lastAckRange.Smallest <= ackRange.Smallest { return false } - if lastAckRange.First <= ackRange.Last+1 { + if lastAckRange.Smallest <= ackRange.Largest+1 { return false } } @@ -215,23 +215,27 @@ func (f *AckFrame) validateAckRanges() bool { return true } +// LargestAcked is the largest acked packet number +func (f *AckFrame) LargestAcked() protocol.PacketNumber { + return f.AckRanges[0].Largest +} + +// LowestAcked is the lowest acked packet number +func (f *AckFrame) LowestAcked() protocol.PacketNumber { + return f.AckRanges[len(f.AckRanges)-1].Smallest +} + // AcksPacket determines if this ACK frame acks a certain packet number func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { - if p < f.LowestAcked || p > f.LargestAcked { // this is just a performance optimization + if p < f.LowestAcked() || p > f.LargestAcked() { return false } - if f.HasMissingRanges() { - // TODO: this could be implemented as a binary search - for _, ackRange := range f.AckRanges { - if p >= ackRange.First && p <= ackRange.Last { - return true - } - } - return false - } - // if packet doesn't have missing ranges - return (p >= f.LowestAcked && p <= f.LargestAcked) + i := sort.Search(len(f.AckRanges), func(i int) bool { + return p >= f.AckRanges[i].Smallest + }) + // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked + return p <= f.AckRanges[i].Largest } func encodeAckDelay(delay time.Duration) uint64 { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go index 1f1c22e9..c2a71e0b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go @@ -9,11 +9,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -var ( - errInconsistentAckLargestAcked = errors.New("internal inconsistency: LargestAcked does not match ACK ranges") - errInconsistentAckLowestAcked = errors.New("internal inconsistency: LowestAcked does not match ACK ranges") - errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") -) +var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, error) { frame := &AckFrame{} @@ -23,11 +19,7 @@ func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, return nil, err } - hasMissingRanges := false - if typeByte&0x20 == 0x20 { - hasMissingRanges = true - } - + hasMissingRanges := typeByte&0x20 == 0x20 largestAckedLen := 2 * ((typeByte & 0x0C) >> 2) if largestAckedLen == 0 { largestAckedLen = 1 @@ -38,11 +30,11 @@ func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, missingSequenceNumberDeltaLen = 1 } - largestAcked, err := utils.BigEndian.ReadUintN(r, largestAckedLen) + la, err := utils.BigEndian.ReadUintN(r, largestAckedLen) if err != nil { return nil, err } - frame.LargestAcked = protocol.PacketNumber(largestAcked) + largestAcked := protocol.PacketNumber(la) delay, err := utils.BigEndian.ReadUfloat16(r) if err != nil { @@ -62,11 +54,12 @@ func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, return nil, errInvalidAckRanges } - ackBlockLength, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) + abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) if err != nil { return nil, err } - if frame.LargestAcked > 0 && ackBlockLength < 1 { + ackBlockLength := protocol.PacketNumber(abl) + if largestAcked > 0 && ackBlockLength < 1 { return nil, errors.New("invalid first ACK range") } @@ -76,8 +69,8 @@ func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, if hasMissingRanges { ackRange := AckRange{ - First: protocol.PacketNumber(largestAcked-ackBlockLength) + 1, - Last: frame.LargestAcked, + Smallest: largestAcked - ackBlockLength + 1, + Largest: largestAcked, } frame.AckRanges = append(frame.AckRanges, ackRange) @@ -90,29 +83,27 @@ func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, return nil, err } - ackBlockLength, err = utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) + abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) if err != nil { return nil, err } - - length := protocol.PacketNumber(ackBlockLength) + ackBlockLength := protocol.PacketNumber(abl) if inLongBlock { - frame.AckRanges[len(frame.AckRanges)-1].First -= protocol.PacketNumber(gap) + length - frame.AckRanges[len(frame.AckRanges)-1].Last -= protocol.PacketNumber(gap) + frame.AckRanges[len(frame.AckRanges)-1].Smallest -= protocol.PacketNumber(gap) + ackBlockLength + frame.AckRanges[len(frame.AckRanges)-1].Largest -= protocol.PacketNumber(gap) } else { lastRangeComplete = false ackRange := AckRange{ - Last: frame.AckRanges[len(frame.AckRanges)-1].First - protocol.PacketNumber(gap) - 1, + Largest: frame.AckRanges[len(frame.AckRanges)-1].Smallest - protocol.PacketNumber(gap) - 1, } - ackRange.First = ackRange.Last - length + 1 + ackRange.Smallest = ackRange.Largest - ackBlockLength + 1 frame.AckRanges = append(frame.AckRanges, ackRange) } - if length > 0 { + if ackBlockLength > 0 { lastRangeComplete = true } - inLongBlock = (ackBlockLength == 0) } @@ -121,13 +112,11 @@ func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, if !lastRangeComplete { frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1] } - - frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].First } else { - if frame.LargestAcked == 0 { - frame.LowestAcked = 0 - } else { - frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength) + frame.AckRanges = make([]AckRange, 1) + if largestAcked != 0 { + frame.AckRanges[0].Largest = largestAcked + frame.AckRanges[0].Smallest = largestAcked + 1 - ackBlockLength } } @@ -171,7 +160,8 @@ func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, } func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error { - largestAckedLen := protocol.GetPacketNumberLength(f.LargestAcked) + largestAcked := f.LargestAcked() + largestAckedLen := protocol.GetPacketNumberLength(largestAcked) typeByte := uint8(0x40) @@ -192,16 +182,15 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error switch largestAckedLen { case protocol.PacketNumberLen1: - b.WriteByte(uint8(f.LargestAcked)) + b.WriteByte(uint8(largestAcked)) case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(f.LargestAcked)) + utils.BigEndian.WriteUint16(b, uint16(largestAcked)) case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(f.LargestAcked)) + utils.BigEndian.WriteUint32(b, uint32(largestAcked)) case protocol.PacketNumberLen6: - utils.BigEndian.WriteUint48(b, uint64(f.LargestAcked)&(1<<48-1)) + utils.BigEndian.WriteUint48(b, uint64(largestAcked)&(1<<48-1)) } - f.DelayTime = time.Since(f.PacketReceivedTime) utils.BigEndian.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) var numRanges uint64 @@ -216,15 +205,9 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error var firstAckBlockLength protocol.PacketNumber if !f.HasMissingRanges() { - firstAckBlockLength = f.LargestAcked - f.LowestAcked + 1 + firstAckBlockLength = largestAcked - f.LowestAcked() + 1 } else { - if f.LargestAcked != f.AckRanges[0].Last { - return errInconsistentAckLargestAcked - } - if f.LowestAcked != f.AckRanges[len(f.AckRanges)-1].First { - return errInconsistentAckLowestAcked - } - firstAckBlockLength = f.LargestAcked - f.AckRanges[0].First + 1 + firstAckBlockLength = largestAcked - f.AckRanges[0].Smallest + 1 numRangesWritten++ } @@ -244,8 +227,8 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error continue } - length := ackRange.Last - ackRange.First + 1 - gap := f.AckRanges[i-1].First - ackRange.Last - 1 + length := ackRange.Largest - ackRange.Smallest + 1 + gap := f.AckRanges[i-1].Smallest - ackRange.Largest - 1 num := gap/0xFF + 1 if gap%0xFF == 0 { @@ -310,7 +293,7 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error func (f *AckFrame) lengthLegacy(_ protocol.VersionNumber) protocol.ByteCount { length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp - length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked)) + length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked())) missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen()) @@ -337,7 +320,7 @@ func (f *AckFrame) numWritableNackRanges() uint64 { } lastAckRange := f.AckRanges[i-1] - gap := lastAckRange.First - ackRange.Last - 1 + gap := lastAckRange.Smallest - ackRange.Largest - 1 rangeLength := 1 + uint64(gap)/0xFF if uint64(gap)%0xFF == 0 { rangeLength-- @@ -358,13 +341,13 @@ func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen { if f.HasMissingRanges() { for _, ackRange := range f.AckRanges { - rangeLength := ackRange.Last - ackRange.First + 1 + rangeLength := ackRange.Largest - ackRange.Smallest + 1 if rangeLength > maxRangeLength { maxRangeLength = rangeLength } } } else { - maxRangeLength = f.LargestAcked - f.LowestAcked + 1 + maxRangeLength = f.LargestAcked() - f.LowestAcked() + 1 } if maxRangeLength <= 0xFF { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go index 783528e6..0f418580 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go @@ -4,11 +4,11 @@ import "github.com/lucas-clemente/quic-go/internal/protocol" // AckRange is an ACK range type AckRange struct { - First protocol.PacketNumber - Last protocol.PacketNumber + Smallest protocol.PacketNumber + Largest protocol.PacketNumber } // Len returns the number of packets contained in this ACK range func (r AckRange) Len() protocol.PacketNumber { - return r.Last - r.First + 1 + return r.Largest - r.Smallest + 1 } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go index 6911446f..ce558e36 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "errors" "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -11,19 +12,19 @@ import ( // ParseNextFrame parses the next frame // It skips PADDING frames. func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Frame, error) { - if r.Len() == 0 { - return nil, nil - } - typeByte, _ := r.ReadByte() - if typeByte == 0x0 { // PADDING frame - return ParseNextFrame(r, hdr, v) - } - r.UnreadByte() + for r.Len() != 0 { + typeByte, _ := r.ReadByte() + if typeByte == 0x0 { // PADDING frame + continue + } + r.UnreadByte() - if !v.UsesIETFFrameFormat() { - return parseGQUICFrame(r, typeByte, hdr, v) + if !v.UsesIETFFrameFormat() { + return parseGQUICFrame(r, typeByte, hdr, v) + } + return parseIETFFrame(r, typeByte, v) } - return parseIETFFrame(r, typeByte, v) + return nil, nil } func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) { @@ -85,11 +86,26 @@ func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (F if err != nil { err = qerr.Error(qerr.InvalidFrameData, err.Error()) } - case 0xe: + case 0xd: frame, err = parseAckFrame(r, v) if err != nil { err = qerr.Error(qerr.InvalidAckData, err.Error()) } + case 0xe: + frame, err = parsePathChallengeFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0xf: + frame, err = parsePathResponseFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0x1a: + frame, err = parseAckEcnFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidAckData, err.Error()) + } default: err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) } @@ -139,6 +155,10 @@ func parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *Header, v protocol.Ver err = qerr.Error(qerr.InvalidBlockedData, err.Error()) } case 0x6: + if !v.UsesStopWaitingFrames() { + err = errors.New("STOP_WAITING frames not supported by this QUIC version") + break + } frame, err = parseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, v) if err != nil { err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go index fc346f3f..41a322b8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go @@ -2,6 +2,9 @@ package wire import ( "bytes" + "crypto/rand" + "errors" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -10,12 +13,18 @@ import ( // Header is the header of a QUIC packet. // It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header. type Header struct { - Raw []byte - ConnectionID protocol.ConnectionID - OmitConnectionID bool - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - Version protocol.VersionNumber // VersionNumber sent by the client + IsPublicHeader bool + + Raw []byte + + Version protocol.VersionNumber + + DestConnectionID protocol.ConnectionID + SrcConnectionID protocol.ConnectionID + OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber IsVersionNegotiation bool SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server @@ -29,85 +38,296 @@ type Header struct { Type protocol.PacketType IsLongHeader bool KeyPhase int - - // only needed for logging - isPublicHeader bool + PayloadLen protocol.ByteCount + Token []byte } -// ParseHeaderSentByServer parses the header for a packet that was sent by the server. -func ParseHeaderSentByServer(b *bytes.Reader, version protocol.VersionNumber) (*Header, error) { - typeByte, err := b.ReadByte() - if err != nil { - return nil, err - } - _ = b.UnreadByte() // unread the type byte - - var isPublicHeader bool - if typeByte&0x80 > 0 { // gQUIC always has 0x80 unset. IETF Long Header or Version Negotiation - isPublicHeader = false - } else if typeByte&0xcf == 0x9 { // gQUIC Version Negotiation Packet - isPublicHeader = true - } else { - // the client knows the version that this packet was sent with - isPublicHeader = !version.UsesTLS() - } - - return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader) -} - -// ParseHeaderSentByClient parses the header for a packet that was sent by the client. -func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) { - typeByte, err := b.ReadByte() - if err != nil { - return nil, err - } - _ = b.UnreadByte() // unread the type byte - - // In an IETF QUIC packet header - // * either 0x80 is set (for the Long Header) - // * or 0x8 is unset (for the Short Header) - // In a gQUIC Public Header - // * 0x80 is always unset and - // * and 0x8 is always set (this is the Connection ID flag, which the client always sets) - isPublicHeader := typeByte&0x88 == 0x8 - return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader) -} - -func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHeader bool) (*Header, error) { - // This is a gQUIC Public Header. - if isPublicHeader { - hdr, err := parsePublicHeader(b, sentBy) - if err != nil { - return nil, err - } - hdr.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later - return hdr, nil - } - return parseHeader(b, sentBy) -} +var errInvalidPacketNumberLen = errors.New("invalid packet number length") // Write writes the Header. -func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { - if !version.UsesTLS() { - h.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later - return h.writePublicHeader(b, pers, version) +func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error { + if !ver.UsesIETFHeaderFormat() { + h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later + return h.writePublicHeader(b, pers, ver) } - return h.writeHeader(b) + // write an IETF QUIC header + if h.IsLongHeader { + return h.writeLongHeader(b, ver) + } + return h.writeShortHeader(b, ver) +} + +// TODO: add support for the key phase +func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error { + b.WriteByte(byte(0x80 | h.Type)) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) + if err != nil { + return err + } + b.WriteByte(connIDLen) + b.Write(h.DestConnectionID.Bytes()) + b.Write(h.SrcConnectionID.Bytes()) + + if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() { + utils.WriteVarInt(b, uint64(len(h.Token))) + b.Write(h.Token) + } + + if h.Type == protocol.PacketTypeRetry { + odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID) + if err != nil { + return err + } + // randomize the first 4 bits + odcilByte := make([]byte, 1) + _, _ = rand.Read(odcilByte) // it's safe to ignore the error here + odcilByte[0] = (odcilByte[0] & 0xf0) | odcil + b.Write(odcilByte) + b.Write(h.OrigDestConnectionID.Bytes()) + b.Write(h.Token) + return nil + } + + if v.UsesLengthInHeader() { + utils.WriteVarInt(b, uint64(h.PayloadLen)) + } + if v.UsesVarintPacketNumbers() { + return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen) + } + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + if h.Type == protocol.PacketType0RTT && v == protocol.Version44 { + if len(h.DiversificationNonce) != 32 { + return errors.New("invalid diversification nonce length") + } + b.Write(h.DiversificationNonce) + } + return nil +} + +func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error { + typeByte := byte(0x30) + typeByte |= byte(h.KeyPhase << 6) + if !v.UsesVarintPacketNumbers() { + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + case protocol.PacketNumberLen2: + typeByte |= 0x1 + case protocol.PacketNumberLen4: + typeByte |= 0x2 + default: + return errInvalidPacketNumberLen + } + } + + b.WriteByte(typeByte) + b.Write(h.DestConnectionID.Bytes()) + + if !v.UsesVarintPacketNumbers() { + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + } + return nil + } + return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen) +} + +// writePublicHeader writes a Public Header. +func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error { + if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) { + return errors.New("PublicHeader: Can only write regular packets") + } + if h.SrcConnectionID.Len() != 0 { + return errors.New("PublicHeader: SrcConnectionID must not be set") + } + if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 { + return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID)) + } + + publicFlagByte := uint8(0x00) + if h.VersionFlag { + publicFlagByte |= 0x01 + } + if h.DestConnectionID.Len() > 0 { + publicFlagByte |= 0x08 + } + if len(h.DiversificationNonce) > 0 { + if len(h.DiversificationNonce) != 32 { + return errors.New("invalid diversification nonce length") + } + publicFlagByte |= 0x04 + } + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + publicFlagByte |= 0x00 + case protocol.PacketNumberLen2: + publicFlagByte |= 0x10 + case protocol.PacketNumberLen4: + publicFlagByte |= 0x20 + } + b.WriteByte(publicFlagByte) + + if h.DestConnectionID.Len() > 0 { + b.Write(h.DestConnectionID) + } + if h.VersionFlag && pers == protocol.PerspectiveClient { + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + } + if len(h.DiversificationNonce) > 0 { + b.Write(h.DiversificationNonce) + } + + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + case protocol.PacketNumberLen6: + return errInvalidPacketNumberLen + default: + return errors.New("PublicHeader: PacketNumberLen not set") + } + + return nil } // GetLength determines the length of the Header. -func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNumber) (protocol.ByteCount, error) { - if !version.UsesTLS() { - return h.getPublicHeaderLength(pers) +func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) { + if !v.UsesIETFHeaderFormat() { + return h.getPublicHeaderLength() } - return h.getHeaderLength() + return h.getHeaderLength(v) +} + +func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) { + if h.IsLongHeader { + length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + if v.UsesLengthInHeader() { + length += utils.VarIntLen(uint64(h.PayloadLen)) + } + if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() { + length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + } + if h.Type == protocol.PacketType0RTT && v == protocol.Version44 { + length += protocol.ByteCount(len(h.DiversificationNonce)) + } + return length, nil + } + + length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { + return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + length += protocol.ByteCount(h.PacketNumberLen) + return length, nil +} + +// getPublicHeaderLength gets the length of the publicHeader in bytes. +// It can only be called for regular packets. +func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) { + length := protocol.ByteCount(1) // 1 byte for public flags + if h.PacketNumberLen == protocol.PacketNumberLen6 { + return 0, errInvalidPacketNumberLen + } + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { + return 0, errPacketNumberLenNotSet + } + length += protocol.ByteCount(h.PacketNumberLen) + length += protocol.ByteCount(h.DestConnectionID.Len()) + // Version Number in packets sent by the client + if h.VersionFlag { + length += 4 + } + length += protocol.ByteCount(len(h.DiversificationNonce)) + return length, nil } // Log logs the Header func (h *Header) Log(logger utils.Logger) { - if h.isPublicHeader { + if h.IsPublicHeader { h.logPublicHeader(logger) } else { h.logHeader(logger) } } + +func (h *Header) logHeader(logger utils.Logger) { + if h.IsLongHeader { + if h.Version == 0 { + logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions) + } else { + var token string + if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { + if len(h.Token) == 0 { + token = "Token: (empty), " + } else { + token = fmt.Sprintf("Token: %#x, ", h.Token) + } + } + if h.Type == protocol.PacketTypeRetry { + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version) + return + } + if h.Version == protocol.Version44 { + var divNonce string + if h.Type == protocol.PacketType0RTT { + divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce) + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version) + return + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version) + } + } else { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } +} + +func (h *Header) logPublicHeader(logger utils.Logger) { + ver := "(unset)" + if h.Version != 0 { + ver = h.Version.String() + } + logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) +} + +func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) { + dcil, err := encodeSingleConnIDLen(dest) + if err != nil { + return 0, err + } + scil, err := encodeSingleConnIDLen(src) + if err != nil { + return 0, err + } + return scil | dcil<<4, nil +} + +func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) { + len := id.Len() + if len == 0 { + return 0, nil + } + if len < 4 || len > 18 { + return 0, fmt.Errorf("invalid connection ID length: %d bytes", len) + } + return byte(len - 3), nil +} + +func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) { + return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf) +} + +func decodeSingleConnIDLen(enc uint8) int { + if enc == 0 { + return 0 + } + return int(enc) + 3 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header_parser.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header_parser.go new file mode 100644 index 00000000..08f5b406 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header_parser.go @@ -0,0 +1,273 @@ +package wire + +import ( + "bytes" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// The InvariantHeader is the version independent part of the header +type InvariantHeader struct { + IsLongHeader bool + Version protocol.VersionNumber + SrcConnectionID protocol.ConnectionID + DestConnectionID protocol.ConnectionID + + typeByte byte +} + +// ParseInvariantHeader parses the version independent part of the header +func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*InvariantHeader, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + + h := &InvariantHeader{typeByte: typeByte} + h.IsLongHeader = typeByte&0x80 > 0 + + // If this is not a Long Header, it could either be a Public Header or a Short Header. + if !h.IsLongHeader { + // In the Public Header 0x8 is the Connection ID Flag. + // In the IETF Short Header: + // * 0x8 it is the gQUIC Demultiplexing bit, and always 0. + // * 0x20 and 0x10 are always 1. + var connIDLen int + if typeByte&0x8 > 0 { // Public Header containing a connection ID + connIDLen = 8 + } + if typeByte&0x38 == 0x30 { // Short Header + connIDLen = shortHeaderConnIDLen + } + if connIDLen > 0 { + h.DestConnectionID, err = protocol.ReadConnectionID(b, connIDLen) + if err != nil { + return nil, err + } + } + return h, nil + } + // Long Header + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + h.Version = protocol.VersionNumber(v) + connIDLenByte, err := b.ReadByte() + if err != nil { + return nil, err + } + dcil, scil := decodeConnIDLen(connIDLenByte) + h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil) + if err != nil { + return nil, err + } + h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil) + if err != nil { + return nil, err + } + return h, nil +} + +// Parse parses the version dependent part of the header +func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) { + if iv.IsLongHeader { + if iv.Version == 0 { // Version Negotiation Packet + return iv.parseVersionNegotiationPacket(b) + } + return iv.parseLongHeader(b, sentBy, ver) + } + // The Public Header never uses 6 byte packet numbers. + // Therefore, the third and fourth bit will never be 11. + // For the Short Header, the third and fourth bit are always 11. + if iv.typeByte&0x30 != 0x30 { + if sentBy == protocol.PerspectiveServer && iv.typeByte&0x1 > 0 { + return iv.parseVersionNegotiationPacket(b) + } + return iv.parsePublicHeader(b, sentBy, ver) + } + return iv.parseShortHeader(b, ver) +} + +func (iv *InvariantHeader) toHeader() *Header { + return &Header{ + IsLongHeader: iv.IsLongHeader, + DestConnectionID: iv.DestConnectionID, + SrcConnectionID: iv.SrcConnectionID, + Version: iv.Version, + } +} + +func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Header, error) { + h := iv.toHeader() + h.VersionFlag = true + if b.Len() == 0 { + return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") + } + h.IsVersionNegotiation = true + h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, qerr.InvalidVersionNegotiationPacket + } + h.SupportedVersions[i] = protocol.VersionNumber(v) + } + return h, nil +} + +func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, v protocol.VersionNumber) (*Header, error) { + h := iv.toHeader() + h.Type = protocol.PacketType(iv.typeByte & 0x7f) + + if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake { + return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) + } + + if h.Type == protocol.PacketTypeRetry { + odcilByte, err := b.ReadByte() + if err != nil { + return nil, err + } + odcil := decodeSingleConnIDLen(odcilByte & 0xf) + h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil) + if err != nil { + return nil, err + } + h.Token = make([]byte, b.Len()) + if _, err := io.ReadFull(b, h.Token); err != nil { + return nil, err + } + return h, nil + } + + if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() { + tokenLen, err := utils.ReadVarInt(b) + if err != nil { + return nil, err + } + if tokenLen > uint64(b.Len()) { + return nil, io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return nil, err + } + } + + if v.UsesLengthInHeader() { + pl, err := utils.ReadVarInt(b) + if err != nil { + return nil, err + } + h.PayloadLen = protocol.ByteCount(pl) + } + if v.UsesVarintPacketNumbers() { + pn, pnLen, err := utils.ReadVarIntPacketNumber(b) + if err != nil { + return nil, err + } + h.PacketNumber = pn + h.PacketNumberLen = pnLen + } else { + pn, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + h.PacketNumber = protocol.PacketNumber(pn) + h.PacketNumberLen = protocol.PacketNumberLen4 + } + if h.Type == protocol.PacketType0RTT && v == protocol.Version44 && sentBy == protocol.PerspectiveServer { + h.DiversificationNonce = make([]byte, 32) + if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + } + + return h, nil +} + +func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*Header, error) { + h := iv.toHeader() + h.KeyPhase = int(iv.typeByte&0x40) >> 6 + + if v.UsesVarintPacketNumbers() { + pn, pnLen, err := utils.ReadVarIntPacketNumber(b) + if err != nil { + return nil, err + } + h.PacketNumber = pn + h.PacketNumberLen = pnLen + } else { + switch iv.typeByte & 0x3 { + case 0x0: + h.PacketNumberLen = protocol.PacketNumberLen1 + case 0x1: + h.PacketNumberLen = protocol.PacketNumberLen2 + case 0x2: + h.PacketNumberLen = protocol.PacketNumberLen4 + default: + return nil, errInvalidPacketNumberLen + } + p, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen)) + if err != nil { + return nil, err + } + h.PacketNumber = protocol.PacketNumber(p) + } + return h, nil +} + +func (iv *InvariantHeader) parsePublicHeader(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) { + h := iv.toHeader() + h.IsPublicHeader = true + h.ResetFlag = iv.typeByte&0x2 > 0 + if h.ResetFlag { + return h, nil + } + + h.VersionFlag = iv.typeByte&0x1 > 0 + if h.VersionFlag && sentBy == protocol.PerspectiveClient { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + h.Version = protocol.VersionNumber(v) + } + + // Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server. + // It doesn't have any meaning when sent by the client. + if sentBy == protocol.PerspectiveServer && iv.typeByte&0x4 > 0 { + h.DiversificationNonce = make([]byte, 32) + if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + } + + switch iv.typeByte & 0x30 { + case 0x00: + h.PacketNumberLen = protocol.PacketNumberLen1 + case 0x10: + h.PacketNumberLen = protocol.PacketNumberLen2 + case 0x20: + h.PacketNumberLen = protocol.PacketNumberLen4 + } + + pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen)) + if err != nil { + return nil, err + } + h.PacketNumber = protocol.PacketNumber(pn) + + return h, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go deleted file mode 100644 index 01bf0a26..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go +++ /dev/null @@ -1,187 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/qerr" -) - -// parseHeader parses the header. -func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { - typeByte, err := b.ReadByte() - if err != nil { - return nil, err - } - if typeByte&0x80 > 0 { - return parseLongHeader(b, packetSentBy, typeByte) - } - return parseShortHeader(b, typeByte) -} - -// parse long header and version negotiation packets -func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte) (*Header, error) { - connID, err := utils.BigEndian.ReadUint64(b) - if err != nil { - return nil, err - } - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, err - } - h := &Header{ - ConnectionID: protocol.ConnectionID(connID), - Version: protocol.VersionNumber(v), - } - if v == 0 { // version negotiation packet - if sentBy == protocol.PerspectiveClient { - return nil, qerr.InvalidVersion - } - if b.Len() == 0 { - return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") - } - h.IsVersionNegotiation = true - h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4) - for i := 0; b.Len() > 0; i++ { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, qerr.InvalidVersionNegotiationPacket - } - h.SupportedVersions[i] = protocol.VersionNumber(v) - } - return h, nil - } - h.IsLongHeader = true - pn, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, err - } - h.PacketNumber = protocol.PacketNumber(pn) - h.PacketNumberLen = protocol.PacketNumberLen4 - h.Type = protocol.PacketType(typeByte & 0x7f) - if sentBy == protocol.PerspectiveClient && (h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeHandshake && h.Type != protocol.PacketType0RTT) { - return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) - } - if sentBy == protocol.PerspectiveServer && (h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketTypeHandshake) { - return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) - } - return h, nil -} - -func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { - omitConnID := typeByte&0x40 > 0 - var connID uint64 - if !omitConnID { - var err error - connID, err = utils.BigEndian.ReadUint64(b) - if err != nil { - return nil, err - } - } - // bit 4 must be set, bit 5 must be unset - if typeByte&0x18 != 0x10 { - return nil, errors.New("invalid bit 4 and 5") - } - var pnLen protocol.PacketNumberLen - switch typeByte & 0x7 { - case 0x0: - pnLen = protocol.PacketNumberLen1 - case 0x1: - pnLen = protocol.PacketNumberLen2 - case 0x2: - pnLen = protocol.PacketNumberLen4 - default: - return nil, errors.New("invalid short header type") - } - pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen)) - if err != nil { - return nil, err - } - return &Header{ - KeyPhase: int(typeByte&0x20) >> 5, - OmitConnectionID: omitConnID, - ConnectionID: protocol.ConnectionID(connID), - PacketNumber: protocol.PacketNumber(pn), - PacketNumberLen: pnLen, - }, nil -} - -// writeHeader writes the Header. -func (h *Header) writeHeader(b *bytes.Buffer) error { - if h.IsLongHeader { - return h.writeLongHeader(b) - } - return h.writeShortHeader(b) -} - -// TODO: add support for the key phase -func (h *Header) writeLongHeader(b *bytes.Buffer) error { - b.WriteByte(byte(0x80 | h.Type)) - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) - utils.BigEndian.WriteUint32(b, uint32(h.Version)) - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) - return nil -} - -func (h *Header) writeShortHeader(b *bytes.Buffer) error { - typeByte := byte(0x10) - typeByte ^= byte(h.KeyPhase << 5) - if h.OmitConnectionID { - typeByte ^= 0x40 - } - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - case protocol.PacketNumberLen2: - typeByte ^= 0x1 - case protocol.PacketNumberLen4: - typeByte ^= 0x2 - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - b.WriteByte(typeByte) - - if !h.OmitConnectionID { - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) - } - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(h.PacketNumber)) - case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) - case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) - } - return nil -} - -// getHeaderLength gets the length of the Header in bytes. -func (h *Header) getHeaderLength() (protocol.ByteCount, error) { - if h.IsLongHeader { - return 1 + 8 + 4 + 4, nil - } - - length := protocol.ByteCount(1) // type byte - if !h.OmitConnectionID { - length += 8 - } - if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { - return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - length += protocol.ByteCount(h.PacketNumberLen) - return length, nil -} - -func (h *Header) logHeader(logger utils.Logger) { - if h.IsLongHeader { - logger.Debugf(" Long Header{Type: %s, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) - } else { - connID := "(omitted)" - if !h.OmitConnectionID { - connID = fmt.Sprintf("%#x", h.ConnectionID) - } - logger.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go index eaf5b1ea..465e82ab 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go @@ -1,6 +1,11 @@ package wire -import "github.com/lucas-clemente/quic-go/internal/utils" +import ( + "fmt" + "strings" + + "github.com/lucas-clemente/quic-go/internal/utils" +) // LogFrame logs a frame, either sent or received func LogFrame(logger utils.Logger, frame Frame, sent bool) { @@ -21,7 +26,15 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) { logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) } case *AckFrame: - logger.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) + if len(f.AckRanges) > 1 { + ackRanges := make([]string, len(f.AckRanges)) + for i, r := range f.AckRanges { + ackRanges[i] = fmt.Sprintf("{Largest: %#x, Smallest: %#x}", r.Largest, r.Smallest) + } + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %#x, LowestAcked: %#x, AckRanges: {%s}, DelayTime: %s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String()) + } else { + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %#x, LowestAcked: %#x, DelayTime: %s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String()) + } default: logger.Debugf("\t%s %#v", dir, frame) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go new file mode 100644 index 00000000..f2a27d84 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A PathChallengeFrame is a PATH_CHALLENGE frame +type PathChallengeFrame struct { + Data [8]byte +} + +func parsePathChallengeFrame(r *bytes.Reader, version protocol.VersionNumber) (*PathChallengeFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathChallengeFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x0e) + b.WriteByte(typeByte) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go new file mode 100644 index 00000000..2ab2fcda --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A PathResponseFrame is a PATH_RESPONSE frame +type PathResponseFrame struct { + Data [8]byte +} + +func parsePathResponseFrame(r *bytes.Reader, version protocol.VersionNumber) (*PathResponseFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathResponseFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x0f) + b.WriteByte(typeByte) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go deleted file mode 100644 index af996b29..00000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go +++ /dev/null @@ -1,244 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "io" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/qerr" -) - -var ( - errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errReceivedOmittedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported") - errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") - errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") -) - -// writePublicHeader writes a Public Header. -func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error { - if h.VersionFlag && pers == protocol.PerspectiveServer { - return errors.New("PublicHeader: Writing of Version Negotiation Packets not supported") - } - if h.VersionFlag && h.ResetFlag { - return errResetAndVersionFlagSet - } - - publicFlagByte := uint8(0x00) - if h.VersionFlag { - publicFlagByte |= 0x01 - } - if h.ResetFlag { - publicFlagByte |= 0x02 - } - if !h.OmitConnectionID { - publicFlagByte |= 0x08 - } - if len(h.DiversificationNonce) > 0 { - if len(h.DiversificationNonce) != 32 { - return errors.New("invalid diversification nonce length") - } - publicFlagByte |= 0x04 - } - // only set PacketNumberLen bits if a packet number will be written - if h.hasPacketNumber(pers) { - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - publicFlagByte |= 0x00 - case protocol.PacketNumberLen2: - publicFlagByte |= 0x10 - case protocol.PacketNumberLen4: - publicFlagByte |= 0x20 - case protocol.PacketNumberLen6: - publicFlagByte |= 0x30 - } - } - b.WriteByte(publicFlagByte) - - if !h.OmitConnectionID { - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) - } - if h.VersionFlag && pers == protocol.PerspectiveClient { - utils.BigEndian.WriteUint32(b, uint32(h.Version)) - } - if len(h.DiversificationNonce) > 0 { - b.Write(h.DiversificationNonce) - } - // if we're a server, and the VersionFlag is set, we must not include anything else in the packet - if !h.hasPacketNumber(pers) { - return nil - } - - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(h.PacketNumber)) - case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) - case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) - case protocol.PacketNumberLen6: - utils.BigEndian.WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) - default: - return errors.New("PublicHeader: PacketNumberLen not set") - } - - return nil -} - -// parsePublicHeader parses a QUIC packet's Public Header. -// The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient. -func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { - header := &Header{} - - // First byte - publicFlagByte, err := b.ReadByte() - if err != nil { - return nil, err - } - header.ResetFlag = publicFlagByte&0x02 > 0 - header.VersionFlag = publicFlagByte&0x01 > 0 - - // TODO: activate this check once Chrome sends the correct value - // see https://github.com/lucas-clemente/quic-go/issues/232 - // if publicFlagByte&0x04 > 0 { - // return nil, errors.New("diversification nonces should only be sent by servers") - // } - - header.OmitConnectionID = publicFlagByte&0x08 == 0 - if header.OmitConnectionID && packetSentBy == protocol.PerspectiveClient { - return nil, errReceivedOmittedConnectionID - } - if header.hasPacketNumber(packetSentBy) { - switch publicFlagByte & 0x30 { - case 0x30: - header.PacketNumberLen = protocol.PacketNumberLen6 - case 0x20: - header.PacketNumberLen = protocol.PacketNumberLen4 - case 0x10: - header.PacketNumberLen = protocol.PacketNumberLen2 - case 0x00: - header.PacketNumberLen = protocol.PacketNumberLen1 - } - } - - // Connection ID - if !header.OmitConnectionID { - var connID uint64 - connID, err = utils.BigEndian.ReadUint64(b) - if err != nil { - return nil, err - } - header.ConnectionID = protocol.ConnectionID(connID) - if header.ConnectionID == 0 { - return nil, errInvalidConnectionID - } - } - - // Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server. - // It doesn't have any meaning when sent by the client. - if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 { - if !header.VersionFlag && !header.ResetFlag { - header.DiversificationNonce = make([]byte, 32) - if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil { - return nil, err - } - } - } - - // Version (optional) - if !header.ResetFlag && header.VersionFlag { - if packetSentBy == protocol.PerspectiveServer { // parse the version negotiation packet - if b.Len() == 0 { - return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") - } - if b.Len()%4 != 0 { - return nil, qerr.InvalidVersionNegotiationPacket - } - header.IsVersionNegotiation = true - header.SupportedVersions = make([]protocol.VersionNumber, 0) - for { - var versionTag uint32 - versionTag, err = utils.BigEndian.ReadUint32(b) - if err != nil { - break - } - v := protocol.VersionNumber(versionTag) - header.SupportedVersions = append(header.SupportedVersions, v) - } - // a version negotiation packet doesn't have a packet number - return header, nil - } - // packet was sent by the client. Read the version number - var versionTag uint32 - versionTag, err = utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, err - } - header.Version = protocol.VersionNumber(versionTag) - } - - // Packet number - if header.hasPacketNumber(packetSentBy) { - packetNumber, err := utils.BigEndian.ReadUintN(b, uint8(header.PacketNumberLen)) - if err != nil { - return nil, err - } - header.PacketNumber = protocol.PacketNumber(packetNumber) - } - return header, nil -} - -// getPublicHeaderLength gets the length of the publicHeader in bytes. -// It can only be called for regular packets. -func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.ByteCount, error) { - if h.VersionFlag && h.ResetFlag { - return 0, errResetAndVersionFlagSet - } - if h.VersionFlag && pers == protocol.PerspectiveServer { - return 0, errGetLengthNotForVersionNegotiation - } - - length := protocol.ByteCount(1) // 1 byte for public flags - if h.hasPacketNumber(pers) { - if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { - return 0, errPacketNumberLenNotSet - } - length += protocol.ByteCount(h.PacketNumberLen) - } - if !h.OmitConnectionID { - length += 8 // 8 bytes for the connection ID - } - // Version Number in packets sent by the client - if h.VersionFlag { - length += 4 - } - length += protocol.ByteCount(len(h.DiversificationNonce)) - return length, nil -} - -// hasPacketNumber determines if this Public Header will contain a packet number -// this depends on the ResetFlag, the VersionFlag and who sent the packet -func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { - if h.ResetFlag { - return false - } - if h.VersionFlag && packetSentBy == protocol.PerspectiveServer { - return false - } - return true -} - -func (h *Header) logPublicHeader(logger utils.Logger) { - connID := "(omitted)" - if !h.OmitConnectionID { - connID = fmt.Sprintf("%#x", h.ConnectionID) - } - ver := "(unset)" - if h.Version != 0 { - ver = h.Version.String() - } - logger.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go index 6adc9f69..b57ea7ad 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go @@ -16,11 +16,11 @@ type PublicReset struct { Nonce uint64 } -// WritePublicReset writes a Public Reset +// WritePublicReset writes a PUBLIC_RESET func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { b := &bytes.Buffer{} b.WriteByte(0x0a) - utils.BigEndian.WriteUint64(b, uint64(connectionID)) + b.Write(connectionID) utils.LittleEndian.WriteUint32(b, uint32(handshake.TagPRST)) utils.LittleEndian.WriteUint32(b, 2) utils.LittleEndian.WriteUint32(b, uint32(handshake.TagRNON)) @@ -32,7 +32,7 @@ func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p return b.Bytes() } -// ParsePublicReset parses a Public Reset +// ParsePublicReset parses a PUBLIC_RESET func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { pr := PublicReset{} msg, err := handshake.ParseHandshakeMessage(r) @@ -44,7 +44,7 @@ func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { } // The RSEQ tag is mandatory according to the gQUIC wire spec. - // However, Google doesn't send RSEQ in their Public Resets. + // However, Google doesn't send RSEQ in their PUBLIC_RESETs. // Therefore, we'll treat RSEQ as an optional field. if rseq, ok := msg.Data[handshake.TagRSEQ]; ok { if len(rseq) != 8 { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go index d848127a..192319e6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go @@ -76,10 +76,6 @@ func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF if frame.Offset+frame.DataLen() > protocol.MaxByteCount { return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") } - // empty frames are only allowed if they have offset 0 or the FIN bit set - if frame.DataLen() == 0 && !frame.FinBit && frame.Offset != 0 { - return nil, qerr.EmptyStreamFrameNoFin - } return frame, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go index cf72fc2e..df8b1f2c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go @@ -12,7 +12,7 @@ import ( func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { buf := bytes.NewBuffer(make([]byte, 0, 1+8+len(versions)*4)) buf.Write([]byte{0x1 | 0x8}) // type byte - utils.BigEndian.WriteUint64(buf, uint64(connID)) + buf.Write(connID) for _, v := range versions { utils.BigEndian.WriteUint32(buf, uint32(v)) } @@ -20,19 +20,23 @@ func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []pro } // ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft -func ComposeVersionNegotiation( - connID protocol.ConnectionID, - versions []protocol.VersionNumber, -) []byte { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) ([]byte, error) { greasedVersions := protocol.GetGreasedVersions(versions) - buf := bytes.NewBuffer(make([]byte, 0, 1+8+4+len(greasedVersions)*4)) + expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* connection ID length field */ + destConnID.Len() + srcConnID.Len() + len(greasedVersions)*4 + buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) r := make([]byte, 1) _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. buf.WriteByte(r[0] | 0x80) - utils.BigEndian.WriteUint64(buf, uint64(connID)) utils.BigEndian.WriteUint32(buf, 0) // version 0 + connIDLen, err := encodeConnIDLen(destConnID, srcConnID) + if err != nil { + return nil, err + } + buf.WriteByte(connIDLen) + buf.Write(destConnID) + buf.Write(srcConnID) for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v)) } - return buf.Bytes() + return buf.Bytes(), nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/mint_utils.go b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go index b32a0905..657adb56 100644 --- a/vendor/github.com/lucas-clemente/quic-go/mint_utils.go +++ b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go @@ -1,70 +1,15 @@ package quic import ( - "bytes" gocrypto "crypto" "crypto/tls" "crypto/x509" "errors" - "fmt" - "io" "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/crypto" - "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" ) -type mintController struct { - csc *handshake.CryptoStreamConn - conn *mint.Conn -} - -var _ handshake.MintTLS = &mintController{} - -func newMintController( - csc *handshake.CryptoStreamConn, - mconf *mint.Config, - pers protocol.Perspective, -) handshake.MintTLS { - var conn *mint.Conn - if pers == protocol.PerspectiveClient { - conn = mint.Client(csc, mconf) - } else { - conn = mint.Server(csc, mconf) - } - return &mintController{ - csc: csc, - conn: conn, - } -} - -func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { - return mc.conn.ConnectionState().CipherSuite -} - -func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - return mc.conn.ComputeExporter(label, context, keyLength) -} - -func (mc *mintController) Handshake() mint.Alert { - return mc.conn.Handshake() -} - -func (mc *mintController) State() mint.State { - return mc.conn.ConnectionState().HandshakeState -} - -func (mc *mintController) ConnectionState() mint.ConnectionState { - return mc.conn.ConnectionState() -} - -func (mc *mintController) SetCryptoStream(stream io.ReadWriter) { - mc.csc.SetStream(stream) -} - func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { mconf := &mint.Config{ NonBlocking: true, @@ -105,64 +50,3 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf } return mconf, nil } - -// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets -// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0. -func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, logger utils.Logger, version protocol.VersionNumber) (*wire.StreamFrame, error) { - decrypted, err := aead.Open(data[:0], data, hdr.PacketNumber, hdr.Raw) - if err != nil { - return nil, err - } - var frame *wire.StreamFrame - r := bytes.NewReader(decrypted) - for { - f, err := wire.ParseNextFrame(r, hdr, version) - if err != nil { - return nil, err - } - var ok bool - if frame, ok = f.(*wire.StreamFrame); ok || frame == nil { - break - } - } - if frame == nil { - return nil, errors.New("Packet doesn't contain a STREAM_FRAME") - } - if frame.StreamID != version.CryptoStreamID() { - return nil, fmt.Errorf("Received STREAM_FRAME for wrong stream (Stream ID %d)", frame.StreamID) - } - // We don't need a check for the stream ID here. - // The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream. - if frame.Offset != 0 { - return nil, errors.New("received stream data with non-zero offset") - } - if logger.Debug() { - logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) - hdr.Log(logger) - wire.LogFrame(logger, frame, false) - } - return frame, nil -} - -// packUnencryptedPacket provides a low-overhead way to pack a packet. -// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. -func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective, logger utils.Logger) ([]byte, error) { - raw := *getPacketBuffer() - buffer := bytes.NewBuffer(raw[:0]) - if err := hdr.Write(buffer, pers, hdr.Version); err != nil { - return nil, err - } - payloadStartIndex := buffer.Len() - if err := f.Write(buffer, hdr.Version); err != nil { - return nil, err - } - raw = raw[0:buffer.Len()] - _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) - raw = raw[0 : buffer.Len()+aead.Overhead()] - if logger.Debug() { - logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) - hdr.Log(logger) - wire.LogFrame(logger, f, true) - } - return raw, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/mockgen.go index 65f38546..5b7cd4f0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/mockgen.go +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen.go @@ -1,16 +1,19 @@ package quic -//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" -//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI" -//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource" -//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager" -//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go" -//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker" -//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_unpacker_test.go mock_unpacker_test.go" -//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD" -//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" -//go:generate sh -c "goimports -w mock*_test.go" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI" +//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager" +//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" +//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD" +//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD" +//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner" +//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager" +//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer" diff --git a/vendor/github.com/lucas-clemente/quic-go/multiplexer.go b/vendor/github.com/lucas-clemente/quic-go/multiplexer.go new file mode 100644 index 00000000..c4482ac2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/multiplexer.go @@ -0,0 +1,63 @@ +package quic + +import ( + "fmt" + "net" + "sync" + + "github.com/lucas-clemente/quic-go/internal/utils" +) + +var ( + connMuxerOnce sync.Once + connMuxer multiplexer +) + +type multiplexer interface { + AddConn(net.PacketConn, int) (packetHandlerManager, error) +} + +type connManager struct { + connIDLen int + manager packetHandlerManager +} + +// The connMultiplexer listens on multiple net.PacketConns and dispatches +// incoming packets to the session handler. +type connMultiplexer struct { + mutex sync.Mutex + + conns map[net.PacketConn]connManager + newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests + + logger utils.Logger +} + +var _ multiplexer = &connMultiplexer{} + +func getMultiplexer() multiplexer { + connMuxerOnce.Do(func() { + connMuxer = &connMultiplexer{ + conns: make(map[net.PacketConn]connManager), + logger: utils.DefaultLogger.WithPrefix("muxer"), + newPacketHandlerManager: newPacketHandlerMap, + } + }) + return connMuxer +} + +func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + p, ok := m.conns[c] + if !ok { + manager := m.newPacketHandlerManager(c, connIDLen, m.logger) + p = connManager{connIDLen: connIDLen, manager: manager} + m.conns[c] = p + } + if p.connIDLen != connIDLen { + return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) + } + return p.manager, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go b/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go new file mode 100644 index 00000000..35f9bdfa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go @@ -0,0 +1,198 @@ +package quic + +import ( + "bytes" + "fmt" + "net" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// The packetHandlerMap stores packetHandlers, identified by connection ID. +// It is used: +// * by the server to store sessions +// * when multiplexing outgoing connections to store clients +type packetHandlerMap struct { + mutex sync.RWMutex + + conn net.PacketConn + connIDLen int + + handlers map[string] /* string(ConnectionID)*/ packetHandler + server unknownPacketHandler + closed bool + + deleteClosedSessionsAfter time.Duration + + logger utils.Logger +} + +var _ packetHandlerManager = &packetHandlerMap{} + +func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager { + m := &packetHandlerMap{ + conn: conn, + connIDLen: connIDLen, + handlers: make(map[string]packetHandler), + deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, + logger: logger, + } + go m.listen() + return m +} + +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { + h.mutex.Lock() + h.handlers[string(id)] = handler + h.mutex.Unlock() +} + +func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.removeByConnectionIDAsString(string(id)) +} + +func (h *packetHandlerMap) removeByConnectionIDAsString(id string) { + h.mutex.Lock() + h.handlers[id] = nil + h.mutex.Unlock() + + time.AfterFunc(h.deleteClosedSessionsAfter, func() { + h.mutex.Lock() + delete(h.handlers, id) + h.mutex.Unlock() + }) +} + +func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { + h.mutex.Lock() + h.server = s + h.mutex.Unlock() +} + +func (h *packetHandlerMap) CloseServer() { + h.mutex.Lock() + h.server = nil + var wg sync.WaitGroup + for id, handler := range h.handlers { + if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer { + wg.Add(1) + go func(id string, handler packetHandler) { + // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + _ = handler.Close() + h.removeByConnectionIDAsString(id) + wg.Done() + }(id, handler) + } + } + h.mutex.Unlock() + wg.Wait() +} + +func (h *packetHandlerMap) close(e error) error { + h.mutex.Lock() + if h.closed { + h.mutex.Unlock() + return nil + } + h.closed = true + + var wg sync.WaitGroup + for _, handler := range h.handlers { + if handler != nil { + wg.Add(1) + go func(handler packetHandler) { + handler.destroy(e) + wg.Done() + }(handler) + } + } + + if h.server != nil { + h.server.closeWithError(e) + } + h.mutex.Unlock() + wg.Wait() + return nil +} + +func (h *packetHandlerMap) listen() { + for { + data := *getPacketBuffer() + data = data[:protocol.MaxReceivePacketSize] + // The packet size should not exceed protocol.MaxReceivePacketSize bytes + // If it does, we only read a truncated packet, which will then end up undecryptable + n, addr, err := h.conn.ReadFrom(data) + if err != nil { + h.close(err) + return + } + data = data[:n] + + if err := h.handlePacket(addr, data); err != nil { + h.logger.Debugf("error handling packet from %s: %s", addr, err) + } + } +} + +func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { + rcvTime := time.Now() + + r := bytes.NewReader(data) + iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen) + // drop the packet if we can't parse the header + if err != nil { + return fmt.Errorf("error parsing invariant header: %s", err) + } + + h.mutex.RLock() + handler, ok := h.handlers[string(iHdr.DestConnectionID)] + server := h.server + h.mutex.RUnlock() + + var sentBy protocol.Perspective + var version protocol.VersionNumber + var handlePacket func(*receivedPacket) + if ok && handler == nil { + // Late packet for closed session + return nil + } + if !ok { + if server == nil { // no server set + return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) + } + handlePacket = server.handlePacket + sentBy = protocol.PerspectiveClient + version = iHdr.Version + } else { + sentBy = handler.GetPerspective().Opposite() + version = handler.GetVersion() + handlePacket = handler.handlePacket + } + + hdr, err := iHdr.Parse(r, sentBy, version) + if err != nil { + return fmt.Errorf("error parsing header: %s", err) + } + hdr.Raw = data[:len(data)-r.Len()] + packetData := data[len(data)-r.Len():] + + if hdr.IsLongHeader && hdr.Version.UsesLengthInHeader() { + if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { + return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) + } + packetData = packetData[:int(hdr.PayloadLen)] + // TODO(#1312): implement parsing of compound packets + } + + handlePacket(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: packetData, + rcvTime: rcvTime, + }) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go index e4fa653a..9b2cccd5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -46,11 +46,15 @@ type streamFrameSource interface { } type packetPacker struct { - connectionID protocol.ConnectionID - perspective protocol.Perspective - version protocol.VersionNumber - divNonce []byte - cryptoSetup sealingManager + destConnID protocol.ConnectionID + srcConnID protocol.ConnectionID + + perspective protocol.Perspective + version protocol.VersionNumber + cryptoSetup sealingManager + + token []byte + divNonce []byte packetNumberGenerator *packetNumberGenerator getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen @@ -67,10 +71,13 @@ type packetPacker struct { numNonRetransmittableAcks int } -func newPacketPacker(connectionID protocol.ConnectionID, +func newPacketPacker( + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen, remoteAddr net.Addr, // only used for determining the max packet size + token []byte, divNonce []byte, cryptoSetup sealingManager, streamFramer streamFrameSource, @@ -93,7 +100,9 @@ func newPacketPacker(connectionID protocol.ConnectionID, return &packetPacker{ cryptoSetup: cryptoSetup, divNonce: divNonce, - connectionID: connectionID, + token: token, + destConnID: destConnID, + srcConnID: srcConnID, perspective: perspective, version: version, streams: streamFramer, @@ -167,7 +176,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP var payloadLength protocol.ByteCount header := p.getHeader(encLevel) - headerLength, err := header.GetLength(p.perspective, p.version) + headerLength, err := header.GetLength(p.version) if err != nil { return nil, err } @@ -293,7 +302,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - headerLength, err := header.GetLength(p.perspective, p.version) + headerLength, err := header.GetLength(p.version) if err != nil { return nil, err } @@ -347,7 +356,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream() header := p.getHeader(encLevel) - headerLength, err := header.GetLength(p.perspective, p.version) + headerLength, err := header.GetLength(p.version) if err != nil { return nil, err } @@ -446,35 +455,38 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header packetNumberLen := p.getPacketNumberLen(pnum) header := &wire.Header{ - ConnectionID: p.connectionID, PacketNumber: pnum, PacketNumberLen: packetNumberLen, + Version: p.version, } - if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { - header.PacketNumberLen = protocol.PacketNumberLen4 + if p.version.UsesIETFHeaderFormat() && encLevel != protocol.EncryptionForwardSecure { header.IsLongHeader = true + header.SrcConnectionID = p.srcConnID + if !p.version.UsesVarintPacketNumbers() { + header.PacketNumberLen = protocol.PacketNumberLen4 + } + // Set the payload len to maximum size. + // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns. + header.PayloadLen = p.maxPacketSize if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient { header.Type = protocol.PacketTypeInitial + header.Token = p.token } else { header.Type = protocol.PacketTypeHandshake } } - if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { - header.OmitConnectionID = true + if !p.omitConnectionID || encLevel != protocol.EncryptionForwardSecure { + header.DestConnectionID = p.destConnID } if !p.version.UsesTLS() { if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { + header.Type = protocol.PacketType0RTT header.DiversificationNonce = p.divNonce } if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { header.VersionFlag = true - header.Version = p.version - } - } else { - if encLevel != protocol.EncryptionForwardSecure { - header.Version = p.version } } return header @@ -488,6 +500,20 @@ func (p *packetPacker) writeAndSealPacket( raw := *getPacketBuffer() buffer := bytes.NewBuffer(raw[:0]) + // the payload length is only needed for Long Headers + if header.IsLongHeader { + if header.Type == protocol.PacketTypeInitial { + headerLen, _ := header.GetLength(p.version) + header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen + } else { + payloadLen := protocol.ByteCount(sealer.Overhead()) + for _, frame := range payloadFrames { + payloadLen += frame.Length(p.version) + } + header.PayloadLen = payloadLen + } + } + if err := header.Write(buffer, p.perspective, p.version); err != nil { return nil, err } @@ -541,6 +567,10 @@ func (p *packetPacker) SetOmitConnectionID() { p.omitConnectionID = true } +func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) { + p.destConnID = connID +} + func (p *packetPacker) SetMaxPacketSize(size protocol.ByteCount) { p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, size) } diff --git a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go index 9fc158f1..43c7bcf6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go @@ -8,7 +8,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -28,9 +27,12 @@ type receiveStream struct { sender streamSender - frameQueue *streamFrameSorter - readPosInFrame int - readOffset protocol.ByteCount + frameQueue *frameSorter + readOffset protocol.ByteCount + + currentFrame []byte + currentFrameIsLast bool // is the currentFrame the last frame on this stream + readPosInFrame int closeForShutdownErr error cancelReadErr error @@ -61,7 +63,7 @@ func newReceiveStream( streamID: streamID, sender: sender, flowController: flowController, - frameQueue: newStreamFrameSorter(), + frameQueue: newFrameSorter(), readChan: make(chan struct{}, 1), version: version, } @@ -73,48 +75,57 @@ func (s *receiveStream) StreamID() protocol.StreamID { // Read implements io.Reader. It is not thread safe! func (s *receiveStream) Read(p []byte) (int, error) { + completed, n, err := s.readImpl(p) + if completed { + s.sender.onStreamCompleted(s.streamID) + } + return n, err +} + +func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.finRead { - return 0, io.EOF + return false, 0, io.EOF } if s.canceledRead { - return 0, s.cancelReadErr + return false, 0, s.cancelReadErr } if s.resetRemotely { - return 0, s.resetRemotelyErr + return false, 0, s.resetRemotelyErr } if s.closedForShutdown { - return 0, s.closeForShutdownErr + return false, 0, s.closeForShutdownErr } bytesRead := 0 for bytesRead < len(p) { - frame := s.frameQueue.Head() - if frame == nil && bytesRead > 0 { - return bytesRead, s.closeForShutdownErr + if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { + s.dequeueNextFrame() + } + if s.currentFrame == nil && bytesRead > 0 { + return false, bytesRead, s.closeForShutdownErr } for { // Stop waiting on errors if s.closedForShutdown { - return bytesRead, s.closeForShutdownErr + return false, bytesRead, s.closeForShutdownErr } if s.canceledRead { - return bytesRead, s.cancelReadErr + return false, bytesRead, s.cancelReadErr } if s.resetRemotely { - return bytesRead, s.resetRemotelyErr + return false, bytesRead, s.resetRemotelyErr } deadline := s.readDeadline if !deadline.IsZero() && !time.Now().Before(deadline) { - return bytesRead, errDeadline + return false, bytesRead, errDeadline } - if frame != nil { - s.readPosInFrame = int(s.readOffset - frame.Offset) + if s.currentFrame != nil || s.currentFrameIsLast { break } @@ -128,20 +139,21 @@ func (s *receiveStream) Read(p []byte) (int, error) { } } s.mutex.Lock() - frame = s.frameQueue.Head() + if s.currentFrame == nil { + s.dequeueNextFrame() + } } if bytesRead > len(p) { - return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) } - if s.readPosInFrame > int(frame.DataLen()) { - return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) + if s.readPosInFrame > len(s.currentFrame) { + return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } s.mutex.Unlock() - copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) - m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) + m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) s.readPosInFrame += m bytesRead += m s.readOffset += protocol.ByteCount(m) @@ -151,21 +163,20 @@ func (s *receiveStream) Read(p []byte) (int, error) { if !s.resetRemotely { s.flowController.AddBytesRead(protocol.ByteCount(m)) } - // this call triggers the flow controller to increase the flow control window, if necessary - if s.flowController.HasWindowUpdate() { - s.sender.onHasWindowUpdate(s.streamID) - } + // increase the flow control window, if necessary + s.flowController.MaybeQueueWindowUpdate() - if s.readPosInFrame >= int(frame.DataLen()) { - s.frameQueue.Pop() - s.finRead = frame.FinBit - if frame.FinBit { - s.sender.onStreamCompleted(s.streamID) - return bytesRead, io.EOF - } + if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { + s.finRead = true + return true, bytesRead, io.EOF } } - return bytesRead, nil + return false, bytesRead, nil +} + +func (s *receiveStream) dequeueNextFrame() { + s.currentFrame, s.currentFrameIsLast = s.frameQueue.Pop() + s.readPosInFrame = 0 } func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) error { @@ -198,7 +209,7 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { s.mutex.Lock() defer s.mutex.Unlock() - if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { + if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.FinBit); err != nil { return err } s.signalRead() @@ -206,25 +217,33 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { } func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + completed, err := s.handleRstStreamFrameImpl(frame) + if completed { + s.sender.onStreamCompleted(s.streamID) + } + return err +} + +func (s *receiveStream) handleRstStreamFrameImpl(frame *wire.RstStreamFrame) (bool /*completed */, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.closedForShutdown { - return nil + return false, nil } if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil { - return err + return false, err } // In gQUIC, error code 0 has a special meaning. // The peer will reliably continue transmitting, but is not interested in reading from the stream. // We should therefore just continue reading from the stream, until we encounter the FIN bit. if !s.version.UsesIETFFrameFormat() && frame.ErrorCode == 0 { - return nil + return false, nil } // ignore duplicate RST_STREAM frames for this stream (after checking their final offset) if s.resetRemotely { - return nil + return false, nil } s.resetRemotely = true s.resetRemotelyErr = streamCanceledError{ @@ -232,8 +251,7 @@ func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), } s.signalRead() - s.sender.onStreamCompleted(s.streamID) - return nil + return true, nil } func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { diff --git a/vendor/github.com/lucas-clemente/quic-go/send_stream.go b/vendor/github.com/lucas-clemente/quic-go/send_stream.go index 62ef4456..eee66b6e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/send_stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/send_stream.go @@ -133,11 +133,19 @@ func (s *sendStream) Write(p []byte) (int, error) { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + completed, frame, hasMoreData := s.popStreamFrameImpl(maxBytes) + if completed { + s.sender.onStreamCompleted(s.streamID) + } + return frame, hasMoreData +} + +func (s *sendStream) popStreamFrameImpl(maxBytes protocol.ByteCount) (bool /* completed */, *wire.StreamFrame, bool /* has more data to send */) { s.mutex.Lock() defer s.mutex.Unlock() if s.closeForShutdownErr != nil { - return nil, false + return false, nil, false } frame := &wire.StreamFrame{ @@ -147,7 +155,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr } maxDataLen := frame.MaxDataLen(maxBytes, s.version) if maxDataLen == 0 { // a STREAM frame must have at least one byte of data - return nil, s.dataForWriting != nil + return false, nil, s.dataForWriting != nil } frame.Data, frame.FinBit = s.getDataForWriting(maxDataLen) if len(frame.Data) == 0 && !frame.FinBit { @@ -156,24 +164,21 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr // - there's data for writing, but the stream is stream-level flow control blocked // - there's data for writing, but the stream is connection-level flow control blocked if s.dataForWriting == nil { - return nil, false + return false, nil, false } - isBlocked, _ := s.flowController.IsBlocked() - return nil, !isBlocked - } - if frame.FinBit { - s.finSent = true - s.sender.onStreamCompleted(s.streamID) - } else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream - if isBlocked, offset := s.flowController.IsBlocked(); isBlocked { + if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { s.sender.queueControlFrame(&wire.StreamBlockedFrame{ StreamID: s.streamID, Offset: offset, }) - return frame, false + return false, nil, false } + return false, nil, true } - return frame, s.dataForWriting != nil + if frame.FinBit { + s.finSent = true + } + return frame.FinBit, frame, s.dataForWriting != nil } func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { @@ -218,18 +223,22 @@ func (s *sendStream) Close() error { func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error { s.mutex.Lock() - defer s.mutex.Unlock() + completed, err := s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) + s.mutex.Unlock() - return s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) + if completed { + s.sender.onStreamCompleted(s.streamID) + } + return err } // must be called after locking the mutex -func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) error { +func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) (bool /*completed */, error) { if s.canceledWrite { - return nil + return false, nil } if s.finishedWriting { - return fmt.Errorf("CancelWrite for closed stream %d", s.streamID) + return false, fmt.Errorf("CancelWrite for closed stream %d", s.streamID) } s.canceledWrite = true s.cancelWriteErr = writeErr @@ -241,14 +250,13 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr }) // TODO(#991): cancel retransmissions for this stream s.ctxCancel() - s.sender.onStreamCompleted(s.streamID) - return nil + return true, nil } func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.handleStopSendingFrameImpl(frame) + if completed := s.handleStopSendingFrameImpl(frame); completed { + s.sender.onStreamCompleted(s.streamID) + } } func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { @@ -261,7 +269,10 @@ func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { } // must be called after locking the mutex -func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) { +func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) bool /*completed*/ { + s.mutex.Lock() + defer s.mutex.Unlock() + writeErr := streamCanceledError{ errorCode: frame.ErrorCode, error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), @@ -270,7 +281,8 @@ func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) { if !s.version.UsesIETFFrameFormat() { errorCode = errorCodeStoppingGQUIC } - s.cancelWriteImpl(errorCode, writeErr) + completed, _ := s.cancelWriteImpl(errorCode, writeErr) + return completed } func (s *sendStream) Context() context.Context { diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go index 1e56f0b9..d9af6a43 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/server.go @@ -1,10 +1,10 @@ package quic import ( - "bytes" "crypto/tls" "errors" "fmt" + "io" "net" "sync" "time" @@ -14,26 +14,64 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/qerr" ) // packetHandler handles packets type packetHandler interface { + handlePacket(*receivedPacket) + io.Closer + destroy(error) + GetVersion() protocol.VersionNumber + GetPerspective() protocol.Perspective +} + +type unknownPacketHandler interface { + handlePacket(*receivedPacket) + closeWithError(error) error +} + +type packetHandlerManager interface { + Add(protocol.ConnectionID, packetHandler) + SetServer(unknownPacketHandler) + Remove(protocol.ConnectionID) + CloseServer() +} + +type quicSession interface { Session - getCryptoStream() cryptoStreamI - handshakeStatus() <-chan error handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber run() error + destroy(error) closeRemote(error) } +type sessionRunner interface { + onHandshakeComplete(Session) + removeConnectionID(protocol.ConnectionID) +} + +type runner struct { + onHandshakeCompleteImpl func(Session) + removeConnectionIDImpl func(protocol.ConnectionID) +} + +func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) } +func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) } + +var _ sessionRunner = &runner{} + // A Listener of QUIC type server struct { + mutex sync.Mutex + tlsConf *tls.Config config *Config conn net.PacketConn + // If the server is started with ListenAddr, we create a packet conn. + // If it is started with Listen, we take a packet conn as a parameter. + createdPacketConn bool supportsTLS bool serverTLS *serverTLS @@ -41,25 +79,25 @@ type server struct { certChain crypto.CertChain scfg *handshake.ServerConfig - sessionsMutex sync.RWMutex - sessions map[protocol.ConnectionID]packetHandler - closed bool + sessionHandler packetHandlerManager + + serverError error + errorChan chan struct{} + closed bool - serverError error sessionQueue chan Session - errorChan chan struct{} - // set as members, so they can be set in the tests - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error) - deleteClosedSessionsAfter time.Duration + sessionRunner sessionRunner + // set as a member, so they can be set in the tests + newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error) logger utils.Logger } var _ Listener = &server{} +var _ unknownPacketHandler = &server{} // ListenAddr creates a QUIC server listening on a given address. -// The listener is not active until Serve() is called. // The tls.Config must not be nil, the quic.Config may be nil. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) @@ -70,13 +108,21 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err if err != nil { return nil, err } - return Listen(conn, tlsConf, config) + serv, err := listen(conn, tlsConf, config) + if err != nil { + return nil, err + } + serv.createdPacketConn = true + return serv, nil } // Listen listens for QUIC connections on a given net.PacketConn. -// The listener is not active until Serve() is called. // The tls.Config must not be nil, the quic.Config may be nil. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { + return listen(conn, tlsConf, config) +} + +func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) { certChain := crypto.NewCertChain(tlsConf) kex, err := crypto.NewCurve25519KEX() if err != nil { @@ -100,36 +146,43 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, } } - s := &server{ - conn: conn, - tlsConf: tlsConf, - config: config, - certChain: certChain, - scfg: scfg, - sessions: map[protocol.ConnectionID]packetHandler{}, - newSession: newSession, - deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, - sessionQueue: make(chan Session, 5), - errorChan: make(chan struct{}), - supportsTLS: supportsTLS, - logger: utils.DefaultLogger, + sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength) + if err != nil { + return nil, err } + s := &server{ + conn: conn, + tlsConf: tlsConf, + config: config, + certChain: certChain, + scfg: scfg, + newSession: newSession, + sessionHandler: sessionHandler, + sessionQueue: make(chan Session, 5), + errorChan: make(chan struct{}), + supportsTLS: supportsTLS, + logger: utils.DefaultLogger.WithPrefix("server"), + } + s.setup() if supportsTLS { if err := s.setupTLS(); err != nil { return nil, err } } - go s.serve() + sessionHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } -func (s *server) setupTLS() error { - cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie, s.logger) - if err != nil { - return err +func (s *server) setup() { + s.sessionRunner = &runner{ + onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess }, + removeConnectionIDImpl: s.sessionHandler.Remove, } - serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger) +} + +func (s *server) setupTLS() error { + serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger) if err != nil { return err } @@ -141,16 +194,10 @@ func (s *server) setupTLS() error { case <-s.errorChan: return case tlsSession := <-sessionChan: - connID := tlsSession.connID - sess := tlsSession.sess - s.sessionsMutex.Lock() - if _, ok := s.sessions[connID]; ok { // drop this session if it already exists - s.sessionsMutex.Unlock() - continue - } - s.sessions[connID] = sess - s.sessionsMutex.Unlock() - s.runHandshakeAndSession(sess, connID) + // The connection ID is a randomly chosen value. + // It is safe to assume that it doesn't collide with other randomly chosen values. + serverSession := newServerSession(tlsSession.sess, s.config, s.logger) + s.sessionHandler.Add(tlsSession.connID, serverSession) } } }() @@ -218,6 +265,15 @@ func populateServerConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDLen := config.ConnectionIDLength + if connIDLen == 0 { + connIDLen = protocol.DefaultConnectionIDLength + } + for _, v := range versions { + if v == protocol.Version44 { + connIDLen = protocol.ConnectionIDLenGQUIC + } + } return &Config{ Versions: versions, @@ -229,27 +285,7 @@ func populateServerConfig(config *Config) *Config { MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, - } -} - -// serve listens on an existing PacketConn -func (s *server) serve() { - for { - data := *getPacketBuffer() - data = data[:protocol.MaxReceivePacketSize] - // The packet size should not exceed protocol.MaxReceivePacketSize bytes - // If it does, we only read a truncated packet, which will then end up undecryptable - n, remoteAddr, err := s.conn.ReadFrom(data) - if err != nil { - s.serverError = err - close(s.errorChan) - _ = s.Close() - return - } - data = data[:n] - if err := s.handlePacket(s.conn, remoteAddr, data); err != nil { - s.logger.Errorf("error handling packet: %s", err.Error()) - } + ConnectionIDLength: connIDLen, } } @@ -266,171 +302,119 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { - s.sessionsMutex.Lock() + s.mutex.Lock() + defer s.mutex.Unlock() if s.closed { - s.sessionsMutex.Unlock() return nil } - s.closed = true + return s.closeWithMutex() +} - var wg sync.WaitGroup - for _, session := range s.sessions { - if session != nil { - wg.Add(1) - go func(sess packetHandler) { - // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - _ = sess.Close(nil) - wg.Done() - }(session) - } +func (s *server) closeWithMutex() error { + s.sessionHandler.CloseServer() + if s.serverError == nil { + s.serverError = errors.New("server closed") } - s.sessionsMutex.Unlock() - wg.Wait() - - err := s.conn.Close() - <-s.errorChan // wait for serve() to return + var err error + // If the server was started with ListenAddr, we created the packet conn. + // We need to close it in order to make the go routine reading from that conn return. + if s.createdPacketConn { + err = s.conn.Close() + } + s.closed = true + close(s.errorChan) return err } +func (s *server) closeWithError(e error) error { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.closed { + return nil + } + s.serverError = e + return s.closeWithMutex() +} + // Addr returns the server's network address func (s *server) Addr() net.Addr { return s.conn.LocalAddr() } -func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet []byte) error { - rcvTime := time.Now() - - r := bytes.NewReader(packet) - hdr, err := wire.ParseHeaderSentByClient(r) - if err != nil { - return qerr.Error(qerr.InvalidPacketHeader, err.Error()) +func (s *server) handlePacket(p *receivedPacket) { + if err := s.handlePacketImpl(p); err != nil { + s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) } - hdr.Raw = packet[:len(packet)-r.Len()] - packetData := packet[len(packet)-r.Len():] - connID := hdr.ConnectionID +} - if hdr.Type == protocol.PacketTypeInitial { - if s.supportsTLS { - go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData) +func (s *server) handlePacketImpl(p *receivedPacket) error { + hdr := p.header + + if hdr.VersionFlag || hdr.IsLongHeader { + // send a Version Negotiation Packet if the client is speaking a different protocol version + if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + return s.sendVersionNegotiationPacket(p) } + } + if hdr.Type == protocol.PacketTypeInitial && hdr.Version.UsesTLS() { + go s.serverTLS.HandleInitial(p) return nil } - s.sessionsMutex.RLock() - session, sessionKnown := s.sessions[connID] - s.sessionsMutex.RUnlock() - - if sessionKnown && session == nil { - // Late packet for closed session - return nil - } - - // ignore all Public Reset packets - if hdr.ResetFlag { - if sessionKnown { - var pr *wire.PublicReset - pr, err = wire.ParsePublicReset(r) - if err != nil { - s.logger.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.", hdr.ConnectionID) - } else { - s.logger.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) - } - } else { - s.logger.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) - } - return nil - } - - // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset - // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. - // TODO(#943): implement sending of IETF draft style stateless resets - if !sessionKnown && (!hdr.VersionFlag && hdr.Type != protocol.PacketTypeInitial) { - _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) - return err - } - - // a session is only created once the client sent a supported version - // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated - // it is safe to drop it - if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - return nil - } - - // send a Version Negotiation Packet if the client is speaking a different protocol version - // since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet - if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - // drop packets that are too small to be valid first packets - if len(packet) < protocol.MinClientHelloSize+len(hdr.Raw) { - return errors.New("dropping small packet with unknown version") - } - s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) - _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) + // TODO(#943): send Stateless Reset, if this an IETF QUIC packet + if !hdr.VersionFlag && !hdr.Version.UsesIETFHeaderFormat() { + _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr) return err } // This is (potentially) a Client Hello. // Make sure it has the minimum required size before spending any more ressources on it. - if !sessionKnown && len(packet) < protocol.MinClientHelloSize+len(hdr.Raw) { + if len(p.data) < protocol.MinClientHelloSize { return errors.New("dropping small packet for unknown connection") } - if !sessionKnown { - version := hdr.Version - if !protocol.IsSupportedVersion(s.config.Versions, version) { - return errors.New("Server BUG: negotiated version not supported") - } - - s.logger.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) - session, err = s.newSession( - &conn{pconn: pconn, currentAddr: remoteAddr}, - version, - hdr.ConnectionID, - s.scfg, - s.tlsConf, - s.config, - s.logger, - ) - if err != nil { - return err - } - s.sessionsMutex.Lock() - s.sessions[connID] = session - s.sessionsMutex.Unlock() - - s.runHandshakeAndSession(session, connID) + var destConnID, srcConnID protocol.ConnectionID + if hdr.Version.UsesIETFHeaderFormat() { + srcConnID = hdr.DestConnectionID + } else { + destConnID = hdr.DestConnectionID + srcConnID = hdr.DestConnectionID } - session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) + s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr) + sess, err := s.newSession( + &conn{pconn: s.conn, currentAddr: p.remoteAddr}, + s.sessionRunner, + hdr.Version, + destConnID, + srcConnID, + s.scfg, + s.tlsConf, + s.config, + s.logger, + ) + if err != nil { + return err + } + s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger)) + go sess.run() + sess.handlePacket(p) return nil } -func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) { - go func() { - _ = session.run() - // session.run() returns as soon as the session is closed - s.removeConnection(connID) - }() +func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error { + hdr := p.header + s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) - go func() { - if err := <-session.handshakeStatus(); err != nil { - return + var data []byte + if hdr.IsPublicHeader { + data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions) + } else { + var err error + data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) + if err != nil { + return err } - s.sessionQueue <- session - }() -} - -func (s *server) removeConnection(id protocol.ConnectionID) { - s.sessionsMutex.Lock() - s.sessions[id] = nil - s.sessionsMutex.Unlock() - - time.AfterFunc(s.deleteClosedSessionsAfter, func() { - s.sessionsMutex.Lock() - delete(s.sessions, id) - s.sessionsMutex.Unlock() - }) + } + _, err := s.conn.WriteTo(data, p.remoteAddr) + return err } diff --git a/vendor/github.com/lucas-clemente/quic-go/server_session.go b/vendor/github.com/lucas-clemente/quic-go/server_session.go new file mode 100644 index 00000000..51743b3a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/server_session.go @@ -0,0 +1,63 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type serverSession struct { + quicSession + + config *Config + + logger utils.Logger +} + +var _ packetHandler = &serverSession{} + +func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler { + return &serverSession{ + quicSession: sess, + config: config, + logger: logger, + } +} + +func (s *serverSession) handlePacket(p *receivedPacket) { + if err := s.handlePacketImpl(p); err != nil { + s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) + } +} + +func (s *serverSession) handlePacketImpl(p *receivedPacket) error { + hdr := p.header + // ignore all Public Reset packets + if hdr.ResetFlag { + return fmt.Errorf("Received unexpected Public Reset for connection %s", hdr.DestConnectionID) + } + + // Probably an old packet that was sent by the client before the version was negotiated. + // It is safe to drop it. + if (hdr.VersionFlag || hdr.IsLongHeader) && hdr.Version != s.quicSession.GetVersion() { + return nil + } + + if hdr.IsLongHeader { + switch hdr.Type { + case protocol.PacketTypeHandshake, protocol.PacketType0RTT: // 0-RTT accepted for gQUIC 44 + // nothing to do here. Packet will be passed to the session. + default: + // Note that this also drops 0-RTT packets. + return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) + } + } + + s.quicSession.handlePacket(p) + return nil +} + +func (s *serverSession) GetPerspective() protocol.Perspective { + return protocol.PerspectiveServer +} diff --git a/vendor/github.com/lucas-clemente/quic-go/server_tls.go b/vendor/github.com/lucas-clemente/quic-go/server_tls.go index 9f387409..01508df3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server_tls.go +++ b/vendor/github.com/lucas-clemente/quic-go/server_tls.go @@ -1,48 +1,34 @@ package quic import ( + "bytes" "crypto/tls" "errors" - "fmt" "net" "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/qerr" ) -type nullAEAD struct { - aead crypto.AEAD -} - -var _ quicAEAD = &nullAEAD{} - -func (n *nullAEAD) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return n.aead.Open(dst, src, packetNumber, associatedData) -} - -func (n *nullAEAD) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return nil, errors.New("no 1-RTT keys") -} - type tlsSession struct { connID protocol.ConnectionID - sess packetHandler + sess quicSession } type serverTLS struct { - conn net.PacketConn - config *Config - supportedVersions []protocol.VersionNumber - mintConf *mint.Config - params *handshake.TransportParameters - newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) + conn net.PacketConn + config *Config + mintConf *mint.Config + params *handshake.TransportParameters + cookieGenerator *handshake.CookieGenerator - sessionChan chan<- tlsSession + newSession func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, *mint.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error) + + sessionRunner sessionRunner + sessionChan chan<- tlsSession logger utils.Logger } @@ -50,45 +36,48 @@ type serverTLS struct { func newServerTLS( conn net.PacketConn, config *Config, - cookieHandler *handshake.CookieHandler, + runner sessionRunner, tlsConf *tls.Config, logger utils.Logger, ) (*serverTLS, <-chan tlsSession, error) { + cookieGenerator, err := handshake.NewCookieGenerator() + if err != nil { + return nil, nil, err + } + params := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + IdleTimeout: config.IdleTimeout, + MaxBidiStreams: uint16(config.MaxIncomingStreams), + MaxUniStreams: uint16(config.MaxIncomingUniStreams), + DisableMigration: true, + // TODO(#855): generate a real token + StatelessResetToken: bytes.Repeat([]byte{42}, 16), + } mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) if err != nil { return nil, nil, err } - mconf.RequireCookie = true - cs, err := mint.NewDefaultCookieProtector() - if err != nil { - return nil, nil, err - } - mconf.CookieProtector = cs - mconf.CookieHandler = cookieHandler sessionChan := make(chan tlsSession) s := &serverTLS{ - conn: conn, - config: config, - supportedVersions: config.Versions, - mintConf: mconf, - sessionChan: sessionChan, - params: &handshake.TransportParameters{ - StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, - ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - IdleTimeout: config.IdleTimeout, - MaxBidiStreams: uint16(config.MaxIncomingStreams), - MaxUniStreams: uint16(config.MaxIncomingUniStreams), - }, - logger: logger, + conn: conn, + config: config, + mintConf: mconf, + sessionRunner: runner, + sessionChan: sessionChan, + cookieGenerator: cookieGenerator, + params: params, + newSession: newTLSServerSession, + logger: logger, } - s.newMintConn = s.newMintConnImpl return s, sessionChan, nil } -func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { - s.logger.Debugf("Received a Packet. Handling it statelessly.") - sess, err := s.handleInitialImpl(remoteAddr, hdr, data) +func (s *serverTLS) HandleInitial(p *receivedPacket) { + // TODO: add a check that DestConnID == SrcConnID + s.logger.Debugf("<- Received Initial packet.") + sess, connID, err := s.handleInitialImpl(p) if err != nil { s.logger.Errorf("Error occurred handling initial packet: %s", err) return @@ -97,130 +86,93 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data [] return } s.sessionChan <- tlsSession{ - connID: hdr.ConnectionID, + connID: connID, sess: sess, } } -// will be set to s.newMintConn by the constructor -func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { - extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v, s.logger) - conf := s.mintConf.Clone() - conf.ExtensionHandler = extHandler - return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil +func (s *serverTLS) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) { + hdr := p.header + if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { + return nil, nil, errors.New("dropping Initial packet with too short connection ID") + } + if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize { + return nil, nil, errors.New("dropping too small Initial packet") + } + + var cookie *handshake.Cookie + if len(hdr.Token) > 0 { + c, err := s.cookieGenerator.DecodeToken(hdr.Token) + if err == nil { + cookie = c + } + } + if !s.config.AcceptCookie(p.remoteAddr, cookie) { + // Log the Initial packet now. + // If no Retry is sent, the packet will be logged by the session. + p.header.Log(s.logger) + return nil, nil, s.sendRetry(p.remoteAddr, hdr) + } + + extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, hdr.Version, s.logger) + mconf := s.mintConf.Clone() + mconf.ExtensionHandler = extHandler + + // A server is allowed to perform multiple Retries. + // It doesn't make much sense, but it's something that our API allows. + // In that case it must use a source connection ID of at least 8 bytes. + connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + if err != nil { + return nil, nil, err + } + s.logger.Debugf("Changing connection ID to %s.", connID) + sess, err := s.newSession( + &conn{pconn: s.conn, currentAddr: p.remoteAddr}, + s.sessionRunner, + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + 1, + s.config, + mconf, + s.params, + s.logger, + hdr.Version, + ) + if err != nil { + return nil, nil, err + } + go sess.run() + sess.handlePacket(p) + return sess, connID, nil } -func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error { - ccf := &wire.ConnectionCloseFrame{ - ErrorCode: qerr.HandshakeFailed, - ReasonPhrase: closeErr.Error(), - } - replyHdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID - PacketNumber: 1, // random packet number - Version: clientHdr.Version, - } - data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger) +func (s *serverTLS) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { + token, err := s.cookieGenerator.NewToken(remoteAddr) if err != nil { return err } - _, err = s.conn.WriteTo(data, remoteAddr) - return err -} - -func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { - if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize { - return nil, errors.New("dropping too small Initial packet") - } - // check version, if not matching send VNP - if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { - s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) - _, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.supportedVersions), remoteAddr) - return nil, err - } - - // unpack packet and check stream frame contents - aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version) - if err != nil { - return nil, err - } - frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version) - if err != nil { - s.logger.Debugf("Error unpacking initial packet: %s", err) - return nil, nil - } - sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) - if err != nil { - if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { - s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr) - } - return nil, err - } - return sess, nil -} - -func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) { - version := hdr.Version - bc := handshake.NewCryptoStreamConn(remoteAddr) - bc.AddDataForReading(frame.Data) - tls, paramsChan, err := s.newMintConn(bc, version) - if err != nil { - return nil, err - } - alert := tls.Handshake() - if alert == mint.AlertStatelessRetry { - // the HelloRetryRequest was written to the bufferConn - // Take that data and write send a Retry packet - replyHdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - ConnectionID: hdr.ConnectionID, // echo the client's connection ID - PacketNumber: hdr.PacketNumber, // echo the client's packet number - Version: version, - } - f := &wire.StreamFrame{ - StreamID: version.CryptoStreamID(), - Data: bc.GetDataForWriting(), - } - data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger) - if err != nil { - return nil, err - } - _, err = s.conn.WriteTo(data, remoteAddr) - return nil, err - } - if alert != mint.AlertNoAlert { - return nil, alert - } - if tls.State() != mint.StateServerNegotiated { - return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State()) - } - if alert := tls.Handshake(); alert != mint.AlertNoAlert { - return nil, alert - } - if tls.State() != mint.StateServerWaitFlight2 { - return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) - } - params := <-paramsChan - sess, err := newTLSServerSession( - &conn{pconn: s.conn, currentAddr: remoteAddr}, - hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here - protocol.PacketNumber(1), // TODO: use a random packet number here - s.config, - tls, - bc, - aead, - ¶ms, - version, - s.logger, - ) - if err != nil { - return nil, err - } - cs := sess.getCryptoStream() - cs.setReadOffset(frame.DataLen()) - bc.SetStream(cs) - return sess, nil + connID, err := protocol.GenerateConnectionIDForInitial() + if err != nil { + return err + } + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + Version: hdr.Version, + SrcConnectionID: connID, + DestConnectionID: hdr.SrcConnectionID, + OrigDestConnectionID: hdr.DestConnectionID, + Token: token, + } + s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID) + replyHdr.Log(s.logger) + buf := &bytes.Buffer{} + if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil { + return err + } + if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { + s.logger.Debugf("Error sending Retry: %s", err) + } + return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/session.go b/vendor/github.com/lucas-clemente/quic-go/session.go index ed73637c..81422d1f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/session.go +++ b/vendor/github.com/lucas-clemente/quic-go/session.go @@ -10,9 +10,9 @@ import ( "sync" "time" + "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/congestion" - "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -50,6 +50,10 @@ type cryptoStreamHandler interface { ConnectionState() handshake.ConnectionState } +type divNonceSetter interface { + SetDiversificationNonce([]byte) error +} + type receivedPacket struct { remoteAddr net.Addr header *wire.Header @@ -63,21 +67,26 @@ var ( ) type closeError struct { - err error - remote bool + err error + remote bool + sendClose bool } // A Session is a QUIC session type session struct { - connectionID protocol.ConnectionID - perspective protocol.Perspective - version protocol.VersionNumber - config *Config + sessionRunner sessionRunner + + destConnID protocol.ConnectionID + srcConnID protocol.ConnectionID + + perspective protocol.Perspective + version protocol.VersionNumber + config *Config conn connection streamsMap streamManager - cryptoStream cryptoStreamI + cryptoStream cryptoStream rttStats *congestion.RTTStats @@ -91,7 +100,6 @@ type session struct { packer *packetPacker cryptoStreamHandler cryptoStreamHandler - divNonceChan chan<- []byte // only set for the client receivedPackets chan *receivedPacket sendingScheduled chan struct{} @@ -111,11 +119,7 @@ type session struct { paramsChan <-chan handshake.TransportParameters // the handshakeEvent channel is passed to the CryptoSetup. // It receives when it makes sense to try decrypting undecryptable packets. - handshakeEvent <-chan struct{} - // handshakeChan is returned by handshakeStatus. - // It receives any error that might occur during the handshake. - // It is closed when the handshake is complete. - handshakeChan chan error + handshakeEvent <-chan struct{} handshakeComplete bool receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this @@ -146,18 +150,23 @@ var _ streamSender = &session{} // newSession makes a new session func newSession( conn connection, + sessionRunner sessionRunner, v protocol.VersionNumber, - connectionID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, scfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger, -) (packetHandler, error) { +) (quicSession, error) { + logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID) paramsChan := make(chan handshake.TransportParameters) handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, - connectionID: connectionID, + sessionRunner: sessionRunner, + srcConnID: srcConnID, + destConnID: destConnID, perspective: protocol.PerspectiveServer, version: v, config: config, @@ -178,7 +187,7 @@ func newSession( } cs, err := newCryptoSetup( s.cryptoStream, - s.connectionID, + srcConnID, s.conn.RemoteAddr(), s.version, divNonce, @@ -197,10 +206,13 @@ func newSession( s.unpacker = newPacketUnpackerGQUIC(cs, s.version) s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) - s.packer = newPacketPacker(s.connectionID, + s.packer = newPacketPacker( + destConnID, + srcConnID, 1, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), + nil, // no token divNonce, cs, s.streamFramer, @@ -213,20 +225,25 @@ func newSession( // declare this as a variable, so that we can it mock it in the tests var newClientSession = func( conn connection, + sessionRunner sessionRunner, hostname string, v protocol.VersionNumber, - connectionID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiation logger utils.Logger, -) (packetHandler, error) { +) (quicSession, error) { + logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID) paramsChan := make(chan handshake.TransportParameters) handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, - connectionID: connectionID, + sessionRunner: sessionRunner, + srcConnID: srcConnID, + destConnID: destConnID, perspective: protocol.PerspectiveClient, version: v, config: config, @@ -242,10 +259,10 @@ var newClientSession = func( IdleTimeout: s.config.IdleTimeout, OmitConnectionID: s.config.RequestConnectionIDOmission, } - cs, divNonceChan, err := newCryptoSetupClient( + cs, err := newCryptoSetupClient( s.cryptoStream, hostname, - s.connectionID, + destConnID, s.version, tlsConf, transportParams, @@ -259,14 +276,16 @@ var newClientSession = func( return nil, err } s.cryptoStreamHandler = cs - s.divNonceChan = divNonceChan s.unpacker = newPacketUnpackerGQUIC(cs, s.version) s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) - s.packer = newPacketPacker(s.connectionID, + s.packer = newPacketPacker( + destConnID, + srcConnID, 1, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), + nil, // no token nil, // no diversification nonce cs, s.streamFramer, @@ -278,41 +297,50 @@ var newClientSession = func( func newTLSServerSession( conn connection, - connectionID protocol.ConnectionID, + runner sessionRunner, + origConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, config *Config, - tls handshake.MintTLS, - cryptoStreamConn *handshake.CryptoStreamConn, - nullAEAD crypto.AEAD, + mintConf *mint.Config, peerParams *handshake.TransportParameters, - v protocol.VersionNumber, logger utils.Logger, -) (packetHandler, error) { + v protocol.VersionNumber, +) (quicSession, error) { handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, + sessionRunner: runner, config: config, - connectionID: connectionID, + srcConnID: srcConnID, + destConnID: destConnID, perspective: protocol.PerspectiveServer, version: v, handshakeEvent: handshakeEvent, logger: logger, } s.preSetup() - cs := handshake.NewCryptoSetupTLSServer( - tls, - cryptoStreamConn, - nullAEAD, + cs, err := handshake.NewCryptoSetupTLSServer( + s.cryptoStream, + origConnID, + mintConf, handshakeEvent, v, ) + if err != nil { + return nil, err + } s.cryptoStreamHandler = cs s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) - s.packer = newPacketPacker(s.connectionID, + s.packer = newPacketPacker( + s.destConnID, + s.srcConnID, initialPacketNumber, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), + nil, // no token nil, // no diversification nonce cs, s.streamFramer, @@ -331,20 +359,24 @@ func newTLSServerSession( // declare this as a variable, such that we can it mock it in the tests var newTLSClientSession = func( conn connection, - hostname string, - v protocol.VersionNumber, - connectionID protocol.ConnectionID, - config *Config, - tls handshake.MintTLS, + runner sessionRunner, + token []byte, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + conf *Config, + mintConf *mint.Config, paramsChan <-chan handshake.TransportParameters, initialPacketNumber protocol.PacketNumber, logger utils.Logger, -) (packetHandler, error) { + v protocol.VersionNumber, +) (quicSession, error) { handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, - config: config, - connectionID: connectionID, + sessionRunner: runner, + config: conf, + srcConnID: srcConnID, + destConnID: destConnID, perspective: protocol.PerspectiveClient, version: v, handshakeEvent: handshakeEvent, @@ -352,13 +384,11 @@ var newTLSClientSession = func( logger: logger, } s.preSetup() - tls.SetCryptoStream(s.cryptoStream) cs, err := handshake.NewCryptoSetupTLSClient( s.cryptoStream, - s.connectionID, - hostname, + s.destConnID, + mintConf, handshakeEvent, - tls, v, ) if err != nil { @@ -368,10 +398,13 @@ var newTLSClientSession = func( s.unpacker = newPacketUnpacker(cs, s.version) s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) - s.packer = newPacketPacker(s.connectionID, + s.packer = newPacketPacker( + s.destConnID, + s.srcConnID, initialPacketNumber, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), + token, nil, // no diversification nonce cs, s.streamFramer, @@ -383,10 +416,11 @@ var newTLSClientSession = func( func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} - s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger) + s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), + s.onHasConnectionWindowUpdate, s.rttStats, s.logger, ) @@ -394,7 +428,6 @@ func (s *session) preSetup() { } func (s *session) postSetup() error { - s.handshakeChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -406,8 +439,8 @@ func (s *session) postSetup() error { s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.version) - s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.connFlowController, s.packer.QueueControlFrame) return nil } @@ -417,7 +450,7 @@ func (s *session) run() error { go func() { if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil { - s.Close(err) + s.closeLocal(err) } }() @@ -482,7 +515,8 @@ runLoop: pacingDeadline = s.sentPacketHandler.TimeUntilSend() } if s.config.KeepAlive && !s.keepAlivePingSent && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 { - // send the PING frame since there is no activity in the session + // send a PING frame since there is no activity in the session + s.logger.Debugf("Sending a keep-alive ping to keep the connection alive.") s.packer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true } else if !pacingDeadline.IsZero() && now.Before(pacingDeadline) { @@ -508,12 +542,11 @@ runLoop: } } - // only send the error the handshakeChan when the handshake is not completed yet - // otherwise this chan will already be closed - if !s.handshakeComplete { - s.handshakeChan <- closeErr.err + if err := s.handleCloseError(closeErr); err != nil { + s.logger.Infof("Handling close error failed: %s", err) } - s.handleCloseError(closeErr) + s.logger.Infof("Connection %s closed.", s.srcConnID) + s.sessionRunner.removeConnectionID(s.srcConnID) return closeErr.err } @@ -560,20 +593,35 @@ func (s *session) handleHandshakeEvent(completed bool) { } s.handshakeComplete = true s.handshakeEvent = nil // prevent this case from ever being selected again - if !s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient { - // In gQUIC, there's no equivalent to the Finished message in TLS - // The server knows that the handshake is complete when it receives the first forward-secure packet sent by the client. - // We need to make sure that the client actually sends such a packet. - s.packer.QueueControlFrame(&wire.PingFrame{}) - s.scheduleSending() + s.sessionRunner.onHandshakeComplete(s) + + // In gQUIC, the server completes the handshake first (after sending the SHLO). + // In TLS 1.3, the client completes the handshake first (after sending the CFIN). + // We need to make sure they learn about the peer completing the handshake, + // in order to stop retransmitting handshake packets. + // They will stop retransmitting handshake packets when receiving the first forward-secure packet. + // We need to make sure that a retransmittable forward-secure packet is sent, + // independent from the application protocol. + if (!s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient) || + (s.version.UsesTLS() && s.perspective == protocol.PerspectiveServer) { + s.queueControlFrame(&wire.PingFrame{}) + s.sentPacketHandler.SetHandshakeComplete() } - close(s.handshakeChan) } func (s *session) handlePacketImpl(p *receivedPacket) error { + hdr := p.header + // The server can change the source connection ID with the first Handshake packet. + // After this, all packets with a different source connection have to be ignored. + if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID) + return nil + } if s.perspective == protocol.PerspectiveClient { if divNonce := p.header.DiversificationNonce; len(divNonce) > 0 { - s.divNonceChan <- divNonce + if err := s.cryptoStreamHandler.(divNonceSetter).SetDiversificationNonce(divNonce); err != nil { + return err + } } } @@ -582,25 +630,20 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { p.rcvTime = time.Now() } - s.receivedFirstPacket = true - s.lastNetworkActivityTime = p.rcvTime - s.keepAlivePingSent = false - hdr := p.header - data := p.data - // Calculate packet number hdr.PacketNumber = protocol.InferPacketNumber( hdr.PacketNumberLen, s.largestRcvdPacketNumber, hdr.PacketNumber, + s.version, ) - packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) + packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data) if s.logger.Debug() { if err != nil { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID) } else { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID, packet.encryptionLevel) } hdr.Log(s.logger) } @@ -609,13 +652,26 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { return err } - // In TLS 1.3, the client considers the handshake complete as soon as - // it received the server's Finished message and sent its Finished. - // We have to wait for the first forward-secure packet from the server before - // deleting all handshake packets from the history. - if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.EncryptionForwardSecure { - s.receivedFirstForwardSecurePacket = true - s.sentPacketHandler.SetHandshakeComplete() + // The server can change the source connection ID with the first Handshake packet. + if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID) + s.destConnID = hdr.SrcConnectionID + s.packer.ChangeDestConnectionID(s.destConnID) + } + + s.receivedFirstPacket = true + s.lastNetworkActivityTime = p.rcvTime + s.keepAlivePingSent = false + + // In gQUIC, the server completes the handshake first (after sending the SHLO). + // In TLS 1.3, the client completes the handshake first (after sending the CFIN). + // We know that the peer completed the handshake as soon as we receive a forward-secure packet. + if (!s.version.UsesTLS() && s.perspective == protocol.PerspectiveServer) || + (s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient) { + if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.EncryptionForwardSecure { + s.receivedFirstForwardSecurePacket = true + s.sentPacketHandler.SetHandshakeComplete() + } } s.lastRcvdPacketNumber = hdr.PacketNumber @@ -662,6 +718,11 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve case *wire.StopSendingFrame: err = s.handleStopSendingFrame(frame) case *wire.PingFrame: + case *wire.PathChallengeFrame: + s.handlePathChallengeFrame(frame) + case *wire.PathResponseFrame: + // since we don't send PATH_CHALLENGEs, we don't expect PATH_RESPONSEs + err = errors.New("unexpected PATH_RESPONSE frame") default: return errors.New("Session BUG: unexpected frame type") } @@ -760,6 +821,10 @@ func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { return nil } +func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { + s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) +} + func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { return err @@ -768,9 +833,17 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt return nil } +// closeLocal closes the session and send a CONNECTION_CLOSE containing the error func (s *session) closeLocal(e error) { s.closeOnce.Do(func() { - s.closeChan <- closeError{err: e, remote: false} + s.closeChan <- closeError{err: e, sendClose: true, remote: false} + }) +} + +// destroy closes the session without sending the error on the wire +func (s *session) destroy(e error) { + s.closeOnce.Do(func() { + s.closeChan <- closeError{err: e, sendClose: false, remote: false} }) } @@ -780,10 +853,16 @@ func (s *session) closeRemote(e error) { }) } -// Close the connection. If err is nil it will be set to qerr.PeerGoingAway. +// Close the connection. It sends a qerr.PeerGoingAway. // It waits until the run loop has stopped before returning -func (s *session) Close(e error) error { - s.closeLocal(e) +func (s *session) Close() error { + s.closeLocal(nil) + <-s.ctx.Done() + return nil +} + +func (s *session) CloseWithError(code protocol.ApplicationErrorCode, e error) error { + s.closeLocal(qerr.Error(qerr.ErrorCode(code), e.Error())) <-s.ctx.Done() return nil } @@ -800,7 +879,7 @@ func (s *session) handleCloseError(closeErr closeError) error { } // Don't log 'normal' reasons if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout { - s.logger.Infof("Closing connection %x", s.connectionID) + s.logger.Infof("Closing connection %s.", s.srcConnID) } else { s.logger.Errorf("Closing session with error: %s", closeErr.err.Error()) } @@ -808,7 +887,7 @@ func (s *session) handleCloseError(closeErr closeError) error { s.cryptoStream.closeForShutdown(quicErr) s.streamsMap.CloseWithError(quicErr) - if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry { + if !closeErr.sendClose { return nil } @@ -818,7 +897,6 @@ func (s *session) handleCloseError(closeErr closeError) error { } if quicErr.ErrorCode == qerr.DecryptionFailure || - quicErr == handshake.ErrHOLExperiment || quicErr == handshake.ErrNSTPExperiment { return s.sendPublicReset(s.lastRcvdPacketNumber) } @@ -859,24 +937,10 @@ sendLoop: // There will only be a new ACK after receiving new packets. // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. return s.maybeSendAckOnlyPacket() - case ackhandler.SendRTO: - // try to send a retransmission first - sentPacket, err := s.maybeSendRetransmission() - if err != nil { + case ackhandler.SendTLP, ackhandler.SendRTO: + if err := s.sendProbePacket(); err != nil { return err } - if !sentPacket { - // In RTO mode, a probe packet has to be sent. - // Add a PING frame to make sure a (retransmittable) packet will be sent. - s.queueControlFrame(&wire.PingFrame{}) - sentPacket, err := s.sendPacket() - if err != nil { - return err - } - if !sentPacket { - return errors.New("session BUG: expected a packet to be sent in RTO mode") - } - } numPacketsSent++ case ackhandler.SendRetransmission: sentPacket, err := s.maybeSendRetransmission() @@ -954,9 +1018,9 @@ func (s *session) maybeSendRetransmission() (bool, error) { } if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { - s.logger.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + s.logger.Debugf("Dequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) } else { - s.logger.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) + s.logger.Debugf("Dequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) } if s.version.UsesStopWaitingFrames() { @@ -979,10 +1043,34 @@ func (s *session) maybeSendRetransmission() (bool, error) { return true, nil } -func (s *session) sendPacket() (bool, error) { - if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { - s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset}) +func (s *session) sendProbePacket() error { + p, err := s.sentPacketHandler.DequeueProbePacket() + if err != nil { + return err } + s.logger.Debugf("Sending a retransmission for %#x as a probe packet.", p.PacketNumber) + + if s.version.UsesStopWaitingFrames() { + s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) + } + packets, err := s.packer.PackRetransmission(p) + if err != nil { + return err + } + ackhandlerPackets := make([]*ackhandler.Packet, len(packets)) + for i, packet := range packets { + ackhandlerPackets[i] = packet.ToAckHandlerPacket() + } + s.sentPacketHandler.SentPacketsAsRetransmission(ackhandlerPackets, p.PacketNumber) + for _, packet := range packets { + if err := s.sendPackedPacket(packet); err != nil { + return err + } + } + return nil +} + +func (s *session) sendPacket() (bool, error) { if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) } @@ -1031,7 +1119,7 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel) packet.header.Log(s.logger) for _, frame := range packet.frames { wire.LogFrame(s.logger, frame, true) @@ -1096,12 +1184,13 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), initialSendWindow, + s.onHasStreamWindowUpdate, s.rttStats, s.logger, ) } -func (s *session) newCryptoStream() cryptoStreamI { +func (s *session) newCryptoStream() cryptoStream { id := s.version.CryptoStreamID() flowController := flowcontrol.NewStreamFlowController( id, @@ -1110,6 +1199,7 @@ func (s *session) newCryptoStream() cryptoStreamI { protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), 0, + s.onHasStreamWindowUpdate, s.rttStats, s.logger, ) @@ -1117,8 +1207,8 @@ func (s *session) newCryptoStream() cryptoStreamI { } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { - s.logger.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) - return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) + s.logger.Infof("Sending PUBLIC_RESET for connection %s, packet number %d", s.destConnID, rejectedPacketNumber) + return s.conn.Write(wire.WritePublicReset(s.destConnID, rejectedPacketNumber, 0)) } // scheduleSending signals that we have data for sending @@ -1159,8 +1249,13 @@ func (s *session) queueControlFrame(f wire.Frame) { s.scheduleSending() } -func (s *session) onHasWindowUpdate(id protocol.StreamID) { - s.windowUpdateQueue.Add(id) +func (s *session) onHasStreamWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.AddStream(id) + s.scheduleSending() +} + +func (s *session) onHasConnectionWindowUpdate() { + s.windowUpdateQueue.AddConnection() s.scheduleSending() } @@ -1171,7 +1266,7 @@ func (s *session) onHasStreamData(id protocol.StreamID) { func (s *session) onStreamCompleted(id protocol.StreamID) { if err := s.streamsMap.DeleteStream(id); err != nil { - s.Close(err) + s.closeLocal(err) } } @@ -1183,14 +1278,6 @@ func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *session) handshakeStatus() <-chan error { - return s.handshakeChan -} - -func (s *session) getCryptoStream() cryptoStreamI { - return s.cryptoStream -} - func (s *session) GetVersion() protocol.VersionNumber { return s.version } diff --git a/vendor/github.com/lucas-clemente/quic-go/stream.go b/vendor/github.com/lucas-clemente/quic-go/stream.go index 83123493..5d6ce671 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream.go @@ -18,8 +18,8 @@ const ( // The streamSender is notified by the stream about various events. type streamSender interface { queueControlFrame(wire.Frame) - onHasWindowUpdate(protocol.StreamID) onHasStreamData(protocol.StreamID) + // must be called without holding the mutex that is acquired by closeForShutdown onStreamCompleted(protocol.StreamID) } @@ -34,10 +34,6 @@ func (s *uniStreamSender) queueControlFrame(f wire.Frame) { s.streamSender.queueControlFrame(f) } -func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) { - s.streamSender.onHasWindowUpdate(id) -} - func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { s.streamSender.onHasStreamData(id) } @@ -105,7 +101,7 @@ func newStream(streamID protocol.StreamID, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *stream { - s := &stream{sender: sender} + s := &stream{sender: sender, version: version} senderForSendStream := &uniStreamSender{ streamSender: sender, onStreamCompletedImpl: func() { diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go index c453f864..aabfac9f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go @@ -9,7 +9,7 @@ import ( type streamFramer struct { streamGetter streamGetter - cryptoStream cryptoStreamI + cryptoStream cryptoStream version protocol.VersionNumber streamQueueMutex sync.Mutex @@ -19,7 +19,7 @@ type streamFramer struct { } func newStreamFramer( - cryptoStream cryptoStreamI, + cryptoStream cryptoStream, streamGetter streamGetter, v protocol.VersionNumber, ) *streamFramer { diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go index ffca45ef..07e7f53f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/client-state-machine.go @@ -58,7 +58,7 @@ type clientStateStart struct { cookie []byte firstClientHello *HandshakeMessage helloRetryRequest *HandshakeMessage - hsCtx HandshakeContext + hsCtx *HandshakeContext } var _ HandshakeState = &clientStateStart{} @@ -172,8 +172,10 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } ch.CipherSuites = compatibleSuites + // TODO(ekr@rtfm.com): Check that the ticket can be used for early + // data. // Signal early data if we're going to do it - if len(state.Opts.EarlyData) > 0 { + if state.Config.AllowEarlyData && state.helloRetryRequest == nil { state.Params.ClientSendingEarlyData = true ed = &EarlyDataExtension{} err = ch.Extensions.Add(ed) @@ -255,9 +257,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) - } else if len(state.Opts.EarlyData) > 0 { - logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") - return nil, nil, AlertInternalError } else { clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch) if err != nil { @@ -291,7 +290,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ if state.Params.ClientSendingEarlyData { toSend = append(toSend, []HandshakeAction{ RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - SendEarlyData{}, }...) } @@ -302,7 +300,7 @@ type clientStateWaitSH struct { Config *Config Opts ConnectionOptions Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext OfferedDH map[NamedGroup][]byte OfferedPSK PreSharedKey PSK []byte @@ -412,6 +410,11 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, body: h.Sum(nil), } + state.hsCtx.receivedEndOfFlight() + + // TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT + // mode. In DTLS, we also need to bump the sequence number. + // This is a pre-existing defect in Mint. Issue #175. logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") return clientStateStart{ Config: state.Config, @@ -420,7 +423,7 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, cookie: serverCookie.Cookie, firstClientHello: firstClientHello, helloRetryRequest: hm, - }, nil, AlertNoAlert + }, []HandshakeAction{ResetOut{1}}, AlertNoAlert } // This is SH. @@ -515,7 +518,6 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") nextState := clientStateWaitEE{ Config: state.Config, @@ -530,13 +532,20 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, toSend := []HandshakeAction{ RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, } + // We're definitely not going to have to send anything with + // early data. + if !state.Params.ClientSendingEarlyData { + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)}) + } + return nextState, toSend, AlertNoAlert } type clientStateWaitEE struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash masterSecret []byte @@ -596,6 +605,14 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, state.handshakeHash.Write(hm.Marshal()) + toSend := []HandshakeAction{} + + if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData { + // We didn't get 0-RTT, so rekey to handshake. + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } + if state.Params.UsingPSK { logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") nextState := clientStateWaitFinished{ @@ -608,7 +625,7 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") @@ -622,13 +639,13 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } type clientStateWaitCertCR struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash masterSecret []byte @@ -706,7 +723,7 @@ func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta type clientStateWaitCert struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -760,7 +777,7 @@ func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState type clientStateWaitCV struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -861,7 +878,7 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, type clientStateWaitFinished struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -933,6 +950,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS toSend := []HandshakeAction{} if state.Params.UsingEarlyData { + logf(logTypeHandshake, "Sending end of early data") // Note: We only send EOED if the server is actually going to use the early // data. Otherwise, it will never see it, and the transcripts will // mismatch. @@ -942,10 +960,11 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS state.handshakeHash.Write(eoedm.Marshal()) logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) - } - clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) - toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}) + // And then rekey to handshake + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } if state.Params.UsingClientAuth { // Extract constraints from certicateRequest @@ -1045,6 +1064,8 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, }...) + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") nextState := stateConnected{ Params: state.Params, diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go index 0fdba602..05af3e95 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/common.go @@ -25,6 +25,7 @@ const ( RecordTypeAlert RecordType = 21 RecordTypeHandshake RecordType = 22 RecordTypeApplicationData RecordType = 23 + RecordTypeAck RecordType = 25 ) // enum {...} HandshakeType; @@ -166,6 +167,8 @@ const ( type State uint8 const ( + StateInit = 0 + // states valid for the client StateClientStart State = iota StateClientWaitSH @@ -179,6 +182,7 @@ const ( StateServerStart State = iota StateServerRecvdCH StateServerNegotiated + StateServerReadPastEarlyData StateServerWaitEOED StateServerWaitFlight2 StateServerWaitCert @@ -211,6 +215,8 @@ func (s State) String() string { return "Server RECVD_CH" case StateServerNegotiated: return "Server NEGOTIATED" + case StateServerReadPastEarlyData: + return "Server READ_PAST_EARLY_DATA" case StateServerWaitEOED: return "Server WAIT_EOED" case StateServerWaitFlight2: @@ -252,3 +258,9 @@ func (e Epoch) label() string { } return "Application data (updated)" } + +func assert(b bool) { + if !b { + panic("Assertion failed") + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go index 0ce05b2a..12a99171 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/conn.go @@ -13,8 +13,6 @@ import ( "time" ) -var WouldBlock = fmt.Errorf("Would have blocked") - type Certificate struct { Chain []*x509.Certificate PrivateKey crypto.Signer @@ -253,6 +251,8 @@ type ConnectionState struct { PeerCertificates []*x509.Certificate // certificate chain presented by remote peer VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates NextProto string // Selected ALPN proto + UsingPSK bool // Are we using PSK. + UsingEarlyData bool // Did we negotiate 0-RTT. } // Conn implements the net.Conn interface, as with "crypto/tls" @@ -263,8 +263,6 @@ type Conn struct { conn net.Conn isClient bool - EarlyData []byte - state stateConnected hState HandshakeState handshakeMutex sync.Mutex @@ -273,22 +271,27 @@ type Conn struct { readBuffer []byte in, out *RecordLayer - hsCtx HandshakeContext + hsCtx *HandshakeContext } func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { - c := &Conn{conn: conn, config: config, isClient: isClient} + c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}} if !config.UseDTLS { - c.in = NewRecordLayerTLS(c.conn) - c.out = NewRecordLayerTLS(c.conn) - c.hsCtx.hIn = NewHandshakeLayerTLS(c.in) - c.hsCtx.hOut = NewHandshakeLayerTLS(c.out) + c.in = NewRecordLayerTLS(c.conn, directionRead) + c.out = NewRecordLayerTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out) } else { - c.in = NewRecordLayerDTLS(c.conn) - c.out = NewRecordLayerDTLS(c.conn) - c.hsCtx.hIn = NewHandshakeLayerDTLS(c.in) - c.hsCtx.hOut = NewHandshakeLayerDTLS(c.out) + c.in = NewRecordLayerDTLS(c.conn, directionRead) + c.out = NewRecordLayerDTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out) + c.hsCtx.timeoutMS = initialTimeout + c.hsCtx.timers = newTimerSet() + c.hsCtx.waitingNextFlight = true } + c.in.label = c.label() + c.out.label = c.label() c.hsCtx.hIn.nonblocking = c.config.NonBlocking return c } @@ -374,20 +377,54 @@ func (c *Conn) consumeRecord() error { return io.EOF } + case RecordTypeAck: + if !c.hsCtx.hIn.datagram { + logf(logTypeHandshake, "Received ACK in TLS mode") + return AlertUnexpectedMessage + } + return c.hsCtx.processAck(pt.fragment) + case RecordTypeApplicationData: c.readBuffer = append(c.readBuffer, pt.fragment...) logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) + } return err } +func readPartial(in *[]byte, buffer []byte) int { + logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in))) + read := copy(buffer, *in) + *in = (*in)[read:] + + logf(logTypeVerbose, "Returning %v", string(buffer)) + return read +} + // Read application data up to the size of buffer. Handshake and alert records // are consumed by the Conn object directly. func (c *Conn) Read(buffer []byte) (int, error) { if _, connected := c.hState.(stateConnected); !connected { - return 0, errors.New("Read called before the handshake completed") + // Clients can't call Read prior to handshake completion. + if c.isClient { + return 0, errors.New("Read called before the handshake completed") + } + + // Neither can servers that don't allow early data. + if !c.config.AllowEarlyData { + return 0, errors.New("Read called before the handshake completed") + } + + // If there's no early data, then return WouldBlock + if len(c.hsCtx.earlyData) == 0 { + return 0, AlertWouldBlock + } + + return readPartial(&c.hsCtx.earlyData, buffer), nil } + + // The handshake is now connected. logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) if alert := c.Handshake(); alert != AlertNoAlert { return 0, alert @@ -397,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) { return 0, nil } + // Run our timers. + if c.config.UseDTLS { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return 0, AlertInternalError + } + } + // Lock the input channel c.in.Lock() defer c.in.Unlock() @@ -406,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) { // err can be nil if consumeRecord processed a non app-data // record. if err != nil { - if c.config.NonBlocking || err != WouldBlock { + if c.config.NonBlocking || err != AlertWouldBlock { logf(logTypeIO, "conn.Read returns err=%v", err) return 0, err } } } - var read int - n := len(buffer) - logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) - if len(c.readBuffer) <= n { - buffer = buffer[:len(c.readBuffer)] - copy(buffer, c.readBuffer) - read = len(c.readBuffer) - c.readBuffer = c.readBuffer[:0] - } else { - logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) - copy(buffer[:n], c.readBuffer[:n]) - c.readBuffer = c.readBuffer[n:] - read = n - } - - logf(logTypeVerbose, "Returning %v", string(buffer)) - return read, nil + return readPartial(&c.readBuffer, buffer), nil } // Write application data @@ -438,6 +466,10 @@ func (c *Conn) Write(buffer []byte) (int, error) { c.out.Lock() defer c.out.Unlock() + if !c.Writable() { + return 0, errors.New("Write called before the handshake completed (and early data not in use)") + } + // Send full-size fragments var start int sent := 0 @@ -549,13 +581,23 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } case SendQueuedHandshake: - err := c.hsCtx.hOut.SendQueuedMessages() + _, err := c.hsCtx.hOut.SendQueuedMessages() if err != nil { logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) return AlertInternalError } + if c.config.UseDTLS { + c.hsCtx.timers.start(retransmitTimerLabel, + c.hsCtx.handshakeRetransmit, + c.hsCtx.timeoutMS) + } case RekeyIn: logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet) + // Check that we don't have an input data in the handshake frame parser. + if len(c.hsCtx.hIn.frame.remainder) > 0 { + logf(logTypeHandshake, "%s Rekey with data still in handshake buffers", label) + return AlertDecodeError + } err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) if err != nil { logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) @@ -570,61 +612,9 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { return AlertInternalError } - case SendEarlyData: - logf(logTypeHandshake, "%s Sending early data...", label) - _, err := c.Write(c.EarlyData) - if err != nil { - logf(logTypeHandshake, "%s Error writing early data: %v", label, err) - return AlertInternalError - } - - case ReadPastEarlyData: - logf(logTypeHandshake, "%s Reading past early data...", label) - // Scan past all records that fail to decrypt - _, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok := err.(DecryptError) - - for ok { - _, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok = err.(DecryptError) - } - - case ReadEarlyData: - logf(logTypeHandshake, "%s Reading early data...", label) - t, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type(1): %v", label, t) - - for t == RecordTypeApplicationData { - // Read a record into the buffer. Note that this is safe - // in blocking mode because we read the record in in - // PeekRecordType. - pt, err := c.in.ReadRecord() - if err != nil { - logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) - return AlertInternalError - } - - logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) - c.EarlyData = append(c.EarlyData, pt.fragment...) - - t, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type (2): %v", label, t) - } - logf(logTypeHandshake, "%s Done reading early data", label) + case ResetOut: + logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq) + c.out.ResetClear(action.seq) case StorePSK: logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) @@ -637,7 +627,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } default: - logf(logTypeHandshake, "%s Unknown actionuction type", label) + logf(logTypeHandshake, "%s Unknown action type", label) + assert(false) return AlertInternalError } @@ -657,7 +648,6 @@ func (c *Conn) HandshakeSetup() Alert { opts := ConnectionOptions{ ServerName: c.config.ServerName, NextProtos: c.config.NextProtos, - EarlyData: c.EarlyData, } if c.isClient { @@ -706,18 +696,21 @@ type handshakeMessageReaderImpl struct { var _ handshakeMessageReader = &handshakeMessageReaderImpl{} func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) { - hm, err := r.hsCtx.hIn.ReadMessage() - if err == WouldBlock { - return nil, AlertWouldBlock + var hm *HandshakeMessage + var err error + for { + hm, err = r.hsCtx.hIn.ReadMessage() + if err == AlertWouldBlock { + return nil, AlertWouldBlock + } + if err != nil { + logf(logTypeHandshake, "Error reading message: %v", err) + return nil, AlertCloseNotify + } + if hm != nil { + break + } } - if err != nil { - logf(logTypeHandshake, "[client] Error reading message: %v", err) - return nil, AlertCloseNotify - } - - // Once you have read a message, you no longer need the outgoing queue - // for DTLS. - r.hsCtx.hOut.ClearQueuedMessages() return hm, AlertNoAlert } @@ -753,14 +746,21 @@ func (c *Conn) Handshake() Alert { state := c.hState _, connected := state.(stateConnected) - hmr := &handshakeMessageReaderImpl{hsCtx: &c.hsCtx} + hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx} for !connected { var alert Alert var actions []HandshakeAction + // Advance the state machine state, actions, alert = state.Next(hmr) - if alert == WouldBlock { + if alert == AlertWouldBlock { logf(logTypeHandshake, "%s Would block reading message: %s", label, alert) + // If we blocked, then run our timers to see if any have expired. + if c.hsCtx.hIn.datagram { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return AlertInternalError + } + } return AlertWouldBlock } if alert == AlertCloseNotify { @@ -788,6 +788,34 @@ func (c *Conn) Handshake() Alert { if connected { c.state = state.(stateConnected) c.handshakeComplete = true + + if !c.isClient { + // Send NewSessionTicket if configured to + if c.config.SendSessionTickets { + actions, alert := c.state.NewSessionTicket( + c.config.TicketLen, + c.config.TicketLifetime, + c.config.EarlyDataLifetime) + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + } + + // If there is early data, move it into the main buffer + if c.hsCtx.earlyData != nil { + c.readBuffer = c.hsCtx.earlyData + c.hsCtx.earlyData = nil + } + + } else { + assert(c.hsCtx.earlyData == nil) + } } if c.config.NonBlocking { @@ -798,23 +826,6 @@ func (c *Conn) Handshake() Alert { } } - // Send NewSessionTicket if acting as server - if !c.isClient && c.config.SendSessionTickets { - actions, alert := c.state.NewSessionTicket( - c.config.TicketLen, - c.config.TicketLifetime, - c.config.EarlyDataLifetime) - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert - } - } - } - return AlertNoAlert } @@ -848,6 +859,9 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error { } func (c *Conn) GetHsState() State { + if c.hState == nil { + return StateInit + } return c.hState.State() } @@ -878,7 +892,30 @@ func (c *Conn) ConnectionState() ConnectionState { state.NextProto = c.state.Params.NextProto state.VerifiedChains = c.state.verifiedChains state.PeerCertificates = c.state.peerCertificates + state.UsingPSK = c.state.Params.UsingPSK + state.UsingEarlyData = c.state.Params.UsingEarlyData } return state } + +func (c *Conn) Writable() bool { + // If we're connected, we're writable. + if _, connected := c.hState.(stateConnected); connected { + return true + } + + // If we're a client in 0-RTT, then we're writable. + if c.isClient && c.out.cipher.epoch == EpochEarlyData { + return true + } + + return false +} + +func (c *Conn) label() string { + if c.isClient { + return "client" + } + return "server" +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go index df4f1aa1..aa914e3e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/dtls.go @@ -2,14 +2,33 @@ package mint import ( "fmt" + "github.com/bifurcation/mint/syntax" + "time" ) -// This file is a placeholder. DTLS-specific stuff (timer management, -// ACKs, retransmits, etc. will eventually go here. const ( - initialMtu = 1200 + initialMtu = 1200 + initialTimeout = 100 ) +// labels for timers +const ( + retransmitTimerLabel = "handshake retransmit" + ackTimerLabel = "ack timer" +) + +type SentHandshakeFragment struct { + seq uint32 + offset int + fragLength int + record uint64 + acked bool +} + +type DtlsAck struct { + RecordNumbers []uint64 `tls:"head=2"` +} + func wireVersion(h *HandshakeLayer) uint16 { if h.datagram { return dtls12WireVersion @@ -26,3 +45,178 @@ func dtlsConvertVersion(version uint16) uint16 { } panic(fmt.Sprintf("Internal error, unexpected version=%d", version)) } + +// TODO(ekr@rtfm.com): Move these to state-machine.go +func (h *HandshakeContext) handshakeRetransmit() error { + if _, err := h.hOut.SendQueuedMessages(); err != nil { + return err + } + + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + // TODO(ekr@rtfm.com): Back off timer + return nil +} + +func (h *HandshakeContext) sendAck() error { + toack := h.hIn.recvdRecords + + count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU + if len(toack) > count { + toack = toack[:count] + } + logf(logTypeHandshake, "Sending ACK: [%x]", toack) + + ack := &DtlsAck{toack} + body, err := syntax.Marshal(&ack) + if err != nil { + return err + } + err = h.hOut.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAck, + fragment: body, + }) + if err != nil { + return err + } + return nil +} + +func (h *HandshakeContext) processAck(data []byte) error { + // Cancel the retransmit timer because we will be resending + // and possibly re-arming later. + h.timers.cancel(retransmitTimerLabel) + + ack := &DtlsAck{} + read, err := syntax.Unmarshal(data, &ack) + if err != nil { + return err + } + if len(data) != read { + return fmt.Errorf("Invalid encoding: Extra data not consumed") + } + logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers) + + for _, r := range ack.RecordNumbers { + for _, m := range h.sentFragments { + if r == m.record { + logf(logTypeHandshake, "Marking %v %v(%v) as acked", + m.seq, m.offset, m.fragLength) + m.acked = true + } + } + } + + count, err := h.hOut.SendQueuedMessages() + if err != nil { + return err + } + + if count == 0 { + logf(logTypeHandshake, "All messages ACKed") + h.hOut.ClearQueuedMessages() + return nil + } + + // Reset the timer + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + return nil +} + +func (c *Conn) GetDTLSTimeout() (bool, time.Duration) { + return c.hsCtx.timers.remaining() +} + +func (h *HandshakeContext) receivedHandshakeMessage() { + logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight) + // This just enables tests. + if h.hIn == nil { + return + } + + if !h.hIn.datagram { + return + } + + if h.waitingNextFlight { + logf(logTypeHandshake, "Received the start of the flight") + + // Clear the outgoing DTLS queue and terminate the retransmit timer + h.hOut.ClearQueuedMessages() + h.timers.cancel(retransmitTimerLabel) + + // OK, we're not waiting any more. + h.waitingNextFlight = false + } + + // Now pre-emptively arm the ACK timer if it's not armed already. + // We'll automatically dis-arm it at the end of the handshake. + if h.timers.getTimer(ackTimerLabel) == nil { + h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4) + } +} + +func (h *HandshakeContext) receivedEndOfFlight() { + logf(logTypeHandshake, "%p Received the end of the flight", h) + if !h.hIn.datagram { + return + } + + // Empty incoming queue + h.hIn.queued = nil + + // Note that we are waiting for the next flight. + h.waitingNextFlight = true + + // Clear the ACK queue. + h.hIn.recvdRecords = nil + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) +} + +func (h *HandshakeContext) receivedFinalFlight() { + logf(logTypeHandshake, "%p Received final flight", h) + if !h.hIn.datagram { + return + } + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) + + // But send an ACK immediately. + h.sendAck() +} + +func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool { + logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen) + for _, f := range h.sentFragments { + if !f.acked { + continue + } + + if f.seq != seq { + continue + } + + if f.offset > offset { + continue + } + + // At this point, we know that the stored fragment starts + // at or before what we want to send, so check where the end + // is. + if f.offset+f.fragLength < offset+fraglen { + continue + } + + return true + } + + return false +} diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go index 54f40ce2..4ccfc23f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/frame-reader.go @@ -67,7 +67,7 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { f.writeOffset += copied if f.writeOffset < len(f.working) { logf(logTypeVerbose, "Read would have blocked 1") - return nil, nil, WouldBlock + return nil, nil, AlertWouldBlock } // Reset the write offset, because we are now full. f.writeOffset = 0 @@ -94,5 +94,5 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { } logf(logTypeVerbose, "Read would have blocked 2") - return nil, nil, WouldBlock + return nil, nil, AlertWouldBlock } diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go index 888c5f36..de17b30b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/handshake-layer.go @@ -35,7 +35,6 @@ type HandshakeMessage struct { datagram bool offset uint32 // Used for DTLS length uint32 - records []uint64 // Used for DTLS cipher *cipherState } @@ -119,6 +118,7 @@ func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*H } type HandshakeLayer struct { + ctx *HandshakeContext // The handshake we are attached to nonblocking bool // Should we operate in nonblocking mode conn *RecordLayer // Used for reading/writing records frame *frameReader // The buffered frame reader @@ -126,6 +126,7 @@ type HandshakeLayer struct { msgSeq uint32 // The DTLS message sequence number queued []*HandshakeMessage // In/out queue sent []*HandshakeMessage // Sent messages for DTLS + recvdRecords []uint64 // Records we have received. maxFragmentLen int } @@ -152,8 +153,9 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { return int(val), nil } -func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer { +func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} + h.ctx = c h.conn = r h.datagram = false h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) @@ -161,8 +163,9 @@ func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer { return &h } -func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer { +func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} + h.ctx = c h.conn = r h.datagram = true h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) @@ -172,16 +175,25 @@ func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer { func (h *HandshakeLayer) readRecord() error { logf(logTypeVerbose, "Trying to read record") - pt, err := h.conn.ReadRecord() + pt, err := h.conn.readRecordAnyEpoch() if err != nil { return err } - if pt.contentType != RecordTypeHandshake && - pt.contentType != RecordTypeAlert { + switch pt.contentType { + case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck: + default: return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) } + if pt.contentType == RecordTypeAck { + if !h.datagram { + return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS") + } + logf(logTypeIO, "read ACK") + return h.ctx.processAck(pt.fragment) + } + if pt.contentType == RecordTypeAlert { logf(logTypeIO, "read alert %v", pt.fragment[1]) if len(pt.fragment) < 2 { @@ -191,6 +203,19 @@ func (h *HandshakeLayer) readRecord() error { return Alert(pt.fragment[1]) } + assert(h.ctx.hIn.conn != nil) + if pt.epoch != h.ctx.hIn.conn.cipher.epoch { + // This is out of order but we're dropping it. + // TODO(ekr@rtfm.com): If server, need to retransmit Finished. + if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData { + return nil + } + + // Anything else shouldn't happen. + return AlertIllegalParameter + } + + h.recvdRecords = append(h.recvdRecords, pt.seq) h.frame.addChunk(pt.fragment) return nil @@ -227,9 +252,13 @@ func (h *HandshakeLayer) noteMessageDelivered(seq uint32) { func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) { if hm.seq < h.msgSeq { - return nil, WouldBlock + return nil, nil } + // TODO(ekr@rtfm.com): Send an ACK immediately if we got something + // out of order. + h.ctx.receivedHandshakeMessage() + if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { // TODO(ekr@rtfm.com): Check the length? // This is complete. @@ -259,12 +288,12 @@ func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMe func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { if len(h.queued) == 0 { - return nil, WouldBlock + return nil, nil } hm := h.queued[0] if hm.seq != h.msgSeq { - return nil, WouldBlock + return nil, nil } if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { @@ -307,7 +336,7 @@ func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { } - return nil, WouldBlock + return nil, nil } func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { @@ -315,19 +344,19 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { var err error hm, err := h.checkMessageAvailable() - if err == nil { - return hm, err - } - if err != WouldBlock { + if err != nil { return nil, err } + if hm != nil { + return hm, nil + } for { logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) if h.frame.needed() > 0 { logf(logTypeVerbose, "Trying to read a new record") err = h.readRecord() - if err != nil && (h.nonblocking || err != WouldBlock) { + if err != nil && (h.nonblocking || err != AlertWouldBlock) { return nil, err } } @@ -336,7 +365,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { if err == nil { break } - if err != nil && (h.nonblocking || err != WouldBlock) { + if err != nil && (h.nonblocking || err != AlertWouldBlock) { return nil, err } } @@ -370,12 +399,13 @@ func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error { return nil } -func (h *HandshakeLayer) SendQueuedMessages() error { +func (h *HandshakeLayer) SendQueuedMessages() (int, error) { logf(logTypeHandshake, "Sending outgoing messages") - err := h.WriteMessages(h.queued) - h.ClearQueuedMessages() // This isn't going to work for DTLS, but we'll - // get there. - return err + count, err := h.WriteMessages(h.queued) + if !h.datagram { + h.ClearQueuedMessages() + } + return count, err } func (h *HandshakeLayer) ClearQueuedMessages() { @@ -383,7 +413,7 @@ func (h *HandshakeLayer) ClearQueuedMessages() { h.queued = nil } -func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (int, error) { +func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) { var buf []byte // Figure out if we're going to want the full header or just @@ -408,17 +438,35 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int } body := hm.body[start : start+bodylen] + // Now see if this chunk has been ACKed. This doesn't produce ideal + // retransmission but is simple. + if h.ctx.fragmentAcked(hm.seq, start, bodylen) { + logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen) + return false, start + bodylen, nil + } + // Encode the data. if hdrlen > 0 { hm2 := *hm hm2.offset = uint32(start) hm2.body = body buf = hm2.Marshal() + hm = &hm2 } else { buf = body } - return start + bodylen, h.conn.writeRecordWithPadding( + if h.datagram { + // Remember that we sent this. + h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{ + hm.seq, + start, + len(body), + h.conn.cipher.combineSeq(true), + false, + }) + } + return true, start + bodylen, h.conn.writeRecordWithPadding( &TLSPlaintext{ contentType: RecordTypeHandshake, fragment: buf, @@ -426,38 +474,46 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int hm.cipher, 0) } -func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { +func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) { start := int(0) if len(hm.body) > maxHandshakeMessageLen { - return fmt.Errorf("Tried to write a handshake message that's too long") + return 0, fmt.Errorf("Tried to write a handshake message that's too long") } + written := 0 + wrote := false + // Always make one pass through to allow EOED (which is empty). for { var err error - start, err = h.writeFragment(hm, start, h.maxFragmentLen) + wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen) if err != nil { - return err + return 0, err + } + if wrote { + written++ } if start >= len(hm.body) { break } } - return nil + return written, nil } -func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { +func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) { + written := 0 for _, hm := range hms { logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) - err := h.WriteMessage(hm) + wrote, err := h.WriteMessage(hm) if err != nil { - return err + return 0, err } + written += wrote } - return nil + return written, nil } func encodeUint(v uint64, size int, out []byte) []byte { diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go index 4697bbc8..2c80b8d7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/negotiation.go @@ -168,10 +168,11 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") } -func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { - usingEarlyData := gotEarlyData && usingPSK && allowEarlyData - logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) - return usingEarlyData +func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) { + using = gotEarlyData && usingPSK && allowEarlyData + rejected = gotEarlyData && !using + logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected) + return } func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go index 761a868d..5cf8ae2c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/record-layer.go @@ -1,7 +1,6 @@ package mint import ( - "bytes" "crypto/cipher" "fmt" "io" @@ -21,6 +20,13 @@ func (err DecryptError) Error() string { return string(err) } +type direction uint8 + +const ( + directionWrite = direction(1) + directionRead = direction(2) +) + // struct { // ContentType type; // ProtocolVersion record_version [0301 for CH, 0303 for others] @@ -31,20 +37,23 @@ type TLSPlaintext struct { // Omitted: record_version (static) // Omitted: length (computed from fragment) contentType RecordType + epoch Epoch + seq uint64 fragment []byte } type cipherState struct { epoch Epoch // DTLS epoch ivLength int // Length of the seq and nonce fields - seq []byte // Zero-padded sequence number + seq uint64 // Zero-padded sequence number iv []byte // Buffer for the IV cipher cipher.AEAD // AEAD cipher } type RecordLayer struct { sync.Mutex - + label string + direction direction version uint16 // The current version number conn io.ReadWriter // The underlying connection frame *frameReader // The buffered frame reader @@ -52,7 +61,9 @@ type RecordLayer struct { cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" cachedError error // Error on the last record read - cipher *cipherState + cipher *cipherState + readCiphers map[Epoch]*cipherState + datagram bool } @@ -76,7 +87,7 @@ func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { } func newCipherStateNull() *cipherState { - return &cipherState{EpochClear, 0, bytes.Repeat([]byte{0}, sequenceNumberLen), nil, nil} + return &cipherState{EpochClear, 0, 0, nil, nil} } func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) { @@ -85,11 +96,13 @@ func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) return nil, err } - return &cipherState{epoch, len(iv), bytes.Repeat([]byte{0}, sequenceNumberLen), iv, cipher}, nil + return &cipherState{epoch, len(iv), 0, iv, cipher}, nil } -func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { +func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer { r := RecordLayer{} + r.label = "" + r.direction = dir r.conn = conn r.frame = newFrameReader(recordLayerFrameDetails{false}) r.cipher = newCipherStateNull() @@ -97,11 +110,15 @@ func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { return &r } -func NewRecordLayerDTLS(conn io.ReadWriter) *RecordLayer { +func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer { r := RecordLayer{} + r.label = "" + r.direction = dir r.conn = conn r.frame = newFrameReader(recordLayerFrameDetails{true}) r.cipher = newCipherStateNull() + r.readCiphers = make(map[Epoch]*cipherState, 0) + r.readCiphers[0] = r.cipher r.datagram = true return &r } @@ -110,53 +127,67 @@ func (r *RecordLayer) SetVersion(v uint16) { r.version = v } +func (r *RecordLayer) ResetClear(seq uint64) { + r.cipher = newCipherStateNull() + r.cipher.seq = seq +} + func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { cipher, err := newCipherStateAead(epoch, factory, key, iv) if err != nil { return err } r.cipher = cipher + if r.datagram && r.direction == directionRead { + r.readCiphers[epoch] = cipher + } return nil } -func (c *cipherState) formatSeq(datagram bool) []byte { - seq := append([]byte{}, c.seq...) +// TODO(ekr@rtfm.com): This is never used, which is a bug. +func (r *RecordLayer) DiscardReadKey(epoch Epoch) { + if !r.datagram { + return + } + + _, ok := r.readCiphers[epoch] + assert(ok) + delete(r.readCiphers, epoch) +} + +func (c *cipherState) combineSeq(datagram bool) uint64 { + seq := c.seq if datagram { - seq[0] = byte(c.epoch >> 8) - seq[1] = byte(c.epoch & 0xff) + seq |= uint64(c.epoch) << 48 } return seq } -func (c *cipherState) computeNonce(seq []byte) []byte { +func (c *cipherState) computeNonce(seq uint64) []byte { nonce := make([]byte, len(c.iv)) copy(nonce, c.iv) - offset := len(c.iv) - len(seq) - for i, b := range seq { - nonce[i+offset] ^= b + s := seq + + offset := len(c.iv) + for i := 0; i < 8; i++ { + nonce[(offset-i)-1] ^= byte(s & 0xff) + s >>= 8 } + logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce) return nonce } func (c *cipherState) incrementSequenceNumber() { - var i int - for i = len(c.seq) - 1; i >= 0; i-- { - c.seq[i]++ - if c.seq[i] != 0 { - break - } - } - - if i < 0 { + if c.seq >= (1<<48 - 1) { // Not allowed to let sequence number wrap. // Instead, must renegotiate before it does. - // Not likely enough to bother. - // TODO(ekr@rtfm.com): Check for DTLS here - // because the limit is sooner. + // Not likely enough to bother. This is the + // DTLS limit. panic("TLS: sequence number wraparound") } + c.seq++ } func (c *cipherState) overhead() int { @@ -166,8 +197,9 @@ func (c *cipherState) overhead() int { return c.cipher.Overhead() } -func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, padLen int) *TLSPlaintext { - logf(logTypeIO, "Encrypt seq=[%x]", seq) +func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext { + assert(r.direction == directionWrite) + logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq) // Expand the fragment to hold contentType, padding, and overhead originalLen := len(pt.fragment) plaintextLen := originalLen + 1 + padLen @@ -191,8 +223,9 @@ func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, return out } -func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, error) { - logf(logTypeIO, "Decrypt seq=[%x]", seq) +func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) { + assert(r.direction == directionRead) + logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq) if len(pt.fragment) < r.cipher.overhead() { msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead()) return nil, 0, DecryptError(msg) @@ -207,7 +240,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, // Decrypt _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil) if err != nil { - logf(logTypeIO, "AEAD decryption failure [%x]", pt) + logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt) return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") } @@ -222,6 +255,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, // Truncate the message to remove contentType, padding, overhead out.fragment = out.fragment[:newLen] + out.seq = seq return out, padLen, nil } @@ -230,11 +264,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { var err error for { - pt, err = r.nextRecord() + pt, err = r.nextRecord(false) if err == nil { break } - if !block || err != WouldBlock { + if !block || err != AlertWouldBlock { return 0, err } } @@ -242,7 +276,7 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { } func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { - pt, err := r.nextRecord() + pt, err := r.nextRecord(false) // Consume the cached record if there was one r.cachedRecord = nil @@ -251,10 +285,20 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { return pt, err } -func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { +func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) { + pt, err := r.nextRecord(true) + + // Consume the cached record if there was one + r.cachedRecord = nil + r.cachedError = nil + + return pt, err +} + +func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) { cipher := r.cipher if r.cachedRecord != nil { - logf(logTypeIO, "Returning cached record") + logf(logTypeIO, "%s Returning cached record", r.label) return r.cachedRecord, r.cachedError } @@ -262,9 +306,10 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { // // 1. We get a frame // 2. We try to read off the socket and get nothing, in which case - // return WouldBlock + // returnAlertWouldBlock // 3. We get an error. - err := WouldBlock + var err error + err = AlertWouldBlock var header, body []byte for err != nil { @@ -272,24 +317,24 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) n, err := r.conn.Read(buf) if err != nil { - logf(logTypeIO, "Error reading, %v", err) + logf(logTypeIO, "%s Error reading, %v", r.label, err) return nil, err } if n == 0 { - return nil, WouldBlock + return nil, AlertWouldBlock } - logf(logTypeIO, "Read %v bytes", n) + logf(logTypeIO, "%s Read %v bytes", r.label, n) buf = buf[:n] r.frame.addChunk(buf) } header, body, err = r.frame.process() - // Loop around on WouldBlock to see if some + // Loop around onAlertWouldBlock to see if some // data is now available. - if err != nil && err != WouldBlock { + if err != nil && err != AlertWouldBlock { return nil, err } } @@ -299,7 +344,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { switch RecordType(header[0]) { default: return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) - case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: + case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck: pt.contentType = RecordType(header[0]) } @@ -318,28 +363,48 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { pt.fragment = make([]byte, size) copy(pt.fragment, body) + // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data. + // Attempt to decrypt fragment - if cipher.cipher != nil { - seq := cipher.seq - if r.datagram { - seq = header[3:11] - } - // TODO(ekr@rtfm.com): Handle the wrong epoch. + seq := cipher.seq + if r.datagram { // TODO(ekr@rtfm.com): Handle duplicates. - logf(logTypeIO, "RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), seq, pt.contentType, pt.fragment) + seq, _ = decodeUint(header[3:11], 8) + epoch := Epoch(seq >> 48) + + // Look up the cipher suite from the epoch + c, ok := r.readCiphers[epoch] + if !ok { + logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch) + return nil, AlertWouldBlock + } + + if epoch != cipher.epoch { + logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch, + cipher.epoch, allowOldEpoch) + if !allowOldEpoch { + return nil, AlertWouldBlock + } + cipher = c + } + } + + if cipher.cipher != nil { + logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment) pt, _, err = r.decrypt(pt, seq) if err != nil { - logf(logTypeIO, "Decryption failed") + logf(logTypeIO, "%s Decryption failed", r.label) return nil, err } } + pt.epoch = cipher.epoch // Check that plaintext length is not too long if len(pt.fragment) > maxFragmentLen { return nil, fmt.Errorf("tls.record: Plaintext size too big") } - logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment) r.cachedRecord = pt cipher.incrementSequenceNumber() @@ -355,10 +420,9 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error } func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error { - seq := cipher.formatSeq(r.datagram) - + seq := cipher.combineSeq(r.datagram) if cipher.cipher != nil { - logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) pt = r.encrypt(cipher, seq, pt, padLen) } else if padLen > 0 { return fmt.Errorf("tls.record: Padding can only be done on encrypted records") @@ -376,16 +440,17 @@ func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherSta byte(r.version >> 8), byte(r.version & 0xff), byte(length >> 8), byte(length)} } else { + header = make([]byte, 13) version := dtlsConvertVersion(r.version) - header = []byte{byte(pt.contentType), + copy(header, []byte{byte(pt.contentType), byte(version >> 8), byte(version & 0xff), - seq[0], seq[1], seq[2], seq[3], - seq[4], seq[5], seq[6], seq[7], - byte(length >> 8), byte(length)} + }) + encodeUint(seq, 8, header[3:]) + encodeUint(uint64(length), 2, header[11:]) } record := append(header, pt.fragment...) - logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) cipher.incrementSequenceNumber() _, err := r.conn.Write(record) diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go index 0b851f40..f91b22e4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/server-state-machine.go @@ -24,14 +24,17 @@ import ( // | [Send CertificateRequest] // Can send | [Send Certificate + CertificateVerify] // app data --> | Send Finished -// after +--------+--------+ -// here No 0-RTT | | 0-RTT -// | v -// | WAIT_EOED <---+ -// | Recv | | | Recv -// | EndOfEarlyData | | | early data -// | | +-----+ -// +> WAIT_FLIGHT2 <-+ +// after here | +// +-----------+--------+ +// | | | +// Rejected 0-RTT | No | | 0-RTT +// | 0-RTT | | +// | | v +// +---->READ_PAST | WAIT_EOED <---+ +// Decrypt | | | Decrypt | Recv | | | Recv +// error | | | OK + HS | EOED | | | early data +// +-----+ | V | +-----+ +// +---> WAIT_FLIGHT2 <-+ // | // +--------+--------+ // No auth | | Client auth @@ -50,16 +53,17 @@ import ( // // NB: Not using state RECVD_CH // -// State Instructions -// START {} -// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] -// WAIT_EOED RekeyIn; -// WAIT_FLIGHT2 {} -// WAIT_CERT_CR {} -// WAIT_CERT {} -// WAIT_CV {} -// WAIT_FINISHED RekeyIn; RekeyOut; -// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) +// State Instructions +// START {} +// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] +// WAIT_EOED RekeyIn; +// READ_PAST {} +// WAIT_FLIGHT2 {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) // A cookie can be sent to the client in a HRR. type cookie struct { @@ -74,7 +78,7 @@ type cookie struct { type serverStateStart struct { Config *Config conn *Conn - hsCtx HandshakeContext + hsCtx *HandshakeContext } var _ HandshakeState = &serverStateStart{} @@ -235,10 +239,6 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) return nil, nil, AlertInternalError } - if clientSentCookie && initialCipherSuite.Suite != params.Suite { - logf(logTypeHandshake, "[ServerStateStart] Would have selected a different CipherSuite after receiving the client's Cookie") - return nil, nil, AlertInternalError - } } // Figure out if we actually should do DH / PSK @@ -361,7 +361,7 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ // Figure out if we're going to do early data var clientEarlyTrafficSecret []byte connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData] - connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) + connParams.UsingEarlyData, connParams.RejectedEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) if connParams.UsingEarlyData { h := params.Hash.New() h.Write(clientHello.Marshal()) @@ -379,6 +379,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertNoApplicationProtocol } + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. return serverStateNegotiated{ @@ -445,7 +447,7 @@ func (state *serverStateStart) generateHRR(cs CipherSuite, legacySessionId []byt type serverStateNegotiated struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext dhGroup NamedGroup dhPublic []byte dhSecret []byte @@ -731,7 +733,6 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat } toSend = append(toSend, []HandshakeAction{ RekeyIn{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - ReadEarlyData{}, }...) return nextState, toSend, AlertNoAlert } @@ -739,9 +740,9 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") toSend = append(toSend, []HandshakeAction{ RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, - ReadPastEarlyData{}, }...) - waitFlight2 := serverStateWaitFlight2{ + var nextState HandshakeState + nextState = serverStateWaitFlight2{ Config: state.Config, Params: state.Params, hsCtx: state.hsCtx, @@ -753,13 +754,19 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat serverTrafficSecret: serverTrafficSecret, exporterSecret: exporterSecret, } - return waitFlight2, toSend, AlertNoAlert + if state.Params.RejectedEarlyData { + nextState = serverStateReadPastEarlyData{ + hsCtx: state.hsCtx, + next: &nextState, + } + } + return nextState, toSend, AlertNoAlert } type serverStateWaitEOED struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -776,6 +783,38 @@ func (state serverStateWaitEOED) State() State { } func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading early data...") + assert(state.hsCtx.hIn.conn.cipher.epoch == EpochEarlyData) + t, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + if err != nil { + logf(logTypeHandshake, "Server Error reading record type (1): %v", err) + return nil, nil, AlertBadRecordMAC + } + + logf(logTypeHandshake, "Server got record type(1): %v", t) + + if t != RecordTypeApplicationData { + break + } + + // Read a record into the buffer. Note that this is safe + // in blocking mode because we read the record in + // PeekRecordType. + pt, err := state.hsCtx.hIn.conn.ReadRecord() + if err != nil { + logf(logTypeHandshake, "Server error reading early data record: %v", err) + return nil, nil, AlertInternalError + } + + logf(logTypeHandshake, "Server read early data: %x", pt.fragment) + state.hsCtx.earlyData = append(state.hsCtx.earlyData, pt.fragment...) + } + hm, alert := hr.ReadMessage() if alert != AlertNoAlert { return nil, nil, alert @@ -813,10 +852,44 @@ func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState return waitFlight2, toSend, AlertNoAlert } +var _ HandshakeState = &serverStateReadPastEarlyData{} + +type serverStateReadPastEarlyData struct { + hsCtx *HandshakeContext + next *HandshakeState +} + +func (state serverStateReadPastEarlyData) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading past early data...") + // Scan past all records that fail to decrypt + _, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == nil { + break + } + + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + // Continue on DecryptError + _, ok := err.(DecryptError) + if !ok { + return nil, nil, AlertInternalError // Really need something else. + } + } + + return *state.next, nil, AlertNoAlert +} + +func (state serverStateReadPastEarlyData) State() State { + return StateServerReadPastEarlyData +} + type serverStateWaitFlight2 struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -868,7 +941,7 @@ func (state serverStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeSta type serverStateWaitCert struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -940,7 +1013,7 @@ func (state serverStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState type serverStateWaitCV struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte @@ -1023,7 +1096,7 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, type serverStateWaitFinished struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte @@ -1082,6 +1155,8 @@ func (state serverStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS // Compute client traffic keys clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + state.hsCtx.receivedFinalFlight() + logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") nextState := stateConnected{ Params: state.Params, diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go index 7639c5f6..558b76cc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/state-machine.go @@ -17,10 +17,6 @@ type SendQueuedHandshake struct{} type SendEarlyData struct{} -type ReadEarlyData struct{} - -type ReadPastEarlyData struct{} - type RekeyIn struct { epoch Epoch KeySet keySet @@ -31,6 +27,10 @@ type RekeyOut struct { KeySet keySet } +type ResetOut struct { + seq uint64 +} + type StorePSK struct { PSK PreSharedKey } @@ -50,7 +50,6 @@ type AppExtensionHandler interface { type ConnectionOptions struct { ServerName string NextProtos []string - EarlyData []byte } // ConnectionParameters objects represent the parameters negotiated for a @@ -60,6 +59,7 @@ type ConnectionParameters struct { UsingDH bool ClientSendingEarlyData bool UsingEarlyData bool + RejectedEarlyData bool UsingClientAuth bool CipherSuite CipherSuite @@ -69,7 +69,13 @@ type ConnectionParameters struct { // Working state for the handshake. type HandshakeContext struct { - hIn, hOut *HandshakeLayer + timeoutMS uint32 + timers *timerSet + recvdRecords []uint64 + sentFragments []*SentHandshakeFragment + hIn, hOut *HandshakeLayer + waitingNextFlight bool + earlyData []byte } func (hc *HandshakeContext) SetVersion(version uint16) { @@ -84,7 +90,7 @@ func (hc *HandshakeContext) SetVersion(version uint16) { // stateConnected is symmetric between client and server type stateConnected struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext isClient bool cryptoParams CipherSuiteParams resumptionSecret []byte diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/timer.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/timer.go new file mode 100644 index 00000000..0b7f7aff --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/bifurcation/mint/timer.go @@ -0,0 +1,122 @@ +package mint + +import ( + "time" +) + +// This is a simple timer implementation. Timers are stored in a sorted +// list. +// TODO(ekr@rtfm.com): Add a way to uncouple these from the system +// clock. +type timerCb func() error + +type timer struct { + label string + cb timerCb + deadline time.Time + duration uint32 +} + +type timerSet struct { + ts []*timer +} + +func newTimerSet() *timerSet { + return &timerSet{} +} + +func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer { + now := time.Now() + t := timer{ + label, + cb, + now.Add(time.Millisecond * time.Duration(delayMs)), + delayMs, + } + logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline) + + var i int + ntimers := len(ts.ts) + for i = 0; i < ntimers; i++ { + if t.deadline.Before(ts.ts[i].deadline) { + break + } + } + + tmp := make([]*timer, 0, ntimers+1) + tmp = append(tmp, ts.ts[:i]...) + tmp = append(tmp, &t) + tmp = append(tmp, ts.ts[i:]...) + ts.ts = tmp + + return &t +} + +// TODO(ekr@rtfm.com): optimize this now that the list is sorted. +// We should be able to do just one list manipulation, as long +// as we're careful about how we handle inserts during callbacks. +func (ts *timerSet) check(now time.Time) error { + for i, t := range ts.ts { + if now.After(t.deadline) { + ts.ts = append(ts.ts[:i], ts.ts[:i+1]...) + if t.cb != nil { + logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline) + cb := t.cb + t.cb = nil + err := cb() + if err != nil { + return err + } + } + } else { + break + } + } + return nil +} + +// Returns the next time any of the timers would fire. +func (ts *timerSet) remaining() (bool, time.Duration) { + for _, t := range ts.ts { + if t.cb != nil { + return true, time.Until(t.deadline) + } + } + + return false, time.Duration(0) +} + +func (ts *timerSet) cancel(label string) { + for _, t := range ts.ts { + if t.label == label { + t.cancel() + } + } +} + +func (ts *timerSet) getTimer(label string) *timer { + for _, t := range ts.ts { + if t.label == label && t.cb != nil { + return t + } + } + return nil +} + +func (ts *timerSet) getAllTimers() []string { + var ret []string + + for _, t := range ts.ts { + if t.cb != nil { + ret = append(ret, t.label) + } + } + + return ret +} + +func (t *timer) cancel() { + logf(logTypeHandshake, "Timer %s cancelled", t.label) + t.cb = nil + t.label = "" +} diff --git a/vendor/github.com/hashicorp/golang-lru/2q.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/2q.go similarity index 86% rename from vendor/github.com/hashicorp/golang-lru/2q.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/2q.go index 337d9632..e474cd07 100644 --- a/vendor/github.com/hashicorp/golang-lru/2q.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/2q.go @@ -30,9 +30,9 @@ type TwoQueueCache struct { size int recentSize int - recent *simplelru.LRU - frequent *simplelru.LRU - recentEvict *simplelru.LRU + recent simplelru.LRUCache + frequent simplelru.LRUCache + recentEvict simplelru.LRUCache lock sync.RWMutex } @@ -84,7 +84,8 @@ func New2QParams(size int, recentRatio float64, ghostRatio float64) (*TwoQueueCa return c, nil } -func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) { +// Get looks up a key's value from the cache. +func (c *TwoQueueCache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() @@ -105,6 +106,7 @@ func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) { return nil, false } +// Add adds a value to the cache. func (c *TwoQueueCache) Add(key, value interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -160,12 +162,15 @@ func (c *TwoQueueCache) ensureSpace(recentEvict bool) { c.frequent.RemoveOldest() } +// Len returns the number of items in the cache. func (c *TwoQueueCache) Len() int { c.lock.RLock() defer c.lock.RUnlock() return c.recent.Len() + c.frequent.Len() } +// Keys returns a slice of the keys in the cache. +// The frequently used keys are first in the returned slice. func (c *TwoQueueCache) Keys() []interface{} { c.lock.RLock() defer c.lock.RUnlock() @@ -174,6 +179,7 @@ func (c *TwoQueueCache) Keys() []interface{} { return append(k1, k2...) } +// Remove removes the provided key from the cache. func (c *TwoQueueCache) Remove(key interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -188,6 +194,7 @@ func (c *TwoQueueCache) Remove(key interface{}) { } } +// Purge is used to completely clear the cache. func (c *TwoQueueCache) Purge() { c.lock.Lock() defer c.lock.Unlock() @@ -196,13 +203,17 @@ func (c *TwoQueueCache) Purge() { c.recentEvict.Purge() } +// Contains is used to check if the cache contains a key +// without updating recency or frequency. func (c *TwoQueueCache) Contains(key interface{}) bool { c.lock.RLock() defer c.lock.RUnlock() return c.frequent.Contains(key) || c.recent.Contains(key) } -func (c *TwoQueueCache) Peek(key interface{}) (interface{}, bool) { +// Peek is used to inspect the cache value of a key +// without updating recency or frequency. +func (c *TwoQueueCache) Peek(key interface{}) (value interface{}, ok bool) { c.lock.RLock() defer c.lock.RUnlock() if val, ok := c.frequent.Peek(key); ok { diff --git a/vendor/github.com/hashicorp/golang-lru/arc.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/arc.go similarity index 91% rename from vendor/github.com/hashicorp/golang-lru/arc.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/arc.go index a2a25281..555225a2 100644 --- a/vendor/github.com/hashicorp/golang-lru/arc.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/arc.go @@ -18,11 +18,11 @@ type ARCCache struct { size int // Size is the total capacity of the cache p int // P is the dynamic preference towards T1 or T2 - t1 *simplelru.LRU // T1 is the LRU for recently accessed items - b1 *simplelru.LRU // B1 is the LRU for evictions from t1 + t1 simplelru.LRUCache // T1 is the LRU for recently accessed items + b1 simplelru.LRUCache // B1 is the LRU for evictions from t1 - t2 *simplelru.LRU // T2 is the LRU for frequently accessed items - b2 *simplelru.LRU // B2 is the LRU for evictions from t2 + t2 simplelru.LRUCache // T2 is the LRU for frequently accessed items + b2 simplelru.LRUCache // B2 is the LRU for evictions from t2 lock sync.RWMutex } @@ -60,11 +60,11 @@ func NewARC(size int) (*ARCCache, error) { } // Get looks up a key's value from the cache. -func (c *ARCCache) Get(key interface{}) (interface{}, bool) { +func (c *ARCCache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() - // Ff the value is contained in T1 (recent), then + // If the value is contained in T1 (recent), then // promote it to T2 (frequent) if val, ok := c.t1.Peek(key); ok { c.t1.Remove(key) @@ -153,7 +153,7 @@ func (c *ARCCache) Add(key, value interface{}) { // Remove from B2 c.b2.Remove(key) - // Add the key to the frequntly used list + // Add the key to the frequently used list c.t2.Add(key, value) return } @@ -247,7 +247,7 @@ func (c *ARCCache) Contains(key interface{}) bool { // Peek is used to inspect the cache value of a key // without updating recency or frequency. -func (c *ARCCache) Peek(key interface{}) (interface{}, bool) { +func (c *ARCCache) Peek(key interface{}) (value interface{}, ok bool) { c.lock.RLock() defer c.lock.RUnlock() if val, ok := c.t1.Peek(key); ok { diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/doc.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/doc.go new file mode 100644 index 00000000..2547df97 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/doc.go @@ -0,0 +1,21 @@ +// Package lru provides three different LRU caches of varying sophistication. +// +// Cache is a simple LRU cache. It is based on the +// LRU implementation in groupcache: +// https://github.com/golang/groupcache/tree/master/lru +// +// TwoQueueCache tracks frequently used and recently used entries separately. +// This avoids a burst of accesses from taking out frequently used entries, +// at the cost of about 2x computational overhead and some extra bookkeeping. +// +// ARCCache is an adaptive replacement cache. It tracks recent evictions as +// well as recent usage in both the frequent and recent caches. Its +// computational overhead is comparable to TwoQueueCache, but the memory +// overhead is linear with the size of the cache. +// +// ARC has been patented by IBM, so do not use it if that is problematic for +// your program. +// +// All caches in this package take locks while operating, and are therefore +// thread-safe for consumers. +package lru diff --git a/vendor/github.com/hashicorp/golang-lru/lru.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/lru.go similarity index 75% rename from vendor/github.com/hashicorp/golang-lru/lru.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/lru.go index a6285f98..c8d9b0a2 100644 --- a/vendor/github.com/hashicorp/golang-lru/lru.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/lru.go @@ -1,6 +1,3 @@ -// This package provides a simple LRU cache. It is based on the -// LRU implementation in groupcache: -// https://github.com/golang/groupcache/tree/master/lru package lru import ( @@ -11,11 +8,11 @@ import ( // Cache is a thread-safe fixed size LRU cache. type Cache struct { - lru *simplelru.LRU + lru simplelru.LRUCache lock sync.RWMutex } -// New creates an LRU of the given size +// New creates an LRU of the given size. func New(size int) (*Cache, error) { return NewWithEvict(size, nil) } @@ -33,7 +30,7 @@ func NewWithEvict(size int, onEvicted func(key interface{}, value interface{})) return c, nil } -// Purge is used to completely clear the cache +// Purge is used to completely clear the cache. func (c *Cache) Purge() { c.lock.Lock() c.lru.Purge() @@ -41,30 +38,30 @@ func (c *Cache) Purge() { } // Add adds a value to the cache. Returns true if an eviction occurred. -func (c *Cache) Add(key, value interface{}) bool { +func (c *Cache) Add(key, value interface{}) (evicted bool) { c.lock.Lock() defer c.lock.Unlock() return c.lru.Add(key, value) } // Get looks up a key's value from the cache. -func (c *Cache) Get(key interface{}) (interface{}, bool) { +func (c *Cache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() return c.lru.Get(key) } -// Check if a key is in the cache, without updating the recent-ness -// or deleting it for being stale. +// Contains checks if a key is in the cache, without updating the +// recent-ness or deleting it for being stale. func (c *Cache) Contains(key interface{}) bool { c.lock.RLock() defer c.lock.RUnlock() return c.lru.Contains(key) } -// Returns the key value (or undefined if not found) without updating +// Peek returns the key value (or undefined if not found) without updating // the "recently used"-ness of the key. -func (c *Cache) Peek(key interface{}) (interface{}, bool) { +func (c *Cache) Peek(key interface{}) (value interface{}, ok bool) { c.lock.RLock() defer c.lock.RUnlock() return c.lru.Peek(key) @@ -73,16 +70,15 @@ func (c *Cache) Peek(key interface{}) (interface{}, bool) { // ContainsOrAdd checks if a key is in the cache without updating the // recent-ness or deleting it for being stale, and if not, adds the value. // Returns whether found and whether an eviction occurred. -func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evict bool) { +func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evicted bool) { c.lock.Lock() defer c.lock.Unlock() if c.lru.Contains(key) { return true, false - } else { - evict := c.lru.Add(key, value) - return false, evict } + evicted = c.lru.Add(key, value) + return false, evicted } // Remove removes the provided key from the cache. diff --git a/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go similarity index 86% rename from vendor/github.com/hashicorp/golang-lru/simplelru/lru.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go index cb416b39..5673773b 100644 --- a/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go @@ -36,7 +36,7 @@ func NewLRU(size int, onEvict EvictCallback) (*LRU, error) { return c, nil } -// Purge is used to completely clear the cache +// Purge is used to completely clear the cache. func (c *LRU) Purge() { for k, v := range c.items { if c.onEvict != nil { @@ -48,7 +48,7 @@ func (c *LRU) Purge() { } // Add adds a value to the cache. Returns true if an eviction occurred. -func (c *LRU) Add(key, value interface{}) bool { +func (c *LRU) Add(key, value interface{}) (evicted bool) { // Check for existing item if ent, ok := c.items[key]; ok { c.evictList.MoveToFront(ent) @@ -78,17 +78,18 @@ func (c *LRU) Get(key interface{}) (value interface{}, ok bool) { return } -// Check if a key is in the cache, without updating the recent-ness +// Contains checks if a key is in the cache, without updating the recent-ness // or deleting it for being stale. func (c *LRU) Contains(key interface{}) (ok bool) { _, ok = c.items[key] return ok } -// Returns the key value (or undefined if not found) without updating +// Peek returns the key value (or undefined if not found) without updating // the "recently used"-ness of the key. func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) { - if ent, ok := c.items[key]; ok { + var ent *list.Element + if ent, ok = c.items[key]; ok { return ent.Value.(*entry).value, true } return nil, ok @@ -96,7 +97,7 @@ func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) { // Remove removes the provided key from the cache, returning if the // key was contained. -func (c *LRU) Remove(key interface{}) bool { +func (c *LRU) Remove(key interface{}) (present bool) { if ent, ok := c.items[key]; ok { c.removeElement(ent) return true @@ -105,7 +106,7 @@ func (c *LRU) Remove(key interface{}) bool { } // RemoveOldest removes the oldest item from the cache. -func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) { +func (c *LRU) RemoveOldest() (key interface{}, value interface{}, ok bool) { ent := c.evictList.Back() if ent != nil { c.removeElement(ent) @@ -116,7 +117,7 @@ func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) { } // GetOldest returns the oldest entry -func (c *LRU) GetOldest() (interface{}, interface{}, bool) { +func (c *LRU) GetOldest() (key interface{}, value interface{}, ok bool) { ent := c.evictList.Back() if ent != nil { kv := ent.Value.(*entry) diff --git a/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go new file mode 100644 index 00000000..744cac01 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go @@ -0,0 +1,37 @@ +package simplelru + + +// LRUCache is the interface for simple LRU cache. +type LRUCache interface { + // Adds a value to the cache, returns true if an eviction occurred and + // updates the "recently used"-ness of the key. + Add(key, value interface{}) bool + + // Returns key's value from the cache and + // updates the "recently used"-ness of the key. #value, isFound + Get(key interface{}) (value interface{}, ok bool) + + // Check if a key exsists in cache without updating the recent-ness. + Contains(key interface{}) (ok bool) + + // Returns key's value without updating the "recently used"-ness of the key. + Peek(key interface{}) (value interface{}, ok bool) + + // Removes a key from the cache. + Remove(key interface{}) bool + + // Removes the oldest entry from cache. + RemoveOldest() (interface{}, interface{}, bool) + + // Returns the oldest entry from the cache. #key, value, isFound + GetOldest() (interface{}, interface{}, bool) + + // Returns a slice of the keys in the cache, from oldest to newest. + Keys() []interface{} + + // Returns the number of items in the cache. + Len() int + + // Clear all cache entries + Purge() +} diff --git a/vendor/github.com/lucas-clemente/aes12/aes_gcm.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/aes_gcm.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/aes_gcm.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/aes_gcm.go diff --git a/vendor/github.com/lucas-clemente/aes12/asm_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/asm_amd64.s similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/asm_amd64.s rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/asm_amd64.s diff --git a/vendor/github.com/lucas-clemente/aes12/block.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/block.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/block.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/block.go diff --git a/vendor/github.com/lucas-clemente/aes12/cipher.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/cipher.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher.go diff --git a/vendor/github.com/lucas-clemente/aes12/cipher_2.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher_2.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/cipher_2.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher_2.go diff --git a/vendor/github.com/lucas-clemente/aes12/cipher_amd64.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher_amd64.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/cipher_amd64.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher_amd64.go diff --git a/vendor/github.com/lucas-clemente/aes12/cipher_generic.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher_generic.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/cipher_generic.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/cipher_generic.go diff --git a/vendor/github.com/lucas-clemente/aes12/const.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/const.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/const.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/const.go diff --git a/vendor/github.com/lucas-clemente/aes12/gcm.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/gcm.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/gcm.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/gcm.go diff --git a/vendor/github.com/lucas-clemente/aes12/gcm_amd64.s b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/gcm_amd64.s similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/gcm_amd64.s rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/gcm_amd64.s diff --git a/vendor/github.com/lucas-clemente/aes12/xor.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/xor.go similarity index 100% rename from vendor/github.com/lucas-clemente/aes12/xor.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12/xor.go diff --git a/vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_2.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_2.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_2.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_2.go diff --git a/vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_3.go b/vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_3.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_3.go rename to vendor/github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/quic-go-certificates/cert_set_3.go diff --git a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go index ed006aa2..6cd359e5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go +++ b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go @@ -3,6 +3,7 @@ package quic import ( "sync" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -10,29 +11,50 @@ import ( type windowUpdateQueue struct { mutex sync.Mutex - queue map[protocol.StreamID]bool // used as a set - callback func(wire.Frame) - cryptoStream cryptoStreamI - streamGetter streamGetter + queue map[protocol.StreamID]bool // used as a set + queuedConn bool // connection-level window update + + cryptoStream cryptoStream + streamGetter streamGetter + connFlowController flowcontrol.ConnectionFlowController + callback func(wire.Frame) } -func newWindowUpdateQueue(streamGetter streamGetter, cryptoStream cryptoStreamI, cb func(wire.Frame)) *windowUpdateQueue { +func newWindowUpdateQueue( + streamGetter streamGetter, + cryptoStream cryptoStream, + connFC flowcontrol.ConnectionFlowController, + cb func(wire.Frame), +) *windowUpdateQueue { return &windowUpdateQueue{ - queue: make(map[protocol.StreamID]bool), - streamGetter: streamGetter, - cryptoStream: cryptoStream, - callback: cb, + queue: make(map[protocol.StreamID]bool), + streamGetter: streamGetter, + cryptoStream: cryptoStream, + connFlowController: connFC, + callback: cb, } } -func (q *windowUpdateQueue) Add(id protocol.StreamID) { +func (q *windowUpdateQueue) AddStream(id protocol.StreamID) { q.mutex.Lock() q.queue[id] = true q.mutex.Unlock() } +func (q *windowUpdateQueue) AddConnection() { + q.mutex.Lock() + q.queuedConn = true + q.mutex.Unlock() +} + func (q *windowUpdateQueue) QueueAll() { q.mutex.Lock() + // queue a connection-level window update + if q.queuedConn { + q.callback(&wire.MaxDataFrame{ByteOffset: q.connFlowController.GetWindowUpdate()}) + q.queuedConn = false + } + // queue all stream-level window updates var offset protocol.ByteCount for id := range q.queue { if id == q.cryptoStream.StreamID() { diff --git a/vendor/golang.org/x/net/http/httpguts/LICENSE b/vendor/golang.org/x/net/http/httpguts/LICENSE new file mode 100644 index 00000000..6a66aea5 --- /dev/null +++ b/vendor/golang.org/x/net/http/httpguts/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/net/http/httpguts/guts.go b/vendor/golang.org/x/net/http/httpguts/guts.go new file mode 100644 index 00000000..e6cd0ced --- /dev/null +++ b/vendor/golang.org/x/net/http/httpguts/guts.go @@ -0,0 +1,50 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httpguts provides functions implementing various details +// of the HTTP specification. +// +// This package is shared by the standard library (which vendors it) +// and x/net/http2. It comes with no API stability promise. +package httpguts + +import ( + "net/textproto" + "strings" +) + +// ValidTrailerHeader reports whether name is a valid header field name to appear +// in trailers. +// See RFC 7230, Section 4.1.2 +func ValidTrailerHeader(name string) bool { + name = textproto.CanonicalMIMEHeaderKey(name) + if strings.HasPrefix(name, "If-") || badTrailer[name] { + return false + } + return true +} + +var badTrailer = map[string]bool{ + "Authorization": true, + "Cache-Control": true, + "Connection": true, + "Content-Encoding": true, + "Content-Length": true, + "Content-Range": true, + "Content-Type": true, + "Expect": true, + "Host": true, + "Keep-Alive": true, + "Max-Forwards": true, + "Pragma": true, + "Proxy-Authenticate": true, + "Proxy-Authorization": true, + "Proxy-Connection": true, + "Range": true, + "Realm": true, + "Te": true, + "Trailer": true, + "Transfer-Encoding": true, + "Www-Authenticate": true, +} diff --git a/vendor/golang.org/x/net/http/httpguts/httplex.go b/vendor/golang.org/x/net/http/httpguts/httplex.go new file mode 100644 index 00000000..e7de24ee --- /dev/null +++ b/vendor/golang.org/x/net/http/httpguts/httplex.go @@ -0,0 +1,346 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httpguts + +import ( + "net" + "strings" + "unicode/utf8" + + "golang.org/x/net/idna" +) + +var isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +func IsTokenRune(r rune) bool { + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] +} + +func isNotToken(r rune) bool { + return !IsTokenRune(r) +} + +// HeaderValuesContainsToken reports whether any string in values +// contains the provided token, ASCII case-insensitively. +func HeaderValuesContainsToken(values []string, token string) bool { + for _, v := range values { + if headerValueContainsToken(v, token) { + return true + } + } + return false +} + +// isOWS reports whether b is an optional whitespace byte, as defined +// by RFC 7230 section 3.2.3. +func isOWS(b byte) bool { return b == ' ' || b == '\t' } + +// trimOWS returns x with all optional whitespace removes from the +// beginning and end. +func trimOWS(x string) string { + // TODO: consider using strings.Trim(x, " \t") instead, + // if and when it's fast enough. See issue 10292. + // But this ASCII-only code will probably always beat UTF-8 + // aware code. + for len(x) > 0 && isOWS(x[0]) { + x = x[1:] + } + for len(x) > 0 && isOWS(x[len(x)-1]) { + x = x[:len(x)-1] + } + return x +} + +// headerValueContainsToken reports whether v (assumed to be a +// 0#element, in the ABNF extension described in RFC 7230 section 7) +// contains token amongst its comma-separated tokens, ASCII +// case-insensitively. +func headerValueContainsToken(v string, token string) bool { + v = trimOWS(v) + if comma := strings.IndexByte(v, ','); comma != -1 { + return tokenEqual(trimOWS(v[:comma]), token) || headerValueContainsToken(v[comma+1:], token) + } + return tokenEqual(v, token) +} + +// lowerASCII returns the ASCII lowercase version of b. +func lowerASCII(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively. +func tokenEqual(t1, t2 string) bool { + if len(t1) != len(t2) { + return false + } + for i, b := range t1 { + if b >= utf8.RuneSelf { + // No UTF-8 or non-ASCII allowed in tokens. + return false + } + if lowerASCII(byte(b)) != lowerASCII(t2[i]) { + return false + } + } + return true +} + +// isLWS reports whether b is linear white space, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// LWS = [CRLF] 1*( SP | HT ) +func isLWS(b byte) bool { return b == ' ' || b == '\t' } + +// isCTL reports whether b is a control byte, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// CTL = +func isCTL(b byte) bool { + const del = 0x7f // a CTL + return b < ' ' || b == del +} + +// ValidHeaderFieldName reports whether v is a valid HTTP/1.x header name. +// HTTP/2 imposes the additional restriction that uppercase ASCII +// letters are not allowed. +// +// RFC 7230 says: +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// token = 1*tchar +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +func ValidHeaderFieldName(v string) bool { + if len(v) == 0 { + return false + } + for _, r := range v { + if !IsTokenRune(r) { + return false + } + } + return true +} + +// ValidHostHeader reports whether h is a valid host header. +func ValidHostHeader(h string) bool { + // The latest spec is actually this: + // + // http://tools.ietf.org/html/rfc7230#section-5.4 + // Host = uri-host [ ":" port ] + // + // Where uri-host is: + // http://tools.ietf.org/html/rfc3986#section-3.2.2 + // + // But we're going to be much more lenient for now and just + // search for any byte that's not a valid byte in any of those + // expressions. + for i := 0; i < len(h); i++ { + if !validHostByte[h[i]] { + return false + } + } + return true +} + +// See the validHostHeader comment. +var validHostByte = [256]bool{ + '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, + '8': true, '9': true, + + 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true, + 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true, + 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true, + 'y': true, 'z': true, + + 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, + 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true, + 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true, + 'Y': true, 'Z': true, + + '!': true, // sub-delims + '$': true, // sub-delims + '%': true, // pct-encoded (and used in IPv6 zones) + '&': true, // sub-delims + '(': true, // sub-delims + ')': true, // sub-delims + '*': true, // sub-delims + '+': true, // sub-delims + ',': true, // sub-delims + '-': true, // unreserved + '.': true, // unreserved + ':': true, // IPv6address + Host expression's optional port + ';': true, // sub-delims + '=': true, // sub-delims + '[': true, + '\'': true, // sub-delims + ']': true, + '_': true, // unreserved + '~': true, // unreserved +} + +// ValidHeaderFieldValue reports whether v is a valid "field-value" according to +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 : +// +// message-header = field-name ":" [ field-value ] +// field-value = *( field-content | LWS ) +// field-content = +// +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 : +// +// TEXT = +// LWS = [CRLF] 1*( SP | HT ) +// CTL = +// +// RFC 7230 says: +// field-value = *( field-content / obs-fold ) +// obj-fold = N/A to http2, and deprecated +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// VCHAR = "any visible [USASCII] character" +// +// http2 further says: "Similarly, HTTP/2 allows header field values +// that are not valid. While most of the values that can be encoded +// will not alter header field parsing, carriage return (CR, ASCII +// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII +// 0x0) might be exploited by an attacker if they are translated +// verbatim. Any request or response that contains a character not +// permitted in a header field value MUST be treated as malformed +// (Section 8.1.2.6). Valid characters are defined by the +// field-content ABNF rule in Section 3.2 of [RFC7230]." +// +// This function does not (yet?) properly handle the rejection of +// strings that begin or end with SP or HTAB. +func ValidHeaderFieldValue(v string) bool { + for i := 0; i < len(v); i++ { + b := v[i] + if isCTL(b) && !isLWS(b) { + return false + } + } + return true +} + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} + +// PunycodeHostPort returns the IDNA Punycode version +// of the provided "host" or "host:port" string. +func PunycodeHostPort(v string) (string, error) { + if isASCII(v) { + return v, nil + } + + host, port, err := net.SplitHostPort(v) + if err != nil { + // The input 'v' argument was just a "host" argument, + // without a port. This error should not be returned + // to the caller. + host = v + port = "" + } + host, err = idna.ToASCII(host) + if err != nil { + // Non-UTF-8? Not representable in Punycode, in any + // case. + return "", err + } + if port == "" { + return host, nil + } + return net.JoinHostPort(host, port), nil +} diff --git a/vendor/golang.org/x/sys/cpu/LICENSE b/vendor/golang.org/x/sys/cpu/LICENSE new file mode 100644 index 00000000..6a66aea5 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sys/cpu/cpu.go b/vendor/golang.org/x/sys/cpu/cpu.go new file mode 100644 index 00000000..3d88f866 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu.go @@ -0,0 +1,38 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cpu implements processor feature detection for +// various CPU architectures. +package cpu + +// CacheLinePad is used to pad structs to avoid false sharing. +type CacheLinePad struct{ _ [cacheLineSize]byte } + +// X86 contains the supported CPU features of the +// current X86/AMD64 platform. If the current platform +// is not X86/AMD64 then all feature flags are false. +// +// X86 is padded to avoid false sharing. Further the HasAVX +// and HasAVX2 are only set if the OS supports XMM and YMM +// registers in addition to the CPUID feature bit being set. +var X86 struct { + _ CacheLinePad + HasAES bool // AES hardware implementation (AES NI) + HasADX bool // Multi-precision add-carry instruction extensions + HasAVX bool // Advanced vector extension + HasAVX2 bool // Advanced vector extension 2 + HasBMI1 bool // Bit manipulation instruction set 1 + HasBMI2 bool // Bit manipulation instruction set 2 + HasERMS bool // Enhanced REP for MOVSB and STOSB + HasFMA bool // Fused-multiply-add instructions + HasOSXSAVE bool // OS supports XSAVE/XRESTOR for saving/restoring XMM registers. + HasPCLMULQDQ bool // PCLMULQDQ instruction - most often used for AES-GCM + HasPOPCNT bool // Hamming weight instruction POPCNT. + HasSSE2 bool // Streaming SIMD extension 2 (always available on amd64) + HasSSE3 bool // Streaming SIMD extension 3 + HasSSSE3 bool // Supplemental streaming SIMD extension 3 + HasSSE41 bool // Streaming SIMD extension 4 and 4.1 + HasSSE42 bool // Streaming SIMD extension 4 and 4.2 + _ CacheLinePad +} diff --git a/vendor/golang.org/x/sys/cpu/cpu_arm.go b/vendor/golang.org/x/sys/cpu/cpu_arm.go new file mode 100644 index 00000000..d93036f7 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_arm.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const cacheLineSize = 32 diff --git a/vendor/golang.org/x/sys/cpu/cpu_arm64.go b/vendor/golang.org/x/sys/cpu/cpu_arm64.go new file mode 100644 index 00000000..1d2ab290 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_arm64.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const cacheLineSize = 64 diff --git a/vendor/golang.org/x/sys/cpu/cpu_gc_x86.go b/vendor/golang.org/x/sys/cpu/cpu_gc_x86.go new file mode 100644 index 00000000..f7cb4697 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_gc_x86.go @@ -0,0 +1,16 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 +// +build !gccgo + +package cpu + +// cpuid is implemented in cpu_x86.s for gc compiler +// and in cpu_gccgo.c for gccgo. +func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) + +// xgetbv with ecx = 0 is implemented in cpu_x86.s for gc compiler +// and in cpu_gccgo.c for gccgo. +func xgetbv() (eax, edx uint32) diff --git a/vendor/golang.org/x/sys/cpu/cpu_gccgo.c b/vendor/golang.org/x/sys/cpu/cpu_gccgo.c new file mode 100644 index 00000000..e363c7d1 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_gccgo.c @@ -0,0 +1,43 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 +// +build gccgo + +#include +#include + +// Need to wrap __get_cpuid_count because it's declared as static. +int +gccgoGetCpuidCount(uint32_t leaf, uint32_t subleaf, + uint32_t *eax, uint32_t *ebx, + uint32_t *ecx, uint32_t *edx) +{ + return __get_cpuid_count(leaf, subleaf, eax, ebx, ecx, edx); +} + +// xgetbv reads the contents of an XCR (Extended Control Register) +// specified in the ECX register into registers EDX:EAX. +// Currently, the only supported value for XCR is 0. +// +// TODO: Replace with a better alternative: +// +// #include +// +// #pragma GCC target("xsave") +// +// void gccgoXgetbv(uint32_t *eax, uint32_t *edx) { +// unsigned long long x = _xgetbv(0); +// *eax = x & 0xffffffff; +// *edx = (x >> 32) & 0xffffffff; +// } +// +// Note that _xgetbv is defined starting with GCC 8. +void +gccgoXgetbv(uint32_t *eax, uint32_t *edx) +{ + __asm(" xorl %%ecx, %%ecx\n" + " xgetbv" + : "=a"(*eax), "=d"(*edx)); +} diff --git a/vendor/golang.org/x/sys/cpu/cpu_gccgo.go b/vendor/golang.org/x/sys/cpu/cpu_gccgo.go new file mode 100644 index 00000000..ba49b91b --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_gccgo.go @@ -0,0 +1,26 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 +// +build gccgo + +package cpu + +//extern gccgoGetCpuidCount +func gccgoGetCpuidCount(eaxArg, ecxArg uint32, eax, ebx, ecx, edx *uint32) + +func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) { + var a, b, c, d uint32 + gccgoGetCpuidCount(eaxArg, ecxArg, &a, &b, &c, &d) + return a, b, c, d +} + +//extern gccgoXgetbv +func gccgoXgetbv(eax, edx *uint32) + +func xgetbv() (eax, edx uint32) { + var a, d uint32 + gccgoXgetbv(&a, &d) + return a, d +} diff --git a/vendor/golang.org/x/sys/cpu/cpu_mips64x.go b/vendor/golang.org/x/sys/cpu/cpu_mips64x.go new file mode 100644 index 00000000..6165f121 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_mips64x.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build mips64 mips64le + +package cpu + +const cacheLineSize = 32 diff --git a/vendor/golang.org/x/sys/cpu/cpu_mipsx.go b/vendor/golang.org/x/sys/cpu/cpu_mipsx.go new file mode 100644 index 00000000..1269eee8 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_mipsx.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build mips mipsle + +package cpu + +const cacheLineSize = 32 diff --git a/vendor/golang.org/x/sys/cpu/cpu_ppc64x.go b/vendor/golang.org/x/sys/cpu/cpu_ppc64x.go new file mode 100644 index 00000000..d10759a5 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_ppc64x.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ppc64 ppc64le + +package cpu + +const cacheLineSize = 128 diff --git a/vendor/golang.org/x/sys/cpu/cpu_s390x.go b/vendor/golang.org/x/sys/cpu/cpu_s390x.go new file mode 100644 index 00000000..684c4f00 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_s390x.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const cacheLineSize = 256 diff --git a/vendor/golang.org/x/sys/cpu/cpu_x86.go b/vendor/golang.org/x/sys/cpu/cpu_x86.go new file mode 100644 index 00000000..71e288b0 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_x86.go @@ -0,0 +1,55 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 + +package cpu + +const cacheLineSize = 64 + +func init() { + maxID, _, _, _ := cpuid(0, 0) + + if maxID < 1 { + return + } + + _, _, ecx1, edx1 := cpuid(1, 0) + X86.HasSSE2 = isSet(26, edx1) + + X86.HasSSE3 = isSet(0, ecx1) + X86.HasPCLMULQDQ = isSet(1, ecx1) + X86.HasSSSE3 = isSet(9, ecx1) + X86.HasFMA = isSet(12, ecx1) + X86.HasSSE41 = isSet(19, ecx1) + X86.HasSSE42 = isSet(20, ecx1) + X86.HasPOPCNT = isSet(23, ecx1) + X86.HasAES = isSet(25, ecx1) + X86.HasOSXSAVE = isSet(27, ecx1) + + osSupportsAVX := false + // For XGETBV, OSXSAVE bit is required and sufficient. + if X86.HasOSXSAVE { + eax, _ := xgetbv() + // Check if XMM and YMM registers have OS support. + osSupportsAVX = isSet(1, eax) && isSet(2, eax) + } + + X86.HasAVX = isSet(28, ecx1) && osSupportsAVX + + if maxID < 7 { + return + } + + _, ebx7, _, _ := cpuid(7, 0) + X86.HasBMI1 = isSet(3, ebx7) + X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX + X86.HasBMI2 = isSet(8, ebx7) + X86.HasERMS = isSet(9, ebx7) + X86.HasADX = isSet(19, ebx7) +} + +func isSet(bitpos uint, value uint32) bool { + return value&(1<