diff --git a/SemiBin/semi_supervised_model.py b/SemiBin/semi_supervised_model.py index feca9b2..4543bb9 100644 --- a/SemiBin/semi_supervised_model.py +++ b/SemiBin/semi_supervised_model.py @@ -263,7 +263,8 @@ def train(out, contig_fastas, binned_lengths, logger, datas, data_splits, cannot dataset=dataset_unlabeled, batch_size=batchsize, shuffle=True, - num_workers=0) + num_workers=0, + drop_last=True) for train_input1, train_input2, train_label in train_loader: model.train()