Skip to content

Commit

Permalink
[BugFix]Ensure that bf16 arrays are created as expected (#16436)
Browse files Browse the repository at this point in the history
Co-authored-by: Bin Li <[email protected]>
  • Loading branch information
sisleyli and Bin Li authored Jan 20, 2024
1 parent ffa404f commit ccca00a
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def copyfrom(self, source_array):
if (not source_array.flags["C_CONTIGUOUS"]) or (
dtype == "bfloat16" or dtype != np_dtype_str
):
if dtype == "bfloat16":
source_array = np.frombuffer(source_array.tobytes(), "uint16")
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
Expand Down

0 comments on commit ccca00a

Please sign in to comment.