Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating a task loader for list of data sets (similar to meta-data set) #333

Closed
brando90 opened this issue May 10, 2022 · 3 comments
Closed

Comments

@brando90
Copy link

brando90 commented May 10, 2022

I want a few-shot learning data set that works similar to meta-data set (as a first step to reach that) i.e. sample a data set first then create a n-way, k-shot task from it. Based on the following slack discussion:

Is there any type few-shot learning benchmark that is data set based (like meta-data set) that is supported by learn2learn? e.g. meta-data set samples a data set first, then n-way classes based on the dataset selected and then create the few-shot learning task. Is a benchmark like that supported by learn2learn?

The slack discussion suggested creating a indexable data set, a task transform that indexed that and then giving that to TaskDataset. I don't think taht works because the transforms require the dataset at creation time. Thus instead what I did is to create a single transform that dynamically gets the data set and then creates the task transforms with it.

I think it works since the print statement display different n-way class indices and the size of the images look correct to me. Will post here in case it's useful to someone else and most importantly to correct it if it's wrong (since it's not following what @seba-1511 initially suggested):

import random
from typing import Callable

import learn2learn as l2l
import numpy as np
import torch
from learn2learn.data import TaskDataset, MetaDataset, DataDescription
from learn2learn.data.transforms import TaskTransform
from torch.utils.data import Dataset


class IndexableDataSet(Dataset):

    def __init__(self, datasets):
        self.datasets = datasets

    def __len__(self) -> int:
        return len(self.datasets)

    def __getitem__(self, idx: int):
        return self.datasets[idx]


class SingleDatasetPerTaskTransform(Callable):
    """
    Transform that samples a data set first, then creates a task (e.g. n-way, k-shot) and finally
    applies the remaining task transforms.
    """

    def __init__(self, indexable_dataset: IndexableDataSet, cons_remaining_task_transforms: Callable):
        """

        :param: cons_remaining_task_transforms; constructor that builds the remaining task transforms. Cannot be a list
        of transforms because we don't know apriori which is the data set we will use. So this function should be of
        type MetaDataset -> list[TaskTransforms] i.e. given the dataset it returns the transforms for it.
        """
        self.indexable_dataset = MetaDataset(indexable_dataset)
        self.cons_remaining_task_transforms = cons_remaining_task_transforms

    def __call__(self, task_description: list):
        """
        idea:
        - receives the index of the dataset to use
        - then use the normal NWays l2l function
        """
        # - this is what I wish could have gone in a seperate callable transform, but idk how since the transforms take apriori (not dynamically) which data set to use.
        i = random.randint(0, len(self.indexable_dataset) - 1)
        task_description = [DataDescription(index=i)]  # using this to follow the l2l convention

        # - get the sampled data set
        dataset_index = task_description[0].index
        dataset = self.indexable_dataset[dataset_index]
        dataset = MetaDataset(dataset)

        # - use the sampled data set to create task
        remaining_task_transforms: list[TaskTransform] = self.cons_remaining_task_transforms(dataset)
        description = None
        for transform in remaining_task_transforms:
            description = transform(description)
        return description


def sample_dataset(dataset):
    def sample_random_dataset(x):
        print(f'{x=}')
        i = random.randint(0, len(dataset) - 1)
        return [DataDescription(index=i)]
        # return dataset[i]

    return sample_random_dataset


def get_task_transforms(dataset: IndexableDataSet) -> list[TaskTransform]:
    """
    :param dataset:
    :return:
    """
    transforms = [
        sample_dataset(dataset),
        l2l.data.transforms.NWays(dataset, n=5),
        l2l.data.transforms.KShots(dataset, k=5),
        l2l.data.transforms.LoadData(dataset),
        l2l.data.transforms.RemapLabels(dataset),
        l2l.data.transforms.ConsecutiveLabels(dataset),
    ]
    return transforms


def print_datasets(dataset_lst: list):
    for dataset in dataset_lst:
        print(f'\n{dataset=}\n')


def get_indexable_list_of_datasets_mi_and_cifarfs(root: str = '~/data/l2l_data/') -> IndexableDataSet:
    from learn2learn.vision.benchmarks import mini_imagenet_tasksets
    datasets, transforms = mini_imagenet_tasksets(root=root)
    mi = datasets[0].dataset

    from learn2learn.vision.benchmarks import cifarfs_tasksets
    datasets, transforms = cifarfs_tasksets(root=root)
    cifarfs = datasets[0].dataset

    dataset_list = [mi, cifarfs]

    dataset_list = [l2l.data.MetaDataset(dataset) for dataset in dataset_list]
    dataset = IndexableDataSet(dataset_list)
    return dataset


# -- tests

def loop_through_l2l_indexable_datasets_test():
    """
    """
    # - for determinism
    random.seed(0)
    torch.manual_seed(0)
    np.random.seed(0)

    # - options for number of tasks/meta-batch size
    batch_size: int = 10

    # - create indexable data set
    indexable_dataset: IndexableDataSet = get_indexable_list_of_datasets_mi_and_cifarfs()

    # - get task transforms
    def get_remaining_transforms(dataset: MetaDataset) -> list[TaskTransform]:
        remaining_task_transforms = [
            l2l.data.transforms.NWays(dataset, n=5),
            l2l.data.transforms.KShots(dataset, k=5),
            l2l.data.transforms.LoadData(dataset),
            l2l.data.transforms.RemapLabels(dataset),
            l2l.data.transforms.ConsecutiveLabels(dataset),
        ]
        return remaining_task_transforms
    task_transforms: TaskTransform = SingleDatasetPerTaskTransform(indexable_dataset, get_remaining_transforms)

    # -
    taskset: TaskDataset = TaskDataset(dataset=indexable_dataset, task_transforms=task_transforms)

    # - loop through tasks
    for task_num in range(batch_size):
        print(f'{task_num=}')
        X, y = taskset.sample()
        print(f'{X.size()=}')
        print(f'{y.size()=}')
        print(f'{y=}')
        print()

    print('-- end of test --')


# -- Run experiment

if __name__ == "__main__":
    import time
    from uutils import report_times

    start = time.time()
    # - run experiment
    loop_through_l2l_indexable_datasets_test()
    # - Done
    print(f"\nSuccess Done!: {report_times(start)}\a")

output:

task_num=0
X.size()=torch.Size([25, 3, 32, 32])
y.size()=torch.Size([25])
y=tensor([0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 1, 1, 1, 1,
        1])
task_num=1
X.size()=torch.Size([25, 3, 32, 32])
y.size()=torch.Size([25])
y=tensor([4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 3, 3, 3, 3,
        3])
task_num=2
X.size()=torch.Size([25, 3, 84, 84])
y.size()=torch.Size([25])
y=tensor([4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 1, 1, 1, 1,
        1])
task_num=3
X.size()=torch.Size([25, 3, 84, 84])
y.size()=torch.Size([25])
y=tensor([1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 4, 4, 4, 4,
        4])
task_num=4
X.size()=torch.Size([25, 3, 84, 84])
y.size()=torch.Size([25])
y=tensor([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 4, 4, 4, 4,
        4])
task_num=5
X.size()=torch.Size([25, 3, 32, 32])
y.size()=torch.Size([25])
y=tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 3, 3, 3, 3,
        3])
task_num=6
X.size()=torch.Size([25, 3, 32, 32])
y.size()=torch.Size([25])
y=tensor([3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 4, 4, 4, 4,
        4])
task_num=7
X.size()=torch.Size([25, 3, 84, 84])
y.size()=torch.Size([25])
y=tensor([0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 1, 1, 1, 1,
        1])
task_num=8
X.size()=torch.Size([25, 3, 84, 84])
y.size()=torch.Size([25])
y=tensor([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 3, 3, 3, 3,
        3])
task_num=9
X.size()=torch.Size([25, 3, 32, 32])
y.size()=torch.Size([25])
y=tensor([2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 1, 1, 1, 1,
        1])
-- end of test --
Success Done!: time passed: hours:0.030430123541090225, minutes=1.8258074124654133, seconds=109.5484447479248

related: meta-data set gitissue: #286

@brando90
Copy link
Author

@brando90
Copy link
Author

brando90 commented May 12, 2022

croping is done after padding. Idk why but it is. Seems weird to me.

https://pytorch.org/vision/main/generated/torchvision.transforms.RandomCrop.html

pad_if_needed (boolean) – It will pad the image if smaller than the desired size to avoid raising an exception. Since cropping is done after padding, the padding seems to be done at a random offset.

@seba-1511
Copy link
Member

Closing: inactive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants