Skip to content

Commit

Permalink
Merge pull request #105 from golang-fips/md4md5
Browse files Browse the repository at this point in the history
Support for MD4 and MD5
  • Loading branch information
qmuntal authored Aug 31, 2023
2 parents c96fdff + 9e9cc11 commit 13f20f3
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 69 deletions.
32 changes: 19 additions & 13 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,21 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
}
cacheMD.Store(ch, md)
}()
// SupportsHash returns false for MD5 and MD5SHA1 because we don't
// provide a hash.Hash implementation for them. Yet, they can
// SupportsHash returns false for MD5SHA1 because we don't
// provide a hash.Hash implementation for it. Yet, it can
// still be used when signing/verifying with an RSA key.
switch ch {
case crypto.MD5:
return C.go_openssl_EVP_md5()
case crypto.MD5SHA1:
if ch == crypto.MD5SHA1 {
if vMajor == 1 && vMinor == 0 {
return C.go_openssl_EVP_md5_sha1_backport()
} else {
return C.go_openssl_EVP_md5_sha1()
}
}
if !SupportsHash(ch) {
return nil
}
switch ch {
case crypto.MD4:
return C.go_openssl_EVP_md4()
case crypto.MD5:
return C.go_openssl_EVP_md5()
case crypto.SHA1:
return C.go_openssl_EVP_sha1()
case crypto.SHA224:
Expand All @@ -88,13 +86,21 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
case crypto.SHA512:
return C.go_openssl_EVP_sha512()
case crypto.SHA3_224:
return C.go_openssl_EVP_sha3_224()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_224()
}
case crypto.SHA3_256:
return C.go_openssl_EVP_sha3_256()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_256()
}
case crypto.SHA3_384:
return C.go_openssl_EVP_sha3_384()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_384()
}
case crypto.SHA3_512:
return C.go_openssl_EVP_sha3_512()
if version1_1_1_or_above() {
return C.go_openssl_EVP_sha3_512()
}
}
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ FOR_ALL_OPENSSL_FUNCTIONS
#undef DEFINEFUNC_RENAMED_1_1
#undef DEFINEFUNC_RENAMED_3_0

// go_sha_sum copies ctx into ctx2 and calls EVP_DigestFinal using ctx2.
// go_hash_sum copies ctx into ctx2 and calls EVP_DigestFinal using ctx2.
// This is necessary because Go hash.Hash mandates that Sum has no effect
// on the underlying stream. In particular it is OK to Sum, then Write more,
// then Sum again, and the second Sum acts as if the first didn't happen.
// It is written in C because Sum() tend to be in the hot path,
// and doing one cgo call instead of two is a significant performance win.
static inline int
go_sha_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out)
go_hash_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out)
{
if (go_openssl_EVP_MD_CTX_copy(ctx2, ctx) != 1)
return 0;
Expand Down
163 changes: 130 additions & 33 deletions sha.go → hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,81 +24,87 @@ import (
// and applying a noescape along the way.
// This is all to preserve compatibility with the allocation behavior of the non-openssl implementations.

func shaX(ch crypto.Hash, p []byte, sum []byte) bool {
func hashOneShot(ch crypto.Hash, p []byte, sum []byte) bool {
return C.go_openssl_EVP_Digest(unsafe.Pointer(&*addr(p)), C.size_t(len(p)), (*C.uchar)(unsafe.Pointer(&*addr(sum))), nil, cryptoHashToMD(ch), nil) != 0
}

func MD4(p []byte) (sum [16]byte) {
if !hashOneShot(crypto.MD4, p, sum[:]) {
panic("openssl: MD4 failed")
}
return
}

func MD5(p []byte) (sum [16]byte) {
if !hashOneShot(crypto.MD5, p, sum[:]) {
panic("openssl: MD5 failed")
}
return
}

func SHA1(p []byte) (sum [20]byte) {
if !shaX(crypto.SHA1, p, sum[:]) {
if !hashOneShot(crypto.SHA1, p, sum[:]) {
panic("openssl: SHA1 failed")
}
return
}

func SHA224(p []byte) (sum [28]byte) {
if !shaX(crypto.SHA224, p, sum[:]) {
if !hashOneShot(crypto.SHA224, p, sum[:]) {
panic("openssl: SHA224 failed")
}
return
}

func SHA256(p []byte) (sum [32]byte) {
if !shaX(crypto.SHA256, p, sum[:]) {
if !hashOneShot(crypto.SHA256, p, sum[:]) {
panic("openssl: SHA256 failed")
}
return
}

func SHA384(p []byte) (sum [48]byte) {
if !shaX(crypto.SHA384, p, sum[:]) {
if !hashOneShot(crypto.SHA384, p, sum[:]) {
panic("openssl: SHA384 failed")
}
return
}

func SHA512(p []byte) (sum [64]byte) {
if !shaX(crypto.SHA512, p, sum[:]) {
if !hashOneShot(crypto.SHA512, p, sum[:]) {
panic("openssl: SHA512 failed")
}
return
}

// SupportsHash returns true if a hash.Hash implementation is supported for h.
func SupportsHash(h crypto.Hash) bool {
switch h {
case crypto.SHA1, crypto.SHA224, crypto.SHA256, crypto.SHA384, crypto.SHA512:
return true
case crypto.SHA3_224, crypto.SHA3_256, crypto.SHA3_384, crypto.SHA3_512:
return vMajor > 1 ||
(vMajor >= 1 && vMinor > 1) ||
(vMajor >= 1 && vMinor >= 1 && vPatch >= 1)
}
return false
return cryptoHashToMD(h) != nil
}

func SHA3_224(p []byte) (sum [28]byte) {
if !shaX(crypto.SHA3_224, p, sum[:]) {
if !hashOneShot(crypto.SHA3_224, p, sum[:]) {
panic("openssl: SHA3_224 failed")
}
return
}

func SHA3_256(p []byte) (sum [32]byte) {
if !shaX(crypto.SHA3_256, p, sum[:]) {
if !hashOneShot(crypto.SHA3_256, p, sum[:]) {
panic("openssl: SHA3_256 failed")
}
return
}

func SHA3_384(p []byte) (sum [48]byte) {
if !shaX(crypto.SHA3_384, p, sum[:]) {
if !hashOneShot(crypto.SHA3_384, p, sum[:]) {
panic("openssl: SHA3_384 failed")
}
return
}

func SHA3_512(p []byte) (sum [64]byte) {
if !shaX(crypto.SHA3_512, p, sum[:]) {
if !hashOneShot(crypto.SHA3_512, p, sum[:]) {
panic("openssl: SHA3_512 failed")
}
return
Expand Down Expand Up @@ -183,17 +189,17 @@ func (h *evpHash) BlockSize() int {
}

func (h *evpHash) sum(out []byte) {
if C.go_sha_sum(h.ctx, h.ctx2, base(out)) != 1 {
panic(newOpenSSLError("go_sha_sum"))
if C.go_hash_sum(h.ctx, h.ctx2, base(out)) != 1 {
panic(newOpenSSLError("go_hash_sum"))
}
runtime.KeepAlive(h)
}

// shaState returns a pointer to the internal sha structure.
// hashState returns a pointer to the internal hash structure.
//
// The EVP_MD_CTX memory layout has changed in OpenSSL 3
// and the property holding the internal structure is no longer md_data but algctx.
func (h *evpHash) shaState() unsafe.Pointer {
func (h *evpHash) hashState() unsafe.Pointer {
switch vMajor {
case 1:
// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12.
Expand All @@ -217,6 +223,97 @@ func (h *evpHash) shaState() unsafe.Pointer {
}
}

// NewMD4 returns a new MD4 hash.
// The returned hash doesn't implement encoding.BinaryMarshaler and
// encoding.BinaryUnmarshaler.
func NewMD4() hash.Hash {
return &md4Hash{
evpHash: newEvpHash(crypto.MD4, 16, 64),
}
}

type md4Hash struct {
*evpHash
out [16]byte
}

func (h *md4Hash) Sum(in []byte) []byte {
h.sum(h.out[:])
return append(in, h.out[:]...)
}

// NewMD5 returns a new MD5 hash.
func NewMD5() hash.Hash {
return &md5Hash{
evpHash: newEvpHash(crypto.MD5, 16, 64),
}
}

// md5State layout is taken from
// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/md5.h#L33.
type md5State struct {
h [4]uint32
nl, nh uint32
x [64]byte
nx uint32
}

type md5Hash struct {
*evpHash
out [16]byte
}

func (h *md5Hash) Sum(in []byte) []byte {
h.sum(h.out[:])
return append(in, h.out[:]...)
}

const (
md5Magic = "md5\x01"
md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8
)

func (h *md5Hash) MarshalBinary() ([]byte, error) {
d := (*md5State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/md5: can't retrieve hash state")
}
b := make([]byte, 0, md5MarshaledSize)
b = append(b, md5Magic...)
b = appendUint32(b, d.h[0])
b = appendUint32(b, d.h[1])
b = appendUint32(b, d.h[2])
b = appendUint32(b, d.h[3])
b = append(b, d.x[:d.nx]...)
b = b[:len(b)+len(d.x)-int(d.nx)] // already zero
b = appendUint64(b, uint64(d.nl)>>3|uint64(d.nh)<<29)
return b, nil
}

func (h *md5Hash) UnmarshalBinary(b []byte) error {
if len(b) < len(md5Magic) || string(b[:len(md5Magic)]) != md5Magic {
return errors.New("crypto/md5: invalid hash state identifier")
}
if len(b) != md5MarshaledSize {
return errors.New("crypto/md5: invalid hash state size")
}
d := (*md5State)(h.hashState())
if d == nil {
return errors.New("crypto/md5: can't retrieve hash state")
}
b = b[len(md5Magic):]
b, d.h[0] = consumeUint32(b)
b, d.h[1] = consumeUint32(b)
b, d.h[2] = consumeUint32(b)
b, d.h[3] = consumeUint32(b)
b = b[copy(d.x[:], b):]
_, n := consumeUint64(b)
d.nl = uint32(n << 3)
d.nh = uint32(n >> 29)
d.nx = uint32(n) % 64
return nil
}

// NewSHA1 returns a new SHA1 hash.
func NewSHA1() hash.Hash {
return &sha1Hash{
Expand Down Expand Up @@ -249,7 +346,7 @@ const (
)

func (h *sha1Hash) MarshalBinary() ([]byte, error) {
d := (*sha1State)(h.shaState())
d := (*sha1State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha1: can't retrieve hash state")
}
Expand All @@ -273,7 +370,7 @@ func (h *sha1Hash) UnmarshalBinary(b []byte) error {
if len(b) != sha1MarshaledSize {
return errors.New("crypto/sha1: invalid hash state size")
}
d := (*sha1State)(h.shaState())
d := (*sha1State)(h.hashState())
if d == nil {
return errors.New("crypto/sha1: can't retrieve hash state")
}
Expand Down Expand Up @@ -341,7 +438,7 @@ type sha256State struct {
}

func (h *sha224Hash) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
}
Expand All @@ -362,7 +459,7 @@ func (h *sha224Hash) MarshalBinary() ([]byte, error) {
}

func (h *sha256Hash) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
}
Expand All @@ -389,7 +486,7 @@ func (h *sha224Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize256 {
return errors.New("crypto/sha256: invalid hash state size")
}
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return errors.New("crypto/sha256: can't retrieve hash state")
}
Expand Down Expand Up @@ -417,7 +514,7 @@ func (h *sha256Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize256 {
return errors.New("crypto/sha256: invalid hash state size")
}
d := (*sha256State)(h.shaState())
d := (*sha256State)(h.hashState())
if d == nil {
return errors.New("crypto/sha256: can't retrieve hash state")
}
Expand Down Expand Up @@ -490,7 +587,7 @@ const (
)

func (h *sha384Hash) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
}
Expand All @@ -511,7 +608,7 @@ func (h *sha384Hash) MarshalBinary() ([]byte, error) {
}

func (h *sha512Hash) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
}
Expand Down Expand Up @@ -541,7 +638,7 @@ func (h *sha384Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize512 {
return errors.New("crypto/sha512: invalid hash state size")
}
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return errors.New("crypto/sha512: can't retrieve hash state")
}
Expand Down Expand Up @@ -572,7 +669,7 @@ func (h *sha512Hash) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize512 {
return errors.New("crypto/sha512: invalid hash state size")
}
d := (*sha512State)(h.shaState())
d := (*sha512State)(h.hashState())
if d == nil {
return errors.New("crypto/sha512: can't retrieve hash state")
}
Expand Down
Loading

0 comments on commit 13f20f3

Please sign in to comment.