diff --git a/crypto/crypto.go b/crypto/crypto.go index 7e38866..7661573 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -1,10 +1,18 @@ package crypto import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" "encoding/base64" + "errors" + "io" + "strings" "github.com/gogo/protobuf/proto" ic "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" ) func Sign(key ic.PrivKey, channelMessage proto.Message) ([]byte, error) { @@ -34,6 +42,117 @@ func ToPrivKey(privKey string) (ic.PrivKey, error) { } // public key string to ic.PubKey interface -func ToPubKey(pubKey []byte) (ic.PubKey, error) { +func ToPubKey(pubKey string) (ic.PubKey, error) { + raw, err := base64.StdEncoding.DecodeString(pubKey) + if err != nil { + return nil, err + } + return ic.UnmarshalSecp256k1PublicKey(raw) +} + +// Secp256k1 private key string to ic.PrivKey interface +func ToPrivKeyRaw(privKey []byte) (ic.PrivKey, error) { + return ic.UnmarshalSecp256k1PrivateKey(privKey) +} + +// public key string to ic.PubKey interface +func ToPubKeyRaw(pubKey []byte) (ic.PubKey, error) { return ic.UnmarshalSecp256k1PublicKey(pubKey) } + +func GenKeyPairs() (ic.PrivKey, ic.PubKey, error) { + return ic.GenerateSecp256k1Key(rand.Reader) +} + +func addBase64Padding(text []byte) string { + value := string(text) + m := len(value) % 4 + if m != 0 { + value += strings.Repeat("=", 4-m) + } + + return value +} + +func removeBase64Padding(value string) string { + return strings.Replace(value, "=", "", -1) +} + +// pkcs7 padding +func Pad(src []byte) []byte { + padding := aes.BlockSize - len(src)%aes.BlockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padtext...) +} + +func Unpad(src []byte) ([]byte, error) { + length := len(src) + unpadding := int(src[length-1]) + + if unpadding > length { + return nil, errors.New("unpad error. This could happen when incorrect encryption key is used") + } + + return src[:(length - unpadding)], nil +} + +func Encrypt(key, text []byte) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + msg := Pad(text) + ciphertext := make([]byte, aes.BlockSize+len(msg)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return "", err + } + + cfb := cipher.NewCFBEncrypter(block, iv) + cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(msg)) + finalMsg := removeBase64Padding(base64.URLEncoding.EncodeToString(ciphertext)) + return finalMsg, nil +} + +func Decrypt(key, text []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + decodedMsg, err := base64.URLEncoding.DecodeString(addBase64Padding(text)) + if err != nil { + return nil, err + } + + if (len(decodedMsg) % aes.BlockSize) != 0 { + return nil, errors.New("blocksize must be multiple of decoded message length") + } + + iv := decodedMsg[:aes.BlockSize] + msg := decodedMsg[aes.BlockSize:] + + cfb := cipher.NewCFBDecrypter(block, iv) + cfb.XORKeyStream(msg, msg) + + unpadMsg, err := Unpad(msg) + if err != nil { + return nil, err + } + + return unpadMsg, nil +} + +func getPubKeyFromPeerId(pid string) (ic.PubKey, error) { + peerId, err := peer.IDB58Decode(pid) + if err != nil { + return nil, err + } + pubKey, err2 := peerId.ExtractPublicKey() + if err2 != nil { + return nil, err2 + } + return pubKey, nil +} + diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index 69b35aa..b09e4ff 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -8,6 +8,7 @@ import ( const ( KeyString = "CAISIJFNZZd5ZSvi9OlJP/mz/vvUobvlrr2//QN4DzX/EShP" + EncryptKey = "Tron2theMoon1234" ) func TestSignVerify(t *testing.T) { @@ -23,9 +24,9 @@ func TestSignVerify(t *testing.T) { t.Error("get raw public key from privKey failed") return } - pubKey, err := ToPubKey(rawPubKey) + pubKey, err := ToPubKeyRaw(rawPubKey) if err != nil { - t.Error("ToPubKey failed") + t.Error("ToPubKeyRaw failed") return } @@ -44,3 +45,13 @@ func TestSignVerify(t *testing.T) { t.Error("Verify with public key failed") } } + +func TestEncryptDecrypt(t *testing.T) { + origin := "Hello World" + key := []byte(EncryptKey) + encryptMsg, _ := Encrypt(key, []byte(origin)) + msg, _ := Decrypt(key, []byte(encryptMsg)) + if string(msg) != origin { + t.Errorf("Decrypt failed") + } +}