Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opt-125m-gptq pytorch model support #260

Open
AmosLewis opened this issue Jun 12, 2024 · 0 comments
Open

opt-125m-gptq pytorch model support #260

AmosLewis opened this issue Jun 12, 2024 · 0 comments
Labels

Comments

@AmosLewis
Copy link
Collaborator

tests model-run onnx-import torch-mlir iree-compile inference
pytorch/models/opt-125m-gptq failed notrun notrun notrun notrun

python ./run.py --torchmlirbuild ../../torch-mlir/build --tolerance 0.001 0.001 --cachedir ./huggingface_cache --ireebuild ../../iree-build -f pytorch -g models --mode onnx --report --torchtolinalg --tests pytorch/models/pytorch/models/opt-125m-gptq

iree candidate-20240610.920

torch-mlir 
commit 7e0e23c66820d1db548103acbdf1337f701dc5a3 (upstream/main)
Author: Sambhav Jain <[email protected]>
Date:   Sun Jun 9 00:32:49 2024 -0700

    Test custom op import with symbolic shapes (#3431)
    
    Tests the basic constructs of registering a custom op and its abstract
    implementations (with FakeTensors) in python, going through TorchDynamo
    export, followed by importing the shape expressions in the Torch
    dialect.
    
    Also fixes the importer were previously the symbolic bind op insertion
    was not gated in one place.

model-run.log

Traceback (most recent call last):
  File "/home/chi/src/SHARK-TestSuite/e2eshark/test-run/pytorch/models/opt-125m-gptq/runmodel.py", line 149, in <module>
    onnx_program = torch.onnx.export(model, E2ESHARK_CHECK["input"], onnx_name)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1138, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1969, in _run_symbolic_function
    raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::bitwise_right_shift' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant