diff --git a/siwe.go b/siwe.go index e9b9e5b..7be0341 100644 --- a/siwe.go +++ b/siwe.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" @@ -79,13 +80,13 @@ func InitMessage(domain, address, uri, nonce string, options map[string]interfac var chainId int if val, ok := options["chainId"]; ok { - switch val.(type) { + switch tv := val.(type) { case float64: - chainId = int(val.(float64)) + chainId = int(tv) case int: - chainId = val.(int) + chainId = tv case string: - parsed, err := strconv.Atoi(val.(string)) + parsed, err := strconv.Atoi(tv) if err != nil { return nil, &InvalidMessage{"Invalid format for field `chainId`, must be an integer"} } @@ -136,10 +137,8 @@ func InitMessage(domain, address, uri, nonce string, options map[string]interfac var resources []url.URL if val, ok := options["resources"]; ok { - switch val.(type) { - case []url.URL: - resources = val.([]url.URL) - default: + resources, ok = val.([]url.URL) + if !ok { return nil, &InvalidMessage{"`resources` must be a []url.URL"} } } @@ -165,7 +164,6 @@ func InitMessage(domain, address, uri, nonce string, options map[string]interfac func parseMessage(message string) (map[string]interface{}, error) { match := _SIWE_MESSAGE.FindStringSubmatch(message) - if match == nil { return nil, &InvalidMessage{"Message could not be parsed"} } @@ -177,18 +175,18 @@ func parseMessage(message string) (map[string]interface{}, error) { } } - if _, ok := result["domain"]; !ok { + domain, ok := result["domain"].(string) + if !ok { return nil, &InvalidMessage{"`domain` must not be empty"} } - domain := result["domain"].(string) if ok, err := validateDomain(&domain); !ok { return nil, err } - if _, ok := result["uri"]; !ok { + uri, ok := result["uri"].(string) + if !ok { return nil, &InvalidMessage{"`domain` must not be empty"} } - uri := result["uri"].(string) if _, err := validateURI(&uri); err != nil { return nil, err } @@ -200,12 +198,20 @@ func parseMessage(message string) (map[string]interface{}, error) { } if val, ok := result["resources"]; ok { - resources := strings.Split(val.(string), "\n- ")[1:] + resourcesStr, ok := val.(string) + if !ok { + return nil, &InvalidMessage{fmt.Sprintf("resources is not a string but %T", val)} + } + resources := strings.Split(resourcesStr, "\n- ") + if len(resources) < 1 { + return nil, &InvalidMessage{"expected at least one resource"} + } + resources = resources[1:] validateResources := make([]url.URL, len(resources)) for i, resource := range resources { validateResource, err := url.Parse(resource) if err != nil { - return nil, &InvalidMessage{fmt.Sprintf("Invalid format for field `resources` at position %d", i)} + return nil, &InvalidMessage{fmt.Sprintf("Invalid format for field `resources` at position %d: %s", i, err)} } validateResources[i] = *validateResource } @@ -237,11 +243,9 @@ func ParseMessage(message string) (*Message, error) { return parsed, nil } -func (m *Message) eip191Hash() common.Hash { - // Ref: https://stackoverflow.com/questions/49085737/geth-ecrecover-invalid-signature-recovery-id +func (m *Message) eip191Hash() []byte { data := []byte(m.String()) - msg := fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(data), data) - return crypto.Keccak256Hash([]byte(msg)) + return accounts.TextHash(data) } // ValidNow validates the time constraints of the message at current time. @@ -283,7 +287,7 @@ func (m *Message) VerifyEIP191(signature string) (*ecdsa.PublicKey, error) { return nil, &InvalidSignature{"Invalid signature recovery byte"} } - pkey, err := crypto.SigToPub(m.eip191Hash().Bytes(), sigBytes) + pkey, err := crypto.SigToPub(m.eip191Hash(), sigBytes) if err != nil { return nil, &InvalidSignature{"Failed to recover public key from signature"} } diff --git a/siwe_test.go b/siwe_test.go index 5c4bb5e..a47c349 100644 --- a/siwe_test.go +++ b/siwe_test.go @@ -3,7 +3,7 @@ package siwe import ( "crypto/ecdsa" "encoding/json" - "io/ioutil" + "io" "net/url" "os" "strconv" @@ -239,7 +239,7 @@ func TestValidate(t *testing.T) { assert.Nil(t, err) hash := message.eip191Hash() - signature, err := crypto.Sign(hash.Bytes(), privateKey) + signature, err := crypto.Sign(hash, privateKey) signature[64] += 27 assert.Nil(t, err) @@ -257,7 +257,7 @@ func TestValidateTampered(t *testing.T) { assert.Nil(t, err) hash := message.eip191Hash() - signature, err := crypto.Sign(hash.Bytes(), privateKey) + signature, err := crypto.Sign(hash, privateKey) signature[64] += 27 assert.Nil(t, err) @@ -408,11 +408,12 @@ func TestGlobalTestVector(t *testing.T) { } for test, file := range files { - data, _ := ioutil.ReadAll(file) + data, err := io.ReadAll(file) + assert.NoError(t, err, test) var result map[string]interface{} - err := json.Unmarshal([]byte(data), &result) - assert.Nil(t, err) + err = json.Unmarshal([]byte(data), &result) + assert.NoError(t, err) switch test { case "parsing-negative":