Skip to content

Commit

Permalink
Fix logging of GPU ID (#95)
Browse files Browse the repository at this point in the history
* Fix logging

* Fix torch CI GPU issue

* Fix CI GPU install problem

* One more attempt to fix  torch CI GPU issue

* One more chance Fix torch CI GPU issue

* Fix typo
  • Loading branch information
VibhuJawa authored Oct 9, 2024
1 parent ba49e35 commit d7e2643
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
15 changes: 6 additions & 9 deletions ci/test_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ fi
echo "Installing pytorch,transformers and pytest to the environment for crossfit tests..."
mamba install \
cuda-version=$CUDA_VERSION \
conda-forge::pytorch \
conda-forge::transformers \
conda-forge::pytest \
"pytorch>=2.*=*cuda*" \
transformers \
pytest \
sentence-transformers \
sentencepiece \
-c conda-forge \
--override-channels \
-c nvidia \
--yes

# Have to install sentence-transformers from pip
# because conda-forge leads to a torch vision conflict
# which leads to it being installed on CPUs
pip3 install sentence-transformers sentencepiece

# Install the crossfit package in editable mode with test dependencies
pip3 install -e '.[test]'
# Running tests
Expand Down
5 changes: 4 additions & 1 deletion crossfit/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def __init__(self, pre=None, cols=False, keep_cols=None):
self.pre = pre
self.cols = cols
self.keep_cols = keep_cols or []
self.worker_name = getattr(self.get_worker(), "name", 0)

@property
def worker_name(self):
return getattr(self.get_worker(), "name", 0)

def setup(self):
pass
Expand Down

0 comments on commit d7e2643

Please sign in to comment.