forked from brucechou1983/CheXNet-Keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
weights.py
30 lines (24 loc) · 1013 Bytes
/
weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
def get_class_weights(total_counts, class_positive_counts, multiply):
"""
Calculate class_weight used in training
Arguments:
total_counts - int
class_positive_counts - dict of int, ex: {"Effusion": 300, "Infiltration": 500 ...}
multiply - int, positve weighting multiply
use_class_balancing - boolean
Returns:
class_weight - dict of dict, ex: {"Effusion": { 0: 0.01, 1: 0.99 }, ... }
"""
def get_single_class_weight(pos_counts, total_counts):
denominator = (total_counts - pos_counts) * multiply + pos_counts
return {
0: pos_counts / denominator,
1: (denominator - pos_counts) / denominator,
}
class_names = list(class_positive_counts.keys())
label_counts = np.array(list(class_positive_counts.values()))
class_weights = []
for i, class_name in enumerate(class_names):
class_weights.append(get_single_class_weight(label_counts[i], total_counts))
return class_weights