Skip to content

Commit

Permalink
feat: initial config schema (#135)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 authored Feb 8, 2023
1 parent 7fdd599 commit d3488c9
Show file tree
Hide file tree
Showing 8 changed files with 1,021 additions and 754 deletions.
31 changes: 31 additions & 0 deletions numalogic/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from numalogic.config._config import NumalogicConf, ModelInfo, LightningTrainerConf, RegistryConf
from numalogic.config.factory import (
ModelFactory,
PreprocessFactory,
PostprocessFactory,
ThresholdFactory,
)


__all__ = [
"NumalogicConf",
"ModelInfo",
"LightningTrainerConf",
"RegistryConf",
"ModelFactory",
"PreprocessFactory",
"PostprocessFactory",
"ThresholdFactory",
]
76 changes: 76 additions & 0 deletions numalogic/config/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass, field
from typing import List, Optional, Any, Dict

from omegaconf import MISSING


@dataclass
class ModelInfo:
"""
Schema for defining the model/estimator.
Args:
name: name of the model; this should map to a supported list of models
mentioned in the factory file
conf: kwargs for instantiating the model class
stateful: flag indicating if the model is stateful or not
"""

name: str = MISSING
conf: Dict[str, Any] = field(default_factory=dict)
stateful: bool = True


@dataclass
class RegistryConf:
# TODO implement this
"""
Registry config base class
"""
pass


@dataclass
class LightningTrainerConf:
"""
Schema for defining the Pytorch Lightning trainer behavior.
More details on the arguments are provided here:
https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api
"""

max_epochs: int = 100
logger: bool = False
check_val_every_n_epoch: int = 5
log_every_n_steps: int = 20
enable_checkpointing: bool = False
enable_progress_bar: bool = True
enable_model_summary: bool = True
limit_val_batches: bool = 0
callbacks: Optional[Any] = None


@dataclass
class NumalogicConf:
"""
Top level config schema for numalogic.
"""

model: ModelInfo = field(default_factory=ModelInfo)
trainer: LightningTrainerConf = field(default_factory=LightningTrainerConf)
registry: RegistryConf = field(default_factory=RegistryConf)
preprocess: List[ModelInfo] = field(default_factory=list)
threshold: ModelInfo = field(default_factory=ModelInfo)
postprocess: ModelInfo = field(default_factory=ModelInfo)
78 changes: 78 additions & 0 deletions numalogic/config/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler

from numalogic.config._config import ModelInfo
from numalogic.models.autoencoder.variants import (
VanillaAE,
SparseVanillaAE,
Conv1dAE,
SparseConv1dAE,
LSTMAE,
SparseLSTMAE,
TransformerAE,
SparseTransformerAE,
)
from numalogic.models.threshold import StdDevThreshold
from numalogic.postprocess import TanhNorm
from numalogic.preprocess import LogTransformer, StaticPowerTransformer
from numalogic.tools.exceptions import UnknownConfigArgsError


class _ObjectFactory:
_CLS_MAP = {}

def get_instance(self, model_info: ModelInfo):
try:
_cls = self._CLS_MAP[model_info.name]
except KeyError:
raise UnknownConfigArgsError(f"Invalid model info instance provided: {model_info}")
return _cls(**model_info.conf)

def get_cls(self, model_info: ModelInfo):
try:
return self._CLS_MAP[model_info.name]
except KeyError:
raise UnknownConfigArgsError(f"Invalid model info instance provided: {model_info}")


class PreprocessFactory(_ObjectFactory):
_CLS_MAP = {
"StandardScaler": StandardScaler,
"MinMaxScaler": MinMaxScaler,
"MaxAbsScaler": MaxAbsScaler,
"RobustScaler": RobustScaler,
"LogTransformer": LogTransformer,
"StaticPowerTransformer": StaticPowerTransformer,
}


class PostprocessFactory(_ObjectFactory):
_CLS_MAP = {"TanhNorm": TanhNorm}


class ThresholdFactory(_ObjectFactory):
_CLS_MAP = {"StdDevThreshold": StdDevThreshold}


class ModelFactory(_ObjectFactory):
_CLS_MAP = {
"VanillaAE": VanillaAE,
"SparseVanillaAE": SparseVanillaAE,
"Conv1dAE": Conv1dAE,
"SparseConv1dAE": SparseConv1dAE,
"LSTMAE": LSTMAE,
"SparseLSTMAE": SparseLSTMAE,
"TransformerAE": TransformerAE,
"SparseTransformerAE": SparseTransformerAE,
}
4 changes: 4 additions & 0 deletions numalogic/tools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ class DataModuleError(Exception):

class InvalidDataShapeError(Exception):
pass


class UnknownConfigArgsError(Exception):
pass
Loading

0 comments on commit d3488c9

Please sign in to comment.