-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored master branch #1
base: master
Are you sure you want to change the base?
Conversation
train_df['idx'] = range(0, len(train_df)) | ||
train_df['idx'] = range(len(train_df)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function Dataset.load_dataset
refactored with the following changes:
- Replace range(0, x) with range(x) (
remove-zero-from-range
)
f = open(output_dir, 'w') | ||
|
||
model = model.to(DEVICE) | ||
model.eval() | ||
state_h, state_c = model.init_state() | ||
state_h = state_h.to(DEVICE) | ||
state_c = state_h.to(DEVICE) | ||
|
||
i = 0 | ||
for batch, (user_id, sequence) in enumerate(dataloder): | ||
sequence = sequence[:,1:].to(DEVICE) | ||
|
||
#y_pred, _ = model(sequence, (state_h, state_c)) # when use lgtrnet_v4 model, you need to use this line code since it doesn't have state_h and state_c. | ||
y_pred, (state_h, state_c) = model(sequence, (state_h, state_c)) | ||
#y = int(torch.argmax(y_pred).data) | ||
#f.write('%s\n' % y) | ||
topk = torch.topk(y_pred, 10)[1].data[0].tolist() | ||
f.write('%s\n' % topk) | ||
|
||
i += 1 | ||
#if i > 3 : break | ||
f.close() | ||
with open(output_dir, 'w') as f: | ||
model = model.to(DEVICE) | ||
model.eval() | ||
state_h, state_c = model.init_state() | ||
state_h = state_h.to(DEVICE) | ||
state_c = state_h.to(DEVICE) | ||
|
||
i = 0 | ||
for batch, (user_id, sequence) in enumerate(dataloder): | ||
sequence = sequence[:,1:].to(DEVICE) | ||
|
||
#y_pred, _ = model(sequence, (state_h, state_c)) # when use lgtrnet_v4 model, you need to use this line code since it doesn't have state_h and state_c. | ||
y_pred, (state_h, state_c) = model(sequence, (state_h, state_c)) | ||
#y = int(torch.argmax(y_pred).data) | ||
#f.write('%s\n' % y) | ||
topk = torch.topk(y_pred, 10)[1].data[0].tolist() | ||
f.write('%s\n' % topk) | ||
|
||
i += 1 | ||
#if i > 3 : break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function inference
refactored with the following changes:
- Use
with
when opening file to ensure closure (ensure-file-closed
)
|
||
best_loss = float('inf') | ||
for epoch in range(args.max_epochs): | ||
print('Epoch {}/{}'.format(epoch+1, args.max_epochs),) | ||
epoch_loss = {} | ||
epoch_loss = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function train_model
refactored with the following changes:
- Use previously assigned local variable (
use-assigned-variable
) - Replace unneeded comprehension with generator (
comprehension-to-generator
)
DEVICE = torch.device("cuda" if USE_CUDA else "cpu") | ||
return DEVICE | ||
return torch.device("cuda" if USE_CUDA else "cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_device
refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable
)
if not 'SM_CHANNEL_TRAIN' in os.environ : | ||
if 'SM_CHANNEL_TRAIN' not in os.environ: | ||
os.environ['SM_CHANNEL_TRAIN'] = '%s/data-%s/' % (root_path, kind) | ||
if not 'SM_MODEL_DIR' in os.environ: | ||
if 'SM_MODEL_DIR' not in os.environ: | ||
os.environ['SM_MODEL_DIR'] = '%s/model/' % root_path | ||
|
||
# for inference | ||
if not 'SM_CHANNEL_EVAL' in os.environ : | ||
if 'SM_CHANNEL_EVAL' not in os.environ: | ||
os.environ['SM_CHANNEL_EVAL'] = '%s/data-%s/' % (root_path, kind) | ||
if not 'SM_CHANNEL_MODEL' in os.environ : | ||
if 'SM_CHANNEL_MODEL' not in os.environ: | ||
os.environ['SM_CHANNEL_MODEL'] = '%s/model/' % root_path | ||
if not 'SM_OUTPUT_DATA_DIR' in os.environ : | ||
if 'SM_OUTPUT_DATA_DIR' not in os.environ: | ||
os.environ['SM_OUTPUT_DATA_DIR'] = '%s/output/' % root_path | ||
|
||
args = get_args() | ||
|
||
return args | ||
return get_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function set_env
refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable
) - Simplify logical expression using De Morgan identities (
de-morgan
)
if len_seq < max_len: | ||
ls_zero = [0 for i in range(max_len - len_seq)] | ||
ls_zero.extend(seq) | ||
_seq = ls_zero | ||
else: | ||
_seq = seq[-max_len:] | ||
if len_seq >= max_len: | ||
return seq[-max_len:] | ||
|
||
return _seq | ||
ls_zero = [0 for i in range(max_len - len_seq)] | ||
ls_zero.extend(seq) | ||
return ls_zero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function list_max_len
refactored with the following changes:
- Lift return into if (
lift-return-into-if
) - Swap if/else branches (
swap-if-else-branches
) - Remove unnecessary else after guard condition (
remove-unnecessary-else
)
if eval(item_id) in item_dict.keys(): | ||
if eval(item_id) in item_dict: | ||
item_dict[eval(item_id)] += 1 | ||
else: | ||
item_dict[eval(item_id)] = 1 | ||
#print('the user of %s has been done!', user) | ||
#print('the user of %s has been done!', user) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function item_stat
refactored with the following changes:
- Remove unnecessary call to keys() (
remove-dict-keys
)
if flag == True: | ||
if flag: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 54-85
refactored with the following changes:
- Hoist repeated code outside conditional statement (
hoist-statement-from-if
) - Simplify comparison to boolean (
simplify-boolean-comparison
)
train_df['idx'] = range(0, len(train_df)) | ||
train_df['idx'] = range(len(train_df)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function DataProcess.load_dataset
refactored with the following changes:
- Replace range(0, x) with range(x) (
remove-zero-from-range
)
if user_id in user_dict.keys(): | ||
if user_id in user_dict: | ||
user_dict[eval(user_id)] +=1 | ||
else: | ||
user_dict[eval(user_id)] = 1 | ||
|
||
for item_id in num: | ||
if item_id in item_dict.keys(): | ||
if item_id in item_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_user_item_id
refactored with the following changes:
- Remove unnecessary call to keys() (
remove-dict-keys
)
Sourcery Code Quality Report✅ Merging this PR will increase code quality in the affected files by 0.04%.
Here are some functions in these files that still need a tune-up:
Legend and ExplanationThe emojis denote the absolute quality of the code:
The 👍 and 👎 indicate whether the quality has improved or gotten worse with this pull request. Please see our documentation here for details on how these metrics are calculated. We are actively working on this report - lots more documentation and extra metrics to come! Let us know what you think of it by mentioning @sourcery-ai in a comment. |
Branch
master
refactored by Sourcery.If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.
See our documentation here.
Run Sourcery locally
Reduce the feedback loop during development by using the Sourcery editor plugin:
Review changes via command line
To manually merge these changes, make sure you're on the
master
branch, then run: