-
Notifications
You must be signed in to change notification settings - Fork 20
/
main.py
68 lines (52 loc) · 2.13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
from logging import getLogger
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.utils import init_logger, init_seed, set_color
from ncl import NCL
from trainer import NCLTrainer
def run_single_model(args):
# configurations initialization
config = Config(
model=NCL,
dataset=args.dataset,
config_file_list=args.config_file_list
)
init_seed(config['seed'], config['reproducibility'])
# logger initialization
init_logger(config)
logger = getLogger()
logger.info(config)
# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)
# model loading and initialization
model = NCL(config, train_data.dataset).to(config['device'])
logger.info(model)
# trainer loading and initialization
trainer = NCLTrainer(config, model)
# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=True, show_progress=config['show_progress']
)
# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=True, show_progress=config['show_progress'])
logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result', 'yellow') + f': {test_result}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='yelp', help='The datasets can be: ml-1m, yelp, amazon-books, gowalla-merged, alibaba.')
parser.add_argument('--config', type=str, default='', help='External config file name.')
args, _ = parser.parse_known_args()
# Config files
args.config_file_list = [
'properties/overall.yaml',
'properties/NCL.yaml'
]
if args.dataset in ['ml-1m', 'yelp', 'amazon-books', 'gowalla-merged', 'alibaba']:
args.config_file_list.append(f'properties/{args.dataset}.yaml')
if args.config is not '':
args.config_file_list.append(args.config)
run_single_model(args)