Skip to content

Commit

Permalink
Add utilities for instantiating JAX models from config files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662080465
  • Loading branch information
mjanusz authored and copybara-github committed Aug 12, 2024
1 parent 1b96457 commit 2060884
Showing 1 changed file with 5 additions and 29 deletions.
34 changes: 5 additions & 29 deletions ffn/training/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,10 @@
# ==============================================================================
"""Functions for dynamically importing symbols from modules."""

import importlib
from absl import logging
from connectomics.common import import_util


def import_symbol(specifier, default_packages='ffn.training.models'):
"""Imports a symbol from a python module.
The calling module must have the target module for the import as dependency.
Args:
specifier: full path specifier in format
[<packages>.]<module_name>.<model_class>, if packages is missing
``default_packages`` is used.
default_packages: chain of packages before module in format
<top_pack>.<sub_pack>.<subsub_pack> etc.
Returns:
symbol: object from module
"""
module_path, symbol_name = specifier.rsplit('.', 1)
try:
logging.info('Importing symbol %s from %s.%s',
symbol_name, default_packages, module_path)
module = importlib.import_module(default_packages + '.' + module_path)
except ImportError as e:
logging.info(e)
logging.info('Importing symbol %s from %s', symbol_name, module_path)
module = importlib.import_module(module_path)

symbol = getattr(module, symbol_name)
return symbol
def import_symbol(
specifier: str, default_packages: str = 'ffn.training.models'
):
return import_util.import_symbol(specifier, default_packages)

0 comments on commit 2060884

Please sign in to comment.