Skip to content

Commit

Permalink
tls1prf: require callers to pass in the result buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Sep 8, 2023
1 parent 13f20f3 commit 3e5af4f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
40 changes: 21 additions & 19 deletions tls1prf.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ func SupportsTLS1PRF() bool {
(vMajor >= 1 && vMinor >= 1)
}

func TLS1PRF(secret, label, seed []byte, keyLen int, h func() hash.Hash) ([]byte, error) {
// TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil or crypto.MD5SHA1,
// else it implements the TLS 1.2 pseudo-random function.
// The pseudo-random number will be written to result and will be of length len(result).
func TLS1PRF(result, secret, label, seed []byte, h func() hash.Hash) error {
var md C.GO_EVP_MD_PTR
if h == nil {
// TLS 1.0/1.1 PRF doesn't allow to specify the hash function,
Expand All @@ -29,70 +32,69 @@ func TLS1PRF(secret, label, seed []byte, keyLen int, h func() hash.Hash) ([]byte
md = hashToMD(h())
}
if md == nil {
return nil, errors.New("unsupported hash function")
return errors.New("unsupported hash function")
}

ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_TLS1_PRF, nil)
if ctx == nil {
return nil, newOpenSSLError("EVP_PKEY_CTX_new_id")
return newOpenSSLError("EVP_PKEY_CTX_new_id")
}
defer func() {
C.go_openssl_EVP_PKEY_CTX_free(ctx)
}()

if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
return newOpenSSLError("EVP_PKEY_derive_init")
}
switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set_tls1_prf_md(ctx, md) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
}
if C.go_openssl_EVP_PKEY_CTX_set1_tls1_prf_secret(ctx,
base(secret), C.int(len(secret))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
}
if C.go_openssl_EVP_PKEY_CTX_add1_tls1_prf_seed(ctx,
base(label), C.int(len(label))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
}
if C.go_openssl_EVP_PKEY_CTX_add1_tls1_prf_seed(ctx,
base(seed), C.int(len(seed))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
}
case 1:
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_TLS_MD,
0, unsafe.Pointer(md)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_TLS_SECRET,
C.int(len(secret)), unsafe.Pointer(base(secret))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_TLS_SEED,
C.int(len(label)), unsafe.Pointer(base(label))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_TLS_SEED,
C.int(len(seed)), unsafe.Pointer(base(seed))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
}
}
outLen := C.size_t(keyLen)
out := make([]byte, outLen)
if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &outLen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
outLen := C.size_t(len(result))
if C.go_openssl_EVP_PKEY_derive(ctx, base(result), &outLen) != 1 {
return newOpenSSLError("EVP_PKEY_derive")
}
if outLen != C.size_t(keyLen) {
return nil, errors.New("tls1-prf: entropy limit reached")
if outLen != C.size_t(len(result)) {
return errors.New("tls1-prf: entropy limit reached")
}
return out[:outLen], nil
return nil
}
7 changes: 4 additions & 3 deletions tls1prf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ func TestTLS1PRF(t *testing.T) {
if !openssl.SupportsHash(tt.hash) {
t.Skip("skipping: hash not supported")
}
out, err := openssl.TLS1PRF(tt.secret, tt.label, tt.seed, len(tt.out), cryptoToHash(tt.hash))
result := make([]byte, len(tt.out))
err := openssl.TLS1PRF(result, tt.secret, tt.label, tt.seed, cryptoToHash(tt.hash))
if err != nil {
t.Fatalf("error deriving TLS 1.2 PRF: %v.", err)
}
if !bytes.Equal(out, tt.out) {
t.Errorf("incorrect key output: have %v, need %v.", out, tt.out)
if !bytes.Equal(result, tt.out) {
t.Errorf("incorrect key output: have %v, need %v.", result, tt.out)
}
})
}
Expand Down

0 comments on commit 3e5af4f

Please sign in to comment.