Skip to content

Commit

Permalink
Merge pull request #214 from pass-culture/staging
Browse files Browse the repository at this point in the history
MEP #20240830
  • Loading branch information
cdelabre authored Aug 30, 2024
2 parents 1d7fe13 + 41cdf99 commit 1f5dc3a
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 57 deletions.
26 changes: 18 additions & 8 deletions apps/recommendation/api/src/huggy/core/endpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from abc import ABC
from dataclasses import dataclass
from typing import Optional

from huggy.schemas.item import RecommendableItem
from aiocache import Cache
from aiocache.serializers import PickleSerializer
from huggy.utils.hash import hash_from_keys


@dataclass
class PredictionResultIem:
model_version: str
model_display_name: str
recommendable_items: list[RecommendableItem]
VERTEX_CACHE = Cache(
Cache.MEMORY, ttl=6000, serializer=PickleSerializer(), namespace="vertex_cache"
)


class AbstractEndpoint(ABC): # noqa: B024
Expand Down Expand Up @@ -45,3 +44,14 @@ async def to_dict(self):
"cached": self.cached,
"use_cache": self.use_cache,
}

@staticmethod
def _get_instance_hash(instance: dict, ignore_keys: Optional[list] = None) -> str:
"""
Generate a hash from the instance to use as a key for caching
"""
# drop call_id from instance
if ignore_keys is None:
ignore_keys = ["call_id"]
keys = [k for k in instance if k not in ignore_keys]
return hash_from_keys(instance, keys=keys)
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from abc import abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Optional

from aiocache import Cache
from aiocache.serializers import PickleSerializer
from fastapi.encoders import jsonable_encoder
from huggy.core.endpoint import AbstractEndpoint, PredictionResultIem
from huggy.core.endpoint import VERTEX_CACHE, AbstractEndpoint
from huggy.schemas.item import RecommendableItem
from huggy.schemas.offer import Offer
from huggy.schemas.playlist_params import PlaylistParams
from huggy.schemas.user import UserContext
from huggy.utils.cloud_logging import logger
from huggy.utils.hash import hash_from_keys
from huggy.utils.vertex_ai import endpoint_score

VERTEX_CACHE = Cache(
Cache.MEMORY, ttl=6000, serializer=PickleSerializer(), namespace="vertex_cache"
)

@dataclass
class RetrievalPredictionResultItem:
model_version: str
model_display_name: str
recommendable_items: list[RecommendableItem]


def to_datetime(ts):
Expand Down Expand Up @@ -108,18 +107,6 @@ def init_input(self, user: UserContext, params_in: PlaylistParams, call_id: str)
self.user_input = str(self.user.user_id)
self.is_geolocated = self.user.is_geolocated

def _get_instance_hash(
self, instance: dict, ignore_keys: Optional[list] = None
) -> str:
"""
Generate a hash from the instance to use as a key for caching
"""
# drop call_id from instance
if ignore_keys is None:
ignore_keys = ["call_id"]
keys = [k for k in instance if k not in ignore_keys]
return hash_from_keys(instance, keys=keys)

@abstractmethod
def get_instance(self, size):
pass
Expand Down Expand Up @@ -205,14 +192,16 @@ def get_params(self):

return filters

async def _vertex_retrieval_score(self, instance: dict) -> PredictionResultIem:
async def _vertex_retrieval_score(
self, instance: dict
) -> RetrievalPredictionResultItem:
prediction_result = await endpoint_score(
instances=instance,
endpoint_name=self.endpoint_name,
fallback_endpoints=self.fallback_endpoints,
)

return PredictionResultIem(
return RetrievalPredictionResultItem(
model_version=prediction_result.model_version,
model_display_name=prediction_result.model_display_name,
recommendable_items=[
Expand Down Expand Up @@ -248,29 +237,27 @@ async def _vertex_retrieval_score(self, instance: dict) -> PredictionResultIem:
)

async def model_score(self) -> list[RecommendableItem]:
# TODO same for ranking
result: RetrievalPredictionResultItem = None
instance = self.get_instance(self.size)
# Retrieve cache if exists
if self.use_cache:
instance_hash = self._get_instance_hash(instance)
cache_key = f"{self.endpoint_name}:{instance_hash}"
result: PredictionResultIem = await VERTEX_CACHE.get(cache_key)
# Compute retrieval if cache not found or used
if result is not None:
self.cached = True
self._log_cache_usage(cache_key, "Used")
return result.recommendable_items

result = await self._vertex_retrieval_score(instance)
# Update Cache
if self.use_cache:
await VERTEX_CACHE.set(cache_key, result)
self._log_cache_usage(cache_key, "Set")
result: RetrievalPredictionResultItem = await VERTEX_CACHE.get(cache_key)

if result is not None:
self.cached = True
self._log_cache_usage(cache_key, "Used")
# Compute retrieval if cache not found or used
else:
result = await self._vertex_retrieval_score(instance)
if self.use_cache and result is not None:
await VERTEX_CACHE.set(cache_key, result)
self._log_cache_usage(cache_key, "Set")

# set model version and model display name for logging purposes
self.model_display_name = result.model_display_name
self.model_version = result.model_version

return result.recommendable_items

def _log_cache_usage(self, cache_key: str, action: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
from huggy.schemas.offer import Offer
from huggy.schemas.user import UserContext
from huggy.schemas.utils import parse_input
from huggy.utils.env_vars import DEFAULT_RECO_MODEL, DEFAULT_SIMILAR_OFFER_MODEL
from huggy.utils.env_vars import (
DEFAULT_RECO_MODEL,
DEFAULT_RECO_MODEL_DESCRIPTION,
DEFAULT_SIMILAR_OFFER_DESCRIPTION,
DEFAULT_SIMILAR_OFFER_MODEL,
)
from pydantic import ValidationError

RECOMMENDATION_ENDPOINTS = {
# Default endpoint
"default": RecoModelConfigurationInput(
name="default",
description="""Default Configuration.""",
description=DEFAULT_RECO_MODEL_DESCRIPTION,
diversification_params=DiversificationParamsInput(
diversication_type=DiversificationChoices.ON,
),
Expand Down Expand Up @@ -162,7 +167,7 @@
SIMILAR_OFFER_ENDPOINTS = {
"default": SimilarModelConfigurationInput(
name="default",
description="""Default similar offer configuration (cache).""",
description=DEFAULT_SIMILAR_OFFER_DESCRIPTION,
diversification_params=DiversificationParamsInput(
diversication_type=DiversificationChoices.OFF,
),
Expand Down
53 changes: 44 additions & 9 deletions apps/recommendation/api/src/huggy/core/scorer/offer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import typing as t
from dataclasses import dataclass

import huggy.schemas.item as i
import huggy.schemas.offer as o
import huggy.schemas.playlist_params as pp
import huggy.schemas.recommendable_offer as r_o
import huggy.schemas.user as u
from aiocache import Cache
from aiocache.serializers import PickleSerializer
from huggy.core.endpoint.ranking_endpoint import RankingEndpoint
from huggy.core.endpoint.retrieval_endpoint import RetrievalEndpoint
from huggy.crud.non_recommendable_offer import get_non_recommendable_items
Expand All @@ -14,9 +17,19 @@
from huggy.utils.cloud_logging import logger
from huggy.utils.distance import haversine_distance
from huggy.utils.exception import log_error
from huggy.utils.hash import hash_from_keys
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.asyncio import AsyncSession

OFFER_DB_CACHE = Cache(
Cache.MEMORY, ttl=6000, serializer=PickleSerializer(), namespace="offer_db_cache"
)


@dataclass
class RecommendableOfferResult:
recommendable_offer: list[r_o.RecommendableOffer]


class OfferScorer:
def __init__(
Expand All @@ -35,6 +48,8 @@ def __init__(
self.params_in = params_in
self.retrieval_endpoints = retrieval_endpoints
self.ranking_endpoint = ranking_endpoint
self.use_cache = True
self.offer_cached = False

async def to_dict(self):
return {
Expand Down Expand Up @@ -84,6 +99,8 @@ async def get_scoring(
"call_id": call_id,
"user_id": self.user.user_id,
"total_offers": len(recommendable_offers),
"offer_cached": self.offer_cached,
"use_cache": self.use_cache,
},
)

Expand Down Expand Up @@ -118,13 +135,31 @@ async def get_recommendable_offers(
for item in recommendable_items
if item.item_id not in non_recommendable_items and item.item_id is not None
}
return await self.get_nearest_offers(
db,
self.user,
recommendable_items_ids,
offer=self.offer,
query_order=self.model_params.query_order,
)

result: RecommendableOfferResult = None
if self.use_cache:
instance_hash = hash_from_keys(
{"item_ids": sorted(recommendable_items_ids.keys())}
)
cache_key = f"{self.user.iris_id}:{instance_hash}"
result: RecommendableOfferResult = await OFFER_DB_CACHE.get(cache_key)

if result is not None:
self.offer_cached = True

else:
# get nearest offers
result = await self.get_nearest_offers(
db,
self.user,
recommendable_items_ids,
offer=self.offer,
query_order=self.model_params.query_order,
)
if self.use_cache and result is not None:
await OFFER_DB_CACHE.set(cache_key, result)

return result.recommendable_offer

async def get_distance(
self,
Expand Down Expand Up @@ -163,7 +198,7 @@ async def get_nearest_offers(
limit: int = 500,
offer: t.Optional[o.Offer] = None,
query_order: QueryOrderChoices = QueryOrderChoices.ITEM_RANK,
) -> list[r_o.RecommendableOffer]:
) -> RecommendableOfferResult:
recommendable_offers = []
multiple_item_offers = []
for v in recommendable_items_ids.values():
Expand Down Expand Up @@ -207,4 +242,4 @@ async def get_nearest_offers(

except ProgrammingError as exc:
log_error(exc, message="Exception error on get_nearest_offers")
return recommendable_offers
return RecommendableOfferResult(recommendable_offer=recommendable_offers)
14 changes: 13 additions & 1 deletion apps/recommendation/api/src/huggy/utils/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,21 @@

# config
DEFAULT_SIMILAR_OFFER_MODEL = os.environ.get("DEFAULT_SIMILAR_OFFER_MODEL", "default")
DEFAULT_SIMILAR_OFFER_DESCRIPTION = os.environ.get(
"DEFAULT_SIMILAR_OFFER_DESCRIPTION", "Similar Offer Configuration (default)"
)

DEFAULT_RECO_MODEL = os.environ.get("DEFAULT_RECO_MODEL", "default")
DEFAULT_RECO_MODEL_DESCRIPTION = os.environ.get(
"DEFAULT_RECO_MODEL_DESCRIPTION", "Recommendation Configuration (default)"
)

# endpoints
RANKING_VERSION_B_ENDPOINT_NAME = os.environ.get(
"RANKING_VERSION_B_ENDPOINT_NAME",
f"recommendation_user_ranking_version_b_{ENV_SHORT_NAME}",
)
NUMBER_OF_RECOMMENDATIONS = 40
NUMBER_OF_RECOMMENDATIONS = os.environ.get(
"NUMBER_OF_RECOMMENDATIONS",
40,
)

0 comments on commit 1f5dc3a

Please sign in to comment.