Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(DE-233) feat(api-reco): include by default list of items for retrieval + add factory #239

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
from dataclasses import dataclass
from datetime import datetime
from typing import Optional

from fastapi.encoders import jsonable_encoder
from huggy.core.endpoint import VERTEX_CACHE, AbstractEndpoint
Expand Down Expand Up @@ -271,12 +272,12 @@ def _log_cache_usage(self, cache_key: str, action: str) -> None:
)


class FilterRetrievalEndpoint(RetrievalEndpoint):
MODEL_TYPE = "filter"
class BookingNumberRetrievalEndpoint(RetrievalEndpoint):
MODEL_TYPE = "tops"

def get_instance(self, size: int):
return {
"model_type": "filter",
"model_type": "tops",
"size": size,
"params": self.get_params(),
"call_id": self.call_id,
Expand All @@ -287,11 +288,11 @@ def get_instance(self, size: int):


class CreationTrendRetrievalEndpoint(RetrievalEndpoint):
MODEL_TYPE = "filter"
MODEL_TYPE = "tops"

def get_instance(self, size: int):
return {
"model_type": "filter",
"model_type": "tops",
"size": size,
"params": self.get_params(),
"call_id": self.call_id,
Expand All @@ -302,11 +303,11 @@ def get_instance(self, size: int):


class ReleaseTrendRetrievalEndpoint(RetrievalEndpoint):
MODEL_TYPE = "filter"
MODEL_TYPE = "tops"

def get_instance(self, size: int):
return {
"model_type": "filter",
"model_type": "tops",
"size": size,
"params": self.get_params(),
"call_id": self.call_id,
Expand Down Expand Up @@ -339,19 +340,16 @@ class OfferRetrievalEndpoint(RetrievalEndpoint):
def init_input(
self,
user: UserContext,
offer: Offer,
input_offers: Optional[list[Offer]],
params_in: PlaylistParams,
call_id: str,
):
self.user = user
self.offer = offer
self.input_offers = input_offers
self.call_id = call_id
if params_in.offers:
self.items = [offer.item_id for offer in params_in.offers]
else:
self.items = [str(self.offer.item_id)]
self.params_in = params_in
self.is_geolocated = self.offer.is_geolocated if self.offer else False
self.items = [offer.item_id for offer in self.input_offers]
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
self.is_geolocated = any(offer.is_geolocated for offer in self.input_offers)

def get_instance(self, size: int):
return {
Expand All @@ -374,27 +372,28 @@ class OfferSemanticRetrievalEndpoint(OfferRetrievalEndpoint):
def get_instance(self, size: int):
return {
"model_type": "similar_offer",
"offer_id": str(self.item_id),
"items": self.items,
"size": size,
"params": self.get_params(),
"call_id": self.call_id,
"debug": 1,
"similarity_metric": "l2",
"prefilter": 1,
"vector_column_name": "raw_embeddings",
"user_id": str(self.user.user_id),
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
}


class OfferFilterRetrievalEndpoint(OfferRetrievalEndpoint):
MODEL_TYPE = "filter"
class OfferBookingNumberRetrievalEndpoint(OfferRetrievalEndpoint):
MODEL_TYPE = "tops"

def get_instance(self, size: int):
return {
"model_type": "filter",
"model_type": "tops",
"size": size,
"params": self.get_params(),
"call_id": self.call_id,
"debug": 1,
"vector_column_name": "booking_number_desc",
"user_id": self.user.user_id,
"similarity_metric": "dot",
}
33 changes: 19 additions & 14 deletions apps/recommendation/api/src/huggy/core/model_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from abc import ABC, abstractmethod
from typing import Optional

import huggy.schemas.offer as o
import pytz
Expand All @@ -26,7 +27,7 @@ class ModelEngine(ABC):
params_in (PlaylistParams): The playlist parameters.
call_id (str): The call ID.
context (str): The context.
offer (o.Offer, optional): The offer. Defaults to None.
input_offers (list[o.Offer], optional): The offer. Defaults to None.
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
reco_origin (str): The recommendation origin. One of "unknown", "cold_start", "algo".
model_origin (str): The model origin.
model_params (ModelConfiguration): The model configuration.
Expand All @@ -47,13 +48,10 @@ def __init__(
params_in: PlaylistParams,
call_id: str,
context: str,
offer: o.Offer = None,
input_offers: Optional[list[o.Offer]] = None,
):
self.user = user
self.offer = offer
self.offers = (
list(params_in.offers) if isinstance(params_in.offers, list) else [offer]
)
self.input_offers = input_offers
self.params_in = params_in
self.call_id = call_id
self.context = context
Expand Down Expand Up @@ -89,7 +87,7 @@ def get_scorer(self) -> OfferScorer:
model_params=self.model_params,
retrieval_endpoints=self.model_params.retrieval_endpoints,
ranking_endpoint=self.model_params.ranking_endpoint,
offer=self.offer,
input_offers=self.input_offers,
)

async def get_scoring(self, db: AsyncSession) -> list[str]:
Expand All @@ -109,7 +107,7 @@ async def get_scoring(self, db: AsyncSession) -> list[str]:
# apply diversification filter
if diversification_params.is_active:
scored_offers = order_offers_by_score_and_diversify_features(
offers=scored_offers,
scored_offers=scored_offers,
score_column=diversification_params.order_column,
score_order_ascending=diversification_params.order_ascending,
shuffle_recommendation=diversification_params.is_reco_shuffled,
Expand All @@ -121,7 +119,7 @@ async def get_scoring(self, db: AsyncSession) -> list[str]:
scoring_size = min(len(scored_offers), NUMBER_OF_RECOMMENDATIONS)
await self.save_context(
session=db,
offers=scored_offers[:scoring_size],
scored_offers=scored_offers[:scoring_size],
context=self.context,
user=self.user,
)
Expand All @@ -131,18 +129,25 @@ async def get_scoring(self, db: AsyncSession) -> list[str]:
async def save_context(
self,
session: AsyncSession,
offers: list[RankedOffer],
scored_offers: list[RankedOffer],
context: str,
user: UserContext,
) -> None:
if len(offers) > 0:
if len(scored_offers) > 0:
date = datetime.datetime.now(pytz.utc)
context_extra_data = await self.log_extra_data()
# add similar offer_id origin input.
if self.offer is not None:
context_extra_data["offer_origin_id"] = self.offer.offer_id
if self.input_offers is not None:
if len(self.input_offers) == 1:
context_extra_data["offer_origin_id"] = self.input_offers[
0
].offer_id
else:
context_extra_data["offer_origin_ids"] = ":".join(
[offer.offer_id for offer in self.input_offers]
)
cdelabre marked this conversation as resolved.
Show resolved Hide resolved

for idx, o in enumerate(offers):
for idx, o in enumerate(scored_offers):
session.add(
PastOfferContext(
call_id=self.call_id,
Expand Down
153 changes: 153 additions & 0 deletions apps/recommendation/api/src/huggy/core/model_engine/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Optional

import huggy.schemas.offer as o
from huggy.core.model_engine import ModelEngine
from huggy.core.model_engine.recommendation import Recommendation
from huggy.core.model_engine.similar_offer import SimilarOffer
from huggy.schemas.playlist_params import PlaylistParams
from huggy.schemas.user import UserContext
from sqlalchemy.ext.asyncio import AsyncSession


class ModelEngineOut:
model: ModelEngine
results: list[str]


class ModelEngineFactory:
"""
Factory for creating the appropriate model engine handler.
"""

@staticmethod
async def handle_prediction(
db: AsyncSession,
user: UserContext,
params_in: PlaylistParams,
call_id: str,
context: str,
*,
use_fallback: bool,
input_offers: Optional[list[o.Offer]] = None,
) -> ModelEngineOut:
"""
Returns the appropriate model engine based on input context and offers.
Fallback to default recommendation if no results are found or specific conditions apply.
"""
input_offers = input_offers or []

model_engine = ModelEngineFactory._determine_model_engine(
user, params_in, call_id, context, input_offers
)

# Get results from the selected model engine
results = await model_engine.get_scoring(db)

# Handle fallback scenario if enabled and no results are found
if use_fallback and not results:
model_engine = await ModelEngineFactory._handle_fallback(
user, params_in, call_id, input_offers
)
results = await model_engine.get_scoring(db)

return ModelEngineOut(model=model_engine, results=results)

@staticmethod
def _determine_model_engine(
user: UserContext,
params_in: PlaylistParams,
call_id: str,
context: str,
input_offers: list[o.Offer],
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
) -> ModelEngine:
"""
Determines the appropriate model engine based on the context and input offers.
"""
if context == "similar_offer":
return ModelEngineFactory._get_similar_offer_model(
user, params_in, call_id, input_offers
)
elif context == "recommendation":
return ModelEngineFactory._get_recommendation_model(
user, params_in, call_id, input_offers
)
else:
raise Exception(f"context {context} is not available")

@staticmethod
def _get_similar_offer_model(
user: UserContext,
params_in: PlaylistParams,
call_id: str,
input_offers: list[o.Offer],
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
) -> ModelEngine:
"""
Selects a model engine for the 'similar_offer' context.
"""
if any(offer.is_sensitive for offer in input_offers):
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
return Recommendation(
user=user,
params_in=params_in,
call_id=call_id,
context="recommendation_fallback",
)
elif input_offers:
return SimilarOffer(
user=user,
params_in=params_in,
call_id=call_id,
context="similar_offer",
input_offers=input_offers,
)
else:
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
return Recommendation(
user=user,
params_in=params_in,
call_id=call_id,
context="recommendation_fallback",
)

@staticmethod
def _get_recommendation_model(
user: UserContext,
params_in: PlaylistParams,
call_id: str,
input_offers: list[o.Offer],
cdelabre marked this conversation as resolved.
Show resolved Hide resolved
) -> ModelEngine:
"""
Selects a model engine for the 'recommendation' context.
"""
if input_offers:
return SimilarOffer(
user=user,
params_in=params_in,
call_id=call_id,
context="hybrid_recommendation",
input_offers=input_offers,
)
else:
return Recommendation(
user=user,
params_in=params_in,
call_id=call_id,
context="recommendation",
)

@staticmethod
async def _handle_fallback(
user: UserContext,
params_in: PlaylistParams,
call_id: str,
input_offers: list[o.Offer],
) -> ModelEngine:
"""
Handles fallback by using the 'recommendation_fallback' model.
"""
fallback_model = Recommendation(
user=user,
params_in=params_in,
call_id=call_id,
context="recommendation_fallback",
input_offers=input_offers,
)
return fallback_model
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from huggy.core.model_selection.model_configuration.configuration import ForkOut
from huggy.schemas.playlist_params import PlaylistParams
from huggy.schemas.user import UserContext
from sqlalchemy.ext.asyncio import AsyncSession


class SimilarOffer(ModelEngine):
Expand All @@ -25,15 +24,15 @@ def get_model_configuration(
self, user: UserContext, params_in: PlaylistParams
) -> ForkOut:
return select_sim_model_params(
params_in.model_endpoint, offer=self.offer, offers=params_in.offers
params_in.model_endpoint, input_offers=self.input_offers
)

def get_scorer(self):
# init input
for endpoint in self.model_params.retrieval_endpoints:
endpoint.init_input(
user=self.user,
offer=self.offer,
input_offers=self.input_offers,
params_in=self.params_in,
call_id=self.call_id,
)
Expand All @@ -49,14 +48,5 @@ def get_scorer(self):
model_params=self.model_params,
retrieval_endpoints=self.model_params.retrieval_endpoints,
ranking_endpoint=self.model_params.ranking_endpoint,
offer=self.offer,
input_offers=self.input_offers,
)

async def get_scoring(self, db: AsyncSession) -> list[str]:
if self.offer is not None and self.offer.item_id is None:
return []
if (
len(self.offers) > 0 and len([offer.item_id for offer in self.offers]) == 0
): # to review
return []
return await super().get_scoring(db)
Loading
Loading