diff --git a/bayesflow/trainers.py b/bayesflow/trainers.py index d4779befb..16fb9f2bb 100644 --- a/bayesflow/trainers.py +++ b/bayesflow/trainers.py @@ -21,9 +21,9 @@ import logging import os from pickle import load as pickle_load -import tensorflow as tf import numpy as np +import tensorflow as tf from tqdm.autonotebook import tqdm from bayesflow.amortizers import ( @@ -737,7 +737,10 @@ def train_from_presimulation( input_dict = self.configurator(epoch_data[index]) # Like the number of iterations, the batch size is inferred from presimulated dictionary or list - batch_size = epoch_data[index][DEFAULT_KEYS["sim_data"]].shape[0] + if isinstance(self.amortizer, AmortizedModelComparison): + batch_size = input_dict[DEFAULT_KEYS["summary_conditions"]].shape[0] + else: + batch_size = epoch_data[index][DEFAULT_KEYS["sim_data"]].shape[0] loss = self._train_step(batch_size, _backprop_step, input_dict, **kwargs) # Store returned loss