Skip to content

Commit

Permalink
set PJRT_DEVICE before loading torch_xla
Browse files Browse the repository at this point in the history
Mitigation for #326

PiperOrigin-RevId: 692234642
  • Loading branch information
chunnienc authored and copybara-github committed Nov 1, 2024
1 parent 284fe62 commit bd09407
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ai_edge_torch/lowertools/torch_xla_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@
import gc
import itertools
import logging
import os
import tempfile
from typing import Any, Dict, Optional, Tuple, Union

if "PJRT_DEVICE" not in os.environ:
# https://github.com/google-ai-edge/ai-edge-torch/issues/326
os.environ["PJRT_DEVICE"] = "CPU"

from ai_edge_torch import model
from ai_edge_torch._convert import conversion_utils
from ai_edge_torch._convert import signature as signature_module
Expand Down

0 comments on commit bd09407

Please sign in to comment.