diff --git a/armi/bookkeeping/db/database.py b/armi/bookkeeping/db/database.py index 5a718cda8..85d555d0e 100644 --- a/armi/bookkeeping/db/database.py +++ b/armi/bookkeeping/db/database.py @@ -829,11 +829,11 @@ def loadReadOnly(self, cycle, node, statePointName=None): return r @staticmethod - def _setParamsBeforeFreezing(r): + def _setParamsBeforeFreezing(r: Reactor): """Set some special case parameters before they are made read-only.""" - for child in r.getChildren(deep=True): - if not isinstance(child, Component): - continue + for child in r.iterChildren( + deep=True, predicate=lambda c: isinstance(c, Component) + ): # calling Component.getVolume() sets the volume parameter child.getVolume() diff --git a/armi/bookkeeping/report/reportInterface.py b/armi/bookkeeping/report/reportInterface.py index 0105f70bb..c970db6e4 100644 --- a/armi/bookkeeping/report/reportInterface.py +++ b/armi/bookkeeping/report/reportInterface.py @@ -184,7 +184,7 @@ def reportSFP(sfp): runLog.important(title) runLog.important("-" * len(title)) totFis = 0.0 - for a in sfp.getChildren(): + for a in sfp: runLog.important( "{assembly:15s} discharged at t={dTime:10f} after {residence:10f} yrs. It entered at cycle: {cycle}. " "It has {fiss:10f} kg (x {mult}) fissile and peak BU={bu:.2f} %.".format( @@ -208,16 +208,16 @@ def reportSFP(sfp): @staticmethod def countAssembliesSFP(sfp): """Report on the count of assemblies in the SFP at each timestep.""" - if not sfp.getChildren(): + if not len(sfp): return runLog.important("Count:") totCount = 0 thisTimeCount = 0 - a = sfp.getChildren()[0] + a = sfp[0] lastTime = a.getAge() / units.DAYS_PER_YEAR + a.p.chargeTime - for a in sfp.getChildren(): + for a in sfp: thisTime = a.getAge() / units.DAYS_PER_YEAR + a.p.chargeTime if thisTime != lastTime: diff --git a/armi/physics/fuelCycle/fuelHandlers.py b/armi/physics/fuelCycle/fuelHandlers.py index 9f99d5a8f..5996f3d15 100644 --- a/armi/physics/fuelCycle/fuelHandlers.py +++ b/armi/physics/fuelCycle/fuelHandlers.py @@ -654,7 +654,7 @@ def _getAssembliesInRings( f"or otherwise instantiate a SpentFuelPool object as r.excore['sfp']" ) else: - sfpAssems = self.r.excore["sfp"].getChildren() + sfpAssems = list(self.r.excore["sfp"]) assemblyList = [[] for _i in range(len(ringList))] # empty lists for each ring if exclusions is None: diff --git a/armi/physics/fuelCycle/tests/test_fuelHandlers.py b/armi/physics/fuelCycle/tests/test_fuelHandlers.py index b6bbc14a0..305a7ab67 100644 --- a/armi/physics/fuelCycle/tests/test_fuelHandlers.py +++ b/armi/physics/fuelCycle/tests/test_fuelHandlers.py @@ -468,7 +468,7 @@ def runShuffling(self, fh): self.r.p.cycle = cycle fh.cycle = cycle fh.manageFuel(cycle) - for a in self.r.excore["sfp"].getChildren(): + for a in self.r.excore["sfp"]: self.assertEqual(a.getLocation(), "SFP") for b in self.r.core.getBlocks(Flags.FUEL): self.assertGreater(b.p.kgHM, 0.0, "b.p.kgHM not populated!") @@ -498,7 +498,7 @@ def test_repeatShuffles(self): ensure repeatability. """ # check labels before shuffling: - for a in self.r.excore["sfp"].getChildren(): + for a in self.r.excore["sfp"]: self.assertEqual(a.getLocation(), "SFP") # do some shuffles @@ -532,7 +532,7 @@ def test_repeatShuffles(self): # make sure the shuffle was repeated perfectly. for a in self.r.core.getAssemblies(): self.assertEqual(a.getName(), firstPassResults[a.getLocation()]) - for a in self.r.excore["sfp"].getChildren(): + for a in self.r.excore["sfp"]: self.assertEqual(a.getLocation(), "SFP") # Do some cleanup, since the fuelHandler Interface has code that gets @@ -811,7 +811,7 @@ def test_dischargeSwap(self): # grab an arbitrary fuel assembly from the core and from the SFP a1 = self.r.core.getAssemblies(Flags.FUEL)[0] - a2 = self.r.excore["sfp"].getChildren(Flags.FUEL)[0] + a2 = self.r.excore["sfp"].getChildrenWithFlags(Flags.FUEL)[0] # grab the stationary blocks pre swap a1PreSwapStationaryBlocks = [ diff --git a/armi/physics/neutronics/globalFlux/globalFluxInterface.py b/armi/physics/neutronics/globalFlux/globalFluxInterface.py index 9e51adab5..bb5cf8342 100644 --- a/armi/physics/neutronics/globalFlux/globalFluxInterface.py +++ b/armi/physics/neutronics/globalFlux/globalFluxInterface.py @@ -277,7 +277,7 @@ def getTightCouplingValue(self): return self.r.core.p.keff if self.coupler.parameter == "power": scaledCorePowerDistribution = [] - for a in self.r.core.getChildren(): + for a in self.r.core: scaledPower = [] assemPower = sum(b.p.power for b in a) for b in a: diff --git a/armi/physics/neutronics/globalFlux/tests/test_globalFluxInterface.py b/armi/physics/neutronics/globalFlux/tests/test_globalFluxInterface.py index 09adacaf8..ae9c0e96c 100644 --- a/armi/physics/neutronics/globalFlux/tests/test_globalFluxInterface.py +++ b/armi/physics/neutronics/globalFlux/tests/test_globalFluxInterface.py @@ -286,7 +286,7 @@ def test_getTightCouplingValue(self): self._setTightCouplingTrue() self.assertEqual(self.gfi.getTightCouplingValue(), 1.0) # set in setUp self.gfi.coupler.parameter = "power" - for a in self.r.core.getChildren(): + for a in self.r.core: for b in a: b.p.power = 10.0 self.assertEqual( diff --git a/armi/physics/neutronics/isotopicDepletion/crossSectionTable.py b/armi/physics/neutronics/isotopicDepletion/crossSectionTable.py index a7921e3bc..3103f7f31 100644 --- a/armi/physics/neutronics/isotopicDepletion/crossSectionTable.py +++ b/armi/physics/neutronics/isotopicDepletion/crossSectionTable.py @@ -223,7 +223,7 @@ def makeReactionRateTable(obj, nuclides: List = None): for nucName in nuclides } - for armiObject in obj.getChildren(): + for armiObject in obj: for nucName in nuclides: rxnRates = armiObject.getReactionRates(nucName, nDensity=1.0) for rxName, rxRate in rxnRates.items(): diff --git a/armi/reactor/assemblies.py b/armi/reactor/assemblies.py index da55c7694..c46a73aa6 100644 --- a/armi/reactor/assemblies.py +++ b/armi/reactor/assemblies.py @@ -326,10 +326,8 @@ def getPinPlenumVolumeInCubicMeters(self): ------- This is a bit design-specific for pinned assemblies """ - plenumBlocks = self.getBlocks(Flags.PLENUM) - plenumVolume = 0.0 - for b in plenumBlocks: + for b in self.iterChildrenWithFlags(Flags.PLENUM): cladId = b.getComponent(Flags.CLAD).getDimension("id") length = b.getHeight() plenumVolume += ( @@ -339,7 +337,7 @@ def getPinPlenumVolumeInCubicMeters(self): def getAveragePlenumTemperature(self): """Return the average of the plenum block outlet temperatures.""" - plenumBlocks = self.getBlocks(Flags.PLENUM) + plenumBlocks = self.iterChildrenWithFlags(Flags.PLENUM) plenumTemps = [b.p.THcoolantOutletT for b in plenumBlocks] # no plenum blocks, use the top block of the assembly for plenum temperature @@ -814,7 +812,7 @@ def dump(self, fName=None): with open(fName, "w") as pkl: pickle.dump(self, pkl) - def getBlocks(self, typeSpec=None, exact=False): + def getBlocks(self, typeSpec=None, exact=False) -> list[blocks.Block]: """ Get blocks in an assembly from bottom to top. @@ -831,9 +829,10 @@ def getBlocks(self, typeSpec=None, exact=False): List of blocks. """ if typeSpec is None: - return self.getChildren() + items = iter(self) else: - return self.getChildrenWithFlags(typeSpec, exactMatch=exact) + items = self.iterChildrenWithFlags(typeSpec, exact) + return list(items) def getBlocksAndZ(self, typeSpec=None, returnBottomZ=False, returnTopZ=False): """ @@ -881,26 +880,40 @@ def getBlocksAndZ(self, typeSpec=None, returnBottomZ=False, returnTopZ=False): return zip(blocks, zCoords) def hasContinuousCoolantChannel(self): - return all( - b.containsAtLeastOneChildWithFlags(Flags.COOLANT) for b in self.getBlocks() - ) + return all(b.containsAtLeastOneChildWithFlags(Flags.COOLANT) for b in self) def getFirstBlock(self, typeSpec=None, exact=False): - bs = self.getBlocks(typeSpec, exact=exact) - if bs: - return bs[0] + """Find the first block that matches the spec. + + Parameters + ---------- + typeSpec : flag or list of flags, optional + Specification to require on the returned block. + exact : bool, optional + Require block to exactly match ``typeSpec`` + + Returns + ------- + Block or None + First block that matches if such a block could be found. + """ + if typeSpec is None: + items = iter(self) else: + items = self.iterChildrenWithFlags(typeSpec, exact) + try: + # Create an iterator and attempt to advance it to the first value. + return next(items) + except StopIteration: + # No items found in the iteration -> no blocks match the request return None def getFirstBlockByType(self, typeName): - bs = [ - b - for b in self.getChildren(deep=False) - if isinstance(b, blocks.Block) and b.getType() == typeName - ] - if bs: - return bs[0] - return None + blocks = filter(lambda b: b.getType() == typeName, self) + try: + return next(blocks) + except StopIteration: + return None def getBlockAtElevation(self, elevation): """ @@ -1191,7 +1204,11 @@ def countBlocksWithFlags(self, blockTypeSpec=None): blockCounter : int number of blocks of this type """ - return len(self.getBlocks(blockTypeSpec)) + if blockTypeSpec is None: + items = iter(self) + else: + items = self.iterChildrenWithFlags(blockTypeSpec) + return sum(1 for _ in items) def getDim(self, typeSpec, dimName): """ diff --git a/armi/reactor/blocks.py b/armi/reactor/blocks.py index 023e86361..3e0e1fda0 100644 --- a/armi/reactor/blocks.py +++ b/armi/reactor/blocks.py @@ -627,7 +627,7 @@ def getArea(self, cold=False): return area a = 0.0 - for c in self.getChildren(): + for c in self: myArea = c.getArea(cold=cold) a += myArea fullArea = a @@ -930,7 +930,7 @@ def add(self, c): self._updatePitchComponent(c) def removeAll(self, recomputeAreaFractions=True): - for c in self.getChildren(): + for c in list(self): self.remove(c, recomputeAreaFractions=False) if recomputeAreaFractions: # only do this once self.getVolumeFractions() @@ -1283,7 +1283,7 @@ def getPinPitch(self, cold=False): def getDimensions(self, dimension): """Return dimensional values of the specified dimension.""" dimVals = set() - for c in self.getChildren(): + for c in self: try: dimVal = c.getDimension(dimension) except parameters.ParameterError: @@ -1532,7 +1532,7 @@ def getPinLocations(self) -> list[grids.IndexLocation]: :meth:`getPinCoordinates` - companion for this method. """ items = [] - for clad in self.getChildrenWithFlags(Flags.CLAD): + for clad in self.iterChildrenWithFlags(Flags.CLAD): if isinstance(clad.spatialLocator, grids.MultiIndexLocation): items.extend(clad.spatialLocator) else: diff --git a/armi/reactor/blueprints/__init__.py b/armi/reactor/blueprints/__init__.py index 48532ed58..3bee854e4 100644 --- a/armi/reactor/blueprints/__init__.py +++ b/armi/reactor/blueprints/__init__.py @@ -519,7 +519,7 @@ def _checkAssemblyAreaConsistency(self, cs): runLog.error("CURRENT COMPARISON BLOCK:") b.printContents(includeNuclides=False) - for c in b.getChildren(): + for c in b: runLog.error( "{0} area {1} effective area {2}" "".format(c, c.getArea(), c.getVolume() / b.getHeight()) diff --git a/armi/reactor/components/__init__.py b/armi/reactor/components/__init__.py index 516cdae81..f4fe2c6d8 100644 --- a/armi/reactor/components/__init__.py +++ b/armi/reactor/components/__init__.py @@ -383,7 +383,7 @@ def _deriveVolumeAndArea(self): # Determine the volume/areas of the non-derived shape components within the parent. siblingVolume = 0.0 siblingArea = 0.0 - for sibling in self.parent.getChildren(): + for sibling in self.parent: if sibling is self: continue elif not self and isinstance(sibling, DerivedShape): diff --git a/armi/reactor/components/component.py b/armi/reactor/components/component.py index 5a4da2f3b..be2ce62db 100644 --- a/armi/reactor/components/component.py +++ b/armi/reactor/components/component.py @@ -964,7 +964,7 @@ def clearLinkedCache(self): def getLinkedComponents(self): """Find other components that are linked to this component.""" dependents = [] - for child in self.parent.getChildren(): + for child in self.parent: for dimName in child.DIMENSION_NAMES: isLinked = child.dimensionIsLinked(dimName) if isLinked and child.p[dimName].getLinkedComponent() is self: diff --git a/armi/reactor/composites.py b/armi/reactor/composites.py index 62a0e329d..25e024c1e 100644 --- a/armi/reactor/composites.py +++ b/armi/reactor/composites.py @@ -36,7 +36,7 @@ import itertools import operator import timeit -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union, Iterator, Callable import numpy as np import six @@ -471,7 +471,7 @@ def duplicate(self): def clearCache(self): """Clear the cache so all new values are recomputed.""" self.cached = {} - for child in self.getChildren(): + for child in self: child.clearCache() def _getCached(self, name): # TODO: stop the "returns None" nonsense? @@ -521,13 +521,40 @@ def updateParamsFrom(self, new): for paramName, val in new.p.items(): self.p[paramName] = val - def getChildren(self, deep=False, generationNum=1, includeMaterials=False): + def iterChildren( + self, + deep=False, + generationNum=1, + predicate: Optional[Callable[["ArmiObject"], bool]] = None, + ) -> Iterator["ArmiObject"]: + """Iterate over children of this object.""" + raise NotImplementedError() + + def getChildren( + self, deep=False, generationNum=1, includeMaterials=False + ) -> list["ArmiObject"]: """Return the children of this object.""" - raise NotImplementedError + raise NotImplementedError() - def getChildrenWithFlags(self, typeSpec: TypeSpec, exactMatch=True): + def iterChildrenWithFlags( + self, typeSpec: TypeSpec, exactMatch=False + ) -> Iterator["ArmiObject"]: + """Produce an iterator of children that have given flags.""" + return self.iterChildren(predicate=lambda o: o.hasFlags(typeSpec, exactMatch)) + + def getChildrenWithFlags( + self, typeSpec: TypeSpec, exactMatch=False + ) -> list["ArmiObject"]: """Get all children that have given flags.""" - raise NotImplementedError + return list(self.iterChildrenWithFlags(typeSpec, exactMatch)) + + def iterChildrenOfType(self, typeName: str) -> Iterator["ArmiObject"]: + """Iterate over children that have a specific input type name.""" + return self.iterChildren(predicate=lambda o: o.getType() == typeName) + + def getChildrenOfType(self, typeName: str) -> list["ArmiObject"]: + """Produce a list of children that have a specific input type name.""" + return list(self.iterChildrenOfType(typeName)) def getComponents(self, typeSpec: TypeSpec = None, exact=False): """ @@ -613,7 +640,7 @@ def copyParamsToChildren(self, paramNames): """ for paramName in paramNames: myVal = self.p[paramName] - for c in self.getChildren(): + for c in self: c.p[paramName] = myVal @classmethod @@ -894,7 +921,7 @@ def getMaxArea(self): """ raise NotImplementedError() - def getMass(self, nuclideNames=None): + def getMass(self, nuclideNames=None) -> float: """ Determine the mass in grams of nuclide(s) and/or elements in this object. @@ -917,7 +944,7 @@ def getMass(self, nuclideNames=None): mass : float The mass in grams. """ - return sum([c.getMass(nuclideNames=nuclideNames) for c in self]) + return sum(c.getMass(nuclideNames=nuclideNames) for c in self) def getMassFrac(self, nucName): """ @@ -1627,7 +1654,7 @@ def setLumpedFissionProducts(self, lfpCollection): self._lumpedFissionProducts = lfpCollection def setChildrenLumpedFissionProducts(self, lfpCollection): - for c in self.getChildren(): + for c in self: c.setLumpedFissionProducts(lfpCollection) def getFissileMassEnrich(self): @@ -1928,7 +1955,7 @@ def getNuclides(self): List of nuclide names that exist in this """ nucs = set() - for child in self.getChildren(): + for child in self: nucs.update(child.getNuclides()) return nucs @@ -2516,6 +2543,8 @@ class is a child-class of the ARMIObject class and provides a structure """ + _children: list["Composite"] + def __init__(self, name): ArmiObject.__init__(self, name) self.childrenByLocator = {} @@ -2613,9 +2642,106 @@ def setChildren(self, items): for c in items: self.add(c) + def iterChildren( + self, + deep=False, + generationNum=1, + predicate: Optional[Callable[["Composite"], bool]] = None, + ) -> Iterator["Composite"]: + """Iterate over children objects of this composite. + + Parameters + ---------- + deep : bool, optional + If true, traverse the entire composite tree. Otherwise, go as far as ``generationNum``. + generationNum: int, optional + Produce composites at this depth. A depth of ``1`` includes children of ``self``, ``2`` + is children of children, and so on. + predicate: f(Composite) -> bool, optional + Function to check on a composite before producing it. All items in the iteration + will pass this check. + + Returns + ------- + iterator of Composite + + See Also + -------- + :meth:`getChildren` produces a list for situations where you need to perform + multiple iterations or do list operations (append, indexing, sorting, containment, etc.) + + Composites are naturally iterable. The following are identical:: + + >>> for child in c.getChildren(): + ... pass + >>> for child in c.iterChildren(): + ... pass + >>> for child in c: + ... pass + + If you do not need any depth-traversal, natural iteration should be sufficient. + + The :func:`filter` command may be sufficient if you do not wish to pass a predicate. The following + are identical:: + >>> checker = lambda c: len(c.name) % 3 + >>> for child in c.getChildren(predicate=checker): + ... pass + >>> for child in c.iterChildren(predicate=checker): + ... pass + >>> for child in filter(checker, c): + ... pass + + If you're going to be doing traversal beyond the first generation, this method will help you. + + """ + if deep and generationNum > 1: + raise RuntimeError( + "Cannot get children with a generation number set and the deep flag set" + ) + if predicate is None: + checker = lambda _: True + else: + checker = predicate + yield from self._iterChildren(deep, generationNum, checker) + + def _iterChildren( + self, deep: bool, generationNum: int, checker: Callable[["Composite"], bool] + ) -> Iterator["Composite"]: + if deep or generationNum == 1: + yield from filter(checker, self) + if deep or generationNum > 1: + for c in self: + yield from c._iterChildren(deep, generationNum - 1, checker) + + def iterChildrenWithMaterials(self, *args, **kwargs) -> Iterator: + """Produce an iterator that also includes any materials found on descendants. + + Arguments are forwarded to :meth:`iterChildren` and control the depth of traversal + and filtering of objects. + + This is useful for sending state across MPI tasks where you need a more full + representation of the composite tree. Which includes the materials attached + to components. + """ + children = self.iterChildren(*args, **kwargs) + # Each entry is either (c, ) or (c, c.material) if the child has a material attribute + stitched = map( + lambda c: ( + (c,) if getattr(c, "material", None) is None else (c, c.material) + ), + children, + ) + # Iterator that iterates over each "sub" iterator. If we have ((c0, ), (c1, m1)), this produces a single + # iterator of (c0, c1, m1) + return itertools.chain.from_iterable(stitched) + def getChildren( - self, deep=False, generationNum=1, includeMaterials=False, predicate=None - ): + self, + deep=False, + generationNum=1, + includeMaterials=False, + predicate: Optional[Callable[["Composite"], bool]] = None, + ) -> list["Composite"]: """ Return the children objects of this composite. @@ -2657,6 +2783,11 @@ def getChildren( to meet the predicate only affects the object in question; children will still be considered. + See Also + -------- + :meth:`iterChildren` if you do not need to produce a full list, e.g., just iterating + over objects. + Examples -------- >>> obj.getChildren() @@ -2673,49 +2804,15 @@ def getChildren( [grandchild1, grandchild3] """ - _pred = predicate or (lambda x: True) - if deep and generationNum > 1: - raise RuntimeError( - "Cannot get children with a generation number set and the deep flag set" + if not includeMaterials: + items = self.iterChildren( + deep=deep, generationNum=generationNum, predicate=predicate ) - - children = [] - for child in self._children: - if generationNum == 1 or deep: - if _pred(child): - children.append(child) - - if generationNum > 1 or deep: - children.extend( - child.getChildren( - deep=deep, - generationNum=generationNum - 1, - includeMaterials=includeMaterials, - predicate=predicate, - ) - ) - if includeMaterials: - material = getattr(self, "material", None) - if material: - children.append(material) - - return children - - def getChildrenWithFlags(self, typeSpec: TypeSpec, exactMatch=False): - """Get all children of a specific type.""" - children = [] - for child in self: - if child.hasFlags(typeSpec, exact=exactMatch): - children.append(child) - return children - - def getChildrenOfType(self, typeName): - """Get children that have a specific input type name.""" - children = [] - for child in self: - if child.getType() == typeName: - children.append(child) - return children + else: + items = self.iterChildrenWithMaterials( + deep=deep, generationNum=generationNum, predicate=predicate + ) + return list(items) def getComponents(self, typeSpec: TypeSpec = None, exact=False): return list(self.iterComponents(typeSpec, exact)) @@ -2769,8 +2866,11 @@ def syncMpiState(self): startTime = timeit.default_timer() # sync parameters... - allComps = [self] + self.getChildren(deep=True, includeMaterials=True) - allComps = [c for c in allComps if hasattr(c, "p")] + genItems = itertools.chain( + [self], + self.iterChildrenWithMaterials(deep=True), + ) + allComps = [c for c in genItems if hasattr(c, "p")] sendBuf = [c.p.getSyncData() for c in allComps] runLog.debug("syncMpiState has {} comps".format(len(allComps))) @@ -2875,7 +2975,11 @@ def _markSynchronized(self): SINCE_LAST_DISTRIBUTE_STATE. """ paramDefs = set() - for child in [self] + self.getChildren(deep=True, includeMaterials=True): + items = itertools.chain( + [self], + self.iterChildrenWithMaterials(deep=True), + ) + for child in items: # Materials don't have a "p" / Parameter attribute to sync if hasattr(child, "p"): # below reads as: assigned & everything_but(SINCE_LAST_DISTRIBUTE_STATE) @@ -2956,7 +3060,7 @@ def getLumpedFissionProductCollection(self): """ lfps = ArmiObject.getLumpedFissionProductCollection(self) if lfps is None: - for c in self.getChildren(): + for c in self: lfps = c.getLumpedFissionProductCollection() if lfps is not None: break @@ -3094,7 +3198,7 @@ def getReactionRates(self, nucName, nDensity=None): def printContents(self, includeNuclides=True): """Display information about all the comprising children in this object.""" runLog.important(self) - for c in self.getChildren(): + for c in self: c.printContents(includeNuclides=includeNuclides) def _genChildByLocationLookupTable(self): @@ -3140,7 +3244,7 @@ class StateRetainer: """ - def __init__(self, composite, paramsToApply=None): + def __init__(self, composite: Composite, paramsToApply=None): """ Create an instance of a StateRetainer. @@ -3168,9 +3272,11 @@ def _enterExitHelper(self, func): ``backUp()`` or ``restoreBackup()``. """ paramDefs = set() - for child in [self.composite] + self.composite.getChildren( - deep=True, includeMaterials=True - ): + items = itertools.chain( + (self.composite,), + self.composite.iterChildrenWithMaterials(deep=True), + ) + for child in items: if hasattr(child, "p"): # materials don't have Parameters paramDefs.update(child.p.paramDefs) diff --git a/armi/reactor/converters/axialExpansionChanger/expansionData.py b/armi/reactor/converters/axialExpansionChanger/expansionData.py index 39d4dacaa..985de6f13 100644 --- a/armi/reactor/converters/axialExpansionChanger/expansionData.py +++ b/armi/reactor/converters/axialExpansionChanger/expansionData.py @@ -288,19 +288,17 @@ def determineTargetComponent( if flagOfInterest is None: # Follow expansion of most neutronically important component, fuel then control/poison for targetFlag in TARGET_FLAGS_IN_PREFERRED_ORDER: - candidates = [c for c in b.getChildren() if c.hasFlags(targetFlag)] + candidates = b.getChildrenWithFlags(targetFlag) if candidates: break # some blocks/components are not included in the above list but should still be found if not candidates: candidates = [c for c in b.getChildren() if c.p.flags in b.p.flags] else: - candidates = [c for c in b.getChildren() if c.hasFlags(flagOfInterest)] + candidates = b.getChildrenWithFlags(flagOfInterest) if len(candidates) == 0: # if only 1 solid, be smart enought to snag it - solidMaterials = list( - c for c in b if not isinstance(c.material, material.Fluid) - ) + solidMaterials = getSolidComponents(b) if len(solidMaterials) == 1: candidates = solidMaterials if len(candidates) == 0: diff --git a/armi/reactor/reactorParameters.py b/armi/reactor/reactorParameters.py index 90eeb185b..3f7577d27 100644 --- a/armi/reactor/reactorParameters.py +++ b/armi/reactor/reactorParameters.py @@ -789,5 +789,5 @@ def makeParametersReadOnly(r): Once you make one Reactor read-only, you cannot make it writeable again. """ r.p.readOnly = True - for child in r.getChildren(deep=True): + for child in r.iterChildren(deep=True): child.p.readOnly = True diff --git a/armi/reactor/spentFuelPool.py b/armi/reactor/spentFuelPool.py index 666c646b4..a3e374de6 100644 --- a/armi/reactor/spentFuelPool.py +++ b/armi/reactor/spentFuelPool.py @@ -77,7 +77,7 @@ def add(self, assem, loc=None): def getAssembly(self, name): """Get a specific assembly by name.""" - for a in self.getChildren(): + for a in self: if a.getName() == name: return a @@ -123,7 +123,7 @@ def normalizeNames(self, startIndex=0): The new max Assembly number. """ ind = startIndex - for a in self.getChildren(): + for a in self: oldName = a.getName() newName = a.makeNameFromAssemNum(ind) if oldName == newName: diff --git a/armi/reactor/tests/test_components.py b/armi/reactor/tests/test_components.py index 7ce082a54..6c4d70a38 100644 --- a/armi/reactor/tests/test_components.py +++ b/armi/reactor/tests/test_components.py @@ -165,8 +165,9 @@ def getHeight(self): def clearCache(self): pass - def getChildren(self): - return [] + def __iter__(self): + """Act like an iterator but don't actually iterate.""" + return iter(()) derivedMustUpdate = False diff --git a/armi/reactor/tests/test_composites.py b/armi/reactor/tests/test_composites.py index 5a60cc5a1..318acd34c 100644 --- a/armi/reactor/tests/test_composites.py +++ b/armi/reactor/tests/test_composites.py @@ -15,6 +15,7 @@ """Tests for the composite pattern.""" from copy import deepcopy import logging +import itertools import unittest from armi import nuclearDataIO @@ -76,6 +77,8 @@ def __init__(self, name, i=0): composites.Composite.__init__(self, name) self.p.type = name self.spatialLocator = grids.IndexLocation(i, i, i, _testGrid) + # Some special material attribute for testing getChildren(includeMaterials=True) + self.material = ("hello", "world") def getChildren( self, deep=False, generationNum=1, includeMaterials=False, predicate=None @@ -106,12 +109,21 @@ def setUp(self): container.add(leaf) nested = DummyComposite("clad", 98) nested.setType("clad") + self.cladChild = nested self.secondGen = DummyComposite("liner", 97) self.thirdGen = DummyLeaf("pin 77", 33) self.secondGen.add(self.thirdGen) nested.add(self.secondGen) container.add(nested) self.container = container + # Composite tree structure in list of lists for testing + # tree[i] contains the children at "generation" or "depth" i + self.tree: list[list[composites.Composite]] = [ + [self.container], + list(self.container), + [self.secondGen], + [self.thirdGen], + ] def test_composite(self): """Test basic Composite things. @@ -139,22 +151,87 @@ def test_getChildren(self): :id: T_ARMI_CMP1 :tests: R_ARMI_CMP """ - # There are 5 leaves and 1 composite in container. The composite has one leaf. firstGen = self.container.getChildren() - self.assertEqual(len(firstGen), 6) + self.assertEqual(firstGen, self.tree[1]) + secondGen = self.container.getChildren(generationNum=2) - self.assertEqual(len(secondGen), 1) + self.assertEqual(secondGen, self.tree[2]) + self.assertIs(secondGen[0], self.secondGen) third = self.container.getChildren(generationNum=3) - self.assertEqual(len(third), 1) + self.assertEqual(third, self.tree[3]) self.assertIs(third[0], self.thirdGen) + allC = self.container.getChildren(deep=True) - self.assertEqual(len(allC), 8) + expected = self.tree[1] + self.tree[2] + self.tree[3] + self.assertTrue( + all(a is e for a, e in itertools.zip_longest(allC, expected)), + msg=f"Deep traversal differs: {allC=} != {expected=}", + ) onlyLiner = self.container.getChildren( deep=True, predicate=lambda o: o.p.type == "liner" ) self.assertEqual(len(onlyLiner), 1) + self.assertIs(onlyLiner[0], self.secondGen) + + def test_getChildrenWithMaterials(self): + """Test the ability for getChildren to place the material after the object.""" + withMaterials = self.container.getChildren(deep=True, includeMaterials=True) + # Grab the iterable so we can control the progression + items = iter(withMaterials) + for item in items: + expectedMat = getattr(item, "material", None) + if expectedMat is None: + continue + # Material should be the next item in the list + actualMat = next(items) + self.assertIs(actualMat, expectedMat) + break + else: + raise RuntimeError("No materials found with includeMaterials=True") + + def test_iterChildren(self): + """Detailed testing on Composite.iterChildren.""" + + def compareIterables(actual, expected: list[composites.Composite]): + for e in expected: + a = next(actual) + self.assertIs(a, e) + # Ensure we've consumed the actual iterator and there's nothing left + with self.assertRaises(StopIteration): + next(actual) + + compareIterables(self.container.iterChildren(), self.tree[1]) + compareIterables(self.container.iterChildren(generationNum=2), self.tree[2]) + compareIterables(self.container.iterChildren(generationNum=3), self.tree[3]) + compareIterables( + self.container.iterChildren(deep=True), + self.tree[1] + self.tree[2] + self.tree[3], + ) + + def test_iterAndGetChildren(self): + """Compare that iter children and get children are consistent.""" + self._compareIterGetChildren() + self._compareIterGetChildren(deep=True) + self._compareIterGetChildren(generationNum=2) + # Some wacky predicate just to check we can use that too + self._compareIterGetChildren(deep=True, predicate=lambda c: len(c.name) % 3) + + def _compareIterGetChildren(self, **kwargs): + fromIter = self.container.iterChildren(**kwargs) + fromGetter = self.container.getChildren(**kwargs) + msg = repr(kwargs) + # Use zip longest just in case one iterator comes up short + for count, (it, gt) in enumerate(itertools.zip_longest(fromIter, fromGetter)): + self.assertIs(it, gt, msg=f"{count=} :: {msg}") + + def test_simpleIterChildren(self): + """Test that C.iterChildren() is identical to iter(C).""" + for count, (fromNative, fromIterChildren) in enumerate( + itertools.zip_longest(self.container, self.container.iterChildren()) + ): + self.assertIs(fromIterChildren, fromNative, msg=count) def test_getName(self): """Test the getName method. @@ -170,26 +247,26 @@ def test_getName(self): def test_sort(self): # in this case, the children should start sorted - c0 = [c.name for c in self.container.getChildren()] + c0 = [c.name for c in self.container] self.container.sort() - c1 = [c.name for c in self.container.getChildren()] + c1 = [c.name for c in self.container] self.assertNotEqual(c0, c1) # verify repeated sortings behave for _ in range(3): self.container.sort() - ci = [c.name for c in self.container.getChildren()] + ci = [c.name for c in self.container] self.assertEqual(c1, ci) # break the order children = self.container.getChildren() self.container._children = children[2:] + children[:2] - c2 = [c.name for c in self.container.getChildren()] + c2 = [c.name for c in self.container] self.assertNotEqual(c1, c2) # verify the sort order self.container.sort() - c3 = [c.name for c in self.container.getChildren()] + c3 = [c.name for c in self.container] self.assertEqual(c1, c3) def test_areChildernOfType(self): @@ -320,12 +397,12 @@ def test_setChildrenLumpedFissionProducts(self): # validate that the LFP collection is None self.container.setChildrenLumpedFissionProducts(None) - for c in self.container.getChildren(): + for c in self.container: self.assertIsNone(c._lumpedFissionProducts) # validate that the LFP collection is not None self.container.setChildrenLumpedFissionProducts(lfps) - for c in self.container.getChildren(): + for c in self.container: self.assertIsNotNone(c._lumpedFissionProducts) def test_requiresLumpedFissionProducts(self): @@ -417,6 +494,43 @@ def test_syncParameters(self): numSynced = self.container._syncParameters(data, {}) self.assertEqual(numSynced, 2) + def test_iterChildrenWithFlags(self): + expectedChildren = {c for c in self.container if c.hasFlags(Flags.DUCT)} + found = set() + for c in self.container.iterChildrenWithFlags(Flags.DUCT): + self.assertIn(c, expectedChildren) + found.add(c) + self.assertSetEqual(found, expectedChildren) + + def test_iterChildrenOfType(self): + clads = self.container.iterChildrenOfType("clad") + first = next(clads) + self.assertIs(first, self.cladChild) + with self.assertRaises(StopIteration): + next(clads) + + def test_removeAll(self): + """Test the ability to remove all children of a composite.""" + self.container.removeAll() + self.assertEqual(len(self.container), 0) + # Nothing to iterate over + items = iter(self.container) + with self.assertRaises(StopIteration): + next(items) + for child in self.tree[1]: + self.assertIsNone(child.parent) + + def test_setChildren(self): + """Test the ability to override children on a composite.""" + newChildren = self.tree[2] + self.tree[3] + oldChildren = list(self.container) + self.container.setChildren(newChildren) + self.assertEqual(len(self.container), len(newChildren)) + for old in oldChildren: + self.assertIsNone(old.parent) + for actualNew, expectedNew in zip(newChildren, self.container): + self.assertIs(actualNew, expectedNew) + class TestCompositeTree(unittest.TestCase): blueprintYaml = """ @@ -786,7 +900,7 @@ def test_getNumberDensities(self): # sum nuc densities from children components totalVolume = self.obj.getVolume() childDensities = {} - for o in self.obj.getChildren(): + for o in self.obj: m = o.getVolume() d = o.getNumberDensities() for nuc, val in d.items(): @@ -824,7 +938,7 @@ def test_getNumberDensitiesWithExpandedFissionProducts(self): # sum nuc densities from children components totalVolume = self.obj.getVolume() childDensities = {} - for o in self.obj.getChildren(): + for o in self.obj: # get the number densities with and without fission products d0 = o.getNumberDensities(expandFissionProducts=False) d = o.getNumberDensities(expandFissionProducts=True) diff --git a/armi/reactor/tests/test_reactors.py b/armi/reactor/tests/test_reactors.py index 6187155e6..69790e6c7 100644 --- a/armi/reactor/tests/test_reactors.py +++ b/armi/reactor/tests/test_reactors.py @@ -360,7 +360,7 @@ def test_growToFullCore(self): self.assertFalse(self.r.core.isFullCore) self.r.core.growToFullCore(self.o.cs) aNums = [] - for a in self.r.core.getChildren(): + for a in self.r.core: self.assertNotIn(a.getNum(), aNums) aNums.append(a.getNum()) diff --git a/doc/release/0.5.rst b/doc/release/0.5.rst index f730979d6..2eab84399 100644 --- a/doc/release/0.5.rst +++ b/doc/release/0.5.rst @@ -9,6 +9,9 @@ Release Date: TBD New Features ------------ #. Move instead of copy files from TemporaryDirectoryChanger. (`PR#2022 `_) +#. Provide ``Composite.iterChildren``, ``Composite.iterChildrenWithFlags``, ``Composite.iterChildrenOfType``, + ``Composite.iterChildrenWithMaterials``, for efficient composite tree traversal. + (`PR#2031 `_) API Changes -----------