-
Notifications
You must be signed in to change notification settings - Fork 1
/
count_patches.py
67 lines (59 loc) · 2.74 KB
/
count_patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import numpy as np
from utils.utils import get_simple_loader
from datasets.dataset_generic import Generic_MIL_Dataset
import argparse
parser = argparse.ArgumentParser(description='Configurations for WSI Training')
parser.add_argument('--data_root_dir', type=str, default=None,help='directory containing features folders')
parser.add_argument('--features_folder', type=str, default=None, help='folder within data_root_dir containing the features - must contain pt_files/h5_files subfolder')
parser.add_argument('--csv_path',type=str,default=None,help='path to dataset_csv file')
parser.add_argument('--coords_path',type=str,default=None,help='path to coords folder')
parser.add_argument('--n_classes',type=int,default=2,help='number of classes')
args = parser.parse_args()
assert args.n_classes == 2, "currently only implemented for binary classification"
def count_patches(dataset):
dataset.load_from_h5(True)
loader = get_simple_loader(dataset)
patches0=0
patches1=0
print("slides: ",len(loader))
patch_counts=[]
all_counts=[]
for batch_idx, (data, label, coords, ids) in enumerate(loader):
count=len(coords)
if label==0:
patches0=patches0+count
elif label==1:
patches1=patches1+count
patch_counts=patch_counts+[[ids,count]]
all_counts=all_counts+[count]
print("number", batch_idx, " slide",ids," class 0 patches: ",patches0, " class 1 patches: ",patches1)
patches=patches0+patches1
#pd.DataFrame(patch_counts,columns=["slide","patches"]).to_csv("results/patch_counts/ESGO_available_staging.csv",index=False)
return patches, all_counts
csv_path=args.csv_path
data_root_dir=args.data_root_dir
features_folder=args.features_folder
coords_path=args.coords_path
n_classes = args.n_classes
#label_dict = {'invalid': 0,'effective': 1}
label_dict = {'high_grade':0,'low_grade':1,'clear_cell':1,'endometrioid':1,'mucinous':1}
dataset = Generic_MIL_Dataset(csv_path = csv_path,
data_dir= os.path.join(data_root_dir, features_folder),
small_data_dir=None,
coords_path = coords_path,
small_coords_path=None,
shuffle = False,
seed = 0,
print_info = True,
label_dict = label_dict,
patient_strat=False,
ignore=[],
max_patches_per_slide=100000000)
split_dataset=dataset
patches, all_counts = count_patches(split_dataset)
print("{} patches".format(patches))
print("min patches:",min(all_counts))
print("max patches:",max(all_counts))
print("mean patches:",np.mean(all_counts))
print("sd patches:",np.std(all_counts))