Skip to content

Commit

Permalink
Add environment variable backed properties to config
Browse files Browse the repository at this point in the history
* Allow properties to be defined by environment
variables.
  • Loading branch information
8W9aG committed Nov 8, 2024
1 parent e46c4f3 commit c72742f
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/cog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .base_input import BaseInput
from .base_predictor import BasePredictor
from .code_xforms import load_module_from_string, strip_model_source_code
from .env_property import env_property
from .errors import ConfigDoesNotExist
from .mode import Mode
from .predictor import (
Expand All @@ -23,6 +24,11 @@
from .types import CogConfig

COG_YAML_FILE = "cog.yaml"
COG_PREDICT_TYPE_STUB_ENV_VAR = "COG_PREDICT_TYPE_STUB"
COG_TRAIN_TYPE_STUB_ENV_VAR = "COG_TRAIN_TYPE_STUB"
COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP"
COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP"
COG_GPU_ENV_VAR = "COG_GPU"
PREDICT_METHOD_NAME = "predict"
TRAIN_METHOD_NAME = "train"

Expand All @@ -37,6 +43,14 @@ def _method_name_from_mode(mode: Mode) -> str:
raise ValueError(f"Mode {mode} not recognised for method name mapping")


def _env_var_from_mode(mode: Mode) -> str:
if mode == Mode.PREDICT:
return COG_PREDICT_CODE_STRIP_ENV_VAR
elif mode == Mode.TRAIN:
return COG_TRAIN_CODE_STRIP_ENV_VAR
raise ValueError(f"Mode {mode} not recognised for env var mapping")


class Config:
"""A class for reading the cog.yaml properties."""

Expand Down Expand Up @@ -65,16 +79,19 @@ def _cog_config(self) -> CogConfig:
return config

@property
@env_property(COG_PREDICT_TYPE_STUB_ENV_VAR)
def predictor_predict_ref(self) -> Optional[str]:
"""Find the predictor ref for the predict mode."""
return self._cog_config.get(str(Mode.PREDICT))

@property
@env_property(COG_TRAIN_TYPE_STUB_ENV_VAR)
def predictor_train_ref(self) -> Optional[str]:
"""Find the predictor ref for the train mode."""
return self._cog_config.get(str(Mode.TRAIN))

@property
@env_property(COG_GPU_ENV_VAR)
def requires_gpu(self) -> bool:
"""Whether this cog requires the use of a GPU."""
return bool(self._cog_config.get("build", {}).get("gpu", False))
Expand All @@ -87,6 +104,9 @@ def _predictor_code(
mode: Mode,
module_name: str,
) -> Optional[str]:
source_code = os.environ.get(_env_var_from_mode(mode))
if source_code is not None:
return source_code
if sys.version_info >= (3, 9):
with open(module_path, encoding="utf-8") as file:
return strip_model_source_code(file.read(), [class_name], [method_name])
Expand Down
42 changes: 42 additions & 0 deletions python/cog/env_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, Union

R = TypeVar("R")


def _get_origin(typ: Any) -> Any:
if hasattr(typ, "__origin__"):
return typ.__origin__
return None


def _get_args(typ: Any) -> Any:
if hasattr(typ, "__args__"):
return typ.__args__
return ()


def env_property(
env_var: str,
) -> Callable[[Callable[[Any], R]], Callable[[Any], R]]:
"""Wraps a class property in an environment variable check."""

def decorator(func: Callable[[Any], R]) -> Callable[[Any], R]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> R:
result = os.environ.get(env_var)
if result is not None:
expected_type = func.__annotations__.get("return", str)
if (
_get_origin(expected_type) is Optional
or _get_origin(expected_type) is Union
):
expected_type = _get_args(expected_type)[0]
return expected_type(result)
result = func(*args, **kwargs)
return result

return wrapper

return decorator
125 changes: 125 additions & 0 deletions python/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,110 @@
import os
import tempfile

import pytest

from cog.config import (
COG_GPU_ENV_VAR,
COG_PREDICT_CODE_STRIP_ENV_VAR,
COG_PREDICT_TYPE_STUB_ENV_VAR,
COG_TRAIN_TYPE_STUB_ENV_VAR,
COG_YAML_FILE,
Config,
)
from cog.errors import ConfigDoesNotExist
from cog.mode import Mode


def test_predictor_predict_ref_env_var():
predict_ref = "predict.py:Predictor"
os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] = predict_ref
config = Config()
config_predict_ref = config.predictor_predict_ref
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
assert (
config_predict_ref == predict_ref
), "Predict Reference should come from the environment variable."


def test_predictor_predict_ref_no_env_var():
if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
pwd = os.getcwd()
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with open(COG_YAML_FILE, "w", encoding="utf-8") as handle:
handle.write("""
build:
python_version: "3.11"
predict: "predict.py:Predictor"
""")
config = Config()
config_predict_ref = config.predictor_predict_ref
assert (
config_predict_ref == "predict.py:Predictor"
), "Predict Reference should come from the cog config file."
os.chdir(pwd)


def test_config_no_config_file():
if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
config = Config()
with pytest.raises(ConfigDoesNotExist):
_ = config.predictor_predict_ref


def test_config_initial_values():
if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
config = Config(config={"predict": "predict.py:Predictor"})
config_predict_ref = config.predictor_predict_ref
assert (
config_predict_ref == "predict.py:Predictor"
), "Predict Reference should come from the initial config dictionary."


def test_predictor_train_ref_env_var():
train_ref = "predict.py:Predictor"
os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR] = train_ref
config = Config()
config_train_ref = config.predictor_train_ref
del os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR]
assert (
config_train_ref == train_ref
), "Train Reference should come from the environment variable."


def test_predictor_train_ref_no_env_var():
train_ref = "predict.py:Predictor"
if COG_TRAIN_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR]
config = Config(config={"train": train_ref})
config_train_ref = config.predictor_train_ref
assert (
config_train_ref == train_ref
), "Train Reference should come from the initial config dictionary."


def test_requires_gpu_env_var():
gpu = True
os.environ[COG_GPU_ENV_VAR] = str(gpu)
config = Config()
config_gpu = config.requires_gpu
del os.environ[COG_GPU_ENV_VAR]
assert config_gpu, "Requires GPU should come from the environment variable."


def test_requires_gpu_no_env_var():
if COG_GPU_ENV_VAR in os.environ:
del os.environ[COG_GPU_ENV_VAR]
config = Config(config={"build": {"gpu": False}})
config_gpu = config.requires_gpu
assert (
not config_gpu
), "Requires GPU should come from the initial config dictionary."


def test_get_predictor_ref_predict():
train_ref = "predict.py:Predictor"
config = Config(config={"train": train_ref})
Expand All @@ -25,6 +123,33 @@ def test_get_predictor_ref_train():
), "The predict ref should equal the config predict ref."


def test_get_predictor_types_with_env_var():
predict_ref = "predict.py:Predictor"
os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] = predict_ref
os.environ[COG_PREDICT_CODE_STRIP_ENV_VAR] = """
from cog import BasePredictor, Path
from typing import Optional
from pydantic import BaseModel
class ModelOutput(BaseModel):
success: bool
error: Optional[str]
segmentedImage: Optional[Path]
class Predictor(BasePredictor):
def predict(self, msg: str) -> ModelOutput:
return None
"""
config = Config()
input_type, output_type = config.get_predictor_types(Mode.PREDICT)
del os.environ[COG_PREDICT_CODE_STRIP_ENV_VAR]
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
assert (
str(input_type) == "<class 'cog.predictor.Input'>"
), "Predict input type should be the predictor Input."
assert (
str(output_type) == "<class 'cog.predictor.get_output_type.<locals>.Output'>"
), "Predict output type should be the predictor Output."


def test_get_predictor_types():
with tempfile.TemporaryDirectory() as tmpdir:
predict_python_file = os.path.join(tmpdir, "predict.py")
Expand Down

0 comments on commit c72742f

Please sign in to comment.