Skip to content

Commit

Permalink
chore: Upgrade to Numaflow 0.11 and Fix: Trainer Pod getting the slav…
Browse files Browse the repository at this point in the history
…e redis client connection (#173)
  • Loading branch information
shashank10456 authored Nov 22, 2023
1 parent e8c600b commit e7063e9
Show file tree
Hide file tree
Showing 19 changed files with 574 additions and 561 deletions.
9 changes: 9 additions & 0 deletions numaprom/clients/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def get_redis_client(
mastername: str,
recreate: bool = False,
master_node: bool = True,
reset: bool = False,
) -> redis_client_t:
"""Return a master redis client for sentinel connections, with retry.
Expand All @@ -40,7 +41,12 @@ def get_redis_client(
"""
global SENTINEL_CLIENT

if reset:
LOGGER.info("Reset Sentinel Client to None")
SENTINEL_CLIENT = None

if not recreate and SENTINEL_CLIENT:
LOGGER.info("Reusing Existing Sentinel Client")
return SENTINEL_CLIENT

retry = Retry(
Expand Down Expand Up @@ -70,14 +76,17 @@ def get_redis_client(
**conn_kwargs
)
if master_node:
LOGGER.info("Creating Master Sentinel Redis Client")
SENTINEL_CLIENT = sentinel.master_for(mastername)
else:
LOGGER.info("Creating Slave Sentinel Redis Client")
SENTINEL_CLIENT = sentinel.slave_for(mastername)
LOGGER.info(
"Sentinel redis params: {args}, master_node: {is_master}",
args=conn_kwargs,
is_master=master_node,
)

return SENTINEL_CLIENT


Expand Down
4 changes: 2 additions & 2 deletions numaprom/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Callable

from pynumaflow.function import Messages
from pynumaflow.sink import Responses
from pynumaflow.mapper import Messages
from pynumaflow.sinker import Responses

from numaprom.udf import preprocess, postprocess, window, inference, threshold
from numaprom.udsink import train, train_rollout
Expand Down
2 changes: 1 addition & 1 deletion numaprom/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytz
from numalogic.config import PostprocessFactory
from numalogic.models.threshold import SigmoidThreshold
from pynumaflow.function import Messages, Message
from pynumaflow.mapper import Messages, Message
from numaprom import LOGGER, MetricConf
from numaprom.clients.prometheus import Prometheus
from numaprom.entities import TrainerPayload, StreamPayload
Expand Down
2 changes: 1 addition & 1 deletion numaprom/udf/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numalogic.tools.data import StreamingDataset
from numalogic.tools.exceptions import RedisRegistryError
from orjson import orjson
from pynumaflow.function import Datum
from pynumaflow.mapper import Datum
from torch.utils.data import DataLoader

from numaprom import LOGGER
Expand Down
2 changes: 1 addition & 1 deletion numaprom/udf/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from orjson import orjson
from pynumaflow.function import Datum
from pynumaflow.mapper import Datum
from redis.exceptions import RedisError, RedisClusterException
from redis.sentinel import MasterNotFoundError

Expand Down
3 changes: 2 additions & 1 deletion numaprom/udf/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import orjson
from numalogic.registry import RedisRegistry, LocalLRUCache
from numalogic.tools.exceptions import RedisRegistryError
from pynumaflow.function import Datum
from pynumaflow.mapper import Datum


from numaprom import LOGGER
from numaprom.clients.sentinel import get_redis_client
Expand Down
2 changes: 1 addition & 1 deletion numaprom/udf/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numalogic.registry import RedisRegistry, LocalLRUCache
from numalogic.tools.exceptions import RedisRegistryError
from orjson import orjson
from pynumaflow.function import Datum
from pynumaflow.mapper import Datum

from numaprom import LOGGER
from numaprom._constants import TRAIN_VTX_KEY, POSTPROC_VTX_KEY
Expand Down
2 changes: 1 addition & 1 deletion numaprom/udf/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import numpy.typing as npt
from orjson import orjson
from pynumaflow.function import Datum
from pynumaflow.mapper import Datum
from redis.exceptions import RedisError, RedisClusterException

from numaprom import LOGGER
Expand Down
2 changes: 1 addition & 1 deletion numaprom/udsink/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numalogic.tools.exceptions import RedisRegistryError
from numalogic.tools.types import redis_client_t
from orjson import orjson
from pynumaflow.sink import Datum, Responses, Response
from pynumaflow.sinker import Datum, Responses, Response
from sklearn.pipeline import make_pipeline
from torch.utils.data import DataLoader

Expand Down
7 changes: 5 additions & 2 deletions numaprom/udsink/train_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numalogic.tools.exceptions import RedisRegistryError
from numalogic.tools.types import redis_client_t
from orjson import orjson
from pynumaflow.sink import Datum, Responses, Response
from pynumaflow.sinker import Datum, Responses, Response
from sklearn.pipeline import make_pipeline
from torch.utils.data import DataLoader

Expand All @@ -23,6 +23,8 @@
from numaprom.watcher import ConfigManager

REQUEST_EXPIRY = int(os.getenv("REQUEST_EXPIRY", "300"))
# REDIS_CLIENT = get_redis_client_from_conf(master_node=True, recreate=True)
REDIS_CLIENT_MASTER = get_redis_client_from_conf(master_node=True, reset=True)


# TODO: extract all good hashes, including when there are 2 hashes at a time
Expand Down Expand Up @@ -107,8 +109,9 @@ def get_model_config(metric_config):


def train_rollout(datums: Iterator[Datum]) -> Responses:
global REDIS_CLIENT_MASTER
redis_client = REDIS_CLIENT_MASTER
responses = Responses()
redis_client = get_redis_client_from_conf()

for _datum in datums:
payload = TrainerPayload(**orjson.loads(_datum.value))
Expand Down
Loading

0 comments on commit e7063e9

Please sign in to comment.