diff --git a/mala/network/predictor.py b/mala/network/predictor.py index 5a4a44588..0e1c6e484 100644 --- a/mala/network/predictor.py +++ b/mala/network/predictor.py @@ -1,4 +1,4 @@ -"""Tester class for testing a network.""" +"""Predictor class.""" import numpy as np import torch @@ -59,13 +59,6 @@ def predict_from_qeout(self, path_to_file, gather_ldos=False): predicted_ldos : numpy.array Precicted LDOS for these atomic positions. """ - self.data.grid_dimension = self.parameters.inference_data_grid - self.data.grid_size = ( - self.data.grid_dimension[0] - * self.data.grid_dimension[1] - * self.data.grid_dimension[2] - ) - self.data.target_calculator.read_additional_calculation_data( path_to_file, "espresso-out" ) @@ -230,18 +223,17 @@ def _forward_snap_descriptors( ) for i in range(0, self.number_of_batches_per_snapshot): - inputs = snap_descriptors[ - i - * self.parameters.mini_batch_size : (i + 1) - * self.parameters.mini_batch_size - ] - inputs = inputs.to(self.parameters._configuration["device"]) - predicted_outputs[ - i - * self.parameters.mini_batch_size : (i + 1) - * self.parameters.mini_batch_size - ] = self.data.output_data_scaler.inverse_transform( - self.network(inputs).to("cpu"), as_numpy=True + sl = slice( + i * self.parameters.mini_batch_size, + (i + 1) * self.parameters.mini_batch_size, + ) + inputs = snap_descriptors[sl].to( + self.parameters._configuration["device"] + ) + predicted_outputs[sl] = ( + self.data.output_data_scaler.inverse_transform( + self.network(inputs).to("cpu"), as_numpy=True + ) ) # Restricting the actual quantities to physical meaningful values,