From e22fc340bb94cd47f62d9eef559dcdc11dd99da0 Mon Sep 17 00:00:00 2001 From: Heungsub Hans Lee Date: Mon, 30 Sep 2019 15:40:22 +0900 Subject: [PATCH] Import torch.nn consistently --- examples/amoebanetd_speed_benchmark/main.py | 4 ++-- examples/resnet/__init__.py | 2 +- examples/resnet/bottleneck.py | 3 +-- examples/resnet101_accuracy_benchmark/main.py | 8 ++++---- examples/resnet101_speed_benchmark/main.py | 4 ++-- tests/test_bugs.py | 4 ++-- tests/test_deferred_batch_norm.py | 3 +-- tests/test_gpipe.py | 2 +- tests/test_inplace.py | 2 +- tests/test_pipeline.py | 2 +- tests/test_transparency.py | 2 +- torchgpipe/batchnorm.py | 3 +-- torchgpipe/gpipe.py | 3 +-- 13 files changed, 19 insertions(+), 23 deletions(-) diff --git a/examples/amoebanetd_speed_benchmark/main.py b/examples/amoebanetd_speed_benchmark/main.py index 47ca920..dc7a7f4 100644 --- a/examples/amoebanetd_speed_benchmark/main.py +++ b/examples/amoebanetd_speed_benchmark/main.py @@ -6,8 +6,8 @@ import click import torch -import torch.nn as nn -from torch.nn import functional as F +from torch import nn +import torch.nn.functional as F from torch.optim import SGD from torch.utils.data import DataLoader diff --git a/examples/resnet/__init__.py b/examples/resnet/__init__.py index f1eadb8..7a5a1fa 100644 --- a/examples/resnet/__init__.py +++ b/examples/resnet/__init__.py @@ -7,7 +7,7 @@ from collections import OrderedDict from typing import Any, List -import torch.nn as nn +from torch import nn from resnet.bottleneck import bottleneck from resnet.flatten_sequential import flatten_sequential diff --git a/examples/resnet/bottleneck.py b/examples/resnet/bottleneck.py index 2f821d9..f149d95 100644 --- a/examples/resnet/bottleneck.py +++ b/examples/resnet/bottleneck.py @@ -2,8 +2,7 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Optional, Tuple, Union -from torch import Tensor -import torch.nn as nn +from torch import Tensor, nn __all__ = ['bottleneck'] diff --git a/examples/resnet101_accuracy_benchmark/main.py b/examples/resnet101_accuracy_benchmark/main.py index c38cae9..5b76afb 100644 --- a/examples/resnet101_accuracy_benchmark/main.py +++ b/examples/resnet101_accuracy_benchmark/main.py @@ -5,8 +5,8 @@ import click import torch -import torch.nn as nn -from torch.nn import functional as F +from torch import nn +import torch.nn.functional as F from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader @@ -34,7 +34,7 @@ def dataparallel256(model: nn.Module, devices: List[int]) -> Stuffs: devices = [devices[0], devices[1]] model.to(devices[0]) - model = torch.nn.DataParallel(model, device_ids=devices, output_device=devices[-1]) + model = nn.DataParallel(model, device_ids=devices, output_device=devices[-1]) return model, batch_size, [torch.device(device) for device in devices] @@ -44,7 +44,7 @@ def dataparallel1k(model: nn.Module, devices: List[int]) -> Stuffs: devices = [devices[0], devices[1], devices[2], devices[3]] model.to(devices[0]) - model = torch.nn.DataParallel(model, device_ids=devices, output_device=devices[-1]) + model = nn.DataParallel(model, device_ids=devices, output_device=devices[-1]) return model, batch_size, [torch.device(device) for device in devices] diff --git a/examples/resnet101_speed_benchmark/main.py b/examples/resnet101_speed_benchmark/main.py index 5428dd3..08d348d 100644 --- a/examples/resnet101_speed_benchmark/main.py +++ b/examples/resnet101_speed_benchmark/main.py @@ -5,8 +5,8 @@ import click import torch -import torch.nn as nn -from torch.nn import functional as F +from torch import nn +import torch.nn.functional as F from torch.optim import SGD from resnet import resnet101 diff --git a/tests/test_bugs.py b/tests/test_bugs.py index 12ea030..383feb8 100644 --- a/tests/test_bugs.py +++ b/tests/test_bugs.py @@ -1,6 +1,6 @@ import pytest import torch -import torch.nn as nn +from torch import nn from torchgpipe import GPipe @@ -25,7 +25,7 @@ def forward(ctx, input): def backward(ctx, grad): return grad - class M(torch.nn.Module): + class M(nn.Module): def forward(self, input): return Identity.apply(input) diff --git a/tests/test_deferred_batch_norm.py b/tests/test_deferred_batch_norm.py index a41fed2..7f64a8c 100644 --- a/tests/test_deferred_batch_norm.py +++ b/tests/test_deferred_batch_norm.py @@ -3,8 +3,7 @@ import pytest import torch -import torch.nn as nn -import torch.optim as optim +from torch import nn, optim from torchgpipe.batchnorm import DeferredBatchNorm diff --git a/tests/test_gpipe.py b/tests/test_gpipe.py index b86be55..783e245 100644 --- a/tests/test_gpipe.py +++ b/tests/test_gpipe.py @@ -4,7 +4,7 @@ import pytest import torch -import torch.nn as nn +from torch import nn from torchgpipe import GPipe from torchgpipe.gpipe import verify_module diff --git a/tests/test_inplace.py b/tests/test_inplace.py index 9dd2ea3..d809c35 100644 --- a/tests/test_inplace.py +++ b/tests/test_inplace.py @@ -1,6 +1,6 @@ import pytest import torch -import torch.nn as nn +from torch import nn from torchgpipe import GPipe diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 0cf7156..3ac3dcc 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,7 @@ import time import torch -import torch.nn as nn +from torch import nn from torchgpipe.microbatch import Batch from torchgpipe.pipeline import Pipeline diff --git a/tests/test_transparency.py b/tests/test_transparency.py index e9c5a96..31f9c9f 100644 --- a/tests/test_transparency.py +++ b/tests/test_transparency.py @@ -1,5 +1,5 @@ import torch -import torch.nn as nn +from torch import nn from torchgpipe import GPipe diff --git a/torchgpipe/batchnorm.py b/torchgpipe/batchnorm.py index f436e57..805cff9 100644 --- a/torchgpipe/batchnorm.py +++ b/torchgpipe/batchnorm.py @@ -2,8 +2,7 @@ from typing import Optional, TypeVar, cast import torch -from torch import Tensor -import torch.nn as nn +from torch import Tensor, nn import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm diff --git a/torchgpipe/gpipe.py b/torchgpipe/gpipe.py index b2a3f0b..82d899e 100644 --- a/torchgpipe/gpipe.py +++ b/torchgpipe/gpipe.py @@ -3,10 +3,9 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast import torch -from torch import Tensor +from torch import Tensor, nn import torch.autograd import torch.cuda -import torch.nn as nn from torchgpipe.batchnorm import DeferredBatchNorm from torchgpipe.microbatch import check, gather, scatter