diff --git a/deepcell_spots/applications/polaris.py b/deepcell_spots/applications/polaris.py index 6b9719a..7b7ca51 100644 --- a/deepcell_spots/applications/polaris.py +++ b/deepcell_spots/applications/polaris.py @@ -474,14 +474,18 @@ def _predict(self, dec_prob_im[b, x, y] = prob - decoded_spots_locations = max_cp_array_to_point_list_max(dec_prob_im, - threshold=None, min_distance=1) mask = [] - for i in range(np.shape(decoded_spots_locations)[1]): - x = decoded_spots_locations[0][i, 0] - y = decoded_spots_locations[0][i, 1] + for b in range(spots_image.shape[0]): + decoded_spots_locations = max_cp_array_to_point_list_max(dec_prob_im[b:b+1], + threshold=None, min_distance=1) - mask.append(df_results.loc[(df_results.x==x) & (df_results.y==y)].index[0]) + for i in range(np.shape(decoded_spots_locations)[1]): + x = decoded_spots_locations[0][i, 0] + y = decoded_spots_locations[0][i, 1] + + mask.append(df_results.loc[(df_results.x==x) & + (df_results.y==y) & + (df_results.batch_id==b)].index[0]) df_results = df_results.loc[mask]