Skip to content

Commit

Permalink
Switch to passing args rather than unrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
brryan committed Jan 6, 2025
1 parent 7f3abac commit 1174432
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 78 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ jobs:
--input ../inputs/stepdiff_smr_ddmc.in --use_mpiexec
./stepdiff_smr.py --executable $JAYBENNE_EXECUTABLE \
--input ../inputs/stepdiff_smr_ddmc.in --use_mpiexec \
--mpi_nthreads 8 --oversubscribe
--mpi_nthreads 8 --mpi_oversubscribe
./stepdiff_smr.py --executable $JAYBENNE_EXECUTABLE \
--input ../inputs/stepdiff_smr_hybrid.in --use_mpiexec
./stepdiff_smr.py --executable $JAYBENNE_EXECUTABLE \
--input ../inputs/stepdiff_smr_hybrid.in --use_mpiexec \
--mpi_nthreads 8 --oversubscribe
--mpi_nthreads 8 --mpi_oversubscribe
112 changes: 52 additions & 60 deletions tst/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ def get_default_parser():
action="store_true",
help="Whether to launch the executable with mpiexec",
)
parser.add_argument(
"--mpi_oversubscribe",
dest="mpi_oversubscribe",
action="store_true",
help="Allow MPI to oversubscribe cores",
)
parser.add_argument(
"--mpi_nthreads",
type=int,
Expand Down Expand Up @@ -203,10 +209,10 @@ def get_default_parser():
help="Method of comparison against true solution at each point or via mean [pointwise, mean]",
)
parser.add_argument(
"--oversubscribe",
dest="use_mpiexec",
"--visualize",
dest="visualize",
action="store_true",
help="Allow MPI to oversubscribe cores",
help="Whether to generate visual output.",
)
return parser

Expand Down Expand Up @@ -312,19 +318,11 @@ def run_problem(
# -- Run test problem with previously built code, input file, and modified inputs, and
# compare to analytic expectation
def analytic_comparison(
args,
variables,
solutions,
input_file,
modified_inputs={},
executable=None,
use_mpiexec=False,
oversubscribe=False,
mpi_nthreads=1,
build_type="Release",
tolerance=1.0e-10,
cleanup=False,
comparison="mean",
visualize=False,
):

input_file = os.path.join(
Expand All @@ -333,33 +331,33 @@ def analytic_comparison(
problem = read_input_value("parthenon/job", "problem_id", input_file)
print("\n=== ANALYTIC COMPARISON TEST PROBLEM ===")
print(f"= problem: {problem}")
print(f"= executable: {executable}")
print(f"= use_mpiexec: {use_mpiexec}")
if use_mpiexec:
print(f"= oversubscribe: {oversubscribe}")
print(f"= mpi_nthreads: {mpi_nthreads}")
print(f"= build_type: {build_type}")
print(f"= executable: {args.executable}")
print(f"= use_mpiexec: {args.use_mpiexec}")
if args.use_mpiexec:
print(f"= oversubscribe: {args.mpi_oversubscribe}")
print(f"= mpi_nthreads: {args.mpi_nthreads}")
print(f"= build_type: {args.build_type}")
print(f"= tolerance: {tolerance}")
print(f"= cleanup: {cleanup}")
print(f"= comparison: {comparison}")
print(f"= visualize: {visualize}")
print(f"= cleanup: {args.cleanup}")
print(f"= comparison: {args.comparison}")
print(f"= visualize: {args.visualize}")
print("========================================\n")

assert (
comparison == "mean"
or comparison == "pointwise"
or comparison == "weighted_mean"
args.comparison == "mean"
or args.comparison == "pointwise"
or args.comparison == "weighted_mean"
), 'Invalid "comparison" option!'

dump = run_problem(
executable,
build_type,
args.executable,
args.build_type,
input_file,
modified_inputs,
cleanup,
use_mpiexec,
oversubscribe,
mpi_nthreads,
args.cleanup,
args.use_mpiexec,
args.mpi_oversubscribe,
args.mpi_nthreads,
)

# Loop over meshblocks and cells and compare each variable to its corresponding solution
Expand Down Expand Up @@ -399,7 +397,7 @@ def analytic_comparison(
t, x, y, z
)
weighted_norm += solutions[nv](t, x, y, z)
if comparison == "pointwise":
if args.comparison == "pointwise":
if frac_error > tolerance:
success = False

Expand All @@ -413,14 +411,14 @@ def analytic_comparison(
print(f"Max error: {max_error:.2e}")
print(f"Max fractional error: {max_frac_error:.2e}")

if comparison == "mean":
if args.comparison == "mean":
if mean_frac_error > tolerance:
success = False
elif comparison == "weighted_mean":
elif args.comparison == "weighted_mean":
if mean_frac_error_weighted > tolerance:
success = False

if visualize:
if args.visualize:
import matplotlib.pyplot as plt

for nv, variable_name in enumerate(variables):
Expand All @@ -436,7 +434,7 @@ def analytic_comparison(
plt.savefig(f"../analytic_compare_{variable_name}.png")
plt.clf()

if cleanup == True:
if args.cleanup == True:
clean_up()

if success:
Expand All @@ -450,48 +448,42 @@ def analytic_comparison(
# -- Run test problem with previously built code, input file, and modified inputs, and
# compare to gold output
def gold_comparison(
args,
variables,
input_file,
modified_inputs={},
executable=None,
use_mpiexec=False,
oversubscribe=False,
build_type="Release",
upgold=False,
compression_factor=1,
tolerance=0.2,
cleanup=False,
comparison="mean",
):
input_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../inputs/", input_file
)
problem = read_input_value("parthenon/job", "problem_id", input_file)
print("\n=== GOLD COMPARISON TEST PROBLEM ===")
print(f"= problem: {problem}")
print(f"= executable: {executable}")
print(f"= use_mpiexec: {use_mpiexec}")
if use_mpiexec:
print(f"= oversubscribe: {oversubscribe}")
print(f"= build_type: {build_type}")
print(f"= executable: {args.executable}")
print(f"= use_mpiexec: {args.use_mpiexec}")
if args.use_mpiexec:
print(f"= oversubscribe: {args.mpi_oversubscribe}")
print(f"= build_type: {args.build_type}")
print(f"= compression: {compression_factor}")
print(f"= tolerance: {tolerance}")
print(f"= cleanup: {cleanup}")
print(f"= comparison: {comparison}")
print(f"= cleanup: {args.cleanup}")
print(f"= comparison: {args.comparison}")
print("====================================\n")

assert (
comparison == "mean" or comparison == "pointwise"
args.comparison == "mean" or args.comparison == "pointwise"
), 'Invalid "comparison" option!'

dump = run_problem(
executable,
build_type,
args.executable,
args.build_type,
input_file,
modified_inputs,
cleanup,
use_mpiexec,
oversubscribe,
args.cleanup,
args.use_mpiexec,
args.mpi_oversubscribe,
)

# Construct array of results values
Expand All @@ -515,7 +507,7 @@ def gold_comparison(
# Write gold file, or compare to existing gold file
success = True
gold_name = os.path.join("../", SCRIPT_NAME) + ".gold"
if upgold:
if args.upgold:
np.savetxt(gold_name, variables_data, newline="\n")
else:
gold_variables = np.loadtxt(gold_name)
Expand All @@ -525,7 +517,7 @@ def gold_comparison(
else:
for n in range(len(gold_variables)):
if not soft_equiv(variables_data[n], gold_variables[n], tol=tolerance):
if comparison == "pointwise":
if args.comparison == "pointwise":
success = False

norm = np.clip((variables_data + gold_variables) / 2, 1.0e-100, None)
Expand All @@ -534,15 +526,15 @@ def gold_comparison(
max_error = np.max(np.fabs(variables_data - gold_variables))
max_frac_error = np.max(np.fabs(variables_data - gold_variables) / norm)

if comparison == "mean":
if args.comparison == "mean":
if mean_frac_error > tolerance:
success = False

if cleanup == True:
if args.cleanup == True:
clean_up()

# Report upgolding, success, or failure
if upgold:
if args.upgold:
print(f"Gold file {gold_name} updated!")
else:
print(f"Mean error: {mean_error:.2e}")
Expand Down
9 changes: 1 addition & 8 deletions tst/stepdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,11 @@ def ur_solution(t, x, y, z):


code = rt.analytic_comparison(
args=args,
variables=["field.jaybenne.energy_tally"],
solutions=[ur_solution],
input_file=args.input,
modified_inputs=modified_inputs,
executable=args.executable,
use_mpiexec=args.use_mpiexec,
oversubscribe=args.oversubscribe,
mpi_nthreads=args.mpi_nthreads,
build_type=args.build_type,
tolerance=0.05,
cleanup=args.cleanup,
comparison=args.comparison,
)

sys.exit(code)
9 changes: 1 addition & 8 deletions tst/stepdiff_smr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,11 @@ def ur_solution(t, x, y, z):


code = rt.analytic_comparison(
args=args,
variables=["field.jaybenne.energy_tally"],
solutions=[ur_solution],
input_file=args.input,
modified_inputs=modified_inputs,
executable=args.executable,
use_mpiexec=args.use_mpiexec,
oversubscribe=args.oversubscribe,
mpi_nthreads=args.mpi_nthreads,
build_type=args.build_type,
tolerance=0.3,
cleanup=args.cleanup,
comparison=args.comparison,
)

sys.exit(code)

0 comments on commit 1174432

Please sign in to comment.