diff --git a/irescue/count.py b/irescue/count.py index ba8f3ed..de32e77 100644 --- a/irescue/count.py +++ b/irescue/count.py @@ -82,7 +82,8 @@ def parse_maps(maps_file, feature_index): eqcl = [EquivalenceClass(i, umi, feat, count)] yield it, eqcl -def compute_cell_counts(equivalence_classes, features_index, dumpEC): +def compute_cell_counts(equivalence_classes, features_index, max_iters, + tolerance, dumpEC): """ Calculate TE counts of a single cell, given a list of equivalence classes. @@ -204,6 +205,8 @@ def compute_cell_counts(equivalence_classes, features_index, dumpEC): for x in path_: # add parent's UMI sequence and dedup features dump[x] += (dump[parent_][0], features[i]) + # EM stats placeholder in case of no multimapped UMIs + em_stats = (None, None) if em_array: # optimize the assignment of UMI from multimapping reads em_array = np.array(em_array) @@ -213,12 +216,16 @@ def compute_cell_counts(equivalence_classes, features_index, dumpEC): todel = np.argwhere(np.all(em_array[..., :] == 0, axis=0)) em_array = np.delete(em_array, todel, axis=1) # run EM - em_counts = run_em(em_array, cycles=100) + em_counts, em_stats = run_em( + em_array, + cycles=max_iters, + tolerance=tolerance + ) em_counts = [x*em_array.shape[0] for x in em_counts] for i, c in zip(tokeep, em_counts): if c > 0: counts[i] += c - return dict(counts), dump + return dict(counts), dump, em_stats def split_barcodes(barcodes_file, n): """ @@ -232,8 +239,8 @@ def split_barcodes(barcodes_file, n): for i, chunk in enumerate(get_ranges(nBarcodes, n)): yield i, {next(f).strip(): x+1 for x in chunk} -def run_count(maps_file, features_index, tmpdir, dumpEC, verbose, - barcodes_set): +def run_count(maps_file, features_index, tmpdir, dumpEC, max_iters, tolerance, + verbose, barcodes_set): # NB: keep args order consistent with main.countFun taskn, barcodes = barcodes_set matrix_file = os.path.join(tmpdir, f'{taskn}_matrix.mtx.gz') @@ -250,14 +257,16 @@ def run_count(maps_file, features_index, tmpdir, dumpEC, verbose, f'{cellidx} ({cellbarcode.decode()})', level=2, send=verbose ) - cellcounts, dump = compute_cell_counts( + cellcounts, dump, em_stats = compute_cell_counts( equivalence_classes=cellmaps, features_index=features_index, + max_iters=max_iters, + tolerance=tolerance, dumpEC=dumpEC ) writerr( - f'[{taskn}] Write count for cell ' - f'{cellidx} ({cellbarcode.decode()})', + f'[{taskn}] Write cell {cellidx} ({cellbarcode.decode()}). ' + f'EM cycles: {em_stats[0]}. Coverged: {em_stats[1]}.', level=1, send=verbose ) # round counts to 3rd decimal point and write to matrix file diff --git a/irescue/em.py b/irescue/em.py index 39443df..7aa4c7d 100644 --- a/irescue/em.py +++ b/irescue/em.py @@ -17,7 +17,15 @@ def m_step(matrix): counts = matrix.sum(axis=0) / matrix.sum() return(counts) -def run_em(matrix, cycles=100): +def log_likelihood(matrix, counts): + """ + Compute log-likelihood of data. + """ + likelihoods = (matrix * counts).sum(axis=1) + log_likelihood = np.sum(np.log(likelihoods + np.finfo(float).eps)) + return log_likelihood + +def run_em(matrix, cycles=100, tolerance=1e-5): """ Run Expectation-Maximization (EM) algorithm to redistribute read counts across a set of features. @@ -28,22 +36,45 @@ def run_em(matrix, cycles=100): Reads-features compatibility matrix. cycles : int, optional Number of EM cycles. + tolerance : float + Tolerance threshold of log-likelihood difference to infer convergence. Returns ------- out : list Optimized relative feature abundances. + cycle : int + Number of EM cycles. + converged : bool + Indicates if convergence has been reached before cycles theshold. """ - + # calculate initial estimation of relative abundance. # (let the sum of counts of features be 1, # will be multiplied by the real UMI count later) nFeatures = matrix.shape[1] counts = np.array([1 / nFeatures] * nFeatures) - # run EM for n cycles - for _ in range(cycles): + # Initial log-likelihood + prev_loglik = log_likelihood(matrix, counts) + + converged = False + curr_cycle = 0 + + # Run EM iterations + while curr_cycle < cycles: + curr_cycle += 1 e_matrix = e_step(matrix=matrix, counts=counts) counts = m_step(matrix=e_matrix) - return(counts) \ No newline at end of file + # Compute the new log-likelihood + loglik = log_likelihood(matrix, counts) + + # Check for convergence + if np.abs(loglik - prev_loglik) < tolerance: + converged = True + break + + prev_loglik = loglik + + return counts, (curr_cycle, converged) diff --git a/irescue/main.py b/irescue/main.py index 3a6315a..a7ee712 100644 --- a/irescue/main.py +++ b/irescue/main.py @@ -55,6 +55,12 @@ def parseArguments(): help="Minimum overlap between read and TE" " as a fraction of read's alignment" " (i.e. 0.00 <= NUM <= 1.00) (Default: disabled).") + parser.add_argument('--max-iters', type=int, metavar='INT', default=100, + help="Maximum number of EM iterations " + "(Default: %(default)s).") + parser.add_argument('--tolerance', type=float, metavar='FLOAT', + default=1e-4, help="Log-likelihood change below which " + "convergence is assumed (Default: %(default)s).") parser.add_argument('--dump-ec', action='store_true', help="Write a description log file of Equivalence " "Classes.") @@ -190,7 +196,7 @@ def main(): # calculate TE counts countFun = partial( run_count, mappings_file, feature_index, dirs['tmp'], - args.dump_ec, args.verbose + args.dump_ec, args.max_iters, args.tolerance, args.verbose ) if args.threads > 1: mtxFiles = pool.map(countFun, bc_per_thread)