Skip to content

Commit

Permalink
[AB-42](ab-test) Update config for retrieval version b endpoint (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdelabre authored Oct 11, 2024
1 parent 565ffc7 commit 61b8713
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def get_instance(self, size: int):
"call_id": self.call_id,
"debug": 1,
"vector_column_name": "booking_number_desc",
"similarity_metric": "dot",
}


Expand All @@ -296,6 +297,7 @@ def get_instance(self, size: int):
"call_id": self.call_id,
"debug": 1,
"vector_column_name": "booking_creation_trend_desc",
"similarity_metric": "dot",
}


Expand All @@ -310,6 +312,7 @@ def get_instance(self, size: int):
"call_id": self.call_id,
"debug": 1,
"vector_column_name": "booking_release_trend_desc",
"similarity_metric": "dot",
}


Expand All @@ -326,6 +329,7 @@ def get_instance(self, size: int):
"debug": 1,
"prefilter": 1,
"vector_column_name": "raw_embeddings",
"similarity_metric": "dot",
}


Expand Down Expand Up @@ -360,6 +364,7 @@ def get_instance(self, size: int):
"similarity_metric": "l2",
"prefilter": 1,
"vector_column_name": "raw_embeddings",
"user_id": str(self.user.user_id),
}


Expand Down Expand Up @@ -391,4 +396,5 @@ def get_instance(self, size: int):
"call_id": self.call_id,
"debug": 1,
"vector_column_name": "booking_number_desc",
"similarity_metric": "dot",
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
query_order=QueryOrderChoices.ITEM_RANK,
),
cold_start_model_type=ModelTypeInput(
retrieval=RetrievalChoices.TOPS,
retrieval=RetrievalChoices.MIX_TOPS,
ranking=RankingChoices.MODEL,
query_order=QueryOrderChoices.ITEM_RANK,
),
Expand All @@ -62,13 +62,13 @@
diversication_type=DiversificationChoices.ON,
),
warn_model_type=ModelTypeInput(
retrieval=RetrievalChoices.MIX,
ranking=RankingChoices.VERSION_B,
retrieval=RetrievalChoices.MIX_VERSION_B,
ranking=RankingChoices.MODEL,
query_order=QueryOrderChoices.ITEM_RANK,
),
cold_start_model_type=ModelTypeInput(
retrieval=RetrievalChoices.TOPS,
ranking=RankingChoices.VERSION_B,
retrieval=RetrievalChoices.MIX_TOPS_VERSION_B,
ranking=RankingChoices.MODEL,
query_order=QueryOrderChoices.ITEM_RANK,
),
fork_params=ForkParamsInput(
Expand Down Expand Up @@ -194,13 +194,13 @@
diversication_type=DiversificationChoices.OFF,
),
warn_model_type=ModelTypeInput(
retrieval=RetrievalChoices.MIX,
ranking=RankingChoices.VERSION_B,
retrieval=RetrievalChoices.MIX_VERSION_B,
ranking=RankingChoices.MODEL,
query_order=QueryOrderChoices.ITEM_RANK,
),
cold_start_model_type=ModelTypeInput(
retrieval=RetrievalChoices.MIX,
ranking=RankingChoices.VERSION_B,
retrieval=RetrievalChoices.MIX_VERSION_B,
ranking=RankingChoices.MODEL,
query_order=QueryOrderChoices.ITEM_RANK,
),
fork_params=ForkParamsInput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
offer_retrieval_endpoint = OfferRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval,
size=100,
fallback_endpoints=[
RetrievalEndpointName.recommendation_semantic_retrieval,
],
use_cache=True,
)

offer_retrieval_endpoint_version_b = OfferRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval_version_b,
size=100,
use_cache=True,
)

semantic_offer_retrieval_endpoint = OfferSemanticRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_semantic_retrieval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from huggy.core.model_selection.endpoint import RetrievalEndpointName

# default
filter_retrieval_endpoint = FilterRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval,
size=150,
Expand All @@ -18,16 +19,39 @@
use_cache=True,
)


trend_release_date_retrieval_endpoint = ReleaseTrendRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval,
size=150,
use_cache=True,
)


trend_creation_date_retrieval_endpoint = CreationTrendRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval,
size=150,
use_cache=True,
)

# version B
filter_retrieval_endpoint_version_b = FilterRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval_version_b,
size=150,
use_cache=True,
)

recommendation_retrieval_endpoint_version_b = RecommendationRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval_version_b,
size=150,
use_cache=True,
)

trend_release_date_retrieval_endpoint_version_b = ReleaseTrendRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval_version_b,
size=150,
use_cache=True,
)

trend_creation_date_retrieval_endpoint_version_b = CreationTrendRetrievalEndpoint(
endpoint_name=RetrievalEndpointName.recommendation_user_retrieval_version_b,
size=150,
use_cache=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@ def get_retrieval(self, model_type) -> list[RetrievalEndpoint]:
]
return {
RetrievalChoices.MIX: default,
RetrievalChoices.MIX_VERSION_B: [
user_retrieval.filter_retrieval_endpoint_version_b,
user_retrieval.recommendation_retrieval_endpoint_version_b,
user_retrieval.trend_release_date_retrieval_endpoint_version_b,
user_retrieval.trend_creation_date_retrieval_endpoint_version_b,
],
RetrievalChoices.MIX_TOPS: [
user_retrieval.filter_retrieval_endpoint,
user_retrieval.trend_release_date_retrieval_endpoint,
user_retrieval.trend_creation_date_retrieval_endpoint,
],
RetrievalChoices.MIX_TOPS_VERSION_B: [
user_retrieval.filter_retrieval_endpoint_version_b,
user_retrieval.trend_release_date_retrieval_endpoint_version_b,
user_retrieval.trend_creation_date_retrieval_endpoint_version_b,
],
RetrievalChoices.MIX_RECOMMENDATION: [
user_retrieval.recommendation_retrieval_endpoint,
user_retrieval.filter_retrieval_endpoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@ class SimilarModelConfigurationInput(ModelConfigurationInput):
def get_retrieval(self, model_type) -> list[RetrievalEndpoint]:
default = [
offer_retrieval.offer_retrieval_endpoint,
# offer_retrieval.semantic_offer_retrieval_endpoint,
]
return {
RetrievalChoices.MIX: default,
RetrievalChoices.SEMANTIC: [
offer_retrieval.semantic_offer_retrieval_endpoint,
RetrievalChoices.MIX_VERSION_B: [
offer_retrieval.offer_retrieval_endpoint_version_b
],
RetrievalChoices.MIX_TOPS: [
offer_retrieval.offer_retrieval_endpoint,
offer_retrieval.offer_filter_retrieval_endpoint,
RetrievalChoices.SEMANTIC: [
offer_retrieval.semantic_offer_retrieval_endpoint,
],
}.get(model_type, default)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class RetrievalChoices(Enum):
RAW_RECOMMENDATION = "raw_reco"
RECOMMENDATION_VERSION_B = "reco_b"
MIX_VERSION_B = "mix_b"
MIX_TOPS_VERSION_B = "m&t_b"
SEMANTIC = "sm"


Expand Down

0 comments on commit 61b8713

Please sign in to comment.