Skip to content

Commit

Permalink
RFCT Simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
luispedro committed Nov 2, 2023
1 parent 5baeaab commit 7525011
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
1 change: 1 addition & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Unreleased
* citation: Add citation subcommand
* SemiBin1: Introduce separate SemiBin1 command
* internal: Code simplification and refactor

Version 2.0.2 Oct 31 2023 by BigDataBiology
* multi_easy_bin: Fix multi_easy_bin with --write-pre-recluster (#128)
Expand Down
10 changes: 4 additions & 6 deletions SemiBin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,13 +986,11 @@ def training(logger, contig_fasta, num_process,
is_combined = False

if training_type == 'semi':
if mode == 'single':
for fafile in contig_fasta:
binned_lengths.append(
utils.compute_min_length(min_length, contig_fasta[0], ratio))
else:
for fafile in contig_fasta:
binned_lengths.append(
utils.compute_min_length(min_length, fafile, ratio))
utils.compute_min_length(min_length, fafile, ratio))
if mode == 'single':
break

model = train(
logger,
Expand Down
16 changes: 8 additions & 8 deletions SemiBin/self_supervised_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,30 @@ def train_self(logger, out : str, datapaths, data_splits, is_combined=True,
train_data = pd.read_csv(datapaths[0], index_col=0).values

if not is_combined:
train_data_input = train_data[:, 0:136]
else:
train_data_input = train_data
train_data = train_data[:, :136]

torch.set_num_threads(num_process)

logger.info('Training model...')

if not is_combined:
model = Semi_encoding_single(train_data_input.shape[1]).to(device)
model = Semi_encoding_single(train_data.shape[1])
else:
model = Semi_encoding_multiple(train_data_input.shape[1]).to(device)
model = Semi_encoding_multiple(train_data.shape[1])

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

for epoch in tqdm(range(epoches)):
for data_index in range(len(datapaths)):
for data_index, (datapath, data_split_path) in enumerate(zip(datapaths, data_splits)):
if epoch == 0:
logger.debug(f'Reading training data for index {data_index}...')

data = pd.read_csv(datapaths[data_index], index_col=0)
data = pd.read_csv(datapath, index_col=0)
data.index = data.index.astype(str)
data_split = pd.read_csv(data_splits[data_index], index_col=0)
data_split = pd.read_csv(data_split_path, index_col=0)

if mode == 'several':
if data.shape[1] != 138 or data_split.shape[1] != 136:
Expand Down

0 comments on commit 7525011

Please sign in to comment.