diff --git a/holodeck/librarian/combine.py b/holodeck/librarian/combine.py index 202fa369..ede1a380 100644 --- a/holodeck/librarian/combine.py +++ b/holodeck/librarian/combine.py @@ -142,6 +142,7 @@ def sam_lib_combine( param_names = pspace.param_names param_samples = pspace.param_samples + # Standard Library: vary all parameters together if library: if param_samples == None: # noqa : use `== None` to match arrays log.error(f"`library`={library} but `param_samples`={param_samples}`") @@ -150,6 +151,8 @@ def sam_lib_combine( nsamp_all, ndim = param_samples.shape log.debug(f"{nsamp_all=}, {ndim=}, {param_names=}") + + # Domain: Vary only one parameter at a time to explore the domain else: err = f"Expected 'domain' but `param_samples`={param_samples} is not `None`!" assert param_samples == None, err # noqa : use `== None` to match arrays @@ -158,6 +161,9 @@ def sam_lib_combine( nsamp_dim = args.nsamples # get the total number of samples nsamp_all = ndim * nsamp_dim + # for 'domain', param_samples will eventually be shaped (ndim, nsamp_dim, ndim), but load + # the data first as `(nsamp_all, ndim)`, then we will reshape. + param_samples = np.zeros((nsamp_all, ndim)) # ---- make sure all files exist; get shape information from files @@ -197,8 +203,8 @@ def sam_lib_combine( sspar = None bgpar = None - gwb, hc_ss, hc_bg, sspar, bgpar, bad_files = _load_library_from_all_files( - path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, log, library + gwb, hc_ss, hc_bg, sspar, bgpar, param_samples, bad_files = _load_library_from_all_files( + path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, param_samples, log, library ) if has_gwb: log.info(f"Loaded data from all library files | {holo.utils.stats(gwb)=}") @@ -211,8 +217,7 @@ def sam_lib_combine( with h5py.File(lib_path, 'w') as h5: h5.create_dataset('fobs_cents', data=fobs_cents) h5.create_dataset('fobs_edges', data=fobs_edges) - if param_samples is not None: - h5.create_dataset('sample_params', data=param_samples) + h5.create_dataset('sample_params', data=param_samples) if gwb is not None: # if 'domain', reshape to include each dimension if not library: @@ -354,7 +359,10 @@ def _check_files_and_load_shapes(log, path_sims, nsamp, library): return fobs_cents, fobs_edges, nreals, nloudest, has_gwb, has_ss, has_params -def _load_library_from_all_files(path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, log, library): +def _load_library_from_all_files( + path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, param_samples, + log, library, +): """Load data from all individual simulation files. Arguments @@ -368,6 +376,7 @@ def _load_library_from_all_files(path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, log hc_bg sspar bgpar + param_samples log : ``logging.Logger`` Logging instance. library : bool @@ -375,18 +384,18 @@ def _load_library_from_all_files(path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, log """ if hc_bg is not None: - nsamp = hc_bg.shape[0] + nsamp_all = hc_bg.shape[0] elif gwb is not None: - nsamp = gwb.shape[0] + nsamp_all = gwb.shape[0] else: err = "Unable to get shape from either `hc_bg` or `gwb`!" log.exception(err) raise RuntimeError(err) - log.info(f"Collecting data from {nsamp} files") - bad_files = np.zeros(nsamp, dtype=bool) #: track which files contain UN-useable data + log.info(f"Collecting data from {nsamp_all} files") + bad_files = np.zeros(nsamp_all, dtype=bool) #: track which files contain UN-useable data msg = None - for pnum in tqdm.trange(nsamp): + for pnum in tqdm.trange(nsamp_all): fname = libraries._get_sim_fname(path_sims, pnum, library=library) temp = np.load(fname, allow_pickle=True) # When a processor fails for a given parameter, the output file is still created with the 'fail' key added @@ -404,6 +413,11 @@ def _load_library_from_all_files(path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, log bad_files[pnum] = True continue + # for 'domain' simulations, we need to load the parameters. For 'library' runs, we + # already have them. + if not library: + param_samples[pnum, :] = temp['params'][:] + # store the GWB from this file if gwb is not None: gwb[pnum, :, :] = temp['gwb'][...] @@ -418,7 +432,7 @@ def _load_library_from_all_files(path_sims, gwb, hc_ss, hc_bg, sspar, bgpar, log log.info(f"{holo.utils.frac_str(bad_files)} files are failures") - return gwb, hc_ss, hc_bg, sspar, bgpar, bad_files + return gwb, hc_ss, hc_bg, sspar, bgpar, param_samples, bad_files if __name__ == "__main__": diff --git a/holodeck/librarian/gen_lib.py b/holodeck/librarian/gen_lib.py index decb3f60..48e8814d 100644 --- a/holodeck/librarian/gen_lib.py +++ b/holodeck/librarian/gen_lib.py @@ -307,6 +307,8 @@ def run_sam_at_pspace_params(args, space, pnum, params): gwb_flag=args.gwb_flag, singles_flag=args.ss_flag, details_flag=False, params_flag=args.params_flag, log=log, ) + data['params'] = params + data['param_names'] = space.param_names rv = True except Exception as err: