Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery refactored master branch #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

sourcery-ai[bot]
Copy link

@sourcery-ai sourcery-ai bot commented Feb 22, 2021

Branch master refactored by Sourcery.

If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.

See our documentation here.

Run Sourcery locally

Reduce the feedback loop during development by using the Sourcery editor plugin:

Review changes via command line

To manually merge these changes, make sure you're on the master branch, then run:

git fetch origin sourcery/master
git merge --ff-only FETCH_HEAD
git reset HEAD^

train_df['idx'] = range(0, len(train_df))
train_df['idx'] = range(len(train_df))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Dataset.load_dataset refactored with the following changes:

Comment on lines -10 to +29
f = open(output_dir, 'w')

model = model.to(DEVICE)
model.eval()
state_h, state_c = model.init_state()
state_h = state_h.to(DEVICE)
state_c = state_h.to(DEVICE)

i = 0
for batch, (user_id, sequence) in enumerate(dataloder):
sequence = sequence[:,1:].to(DEVICE)

#y_pred, _ = model(sequence, (state_h, state_c)) # when use lgtrnet_v4 model, you need to use this line code since it doesn't have state_h and state_c.
y_pred, (state_h, state_c) = model(sequence, (state_h, state_c))
#y = int(torch.argmax(y_pred).data)
#f.write('%s\n' % y)
topk = torch.topk(y_pred, 10)[1].data[0].tolist()
f.write('%s\n' % topk)

i += 1
#if i > 3 : break
f.close()
with open(output_dir, 'w') as f:
model = model.to(DEVICE)
model.eval()
state_h, state_c = model.init_state()
state_h = state_h.to(DEVICE)
state_c = state_h.to(DEVICE)

i = 0
for batch, (user_id, sequence) in enumerate(dataloder):
sequence = sequence[:,1:].to(DEVICE)

#y_pred, _ = model(sequence, (state_h, state_c)) # when use lgtrnet_v4 model, you need to use this line code since it doesn't have state_h and state_c.
y_pred, (state_h, state_c) = model(sequence, (state_h, state_c))
#y = int(torch.argmax(y_pred).data)
#f.write('%s\n' % y)
topk = torch.topk(y_pred, 10)[1].data[0].tolist()
f.write('%s\n' % topk)

i += 1
#if i > 3 : break
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function inference refactored with the following changes:

Comment on lines -28 to +32

best_loss = float('inf')
for epoch in range(args.max_epochs):
print('Epoch {}/{}'.format(epoch+1, args.max_epochs),)
epoch_loss = {}
epoch_loss = {}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function train_model refactored with the following changes:

DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
return DEVICE
return torch.device("cuda" if USE_CUDA else "cpu")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_device refactored with the following changes:

Comment on lines -30 to +42
if not 'SM_CHANNEL_TRAIN' in os.environ :
if 'SM_CHANNEL_TRAIN' not in os.environ:
os.environ['SM_CHANNEL_TRAIN'] = '%s/data-%s/' % (root_path, kind)
if not 'SM_MODEL_DIR' in os.environ:
if 'SM_MODEL_DIR' not in os.environ:
os.environ['SM_MODEL_DIR'] = '%s/model/' % root_path

# for inference
if not 'SM_CHANNEL_EVAL' in os.environ :
if 'SM_CHANNEL_EVAL' not in os.environ:
os.environ['SM_CHANNEL_EVAL'] = '%s/data-%s/' % (root_path, kind)
if not 'SM_CHANNEL_MODEL' in os.environ :
if 'SM_CHANNEL_MODEL' not in os.environ:
os.environ['SM_CHANNEL_MODEL'] = '%s/model/' % root_path
if not 'SM_OUTPUT_DATA_DIR' in os.environ :
if 'SM_OUTPUT_DATA_DIR' not in os.environ:
os.environ['SM_OUTPUT_DATA_DIR'] = '%s/output/' % root_path

args = get_args()

return args
return get_args()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function set_env refactored with the following changes:

Comment on lines -7 to +12
if len_seq < max_len:
ls_zero = [0 for i in range(max_len - len_seq)]
ls_zero.extend(seq)
_seq = ls_zero
else:
_seq = seq[-max_len:]
if len_seq >= max_len:
return seq[-max_len:]

return _seq
ls_zero = [0 for i in range(max_len - len_seq)]
ls_zero.extend(seq)
return ls_zero
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function list_max_len refactored with the following changes:

Comment on lines -40 to +42
if eval(item_id) in item_dict.keys():
if eval(item_id) in item_dict:
item_dict[eval(item_id)] += 1
else:
item_dict[eval(item_id)] = 1
#print('the user of %s has been done!', user)
#print('the user of %s has been done!', user)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function item_stat refactored with the following changes:

Comment on lines -54 to +52
if flag == True:
if flag:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 54-85 refactored with the following changes:

train_df['idx'] = range(0, len(train_df))
train_df['idx'] = range(len(train_df))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function DataProcess.load_dataset refactored with the following changes:

Comment on lines -12 to +18
if user_id in user_dict.keys():
if user_id in user_dict:
user_dict[eval(user_id)] +=1
else:
user_dict[eval(user_id)] = 1

for item_id in num:
if item_id in item_dict.keys():
if item_id in item_dict:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_user_item_id refactored with the following changes:

@sourcery-ai
Copy link
Author

sourcery-ai bot commented Feb 22, 2021

Sourcery Code Quality Report

✅  Merging this PR will increase code quality in the affected files by 0.04%.

Quality metrics Before After Change
Complexity 2.17 ⭐ 2.17 ⭐ 0.00
Method Length 82.44 🙂 81.89 🙂 -0.55 👍
Working memory 11.44 😞 11.42 😞 -0.02 👍
Quality 63.94% 🙂 63.98% 🙂 0.04% 👍
Other metrics Before After Change
Lines 1680 1706 26
Changed files Quality Before Quality After Quality Change
dataset.py 80.81% ⭐ 80.83% ⭐ 0.02% 👍
inference.py 62.54% 🙂 62.65% 🙂 0.11% 👍
train.py 42.41% 😞 42.50% 😞 0.09% 👍
util.py 80.02% ⭐ 80.35% ⭐ 0.33% 👍
network/lgtrnet_v1.py 62.86% 🙂 62.72% 🙂 -0.14% 👎
network/lgtrnet_v2.py 60.51% 🙂 60.36% 🙂 -0.15% 👎
network/lgtrnet_v2_2.py 63.16% 🙂 63.06% 🙂 -0.10% 👎
network/lgtrnet_v3.py 64.90% 🙂 64.88% 🙂 -0.02% 👎
network/lgtrnet_v4.py 65.46% 🙂 65.35% 🙂 -0.11% 👎
network/lstm.py 54.80% 🙂 54.67% 🙂 -0.13% 👎
tools/data_vis.py 58.29% 🙂 59.17% 🙂 0.88% 👍
tools/dataload.py 76.45% ⭐ 76.46% ⭐ 0.01% 👍
tools/get_user_item_id.py 52.99% 🙂 53.31% 🙂 0.32% 👍

Here are some functions in these files that still need a tune-up:

File Function Complexity Length Working Memory Quality Recommendation
train.py train_model 25 😞 234 ⛔ 16 ⛔ 25.98% 😞 Refactor to reduce nesting. Try splitting into smaller methods. Extract out complex expressions
network/lstm.py Model.__init__ 4 ⭐ 187 😞 19 ⛔ 43.84% 😞 Try splitting into smaller methods. Extract out complex expressions
network/lgtrnet_v2.py Model.__init__ 0 ⭐ 239 ⛔ 16 ⛔ 46.87% 😞 Try splitting into smaller methods. Extract out complex expressions
network/lgtrnet_v1.py LgTrBlock.__init__ 2 ⭐ 171 😞 15 😞 50.97% 🙂 Try splitting into smaller methods. Extract out complex expressions
network/lgtrnet_v2.py LgTrBlock.__init__ 2 ⭐ 171 😞 15 😞 50.97% 🙂 Try splitting into smaller methods. Extract out complex expressions

Legend and Explanation

The emojis denote the absolute quality of the code:

  • ⭐ excellent
  • 🙂 good
  • 😞 poor
  • ⛔ very poor

The 👍 and 👎 indicate whether the quality has improved or gotten worse with this pull request.


Please see our documentation here for details on how these metrics are calculated.

We are actively working on this report - lots more documentation and extra metrics to come!

Let us know what you think of it by mentioning @sourcery-ai in a comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0 participants