Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement EM convergence with tolerance threshold of log-likelihood change #20

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions irescue/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def parse_maps(maps_file, feature_index):
eqcl = [EquivalenceClass(i, umi, feat, count)]
yield it, eqcl

def compute_cell_counts(equivalence_classes, features_index, dumpEC):
def compute_cell_counts(equivalence_classes, features_index, max_iters,
tolerance, dumpEC):
"""
Calculate TE counts of a single cell, given a list of equivalence classes.

Expand Down Expand Up @@ -204,6 +205,8 @@ def compute_cell_counts(equivalence_classes, features_index, dumpEC):
for x in path_:
# add parent's UMI sequence and dedup features
dump[x] += (dump[parent_][0], features[i])
# EM stats placeholder in case of no multimapped UMIs
em_stats = (None, None)
if em_array:
# optimize the assignment of UMI from multimapping reads
em_array = np.array(em_array)
Expand All @@ -213,12 +216,16 @@ def compute_cell_counts(equivalence_classes, features_index, dumpEC):
todel = np.argwhere(np.all(em_array[..., :] == 0, axis=0))
em_array = np.delete(em_array, todel, axis=1)
# run EM
em_counts = run_em(em_array, cycles=100)
em_counts, em_stats = run_em(
em_array,
cycles=max_iters,
tolerance=tolerance
)
em_counts = [x*em_array.shape[0] for x in em_counts]
for i, c in zip(tokeep, em_counts):
if c > 0:
counts[i] += c
return dict(counts), dump
return dict(counts), dump, em_stats

def split_barcodes(barcodes_file, n):
"""
Expand All @@ -232,8 +239,8 @@ def split_barcodes(barcodes_file, n):
for i, chunk in enumerate(get_ranges(nBarcodes, n)):
yield i, {next(f).strip(): x+1 for x in chunk}

def run_count(maps_file, features_index, tmpdir, dumpEC, verbose,
barcodes_set):
def run_count(maps_file, features_index, tmpdir, dumpEC, max_iters, tolerance,
verbose, barcodes_set):
# NB: keep args order consistent with main.countFun
taskn, barcodes = barcodes_set
matrix_file = os.path.join(tmpdir, f'{taskn}_matrix.mtx.gz')
Expand All @@ -250,14 +257,16 @@ def run_count(maps_file, features_index, tmpdir, dumpEC, verbose,
f'{cellidx} ({cellbarcode.decode()})',
level=2, send=verbose
)
cellcounts, dump = compute_cell_counts(
cellcounts, dump, em_stats = compute_cell_counts(
equivalence_classes=cellmaps,
features_index=features_index,
max_iters=max_iters,
tolerance=tolerance,
dumpEC=dumpEC
)
writerr(
f'[{taskn}] Write count for cell '
f'{cellidx} ({cellbarcode.decode()})',
f'[{taskn}] Write cell {cellidx} ({cellbarcode.decode()}). '
f'EM cycles: {em_stats[0]}. Coverged: {em_stats[1]}.',
level=1, send=verbose
)
# round counts to 3rd decimal point and write to matrix file
Expand Down
41 changes: 36 additions & 5 deletions irescue/em.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@ def m_step(matrix):
counts = matrix.sum(axis=0) / matrix.sum()
return(counts)

def run_em(matrix, cycles=100):
def log_likelihood(matrix, counts):
"""
Compute log-likelihood of data.
"""
likelihoods = (matrix * counts).sum(axis=1)
log_likelihood = np.sum(np.log(likelihoods + np.finfo(float).eps))
return log_likelihood

def run_em(matrix, cycles=100, tolerance=1e-5):
"""
Run Expectation-Maximization (EM) algorithm to redistribute read counts
across a set of features.
Expand All @@ -28,22 +36,45 @@ def run_em(matrix, cycles=100):
Reads-features compatibility matrix.
cycles : int, optional
Number of EM cycles.
tolerance : float
Tolerance threshold of log-likelihood difference to infer convergence.

Returns
-------
out : list
Optimized relative feature abundances.
cycle : int
Number of EM cycles.
converged : bool
Indicates if convergence has been reached before cycles theshold.
"""

# calculate initial estimation of relative abundance.
# (let the sum of counts of features be 1,
# will be multiplied by the real UMI count later)
nFeatures = matrix.shape[1]
counts = np.array([1 / nFeatures] * nFeatures)

# run EM for n cycles
for _ in range(cycles):
# Initial log-likelihood
prev_loglik = log_likelihood(matrix, counts)

converged = False
curr_cycle = 0

# Run EM iterations
while curr_cycle < cycles:
curr_cycle += 1
e_matrix = e_step(matrix=matrix, counts=counts)
counts = m_step(matrix=e_matrix)

return(counts)
# Compute the new log-likelihood
loglik = log_likelihood(matrix, counts)

# Check for convergence
if np.abs(loglik - prev_loglik) < tolerance:
converged = True
break

prev_loglik = loglik

return counts, (curr_cycle, converged)
8 changes: 7 additions & 1 deletion irescue/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def parseArguments():
help="Minimum overlap between read and TE"
" as a fraction of read's alignment"
" (i.e. 0.00 <= NUM <= 1.00) (Default: disabled).")
parser.add_argument('--max-iters', type=int, metavar='INT', default=100,
help="Maximum number of EM iterations "
"(Default: %(default)s).")
parser.add_argument('--tolerance', type=float, metavar='FLOAT',
default=1e-4, help="Log-likelihood change below which "
"convergence is assumed (Default: %(default)s).")
parser.add_argument('--dump-ec', action='store_true',
help="Write a description log file of Equivalence "
"Classes.")
Expand Down Expand Up @@ -190,7 +196,7 @@ def main():
# calculate TE counts
countFun = partial(
run_count, mappings_file, feature_index, dirs['tmp'],
args.dump_ec, args.verbose
args.dump_ec, args.max_iters, args.tolerance, args.verbose
)
if args.threads > 1:
mtxFiles = pool.map(countFun, bc_per_thread)
Expand Down
Loading