diff --git a/src/brevitas/nn/__init__.py b/src/brevitas/nn/__init__.py index 7138da2bd..cc96889b6 100644 --- a/src/brevitas/nn/__init__.py +++ b/src/brevitas/nn/__init__.py @@ -17,7 +17,6 @@ from .quant_conv import QuantConv2d from .quant_convtranspose import QuantConvTranspose1d from .quant_convtranspose import QuantConvTranspose2d -from .quant_dropout import QuantDropout from .quant_eltwise import QuantCat from .quant_eltwise import QuantEltwiseAdd from .quant_embedding import QuantEmbedding diff --git a/src/brevitas/nn/quant_dropout.py b/src/brevitas/nn/quant_dropout.py deleted file mode 100644 index 388752a9d..000000000 --- a/src/brevitas/nn/quant_dropout.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Union - -from torch import Tensor -from torch.nn import Dropout - -from brevitas.quant_tensor import QuantTensor - -from .mixin.base import QuantLayerMixin - - -class QuantDropout(QuantLayerMixin, Dropout): - - def __init__(self, p: float = 0.5, return_quant_tensor: bool = True): - Dropout.__init__(self, p=p, inplace=False) - QuantLayerMixin.__init__(self, return_quant_tensor=return_quant_tensor) - - @property - def channelwise_separable(self) -> bool: - return True - - @property - def requires_export_handler(self): - return False - - def forward(self, input: Union[Tensor, QuantTensor]): - x = self.unpack_input(input) - x = x.set(value=super().forward(x.value)) - return self.pack_output(x)