Skip to content

Commit

Permalink
[refactor] Default dtype of ndarray type should be None instead of f32 (
Browse files Browse the repository at this point in the history
  • Loading branch information
ailzhang authored Jul 12, 2022
1 parent f6b40de commit 7f97f9c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/taichi/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def produce_injected_args(kernel, symbolic_args=None):
raise TaichiCompilationError(
f'{field_dim} from Arg {arg.name} doesn\'t match kernel\'s annotated field_dim={anno.field_dim}'
)
if dtype != anno.dtype:
if anno.dtype is not None and dtype != anno.dtype:
raise TaichiCompilationError(
f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno.dtype.to_string()}'
)
Expand Down
16 changes: 6 additions & 10 deletions python/taichi/types/ndarray_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from taichi.types.primitive_types import f32


class NdarrayTypeMetadata:
def __init__(self, element_type, shape=None, layout=None):
self.element_type = element_type
Expand All @@ -20,13 +17,12 @@ class NdarrayType:
field_dim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now.
layout (Union[Layout, NoneType], optional): None if not specified (will be treated as Layout.AOS for external arrays), Layout.AOS or Layout.SOA.
"""
def __init__(
self,
dtype=f32, # TODO: default should be None
element_dim=None,
element_shape=None,
field_dim=None,
layout=None):
def __init__(self,
dtype=None,
element_dim=None,
element_shape=None,
field_dim=None,
layout=None):
if element_dim is not None and (element_dim < 0 or element_dim > 2):
raise ValueError(
"Only scalars, vectors, and matrices are allowed as elements of ti.types.ndarray()"
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_arg_mismatched_ndarray_dtype():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(field_dim=1)):
def test(pos: ti.types.ndarray(dtype=ti.f32, field_dim=1)):
for i in range(n):
pos[i] = 2.5

Expand Down

0 comments on commit 7f97f9c

Please sign in to comment.