diff --git a/experimental/shoshin/configs/celeb_a_resnet_config.py b/experimental/shoshin/configs/celeb_a_resnet_config.py index 1d34b1fec..407548885 100644 --- a/experimental/shoshin/configs/celeb_a_resnet_config.py +++ b/experimental/shoshin/configs/celeb_a_resnet_config.py @@ -23,8 +23,13 @@ def get_config() -> ml_collections.ConfigDict: """Get mlp config.""" config = base_config.get_config() + config.data.subgroup_ids = ( + 'Blond_Hair', + ) # ('Blond_Hair') Currently only the use ofa single attribute supported + config.data.subgroup_proportions = (0.01,) + data = config.data - data.name = 'celeb_a' + data.name = 'local_celeb_a' data.num_classes = 2 model = config.model diff --git a/experimental/shoshin/data.py b/experimental/shoshin/data.py index 59b4d68d2..de4426f81 100644 --- a/experimental/shoshin/data.py +++ b/experimental/shoshin/data.py @@ -534,10 +534,189 @@ def get_waterbirds_dataset( eval_ds=eval_datasets) -@register_dataset('celeb_a') +IMG_ALIGNED_DATA = ('https://drive.google.com/uc?export=download&' + 'id=0B7EVK8r0v71pZjFTYXZWM3FlRnM') +EVAL_LIST = ('https://drive.google.com/uc?export=download&' + 'id=0B7EVK8r0v71pY0NSMzRuSXJEVkk') +# Landmark coordinates: left_eye, right_eye etc. +LANDMARKS_DATA = ('https://drive.google.com/uc?export=download&' + 'id=0B7EVK8r0v71pd0FJY3Blby1HUTQ') + +# Attributes in the image (Eyeglasses, Mustache etc). +ATTR_DATA = ('https://drive.google.com/uc?export=download&' + 'id=0B7EVK8r0v71pblRyaVFSWGxPY0U') + +LANDMARK_HEADINGS = ('lefteye_x lefteye_y righteye_x righteye_y ' + 'nose_x nose_y leftmouth_x leftmouth_y rightmouth_x ' + 'rightmouth_y').split() +ATTR_HEADINGS = ( + '5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs ' + 'Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair ' + 'Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair ' + 'Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache ' + 'Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline ' + 'Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings ' + 'Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young' +).split() + +_CITATION = """\ +@inproceedings{conf/iccv/LiuLWT15, + added-at = {2018-10-09T00:00:00.000+0200}, + author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou}, + biburl = {https://www.bibsonomy.org/bibtex/250e4959be61db325d2f02c1d8cd7bfbb/dblp}, + booktitle = {ICCV}, + crossref = {conf/iccv/2015}, + ee = {http://doi.ieeecomputersociety.org/10.1109/ICCV.2015.425}, + interhash = {3f735aaa11957e73914bbe2ca9d5e702}, + intrahash = {50e4959be61db325d2f02c1d8cd7bfbb}, + isbn = {978-1-4673-8391-2}, + keywords = {dblp}, + pages = {3730-3738}, + publisher = {IEEE Computer Society}, + timestamp = {2018-10-11T11:43:28.000+0200}, + title = {Deep Learning Face Attributes in the Wild.}, + url = {http://dblp.uni-trier.de/db/conf/iccv/iccv2015.html#LiuLWT15}, + year = 2015 +} +""" + + +class LocalCelebADataset(tfds.core.GeneratorBasedBuilder): + """CelebA dataset. Aligned and cropped. With metadata.""" + + VERSION = tfds.core.Version('2.0.1') + SUPPORTED_VERSIONS = [ + tfds.core.Version('2.0.0'), + ] + RELEASE_NOTES = { + '2.0.1': 'New split API (https://tensorflow.org/datasets/splits)', + } + + def __init__(self, + subgroup_ids: List[str], + subgroup_proportions: Optional[List[float]] = None, + label_attr: Optional[str] = 'Male', + **kwargs): + super(LocalCelebADataset, self).__init__(**kwargs) + self.subgroup_ids = subgroup_ids + self.label_attr = label_attr + if subgroup_proportions: + self.subgroup_proportions = subgroup_proportions + else: + self.subgroup_proportions = [1.] * len(subgroup_ids) + + def _info(self): + return tfds.core.DatasetInfo( + builder=self, + features=tfds.features.FeaturesDict({ + 'example_id': + tfds.features.Text(), + 'subgroup_id': + tfds.features.Text(), + 'subgroup_label': + tfds.features.ClassLabel(num_classes=2), + 'feature': + tfds.features.Image( + shape=(RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 3), + encoding_format='jpeg'), + 'label': + tfds.features.ClassLabel(num_classes=2), + 'image_filename': + tfds.features.Text(), + }), + supervised_keys=('feature', 'label', 'example_id'), + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Download the data and define splits.""" + return { + 'train': + self._generate_examples( + 'train', + is_training=True), + 'validation': + self._generate_examples( + 'validation'), + 'test': + self._generate_examples( + 'test'), + } + + def _generate_examples(self, + split: str, + is_training: Optional[bool] = False + ) -> Iterator[Tuple[str, Dict[str, Any]]]: + """Generator of examples for each split.""" + read_config = tfds.ReadConfig() + read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key + + dataset = tfds.load('celeb_a', read_config=read_config, split=[split])[0] + if is_training: + options = tf.data.Options() + options.experimental_deterministic = False + dataset = dataset.with_options(options) + + # Prepare initial training set. + # Pre-computed dataset size or large number >= estimated dataset size. + dataset_size = 300000 + dataset = dataset.shuffle(dataset_size) + sampled_datasets = [] + remaining_proportion = 1. + + def filter_fn_subgroup(element): + return element['attributes'][self.subgroup_ids[0]] + + subgroup_dataset = dataset.filter(filter_fn_subgroup) + subgroup_sample_size = int(dataset_size * self.subgroup_proportions[0]) + subgroup_dataset = subgroup_dataset.take(subgroup_sample_size) + sampled_datasets.append(subgroup_dataset) + remaining_proportion -= self.subgroup_proportions[0] + + def filter_fn_remaining(element): + return not element['attributes'][self.subgroup_ids[0]] + + remaining_dataset = dataset.filter(filter_fn_remaining) + remaining_sample_size = int(dataset_size * remaining_proportion) + remaining_dataset = remaining_dataset.take(remaining_sample_size) + sampled_datasets.append(remaining_dataset) + + dataset = sampled_datasets[0] + for ds in sampled_datasets[1:]: + dataset = dataset.concatenate(ds) + dataset = dataset.shuffle(dataset_size) + + for example in dataset: + file_name = example['tfds_id'].numpy() + subgroup_id = self.subgroup_ids[0] if example['attributes'][ + self.subgroup_ids[0]] else 'Not_' + self.subgroup_ids[0] + subgroup_label = 1 if example['attributes'][self.subgroup_ids[0]] else 0 + label = 1 if example['attributes'][self.label_attr] else 0 + record = { + 'example_id': + file_name, + 'subgroup_id': + subgroup_id, + 'subgroup_label': + subgroup_label, + 'feature': + tf.image.resize( + example['image'], [RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE], + method='nearest').numpy(), + 'label': + label, + 'image_filename': + file_name + } + yield file_name, record + + +@register_dataset('local_celeb_a') def get_celeba_dataset( - num_splits: int, initial_sample_proportion: float, - subgroup_ids: List[str], subgroup_proportions: List[float], + num_splits: int, + initial_sample_proportion: float, + subgroup_ids: List[str], + subgroup_proportions: List[float], + is_training: Optional[bool] = True, ) -> Dataloader: """Returns datasets for training, validation, and possibly test sets. @@ -548,54 +727,56 @@ def get_celeba_dataset( subgroup_ids: List of strings of IDs indicating subgroups. subgroup_proportions: List of floats indicating proportion that each subgroup should take in initial training dataset. + is_training: Dataset used for evaluation (in this case as_supervised is + set to True in val/test) Returns: A tuple containing the split training data, split validation data, the combined training dataset, and a dictionary mapping evaluation dataset names to their respective combined datasets. """ - del subgroup_proportions, subgroup_ids read_config = tfds.ReadConfig() read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key + split_size_in_pct = int(100 * initial_sample_proportion / num_splits) reduced_dataset_sz = int(100 * initial_sample_proportion) + builder_kwargs = { + 'subgroup_ids': subgroup_ids, + 'subgroup_proportions': subgroup_proportions + } train_splits = tfds.load( - 'celeb_a', + 'local_celeb_a_dataset', read_config=read_config, split=[ - f'train[:{k}%]+train[{k+split_size_in_pct}%:]' + f'train[{k}%:{k+split_size_in_pct}%]' for k in range(0, reduced_dataset_sz, split_size_in_pct) ], + builder_kwargs=builder_kwargs, data_dir=DATA_DIR, try_gcs=False, - as_supervised=True + as_supervised=is_training, ) val_splits = tfds.load( - 'celeb_a', + 'local_celeb_a_dataset', read_config=read_config, split=[ f'validation[{k}%:{k+split_size_in_pct}%]' for k in range(0, reduced_dataset_sz, split_size_in_pct) ], + builder_kwargs=builder_kwargs, data_dir=DATA_DIR, + as_supervised=is_training, try_gcs=False, - as_supervised=True ) - train_sample = tfds.load( - 'celeb_a', - split='train_sample', - data_dir=DATA_DIR, - try_gcs=False, - as_supervised=True, - with_info=False) test_ds = tfds.load( - 'celeb_a', + 'local_celeb_a_dataset', split='test', + builder_kwargs=builder_kwargs, data_dir=DATA_DIR, try_gcs=False, - as_supervised=True, - with_info=False) + as_supervised=is_training, + with_info=False,) train_ds = gather_data_splits(list(range(num_splits)), train_splits) val_ds = gather_data_splits(list(range(num_splits)), val_splits) @@ -607,5 +788,4 @@ def get_celeba_dataset( train_splits, val_splits, train_ds, - train_sample_ds=train_sample, eval_ds=eval_datasets) diff --git a/experimental/shoshin/sampling_policies.py b/experimental/shoshin/sampling_policies.py index a0256074d..521f3278d 100644 --- a/experimental/shoshin/sampling_policies.py +++ b/experimental/shoshin/sampling_policies.py @@ -56,6 +56,9 @@ def compute_ids_to_sample( elif sampling_score == 'bias': sample_avg = predictions_df[prediction_bias_cols].mean(axis=1).to_numpy() predictions_df['sampling_score'] = 1 - sample_avg + elif sampling_score == 'random': + predictions_df['sampling_score'] = np.random.randint( + 1, predictions_df.shape[0], predictions_df.shape[0]) predictions_df = predictions_df.sort_values( by='sampling_score', ascending=True) return predictions_df.head(num_samples)['example_id'].to_numpy()