Skip to content

Commit

Permalink
FEAT: load model and continue to train
Browse files Browse the repository at this point in the history
  • Loading branch information
Paitesanshi committed Jan 15, 2023
1 parent c13fc38 commit 2dadb4a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
2 changes: 1 addition & 1 deletion recbole/quick_start/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from recbole.quick_start.quick_start import run_recbole, objective_function, load_data_and_model
from recbole.quick_start.quick_start import run_recbole, objective_function, load_data_and_model, continue_train
31 changes: 30 additions & 1 deletion recbole/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@
from recbole.data import create_dataset, data_preparation, save_split_dataloaders, load_split_dataloaders
from recbole.utils import init_logger, get_model, get_trainer, init_seed, set_color

def continue_train(file_path):
logger = getLogger()
config, model, dataset, train_data, valid_data, test_data = load_data_and_model(
model_file=file_path,
)
# ------changed config--------#

# config['base_learning_rate']=1e-6
##############################
# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'], config['task'], config['robust'])(config, model)

# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=True, show_progress=config['show_progress']
)

# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=True, show_progress=config['show_progress'])

logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result', 'yellow') + f': {test_result}')
return {
'best_valid_score': best_valid_score,
'valid_score_bigger': config['valid_metric_bigger'],
'best_valid_result': best_valid_result,
'test_result': test_result
}


def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True):
r""" A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Expand Down Expand Up @@ -49,7 +79,6 @@ def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=Non
init_seed(config['seed'], config['reproducibility'])
model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
logger.info(model)

# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'],config['task'],config['robust'])(config, model)

Expand Down
23 changes: 23 additions & 0 deletions run_continue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @Time : 2020/7/20
# @Author : Shanlei Mu
# @Email : [email protected]

# UPDATE
# @Time : 2020/10/3, 2020/10/1
# @Author : Yupeng Hou, Zihan Lin
# @Email : [email protected], [email protected]


import argparse

from recbole.quick_start import run_recbole, continue_train

if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument('--file_path', '-f', type=str, default=None, help='config files')

args, _ = parser.parse_known_args()

# config_file_list = args.config_files.strip().split(' ') if args.config_files else None
continue_train(args.file_path)

0 comments on commit 2dadb4a

Please sign in to comment.