Skip to content

Commit

Permalink
Merge branch 'f/fix_vision_deps' into 'main'
Browse files Browse the repository at this point in the history
Fix dependency on computer vision datasets

See merge request es/ai/hannah/hannah!353
  • Loading branch information
cgerum committed Nov 15, 2023
2 parents 1dab7a2 + e75363e commit 3ca2320
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 42 deletions.
2 changes: 1 addition & 1 deletion hannah/conf/module/anomaly_detection.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: hannah.modules.AnomalyDetectionModule
_target_: hannah.modules.vision.AnomalyDetectionModule
num_workers: 0
batch_size: 128
shuffle_all_dataloaders: False
Expand Down
2 changes: 1 addition & 1 deletion hannah/conf/module/image_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
## See the License for the specific language governing permissions and
## limitations under the License.
##
_target_: hannah.modules.ImageClassifierModule
_target_: hannah.modules.vision.ImageClassifierModule
num_workers: 0
batch_size: 128
shuffle_all_dataloaders: False
43 changes: 17 additions & 26 deletions hannah/datasets/fake1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,44 +26,35 @@

from ..utils.utils import extract_from_download_cache, list_all_files
from .base import AbstractDataset, DatasetType
from .vision.base import TorchvisionDatasetBase


class Fake1dDataset(TorchvisionDatasetBase):
class Fake1dDataset(AbstractDataset):
def __init__(self, config, size):
self.config = config
self.size = size

self.data = torch.randn((size, config["channels"], config["resolution"])).split(
1, 0
)
self.target = torch.randn(
(size, config.size), dtype=torch.int32, min=0, max=config["num_classes"]
)

@classmethod
def prepare(cls, config):
pass

@classmethod
def splits(cls, config):
resolution = config.resolution
channels = config.channels

test_data = torchvision.datasets.FakeData(
size=128,
image_size=(channels, resolution),
num_classes=config.num_classes,
)
val_data = torchvision.datasets.FakeData(
size=128,
image_size=(channels, resolution),
num_classes=config.num_classes,
)
train_data = torchvision.datasets.FakeData(
size=512,
image_size=(channels, resolution),
num_classes=config.num_classes,
)

return cls(config, train_data), cls(config, val_data), cls(config, test_data)
return cls(config, size=128), cls(config, size=32), cls(config, size=32)

def __getitem__(self, index):
data, target = self.dataset[index]
data = np.array(data).astype(np.float32) / 255
data = self.transform(image=data)["image"]
data = torch.squeeze(data)
data, target = self.data[index], self.target[index]
return data, target

@property
def class_names(self):
return [f"class{n}" for n in range(self.config.num_classes)]

def __len__(self):
return len(self.targets)
11 changes: 2 additions & 9 deletions hannah/datasets/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,13 @@
#
import logging
import re
import tarfile
from collections import Counter, namedtuple
from typing import Dict, List, Optional
from collections import Counter
from typing import List

import albumentations as A
import cv2
import numpy as np
import pandas as pd
import requests
import torch
import torchvision
from albumentations.pytorch import ToTensorV2
from omegaconf import DictConfig
from sklearn.model_selection import train_test_split

from ..base import AbstractDataset

Expand Down
2 changes: 0 additions & 2 deletions hannah/datasets/vision/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
import logging
import os

import albumentations as A
import torch.utils.data as data
import torchvision
from albumentations.pytorch.transforms import ToTensorV2
from torchvision import datasets

from .base import TorchvisionDatasetBase
Expand Down
3 changes: 0 additions & 3 deletions hannah/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@
StreamClassifierModule,
)
from .object_detection import ObjectDetectionModule
from .vision import AnomalyDetectionModule, ImageClassifierModule

__all__ = [
"AnomalyDetectionModule",
"CrossValidationStreamClassifierModule",
"SpeechClassifierModule",
"StreamClassifierModule",
"ImageClassifierModule",
"AnomalyDetectionModule",
"ObjectDetectionModule",
"CartesianClassifierModule",
Expand Down

0 comments on commit 3ca2320

Please sign in to comment.