Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MD4 and MD5 #105

Merged
merged 5 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,21 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
}
cacheMD.Store(ch, md)
}()
// SupportsHash returns false for MD5 and MD5SHA1 because we don't
// provide a hash.Hash implementation for them. Yet, they can
// SupportsHash returns false for MD5SHA1 because we don't
// provide a hash.Hash implementation for it. Yet, it can
// still be used when signing/verifying with an RSA key.
switch ch {
case crypto.MD5:
return C.go_openssl_EVP_md5()
case crypto.MD5SHA1:
if ch == crypto.MD5SHA1 {
if vMajor == 1 && vMinor == 0 {
return C.go_openssl_EVP_md5_sha1_backport()
} else {
return C.go_openssl_EVP_md5_sha1()
}
}
if !SupportsHash(ch) {
return nil
}
switch ch {
case crypto.MD4:
return C.go_openssl_EVP_md4()
case crypto.MD5:
return C.go_openssl_EVP_md5()
case crypto.SHA1:
return C.go_openssl_EVP_sha1()
case crypto.SHA224:
Expand All @@ -88,13 +86,21 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
case crypto.SHA512:
return C.go_openssl_EVP_sha512()
case crypto.SHA3_224:
return C.go_openssl_EVP_sha3_224()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_224()
}
case crypto.SHA3_256:
return C.go_openssl_EVP_sha3_256()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_256()
}
case crypto.SHA3_384:
return C.go_openssl_EVP_sha3_384()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_384()
}
case crypto.SHA3_512:
return C.go_openssl_EVP_sha3_512()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_512()
}
}
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ FOR_ALL_OPENSSL_FUNCTIONS
#undef DEFINEFUNC_RENAMED_1_1
#undef DEFINEFUNC_RENAMED_3_0

// go_sha_sum copies ctx into ctx2 and calls EVP_DigestFinal using ctx2.
// go_hash_sum copies ctx into ctx2 and calls EVP_DigestFinal using ctx2.
// This is necessary because Go hash.Hash mandates that Sum has no effect
// on the underlying stream. In particular it is OK to Sum, then Write more,
// then Sum again, and the second Sum acts as if the first didn't happen.
// It is written in C because Sum() tend to be in the hot path,
// and doing one cgo call instead of two is a significant performance win.
static inline int
go_sha_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out)
go_hash_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out)
{
if (go_openssl_EVP_MD_CTX_copy(ctx2, ctx) != 1)
return 0;
Expand Down
163 changes: 130 additions & 33 deletions sha.go → hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,81 +24,87 @@ import (
// and applying a noescape along the way.
// This is all to preserve compatibility with the allocation behavior of the non-openssl implementations.

func shaX(ch crypto.Hash, p []byte, sum []byte) bool {
func hashOneShot(ch crypto.Hash, p []byte, sum []byte) bool {
return C.go_openssl_EVP_Digest(unsafe.Pointer(&*addr(p)), C.size_t(len(p)), (*C.uchar)(unsafe.Pointer(&*addr(sum))), nil, cryptoHashToMD(ch), nil) != 0
}

func MD4(p []byte) (sum [16]byte) {
if !hashOneShot(crypto.MD4, p, sum[:]) {
panic("openssl: MD4 failed")
}
return
}

func MD5(p []byte) (sum [16]byte) {
if !hashOneShot(crypto.MD5, p, sum[:]) {
panic("openssl: MD5 failed")
}
return
}

func SHA1(p []byte) (sum [20]byte) {
if !shaX(crypto.SHA1, p, sum[:]) {
if !hashOneShot(crypto.SHA1, p, sum[:]) {
panic("openssl: SHA1 failed")
}
return
}

func SHA224(p []byte) (sum [28]byte) {
if !shaX(crypto.SHA224, p, sum[:]) {
if !hashOneShot(crypto.SHA224, p, sum[:]) {
panic("openssl: SHA224 failed")
}
return
}

func SHA256(p []byte) (sum [32]byte) {
if !shaX(crypto.SHA256, p, sum[:]) {
if !hashOneShot(crypto.SHA256, p, sum[:]) {
panic("openssl: SHA256 failed")
}
return
}

func SHA384(p []byte) (sum [48]byte) {
if !shaX(crypto.SHA384, p, sum[:]) {
if !hashOneShot(crypto.SHA384, p, sum[:]) {
panic("openssl: SHA384 failed")
}
return
}

func SHA512(p []byte) (sum [64]byte) {
if !shaX(crypto.SHA512, p, sum[:]) {
if !hashOneShot(crypto.SHA512, p, sum[:]) {
panic("openssl: SHA512 failed")
}
return
}

// SupportsHash returns true if a hash.Hash implementation is supported for h.
func SupportsHash(h crypto.Hash) bool {
switch h {
case crypto.SHA1, crypto.SHA224, crypto.SHA256, crypto.SHA384, crypto.SHA512:
return true
case crypto.SHA3_224, crypto.SHA3_256, crypto.SHA3_384, crypto.SHA3_512:
return vMajor > 1 ||
(vMajor >= 1 && vMinor > 1) ||
(vMajor >= 1 && vMinor >= 1 && vPatch >= 1)
}
return false
return cryptoHashToMD(h) != nil
}

func SHA3_224(p []byte) (sum [28]byte) {
if !shaX(crypto.SHA3_224, p, sum[:]) {
if !hashOneShot(crypto.SHA3_224, p, sum[:]) {
panic("openssl: SHA3_224 failed")
}
return
}

func SHA3_256(p []byte) (sum [32]byte) {
if !shaX(crypto.SHA3_256, p, sum[:]) {
if !hashOneShot(crypto.SHA3_256, p, sum[:]) {
panic("openssl: SHA3_256 failed")
}
return
}

func SHA3_384(p []byte) (sum [48]byte) {
if !shaX(crypto.SHA3_384, p, sum[:]) {
if !hashOneShot(crypto.SHA3_384, p, sum[:]) {
panic("openssl: SHA3_384 failed")
}
return
}

func SHA3_512(p []byte) (sum [64]byte) {
if !shaX(crypto.SHA3_512, p, sum[:]) {
if !hashOneShot(crypto.SHA3_512, p, sum[:]) {
panic("openssl: SHA3_512 failed")
}
return
Expand Down Expand Up @@ -183,17 +189,17 @@ func (h *evpHash) BlockSize() int {
}

func (h *evpHash) sum(out []byte) {
if C.go_sha_sum(h.ctx, h.ctx2, base(out)) != 1 {
panic(newOpenSSLError("go_sha_sum"))
if C.go_hash_sum(h.ctx, h.ctx2, base(out)) != 1 {
panic(newOpenSSLError("go_hash_sum"))
}
runtime.KeepAlive(h)
}

// shaState returns a pointer to the internal sha structure.
// hashState returns a pointer to the internal hash structure.
//
// The EVP_MD_CTX memory layout has changed in OpenSSL 3
// and the property holding the internal structure is no longer md_data but algctx.
func (h *evpHash) shaState() unsafe.Pointer {
func (h *evpHash) hashState() unsafe.Pointer {
switch vMajor {
case 1:
// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12.
Expand All @@ -217,6 +223,97 @@ func (h *evpHash) shaState() unsafe.Pointer {
}
}

// NewMD4 returns a new MD4 hash.
// The returned hash doesn't implement encoding.BinaryMarshaler and
// encoding.BinaryUnmarshaler.
func NewMD4() hash.Hash {
return &md4Hash{
evpHash: newEvpHash(crypto.MD4, 16, 64),
}
}

type md4Hash struct {
*evpHash
out [16]byte
}

func (h *md4Hash) Sum(in []byte) []byte {
h.sum(h.out[:])
return append(in, h.out[:]...)
}

// NewMD5 returns a new MD5 hash.
func NewMD5() hash.Hash {
return &md5Hash{
evpHash: newEvpHash(crypto.MD5, 16, 64),
}
}

// md5State layout is taken from
// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33.
type md5State struct {
h [4]uint32
nl, nh uint32
x [64]byte
nx uint32
}

type md5Hash struct {
*evpHash
out [16]byte
}

func (h *md5Hash) Sum(in []byte) []byte {
h.sum(h.out[:])
return append(in, h.out[:]...)
}

const (
md5Magic = "md5\x01"
md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8
)

func (h *md5Hash) MarshalBinary() ([]byte, error) {
d := (*md5State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/md5: can't retrieve hash state")
}
b := make([]byte, 0, md5MarshaledSize)
b = append(b, md5Magic...)
b = appendUint32(b, d.h[0])
b = appendUint32(b, d.h[1])
b = appendUint32(b, d.h[2])
b = appendUint32(b, d.h[3])
b = append(b, d.x[:d.nx]...)
b = b[:len(b)+len(d.x)-int(d.nx)] // already zero
b = appendUint64(b, uint64(d.nl)>>3|uint64(d.nh)<<29)
return b, nil
}

func (h *md5Hash) UnmarshalBinary(b []byte) error {
if len(b) < len(md5Magic) || string(b[:len(md5Magic)]) != md5Magic {
return errors.New("crypto/md5: invalid hash state identifier")
}
if len(b) != md5MarshaledSize {
return errors.New("crypto/md5: invalid hash state size")
}
d := (*md5State)(h.hashState())
if d == nil {
return errors.New("crypto/md5: can't retrieve hash state")
}
b = b[len(md5Magic):]
b, d.h[0] = consumeUint32(b)
b, d.h[1] = consumeUint32(b)
b, d.h[2] = consumeUint32(b)
b, d.h[3] = consumeUint32(b)
b = b[copy(d.x[:], b):]
_, n := consumeUint64(b)
d.nl = uint32(n << 3)
d.nh = uint32(n >> 29)
d.nx = uint32(n) % 64
return nil
}

// NewSHA1 returns a new SHA1 hash.
func NewSHA1() hash.Hash {
return &sha1Hash{
Expand Down Expand Up @@ -249,7 +346,7 @@ const (
)

func (h *sha1Hash) MarshalBinary() ([]byte, error) {
d := (*sha1State)(h.shaState())
d := (*sha1State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha1: can't retrieve hash state")
}
Expand All @@ -273,7 +370,7 @@ func (h *sha1Hash) UnmarshalBinary(b []byte) error {
if len(b) != sha1MarshaledSize {
return errors.New("crypto/sha1: invalid hash state size")
}
d := (*sha1State)(h.shaState())
d := (*sha1State)(h.hashState())
if d == nil {
return errors.New("crypto/sha1: can't retrieve hash state")
}
Expand Down Expand Up @@ -341,7 +438,7 @@ type sha256State struct {
}

func (h *sha224Hash) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
}
Expand All @@ -362,7 +459,7 @@ func (h *sha224Hash) MarshalBinary() ([]byte, error) {
}

func (h *sha256Hash) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
}
Expand All @@ -389,7 +486,7 @@ func (h *sha224Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize256 {
return errors.New("crypto/sha256: invalid hash state size")
}
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return errors.New("crypto/sha256: can't retrieve hash state")
}
Expand Down Expand Up @@ -417,7 +514,7 @@ func (h *sha256Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize256 {
return errors.New("crypto/sha256: invalid hash state size")
}
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return errors.New("crypto/sha256: can't retrieve hash state")
}
Expand Down Expand Up @@ -490,7 +587,7 @@ const (
)

func (h *sha384Hash) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
}
Expand All @@ -511,7 +608,7 @@ func (h *sha384Hash) MarshalBinary() ([]byte, error) {
}

func (h *sha512Hash) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
}
Expand Down Expand Up @@ -541,7 +638,7 @@ func (h *sha384Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize512 {
return errors.New("crypto/sha512: invalid hash state size")
}
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return errors.New("crypto/sha512: can't retrieve hash state")
}
Expand Down Expand Up @@ -572,7 +669,7 @@ func (h *sha512Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize512 {
return errors.New("crypto/sha512: invalid hash state size")
}
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return errors.New("crypto/sha512: can't retrieve hash state")
}
Expand Down
Loading