Skip to content

Commit

Permalink
Fix transforms in parameter wrapped objects not following modes (#514)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #514

Models (and maybe some generators) can change modes (`model.train()` and `model.eval()`).  When parameter wrapped, these are still possible but the transforms do not follow the mode.

For base objects that cannot change modes, we just modify the transform.

Reviewed By: crasanders

Differential Revision: D68033333

fbshipit-source-id: d0eaec89322a2cd1db43eb1b5d4b20bd5611e089
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Jan 11, 2025
1 parent 85c97e2 commit 126d286
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 1 deletion.
70 changes: 70 additions & 0 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,41 @@ def acqf_kwargs(self) -> dict | None:
def acqf_kwargs(self, value: dict):
self._base_obj.acqf_kwargs = value

@property
def training(self) -> bool:
# Check if generator has it first
try:
return self._base_obj.training # type: ignore
except AttributeError:
warnings.warn(
f"{self._base_obj.__class__.__name__} has no attribute 'training', returning transforms' `training`"
)
return self.transforms.training

def train(self):
"""Set transforms to train mode and attempts to set the underlying generator to
train as well."""
self.transforms.train()

try:
self._base_obj.train()
except AttributeError:
warnings.warn(
f"{self._base_obj.__class__.__name__} has no attribute 'train'"
)

def eval(self):
"""Set transforms to eval mode and attempts to set the underlying generator to
eval as well."""
self.transforms.eval()

try:
self._base_obj.eval()
except AttributeError:
warnings.warn(
f"{self._base_obj.__class__.__name__} has no attribute 'eval'"
)

@classmethod
def get_config_options(
cls,
Expand Down Expand Up @@ -671,6 +706,41 @@ def p_below_threshold(
x = self.transforms.transform(x)
return self._base_obj.p_below_threshold(x, f_thresh)

@property
def training(self) -> bool:
# Check if model has it first
try:
return self._base_obj.training # type: ignore
except AttributeError:
warnings.warn(
f"{self._base_obj.__class__.__name__} has no attribute 'training', returning transforms' 'training'"
)
return self.transforms.training

def train(self):
"""Set transforms to train mode and attempts to set the underlying model to
train as well."""
self.transforms.train()

try:
self._base_obj.train()
except AttributeError:
warnings.warn(
f"{self._base_obj.__class__.__name__} has no attribute 'train'"
)

def eval(self):
"""Set transforms to eval mode and attempts to set the underlying model to
eval as well."""
self.transforms.eval()

try:
self._base_obj.eval()
except AttributeError:
warnings.warn(
f"{self._base_obj.__class__.__name__} has no attribute 'eval'"
)

@classmethod
def get_config_options(
cls,
Expand Down
49 changes: 48 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,54 @@
ParameterTransforms,
)
from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round
from aepsych.transforms.parameters import Log10Plus, NormalizeScale


class TransformsWrapperTest(unittest.TestCase):
def test_model_mode_change(self):
transforms = ParameterTransforms(norm=NormalizeScale(d=3))
model = ParameterTransformedModel(
GPClassificationModel, dim=3, transforms=transforms
)

# Starts both in training
self.assertTrue(model.training)
self.assertTrue(model.transforms.training)

# Swap to eval
model.eval()
self.assertFalse(model.training)
self.assertFalse(model.transforms.training)

# Swap back to train
model.train()
self.assertTrue(model.training)
self.assertTrue(model.transforms.training)

def test_generator_mode_change(self):
transforms = ParameterTransforms(norm=NormalizeScale(d=3))
generator = ParameterTransformedGenerator(
SobolGenerator,
lb=torch.tensor([0, 0, 0]),
ub=torch.tensor([1, 1, 1]),
transforms=transforms,
)

# Starts both in training
with self.assertWarns(
Warning
): # Sobol can't be moved to eval, so it should warn
self.assertTrue(generator.training)
self.assertTrue(generator.transforms.training)

# Swap to eval
generator.eval()
self.assertFalse(generator.training)
self.assertFalse(generator.transforms.training)

# Swap back to train
generator.train()
self.assertTrue(generator.training)
self.assertTrue(generator.transforms.training)


class TransformsConfigTest(unittest.TestCase):
Expand Down

0 comments on commit 126d286

Please sign in to comment.