Skip to content

Commit

Permalink
add SCryptPasswordEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyang2 committed Jul 2, 2024
1 parent 94082c1 commit 7ec659e
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 0 deletions.
99 changes: 99 additions & 0 deletions password/scrypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package password

import (
"bytes"
"encoding/base64"
"math"
"strconv"
"strings"

"golang.org/x/crypto/scrypt"

"github.com/xuyang2/password-encoder/keygen"
)

type SCryptPasswordEncoder struct {
saltGen keygen.BytesKeyGenerator

cpuCost int // cpu cost of the algorithm (as defined in scrypt this is N)
memoryCost int // memory cost of the algorithm (as defined in scrypt this is r)
parallelization int // the parallelization of the algorithm (as defined in scrypt this is p)
keyLen int
}

var _ PasswordEncoder = (*SCryptPasswordEncoder)(nil)

func DefaultSCryptPasswordEncoder() *SCryptPasswordEncoder {
saltLen := 16
saltGen := keygen.NewSecureRandomBytesKeyGenerator(saltLen)
return &SCryptPasswordEncoder{
saltGen: saltGen,
cpuCost: 65536,
memoryCost: 8,
parallelization: 1,
keyLen: 32,
}
}

func (e *SCryptPasswordEncoder) Encode(rawPassword string) (string, error) {
salt, err := e.saltGen.GenerateKey()
if err != nil {
return "", err
}

derived, err := scrypt.Key([]byte(rawPassword), salt, e.cpuCost, e.memoryCost, e.parallelization, e.keyLen)
return e.encode(derived, salt), nil
}

func (e *SCryptPasswordEncoder) encode(derived, salt []byte) string {
params := ((int)(math.Log2(float64(e.cpuCost))) << 16) | e.memoryCost<<8 | e.parallelization
var sb strings.Builder
sb.WriteString("$")
sb.WriteString(strconv.FormatInt(int64(params), 16))
sb.WriteString("$")
sb.WriteString(e.encodePart(salt))
sb.WriteString("$")
sb.WriteString(e.encodePart(derived))
return sb.String()
}

func (e *SCryptPasswordEncoder) encodePart(part []byte) string {
encoded := base64.StdEncoding.EncodeToString(part)
// encoded = strings.Replace(encoded, "\n", "", -1)
return encoded
}

func (e *SCryptPasswordEncoder) Matches(rawPassword string, encodedPassword string) bool {
parts := strings.Split(encodedPassword, "$")
if len(parts) != 4 { // ["", params, salt, derived]
return false
}

params, err := strconv.ParseInt(parts[1], 16, 64)
if err != nil {
return false
}

salt, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}

derived, err := base64.StdEncoding.DecodeString(parts[3])
if err != nil {
return false
}

cpuCost := int(math.Pow(2, float64(params>>16&0xffff)))
memoryCost := int(params) >> 8 & 0xff
parallelization := int(params) & 0xff

generated, err := scrypt.Key([]byte(rawPassword), salt, cpuCost, memoryCost, parallelization, e.keyLen)

return bytes.Equal(derived, generated)
}

func (e *SCryptPasswordEncoder) UpgradeEncoding(encodedPassword string) bool {
// TODO: compare cost
return false
}
75 changes: 75 additions & 0 deletions password/scrypt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package password

import (
"errors"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/xuyang2/password-encoder/keygen/keygentest"
)

func TestSCryptPasswordEncoder_Matches(t *testing.T) {
t.Run("ok", func(t *testing.T) {
encoder := DefaultSCryptPasswordEncoder()

rawPassword := "myPassword"
encodedPassword, err := encoder.Encode(rawPassword)

assert.NoError(t, err)
assert.NotEqual(t, rawPassword, encodedPassword)

assert.True(t, encoder.Matches(rawPassword, encodedPassword))
assert.False(t, encoder.Matches(rawPassword+"a", encodedPassword))

parts := strings.Split(encodedPassword, "$")
assert.Len(t, parts, 4)
assert.True(t, encoder.Matches(rawPassword, strings.Join([]string{"", parts[1], parts[2], parts[3]}, "$")))
assert.False(t, encoder.Matches(rawPassword, strings.Join([]string{"", "_", parts[2], parts[3]}, "$")))
assert.False(t, encoder.Matches(rawPassword, strings.Join([]string{"", parts[1], "_", parts[3]}, "$")))
assert.False(t, encoder.Matches(rawPassword, strings.Join([]string{"", parts[1], parts[2], "_"}, "$")))

assert.False(t, encoder.Matches(rawPassword, ""))
})

t.Run("spring-security encoded", func(t *testing.T) {
encoder := DefaultSCryptPasswordEncoder()

// SCryptPasswordEncoder encoder = SCryptPasswordEncoder.defaultsForSpringSecurity_v5_8();
// String encodedPassword = encoder.encode("myPassword");
rawPassword := "myPassword"
encodedPassword := "$100801$4P6llsBJYk/EbyFZaq6yyw==$+G59NWVc3S/n67Eo5+bxjY7RP9NsDAclJzorgIet0Rs="

assert.True(t, encoder.Matches(rawPassword, encodedPassword))
assert.False(t, encoder.Matches(rawPassword+"a", encodedPassword))
})
}

func TestSCryptPasswordEncoder_Encode(t *testing.T) {
t.Run("no err", func(t *testing.T) {
encoder := DefaultSCryptPasswordEncoder()
encoded, err := encoder.Encode("?")
assert.NoError(t, err)
assert.True(t, encoded != "")
})

t.Run("err saltGen", func(t *testing.T) {
encoder := DefaultSCryptPasswordEncoder()
encoder.saltGen = keygentest.ErrBytesKeyGenerator(errors.New("oops"), 8)
_, err := encoder.Encode("?")
assert.Error(t, err)
})
}

func TestSCryptPasswordEncoder_UpgradeEncoding(t *testing.T) {
t.Run("always false", func(t *testing.T) {
encoder := DefaultSCryptPasswordEncoder()

encodedPassword, err := encoder.Encode("password")
require.NoError(t, err)

assert.Equal(t, false, encoder.UpgradeEncoding(encodedPassword))
})
}

0 comments on commit 7ec659e

Please sign in to comment.