Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zFPKM Calculation #210

Open
wants to merge 5 commits into
base: remove-hardcoded-paths/merge-xomics
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 80 additions & 57 deletions main/como/rnaseq_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ async def _build_matrix_results(
:param taxon: The NCBI Taxon ID
:return: A dataclass `ReadMatrixResults`
"""
gene_info = gene_info_migrations(gene_info)
conversion = await ensembl_to_gene_id_and_symbol(ids=matrix["ensembl_gene_id"].tolist(), taxon=taxon)
conversion["ensembl_gene_id"] = conversion["ensembl_gene_id"].str.split(",")
conversion = conversion.explode("ensembl_gene_id")
conversion.reset_index(inplace=True, drop=True)
matrix = matrix.merge(conversion, on="ensembl_gene_id", how="left")

# Only include Entrez and Ensembl Gene IDs that are present in `gene_info`
Expand All @@ -181,6 +183,7 @@ async def _build_matrix_results(
matrix = matrix.replace(to_replace="-", value=pd.NA).dropna()
matrix["entrez_gene_id"] = matrix["entrez_gene_id"].astype(int)

gene_info = gene_info_migrations(gene_info)
gene_info = gene_info.replace(to_replace="-", value=pd.NA).dropna()
gene_info["entrez_gene_id"] = gene_info["entrez_gene_id"].astype(int)

Expand All @@ -189,6 +192,7 @@ async def _build_matrix_results(
on=["entrez_gene_id", "ensembl_gene_id"],
how="inner",
)

gene_info = gene_info.merge(
counts_matrix[["entrez_gene_id", "ensembl_gene_id"]],
on=["entrez_gene_id", "ensembl_gene_id"],
Expand Down Expand Up @@ -272,7 +276,7 @@ def calculate_fpkm(metrics: NamedMetrics) -> NamedMetrics:
return metrics


def _zfpkm_calculation(col: pd.Series, kernel: KernelDensity, peak_parameters: tuple[float, float]) -> _ZFPKMResult:
def _zfpkm_calculation(row: pd.Series, kernel: KernelDensity, peak_parameters: tuple[float, float]) -> _ZFPKMResult:
"""Log2 Transformations.

Stabilize the variance in the data to make the distribution more symmetric; this is helpful for Gaussian fitting
Expand Down Expand Up @@ -316,12 +320,12 @@ def _zfpkm_calculation(col: pd.Series, kernel: KernelDensity, peak_parameters: t
a threshold for calling a gene as "expressed"
: https://doi.org/10.1186/1471-2164-14-778
"""
col_log2: npt.NDArray = np.log2(col + 1)
col_log2 = np.nan_to_num(col_log2, nan=0)
refit: KernelDensity = kernel.fit(col_log2.reshape(-1, 1)) # type: ignore
row_log2: npt.NDArray = np.log2(row + 1)
row_log2 = np.nan_to_num(row_log2, nan=0)
refit: KernelDensity = kernel.fit(row_log2.reshape(-1, 1)) # type: ignore

# kde: KernelDensity = KernelDensity(kernel="gaussian", bandwidth=bandwidth).fit(col_log2.reshape(-1, 1))
x_range = np.linspace(col_log2.min(), col_log2.max(), 1000)
# kde: KernelDensity = KernelDensity(kernel="gaussian", bandwidth=bandwidth).fit(row_log2.reshape(-1, 1))
x_range = np.linspace(row_log2.min(), row_log2.max(), 1000)
density = np.exp(refit.score_samples(x_range.reshape(-1, 1)))
peaks, _ = find_peaks(density, height=peak_parameters[0], distance=peak_parameters[1])
peak_positions = x_range[peaks]
Expand All @@ -333,9 +337,9 @@ def _zfpkm_calculation(col: pd.Series, kernel: KernelDensity, peak_parameters: t
if len(peaks) != 0:
mu = peak_positions.max()
max_fpkm = density[peaks[np.argmax(peak_positions)]]
u = col_log2[col_log2 > mu].mean()
u = row_log2[row_log2 > mu].mean()
stddev = (u - mu) * np.sqrt(np.pi / 2)
zfpkm = pd.Series((col_log2 - mu) / stddev, dtype=np.float32, name=col.name)
zfpkm = pd.Series((row_log2 - mu) / stddev, dtype=np.float32, name=row.name)

return _ZFPKMResult(zfpkm=zfpkm, density=Density(x_range, density), mu=mu, std_dev=stddev, max_fpkm=max_fpkm)

Expand All @@ -354,49 +358,51 @@ def zfpkm_transform(
)
update_every_percent /= 100

total = len(fpkm_df.columns)
update_per_step: int = int(np.ceil(total * update_every_percent))
cores = min(multiprocessing.cpu_count() - 2, total)
logger.debug(f"Processing {total:,} samples through zFPKM transform using {cores} cores")
total_rows = len(fpkm_df)
update_per_step: int = int(np.ceil(total_rows * update_every_percent))
cores = max(min(multiprocessing.cpu_count() - 2, total_rows), 1) # Get at least 1 core and at most cpu_count() - 2
logger.debug(
f"Will update every {update_per_step:,} steps as this is approximately "
f"{update_every_percent:.1%} of {total:,}"
f"zFPKM transforming {total_rows:,} gene(s) across {len(fpkm_df.columns)} sample(s) using {cores} cores"
)
logger.debug(f"Will update every {update_per_step:,} steps (~{update_every_percent:.1%} of {total_rows:,})")

with Pool(processes=cores) as pool:
kernel = KernelDensity(kernel="gaussian", bandwidth=bandwidth)
chunksize = int(math.ceil(len(fpkm_df.columns) / (4 * cores)))
partial_func = partial(_zfpkm_calculation, kernel=kernel, peak_parameters=peak_parameters)
chunk_time = time.time()
start_time = time.time()
log_padding = len(str(f"{total_rows:,}"))

log_padding = len(str(f"{total:,}"))
zfpkm_df = pd.DataFrame(data=0, index=fpkm_df.index, columns=fpkm_df.columns)
zfpkm_series: list[pd.Series | None] = [None] * total_rows
results: dict[str, _ZFPKMResult] = {}
result: _ZFPKMResult
for i, result in enumerate(
pool.imap(
partial_func,
(fpkm_df[col] for col in fpkm_df.columns),
(row for _, row in fpkm_df.iterrows()),
chunksize=chunksize,
)
):
key = str(result.zfpkm.name)
results[key] = result
zfpkm_df[key] = result.zfpkm
zfpkm_series[i] = result.zfpkm

# show updates every X% and at the end, but skip on first iteration
if i != 0 and (i % update_per_step == 0 or i == total):
if i != 0 and (i % update_per_step == 0 or i >= total_rows):
current_time = time.time()
chunk = current_time - chunk_time
total_time = current_time - start_time
formatted = f"{i:,}"
chunk_num = f"{i:,}"
logger.debug(
f"Processed {formatted:>{log_padding}} of {total:,} - "
f"Processed {chunk_num:>{log_padding}} of {total_rows:,} - "
f"chunk took {chunk:.1f} seconds - "
f"running for {total_time:.1f} seconds"
)
chunk_time = current_time

zfpkm_df = pd.concat(zfpkm_series, axis=1)

return results, zfpkm_df


Expand All @@ -408,7 +414,7 @@ def zfpkm_plot(results, *, plot_xfloor: int = -4, subplot_titles: bool = True):
:param plot_xfloor: Lower limit for the x-axis.
:param subplot_titles: Whether to display facet titles (sample names).
"""
mega_df = pd.DataFrame(columns=["sample_name", "log2fpkm", "fpkm_density", "fitted_density_scaled"])
to_concat: list[pd.DataFrame | None] = [None] * len(results)
JoshLoecker marked this conversation as resolved.
Show resolved Hide resolved
for name, result in results.items():
stddev = result.std_dev
x = np.array(result.density.x)
Expand All @@ -419,15 +425,18 @@ def zfpkm_plot(results, *, plot_xfloor: int = -4, subplot_titles: bool = True):
max_fitted = fitted.max()
scale_fitted = fitted * (max_fpkm / max_fitted)

df = pd.DataFrame(
{
"sample_name": [name] * len(x),
"log2fpkm": x,
"fpkm_density": y,
"fitted_density_scaled": scale_fitted,
}
to_concat.append(
pd.DataFrame(
{
"sample_name": [name] * len(x),
"log2fpkm": x,
"fpkm_density": y,
"fitted_density_scaled": scale_fitted,
}
)
)
mega_df = pd.concat([mega_df, df], ignore_index=True)
mega_df = pd.concat(to_concat, ignore_index=True)
mega_df.columns = pd.Series(data=["sample_name", "log2fpkm", "fpkm_density", "fitted_density_scaled"])

mega_df = mega_df.melt(id_vars=["log2fpkm", "sample_name"], var_name="source", value_name="density")
subplot_titles = list(results.keys()) if subplot_titles else None
Expand Down Expand Up @@ -547,7 +556,7 @@ def tpm_quantile_filter(*, metrics: NamedMetrics, filtering_options: _FilteringO

# Only keep `entrez_gene_ids` that pass `min_genes`
metric.entrez_gene_ids = [gene for gene, keep in zip(entrez_ids, min_genes) if keep]
metric.gene_sizes = [gene for gene, keep in zip(gene_size, min_genes) if keep]
metric.gene_sizes = np.array(gene for gene, keep in zip(gene_size, min_genes) if keep)
metric.count_matrix = metric.count_matrix.iloc[min_genes, :]
metric.normalization_matrix = metrics[sample].normalization_matrix.iloc[min_genes, :]

Expand All @@ -564,15 +573,13 @@ def zfpkm_filter(*, metrics: NamedMetrics, filtering_options: _FilteringOptions,
min_sample_expression = filtering_options.replicate_ratio
high_confidence_sample_expression = filtering_options.high_replicate_ratio
cut_off = filtering_options.cut_off

if calcualte_fpkm:
metrics = calculate_fpkm(metrics)
metrics = calculate_fpkm(metrics) if calcualte_fpkm else metrics
JoshLoecker marked this conversation as resolved.
Show resolved Hide resolved

metric: _StudyMetrics
for metric in metrics.values():
# if fpkm was not calculated, the normalization matrix will be empty; collect the count matrix instead
matrix = metric.count_matrix if metric.normalization_matrix.empty else metric.normalization_matrix
matrix = matrix[matrix.sum(axis=1) > 0]
matrix = matrix[matrix.sum(axis=1) > 0] # remove rows (genes) that have no counts

minimums = matrix == 0
results, zfpkm_df = zfpkm_transform(matrix)
Expand Down Expand Up @@ -623,7 +630,6 @@ async def _save_rnaseq_tests(
rnaseq_matrix: pd.DataFrame,
metadata_df: pd.DataFrame,
gene_info_df: pd.DataFrame,
output_filepath: Path,
prep: RNAPrepMethod,
taxon: int,
replicate_ratio: float,
Expand All @@ -632,6 +638,8 @@ async def _save_rnaseq_tests(
high_batch_ratio: float,
technique: FilteringTechnique,
cut_off: int | float,
output_boolean_activity_filepath: Path,
output_zscore_normalization_filepath: Path,
):
"""Save the results of the RNA-Seq tests to a CSV file."""
filtering_options = _FilteringOptions(
Expand All @@ -651,19 +659,32 @@ async def _save_rnaseq_tests(
metrics = read_counts_results.metrics
entrez_gene_ids = read_counts_results.entrez_gene_ids

metrics = filter_counts(
metrics: NamedMetrics = filter_counts(
context_name=context_name,
metrics=metrics,
technique=technique,
filtering_options=filtering_options,
prep=prep,
)

merged_zscore_df = pd.DataFrame()
expressed_genes: list[str] = []
top_genes: list[str] = []
for metric in metrics.values():
expressed_genes.extend(metric.entrez_gene_ids)
top_genes.extend(metric.high_confidence_entrez_gene_ids)
if metric.normalization_matrix is not None:
merged_zscore_df = (
metric.normalization_matrix
if merged_zscore_df.empty
else pd.concat(
[merged_zscore_df, metric.normalization_matrix],
axis=1,
)
)
merged_zscore_df.index = pd.Series(entrez_gene_ids, name="entrez_gene_id")
merged_zscore_df.to_csv(output_zscore_normalization_filepath, index=True)
logger.success(f"Wrote z-score normalization matrix to {output_zscore_normalization_filepath}")

expression_frequency = pd.Series(expressed_genes).value_counts()
expression_df = pd.DataFrame(
Expand All @@ -687,11 +708,11 @@ async def _save_rnaseq_tests(
expressed_count = len(boolean_matrix[boolean_matrix["expressed"] == 1])
high_confidence_count = len(boolean_matrix[boolean_matrix["high"] == 1])

boolean_matrix.to_csv(output_filepath, index=False)
boolean_matrix.to_csv(output_boolean_activity_filepath, index=False)
logger.info(
f"{context_name} - Found {expressed_count} expressed and {high_confidence_count} confidently expressed genes"
)
logger.success(f"Wrote boolean matrix to {output_filepath}")
logger.success(f"Wrote boolean matrix to {output_boolean_activity_filepath}")


async def _create_metadata_df(path: Path) -> pd.DataFrame:
Expand All @@ -703,15 +724,15 @@ async def _create_metadata_df(path: Path) -> pd.DataFrame:
return pd.read_excel(path)


async def rnaseq_gen( # noqa: C901, allow complex function
async def rnaseq_gen(
context_name: str,
input_rnaseq_filepath: Path,
input_gene_info_filepath: Path,
output_rnaseq_filepath: Path,
prep: RNAPrepMethod,
taxon: int,
input_metadata_filepath: Path | None = None,
input_metadata_df: pd.DataFrame | None = None,
taxon_id: int,
output_boolean_activity_filepath: Path,
output_zscore_normalization_filepath: Path,
input_metadata_filepath_or_df: Path | pd.DataFrame,
replicate_ratio: float = 0.5,
high_replicate_ratio: float = 1.0,
batch_ratio: float = 0.5,
Expand All @@ -728,11 +749,11 @@ async def rnaseq_gen( # noqa: C901, allow complex function
:param context_name: The name of the context being processed
:param input_rnaseq_filepath: The filepath to the gene count matrix
:param input_gene_info_filepath: The filepath to the gene info file
:param output_rnaseq_filepath: The filepath to write the output gene count matrix
:param output_boolean_activity_filepath: The filepath to write the output gene count matrix
:param output_zscore_normalization_filepath: The filepath to write the output z-score normalization matrix
:param prep: The preparation method
:param taxon: The NCBI Taxon ID
:param input_metadata_filepath: The filepath to the metadata file
:param input_metadata_df: The metadata dataframe
:param taxon_id: The NCBI Taxon ID
:param input_metadata_filepath_or_df: The filepath or dataframe containing metadata information
:param replicate_ratio: The percentage of replicates that a gene must
appear in for a gene to be marked as "active" in a batch/study
:param batch_ratio: The percentage of batches that a gene must appear in for a gene to be marked as 'active"
Expand All @@ -744,9 +765,6 @@ async def rnaseq_gen( # noqa: C901, allow complex function
:param cutoff: The cutoff value to use for the provided filtering technique
:return: None
"""
if not input_metadata_df and not input_metadata_filepath:
raise ValueError("At least one of input_metadata_filepath or input_metadata_df must be provided")

technique = (
FilteringTechnique.from_string(str(technique.lower())) if isinstance(technique, (str, int)) else technique
)
Expand Down Expand Up @@ -782,20 +800,25 @@ async def rnaseq_gen( # noqa: C901, allow complex function
)

logger.debug(f"Starting '{context_name}'")
output_rnaseq_filepath.parent.mkdir(parents=True, exist_ok=True)

output_boolean_activity_filepath.parent.mkdir(parents=True, exist_ok=True)
metadata_df = (
input_metadata_filepath_or_df
if isinstance(input_metadata_filepath_or_df, pd.DataFrame)
else await _create_metadata_df(input_metadata_filepath_or_df)
)
await _save_rnaseq_tests(
context_name=context_name,
rnaseq_matrix=await _read_counts(input_rnaseq_filepath),
metadata_df=input_metadata_df or await _create_metadata_df(input_metadata_filepath),
metadata_df=metadata_df,
gene_info_df=pd.read_csv(input_gene_info_filepath),
output_filepath=output_rnaseq_filepath,
prep=prep,
taxon=taxon,
taxon=taxon_id,
replicate_ratio=replicate_ratio,
batch_ratio=batch_ratio,
high_replicate_ratio=high_replicate_ratio,
high_batch_ratio=high_batch_ratio,
technique=technique,
cut_off=cutoff,
output_boolean_activity_filepath=output_boolean_activity_filepath,
output_zscore_normalization_filepath=output_zscore_normalization_filepath,
)
22 changes: 11 additions & 11 deletions main/como/rnaseq_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,12 +578,12 @@ async def rnaseq_preprocess(
output_gene_info_filepath: Path,
como_context_dir: Path | None = None,
input_matrix_filepath: Path | list[Path] | None = None,
output_trna_config_filepath: Path | None = None,
output_mrna_config_filepath: Path | None = None,
output_trna_metadata_filepath: Path | None = None,
output_mrna_metadata_filepath: Path | None = None,
output_trna_count_matrix_filepath: Path | None = None,
output_mrna_count_matrix_filepath: Path | None = None,
cache: bool = True,
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
log_level: LOG_LEVEL = "INFO",
log_location: str | TextIOWrapper = sys.stderr,
) -> None:
"""Preprocesses RNA-seq data for downstream analysis.
Expand All @@ -594,8 +594,8 @@ async def rnaseq_preprocess(
:param context_name: The context/cell type being processed
:param taxon: The NCBI taxonomy ID
:param output_gene_info_filepath: Path to the output gene information CSV file
:param output_trna_config_filepath: Path to the output tRNA config file (if in "create" mode)
:param output_mrna_config_filepath: Path to the output mRNA config file (if in "create" mode)
:param output_trna_metadata_filepath: Path to the output tRNA config file (if in "create" mode)
:param output_mrna_metadata_filepath: Path to the output mRNA config file (if in "create" mode)
:param output_trna_count_matrix_filepath: The path to write total RNA count matrices
:param output_mrna_count_matrix_filepath: The path to write messenger RNA count matrices
:param como_context_dir: If in "create" mode, the input path(s) to the COMO_input directory of the current context
Expand All @@ -616,11 +616,11 @@ async def rnaseq_preprocess(
output_gene_info_filepath = output_gene_info_filepath.resolve()
como_context_dir = como_context_dir.resolve()
input_matrix_filepath = [i.resolve() for i in _listify(input_matrix_filepath)] if input_matrix_filepath else None
output_trna_config_filepath = (
output_trna_config_filepath.resolve() if output_trna_config_filepath else output_trna_config_filepath
output_trna_metadata_filepath = (
output_trna_metadata_filepath.resolve() if output_trna_metadata_filepath else output_trna_metadata_filepath
)
output_mrna_config_filepath = (
output_mrna_config_filepath.resolve() if output_mrna_config_filepath else output_mrna_config_filepath
output_mrna_metadata_filepath = (
output_mrna_metadata_filepath.resolve() if output_mrna_metadata_filepath else output_mrna_metadata_filepath
)
output_trna_count_matrix_filepath = (
output_trna_count_matrix_filepath.resolve()
Expand All @@ -640,8 +640,8 @@ async def rnaseq_preprocess(
como_context_dir=como_context_dir,
input_matrix_filepath=input_matrix_filepath,
output_gene_info_filepath=output_gene_info_filepath,
output_trna_config_filepath=output_trna_config_filepath,
output_mrna_config_filepath=output_mrna_config_filepath,
output_trna_config_filepath=output_trna_metadata_filepath,
output_mrna_config_filepath=output_mrna_metadata_filepath,
output_trna_matrix_filepath=output_trna_count_matrix_filepath,
output_mrna_matrix_filepath=output_mrna_count_matrix_filepath,
cache=cache,
Expand Down
Loading