diff --git a/lime/lime_image.py b/lime/lime_image.py index ea3940e2..7eae0a5c 100644 --- a/lime/lime_image.py +++ b/lime/lime_image.py @@ -7,7 +7,7 @@ import numpy as np import sklearn from sklearn.utils import check_random_state -from skimage.color import gray2rgb +from skimage.color import gray2rgb, rgb2gray from tqdm.auto import tqdm @@ -261,7 +261,7 @@ def data_labels(self, for z in zeros: mask[segments == z] = True temp[mask] = fudged_image[mask] - imgs.append(temp) + imgs.append(rgb2gray(temp).reshape(temp.shape[0], temp.shape[1], 1)) if len(imgs) == batch_size: preds = classifier_fn(np.array(imgs)) labels.extend(preds)