Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG: RuntimeError: Couldn't instantiate class <class 'mistral_inference.args.TransformerArgs'> using init args dict_keys(['dim', 'n_layers', 'vocab_size', 'model_type']) #221

Open
NM5035 opened this issue Sep 20, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@NM5035
Copy link

NM5035 commented Sep 20, 2024

Python -VV

Python 3.10

Pip Freeze

absl-py==1.4.0
accelerate==0.34.2
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
alabaster==0.7.16
albucore==0.0.15
albumentations==1.4.15
altair==4.2.2
annotated-types==0.7.0
anyio==3.7.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.5.1
arviz==0.19.0
astropy==6.1.3
astropy-iers-data==0.2024.9.16.0.32.21
astunparse==1.6.3
async-timeout==4.0.3
atpublic==4.1.0
attrs==24.2.0
audioread==3.0.1
autograd==1.7.0
babel==2.16.0
backcall==0.2.0
beautifulsoup4==4.12.3
bidict==0.23.1
bigframes==1.17.0
bigquery-magics==0.2.0
bleach==6.1.0
blinker==1.4
blis==0.7.11
blosc2==2.0.0
bokeh==3.4.3
bqplot==0.12.43
branca==0.7.2
build==1.2.2
CacheControl==0.14.0
cachetools==5.5.0
catalogue==2.0.10
causal-conv1d==1.4.0
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.86
clarabel==0.9.0
click==8.1.7
cloudpathlib==0.19.0
cloudpickle==2.2.1
cmake==3.30.3
cmdstanpy==1.2.4
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.5
cons==0.4.6
contextlib2==21.6.0
contourpy==1.3.0
cryptography==43.0.1
cuda-python==12.2.1
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.4.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=57366e7ef09dc63e0b389aff20df6c37d91e2790065861ee31a4720149f5b694
cufflinks==0.17.3
cupy-cuda12x==12.2.0
cvxopt==1.3.2
cvxpy==1.5.3
cycler==0.12.1
cymem==2.0.8
Cython==3.0.11
dask==2024.8.0
datascience==0.17.6
db-dtypes==1.3.0
dbus-python==1.2.18
debugpy==1.6.6
decorator==4.4.2
defusedxml==0.7.1
distributed==2024.8.0
distro==1.7.0
dlib==19.24.2
dm-tree==0.1.8
docstring_parser==0.16
docutils==0.18.1
dopamine_rl==4.0.9
duckdb==1.1.0
earthengine-api==1.0.0
easydict==1.13
ecos==2.0.14
editdistance==0.8.1
eerepr==0.0.4
einops==0.8.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
entrypoints==0.4
et-xmlfile==1.1.0
etils==1.9.4
etuples==0.3.9
eval_type_backport==0.2.0
exceptiongroup==1.2.2
fastai==2.7.17
fastcore==1.7.8
fastdownload==0.0.7
fastjsonschema==2.20.0
fastprogress==1.0.3
fastrlock==0.8.2
filelock==3.16.0
fire==0.6.0
firebase-admin==6.5.0
Flask==2.2.5
flatbuffers==24.3.25
flax==0.8.4
folium==0.17.0
fonttools==4.53.1
frozendict==2.4.4
frozenlist==1.4.1
fsspec==2024.6.1
future==1.0.0
gast==0.6.0
gcsfs==2024.6.1
GDAL==3.6.4
gdown==5.1.0
geemap==0.34.1
gensim==4.3.3
geocoder==1.38.1
geographiclib==2.0
geopandas==1.0.1
geopy==2.4.1
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-ai-generativelanguage==0.6.6
google-api-core==2.19.2
google-api-python-client==2.137.0
google-auth==2.27.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.1
google-cloud-aiplatform==1.66.0
google-cloud-bigquery==3.25.0
google-cloud-bigquery-connection==1.15.5
google-cloud-bigquery-storage==2.26.0
google-cloud-bigtable==2.26.0
google-cloud-core==2.4.1
google-cloud-datastore==2.19.0
google-cloud-firestore==2.16.1
google-cloud-functions==1.16.5
google-cloud-iam==2.15.2
google-cloud-language==2.13.4
google-cloud-pubsub==2.23.1
google-cloud-resource-manager==1.12.5
google-cloud-storage==2.8.0
google-cloud-translate==3.15.5
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=e86f433b6968bdd1e17b9e4cf8e0b4f105b3f6a409ffcd410336d06853a81096
google-crc32c==1.6.0
google-generativeai==0.7.2
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.65.0
googledrivedownloader==0.4
graphviz==0.20.3
greenlet==3.1.0
grpc-google-iam-v1==0.13.1
grpcio==1.64.1
grpcio-status==1.48.2
gspread==6.0.2
gspread-dataframe==3.3.1
gym==0.25.2
gym-notices==0.0.8
h5netcdf==1.3.0
h5py==3.11.0
holidays==0.57
holoviews==1.19.1
html5lib==1.1
httpimport==1.4.0
httplib2==0.22.0
huggingface-hub==0.24.7
humanize==4.10.0
hyperopt==0.2.7
ibis-framework==8.0.0
idna==3.10
imageio==2.35.1
imageio-ffmpeg==0.5.1
imagesize==1.4.1
imbalanced-learn==0.12.3
imgaug==0.4.0
immutabledict==4.2.0
importlib_metadata==8.5.0
importlib_resources==6.4.5
imutils==0.5.4
inflect==7.4.0
iniconfig==2.0.0
intel-cmplr-lib-ur==2024.2.1
intel-openmp==2024.2.1
ipyevents==2.0.2
ipyfilechooser==0.6.0
ipykernel==5.5.6
ipyleaflet==0.19.2
ipyparallel==8.8.0
ipython==7.34.0
ipython-genutils==0.2.0
ipython-sql==0.5.0
ipytree==0.2.2
ipywidgets==7.7.1
itsdangerous==2.2.0
jax==0.4.26
jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750
jeepney==0.7.1
jellyfish==1.1.0
jieba==0.42.1
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.3.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter-leaflet==0.19.2
jupyter-server==1.24.0
jupyter_core==5.7.2
jupyterlab_pygments==0.3.0
jupyterlab_widgets==3.0.13
kaggle==1.6.17
kagglehub==0.2.9
keras==3.4.1
keyring==23.5.0
kiwisolver==1.4.7
langcodes==3.4.0
language_data==1.2.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
lightgbm==4.5.0
linkify-it-py==2.0.3
llvmlite==0.43.0
locket==1.0.0
logical-unification==0.4.6
lxml==4.9.4
mamba-ssm==2.2.2
marisa-trie==1.2.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.1
matplotlib-inline==0.1.7
matplotlib-venn==1.1.1
mdit-py-plugins==0.4.2
mdurl==0.1.2
miniKanren==1.0.3
missingno==0.5.2
mistral_common==1.4.2
mistral_inference==1.4.0
mistune==0.8.4
mizani==0.11.4
mkl==2024.2.2
ml-dtypes==0.4.1
mlxtend==0.23.1
more-itertools==10.5.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.8
multidict==6.1.0
multipledispatch==1.0.0
multitasking==0.0.11
murmurhash==1.0.10
music21==9.1.0
namex==0.0.8
natsort==8.4.0
nbclassic==1.1.0
nbclient==0.10.0
nbconvert==6.5.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nibabel==5.2.1
ninja==1.11.1.1
nltk==3.8.1
notebook==6.5.5
notebook_shim==0.2.4
numba==0.60.0
numexpr==2.10.1
numpy==1.26.4
nvidia-nccl-cu12==2.23.4
nvtx==0.2.10
oauth2client==4.1.3
oauthlib==3.2.2
opencv-contrib-python==4.10.0.84
opencv-python==4.10.0.84
opencv-python-headless==4.10.0.84
openpyxl==3.1.5
opt-einsum==3.3.0
optax==0.2.2
optree==0.12.1
orbax-checkpoint==0.6.4
osqp==0.6.7.post0
packaging==24.1
pandas==2.1.4
pandas-datareader==0.10.0
pandas-gbq==0.23.1
pandas-stubs==2.1.4.231227
pandocfilters==1.5.1
panel==1.4.5
param==2.1.1
parso==0.8.4
parsy==2.1
partd==1.4.2
pathlib==1.0.1
patsy==0.5.6
peewee==3.17.6
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.4.0
pip-tools==7.4.1
platformdirs==4.3.4
plotly==5.15.0
plotnine==0.13.6
pluggy==1.5.0
polars==1.6.0
pooch==1.8.2
portpicker==1.5.2
prefetch_generator==1.0.3
preshed==3.0.9
prettytable==3.11.0
proglog==0.1.10
progressbar2==4.5.0
prometheus_client==0.20.0
promise==2.3
prompt_toolkit==3.0.47
prophet==1.1.5
proto-plus==1.24.0
protobuf==3.20.3
psutil==5.9.5
psycopg2==2.9.9
ptyprocess==0.7.0
py-cpuinfo==9.0.0
py4j==0.10.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycocotools==2.0.8
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydata-google-auth==1.8.2
pydot==3.0.1
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
PyDrive2==1.20.0
pyerfa==2.0.1.4
pygame==2.6.0
Pygments==2.18.0
PyGObject==3.42.1
PyJWT==2.9.0
pymc==5.16.2
pymystem3==0.2.0
pynvjitlink-cu12==0.3.0
pyogrio==0.9.0
PyOpenGL==3.1.7
pyOpenSSL==24.2.1
pyparsing==3.1.4
pyperclip==1.9.0
pyproj==3.6.1
pyproject_hooks==1.1.0
pyshp==2.3.1
PySocks==1.7.1
pytensor==2.25.4
pytest==7.4.4
python-apt==2.4.0
python-box==7.2.0
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==8.0.4
python-utils==3.8.2
pytz==2024.2
pyviz_comms==3.0.3
PyYAML==6.0.2
pyzmq==24.0.1
qdldl==0.1.7.post4
ratelim==0.1.6
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
requests-oauthlib==1.3.1
requirements-parser==0.9.0
rich==13.8.1
rmm-cu12==24.4.0
rpds-py==0.20.0
rpy2==3.4.2
rsa==4.9
safetensors==0.4.5
scikit-image==0.23.2
scikit-learn==1.3.2
scipy==1.13.1
scooby==0.10.0
scs==3.2.7
seaborn==0.13.1
SecretStorage==3.3.1
Send2Trash==1.8.3
sentencepiece==0.2.0
shapely==2.0.6
shellingham==1.5.4
simple-parsing==0.1.6
six==1.16.0
sklearn-pandas==2.2.0
smart-open==7.0.4
sniffio==1.3.1
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.6
soxr==0.5.0.post1
spacy==3.7.6
spacy-legacy==3.0.12
spacy-loggers==1.0.5
Sphinx==5.0.2
sphinxcontrib-applehelp==2.0.0
sphinxcontrib-devhelp==2.0.0
sphinxcontrib-htmlhelp==2.1.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==2.0.0
sphinxcontrib-serializinghtml==2.0.0
SQLAlchemy==2.0.35
sqlglot==20.11.0
sqlparse==0.5.1
srsly==2.4.8
stanio==0.5.1
statsmodels==0.14.3
StrEnum==0.4.15
sympy==1.13.2
tables==3.8.0
tabulate==0.9.0
tbb==2021.13.1
tblib==3.0.0
tenacity==9.0.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorflow==2.17.0
tensorflow-datasets==4.9.6
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-metadata==1.15.0
tensorflow-probability==0.24.0
tensorstore==0.1.65
termcolor==2.4.0
terminado==0.18.1
text-unidecode==1.3
textblob==0.17.1
tf-slim==1.1.0
tf_keras==2.17.0
thinc==8.2.5
threadpoolctl==3.5.0
tifffile==2024.8.30
tiktoken==0.7.0
tinycss2==1.3.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch @ https://download.pytorch.org/whl/cu121_full/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=f3ed9a2b7f8671b2b32a2f036d1b81055eb3ad9b18ba43b705aa34bae4289e1a
torchaudio @ https://download.pytorch.org/whl/cu121_full/torchaudio-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=da8c87c80a1c1376a48dc33eef30b03bbdf1df25a05bd2b1c620b8811c7b19be
torchsummary==1.5.1
torchvision @ https://download.pytorch.org/whl/cu121_full/torchvision-0.19.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=b8cc4bf381b75522995b601e07a1b433b5fd925dc3e34a7fa6cd22f449d65379
tornado==6.3.3
tqdm==4.66.5
traitlets==5.7.1
traittypes==0.2.1
transformers==4.44.2
triton==3.0.0
tweepy==4.14.0
typeguard==4.3.0
typer==0.12.5
types-pytz==2024.2.0.20240913
types-setuptools==75.1.0.20240917
typing_extensions==4.12.2
tzdata==2024.1
tzlocal==5.2
uc-micro-py==1.0.3
uritemplate==4.1.1
urllib3==2.0.7
vega-datasets==0.9.0
wadllib==1.3.6
wasabi==1.1.3
wcwidth==0.2.13
weasel==0.4.1
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.4
widgetsnbextension==3.6.9
wordcloud==1.9.3
wrapt==1.16.0
xarray==2024.9.0
xarray-einstats==0.7.0
xformers==0.0.28.post1
xgboost==2.1.1
xlrd==2.0.1
xyzservices==2024.9.0
yarl==1.11.1
yellowbrick==1.5
yfinance==0.2.43
zict==3.0.0
zipp==3.20.2

Reproduction Steps

  1. I install pip install mamba_ssm causal-conv1d
  2. Apply the [Fix issue 496 #501]: Fix issue 496 state-spaces/mamba#501
  3. Load this model locally in content diredctory in colab env : mistralai/Mamba-Codestral-7B-v0.1
  4. Use Mamba-Codestral-7B in inference
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest

# load tokenizer
mistral_tokenizer = MistralTokenizer.from_file("/content/mistral_models/mamba-codestral/tokenizer.model.v3")
# chat completion request
completion_request = ChatCompletionRequest(messages=[UserMessage(content="Write a sample of C code.")])
# encode message
tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens
# load model
model = Transformer.from_folder("/content/mistral_models/mamba-codestral")
# generate results
out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)
# decode generated tokens
result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
print(result)
  1. got this error
WARNING:simple_parsing.helpers.serialization.serializable:Couldn't find the field 'head_dim' in the dict with keys ['dim', 'n_layers', 'vocab_size', 'n_groups', 'rms_norm', 'residual_in_fp32', 'fused_add_norm', 'pad_vocab_size_multiple', 'tie_embeddings', 'model_type']
WARNING:simple_parsing.helpers.serialization.serializable:Couldn't find the field 'hidden_dim' in the dict with keys ['dim', 'n_layers', 'vocab_size', 'n_groups', 'rms_norm', 'residual_in_fp32', 'fused_add_norm', 'pad_vocab_size_multiple', 'tie_embeddings', 'model_type']
WARNING:simple_parsing.helpers.serialization.serializable:Couldn't find the field 'n_heads' in the dict with keys ['dim', 'n_layers', 'vocab_size', 'n_groups', 'rms_norm', 'residual_in_fp32', 'fused_add_norm', 'pad_vocab_size_multiple', 'tie_embeddings', 'model_type']
WARNING:simple_parsing.helpers.serialization.serializable:Couldn't find the field 'n_kv_heads' in the dict with keys ['dim', 'n_layers', 'vocab_size', 'n_groups', 'rms_norm', 'residual_in_fp32', 'fused_add_norm', 'pad_vocab_size_multiple', 'tie_embeddings', 'model_type']
WARNING:simple_parsing.helpers.serialization.serializable:Couldn't find the field 'norm_eps' in the dict with keys ['dim', 'n_layers', 'vocab_size', 'n_groups', 'rms_norm', 'residual_in_fp32', 'fused_add_norm', 'pad_vocab_size_multiple', 'tie_embeddings', 'model_type']
WARNING:simple_parsing.helpers.serialization.serializable:Dropping extra args {'n_groups': 8, 'rms_norm': True, 'residual_in_fp32': True, 'fused_add_norm': True, 'pad_vocab_size_multiple': 1, 'tie_embeddings': False}
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/simple_parsing/helpers/serialization/serializable.py](https://localhost:8080/#) in from_dict(cls, d, drop_extra_fields)
    896     try:
--> 897         instance = cls(**init_args)  # type: ignore
    898     except TypeError as e:

TypeError: TransformerArgs.__init__() missing 5 required positional arguments: 'head_dim', 'hidden_dim', 'n_heads', 'n_kv_heads', and 'norm_eps'

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
3 frames
[<ipython-input-3-8c814f063ac4>](https://localhost:8080/#) in <cell line: 17>()
     15 tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens
     16 # load model
---> 17 model = Transformer.from_folder("/content/mistral_models/mamba-codestral")
     18 # generate results
     19 out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)

[/usr/local/lib/python3.10/dist-packages/mistral_inference/transformer.py](https://localhost:8080/#) in from_folder(folder, max_batch_size, num_pipeline_ranks, device, dtype)
    260     ) -> "Transformer":
    261         with open(Path(folder) / "params.json", "r") as f:
--> 262             model_args = TransformerArgs.from_dict(json.load(f))
    263         model_args.max_batch_size = max_batch_size
    264         if num_pipeline_ranks > 1:

[/usr/local/lib/python3.10/dist-packages/simple_parsing/helpers/serialization/serializable.py](https://localhost:8080/#) in from_dict(cls, obj, drop_extra_fields)
    252         Passing `drop_extra_fields=False` forces the above-mentioned behaviour.
    253         """
--> 254         return from_dict(cls, obj, drop_extra_fields=drop_extra_fields)
    255 
    256     def dump(self, fp: IO[str], dump_fn: DumpFn = json.dump) -> None:

[/usr/local/lib/python3.10/dist-packages/simple_parsing/helpers/serialization/serializable.py](https://localhost:8080/#) in from_dict(cls, d, drop_extra_fields)
    898     except TypeError as e:
    899         # raise RuntimeError(f"Couldn't instantiate class {cls} using init args {init_args}.")
--> 900         raise RuntimeError(
    901             f"Couldn't instantiate class {cls} using init args {init_args.keys()}: {e}"
    902         )

RuntimeError: Couldn't instantiate class <class 'mistral_inference.args.TransformerArgs'> using init args dict_keys(['dim', 'n_layers', 'vocab_size', 'model_type']): TransformerArgs.__init__() missing 5 required positional arguments: 'head_dim', 'hidden_dim', 'n_heads', 'n_kv_heads', and 'norm_eps'

Expected Behavior

Expect the answer from the mamba-codestral to my query

Additional Context

No response

Suggested Solutions

No response

@NM5035 NM5035 added the bug Something isn't working label Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant