Skip to content

Commit

Permalink
Expose Task and Backbone (keras-team#1506)
Browse files Browse the repository at this point in the history
These are already exposed on KerasCV, and I think it is time to
also expose these in KerasNLP. This will give us a class to document
common model functionality to all backbones such as `enable_lora`
and `token_embedding` on keras.io.

It can also open up a path for writing a custom architecture outside
the library itself.
  • Loading branch information
mattdangerw authored and abuelnasr0 committed Apr 2, 2024
1 parent 8698f84 commit 29a87cb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.bart.bart_backbone import BartBackbone
from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM
Expand Down Expand Up @@ -130,6 +131,7 @@
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.models.t5.t5_backbone import T5Backbone
from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer
from keras_nlp.models.task import Task
from keras_nlp.models.whisper.whisper_audio_feature_extractor import (
WhisperAudioFeatureExtractor,
)
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.utils.preset_utils import check_preset_class
Expand All @@ -20,7 +21,7 @@
from keras_nlp.utils.python_utils import format_docstring


@keras.saving.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.models.Backbone")
class Backbone(keras.Model):
def __init__(self, *args, dtype=None, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rich import markup
from rich import table as rich_table

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.utils.keras_utils import print_msg
Expand All @@ -26,7 +27,7 @@
from keras_nlp.utils.python_utils import format_docstring


@keras.saving.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.models.Task")
class Task(PipelineModel):
"""Base class for Task models."""

Expand Down

0 comments on commit 29a87cb

Please sign in to comment.