Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the check for a jagged array during _writeParams #2051

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
22 changes: 12 additions & 10 deletions armi/bookkeeping/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,16 +914,18 @@ def _compose(self, comps, cs, parent=None):

return comp

def _writeParams(self, h5group, comps) -> tuple:
def _getShape(arr: [np.ndarray, List, Tuple]):
"""Get the shape of a np.ndarray, list, or tuple."""
if isinstance(arr, np.ndarray):
return arr.shape
elif isinstance(arr, (list, tuple)):
return (len(arr),)
else:
return (1,)
@staticmethod
def _getArrayShape(arr: [np.ndarray, List, Tuple]):
"""Get the shape of a np.ndarray, list, or tuple."""
if isinstance(arr, np.ndarray):
return arr.shape
elif isinstance(arr, (list, tuple)):
return (len(arr),)
else:
# not a list, tuple, or array (likely int, float, or None)
return 1

def _writeParams(self, h5group, comps) -> tuple:
c = comps[0]
groupName = c.__class__.__name__
if groupName not in h5group:
Expand Down Expand Up @@ -969,7 +971,7 @@ def _getShape(arr: [np.ndarray, List, Tuple]):
else:
# check if temp is a jagged array
if any(isinstance(x, (np.ndarray, list)) for x in temp):
jagged = len(set([_getShape(x) for x in temp])) != 1
jagged = len(set([self._getArrayShape(x) for x in temp])) != 1
else:
jagged = False
data = (
Expand Down
14 changes: 14 additions & 0 deletions armi/bookkeeping/db/tests/test_database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

from armi.bookkeeping.db import _getH5File, database, loadOperator
from armi.bookkeeping.db.database import Database
from armi.bookkeeping.db.databaseInterface import DatabaseInterface
from armi.bookkeeping.db.jaggedArray import JaggedArray
from armi.reactor import parameters
Expand Down Expand Up @@ -283,6 +284,19 @@ def _compareRoundTrip(self, data):
roundTrip = database.unpackSpecialData(packed, attrs, "testing")
self._compareArrays(data, roundTrip)

def test_getArrayShape(self):
"""Tests a helper method for ``_writeParams``."""
base = [1, 2, 3, 4]
self.assertEqual(Database._getArrayShape(base), (4,))
self.assertEqual(Database._getArrayShape(tuple(base)), (4,))
arr = np.array(base)
self.assertEqual(Database._getArrayShape(arr), (4,))
arr = np.array([base])
self.assertEqual(Database._getArrayShape(arr), (1, 4))
# not array type
self.assertEqual(Database._getArrayShape(1), 1)
self.assertEqual(Database._getArrayShape(None), 1)

def test_writeToDB(self):
"""Test writing to the database.

Expand Down
1 change: 1 addition & 0 deletions doc/release/0.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ API Changes

Bug Fixes
---------
#. Fix check for jagged arrays during ``_writeParams``. (`PR#2051 <https://github.com/terrapower/armi/pull/2051>`_)
john-science marked this conversation as resolved.
Show resolved Hide resolved
#. TBD

Quality Work
Expand Down