Skip to content

Commit

Permalink
Update SPFinder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PSSUN committed Dec 16, 2024
1 parent d2caeff commit c125107
Showing 1 changed file with 84 additions and 56 deletions.
140 changes: 84 additions & 56 deletions STMiner/SPFinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def compare_gene_to_genes(self, gene_name):
return compare_gmm_distance(gene_gmm, self.patterns)

def get_genes_csr_array(
self,
min_cells,
normalize=True,
exclude_highly_expressed=False,
log1p=False,
vmax=99,
gene_list=None,
self,
min_cells: int,
normalize: bool = True,
exclude_highly_expressed: bool = False,
log1p: bool = False,
vmax: int = 99,
gene_list: list = None,
):
error_gene_list = []
self.csr_dict = {}
Expand Down Expand Up @@ -125,7 +125,7 @@ def get_genes_csr_array(
print("Error when parse gene " + gene + "\nError: ")
print(e)

def spatial_high_variable_genes(self, vmax=99, thread=1):
def spatial_high_variable_genes(self, vmax: int = 99, thread: int = 1):
"""
Compute the optimal transport (OT) distance matrix for high variable genes.
Expand Down Expand Up @@ -154,7 +154,7 @@ def spatial_high_variable_genes(self, vmax=99, thread=1):
if (not isinstance(thread, int)) or (thread <= 1):
distance_dict = {}
for key in tqdm(
list(self.csr_dict.keys()), desc="Computing ot distances..."
list(self.csr_dict.keys()), desc="Computing ot distances..."
):
try:
distance_dict[key] = calculate_ot_distance(
Expand Down Expand Up @@ -191,15 +191,15 @@ def _mpl_worker(self, global_matrix, key, result_dict):
result_dict[key] = res

def fit_pattern(
self,
n_top_genes=-1,
n_comp=20,
normalize=True,
exclude_highly_expressed=False,
log1p=False,
min_cells=20,
gene_list=None,
remove_low_exp_spots=False,
self,
n_top_genes: int = -1,
n_comp: int = 20,
normalize: bool = True,
exclude_highly_expressed: bool = False,
log1p: bool = False,
min_cells: int = 20,
gene_list: list = None,
remove_low_exp_spots: bool = False,
):
"""
Given a distance matrix with the distances between each pair of objects in a set, and a chosen number of
Expand Down Expand Up @@ -247,7 +247,12 @@ def fit_pattern(
print(e)

def preprocess(
self, normalize, exclude_highly_expressed, log1p, min_cells, n_top_genes=2000
self,
normalize: bool,
exclude_highly_expressed: bool,
log1p: bool,
min_cells: int,
n_top_genes: int = 2000
):
sc.pp.filter_genes(self.adata, min_cells=min_cells)
sc.pp.highly_variable_genes(
Expand All @@ -264,54 +269,77 @@ def preprocess(
sc.pp.log1p(self.adata)

def build_distance_array(self, method="gmm", gene_list=None):
"""
Build a distance array for genes based on the specified method.
This function supports four distance calculation methods: "gmm" (Gaussian Mixture Model),
"mse" (Mean Squared Error), "cs" (Cosine Similarity), and "ot" (Optimal Transport).
If no gene list is provided, all genes are used.
Parameters:
- method (str): The distance calculation method to use. Default is "gmm".
- gene_list (list): A list of specific genes to use. If not provided, all genes are used.
Returns:
No direct return value, but updates the `self.genes_distance_array` attribute with the calculated distances.
"""
# If no gene list is provided, use all genes
if gene_list is None:
gene_list = list(self.adata.var.index)

# Build the distance array based on the specified method
if method == "gmm":
self.genes_distance_array = build_gmm_distance_array(self.patterns)
elif method == "mse":
self.genes_distance_array = build_mse_distance_array(self.adata, gene_list)
elif method == "cs":
self.genes_distance_array = build_cosine_similarity_array(
self.adata, gene_list
)
self.genes_distance_array = build_cosine_similarity_array(self.adata, gene_list)
elif method == "ot":
self.genes_distance_array = build_ot_distance_array(
self.csr_dict, gene_list
)
self.genes_distance_array = build_ot_distance_array(self.csr_dict, gene_list)
else:
# Raise an error if the method is unknown
raise ValueError("Unknown method, method should be one of gmm, mse, cs, ot")

def get_pattern_array(self, vote_rate=0):
self.patterns_binary_matrix_dict = {}
label_list = set(self.genes_labels["labels"])
for label in label_list:
gene_list = list(
self.genes_labels[self.genes_labels["labels"] == label]["gene_id"]
)
total_count = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
total_coo_list = []
vote_array = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
for gene in gene_list:
exp_matrix = get_exp_array(self.adata, gene)
# calculate nonzero index
non_zero_coo_list = np.vstack((np.nonzero(exp_matrix))).T.tolist()
for coo in non_zero_coo_list:
total_coo_list.append(tuple(coo))
total_count = scale_array(exp_matrix, total_count)
count_dict = Counter(total_coo_list)
for ele, count in count_dict.items():
if int(count) / len(gene_list) >= vote_rate:
vote_array[ele] = 1
total_count = total_count * vote_array
binary_arr = np.where(total_count != 0, 1, total_count)
self.patterns_matrix_dict[label] = total_count
self.patterns_binary_matrix_dict[label] = binary_arr
def get_pattern_array(self, vote_rate: int = 0, mode: str = "vote"):
if mode == "vote":
self.patterns_binary_matrix_dict = {}
label_list = set(self.genes_labels["labels"])
for label in label_list:
gene_list = list(
self.genes_labels[self.genes_labels["labels"] == label]["gene_id"]
)
total_count = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
total_coo_list = []
vote_array = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
for gene in gene_list:
exp_matrix = get_exp_array(self.adata, gene)
# calculate nonzero index
non_zero_coo_list = np.vstack((np.nonzero(exp_matrix))).T.tolist()
for coo in non_zero_coo_list:
total_coo_list.append(tuple(coo))
total_count = scale_array(exp_matrix, total_count)
count_dict = Counter(total_coo_list)
for ele, count in count_dict.items():
if int(count) / len(gene_list) >= vote_rate:
vote_array[ele] = 1
total_count = total_count * vote_array
binary_arr = np.where(total_count != 0, 1, total_count)
self.patterns_matrix_dict[label] = total_count
self.patterns_binary_matrix_dict[label] = binary_arr
elif mode == "test":
p_value_threshold = 0.05
# TODO: rewrite test mode, improve run time.

pass
else:
raise ValueError("mode should be vote or test")

def cluster_gene(
self,
n_clusters,
mds_components=20,
use_highly_variable_gene=False,
n_top_genes=500,
):
self,
n_clusters: int,
mds_components=20,
use_highly_variable_gene=False,
n_top_genes=500):
# TODO: genes_labels should be int not float
if use_highly_variable_gene:
df = pd.DataFrame(self.genes_distance_array.mean(axis=1), columns=["mean"])
Expand Down

0 comments on commit c125107

Please sign in to comment.