Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Extensions API #6354

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions docs/07 - Extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ The extensions framework is based on special functions and variables that you ca
| Function | Description |
|-------------|-------------|
| `def setup()` | Is executed when the extension gets imported. |
| `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def ui()` | Obsolete, but still supported. Creates custom gradio elements when the UI is launched. |
| `def ui_block()` | Creates custom Gradio elements at the bottom of the Chat, Default and Notebook tabs. |
| `def ui_tab()` | Creates a tab for a large interface of extension. Similar to the deprecated is_tab=true in params. |
| `def ui_params()` | Creates a tab for extension settings in Parameters. |
| `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
| `def custom_js()` | Same as above but for javascript. |
| `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
Expand All @@ -48,7 +51,7 @@ The extensions framework is based on special functions and variables that you ca
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. |
| `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `multimodal` extension for an example. |

Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab.
Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI. The `is_tab` key is deprecated and it is better to write UIs in `def ui_tab():` instead, but is still supported if the UI is created in the deprecated `def ui():`

Example:

Expand Down Expand Up @@ -230,10 +233,31 @@ def setup():
"""
pass

def ui():
def ui_block():
"""
Gets executed when the UI is drawn. Custom gradio elements and
their corresponding event handlers should be defined here.
Gets executed when the UI is drawn. The custom gradio elements
that are used most often and their corresponding event handlers
should be defined here.

To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
pass

def ui_tab():
"""
Gets executed when the UI is drawn and creates a tab for the big UI.
Its gradio elements and corresponding event handlers should be defined here.

To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
pass

def ui_params():
"""
Executed when the user interface is rendered. Elements of the extension
settings and event handlers corresponding to them should be defined here.

To learn about gradio components, check out the docs:
https://gradio.app/docs/
Expand Down
27 changes: 24 additions & 3 deletions extensions/example/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,31 @@ def setup():
"""
pass

def ui():
def ui_block():
"""
Gets executed when the UI is drawn. Custom gradio elements and
their corresponding event handlers should be defined here.
Gets executed when the UI is drawn. The custom gradio elements
that are used most often and their corresponding event handlers
should be defined here.

To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
pass

def ui_tab():
"""
Gets executed when the UI is drawn and creates a tab for the big UI.
Its gradio elements and corresponding event handlers should be defined here.

To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
pass

def ui_params():
"""
Executed when the user interface is rendered. Elements of the extension
settings and event handlers corresponding to them should be defined here.

To learn about gradio components, check out the docs:
https://gradio.app/docs/
Expand Down
26 changes: 22 additions & 4 deletions modules/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,42 @@ def _apply_custom_js():
def create_extensions_block():
to_display = []
for extension, name in iterator():
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
# Use ui_block if it is defined, otherwise use the old ui
if hasattr(extension, "ui_block"):
to_display.append((extension, name))
elif hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
to_display.append((extension, name))

# Creating the extension ui elements
if len(to_display) > 0:
with gr.Column(elem_id="extensions"):
for row in to_display:
extension, _ = row
extension.ui()
if hasattr(extension, "ui_block"):
extension.ui_block()
else:
extension.ui()


def create_extensions_tabs():
for extension, name in iterator():
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
# Use ui_tab if it is defined, otherwise use the old ui with the is_tab parameter
if hasattr(extension, "ui_tab"):
display_name = getattr(extension, 'params', {}).get('display_name', name)
with gr.Tab(display_name, elem_classes="extension-tab"):
extension.ui_tab()
elif hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
display_name = getattr(extension, 'params', {}).get('display_name', name)
with gr.Tab(display_name, elem_classes="extension-tab"):
extension.ui()

# Creates a tab in Parameters to hold the extension settings
def create_extensions_params():
for extension, name in iterator():
if hasattr(extension, "ui_params"):
display_name = getattr(extension, 'params', {}).get('display_name', name)
with gr.Tab(display_name):
extension.ui_params()


EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
Expand Down
3 changes: 2 additions & 1 deletion modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gradio as gr

from modules import loaders, presets, shared, ui, ui_chat, utils
from modules import loaders, presets, shared, ui, ui_chat, utils, extensions
from modules.utils import gradio


Expand Down Expand Up @@ -102,6 +102,7 @@ def create_ui(default_preset):
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')

ui_chat.create_chat_settings_ui()
extensions.create_extensions_params()


def create_event_handlers():
Expand Down