From d6c1cf2f8258630de5fed8c0ed9339edd7d03676 Mon Sep 17 00:00:00 2001 From: ChenglongMa Date: Wed, 27 Sep 2023 10:44:07 +1000 Subject: [PATCH] Add data type check --- recbole/quick_start/quick_start.py | 37 ++++++++++++++++++------------ 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index 097acc6d9..4d8368e9c 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -12,20 +12,16 @@ ######################## """ import logging -from logging import getLogger - import sys +from collections.abc import MutableMapping +from logging import getLogger - -import pickle from ray import tune from recbole.config import Config from recbole.data import ( create_dataset, data_preparation, - save_split_dataloaders, - load_split_dataloaders, ) from recbole.data.transform import construct_transform from recbole.utils import ( @@ -69,13 +65,15 @@ def run( queue = mp.get_context('spawn').SimpleQueue() config_dict = config_dict or {} - config_dict.update({ - "world_size": world_size, - "ip": ip, - "port": port, - "nproc": nproc, - "offset": group_offset, - }) + config_dict.update( + { + "world_size": world_size, + "ip": ip, + "port": port, + "nproc": nproc, + "offset": group_offset, + } + ) kwargs = { "config_dict": config_dict, "queue": queue, @@ -94,7 +92,12 @@ def run( def run_recbole( - model=None, dataset=None, config_file_list=None, config_dict=None, saved=True, queue=None + model=None, + dataset=None, + config_file_list=None, + config_dict=None, + saved=True, + queue=None, ): r"""A fast running api, which includes the complete process of training and testing a model on a specified dataset @@ -169,11 +172,15 @@ def run_recbole( if config["local_rank"] == 0 and queue is not None: queue.put(result) # for multiprocessing, e.g., mp.spawn - return result # for the single process + return result # for the single process def run_recboles(rank, *args): kwargs = args[-1] + if not isinstance(kwargs, MutableMapping): + raise ValueError( + f"The last argument of run_recboles should be a dict, but got {type(kwargs)}" + ) kwargs["config_dict"] = kwargs.get("config_dict", {}) kwargs["config_dict"]["local_rank"] = rank run_recbole(