Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
LBerth committed Sep 5, 2024
1 parent 704d1f7 commit bfa8474
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 59 deletions.
15 changes: 2 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
from lightning.pytorch.cli import LightningCLI

from mfai.torch.segmentation_module import SegmentationLightningModule

from mfai.torch.models import (
DeepLabV3,
DeepLabV3Plus,
HalfUNet,
Segformer,
SwinUNETR,
UNet,
CustomUnet,
UNETRPP,
)
from mfai.torch.dummy_dataset import DummyDataModule
from mfai.torch.segmentation_module import SegmentationLightningModule


def cli_main():
cli = LightningCLI(SegmentationLightningModule, DummyDataModule)
cli = LightningCLI(SegmentationLightningModule, DummyDataModule) # noqa: F841


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions mfai/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path

import numpy
import onnx
import onnxruntime
Expand Down
2 changes: 1 addition & 1 deletion mfai/torch/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lightning.pytorch.core import LightningDataModule
import torch
from lightning.pytorch.core import LightningDataModule
from torch.utils.data import DataLoader, Dataset


Expand Down
7 changes: 4 additions & 3 deletions mfai/torch/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from pathlib import Path
from typing import Optional, Tuple, Literal
from typing import Optional, Tuple

from torch import nn

from .deeplabv3 import DeepLabV3, DeepLabV3Plus
from .half_unet import HalfUNet
from .segformer import Segformer
from .swinunetr import SwinUNETR
from .unet import UNet, CustomUnet
from .unet import CustomUnet, UNet
from .unetrpp import UNETRPP


all_nn_architectures = (
DeepLabV3,
DeepLabV3Plus,
Expand Down
3 changes: 2 additions & 1 deletion mfai/torch/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from dataclasses_json import dataclass_json
from torch import nn

from .base import ModelABC
from mfai.torch.models.encoders import get_encoder

from .base import ModelABC


class DoubleConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int, name: str):
Expand Down
4 changes: 2 additions & 2 deletions mfai/torch/models/unetrpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
Added 2d support and Bilinear interpolation for upsampling.
"""

import functools
import math
import operator
import warnings
from dataclasses import dataclass
import operator
from typing import Sequence, Tuple, Union
import functools

import torch
import torch.nn as nn
Expand Down
3 changes: 2 additions & 1 deletion mfai/torch/namedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from copy import deepcopy
from dataclasses import dataclass
from itertools import chain
from functools import cached_property
from itertools import chain
from typing import List, Union

import torch
from tabulate import tabulate

Expand Down
11 changes: 5 additions & 6 deletions mfai/torch/segmentation_module.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pathlib import Path
from typing import Callable, Literal

import lightning.pytorch as pl
import pandas as pd
import torch
import torchmetrics as tm
import pandas as pd
from pytorch_lightning.utilities import rank_zero_only

from typing import Literal, Callable
from pathlib import Path
from mfai.torch.models.base import ModelABC

from pytorch_lightning.utilities import rank_zero_only

# define custom scalar in tensorboard, to have 2 lines on same graph
layout = {
"Check Overfit": {
Expand Down Expand Up @@ -230,4 +230,3 @@ def probabilities_to_classes(self, y_hat):
# Default detection threshold = 0.5
y_hat = (y_hat > 0.5).int()
return y_hat

30 changes: 16 additions & 14 deletions tests/test_lightning.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import torch
from lightning.pytorch.cli import ArgsType, LightningCLI

from mfai.torch.segmentation_module import SegmentationLightningModule

from mfai.torch.models import DeepLabV3, DeepLabV3Plus, HalfUNet, Segformer, SwinUNETR, UNet, CustomUnet, UNETRPP
from mfai.torch.dummy_dataset import DummyDataModule
import torch
from mfai.torch.models import UNet
from mfai.torch.segmentation_module import SegmentationLightningModule


def test_init_train_forward():
arch = UNet(in_channels=1, out_channels=1, input_shape=[64, 64])
loss = torch.nn.MSELoss()
model = SegmentationLightningModule(arch, 'binary', loss)
model = SegmentationLightningModule(arch, "binary", loss)
x = torch.randn((1, 1, 64, 64)).float()
y = torch.randint(0, 1, (1, 1, 64, 64)).float()
model.training_step((x, y), 0)
Expand All @@ -25,15 +24,18 @@ def cli_main(args: ArgsType = None):


def test_cli():
cli_main([
"--model.model=Segformer",
"--model.type_segmentation=binary",
"--model.loss=torch.nn.BCEWithLogitsLoss",
"--in_channels=2",
"--out_channels=1",
"--input_shape=[64, 64]",
"--trainer.fast_dev_run=True"
])
cli_main(
[
"--model.model=Segformer",
"--model.type_segmentation=binary",
"--model.loss=torch.nn.BCEWithLogitsLoss",
"--in_channels=2",
"--out_channels=1",
"--input_shape=[64, 64]",
"--trainer.fast_dev_run=True",
]
)


def test_cli_with_config_file():
cli_main(["--config=mfai/config/cli_fit_test.yaml", "--trainer.fast_dev_run=True"])
11 changes: 6 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
4. onnx loaded and used for inference
"""

from pathlib import Path
import tempfile
from pathlib import Path
from typing import Tuple

from marshmallow.exceptions import ValidationError
import torch
import pytest
import torch
from marshmallow.exceptions import ValidationError

from mfai.torch import export_to_onnx, onnx_load_and_infer
from mfai.torch.models import (
DeepLabV3Plus,
HalfUNet,
all_nn_architectures,
load_from_settings_file,
HalfUNet,
DeepLabV3Plus,
)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_namedtensors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import pytest
import torch

from mfai.torch.namedtensor import NamedTensor


Expand Down
13 changes: 1 addition & 12 deletions train_and_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
from lightning.pytorch.cli import LightningCLI

from mfai.torch.segmentation_module import SegmentationLightningModule

from mfai.torch.models import (
DeepLabV3,
DeepLabV3Plus,
HalfUNet,
Segformer,
SwinUNETR,
UNet,
CustomUnet,
UNETRPP,
)
from mfai.torch.dummy_dataset import DummyDataModule
from mfai.torch.segmentation_module import SegmentationLightningModule


def cli_main():
Expand Down

0 comments on commit bfa8474

Please sign in to comment.