Skip to content

Commit

Permalink
abstraction: get part from full encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
ahangsu committed Jan 3, 2023
1 parent b0793c6 commit 44318bb
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 60 deletions.
43 changes: 21 additions & 22 deletions pyteal/ast/abi/array_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from pyteal.ast.abi.tuple import _encode_tuple
from pyteal.ast.abi.bool import Bool, BoolTypeSpec
from pyteal.ast.abi.uint import Uint16, Uint16TypeSpec
from pyteal.ast.abi.util import substring_for_decoding
from pyteal.ast.abi.util import (
substring_for_decoding,
_get_encoding_or_store_from_encoded_bytes,
)

T = TypeVar("T", bound=BaseType)

Expand Down Expand Up @@ -223,11 +226,9 @@ def __prototype_encoding_store_into(self, output: T | None = None) -> Expr:
bitIndex = self.index
if arrayType.is_dynamic():
bitIndex = bitIndex + Int(Uint16TypeSpec().bit_size())

if output is not None:
return cast(Bool, output).decode_bit(encodedArray, bitIndex)
else:
return SetBit(Bytes(b"\x00"), Int(0), GetBit(encodedArray, bitIndex))
return _get_encoding_or_store_from_encoded_bytes(
BoolTypeSpec(), encodedArray, output, start_index=bitIndex
)

# Compute the byteIndex (first byte indicating the element encoding)
# (If the array is dynamic, add 2 to byte index for dynamic array length uint16 prefix)
Expand Down Expand Up @@ -265,28 +266,26 @@ def __prototype_encoding_store_into(self, output: T | None = None) -> Expr:
.Else(nextValueStart)
)

if output is not None:
return output.decode(
encodedArray, start_index=valueStart, end_index=valueEnd
)
else:
return substring_for_decoding(
encodedArray, start_index=valueStart, end_index=valueEnd
)
return _get_encoding_or_store_from_encoded_bytes(
arrayType.value_type_spec(),
encodedArray,
output,
start_index=valueStart,
end_index=valueEnd,
)

# Handling case for array elements are static:
# since array._stride() is element's static byte length
# we partition the substring for array element.
valueStart = byteIndex
valueLength = Int(arrayType._stride())
if output is not None:
return output.decode(
encodedArray, start_index=valueStart, length=valueLength
)
else:
return substring_for_decoding(
encodedArray, start_index=valueStart, length=valueLength
)
return _get_encoding_or_store_from_encoded_bytes(
arrayType.value_type_spec(),
encodedArray,
output,
start_index=valueStart,
length=valueLength,
)

def store_into(self, output: T) -> Expr:
"""Partitions the byte string of the given ABI array and stores the byte string of array
Expand Down
72 changes: 34 additions & 38 deletions pyteal/ast/abi/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from collections import OrderedDict

from pyteal.types import TealType
from pyteal.types import TealType, require_type
from pyteal.errors import TealInputError, TealInternalError
from pyteal.ast.expr import Expr
from pyteal.ast.seq import Seq
Expand All @@ -38,7 +38,11 @@
_bool_aware_static_byte_length,
)
from pyteal.ast.abi.uint import NUM_BITS_IN_BYTE, Uint16
from pyteal.ast.abi.util import substring_for_decoding, type_spec_from_annotation
from pyteal.ast.abi.util import (
substring_for_decoding,
type_spec_from_annotation,
_get_encoding_or_store_from_encoded_bytes,
)


def _encode_tuple(values: Sequence[BaseType]) -> Expr:
Expand Down Expand Up @@ -124,6 +128,9 @@ class _IndexTuple:
value_types: Sequence[TypeSpec]
encoded: Expr

def __post_init__(self):
require_type(self.encoded, TealType.bytes)

def __call__(self, index: int, output: BaseType | None = None) -> Expr:
if index not in range(len(self.value_types)):
raise ValueError("Index outside of range")
Expand Down Expand Up @@ -165,16 +172,12 @@ def __call__(self, index: int, output: BaseType | None = None) -> Expr:
# value is the beginning of a bool sequence (or a single bool)
bitOffsetInEncoded = offset * NUM_BITS_IN_BYTE

if output is None:
return SetBit(
Bytes(b"\x00"),
Int(0),
GetBit(self.encoded, Int(bitOffsetInEncoded)),
)
else:
return cast(Bool, output).decode_bit(
self.encoded, Int(bitOffsetInEncoded)
)
return _get_encoding_or_store_from_encoded_bytes(
BoolTypeSpec(),
self.encoded,
output,
start_index=Int(bitOffsetInEncoded),
)

if valueType.is_dynamic():
hasNextDynamicValue = False
Expand Down Expand Up @@ -203,22 +206,20 @@ def __call__(self, index: int, output: BaseType | None = None) -> Expr:
if not hasNextDynamicValue:
# This is the final dynamic value, so decode the substring from start_index to the end of
# encoded
if output is None:
return substring_for_decoding(self.encoded, start_index=start_index)
else:
return output.decode(self.encoded, start_index=start_index)
return _get_encoding_or_store_from_encoded_bytes(
valueType, self.encoded, output, start_index=start_index
)

# There is a dynamic value after this one, and end_index is where its tail starts, so decode
# the substring from start_index to end_index
end_index = ExtractUint16(self.encoded, Int(nextDynamicValueOffset))
if output is None:
return substring_for_decoding(
self.encoded, start_index=start_index, end_index=end_index
)
else:
return output.decode(
self.encoded, start_index=start_index, end_index=end_index
)
return _get_encoding_or_store_from_encoded_bytes(
valueType,
self.encoded,
output,
start_index=start_index,
end_index=end_index,
)

start_index = Int(offset)
length = Int(valueType.byte_length_static())
Expand All @@ -232,25 +233,20 @@ def __call__(self, index: int, output: BaseType | None = None) -> Expr:
return output.decode(self.encoded)
# This is the last value in the tuple, so decode the substring from start_index to the end of
# encoded
if output is None:
return substring_for_decoding(self.encoded, start_index=start_index)
else:
return output.decode(self.encoded, start_index=start_index)
return _get_encoding_or_store_from_encoded_bytes(
valueType, self.encoded, output, start_index=start_index
)

if offset == 0:
# This is the first value in the tuple, so decode the substring from 0 with length length
if output is None:
return substring_for_decoding(self.encoded, length=length)
else:
return output.decode(self.encoded, length=length)
return _get_encoding_or_store_from_encoded_bytes(
valueType, self.encoded, output, length=length
)

# This is not the first or last value, so decode the substring from start_index with length length
if output is None:
return substring_for_decoding(
self.encoded, start_index=start_index, length=length
)
else:
return output.decode(self.encoded, start_index=start_index, length=length)
return _get_encoding_or_store_from_encoded_bytes(
valueType, self.encoded, output, start_index=start_index, length=length
)


class TupleTypeSpec(TypeSpec):
Expand Down
49 changes: 49 additions & 0 deletions pyteal/ast/abi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import algosdk.abi

from pyteal.errors import TealInputError
from pyteal.types import require_type, TealType
from pyteal.ast.expr import Expr
from pyteal.ast.int import Int
from pyteal.ast.substring import Extract, Substring, Suffix
Expand Down Expand Up @@ -593,3 +594,51 @@ def type_spec_is_assignable_to(a: TypeSpec, b: TypeSpec) -> bool:
return True

return False


def _get_encoding_or_store_from_encoded_bytes(
encoding_type: TypeSpec,
full_encoding: Expr,
output: BaseType | None = None,
*,
start_index: Expr | None = None,
end_index: Expr | None = None,
length: Expr | None = None,
) -> Expr:
from pyteal.ast.abi import BoolTypeSpec, Bool
from pyteal.ast.bytes import Bytes
from pyteal.ast.binaryexpr import GetBit
from pyteal.ast.ternaryexpr import SetBit

require_type(full_encoding, TealType.bytes)

match encoding_type:
case BoolTypeSpec():
if start_index is None:
raise TealInputError(
"on BoolTypeSpec, requiring start index to be not None."
)

if output is None:
return SetBit(
Bytes(b"\x00"),
Int(0),
GetBit(full_encoding, start_index),
)
else:
return cast(Bool, output).decode_bit(full_encoding, start_index)
case _:
if output is None:
return substring_for_decoding(
encoded=full_encoding,
start_index=start_index,
end_index=end_index,
length=length,
)
else:
return output.decode(
full_encoding,
start_index=start_index,
end_index=end_index,
length=length,
)

0 comments on commit 44318bb

Please sign in to comment.