-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Avik Basu <[email protected]>
- Loading branch information
Showing
8 changed files
with
1,021 additions
and
754 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright 2022 The Numaproj Authors. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from numalogic.config._config import NumalogicConf, ModelInfo, LightningTrainerConf, RegistryConf | ||
from numalogic.config.factory import ( | ||
ModelFactory, | ||
PreprocessFactory, | ||
PostprocessFactory, | ||
ThresholdFactory, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"NumalogicConf", | ||
"ModelInfo", | ||
"LightningTrainerConf", | ||
"RegistryConf", | ||
"ModelFactory", | ||
"PreprocessFactory", | ||
"PostprocessFactory", | ||
"ThresholdFactory", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright 2022 The Numaproj Authors. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from dataclasses import dataclass, field | ||
from typing import List, Optional, Any, Dict | ||
|
||
from omegaconf import MISSING | ||
|
||
|
||
@dataclass | ||
class ModelInfo: | ||
""" | ||
Schema for defining the model/estimator. | ||
Args: | ||
name: name of the model; this should map to a supported list of models | ||
mentioned in the factory file | ||
conf: kwargs for instantiating the model class | ||
stateful: flag indicating if the model is stateful or not | ||
""" | ||
|
||
name: str = MISSING | ||
conf: Dict[str, Any] = field(default_factory=dict) | ||
stateful: bool = True | ||
|
||
|
||
@dataclass | ||
class RegistryConf: | ||
# TODO implement this | ||
""" | ||
Registry config base class | ||
""" | ||
pass | ||
|
||
|
||
@dataclass | ||
class LightningTrainerConf: | ||
""" | ||
Schema for defining the Pytorch Lightning trainer behavior. | ||
More details on the arguments are provided here: | ||
https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api | ||
""" | ||
|
||
max_epochs: int = 100 | ||
logger: bool = False | ||
check_val_every_n_epoch: int = 5 | ||
log_every_n_steps: int = 20 | ||
enable_checkpointing: bool = False | ||
enable_progress_bar: bool = True | ||
enable_model_summary: bool = True | ||
limit_val_batches: bool = 0 | ||
callbacks: Optional[Any] = None | ||
|
||
|
||
@dataclass | ||
class NumalogicConf: | ||
""" | ||
Top level config schema for numalogic. | ||
""" | ||
|
||
model: ModelInfo = field(default_factory=ModelInfo) | ||
trainer: LightningTrainerConf = field(default_factory=LightningTrainerConf) | ||
registry: RegistryConf = field(default_factory=RegistryConf) | ||
preprocess: List[ModelInfo] = field(default_factory=list) | ||
threshold: ModelInfo = field(default_factory=ModelInfo) | ||
postprocess: ModelInfo = field(default_factory=ModelInfo) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2022 The Numaproj Authors. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler | ||
|
||
from numalogic.config._config import ModelInfo | ||
from numalogic.models.autoencoder.variants import ( | ||
VanillaAE, | ||
SparseVanillaAE, | ||
Conv1dAE, | ||
SparseConv1dAE, | ||
LSTMAE, | ||
SparseLSTMAE, | ||
TransformerAE, | ||
SparseTransformerAE, | ||
) | ||
from numalogic.models.threshold import StdDevThreshold | ||
from numalogic.postprocess import TanhNorm | ||
from numalogic.preprocess import LogTransformer, StaticPowerTransformer | ||
from numalogic.tools.exceptions import UnknownConfigArgsError | ||
|
||
|
||
class _ObjectFactory: | ||
_CLS_MAP = {} | ||
|
||
def get_instance(self, model_info: ModelInfo): | ||
try: | ||
_cls = self._CLS_MAP[model_info.name] | ||
except KeyError: | ||
raise UnknownConfigArgsError(f"Invalid model info instance provided: {model_info}") | ||
return _cls(**model_info.conf) | ||
|
||
def get_cls(self, model_info: ModelInfo): | ||
try: | ||
return self._CLS_MAP[model_info.name] | ||
except KeyError: | ||
raise UnknownConfigArgsError(f"Invalid model info instance provided: {model_info}") | ||
|
||
|
||
class PreprocessFactory(_ObjectFactory): | ||
_CLS_MAP = { | ||
"StandardScaler": StandardScaler, | ||
"MinMaxScaler": MinMaxScaler, | ||
"MaxAbsScaler": MaxAbsScaler, | ||
"RobustScaler": RobustScaler, | ||
"LogTransformer": LogTransformer, | ||
"StaticPowerTransformer": StaticPowerTransformer, | ||
} | ||
|
||
|
||
class PostprocessFactory(_ObjectFactory): | ||
_CLS_MAP = {"TanhNorm": TanhNorm} | ||
|
||
|
||
class ThresholdFactory(_ObjectFactory): | ||
_CLS_MAP = {"StdDevThreshold": StdDevThreshold} | ||
|
||
|
||
class ModelFactory(_ObjectFactory): | ||
_CLS_MAP = { | ||
"VanillaAE": VanillaAE, | ||
"SparseVanillaAE": SparseVanillaAE, | ||
"Conv1dAE": Conv1dAE, | ||
"SparseConv1dAE": SparseConv1dAE, | ||
"LSTMAE": LSTMAE, | ||
"SparseLSTMAE": SparseLSTMAE, | ||
"TransformerAE": TransformerAE, | ||
"SparseTransformerAE": SparseTransformerAE, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.