From 87a1965dc71abe137b920d7969dfaac75fa3fb6b Mon Sep 17 00:00:00 2001 From: Jmgr Date: Fri, 22 Nov 2024 18:08:40 +0000 Subject: [PATCH] chore: more fixes --- nada_dsl/nada_types/collections.py | 9 ++++----- test-programs/ntuple_accessor.py | 3 ++- test-programs/object_accessor.py | 1 - 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index 573c430..b526bdd 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -108,7 +108,7 @@ def store_in_ast(self, ty): ) -class TupleType(DslType): +class TupleType(NadaType): """Marker type for Tuples.""" is_compound = True @@ -171,7 +171,7 @@ def _generate_accessor(ty: Any, accessor: Any) -> DslType: return ty.instantiate(accessor) -class NTupleType(DslType): +class NTupleType(NadaType): """Marker type for NTuples.""" is_compound = True @@ -261,7 +261,7 @@ def store_in_ast(self, ty: object): ) -class ObjectType(DslType): +class ObjectType(NadaType): """Marker type for Objects.""" is_compound = True @@ -412,7 +412,7 @@ def store_in_ast(self, ty: DslTypeRepr): ) -class ArrayType(DslType): +class ArrayType(NadaType): """Marker type for arrays.""" is_compound = True @@ -533,7 +533,6 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: if is_primitive_integer(self.contained_type) and is_primitive_integer( other.contained_type ): - return self.contained_type.instantiate( InnerProduct(left=self, right=other, source_ref=SourceRef.back_frame()) ) # type: ignore diff --git a/test-programs/ntuple_accessor.py b/test-programs/ntuple_accessor.py index a5e9e2a..08640b2 100644 --- a/test-programs/ntuple_accessor.py +++ b/test-programs/ntuple_accessor.py @@ -26,5 +26,6 @@ def add(acc: PublicInteger, a: PublicInteger) -> PublicInteger: return [Output(final, "my_output", party1)] + if __name__ == "__main__": - nada_main() \ No newline at end of file + nada_main() diff --git a/test-programs/object_accessor.py b/test-programs/object_accessor.py index 0f5679f..8c006b5 100644 --- a/test-programs/object_accessor.py +++ b/test-programs/object_accessor.py @@ -15,7 +15,6 @@ def nada_main(): array = object.b scalar_2 = object.c - def add(acc: PublicInteger, a: PublicInteger) -> PublicInteger: return acc + a