diff --git a/SemiBin/main.py b/SemiBin/main.py index 3f74e5b..4b0e5e2 100644 --- a/SemiBin/main.py +++ b/SemiBin/main.py @@ -989,17 +989,6 @@ def training(logger, contig_fasta, num_process, is_combined = False if training_mode == 'semi': - contig_fasta_unzip = [] - for fasta_index,temp_fasta in enumerate(contig_fasta): - if temp_fasta.endswith('.gz') or temp_fasta.endswith('.bz2') or temp_fasta.endswith('.xz'): - temp_fasta_unzip = os.path.join(output, 'unzip_contig_{}.fa'.format(fasta_index)) - with open(temp_fasta_unzip, 'wt') as out: - for h,seq in fasta_iter(temp_fasta): - out.write(f'>{h}\n{seq}\n') - contig_fasta_unzip.append(temp_fasta_unzip) - else: - contig_fasta_unzip.append(temp_fasta) - if mode == 'single': binned_lengths.append( utils.compute_min_length(min_length, contig_fasta[0], ratio)) @@ -1010,7 +999,7 @@ def training(logger, contig_fasta, num_process, model = train( output, - contig_fasta_unzip, + contig_fasta, binned_lengths, logger, data, diff --git a/SemiBin/utils.py b/SemiBin/utils.py index 6202bd5..b555223 100644 --- a/SemiBin/utils.py +++ b/SemiBin/utils.py @@ -383,6 +383,17 @@ def extract_seeds(vs, sel): counts = data.groupby('gene')['orf'].count() return extract_seeds(counts, data) +def maybe_uncompress(fafile, tdir): + if fafile.endswith('.gz') or \ + fafile.endswith('.bz2') or \ + fafile.endswith('.xz'): + oname = f'{tdir}/expanded.fa' + with open(oname, 'wt') as out: + for header, seq in fasta_iter(fafile): + out.write(f'>{header}\n{seq}\n') + return oname + return fafile + def cal_num_bins(fasta_path, binned_length, num_process, multi_mode=False, output = None, orf_finder = 'prodigal', prodigal_output_faa=None): '''Estimate number of bins from a FASTA file @@ -395,6 +406,7 @@ def cal_num_bins(fasta_path, binned_length, num_process, multi_mode=False, outpu ''' from .orffinding import run_orffinder with tempfile.TemporaryDirectory() as tdir: + fasta_path = maybe_uncompress(fasta_path, tdir) if output is not None: if os.path.exists(os.path.join(output, 'markers.hmmout')): return get_marker(os.path.join(output, 'markers.hmmout'), fasta_path, binned_length, multi_mode, orf_finder=orf_finder)