mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-14 23:06:27 +03:00
update to quic-go v0.10.0 (#2288)
quic-go now vendors all of its dependencies, so we don't need to vendor them here. Created by running: gvt delete github.com/lucas-clemente/quic-go gvt delete github.com/bifurcation/mint gvt delete github.com/lucas-clemente/aes12 gvt delete github.com/lucas-clemente/fnv128a gvt delete github.com/lucas-clemente/quic-go-certificates gvt delete github.com/aead/chacha20 gvt delete github.com/hashicorp/golang-lru gvt fetch -tag v0.10.0-no-integrationtests github.com/lucas-clemente/quic-go
This commit is contained in:
parent
9edc16e4d6
commit
dfbc2e81e3
185 changed files with 7807 additions and 12112 deletions
23
vendor/github.com/aead/chacha20/chacha/chacha.go
generated
vendored
23
vendor/github.com/aead/chacha20/chacha/chacha.go
generated
vendored
|
@ -9,6 +9,7 @@ package chacha // import "github.com/aead/chacha20/chacha"
|
|||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -28,6 +29,7 @@ const (
|
|||
var (
|
||||
useSSE2 bool
|
||||
useSSSE3 bool
|
||||
useAVX bool
|
||||
useAVX2 bool
|
||||
)
|
||||
|
||||
|
@ -55,7 +57,7 @@ func setup(state *[64]byte, nonce, key []byte) (err error) {
|
|||
|
||||
copy(hNonce[:], nonce[:16])
|
||||
copy(tmpKey[:], key)
|
||||
hChaCha20(&tmpKey, &hNonce, &tmpKey)
|
||||
HChaCha20(&tmpKey, &hNonce, &tmpKey)
|
||||
copy(Nonce[8:], nonce[16:])
|
||||
initialize(state, tmpKey[:], &Nonce)
|
||||
|
||||
|
@ -161,6 +163,21 @@ func (c *Cipher) XORKeyStream(dst, src []byte) {
|
|||
c.off = 0
|
||||
}
|
||||
|
||||
// check for counter overflow
|
||||
blocksToXOR := len(src) / 64
|
||||
if len(src)%64 != 0 {
|
||||
blocksToXOR++
|
||||
}
|
||||
var overflow bool
|
||||
if c.noncesize == INonceSize {
|
||||
overflow = binary.LittleEndian.Uint32(c.state[48:]) > math.MaxUint32-uint32(blocksToXOR)
|
||||
} else {
|
||||
overflow = binary.LittleEndian.Uint64(c.state[48:]) > math.MaxUint64-uint64(blocksToXOR)
|
||||
}
|
||||
if overflow {
|
||||
panic("chacha20/chacha: counter overflow")
|
||||
}
|
||||
|
||||
c.off += xorKeyStream(dst, src, &(c.block), &(c.state), c.rounds)
|
||||
}
|
||||
|
||||
|
@ -174,3 +191,7 @@ func (c *Cipher) SetCounter(ctr uint64) {
|
|||
}
|
||||
c.off = 0
|
||||
}
|
||||
|
||||
// HChaCha20 generates 32 pseudo-random bytes from a 128 bit nonce and a 256 bit secret key.
|
||||
// It can be used as a key-derivation-function (KDF).
|
||||
func HChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { hChaCha20(out, nonce, key) }
|
||||
|
|
234
vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s
generated
vendored
234
vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s
generated
vendored
|
@ -2,111 +2,10 @@
|
|||
// Use of this source code is governed by a license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
// +build go1.7,amd64,!gccgo,!appengine,!nacl
|
||||
// +build amd64,!gccgo,!appengine,!nacl
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
DATA ·sigma_AVX<>+0x00(SB)/4, $0x61707865
|
||||
DATA ·sigma_AVX<>+0x04(SB)/4, $0x3320646e
|
||||
DATA ·sigma_AVX<>+0x08(SB)/4, $0x79622d32
|
||||
DATA ·sigma_AVX<>+0x0C(SB)/4, $0x6b206574
|
||||
GLOBL ·sigma_AVX<>(SB), (NOPTR+RODATA), $16
|
||||
|
||||
DATA ·one_AVX<>+0x00(SB)/8, $1
|
||||
DATA ·one_AVX<>+0x08(SB)/8, $0
|
||||
GLOBL ·one_AVX<>(SB), (NOPTR+RODATA), $16
|
||||
|
||||
DATA ·one_AVX2<>+0x00(SB)/8, $0
|
||||
DATA ·one_AVX2<>+0x08(SB)/8, $0
|
||||
DATA ·one_AVX2<>+0x10(SB)/8, $1
|
||||
DATA ·one_AVX2<>+0x18(SB)/8, $0
|
||||
GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32
|
||||
|
||||
DATA ·two_AVX2<>+0x00(SB)/8, $2
|
||||
DATA ·two_AVX2<>+0x08(SB)/8, $0
|
||||
DATA ·two_AVX2<>+0x10(SB)/8, $2
|
||||
DATA ·two_AVX2<>+0x18(SB)/8, $0
|
||||
GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32
|
||||
|
||||
DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302
|
||||
DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
|
||||
DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302
|
||||
DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A
|
||||
GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32
|
||||
|
||||
DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003
|
||||
DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
|
||||
DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003
|
||||
DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
|
||||
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32
|
||||
|
||||
#define ROTL(n, t, v) \
|
||||
VPSLLD $n, v, t; \
|
||||
VPSRLD $(32-n), v, v; \
|
||||
VPXOR v, t, v
|
||||
|
||||
#define CHACHA_QROUND(v0, v1, v2, v3, t, c16, c8) \
|
||||
VPADDD v0, v1, v0; \
|
||||
VPXOR v3, v0, v3; \
|
||||
VPSHUFB c16, v3, v3; \
|
||||
VPADDD v2, v3, v2; \
|
||||
VPXOR v1, v2, v1; \
|
||||
ROTL(12, t, v1); \
|
||||
VPADDD v0, v1, v0; \
|
||||
VPXOR v3, v0, v3; \
|
||||
VPSHUFB c8, v3, v3; \
|
||||
VPADDD v2, v3, v2; \
|
||||
VPXOR v1, v2, v1; \
|
||||
ROTL(7, t, v1)
|
||||
|
||||
#define CHACHA_SHUFFLE(v1, v2, v3) \
|
||||
VPSHUFD $0x39, v1, v1; \
|
||||
VPSHUFD $0x4E, v2, v2; \
|
||||
VPSHUFD $-109, v3, v3
|
||||
|
||||
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
|
||||
VMOVDQU (0+off)(src), t0; \
|
||||
VPERM2I128 $32, v1, v0, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (0+off)(dst); \
|
||||
VMOVDQU (32+off)(src), t0; \
|
||||
VPERM2I128 $32, v3, v2, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (32+off)(dst); \
|
||||
VMOVDQU (64+off)(src), t0; \
|
||||
VPERM2I128 $49, v1, v0, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (64+off)(dst); \
|
||||
VMOVDQU (96+off)(src), t0; \
|
||||
VPERM2I128 $49, v3, v2, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (96+off)(dst)
|
||||
|
||||
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
|
||||
VMOVDQU (0+off)(src), t0; \
|
||||
VPERM2I128 $32, v1, v0, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (0+off)(dst); \
|
||||
VMOVDQU (32+off)(src), t0; \
|
||||
VPERM2I128 $32, v3, v2, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (32+off)(dst); \
|
||||
|
||||
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \
|
||||
VPERM2I128 $49, v1, v0, t0; \
|
||||
VMOVDQU t0, 0(dst); \
|
||||
VPERM2I128 $49, v3, v2, t0; \
|
||||
VMOVDQU t0, 32(dst)
|
||||
|
||||
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t0) \
|
||||
VPXOR 0+off(src), v0, t0; \
|
||||
VMOVDQU t0, 0+off(dst); \
|
||||
VPXOR 16+off(src), v1, t0; \
|
||||
VMOVDQU t0, 16+off(dst); \
|
||||
VPXOR 32+off(src), v2, t0; \
|
||||
VMOVDQU t0, 32+off(dst); \
|
||||
VPXOR 48+off(src), v3, t0; \
|
||||
VMOVDQU t0, 48+off(dst)
|
||||
#include "const.s"
|
||||
#include "macro.s"
|
||||
|
||||
#define TWO 0(SP)
|
||||
#define C16 32(SP)
|
||||
|
@ -122,10 +21,10 @@ GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32
|
|||
TEXT ·xorKeyStreamAVX2(SB), 4, $320-80
|
||||
MOVQ dst_base+0(FP), DI
|
||||
MOVQ src_base+24(FP), SI
|
||||
MOVQ src_len+32(FP), CX
|
||||
MOVQ block+48(FP), BX
|
||||
MOVQ state+56(FP), AX
|
||||
MOVQ rounds+64(FP), DX
|
||||
MOVQ src_len+32(FP), CX
|
||||
|
||||
MOVQ SP, R8
|
||||
ADDQ $32, SP
|
||||
|
@ -185,28 +84,28 @@ at_least_512:
|
|||
|
||||
chacha_loop_512:
|
||||
VMOVDQA Y8, TMP_0
|
||||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8)
|
||||
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
|
||||
VMOVDQA TMP_0, Y8
|
||||
VMOVDQA Y0, TMP_0
|
||||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8)
|
||||
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8)
|
||||
CHACHA_SHUFFLE(Y1, Y2, Y3)
|
||||
CHACHA_SHUFFLE(Y5, Y6, Y7)
|
||||
CHACHA_SHUFFLE(Y9, Y10, Y11)
|
||||
CHACHA_SHUFFLE(Y13, Y14, Y15)
|
||||
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
|
||||
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
|
||||
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
|
||||
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
|
||||
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
|
||||
CHACHA_SHUFFLE_AVX(Y13, Y14, Y15)
|
||||
|
||||
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8)
|
||||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8)
|
||||
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
|
||||
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
|
||||
VMOVDQA TMP_0, Y0
|
||||
VMOVDQA Y8, TMP_0
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8)
|
||||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
|
||||
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
|
||||
VMOVDQA TMP_0, Y8
|
||||
CHACHA_SHUFFLE(Y3, Y2, Y1)
|
||||
CHACHA_SHUFFLE(Y7, Y6, Y5)
|
||||
CHACHA_SHUFFLE(Y11, Y10, Y9)
|
||||
CHACHA_SHUFFLE(Y15, Y14, Y13)
|
||||
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
|
||||
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
|
||||
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
|
||||
CHACHA_SHUFFLE_AVX(Y15, Y14, Y13)
|
||||
SUBQ $2, R9
|
||||
JA chacha_loop_512
|
||||
|
||||
|
@ -289,18 +188,18 @@ between_320_and_448:
|
|||
MOVQ DX, R9
|
||||
|
||||
chacha_loop_384:
|
||||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE(Y1, Y2, Y3)
|
||||
CHACHA_SHUFFLE(Y5, Y6, Y7)
|
||||
CHACHA_SHUFFLE(Y9, Y10, Y11)
|
||||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE(Y3, Y2, Y1)
|
||||
CHACHA_SHUFFLE(Y7, Y6, Y5)
|
||||
CHACHA_SHUFFLE(Y11, Y10, Y9)
|
||||
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
|
||||
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
|
||||
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
|
||||
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
|
||||
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
|
||||
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
|
||||
SUBQ $2, R9
|
||||
JA chacha_loop_384
|
||||
|
||||
|
@ -361,14 +260,14 @@ between_192_and_320:
|
|||
MOVQ DX, R9
|
||||
|
||||
chacha_loop_256:
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE(Y5, Y6, Y7)
|
||||
CHACHA_SHUFFLE(Y9, Y10, Y11)
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE(Y7, Y6, Y5)
|
||||
CHACHA_SHUFFLE(Y11, Y10, Y9)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
|
||||
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
|
||||
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
|
||||
SUBQ $2, R9
|
||||
JA chacha_loop_256
|
||||
|
||||
|
@ -413,10 +312,10 @@ between_64_and_192:
|
|||
MOVQ DX, R9
|
||||
|
||||
chacha_loop_128:
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE(Y5, Y6, Y7)
|
||||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE(Y7, Y6, Y5)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
|
||||
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
|
||||
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
|
||||
SUBQ $2, R9
|
||||
JA chacha_loop_128
|
||||
|
||||
|
@ -455,10 +354,10 @@ between_0_and_64:
|
|||
MOVQ DX, R9
|
||||
|
||||
chacha_loop_64:
|
||||
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15)
|
||||
CHACHA_SHUFFLE(X5, X6, X7)
|
||||
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15)
|
||||
CHACHA_SHUFFLE(X7, X6, X5)
|
||||
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
|
||||
CHACHA_SHUFFLE_AVX(X5, X6, X7)
|
||||
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
|
||||
CHACHA_SHUFFLE_AVX(X7, X6, X5)
|
||||
SUBQ $2, R9
|
||||
JA chacha_loop_64
|
||||
|
||||
|
@ -466,7 +365,7 @@ chacha_loop_64:
|
|||
VPADDD X1, X5, X5
|
||||
VPADDD X2, X6, X6
|
||||
VPADDD X3, X7, X7
|
||||
VMOVDQU ·one_AVX<>(SB), X0
|
||||
VMOVDQU ·one<>(SB), X0
|
||||
VPADDQ X0, X3, X3
|
||||
|
||||
CMPQ CX, $64
|
||||
|
@ -505,38 +404,3 @@ done:
|
|||
MOVQ CX, ret+72(FP)
|
||||
RET
|
||||
|
||||
// func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
TEXT ·hChaCha20AVX(SB), 4, $0-24
|
||||
MOVQ out+0(FP), DI
|
||||
MOVQ nonce+8(FP), AX
|
||||
MOVQ key+16(FP), BX
|
||||
|
||||
VMOVDQU ·sigma_AVX<>(SB), X0
|
||||
VMOVDQU 0(BX), X1
|
||||
VMOVDQU 16(BX), X2
|
||||
VMOVDQU 0(AX), X3
|
||||
VMOVDQU ·rol16_AVX2<>(SB), X5
|
||||
VMOVDQU ·rol8_AVX2<>(SB), X6
|
||||
|
||||
MOVQ $20, CX
|
||||
|
||||
chacha_loop:
|
||||
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6)
|
||||
CHACHA_SHUFFLE(X1, X2, X3)
|
||||
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6)
|
||||
CHACHA_SHUFFLE(X3, X2, X1)
|
||||
SUBQ $2, CX
|
||||
JNZ chacha_loop
|
||||
|
||||
VMOVDQU X0, 0(DI)
|
||||
VMOVDQU X3, 16(DI)
|
||||
VZEROUPPER
|
||||
RET
|
||||
|
||||
// func supportsAVX2() bool
|
||||
TEXT ·supportsAVX2(SB), 4, $0-1
|
||||
MOVQ runtime·support_avx(SB), AX
|
||||
MOVQ runtime·support_avx2(SB), BX
|
||||
ANDQ AX, BX
|
||||
MOVB BX, ret+0(FP)
|
||||
RET
|
||||
|
|
37
vendor/github.com/aead/chacha20/chacha/chacha_386.go
generated
vendored
37
vendor/github.com/aead/chacha20/chacha/chacha_386.go
generated
vendored
|
@ -6,11 +6,16 @@
|
|||
|
||||
package chacha
|
||||
|
||||
import "encoding/binary"
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
func init() {
|
||||
useSSE2 = supportsSSE2()
|
||||
useSSSE3 = supportsSSSE3()
|
||||
useSSE2 = cpu.X86.HasSSE2
|
||||
useSSSE3 = cpu.X86.HasSSSE3
|
||||
useAVX = false
|
||||
useAVX2 = false
|
||||
}
|
||||
|
||||
|
@ -23,14 +28,6 @@ func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
|
|||
copy(state[48:], nonce[:])
|
||||
}
|
||||
|
||||
// This function is implemented in chacha_386.s
|
||||
//go:noescape
|
||||
func supportsSSE2() bool
|
||||
|
||||
// This function is implemented in chacha_386.s
|
||||
//go:noescape
|
||||
func supportsSSSE3() bool
|
||||
|
||||
// This function is implemented in chacha_386.s
|
||||
//go:noescape
|
||||
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
|
@ -43,25 +40,21 @@ func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
|||
//go:noescape
|
||||
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
|
||||
// This function is implemented in chacha_386.s
|
||||
//go:noescape
|
||||
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
|
||||
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
|
||||
if useSSSE3 {
|
||||
switch {
|
||||
case useSSSE3:
|
||||
hChaCha20SSSE3(out, nonce, key)
|
||||
} else if useSSE2 {
|
||||
case useSSE2:
|
||||
hChaCha20SSE2(out, nonce, key)
|
||||
} else {
|
||||
default:
|
||||
hChaCha20Generic(out, nonce, key)
|
||||
}
|
||||
}
|
||||
|
||||
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
|
||||
if useSSSE3 {
|
||||
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
|
||||
} else if useSSE2 {
|
||||
if useSSE2 {
|
||||
return xorKeyStreamSSE2(dst, src, block, state, rounds)
|
||||
} else {
|
||||
return xorKeyStreamGeneric(dst, src, block, state, rounds)
|
||||
}
|
||||
return xorKeyStreamGeneric(dst, src, block, state, rounds)
|
||||
}
|
||||
|
|
408
vendor/github.com/aead/chacha20/chacha/chacha_386.s
generated
vendored
408
vendor/github.com/aead/chacha20/chacha/chacha_386.s
generated
vendored
|
@ -4,126 +4,125 @@
|
|||
|
||||
// +build 386,!gccgo,!appengine,!nacl
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
DATA ·sigma<>+0x00(SB)/4, $0x61707865
|
||||
DATA ·sigma<>+0x04(SB)/4, $0x3320646e
|
||||
DATA ·sigma<>+0x08(SB)/4, $0x79622d32
|
||||
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574
|
||||
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16
|
||||
|
||||
DATA ·one<>+0x00(SB)/8, $1
|
||||
DATA ·one<>+0x08(SB)/8, $0
|
||||
GLOBL ·one<>(SB), (NOPTR+RODATA), $16
|
||||
|
||||
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302
|
||||
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
|
||||
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16
|
||||
|
||||
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003
|
||||
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
|
||||
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
|
||||
|
||||
#define ROTL_SSE2(n, t, v) \
|
||||
MOVO v, t; \
|
||||
PSLLL $n, t; \
|
||||
PSRLL $(32-n), v; \
|
||||
PXOR t, v
|
||||
|
||||
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
ROTL_SSE2(16, t0, v3); \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE2(12, t0, v1); \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
ROTL_SSE2(8, t0, v3); \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE2(7, t0, v1)
|
||||
|
||||
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
PSHUFB r16, v3; \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE2(12, t0, v1); \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
PSHUFB r8, v3; \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE2(7, t0, v1)
|
||||
|
||||
#define CHACHA_SHUFFLE(v1, v2, v3) \
|
||||
PSHUFL $0x39, v1, v1; \
|
||||
PSHUFL $0x4E, v2, v2; \
|
||||
PSHUFL $0x93, v3, v3
|
||||
|
||||
#define XOR(dst, src, off, v0, v1, v2, v3, t0) \
|
||||
MOVOU 0+off(src), t0; \
|
||||
PXOR v0, t0; \
|
||||
MOVOU t0, 0+off(dst); \
|
||||
MOVOU 16+off(src), t0; \
|
||||
PXOR v1, t0; \
|
||||
MOVOU t0, 16+off(dst); \
|
||||
MOVOU 32+off(src), t0; \
|
||||
PXOR v2, t0; \
|
||||
MOVOU t0, 32+off(dst); \
|
||||
MOVOU 48+off(src), t0; \
|
||||
PXOR v3, t0; \
|
||||
MOVOU t0, 48+off(dst)
|
||||
#include "const.s"
|
||||
#include "macro.s"
|
||||
|
||||
// FINALIZE xors len bytes from src and block using
|
||||
// the temp. registers t0 and t1 and writes the result
|
||||
// to dst.
|
||||
#define FINALIZE(dst, src, block, len, t0, t1) \
|
||||
XORL t0, t0; \
|
||||
XORL t1, t1; \
|
||||
finalize: \
|
||||
MOVB 0(src), t0; \
|
||||
MOVB 0(block), t1; \
|
||||
XORL t0, t1; \
|
||||
MOVB t1, 0(dst); \
|
||||
INCL src; \
|
||||
INCL block; \
|
||||
INCL dst; \
|
||||
DECL len; \
|
||||
JA finalize \
|
||||
XORL t0, t0; \
|
||||
XORL t1, t1; \
|
||||
FINALIZE_LOOP:; \
|
||||
MOVB 0(src), t0; \
|
||||
MOVB 0(block), t1; \
|
||||
XORL t0, t1; \
|
||||
MOVB t1, 0(dst); \
|
||||
INCL src; \
|
||||
INCL block; \
|
||||
INCL dst; \
|
||||
DECL len; \
|
||||
JG FINALIZE_LOOP \
|
||||
|
||||
#define Dst DI
|
||||
#define Nonce AX
|
||||
#define Key BX
|
||||
#define Rounds DX
|
||||
|
||||
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
TEXT ·hChaCha20SSE2(SB), 4, $0-12
|
||||
MOVL out+0(FP), Dst
|
||||
MOVL nonce+4(FP), Nonce
|
||||
MOVL key+8(FP), Key
|
||||
|
||||
MOVOU ·sigma<>(SB), X0
|
||||
MOVOU 0*16(Key), X1
|
||||
MOVOU 1*16(Key), X2
|
||||
MOVOU 0*16(Nonce), X3
|
||||
MOVL $20, Rounds
|
||||
|
||||
chacha_loop:
|
||||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
|
||||
CHACHA_SHUFFLE_SSE(X1, X2, X3)
|
||||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
|
||||
CHACHA_SHUFFLE_SSE(X3, X2, X1)
|
||||
SUBL $2, Rounds
|
||||
JNZ chacha_loop
|
||||
|
||||
MOVOU X0, 0*16(Dst)
|
||||
MOVOU X3, 1*16(Dst)
|
||||
RET
|
||||
|
||||
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
TEXT ·hChaCha20SSSE3(SB), 4, $0-12
|
||||
MOVL out+0(FP), Dst
|
||||
MOVL nonce+4(FP), Nonce
|
||||
MOVL key+8(FP), Key
|
||||
|
||||
MOVOU ·sigma<>(SB), X0
|
||||
MOVOU 0*16(Key), X1
|
||||
MOVOU 1*16(Key), X2
|
||||
MOVOU 0*16(Nonce), X3
|
||||
MOVL $20, Rounds
|
||||
|
||||
MOVOU ·rol16<>(SB), X5
|
||||
MOVOU ·rol8<>(SB), X6
|
||||
|
||||
chacha_loop:
|
||||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
|
||||
CHACHA_SHUFFLE_SSE(X1, X2, X3)
|
||||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
|
||||
CHACHA_SHUFFLE_SSE(X3, X2, X1)
|
||||
SUBL $2, Rounds
|
||||
JNZ chacha_loop
|
||||
|
||||
MOVOU X0, 0*16(Dst)
|
||||
MOVOU X3, 1*16(Dst)
|
||||
RET
|
||||
|
||||
#undef Dst
|
||||
#undef Nonce
|
||||
#undef Key
|
||||
#undef Rounds
|
||||
|
||||
#define State AX
|
||||
#define Dst DI
|
||||
#define Src SI
|
||||
#define Len DX
|
||||
#define Tmp0 BX
|
||||
#define Tmp1 BP
|
||||
|
||||
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
TEXT ·xorKeyStreamSSE2(SB), 4, $0-40
|
||||
MOVL dst_base+0(FP), DI
|
||||
MOVL src_base+12(FP), SI
|
||||
MOVL src_len+16(FP), CX
|
||||
MOVL state+28(FP), AX
|
||||
MOVL rounds+32(FP), DX
|
||||
MOVL dst_base+0(FP), Dst
|
||||
MOVL src_base+12(FP), Src
|
||||
MOVL state+28(FP), State
|
||||
MOVL src_len+16(FP), Len
|
||||
MOVL $0, ret+36(FP) // Number of bytes written to the keystream buffer - 0 iff len mod 64 == 0
|
||||
|
||||
MOVOU 0(AX), X0
|
||||
MOVOU 16(AX), X1
|
||||
MOVOU 32(AX), X2
|
||||
MOVOU 48(AX), X3
|
||||
MOVOU 0*16(State), X0
|
||||
MOVOU 1*16(State), X1
|
||||
MOVOU 2*16(State), X2
|
||||
MOVOU 3*16(State), X3
|
||||
TESTL Len, Len
|
||||
JZ DONE
|
||||
|
||||
TESTL CX, CX
|
||||
JZ done
|
||||
|
||||
at_least_64:
|
||||
GENERATE_KEYSTREAM:
|
||||
MOVO X0, X4
|
||||
MOVO X1, X5
|
||||
MOVO X2, X6
|
||||
MOVO X3, X7
|
||||
MOVL rounds+32(FP), Tmp0
|
||||
|
||||
MOVL DX, BX
|
||||
|
||||
chacha_loop:
|
||||
CHACHA_LOOP:
|
||||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
|
||||
CHACHA_SHUFFLE(X5, X6, X7)
|
||||
CHACHA_SHUFFLE_SSE(X5, X6, X7)
|
||||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
|
||||
CHACHA_SHUFFLE(X7, X6, X5)
|
||||
SUBL $2, BX
|
||||
JA chacha_loop
|
||||
CHACHA_SHUFFLE_SSE(X7, X6, X5)
|
||||
SUBL $2, Tmp0
|
||||
JA CHACHA_LOOP
|
||||
|
||||
MOVOU 0(AX), X0
|
||||
MOVOU 0*16(State), X0 // Restore X0 from state
|
||||
PADDL X0, X4
|
||||
PADDL X1, X5
|
||||
PADDL X2, X6
|
||||
|
@ -131,181 +130,34 @@ chacha_loop:
|
|||
MOVOU ·one<>(SB), X0
|
||||
PADDQ X0, X3
|
||||
|
||||
CMPL CX, $64
|
||||
JB less_than_64
|
||||
CMPL Len, $64
|
||||
JL BUFFER_KEYSTREAM
|
||||
|
||||
XOR(DI, SI, 0, X4, X5, X6, X7, X0)
|
||||
MOVOU 0(AX), X0
|
||||
ADDL $64, SI
|
||||
ADDL $64, DI
|
||||
SUBL $64, CX
|
||||
JNZ at_least_64
|
||||
XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X0)
|
||||
MOVOU 0*16(State), X0 // Restore X0 from state
|
||||
ADDL $64, Src
|
||||
ADDL $64, Dst
|
||||
SUBL $64, Len
|
||||
JZ DONE
|
||||
JMP GENERATE_KEYSTREAM // There is at least one more plaintext byte
|
||||
|
||||
less_than_64:
|
||||
MOVL CX, BP
|
||||
TESTL BP, BP
|
||||
JZ done
|
||||
BUFFER_KEYSTREAM:
|
||||
MOVL block+24(FP), State
|
||||
MOVOU X4, 0(State)
|
||||
MOVOU X5, 16(State)
|
||||
MOVOU X6, 32(State)
|
||||
MOVOU X7, 48(State)
|
||||
MOVL Len, ret+36(FP) // Number of bytes written to the keystream buffer - 0 < Len < 64
|
||||
FINALIZE(Dst, Src, State, Len, Tmp0, Tmp1)
|
||||
|
||||
MOVL block+24(FP), BX
|
||||
MOVOU X4, 0(BX)
|
||||
MOVOU X5, 16(BX)
|
||||
MOVOU X6, 32(BX)
|
||||
MOVOU X7, 48(BX)
|
||||
FINALIZE(DI, SI, BX, BP, AX, DX)
|
||||
|
||||
done:
|
||||
MOVL state+28(FP), AX
|
||||
MOVOU X3, 48(AX)
|
||||
MOVL CX, ret+36(FP)
|
||||
DONE:
|
||||
MOVL state+28(FP), State
|
||||
MOVOU X3, 3*16(State)
|
||||
RET
|
||||
|
||||
// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
TEXT ·xorKeyStreamSSSE3(SB), 4, $64-40
|
||||
MOVL dst_base+0(FP), DI
|
||||
MOVL src_base+12(FP), SI
|
||||
MOVL src_len+16(FP), CX
|
||||
MOVL state+28(FP), AX
|
||||
MOVL rounds+32(FP), DX
|
||||
|
||||
MOVOU 48(AX), X3
|
||||
TESTL CX, CX
|
||||
JZ done
|
||||
|
||||
MOVL SP, BP
|
||||
ADDL $16, SP
|
||||
ANDL $-16, SP
|
||||
|
||||
MOVOU ·one<>(SB), X0
|
||||
MOVOU 16(AX), X1
|
||||
MOVOU 32(AX), X2
|
||||
MOVO X0, 0(SP)
|
||||
MOVO X1, 16(SP)
|
||||
MOVO X2, 32(SP)
|
||||
|
||||
MOVOU 0(AX), X0
|
||||
MOVOU ·rol16<>(SB), X1
|
||||
MOVOU ·rol8<>(SB), X2
|
||||
|
||||
at_least_64:
|
||||
MOVO X0, X4
|
||||
MOVO 16(SP), X5
|
||||
MOVO 32(SP), X6
|
||||
MOVO X3, X7
|
||||
|
||||
MOVL DX, BX
|
||||
|
||||
chacha_loop:
|
||||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2)
|
||||
CHACHA_SHUFFLE(X5, X6, X7)
|
||||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2)
|
||||
CHACHA_SHUFFLE(X7, X6, X5)
|
||||
SUBL $2, BX
|
||||
JA chacha_loop
|
||||
|
||||
MOVOU 0(AX), X0
|
||||
PADDL X0, X4
|
||||
PADDL 16(SP), X5
|
||||
PADDL 32(SP), X6
|
||||
PADDL X3, X7
|
||||
PADDQ 0(SP), X3
|
||||
|
||||
CMPL CX, $64
|
||||
JB less_than_64
|
||||
|
||||
XOR(DI, SI, 0, X4, X5, X6, X7, X0)
|
||||
MOVOU 0(AX), X0
|
||||
ADDL $64, SI
|
||||
ADDL $64, DI
|
||||
SUBL $64, CX
|
||||
JNZ at_least_64
|
||||
|
||||
less_than_64:
|
||||
MOVL BP, SP
|
||||
MOVL CX, BP
|
||||
TESTL BP, BP
|
||||
JE done
|
||||
|
||||
MOVL block+24(FP), BX
|
||||
MOVOU X4, 0(BX)
|
||||
MOVOU X5, 16(BX)
|
||||
MOVOU X6, 32(BX)
|
||||
MOVOU X7, 48(BX)
|
||||
FINALIZE(DI, SI, BX, BP, AX, DX)
|
||||
|
||||
done:
|
||||
MOVL state+28(FP), AX
|
||||
MOVOU X3, 48(AX)
|
||||
MOVL CX, ret+36(FP)
|
||||
RET
|
||||
|
||||
// func supportsSSE2() bool
|
||||
TEXT ·supportsSSE2(SB), NOSPLIT, $0-1
|
||||
XORL AX, AX
|
||||
INCL AX
|
||||
CPUID
|
||||
SHRL $26, DX
|
||||
ANDL $1, DX
|
||||
MOVB DX, ret+0(FP)
|
||||
RET
|
||||
|
||||
// func supportsSSSE3() bool
|
||||
TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1
|
||||
XORL AX, AX
|
||||
INCL AX
|
||||
CPUID
|
||||
SHRL $9, CX
|
||||
ANDL $1, CX
|
||||
MOVB CX, ret+0(FP)
|
||||
RET
|
||||
|
||||
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
TEXT ·hChaCha20SSE2(SB), 4, $0-12
|
||||
MOVL out+0(FP), DI
|
||||
MOVL nonce+4(FP), AX
|
||||
MOVL key+8(FP), BX
|
||||
|
||||
MOVOU ·sigma<>(SB), X0
|
||||
MOVOU 0(BX), X1
|
||||
MOVOU 16(BX), X2
|
||||
MOVOU 0(AX), X3
|
||||
|
||||
MOVL $20, CX
|
||||
|
||||
chacha_loop:
|
||||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
|
||||
CHACHA_SHUFFLE(X1, X2, X3)
|
||||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
|
||||
CHACHA_SHUFFLE(X3, X2, X1)
|
||||
SUBL $2, CX
|
||||
JNZ chacha_loop
|
||||
|
||||
MOVOU X0, 0(DI)
|
||||
MOVOU X3, 16(DI)
|
||||
RET
|
||||
|
||||
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
TEXT ·hChaCha20SSSE3(SB), 4, $0-12
|
||||
MOVL out+0(FP), DI
|
||||
MOVL nonce+4(FP), AX
|
||||
MOVL key+8(FP), BX
|
||||
|
||||
MOVOU ·sigma<>(SB), X0
|
||||
MOVOU 0(BX), X1
|
||||
MOVOU 16(BX), X2
|
||||
MOVOU 0(AX), X3
|
||||
MOVOU ·rol16<>(SB), X5
|
||||
MOVOU ·rol8<>(SB), X6
|
||||
|
||||
MOVL $20, CX
|
||||
|
||||
chacha_loop:
|
||||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
|
||||
CHACHA_SHUFFLE(X1, X2, X3)
|
||||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
|
||||
CHACHA_SHUFFLE(X3, X2, X1)
|
||||
SUBL $2, CX
|
||||
JNZ chacha_loop
|
||||
|
||||
MOVOU X0, 0(DI)
|
||||
MOVOU X3, 16(DI)
|
||||
RET
|
||||
#undef State
|
||||
#undef Dst
|
||||
#undef Src
|
||||
#undef Len
|
||||
#undef Tmp0
|
||||
#undef Tmp1
|
||||
|
|
|
@ -6,24 +6,19 @@
|
|||
|
||||
package chacha
|
||||
|
||||
import "golang.org/x/sys/cpu"
|
||||
|
||||
func init() {
|
||||
useSSE2 = true
|
||||
useSSSE3 = supportsSSSE3()
|
||||
useAVX2 = supportsAVX2()
|
||||
useSSE2 = cpu.X86.HasSSE2
|
||||
useSSSE3 = cpu.X86.HasSSSE3
|
||||
useAVX = cpu.X86.HasAVX
|
||||
useAVX2 = cpu.X86.HasAVX2
|
||||
}
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func initialize(state *[64]byte, key []byte, nonce *[16]byte)
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func supportsSSSE3() bool
|
||||
|
||||
// This function is implemented in chachaAVX2_amd64.s
|
||||
//go:noescape
|
||||
func supportsAVX2() bool
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
|
@ -44,29 +39,38 @@ func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
|
|||
//go:noescape
|
||||
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
|
||||
// This function is implemented in chachaAVX2_amd64.s
|
||||
//go:noescape
|
||||
func xorKeyStreamAVX2(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
|
||||
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
|
||||
if useAVX2 {
|
||||
switch {
|
||||
case useAVX:
|
||||
hChaCha20AVX(out, nonce, key)
|
||||
} else if useSSSE3 {
|
||||
case useSSSE3:
|
||||
hChaCha20SSSE3(out, nonce, key)
|
||||
} else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64
|
||||
case useSSE2:
|
||||
hChaCha20SSE2(out, nonce, key)
|
||||
} else {
|
||||
default:
|
||||
hChaCha20Generic(out, nonce, key)
|
||||
}
|
||||
}
|
||||
|
||||
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
|
||||
if useAVX2 {
|
||||
switch {
|
||||
case useAVX2:
|
||||
return xorKeyStreamAVX2(dst, src, block, state, rounds)
|
||||
} else if useSSSE3 {
|
||||
case useAVX:
|
||||
return xorKeyStreamAVX(dst, src, block, state, rounds)
|
||||
case useSSSE3:
|
||||
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
|
||||
} else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64
|
||||
case useSSE2:
|
||||
return xorKeyStreamSSE2(dst, src, block, state, rounds)
|
||||
default:
|
||||
return xorKeyStreamGeneric(dst, src, block, state, rounds)
|
||||
}
|
||||
return xorKeyStreamGeneric(dst, src, block, state, rounds)
|
||||
}
|
1786
vendor/github.com/aead/chacha20/chacha/chacha_amd64.s
generated
vendored
1786
vendor/github.com/aead/chacha20/chacha/chacha_amd64.s
generated
vendored
File diff suppressed because it is too large
Load diff
56
vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go
generated
vendored
56
vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go
generated
vendored
|
@ -1,56 +0,0 @@
|
|||
// Copyright (c) 2017 Andreas Auernhammer. All rights reserved.
|
||||
// Use of this source code is governed by a license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
// +build amd64,!gccgo,!appengine,!nacl,!go1.7
|
||||
|
||||
package chacha
|
||||
|
||||
func init() {
|
||||
useSSE2 = true
|
||||
useSSSE3 = supportsSSSE3()
|
||||
useAVX2 = false
|
||||
}
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func initialize(state *[64]byte, key []byte, nonce *[16]byte)
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func supportsSSSE3() bool
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
|
||||
// This function is implemented in chacha_amd64.s
|
||||
//go:noescape
|
||||
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
|
||||
|
||||
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
|
||||
if useSSSE3 {
|
||||
hChaCha20SSSE3(out, nonce, key)
|
||||
} else if useSSE2 { // on amd64 this is always true - used to test generic on amd64
|
||||
hChaCha20SSE2(out, nonce, key)
|
||||
} else {
|
||||
hChaCha20Generic(out, nonce, key)
|
||||
}
|
||||
}
|
||||
|
||||
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
|
||||
if useSSSE3 {
|
||||
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
|
||||
} else if useSSE2 { // on amd64 this is always true - used to test generic on amd64
|
||||
return xorKeyStreamSSE2(dst, src, block, state, rounds)
|
||||
}
|
||||
return xorKeyStreamGeneric(dst, src, block, state, rounds)
|
||||
}
|
7
vendor/github.com/aead/chacha20/chacha/chacha_ref.go
generated
vendored
7
vendor/github.com/aead/chacha20/chacha/chacha_ref.go
generated
vendored
|
@ -8,6 +8,13 @@ package chacha
|
|||
|
||||
import "encoding/binary"
|
||||
|
||||
func init() {
|
||||
useSSE2 = false
|
||||
useSSSE3 = false
|
||||
useAVX = false
|
||||
useAVX2 = false
|
||||
}
|
||||
|
||||
func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
|
||||
binary.LittleEndian.PutUint32(state[0:], sigma[0])
|
||||
binary.LittleEndian.PutUint32(state[4:], sigma[1])
|
||||
|
|
53
vendor/github.com/aead/chacha20/chacha/const.s
generated
vendored
Normal file
53
vendor/github.com/aead/chacha20/chacha/const.s
generated
vendored
Normal file
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
|
||||
// Use of this source code is governed by a license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
DATA ·sigma<>+0x00(SB)/4, $0x61707865
|
||||
DATA ·sigma<>+0x04(SB)/4, $0x3320646e
|
||||
DATA ·sigma<>+0x08(SB)/4, $0x79622d32
|
||||
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574
|
||||
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 // The 4 ChaCha initialization constants
|
||||
|
||||
// SSE2/SSE3/AVX constants
|
||||
|
||||
DATA ·one<>+0x00(SB)/8, $1
|
||||
DATA ·one<>+0x08(SB)/8, $0
|
||||
GLOBL ·one<>(SB), (NOPTR+RODATA), $16 // The constant 1 as 128 bit value
|
||||
|
||||
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302
|
||||
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
|
||||
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 16 bit left rotate constant
|
||||
|
||||
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003
|
||||
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
|
||||
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 8 bit left rotate constant
|
||||
|
||||
// AVX2 constants
|
||||
|
||||
DATA ·one_AVX2<>+0x00(SB)/8, $0
|
||||
DATA ·one_AVX2<>+0x08(SB)/8, $0
|
||||
DATA ·one_AVX2<>+0x10(SB)/8, $1
|
||||
DATA ·one_AVX2<>+0x18(SB)/8, $0
|
||||
GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 // The constant 1 as 256 bit value
|
||||
|
||||
DATA ·two_AVX2<>+0x00(SB)/8, $2
|
||||
DATA ·two_AVX2<>+0x08(SB)/8, $0
|
||||
DATA ·two_AVX2<>+0x10(SB)/8, $2
|
||||
DATA ·two_AVX2<>+0x18(SB)/8, $0
|
||||
GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32
|
||||
|
||||
DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302
|
||||
DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
|
||||
DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302
|
||||
DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A
|
||||
GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 16 bit left rotate constant
|
||||
|
||||
DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003
|
||||
DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
|
||||
DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003
|
||||
DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
|
||||
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 8 bit left rotate constant
|
163
vendor/github.com/aead/chacha20/chacha/macro.s
generated
vendored
Normal file
163
vendor/github.com/aead/chacha20/chacha/macro.s
generated
vendored
Normal file
|
@ -0,0 +1,163 @@
|
|||
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
|
||||
// Use of this source code is governed by a license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
|
||||
|
||||
// ROTL_SSE rotates all 4 32 bit values of the XMM register v
|
||||
// left by n bits using SSE2 instructions (0 <= n <= 32).
|
||||
// The XMM register t is used as a temp. register.
|
||||
#define ROTL_SSE(n, t, v) \
|
||||
MOVO v, t; \
|
||||
PSLLL $n, t; \
|
||||
PSRLL $(32-n), v; \
|
||||
PXOR t, v
|
||||
|
||||
// ROTL_AVX rotates all 4/8 32 bit values of the AVX/AVX2 register v
|
||||
// left by n bits using AVX/AVX2 instructions (0 <= n <= 32).
|
||||
// The AVX/AVX2 register t is used as a temp. register.
|
||||
#define ROTL_AVX(n, t, v) \
|
||||
VPSLLD $n, v, t; \
|
||||
VPSRLD $(32-n), v, v; \
|
||||
VPXOR v, t, v
|
||||
|
||||
// CHACHA_QROUND_SSE2 performs a ChaCha quarter-round using the
|
||||
// 4 XMM registers v0, v1, v2 and v3. It uses only ROTL_SSE2 for
|
||||
// rotations. The XMM register t is used as a temp. register.
|
||||
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t) \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
ROTL_SSE(16, t, v3); \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE(12, t, v1); \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
ROTL_SSE(8, t, v3); \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE(7, t, v1)
|
||||
|
||||
// CHACHA_QROUND_SSSE3 performs a ChaCha quarter-round using the
|
||||
// 4 XMM registers v0, v1, v2 and v3. It uses PSHUFB for 8/16 bit
|
||||
// rotations. The XMM register t is used as a temp. register.
|
||||
//
|
||||
// r16 holds the PSHUFB constant for a 16 bit left rotate.
|
||||
// r8 holds the PSHUFB constant for a 8 bit left rotate.
|
||||
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t, r16, r8) \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
PSHUFB r16, v3; \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE(12, t, v1); \
|
||||
PADDL v1, v0; \
|
||||
PXOR v0, v3; \
|
||||
PSHUFB r8, v3; \
|
||||
PADDL v3, v2; \
|
||||
PXOR v2, v1; \
|
||||
ROTL_SSE(7, t, v1)
|
||||
|
||||
// CHACHA_QROUND_AVX performs a ChaCha quarter-round using the
|
||||
// 4 AVX/AVX2 registers v0, v1, v2 and v3. It uses VPSHUFB for 8/16 bit
|
||||
// rotations. The AVX/AVX2 register t is used as a temp. register.
|
||||
//
|
||||
// r16 holds the VPSHUFB constant for a 16 bit left rotate.
|
||||
// r8 holds the VPSHUFB constant for a 8 bit left rotate.
|
||||
#define CHACHA_QROUND_AVX(v0, v1, v2, v3, t, r16, r8) \
|
||||
VPADDD v0, v1, v0; \
|
||||
VPXOR v3, v0, v3; \
|
||||
VPSHUFB r16, v3, v3; \
|
||||
VPADDD v2, v3, v2; \
|
||||
VPXOR v1, v2, v1; \
|
||||
ROTL_AVX(12, t, v1); \
|
||||
VPADDD v0, v1, v0; \
|
||||
VPXOR v3, v0, v3; \
|
||||
VPSHUFB r8, v3, v3; \
|
||||
VPADDD v2, v3, v2; \
|
||||
VPXOR v1, v2, v1; \
|
||||
ROTL_AVX(7, t, v1)
|
||||
|
||||
// CHACHA_SHUFFLE_SSE performs a ChaCha shuffle using the
|
||||
// 3 XMM registers v1, v2 and v3. The inverse shuffle is
|
||||
// performed by switching v1 and v3: CHACHA_SHUFFLE_SSE(v3, v2, v1).
|
||||
#define CHACHA_SHUFFLE_SSE(v1, v2, v3) \
|
||||
PSHUFL $0x39, v1, v1; \
|
||||
PSHUFL $0x4E, v2, v2; \
|
||||
PSHUFL $0x93, v3, v3
|
||||
|
||||
// CHACHA_SHUFFLE_AVX performs a ChaCha shuffle using the
|
||||
// 3 AVX/AVX2 registers v1, v2 and v3. The inverse shuffle is
|
||||
// performed by switching v1 and v3: CHACHA_SHUFFLE_AVX(v3, v2, v1).
|
||||
#define CHACHA_SHUFFLE_AVX(v1, v2, v3) \
|
||||
VPSHUFD $0x39, v1, v1; \
|
||||
VPSHUFD $0x4E, v2, v2; \
|
||||
VPSHUFD $0x93, v3, v3
|
||||
|
||||
// XOR_SSE extracts 4x16 byte vectors from src at
|
||||
// off, xors all vectors with the corresponding XMM
|
||||
// register (v0 - v3) and writes the result to dst
|
||||
// at off.
|
||||
// The XMM register t is used as a temp. register.
|
||||
#define XOR_SSE(dst, src, off, v0, v1, v2, v3, t) \
|
||||
MOVOU 0+off(src), t; \
|
||||
PXOR v0, t; \
|
||||
MOVOU t, 0+off(dst); \
|
||||
MOVOU 16+off(src), t; \
|
||||
PXOR v1, t; \
|
||||
MOVOU t, 16+off(dst); \
|
||||
MOVOU 32+off(src), t; \
|
||||
PXOR v2, t; \
|
||||
MOVOU t, 32+off(dst); \
|
||||
MOVOU 48+off(src), t; \
|
||||
PXOR v3, t; \
|
||||
MOVOU t, 48+off(dst)
|
||||
|
||||
// XOR_AVX extracts 4x16 byte vectors from src at
|
||||
// off, xors all vectors with the corresponding AVX
|
||||
// register (v0 - v3) and writes the result to dst
|
||||
// at off.
|
||||
// The XMM register t is used as a temp. register.
|
||||
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t) \
|
||||
VPXOR 0+off(src), v0, t; \
|
||||
VMOVDQU t, 0+off(dst); \
|
||||
VPXOR 16+off(src), v1, t; \
|
||||
VMOVDQU t, 16+off(dst); \
|
||||
VPXOR 32+off(src), v2, t; \
|
||||
VMOVDQU t, 32+off(dst); \
|
||||
VPXOR 48+off(src), v3, t; \
|
||||
VMOVDQU t, 48+off(dst)
|
||||
|
||||
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
|
||||
VMOVDQU (0+off)(src), t0; \
|
||||
VPERM2I128 $32, v1, v0, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (0+off)(dst); \
|
||||
VMOVDQU (32+off)(src), t0; \
|
||||
VPERM2I128 $32, v3, v2, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (32+off)(dst); \
|
||||
VMOVDQU (64+off)(src), t0; \
|
||||
VPERM2I128 $49, v1, v0, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (64+off)(dst); \
|
||||
VMOVDQU (96+off)(src), t0; \
|
||||
VPERM2I128 $49, v3, v2, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (96+off)(dst)
|
||||
|
||||
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
|
||||
VMOVDQU (0+off)(src), t0; \
|
||||
VPERM2I128 $32, v1, v0, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (0+off)(dst); \
|
||||
VMOVDQU (32+off)(src), t0; \
|
||||
VPERM2I128 $32, v3, v2, t1; \
|
||||
VPXOR t0, t1, t0; \
|
||||
VMOVDQU t0, (32+off)(dst); \
|
||||
|
||||
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \
|
||||
VPERM2I128 $49, v1, v0, t0; \
|
||||
VMOVDQU t0, 0(dst); \
|
||||
VPERM2I128 $49, v3, v2, t0; \
|
||||
VMOVDQU t0, 32(dst)
|
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
|
@ -1,21 +0,0 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Richard Barnes
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
|
@ -1,99 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mint
|
||||
|
||||
import "strconv"
|
||||
|
||||
type Alert uint8
|
||||
|
||||
const (
|
||||
// alert level
|
||||
AlertLevelWarning = 1
|
||||
AlertLevelError = 2
|
||||
)
|
||||
|
||||
const (
|
||||
AlertCloseNotify Alert = 0
|
||||
AlertUnexpectedMessage Alert = 10
|
||||
AlertBadRecordMAC Alert = 20
|
||||
AlertDecryptionFailed Alert = 21
|
||||
AlertRecordOverflow Alert = 22
|
||||
AlertDecompressionFailure Alert = 30
|
||||
AlertHandshakeFailure Alert = 40
|
||||
AlertBadCertificate Alert = 42
|
||||
AlertUnsupportedCertificate Alert = 43
|
||||
AlertCertificateRevoked Alert = 44
|
||||
AlertCertificateExpired Alert = 45
|
||||
AlertCertificateUnknown Alert = 46
|
||||
AlertIllegalParameter Alert = 47
|
||||
AlertUnknownCA Alert = 48
|
||||
AlertAccessDenied Alert = 49
|
||||
AlertDecodeError Alert = 50
|
||||
AlertDecryptError Alert = 51
|
||||
AlertProtocolVersion Alert = 70
|
||||
AlertInsufficientSecurity Alert = 71
|
||||
AlertInternalError Alert = 80
|
||||
AlertInappropriateFallback Alert = 86
|
||||
AlertUserCanceled Alert = 90
|
||||
AlertNoRenegotiation Alert = 100
|
||||
AlertMissingExtension Alert = 109
|
||||
AlertUnsupportedExtension Alert = 110
|
||||
AlertCertificateUnobtainable Alert = 111
|
||||
AlertUnrecognizedName Alert = 112
|
||||
AlertBadCertificateStatsResponse Alert = 113
|
||||
AlertBadCertificateHashValue Alert = 114
|
||||
AlertUnknownPSKIdentity Alert = 115
|
||||
AlertNoApplicationProtocol Alert = 120
|
||||
AlertWouldBlock Alert = 254
|
||||
AlertNoAlert Alert = 255
|
||||
)
|
||||
|
||||
var alertText = map[Alert]string{
|
||||
AlertCloseNotify: "close notify",
|
||||
AlertUnexpectedMessage: "unexpected message",
|
||||
AlertBadRecordMAC: "bad record MAC",
|
||||
AlertDecryptionFailed: "decryption failed",
|
||||
AlertRecordOverflow: "record overflow",
|
||||
AlertDecompressionFailure: "decompression failure",
|
||||
AlertHandshakeFailure: "handshake failure",
|
||||
AlertBadCertificate: "bad certificate",
|
||||
AlertUnsupportedCertificate: "unsupported certificate",
|
||||
AlertCertificateRevoked: "revoked certificate",
|
||||
AlertCertificateExpired: "expired certificate",
|
||||
AlertCertificateUnknown: "unknown certificate",
|
||||
AlertIllegalParameter: "illegal parameter",
|
||||
AlertUnknownCA: "unknown certificate authority",
|
||||
AlertAccessDenied: "access denied",
|
||||
AlertDecodeError: "error decoding message",
|
||||
AlertDecryptError: "error decrypting message",
|
||||
AlertProtocolVersion: "protocol version not supported",
|
||||
AlertInsufficientSecurity: "insufficient security level",
|
||||
AlertInternalError: "internal error",
|
||||
AlertInappropriateFallback: "inappropriate fallback",
|
||||
AlertUserCanceled: "user canceled",
|
||||
AlertMissingExtension: "missing extension",
|
||||
AlertUnsupportedExtension: "unsupported extension",
|
||||
AlertCertificateUnobtainable: "certificate unobtainable",
|
||||
AlertUnrecognizedName: "unrecognized name",
|
||||
AlertBadCertificateStatsResponse: "bad certificate status response",
|
||||
AlertBadCertificateHashValue: "bad certificate hash value",
|
||||
AlertUnknownPSKIdentity: "unknown PSK identity",
|
||||
AlertNoApplicationProtocol: "no application protocol",
|
||||
AlertNoRenegotiation: "no renegotiation",
|
||||
AlertWouldBlock: "would have blocked",
|
||||
AlertNoAlert: "no alert",
|
||||
}
|
||||
|
||||
func (e Alert) String() string {
|
||||
s, ok := alertText[e]
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
return "alert(" + strconv.Itoa(int(e)) + ")"
|
||||
}
|
||||
|
||||
func (e Alert) Error() string {
|
||||
return e.String()
|
||||
}
|
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
|
@ -1,42 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var url string
|
||||
|
||||
func main() {
|
||||
url := flag.String("url", "https://localhost:4430", "URL to send request")
|
||||
flag.Parse()
|
||||
mintdial := func(network, addr string) (net.Conn, error) {
|
||||
return mint.Dial(network, addr, nil)
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
DialTLS: mintdial,
|
||||
DisableCompression: true,
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
|
||||
response, err := client.Get(*url)
|
||||
if err != nil {
|
||||
fmt.Println("err:", err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
contents, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("%s\n", string(contents))
|
||||
}
|
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
|
@ -1,37 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var addr string
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&addr, "addr", "localhost:4430", "port")
|
||||
flag.Parse()
|
||||
|
||||
conn, err := mint.Dial("tcp", addr, nil)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("TLS handshake failed:", err)
|
||||
return
|
||||
}
|
||||
|
||||
request := "GET / HTTP/1.0\r\n\r\n"
|
||||
conn.Write([]byte(request))
|
||||
|
||||
response := ""
|
||||
buffer := make([]byte, 1024)
|
||||
var read int
|
||||
for err == nil {
|
||||
read, err = conn.Read(buffer)
|
||||
fmt.Println(" ~~ read: ", read)
|
||||
response += string(buffer)
|
||||
}
|
||||
fmt.Println("err:", err)
|
||||
fmt.Println("Received from server:")
|
||||
fmt.Println(response)
|
||||
}
|
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
|
@ -1,226 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
port string
|
||||
serverName string
|
||||
certFile string
|
||||
keyFile string
|
||||
responseFile string
|
||||
h2 bool
|
||||
sendTickets bool
|
||||
)
|
||||
|
||||
type responder []byte
|
||||
|
||||
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(rsp)
|
||||
}
|
||||
|
||||
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
|
||||
// PEM-encoded private key.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
|
||||
keyDER, _ := pem.Decode(keyPEM)
|
||||
if keyDER == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
// We don't include the actual error into
|
||||
// the final error. The reason might be
|
||||
// we don't want to leak any info about
|
||||
// the private key.
|
||||
return nil, fmt.Errorf("No successful private key decoder")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch generalKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return generalKey.(*rsa.PrivateKey), nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return generalKey.(*ecdsa.PrivateKey), nil
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
return nil, fmt.Errorf("Should be unreachable")
|
||||
}
|
||||
|
||||
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
|
||||
// either a raw x509 certificate or a PKCS #7 structure possibly containing
|
||||
// multiple certificates, from the top of certsPEM, which itself may
|
||||
// contain multiple PEM encoded certificate objects.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
|
||||
block, rest := pem.Decode(certsPEM)
|
||||
if block == nil {
|
||||
return nil, rest, nil
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
var certs = []*x509.Certificate{cert}
|
||||
return certs, rest, err
|
||||
}
|
||||
|
||||
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
|
||||
// can handle PEM encoded PKCS #7 structures.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
|
||||
var certs []*x509.Certificate
|
||||
var err error
|
||||
certsPEM = bytes.TrimSpace(certsPEM)
|
||||
for len(certsPEM) > 0 {
|
||||
var cert []*x509.Certificate
|
||||
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if cert == nil {
|
||||
break
|
||||
}
|
||||
|
||||
certs = append(certs, cert...)
|
||||
}
|
||||
if len(certsPEM) > 0 {
|
||||
return nil, fmt.Errorf("Trailing PEM data")
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&port, "port", "4430", "port")
|
||||
flag.StringVar(&serverName, "host", "example.com", "hostname")
|
||||
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
|
||||
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
|
||||
flag.StringVar(&responseFile, "response", "", "file to serve")
|
||||
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
|
||||
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
|
||||
flag.Parse()
|
||||
|
||||
var certChain []*x509.Certificate
|
||||
var priv crypto.Signer
|
||||
var response []byte
|
||||
var err error
|
||||
|
||||
// Load the key and certificate chain
|
||||
if certFile != "" {
|
||||
certs, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
} else {
|
||||
certChain, err = ParseCertificatesPEM(certs)
|
||||
if err != nil {
|
||||
certChain, err = x509.ParseCertificates(certs)
|
||||
if err != nil {
|
||||
log.Fatalf("Error parsing certificates: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if keyFile != "" {
|
||||
keyPEM, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
} else {
|
||||
priv, err = ParsePrivateKeyPEM(keyPEM)
|
||||
if priv == nil || err != nil {
|
||||
log.Fatalf("Error parsing private key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
|
||||
// Load response file
|
||||
if responseFile != "" {
|
||||
log.Printf("Loading response file: %v", responseFile)
|
||||
response, err = ioutil.ReadFile(responseFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
} else {
|
||||
response = []byte("Welcome to the TLS 1.3 zone!")
|
||||
}
|
||||
handler := responder(response)
|
||||
|
||||
config := mint.Config{
|
||||
SendSessionTickets: true,
|
||||
ServerName: serverName,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
|
||||
if h2 {
|
||||
config.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
config.SendSessionTickets = sendTickets
|
||||
|
||||
if certChain != nil && priv != nil {
|
||||
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
|
||||
config.Certificates = []*mint.Certificate{
|
||||
{
|
||||
Chain: certChain,
|
||||
PrivateKey: priv,
|
||||
},
|
||||
}
|
||||
}
|
||||
config.Init(false)
|
||||
|
||||
service := "0.0.0.0:" + port
|
||||
srv := &http.Server{Handler: handler}
|
||||
|
||||
log.Printf("Listening on port %v", port)
|
||||
// Need the inner loop here because the h1 server errors on a dropped connection
|
||||
// Need the outer loop here because the h2 server is per-connection
|
||||
for {
|
||||
listener, err := mint.Listen("tcp", service, &config)
|
||||
if err != nil {
|
||||
log.Printf("Listen Error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !h2 {
|
||||
alert := srv.Serve(listener)
|
||||
if alert != mint.AlertNoAlert {
|
||||
log.Printf("Serve Error: %v", err)
|
||||
}
|
||||
} else {
|
||||
srv2 := new(http2.Server)
|
||||
opts := &http2.ServeConnOpts{
|
||||
Handler: handler,
|
||||
BaseConfig: srv,
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go srv2.ServeConn(conn, opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
|
@ -1,65 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var port string
|
||||
|
||||
func main() {
|
||||
var config mint.Config
|
||||
config.SendSessionTickets = true
|
||||
config.ServerName = "localhost"
|
||||
config.Init(false)
|
||||
|
||||
flag.StringVar(&port, "port", "4430", "port")
|
||||
flag.Parse()
|
||||
|
||||
service := "0.0.0.0:" + port
|
||||
listener, err := mint.Listen("tcp", service, &config)
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("server: listen: %s", err)
|
||||
}
|
||||
log.Print("server: listening")
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("server: accept: %s", err)
|
||||
break
|
||||
}
|
||||
defer conn.Close()
|
||||
log.Printf("server: accepted from %s", conn.RemoteAddr())
|
||||
go handleClient(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func handleClient(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 10)
|
||||
for {
|
||||
log.Print("server: conn: waiting")
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != nil {
|
||||
log.Printf("server: conn: read: %s", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
n, err = conn.Write([]byte("hello world"))
|
||||
log.Printf("server: conn: wrote %d bytes", n)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("server: write: %s", err)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
log.Println("server: conn: closed")
|
||||
}
|
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
|
@ -1,942 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"hash"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client State Machine
|
||||
//
|
||||
// START <----+
|
||||
// Send ClientHello | | Recv HelloRetryRequest
|
||||
// / v |
|
||||
// | WAIT_SH ---+
|
||||
// Can | | Recv ServerHello
|
||||
// send | V
|
||||
// early | WAIT_EE
|
||||
// data | | Recv EncryptedExtensions
|
||||
// | +--------+--------+
|
||||
// | Using | | Using certificate
|
||||
// | PSK | v
|
||||
// | | WAIT_CERT_CR
|
||||
// | | Recv | | Recv CertificateRequest
|
||||
// | | Certificate | v
|
||||
// | | | WAIT_CERT
|
||||
// | | | | Recv Certificate
|
||||
// | | v v
|
||||
// | | WAIT_CV
|
||||
// | | | Recv CertificateVerify
|
||||
// | +> WAIT_FINISHED <+
|
||||
// | | Recv Finished
|
||||
// \ |
|
||||
// | [Send EndOfEarlyData]
|
||||
// | [Send Certificate [+ CertificateVerify]]
|
||||
// | Send Finished
|
||||
// Can send v
|
||||
// app data --> CONNECTED
|
||||
// after
|
||||
// here
|
||||
//
|
||||
// State Instructions
|
||||
// START Send(CH); [RekeyOut; SendEarlyData]
|
||||
// WAIT_SH Send(CH) || RekeyIn
|
||||
// WAIT_EE {}
|
||||
// WAIT_CERT_CR {}
|
||||
// WAIT_CERT {}
|
||||
// WAIT_CV {}
|
||||
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
||||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||
|
||||
type ClientStateStart struct {
|
||||
Caps Capabilities
|
||||
Opts ConnectionOptions
|
||||
Params ConnectionParameters
|
||||
|
||||
cookie []byte
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// key_shares
|
||||
offeredDH := map[NamedGroup][]byte{}
|
||||
ks := KeyShareExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
|
||||
}
|
||||
for i, group := range state.Caps.Groups {
|
||||
pub, priv, err := newKeyShare(group)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
ks.Shares[i].Group = group
|
||||
ks.Shares[i].KeyExchange = pub
|
||||
offeredDH[group] = priv
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "opts: %+v", state.Opts)
|
||||
|
||||
// supported_versions, supported_groups, signature_algorithms, server_name
|
||||
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
|
||||
sni := ServerNameExtension(state.Opts.ServerName)
|
||||
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
|
||||
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||
|
||||
state.Params.ServerName = state.Opts.ServerName
|
||||
|
||||
// Application Layer Protocol Negotiation
|
||||
var alpn *ALPNExtension
|
||||
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
|
||||
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
|
||||
}
|
||||
|
||||
// Construct base ClientHello
|
||||
ch := &ClientHelloBody{
|
||||
CipherSuites: state.Caps.CipherSuites,
|
||||
}
|
||||
_, err := prng.Read(ch.Random[:])
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
|
||||
err := ch.Extensions.Add(ext)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
// XXX: These optional extensions can't be folded into the above because Go
|
||||
// interface-typed values are never reported as nil
|
||||
if alpn != nil {
|
||||
err := ch.Extensions.Add(alpn)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.cookie != nil {
|
||||
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PSK and EarlyData just before transmitting, so that we can
|
||||
// calculate the PSK binder value
|
||||
var psk *PreSharedKeyExtension
|
||||
var ed *EarlyDataExtension
|
||||
var offeredPSK PreSharedKey
|
||||
var earlyHash crypto.Hash
|
||||
var earlySecret []byte
|
||||
var clientEarlyTrafficKeys keySet
|
||||
var clientHello *HandshakeMessage
|
||||
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
|
||||
offeredPSK = key
|
||||
|
||||
// Narrow ciphersuites to ones that match PSK hash
|
||||
params, ok := cipherSuiteMap[key.CipherSuite]
|
||||
if !ok {
|
||||
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
compatibleSuites := []CipherSuite{}
|
||||
for _, suite := range ch.CipherSuites {
|
||||
if cipherSuiteMap[suite].Hash == params.Hash {
|
||||
compatibleSuites = append(compatibleSuites, suite)
|
||||
}
|
||||
}
|
||||
ch.CipherSuites = compatibleSuites
|
||||
|
||||
// Signal early data if we're going to do it
|
||||
if len(state.Opts.EarlyData) > 0 {
|
||||
state.Params.ClientSendingEarlyData = true
|
||||
ed = &EarlyDataExtension{}
|
||||
err = ch.Extensions.Add(ed)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error adding early data extension: %v", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Signal supported PSK key exchange modes
|
||||
if len(state.Caps.PSKModes) == 0 {
|
||||
logf(logTypeHandshake, "PSK selected, but no PSKModes")
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
|
||||
err = ch.Extensions.Add(kem)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Add the shim PSK extension to the ClientHello
|
||||
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
|
||||
psk = &PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
Identities: []PSKIdentity{
|
||||
{
|
||||
Identity: key.Identity,
|
||||
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
|
||||
},
|
||||
},
|
||||
Binders: []PSKBinderEntry{
|
||||
// Note: Stub to get the length fields right
|
||||
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
|
||||
},
|
||||
}
|
||||
ch.Extensions.Add(psk)
|
||||
|
||||
// Compute the binder key
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
earlyHash = params.Hash
|
||||
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
|
||||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||
|
||||
binderLabel := labelExternalBinder
|
||||
if key.IsResumption {
|
||||
binderLabel = labelResumptionBinder
|
||||
}
|
||||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
|
||||
|
||||
// Compute the binder value
|
||||
trunc, err := ch.Truncated()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
truncHash := params.Hash.New()
|
||||
truncHash.Write(trunc)
|
||||
|
||||
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
|
||||
|
||||
// Replace the PSK extension
|
||||
psk.Binders[0].Binder = binder
|
||||
ch.Extensions.Add(psk)
|
||||
|
||||
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
||||
// this one should too.
|
||||
clientHello, _ = HandshakeMessageFromBody(ch)
|
||||
|
||||
// Compute early traffic keys
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
chHash := h.Sum(nil)
|
||||
|
||||
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
|
||||
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
|
||||
} else if len(state.Opts.EarlyData) > 0 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
|
||||
return nil, nil, AlertInternalError
|
||||
} else {
|
||||
clientHello, err = HandshakeMessageFromBody(ch)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
|
||||
nextState := ClientStateWaitSH{
|
||||
Caps: state.Caps,
|
||||
Opts: state.Opts,
|
||||
Params: state.Params,
|
||||
OfferedDH: offeredDH,
|
||||
OfferedPSK: offeredPSK,
|
||||
|
||||
earlySecret: earlySecret,
|
||||
earlyHash: earlyHash,
|
||||
|
||||
firstClientHello: state.firstClientHello,
|
||||
helloRetryRequest: state.helloRetryRequest,
|
||||
clientHello: clientHello,
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{clientHello},
|
||||
}
|
||||
if state.Params.ClientSendingEarlyData {
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||
SendEarlyData{},
|
||||
}...)
|
||||
}
|
||||
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitSH struct {
|
||||
Caps Capabilities
|
||||
Opts ConnectionOptions
|
||||
Params ConnectionParameters
|
||||
OfferedDH map[NamedGroup][]byte
|
||||
OfferedPSK PreSharedKey
|
||||
PSK []byte
|
||||
|
||||
earlySecret []byte
|
||||
earlyHash crypto.Hash
|
||||
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
clientHello *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *HelloRetryRequestBody:
|
||||
hrr := body
|
||||
|
||||
if state.helloRetryRequest != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Check that the version sent by the server is the one we support
|
||||
if hrr.Version != supportedVersion {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
// Check that the server provided a supported ciphersuite
|
||||
supportedCipherSuite := false
|
||||
for _, suite := range state.Caps.CipherSuites {
|
||||
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
|
||||
}
|
||||
if !supportedCipherSuite {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Narrow the supported ciphersuites to the server-provided one
|
||||
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// The only thing we know how to respond to in an HRR is the Cookie
|
||||
// extension, so if there is either no Cookie extension or anything other
|
||||
// than a Cookie extension, we have to fail.
|
||||
serverCookie := new(CookieExtension)
|
||||
foundCookie := hrr.Extensions.Find(serverCookie)
|
||||
if !foundCookie || len(hrr.Extensions) != 1 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
// Hash the body into a pseudo-message
|
||||
// XXX: Ignoring some errors here
|
||||
params := cipherSuiteMap[hrr.CipherSuite]
|
||||
h := params.Hash.New()
|
||||
h.Write(state.clientHello.Marshal())
|
||||
firstClientHello := &HandshakeMessage{
|
||||
msgType: HandshakeTypeMessageHash,
|
||||
body: h.Sum(nil),
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
|
||||
return ClientStateStart{
|
||||
Caps: state.Caps,
|
||||
Opts: state.Opts,
|
||||
cookie: serverCookie.Cookie,
|
||||
firstClientHello: firstClientHello,
|
||||
helloRetryRequest: hm,
|
||||
}.Next(nil)
|
||||
|
||||
case *ServerHelloBody:
|
||||
sh := body
|
||||
|
||||
// Check that the version sent by the server is the one we support
|
||||
if sh.Version != supportedVersion {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
// Check that the server provided a supported ciphersuite
|
||||
supportedCipherSuite := false
|
||||
for _, suite := range state.Caps.CipherSuites {
|
||||
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
|
||||
}
|
||||
if !supportedCipherSuite {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Do PSK or key agreement depending on extensions
|
||||
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
|
||||
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
|
||||
|
||||
foundPSK := sh.Extensions.Find(&serverPSK)
|
||||
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
|
||||
|
||||
if foundPSK && (serverPSK.SelectedIdentity == 0) {
|
||||
state.Params.UsingPSK = true
|
||||
}
|
||||
|
||||
var dhSecret []byte
|
||||
if foundKeyShare {
|
||||
sks := serverKeyShare.Shares[0]
|
||||
priv, ok := state.OfferedDH[sks.Group]
|
||||
if !ok {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
state.Params.UsingDH = true
|
||||
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
|
||||
}
|
||||
|
||||
suite := sh.CipherSuite
|
||||
state.Params.CipherSuite = suite
|
||||
|
||||
params, ok := cipherSuiteMap[suite]
|
||||
if !ok {
|
||||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Start up the handshake hash
|
||||
handshakeHash := params.Hash.New()
|
||||
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||
handshakeHash.Write(state.clientHello.Marshal())
|
||||
handshakeHash.Write(hm.Marshal())
|
||||
|
||||
// Compute handshake secrets
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
var earlySecret []byte
|
||||
if state.Params.UsingPSK {
|
||||
if params.Hash != state.earlyHash {
|
||||
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
|
||||
state.earlyHash, suite, params.Hash)
|
||||
}
|
||||
|
||||
earlySecret = state.earlySecret
|
||||
} else {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||
}
|
||||
|
||||
if dhSecret == nil {
|
||||
dhSecret = zero
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
h2 := handshakeHash.Sum(nil)
|
||||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
|
||||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||
|
||||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||
|
||||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
|
||||
nextState := ClientStateWaitEE{
|
||||
Caps: state.Caps,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
certificates: state.Caps.Certificates,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
|
||||
}
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
type ClientStateWaitEE struct {
|
||||
Caps Capabilities
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
certificates []*Certificate
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
ee := EncryptedExtensionsBody{}
|
||||
_, err := ee.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
serverALPN := ALPNExtension{}
|
||||
serverEarlyData := EarlyDataExtension{}
|
||||
|
||||
gotALPN := ee.Extensions.Find(&serverALPN)
|
||||
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
|
||||
|
||||
if gotALPN && len(serverALPN.Protocols) > 0 {
|
||||
state.Params.NextProto = serverALPN.Protocols[0]
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
if state.Params.UsingPSK {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
|
||||
nextState := ClientStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
|
||||
nextState := ClientStateWaitCertCR{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitCertCR struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
certificates []*Certificate
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *CertificateBody:
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
|
||||
nextState := ClientStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificate: body,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
|
||||
case *CertificateRequestBody:
|
||||
// A certificate request in the handshake should have a zero-length context
|
||||
if len(body.CertificateRequestContext) > 0 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
state.Params.UsingClientAuth = true
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
|
||||
nextState := ClientStateWaitCert{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificateRequest: body,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
type ClientStateWaitCert struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
cert := &CertificateBody{}
|
||||
_, err := cert.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
|
||||
nextState := ClientStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificate: cert,
|
||||
serverCertificateRequest: state.serverCertificateRequest,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitCV struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificate *CertificateBody
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
certVerify := CertificateVerifyBody{}
|
||||
_, err := certVerify.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
|
||||
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
if state.AuthCertificate != nil {
|
||||
err := state.AuthCertificate(state.serverCertificate.CertificateList)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
|
||||
return nil, nil, AlertBadCertificate
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
|
||||
nextState := ClientStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificateRequest: state.serverCertificateRequest,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitFinished struct {
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Verify server's Finished
|
||||
h3 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||
|
||||
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
|
||||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||
|
||||
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
|
||||
_, err := fin.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
|
||||
fin.VerifyData, serverFinishedData)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Update the handshake hash with the Finished
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
|
||||
h4 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
|
||||
|
||||
// Compute traffic secrets and keys
|
||||
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||
|
||||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
|
||||
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
|
||||
|
||||
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
|
||||
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||
|
||||
// Assemble client's second flight
|
||||
toSend := []HandshakeAction{}
|
||||
|
||||
if state.Params.UsingEarlyData {
|
||||
// Note: We only send EOED if the server is actually going to use the early
|
||||
// data. Otherwise, it will never see it, and the transcripts will
|
||||
// mismatch.
|
||||
// EOED marshal is infallible
|
||||
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
|
||||
toSend = append(toSend, SendHandshakeMessage{eoedm})
|
||||
state.handshakeHash.Write(eoedm.Marshal())
|
||||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
|
||||
}
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
|
||||
|
||||
if state.Params.UsingClientAuth {
|
||||
// Extract constraints from certicateRequest
|
||||
schemes := SignatureAlgorithmsExtension{}
|
||||
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
|
||||
if !gotSchemes {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
// Select a certificate
|
||||
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
|
||||
if err != nil {
|
||||
// XXX: Signal this to the application layer?
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||
|
||||
certificate := &CertificateBody{}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
state.handshakeHash.Write(certm.Marshal())
|
||||
} else {
|
||||
// Create and send Certificate, CertificateVerify
|
||||
certificate := &CertificateBody{
|
||||
CertificateList: make([]CertificateEntry, len(cert.Chain)),
|
||||
}
|
||||
for i, entry := range cert.Chain {
|
||||
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||
}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
state.handshakeHash.Write(certm.Marshal())
|
||||
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
|
||||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
|
||||
|
||||
err = certificateVerify.Sign(cert.PrivateKey, hcv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||
state.handshakeHash.Write(certvm.Marshal())
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the client's Finished message
|
||||
h5 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||
|
||||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||
|
||||
fin = &FinishedBody{
|
||||
VerifyDataLen: len(clientFinishedData),
|
||||
VerifyData: clientFinishedData,
|
||||
}
|
||||
finm, err := HandshakeMessageFromBody(fin)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Compute the resumption secret
|
||||
state.handshakeHash.Write(finm.Marshal())
|
||||
h6 := state.handshakeHash.Sum(nil)
|
||||
|
||||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
SendHandshakeMessage{finm},
|
||||
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
|
||||
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
|
||||
}...)
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
|
||||
nextState := StateConnected{
|
||||
Params: state.Params,
|
||||
isClient: true,
|
||||
cryptoParams: state.cryptoParams,
|
||||
resumptionSecret: resumptionSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
|
@ -1,152 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedVersion uint16 = 0x7f15 // draft-21
|
||||
|
||||
// Flags for some minor compat issues
|
||||
allowWrongVersionNumber = true
|
||||
allowPKCS1 = true
|
||||
)
|
||||
|
||||
// enum {...} ContentType;
|
||||
type RecordType byte
|
||||
|
||||
const (
|
||||
RecordTypeAlert RecordType = 21
|
||||
RecordTypeHandshake RecordType = 22
|
||||
RecordTypeApplicationData RecordType = 23
|
||||
)
|
||||
|
||||
// enum {...} HandshakeType;
|
||||
type HandshakeType byte
|
||||
|
||||
const (
|
||||
// Omitted: *_RESERVED
|
||||
HandshakeTypeClientHello HandshakeType = 1
|
||||
HandshakeTypeServerHello HandshakeType = 2
|
||||
HandshakeTypeNewSessionTicket HandshakeType = 4
|
||||
HandshakeTypeEndOfEarlyData HandshakeType = 5
|
||||
HandshakeTypeHelloRetryRequest HandshakeType = 6
|
||||
HandshakeTypeEncryptedExtensions HandshakeType = 8
|
||||
HandshakeTypeCertificate HandshakeType = 11
|
||||
HandshakeTypeCertificateRequest HandshakeType = 13
|
||||
HandshakeTypeCertificateVerify HandshakeType = 15
|
||||
HandshakeTypeServerConfiguration HandshakeType = 17
|
||||
HandshakeTypeFinished HandshakeType = 20
|
||||
HandshakeTypeKeyUpdate HandshakeType = 24
|
||||
HandshakeTypeMessageHash HandshakeType = 254
|
||||
)
|
||||
|
||||
// uint8 CipherSuite[2];
|
||||
type CipherSuite uint16
|
||||
|
||||
const (
|
||||
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
|
||||
// value for this type so that we can detect when a field is set.
|
||||
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
|
||||
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
|
||||
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
|
||||
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
|
||||
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
|
||||
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
|
||||
)
|
||||
|
||||
func (c CipherSuite) String() string {
|
||||
switch c {
|
||||
case CIPHER_SUITE_UNKNOWN:
|
||||
return "unknown"
|
||||
case TLS_AES_128_GCM_SHA256:
|
||||
return "TLS_AES_128_GCM_SHA256"
|
||||
case TLS_AES_256_GCM_SHA384:
|
||||
return "TLS_AES_256_GCM_SHA384"
|
||||
case TLS_CHACHA20_POLY1305_SHA256:
|
||||
return "TLS_CHACHA20_POLY1305_SHA256"
|
||||
case TLS_AES_128_CCM_SHA256:
|
||||
return "TLS_AES_128_CCM_SHA256"
|
||||
case TLS_AES_256_CCM_8_SHA256:
|
||||
return "TLS_AES_256_CCM_8_SHA256"
|
||||
}
|
||||
// cannot use %x here, since it calls String(), leading to infinite recursion
|
||||
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
|
||||
}
|
||||
|
||||
// enum {...} SignatureScheme
|
||||
type SignatureScheme uint16
|
||||
|
||||
const (
|
||||
// RSASSA-PKCS1-v1_5 algorithms
|
||||
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
|
||||
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
|
||||
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
|
||||
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
|
||||
// ECDSA algorithms
|
||||
ECDSA_P256_SHA256 SignatureScheme = 0x0403
|
||||
ECDSA_P384_SHA384 SignatureScheme = 0x0503
|
||||
ECDSA_P521_SHA512 SignatureScheme = 0x0603
|
||||
// RSASSA-PSS algorithms
|
||||
RSA_PSS_SHA256 SignatureScheme = 0x0804
|
||||
RSA_PSS_SHA384 SignatureScheme = 0x0805
|
||||
RSA_PSS_SHA512 SignatureScheme = 0x0806
|
||||
// EdDSA algorithms
|
||||
Ed25519 SignatureScheme = 0x0807
|
||||
Ed448 SignatureScheme = 0x0808
|
||||
)
|
||||
|
||||
// enum {...} ExtensionType
|
||||
type ExtensionType uint16
|
||||
|
||||
const (
|
||||
ExtensionTypeServerName ExtensionType = 0
|
||||
ExtensionTypeSupportedGroups ExtensionType = 10
|
||||
ExtensionTypeSignatureAlgorithms ExtensionType = 13
|
||||
ExtensionTypeALPN ExtensionType = 16
|
||||
ExtensionTypeKeyShare ExtensionType = 40
|
||||
ExtensionTypePreSharedKey ExtensionType = 41
|
||||
ExtensionTypeEarlyData ExtensionType = 42
|
||||
ExtensionTypeSupportedVersions ExtensionType = 43
|
||||
ExtensionTypeCookie ExtensionType = 44
|
||||
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
|
||||
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
|
||||
)
|
||||
|
||||
// enum {...} NamedGroup
|
||||
type NamedGroup uint16
|
||||
|
||||
const (
|
||||
// Elliptic Curve Groups.
|
||||
P256 NamedGroup = 23
|
||||
P384 NamedGroup = 24
|
||||
P521 NamedGroup = 25
|
||||
// ECDH functions.
|
||||
X25519 NamedGroup = 29
|
||||
X448 NamedGroup = 30
|
||||
// Finite field groups.
|
||||
FFDHE2048 NamedGroup = 256
|
||||
FFDHE3072 NamedGroup = 257
|
||||
FFDHE4096 NamedGroup = 258
|
||||
FFDHE6144 NamedGroup = 259
|
||||
FFDHE8192 NamedGroup = 260
|
||||
)
|
||||
|
||||
// enum {...} PskKeyExchangeMode;
|
||||
type PSKKeyExchangeMode uint8
|
||||
|
||||
const (
|
||||
PSKModeKE PSKKeyExchangeMode = 0
|
||||
PSKModeDHEKE PSKKeyExchangeMode = 1
|
||||
)
|
||||
|
||||
// enum {
|
||||
// update_not_requested(0), update_requested(1), (255)
|
||||
// } KeyUpdateRequest;
|
||||
type KeyUpdateRequest uint8
|
||||
|
||||
const (
|
||||
KeyUpdateNotRequested KeyUpdateRequest = 0
|
||||
KeyUpdateRequested KeyUpdateRequest = 1
|
||||
)
|
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
|
@ -1,819 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var WouldBlock = fmt.Errorf("Would have blocked")
|
||||
|
||||
type Certificate struct {
|
||||
Chain []*x509.Certificate
|
||||
PrivateKey crypto.Signer
|
||||
}
|
||||
|
||||
type PreSharedKey struct {
|
||||
CipherSuite CipherSuite
|
||||
IsResumption bool
|
||||
Identity []byte
|
||||
Key []byte
|
||||
NextProto string
|
||||
ReceivedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
TicketAgeAdd uint32
|
||||
}
|
||||
|
||||
type PreSharedKeyCache interface {
|
||||
Get(string) (PreSharedKey, bool)
|
||||
Put(string, PreSharedKey)
|
||||
Size() int
|
||||
}
|
||||
|
||||
type PSKMapCache map[string]PreSharedKey
|
||||
|
||||
// A CookieHandler does two things:
|
||||
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||
// - validates this byte string echoed by the client in the ClientHello
|
||||
type CookieHandler interface {
|
||||
Generate(*Conn) ([]byte, error)
|
||||
Validate(*Conn, []byte) bool
|
||||
}
|
||||
|
||||
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||||
psk, ok = cache[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
|
||||
(*cache)[key] = psk
|
||||
}
|
||||
|
||||
func (cache PSKMapCache) Size() int {
|
||||
return len(cache)
|
||||
}
|
||||
|
||||
// Config is the struct used to pass configuration settings to a TLS client or
|
||||
// server instance. The settings for client and server are pretty different,
|
||||
// but we just throw them all in here.
|
||||
type Config struct {
|
||||
// Client fields
|
||||
ServerName string
|
||||
|
||||
// Server fields
|
||||
SendSessionTickets bool
|
||||
TicketLifetime uint32
|
||||
TicketLen int
|
||||
EarlyDataLifetime uint32
|
||||
AllowEarlyData bool
|
||||
// Require the client to echo a cookie.
|
||||
RequireCookie bool
|
||||
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
||||
// The default cookie handler uses 32 random bytes as a cookie.
|
||||
CookieHandler CookieHandler
|
||||
RequireClientAuth bool
|
||||
|
||||
// Shared fields
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
NextProtos []string
|
||||
PSKs PreSharedKeyCache
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
NonBlocking bool
|
||||
|
||||
// The same config object can be shared among different connections, so it
|
||||
// needs its own mutex
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Clone returns a shallow clone of c. It is safe to clone a Config that is
|
||||
// being used concurrently by a TLS client or server.
|
||||
func (c *Config) Clone() *Config {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return &Config{
|
||||
ServerName: c.ServerName,
|
||||
|
||||
SendSessionTickets: c.SendSessionTickets,
|
||||
TicketLifetime: c.TicketLifetime,
|
||||
TicketLen: c.TicketLen,
|
||||
EarlyDataLifetime: c.EarlyDataLifetime,
|
||||
AllowEarlyData: c.AllowEarlyData,
|
||||
RequireCookie: c.RequireCookie,
|
||||
RequireClientAuth: c.RequireClientAuth,
|
||||
|
||||
Certificates: c.Certificates,
|
||||
AuthCertificate: c.AuthCertificate,
|
||||
CipherSuites: c.CipherSuites,
|
||||
Groups: c.Groups,
|
||||
SignatureSchemes: c.SignatureSchemes,
|
||||
NextProtos: c.NextProtos,
|
||||
PSKs: c.PSKs,
|
||||
PSKModes: c.PSKModes,
|
||||
NonBlocking: c.NonBlocking,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Init(isClient bool) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Set defaults
|
||||
if len(c.CipherSuites) == 0 {
|
||||
c.CipherSuites = defaultSupportedCipherSuites
|
||||
}
|
||||
if len(c.Groups) == 0 {
|
||||
c.Groups = defaultSupportedGroups
|
||||
}
|
||||
if len(c.SignatureSchemes) == 0 {
|
||||
c.SignatureSchemes = defaultSignatureSchemes
|
||||
}
|
||||
if c.TicketLen == 0 {
|
||||
c.TicketLen = defaultTicketLen
|
||||
}
|
||||
if !reflect.ValueOf(c.PSKs).IsValid() {
|
||||
c.PSKs = &PSKMapCache{}
|
||||
}
|
||||
if len(c.PSKModes) == 0 {
|
||||
c.PSKModes = defaultPSKModes
|
||||
}
|
||||
|
||||
// If there is no certificate, generate one
|
||||
if !isClient && len(c.Certificates) == 0 {
|
||||
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
||||
priv, err := newSigningKey(RSA_PSS_SHA256)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Certificates = []*Certificate{
|
||||
{
|
||||
Chain: []*x509.Certificate{cert},
|
||||
PrivateKey: priv,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) ValidForServer() bool {
|
||||
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
|
||||
(len(c.Certificates) > 0 &&
|
||||
len(c.Certificates[0].Chain) > 0 &&
|
||||
c.Certificates[0].PrivateKey != nil)
|
||||
}
|
||||
|
||||
func (c *Config) ValidForClient() bool {
|
||||
return len(c.ServerName) > 0
|
||||
}
|
||||
|
||||
var (
|
||||
defaultSupportedCipherSuites = []CipherSuite{
|
||||
TLS_AES_128_GCM_SHA256,
|
||||
TLS_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
defaultSupportedGroups = []NamedGroup{
|
||||
P256,
|
||||
P384,
|
||||
FFDHE2048,
|
||||
X25519,
|
||||
}
|
||||
|
||||
defaultSignatureSchemes = []SignatureScheme{
|
||||
RSA_PSS_SHA256,
|
||||
RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA512,
|
||||
ECDSA_P256_SHA256,
|
||||
ECDSA_P384_SHA384,
|
||||
ECDSA_P521_SHA512,
|
||||
}
|
||||
|
||||
defaultTicketLen = 16
|
||||
|
||||
defaultPSKModes = []PSKKeyExchangeMode{
|
||||
PSKModeKE,
|
||||
PSKModeDHEKE,
|
||||
}
|
||||
)
|
||||
|
||||
type ConnectionState struct {
|
||||
HandshakeState string // string representation of the handshake state.
|
||||
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
||||
NextProto string // Selected ALPN proto
|
||||
}
|
||||
|
||||
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||||
// * Read, Write, and Close are provided locally
|
||||
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
|
||||
type Conn struct {
|
||||
config *Config
|
||||
conn net.Conn
|
||||
isClient bool
|
||||
|
||||
EarlyData []byte
|
||||
|
||||
state StateConnected
|
||||
hState HandshakeState
|
||||
handshakeMutex sync.Mutex
|
||||
handshakeAlert Alert
|
||||
handshakeComplete bool
|
||||
|
||||
readBuffer []byte
|
||||
in, out *RecordLayer
|
||||
hIn, hOut *HandshakeLayer
|
||||
|
||||
extHandler AppExtensionHandler
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||||
c := &Conn{conn: conn, config: config, isClient: isClient}
|
||||
c.in = NewRecordLayer(c.conn)
|
||||
c.out = NewRecordLayer(c.conn)
|
||||
c.hIn = NewHandshakeLayer(c.in)
|
||||
c.hIn.nonblocking = c.config.NonBlocking
|
||||
c.hOut = NewHandshakeLayer(c.out)
|
||||
return c
|
||||
}
|
||||
|
||||
// Read up
|
||||
func (c *Conn) consumeRecord() error {
|
||||
pt, err := c.in.ReadRecord()
|
||||
if pt == nil {
|
||||
logf(logTypeIO, "extendBuffer returns error %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt.contentType {
|
||||
case RecordTypeHandshake:
|
||||
logf(logTypeHandshake, "Received post-handshake message")
|
||||
// We do not support fragmentation of post-handshake handshake messages.
|
||||
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||||
start := 0
|
||||
for start < len(pt.fragment) {
|
||||
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||||
}
|
||||
|
||||
hm := &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(pt.fragment[start])
|
||||
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||||
|
||||
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||||
}
|
||||
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
||||
|
||||
// Advance state machine
|
||||
state, actions, alert := c.state.Next(hm)
|
||||
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||
// authentication, we'll need to allow transitions other than
|
||||
// Connected -> Connected
|
||||
var connected bool
|
||||
c.state, connected = state.(StateConnected)
|
||||
if !connected {
|
||||
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
start += handshakeHeaderLen + hmLen
|
||||
}
|
||||
case RecordTypeAlert:
|
||||
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
if len(pt.fragment) != 2 {
|
||||
c.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
if Alert(pt.fragment[1]) == AlertCloseNotify {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
switch pt.fragment[0] {
|
||||
case AlertLevelWarning:
|
||||
// drop on the floor
|
||||
case AlertLevelError:
|
||||
return Alert(pt.fragment[1])
|
||||
default:
|
||||
c.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
case RecordTypeApplicationData:
|
||||
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||||
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Read application data up to the size of buffer. Handshake and alert records
|
||||
// are consumed by the Conn object directly.
|
||||
func (c *Conn) Read(buffer []byte) (int, error) {
|
||||
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||||
if alert := c.Handshake(); alert != AlertNoAlert {
|
||||
return 0, alert
|
||||
}
|
||||
|
||||
if len(buffer) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Lock the input channel
|
||||
c.in.Lock()
|
||||
defer c.in.Unlock()
|
||||
for len(c.readBuffer) == 0 {
|
||||
err := c.consumeRecord()
|
||||
|
||||
// err can be nil if consumeRecord processed a non app-data
|
||||
// record.
|
||||
if err != nil {
|
||||
if c.config.NonBlocking || err != WouldBlock {
|
||||
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var read int
|
||||
n := len(buffer)
|
||||
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
||||
if len(c.readBuffer) <= n {
|
||||
buffer = buffer[:len(c.readBuffer)]
|
||||
copy(buffer, c.readBuffer)
|
||||
read = len(c.readBuffer)
|
||||
c.readBuffer = c.readBuffer[:0]
|
||||
} else {
|
||||
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
||||
copy(buffer[:n], c.readBuffer[:n])
|
||||
c.readBuffer = c.readBuffer[n:]
|
||||
read = n
|
||||
}
|
||||
|
||||
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// Write application data
|
||||
func (c *Conn) Write(buffer []byte) (int, error) {
|
||||
// Lock the output channel
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
sent := 0
|
||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||
err := c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: buffer[start : start+maxFragmentLen],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent += maxFragmentLen
|
||||
}
|
||||
|
||||
// Send a final partial fragment if necessary
|
||||
if start < len(buffer) {
|
||||
err := c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: buffer[start:],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent += len(buffer[start:])
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
// c.out.Mutex <= L.
|
||||
func (c *Conn) sendAlert(err Alert) error {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
var level int
|
||||
switch err {
|
||||
case AlertNoRenegotiation, AlertCloseNotify:
|
||||
level = AlertLevelWarning
|
||||
default:
|
||||
level = AlertLevelError
|
||||
}
|
||||
|
||||
buf := []byte{byte(err), byte(level)}
|
||||
c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAlert,
|
||||
fragment: buf,
|
||||
})
|
||||
|
||||
// close_notify and end_of_early_data are not actually errors
|
||||
if level == AlertLevelWarning {
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
// XXX crypto/tls has an interlock with Write here. Do we need that?
|
||||
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines associated with the connection.
|
||||
// A zero value for t means Read and Write will not time out.
|
||||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline on the underlying connection.
|
||||
// A zero value for t means Read will not time out.
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||||
// A zero value for t means Write will not time out.
|
||||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||||
label := "[server]"
|
||||
if c.isClient {
|
||||
label = "[client]"
|
||||
}
|
||||
|
||||
switch action := actionGeneric.(type) {
|
||||
case SendHandshakeMessage:
|
||||
err := c.hOut.WriteMessage(action.Message)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case RekeyIn:
|
||||
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case RekeyOut:
|
||||
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case SendEarlyData:
|
||||
logf(logTypeHandshake, "%s Sending early data...", label)
|
||||
_, err := c.Write(c.EarlyData)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case ReadPastEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading past early data...", label)
|
||||
// Scan past all records that fail to decrypt
|
||||
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok := err.(DecryptError)
|
||||
|
||||
for ok {
|
||||
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok = err.(DecryptError)
|
||||
}
|
||||
|
||||
case ReadEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading early data...", label)
|
||||
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
||||
|
||||
for t == RecordTypeApplicationData {
|
||||
// Read a record into the buffer. Note that this is safe
|
||||
// in blocking mode because we read the record in in
|
||||
// PeekRecordType.
|
||||
pt, err := c.in.ReadRecord()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
||||
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
||||
|
||||
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
||||
}
|
||||
logf(logTypeHandshake, "%s Done reading early data", label)
|
||||
|
||||
case StorePSK:
|
||||
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||||
if c.isClient {
|
||||
// Clients look up PSKs based on server name
|
||||
c.config.PSKs.Put(c.config.ServerName, action.PSK)
|
||||
} else {
|
||||
// Servers look them up based on the identity in the extension
|
||||
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
|
||||
}
|
||||
|
||||
default:
|
||||
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
func (c *Conn) HandshakeSetup() Alert {
|
||||
var state HandshakeState
|
||||
var actions []HandshakeAction
|
||||
var alert Alert
|
||||
|
||||
if err := c.config.Init(c.isClient); err != nil {
|
||||
logf(logTypeHandshake, "Error initializing config: %v", err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
// Set things up
|
||||
caps := Capabilities{
|
||||
CipherSuites: c.config.CipherSuites,
|
||||
Groups: c.config.Groups,
|
||||
SignatureSchemes: c.config.SignatureSchemes,
|
||||
PSKs: c.config.PSKs,
|
||||
PSKModes: c.config.PSKModes,
|
||||
AllowEarlyData: c.config.AllowEarlyData,
|
||||
RequireCookie: c.config.RequireCookie,
|
||||
CookieHandler: c.config.CookieHandler,
|
||||
RequireClientAuth: c.config.RequireClientAuth,
|
||||
NextProtos: c.config.NextProtos,
|
||||
Certificates: c.config.Certificates,
|
||||
ExtensionHandler: c.extHandler,
|
||||
}
|
||||
opts := ConnectionOptions{
|
||||
ServerName: c.config.ServerName,
|
||||
NextProtos: c.config.NextProtos,
|
||||
EarlyData: c.EarlyData,
|
||||
}
|
||||
|
||||
if caps.RequireCookie && caps.CookieHandler == nil {
|
||||
caps.CookieHandler = &defaultCookieHandler{}
|
||||
}
|
||||
|
||||
if c.isClient {
|
||||
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||||
return alert
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
} else {
|
||||
state = ServerStateStart{Caps: caps, conn: c}
|
||||
}
|
||||
|
||||
c.hState = state
|
||||
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||||
// determines whether a client or server handshake is performed. If a
|
||||
// handshake has already been performed, then its result will be returned.
|
||||
func (c *Conn) Handshake() Alert {
|
||||
label := "[server]"
|
||||
if c.isClient {
|
||||
label = "[client]"
|
||||
}
|
||||
|
||||
// TODO Lock handshakeMutex
|
||||
// TODO Remove CloseNotify hack
|
||||
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
|
||||
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
|
||||
return c.handshakeAlert
|
||||
}
|
||||
if c.handshakeComplete {
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
var alert Alert
|
||||
if c.hState == nil {
|
||||
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
||||
alert = c.HandshakeSetup()
|
||||
if alert != AlertNoAlert {
|
||||
return alert
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
||||
}
|
||||
|
||||
state := c.hState
|
||||
_, connected := state.(StateConnected)
|
||||
|
||||
var actions []HandshakeAction
|
||||
|
||||
for !connected {
|
||||
// Read a handshake message
|
||||
hm, err := c.hIn.ReadMessage()
|
||||
if err == WouldBlock {
|
||||
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
||||
return AlertWouldBlock
|
||||
}
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
||||
c.sendAlert(AlertCloseNotify)
|
||||
return AlertCloseNotify
|
||||
}
|
||||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||
|
||||
// Advance the state machine
|
||||
state, actions, alert = state.Next(hm)
|
||||
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
return alert
|
||||
}
|
||||
|
||||
for index, action := range actions {
|
||||
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
|
||||
c.hState = state
|
||||
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||||
|
||||
_, connected = state.(StateConnected)
|
||||
}
|
||||
|
||||
c.state = state.(StateConnected)
|
||||
|
||||
// Send NewSessionTicket if acting as server
|
||||
if !c.isClient && c.config.SendSessionTickets {
|
||||
actions, alert := c.state.NewSessionTicket(
|
||||
c.config.TicketLen,
|
||||
c.config.TicketLifetime,
|
||||
c.config.EarlyDataLifetime)
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.handshakeComplete = true
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
||||
if !c.handshakeComplete {
|
||||
return fmt.Errorf("Cannot update keys until after handshake")
|
||||
}
|
||||
|
||||
request := KeyUpdateNotRequested
|
||||
if requestUpdate {
|
||||
request = KeyUpdateRequested
|
||||
}
|
||||
|
||||
// Create the key update and update state
|
||||
actions, alert := c.state.KeyUpdate(request)
|
||||
if alert != AlertNoAlert {
|
||||
c.sendAlert(alert)
|
||||
return fmt.Errorf("Alert while generating key update: %v", alert)
|
||||
}
|
||||
|
||||
// Take actions (send key update and rekey)
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
c.sendAlert(alert)
|
||||
return fmt.Errorf("Alert during key update actions: %v", alert)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) GetHsState() string {
|
||||
return reflect.TypeOf(c.hState).Name()
|
||||
}
|
||||
|
||||
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
_, connected := c.hState.(StateConnected)
|
||||
if !connected {
|
||||
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||||
}
|
||||
|
||||
if c.state.exporterSecret == nil {
|
||||
return nil, fmt.Errorf("Internal error: no exporter secret")
|
||||
}
|
||||
|
||||
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
|
||||
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
|
||||
|
||||
hc := c.state.cryptoParams.Hash.New().Sum(context)
|
||||
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||||
}
|
||||
|
||||
func (c *Conn) State() ConnectionState {
|
||||
state := ConnectionState{
|
||||
HandshakeState: c.GetHsState(),
|
||||
}
|
||||
|
||||
if c.handshakeComplete {
|
||||
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||||
state.NextProto = c.state.Params.NextProto
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
||||
if c.hState != nil {
|
||||
return fmt.Errorf("Can't set extension handler after setup")
|
||||
}
|
||||
|
||||
c.extHandler = h
|
||||
return nil
|
||||
}
|
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
|
@ -1,654 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
// Blank includes to ensure hash support
|
||||
_ "crypto/sha1"
|
||||
_ "crypto/sha256"
|
||||
_ "crypto/sha512"
|
||||
)
|
||||
|
||||
var prng = rand.Reader
|
||||
|
||||
type aeadFactory func(key []byte) (cipher.AEAD, error)
|
||||
|
||||
type CipherSuiteParams struct {
|
||||
Suite CipherSuite
|
||||
Cipher aeadFactory // Cipher factory
|
||||
Hash crypto.Hash // Hash function
|
||||
KeyLen int // Key length in octets
|
||||
IvLen int // IV length in octets
|
||||
}
|
||||
|
||||
type signatureAlgorithm uint8
|
||||
|
||||
const (
|
||||
signatureAlgorithmUnknown = iota
|
||||
signatureAlgorithmRSA_PKCS1
|
||||
signatureAlgorithmRSA_PSS
|
||||
signatureAlgorithmECDSA
|
||||
)
|
||||
|
||||
var (
|
||||
hashMap = map[SignatureScheme]crypto.Hash{
|
||||
RSA_PKCS1_SHA1: crypto.SHA1,
|
||||
RSA_PKCS1_SHA256: crypto.SHA256,
|
||||
RSA_PKCS1_SHA384: crypto.SHA384,
|
||||
RSA_PKCS1_SHA512: crypto.SHA512,
|
||||
ECDSA_P256_SHA256: crypto.SHA256,
|
||||
ECDSA_P384_SHA384: crypto.SHA384,
|
||||
ECDSA_P521_SHA512: crypto.SHA512,
|
||||
RSA_PSS_SHA256: crypto.SHA256,
|
||||
RSA_PSS_SHA384: crypto.SHA384,
|
||||
RSA_PSS_SHA512: crypto.SHA512,
|
||||
}
|
||||
|
||||
sigMap = map[SignatureScheme]signatureAlgorithm{
|
||||
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
|
||||
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
|
||||
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
|
||||
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
|
||||
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
|
||||
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
|
||||
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
|
||||
}
|
||||
|
||||
curveMap = map[SignatureScheme]NamedGroup{
|
||||
ECDSA_P256_SHA256: P256,
|
||||
ECDSA_P384_SHA384: P384,
|
||||
ECDSA_P521_SHA512: P521,
|
||||
}
|
||||
|
||||
newAESGCM = func(key []byte) (cipher.AEAD, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TLS always uses 12-byte nonces
|
||||
return cipher.NewGCMWithNonceSize(block, 12)
|
||||
}
|
||||
|
||||
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
|
||||
TLS_AES_128_GCM_SHA256: {
|
||||
Suite: TLS_AES_128_GCM_SHA256,
|
||||
Cipher: newAESGCM,
|
||||
Hash: crypto.SHA256,
|
||||
KeyLen: 16,
|
||||
IvLen: 12,
|
||||
},
|
||||
TLS_AES_256_GCM_SHA384: {
|
||||
Suite: TLS_AES_256_GCM_SHA384,
|
||||
Cipher: newAESGCM,
|
||||
Hash: crypto.SHA384,
|
||||
KeyLen: 32,
|
||||
IvLen: 12,
|
||||
},
|
||||
}
|
||||
|
||||
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
|
||||
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
|
||||
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
|
||||
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
|
||||
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
|
||||
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
|
||||
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
|
||||
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
|
||||
}
|
||||
|
||||
defaultRSAKeySize = 2048
|
||||
)
|
||||
|
||||
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
|
||||
switch group {
|
||||
case P256:
|
||||
crv = elliptic.P256()
|
||||
case P384:
|
||||
crv = elliptic.P384()
|
||||
case P521:
|
||||
crv = elliptic.P521()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
|
||||
switch key.Curve.Params().Name {
|
||||
case elliptic.P256().Params().Name:
|
||||
g = P256
|
||||
case elliptic.P384().Params().Name:
|
||||
g = P384
|
||||
case elliptic.P521().Params().Name:
|
||||
g = P521
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
|
||||
size = 0
|
||||
switch group {
|
||||
case X25519:
|
||||
size = 32
|
||||
case P256:
|
||||
size = 65
|
||||
case P384:
|
||||
size = 97
|
||||
case P521:
|
||||
size = 133
|
||||
case FFDHE2048:
|
||||
size = 256
|
||||
case FFDHE3072:
|
||||
size = 384
|
||||
case FFDHE4096:
|
||||
size = 512
|
||||
case FFDHE6144:
|
||||
size = 768
|
||||
case FFDHE8192:
|
||||
size = 1024
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
|
||||
switch group {
|
||||
case FFDHE2048:
|
||||
p = finiteFieldPrime2048
|
||||
case FFDHE3072:
|
||||
p = finiteFieldPrime3072
|
||||
case FFDHE4096:
|
||||
p = finiteFieldPrime4096
|
||||
case FFDHE6144:
|
||||
p = finiteFieldPrime6144
|
||||
case FFDHE8192:
|
||||
p = finiteFieldPrime8192
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
|
||||
sigType := sigMap[alg]
|
||||
switch key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
|
||||
case *ecdsa.PrivateKey:
|
||||
return sigType == signatureAlgorithmECDSA
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
|
||||
primeLen := len(p.Bytes())
|
||||
for {
|
||||
// g = 2 for all ffdhe groups
|
||||
priv, err = rand.Int(prng, p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pub = big.NewInt(0)
|
||||
pub.Exp(big.NewInt(2), priv, p)
|
||||
|
||||
if len(pub.Bytes()) == primeLen {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
|
||||
switch group {
|
||||
case P256, P384, P521:
|
||||
var x, y *big.Int
|
||||
crv := curveFromNamedGroup(group)
|
||||
priv, x, y, err = elliptic.GenerateKey(crv, prng)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pub = elliptic.Marshal(crv, x, y)
|
||||
return
|
||||
|
||||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||
p := primeFromNamedGroup(group)
|
||||
x, X, err2 := ffdheKeyShareFromPrime(p)
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
|
||||
priv = x.Bytes()
|
||||
pubBytes := X.Bytes()
|
||||
|
||||
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||
|
||||
pub = make([]byte, numBytes)
|
||||
copy(pub[numBytes-len(pubBytes):], pubBytes)
|
||||
|
||||
return
|
||||
|
||||
case X25519:
|
||||
var private, public [32]byte
|
||||
_, err = prng.Read(private[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
curve25519.ScalarBaseMult(&public, &private)
|
||||
priv = private[:]
|
||||
pub = public[:]
|
||||
return
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
|
||||
}
|
||||
}
|
||||
|
||||
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
|
||||
switch group {
|
||||
case P256, P384, P521:
|
||||
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
|
||||
crv := curveFromNamedGroup(group)
|
||||
pubX, pubY := elliptic.Unmarshal(crv, pub)
|
||||
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
|
||||
xBytes := x.Bytes()
|
||||
|
||||
numBytes := len(crv.Params().P.Bytes())
|
||||
|
||||
ret := make([]byte, numBytes)
|
||||
copy(ret[numBytes-len(xBytes):], xBytes)
|
||||
|
||||
return ret, nil
|
||||
|
||||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||
if len(pub) != numBytes {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
p := primeFromNamedGroup(group)
|
||||
x := big.NewInt(0).SetBytes(priv)
|
||||
Y := big.NewInt(0).SetBytes(pub)
|
||||
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
|
||||
|
||||
ret := make([]byte, numBytes)
|
||||
copy(ret[numBytes-len(ZBytes):], ZBytes)
|
||||
|
||||
return ret, nil
|
||||
|
||||
case X25519:
|
||||
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
|
||||
var private, public, ret [32]byte
|
||||
copy(private[:], priv)
|
||||
copy(public[:], pub)
|
||||
curve25519.ScalarMult(&ret, &private, &public)
|
||||
|
||||
return ret[:], nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
|
||||
}
|
||||
}
|
||||
|
||||
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
|
||||
switch sig {
|
||||
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
|
||||
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
|
||||
RSA_PSS_SHA256, RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA512:
|
||||
return rsa.GenerateKey(prng, defaultRSAKeySize)
|
||||
case ECDSA_P256_SHA256:
|
||||
return ecdsa.GenerateKey(elliptic.P256(), prng)
|
||||
case ECDSA_P384_SHA384:
|
||||
return ecdsa.GenerateKey(elliptic.P384(), prng)
|
||||
case ECDSA_P521_SHA512:
|
||||
return ecdsa.GenerateKey(elliptic.P521(), prng)
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
|
||||
}
|
||||
}
|
||||
|
||||
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||
sigAlg, ok := x509AlgMap[alg]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||
}
|
||||
if len(name) == 0 {
|
||||
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||
SignatureAlgorithm: sigAlg,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
DNSNames: []string{name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// It is safe to ignore the error here because we're parsing known-good data
|
||||
cert, _ := x509.ParseCertificate(der)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// XXX(rlb): Copied from crypto/x509
|
||||
type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
|
||||
var opts crypto.SignerOpts
|
||||
|
||||
hash := hashMap[alg]
|
||||
if hash == crypto.SHA1 {
|
||||
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||
}
|
||||
|
||||
sigType := sigMap[alg]
|
||||
var realInput []byte
|
||||
switch key := privateKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
switch {
|
||||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
|
||||
opts = hash
|
||||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
fallthrough
|
||||
case sigType == signatureAlgorithmRSA_PSS:
|
||||
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
|
||||
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput = h.Sum(nil)
|
||||
case *ecdsa.PrivateKey:
|
||||
if sigType != signatureAlgorithmECDSA {
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
|
||||
}
|
||||
|
||||
algGroup := curveMap[alg]
|
||||
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
|
||||
if algGroup != keyGroup {
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput = h.Sum(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
|
||||
}
|
||||
|
||||
sig, err := privateKey.Sign(prng, realInput, opts)
|
||||
logf(logTypeCrypto, "signature: %x", sig)
|
||||
return sig, err
|
||||
}
|
||||
|
||||
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
|
||||
hash := hashMap[alg]
|
||||
|
||||
if hash == crypto.SHA1 {
|
||||
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||
}
|
||||
|
||||
sigType := sigMap[alg]
|
||||
switch pub := publicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
switch {
|
||||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
|
||||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
fallthrough
|
||||
case sigType == signatureAlgorithmRSA_PSS:
|
||||
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
|
||||
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
|
||||
default:
|
||||
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
|
||||
}
|
||||
|
||||
case *ecdsa.PublicKey:
|
||||
if sigType != signatureAlgorithmECDSA {
|
||||
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
|
||||
}
|
||||
|
||||
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
|
||||
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
|
||||
}
|
||||
|
||||
ecdsaSig := new(ecdsaSignature)
|
||||
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
|
||||
return err
|
||||
} else if len(rest) != 0 {
|
||||
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
|
||||
}
|
||||
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
|
||||
return fmt.Errorf("tls.verify: ECDSA verification failure")
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("tls.verify: Unsupported key type")
|
||||
}
|
||||
}
|
||||
|
||||
// 0
|
||||
// |
|
||||
// v
|
||||
// PSK -> HKDF-Extract = Early Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(.,
|
||||
// | "ext binder" |
|
||||
// | "res binder",
|
||||
// | "")
|
||||
// | = binder_key
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c e traffic",
|
||||
// | ClientHello)
|
||||
// | = client_early_traffic_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "e exp master",
|
||||
// | ClientHello)
|
||||
// | = early_exporter_master_secret
|
||||
// v
|
||||
// Derive-Secret(., "derived", "")
|
||||
// |
|
||||
// v
|
||||
// (EC)DHE -> HKDF-Extract = Handshake Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c hs traffic",
|
||||
// | ClientHello...ServerHello)
|
||||
// | = client_handshake_traffic_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "s hs traffic",
|
||||
// | ClientHello...ServerHello)
|
||||
// | = server_handshake_traffic_secret
|
||||
// v
|
||||
// Derive-Secret(., "derived", "")
|
||||
// |
|
||||
// v
|
||||
// 0 -> HKDF-Extract = Master Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c ap traffic",
|
||||
// | ClientHello...server Finished)
|
||||
// | = client_application_traffic_secret_0
|
||||
// |
|
||||
// +-----> Derive-Secret(., "s ap traffic",
|
||||
// | ClientHello...server Finished)
|
||||
// | = server_application_traffic_secret_0
|
||||
// |
|
||||
// +-----> Derive-Secret(., "exp master",
|
||||
// | ClientHello...server Finished)
|
||||
// | = exporter_master_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "res master",
|
||||
// ClientHello...client Finished)
|
||||
// = resumption_master_secret
|
||||
|
||||
// From RFC 5869
|
||||
// PRK = HMAC-Hash(salt, IKM)
|
||||
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
|
||||
salt := saltIn
|
||||
|
||||
// if [salt is] not provided, it is set to a string of HashLen zeros
|
||||
if salt == nil {
|
||||
salt = bytes.Repeat([]byte{0}, hash.Size())
|
||||
}
|
||||
|
||||
h := hmac.New(hash.New, salt)
|
||||
h.Write(input)
|
||||
out := h.Sum(nil)
|
||||
|
||||
logf(logTypeCrypto, "HKDF Extract:\n")
|
||||
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
|
||||
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
|
||||
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
const (
|
||||
labelExternalBinder = "ext binder"
|
||||
labelResumptionBinder = "res binder"
|
||||
labelEarlyTrafficSecret = "c e traffic"
|
||||
labelEarlyExporterSecret = "e exp master"
|
||||
labelClientHandshakeTrafficSecret = "c hs traffic"
|
||||
labelServerHandshakeTrafficSecret = "s hs traffic"
|
||||
labelClientApplicationTrafficSecret = "c ap traffic"
|
||||
labelServerApplicationTrafficSecret = "s ap traffic"
|
||||
labelExporterSecret = "exp master"
|
||||
labelResumptionSecret = "res master"
|
||||
labelDerived = "derived"
|
||||
labelFinished = "finished"
|
||||
labelResumption = "resumption"
|
||||
)
|
||||
|
||||
// struct HkdfLabel {
|
||||
// uint16 length;
|
||||
// opaque label<9..255>;
|
||||
// opaque hash_value<0..255>;
|
||||
// };
|
||||
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
|
||||
label := "tls13 " + labelIn
|
||||
|
||||
labelLen := len(label)
|
||||
hashLen := len(hashValue)
|
||||
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
|
||||
hkdfLabel[0] = byte(outLen >> 8)
|
||||
hkdfLabel[1] = byte(outLen)
|
||||
hkdfLabel[2] = byte(labelLen)
|
||||
copy(hkdfLabel[3:3+labelLen], []byte(label))
|
||||
hkdfLabel[3+labelLen] = byte(hashLen)
|
||||
copy(hkdfLabel[3+labelLen+1:], hashValue)
|
||||
|
||||
return hkdfLabel
|
||||
}
|
||||
|
||||
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
|
||||
out := []byte{}
|
||||
T := []byte{}
|
||||
i := byte(1)
|
||||
for len(out) < outLen {
|
||||
block := append(T, info...)
|
||||
block = append(block, i)
|
||||
|
||||
h := hmac.New(hash.New, prk)
|
||||
h.Write(block)
|
||||
|
||||
T = h.Sum(nil)
|
||||
out = append(out, T...)
|
||||
i++
|
||||
}
|
||||
return out[:outLen]
|
||||
}
|
||||
|
||||
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
|
||||
info := hkdfEncodeLabel(label, hashValue, outLen)
|
||||
derived := HkdfExpand(hash, secret, info, outLen)
|
||||
|
||||
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
|
||||
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
|
||||
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
|
||||
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
|
||||
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
|
||||
|
||||
return derived
|
||||
}
|
||||
|
||||
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
|
||||
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
|
||||
}
|
||||
|
||||
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
|
||||
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
|
||||
mac := hmac.New(params.Hash.New, macKey)
|
||||
mac.Write(input)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
type keySet struct {
|
||||
cipher aeadFactory
|
||||
key []byte
|
||||
iv []byte
|
||||
}
|
||||
|
||||
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
|
||||
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
|
||||
return keySet{
|
||||
cipher: params.Cipher,
|
||||
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
|
||||
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
||||
}
|
||||
}
|
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
|
@ -1,586 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
)
|
||||
|
||||
type ExtensionBody interface {
|
||||
Type() ExtensionType
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ExtensionType extension_type;
|
||||
// opaque extension_data<0..2^16-1>;
|
||||
// } Extension;
|
||||
type Extension struct {
|
||||
ExtensionType ExtensionType
|
||||
ExtensionData []byte `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ext Extension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ext)
|
||||
}
|
||||
|
||||
func (ext *Extension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ext)
|
||||
}
|
||||
|
||||
type ExtensionList []Extension
|
||||
|
||||
type extensionListInner struct {
|
||||
List []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (el ExtensionList) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(extensionListInner{el})
|
||||
}
|
||||
|
||||
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
|
||||
var list extensionListInner
|
||||
read, err := syntax.Unmarshal(data, &list)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
*el = list.List
|
||||
return read, nil
|
||||
}
|
||||
|
||||
func (el *ExtensionList) Add(src ExtensionBody) error {
|
||||
data, err := src.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if el == nil {
|
||||
el = new(ExtensionList)
|
||||
}
|
||||
|
||||
// If one already exists with this type, replace it
|
||||
for i := range *el {
|
||||
if (*el)[i].ExtensionType == src.Type() {
|
||||
(*el)[i].ExtensionData = data
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise append
|
||||
*el = append(*el, Extension{
|
||||
ExtensionType: src.Type(),
|
||||
ExtensionData: data,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (el ExtensionList) Find(dst ExtensionBody) bool {
|
||||
for _, ext := range el {
|
||||
if ext.ExtensionType == dst.Type() {
|
||||
_, err := dst.Unmarshal(ext.ExtensionData)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NameType name_type;
|
||||
// select (name_type) {
|
||||
// case host_name: HostName;
|
||||
// } name;
|
||||
// } ServerName;
|
||||
//
|
||||
// enum {
|
||||
// host_name(0), (255)
|
||||
// } NameType;
|
||||
//
|
||||
// opaque HostName<1..2^16-1>;
|
||||
//
|
||||
// struct {
|
||||
// ServerName server_name_list<1..2^16-1>
|
||||
// } ServerNameList;
|
||||
//
|
||||
// But we only care about the case where there's a single DNS hostname. We
|
||||
// will never create anything else, and throw if we receive something else
|
||||
//
|
||||
// 2 1 2
|
||||
// | listLen | NameType | nameLen | name |
|
||||
type ServerNameExtension string
|
||||
|
||||
type serverNameInner struct {
|
||||
NameType uint8
|
||||
HostName []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
type serverNameListInner struct {
|
||||
ServerNameList []serverNameInner `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (sni ServerNameExtension) Type() ExtensionType {
|
||||
return ExtensionTypeServerName
|
||||
}
|
||||
|
||||
func (sni ServerNameExtension) Marshal() ([]byte, error) {
|
||||
list := serverNameListInner{
|
||||
ServerNameList: []serverNameInner{{
|
||||
NameType: 0x00, // host_name
|
||||
HostName: []byte(sni),
|
||||
}},
|
||||
}
|
||||
|
||||
return syntax.Marshal(list)
|
||||
}
|
||||
|
||||
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
|
||||
var list serverNameListInner
|
||||
read, err := syntax.Unmarshal(data, &list)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Syntax requires at least one entry
|
||||
// Entries beyond the first are ignored
|
||||
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
|
||||
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
|
||||
}
|
||||
|
||||
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NamedGroup group;
|
||||
// opaque key_exchange<1..2^16-1>;
|
||||
// } KeyShareEntry;
|
||||
//
|
||||
// struct {
|
||||
// select (Handshake.msg_type) {
|
||||
// case client_hello:
|
||||
// KeyShareEntry client_shares<0..2^16-1>;
|
||||
//
|
||||
// case hello_retry_request:
|
||||
// NamedGroup selected_group;
|
||||
//
|
||||
// case server_hello:
|
||||
// KeyShareEntry server_share;
|
||||
// };
|
||||
// } KeyShare;
|
||||
type KeyShareEntry struct {
|
||||
Group NamedGroup
|
||||
KeyExchange []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (kse KeyShareEntry) SizeValid() bool {
|
||||
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
|
||||
}
|
||||
|
||||
type KeyShareExtension struct {
|
||||
HandshakeType HandshakeType
|
||||
SelectedGroup NamedGroup
|
||||
Shares []KeyShareEntry
|
||||
}
|
||||
|
||||
type KeyShareClientHelloInner struct {
|
||||
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
|
||||
}
|
||||
type KeyShareHelloRetryInner struct {
|
||||
SelectedGroup NamedGroup
|
||||
}
|
||||
type KeyShareServerHelloInner struct {
|
||||
ServerShare KeyShareEntry
|
||||
}
|
||||
|
||||
func (ks KeyShareExtension) Type() ExtensionType {
|
||||
return ExtensionTypeKeyShare
|
||||
}
|
||||
|
||||
func (ks KeyShareExtension) Marshal() ([]byte, error) {
|
||||
switch ks.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
for _, share := range ks.Shares {
|
||||
if !share.SizeValid() {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
}
|
||||
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
|
||||
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
if len(ks.Shares) > 0 {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
|
||||
}
|
||||
|
||||
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
if len(ks.Shares) != 1 {
|
||||
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
|
||||
}
|
||||
|
||||
if !ks.Shares[0].SizeValid() {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
|
||||
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
|
||||
switch ks.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
var inner KeyShareClientHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, share := range inner.ClientShares {
|
||||
if !share.SizeValid() {
|
||||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
}
|
||||
|
||||
ks.Shares = inner.ClientShares
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
var inner KeyShareHelloRetryInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ks.SelectedGroup = inner.SelectedGroup
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
var inner KeyShareServerHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if !inner.ServerShare.SizeValid() {
|
||||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
|
||||
ks.Shares = []KeyShareEntry{inner.ServerShare}
|
||||
return read, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NamedGroup named_group_list<2..2^16-1>;
|
||||
// } NamedGroupList;
|
||||
type SupportedGroupsExtension struct {
|
||||
Groups []NamedGroup `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (sg SupportedGroupsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSupportedGroups
|
||||
}
|
||||
|
||||
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sg)
|
||||
}
|
||||
|
||||
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sg)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
|
||||
// } SignatureSchemeList
|
||||
type SignatureAlgorithmsExtension struct {
|
||||
Algorithms []SignatureScheme `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSignatureAlgorithms
|
||||
}
|
||||
|
||||
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sa)
|
||||
}
|
||||
|
||||
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sa)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque identity<1..2^16-1>;
|
||||
// uint32 obfuscated_ticket_age;
|
||||
// } PskIdentity;
|
||||
//
|
||||
// opaque PskBinderEntry<32..255>;
|
||||
//
|
||||
// struct {
|
||||
// select (Handshake.msg_type) {
|
||||
// case client_hello:
|
||||
// PskIdentity identities<7..2^16-1>;
|
||||
// PskBinderEntry binders<33..2^16-1>;
|
||||
//
|
||||
// case server_hello:
|
||||
// uint16 selected_identity;
|
||||
// };
|
||||
//
|
||||
// } PreSharedKeyExtension;
|
||||
type PSKIdentity struct {
|
||||
Identity []byte `tls:"head=2,min=1"`
|
||||
ObfuscatedTicketAge uint32
|
||||
}
|
||||
|
||||
type PSKBinderEntry struct {
|
||||
Binder []byte `tls:"head=1,min=32"`
|
||||
}
|
||||
|
||||
type PreSharedKeyExtension struct {
|
||||
HandshakeType HandshakeType
|
||||
Identities []PSKIdentity
|
||||
Binders []PSKBinderEntry
|
||||
SelectedIdentity uint16
|
||||
}
|
||||
|
||||
type preSharedKeyClientInner struct {
|
||||
Identities []PSKIdentity `tls:"head=2,min=7"`
|
||||
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||
}
|
||||
|
||||
type preSharedKeyServerInner struct {
|
||||
SelectedIdentity uint16
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) Type() ExtensionType {
|
||||
return ExtensionTypePreSharedKey
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
|
||||
switch psk.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
return syntax.Marshal(preSharedKeyClientInner{
|
||||
Identities: psk.Identities,
|
||||
Binders: psk.Binders,
|
||||
})
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
|
||||
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
|
||||
}
|
||||
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
|
||||
switch psk.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
var inner preSharedKeyClientInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(inner.Identities) != len(inner.Binders) {
|
||||
return 0, fmt.Errorf("Lengths of identities and binders not equal")
|
||||
}
|
||||
|
||||
psk.Identities = inner.Identities
|
||||
psk.Binders = inner.Binders
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
var inner preSharedKeyServerInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
psk.SelectedIdentity = inner.SelectedIdentity
|
||||
return read, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
|
||||
for i, localID := range psk.Identities {
|
||||
if bytes.Equal(localID.Identity, id) {
|
||||
return psk.Binders[i].Binder, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
|
||||
//
|
||||
// struct {
|
||||
// PskKeyExchangeMode ke_modes<1..255>;
|
||||
// } PskKeyExchangeModes;
|
||||
type PSKKeyExchangeModesExtension struct {
|
||||
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
|
||||
}
|
||||
|
||||
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
|
||||
return ExtensionTypePSKKeyExchangeModes
|
||||
}
|
||||
|
||||
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(pkem)
|
||||
}
|
||||
|
||||
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, pkem)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// } EarlyDataIndication;
|
||||
|
||||
type EarlyDataExtension struct{}
|
||||
|
||||
func (ed EarlyDataExtension) Type() ExtensionType {
|
||||
return ExtensionTypeEarlyData
|
||||
}
|
||||
|
||||
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// uint32 max_early_data_size;
|
||||
// } TicketEarlyDataInfo;
|
||||
|
||||
type TicketEarlyDataInfoExtension struct {
|
||||
MaxEarlyDataSize uint32
|
||||
}
|
||||
|
||||
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
|
||||
return ExtensionTypeTicketEarlyDataInfo
|
||||
}
|
||||
|
||||
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(tedi)
|
||||
}
|
||||
|
||||
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, tedi)
|
||||
}
|
||||
|
||||
// opaque ProtocolName<1..2^8-1>;
|
||||
//
|
||||
// struct {
|
||||
// ProtocolName protocol_name_list<2..2^16-1>
|
||||
// } ProtocolNameList;
|
||||
type ALPNExtension struct {
|
||||
Protocols []string
|
||||
}
|
||||
|
||||
type protocolNameInner struct {
|
||||
Name []byte `tls:"head=1,min=1"`
|
||||
}
|
||||
|
||||
type alpnExtensionInner struct {
|
||||
Protocols []protocolNameInner `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (alpn ALPNExtension) Type() ExtensionType {
|
||||
return ExtensionTypeALPN
|
||||
}
|
||||
|
||||
func (alpn ALPNExtension) Marshal() ([]byte, error) {
|
||||
protocols := make([]protocolNameInner, len(alpn.Protocols))
|
||||
for i, protocol := range alpn.Protocols {
|
||||
protocols[i] = protocolNameInner{[]byte(protocol)}
|
||||
}
|
||||
return syntax.Marshal(alpnExtensionInner{protocols})
|
||||
}
|
||||
|
||||
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
|
||||
var inner alpnExtensionInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
alpn.Protocols = make([]string, len(inner.Protocols))
|
||||
for i, protocol := range inner.Protocols {
|
||||
alpn.Protocols[i] = string(protocol.Name)
|
||||
}
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion versions<2..254>;
|
||||
// } SupportedVersions;
|
||||
type SupportedVersionsExtension struct {
|
||||
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSupportedVersions
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sv)
|
||||
}
|
||||
|
||||
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sv)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque cookie<1..2^16-1>;
|
||||
// } Cookie;
|
||||
type CookieExtension struct {
|
||||
Cookie []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (c CookieExtension) Type() ExtensionType {
|
||||
return ExtensionTypeCookie
|
||||
}
|
||||
|
||||
func (c CookieExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(c)
|
||||
}
|
||||
|
||||
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, c)
|
||||
}
|
||||
|
||||
// defaultCookieLength is the default length of a cookie
|
||||
const defaultCookieLength = 32
|
||||
|
||||
type defaultCookieHandler struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
var _ CookieHandler = &defaultCookieHandler{}
|
||||
|
||||
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
||||
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
|
||||
h.data = make([]byte, defaultCookieLength)
|
||||
if _, err := prng.Read(h.data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.data, nil
|
||||
}
|
||||
|
||||
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
|
||||
return bytes.Equal(h.data, data)
|
||||
}
|
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
|
@ -1,147 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
var (
|
||||
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B423861285C97FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
|
||||
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
|
||||
|
||||
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
|
||||
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
|
||||
|
||||
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
|
||||
"FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
|
||||
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
|
||||
|
||||
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
|
||||
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
|
||||
|
||||
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
|
||||
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
|
||||
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
|
||||
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
|
||||
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
|
||||
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
|
||||
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
|
||||
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
|
||||
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
|
||||
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
|
||||
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
|
||||
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
|
||||
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
|
||||
)
|
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
|
@ -1,98 +0,0 @@
|
|||
// Read a generic "framed" packet consisting of a header and a
|
||||
// This is used for both TLS Records and TLS Handshake Messages
|
||||
package mint
|
||||
|
||||
type framing interface {
|
||||
headerLen() int
|
||||
defaultReadLen() int
|
||||
frameLen(hdr []byte) (int, error)
|
||||
}
|
||||
|
||||
const (
|
||||
kFrameReaderHdr = 0
|
||||
kFrameReaderBody = 1
|
||||
)
|
||||
|
||||
type frameNextAction func(f *frameReader) error
|
||||
|
||||
type frameReader struct {
|
||||
details framing
|
||||
state uint8
|
||||
header []byte
|
||||
body []byte
|
||||
working []byte
|
||||
writeOffset int
|
||||
remainder []byte
|
||||
}
|
||||
|
||||
func newFrameReader(d framing) *frameReader {
|
||||
hdr := make([]byte, d.headerLen())
|
||||
return &frameReader{
|
||||
d,
|
||||
kFrameReaderHdr,
|
||||
hdr,
|
||||
nil,
|
||||
hdr,
|
||||
0,
|
||||
nil,
|
||||
}
|
||||
}
|
||||
|
||||
func dup(a []byte) []byte {
|
||||
r := make([]byte, len(a))
|
||||
copy(r, a)
|
||||
return r
|
||||
}
|
||||
|
||||
func (f *frameReader) needed() int {
|
||||
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
|
||||
if tmp < 0 {
|
||||
return 0
|
||||
}
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (f *frameReader) addChunk(in []byte) {
|
||||
// Append to the buffer.
|
||||
logf(logTypeFrameReader, "Appending %v", len(in))
|
||||
f.remainder = append(f.remainder, in...)
|
||||
}
|
||||
|
||||
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
||||
for f.needed() == 0 {
|
||||
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
|
||||
// Fill out our working block
|
||||
copied := copy(f.working[f.writeOffset:], f.remainder)
|
||||
f.remainder = f.remainder[copied:]
|
||||
f.writeOffset += copied
|
||||
if f.writeOffset < len(f.working) {
|
||||
logf(logTypeFrameReader, "Read would have blocked 1")
|
||||
return nil, nil, WouldBlock
|
||||
}
|
||||
// Reset the write offset, because we are now full.
|
||||
f.writeOffset = 0
|
||||
|
||||
// We have read a full frame
|
||||
if f.state == kFrameReaderBody {
|
||||
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
|
||||
f.state = kFrameReaderHdr
|
||||
f.working = f.header
|
||||
return dup(f.header), dup(f.body), nil
|
||||
}
|
||||
|
||||
// We have read the header
|
||||
bodyLen, err := f.details.frameLen(f.header)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
|
||||
|
||||
f.body = make([]byte, bodyLen)
|
||||
f.working = f.body
|
||||
f.writeOffset = 0
|
||||
f.state = kFrameReaderBody
|
||||
}
|
||||
|
||||
logf(logTypeFrameReader, "Read would have blocked 2")
|
||||
return nil, nil, WouldBlock
|
||||
}
|
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
|
@ -1,253 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
handshakeHeaderLen = 4 // handshake message header length
|
||||
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||
)
|
||||
|
||||
// struct {
|
||||
// HandshakeType msg_type; /* handshake type */
|
||||
// uint24 length; /* bytes in message */
|
||||
// select (HandshakeType) {
|
||||
// ...
|
||||
// } body;
|
||||
// } Handshake;
|
||||
//
|
||||
// We do the select{...} part in a different layer, so we treat the
|
||||
// actual message body as opaque:
|
||||
//
|
||||
// struct {
|
||||
// HandshakeType msg_type;
|
||||
// opaque msg<0..2^24-1>
|
||||
// } Handshake;
|
||||
//
|
||||
// TODO: File a spec bug
|
||||
type HandshakeMessage struct {
|
||||
// Omitted: length
|
||||
msgType HandshakeType
|
||||
body []byte
|
||||
}
|
||||
|
||||
// Note: This could be done with the `syntax` module, using the simplified
|
||||
// syntax as discussed above. However, since this is so simple, there's not
|
||||
// much benefit to doing so.
|
||||
func (hm *HandshakeMessage) Marshal() []byte {
|
||||
if hm == nil {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
msgLen := len(hm.body)
|
||||
data := make([]byte, 4+len(hm.body))
|
||||
data[0] = byte(hm.msgType)
|
||||
data[1] = byte(msgLen >> 16)
|
||||
data[2] = byte(msgLen >> 8)
|
||||
data[3] = byte(msgLen)
|
||||
copy(data[4:], hm.body)
|
||||
return data
|
||||
}
|
||||
|
||||
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
|
||||
|
||||
var body HandshakeMessageBody
|
||||
switch hm.msgType {
|
||||
case HandshakeTypeClientHello:
|
||||
body = new(ClientHelloBody)
|
||||
case HandshakeTypeServerHello:
|
||||
body = new(ServerHelloBody)
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
body = new(HelloRetryRequestBody)
|
||||
case HandshakeTypeEncryptedExtensions:
|
||||
body = new(EncryptedExtensionsBody)
|
||||
case HandshakeTypeCertificate:
|
||||
body = new(CertificateBody)
|
||||
case HandshakeTypeCertificateRequest:
|
||||
body = new(CertificateRequestBody)
|
||||
case HandshakeTypeCertificateVerify:
|
||||
body = new(CertificateVerifyBody)
|
||||
case HandshakeTypeFinished:
|
||||
body = &FinishedBody{VerifyDataLen: len(hm.body)}
|
||||
case HandshakeTypeNewSessionTicket:
|
||||
body = new(NewSessionTicketBody)
|
||||
case HandshakeTypeKeyUpdate:
|
||||
body = new(KeyUpdateBody)
|
||||
case HandshakeTypeEndOfEarlyData:
|
||||
body = new(EndOfEarlyDataBody)
|
||||
default:
|
||||
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||
}
|
||||
|
||||
_, err := body.Unmarshal(hm.body)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||
data, err := body.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HandshakeMessage{
|
||||
msgType: body.Type(),
|
||||
body: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type HandshakeLayer struct {
|
||||
nonblocking bool // Should we operate in nonblocking mode
|
||||
conn *RecordLayer // Used for reading/writing records
|
||||
frame *frameReader // The buffered frame reader
|
||||
}
|
||||
|
||||
type handshakeLayerFrameDetails struct{}
|
||||
|
||||
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||
return handshakeHeaderLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||
return handshakeHeaderLen + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
logf(logTypeIO, "Header=%x", hdr)
|
||||
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
||||
}
|
||||
|
||||
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
||||
h := HandshakeLayer{}
|
||||
h.conn = r
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) readRecord() error {
|
||||
logf(logTypeIO, "Trying to read record")
|
||||
pt, err := h.conn.ReadRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pt.contentType != RecordTypeHandshake &&
|
||||
pt.contentType != RecordTypeAlert {
|
||||
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||
}
|
||||
|
||||
if pt.contentType == RecordTypeAlert {
|
||||
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||
if len(pt.fragment) < 2 {
|
||||
h.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
return Alert(pt.fragment[1])
|
||||
}
|
||||
|
||||
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
||||
h.frame.addChunk(pt.fragment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||||
tmp := make([]byte, 2)
|
||||
tmp[0] = AlertLevelError
|
||||
tmp[1] = byte(err)
|
||||
h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAlert,
|
||||
fragment: tmp},
|
||||
)
|
||||
|
||||
// closeNotify is a special case in that it isn't an error:
|
||||
if err != AlertCloseNotify {
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||
var hdr, body []byte
|
||||
var err error
|
||||
|
||||
for {
|
||||
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||
if h.frame.needed() > 0 {
|
||||
logf(logTypeHandshake, "Trying to read a new record")
|
||||
err = h.readRecord()
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hdr, body, err = h.frame.process()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "read handshake message")
|
||||
|
||||
hm := &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(hdr[0])
|
||||
|
||||
hm.body = make([]byte, len(body))
|
||||
copy(hm.body, body)
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
||||
return h.WriteMessages([]*HandshakeMessage{hm})
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
||||
for _, hm := range hms {
|
||||
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||
}
|
||||
|
||||
// Write out headers and bodies
|
||||
buffer := []byte{}
|
||||
for _, msg := range hms {
|
||||
msgLen := len(msg.body)
|
||||
if msgLen > maxHandshakeMessageLen {
|
||||
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
||||
}
|
||||
|
||||
buffer = append(buffer, msg.Marshal()...)
|
||||
}
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start : start+maxFragmentLen],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Send a final partial fragment if necessary
|
||||
if start < len(buffer) {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start:],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
|
@ -1,450 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
)
|
||||
|
||||
type HandshakeMessageBody interface {
|
||||
Type() HandshakeType
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
||||
// Random random;
|
||||
// opaque legacy_session_id<0..32>;
|
||||
// CipherSuite cipher_suites<2..2^16-2>;
|
||||
// opaque legacy_compression_methods<1..2^8-1>;
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } ClientHello;
|
||||
type ClientHelloBody struct {
|
||||
// Omitted: clientVersion
|
||||
// Omitted: legacySessionID
|
||||
// Omitted: legacyCompressionMethods
|
||||
Random [32]byte
|
||||
CipherSuites []CipherSuite
|
||||
Extensions ExtensionList
|
||||
}
|
||||
|
||||
type clientHelloBodyInner struct {
|
||||
LegacyVersion uint16
|
||||
Random [32]byte
|
||||
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||
CipherSuites []CipherSuite `tls:"head=2,min=2"`
|
||||
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
|
||||
Extensions []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Type() HandshakeType {
|
||||
return HandshakeTypeClientHello
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(clientHelloBodyInner{
|
||||
LegacyVersion: 0x0303,
|
||||
Random: ch.Random,
|
||||
LegacySessionID: []byte{},
|
||||
CipherSuites: ch.CipherSuites,
|
||||
LegacyCompressionMethods: []byte{0},
|
||||
Extensions: ch.Extensions,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
||||
var inner clientHelloBodyInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// We are strict about these things because we only support 1.3
|
||||
if inner.LegacyVersion != 0x0303 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
|
||||
}
|
||||
|
||||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||
}
|
||||
|
||||
ch.Random = inner.Random
|
||||
ch.CipherSuites = inner.CipherSuites
|
||||
ch.Extensions = inner.Extensions
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// TODO: File a spec bug to clarify this
|
||||
func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
||||
if len(ch.Extensions) == 0 {
|
||||
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
|
||||
}
|
||||
|
||||
pskExt := ch.Extensions[len(ch.Extensions)-1]
|
||||
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
|
||||
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
||||
}
|
||||
|
||||
chm, err := HandshakeMessageFromBody(&ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chData := chm.Marshal()
|
||||
|
||||
psk := PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
}
|
||||
_, err = psk.Unmarshal(pskExt.ExtensionData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Marshal just the binders so that we know how much to truncate
|
||||
binders := struct {
|
||||
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||
}{Binders: psk.Binders}
|
||||
binderData, _ := syntax.Marshal(binders)
|
||||
binderLen := len(binderData)
|
||||
|
||||
chLen := len(chData)
|
||||
return chData[:chLen-binderLen], nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion server_version;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<2..2^16-1>;
|
||||
// } HelloRetryRequest;
|
||||
type HelloRetryRequestBody struct {
|
||||
Version uint16
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Type() HandshakeType {
|
||||
return HandshakeTypeHelloRetryRequest
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(hrr)
|
||||
}
|
||||
|
||||
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, hrr)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion version;
|
||||
// Random random;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } ServerHello;
|
||||
type ServerHelloBody struct {
|
||||
Version uint16
|
||||
Random [32]byte
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (sh ServerHelloBody) Type() HandshakeType {
|
||||
return HandshakeTypeServerHello
|
||||
}
|
||||
|
||||
func (sh ServerHelloBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sh)
|
||||
}
|
||||
|
||||
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sh)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque verify_data[verify_data_length];
|
||||
// } Finished;
|
||||
//
|
||||
// verifyDataLen is not a field in the TLS struct, but we add it here so
|
||||
// that calling code can tell us how much data to expect when we marshal /
|
||||
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
|
||||
// try to keep the signature consistent for now.)
|
||||
//
|
||||
// For similar reasons, we don't use the `syntax` module here, because this
|
||||
// struct doesn't map well to standard TLS presentation language concepts.
|
||||
//
|
||||
// TODO: File a spec bug
|
||||
type FinishedBody struct {
|
||||
VerifyDataLen int
|
||||
VerifyData []byte
|
||||
}
|
||||
|
||||
func (fin FinishedBody) Type() HandshakeType {
|
||||
return HandshakeTypeFinished
|
||||
}
|
||||
|
||||
func (fin FinishedBody) Marshal() ([]byte, error) {
|
||||
if len(fin.VerifyData) != fin.VerifyDataLen {
|
||||
return nil, fmt.Errorf("tls.finished: data length mismatch")
|
||||
}
|
||||
|
||||
body := make([]byte, len(fin.VerifyData))
|
||||
copy(body, fin.VerifyData)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
|
||||
if len(data) < fin.VerifyDataLen {
|
||||
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
|
||||
}
|
||||
|
||||
fin.VerifyData = make([]byte, fin.VerifyDataLen)
|
||||
copy(fin.VerifyData, data[:fin.VerifyDataLen])
|
||||
return fin.VerifyDataLen, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } EncryptedExtensions;
|
||||
//
|
||||
// Marshal() and Unmarshal() are handled by ExtensionList
|
||||
type EncryptedExtensionsBody struct {
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ee EncryptedExtensionsBody) Type() HandshakeType {
|
||||
return HandshakeTypeEncryptedExtensions
|
||||
}
|
||||
|
||||
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ee)
|
||||
}
|
||||
|
||||
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ee)
|
||||
}
|
||||
|
||||
// opaque ASN1Cert<1..2^24-1>;
|
||||
//
|
||||
// struct {
|
||||
// ASN1Cert cert_data;
|
||||
// Extension extensions<0..2^16-1>
|
||||
// } CertificateEntry;
|
||||
//
|
||||
// struct {
|
||||
// opaque certificate_request_context<0..2^8-1>;
|
||||
// CertificateEntry certificate_list<0..2^24-1>;
|
||||
// } Certificate;
|
||||
type CertificateEntry struct {
|
||||
CertData *x509.Certificate
|
||||
Extensions ExtensionList
|
||||
}
|
||||
|
||||
type CertificateBody struct {
|
||||
CertificateRequestContext []byte
|
||||
CertificateList []CertificateEntry
|
||||
}
|
||||
|
||||
type certificateEntryInner struct {
|
||||
CertData []byte `tls:"head=3,min=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
type certificateBodyInner struct {
|
||||
CertificateRequestContext []byte `tls:"head=1"`
|
||||
CertificateList []certificateEntryInner `tls:"head=3"`
|
||||
}
|
||||
|
||||
func (c CertificateBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificate
|
||||
}
|
||||
|
||||
func (c CertificateBody) Marshal() ([]byte, error) {
|
||||
inner := certificateBodyInner{
|
||||
CertificateRequestContext: c.CertificateRequestContext,
|
||||
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
|
||||
}
|
||||
|
||||
for i, entry := range c.CertificateList {
|
||||
inner.CertificateList[i] = certificateEntryInner{
|
||||
CertData: entry.CertData.Raw,
|
||||
Extensions: entry.Extensions,
|
||||
}
|
||||
}
|
||||
|
||||
return syntax.Marshal(inner)
|
||||
}
|
||||
|
||||
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
|
||||
inner := certificateBodyInner{}
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
|
||||
c.CertificateRequestContext = inner.CertificateRequestContext
|
||||
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
|
||||
|
||||
for i, entry := range inner.CertificateList {
|
||||
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
|
||||
}
|
||||
|
||||
c.CertificateList[i].Extensions = entry.Extensions
|
||||
}
|
||||
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// SignatureScheme algorithm;
|
||||
// opaque signature<0..2^16-1>;
|
||||
// } CertificateVerify;
|
||||
type CertificateVerifyBody struct {
|
||||
Algorithm SignatureScheme
|
||||
Signature []byte `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (cv CertificateVerifyBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificateVerify
|
||||
}
|
||||
|
||||
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(cv)
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, cv)
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
|
||||
// TODO: Change context for client auth
|
||||
// TODO: Put this in a const
|
||||
const context = "TLS 1.3, server CertificateVerify"
|
||||
sigInput := bytes.Repeat([]byte{0x20}, 64)
|
||||
sigInput = append(sigInput, []byte(context)...)
|
||||
sigInput = append(sigInput, []byte{0}...)
|
||||
sigInput = append(sigInput, data...)
|
||||
return sigInput
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
|
||||
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
|
||||
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||
return
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
|
||||
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque certificate_request_context<0..2^8-1>;
|
||||
// Extension extensions<2..2^16-1>;
|
||||
// } CertificateRequest;
|
||||
type CertificateRequestBody struct {
|
||||
CertificateRequestContext []byte `tls:"head=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (cr CertificateRequestBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificateRequest
|
||||
}
|
||||
|
||||
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(cr)
|
||||
}
|
||||
|
||||
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, cr)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// uint32 ticket_lifetime;
|
||||
// uint32 ticket_age_add;
|
||||
// opaque ticket_nonce<1..255>;
|
||||
// opaque ticket<1..2^16-1>;
|
||||
// Extension extensions<0..2^16-2>;
|
||||
// } NewSessionTicket;
|
||||
type NewSessionTicketBody struct {
|
||||
TicketLifetime uint32
|
||||
TicketAgeAdd uint32
|
||||
TicketNonce []byte `tls:"head=1,min=1"`
|
||||
Ticket []byte `tls:"head=2,min=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
const ticketNonceLen = 16
|
||||
|
||||
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
|
||||
buf := make([]byte, 4+ticketNonceLen+ticketLen)
|
||||
_, err := prng.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tkt := &NewSessionTicketBody{
|
||||
TicketLifetime: ticketLifetime,
|
||||
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
|
||||
TicketNonce: buf[4 : 4+ticketNonceLen],
|
||||
Ticket: buf[4+ticketNonceLen:],
|
||||
}
|
||||
|
||||
return tkt, err
|
||||
}
|
||||
|
||||
func (tkt NewSessionTicketBody) Type() HandshakeType {
|
||||
return HandshakeTypeNewSessionTicket
|
||||
}
|
||||
|
||||
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(tkt)
|
||||
}
|
||||
|
||||
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, tkt)
|
||||
}
|
||||
|
||||
// enum {
|
||||
// update_not_requested(0), update_requested(1), (255)
|
||||
// } KeyUpdateRequest;
|
||||
//
|
||||
// struct {
|
||||
// KeyUpdateRequest request_update;
|
||||
// } KeyUpdate;
|
||||
type KeyUpdateBody struct {
|
||||
KeyUpdateRequest KeyUpdateRequest
|
||||
}
|
||||
|
||||
func (ku KeyUpdateBody) Type() HandshakeType {
|
||||
return HandshakeTypeKeyUpdate
|
||||
}
|
||||
|
||||
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ku)
|
||||
}
|
||||
|
||||
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ku)
|
||||
}
|
||||
|
||||
// struct {} EndOfEarlyData;
|
||||
type EndOfEarlyDataBody struct{}
|
||||
|
||||
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
|
||||
return HandshakeTypeEndOfEarlyData
|
||||
}
|
||||
|
||||
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
|
@ -1,55 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// We use this environment variable to control logging. It should be a
|
||||
// comma-separated list of log tags (see below) or "*" to enable all logging.
|
||||
const logConfigVar = "MINT_LOG"
|
||||
|
||||
// Pre-defined log types
|
||||
const (
|
||||
logTypeCrypto = "crypto"
|
||||
logTypeHandshake = "handshake"
|
||||
logTypeNegotiation = "negotiation"
|
||||
logTypeIO = "io"
|
||||
logTypeFrameReader = "frame"
|
||||
logTypeVerbose = "verbose"
|
||||
)
|
||||
|
||||
var (
|
||||
logFunction = log.Printf
|
||||
logAll = false
|
||||
logSettings = map[string]bool{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
parseLogEnv(os.Environ())
|
||||
}
|
||||
|
||||
func parseLogEnv(env []string) {
|
||||
for _, stmt := range env {
|
||||
if strings.HasPrefix(stmt, logConfigVar+"=") {
|
||||
val := stmt[len(logConfigVar)+1:]
|
||||
|
||||
if val == "*" {
|
||||
logAll = true
|
||||
} else {
|
||||
for _, t := range strings.Split(val, ",") {
|
||||
logSettings[t] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func logf(tag string, format string, args ...interface{}) {
|
||||
if logAll || logSettings[tag] {
|
||||
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
|
||||
logFunction(fullFormat, args...)
|
||||
}
|
||||
}
|
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
|
@ -1,217 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
|
||||
for _, offeredVersion := range offered {
|
||||
for _, supportedVersion := range supported {
|
||||
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
|
||||
if offeredVersion == supportedVersion {
|
||||
// XXX: Should probably be highest supported version, but for now, we
|
||||
// only support one version, so it doesn't really matter.
|
||||
return true, offeredVersion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
|
||||
for _, share := range keyShares {
|
||||
for _, group := range groups {
|
||||
if group != share.Group {
|
||||
continue
|
||||
}
|
||||
|
||||
pub, priv, err := newKeyShare(share.Group)
|
||||
if err != nil {
|
||||
// If we encounter an error, just keep looking
|
||||
continue
|
||||
}
|
||||
|
||||
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
|
||||
if err != nil {
|
||||
// If we encounter an error, just keep looking
|
||||
continue
|
||||
}
|
||||
|
||||
return true, group, pub, dhSecret
|
||||
}
|
||||
}
|
||||
|
||||
return false, 0, nil, nil
|
||||
}
|
||||
|
||||
const (
|
||||
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
|
||||
)
|
||||
|
||||
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
|
||||
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
|
||||
for i, id := range identities {
|
||||
identityHex := hex.EncodeToString(id.Identity)
|
||||
|
||||
psk, ok := psks.Get(identityHex)
|
||||
if !ok {
|
||||
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
|
||||
continue
|
||||
}
|
||||
|
||||
// For resumption, make sure the ticket age is correct
|
||||
if psk.IsResumption {
|
||||
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
|
||||
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
|
||||
ticketAgeDelta := knownTicketAge - extTicketAge
|
||||
if knownTicketAge < extTicketAge {
|
||||
ticketAgeDelta = extTicketAge - knownTicketAge
|
||||
}
|
||||
if ticketAgeDelta > ticketAgeTolerance {
|
||||
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
|
||||
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
|
||||
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
|
||||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
|
||||
}
|
||||
}
|
||||
|
||||
params, ok := cipherSuiteMap[psk.CipherSuite]
|
||||
if !ok {
|
||||
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
|
||||
return false, 0, nil, CipherSuiteParams{}, err
|
||||
}
|
||||
|
||||
// Compute binder
|
||||
binderLabel := labelExternalBinder
|
||||
if psk.IsResumption {
|
||||
binderLabel = labelResumptionBinder
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
|
||||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||
|
||||
// context = ClientHello[truncated]
|
||||
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
|
||||
ctxHash := params.Hash.New()
|
||||
ctxHash.Write(context)
|
||||
|
||||
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
|
||||
if !bytes.Equal(binder, binders[i].Binder) {
|
||||
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
|
||||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
|
||||
}
|
||||
|
||||
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
|
||||
return true, i, &psk, params, nil
|
||||
}
|
||||
|
||||
logf(logTypeNegotiation, "Failed to find a usable PSK")
|
||||
return false, 0, nil, CipherSuiteParams{}, nil
|
||||
}
|
||||
|
||||
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
|
||||
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
|
||||
dhAllowed := false
|
||||
dhRequired := true
|
||||
for _, mode := range modes {
|
||||
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
|
||||
dhRequired = dhRequired && (mode == PSKModeDHEKE)
|
||||
}
|
||||
|
||||
// Use PSK if we can meet DH requirement and modes were provided
|
||||
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
|
||||
|
||||
// Use DH if allowed
|
||||
usingDH := canDoDH && (dhAllowed || !usingPSK)
|
||||
|
||||
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
|
||||
return usingDH, usingPSK
|
||||
}
|
||||
|
||||
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
|
||||
// Select for server name if provided
|
||||
candidates := certs
|
||||
if serverName != nil {
|
||||
candidatesByName := []*Certificate{}
|
||||
for _, cert := range certs {
|
||||
for _, name := range cert.Chain[0].DNSNames {
|
||||
if len(*serverName) > 0 && name == *serverName {
|
||||
candidatesByName = append(candidatesByName, cert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidatesByName) == 0 {
|
||||
return nil, 0, fmt.Errorf("No certificates available for server name")
|
||||
}
|
||||
|
||||
candidates = candidatesByName
|
||||
}
|
||||
|
||||
// Select for signature scheme
|
||||
for _, cert := range candidates {
|
||||
for _, scheme := range signatureSchemes {
|
||||
if !schemeValidForKey(scheme, cert.PrivateKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
return cert, scheme, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
||||
}
|
||||
|
||||
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
|
||||
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
|
||||
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
|
||||
return usingEarlyData
|
||||
}
|
||||
|
||||
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
||||
for _, s1 := range offered {
|
||||
if psk != nil {
|
||||
if s1 == psk.CipherSuite {
|
||||
return s1, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, s2 := range supported {
|
||||
if s1 == s2 {
|
||||
return s1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
|
||||
}
|
||||
|
||||
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
|
||||
for _, p1 := range offered {
|
||||
if psk != nil {
|
||||
if p1 != psk.NextProto {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, p2 := range supported {
|
||||
if p1 == p2 {
|
||||
return p1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the client offers ALPN on resumption, it must match the earlier one
|
||||
var err error
|
||||
if psk != nil && psk.IsResumption && (len(offered) > 0) {
|
||||
err = fmt.Errorf("ALPN for PSK not provided")
|
||||
}
|
||||
return "", err
|
||||
}
|
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
|
@ -1,296 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
sequenceNumberLen = 8 // sequence number length
|
||||
recordHeaderLen = 5 // record header length
|
||||
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||
)
|
||||
|
||||
type DecryptError string
|
||||
|
||||
func (err DecryptError) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ContentType type;
|
||||
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
||||
// uint16 length;
|
||||
// opaque fragment[TLSPlaintext.length];
|
||||
// } TLSPlaintext;
|
||||
type TLSPlaintext struct {
|
||||
// Omitted: record_version (static)
|
||||
// Omitted: length (computed from fragment)
|
||||
contentType RecordType
|
||||
fragment []byte
|
||||
}
|
||||
|
||||
type RecordLayer struct {
|
||||
sync.Mutex
|
||||
|
||||
conn io.ReadWriter // The underlying connection
|
||||
frame *frameReader // The buffered frame reader
|
||||
nextData []byte // The next record to send
|
||||
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
||||
cachedError error // Error on the last record read
|
||||
|
||||
ivLength int // Length of the seq and nonce fields
|
||||
seq []byte // Zero-padded sequence number
|
||||
nonce []byte // Buffer for per-record nonces
|
||||
cipher cipher.AEAD // AEAD cipher
|
||||
}
|
||||
|
||||
type recordLayerFrameDetails struct{}
|
||||
|
||||
func (d recordLayerFrameDetails) headerLen() int {
|
||||
return recordHeaderLen
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) defaultReadLen() int {
|
||||
return recordHeaderLen + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
return (int(hdr[3]) << 8) | int(hdr[4]), nil
|
||||
}
|
||||
|
||||
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
|
||||
r := RecordLayer{}
|
||||
r.conn = conn
|
||||
r.frame = newFrameReader(recordLayerFrameDetails{})
|
||||
r.ivLength = 0
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
|
||||
var err error
|
||||
r.cipher, err = cipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ivLength = len(iv)
|
||||
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
|
||||
r.nonce = make([]byte, r.ivLength)
|
||||
copy(r.nonce, iv)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) incrementSequenceNumber() {
|
||||
if r.ivLength == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
|
||||
r.seq[i]++
|
||||
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
|
||||
if r.seq[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Not allowed to let sequence number wrap.
|
||||
// Instead, must renegotiate before it does.
|
||||
// Not likely enough to bother.
|
||||
panic("TLS: sequence number wraparound")
|
||||
}
|
||||
|
||||
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||
// Expand the fragment to hold contentType, padding, and overhead
|
||||
originalLen := len(pt.fragment)
|
||||
plaintextLen := originalLen + 1 + padLen
|
||||
ciphertextLen := plaintextLen + r.cipher.Overhead()
|
||||
|
||||
// Assemble the revised plaintext
|
||||
out := &TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: make([]byte, ciphertextLen),
|
||||
}
|
||||
copy(out.fragment, pt.fragment)
|
||||
out.fragment[originalLen] = byte(pt.contentType)
|
||||
for i := 1; i <= padLen; i++ {
|
||||
out.fragment[originalLen+i] = 0
|
||||
}
|
||||
|
||||
// Encrypt the fragment
|
||||
payload := out.fragment[:plaintextLen]
|
||||
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
||||
if len(pt.fragment) < r.cipher.Overhead() {
|
||||
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
|
||||
return nil, 0, DecryptError(msg)
|
||||
}
|
||||
|
||||
decryptLen := len(pt.fragment) - r.cipher.Overhead()
|
||||
out := &TLSPlaintext{
|
||||
contentType: pt.contentType,
|
||||
fragment: make([]byte, decryptLen),
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
|
||||
if err != nil {
|
||||
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
||||
}
|
||||
|
||||
// Find the padding boundary
|
||||
padLen := 0
|
||||
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
|
||||
}
|
||||
|
||||
// Transfer the content type
|
||||
newLen := decryptLen - padLen - 1
|
||||
out.contentType = RecordType(out.fragment[newLen])
|
||||
|
||||
// Truncate the message to remove contentType, padding, overhead
|
||||
out.fragment = out.fragment[:newLen]
|
||||
return out, padLen, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
||||
var pt *TLSPlaintext
|
||||
var err error
|
||||
|
||||
for {
|
||||
pt, err = r.nextRecord()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if !block || err != WouldBlock {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return pt.contentType, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||
pt, err := r.nextRecord()
|
||||
|
||||
// Consume the cached record if there was one
|
||||
r.cachedRecord = nil
|
||||
r.cachedError = nil
|
||||
|
||||
return pt, err
|
||||
}
|
||||
|
||||
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
if r.cachedRecord != nil {
|
||||
logf(logTypeIO, "Returning cached record")
|
||||
return r.cachedRecord, r.cachedError
|
||||
}
|
||||
|
||||
// Loop until one of three things happens:
|
||||
//
|
||||
// 1. We get a frame
|
||||
// 2. We try to read off the socket and get nothing, in which case
|
||||
// return WouldBlock
|
||||
// 3. We get an error.
|
||||
err := WouldBlock
|
||||
var header, body []byte
|
||||
|
||||
for err != nil {
|
||||
if r.frame.needed() > 0 {
|
||||
buf := make([]byte, recordHeaderLen+maxFragmentLen)
|
||||
n, err := r.conn.Read(buf)
|
||||
if err != nil {
|
||||
logf(logTypeIO, "Error reading, %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nil, WouldBlock
|
||||
}
|
||||
|
||||
logf(logTypeIO, "Read %v bytes", n)
|
||||
|
||||
buf = buf[:n]
|
||||
r.frame.addChunk(buf)
|
||||
}
|
||||
|
||||
header, body, err = r.frame.process()
|
||||
// Loop around on WouldBlock to see if some
|
||||
// data is now available.
|
||||
if err != nil && err != WouldBlock {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pt := &TLSPlaintext{}
|
||||
// Validate content type
|
||||
switch RecordType(header[0]) {
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
||||
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
|
||||
pt.contentType = RecordType(header[0])
|
||||
}
|
||||
|
||||
// Validate version
|
||||
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
|
||||
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
|
||||
}
|
||||
|
||||
// Validate size < max
|
||||
size := (int(header[3]) << 8) + int(header[4])
|
||||
if size > maxFragmentLen+256 {
|
||||
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
||||
}
|
||||
|
||||
pt.fragment = make([]byte, size)
|
||||
copy(pt.fragment, body)
|
||||
|
||||
// Attempt to decrypt fragment
|
||||
if r.cipher != nil {
|
||||
pt, _, err = r.decrypt(pt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Check that plaintext length is not too long
|
||||
if len(pt.fragment) > maxFragmentLen {
|
||||
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
||||
}
|
||||
|
||||
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
|
||||
r.cachedRecord = pt
|
||||
r.incrementSequenceNumber()
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
||||
return r.WriteRecordWithPadding(pt, 0)
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
||||
if r.cipher != nil {
|
||||
pt = r.encrypt(pt, padLen)
|
||||
} else if padLen > 0 {
|
||||
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
||||
}
|
||||
|
||||
if len(pt.fragment) > maxFragmentLen {
|
||||
return fmt.Errorf("tls.record: Record size too big")
|
||||
}
|
||||
|
||||
length := len(pt.fragment)
|
||||
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
|
||||
record := append(header, pt.fragment...)
|
||||
|
||||
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
|
||||
r.incrementSequenceNumber()
|
||||
_, err := r.conn.Write(record)
|
||||
return err
|
||||
}
|
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
|
@ -1,898 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"hash"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Server State Machine
|
||||
//
|
||||
// START <-----+
|
||||
// Recv ClientHello | | Send HelloRetryRequest
|
||||
// v |
|
||||
// RECVD_CH ----+
|
||||
// | Select parameters
|
||||
// | Send ServerHello
|
||||
// v
|
||||
// NEGOTIATED
|
||||
// | Send EncryptedExtensions
|
||||
// | [Send CertificateRequest]
|
||||
// Can send | [Send Certificate + CertificateVerify]
|
||||
// app data --> | Send Finished
|
||||
// after +--------+--------+
|
||||
// here No 0-RTT | | 0-RTT
|
||||
// | v
|
||||
// | WAIT_EOED <---+
|
||||
// | Recv | | | Recv
|
||||
// | EndOfEarlyData | | | early data
|
||||
// | | +-----+
|
||||
// +> WAIT_FLIGHT2 <-+
|
||||
// |
|
||||
// +--------+--------+
|
||||
// No auth | | Client auth
|
||||
// | |
|
||||
// | v
|
||||
// | WAIT_CERT
|
||||
// | Recv | | Recv Certificate
|
||||
// | empty | v
|
||||
// | Certificate | WAIT_CV
|
||||
// | | | Recv
|
||||
// | v | CertificateVerify
|
||||
// +-> WAIT_FINISHED <---+
|
||||
// | Recv Finished
|
||||
// v
|
||||
// CONNECTED
|
||||
//
|
||||
// NB: Not using state RECVD_CH
|
||||
//
|
||||
// State Instructions
|
||||
// START {}
|
||||
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
|
||||
// WAIT_EOED RekeyIn;
|
||||
// WAIT_FLIGHT2 {}
|
||||
// WAIT_CERT_CR {}
|
||||
// WAIT_CERT {}
|
||||
// WAIT_CV {}
|
||||
// WAIT_FINISHED RekeyIn; RekeyOut;
|
||||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||
|
||||
type ServerStateStart struct {
|
||||
Caps Capabilities
|
||||
conn *Conn
|
||||
|
||||
cookieSent bool
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeClientHello {
|
||||
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
ch := &ClientHelloBody{}
|
||||
_, err := ch.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
clientHello := hm
|
||||
connParams := ConnectionParameters{}
|
||||
|
||||
supportedVersions := new(SupportedVersionsExtension)
|
||||
serverName := new(ServerNameExtension)
|
||||
supportedGroups := new(SupportedGroupsExtension)
|
||||
signatureAlgorithms := new(SignatureAlgorithmsExtension)
|
||||
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
|
||||
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
|
||||
clientEarlyData := &EarlyDataExtension{}
|
||||
clientALPN := new(ALPNExtension)
|
||||
clientPSKModes := new(PSKKeyExchangeModesExtension)
|
||||
clientCookie := new(CookieExtension)
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
|
||||
gotServerName := ch.Extensions.Find(serverName)
|
||||
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
|
||||
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
|
||||
gotEarlyData := ch.Extensions.Find(clientEarlyData)
|
||||
ch.Extensions.Find(clientKeyShares)
|
||||
ch.Extensions.Find(clientPSK)
|
||||
ch.Extensions.Find(clientALPN)
|
||||
ch.Extensions.Find(clientPSKModes)
|
||||
ch.Extensions.Find(clientCookie)
|
||||
|
||||
if gotServerName {
|
||||
connParams.ServerName = string(*serverName)
|
||||
}
|
||||
|
||||
// If the client didn't send supportedVersions or doesn't support 1.3,
|
||||
// then we're done here.
|
||||
if !gotSupportedVersions {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
|
||||
if !versionOK {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
|
||||
return nil, nil, AlertAccessDenied
|
||||
}
|
||||
|
||||
// Figure out if we can do DH
|
||||
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
|
||||
|
||||
// Figure out if we can do PSK
|
||||
canDoPSK := false
|
||||
var selectedPSK int
|
||||
var psk *PreSharedKey
|
||||
var params CipherSuiteParams
|
||||
if len(clientPSK.Identities) > 0 {
|
||||
contextBase := []byte{}
|
||||
if state.helloRetryRequest != nil {
|
||||
chBytes := state.firstClientHello.Marshal()
|
||||
hrrBytes := state.helloRetryRequest.Marshal()
|
||||
contextBase = append(chBytes, hrrBytes...)
|
||||
}
|
||||
|
||||
chTrunc, err := ch.Truncated()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
context := append(contextBase, chTrunc...)
|
||||
|
||||
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Figure out if we actually should do DH / PSK
|
||||
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
|
||||
|
||||
// Select a ciphersuite
|
||||
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Send a cookie if required
|
||||
// NB: Need to do this here because it's after ciphersuite selection, which
|
||||
// has to be after PSK selection.
|
||||
// XXX: Doing this statefully for now, could be stateless
|
||||
var cookieData []byte
|
||||
if state.Caps.RequireCookie && !state.cookieSent {
|
||||
var err error
|
||||
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if cookieData != nil {
|
||||
// Ignoring errors because everything here is newly constructed, so there
|
||||
// shouldn't be marshal errors
|
||||
hrr := &HelloRetryRequestBody{
|
||||
Version: supportedVersion,
|
||||
CipherSuite: connParams.CipherSuite,
|
||||
}
|
||||
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
params := cipherSuiteMap[connParams.CipherSuite]
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
firstClientHello := &HandshakeMessage{
|
||||
msgType: HandshakeTypeMessageHash,
|
||||
body: h.Sum(nil),
|
||||
}
|
||||
|
||||
nextState := ServerStateStart{
|
||||
Caps: state.Caps,
|
||||
conn: state.conn,
|
||||
cookieSent: true,
|
||||
firstClientHello: firstClientHello,
|
||||
helloRetryRequest: helloRetryRequest,
|
||||
}
|
||||
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
|
||||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
// If we've got no entropy to make keys from, fail
|
||||
if !connParams.UsingDH && !connParams.UsingPSK {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
var pskSecret []byte
|
||||
var cert *Certificate
|
||||
var certScheme SignatureScheme
|
||||
if connParams.UsingPSK {
|
||||
pskSecret = psk.Key
|
||||
} else {
|
||||
psk = nil
|
||||
|
||||
// If we're not using a PSK mode, then we need to have certain extensions
|
||||
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
|
||||
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
|
||||
return nil, nil, AlertMissingExtension
|
||||
}
|
||||
|
||||
// Select a certificate
|
||||
name := string(*serverName)
|
||||
var err error
|
||||
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
|
||||
return nil, nil, AlertAccessDenied
|
||||
}
|
||||
}
|
||||
|
||||
if !connParams.UsingDH {
|
||||
dhSecret = nil
|
||||
}
|
||||
|
||||
// Figure out if we're going to do early data
|
||||
var clientEarlyTrafficSecret []byte
|
||||
connParams.ClientSendingEarlyData = gotEarlyData
|
||||
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
|
||||
if connParams.UsingEarlyData {
|
||||
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
chHash := h.Sum(nil)
|
||||
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
|
||||
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||
}
|
||||
|
||||
// Select a next protocol
|
||||
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
|
||||
return nil, nil, AlertNoApplicationProtocol
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
|
||||
return ServerStateNegotiated{
|
||||
Caps: state.Caps,
|
||||
Params: connParams,
|
||||
|
||||
dhGroup: dhGroup,
|
||||
dhPublic: dhPublic,
|
||||
dhSecret: dhSecret,
|
||||
pskSecret: pskSecret,
|
||||
selectedPSK: selectedPSK,
|
||||
cert: cert,
|
||||
certScheme: certScheme,
|
||||
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
|
||||
|
||||
firstClientHello: state.firstClientHello,
|
||||
helloRetryRequest: state.helloRetryRequest,
|
||||
clientHello: clientHello,
|
||||
}.Next(nil)
|
||||
}
|
||||
|
||||
type ServerStateNegotiated struct {
|
||||
Caps Capabilities
|
||||
Params ConnectionParameters
|
||||
|
||||
dhGroup NamedGroup
|
||||
dhPublic []byte
|
||||
dhSecret []byte
|
||||
pskSecret []byte
|
||||
clientEarlyTrafficSecret []byte
|
||||
selectedPSK int
|
||||
cert *Certificate
|
||||
certScheme SignatureScheme
|
||||
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
clientHello *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Create the ServerHello
|
||||
sh := &ServerHelloBody{
|
||||
Version: supportedVersion,
|
||||
CipherSuite: state.Params.CipherSuite,
|
||||
}
|
||||
_, err := prng.Read(sh.Random[:])
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
if state.Params.UsingDH {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
|
||||
err = sh.Extensions.Add(&KeyShareExtension{
|
||||
HandshakeType: HandshakeTypeServerHello,
|
||||
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
|
||||
})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.Params.UsingPSK {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
|
||||
err = sh.Extensions.Add(&PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeServerHello,
|
||||
SelectedIdentity: uint16(state.selectedPSK),
|
||||
})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
serverHello, err := HandshakeMessageFromBody(sh)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Look up crypto params
|
||||
params, ok := cipherSuiteMap[sh.CipherSuite]
|
||||
if !ok {
|
||||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Start up the handshake hash
|
||||
handshakeHash := params.Hash.New()
|
||||
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||
handshakeHash.Write(state.clientHello.Marshal())
|
||||
handshakeHash.Write(serverHello.Marshal())
|
||||
|
||||
// Compute handshake secrets
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
var earlySecret []byte
|
||||
if state.Params.UsingPSK {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
|
||||
} else {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||
}
|
||||
|
||||
if state.dhSecret == nil {
|
||||
state.dhSecret = zero
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
h2 := handshakeHash.Sum(nil)
|
||||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
|
||||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||
|
||||
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
|
||||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
|
||||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||
|
||||
// Send an EncryptedExtensions message (even if it's empty)
|
||||
eeList := ExtensionList{}
|
||||
if state.Params.NextProto != "" {
|
||||
logf(logTypeHandshake, "[server] sending ALPN extension")
|
||||
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.Params.UsingEarlyData {
|
||||
logf(logTypeHandshake, "[server] sending EDI extension")
|
||||
err = eeList.Add(&EarlyDataExtension{})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
ee := &EncryptedExtensionsBody{eeList}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
eem, err := HandshakeMessageFromBody(ee)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
handshakeHash.Write(eem.Marshal())
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{serverHello},
|
||||
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||
SendHandshakeMessage{eem},
|
||||
}
|
||||
|
||||
// Authenticate with a certificate if required
|
||||
if !state.Params.UsingPSK {
|
||||
// Send a CertificateRequest message if we want client auth
|
||||
if state.Caps.RequireClientAuth {
|
||||
state.Params.UsingClientAuth = true
|
||||
|
||||
// XXX: We don't support sending any constraints besides a list of
|
||||
// supported signature algorithms
|
||||
cr := &CertificateRequestBody{}
|
||||
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||
err := cr.Extensions.Add(schemes)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
crm, err := HandshakeMessageFromBody(cr)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
//TODO state.state.serverCertificateRequest = cr
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{crm})
|
||||
handshakeHash.Write(crm.Marshal())
|
||||
}
|
||||
|
||||
// Create and send Certificate, CertificateVerify
|
||||
certificate := &CertificateBody{
|
||||
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
|
||||
}
|
||||
for i, entry := range state.cert.Chain {
|
||||
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||
}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
handshakeHash.Write(certm.Marshal())
|
||||
|
||||
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
|
||||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
|
||||
|
||||
hcv := handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||
handshakeHash.Write(certvm.Marshal())
|
||||
}
|
||||
|
||||
// Compute secrets resulting from the server's first flight
|
||||
h3 := handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||
|
||||
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
|
||||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||
|
||||
// Assemble the Finished message
|
||||
fin := &FinishedBody{
|
||||
VerifyDataLen: len(serverFinishedData),
|
||||
VerifyData: serverFinishedData,
|
||||
}
|
||||
finm, _ := HandshakeMessageFromBody(fin)
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{finm})
|
||||
handshakeHash.Write(finm.Marshal())
|
||||
|
||||
// Compute traffic secrets
|
||||
h4 := handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
|
||||
|
||||
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||
|
||||
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
|
||||
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
|
||||
|
||||
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
|
||||
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||
|
||||
if state.Params.UsingEarlyData {
|
||||
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
|
||||
nextState := ServerStateWaitEOED{
|
||||
AuthCertificate: state.Caps.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||
ReadEarlyData{},
|
||||
}...)
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||
ReadPastEarlyData{},
|
||||
}...)
|
||||
waitFlight2 := ServerStateWaitFlight2{
|
||||
AuthCertificate: state.Caps.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||
toSend = append(toSend, moreToSend...)
|
||||
return nextState, toSend, alert
|
||||
}
|
||||
|
||||
type ServerStateWaitEOED struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
if len(hm.body) > 0 {
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||
}
|
||||
waitFlight2 := ServerStateWaitFlight2{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||
toSend = append(toSend, moreToSend...)
|
||||
return nextState, toSend, alert
|
||||
}
|
||||
|
||||
type ServerStateWaitFlight2 struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
if state.Params.UsingClientAuth {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
|
||||
nextState := ServerStateWaitCert{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitCert struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
cert := &CertificateBody{}
|
||||
_, err := cert.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
if len(cert.CertificateList) == 0 {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
|
||||
nextState := ServerStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
clientCertificate: cert,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitCV struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
|
||||
clientCertificate *CertificateBody
|
||||
}
|
||||
|
||||
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
certVerify := &CertificateVerifyBody{}
|
||||
_, err := certVerify.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Verify client signature over handshake hash
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
|
||||
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
if state.AuthCertificate != nil {
|
||||
err := state.AuthCertificate(state.clientCertificate.CertificateList)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
|
||||
return nil, nil, AlertBadCertificate
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
|
||||
}
|
||||
|
||||
// If it passes, record the certificateVerify in the transcript hash
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitFinished struct {
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
|
||||
_, err := fin.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Verify client Finished data
|
||||
h5 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||
|
||||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||
|
||||
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Compute the resumption secret
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
h6 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
|
||||
|
||||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||
|
||||
// Compute client traffic keys
|
||||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
|
||||
nextState := StateConnected{
|
||||
Params: state.Params,
|
||||
isClient: false,
|
||||
cryptoParams: state.cryptoParams,
|
||||
resumptionSecret: resumptionSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
|
@ -1,230 +0,0 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Marker interface for actions that an implementation should take based on
|
||||
// state transitions.
|
||||
type HandshakeAction interface{}
|
||||
|
||||
type SendHandshakeMessage struct {
|
||||
Message *HandshakeMessage
|
||||
}
|
||||
|
||||
type SendEarlyData struct{}
|
||||
|
||||
type ReadEarlyData struct{}
|
||||
|
||||
type ReadPastEarlyData struct{}
|
||||
|
||||
type RekeyIn struct {
|
||||
Label string
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type RekeyOut struct {
|
||||
Label string
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type StorePSK struct {
|
||||
PSK PreSharedKey
|
||||
}
|
||||
|
||||
type HandshakeState interface {
|
||||
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
|
||||
}
|
||||
|
||||
type AppExtensionHandler interface {
|
||||
Send(hs HandshakeType, el *ExtensionList) error
|
||||
Receive(hs HandshakeType, el *ExtensionList) error
|
||||
}
|
||||
|
||||
// Capabilities objects represent the capabilities of a TLS client or server,
|
||||
// as an input to TLS negotiation
|
||||
type Capabilities struct {
|
||||
// For both client and server
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
PSKs PreSharedKeyCache
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
ExtensionHandler AppExtensionHandler
|
||||
|
||||
// For client
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
|
||||
// For server
|
||||
NextProtos []string
|
||||
AllowEarlyData bool
|
||||
RequireCookie bool
|
||||
CookieHandler CookieHandler
|
||||
RequireClientAuth bool
|
||||
}
|
||||
|
||||
// ConnectionOptions objects represent per-connection settings for a client
|
||||
// initiating a connection
|
||||
type ConnectionOptions struct {
|
||||
ServerName string
|
||||
NextProtos []string
|
||||
EarlyData []byte
|
||||
}
|
||||
|
||||
// ConnectionParameters objects represent the parameters negotiated for a
|
||||
// connection.
|
||||
type ConnectionParameters struct {
|
||||
UsingPSK bool
|
||||
UsingDH bool
|
||||
ClientSendingEarlyData bool
|
||||
UsingEarlyData bool
|
||||
UsingClientAuth bool
|
||||
|
||||
CipherSuite CipherSuite
|
||||
ServerName string
|
||||
NextProto string
|
||||
}
|
||||
|
||||
// StateConnected is symmetric between client and server
|
||||
type StateConnected struct {
|
||||
Params ConnectionParameters
|
||||
isClient bool
|
||||
cryptoParams CipherSuiteParams
|
||||
resumptionSecret []byte
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||
var trafficKeys keySet
|
||||
if state.isClient {
|
||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
} else {
|
||||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{kum},
|
||||
RekeyOut{Label: "update", KeySet: trafficKeys},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||
tkt, err := NewSessionTicket(length, lifetime)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
|
||||
|
||||
newPSK := PreSharedKey{
|
||||
CipherSuite: state.cryptoParams.Suite,
|
||||
IsResumption: true,
|
||||
Identity: tkt.Ticket,
|
||||
Key: resumptionKey,
|
||||
NextProto: state.Params.NextProto,
|
||||
ReceivedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
|
||||
TicketAgeAdd: tkt.TicketAgeAdd,
|
||||
}
|
||||
|
||||
tktm, err := HandshakeMessageFromBody(tkt)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
StorePSK{newPSK},
|
||||
SendHandshakeMessage{tktm},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *KeyUpdateBody:
|
||||
var trafficKeys keySet
|
||||
if !state.isClient {
|
||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
} else {
|
||||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
|
||||
|
||||
// If requested, roll outbound keys and send a KeyUpdate
|
||||
if body.KeyUpdateRequest == KeyUpdateRequested {
|
||||
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
||||
if alert != AlertNoAlert {
|
||||
return nil, nil, alert
|
||||
}
|
||||
|
||||
toSend = append(toSend, moreToSend...)
|
||||
}
|
||||
|
||||
return state, toSend, AlertNoAlert
|
||||
|
||||
case *NewSessionTicketBody:
|
||||
// XXX: Allow NewSessionTicket in both directions?
|
||||
if !state.isClient {
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
||||
|
||||
psk := PreSharedKey{
|
||||
CipherSuite: state.cryptoParams.Suite,
|
||||
IsResumption: true,
|
||||
Identity: body.Ticket,
|
||||
Key: resumptionKey,
|
||||
NextProto: state.Params.NextProto,
|
||||
ReceivedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
|
||||
TicketAgeAdd: body.TicketAgeAdd,
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{StorePSK{psk}}
|
||||
return state, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
|
@ -1,243 +0,0 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v interface{}) (int, error) {
|
||||
// Check for well-formedness.
|
||||
// Avoids filling out half a data structure
|
||||
// before discovering a JSON syntax error.
|
||||
d := decodeState{}
|
||||
d.Write(data)
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type decOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
}
|
||||
|
||||
type decodeState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if _, ok := r.(runtime.Error); ok {
|
||||
panic(r)
|
||||
}
|
||||
if s, ok := r.(string); ok {
|
||||
panic(s)
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
|
||||
}
|
||||
|
||||
read = d.value(rv)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
func (e *decodeState) value(v reflect.Value) int {
|
||||
return valueDecoder(v)(e, v, decOpts{})
|
||||
}
|
||||
|
||||
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
|
||||
|
||||
func valueDecoder(v reflect.Value) decoderFunc {
|
||||
return typeDecoder(v.Type().Elem())
|
||||
}
|
||||
|
||||
func typeDecoder(t reflect.Type) decoderFunc {
|
||||
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||
return newTypeDecoder(t)
|
||||
}
|
||||
|
||||
func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return uintDecoder
|
||||
case reflect.Array:
|
||||
return newArrayDecoder(t)
|
||||
case reflect.Slice:
|
||||
return newSliceDecoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructDecoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
}
|
||||
|
||||
///// Specific decoders below
|
||||
|
||||
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
var uintLen int
|
||||
switch v.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
uintLen = 1
|
||||
case reflect.Uint16:
|
||||
uintLen = 2
|
||||
case reflect.Uint32:
|
||||
uintLen = 4
|
||||
case reflect.Uint64:
|
||||
uintLen = 8
|
||||
}
|
||||
|
||||
buf := make([]byte, uintLen)
|
||||
n, err := d.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if n != uintLen {
|
||||
panic(fmt.Errorf("Insufficient data to read uint"))
|
||||
}
|
||||
|
||||
val := uint64(0)
|
||||
for _, b := range buf {
|
||||
val = (val << 8) + uint64(b)
|
||||
}
|
||||
|
||||
v.Elem().SetUint(val)
|
||||
return uintLen
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type arrayDecoder struct {
|
||||
elemDec decoderFunc
|
||||
}
|
||||
|
||||
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
n := v.Elem().Type().Len()
|
||||
read := 0
|
||||
for i := 0; i < n; i += 1 {
|
||||
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newArrayDecoder(t reflect.Type) decoderFunc {
|
||||
dec := &arrayDecoder{typeDecoder(t.Elem())}
|
||||
return dec.decode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type sliceDecoder struct {
|
||||
elementType reflect.Type
|
||||
elementDec decoderFunc
|
||||
}
|
||||
|
||||
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
||||
}
|
||||
|
||||
lengthBytes := make([]byte, opts.head)
|
||||
n, err := d.Read(lengthBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != opts.head {
|
||||
panic(fmt.Errorf("Not enough data to read header"))
|
||||
}
|
||||
|
||||
length := uint(0)
|
||||
for _, b := range lengthBytes {
|
||||
length = (length << 8) + uint(b)
|
||||
}
|
||||
|
||||
if opts.max > 0 && length > opts.max {
|
||||
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||
}
|
||||
if length < opts.min {
|
||||
panic(fmt.Errorf("Length of vector below declared min"))
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err = d.Read(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != length {
|
||||
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
|
||||
}
|
||||
|
||||
elemBuf := &decodeState{}
|
||||
elemBuf.Write(data)
|
||||
elems := []reflect.Value{}
|
||||
read := int(opts.head)
|
||||
for elemBuf.Len() > 0 {
|
||||
elem := reflect.New(sd.elementType)
|
||||
read += sd.elementDec(elemBuf, elem, opts)
|
||||
elems = append(elems, elem)
|
||||
}
|
||||
|
||||
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
|
||||
for i := 0; i < len(elems); i += 1 {
|
||||
v.Elem().Index(i).Set(elems[i].Elem())
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newSliceDecoder(t reflect.Type) decoderFunc {
|
||||
dec := &sliceDecoder{
|
||||
elementType: t.Elem(),
|
||||
elementDec: typeDecoder(t.Elem()),
|
||||
}
|
||||
return dec.decode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type structDecoder struct {
|
||||
fieldOpts []decOpts
|
||||
fieldDecs []decoderFunc
|
||||
}
|
||||
|
||||
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
read := 0
|
||||
for i := range sd.fieldDecs {
|
||||
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newStructDecoder(t reflect.Type) decoderFunc {
|
||||
n := t.NumField()
|
||||
sd := structDecoder{
|
||||
fieldOpts: make([]decOpts, n),
|
||||
fieldDecs: make([]decoderFunc, n),
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 1 {
|
||||
f := t.Field(i)
|
||||
|
||||
tag := f.Tag.Get("tls")
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
sd.fieldOpts[i] = decOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
}
|
||||
|
||||
sd.fieldDecs[i] = typeDecoder(f.Type)
|
||||
}
|
||||
|
||||
return sd.decode
|
||||
}
|
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
|
@ -1,187 +0,0 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func Marshal(v interface{}) ([]byte, error) {
|
||||
e := &encodeState{}
|
||||
err := e.marshal(v, encOpts{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e.Bytes(), nil
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type encOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
}
|
||||
|
||||
type encodeState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if _, ok := r.(runtime.Error); ok {
|
||||
panic(r)
|
||||
}
|
||||
if s, ok := r.(string); ok {
|
||||
panic(s)
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
e.reflectValue(reflect.ValueOf(v), opts)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
|
||||
valueEncoder(v)(e, v, opts)
|
||||
}
|
||||
|
||||
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
|
||||
|
||||
func valueEncoder(v reflect.Value) encoderFunc {
|
||||
if !v.IsValid() {
|
||||
panic(fmt.Errorf("Cannot encode an invalid value"))
|
||||
}
|
||||
return typeEncoder(v.Type())
|
||||
}
|
||||
|
||||
func typeEncoder(t reflect.Type) encoderFunc {
|
||||
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||
return newTypeEncoder(t)
|
||||
}
|
||||
|
||||
func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return uintEncoder
|
||||
case reflect.Array:
|
||||
return newArrayEncoder(t)
|
||||
case reflect.Slice:
|
||||
return newSliceEncoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructEncoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
}
|
||||
|
||||
///// Specific encoders below
|
||||
|
||||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
u := v.Uint()
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.WriteByte(byte(u))
|
||||
case reflect.Uint16:
|
||||
e.Write([]byte{byte(u >> 8), byte(u)})
|
||||
case reflect.Uint32:
|
||||
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
case reflect.Uint64:
|
||||
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
|
||||
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
}
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type arrayEncoder struct {
|
||||
elemEnc encoderFunc
|
||||
}
|
||||
|
||||
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
n := v.Len()
|
||||
for i := 0; i < n; i += 1 {
|
||||
ae.elemEnc(e, v.Index(i), opts)
|
||||
}
|
||||
}
|
||||
|
||||
func newArrayEncoder(t reflect.Type) encoderFunc {
|
||||
enc := &arrayEncoder{typeEncoder(t.Elem())}
|
||||
return enc.encode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type sliceEncoder struct {
|
||||
ae *arrayEncoder
|
||||
}
|
||||
|
||||
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||
}
|
||||
|
||||
arrayState := &encodeState{}
|
||||
se.ae.encode(arrayState, v, opts)
|
||||
|
||||
n := uint(arrayState.Len())
|
||||
if opts.max > 0 && n > opts.max {
|
||||
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
||||
}
|
||||
if n>>(8*opts.head) > 0 {
|
||||
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||
}
|
||||
if n < opts.min {
|
||||
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
||||
}
|
||||
|
||||
for i := int(opts.head - 1); i >= 0; i -= 1 {
|
||||
e.WriteByte(byte(n >> (8 * uint(i))))
|
||||
}
|
||||
e.Write(arrayState.Bytes())
|
||||
}
|
||||
|
||||
func newSliceEncoder(t reflect.Type) encoderFunc {
|
||||
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
|
||||
return enc.encode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type structEncoder struct {
|
||||
fieldOpts []encOpts
|
||||
fieldEncs []encoderFunc
|
||||
}
|
||||
|
||||
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
for i := range se.fieldEncs {
|
||||
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
|
||||
}
|
||||
}
|
||||
|
||||
func newStructEncoder(t reflect.Type) encoderFunc {
|
||||
n := t.NumField()
|
||||
se := structEncoder{
|
||||
fieldOpts: make([]encOpts, n),
|
||||
fieldEncs: make([]encoderFunc, n),
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 1 {
|
||||
f := t.Field(i)
|
||||
tag := f.Tag.Get("tls")
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
se.fieldOpts[i] = encOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
}
|
||||
se.fieldEncs[i] = typeEncoder(f.Type)
|
||||
}
|
||||
|
||||
return se.encode
|
||||
}
|
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
|
@ -1,30 +0,0 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// `tls:"head=2,min=2,max=255"`
|
||||
|
||||
type tagOptions map[string]uint
|
||||
|
||||
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
||||
// name=value pairs, where the values MUST be unsigned integers
|
||||
func parseTag(tag string) tagOptions {
|
||||
opts := tagOptions{}
|
||||
for _, token := range strings.Split(tag, ",") {
|
||||
if strings.Index(token, "=") == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(token, "=")
|
||||
if len(parts[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||
opts[parts[0]] = uint(val)
|
||||
}
|
||||
}
|
||||
return opts
|
||||
}
|
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
|
@ -1,168 +0,0 @@
|
|||
package mint
|
||||
|
||||
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Server returns a new TLS server side connection
|
||||
// using conn as the underlying transport.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Server(conn net.Conn, config *Config) *Conn {
|
||||
return NewConn(conn, config, false)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection
|
||||
// using conn as the underlying transport.
|
||||
// The config cannot be nil: users must set either ServerName or
|
||||
// InsecureSkipVerify in the config.
|
||||
func Client(conn net.Conn, config *Config) *Conn {
|
||||
return NewConn(conn, config, true)
|
||||
}
|
||||
|
||||
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
config *Config
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next incoming TLS connection.
|
||||
// The returned connection c is a *tls.Conn.
|
||||
func (l *Listener) Accept() (c net.Conn, err error) {
|
||||
c, err = l.Listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
server := Server(c, l.config)
|
||||
err = server.Handshake()
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
c = server
|
||||
return
|
||||
}
|
||||
|
||||
// NewListener creates a Listener which accepts connections from an inner
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||
l := new(Listener)
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
return l
|
||||
}
|
||||
|
||||
// Listen creates a TLS listener accepting connections on the
|
||||
// given network address using net.Listen.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||
if config == nil || !config.ValidForServer() {
|
||||
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
|
||||
}
|
||||
l, err := net.Listen(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewListener(l, config), nil
|
||||
}
|
||||
|
||||
type TimeoutError struct{}
|
||||
|
||||
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
||||
func (TimeoutError) Timeout() bool { return true }
|
||||
func (TimeoutError) Temporary() bool { return true }
|
||||
|
||||
// DialWithDialer connects to the given network address using dialer.Dial and
|
||||
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
||||
// timeout or deadline given in the dialer apply to connection and TLS
|
||||
// handshake as a whole.
|
||||
//
|
||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||
// configuration; see the documentation of Config for the defaults.
|
||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||
// We want the Timeout and Deadline values from dialer to cover the
|
||||
// whole process: TCP connection and TLS handshake. This means that we
|
||||
// also need to start our own timers now.
|
||||
timeout := dialer.Timeout
|
||||
|
||||
if !dialer.Deadline.IsZero() {
|
||||
deadlineTimeout := dialer.Deadline.Sub(time.Now())
|
||||
if timeout == 0 || deadlineTimeout < timeout {
|
||||
timeout = deadlineTimeout
|
||||
}
|
||||
}
|
||||
|
||||
var errChannel chan error
|
||||
|
||||
if timeout != 0 {
|
||||
errChannel = make(chan error, 2)
|
||||
time.AfterFunc(timeout, func() {
|
||||
errChannel <- TimeoutError{}
|
||||
})
|
||||
}
|
||||
|
||||
rawConn, err := dialer.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
hostname := addr[:colonPos]
|
||||
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
// If no ServerName is set, infer the ServerName
|
||||
// from the hostname we're connecting to.
|
||||
if config.ServerName == "" {
|
||||
// Make a copy to avoid polluting argument or default.
|
||||
c := config.Clone()
|
||||
c.ServerName = hostname
|
||||
config = c
|
||||
}
|
||||
|
||||
conn := Client(rawConn, config)
|
||||
|
||||
if timeout == 0 {
|
||||
err = conn.Handshake()
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
} else {
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
}()
|
||||
|
||||
err = <-errChannel
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial connects to the given network address using net.Dial
|
||||
// and then initiates a TLS handshake, returning the resulting
|
||||
// TLS connection.
|
||||
// Dial interprets a nil configuration as equivalent to
|
||||
// the zero configuration; see the documentation of Config
|
||||
// for the defaults.
|
||||
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||
}
|
202
vendor/github.com/golang/mock/gomock/LICENSE
generated
vendored
Normal file
202
vendor/github.com/golang/mock/gomock/LICENSE
generated
vendored
Normal file
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
428
vendor/github.com/golang/mock/gomock/call.go
generated
vendored
Normal file
428
vendor/github.com/golang/mock/gomock/call.go
generated
vendored
Normal file
|
@ -0,0 +1,428 @@
|
|||
// Copyright 2010 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package gomock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Call represents an expected call to a mock.
|
||||
type Call struct {
|
||||
t TestReporter // for triggering test failures on invalid call setup
|
||||
|
||||
receiver interface{} // the receiver of the method call
|
||||
method string // the name of the method
|
||||
methodType reflect.Type // the type of the method
|
||||
args []Matcher // the args
|
||||
origin string // file and line number of call setup
|
||||
|
||||
preReqs []*Call // prerequisite calls
|
||||
|
||||
// Expectations
|
||||
minCalls, maxCalls int
|
||||
|
||||
numCalls int // actual number made
|
||||
|
||||
// actions are called when this Call is called. Each action gets the args and
|
||||
// can set the return values by returning a non-nil slice. Actions run in the
|
||||
// order they are created.
|
||||
actions []func([]interface{}) []interface{}
|
||||
}
|
||||
|
||||
// newCall creates a *Call. It requires the method type in order to support
|
||||
// unexported methods.
|
||||
func newCall(t TestReporter, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
|
||||
if h, ok := t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
// TODO: check arity, types.
|
||||
margs := make([]Matcher, len(args))
|
||||
for i, arg := range args {
|
||||
if m, ok := arg.(Matcher); ok {
|
||||
margs[i] = m
|
||||
} else if arg == nil {
|
||||
// Handle nil specially so that passing a nil interface value
|
||||
// will match the typed nils of concrete args.
|
||||
margs[i] = Nil()
|
||||
} else {
|
||||
margs[i] = Eq(arg)
|
||||
}
|
||||
}
|
||||
|
||||
origin := callerInfo(3)
|
||||
actions := []func([]interface{}) []interface{}{func([]interface{}) []interface{} {
|
||||
// Synthesize the zero value for each of the return args' types.
|
||||
rets := make([]interface{}, methodType.NumOut())
|
||||
for i := 0; i < methodType.NumOut(); i++ {
|
||||
rets[i] = reflect.Zero(methodType.Out(i)).Interface()
|
||||
}
|
||||
return rets
|
||||
}}
|
||||
return &Call{t: t, receiver: receiver, method: method, methodType: methodType,
|
||||
args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
|
||||
}
|
||||
|
||||
// AnyTimes allows the expectation to be called 0 or more times
|
||||
func (c *Call) AnyTimes() *Call {
|
||||
c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity
|
||||
return c
|
||||
}
|
||||
|
||||
// MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called, MinTimes also
|
||||
// sets the maximum number of calls to infinity.
|
||||
func (c *Call) MinTimes(n int) *Call {
|
||||
c.minCalls = n
|
||||
if c.maxCalls == 1 {
|
||||
c.maxCalls = 1e8
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called, MaxTimes also
|
||||
// sets the minimum number of calls to 0.
|
||||
func (c *Call) MaxTimes(n int) *Call {
|
||||
c.maxCalls = n
|
||||
if c.minCalls == 1 {
|
||||
c.minCalls = 0
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn declares the action to run when the call is matched.
|
||||
// The return values from this function are returned by the mocked function.
|
||||
// It takes an interface{} argument to support n-arity functions.
|
||||
func (c *Call) DoAndReturn(f interface{}) *Call {
|
||||
// TODO: Check arity and types here, rather than dying badly elsewhere.
|
||||
v := reflect.ValueOf(f)
|
||||
|
||||
c.addAction(func(args []interface{}) []interface{} {
|
||||
vargs := make([]reflect.Value, len(args))
|
||||
ft := v.Type()
|
||||
for i := 0; i < len(args); i++ {
|
||||
if args[i] != nil {
|
||||
vargs[i] = reflect.ValueOf(args[i])
|
||||
} else {
|
||||
// Use the zero value for the arg.
|
||||
vargs[i] = reflect.Zero(ft.In(i))
|
||||
}
|
||||
}
|
||||
vrets := v.Call(vargs)
|
||||
rets := make([]interface{}, len(vrets))
|
||||
for i, ret := range vrets {
|
||||
rets[i] = ret.Interface()
|
||||
}
|
||||
return rets
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// Do declares the action to run when the call is matched. The function's
|
||||
// return values are ignored to retain backward compatibility. To use the
|
||||
// return values call DoAndReturn.
|
||||
// It takes an interface{} argument to support n-arity functions.
|
||||
func (c *Call) Do(f interface{}) *Call {
|
||||
// TODO: Check arity and types here, rather than dying badly elsewhere.
|
||||
v := reflect.ValueOf(f)
|
||||
|
||||
c.addAction(func(args []interface{}) []interface{} {
|
||||
vargs := make([]reflect.Value, len(args))
|
||||
ft := v.Type()
|
||||
for i := 0; i < len(args); i++ {
|
||||
if args[i] != nil {
|
||||
vargs[i] = reflect.ValueOf(args[i])
|
||||
} else {
|
||||
// Use the zero value for the arg.
|
||||
vargs[i] = reflect.Zero(ft.In(i))
|
||||
}
|
||||
}
|
||||
v.Call(vargs)
|
||||
return nil
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// Return declares the values to be returned by the mocked function call.
|
||||
func (c *Call) Return(rets ...interface{}) *Call {
|
||||
if h, ok := c.t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
mt := c.methodType
|
||||
if len(rets) != mt.NumOut() {
|
||||
c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]",
|
||||
c.receiver, c.method, len(rets), mt.NumOut(), c.origin)
|
||||
}
|
||||
for i, ret := range rets {
|
||||
if got, want := reflect.TypeOf(ret), mt.Out(i); got == want {
|
||||
// Identical types; nothing to do.
|
||||
} else if got == nil {
|
||||
// Nil needs special handling.
|
||||
switch want.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
// ok
|
||||
default:
|
||||
c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]",
|
||||
i, c.receiver, c.method, want, c.origin)
|
||||
}
|
||||
} else if got.AssignableTo(want) {
|
||||
// Assignable type relation. Make the assignment now so that the generated code
|
||||
// can return the values with a type assertion.
|
||||
v := reflect.New(want).Elem()
|
||||
v.Set(reflect.ValueOf(ret))
|
||||
rets[i] = v.Interface()
|
||||
} else {
|
||||
c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]",
|
||||
i, c.receiver, c.method, got, want, c.origin)
|
||||
}
|
||||
}
|
||||
|
||||
c.addAction(func([]interface{}) []interface{} {
|
||||
return rets
|
||||
})
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Times declares the exact number of times a function call is expected to be executed.
|
||||
func (c *Call) Times(n int) *Call {
|
||||
c.minCalls, c.maxCalls = n, n
|
||||
return c
|
||||
}
|
||||
|
||||
// SetArg declares an action that will set the nth argument's value,
|
||||
// indirected through a pointer. Or, in the case of a slice, SetArg
|
||||
// will copy value's elements into the nth argument.
|
||||
func (c *Call) SetArg(n int, value interface{}) *Call {
|
||||
if h, ok := c.t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
mt := c.methodType
|
||||
// TODO: This will break on variadic methods.
|
||||
// We will need to check those at invocation time.
|
||||
if n < 0 || n >= mt.NumIn() {
|
||||
c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]",
|
||||
n, mt.NumIn(), c.origin)
|
||||
}
|
||||
// Permit setting argument through an interface.
|
||||
// In the interface case, we don't (nay, can't) check the type here.
|
||||
at := mt.In(n)
|
||||
switch at.Kind() {
|
||||
case reflect.Ptr:
|
||||
dt := at.Elem()
|
||||
if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) {
|
||||
c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]",
|
||||
n, vt, dt, c.origin)
|
||||
}
|
||||
case reflect.Interface:
|
||||
// nothing to do
|
||||
case reflect.Slice:
|
||||
// nothing to do
|
||||
default:
|
||||
c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]",
|
||||
n, at, c.origin)
|
||||
}
|
||||
|
||||
c.addAction(func(args []interface{}) []interface{} {
|
||||
v := reflect.ValueOf(value)
|
||||
switch reflect.TypeOf(args[n]).Kind() {
|
||||
case reflect.Slice:
|
||||
setSlice(args[n], v)
|
||||
default:
|
||||
reflect.ValueOf(args[n]).Elem().Set(v)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// isPreReq returns true if other is a direct or indirect prerequisite to c.
|
||||
func (c *Call) isPreReq(other *Call) bool {
|
||||
for _, preReq := range c.preReqs {
|
||||
if other == preReq || preReq.isPreReq(other) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// After declares that the call may only match after preReq has been exhausted.
|
||||
func (c *Call) After(preReq *Call) *Call {
|
||||
if h, ok := c.t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
if c == preReq {
|
||||
c.t.Fatalf("A call isn't allowed to be its own prerequisite")
|
||||
}
|
||||
if preReq.isPreReq(c) {
|
||||
c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq)
|
||||
}
|
||||
|
||||
c.preReqs = append(c.preReqs, preReq)
|
||||
return c
|
||||
}
|
||||
|
||||
// Returns true if the minimum number of calls have been made.
|
||||
func (c *Call) satisfied() bool {
|
||||
return c.numCalls >= c.minCalls
|
||||
}
|
||||
|
||||
// Returns true iff the maximum number of calls have been made.
|
||||
func (c *Call) exhausted() bool {
|
||||
return c.numCalls >= c.maxCalls
|
||||
}
|
||||
|
||||
func (c *Call) String() string {
|
||||
args := make([]string, len(c.args))
|
||||
for i, arg := range c.args {
|
||||
args[i] = arg.String()
|
||||
}
|
||||
arguments := strings.Join(args, ", ")
|
||||
return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin)
|
||||
}
|
||||
|
||||
// Tests if the given call matches the expected call.
|
||||
// If yes, returns nil. If no, returns error with message explaining why it does not match.
|
||||
func (c *Call) matches(args []interface{}) error {
|
||||
if !c.methodType.IsVariadic() {
|
||||
if len(args) != len(c.args) {
|
||||
return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d",
|
||||
c.origin, len(args), len(c.args))
|
||||
}
|
||||
|
||||
for i, m := range c.args {
|
||||
if !m.Matches(args[i]) {
|
||||
return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
|
||||
c.origin, strconv.Itoa(i), args[i], m)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(c.args) < c.methodType.NumIn()-1 {
|
||||
return fmt.Errorf("Expected call at %s has the wrong number of matchers. Got: %d, want: %d",
|
||||
c.origin, len(c.args), c.methodType.NumIn()-1)
|
||||
}
|
||||
if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) {
|
||||
return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: %d",
|
||||
c.origin, len(args), len(c.args))
|
||||
}
|
||||
if len(args) < len(c.args)-1 {
|
||||
return fmt.Errorf("Expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d",
|
||||
c.origin, len(args), len(c.args)-1)
|
||||
}
|
||||
|
||||
for i, m := range c.args {
|
||||
if i < c.methodType.NumIn()-1 {
|
||||
// Non-variadic args
|
||||
if !m.Matches(args[i]) {
|
||||
return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
|
||||
c.origin, strconv.Itoa(i), args[i], m)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// The last arg has a possibility of a variadic argument, so let it branch
|
||||
|
||||
// sample: Foo(a int, b int, c ...int)
|
||||
if i < len(c.args) && i < len(args) {
|
||||
if m.Matches(args[i]) {
|
||||
// Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any())
|
||||
// Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher)
|
||||
// Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC)
|
||||
// Got Foo(a, b) want Foo(matcherA, matcherB)
|
||||
// Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// The number of actual args don't match the number of matchers,
|
||||
// or the last matcher is a slice and the last arg is not.
|
||||
// If this function still matches it is because the last matcher
|
||||
// matches all the remaining arguments or the lack of any.
|
||||
// Convert the remaining arguments, if any, into a slice of the
|
||||
// expected type.
|
||||
vargsType := c.methodType.In(c.methodType.NumIn() - 1)
|
||||
vargs := reflect.MakeSlice(vargsType, 0, len(args)-i)
|
||||
for _, arg := range args[i:] {
|
||||
vargs = reflect.Append(vargs, reflect.ValueOf(arg))
|
||||
}
|
||||
if m.Matches(vargs.Interface()) {
|
||||
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any())
|
||||
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher)
|
||||
// Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any())
|
||||
// Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher)
|
||||
break
|
||||
}
|
||||
// Wrong number of matchers or not match. Fail.
|
||||
// Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD)
|
||||
// Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD)
|
||||
// Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE)
|
||||
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD)
|
||||
// Got Foo(a, b, c) want Foo(matcherA, matcherB)
|
||||
return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
|
||||
c.origin, strconv.Itoa(i), args[i:], c.args[i])
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all prerequisite calls have been satisfied.
|
||||
for _, preReqCall := range c.preReqs {
|
||||
if !preReqCall.satisfied() {
|
||||
return fmt.Errorf("Expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v",
|
||||
c.origin, preReqCall, c)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the call is not exhausted.
|
||||
if c.exhausted() {
|
||||
return fmt.Errorf("Expected call at %s has already been called the max number of times.", c.origin)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dropPrereqs tells the expected Call to not re-check prerequisite calls any
|
||||
// longer, and to return its current set.
|
||||
func (c *Call) dropPrereqs() (preReqs []*Call) {
|
||||
preReqs = c.preReqs
|
||||
c.preReqs = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Call) call(args []interface{}) []func([]interface{}) []interface{} {
|
||||
c.numCalls++
|
||||
return c.actions
|
||||
}
|
||||
|
||||
// InOrder declares that the given calls should occur in order.
|
||||
func InOrder(calls ...*Call) {
|
||||
for i := 1; i < len(calls); i++ {
|
||||
calls[i].After(calls[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
func setSlice(arg interface{}, v reflect.Value) {
|
||||
va := reflect.ValueOf(arg)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
va.Index(i).Set(v.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Call) addAction(action func([]interface{}) []interface{}) {
|
||||
c.actions = append(c.actions, action)
|
||||
}
|
108
vendor/github.com/golang/mock/gomock/callset.go
generated
vendored
Normal file
108
vendor/github.com/golang/mock/gomock/callset.go
generated
vendored
Normal file
|
@ -0,0 +1,108 @@
|
|||
// Copyright 2011 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package gomock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// callSet represents a set of expected calls, indexed by receiver and method
|
||||
// name.
|
||||
type callSet struct {
|
||||
// Calls that are still expected.
|
||||
expected map[callSetKey][]*Call
|
||||
// Calls that have been exhausted.
|
||||
exhausted map[callSetKey][]*Call
|
||||
}
|
||||
|
||||
// callSetKey is the key in the maps in callSet
|
||||
type callSetKey struct {
|
||||
receiver interface{}
|
||||
fname string
|
||||
}
|
||||
|
||||
func newCallSet() *callSet {
|
||||
return &callSet{make(map[callSetKey][]*Call), make(map[callSetKey][]*Call)}
|
||||
}
|
||||
|
||||
// Add adds a new expected call.
|
||||
func (cs callSet) Add(call *Call) {
|
||||
key := callSetKey{call.receiver, call.method}
|
||||
m := cs.expected
|
||||
if call.exhausted() {
|
||||
m = cs.exhausted
|
||||
}
|
||||
m[key] = append(m[key], call)
|
||||
}
|
||||
|
||||
// Remove removes an expected call.
|
||||
func (cs callSet) Remove(call *Call) {
|
||||
key := callSetKey{call.receiver, call.method}
|
||||
calls := cs.expected[key]
|
||||
for i, c := range calls {
|
||||
if c == call {
|
||||
// maintain order for remaining calls
|
||||
cs.expected[key] = append(calls[:i], calls[i+1:]...)
|
||||
cs.exhausted[key] = append(cs.exhausted[key], call)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FindMatch searches for a matching call. Returns error with explanation message if no call matched.
|
||||
func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {
|
||||
key := callSetKey{receiver, method}
|
||||
|
||||
// Search through the expected calls.
|
||||
expected := cs.expected[key]
|
||||
var callsErrors bytes.Buffer
|
||||
for _, call := range expected {
|
||||
err := call.matches(args)
|
||||
if err != nil {
|
||||
fmt.Fprintf(&callsErrors, "\n%v", err)
|
||||
} else {
|
||||
return call, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If we haven't found a match then search through the exhausted calls so we
|
||||
// get useful error messages.
|
||||
exhausted := cs.exhausted[key]
|
||||
for _, call := range exhausted {
|
||||
if err := call.matches(args); err != nil {
|
||||
fmt.Fprintf(&callsErrors, "\n%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(expected)+len(exhausted) == 0 {
|
||||
fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(callsErrors.String())
|
||||
}
|
||||
|
||||
// Failures returns the calls that are not satisfied.
|
||||
func (cs callSet) Failures() []*Call {
|
||||
failures := make([]*Call, 0, len(cs.expected))
|
||||
for _, calls := range cs.expected {
|
||||
for _, call := range calls {
|
||||
if !call.satisfied() {
|
||||
failures = append(failures, call)
|
||||
}
|
||||
}
|
||||
}
|
||||
return failures
|
||||
}
|
217
vendor/github.com/golang/mock/gomock/controller.go
generated
vendored
Normal file
217
vendor/github.com/golang/mock/gomock/controller.go
generated
vendored
Normal file
|
@ -0,0 +1,217 @@
|
|||
// Copyright 2010 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// GoMock - a mock framework for Go.
|
||||
//
|
||||
// Standard usage:
|
||||
// (1) Define an interface that you wish to mock.
|
||||
// type MyInterface interface {
|
||||
// SomeMethod(x int64, y string)
|
||||
// }
|
||||
// (2) Use mockgen to generate a mock from the interface.
|
||||
// (3) Use the mock in a test:
|
||||
// func TestMyThing(t *testing.T) {
|
||||
// mockCtrl := gomock.NewController(t)
|
||||
// defer mockCtrl.Finish()
|
||||
//
|
||||
// mockObj := something.NewMockMyInterface(mockCtrl)
|
||||
// mockObj.EXPECT().SomeMethod(4, "blah")
|
||||
// // pass mockObj to a real object and play with it.
|
||||
// }
|
||||
//
|
||||
// By default, expected calls are not enforced to run in any particular order.
|
||||
// Call order dependency can be enforced by use of InOrder and/or Call.After.
|
||||
// Call.After can create more varied call order dependencies, but InOrder is
|
||||
// often more convenient.
|
||||
//
|
||||
// The following examples create equivalent call order dependencies.
|
||||
//
|
||||
// Example of using Call.After to chain expected call order:
|
||||
//
|
||||
// firstCall := mockObj.EXPECT().SomeMethod(1, "first")
|
||||
// secondCall := mockObj.EXPECT().SomeMethod(2, "second").After(firstCall)
|
||||
// mockObj.EXPECT().SomeMethod(3, "third").After(secondCall)
|
||||
//
|
||||
// Example of using InOrder to declare expected call order:
|
||||
//
|
||||
// gomock.InOrder(
|
||||
// mockObj.EXPECT().SomeMethod(1, "first"),
|
||||
// mockObj.EXPECT().SomeMethod(2, "second"),
|
||||
// mockObj.EXPECT().SomeMethod(3, "third"),
|
||||
// )
|
||||
//
|
||||
// TODO:
|
||||
// - Handle different argument/return types (e.g. ..., chan, map, interface).
|
||||
package gomock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A TestReporter is something that can be used to report test failures.
|
||||
// It is satisfied by the standard library's *testing.T.
|
||||
type TestReporter interface {
|
||||
Errorf(format string, args ...interface{})
|
||||
Fatalf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// A Controller represents the top-level control of a mock ecosystem.
|
||||
// It defines the scope and lifetime of mock objects, as well as their expectations.
|
||||
// It is safe to call Controller's methods from multiple goroutines.
|
||||
type Controller struct {
|
||||
mu sync.Mutex
|
||||
t TestReporter
|
||||
expectedCalls *callSet
|
||||
finished bool
|
||||
}
|
||||
|
||||
func NewController(t TestReporter) *Controller {
|
||||
return &Controller{
|
||||
t: t,
|
||||
expectedCalls: newCallSet(),
|
||||
}
|
||||
}
|
||||
|
||||
type cancelReporter struct {
|
||||
t TestReporter
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func (r *cancelReporter) Errorf(format string, args ...interface{}) { r.t.Errorf(format, args...) }
|
||||
func (r *cancelReporter) Fatalf(format string, args ...interface{}) {
|
||||
defer r.cancel()
|
||||
r.t.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
// WithContext returns a new Controller and a Context, which is cancelled on any
|
||||
// fatal failure.
|
||||
func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return NewController(&cancelReporter{t, cancel}), ctx
|
||||
}
|
||||
|
||||
func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call {
|
||||
if h, ok := ctrl.t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
recv := reflect.ValueOf(receiver)
|
||||
for i := 0; i < recv.Type().NumMethod(); i++ {
|
||||
if recv.Type().Method(i).Name == method {
|
||||
return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...)
|
||||
}
|
||||
}
|
||||
ctrl.t.Fatalf("gomock: failed finding method %s on %T", method, receiver)
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
|
||||
if h, ok := ctrl.t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
call := newCall(ctrl.t, receiver, method, methodType, args...)
|
||||
|
||||
ctrl.mu.Lock()
|
||||
defer ctrl.mu.Unlock()
|
||||
ctrl.expectedCalls.Add(call)
|
||||
|
||||
return call
|
||||
}
|
||||
|
||||
func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
|
||||
if h, ok := ctrl.t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
// Nest this code so we can use defer to make sure the lock is released.
|
||||
actions := func() []func([]interface{}) []interface{} {
|
||||
ctrl.mu.Lock()
|
||||
defer ctrl.mu.Unlock()
|
||||
|
||||
expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
|
||||
if err != nil {
|
||||
origin := callerInfo(2)
|
||||
ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
|
||||
}
|
||||
|
||||
// Two things happen here:
|
||||
// * the matching call no longer needs to check prerequite calls,
|
||||
// * and the prerequite calls are no longer expected, so remove them.
|
||||
preReqCalls := expected.dropPrereqs()
|
||||
for _, preReqCall := range preReqCalls {
|
||||
ctrl.expectedCalls.Remove(preReqCall)
|
||||
}
|
||||
|
||||
actions := expected.call(args)
|
||||
if expected.exhausted() {
|
||||
ctrl.expectedCalls.Remove(expected)
|
||||
}
|
||||
return actions
|
||||
}()
|
||||
|
||||
var rets []interface{}
|
||||
for _, action := range actions {
|
||||
if r := action(args); r != nil {
|
||||
rets = r
|
||||
}
|
||||
}
|
||||
|
||||
return rets
|
||||
}
|
||||
|
||||
func (ctrl *Controller) Finish() {
|
||||
if h, ok := ctrl.t.(testHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
ctrl.mu.Lock()
|
||||
defer ctrl.mu.Unlock()
|
||||
|
||||
if ctrl.finished {
|
||||
ctrl.t.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.")
|
||||
}
|
||||
ctrl.finished = true
|
||||
|
||||
// If we're currently panicking, probably because this is a deferred call,
|
||||
// pass through the panic.
|
||||
if err := recover(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Check that all remaining expected calls are satisfied.
|
||||
failures := ctrl.expectedCalls.Failures()
|
||||
for _, call := range failures {
|
||||
ctrl.t.Errorf("missing call(s) to %v", call)
|
||||
}
|
||||
if len(failures) != 0 {
|
||||
ctrl.t.Fatalf("aborting test due to missing call(s)")
|
||||
}
|
||||
}
|
||||
|
||||
func callerInfo(skip int) string {
|
||||
if _, file, line, ok := runtime.Caller(skip + 1); ok {
|
||||
return fmt.Sprintf("%s:%d", file, line)
|
||||
}
|
||||
return "unknown file"
|
||||
}
|
||||
|
||||
type testHelper interface {
|
||||
TestReporter
|
||||
Helper()
|
||||
}
|
124
vendor/github.com/golang/mock/gomock/matchers.go
generated
vendored
Normal file
124
vendor/github.com/golang/mock/gomock/matchers.go
generated
vendored
Normal file
|
@ -0,0 +1,124 @@
|
|||
//go:generate mockgen -destination mock_matcher/mock_matcher.go github.com/golang/mock/gomock Matcher
|
||||
|
||||
// Copyright 2010 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package gomock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// A Matcher is a representation of a class of values.
|
||||
// It is used to represent the valid or expected arguments to a mocked method.
|
||||
type Matcher interface {
|
||||
// Matches returns whether x is a match.
|
||||
Matches(x interface{}) bool
|
||||
|
||||
// String describes what the matcher matches.
|
||||
String() string
|
||||
}
|
||||
|
||||
type anyMatcher struct{}
|
||||
|
||||
func (anyMatcher) Matches(x interface{}) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (anyMatcher) String() string {
|
||||
return "is anything"
|
||||
}
|
||||
|
||||
type eqMatcher struct {
|
||||
x interface{}
|
||||
}
|
||||
|
||||
func (e eqMatcher) Matches(x interface{}) bool {
|
||||
return reflect.DeepEqual(e.x, x)
|
||||
}
|
||||
|
||||
func (e eqMatcher) String() string {
|
||||
return fmt.Sprintf("is equal to %v", e.x)
|
||||
}
|
||||
|
||||
type nilMatcher struct{}
|
||||
|
||||
func (nilMatcher) Matches(x interface{}) bool {
|
||||
if x == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(x)
|
||||
switch v.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
|
||||
reflect.Ptr, reflect.Slice:
|
||||
return v.IsNil()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (nilMatcher) String() string {
|
||||
return "is nil"
|
||||
}
|
||||
|
||||
type notMatcher struct {
|
||||
m Matcher
|
||||
}
|
||||
|
||||
func (n notMatcher) Matches(x interface{}) bool {
|
||||
return !n.m.Matches(x)
|
||||
}
|
||||
|
||||
func (n notMatcher) String() string {
|
||||
// TODO: Improve this if we add a NotString method to the Matcher interface.
|
||||
return "not(" + n.m.String() + ")"
|
||||
}
|
||||
|
||||
type assignableToTypeOfMatcher struct {
|
||||
targetType reflect.Type
|
||||
}
|
||||
|
||||
func (m assignableToTypeOfMatcher) Matches(x interface{}) bool {
|
||||
return reflect.TypeOf(x).AssignableTo(m.targetType)
|
||||
}
|
||||
|
||||
func (m assignableToTypeOfMatcher) String() string {
|
||||
return "is assignable to " + m.targetType.Name()
|
||||
}
|
||||
|
||||
// Constructors
|
||||
func Any() Matcher { return anyMatcher{} }
|
||||
func Eq(x interface{}) Matcher { return eqMatcher{x} }
|
||||
func Nil() Matcher { return nilMatcher{} }
|
||||
func Not(x interface{}) Matcher {
|
||||
if m, ok := x.(Matcher); ok {
|
||||
return notMatcher{m}
|
||||
}
|
||||
return notMatcher{Eq(x)}
|
||||
}
|
||||
|
||||
// AssignableToTypeOf is a Matcher that matches if the parameter to the mock
|
||||
// function is assignable to the type of the parameter to this function.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// dbMock.EXPECT().
|
||||
// Insert(gomock.AssignableToTypeOf(&EmployeeRecord{})).
|
||||
// Return(errors.New("DB error"))
|
||||
//
|
||||
func AssignableToTypeOf(x interface{}) Matcher {
|
||||
return assignableToTypeOfMatcher{reflect.TypeOf(x)}
|
||||
}
|
57
vendor/github.com/golang/mock/gomock/mock_matcher/mock_matcher.go
generated
vendored
Normal file
57
vendor/github.com/golang/mock/gomock/mock_matcher/mock_matcher.go
generated
vendored
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/golang/mock/gomock (interfaces: Matcher)
|
||||
|
||||
// Package mock_gomock is a generated GoMock package.
|
||||
package mock_gomock
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockMatcher is a mock of Matcher interface
|
||||
type MockMatcher struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMatcherMockRecorder
|
||||
}
|
||||
|
||||
// MockMatcherMockRecorder is the mock recorder for MockMatcher
|
||||
type MockMatcherMockRecorder struct {
|
||||
mock *MockMatcher
|
||||
}
|
||||
|
||||
// NewMockMatcher creates a new mock instance
|
||||
func NewMockMatcher(ctrl *gomock.Controller) *MockMatcher {
|
||||
mock := &MockMatcher{ctrl: ctrl}
|
||||
mock.recorder = &MockMatcherMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockMatcher) EXPECT() *MockMatcherMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Matches mocks base method
|
||||
func (m *MockMatcher) Matches(arg0 interface{}) bool {
|
||||
ret := m.ctrl.Call(m, "Matches", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Matches indicates an expected call of Matches
|
||||
func (mr *MockMatcherMockRecorder) Matches(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Matches", reflect.TypeOf((*MockMatcher)(nil).Matches), arg0)
|
||||
}
|
||||
|
||||
// String mocks base method
|
||||
func (m *MockMatcher) String() string {
|
||||
ret := m.ctrl.Call(m, "String")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// String indicates an expected call of String
|
||||
func (mr *MockMatcherMockRecorder) String() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockMatcher)(nil).String))
|
||||
}
|
362
vendor/github.com/hashicorp/golang-lru/LICENSE
generated
vendored
362
vendor/github.com/hashicorp/golang-lru/LICENSE
generated
vendored
|
@ -1,362 +0,0 @@
|
|||
Mozilla Public License, version 2.0
|
||||
|
||||
1. Definitions
|
||||
|
||||
1.1. "Contributor"
|
||||
|
||||
means each individual or legal entity that creates, contributes to the
|
||||
creation of, or owns Covered Software.
|
||||
|
||||
1.2. "Contributor Version"
|
||||
|
||||
means the combination of the Contributions of others (if any) used by a
|
||||
Contributor and that particular Contributor's Contribution.
|
||||
|
||||
1.3. "Contribution"
|
||||
|
||||
means Covered Software of a particular Contributor.
|
||||
|
||||
1.4. "Covered Software"
|
||||
|
||||
means Source Code Form to which the initial Contributor has attached the
|
||||
notice in Exhibit A, the Executable Form of such Source Code Form, and
|
||||
Modifications of such Source Code Form, in each case including portions
|
||||
thereof.
|
||||
|
||||
1.5. "Incompatible With Secondary Licenses"
|
||||
means
|
||||
|
||||
a. that the initial Contributor has attached the notice described in
|
||||
Exhibit B to the Covered Software; or
|
||||
|
||||
b. that the Covered Software was made available under the terms of
|
||||
version 1.1 or earlier of the License, but not also under the terms of
|
||||
a Secondary License.
|
||||
|
||||
1.6. "Executable Form"
|
||||
|
||||
means any form of the work other than Source Code Form.
|
||||
|
||||
1.7. "Larger Work"
|
||||
|
||||
means a work that combines Covered Software with other material, in a
|
||||
separate file or files, that is not Covered Software.
|
||||
|
||||
1.8. "License"
|
||||
|
||||
means this document.
|
||||
|
||||
1.9. "Licensable"
|
||||
|
||||
means having the right to grant, to the maximum extent possible, whether
|
||||
at the time of the initial grant or subsequently, any and all of the
|
||||
rights conveyed by this License.
|
||||
|
||||
1.10. "Modifications"
|
||||
|
||||
means any of the following:
|
||||
|
||||
a. any file in Source Code Form that results from an addition to,
|
||||
deletion from, or modification of the contents of Covered Software; or
|
||||
|
||||
b. any new file in Source Code Form that contains any Covered Software.
|
||||
|
||||
1.11. "Patent Claims" of a Contributor
|
||||
|
||||
means any patent claim(s), including without limitation, method,
|
||||
process, and apparatus claims, in any patent Licensable by such
|
||||
Contributor that would be infringed, but for the grant of the License,
|
||||
by the making, using, selling, offering for sale, having made, import,
|
||||
or transfer of either its Contributions or its Contributor Version.
|
||||
|
||||
1.12. "Secondary License"
|
||||
|
||||
means either the GNU General Public License, Version 2.0, the GNU Lesser
|
||||
General Public License, Version 2.1, the GNU Affero General Public
|
||||
License, Version 3.0, or any later versions of those licenses.
|
||||
|
||||
1.13. "Source Code Form"
|
||||
|
||||
means the form of the work preferred for making modifications.
|
||||
|
||||
1.14. "You" (or "Your")
|
||||
|
||||
means an individual or a legal entity exercising rights under this
|
||||
License. For legal entities, "You" includes any entity that controls, is
|
||||
controlled by, or is under common control with You. For purposes of this
|
||||
definition, "control" means (a) the power, direct or indirect, to cause
|
||||
the direction or management of such entity, whether by contract or
|
||||
otherwise, or (b) ownership of more than fifty percent (50%) of the
|
||||
outstanding shares or beneficial ownership of such entity.
|
||||
|
||||
|
||||
2. License Grants and Conditions
|
||||
|
||||
2.1. Grants
|
||||
|
||||
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||
non-exclusive license:
|
||||
|
||||
a. under intellectual property rights (other than patent or trademark)
|
||||
Licensable by such Contributor to use, reproduce, make available,
|
||||
modify, display, perform, distribute, and otherwise exploit its
|
||||
Contributions, either on an unmodified basis, with Modifications, or
|
||||
as part of a Larger Work; and
|
||||
|
||||
b. under Patent Claims of such Contributor to make, use, sell, offer for
|
||||
sale, have made, import, and otherwise transfer either its
|
||||
Contributions or its Contributor Version.
|
||||
|
||||
2.2. Effective Date
|
||||
|
||||
The licenses granted in Section 2.1 with respect to any Contribution
|
||||
become effective for each Contribution on the date the Contributor first
|
||||
distributes such Contribution.
|
||||
|
||||
2.3. Limitations on Grant Scope
|
||||
|
||||
The licenses granted in this Section 2 are the only rights granted under
|
||||
this License. No additional rights or licenses will be implied from the
|
||||
distribution or licensing of Covered Software under this License.
|
||||
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
||||
Contributor:
|
||||
|
||||
a. for any code that a Contributor has removed from Covered Software; or
|
||||
|
||||
b. for infringements caused by: (i) Your and any other third party's
|
||||
modifications of Covered Software, or (ii) the combination of its
|
||||
Contributions with other software (except as part of its Contributor
|
||||
Version); or
|
||||
|
||||
c. under Patent Claims infringed by Covered Software in the absence of
|
||||
its Contributions.
|
||||
|
||||
This License does not grant any rights in the trademarks, service marks,
|
||||
or logos of any Contributor (except as may be necessary to comply with
|
||||
the notice requirements in Section 3.4).
|
||||
|
||||
2.4. Subsequent Licenses
|
||||
|
||||
No Contributor makes additional grants as a result of Your choice to
|
||||
distribute the Covered Software under a subsequent version of this
|
||||
License (see Section 10.2) or under the terms of a Secondary License (if
|
||||
permitted under the terms of Section 3.3).
|
||||
|
||||
2.5. Representation
|
||||
|
||||
Each Contributor represents that the Contributor believes its
|
||||
Contributions are its original creation(s) or it has sufficient rights to
|
||||
grant the rights to its Contributions conveyed by this License.
|
||||
|
||||
2.6. Fair Use
|
||||
|
||||
This License is not intended to limit any rights You have under
|
||||
applicable copyright doctrines of fair use, fair dealing, or other
|
||||
equivalents.
|
||||
|
||||
2.7. Conditions
|
||||
|
||||
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
|
||||
Section 2.1.
|
||||
|
||||
|
||||
3. Responsibilities
|
||||
|
||||
3.1. Distribution of Source Form
|
||||
|
||||
All distribution of Covered Software in Source Code Form, including any
|
||||
Modifications that You create or to which You contribute, must be under
|
||||
the terms of this License. You must inform recipients that the Source
|
||||
Code Form of the Covered Software is governed by the terms of this
|
||||
License, and how they can obtain a copy of this License. You may not
|
||||
attempt to alter or restrict the recipients' rights in the Source Code
|
||||
Form.
|
||||
|
||||
3.2. Distribution of Executable Form
|
||||
|
||||
If You distribute Covered Software in Executable Form then:
|
||||
|
||||
a. such Covered Software must also be made available in Source Code Form,
|
||||
as described in Section 3.1, and You must inform recipients of the
|
||||
Executable Form how they can obtain a copy of such Source Code Form by
|
||||
reasonable means in a timely manner, at a charge no more than the cost
|
||||
of distribution to the recipient; and
|
||||
|
||||
b. You may distribute such Executable Form under the terms of this
|
||||
License, or sublicense it under different terms, provided that the
|
||||
license for the Executable Form does not attempt to limit or alter the
|
||||
recipients' rights in the Source Code Form under this License.
|
||||
|
||||
3.3. Distribution of a Larger Work
|
||||
|
||||
You may create and distribute a Larger Work under terms of Your choice,
|
||||
provided that You also comply with the requirements of this License for
|
||||
the Covered Software. If the Larger Work is a combination of Covered
|
||||
Software with a work governed by one or more Secondary Licenses, and the
|
||||
Covered Software is not Incompatible With Secondary Licenses, this
|
||||
License permits You to additionally distribute such Covered Software
|
||||
under the terms of such Secondary License(s), so that the recipient of
|
||||
the Larger Work may, at their option, further distribute the Covered
|
||||
Software under the terms of either this License or such Secondary
|
||||
License(s).
|
||||
|
||||
3.4. Notices
|
||||
|
||||
You may not remove or alter the substance of any license notices
|
||||
(including copyright notices, patent notices, disclaimers of warranty, or
|
||||
limitations of liability) contained within the Source Code Form of the
|
||||
Covered Software, except that You may alter any license notices to the
|
||||
extent required to remedy known factual inaccuracies.
|
||||
|
||||
3.5. Application of Additional Terms
|
||||
|
||||
You may choose to offer, and to charge a fee for, warranty, support,
|
||||
indemnity or liability obligations to one or more recipients of Covered
|
||||
Software. However, You may do so only on Your own behalf, and not on
|
||||
behalf of any Contributor. You must make it absolutely clear that any
|
||||
such warranty, support, indemnity, or liability obligation is offered by
|
||||
You alone, and You hereby agree to indemnify every Contributor for any
|
||||
liability incurred by such Contributor as a result of warranty, support,
|
||||
indemnity or liability terms You offer. You may include additional
|
||||
disclaimers of warranty and limitations of liability specific to any
|
||||
jurisdiction.
|
||||
|
||||
4. Inability to Comply Due to Statute or Regulation
|
||||
|
||||
If it is impossible for You to comply with any of the terms of this License
|
||||
with respect to some or all of the Covered Software due to statute,
|
||||
judicial order, or regulation then You must: (a) comply with the terms of
|
||||
this License to the maximum extent possible; and (b) describe the
|
||||
limitations and the code they affect. Such description must be placed in a
|
||||
text file included with all distributions of the Covered Software under
|
||||
this License. Except to the extent prohibited by statute or regulation,
|
||||
such description must be sufficiently detailed for a recipient of ordinary
|
||||
skill to be able to understand it.
|
||||
|
||||
5. Termination
|
||||
|
||||
5.1. The rights granted under this License will terminate automatically if You
|
||||
fail to comply with any of its terms. However, if You become compliant,
|
||||
then the rights granted under this License from a particular Contributor
|
||||
are reinstated (a) provisionally, unless and until such Contributor
|
||||
explicitly and finally terminates Your grants, and (b) on an ongoing
|
||||
basis, if such Contributor fails to notify You of the non-compliance by
|
||||
some reasonable means prior to 60 days after You have come back into
|
||||
compliance. Moreover, Your grants from a particular Contributor are
|
||||
reinstated on an ongoing basis if such Contributor notifies You of the
|
||||
non-compliance by some reasonable means, this is the first time You have
|
||||
received notice of non-compliance with this License from such
|
||||
Contributor, and You become compliant prior to 30 days after Your receipt
|
||||
of the notice.
|
||||
|
||||
5.2. If You initiate litigation against any entity by asserting a patent
|
||||
infringement claim (excluding declaratory judgment actions,
|
||||
counter-claims, and cross-claims) alleging that a Contributor Version
|
||||
directly or indirectly infringes any patent, then the rights granted to
|
||||
You by any and all Contributors for the Covered Software under Section
|
||||
2.1 of this License shall terminate.
|
||||
|
||||
5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
|
||||
license agreements (excluding distributors and resellers) which have been
|
||||
validly granted by You or Your distributors under this License prior to
|
||||
termination shall survive termination.
|
||||
|
||||
6. Disclaimer of Warranty
|
||||
|
||||
Covered Software is provided under this License on an "as is" basis,
|
||||
without warranty of any kind, either expressed, implied, or statutory,
|
||||
including, without limitation, warranties that the Covered Software is free
|
||||
of defects, merchantable, fit for a particular purpose or non-infringing.
|
||||
The entire risk as to the quality and performance of the Covered Software
|
||||
is with You. Should any Covered Software prove defective in any respect,
|
||||
You (not any Contributor) assume the cost of any necessary servicing,
|
||||
repair, or correction. This disclaimer of warranty constitutes an essential
|
||||
part of this License. No use of any Covered Software is authorized under
|
||||
this License except under this disclaimer.
|
||||
|
||||
7. Limitation of Liability
|
||||
|
||||
Under no circumstances and under no legal theory, whether tort (including
|
||||
negligence), contract, or otherwise, shall any Contributor, or anyone who
|
||||
distributes Covered Software as permitted above, be liable to You for any
|
||||
direct, indirect, special, incidental, or consequential damages of any
|
||||
character including, without limitation, damages for lost profits, loss of
|
||||
goodwill, work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses, even if such party shall have been
|
||||
informed of the possibility of such damages. This limitation of liability
|
||||
shall not apply to liability for death or personal injury resulting from
|
||||
such party's negligence to the extent applicable law prohibits such
|
||||
limitation. Some jurisdictions do not allow the exclusion or limitation of
|
||||
incidental or consequential damages, so this exclusion and limitation may
|
||||
not apply to You.
|
||||
|
||||
8. Litigation
|
||||
|
||||
Any litigation relating to this License may be brought only in the courts
|
||||
of a jurisdiction where the defendant maintains its principal place of
|
||||
business and such litigation shall be governed by laws of that
|
||||
jurisdiction, without reference to its conflict-of-law provisions. Nothing
|
||||
in this Section shall prevent a party's ability to bring cross-claims or
|
||||
counter-claims.
|
||||
|
||||
9. Miscellaneous
|
||||
|
||||
This License represents the complete agreement concerning the subject
|
||||
matter hereof. If any provision of this License is held to be
|
||||
unenforceable, such provision shall be reformed only to the extent
|
||||
necessary to make it enforceable. Any law or regulation which provides that
|
||||
the language of a contract shall be construed against the drafter shall not
|
||||
be used to construe this License against a Contributor.
|
||||
|
||||
|
||||
10. Versions of the License
|
||||
|
||||
10.1. New Versions
|
||||
|
||||
Mozilla Foundation is the license steward. Except as provided in Section
|
||||
10.3, no one other than the license steward has the right to modify or
|
||||
publish new versions of this License. Each version will be given a
|
||||
distinguishing version number.
|
||||
|
||||
10.2. Effect of New Versions
|
||||
|
||||
You may distribute the Covered Software under the terms of the version
|
||||
of the License under which You originally received the Covered Software,
|
||||
or under the terms of any subsequent version published by the license
|
||||
steward.
|
||||
|
||||
10.3. Modified Versions
|
||||
|
||||
If you create software not governed by this License, and you want to
|
||||
create a new license for such software, you may create and use a
|
||||
modified version of this License if you rename the license and remove
|
||||
any references to the name of the license steward (except to note that
|
||||
such modified license differs from this License).
|
||||
|
||||
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
||||
Licenses If You choose to distribute Source Code Form that is
|
||||
Incompatible With Secondary Licenses under the terms of this version of
|
||||
the License, the notice described in Exhibit B of this License must be
|
||||
attached.
|
||||
|
||||
Exhibit A - Source Code Form License Notice
|
||||
|
||||
This Source Code Form is subject to the
|
||||
terms of the Mozilla Public License, v.
|
||||
2.0. If a copy of the MPL was not
|
||||
distributed with this file, You can
|
||||
obtain one at
|
||||
http://mozilla.org/MPL/2.0/.
|
||||
|
||||
If it is not possible or desirable to put the notice in a particular file,
|
||||
then You may include the notice in a location (such as a LICENSE file in a
|
||||
relevant directory) where a recipient would be likely to look for such a
|
||||
notice.
|
||||
|
||||
You may add additional accurate notices of copyright ownership.
|
||||
|
||||
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
||||
|
||||
This Source Code Form is "Incompatible
|
||||
With Secondary Licenses", as defined by
|
||||
the Mozilla Public License, v. 2.0.
|
21
vendor/github.com/lucas-clemente/aes12/LICENSE
generated
vendored
21
vendor/github.com/lucas-clemente/aes12/LICENSE
generated
vendored
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2016 Lucas Clemente
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
21
vendor/github.com/lucas-clemente/fnv128a/LICENSE
generated
vendored
21
vendor/github.com/lucas-clemente/fnv128a/LICENSE
generated
vendored
|
@ -1,21 +0,0 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Lucas Clemente
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
87
vendor/github.com/lucas-clemente/fnv128a/fnv128a.go
generated
vendored
87
vendor/github.com/lucas-clemente/fnv128a/fnv128a.go
generated
vendored
|
@ -1,87 +0,0 @@
|
|||
// Package fnv128a implements FNV-1 and FNV-1a, non-cryptographic hash functions
|
||||
// created by Glenn Fowler, Landon Curt Noll, and Phong Vo.
|
||||
// See https://en.wikipedia.org/wiki/Fowler-Noll-Vo_hash_function.
|
||||
//
|
||||
// Write() algorithm taken and modified from github.com/romain-jacotin/quic
|
||||
package fnv128a
|
||||
|
||||
import "hash"
|
||||
|
||||
// Hash128 is the common interface implemented by all 128-bit hash functions.
|
||||
type Hash128 interface {
|
||||
hash.Hash
|
||||
Sum128() (uint64, uint64)
|
||||
}
|
||||
|
||||
type sum128a struct {
|
||||
v0, v1, v2, v3 uint64
|
||||
}
|
||||
|
||||
var _ Hash128 = &sum128a{}
|
||||
|
||||
// New1 returns a new 128-bit FNV-1a hash.Hash.
|
||||
func New() Hash128 {
|
||||
s := &sum128a{}
|
||||
s.Reset()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *sum128a) Reset() {
|
||||
s.v0 = 0x6295C58D
|
||||
s.v1 = 0x62B82175
|
||||
s.v2 = 0x07BB0142
|
||||
s.v3 = 0x6C62272E
|
||||
}
|
||||
|
||||
func (s *sum128a) Sum128() (uint64, uint64) {
|
||||
return s.v3<<32 | s.v2, s.v1<<32 | s.v0
|
||||
}
|
||||
|
||||
func (s *sum128a) Write(data []byte) (int, error) {
|
||||
var t0, t1, t2, t3 uint64
|
||||
const fnv128PrimeLow = 0x0000013B
|
||||
const fnv128PrimeShift = 24
|
||||
|
||||
for _, v := range data {
|
||||
// xor the bottom with the current octet
|
||||
s.v0 ^= uint64(v)
|
||||
|
||||
// multiply by the 128 bit FNV magic prime mod 2^128
|
||||
// fnv_prime = 309485009821345068724781371 (decimal)
|
||||
// = 0x0000000001000000000000000000013B (hexadecimal)
|
||||
// = 0x00000000 0x01000000 0x00000000 0x0000013B (in 4*32 words)
|
||||
// = 0x0 1<<fnv128PrimeShift 0x0 fnv128PrimeLow
|
||||
//
|
||||
// fnv128PrimeLow = 0x0000013B
|
||||
// fnv128PrimeShift = 24
|
||||
|
||||
// multiply by the lowest order digit base 2^32 and by the other non-zero digit
|
||||
t0 = s.v0 * fnv128PrimeLow
|
||||
t1 = s.v1 * fnv128PrimeLow
|
||||
t2 = s.v2*fnv128PrimeLow + s.v0<<fnv128PrimeShift
|
||||
t3 = s.v3*fnv128PrimeLow + s.v1<<fnv128PrimeShift
|
||||
|
||||
// propagate carries
|
||||
t1 += (t0 >> 32)
|
||||
t2 += (t1 >> 32)
|
||||
t3 += (t2 >> 32)
|
||||
|
||||
s.v0 = t0 & 0xffffffff
|
||||
s.v1 = t1 & 0xffffffff
|
||||
s.v2 = t2 & 0xffffffff
|
||||
s.v3 = t3 // & 0xffffffff
|
||||
// Doing a s.v3 &= 0xffffffff is not really needed since it simply
|
||||
// removes multiples of 2^128. We can discard these excess bits
|
||||
// outside of the loop when writing the hash in Little Endian.
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (s *sum128a) Size() int { return 16 }
|
||||
|
||||
func (s *sum128a) BlockSize() int { return 1 }
|
||||
|
||||
func (s *sum128a) Sum(in []byte) []byte {
|
||||
panic("FNV: not supported")
|
||||
}
|
21
vendor/github.com/lucas-clemente/quic-go-certificates/LICENSE
generated
vendored
21
vendor/github.com/lucas-clemente/quic-go-certificates/LICENSE
generated
vendored
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2016 Lucas Clemente
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
504
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
504
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
|
@ -2,14 +2,14 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
@ -20,37 +20,68 @@ import (
|
|||
type client struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn connection
|
||||
conn connection
|
||||
// If the client is created with DialAddr, we create a packet conn.
|
||||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
hostname string
|
||||
|
||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
packetHandlers packetHandlerManager
|
||||
|
||||
token []byte
|
||||
numRetries int
|
||||
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
tls handshake.MintTLS // only used when using TLS
|
||||
tlsConf *tls.Config
|
||||
mintConf *mint.Config
|
||||
config *Config
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
session packetHandler
|
||||
handshakeChan chan struct{}
|
||||
closeCallback func(protocol.ConnectionID)
|
||||
|
||||
session quicSession
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandler = &client{}
|
||||
|
||||
var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = utils.GenerateConnectionID
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
generateConnectionID = protocol.GenerateConnectionID
|
||||
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
|
||||
)
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
|
||||
func DialAddr(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
return DialAddrContext(context.Background(), addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrContext(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -59,7 +90,7 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
|
@ -71,16 +102,69 @@ func Dial(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
connID, err := generateConnectionID()
|
||||
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// The host parameter is used for SNI.
|
||||
func DialContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
|
||||
}
|
||||
|
||||
func dialContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
createdPacketConn bool,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
if !createdPacketConn {
|
||||
for _, v := range config.Versions {
|
||||
if v == protocol.Version44 {
|
||||
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
|
||||
}
|
||||
}
|
||||
}
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.packetHandlers = packetHandlers
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
func newClient(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
host string,
|
||||
closeCallback func(protocol.ConnectionID),
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
if hostname == "" {
|
||||
var err error
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -95,29 +179,27 @@ func Dial(
|
|||
}
|
||||
}
|
||||
}
|
||||
clientConfig := populateClientConfig(config)
|
||||
onClose := func(protocol.ConnectionID) {}
|
||||
if closeCallback != nil {
|
||||
onClose = closeCallback
|
||||
}
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger,
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
createdPacketConn: createdPacketConn,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session, nil
|
||||
return c, c.generateConnectionIDs()
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config) *Config {
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
|
@ -155,12 +237,22 @@ func populateClientConfig(config *Config) *Config {
|
|||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDLen := config.ConnectionIDLength
|
||||
if connIDLen == 0 && !createdPacketConn {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
for _, v := range versions {
|
||||
if v == protocol.Version44 {
|
||||
connIDLen = 0
|
||||
}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
ConnectionIDLength: connIDLen,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
|
@ -169,28 +261,54 @@ func populateClientConfig(config *Config) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *client) dial() error {
|
||||
func (c *client) generateConnectionIDs() error {
|
||||
connIDLen := protocol.ConnectionIDLenGQUIC
|
||||
if c.version.UsesTLS() {
|
||||
connIDLen = c.config.ConnectionIDLength
|
||||
}
|
||||
srcConnID, err := generateConnectionID(connIDLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID := srcConnID
|
||||
if c.version.UsesTLS() {
|
||||
destConnID, err = generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.srcConnID = srcConnID
|
||||
c.destConnID = destConnID
|
||||
if c.version == protocol.Version44 {
|
||||
c.srcConnID = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS()
|
||||
err = c.dialTLS(ctx)
|
||||
} else {
|
||||
err = c.dialGQUIC()
|
||||
}
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial()
|
||||
err = c.dialGQUIC(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialGQUIC() error {
|
||||
func (c *client) dialGQUIC(ctx context.Context) error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
return c.establishSecureConnection()
|
||||
err := c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialTLS() error {
|
||||
func (c *client) dialTLS(ctx context.Context) error {
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
|
@ -198,8 +316,8 @@ func (c *client) dialTLS() error {
|
|||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
||||
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
||||
DisableMigration: true,
|
||||
}
|
||||
csc := handshake.NewCryptoStreamConn(nil)
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
|
@ -207,25 +325,16 @@ func (c *client) dialTLS() error {
|
|||
}
|
||||
mintConf.ExtensionHandler = extHandler
|
||||
mintConf.ServerName = c.hostname
|
||||
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||
c.mintConf = mintConf
|
||||
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
if err != handshake.ErrCloseSessionForRetry {
|
||||
return err
|
||||
}
|
||||
c.logger.Infof("Received a Retry packet. Recreating session.")
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
|
@ -234,134 +343,114 @@ func (c *client) dialTLS() error {
|
|||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
runErr = c.session.run() // returns as soon as the session is closed
|
||||
close(errorChan)
|
||||
c.logger.Infof("Connection %x closed.", c.connectionID)
|
||||
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||
err := c.session.run() // returns as soon as the session is closed
|
||||
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
|
||||
c.conn.Close()
|
||||
}
|
||||
errorChan <- err
|
||||
}()
|
||||
|
||||
// wait until the server accepts the QUIC version (or an error occurs)
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case <-c.versionNegotiationChan:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case err := <-c.session.handshakeStatus():
|
||||
case <-ctx.Done():
|
||||
// The session will send a PeerGoingAway error to the server.
|
||||
c.session.Close()
|
||||
return ctx.Err()
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
case <-c.handshakeChan:
|
||||
// handshake successfully completed
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Listen listens on the underlying connection and passes packets on for handling.
|
||||
// It returns when the connection is closed.
|
||||
func (c *client) listen() {
|
||||
var err error
|
||||
|
||||
for {
|
||||
var n int
|
||||
var addr net.Addr
|
||||
data := *getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, addr, err = c.conn.Read(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
c.mutex.Lock()
|
||||
if c.session != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
c.handlePacket(addr, data[:n])
|
||||
func (c *client) handlePacket(p *receivedPacket) {
|
||||
if err := c.handlePacketImpl(p); err != nil {
|
||||
c.logger.Errorf("error handling packet: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||
if err != nil {
|
||||
c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||
// drop this packet if we can't parse the header
|
||||
return
|
||||
}
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||
return
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
|
||||
func (c *client) handlePacketImpl(p *receivedPacket) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.ResetFlag {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
|
||||
c.logger.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||
return
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(r)
|
||||
if err != nil {
|
||||
c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
return
|
||||
}
|
||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
return
|
||||
}
|
||||
|
||||
// handle Version Negotiation Packets
|
||||
if hdr.IsVersionNegotiation {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
return
|
||||
if p.header.IsVersionNegotiation {
|
||||
err := c.handleVersionNegotiationPacket(p.header)
|
||||
if err != nil {
|
||||
c.session.destroy(err)
|
||||
}
|
||||
|
||||
// version negotiation packets have no payload
|
||||
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||
c.session.Close(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !c.version.UsesIETFHeaderFormat() {
|
||||
connID := p.header.DestConnectionID
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
|
||||
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
|
||||
}
|
||||
if p.header.ResetFlag {
|
||||
return c.handlePublicReset(p)
|
||||
}
|
||||
} else {
|
||||
// reject packets with the wrong connection ID
|
||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
}
|
||||
|
||||
if p.header.IsLongHeader {
|
||||
switch p.header.Type {
|
||||
case protocol.PacketTypeRetry:
|
||||
c.handleRetryPacket(p.header)
|
||||
return nil
|
||||
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
|
||||
default:
|
||||
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||
if !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
close(c.versionNegotiationChan)
|
||||
}
|
||||
|
||||
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
|
||||
c.session.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packet[len(packet)-r.Len():],
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
func (c *client) handlePublicReset(p *receivedPacket) error {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return errors.New("Received a spoofed Public Reset")
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
}
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, v := range hdr.SupportedVersions {
|
||||
if v == c.version {
|
||||
// the version negotiation packet contains the version that we offered
|
||||
|
@ -372,7 +461,6 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
}
|
||||
|
||||
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
|
@ -383,49 +471,125 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
// switch to negotiated version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
var err error
|
||||
c.connectionID, err = utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
if err := c.generateConnectionIDs(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.session.destroy(errCloseSessionForNewVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewGQUICSession() (err error) {
|
||||
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||
c.logger.Debugf("<- Received Retry")
|
||||
hdr.Log(c.logger)
|
||||
// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
|
||||
// Only a server that won't send additional Retries can use shorter connection IDs.
|
||||
if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
|
||||
return
|
||||
}
|
||||
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
||||
c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
||||
return
|
||||
}
|
||||
c.numRetries++
|
||||
if c.numRetries > protocol.MaxRetries {
|
||||
c.session.destroy(qerr.CryptoTooManyRejects)
|
||||
return
|
||||
}
|
||||
c.destConnID = hdr.SrcConnectionID
|
||||
c.token = hdr.Token
|
||||
c.session.destroy(errCloseSessionForRetry)
|
||||
}
|
||||
|
||||
func (c *client) createNewGQUICSession() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newClientSession(
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
sess, err := newClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
c.logger,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.session = sess
|
||||
c.packetHandlers.Add(c.srcConnID, c)
|
||||
if c.config.RequestConnectionIDOmission {
|
||||
c.packetHandlers.Add(protocol.ConnectionID{}, c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
version protocol.VersionNumber,
|
||||
) (err error) {
|
||||
) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newTLSClientSession(
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
sess, err := newTLSClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
runner,
|
||||
c.token,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.config,
|
||||
c.tls,
|
||||
c.mintConf,
|
||||
paramsChan,
|
||||
1,
|
||||
c.logger,
|
||||
c.version,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.session = sess
|
||||
c.packetHandlers.Add(c.srcConnID, c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.Close()
|
||||
}
|
||||
|
||||
func (c *client) destroy(e error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.session == nil {
|
||||
return
|
||||
}
|
||||
c.session.destroy(e)
|
||||
}
|
||||
|
||||
func (c *client) GetVersion() protocol.VersionNumber {
|
||||
c.mutex.Lock()
|
||||
v := c.version
|
||||
c.mutex.Unlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *client) GetPerspective() protocol.Perspective {
|
||||
return protocol.PerspectiveClient
|
||||
}
|
||||
|
|
14
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStreamI interface {
|
||||
type cryptoStream interface {
|
||||
StreamID() protocol.StreamID
|
||||
io.Reader
|
||||
io.Writer
|
||||
|
@ -21,21 +21,21 @@ type cryptoStreamI interface {
|
|||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type cryptoStream struct {
|
||||
type cryptoStreamImpl struct {
|
||||
*stream
|
||||
}
|
||||
|
||||
var _ cryptoStreamI = &cryptoStream{}
|
||||
var _ cryptoStream = &cryptoStreamImpl{}
|
||||
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
|
||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||
return &cryptoStream{str}
|
||||
return &cryptoStreamImpl{str}
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
|
||||
s.receiveStream.readOffset = offset
|
||||
s.receiveStream.frameQueue.readPosition = offset
|
||||
s.receiveStream.frameQueue.readPos = offset
|
||||
}
|
||||
|
|
|
@ -5,51 +5,55 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type streamFrameSorter struct {
|
||||
queuedFrames map[protocol.ByteCount]*wire.StreamFrame
|
||||
readPosition protocol.ByteCount
|
||||
gaps *utils.ByteIntervalList
|
||||
type frameSorter struct {
|
||||
queue map[protocol.ByteCount][]byte
|
||||
readPos protocol.ByteCount
|
||||
finalOffset protocol.ByteCount
|
||||
gaps *utils.ByteIntervalList
|
||||
}
|
||||
|
||||
var (
|
||||
errTooManyGapsInReceivedStreamData = errors.New("Too many gaps in received StreamFrame data")
|
||||
errDuplicateStreamData = errors.New("Duplicate Stream Data")
|
||||
errEmptyStreamData = errors.New("Stream Data empty")
|
||||
)
|
||||
var errDuplicateStreamData = errors.New("Duplicate Stream Data")
|
||||
|
||||
func newStreamFrameSorter() *streamFrameSorter {
|
||||
s := streamFrameSorter{
|
||||
gaps: utils.NewByteIntervalList(),
|
||||
queuedFrames: make(map[protocol.ByteCount]*wire.StreamFrame),
|
||||
func newFrameSorter() *frameSorter {
|
||||
s := frameSorter{
|
||||
gaps: utils.NewByteIntervalList(),
|
||||
queue: make(map[protocol.ByteCount][]byte),
|
||||
finalOffset: protocol.MaxByteCount,
|
||||
}
|
||||
s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount})
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
|
||||
if frame.DataLen() == 0 {
|
||||
if frame.FinBit {
|
||||
s.queuedFrames[frame.Offset] = frame
|
||||
return nil
|
||||
}
|
||||
return errEmptyStreamData
|
||||
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, fin bool) error {
|
||||
err := s.push(data, offset, fin)
|
||||
if err == errDuplicateStreamData {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, fin bool) error {
|
||||
if fin {
|
||||
s.finalOffset = offset + protocol.ByteCount(len(data))
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var wasCut bool
|
||||
if oldFrame, ok := s.queuedFrames[frame.Offset]; ok {
|
||||
if frame.DataLen() <= oldFrame.DataLen() {
|
||||
if oldData, ok := s.queue[offset]; ok {
|
||||
if len(data) <= len(oldData) {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
frame.Data = frame.Data[oldFrame.DataLen():]
|
||||
frame.Offset += oldFrame.DataLen()
|
||||
data = data[len(oldData):]
|
||||
offset += protocol.ByteCount(len(oldData))
|
||||
wasCut = true
|
||||
}
|
||||
|
||||
start := frame.Offset
|
||||
end := frame.Offset + frame.DataLen()
|
||||
start := offset
|
||||
end := offset + protocol.ByteCount(len(data))
|
||||
|
||||
// skip all gaps that are before this stream frame
|
||||
var gap *utils.ByteIntervalElement
|
||||
|
@ -69,9 +73,9 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
|
|||
|
||||
if start < gap.Value.Start {
|
||||
add := gap.Value.Start - start
|
||||
frame.Offset += add
|
||||
offset += add
|
||||
start += add
|
||||
frame.Data = frame.Data[add:]
|
||||
data = data[add:]
|
||||
wasCut = true
|
||||
}
|
||||
|
||||
|
@ -89,15 +93,15 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
|
|||
break
|
||||
}
|
||||
// delete queued frames completely covered by the current frame
|
||||
delete(s.queuedFrames, endGap.Value.End)
|
||||
delete(s.queue, endGap.Value.End)
|
||||
endGap = nextEndGap
|
||||
}
|
||||
|
||||
if end > endGap.Value.End {
|
||||
cutLen := end - endGap.Value.End
|
||||
len := frame.DataLen() - cutLen
|
||||
len := protocol.ByteCount(len(data)) - cutLen
|
||||
end -= cutLen
|
||||
frame.Data = frame.Data[:len]
|
||||
data = data[:len]
|
||||
wasCut = true
|
||||
}
|
||||
|
||||
|
@ -130,32 +134,25 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
|
|||
}
|
||||
|
||||
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
|
||||
return errTooManyGapsInReceivedStreamData
|
||||
return errors.New("Too many gaps in received data")
|
||||
}
|
||||
|
||||
if wasCut {
|
||||
data := make([]byte, frame.DataLen())
|
||||
copy(data, frame.Data)
|
||||
frame.Data = data
|
||||
newData := make([]byte, len(data))
|
||||
copy(newData, data)
|
||||
data = newData
|
||||
}
|
||||
|
||||
s.queuedFrames[frame.Offset] = frame
|
||||
s.queue[offset] = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *streamFrameSorter) Pop() *wire.StreamFrame {
|
||||
frame := s.Head()
|
||||
if frame != nil {
|
||||
s.readPosition += frame.DataLen()
|
||||
delete(s.queuedFrames, frame.Offset)
|
||||
func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
|
||||
data, ok := s.queue[s.readPos]
|
||||
if !ok {
|
||||
return nil, s.readPos >= s.finalOffset
|
||||
}
|
||||
return frame
|
||||
}
|
||||
|
||||
func (s *streamFrameSorter) Head() *wire.StreamFrame {
|
||||
frame, ok := s.queuedFrames[s.readPosition]
|
||||
if ok {
|
||||
return frame
|
||||
}
|
||||
return nil
|
||||
delete(s.queue, s.readPos)
|
||||
s.readPos += protocol.ByteCount(len(data))
|
||||
return data, s.readPos >= s.finalOffset
|
||||
}
|
19
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
19
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
|
@ -77,7 +77,7 @@ func newClient(
|
|||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
logger: utils.DefaultLogger,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -172,7 +172,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
responseChan := make(chan *http.Response)
|
||||
dataStream, err := c.session.OpenStreamSync()
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(err)
|
||||
_ = c.closeWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
c.mutex.Lock()
|
||||
|
@ -187,7 +187,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
endStream := !hasBody
|
||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(err)
|
||||
_ = c.closeWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -230,7 +230,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return nil, ctx.Err()
|
||||
case <-c.headerErrored:
|
||||
// an error occurred on the header stream
|
||||
_ = c.CloseWithError(c.headerErr)
|
||||
_ = c.closeWithError(c.headerErr)
|
||||
return nil, c.headerErr
|
||||
}
|
||||
}
|
||||
|
@ -275,16 +275,19 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
|
|||
return dataStream.Close()
|
||||
}
|
||||
|
||||
// Close closes the client
|
||||
func (c *client) CloseWithError(e error) error {
|
||||
func (c *client) closeWithError(e error) error {
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.Close(e)
|
||||
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
||||
}
|
||||
|
||||
// Close closes the client
|
||||
func (c *client) Close() error {
|
||||
return c.CloseWithError(nil)
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.Close()
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
|
|
3
vendor/github.com/lucas-clemente/quic-go/h2quic/request.go
generated
vendored
3
vendor/github.com/lucas-clemente/quic-go/h2quic/request.go
generated
vendored
|
@ -70,9 +70,6 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
|
|||
}
|
||||
|
||||
func hostnameFromRequest(req *http.Request) string {
|
||||
if len(req.Host) > 0 {
|
||||
return req.Host
|
||||
}
|
||||
if req.URL != nil {
|
||||
return req.URL.Host
|
||||
}
|
||||
|
|
8
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
|
@ -8,9 +8,9 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/lex/httplex"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -65,7 +65,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
|
|||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
host, err := httplex.PunycodeHostPort(host)
|
||||
host, err := httpguts.PunycodeHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -89,11 +89,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
|
|||
// potentially pollute our hpack state. (We want to be able to
|
||||
// continue to reuse the hpack encoder for future requests)
|
||||
for k, vv := range req.Header {
|
||||
if !httplex.ValidHeaderFieldName(k) {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httplex.ValidHeaderFieldValue(v) {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||
}
|
||||
}
|
||||
|
|
3
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
3
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
|
@ -98,9 +98,6 @@ func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
|||
// test that we implement http.Flusher
|
||||
var _ http.Flusher = &responseWriter{}
|
||||
|
||||
// test that we implement http.CloseNotifier
|
||||
var _ http.CloseNotifier = &responseWriter{}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
|
|
9
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go
generated
vendored
Normal file
9
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go
generated
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
package h2quic
|
||||
|
||||
import "net/http"
|
||||
|
||||
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
|
||||
// By defining it in a separate file, we can exclude this file from staticcheck.
|
||||
|
||||
// test that we implement http.CloseNotifier
|
||||
var _ http.CloseNotifier = &responseWriter{}
|
8
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
|
||||
"golang.org/x/net/lex/httplex"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
type roundTripCloser interface {
|
||||
|
@ -80,11 +80,11 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
|||
|
||||
if req.URL.Scheme == "https" {
|
||||
for k, vv := range req.Header {
|
||||
if !httplex.ValidHeaderFieldName(k) {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httplex.ValidHeaderFieldValue(v) {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
|
||||
}
|
||||
}
|
||||
|
@ -175,5 +175,5 @@ func validMethod(method string) bool {
|
|||
|
||||
// copied from net/http/http.go
|
||||
func isNotToken(r rune) bool {
|
||||
return !httplex.IsTokenRune(r)
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
||||
|
|
24
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
|
@ -90,7 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.logger = utils.DefaultLogger
|
||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||
s.listenerMutex.Lock()
|
||||
if s.closed {
|
||||
s.listenerMutex.Unlock()
|
||||
|
@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||
func (s *Server) handleHeaderStream(session streamCreator) {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
||||
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -140,10 +140,12 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||
// QuicErrors must originate from stream.Read() returning an error.
|
||||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
if _, ok := err.(*qerr.QuicError); !ok {
|
||||
errorCode := qerr.InternalError
|
||||
if qerr, ok := err.(*qerr.QuicError); !ok {
|
||||
errorCode = qerr.ErrorCode
|
||||
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.Close(err)
|
||||
session.CloseWithError(quic.ErrorCode(errorCode), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -154,10 +156,18 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
if err != nil {
|
||||
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
}
|
||||
h2headersFrame, ok := h2frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
var h2headersFrame *http2.HeadersFrame
|
||||
switch f := h2frame.(type) {
|
||||
case *http2.PriorityFrame:
|
||||
// ignore PRIORITY frames
|
||||
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
|
||||
return nil
|
||||
case *http2.HeadersFrame:
|
||||
h2headersFrame = f
|
||||
default:
|
||||
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
|
||||
}
|
||||
|
||||
if !h2headersFrame.HeadersEnded() {
|
||||
return errors.New("http2 header continuation not implemented")
|
||||
}
|
||||
|
@ -238,7 +248,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
}
|
||||
if s.CloseAfterFirstRequest {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
session.Close(nil)
|
||||
session.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
270
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
270
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
|
@ -1,270 +0,0 @@
|
|||
package quicproxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Connection is a UDP connection
|
||||
type connection struct {
|
||||
ClientAddr *net.UDPAddr // Address of the client
|
||||
ServerConn *net.UDPConn // UDP connection to server
|
||||
|
||||
incomingPacketCounter uint64
|
||||
outgoingPacketCounter uint64
|
||||
}
|
||||
|
||||
// Direction is the direction a packet is sent.
|
||||
type Direction int
|
||||
|
||||
const (
|
||||
// DirectionIncoming is the direction from the client to the server.
|
||||
DirectionIncoming Direction = iota
|
||||
// DirectionOutgoing is the direction from the server to the client.
|
||||
DirectionOutgoing
|
||||
// DirectionBoth is both incoming and outgoing
|
||||
DirectionBoth
|
||||
)
|
||||
|
||||
func (d Direction) String() string {
|
||||
switch d {
|
||||
case DirectionIncoming:
|
||||
return "incoming"
|
||||
case DirectionOutgoing:
|
||||
return "outgoing"
|
||||
case DirectionBoth:
|
||||
return "both"
|
||||
default:
|
||||
panic("unknown direction")
|
||||
}
|
||||
}
|
||||
|
||||
// Is says if one direction matches another direction.
|
||||
// For example, incoming matches both incoming and both, but not outgoing.
|
||||
func (d Direction) Is(dir Direction) bool {
|
||||
if d == DirectionBoth || dir == DirectionBoth {
|
||||
return true
|
||||
}
|
||||
return d == dir
|
||||
}
|
||||
|
||||
// DropCallback is a callback that determines which packet gets dropped.
|
||||
type DropCallback func(dir Direction, packetCount uint64) bool
|
||||
|
||||
// NoDropper doesn't drop packets.
|
||||
var NoDropper DropCallback = func(Direction, uint64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
||||
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
|
||||
|
||||
// NoDelay doesn't apply a delay.
|
||||
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Opts are proxy options.
|
||||
type Opts struct {
|
||||
// The address this proxy proxies packets to.
|
||||
RemoteAddr string
|
||||
// DropPacket determines whether a packet gets dropped.
|
||||
DropPacket DropCallback
|
||||
// DelayPacket determines how long a packet gets delayed. This allows
|
||||
// simulating a connection with non-zero RTTs.
|
||||
// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
|
||||
DelayPacket DelayCallback
|
||||
}
|
||||
|
||||
// QuicProxy is a QUIC proxy that can drop and delay packets.
|
||||
type QuicProxy struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
conn *net.UDPConn
|
||||
serverAddr *net.UDPAddr
|
||||
|
||||
dropPacket DropCallback
|
||||
delayPacket DelayCallback
|
||||
|
||||
// Mapping from client addresses (as host:port) to connection
|
||||
clientDict map[string]*connection
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// NewQuicProxy creates a new UDP proxy
|
||||
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
|
||||
if opts == nil {
|
||||
opts = &Opts{}
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packetDropper := NoDropper
|
||||
if opts.DropPacket != nil {
|
||||
packetDropper = opts.DropPacket
|
||||
}
|
||||
|
||||
packetDelayer := NoDelay
|
||||
if opts.DelayPacket != nil {
|
||||
packetDelayer = opts.DelayPacket
|
||||
}
|
||||
|
||||
p := QuicProxy{
|
||||
clientDict: make(map[string]*connection),
|
||||
conn: conn,
|
||||
serverAddr: raddr,
|
||||
dropPacket: packetDropper,
|
||||
delayPacket: packetDelayer,
|
||||
version: version,
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
|
||||
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
|
||||
go p.runProxy()
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Close stops the UDP Proxy
|
||||
func (p *QuicProxy) Close() error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
for _, c := range p.clientDict {
|
||||
if err := c.ServerConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
// LocalAddr is the address the proxy is listening on.
|
||||
func (p *QuicProxy) LocalAddr() net.Addr {
|
||||
return p.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// LocalPort is the UDP port number the proxy is listening on.
|
||||
func (p *QuicProxy) LocalPort() int {
|
||||
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
||||
}
|
||||
|
||||
func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||
srvudp, err := net.DialUDP("udp", nil, p.serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &connection{
|
||||
ClientAddr: cliAddr,
|
||||
ServerConn: srvudp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runProxy listens on the proxy address and handles incoming packets.
|
||||
func (p *QuicProxy) runProxy() error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
saddr := cliaddr.String()
|
||||
p.mutex.Lock()
|
||||
conn, ok := p.clientDict[saddr]
|
||||
|
||||
if !ok {
|
||||
conn, err = p.newConnection(cliaddr)
|
||||
if err != nil {
|
||||
p.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
p.clientDict[saddr] = conn
|
||||
go p.runConnection(conn)
|
||||
}
|
||||
p.mutex.Unlock()
|
||||
|
||||
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||
|
||||
if p.dropPacket(DirectionIncoming, packetCount) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet to the server
|
||||
delay := p.delayPacket(DirectionIncoming, packetCount)
|
||||
if delay != 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
|
||||
}
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
_, _ = conn.ServerConn.Write(raw)
|
||||
})
|
||||
} else {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
|
||||
}
|
||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runConnection handles packets from server to a single client
|
||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||
|
||||
if p.dropPacket(DirectionOutgoing, packetCount) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
||||
if delay != 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
|
||||
}
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
|
||||
})
|
||||
} else {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
|
||||
}
|
||||
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
41
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
41
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
|
@ -1,41 +0,0 @@
|
|||
package testlog
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var (
|
||||
logFileName string // the log file set in the ginkgo flags
|
||||
logFile *os.File
|
||||
)
|
||||
|
||||
// read the logfile command line flag
|
||||
// to set call ginkgo -- -logfile=log.txt
|
||||
func init() {
|
||||
flag.StringVar(&logFileName, "logfile", "", "log file")
|
||||
}
|
||||
|
||||
var _ = BeforeEach(func() {
|
||||
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
|
||||
|
||||
if len(logFileName) > 0 {
|
||||
var err error
|
||||
logFile, err = os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
log.SetOutput(logFile)
|
||||
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
|
||||
}
|
||||
})
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
if len(logFileName) > 0 {
|
||||
_ = logFile.Close()
|
||||
}
|
||||
})
|
119
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
119
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
|
@ -1,119 +0,0 @@
|
|||
package testserver
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
dataLen = 500 * 1024 // 500 KB
|
||||
dataLenLong = 50 * 1024 * 1024 // 50 MB
|
||||
)
|
||||
|
||||
var (
|
||||
// PRData contains dataLen bytes of pseudo-random data.
|
||||
PRData = GeneratePRData(dataLen)
|
||||
// PRDataLong contains dataLenLong bytes of pseudo-random data.
|
||||
PRDataLong = GeneratePRData(dataLenLong)
|
||||
|
||||
server *h2quic.Server
|
||||
stoppedServing chan struct{}
|
||||
port string
|
||||
)
|
||||
|
||||
func init() {
|
||||
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
sl := r.URL.Query().Get("len")
|
||||
if sl != "" {
|
||||
var err error
|
||||
l, err := strconv.Atoi(sl)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = w.Write(GeneratePRData(l))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
} else {
|
||||
_, err := w.Write(PRData)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
_, err := w.Write(PRDataLong)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
_, err := io.WriteString(w, "Hello, World!\n")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = w.Write(body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
}
|
||||
|
||||
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
|
||||
func GeneratePRData(l int) []byte {
|
||||
res := make([]byte, l)
|
||||
seed := uint64(1)
|
||||
for i := 0; i < l; i++ {
|
||||
seed = seed * 48271 % 2147483647
|
||||
res[i] = byte(seed)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// StartQuicServer starts a h2quic.Server.
|
||||
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
||||
func StartQuicServer(versions []protocol.VersionNumber) {
|
||||
server = &h2quic.Server{
|
||||
Server: &http.Server{
|
||||
TLSConfig: testdata.GetTLSConfig(),
|
||||
},
|
||||
QuicConfig: &quic.Config{
|
||||
Versions: versions,
|
||||
},
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
stoppedServing = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.Serve(conn)
|
||||
close(stoppedServing)
|
||||
}()
|
||||
}
|
||||
|
||||
// StopQuicServer stops the h2quic.Server.
|
||||
func StopQuicServer() {
|
||||
Expect(server.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing).Should(BeClosed())
|
||||
}
|
||||
|
||||
// Port returns the UDP port of the QUIC server.
|
||||
func Port() string {
|
||||
return port
|
||||
}
|
26
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
|
@ -16,8 +16,16 @@ type StreamID = protocol.StreamID
|
|||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
// VersionGQUIC39 is gQUIC version 39.
|
||||
const VersionGQUIC39 = protocol.Version39
|
||||
const (
|
||||
// VersionGQUIC39 is gQUIC version 39.
|
||||
VersionGQUIC39 = protocol.Version39
|
||||
// VersionGQUIC43 is gQUIC version 43.
|
||||
VersionGQUIC43 = protocol.Version43
|
||||
// VersionGQUIC44 is gQUIC version 44.
|
||||
VersionGQUIC44 = protocol.Version44
|
||||
// VersionMilestone0_10_0 uses TLS
|
||||
VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0
|
||||
)
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
@ -139,8 +147,11 @@ type Session interface {
|
|||
LocalAddr() net.Addr
|
||||
// RemoteAddr returns the address of the peer.
|
||||
RemoteAddr() net.Addr
|
||||
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
|
||||
Close(error) error
|
||||
// Close the connection.
|
||||
io.Closer
|
||||
// Close the connection with an error.
|
||||
// The error must not be nil.
|
||||
CloseWithError(ErrorCode, error) error
|
||||
// The context is cancelled when the session is closed.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Context() context.Context
|
||||
|
@ -159,6 +170,13 @@ type Config struct {
|
|||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||
// Currently only valid for the client.
|
||||
RequestConnectionIDOmission bool
|
||||
// The length of the connection ID in bytes. Only valid for IETF QUIC.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If not set, the interpretation depends on where the Config is used:
|
||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
|
||||
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
|
||||
ConnectionIDLength int
|
||||
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 10 seconds.
|
||||
|
|
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
|
@ -29,7 +29,8 @@ type SentPacketHandler interface {
|
|||
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
DequeuePacketForRetransmission() *Packet
|
||||
DequeueProbePacket() (*Packet, error)
|
||||
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
|
|
60
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
60
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
|
@ -25,6 +25,8 @@ type receivedPacketHandler struct {
|
|||
ackAlarm time.Time
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
|
@ -52,11 +54,16 @@ const (
|
|||
)
|
||||
|
||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||
func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler {
|
||||
func NewReceivedPacketHandler(
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
ackSendDelay: ackSendDelay,
|
||||
rttStats: rttStats,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
@ -82,16 +89,22 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
|
|||
// IgnoreBelow sets a lower limit for acking packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
||||
if p <= h.ignoreBelow {
|
||||
return
|
||||
}
|
||||
h.ignoreBelow = p
|
||||
h.packetHistory.DeleteBelow(p)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
|
||||
}
|
||||
}
|
||||
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil {
|
||||
if h.lastAck == nil || p < h.ignoreBelow {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p)
|
||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
|
||||
|
@ -99,7 +112,7 @@ func (h *receivedPacketHandler) hasNewMissingPackets() bool {
|
|||
return false
|
||||
}
|
||||
highestRange := h.packetHistory.GetHighestAckRange()
|
||||
return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing
|
||||
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
|
||||
}
|
||||
|
||||
// maybeQueueAck queues an ACK, if necessary.
|
||||
|
@ -110,6 +123,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
|
||||
// always ack the first packet
|
||||
if h.lastAck == nil {
|
||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||
h.ackQueued = true
|
||||
return
|
||||
}
|
||||
|
@ -118,6 +132,9 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||
if wasMissing {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
|
||||
}
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
|
@ -128,26 +145,41 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
// ack up to 10 packets at once
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
|
||||
h.ackQueued = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
|
||||
}
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
|
||||
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
|
||||
h.ackAlarm = rcvTime.Add(ackDelay)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// send an ACK every 2 retransmittable packets
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
|
||||
}
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
|
||||
}
|
||||
h.ackAlarm = rcvTime.Add(ackSendDelay)
|
||||
}
|
||||
}
|
||||
// If there are new missing packets to report, set a short timer to send an ACK.
|
||||
if h.hasNewMissingPackets() {
|
||||
// wait the minimum of 1/8 min RTT and the existing ack time
|
||||
ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)
|
||||
ackTime := rcvTime.Add(time.Duration(ackDelay))
|
||||
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
|
||||
ackTime := rcvTime.Add(ackDelay)
|
||||
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
|
||||
h.ackAlarm = ackTime
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -159,19 +191,17 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
|
||||
now := time.Now()
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ackRanges := h.packetHistory.GetAckRanges()
|
||||
ack := &wire.AckFrame{
|
||||
LargestAcked: h.largestObserved,
|
||||
LowestAcked: ackRanges[len(ackRanges)-1].First,
|
||||
PacketReceivedTime: h.largestObservedReceivedTime,
|
||||
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||
}
|
||||
|
||||
if len(ackRanges) > 1 {
|
||||
ack.AckRanges = ackRanges
|
||||
ack := &wire.AckFrame{
|
||||
AckRanges: h.packetHistory.GetAckRanges(),
|
||||
DelayTime: now.Sub(h.largestObservedReceivedTime),
|
||||
}
|
||||
|
||||
h.lastAck = ack
|
||||
|
|
|
@ -104,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
|
|||
ackRanges := make([]wire.AckRange, h.ranges.Len())
|
||||
i := 0
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
|
||||
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
|
||||
i++
|
||||
}
|
||||
return ackRanges
|
||||
|
@ -114,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
|||
ackRange := wire.AckRange{}
|
||||
if h.ranges.Len() > 0 {
|
||||
r := h.ranges.Back().Value
|
||||
ackRange.First = r.Start
|
||||
ackRange.Last = r.End
|
||||
ackRange.Smallest = r.Start
|
||||
ackRange.Largest = r.End
|
||||
}
|
||||
return ackRange
|
||||
}
|
||||
|
|
6
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
6
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
|
@ -14,7 +14,9 @@ const (
|
|||
SendRetransmission
|
||||
// SendRTO means that an RTO probe packet should be sent
|
||||
SendRTO
|
||||
// SendAny packet should be sent
|
||||
// SendTLP means that a TLP probe packet should be sent
|
||||
SendTLP
|
||||
// SendAny means that any packet should be sent
|
||||
SendAny
|
||||
)
|
||||
|
||||
|
@ -28,6 +30,8 @@ func (s SendMode) String() string {
|
|||
return "retransmission"
|
||||
case SendRTO:
|
||||
return "rto"
|
||||
case SendTLP:
|
||||
return "tlp"
|
||||
case SendAny:
|
||||
return "any"
|
||||
default:
|
||||
|
|
207
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
207
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
|
@ -1,6 +1,7 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
@ -16,13 +17,12 @@ const (
|
|||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||
// In fraction of an RTT.
|
||||
timeReorderingFraction = 1.0 / 8
|
||||
// The default RTT used before an RTT sample is taken.
|
||||
// Note: This constant is also defined in the congestion package.
|
||||
defaultInitialRTT = 100 * time.Millisecond
|
||||
// defaultRTOTimeout is the RTO time on new connections
|
||||
defaultRTOTimeout = 500 * time.Millisecond
|
||||
// Minimum time in the future a tail loss probe alarm may be set for.
|
||||
minTPLTimeout = 10 * time.Millisecond
|
||||
// Maximum number of tail loss probes before an RTO fires.
|
||||
maxTLPs = 2
|
||||
// Minimum time in the future an RTO alarm may be set for.
|
||||
minRTOTimeout = 200 * time.Millisecond
|
||||
// maxRTOTimeout is the maximum RTO time
|
||||
|
@ -59,6 +59,10 @@ type sentPacketHandler struct {
|
|||
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
||||
handshakeCount uint32
|
||||
|
||||
// The number of times a TLP has been sent without receiving an ack.
|
||||
tlpCount uint32
|
||||
allowTLP bool
|
||||
|
||||
// The number of times an RTO has been sent without receiving an ack.
|
||||
rtoCount uint32
|
||||
// The number of RTO probe packets that should be sent.
|
||||
|
@ -71,10 +75,12 @@ type sentPacketHandler struct {
|
|||
alarm time.Time
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
|
@ -89,6 +95,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se
|
|||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -100,6 +107,7 @@ func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||
|
@ -150,7 +158,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
|
|||
|
||||
if len(packet.Frames) > 0 {
|
||||
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||
packet.largestAcked = ackFrame.LargestAcked
|
||||
packet.largestAcked = ackFrame.LargestAcked()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -168,6 +176,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
|
|||
if h.numRTOs > 0 {
|
||||
h.numRTOs--
|
||||
}
|
||||
h.allowTLP = false
|
||||
}
|
||||
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
|
||||
|
||||
|
@ -176,7 +185,8 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||
if ackFrame.LargestAcked > h.lastSentPacketNumber {
|
||||
largestAcked := ackFrame.LargestAcked()
|
||||
if largestAcked > h.lastSentPacketNumber {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||
}
|
||||
|
||||
|
@ -186,13 +196,13 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
return nil
|
||||
}
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, ackFrame.LargestAcked)
|
||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
}
|
||||
|
||||
if rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
|
||||
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
|
||||
|
@ -212,11 +222,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
if p.largestAcked != 0 {
|
||||
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
|
||||
}
|
||||
if err := h.onPacketAcked(p); err != nil {
|
||||
if err := h.onPacketAcked(p, rcvTime); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.includedInBytesInFlight {
|
||||
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight)
|
||||
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -238,27 +248,29 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
|
|||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
|
||||
var ackedPackets []*Packet
|
||||
ackRangeIndex := 0
|
||||
lowestAcked := ackFrame.LowestAcked()
|
||||
largestAcked := ackFrame.LargestAcked()
|
||||
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
// Ignore packets below the LowestAcked
|
||||
if p.PacketNumber < ackFrame.LowestAcked {
|
||||
// Ignore packets below the lowest acked
|
||||
if p.PacketNumber < lowestAcked {
|
||||
return true, nil
|
||||
}
|
||||
// Break after LargestAcked is reached
|
||||
if p.PacketNumber > ackFrame.LargestAcked {
|
||||
// Break after largest acked is reached
|
||||
if p.PacketNumber > largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ackFrame.HasMissingRanges() {
|
||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for p.PacketNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if p.PacketNumber >= ackRange.First { // packet i contained in ACK range
|
||||
if p.PacketNumber > ackRange.Last {
|
||||
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.First, ackRange.Last)
|
||||
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
|
||||
if p.PacketNumber > ackRange.Largest {
|
||||
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||
}
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
}
|
||||
|
@ -267,12 +279,22 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame)
|
|||
}
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(ackedPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(ackedPackets))
|
||||
for i, p := range ackedPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
|
||||
}
|
||||
return ackedPackets, err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
|
||||
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -280,20 +302,25 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
|
|||
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if h.packetHistory.Len() == 0 {
|
||||
if !h.packetHistory.HasOutstandingPackets() {
|
||||
h.alarm = time.Time{}
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(#497): TLP
|
||||
if !h.handshakeComplete {
|
||||
if h.packetHistory.HasOutstandingHandshakePackets() {
|
||||
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = h.lossTime
|
||||
} else {
|
||||
// RTO
|
||||
h.alarm = h.lastSentRetransmittablePacketTime.Add(h.computeRTOTimeout())
|
||||
// RTO or TLP alarm
|
||||
alarmDuration := h.computeRTOTimeout()
|
||||
if h.tlpCount < maxTLPs {
|
||||
tlpAlarm := h.computeTLPTimeout()
|
||||
// if the RTO duration is shorter than the TLP duration, use the RTO duration
|
||||
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
|
||||
}
|
||||
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -313,11 +340,21 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto
|
|||
if timeSinceSent > delayUntilLost {
|
||||
lostPackets = append(lostPackets, packet)
|
||||
} else if h.lossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
|
||||
}
|
||||
// Note: This conditional is only entered once per call
|
||||
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(lostPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(lostPackets))
|
||||
for i, p := range lostPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
|
||||
}
|
||||
|
||||
for _, p := range lostPackets {
|
||||
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
|
||||
|
@ -327,7 +364,6 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto
|
|||
}
|
||||
if p.canBeRetransmitted {
|
||||
// queue the packet for retransmission, and report the loss to the congestion controller
|
||||
h.logger.Debugf("\tQueueing packet %#x because it was detected lost", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -338,34 +374,57 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) OnAlarm() error {
|
||||
now := time.Now()
|
||||
|
||||
// TODO(#497): TLP
|
||||
var err error
|
||||
if !h.handshakeComplete {
|
||||
h.handshakeCount++
|
||||
err = h.queueHandshakePacketsForRetransmission()
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit or time loss detection
|
||||
err = h.detectLostPackets(now, h.bytesInFlight)
|
||||
} else {
|
||||
// RTO
|
||||
h.rtoCount++
|
||||
h.numRTOs += 2
|
||||
err = h.queueRTOs()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
// When all outstanding are acknowledged, the alarm is canceled in
|
||||
// updateLossDetectionAlarm. This doesn't reset the timer in the session though.
|
||||
// When OnAlarm is called, we therefore need to make sure that there are
|
||||
// actually packets outstanding.
|
||||
if h.packetHistory.HasOutstandingPackets() {
|
||||
if err := h.onVerifiedAlarm(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) onVerifiedAlarm() error {
|
||||
var err error
|
||||
if h.packetHistory.HasOutstandingHandshakePackets() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
|
||||
}
|
||||
h.handshakeCount++
|
||||
err = h.queueHandshakePacketsForRetransmission()
|
||||
} else if !h.lossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
|
||||
}
|
||||
// Early retransmit or time loss detection
|
||||
err = h.detectLostPackets(time.Now(), h.bytesInFlight)
|
||||
} else if h.tlpCount < maxTLPs { // TLP
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount)
|
||||
}
|
||||
h.allowTLP = true
|
||||
h.tlpCount++
|
||||
} else { // RTO
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount)
|
||||
}
|
||||
if h.rtoCount == 0 {
|
||||
h.largestSentBeforeRTO = h.lastSentPacketNumber
|
||||
}
|
||||
h.rtoCount++
|
||||
h.numRTOs += 2
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
return h.alarm
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
|
||||
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
|
||||
// This happens if a packet and its retransmissions is acked in the same ACK.
|
||||
// As soon as we process the first one, this will remove all the retransmissions,
|
||||
// so we won't find the retransmitted packet number later.
|
||||
|
@ -404,8 +463,8 @@ func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
|
|||
return err
|
||||
}
|
||||
h.rtoCount = 0
|
||||
h.tlpCount = 0
|
||||
h.handshakeCount = 0
|
||||
// TODO(#497): h.tlpCount = 0
|
||||
return h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
|
||||
|
@ -447,8 +506,21 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
|||
return packet
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
||||
if len(h.retransmissionQueue) == 0 {
|
||||
p := h.packetHistory.FirstOutstanding()
|
||||
if p == nil {
|
||||
return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
|
||||
}
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return h.DequeuePacketForRetransmission(), nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
|
@ -463,15 +535,22 @@ func (h *sentPacketHandler) SendMode() SendMode {
|
|||
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
|
||||
// but still allow sending of retransmissions and ACKs.
|
||||
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
|
||||
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||
}
|
||||
return SendNone
|
||||
}
|
||||
if h.allowTLP {
|
||||
return SendTLP
|
||||
}
|
||||
if h.numRTOs > 0 {
|
||||
return SendRTO
|
||||
}
|
||||
// Only send ACKs if we're congestion limited.
|
||||
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
|
||||
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
// Send retransmissions first, if there are any.
|
||||
|
@ -479,7 +558,9 @@ func (h *sentPacketHandler) SendMode() SendMode {
|
|||
return SendRetransmission
|
||||
}
|
||||
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
|
||||
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
return SendAny
|
||||
|
@ -501,23 +582,6 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
|||
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
|
||||
}
|
||||
|
||||
// retransmit the oldest two packets
|
||||
func (h *sentPacketHandler) queueRTOs() error {
|
||||
h.largestSentBeforeRTO = h.lastSentPacketNumber
|
||||
// Queue the first two outstanding packets for retransmission.
|
||||
// This does NOT declare this packets as lost:
|
||||
// They are still tracked in the packet history and count towards the bytes in flight.
|
||||
for i := 0; i < 2; i++ {
|
||||
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||
h.logger.Debugf("\tQueueing packet %#x for retransmission (RTO)", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
|
@ -527,7 +591,7 @@ func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
|||
return true, nil
|
||||
})
|
||||
for _, p := range handshakePackets {
|
||||
h.logger.Debugf("\tQueueing packet %#x as a handshake retransmission", p.PacketNumber)
|
||||
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -548,16 +612,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||
duration := 2 * h.rttStats.SmoothedRTT()
|
||||
if duration == 0 {
|
||||
duration = 2 * defaultInitialRTT
|
||||
}
|
||||
duration = utils.MaxDuration(duration, minTPLTimeout)
|
||||
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
|
||||
// exponential backoff
|
||||
// There's an implicit limit to this set by the handshake timeout.
|
||||
return duration << h.handshakeCount
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
|
||||
// TODO(#1236): include the max_ack_delay
|
||||
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||
var rto time.Duration
|
||||
rtt := h.rttStats.SmoothedRTT()
|
||||
|
|
41
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
41
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
|
@ -10,6 +10,9 @@ type sentPacketHistory struct {
|
|||
packetList *PacketList
|
||||
packetMap map[protocol.PacketNumber]*PacketElement
|
||||
|
||||
numOutstandingPackets int
|
||||
numOutstandingHandshakePackets int
|
||||
|
||||
firstOutstanding *PacketElement
|
||||
}
|
||||
|
||||
|
@ -30,6 +33,12 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
|||
if h.firstOutstanding == nil {
|
||||
h.firstOutstanding = el
|
||||
}
|
||||
if p.canBeRetransmitted {
|
||||
h.numOutstandingPackets++
|
||||
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.numOutstandingHandshakePackets++
|
||||
}
|
||||
}
|
||||
return el
|
||||
}
|
||||
|
||||
|
@ -92,6 +101,18 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
|
|||
if !ok {
|
||||
return fmt.Errorf("sent packet history: packet %d not found", pn)
|
||||
}
|
||||
if el.Value.canBeRetransmitted {
|
||||
h.numOutstandingPackets--
|
||||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
el.Value.canBeRetransmitted = false
|
||||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
|
@ -121,7 +142,27 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
|||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
}
|
||||
if el.Value.canBeRetransmitted {
|
||||
h.numOutstandingPackets--
|
||||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
h.packetList.Remove(el)
|
||||
delete(h.packetMap, p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingPackets() bool {
|
||||
return h.numOutstandingPackets > 0
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
|
||||
return h.numOutstandingHandshakePackets > 0
|
||||
}
|
||||
|
|
5
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go
generated
vendored
5
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go
generated
vendored
|
@ -30,8 +30,9 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFr
|
|||
}
|
||||
|
||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||
if ack.LargestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = ack.LargestAcked + 1
|
||||
largestAcked := ack.LargestAcked()
|
||||
if largestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = largestAcked + 1
|
||||
}
|
||||
}
|
||||
|
||||
|
|
146
vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go
generated
vendored
146
vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go
generated
vendored
|
@ -16,11 +16,10 @@ import (
|
|||
// allow a 10 shift right to divide.
|
||||
|
||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling
|
||||
// round trip time.
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const cubeScale = 40
|
||||
const cubeCongestionWindowScale = 410
|
||||
const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale
|
||||
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
|
||||
|
||||
const defaultNumConnections = 2
|
||||
|
||||
|
@ -32,39 +31,35 @@ const beta float32 = 0.7
|
|||
// new concurrent flows and speed up convergence.
|
||||
const betaLastMax float32 = 0.85
|
||||
|
||||
// If true, Cubic's epoch is shifted when the sender is application-limited.
|
||||
const shiftQuicCubicEpochWhenAppLimited = true
|
||||
|
||||
const maxCubicTimeInterval = 30 * time.Millisecond
|
||||
|
||||
// Cubic implements the cubic algorithm from TCP
|
||||
type Cubic struct {
|
||||
clock Clock
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// Time when this cycle started, after last loss event.
|
||||
epoch time.Time
|
||||
// Time when sender went into application-limited period. Zero if not in
|
||||
// application-limited period.
|
||||
appLimitedStartTime time.Time
|
||||
// Time when we updated last_congestion_window.
|
||||
lastUpdateTime time.Time
|
||||
// Last congestion window (in packets) used.
|
||||
lastCongestionWindow protocol.PacketNumber
|
||||
// Max congestion window (in packets) used just before last loss event.
|
||||
|
||||
// Max congestion window used just before last loss event.
|
||||
// Note: to improve fairness to other streams an additional back off is
|
||||
// applied to this value if the new value is below our latest value.
|
||||
lastMaxCongestionWindow protocol.PacketNumber
|
||||
// Number of acked packets since the cycle started (epoch).
|
||||
ackedPacketsCount protocol.PacketNumber
|
||||
lastMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Number of acked bytes since the cycle started (epoch).
|
||||
ackedBytesCount protocol.ByteCount
|
||||
|
||||
// TCP Reno equivalent congestion window in packets.
|
||||
estimatedTCPcongestionWindow protocol.PacketNumber
|
||||
estimatedTCPcongestionWindow protocol.ByteCount
|
||||
|
||||
// Origin point of cubic function.
|
||||
originPointCongestionWindow protocol.PacketNumber
|
||||
originPointCongestionWindow protocol.ByteCount
|
||||
|
||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||
timeToOriginPoint uint32
|
||||
|
||||
// Last congestion window in packets computed by cubic function.
|
||||
lastTargetCongestionWindow protocol.PacketNumber
|
||||
lastTargetCongestionWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
// NewCubic returns a new Cubic instance
|
||||
|
@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic {
|
|||
// Reset is called after a timeout to reset the cubic state
|
||||
func (c *Cubic) Reset() {
|
||||
c.epoch = time.Time{}
|
||||
c.appLimitedStartTime = time.Time{}
|
||||
c.lastUpdateTime = time.Time{}
|
||||
c.lastCongestionWindow = 0
|
||||
c.lastMaxCongestionWindow = 0
|
||||
c.ackedPacketsCount = 0
|
||||
c.ackedBytesCount = 0
|
||||
c.estimatedTCPcongestionWindow = 0
|
||||
c.originPointCongestionWindow = 0
|
||||
c.timeToOriginPoint = 0
|
||||
|
@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 {
|
|||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
func (c *Cubic) betaLastMax() float32 {
|
||||
// betaLastMax is the additional backoff factor after loss for our
|
||||
// N-connection emulation, which emulates the additional backoff of
|
||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||
// effective multiplier is computed as:
|
||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||
// the available congestion window. Resets Cubic state during quiescence.
|
||||
func (c *Cubic) OnApplicationLimited() {
|
||||
if shiftQuicCubicEpochWhenAppLimited {
|
||||
// When sender is not using the available congestion window, Cubic's epoch
|
||||
// should not continue growing. Record the time when sender goes into an
|
||||
// app-limited period here, to compensate later when cwnd growth happens.
|
||||
if c.appLimitedStartTime.IsZero() {
|
||||
c.appLimitedStartTime = c.clock.Now()
|
||||
}
|
||||
} else {
|
||||
// When sender is not using the available congestion window, Cubic's epoch
|
||||
// should not continue growing. Reset the epoch when in such a period.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
// When sender is not using the available congestion window, the window does
|
||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||
// using the entire window during the time since the beginning of the current
|
||||
// "epoch" (the end of the last loss recovery period). Since
|
||||
// application-limited periods break this assumption, we reset the epoch when
|
||||
// in such a period. This reset effectively freezes congestion window growth
|
||||
// through application-limited periods and allows Cubic growth to continue
|
||||
// when the entire window is being used.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
|
||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||
// a loss event. Returns the new congestion window in packets. The new
|
||||
// congestion window is a multiplicative decrease of our current window.
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber {
|
||||
if currentCongestionWindow < c.lastMaxCongestionWindow {
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
|
||||
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
|
||||
// We never reached the old max, so assume we are competing with another
|
||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||
c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow))
|
||||
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||
} else {
|
||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||
}
|
||||
c.epoch = time.Time{} // Reset time.
|
||||
return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta())
|
||||
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||
}
|
||||
|
||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||
// Returns the new congestion window in packets. The new congestion window
|
||||
// follows a cubic function that depends on the time passed since last
|
||||
// packet loss.
|
||||
func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber {
|
||||
c.ackedPacketsCount++ // Packets acked.
|
||||
currentTime := c.clock.Now()
|
||||
|
||||
// Cubic is "independent" of RTT, the update is limited by the time elapsed.
|
||||
if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) {
|
||||
return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow)
|
||||
}
|
||||
c.lastCongestionWindow = currentCongestionWindow
|
||||
c.lastUpdateTime = currentTime
|
||||
func (c *Cubic) CongestionWindowAfterAck(
|
||||
ackedBytes protocol.ByteCount,
|
||||
currentCongestionWindow protocol.ByteCount,
|
||||
delayMin time.Duration,
|
||||
eventTime time.Time,
|
||||
) protocol.ByteCount {
|
||||
c.ackedBytesCount += ackedBytes
|
||||
|
||||
if c.epoch.IsZero() {
|
||||
// First ACK after a loss event.
|
||||
c.epoch = currentTime // Start of epoch.
|
||||
c.ackedPacketsCount = 1 // Reset count.
|
||||
c.epoch = eventTime // Start of epoch.
|
||||
c.ackedBytesCount = ackedBytes // Reset count.
|
||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||
|
@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
|
|||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||
}
|
||||
} else {
|
||||
// If sender was app-limited, then freeze congestion window growth during
|
||||
// app-limited period. Continue growth now by shifting the epoch-start
|
||||
// through the app-limited period.
|
||||
if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() {
|
||||
shift := currentTime.Sub(c.appLimitedStartTime)
|
||||
c.epoch = c.epoch.Add(shift)
|
||||
c.appLimitedStartTime = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||
// the round trip time in account. This is done to allow us to use shift as a
|
||||
// divide operator.
|
||||
elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000
|
||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||
|
||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||
// Right-shifts of negative, signed numbers have
|
||||
// implementation-dependent behavior. Force the offset to be
|
||||
// positive, similar to the kernel implementation.
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale)
|
||||
var targetCongestionWindow protocol.PacketNumber
|
||||
|
||||
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
|
||||
var targetCongestionWindow protocol.ByteCount
|
||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||
} else {
|
||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||
}
|
||||
// With dynamic beta/alpha based on number of active streams, it is possible
|
||||
// for the required_ack_count to become much lower than acked_packets_count_
|
||||
// suddenly, leading to more than one iteration through the following loop.
|
||||
for {
|
||||
// Update estimated TCP congestion_window.
|
||||
requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha())
|
||||
if c.ackedPacketsCount < requiredAckCount {
|
||||
break
|
||||
}
|
||||
c.ackedPacketsCount -= requiredAckCount
|
||||
c.estimatedTCPcongestionWindow++
|
||||
}
|
||||
// Limit the CWND increase to half the acked bytes.
|
||||
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||
|
||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||
// time we ack an estimated tcp window of bytes. For small
|
||||
// congestion windows (less than 25), the formula below will
|
||||
// increase slightly slower than linearly per estimated tcp window
|
||||
// of bytes.
|
||||
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
|
||||
c.ackedBytesCount = 0
|
||||
|
||||
// We have a new cubic congestion window.
|
||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||
|
@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
|
|||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||
}
|
||||
|
||||
return targetCongestionWindow
|
||||
}
|
||||
|
||||
|
|
124
vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go
generated
vendored
124
vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go
generated
vendored
|
@ -8,9 +8,9 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
maxBurstBytes = 3 * protocol.DefaultTCPMSS
|
||||
defaultMinimumCongestionWindow protocol.PacketNumber = 2
|
||||
renoBeta float32 = 0.7 // Reno backoff factor.
|
||||
maxBurstBytes = 3 * protocol.DefaultTCPMSS
|
||||
renoBeta float32 = 0.7 // Reno backoff factor.
|
||||
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
|
||||
)
|
||||
|
||||
type cubicSender struct {
|
||||
|
@ -31,12 +31,6 @@ type cubicSender struct {
|
|||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||
largestSentAtLastCutback protocol.PacketNumber
|
||||
|
||||
// Congestion window in packets.
|
||||
congestionWindow protocol.PacketNumber
|
||||
|
||||
// Slow start congestion window in packets, aka ssthresh.
|
||||
slowstartThreshold protocol.PacketNumber
|
||||
|
||||
// Whether the last loss event caused us to exit slowstart.
|
||||
// Used for stats collection of slowstartPacketsLost
|
||||
lastCutbackExitedSlowstart bool
|
||||
|
@ -44,24 +38,35 @@ type cubicSender struct {
|
|||
// When true, exit slow start with large cutback of congestion window.
|
||||
slowStartLargeReduction bool
|
||||
|
||||
// Minimum congestion window in packets.
|
||||
minCongestionWindow protocol.PacketNumber
|
||||
// Congestion window in packets.
|
||||
congestionWindow protocol.ByteCount
|
||||
|
||||
// Maximum number of outstanding packets for tcp.
|
||||
maxTCPCongestionWindow protocol.PacketNumber
|
||||
// Minimum congestion window in packets.
|
||||
minCongestionWindow protocol.ByteCount
|
||||
|
||||
// Maximum congestion window.
|
||||
maxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Slow start congestion window in bytes, aka ssthresh.
|
||||
slowstartThreshold protocol.ByteCount
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// ACK counter for the Reno implementation.
|
||||
congestionWindowCount protocol.ByteCount
|
||||
numAckedPackets uint64
|
||||
|
||||
initialCongestionWindow protocol.PacketNumber
|
||||
initialMaxCongestionWindow protocol.PacketNumber
|
||||
initialCongestionWindow protocol.ByteCount
|
||||
initialMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
minSlowStartExitWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
var _ SendAlgorithm = &cubicSender{}
|
||||
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
|
||||
|
||||
// NewCubicSender makes a new cubic sender
|
||||
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo {
|
||||
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
|
||||
return &cubicSender{
|
||||
rttStats: rttStats,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
|
@ -69,7 +74,7 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
|
|||
congestionWindow: initialCongestionWindow,
|
||||
minCongestionWindow: defaultMinimumCongestionWindow,
|
||||
slowstartThreshold: initialMaxCongestionWindow,
|
||||
maxTCPCongestionWindow: initialMaxCongestionWindow,
|
||||
maxCongestionWindow: initialMaxCongestionWindow,
|
||||
numConnections: defaultNumConnections,
|
||||
cubic: NewCubic(clock),
|
||||
reno: reno,
|
||||
|
@ -80,21 +85,26 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
|
|||
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
if c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) == 0 {
|
||||
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()/protocol.DefaultTCPMSS)
|
||||
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
|
||||
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
|
||||
delay = delay * 8 / 5
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
||||
// Only update bytesInFlight for data packets.
|
||||
func (c *cubicSender) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
bytesInFlight protocol.ByteCount,
|
||||
packetNumber protocol.PacketNumber,
|
||||
bytes protocol.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
if !isRetransmittable {
|
||||
return false
|
||||
return
|
||||
}
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
|
@ -102,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By
|
|||
}
|
||||
c.largestSentPacketNumber = packetNumber
|
||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *cubicSender) InRecovery() bool {
|
||||
|
@ -114,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool {
|
|||
}
|
||||
|
||||
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
||||
return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS
|
||||
return c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
|
||||
return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS
|
||||
return c.slowstartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) ExitSlowstart() {
|
||||
c.slowstartThreshold = c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
|
||||
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
|
||||
return c.slowstartThreshold
|
||||
}
|
||||
|
||||
|
@ -135,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) OnPacketAcked(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
c.prr.OnPacketAcked(ackedBytes)
|
||||
return
|
||||
}
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight)
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||
if c.InSlowStart() {
|
||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) OnPacketLost(
|
||||
packetNumber protocol.PacketNumber,
|
||||
lostBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
) {
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
|
@ -156,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
|||
c.stats.slowstartPacketsLost++
|
||||
c.stats.slowstartBytesLost += lostBytes
|
||||
if c.slowStartLargeReduction {
|
||||
if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS {
|
||||
// Reduce congestion window by 1 for every mss of bytes lost.
|
||||
c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow)
|
||||
}
|
||||
// Reduce congestion window by lost_bytes for every loss.
|
||||
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
|
||||
c.slowstartThreshold = c.congestionWindow
|
||||
}
|
||||
}
|
||||
|
@ -170,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
|||
c.stats.slowstartPacketsLost++
|
||||
}
|
||||
|
||||
c.prr.OnPacketLost(bytesInFlight)
|
||||
c.prr.OnPacketLost(priorInFlight)
|
||||
|
||||
// TODO(chromium): Separate out all of slow start into a separate class.
|
||||
if c.slowStartLargeReduction && c.InSlowStart() {
|
||||
c.congestionWindow = c.congestionWindow - 1
|
||||
if c.congestionWindow >= 2*c.initialCongestionWindow {
|
||||
c.minSlowStartExitWindow = c.congestionWindow / 2
|
||||
}
|
||||
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
|
||||
} else if c.reno {
|
||||
c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta())
|
||||
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
|
||||
} else {
|
||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||
}
|
||||
// Enforce a minimum congestion window.
|
||||
if c.congestionWindow < c.minCongestionWindow {
|
||||
c.congestionWindow = c.minCongestionWindow
|
||||
}
|
||||
|
@ -188,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
|||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||
// reset packet count from congestion avoidance mode. We start
|
||||
// counting again when we're out of recovery.
|
||||
c.congestionWindowCount = 0
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
|
||||
func (c *cubicSender) RenoBeta() float32 {
|
||||
|
@ -201,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 {
|
|||
|
||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||
// represents, but quic has a separate ack for each packet.
|
||||
func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) maybeIncreaseCwnd(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
// Do not increase the congestion window unless the sender is close to using
|
||||
// the current window.
|
||||
if !c.isCwndLimited(bytesInFlight) {
|
||||
if !c.isCwndLimited(priorInFlight) {
|
||||
c.cubic.OnApplicationLimited()
|
||||
return
|
||||
}
|
||||
if c.congestionWindow >= c.maxTCPCongestionWindow {
|
||||
if c.congestionWindow >= c.maxCongestionWindow {
|
||||
return
|
||||
}
|
||||
if c.InSlowStart() {
|
||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||
c.congestionWindow++
|
||||
c.congestionWindow += protocol.DefaultTCPMSS
|
||||
return
|
||||
}
|
||||
// Congestion avoidance
|
||||
if c.reno {
|
||||
// Classic Reno congestion avoidance.
|
||||
c.congestionWindowCount++
|
||||
c.numAckedPackets++
|
||||
// Divide by num_connections to smoothly increase the CWND at a faster
|
||||
// rate than conventional Reno.
|
||||
if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow {
|
||||
c.congestionWindow++
|
||||
c.congestionWindowCount = 0
|
||||
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
|
||||
c.congestionWindow += protocol.DefaultTCPMSS
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
} else {
|
||||
c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT()))
|
||||
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -282,10 +306,10 @@ func (c *cubicSender) OnConnectionMigration() {
|
|||
c.largestSentAtLastCutback = 0
|
||||
c.lastCutbackExitedSlowstart = false
|
||||
c.cubic.Reset()
|
||||
c.congestionWindowCount = 0
|
||||
c.numAckedPackets = 0
|
||||
c.congestionWindow = c.initialCongestionWindow
|
||||
c.slowstartThreshold = c.initialMaxCongestionWindow
|
||||
c.maxTCPCongestionWindow = c.initialMaxCongestionWindow
|
||||
c.maxCongestionWindow = c.initialMaxCongestionWindow
|
||||
}
|
||||
|
||||
// SetSlowStartLargeReduction allows enabling the SSLR experiment
|
||||
|
|
8
vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go
generated
vendored
|
@ -9,11 +9,11 @@ import (
|
|||
// A SendAlgorithm performs congestion control and calculates the congestion window
|
||||
type SendAlgorithm interface {
|
||||
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
MaybeExitSlowStart()
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
|
||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
|
||||
SetNumEmulatedConnections(n int)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
OnConnectionMigration()
|
||||
|
@ -30,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface {
|
|||
// Stuff only used in testing
|
||||
|
||||
HybridSlowStart() *HybridSlowStart
|
||||
SlowstartThreshold() protocol.PacketNumber
|
||||
SlowstartThreshold() protocol.ByteCount
|
||||
RenoBeta() float32
|
||||
InRecovery() bool
|
||||
}
|
||||
|
|
23
vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go
generated
vendored
23
vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go
generated
vendored
|
@ -1,10 +1,7 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
||||
|
@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
|
|||
// OnPacketLost should be called on the first loss that triggers a recovery
|
||||
// period and all other methods in this class should only be called when in
|
||||
// recovery.
|
||||
func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) {
|
||||
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
|
||||
p.bytesSentSinceLoss = 0
|
||||
p.bytesInFlightBeforeLoss = bytesInFlight
|
||||
p.bytesInFlightBeforeLoss = priorInFlight
|
||||
p.bytesDeliveredSinceLoss = 0
|
||||
p.ackCountSinceLoss = 0
|
||||
}
|
||||
|
@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
|
|||
p.ackCountSinceLoss++
|
||||
}
|
||||
|
||||
// TimeUntilSend calculates the time until a packet can be sent
|
||||
func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration {
|
||||
// CanSend returns if packets can be sent
|
||||
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
|
||||
// Return QuicTime::Zero In order to ensure limited transmit always works.
|
||||
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
|
||||
return 0
|
||||
return true
|
||||
}
|
||||
if congestionWindow > bytesInFlight {
|
||||
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
|
||||
// of sending the entire available window. This prevents burst retransmits
|
||||
// when more packets are lost than the CWND reduction.
|
||||
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
|
||||
if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss {
|
||||
return utils.InfDuration
|
||||
}
|
||||
return 0
|
||||
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
|
||||
}
|
||||
// Implement Proportional Rate Reduction (RFC6937).
|
||||
// Checks a simplified version of the PRR formula that doesn't use division:
|
||||
// AvailableSendWindow =
|
||||
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
|
||||
if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss {
|
||||
return 0
|
||||
}
|
||||
return utils.InfDuration
|
||||
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
|
||||
}
|
||||
|
|
111
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
111
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
|
@ -7,50 +7,27 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// Note: This constant is also defined in the ackhandler package.
|
||||
initialRTTus = 100 * 1000
|
||||
rttAlpha float32 = 0.125
|
||||
oneMinusAlpha float32 = (1 - rttAlpha)
|
||||
rttBeta float32 = 0.25
|
||||
oneMinusBeta float32 = (1 - rttBeta)
|
||||
halfWindow float32 = 0.5
|
||||
quarterWindow float32 = 0.25
|
||||
// The default RTT used before an RTT sample is taken.
|
||||
defaultInitialRTT = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
type rttSample struct {
|
||||
rtt time.Duration
|
||||
time time.Time
|
||||
}
|
||||
|
||||
// RTTStats provides round-trip statistics
|
||||
type RTTStats struct {
|
||||
initialRTTus int64
|
||||
|
||||
recentMinRTTwindow time.Duration
|
||||
minRTT time.Duration
|
||||
latestRTT time.Duration
|
||||
smoothedRTT time.Duration
|
||||
meanDeviation time.Duration
|
||||
|
||||
numMinRTTsamplesRemaining uint32
|
||||
|
||||
newMinRTT rttSample
|
||||
recentMinRTT rttSample
|
||||
halfWindowRTT rttSample
|
||||
quarterWindowRTT rttSample
|
||||
minRTT time.Duration
|
||||
latestRTT time.Duration
|
||||
smoothedRTT time.Duration
|
||||
meanDeviation time.Duration
|
||||
}
|
||||
|
||||
// NewRTTStats makes a properly initialized RTTStats object
|
||||
func NewRTTStats() *RTTStats {
|
||||
return &RTTStats{
|
||||
initialRTTus: initialRTTus,
|
||||
recentMinRTTwindow: utils.InfDuration,
|
||||
}
|
||||
return &RTTStats{}
|
||||
}
|
||||
|
||||
// InitialRTTus is the initial RTT in us
|
||||
func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus }
|
||||
|
||||
// MinRTT Returns the minRTT for the entire connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
|
||||
|
@ -59,28 +36,22 @@ func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
|
|||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
|
||||
|
||||
// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the
|
||||
// minRTT for the entire connection if SampleNewMinRtt was never called.
|
||||
func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt }
|
||||
|
||||
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
|
||||
|
||||
// GetQuarterWindowRTT gets the quarter window RTT
|
||||
func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt }
|
||||
|
||||
// GetHalfWindowRTT gets the half window RTT
|
||||
func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt }
|
||||
// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection.
|
||||
// If no valid updates have occurred, it returns the initial RTT.
|
||||
func (r *RTTStats) SmoothedOrInitialRTT() time.Duration {
|
||||
if r.smoothedRTT != 0 {
|
||||
return r.smoothedRTT
|
||||
}
|
||||
return defaultInitialRTT
|
||||
}
|
||||
|
||||
// MeanDeviation gets the mean deviation
|
||||
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
|
||||
|
||||
// SetRecentMinRTTwindow sets how old a recent min rtt sample can be.
|
||||
func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
|
||||
r.recentMinRTTwindow = recentMinRTTwindow
|
||||
}
|
||||
|
||||
// UpdateRTT updates the RTT based on a new sample.
|
||||
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
||||
if sendDelta == utils.InfDuration || sendDelta <= 0 {
|
||||
|
@ -94,7 +65,6 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
|||
if r.minRTT == 0 || r.minRTT > sendDelta {
|
||||
r.minRTT = sendDelta
|
||||
}
|
||||
r.updateRecentMinRTT(sendDelta, now)
|
||||
|
||||
// Correct for ackDelay if information received from the peer results in a
|
||||
// an RTT sample at least as large as minRTT. Otherwise, only use the
|
||||
|
@ -114,63 +84,12 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update.
|
||||
if r.numMinRTTsamplesRemaining > 0 {
|
||||
r.numMinRTTsamplesRemaining--
|
||||
if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt {
|
||||
r.newMinRTT = rttSample{rtt: sample, time: now}
|
||||
}
|
||||
if r.numMinRTTsamplesRemaining == 0 {
|
||||
r.recentMinRTT = r.newMinRTT
|
||||
r.halfWindowRTT = r.newMinRTT
|
||||
r.quarterWindowRTT = r.newMinRTT
|
||||
}
|
||||
}
|
||||
|
||||
// Update the three recent rtt samples.
|
||||
if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt {
|
||||
r.recentMinRTT = rttSample{rtt: sample, time: now}
|
||||
r.halfWindowRTT = r.recentMinRTT
|
||||
r.quarterWindowRTT = r.recentMinRTT
|
||||
} else if sample <= r.halfWindowRTT.rtt {
|
||||
r.halfWindowRTT = rttSample{rtt: sample, time: now}
|
||||
r.quarterWindowRTT = r.halfWindowRTT
|
||||
} else if sample <= r.quarterWindowRTT.rtt {
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
}
|
||||
|
||||
// Expire old min rtt samples.
|
||||
if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) {
|
||||
r.recentMinRTT = r.halfWindowRTT
|
||||
r.halfWindowRTT = r.quarterWindowRTT
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
} else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) {
|
||||
r.halfWindowRTT = r.quarterWindowRTT
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
} else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) {
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
}
|
||||
}
|
||||
|
||||
// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next
|
||||
// |numSamples| UpdateRTT calls.
|
||||
func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) {
|
||||
r.numMinRTTsamplesRemaining = numSamples
|
||||
r.newMinRTT = rttSample{}
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
|
||||
func (r *RTTStats) OnConnectionMigration() {
|
||||
r.latestRTT = 0
|
||||
r.minRTT = 0
|
||||
r.smoothedRTT = 0
|
||||
r.meanDeviation = 0
|
||||
r.initialRTTus = initialRTTus
|
||||
r.numMinRTTsamplesRemaining = 0
|
||||
r.recentMinRTTwindow = utils.InfDuration
|
||||
r.recentMinRTT = rttSample{}
|
||||
r.halfWindowRTT = rttSample{}
|
||||
r.quarterWindowRTT = rttSample{}
|
||||
}
|
||||
|
||||
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
|
||||
|
|
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
|
@ -15,7 +15,7 @@ const (
|
|||
|
||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||
type TLSExporter interface {
|
||||
GetCipherSuite() mint.CipherSuiteParams
|
||||
ConnectionState() mint.ConnectionState
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
|||
}
|
||||
|
||||
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
|
||||
cs := tls.GetCipherSuite()
|
||||
cs := tls.ConnectionState().CipherSuite
|
||||
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
|
|||
} else {
|
||||
info.Write([]byte("QUIC key expansion\x00"))
|
||||
}
|
||||
utils.BigEndian.WriteUint64(&info, uint64(connID))
|
||||
info.Write(connID)
|
||||
info.Write(chlo)
|
||||
info.Write(scfg)
|
||||
info.Write(cert)
|
||||
|
|
5
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
5
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
|
@ -2,7 +2,6 @@ package crypto
|
|||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -28,9 +27,7 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
|
|||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
connID := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(connID, uint64(connectionID))
|
||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
|
||||
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
|
||||
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
|
||||
|
|
16
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
|
@ -11,8 +11,9 @@ import (
|
|||
|
||||
type baseFlowController struct {
|
||||
// for sending data
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
lastBlockedAt protocol.ByteCount
|
||||
|
||||
// for receiving data
|
||||
mutex sync.RWMutex
|
||||
|
@ -29,6 +30,17 @@ type baseFlowController struct {
|
|||
logger utils.Logger
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
|
|
@ -10,8 +10,9 @@ import (
|
|||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
lastBlockedAt protocol.ByteCount
|
||||
baseFlowController
|
||||
|
||||
queueWindowUpdate func()
|
||||
}
|
||||
|
||||
var _ ConnectionFlowController = &connectionFlowController{}
|
||||
|
@ -21,6 +22,7 @@ var _ ConnectionFlowController = &connectionFlowController{}
|
|||
func NewConnectionFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(),
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ConnectionFlowController {
|
||||
|
@ -32,6 +34,7 @@ func NewConnectionFlowController(
|
|||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
logger: logger,
|
||||
},
|
||||
queueWindowUpdate: queueWindowUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,17 +42,6 @@ func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
|||
return c.baseFlowController.sendWindowSize()
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||
c.mutex.Lock()
|
||||
|
@ -62,6 +54,15 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) MaybeQueueWindowUpdate() {
|
||||
c.mutex.Lock()
|
||||
hasWindowUpdate := c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if hasWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
|
@ -78,6 +79,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
|||
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
if inc > c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
|
||||
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
|
|
8
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
|
@ -10,26 +10,22 @@ type flowController interface {
|
|||
// for receiving
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||
MaybeQueueWindowUpdate() // queues a window update, if necessary
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsBlocked() (bool, protocol.ByteCount)
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||
// HasWindowUpdate says if it is necessary to update the window
|
||||
HasWindowUpdate() bool
|
||||
}
|
||||
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
type connectionFlowControllerI interface {
|
||||
|
|
24
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
|
@ -14,6 +14,8 @@ type streamFlowController struct {
|
|||
|
||||
streamID protocol.StreamID
|
||||
|
||||
queueWindowUpdate func()
|
||||
|
||||
connection connectionFlowControllerI
|
||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||
|
||||
|
@ -30,6 +32,7 @@ func NewStreamFlowController(
|
|||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(protocol.StreamID),
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
|
@ -37,6 +40,7 @@ func NewStreamFlowController(
|
|||
streamID: streamID,
|
||||
contributesToConnection: contributesToConnection,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
|
@ -111,20 +115,16 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
|||
return window
|
||||
}
|
||||
|
||||
// IsBlocked says if it is blocked by stream-level flow control.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 {
|
||||
return false, 0
|
||||
}
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
func (c *streamFlowController) HasWindowUpdate() bool {
|
||||
func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
||||
c.mutex.Lock()
|
||||
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
return hasWindowUpdate
|
||||
if hasWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
if c.contributesToConnection {
|
||||
c.connection.MaybeQueueWindowUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
|
@ -139,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
|||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||
if c.contributesToConnection {
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
|
|
6
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
6
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
|
@ -5,8 +5,6 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -29,12 +27,12 @@ type token struct {
|
|||
|
||||
// A CookieGenerator generates Cookies
|
||||
type CookieGenerator struct {
|
||||
cookieProtector mint.CookieProtector
|
||||
cookieProtector cookieProtector
|
||||
}
|
||||
|
||||
// NewCookieGenerator initializes a new CookieGenerator
|
||||
func NewCookieGenerator() (*CookieGenerator, error) {
|
||||
cookieProtector, err := mint.NewDefaultCookieProtector()
|
||||
cookieProtector, err := newCookieProtector()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
51
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
51
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
|
@ -1,51 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A CookieHandler generates and validates cookies.
|
||||
// The cookie is sent in the TLS Retry.
|
||||
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
|
||||
type CookieHandler struct {
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
cookieGenerator *CookieGenerator
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ mint.CookieHandler = &CookieHandler{}
|
||||
|
||||
// NewCookieHandler creates a new CookieHandler.
|
||||
func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CookieHandler{
|
||||
callback: callback,
|
||||
cookieGenerator: cookieGenerator,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate a new cookie for a mint connection.
|
||||
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
if h.callback(conn.RemoteAddr(), nil) {
|
||||
return nil, nil
|
||||
}
|
||||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||
}
|
||||
|
||||
// Validate a cookie.
|
||||
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
data, err := h.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
return false
|
||||
}
|
||||
return h.callback(conn.RemoteAddr(), data)
|
||||
}
|
86
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_protector.go
generated
vendored
Normal file
86
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_protector.go
generated
vendored
Normal file
|
@ -0,0 +1,86 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// CookieProtector is used to create and verify a cookie
|
||||
type cookieProtector interface {
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
const (
|
||||
cookieSecretSize = 32
|
||||
cookieNonceSize = 32
|
||||
)
|
||||
|
||||
// cookieProtector is used to create and verify a cookie
|
||||
type cookieProtectorImpl struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
// newCookieProtector creates a source for source address tokens
|
||||
func newCookieProtector() (cookieProtector, error) {
|
||||
secret := make([]byte, cookieSecretSize)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cookieProtectorImpl{secret: secret}, nil
|
||||
}
|
||||
|
||||
// NewToken encodes data into a new token.
|
||||
func (s *cookieProtectorImpl) NewToken(data []byte) ([]byte, error) {
|
||||
nonce := make([]byte, cookieNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token.
|
||||
func (s *cookieProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < cookieNonceSize {
|
||||
return nil, fmt.Errorf("Token too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:cookieNonceSize]
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
|
||||
}
|
||||
|
||||
func (s *cookieProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
|
||||
h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go cookie source"))
|
||||
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
|
||||
if _, err := io.ReadFull(h, key); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aeadNonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(h, aeadNonce); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return aead, aeadNonce, nil
|
||||
}
|
69
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
69
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
|
@ -38,7 +38,7 @@ type cryptoSetupClient struct {
|
|||
lastSentCHLO []byte
|
||||
certManager crypto.CertManager
|
||||
|
||||
divNonceChan <-chan []byte
|
||||
divNonceChan chan struct{}
|
||||
diversificationNonce []byte
|
||||
|
||||
clientHelloCounter int
|
||||
|
@ -79,29 +79,31 @@ func NewCryptoSetupClient(
|
|||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, chan<- []byte, error) {
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
divNonceChan := make(chan []byte)
|
||||
divNonceChan := make(chan struct{})
|
||||
cs := &cryptoSetupClient{
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
version: version,
|
||||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
params: params,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
negotiatedVersions: negotiatedVersions,
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
version: version,
|
||||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
params: params,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
// The server might have sent greased versions in the Version Negotiation packet.
|
||||
// We need strip those from the list, since they won't be included in the handshake tag.
|
||||
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
|
||||
divNonceChan: divNonceChan,
|
||||
logger: logger,
|
||||
}
|
||||
return cs, divNonceChan, nil
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
|
@ -120,33 +122,26 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||
}()
|
||||
|
||||
for {
|
||||
err := h.maybeUpgradeCrypto()
|
||||
if err != nil {
|
||||
if err := h.maybeUpgradeCrypto(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.mutex.RLock()
|
||||
sendCHLO := h.secureAEAD == nil
|
||||
h.mutex.RUnlock()
|
||||
|
||||
if sendCHLO {
|
||||
err = h.sendCHLO()
|
||||
if err != nil {
|
||||
if err := h.sendCHLO(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var message HandshakeMessage
|
||||
select {
|
||||
case divNonce := <-h.divNonceChan:
|
||||
if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) {
|
||||
return errConflictingDiversificationNonces
|
||||
}
|
||||
h.diversificationNonce = divNonce
|
||||
case <-h.divNonceChan:
|
||||
// there's no message to process, but we should try upgrading the crypto again
|
||||
continue
|
||||
case message = <-messageChan:
|
||||
case err = <-errorChan:
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -281,6 +276,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*Trans
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
|
||||
|
||||
params, err := readHelloMap(cryptoData)
|
||||
if err != nil {
|
||||
|
@ -326,6 +322,7 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||
if h.secureAEAD != nil {
|
||||
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return data, protocol.EncryptionSecure, nil
|
||||
}
|
||||
|
@ -386,6 +383,21 @@ func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
|
||||
h.mutex.Lock()
|
||||
if len(h.diversificationNonce) > 0 {
|
||||
defer h.mutex.Unlock()
|
||||
if !bytes.Equal(h.diversificationNonce, divNonce) {
|
||||
return errConflictingDiversificationNonces
|
||||
}
|
||||
return nil
|
||||
}
|
||||
h.diversificationNonce = divNonce
|
||||
h.mutex.Unlock()
|
||||
h.divNonceChan <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
h.clientHelloCounter++
|
||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||
|
@ -501,6 +513,7 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
}
|
||||
return nil
|
||||
|
|
12
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
12
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
|
@ -60,11 +60,6 @@ type cryptoSetupServer struct {
|
|||
|
||||
var _ CryptoSetup = &cryptoSetupServer{}
|
||||
|
||||
// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO.
|
||||
// This is an experiment implemented by Chrome in QUIC 36, which we don't support.
|
||||
// TODO: remove this when dropping support for QUIC 36
|
||||
var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported")
|
||||
|
||||
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
|
||||
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
|
||||
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
|
||||
|
@ -132,9 +127,6 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
|
|||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
|
||||
if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment {
|
||||
return false, ErrHOLExperiment
|
||||
}
|
||||
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
|
||||
return false, ErrNSTPExperiment
|
||||
}
|
||||
|
@ -214,6 +206,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
||||
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
|
||||
h.receivedForwardSecurePacket = true
|
||||
// wait for the send on the handshakeEvent chan
|
||||
<-h.sentSHLO
|
||||
|
@ -228,6 +221,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||
if h.secureAEAD != nil {
|
||||
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return res, protocol.EncryptionSecure, nil
|
||||
}
|
||||
|
@ -400,6 +394,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
|
||||
// Generate a new curve instance to derive the forward secure key
|
||||
|
@ -429,6 +424,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
|
||||
|
||||
replyMap := h.params.getHelloMap()
|
||||
// add crypto parameters
|
||||
|
|
54
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
54
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
|
@ -11,9 +11,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
|
||||
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
|
||||
|
||||
// KeyDerivationFunction is used for key derivation
|
||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
|
@ -26,8 +23,8 @@ type cryptoSetupTLS struct {
|
|||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
|
||||
tls MintTLS
|
||||
cryptoStream *CryptoStreamConn
|
||||
tls mintTLS
|
||||
conn *cryptoStreamConn
|
||||
handshakeEvent chan<- struct{}
|
||||
}
|
||||
|
||||
|
@ -35,39 +32,46 @@ var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
|||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
func NewCryptoSetupTLSServer(
|
||||
tls MintTLS,
|
||||
cryptoStream *CryptoStreamConn,
|
||||
nullAEAD crypto.AEAD,
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
config *mint.Config,
|
||||
handshakeEvent chan<- struct{},
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetupTLS {
|
||||
) (CryptoSetupTLS, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := newCryptoStreamConn(cryptoStream)
|
||||
tls := mint.Server(conn, config)
|
||||
return &cryptoSetupTLS{
|
||||
tls: tls,
|
||||
cryptoStream: cryptoStream,
|
||||
conn: conn,
|
||||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||
func NewCryptoSetupTLSClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
hostname string,
|
||||
config *mint.Config,
|
||||
handshakeEvent chan<- struct{},
|
||||
tls MintTLS,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetupTLS, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := newCryptoStreamConn(cryptoStream)
|
||||
tls := mint.Client(conn, config)
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: tls,
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
handshakeEvent: handshakeEvent,
|
||||
|
@ -75,24 +79,16 @@ func NewCryptoSetupTLSClient(
|
|||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
|
||||
// send out that data now
|
||||
if _, err := h.cryptoStream.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
handshakeLoop:
|
||||
for {
|
||||
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
switch h.tls.State() {
|
||||
case mint.StateClientStart: // this happens if a stateless retry is performed
|
||||
return ErrCloseSessionForRetry
|
||||
case mint.StateClientConnected, mint.StateServerConnected:
|
||||
break handshakeLoop
|
||||
state := h.tls.ConnectionState().HandshakeState
|
||||
if err := h.conn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
if state == mint.StateClientConnected || state == mint.StateServerConnected {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
|
86
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
86
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
|
@ -7,95 +7,63 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// The CryptoStreamConn is used as the net.Conn passed to mint.
|
||||
// It has two operating modes:
|
||||
// 1. It can read and write to bytes.Buffers.
|
||||
// 2. It can use a quic.Stream for reading and writing.
|
||||
// The buffer-mode is only used by the server, in order to statelessly handle retries.
|
||||
type CryptoStreamConn struct {
|
||||
remoteAddr net.Addr
|
||||
|
||||
// the buffers are used before the session is initialized
|
||||
readBuf bytes.Buffer
|
||||
writeBuf bytes.Buffer
|
||||
|
||||
// stream will be set once the session is initialized
|
||||
type cryptoStreamConn struct {
|
||||
buffer *bytes.Buffer
|
||||
stream io.ReadWriter
|
||||
}
|
||||
|
||||
var _ net.Conn = &CryptoStreamConn{}
|
||||
var _ net.Conn = &cryptoStreamConn{}
|
||||
|
||||
// NewCryptoStreamConn creates a new CryptoStreamConn
|
||||
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
|
||||
return &CryptoStreamConn{remoteAddr: remoteAddr}
|
||||
}
|
||||
|
||||
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
|
||||
if c.stream != nil {
|
||||
return c.stream.Read(b)
|
||||
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
|
||||
return &cryptoStreamConn{
|
||||
stream: stream,
|
||||
buffer: &bytes.Buffer{},
|
||||
}
|
||||
return c.readBuf.Read(b)
|
||||
}
|
||||
|
||||
// AddDataForReading adds data to the read buffer.
|
||||
// This data will ONLY be read when the stream has not been set.
|
||||
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
|
||||
c.readBuf.Write(data)
|
||||
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
|
||||
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
|
||||
if c.stream != nil {
|
||||
return c.stream.Write(p)
|
||||
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
|
||||
return c.buffer.Write(p)
|
||||
}
|
||||
|
||||
func (c *cryptoStreamConn) Flush() error {
|
||||
if c.buffer.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.writeBuf.Write(p)
|
||||
}
|
||||
|
||||
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
|
||||
func (c *CryptoStreamConn) GetDataForWriting() []byte {
|
||||
defer c.writeBuf.Reset()
|
||||
data := make([]byte, c.writeBuf.Len())
|
||||
copy(data, c.writeBuf.Bytes())
|
||||
return data
|
||||
}
|
||||
|
||||
// SetStream sets the stream.
|
||||
// After setting the stream, the read and write buffer won't be used any more.
|
||||
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
|
||||
c.stream = stream
|
||||
}
|
||||
|
||||
// Flush copies the contents of the write buffer to the stream
|
||||
func (c *CryptoStreamConn) Flush() (int, error) {
|
||||
n, err := io.Copy(c.stream, &c.writeBuf)
|
||||
return int(n), err
|
||||
_, err := c.stream.Write(c.buffer.Bytes())
|
||||
c.buffer.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
// Close is not implemented
|
||||
func (c *CryptoStreamConn) Close() error {
|
||||
func (c *cryptoStreamConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr is not implemented
|
||||
func (c *CryptoStreamConn) LocalAddr() net.Addr {
|
||||
func (c *cryptoStreamConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address
|
||||
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
// RemoteAddr is not implemented
|
||||
func (c *cryptoStreamConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
|
||||
func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
|
||||
func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
|
||||
func (c *cryptoStreamConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
|
@ -35,7 +35,7 @@ func getEphermalKEX() (crypto.KeyExchange, error) {
|
|||
kexMutex.Lock()
|
||||
defer kexMutex.Unlock()
|
||||
// Check if still unfulfilled
|
||||
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
|
||||
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
19
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
19
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
|
@ -2,7 +2,6 @@ package handshake
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"io"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
|
@ -15,6 +14,12 @@ type Sealer interface {
|
|||
Overhead() int
|
||||
}
|
||||
|
||||
// mintTLS combines some methods needed to interact with mint.
|
||||
type mintTLS interface {
|
||||
crypto.TLSExporter
|
||||
Handshake() mint.Alert
|
||||
}
|
||||
|
||||
// A TLSExtensionHandler sends and received the QUIC TLS extension.
|
||||
// It provides the parameters sent by the peer on a channel.
|
||||
type TLSExtensionHandler interface {
|
||||
|
@ -23,18 +28,6 @@ type TLSExtensionHandler interface {
|
|||
GetPeerParams() <-chan TransportParameters
|
||||
}
|
||||
|
||||
// MintTLS combines some methods needed to interact with mint.
|
||||
type MintTLS interface {
|
||||
crypto.TLSExporter
|
||||
|
||||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
State() mint.State
|
||||
ConnectionState() mint.ConnectionState
|
||||
|
||||
SetCryptoStream(io.ReadWriter)
|
||||
}
|
||||
|
||||
type baseCryptoSetup interface {
|
||||
HandleCryptoStream() error
|
||||
ConnectionState() ConnectionState
|
||||
|
|
3
vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go
generated
vendored
Normal file
3
vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go
generated
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
package handshake
|
||||
|
||||
//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"
|
4
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go
generated
vendored
|
@ -50,10 +50,6 @@ const (
|
|||
// TagSFCW is the initial stream flow control receive window.
|
||||
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
|
||||
|
||||
// TagFHL2 forces head of line blocking.
|
||||
// Chrome experiment (see https://codereview.chromium.org/2115033002)
|
||||
// unsupported by quic-go
|
||||
TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24
|
||||
// TagNSTP is the no STOP_WAITING experiment
|
||||
// currently unsupported by quic-go
|
||||
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
|
||||
|
|
94
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go
generated
vendored
94
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go
generated
vendored
|
@ -1,38 +1,106 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type transportParameterID uint16
|
||||
|
||||
const quicTLSExtensionType = 26
|
||||
const quicTLSExtensionType = 0xff5
|
||||
|
||||
const (
|
||||
initialMaxStreamDataParameterID transportParameterID = 0x0
|
||||
initialMaxDataParameterID transportParameterID = 0x1
|
||||
initialMaxStreamsBiDiParameterID transportParameterID = 0x2
|
||||
initialMaxBidiStreamsParameterID transportParameterID = 0x2
|
||||
idleTimeoutParameterID transportParameterID = 0x3
|
||||
omitConnectionIDParameterID transportParameterID = 0x4
|
||||
maxPacketSizeParameterID transportParameterID = 0x5
|
||||
statelessResetTokenParameterID transportParameterID = 0x6
|
||||
initialMaxStreamsUniParameterID transportParameterID = 0x8
|
||||
initialMaxUniStreamsParameterID transportParameterID = 0x8
|
||||
disableMigrationParameterID transportParameterID = 0x9
|
||||
)
|
||||
|
||||
type transportParameter struct {
|
||||
Parameter transportParameterID
|
||||
Value []byte `tls:"head=2"`
|
||||
type clientHelloTransportParameters struct {
|
||||
InitialVersion protocol.VersionNumber
|
||||
Parameters TransportParameters
|
||||
}
|
||||
|
||||
type clientHelloTransportParameters struct {
|
||||
InitialVersion uint32 // actually a protocol.VersionNumber
|
||||
Parameters []transportParameter `tls:"head=2"`
|
||||
func (p *clientHelloTransportParameters) Marshal() []byte {
|
||||
const lenOffset = 4
|
||||
b := &bytes.Buffer{}
|
||||
utils.BigEndian.WriteUint32(b, uint32(p.InitialVersion))
|
||||
b.Write([]byte{0, 0}) // length. Will be replaced later
|
||||
p.Parameters.marshal(b)
|
||||
data := b.Bytes()
|
||||
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
|
||||
return data
|
||||
}
|
||||
|
||||
func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
|
||||
if len(data) < 6 {
|
||||
return errors.New("transport parameter data too short")
|
||||
}
|
||||
p.InitialVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
|
||||
paramsLen := int(binary.BigEndian.Uint16(data[4:6]))
|
||||
data = data[6:]
|
||||
if len(data) != paramsLen {
|
||||
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
||||
}
|
||||
return p.Parameters.unmarshal(data)
|
||||
}
|
||||
|
||||
type encryptedExtensionsTransportParameters struct {
|
||||
NegotiatedVersion uint32 // actually a protocol.VersionNumber
|
||||
SupportedVersions []uint32 `tls:"head=1"` // actually a protocol.VersionNumber
|
||||
Parameters []transportParameter `tls:"head=2"`
|
||||
NegotiatedVersion protocol.VersionNumber
|
||||
SupportedVersions []protocol.VersionNumber
|
||||
Parameters TransportParameters
|
||||
}
|
||||
|
||||
func (p *encryptedExtensionsTransportParameters) Marshal() []byte {
|
||||
b := &bytes.Buffer{}
|
||||
utils.BigEndian.WriteUint32(b, uint32(p.NegotiatedVersion))
|
||||
b.WriteByte(uint8(4 * len(p.SupportedVersions)))
|
||||
for _, v := range p.SupportedVersions {
|
||||
utils.BigEndian.WriteUint32(b, uint32(v))
|
||||
}
|
||||
lenOffset := b.Len()
|
||||
b.Write([]byte{0, 0}) // length. Will be replaced later
|
||||
p.Parameters.marshal(b)
|
||||
data := b.Bytes()
|
||||
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
|
||||
return data
|
||||
}
|
||||
|
||||
func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
|
||||
if len(data) < 5 {
|
||||
return errors.New("transport parameter data too short")
|
||||
}
|
||||
p.NegotiatedVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
|
||||
numVersions := int(data[4])
|
||||
if numVersions%4 != 0 {
|
||||
return fmt.Errorf("invalid length for version list: %d", numVersions)
|
||||
}
|
||||
numVersions /= 4
|
||||
data = data[5:]
|
||||
if len(data) < 4*numVersions+2 /*length field for the parameter list */ {
|
||||
return errors.New("transport parameter data too short")
|
||||
}
|
||||
p.SupportedVersions = make([]protocol.VersionNumber, numVersions)
|
||||
for i := 0; i < numVersions; i++ {
|
||||
p.SupportedVersions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
|
||||
data = data[4:]
|
||||
}
|
||||
paramsLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
if len(data) != paramsLen {
|
||||
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
||||
}
|
||||
return p.Parameters.unmarshal(data)
|
||||
}
|
||||
|
||||
type tlsExtensionBody struct {
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
@ -52,16 +51,12 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
|
|||
if hType != mint.HandshakeTypeClientHello {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
||||
data, err := syntax.Marshal(clientHelloTransportParameters{
|
||||
InitialVersion: uint32(h.initialVersion),
|
||||
Parameters: h.ourParams.getTransportParameters(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
chtp := &clientHelloTransportParameters{
|
||||
InitialVersion: h.initialVersion,
|
||||
Parameters: *h.ourParams,
|
||||
}
|
||||
return el.Add(&tlsExtensionBody{data})
|
||||
return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
|
||||
}
|
||||
|
||||
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
|
@ -84,50 +79,31 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
|
|||
}
|
||||
|
||||
eetp := &encryptedExtensionsTransportParameters{}
|
||||
if _, err := syntax.Unmarshal(ext.data, eetp); err != nil {
|
||||
if err := eetp.Unmarshal(ext.data); err != nil {
|
||||
return err
|
||||
}
|
||||
serverSupportedVersions := make([]protocol.VersionNumber, len(eetp.SupportedVersions))
|
||||
for i, v := range eetp.SupportedVersions {
|
||||
serverSupportedVersions[i] = protocol.VersionNumber(v)
|
||||
}
|
||||
// check that the negotiated_version is the current version
|
||||
if protocol.VersionNumber(eetp.NegotiatedVersion) != h.version {
|
||||
if eetp.NegotiatedVersion != h.version {
|
||||
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
|
||||
}
|
||||
// check that the current version is included in the supported versions
|
||||
if !protocol.IsSupportedVersion(serverSupportedVersions, h.version) {
|
||||
if !protocol.IsSupportedVersion(eetp.SupportedVersions, h.version) {
|
||||
return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions")
|
||||
}
|
||||
// if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server
|
||||
if h.version != h.initialVersion {
|
||||
negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, serverSupportedVersions)
|
||||
negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, eetp.SupportedVersions)
|
||||
if !ok || h.version != negotiatedVersion {
|
||||
return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version")
|
||||
}
|
||||
}
|
||||
|
||||
// check that the server sent the stateless reset token
|
||||
var foundStatelessResetToken bool
|
||||
for _, p := range eetp.Parameters {
|
||||
if p.Parameter == statelessResetTokenParameterID {
|
||||
if len(p.Value) != 16 {
|
||||
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", len(p.Value))
|
||||
}
|
||||
foundStatelessResetToken = true
|
||||
// TODO: handle this value
|
||||
}
|
||||
}
|
||||
if !foundStatelessResetToken {
|
||||
// TODO: return the right error here
|
||||
// check that the server sent a stateless reset token
|
||||
if len(eetp.Parameters.StatelessResetToken) == 0 {
|
||||
return errors.New("server didn't sent stateless_reset_token")
|
||||
}
|
||||
params, err := readTransportParameters(eetp.Parameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Debugf("Received Transport Parameters: %s", params)
|
||||
h.paramsChan <- *params
|
||||
h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
|
||||
h.paramsChan <- eetp.Parameters
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
@ -49,27 +47,13 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
|
|||
if hType != mint.HandshakeTypeEncryptedExtensions {
|
||||
return nil
|
||||
}
|
||||
|
||||
transportParams := append(
|
||||
h.ourParams.getTransportParameters(),
|
||||
// TODO(#855): generate a real token
|
||||
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
|
||||
)
|
||||
supportedVersions := protocol.GetGreasedVersions(h.supportedVersions)
|
||||
versions := make([]uint32, len(supportedVersions))
|
||||
for i, v := range supportedVersions {
|
||||
versions[i] = uint32(v)
|
||||
}
|
||||
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
||||
data, err := syntax.Marshal(encryptedExtensionsTransportParameters{
|
||||
NegotiatedVersion: uint32(h.version),
|
||||
SupportedVersions: versions,
|
||||
Parameters: transportParams,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
eetp := &encryptedExtensionsTransportParameters{
|
||||
NegotiatedVersion: h.version,
|
||||
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
|
||||
Parameters: *h.ourParams,
|
||||
}
|
||||
return el.Add(&tlsExtensionBody{data})
|
||||
return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
|
||||
}
|
||||
|
||||
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
|
@ -90,30 +74,24 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
|
|||
return errors.New("ClientHello didn't contain a QUIC extension")
|
||||
}
|
||||
chtp := &clientHelloTransportParameters{}
|
||||
if _, err := syntax.Unmarshal(ext.data, chtp); err != nil {
|
||||
if err := chtp.Unmarshal(ext.data); err != nil {
|
||||
return err
|
||||
}
|
||||
initialVersion := protocol.VersionNumber(chtp.InitialVersion)
|
||||
|
||||
// perform the stateless version negotiation validation:
|
||||
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
|
||||
// this is the case if and only if the initial version is not contained in the supported versions
|
||||
if initialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) {
|
||||
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
|
||||
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
|
||||
}
|
||||
|
||||
for _, p := range chtp.Parameters {
|
||||
if p.Parameter == statelessResetTokenParameterID {
|
||||
// TODO: return the correct error type
|
||||
return errors.New("client sent a stateless reset token")
|
||||
}
|
||||
// check that the client didn't send a stateless reset token
|
||||
if len(chtp.Parameters.StatelessResetToken) != 0 {
|
||||
// TODO: return the correct error type
|
||||
return errors.New("client sent a stateless reset token")
|
||||
}
|
||||
params, err := readTransportParameters(chtp.Parameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Debugf("Received Transport Parameters: %s", params)
|
||||
h.paramsChan <- *params
|
||||
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
|
||||
h.paramsChan <- chtp.Parameters
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
154
vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go
generated
vendored
154
vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go
generated
vendored
|
@ -26,8 +26,10 @@ type TransportParameters struct {
|
|||
MaxBidiStreams uint16 // only used for IETF QUIC
|
||||
MaxStreams uint32 // only used for gQUIC
|
||||
|
||||
OmitConnectionID bool
|
||||
IdleTimeout time.Duration
|
||||
OmitConnectionID bool // only used for gQUIC
|
||||
IdleTimeout time.Duration
|
||||
DisableMigration bool // only used for IETF QUIC
|
||||
StatelessResetToken []byte // only used for IETF QUIC
|
||||
}
|
||||
|
||||
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
|
||||
|
@ -94,98 +96,114 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte {
|
|||
return tags
|
||||
}
|
||||
|
||||
// readTransportParameters reads the transport parameters sent in the QUIC TLS extension
|
||||
func readTransportParameters(paramsList []transportParameter) (*TransportParameters, error) {
|
||||
params := &TransportParameters{}
|
||||
|
||||
var foundInitialMaxStreamData bool
|
||||
var foundInitialMaxData bool
|
||||
func (p *TransportParameters) unmarshal(data []byte) error {
|
||||
var foundIdleTimeout bool
|
||||
|
||||
for _, p := range paramsList {
|
||||
switch p.Parameter {
|
||||
for len(data) >= 4 {
|
||||
paramID := binary.BigEndian.Uint16(data[:2])
|
||||
paramLen := int(binary.BigEndian.Uint16(data[2:4]))
|
||||
data = data[4:]
|
||||
if len(data) < paramLen {
|
||||
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
|
||||
}
|
||||
switch transportParameterID(paramID) {
|
||||
case initialMaxStreamDataParameterID:
|
||||
foundInitialMaxStreamData = true
|
||||
if len(p.Value) != 4 {
|
||||
return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value))
|
||||
if paramLen != 4 {
|
||||
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
|
||||
}
|
||||
params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
|
||||
p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
|
||||
case initialMaxDataParameterID:
|
||||
foundInitialMaxData = true
|
||||
if len(p.Value) != 4 {
|
||||
return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
|
||||
if paramLen != 4 {
|
||||
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
|
||||
}
|
||||
params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
|
||||
case initialMaxStreamsBiDiParameterID:
|
||||
if len(p.Value) != 2 {
|
||||
return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", len(p.Value))
|
||||
p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
|
||||
case initialMaxBidiStreamsParameterID:
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
|
||||
}
|
||||
params.MaxBidiStreams = binary.BigEndian.Uint16(p.Value)
|
||||
case initialMaxStreamsUniParameterID:
|
||||
if len(p.Value) != 2 {
|
||||
return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", len(p.Value))
|
||||
p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
|
||||
case initialMaxUniStreamsParameterID:
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
|
||||
}
|
||||
params.MaxUniStreams = binary.BigEndian.Uint16(p.Value)
|
||||
p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
|
||||
case idleTimeoutParameterID:
|
||||
foundIdleTimeout = true
|
||||
if len(p.Value) != 2 {
|
||||
return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value))
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
|
||||
}
|
||||
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second)
|
||||
case omitConnectionIDParameterID:
|
||||
if len(p.Value) != 0 {
|
||||
return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value))
|
||||
}
|
||||
params.OmitConnectionID = true
|
||||
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
|
||||
case maxPacketSizeParameterID:
|
||||
if len(p.Value) != 2 {
|
||||
return nil, fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", len(p.Value))
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
|
||||
}
|
||||
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(p.Value))
|
||||
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
|
||||
if maxPacketSize < 1200 {
|
||||
return nil, fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
|
||||
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
|
||||
}
|
||||
params.MaxPacketSize = maxPacketSize
|
||||
p.MaxPacketSize = maxPacketSize
|
||||
case disableMigrationParameterID:
|
||||
if paramLen != 0 {
|
||||
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
|
||||
}
|
||||
p.DisableMigration = true
|
||||
case statelessResetTokenParameterID:
|
||||
if paramLen != 16 {
|
||||
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
|
||||
}
|
||||
p.StatelessResetToken = data[:16]
|
||||
}
|
||||
data = data[paramLen:]
|
||||
}
|
||||
|
||||
if !(foundInitialMaxStreamData && foundInitialMaxData && foundIdleTimeout) {
|
||||
return nil, errors.New("missing parameter")
|
||||
if len(data) != 0 {
|
||||
return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
|
||||
}
|
||||
return params, nil
|
||||
if !foundIdleTimeout {
|
||||
return errors.New("missing parameter")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTransportParameters gets the parameters needed for the TLS handshake.
|
||||
// It doesn't send the initial_max_stream_id_uni parameter, so the peer isn't allowed to open any unidirectional streams.
|
||||
func (p *TransportParameters) getTransportParameters() []transportParameter {
|
||||
initialMaxStreamData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow))
|
||||
initialMaxData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow))
|
||||
initialMaxBidiStreamID := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(initialMaxBidiStreamID, p.MaxBidiStreams)
|
||||
initialMaxUniStreamID := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(initialMaxUniStreamID, p.MaxUniStreams)
|
||||
idleTimeout := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second))
|
||||
maxPacketSize := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize))
|
||||
params := []transportParameter{
|
||||
{initialMaxStreamDataParameterID, initialMaxStreamData},
|
||||
{initialMaxDataParameterID, initialMaxData},
|
||||
{initialMaxStreamsBiDiParameterID, initialMaxBidiStreamID},
|
||||
{initialMaxStreamsUniParameterID, initialMaxUniStreamID},
|
||||
{idleTimeoutParameterID, idleTimeout},
|
||||
{maxPacketSizeParameterID, maxPacketSize},
|
||||
func (p *TransportParameters) marshal(b *bytes.Buffer) {
|
||||
// initial_max_stream_data
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 4)
|
||||
utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
|
||||
// initial_max_data
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 4)
|
||||
utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
|
||||
// initial_max_bidi_streams
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
|
||||
// initial_max_uni_streams
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
|
||||
// idle_timeout
|
||||
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
|
||||
// max_packet_size
|
||||
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
|
||||
// disable_migration
|
||||
if p.DisableMigration {
|
||||
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 0)
|
||||
}
|
||||
if p.OmitConnectionID {
|
||||
params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}})
|
||||
if len(p.StatelessResetToken) > 0 {
|
||||
utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
|
||||
b.Write(p.StatelessResetToken)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// String returns a string representation, intended for logging.
|
||||
// It should only used for IETF QUIC.
|
||||
func (p *TransportParameters) String() string {
|
||||
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, OmitConnectionID: %t, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.OmitConnectionID, p.IdleTimeout)
|
||||
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
|
||||
}
|
||||
|
|
|
@ -49,6 +49,19 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *g
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission))
|
||||
}
|
||||
|
||||
// DequeueProbePacket mocks base method
|
||||
func (m *MockSentPacketHandler) DequeueProbePacket() (*ackhandler.Packet, error) {
|
||||
ret := m.ctrl.Call(m, "DequeueProbePacket")
|
||||
ret0, _ := ret[0].(*ackhandler.Packet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DequeueProbePacket indicates an expected call of DequeueProbePacket
|
||||
func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket))
|
||||
}
|
||||
|
||||
// GetAlarmTimeout mocks base method
|
||||
func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
ret := m.ctrl.Call(m, "GetAlarmTimeout")
|
||||
|
|
14
vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/internal/mocks/congestion.go
generated
vendored
|
@ -68,13 +68,13 @@ func (mr *MockSendAlgorithmMockRecorder) OnConnectionMigration() *gomock.Call {
|
|||
}
|
||||
|
||||
// OnPacketAcked mocks base method
|
||||
func (m *MockSendAlgorithm) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) {
|
||||
m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2)
|
||||
func (m *MockSendAlgorithm) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) {
|
||||
m.ctrl.Call(m, "OnPacketAcked", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// OnPacketAcked indicates an expected call of OnPacketAcked
|
||||
func (mr *MockSendAlgorithmMockRecorder) OnPacketAcked(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketAcked), arg0, arg1, arg2)
|
||||
func (mr *MockSendAlgorithmMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithm)(nil).OnPacketAcked), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// OnPacketLost mocks base method
|
||||
|
@ -88,10 +88,8 @@ func (mr *MockSendAlgorithmMockRecorder) OnPacketLost(arg0, arg1, arg2 interface
|
|||
}
|
||||
|
||||
// OnPacketSent mocks base method
|
||||
func (m *MockSendAlgorithm) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) bool {
|
||||
ret := m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
func (m *MockSendAlgorithm) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) {
|
||||
m.ctrl.Call(m, "OnPacketSent", arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
||||
|
||||
// OnPacketSent indicates an expected call of OnPacketSent
|
||||
|
|
10
vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go
generated
vendored
10
vendor/github.com/lucas-clemente/quic-go/internal/mocks/connection_flow_controller.go
generated
vendored
|
@ -79,6 +79,16 @@ func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Cal
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked))
|
||||
}
|
||||
|
||||
// MaybeQueueWindowUpdate mocks base method
|
||||
func (m *MockConnectionFlowController) MaybeQueueWindowUpdate() {
|
||||
m.ctrl.Call(m, "MaybeQueueWindowUpdate")
|
||||
}
|
||||
|
||||
// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate
|
||||
func (mr *MockConnectionFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).MaybeQueueWindowUpdate))
|
||||
}
|
||||
|
||||
// SendWindowSize mocks base method
|
||||
func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||
ret := m.ctrl.Call(m, "SendWindowSize")
|
||||
|
|
11
vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go
generated
vendored
11
vendor/github.com/lucas-clemente/quic-go/internal/mocks/gen.go
generated
vendored
|
@ -1,11 +0,0 @@
|
|||
package mocks
|
||||
|
||||
//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
|
||||
//go:generate sh -c "./mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
|
||||
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
|
||||
//go:generate sh -c "goimports -w ."
|
107
vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go
generated
vendored
107
vendor/github.com/lucas-clemente/quic-go/internal/mocks/handshake/mint_tls.go
generated
vendored
|
@ -1,107 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
|
||||
|
||||
// Package mockhandshake is a generated GoMock package.
|
||||
package mockhandshake
|
||||
|
||||
import (
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
|
||||
mint "github.com/bifurcation/mint"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockMintTLS is a mock of MintTLS interface
|
||||
type MockMintTLS struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMintTLSMockRecorder
|
||||
}
|
||||
|
||||
// MockMintTLSMockRecorder is the mock recorder for MockMintTLS
|
||||
type MockMintTLSMockRecorder struct {
|
||||
mock *MockMintTLS
|
||||
}
|
||||
|
||||
// NewMockMintTLS creates a new mock instance
|
||||
func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
|
||||
mock := &MockMintTLS{ctrl: ctrl}
|
||||
mock.recorder = &MockMintTLSMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ComputeExporter mocks base method
|
||||
func (m *MockMintTLS) ComputeExporter(arg0 string, arg1 []byte, arg2 int) ([]byte, error) {
|
||||
ret := m.ctrl.Call(m, "ComputeExporter", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ComputeExporter indicates an expected call of ComputeExporter
|
||||
func (mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// ConnectionState mocks base method
|
||||
func (m *MockMintTLS) ConnectionState() mint.ConnectionState {
|
||||
ret := m.ctrl.Call(m, "ConnectionState")
|
||||
ret0, _ := ret[0].(mint.ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ConnectionState indicates an expected call of ConnectionState
|
||||
func (mr *MockMintTLSMockRecorder) ConnectionState() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockMintTLS)(nil).ConnectionState))
|
||||
}
|
||||
|
||||
// GetCipherSuite mocks base method
|
||||
func (m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams {
|
||||
ret := m.ctrl.Call(m, "GetCipherSuite")
|
||||
ret0, _ := ret[0].(mint.CipherSuiteParams)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetCipherSuite indicates an expected call of GetCipherSuite
|
||||
func (mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite))
|
||||
}
|
||||
|
||||
// Handshake mocks base method
|
||||
func (m *MockMintTLS) Handshake() mint.Alert {
|
||||
ret := m.ctrl.Call(m, "Handshake")
|
||||
ret0, _ := ret[0].(mint.Alert)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Handshake indicates an expected call of Handshake
|
||||
func (mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
|
||||
}
|
||||
|
||||
// SetCryptoStream mocks base method
|
||||
func (m *MockMintTLS) SetCryptoStream(arg0 io.ReadWriter) {
|
||||
m.ctrl.Call(m, "SetCryptoStream", arg0)
|
||||
}
|
||||
|
||||
// SetCryptoStream indicates an expected call of SetCryptoStream
|
||||
func (mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0)
|
||||
}
|
||||
|
||||
// State mocks base method
|
||||
func (m *MockMintTLS) State() mint.State {
|
||||
ret := m.ctrl.Call(m, "State")
|
||||
ret0, _ := ret[0].(mint.State)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// State indicates an expected call of State
|
||||
func (mr *MockMintTLSMockRecorder) State() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State))
|
||||
}
|
9
vendor/github.com/lucas-clemente/quic-go/internal/mocks/mockgen.go
generated
vendored
Normal file
9
vendor/github.com/lucas-clemente/quic-go/internal/mocks/mockgen.go
generated
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
package mocks
|
||||
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
|
34
vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go
generated
vendored
34
vendor/github.com/lucas-clemente/quic-go/internal/mocks/stream_flow_controller.go
generated
vendored
|
@ -66,29 +66,27 @@ func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate))
|
||||
}
|
||||
|
||||
// HasWindowUpdate mocks base method
|
||||
func (m *MockStreamFlowController) HasWindowUpdate() bool {
|
||||
ret := m.ctrl.Call(m, "HasWindowUpdate")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HasWindowUpdate indicates an expected call of HasWindowUpdate
|
||||
func (mr *MockStreamFlowControllerMockRecorder) HasWindowUpdate() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).HasWindowUpdate))
|
||||
}
|
||||
|
||||
// IsBlocked mocks base method
|
||||
func (m *MockStreamFlowController) IsBlocked() (bool, protocol.ByteCount) {
|
||||
ret := m.ctrl.Call(m, "IsBlocked")
|
||||
// IsNewlyBlocked mocks base method
|
||||
func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
ret := m.ctrl.Call(m, "IsNewlyBlocked")
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(protocol.ByteCount)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsBlocked indicates an expected call of IsBlocked
|
||||
func (mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked))
|
||||
// IsNewlyBlocked indicates an expected call of IsNewlyBlocked
|
||||
func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked))
|
||||
}
|
||||
|
||||
// MaybeQueueWindowUpdate mocks base method
|
||||
func (m *MockStreamFlowController) MaybeQueueWindowUpdate() {
|
||||
m.ctrl.Call(m, "MaybeQueueWindowUpdate")
|
||||
}
|
||||
|
||||
// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate
|
||||
func (mr *MockStreamFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).MaybeQueueWindowUpdate))
|
||||
}
|
||||
|
||||
// SendWindowSize mocks base method
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue