Skip to content

Commit

Permalink
Add data type check
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenglongMa committed Sep 27, 2023
1 parent 85634a6 commit d6c1cf2
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions recbole/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,16 @@
########################
"""
import logging
from logging import getLogger

import sys
from collections.abc import MutableMapping
from logging import getLogger


import pickle
from ray import tune

from recbole.config import Config
from recbole.data import (
create_dataset,
data_preparation,
save_split_dataloaders,
load_split_dataloaders,
)
from recbole.data.transform import construct_transform
from recbole.utils import (
Expand Down Expand Up @@ -69,13 +65,15 @@ def run(
queue = mp.get_context('spawn').SimpleQueue()

config_dict = config_dict or {}
config_dict.update({
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": group_offset,
})
config_dict.update(
{
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": group_offset,
}
)
kwargs = {
"config_dict": config_dict,
"queue": queue,
Expand All @@ -94,7 +92,12 @@ def run(


def run_recbole(
model=None, dataset=None, config_file_list=None, config_dict=None, saved=True, queue=None
model=None,
dataset=None,
config_file_list=None,
config_dict=None,
saved=True,
queue=None,
):
r"""A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Expand Down Expand Up @@ -169,11 +172,15 @@ def run_recbole(
if config["local_rank"] == 0 and queue is not None:
queue.put(result) # for multiprocessing, e.g., mp.spawn

return result # for the single process
return result # for the single process


def run_recboles(rank, *args):
kwargs = args[-1]
if not isinstance(kwargs, MutableMapping):
raise ValueError(
f"The last argument of run_recboles should be a dict, but got {type(kwargs)}"
)
kwargs["config_dict"] = kwargs.get("config_dict", {})
kwargs["config_dict"]["local_rank"] = rank
run_recbole(
Expand Down

0 comments on commit d6c1cf2

Please sign in to comment.