Skip to content

Commit

Permalink
update to latest embit
Browse files Browse the repository at this point in the history
  • Loading branch information
Stepan Snigirev committed Dec 3, 2023
1 parent e0e439d commit db3ce3e
Show file tree
Hide file tree
Showing 33 changed files with 1,959 additions and 916 deletions.
54 changes: 33 additions & 21 deletions libs/common/embit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,56 @@ class EmbitError(Exception):
pass


def copy(a:bytes) -> bytes:
"""Ugly copy that works everywhere incl micropython"""
if len(a) == 0:
return b""
return a[:1] + a[1:]

class EmbitBase:
@classmethod
def read_from(cls, stream, *args, **kwargs):
"""All classes should be readable from stream"""
raise NotImplementedError(
"%s doesn't implement reading from stream" % type(cls).__name__
"%s doesn't implement reading from stream" % cls.__name__
)

@classmethod
def parse(cls, s, *args, **kwargs):
"""Parses a string or a byte sequence"""
if isinstance(s, str):
s = s.encode()
def parse(cls, s: bytes, *args, **kwargs):
"""Parse raw bytes"""
stream = BytesIO(s)
res = cls.read_from(stream, *args, **kwargs)
if len(stream.read(1)) > 0:
raise EmbitError("Unexpected extra bytes")
return res

def write_to(self, stream, *args, **kwargs):
def write_to(self, stream, *args, **kwargs) -> int:
"""All classes should be writable to stream"""
raise NotImplementedError(
"%s doesn't implement writing to stream" % type(self).__name__
)

def serialize(self, *args, **kwargs):
def serialize(self, *args, **kwargs) -> bytes:
"""Serialize instance to raw bytes"""
stream = BytesIO()
self.write_to(stream, *args, **kwargs)
return stream.getvalue()

def to_string(self, *args, **kwargs):
"""Default string representation is hex of serialized instance or base58 if available"""
def to_string(self, *args, **kwargs) -> str:
"""
String representation.
If not implemented - uses hex or calls to_base58() method if defined.
"""
if hasattr(self, "to_base58"):
return self.to_base58(*args, **kwargs)
res = self.to_base58(*args, **kwargs)
if not isinstance(res, str):
raise ValueError("to_base58() must return string")
return res
return hexlify(self.serialize(*args, **kwargs)).decode()

@classmethod
def from_string(cls, s, *args, **kwargs):
"""Default string representation is hex of serialized instance or base58 if availabe"""
"""Create class instance from string"""
if hasattr(cls, "from_base58"):
return cls.from_base58(s, *args, **kwargs)
return cls.parse(unhexlify(s))

def __str__(self):
"""to_string() can accept kwargs with defaults so str() should work"""
"""Internally calls `to_string()` method with no arguments"""
return self.to_string()

def __repr__(self):
Expand All @@ -69,6 +68,9 @@ def __repr__(self):
return type(self).__name__ + "()"

def __eq__(self, other):
"""Compare two objects by checking their serializations"""
if not hasattr(other, "serialize"):
return False
return self.serialize() == other.serialize()

def __ne__(self, other):
Expand All @@ -79,15 +81,25 @@ def __hash__(self):


class EmbitKey(EmbitBase):
def sec(self):
"""Any EmbitKey should implement sec() method that returns sec-serialized public key"""
def sec(self) -> bytes:
"""
Any EmbitKey should implement sec() method that returns
a sec-serialized public key
"""
raise NotImplementedError(
"%s doesn't implement sec() method" % type(self).__name__
)

def xonly(self) -> bytes:
"""xonly representation of the key"""
return self.sec()[1:33]

@property
def is_private(self) -> bool:
"""Any EmbitKey should implement is_private property"""
"""
Any EmbitKey should implement `is_private` property to distinguish
between private and public keys.
"""
raise NotImplementedError(
"%s doesn't implement is_private property" % type(self).__name__
)
Expand Down
16 changes: 8 additions & 8 deletions libs/common/embit/base58.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@
B58_DIGITS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"


def encode(b):
def encode(b: bytes) -> str:
"""Encode bytes to a base58-encoded string"""

# Convert big-endian bytes to integer
n = int("0x0" + binascii.hexlify(b).decode("utf8"), 16)

# Divide that integer into bas58
res = []
chars = []
while n > 0:
n, r = divmod(n, 58)
res.append(B58_DIGITS[r])
res = "".join(res[::-1])
chars.append(B58_DIGITS[r])
result = "".join(chars[::-1])

pad = 0
for c in b:
if c == 0:
pad += 1
else:
break
return B58_DIGITS[0] * pad + res
return B58_DIGITS[0] * pad + result


def decode(s):
def decode(s: str) -> bytes:
"""Decode a base58-encoding string, returning bytes"""
if not s:
return b""
Expand Down Expand Up @@ -61,12 +61,12 @@ def decode(s):
return b"\x00" * pad + res


def encode_check(b):
def encode_check(b: bytes) -> str:
"""Encode bytes to a base58-encoded string with a checksum"""
return encode(b + hashes.double_sha256(b)[0:4])


def decode_check(s):
def decode_check(s: str) -> bytes:
"""Decode a base58-encoding string with checksum check.
Returns bytes without checksum
"""
Expand Down
31 changes: 19 additions & 12 deletions libs/common/embit/bech32.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,33 @@
# THE SOFTWARE.

"""Reference implementation for Bech32 and segwit addresses."""
from .misc import const

CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
BECH32_CONST = 1
BECH32M_CONST = 0x2bc830a3
BECH32_CONST = const(1)
BECH32M_CONST = const(0x2BC830A3)


class Encoding:
"""Enumeration type to list the various supported encodings."""

BECH32 = 1
BECH32M = 2


def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum."""
generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
chk = 1
for value in values:
top = chk >> 25
chk = (chk & 0x1ffffff) << 5 ^ value
chk = (chk & 0x1FFFFFF) << 5 ^ value
for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0
return chk


def bech32_hrp_expand(hrp):
def bech32_hrp_expand(hrp: str):
"""Expand the HRP into values for checksum computation."""
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]

Expand All @@ -57,6 +60,7 @@ def bech32_verify_checksum(hrp, data):
else:
return None


def bech32_create_checksum(encoding, hrp, data):
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
Expand All @@ -68,22 +72,23 @@ def bech32_create_checksum(encoding, hrp, data):
def bech32_encode(encoding, hrp, data):
"""Compute a Bech32 or Bech32m string given HRP and data values."""
combined = data + bech32_create_checksum(encoding, hrp, data)
return hrp + '1' + ''.join([CHARSET[d] for d in combined])
return hrp + "1" + "".join([CHARSET[d] for d in combined])


def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
(bech.lower() != bech and bech.upper() != bech)):
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
bech.lower() != bech and bech.upper() != bech
):
return (None, None, None)
bech = bech.lower()
pos = bech.rfind('1')
pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None)
if not all(x in CHARSET for x in bech[pos+1:]):
if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None, None)
hrp = bech[:pos]
data = [CHARSET.find(x) for x in bech[pos+1:]]
data = [CHARSET.find(x) for x in bech[pos + 1 :]]
encoding = bech32_verify_checksum(hrp, data)
if encoding is None:
return (None, None, None)
Expand Down Expand Up @@ -125,7 +130,9 @@ def decode(hrp, addr):
return (None, None)
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
return (None, None)
if (data[0] == 0 and encoding != Encoding.BECH32) or (data[0] != 0 and encoding != Encoding.BECH32M):
if (data[0] == 0 and encoding != Encoding.BECH32) or (
data[0] != 0 and encoding != Encoding.BECH32M
):
return (None, None)
return (data[0], decoded)

Expand Down
Loading

0 comments on commit db3ce3e

Please sign in to comment.