diff --git a/mala/datahandling/data_shuffler.py b/mala/datahandling/data_shuffler.py index e7d7a07cb..223f51b99 100644 --- a/mala/datahandling/data_shuffler.py +++ b/mala/datahandling/data_shuffler.py @@ -130,6 +130,29 @@ def __shuffle_numpy( ) ) + # if the number of new snapshots is not a divisor of the grid size + # then we have to trim the original snapshots to size + # the indicies to be removed are selected at random + if self.data_points_to_remove is not None: + if self.parameters.shuffling_seed is not None: + np.random.seed(idx * self.parameters.shuffling_seed) + ngrid = descriptor_data[idx].shape[0] + n_descriptor = descriptor_data[idx].shape[-1] + n_target = target_data[idx].shape[-1] + + current_target = target_data[idx].reshape(-1, n_target) + current_descriptor = descriptor_data[idx].reshape( + -1, n_descriptor + ) + + indices = np.random.choice( + ngrid**3, + size=ngrid**3 - self.data_points_to_remove[idx], + ) + + descriptor_data[idx] = current_descriptor[indices] + target_data[idx] = current_target[indices] + # Do the actual shuffling. target_name_openpmd = os.path.join( target_save_path, save_name.replace("*", "%T") @@ -165,16 +188,12 @@ def __shuffle_numpy( ) new_descriptors[ last_start : current_chunk + last_start - ] = descriptor_data[j].reshape( - current_grid_size, self.input_dimension - )[ + ] = descriptor_data[j].reshape(-1, self.input_dimension)[ i * current_chunk : (i + 1) * current_chunk, : ] new_targets[ last_start : current_chunk + last_start - ] = target_data[j].reshape( - current_grid_size, self.output_dimension - )[ + ] = target_data[j].reshape(-1, self.output_dimension)[ i * current_chunk : (i + 1) * current_chunk, : ] @@ -240,7 +259,6 @@ def __shuffle_numpy( # It will be executed one after another for both of them. # Use this class to parameterize which of both should be shuffled. class __DescriptorOrTarget: - def __init__( self, save_path, @@ -258,7 +276,6 @@ def __init__( self.dimension = dimension class __MockedMPIComm: - def __init__(self): self.rank = 0 self.size = 1 @@ -521,6 +538,8 @@ def shuffle_snapshots( ] number_of_data_points = np.sum(snapshot_size_list) + self.data_points_to_remove = None + if number_of_shuffled_snapshots is None: # If the user does not tell us how many snapshots to use, # we have to check if the number of snapshots is straightforward. @@ -584,10 +603,40 @@ def shuffle_snapshots( del specified_number_of_new_snapshots if number_of_data_points % number_of_new_snapshots != 0: - raise Exception( - "Cannot create this number of snapshots " - "from data provided." - ) + if snapshot_type == "numpy": + self.data_points_to_remove = [] + for i in range(0, self.nr_snapshots): + gridsize = self.parameters.snapshot_directories_list[ + i + ].grid_size + shuffled_gridsize = int( + gridsize / number_of_new_snapshots + ) + self.data_points_to_remove.append( + gridsize + - shuffled_gridsize * number_of_new_snapshots + ) + tot_points_missing = sum(self.data_points_to_remove) + + printout( + "Warning: number of requested snapshots is not a divisor of", + "the original grid sizes.\n", + f"{tot_points_missing} / {number_of_data_points} data points", + "will be left out of the shuffled snapshots." + ) + + shuffle_dimensions = [ + int(number_of_data_points / number_of_new_snapshots), + 1, + 1, + ] + + elif snapshot_type == "openpmd": + # TODO implement arbitrary grid sizes for openpmd + raise Exception( + "Cannot create this number of snapshots " + "from data provided." + ) else: shuffle_dimensions = [ int(number_of_data_points / number_of_new_snapshots), @@ -606,7 +655,6 @@ def shuffle_snapshots( permutations = [] seeds = [] for i in range(0, number_of_new_snapshots): - # This makes the shuffling deterministic, if specified by the user. if self.parameters.shuffling_seed is not None: np.random.seed(i * self.parameters.shuffling_seed)