Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup #29

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions siwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"time"

"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -79,13 +80,13 @@ func InitMessage(domain, address, uri, nonce string, options map[string]interfac

var chainId int
if val, ok := options["chainId"]; ok {
switch val.(type) {
switch tv := val.(type) {
case float64:
chainId = int(val.(float64))
chainId = int(tv)
case int:
chainId = val.(int)
chainId = tv
case string:
parsed, err := strconv.Atoi(val.(string))
parsed, err := strconv.Atoi(tv)
if err != nil {
return nil, &InvalidMessage{"Invalid format for field `chainId`, must be an integer"}
}
Expand Down Expand Up @@ -136,10 +137,8 @@ func InitMessage(domain, address, uri, nonce string, options map[string]interfac

var resources []url.URL
if val, ok := options["resources"]; ok {
switch val.(type) {
case []url.URL:
resources = val.([]url.URL)
default:
resources, ok = val.([]url.URL)
if !ok {
return nil, &InvalidMessage{"`resources` must be a []url.URL"}
}
}
Expand All @@ -165,7 +164,6 @@ func InitMessage(domain, address, uri, nonce string, options map[string]interfac

func parseMessage(message string) (map[string]interface{}, error) {
match := _SIWE_MESSAGE.FindStringSubmatch(message)

if match == nil {
return nil, &InvalidMessage{"Message could not be parsed"}
}
Expand All @@ -177,18 +175,18 @@ func parseMessage(message string) (map[string]interface{}, error) {
}
}

if _, ok := result["domain"]; !ok {
domain, ok := result["domain"].(string)
if !ok {
return nil, &InvalidMessage{"`domain` must not be empty"}
}
domain := result["domain"].(string)
if ok, err := validateDomain(&domain); !ok {
return nil, err
}

if _, ok := result["uri"]; !ok {
uri, ok := result["uri"].(string)
if !ok {
return nil, &InvalidMessage{"`domain` must not be empty"}
}
uri := result["uri"].(string)
if _, err := validateURI(&uri); err != nil {
return nil, err
}
Expand All @@ -200,12 +198,20 @@ func parseMessage(message string) (map[string]interface{}, error) {
}

if val, ok := result["resources"]; ok {
resources := strings.Split(val.(string), "\n- ")[1:]
resourcesStr, ok := val.(string)
if !ok {
return nil, &InvalidMessage{fmt.Sprintf("resources is not a string but %T", val)}
}
resources := strings.Split(resourcesStr, "\n- ")
if len(resources) < 1 {
return nil, &InvalidMessage{"expected at least one resource"}
}
resources = resources[1:]
validateResources := make([]url.URL, len(resources))
for i, resource := range resources {
validateResource, err := url.Parse(resource)
if err != nil {
return nil, &InvalidMessage{fmt.Sprintf("Invalid format for field `resources` at position %d", i)}
return nil, &InvalidMessage{fmt.Sprintf("Invalid format for field `resources` at position %d: %s", i, err)}
}
validateResources[i] = *validateResource
}
Expand Down Expand Up @@ -237,11 +243,9 @@ func ParseMessage(message string) (*Message, error) {
return parsed, nil
}

func (m *Message) eip191Hash() common.Hash {
// Ref: https://stackoverflow.com/questions/49085737/geth-ecrecover-invalid-signature-recovery-id
func (m *Message) eip191Hash() []byte {
data := []byte(m.String())
msg := fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(data), data)
return crypto.Keccak256Hash([]byte(msg))
return accounts.TextHash(data)
}

// ValidNow validates the time constraints of the message at current time.
Expand Down Expand Up @@ -283,7 +287,7 @@ func (m *Message) VerifyEIP191(signature string) (*ecdsa.PublicKey, error) {
return nil, &InvalidSignature{"Invalid signature recovery byte"}
}

pkey, err := crypto.SigToPub(m.eip191Hash().Bytes(), sigBytes)
pkey, err := crypto.SigToPub(m.eip191Hash(), sigBytes)
if err != nil {
return nil, &InvalidSignature{"Failed to recover public key from signature"}
}
Expand Down
13 changes: 7 additions & 6 deletions siwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package siwe
import (
"crypto/ecdsa"
"encoding/json"
"io/ioutil"
"io"
"net/url"
"os"
"strconv"
Expand Down Expand Up @@ -239,7 +239,7 @@ func TestValidate(t *testing.T) {
assert.Nil(t, err)

hash := message.eip191Hash()
signature, err := crypto.Sign(hash.Bytes(), privateKey)
signature, err := crypto.Sign(hash, privateKey)
signature[64] += 27

assert.Nil(t, err)
Expand All @@ -257,7 +257,7 @@ func TestValidateTampered(t *testing.T) {
assert.Nil(t, err)

hash := message.eip191Hash()
signature, err := crypto.Sign(hash.Bytes(), privateKey)
signature, err := crypto.Sign(hash, privateKey)
signature[64] += 27

assert.Nil(t, err)
Expand Down Expand Up @@ -408,11 +408,12 @@ func TestGlobalTestVector(t *testing.T) {
}

for test, file := range files {
data, _ := ioutil.ReadAll(file)
data, err := io.ReadAll(file)
assert.NoError(t, err, test)

var result map[string]interface{}
err := json.Unmarshal([]byte(data), &result)
assert.Nil(t, err)
err = json.Unmarshal([]byte(data), &result)
assert.NoError(t, err)

switch test {
case "parsing-negative":
Expand Down