From 2b931ff25977f9174a4111820bb9e12de83d7dfe Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Thu, 18 Jan 2024 12:46:46 +0800 Subject: [PATCH] Norm when abundance values are large --- SemiBin/cluster.py | 3 ++- SemiBin/long_read_cluster.py | 3 ++- SemiBin/self_supervised_model.py | 4 ++-- SemiBin/semi_supervised_model.py | 3 ++- SemiBin/utils.py | 13 +++++++++++++ 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/SemiBin/cluster.py b/SemiBin/cluster.py index 144be9c..791fca2 100644 --- a/SemiBin/cluster.py +++ b/SemiBin/cluster.py @@ -96,11 +96,12 @@ def run_embed_infomap(logger, model, data, * , from scipy import sparse import torch import numpy as np + from .utils import norm_abundance train_data_input = data.values[:, 0:136] if not is_combined else data.values if is_combined: - if train_data_input.shape[1] - 136 > 20: + if norm_abundance(train_data_input): train_data_kmer = train_data_input[:, 0:136] train_data_depth = train_data_input[:, 136:len(data.values[0])] from sklearn.preprocessing import normalize diff --git a/SemiBin/long_read_cluster.py b/SemiBin/long_read_cluster.py index a5de1ea..adafe85 100644 --- a/SemiBin/long_read_cluster.py +++ b/SemiBin/long_read_cluster.py @@ -50,12 +50,13 @@ def cluster_long_read(logger, model, data, device, is_combined, n_sample, out, contig_dict, *, binned_length, args, minfasta): import pandas as pd + from .utils import norm_abundance contig_list = data.index.tolist() if not is_combined: train_data_input = data.values[:, 0:136] else: train_data_input = data.values - if train_data_input.shape[1] - 136 > 20: + if norm_abundance(train_data_input): train_data_kmer = train_data_input[:, 0:136] train_data_depth = train_data_input[:, 136:len(data.values[0])] from sklearn.preprocessing import normalize diff --git a/SemiBin/self_supervised_model.py b/SemiBin/self_supervised_model.py index 2ced8b2..24441dd 100644 --- a/SemiBin/self_supervised_model.py +++ b/SemiBin/self_supervised_model.py @@ -4,6 +4,7 @@ from torch.optim import lr_scheduler import sys from .semi_supervised_model import Semi_encoding_single, Semi_encoding_multiple, feature_Dataset +from .utils import norm_abundance def loss_function(embedding1, embedding2, label): relu = torch.nn.ReLU() @@ -70,7 +71,7 @@ def train_self(logger, out : str, datapaths, data_splits, is_combined=True, if not is_combined: train_data = train_data[:, :136] else: - if train_data.shape[1] - 136 > 20: + if norm_abundance(train_data): train_data_kmer = train_data[:, :136] train_data_depth = train_data[:, 136:] train_data_depth = normalize(train_data_depth, axis=1, norm='l1') @@ -82,7 +83,6 @@ def train_self(logger, out : str, datapaths, data_splits, is_combined=True, train_data_split = np.concatenate((train_data_split_kmer, train_data_split_depth), axis = 1) data_length = len(train_data) - # cannot link data is sampled randomly n_cannot_link = min(len(train_data_split) * 1000 // 2, 4_000_000) indices1 = np.random.choice(data_length, size=n_cannot_link) diff --git a/SemiBin/semi_supervised_model.py b/SemiBin/semi_supervised_model.py index d9dad56..415f0ae 100644 --- a/SemiBin/semi_supervised_model.py +++ b/SemiBin/semi_supervised_model.py @@ -151,6 +151,7 @@ def train(logger, out, contig_fastas, binned_lengths, datas, data_splits, cannot import pandas as pd from sklearn.preprocessing import normalize import numpy as np + from .utils import norm_abundance train_data = pd.read_csv(datas[0], index_col=0).values if not is_combined: @@ -215,7 +216,7 @@ def train(logger, out, contig_fastas, binned_lengths, datas, data_splits, cannot train_data_input = train_data[:, 0:136] train_data_split_input = train_data_must_link else: - if train_data.shape[1] - 136 > 20: + if norm_abundance(train_data): train_data_kmer = train_data[:, :136] train_data_depth = train_data[:, 136:] train_data_depth = normalize(train_data_depth, axis=1, norm='l1') diff --git a/SemiBin/utils.py b/SemiBin/utils.py index 4e99e31..89d05d2 100644 --- a/SemiBin/utils.py +++ b/SemiBin/utils.py @@ -631,3 +631,16 @@ def compute_min_length(min_length, fafile, ratio): if binned_short: return 1000 return 2500 + +def norm_abundance(data): + import numpy as np + n = data.shape[1] - 136 + flag = False + + if n >= 20: + flag = True + else: + if n >= 5: + if np.mean(np.sum(data[:, 136:], axis=1)) > 2: + flag = True + return flag