diff --git a/archai/discrete_search/evaluators/nlp/parameters.py b/archai/discrete_search/evaluators/nlp/parameters.py index a894f40ee..7734bd691 100644 --- a/archai/discrete_search/evaluators/nlp/parameters.py +++ b/archai/discrete_search/evaluators/nlp/parameters.py @@ -29,7 +29,7 @@ def __init__(self, exclude_cls: Optional[List[nn.Module]] = None, trainable_only """ - self.exclude_cls = [nn.Embedding] or exclude_cls + self.exclude_cls = exclude_cls or [nn.Embedding] self.trainable_only = trainable_only @overrides