Skip to content

Commit

Permalink
Import torch.nn consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Sep 30, 2019
1 parent db3e610 commit e22fc34
Show file tree
Hide file tree
Showing 13 changed files with 19 additions and 23 deletions.
4 changes: 2 additions & 2 deletions examples/amoebanetd_speed_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import click
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader

Expand Down
2 changes: 1 addition & 1 deletion examples/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import OrderedDict
from typing import Any, List

import torch.nn as nn
from torch import nn

from resnet.bottleneck import bottleneck
from resnet.flatten_sequential import flatten_sequential
Expand Down
3 changes: 1 addition & 2 deletions examples/resnet/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional, Tuple, Union

from torch import Tensor
import torch.nn as nn
from torch import Tensor, nn

__all__ = ['bottleneck']

Expand Down
8 changes: 4 additions & 4 deletions examples/resnet101_accuracy_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import click
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -34,7 +34,7 @@ def dataparallel256(model: nn.Module, devices: List[int]) -> Stuffs:

devices = [devices[0], devices[1]]
model.to(devices[0])
model = torch.nn.DataParallel(model, device_ids=devices, output_device=devices[-1])
model = nn.DataParallel(model, device_ids=devices, output_device=devices[-1])

return model, batch_size, [torch.device(device) for device in devices]

Expand All @@ -44,7 +44,7 @@ def dataparallel1k(model: nn.Module, devices: List[int]) -> Stuffs:

devices = [devices[0], devices[1], devices[2], devices[3]]
model.to(devices[0])
model = torch.nn.DataParallel(model, device_ids=devices, output_device=devices[-1])
model = nn.DataParallel(model, device_ids=devices, output_device=devices[-1])

return model, batch_size, [torch.device(device) for device in devices]

Expand Down
4 changes: 2 additions & 2 deletions examples/resnet101_speed_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import click
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD

from resnet import resnet101
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bugs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
import torch.nn as nn
from torch import nn

from torchgpipe import GPipe

Expand All @@ -25,7 +25,7 @@ def forward(ctx, input):
def backward(ctx, grad):
return grad

class M(torch.nn.Module):
class M(nn.Module):
def forward(self, input):
return Identity.apply(input)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_deferred_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import pytest
import torch
import torch.nn as nn
import torch.optim as optim
from torch import nn, optim

from torchgpipe.batchnorm import DeferredBatchNorm

Expand Down
2 changes: 1 addition & 1 deletion tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import torch
import torch.nn as nn
from torch import nn

from torchgpipe import GPipe
from torchgpipe.gpipe import verify_module
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inplace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
import torch.nn as nn
from torch import nn

from torchgpipe import GPipe

Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

import torch
import torch.nn as nn
from torch import nn

from torchgpipe.microbatch import Batch
from torchgpipe.pipeline import Pipeline
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transparency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
import torch.nn as nn
from torch import nn

from torchgpipe import GPipe

Expand Down
3 changes: 1 addition & 2 deletions torchgpipe/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import Optional, TypeVar, cast

import torch
from torch import Tensor
import torch.nn as nn
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm

Expand Down
3 changes: 1 addition & 2 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast

import torch
from torch import Tensor
from torch import Tensor, nn
import torch.autograd
import torch.cuda
import torch.nn as nn

from torchgpipe.batchnorm import DeferredBatchNorm
from torchgpipe.microbatch import check, gather, scatter
Expand Down

0 comments on commit e22fc34

Please sign in to comment.