From bf0abea7c31e8709c132703b4d7bd162fe9ffd82 Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sat, 30 Nov 2024 16:18:21 +0100 Subject: [PATCH 1/3] [patch] Extended the existing ItemKNN approach to include UserKNN by passing the appropriate knn_method parameter in the config. --- recbole/model/general_recommender/itemknn.py | 40 +++++++++++++------- recbole/properties/model/ItemKNN.yaml | 3 +- tests/model/test_model_auto.py | 8 ++++ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/recbole/model/general_recommender/itemknn.py b/recbole/model/general_recommender/itemknn.py index 47aef23c3..ed6711730 100644 --- a/recbole/model/general_recommender/itemknn.py +++ b/recbole/model/general_recommender/itemknn.py @@ -20,7 +20,7 @@ class ComputeSimilarity: - def __init__(self, dataMatrix, topk=100, shrink=0, normalize=True): + def __init__(self, dataMatrix, topk=100, shrink=0, method='item', normalize=True): r"""Computes the cosine similarity of dataMatrix If it is computed on :math:`URM=|users| \times |items|`, pass the URM. @@ -31,6 +31,7 @@ def __init__(self, dataMatrix, topk=100, shrink=0, normalize=True): dataMatrix (scipy.sparse.csr_matrix): The sparse data matrix. topk (int) : The k value in KNN. shrink (int) : hyper-parameter in calculate cosine distance. + method (str) : Calculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. normalize (bool): If True divide the dot product by the product of the norms. """ @@ -38,17 +39,21 @@ def __init__(self, dataMatrix, topk=100, shrink=0, normalize=True): self.shrink = shrink self.normalize = normalize + self.method = method self.n_rows, self.n_columns = dataMatrix.shape - self.TopK = min(topk, self.n_columns) + + if self.method == 'user': + self.TopK = min(topk, self.n_rows) + else: + self.TopK = min(topk, self.n_columns) self.dataMatrix = dataMatrix.copy() - def compute_similarity(self, method, block_size=100): + def compute_similarity(self, block_size=100): r"""Compute the similarity for the given dataset Args: - method (str) : Caculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. block_size (int): divide matrix to :math:`n\_rows \div block\_size` to calculate cosine_distance if method is 'user', otherwise, divide matrix to :math:`n\_columns \div block\_size`. @@ -68,10 +73,10 @@ def compute_similarity(self, method, block_size=100): self.dataMatrix = self.dataMatrix.astype(np.float32) # Compute sum of squared values to be used in normalization - if method == "user": + if self.method == "user": sumOfSquared = np.array(self.dataMatrix.power(2).sum(axis=1)).ravel() end_local = self.n_rows - elif method == "item": + elif self.method == "item": sumOfSquared = np.array(self.dataMatrix.power(2).sum(axis=0)).ravel() end_local = self.n_columns else: @@ -86,7 +91,7 @@ def compute_similarity(self, method, block_size=100): this_block_size = end_block - start_block # All data points for a given user or item - if method == "user": + if self.method == "user": data = self.dataMatrix[start_block:end_block, :] else: data = self.dataMatrix[:, start_block:end_block] @@ -94,7 +99,7 @@ def compute_similarity(self, method, block_size=100): # Compute similarities - if method == "user": + if self.method == "user": this_block_weights = self.dataMatrix.dot(data.T) else: this_block_weights = self.dataMatrix.T.dot(data) @@ -134,7 +139,7 @@ def compute_similarity(self, method, block_size=100): numNotZeros = np.sum(notZerosMask) values.extend(this_line_weights[top_k_idx][notZerosMask]) - if method == "user": + if self.method == "user": rows.extend(np.ones(numNotZeros) * Index) cols.extend(top_k_idx[notZerosMask]) else: @@ -144,7 +149,7 @@ def compute_similarity(self, method, block_size=100): start_block += block_size # End while - if method == "user": + if self.method == "user": W_sparse = sp.csr_matrix( (values, (rows, cols)), shape=(self.n_rows, self.n_rows), @@ -160,7 +165,9 @@ def compute_similarity(self, method, block_size=100): class ItemKNN(GeneralRecommender): - r"""ItemKNN is a basic model that compute item similarity with the interaction matrix.""" + r"""ItemKNN is a basic model that compute item similarity with the interaction matrix. + Adjusting the value of 'knn_method' in the config file sets the method to either ItemKNN or UserKNN, respectively. + """ input_type = InputType.POINTWISE type = ModelType.TRADITIONAL @@ -170,15 +177,20 @@ def __init__(self, config, dataset): # load parameters info self.k = config["k"] + self.method = config["knn_method"] self.shrink = config["shrink"] if "shrink" in config else 0.0 self.interaction_matrix = dataset.inter_matrix(form="csr").astype(np.float32) shape = self.interaction_matrix.shape assert self.n_users == shape[0] and self.n_items == shape[1] _, self.w = ComputeSimilarity( - self.interaction_matrix, topk=self.k, shrink=self.shrink - ).compute_similarity("item") - self.pred_mat = self.interaction_matrix.dot(self.w).tolil() + self.interaction_matrix, topk=self.k, shrink=self.shrink, method=self.method + ).compute_similarity() + + if self.method == "user": + self.pred_mat = self.w.dot(self.interaction_matrix).tolil() + else: + self.pred_mat = self.interaction_matrix.dot(self.w).tolil() self.fake_loss = torch.nn.Parameter(torch.zeros(1)) self.other_parameter_name = ["w", "pred_mat"] diff --git a/recbole/properties/model/ItemKNN.yaml b/recbole/properties/model/ItemKNN.yaml index 9ce30000c..155f2f835 100644 --- a/recbole/properties/model/ItemKNN.yaml +++ b/recbole/properties/model/ItemKNN.yaml @@ -1,2 +1,3 @@ k: 100 # (int) The neighborhood size. -shrink: 0.0 # (float) A normalization parameter to calculate cosine distance. \ No newline at end of file +shrink: 0.0 # (float) A normalization parameter to calculate cosine distance. +knn_method: 'item' # (string) The method to calculate the similarity matrix ['item','user'] \ No newline at end of file diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 9c18b56c6..965dc728e 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -39,6 +39,14 @@ def test_random(self): def test_itemknn(self): config_dict = { "model": "ItemKNN", + "knn_method": "item" + } + quick_test(config_dict) + + def test_userknn(self): + config_dict = { + "model": "ItemKNN", + "knn_method": "user" } quick_test(config_dict) From d72d4cbc7f317616178463e6996cd0b675ae48c4 Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sun, 1 Dec 2024 11:08:34 +0100 Subject: [PATCH 2/3] Implemented AsymKNN as in Aiolli (2013), https://dl.acm.org/doi/pdf/10.1145/2507157.2507189 --- recbole/model/general_recommender/asymknn.py | 225 +++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 recbole/model/general_recommender/asymknn.py diff --git a/recbole/model/general_recommender/asymknn.py b/recbole/model/general_recommender/asymknn.py new file mode 100644 index 000000000..3a0820e49 --- /dev/null +++ b/recbole/model/general_recommender/asymknn.py @@ -0,0 +1,225 @@ +import numpy as np +import scipy.sparse as sp +import torch +from recbole.model.abstract_recommender import GeneralRecommender +from recbole.utils import InputType, ModelType + +class ComputeSimilarity: + def __init__(self, dataMatrix, topk=100, alpha=0.5, method='item'): + r"""Computes the asymmetric cosine similarity of dataMatrix with alpha parameter. + + Args: + dataMatrix (scipy.sparse.csr_matrix): The sparse data matrix. + topk (int) : The k value in KNN. + alpha (float): Asymmetry control parameter in cosine similarity calculation. + method (str) : Caculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. + """ + + super(ComputeSimilarity, self).__init__() + + self.method = method + self.alpha = alpha + + self.n_rows, self.n_columns = dataMatrix.shape + + if self.method == 'user': + self.TopK = min(topk, self.n_rows) + else: + self.TopK = min(topk, self.n_columns) + + self.dataMatrix = dataMatrix.copy() + + def compute_asym_similarity(self, block_size=100): + r"""Compute the asymmetric cosine similarity for the given dataset. + + Args: + block_size (int): Divide matrix into blocks for efficient calculation. + + Returns: + list: The similar nodes, if method is 'user', the shape is [number of users, neigh_num], + else, the shape is [number of items, neigh_num]. + scipy.sparse.csr_matrix: sparse matrix W, if method is 'user', the shape is [self.n_rows, self.n_rows], + else, the shape is [self.n_columns, self.n_columns]. + """ + + values = [] + rows = [] + cols = [] + neigh = [] + + self.dataMatrix = self.dataMatrix.astype(np.float32) + + if self.method == "user": + sumOfMatrix = np.array(self.dataMatrix.sum(axis=1)).ravel() + end_local = self.n_rows + elif self.method == "item": + sumOfMatrix = np.array(self.dataMatrix.sum(axis=0)).ravel() + end_local = self.n_columns + else: + raise NotImplementedError("Make sure 'method' is in ['user', 'item']!") + + start_block = 0 + + # Compute all similarities using vectorization + while start_block < end_local: + end_block = min(start_block + block_size, end_local) + this_block_size = end_block - start_block + + # All data points for a given user or item + if self.method == "user": + data = self.dataMatrix[start_block:end_block, :] + else: + data = self.dataMatrix[:, start_block:end_block] + data = data.toarray() + + # Compute similarities + if self.method == "user": + this_block_weights = self.dataMatrix.dot(data.T) + else: + this_block_weights = self.dataMatrix.T.dot(data) + + for index_in_block in range(this_block_size): + this_line_weights = this_block_weights[:, index_in_block] + + Index = index_in_block + start_block + this_line_weights[Index] = 0.0 + + # Apply asymmetric cosine normalization + denominator = ( + (sumOfMatrix[Index] ** self.alpha) * + (sumOfMatrix ** (1 - self.alpha)) + 1e-6 + ) + this_line_weights = np.multiply(this_line_weights, 1 / denominator) + + # Sort indices and select TopK + relevant_partition = (-this_line_weights).argpartition(self.TopK - 1)[0:self.TopK] + relevant_partition_sorting = np.argsort(-this_line_weights[relevant_partition]) + top_k_idx = relevant_partition[relevant_partition_sorting] + neigh.append(top_k_idx) + + # Incrementally build sparse matrix, do not add zeros + notZerosMask = this_line_weights[top_k_idx] != 0.0 + numNotZeros = np.sum(notZerosMask) + + values.extend(this_line_weights[top_k_idx][notZerosMask]) + if self.method == "user": + rows.extend(np.ones(numNotZeros) * Index) + cols.extend(top_k_idx[notZerosMask]) + else: + rows.extend(top_k_idx[notZerosMask]) + cols.extend(np.ones(numNotZeros) * Index) + + start_block += block_size + + # End while + if self.method == "user": + W_sparse = sp.csr_matrix( + (values, (rows, cols)), + shape=(self.n_rows, self.n_rows), + dtype=np.float32, + ) + else: + W_sparse = sp.csr_matrix( + (values, (rows, cols)), + shape=(self.n_columns, self.n_columns), + dtype=np.float32, + ) + return neigh, W_sparse.tocsc() + + +class AsymKNN(GeneralRecommender): + r"""AsymKNN: A traditional recommender model based on asymmetric cosine similarity and score prediction. + + AsymKNN computes user-item recommendations by leveraging asymmetric cosine similarity + over the interaction matrix. This model allows for flexible adjustment of similarity + calculations and scoring normalization via several tunable parameters. + + Config: + k (int): Number of neighbors to consider in the similarity calculation. + method (str): Specifies whether to calculate similarities based on users or items. + Valid options are 'user' or 'item'. + alpha (float): Weight parameter for asymmetric cosine similarity, controlling + the importance of the interaction matrix in the similarity computation. + Must be in the range [0, 1]. + q (int): Exponent for adjusting the 'locality of scoring function' after similarity computation. + beta (float): Parameter for controlling the balance between factors in the + final score normalization. Must be in the range [0, 1]. + + Reference: + Aiolli,F et al. Efficient top-n recommendation for very large scale binary rated datasets. + In Proceedings of the 7th ACM conference on Recommender systems (pp. 273-280). ACM. + """ + + input_type = InputType.POINTWISE + type = ModelType.TRADITIONAL + + def __init__(self, config, dataset): + super(AsymKNN, self).__init__(config, dataset) + + # load parameters info + self.k = config["k"] # Size of neighborhood for cosine + self.method = config["knn_method"] # Caculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. + self.alpha = config['alpha'] if 'alpha' in config else 0.5 # Asymmetric cosine parameter + self.q = config['q'] if 'q' in config else 1.0 # Weight adjustment exponent + self.beta = config['beta'] if 'beta' in config else 0.5 # Beta for final score normalization + + assert 0 <= self.alpha <= 1, f"The asymmetric parameter 'alpha' must be value between in [0,1], but got {self.alpha}" + assert 0 <= self.beta <= 1, f"The asymmetric parameter 'beta' must be value between [0,1], but got {self.beta}" + assert isinstance(self.k, int), f"The neighborhood parameter 'k' must be an integer, but got {self.k}" + assert isinstance(self.q, int), f"The exponent parameter 'q' must be an integer, but got {self.q}" + + self.interaction_matrix = dataset.inter_matrix(form="csr").astype(np.float32) + shape = self.interaction_matrix.shape + assert self.n_users == shape[0] and self.n_items == shape[1] + _, self.w = ComputeSimilarity( + self.interaction_matrix, topk=self.k, alpha=self.alpha, method=self.method + ).compute_asym_similarity() + + if self.method == "user": + nominator = self.w.dot(self.interaction_matrix) + factor1 = np.power(np.sqrt(self.w.power(2).sum(axis=1)),2*self.beta) + factor2 = np.power(np.sqrt(self.interaction_matrix.power(2).sum(axis=0)),2*(1-self.beta)) + denominator = factor1.dot(factor2) + 1e-6 + else: + nominator = self.interaction_matrix.dot(self.w) + factor1 = np.power(np.sqrt(self.interaction_matrix.power(2).sum(axis=1)),2*self.beta) + factor2 = np.power(np.sqrt(self.w.power(2).sum(axis=1)),2*(1-self.beta)) + denominator = factor1.dot(factor2.T) + 1e-6 + + self.pred_mat = (nominator / denominator).tolil() + + # Apply 'locality of scoring function' via q: f(w) = w^q + self.pred_mat = self.pred_mat.power(self.q) + + self.fake_loss = torch.nn.Parameter(torch.zeros(1)) + self.other_parameter_name = ["w", "pred_mat"] + + def forward(self, user, item): + pass + + def calculate_loss(self, interaction): + return torch.nn.Parameter(torch.zeros(1)) + + def predict(self, interaction): + user = interaction[self.USER_ID] + item = interaction[self.ITEM_ID] + user = user.cpu().numpy().astype(int) + item = item.cpu().numpy().astype(int) + result = [] + + for index in range(len(user)): + uid = user[index] + iid = item[index] + score = self.pred_mat[uid, iid] + result.append(score) + result = torch.from_numpy(np.array(result)).to(self.device) + return result + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + user = user.cpu().numpy() + + score = self.pred_mat[user, :].toarray().flatten() + result = torch.from_numpy(score).to(self.device) + + return result \ No newline at end of file From deb0972b64a8c45608188f6ba11d5dfa6698580b Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sun, 1 Dec 2024 11:09:26 +0100 Subject: [PATCH 3/3] Documentation and tests for AsymKNN --- asset/model_list.json | 14 +++ ...bole.model.general_recommender.asymknn.rst | 4 + .../recbole.model.general_recommender.rst | 1 + .../user_guide/model/general/asymknn.rst | 88 +++++++++++++++++++ docs/source/user_guide/model_intro.rst | 1 + recbole/model/general_recommender/__init__.py | 1 + recbole/properties/model/AsymKNN.yaml | 5 ++ tests/model/test_model_auto.py | 14 +++ 8 files changed, 128 insertions(+) create mode 100644 docs/source/recbole/recbole.model.general_recommender.asymknn.rst create mode 100644 docs/source/user_guide/model/general/asymknn.rst create mode 100644 recbole/properties/model/AsymKNN.yaml diff --git a/asset/model_list.json b/asset/model_list.json index b28b66ded..a2ba8cf85 100644 --- a/asset/model_list.json +++ b/asset/model_list.json @@ -154,6 +154,20 @@ "repository": "RecBole", "repo_link": "https://github.com/RUCAIBox/RecBole" }, + { + "category": "General Recommendation", + "cate_link": "/docs/user_guide/model_intro.html#general-recommendation", + "year": "2013", + "pub": "RecSys'13", + "model": "AsymKNN", + "model_link": "/docs/user_guide/model/general/asymknn.html", + "paper": "Efficient Top-N Recommendation for Very Large Scale Binary Rated Datasets", + "paper_link": "https://doi.org/10.1145/2507157.2507189", + "authors": "Fabio Aiolli", + "ref_code": "", + "repository": "RecBole", + "repo_link": "https://github.com/RUCAIBox/RecBole" + }, { "category": "General Recommendation", "cate_link": "/docs/user_guide/model_intro.html#general-recommendation", diff --git a/docs/source/recbole/recbole.model.general_recommender.asymknn.rst b/docs/source/recbole/recbole.model.general_recommender.asymknn.rst new file mode 100644 index 000000000..55f1cbce0 --- /dev/null +++ b/docs/source/recbole/recbole.model.general_recommender.asymknn.rst @@ -0,0 +1,4 @@ +.. automodule:: recbole.model.general_recommender.asymknn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/recbole/recbole.model.general_recommender.rst b/docs/source/recbole/recbole.model.general_recommender.rst index 9436371ad..2089ced97 100644 --- a/docs/source/recbole/recbole.model.general_recommender.rst +++ b/docs/source/recbole/recbole.model.general_recommender.rst @@ -4,6 +4,7 @@ recbole.model.general\_recommender .. toctree:: :maxdepth: 4 + recbole.model.general_recommender.asymknn recbole.model.general_recommender.admmslim recbole.model.general_recommender.bpr recbole.model.general_recommender.cdae diff --git a/docs/source/user_guide/model/general/asymknn.rst b/docs/source/user_guide/model/general/asymknn.rst new file mode 100644 index 000000000..97844143e --- /dev/null +++ b/docs/source/user_guide/model/general/asymknn.rst @@ -0,0 +1,88 @@ +AsymKNN +=========== + +Introduction +--------------------- + +`[paper] `_ + +**Title:** Efficient Top-N Recommendation for Very Large Scale Binary Rated Datasets + +**Authors:** Fabio Aiolli + +**Abstract:** We present a simple and scalable algorithm for top-N recommendation able to deal with very large datasets and (binary rated) implicit feedback. We focus on memory-based collaborative filtering +algorithms similar to the well known neighboor based technique for explicit feedback. The major difference, that makes the algorithm particularly scalable, is that it uses positive feedback only +and no explicit computation of the complete (user-by-user or itemby-item) similarity matrix needs to be performed. +The study of the proposed algorithm has been conducted on data from the Million Songs Dataset (MSD) challenge whose task was to suggest a set of songs (out of more than 380k available songs) to more than 100k users given half of the user listening history and +complete listening history of other 1 million people. +In particular, we investigate on the entire recommendation pipeline, starting from the definition of suitable similarity and scoring functions and suggestions on how to aggregate multiple ranking strategies to define the overall recommendation. The technique we are +proposing extends and improves the one that already won the MSD challenge last year. + +In this article, we introduce a versatile class of recommendation algorithms that calculate either user-to-user or item-to-item similarities as the foundation for generating recommendations. This approach enables the flexibility to switch between UserKNN and ItemKNN models depending on the desired application. + +A distinguishing feature of this class of algorithms, exemplified by AsymKNN, is its use of asymmetric cosine similarity, which generalizes the traditional cosine similarity. Specifically, when the asymmetry parameter +``alpha = 0.5``, the method reduces to the standard cosine similarity, while other values of ``alpha`` allow for tailored emphasis on specific aspects of the interaction data. Furthermore, setting the parameter +``beta = 1.0`` ensures a traditional UserKNN or ItemKNN, as the final scores are only divided by a fixed positive constant, preserving the same order of recommendations. + +Running with RecBole +------------------------- + +**Model Hyper-Parameters:** + +- ``k (int)`` : The neighborhood size. Defaults to ``100``. + +- ``alpha (float)`` : Weight parameter for asymmetric cosine similarity. Defaults to ``0.5``. + +- ``beta (float)`` : Parameter for controlling the balance between factors in the final score normalization. Defaults to ``1.0``. + +- ``q (int)`` : The 'locality of scoring function' parameter. Defaults to ``1``. + +**Additional Parameters:** + +- ``knn_method (str)`` : Calculate the similarity of users if method is 'user', otherwise, calculate the similarity of items.. Defaults to ``item``. + + +**A Running Example:** + +Write the following code to a python file, such as `run.py` + +.. code:: python + + from recbole.quick_start import run_recbole + + run_recbole(model='AsymKNN', dataset='ml-100k') + +And then: + +.. code:: bash + + python run.py + +Tuning Hyper Parameters +------------------------- + +If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``. + +.. code:: bash + + k choice [10,50,100,200,250,300,400,500,1000,1500,2000,2500] + alpha choice [0.0,0.2,0.5,0.8,1.0] + beta choice [0.0,0.2,0.5,0.8,1.0] + q choice [1,2,3,4,5,6] + +Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model. + +Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning: + +.. code:: bash + + python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test + +For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`. + +If you want to change parameters, dataset or evaluation settings, take a look at + +- :doc:`../../../user_guide/config_settings` +- :doc:`../../../user_guide/data_intro` +- :doc:`../../../user_guide/train_eval_intro` +- :doc:`../../../user_guide/usage` \ No newline at end of file diff --git a/docs/source/user_guide/model_intro.rst b/docs/source/user_guide/model_intro.rst index 8b4c59d78..7de3f59e6 100644 --- a/docs/source/user_guide/model_intro.rst +++ b/docs/source/user_guide/model_intro.rst @@ -13,6 +13,7 @@ task of top-n recommendation. All the collaborative filter(CF) based models are .. toctree:: :maxdepth: 1 + model/general/asymknn model/general/pop model/general/itemknn model/general/bpr diff --git a/recbole/model/general_recommender/__init__.py b/recbole/model/general_recommender/__init__.py index e71f2b4ec..d5ec68e23 100644 --- a/recbole/model/general_recommender/__init__.py +++ b/recbole/model/general_recommender/__init__.py @@ -1,3 +1,4 @@ +from recbole.model.general_recommender.asymknn import AsymKNN from recbole.model.general_recommender.bpr import BPR from recbole.model.general_recommender.cdae import CDAE from recbole.model.general_recommender.convncf import ConvNCF diff --git a/recbole/properties/model/AsymKNN.yaml b/recbole/properties/model/AsymKNN.yaml new file mode 100644 index 000000000..f711d6860 --- /dev/null +++ b/recbole/properties/model/AsymKNN.yaml @@ -0,0 +1,5 @@ +k: 100 # Number of neighbors to consider in the similarity calculation. +q: 1 # Exponent for adjusting the 'locality of scoring function' after similarity computation. +beta: 1.0 # Parameter for controlling the balance between factors in the final score normalization. +alpha: 0.5 # Weight parameter for asymmetric cosine similarity +knn_method: 'item' # Calculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. \ No newline at end of file diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 965dc728e..fe809ee32 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -50,6 +50,20 @@ def test_userknn(self): } quick_test(config_dict) + def test_asymitemknn(self): + config_dict = { + "model": "AsymKNN", + "knn_method": "item" + } + quick_test(config_dict) + + def test_asymuserknn(self): + config_dict = { + "model": "AsymKNN", + "knn_method": "user" + } + quick_test(config_dict) + def test_bpr(self): config_dict = { "model": "BPR",