From 7ec659eb8f6eb81a14259250c9c5e97dbd483ffb Mon Sep 17 00:00:00 2001 From: xuyang2 Date: Tue, 2 Jul 2024 15:30:02 +0800 Subject: [PATCH] add SCryptPasswordEncoder --- password/scrypt.go | 99 +++++++++++++++++++++++++++++++++++++++++ password/scrypt_test.go | 75 +++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 password/scrypt.go create mode 100644 password/scrypt_test.go diff --git a/password/scrypt.go b/password/scrypt.go new file mode 100644 index 0000000..e2e8933 --- /dev/null +++ b/password/scrypt.go @@ -0,0 +1,99 @@ +package password + +import ( + "bytes" + "encoding/base64" + "math" + "strconv" + "strings" + + "golang.org/x/crypto/scrypt" + + "github.com/xuyang2/password-encoder/keygen" +) + +type SCryptPasswordEncoder struct { + saltGen keygen.BytesKeyGenerator + + cpuCost int // cpu cost of the algorithm (as defined in scrypt this is N) + memoryCost int // memory cost of the algorithm (as defined in scrypt this is r) + parallelization int // the parallelization of the algorithm (as defined in scrypt this is p) + keyLen int +} + +var _ PasswordEncoder = (*SCryptPasswordEncoder)(nil) + +func DefaultSCryptPasswordEncoder() *SCryptPasswordEncoder { + saltLen := 16 + saltGen := keygen.NewSecureRandomBytesKeyGenerator(saltLen) + return &SCryptPasswordEncoder{ + saltGen: saltGen, + cpuCost: 65536, + memoryCost: 8, + parallelization: 1, + keyLen: 32, + } +} + +func (e *SCryptPasswordEncoder) Encode(rawPassword string) (string, error) { + salt, err := e.saltGen.GenerateKey() + if err != nil { + return "", err + } + + derived, err := scrypt.Key([]byte(rawPassword), salt, e.cpuCost, e.memoryCost, e.parallelization, e.keyLen) + return e.encode(derived, salt), nil +} + +func (e *SCryptPasswordEncoder) encode(derived, salt []byte) string { + params := ((int)(math.Log2(float64(e.cpuCost))) << 16) | e.memoryCost<<8 | e.parallelization + var sb strings.Builder + sb.WriteString("$") + sb.WriteString(strconv.FormatInt(int64(params), 16)) + sb.WriteString("$") + sb.WriteString(e.encodePart(salt)) + sb.WriteString("$") + sb.WriteString(e.encodePart(derived)) + return sb.String() +} + +func (e *SCryptPasswordEncoder) encodePart(part []byte) string { + encoded := base64.StdEncoding.EncodeToString(part) + // encoded = strings.Replace(encoded, "\n", "", -1) + return encoded +} + +func (e *SCryptPasswordEncoder) Matches(rawPassword string, encodedPassword string) bool { + parts := strings.Split(encodedPassword, "$") + if len(parts) != 4 { // ["", params, salt, derived] + return false + } + + params, err := strconv.ParseInt(parts[1], 16, 64) + if err != nil { + return false + } + + salt, err := base64.StdEncoding.DecodeString(parts[2]) + if err != nil { + return false + } + + derived, err := base64.StdEncoding.DecodeString(parts[3]) + if err != nil { + return false + } + + cpuCost := int(math.Pow(2, float64(params>>16&0xffff))) + memoryCost := int(params) >> 8 & 0xff + parallelization := int(params) & 0xff + + generated, err := scrypt.Key([]byte(rawPassword), salt, cpuCost, memoryCost, parallelization, e.keyLen) + + return bytes.Equal(derived, generated) +} + +func (e *SCryptPasswordEncoder) UpgradeEncoding(encodedPassword string) bool { + // TODO: compare cost + return false +} diff --git a/password/scrypt_test.go b/password/scrypt_test.go new file mode 100644 index 0000000..c8dc70e --- /dev/null +++ b/password/scrypt_test.go @@ -0,0 +1,75 @@ +package password + +import ( + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/xuyang2/password-encoder/keygen/keygentest" +) + +func TestSCryptPasswordEncoder_Matches(t *testing.T) { + t.Run("ok", func(t *testing.T) { + encoder := DefaultSCryptPasswordEncoder() + + rawPassword := "myPassword" + encodedPassword, err := encoder.Encode(rawPassword) + + assert.NoError(t, err) + assert.NotEqual(t, rawPassword, encodedPassword) + + assert.True(t, encoder.Matches(rawPassword, encodedPassword)) + assert.False(t, encoder.Matches(rawPassword+"a", encodedPassword)) + + parts := strings.Split(encodedPassword, "$") + assert.Len(t, parts, 4) + assert.True(t, encoder.Matches(rawPassword, strings.Join([]string{"", parts[1], parts[2], parts[3]}, "$"))) + assert.False(t, encoder.Matches(rawPassword, strings.Join([]string{"", "_", parts[2], parts[3]}, "$"))) + assert.False(t, encoder.Matches(rawPassword, strings.Join([]string{"", parts[1], "_", parts[3]}, "$"))) + assert.False(t, encoder.Matches(rawPassword, strings.Join([]string{"", parts[1], parts[2], "_"}, "$"))) + + assert.False(t, encoder.Matches(rawPassword, "")) + }) + + t.Run("spring-security encoded", func(t *testing.T) { + encoder := DefaultSCryptPasswordEncoder() + + // SCryptPasswordEncoder encoder = SCryptPasswordEncoder.defaultsForSpringSecurity_v5_8(); + // String encodedPassword = encoder.encode("myPassword"); + rawPassword := "myPassword" + encodedPassword := "$100801$4P6llsBJYk/EbyFZaq6yyw==$+G59NWVc3S/n67Eo5+bxjY7RP9NsDAclJzorgIet0Rs=" + + assert.True(t, encoder.Matches(rawPassword, encodedPassword)) + assert.False(t, encoder.Matches(rawPassword+"a", encodedPassword)) + }) +} + +func TestSCryptPasswordEncoder_Encode(t *testing.T) { + t.Run("no err", func(t *testing.T) { + encoder := DefaultSCryptPasswordEncoder() + encoded, err := encoder.Encode("?") + assert.NoError(t, err) + assert.True(t, encoded != "") + }) + + t.Run("err saltGen", func(t *testing.T) { + encoder := DefaultSCryptPasswordEncoder() + encoder.saltGen = keygentest.ErrBytesKeyGenerator(errors.New("oops"), 8) + _, err := encoder.Encode("?") + assert.Error(t, err) + }) +} + +func TestSCryptPasswordEncoder_UpgradeEncoding(t *testing.T) { + t.Run("always false", func(t *testing.T) { + encoder := DefaultSCryptPasswordEncoder() + + encodedPassword, err := encoder.Encode("password") + require.NoError(t, err) + + assert.Equal(t, false, encoder.UpgradeEncoding(encodedPassword)) + }) +}