Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update decoding probability threshold #69

Merged
merged 2 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions deepcell_spots/applications/spot_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _decoding_output_to_dict(self, out):
'prediction', len(decoded_dict['probability'])).astype('U25')
return decoded_dict

def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, pred_prob_thresh=0.5):
def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, pred_prob_thresh=0.95):
"""Threshold the decoded spots to identify unknown. If the highest probability
if below a certain threshold, the spot will be classfied as Unknown.

Expand Down Expand Up @@ -451,7 +451,7 @@ def predict(self,
spots_intensities_vec,
num_iter=500,
batch_size=1000,
pred_prob_thresh=0.5,
pred_prob_thresh=0.95,
rescue_errors=True,
rescue_mixed=False):
"""Predict the gene assignment of each spot.
Expand All @@ -461,7 +461,7 @@ def predict(self,
`[num_spots, (rounds * channels)]`.
num_iter (int): Number of iterations for training. Defaults to 500.
batch_size (int): Size of batches for training. Defaults to 1000.
pred_prob_thresh (float): The threshold of unknown category, within [0,1]. Defaults to 0.5.
pred_prob_thresh (float): The threshold of unknown category, within [0,1]. Defaults to 0.95.
rescue_errors (bool): Whether to check if `'Background'`- and `'Unknown'`-assigned
spots have a Hamming distance of 1 to other barcodes.
rescue_mixed (bool): Whether to check if low probability predictions are the result of
Expand Down
6 changes: 4 additions & 2 deletions deepcell_spots/applications/spot_decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def test_spot_decoding_app(self):

spots_intensities_vec21 = np.ones((100, 6))
decoding_dict_trunc21 = app2.predict(
spots_intensities_vec=spots_intensities_vec21, num_iter=20, batch_size=100
spots_intensities_vec=spots_intensities_vec21, num_iter=20, batch_size=100,
pred_prob_thresh=0.5
)
self.assertListEqual(
decoding_dict_trunc21["predicted_id"].tolist(), (2 * np.ones((100,))).tolist()
Expand All @@ -85,7 +86,8 @@ def test_spot_decoding_app(self):

spots_intensities_vec22 = np.zeros((100, 6))
decoding_dict_trunc22 = app2.predict(
spots_intensities_vec=spots_intensities_vec22, num_iter=20, batch_size=100
spots_intensities_vec=spots_intensities_vec22, num_iter=20, batch_size=100,
pred_prob_thresh=0.5
)
self.assertListEqual(
decoding_dict_trunc22["predicted_id"].tolist(), np.ones((100,)).tolist()
Expand Down