Skip to content

Commit

Permalink
FIX: add base.yaml and fix tmtf
Browse files Browse the repository at this point in the history
  • Loading branch information
Paitesanshi committed Nov 25, 2022
1 parent 034c8c5 commit b0f5252
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 6 deletions.
58 changes: 58 additions & 0 deletions debias_config/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

# general
gpu_id: 0
use_gpu: True
seed: 2020
state: INFO
reproducibility: True
data_path: 'dataset/'
checkpoint_dir: 'saved'
show_progress: True
save_dataset: False
dataset_save_path: ~
save_dataloaders: False
dataloaders_save_path: ~
log_wandb: False
wandb_project: 'recbole'
LABEL_FIELD : 'rating'
RATING_FIELD : 'rating'
TIME_FIELD: 'wday'
#WDAY_FIELD: 'wday'
# training settings
epochs: 100
train_batch_size: 1024
learner: adam
learning_rate: 0.001
neg_sampling:
uniform: 10

eval_step: 1
stopping_step: 50
clip_grad_norm: ~
# clip_grad_norm: {'max_norm': 5, 'norm_type': 2}
weight_decay: 0.0
loss_decimal_place: 4
require_pow: False

# evaluation settings
eval_args:
split: {'RS':[0.8,0.1,0.1]}
group_by: user
order: TO
mode: labeled
repeatable: False
metrics: ["MSE","RMSE","MAE","AUC","LogLoss"]
topk: [10]
load_col: {'inter': ['user_id', 'item_id','rating','timestamp','wday']}
valid_metric: rmse
valid_metric_bigger: True
eval_batch_size: 2048
metric_decimal_place: 4
normalize_field: ['timestamp','rating']
normalize_all: ~
K: 7
T: 100
M: 3
sig: 0.5
embedding_size: 64
benchmark_filename: ['train','valid','test']
2 changes: 1 addition & 1 deletion debias_config/get_dps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ gamma_v: 1.1
gamma_t: 1.1
task: 'ps' #ps,ips,dr,dips
psv_path: 'init_ps/item_ps_week.pth'
pst_path: 'init_ps/tmtf_time_ps_week.pth'
pst_path: 'init_ps/tmtf_time_ps_global.pth'
4 changes: 2 additions & 2 deletions debias_config/get_ps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ base_freq: 1
gamma_v: 1.1
gamma_t: 1.1
task: 'ps' #ps,ips,dr,dips
psv_path: 'init_ps/item_ps_week.pth'
pst_path: 'init_ps/time_ps_week.pth'
psv_path: 'init_ps/item_ps_global.pth'
pst_path: 'init_ps/time_ps_global.pth'
5 changes: 2 additions & 3 deletions recbole/model/general_recommender/tmtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self, config, dataset):
self.LABEL = config['LABEL_FIELD']
self.RATING = config['RATING_FIELD']
self.TIME = config['TIME_FIELD']
self.WDAY = config['WDAY_FIELD']
# load parameters info
self.embedding_size = config['embedding_size']
self.K = config['K']
Expand Down Expand Up @@ -73,7 +72,7 @@ def calculate_loss(self, interaction, weight=None):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
label = interaction[self.LABEL]
time = interaction[self.WDAY].long()
time = interaction[self.TIME].long()
pred = self.forward(user, item, time)

loss = self.loss_fct(pred, label)
Expand All @@ -83,7 +82,7 @@ def calculate_loss(self, interaction, weight=None):
def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
time = interaction[self.WDAY]
time = interaction[self.TIME].long()
pred = self.forward(user, item, time)
return pred

Expand Down

0 comments on commit b0f5252

Please sign in to comment.