diff --git a/STMiner/SPFinder.py b/STMiner/SPFinder.py index 53dbce0..4148579 100644 --- a/STMiner/SPFinder.py +++ b/STMiner/SPFinder.py @@ -85,13 +85,13 @@ def compare_gene_to_genes(self, gene_name): return compare_gmm_distance(gene_gmm, self.patterns) def get_genes_csr_array( - self, - min_cells, - normalize=True, - exclude_highly_expressed=False, - log1p=False, - vmax=99, - gene_list=None, + self, + min_cells: int, + normalize: bool = True, + exclude_highly_expressed: bool = False, + log1p: bool = False, + vmax: int = 99, + gene_list: list = None, ): error_gene_list = [] self.csr_dict = {} @@ -125,7 +125,7 @@ def get_genes_csr_array( print("Error when parse gene " + gene + "\nError: ") print(e) - def spatial_high_variable_genes(self, vmax=99, thread=1): + def spatial_high_variable_genes(self, vmax: int = 99, thread: int = 1): """ Compute the optimal transport (OT) distance matrix for high variable genes. @@ -154,7 +154,7 @@ def spatial_high_variable_genes(self, vmax=99, thread=1): if (not isinstance(thread, int)) or (thread <= 1): distance_dict = {} for key in tqdm( - list(self.csr_dict.keys()), desc="Computing ot distances..." + list(self.csr_dict.keys()), desc="Computing ot distances..." ): try: distance_dict[key] = calculate_ot_distance( @@ -191,15 +191,15 @@ def _mpl_worker(self, global_matrix, key, result_dict): result_dict[key] = res def fit_pattern( - self, - n_top_genes=-1, - n_comp=20, - normalize=True, - exclude_highly_expressed=False, - log1p=False, - min_cells=20, - gene_list=None, - remove_low_exp_spots=False, + self, + n_top_genes: int = -1, + n_comp: int = 20, + normalize: bool = True, + exclude_highly_expressed: bool = False, + log1p: bool = False, + min_cells: int = 20, + gene_list: list = None, + remove_low_exp_spots: bool = False, ): """ Given a distance matrix with the distances between each pair of objects in a set, and a chosen number of @@ -247,7 +247,12 @@ def fit_pattern( print(e) def preprocess( - self, normalize, exclude_highly_expressed, log1p, min_cells, n_top_genes=2000 + self, + normalize: bool, + exclude_highly_expressed: bool, + log1p: bool, + min_cells: int, + n_top_genes: int = 2000 ): sc.pp.filter_genes(self.adata, min_cells=min_cells) sc.pp.highly_variable_genes( @@ -264,54 +269,77 @@ def preprocess( sc.pp.log1p(self.adata) def build_distance_array(self, method="gmm", gene_list=None): + """ + Build a distance array for genes based on the specified method. + + This function supports four distance calculation methods: "gmm" (Gaussian Mixture Model), + "mse" (Mean Squared Error), "cs" (Cosine Similarity), and "ot" (Optimal Transport). + If no gene list is provided, all genes are used. + + Parameters: + - method (str): The distance calculation method to use. Default is "gmm". + - gene_list (list): A list of specific genes to use. If not provided, all genes are used. + + Returns: + No direct return value, but updates the `self.genes_distance_array` attribute with the calculated distances. + """ + # If no gene list is provided, use all genes if gene_list is None: gene_list = list(self.adata.var.index) + + # Build the distance array based on the specified method if method == "gmm": self.genes_distance_array = build_gmm_distance_array(self.patterns) elif method == "mse": self.genes_distance_array = build_mse_distance_array(self.adata, gene_list) elif method == "cs": - self.genes_distance_array = build_cosine_similarity_array( - self.adata, gene_list - ) + self.genes_distance_array = build_cosine_similarity_array(self.adata, gene_list) elif method == "ot": - self.genes_distance_array = build_ot_distance_array( - self.csr_dict, gene_list - ) + self.genes_distance_array = build_ot_distance_array(self.csr_dict, gene_list) + else: + # Raise an error if the method is unknown + raise ValueError("Unknown method, method should be one of gmm, mse, cs, ot") - def get_pattern_array(self, vote_rate=0): - self.patterns_binary_matrix_dict = {} - label_list = set(self.genes_labels["labels"]) - for label in label_list: - gene_list = list( - self.genes_labels[self.genes_labels["labels"] == label]["gene_id"] - ) - total_count = np.zeros(get_exp_array(self.adata, gene_list[0]).shape) - total_coo_list = [] - vote_array = np.zeros(get_exp_array(self.adata, gene_list[0]).shape) - for gene in gene_list: - exp_matrix = get_exp_array(self.adata, gene) - # calculate nonzero index - non_zero_coo_list = np.vstack((np.nonzero(exp_matrix))).T.tolist() - for coo in non_zero_coo_list: - total_coo_list.append(tuple(coo)) - total_count = scale_array(exp_matrix, total_count) - count_dict = Counter(total_coo_list) - for ele, count in count_dict.items(): - if int(count) / len(gene_list) >= vote_rate: - vote_array[ele] = 1 - total_count = total_count * vote_array - binary_arr = np.where(total_count != 0, 1, total_count) - self.patterns_matrix_dict[label] = total_count - self.patterns_binary_matrix_dict[label] = binary_arr + def get_pattern_array(self, vote_rate: int = 0, mode: str = "vote"): + if mode == "vote": + self.patterns_binary_matrix_dict = {} + label_list = set(self.genes_labels["labels"]) + for label in label_list: + gene_list = list( + self.genes_labels[self.genes_labels["labels"] == label]["gene_id"] + ) + total_count = np.zeros(get_exp_array(self.adata, gene_list[0]).shape) + total_coo_list = [] + vote_array = np.zeros(get_exp_array(self.adata, gene_list[0]).shape) + for gene in gene_list: + exp_matrix = get_exp_array(self.adata, gene) + # calculate nonzero index + non_zero_coo_list = np.vstack((np.nonzero(exp_matrix))).T.tolist() + for coo in non_zero_coo_list: + total_coo_list.append(tuple(coo)) + total_count = scale_array(exp_matrix, total_count) + count_dict = Counter(total_coo_list) + for ele, count in count_dict.items(): + if int(count) / len(gene_list) >= vote_rate: + vote_array[ele] = 1 + total_count = total_count * vote_array + binary_arr = np.where(total_count != 0, 1, total_count) + self.patterns_matrix_dict[label] = total_count + self.patterns_binary_matrix_dict[label] = binary_arr + elif mode == "test": + p_value_threshold = 0.05 + # TODO: rewrite test mode, improve run time. + + pass + else: + raise ValueError("mode should be vote or test") def cluster_gene( - self, - n_clusters, - mds_components=20, - use_highly_variable_gene=False, - n_top_genes=500, - ): + self, + n_clusters: int, + mds_components=20, + use_highly_variable_gene=False, + n_top_genes=500): # TODO: genes_labels should be int not float if use_highly_variable_gene: df = pd.DataFrame(self.genes_distance_array.mean(axis=1), columns=["mean"])