Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tagger and trainerXL is not working #250

Open
conneblock opened this issue Jan 8, 2025 · 4 comments
Open

tagger and trainerXL is not working #250

conneblock opened this issue Jan 8, 2025 · 4 comments

Comments

@conneblock
Copy link

env: PYTHONPATH=/content/kohya-trainer
2025-01-08 17:43:27.177299: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-08 17:43:27.177353: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-08 17:43:27.179812: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/import_utils.py", line 710, in _get_module
return importlib.import_module("." + module_name, self.name)
File "/usr/lib/python3.10/importlib/init.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "", line 1050, in _gcd_import
File "", line 1027, in _find_and_load
File "", line 1006, in _find_and_load_unlocked
File "", line 688, in _load_unlocked
File "", line 883, in exec_module
File "", line 241, in _call_with_frames_removed
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/init.py", line 95, in
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
File "/usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py", line 20, in
import flax
File "/usr/local/lib/python3.10/dist-packages/flax/init.py", line 24, in
from flax import core
File "/usr/local/lib/python3.10/dist-packages/flax/core/init.py", line 24, in
from .lift import (
File "/usr/local/lib/python3.10/dist-packages/flax/core/lift.py", line 27, in
from flax import traverse_util
File "/usr/local/lib/python3.10/dist-packages/flax/traverse_util.py", line 66, in
from flax.core.scope import VariableDict
File "/usr/local/lib/python3.10/dist-packages/flax/core/scope.py", line 55, in
from . import meta, partial_eval, tracers
File "/usr/local/lib/python3.10/dist-packages/flax/core/meta.py", line 188, in
class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
File "/usr/lib/python3.10/abc.py", line 106, in new
cls = super().new(mcls, name, bases, namespace, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/flax/struct.py", line 235, in init_subclass
dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types
File "/usr/local/lib/python3.10/dist-packages/flax/struct.py", line 150, in dataclass
jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.tree_util' has no attribute 'register_dataclass'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/content/kohya-trainer/finetune/tag_images_by_wd14_tagger.py", line 15, in
import library.train_util as train_util
File "/content/kohya-trainer/library/train_util.py", line 41, in
from diffusers import (
File "", line 1075, in _handle_fromlist
File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/import_utils.py", line 701, in getattr
value = getattr(module, name)
File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/import_utils.py", line 700, in getattr
module = self._get_module(self._class_to_module[name])
File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/import_utils.py", line 712, in _get_module
raise RuntimeError(
RuntimeError: Failed to import diffusers.pipelines.stable_diffusion because of the following error (look up to see its traceback):
module 'jax.tree_util' has no attribute 'register_dataclass'

@conneblock conneblock changed the title tagger is not working tagger and trainerXL is not working Jan 8, 2025
@Khanykov01
Copy link

Can confirm just got the same issue. Was fine in the afternoon, but now, it gives error.

@TheRamosOnline
Copy link

TheRamosOnline commented Jan 8, 2025

Here as well, trainer was working fine all day, now it doesn't work.

Traceback (most recent call last):
  File "/content/kohya-trainer/train_network_xl_wrapper.py", line 2, in <module>
    from sdxl_train_network import setup_parser, SdxlNetworkTrainer
  File "/content/kohya-trainer/sdxl_train_network.py", line 3, in <module>
    from library import sdxl_model_util, sdxl_train_util, train_util
  File "/content/kohya-trainer/library/sdxl_model_util.py", line 7, in <module>
    from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 38, in <module>
    from .models import (
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 33, in <module>
    from .controlnet_flax import FlaxControlNetModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 16, in <module>
    import flax
  File "/usr/local/lib/python3.10/dist-packages/flax/__init__.py", line 24, in <module>
    from flax import core
  File "/usr/local/lib/python3.10/dist-packages/flax/core/__init__.py", line 24, in <module>
    from .lift import (
  File "/usr/local/lib/python3.10/dist-packages/flax/core/lift.py", line 27, in <module>
    from flax import traverse_util
  File "/usr/local/lib/python3.10/dist-packages/flax/traverse_util.py", line 66, in <module>
    from flax.core.scope import VariableDict
  File "/usr/local/lib/python3.10/dist-packages/flax/core/scope.py", line 55, in <module>
    from . import meta, partial_eval, tracers
  File "/usr/local/lib/python3.10/dist-packages/flax/core/meta.py", line 188, in <module>
    class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
  File "/usr/lib/python3.10/abc.py", line 106, in __new__
    cls = super().__new__(mcls, name, bases, namespace, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/flax/struct.py", line 235, in __init_subclass__
    dataclass(cls, **kwargs)  # pytype: disable=wrong-arg-types
  File "/usr/local/lib/python3.10/dist-packages/flax/struct.py", line 150, in dataclass
    jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.tree_util' has no attribute 'register_dataclass'

@gwhitez
Copy link

gwhitez commented Jan 8, 2025

adding pip install flax==0.7.5 to the dependencies resolves the error.
image

already corrected here
https://colab.research.google.com/github/gwhitez/Lora-Trainer-XL/blob/main/Fix_Lora_Trainer_XL.ipynb

@hollowstrawberry
Copy link
Owner

Implemented the fix. Thank you @gwhitez

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants