diff --git a/deepcell_spots/applications/spot_decoding.py b/deepcell_spots/applications/spot_decoding.py index b51a25f..37ac5d7 100644 --- a/deepcell_spots/applications/spot_decoding.py +++ b/deepcell_spots/applications/spot_decoding.py @@ -220,7 +220,7 @@ def _decoding_output_to_dict(self, out): 'prediction', len(decoded_dict['probability'])).astype('U25') return decoded_dict - def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, pred_prob_thresh=0.5): + def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, pred_prob_thresh=0.95): """Threshold the decoded spots to identify unknown. If the highest probability if below a certain threshold, the spot will be classfied as Unknown. @@ -451,7 +451,7 @@ def predict(self, spots_intensities_vec, num_iter=500, batch_size=1000, - pred_prob_thresh=0.5, + pred_prob_thresh=0.95, rescue_errors=True, rescue_mixed=False): """Predict the gene assignment of each spot. @@ -461,7 +461,7 @@ def predict(self, `[num_spots, (rounds * channels)]`. num_iter (int): Number of iterations for training. Defaults to 500. batch_size (int): Size of batches for training. Defaults to 1000. - pred_prob_thresh (float): The threshold of unknown category, within [0,1]. Defaults to 0.5. + pred_prob_thresh (float): The threshold of unknown category, within [0,1]. Defaults to 0.95. rescue_errors (bool): Whether to check if `'Background'`- and `'Unknown'`-assigned spots have a Hamming distance of 1 to other barcodes. rescue_mixed (bool): Whether to check if low probability predictions are the result of diff --git a/deepcell_spots/applications/spot_decoding_test.py b/deepcell_spots/applications/spot_decoding_test.py index c1de046..a0e8728 100644 --- a/deepcell_spots/applications/spot_decoding_test.py +++ b/deepcell_spots/applications/spot_decoding_test.py @@ -76,7 +76,8 @@ def test_spot_decoding_app(self): spots_intensities_vec21 = np.ones((100, 6)) decoding_dict_trunc21 = app2.predict( - spots_intensities_vec=spots_intensities_vec21, num_iter=20, batch_size=100 + spots_intensities_vec=spots_intensities_vec21, num_iter=20, batch_size=100, + pred_prob_thresh=0.5 ) self.assertListEqual( decoding_dict_trunc21["predicted_id"].tolist(), (2 * np.ones((100,))).tolist() @@ -85,7 +86,8 @@ def test_spot_decoding_app(self): spots_intensities_vec22 = np.zeros((100, 6)) decoding_dict_trunc22 = app2.predict( - spots_intensities_vec=spots_intensities_vec22, num_iter=20, batch_size=100 + spots_intensities_vec=spots_intensities_vec22, num_iter=20, batch_size=100, + pred_prob_thresh=0.5 ) self.assertListEqual( decoding_dict_trunc22["predicted_id"].tolist(), np.ones((100,)).tolist()