Skip to content

Commit

Permalink
fix assert
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 17, 2023
1 parent 4ae1f6d commit 570ebb7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion palm_rlhf_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
self.causal = causal
self.attn_dropout = nn.Dropout(dropout)

assert version.parse(torch.__version__) >= version.parse('2.0.0'), 'in order to use flash attention, you must be using pytorch 2.0 or above'
self.use_flash_attn = use_flash_attn
assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

self.register_buffer("mask", None, persistent=False)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 570ebb7

Please sign in to comment.