Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add back history and reset subcommand in magics #997

Merged
merged 12 commits into from
Oct 7, 2024
28 changes: 28 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,34 @@ A function that computes the lowest common multiples of two integers, and
a function that runs 5 test cases of the lowest common multiple function
```

### Configuring the amount of history to include in the context

By default, two previous Human/AI message exchanges are included in the context of the new prompt.
You can change this using the IPython `%config` magic, for example:

```python
%config AiMagics.max_history = 4
```

Note that old messages are still kept locally in memory,
so they will be included in the context of the next prompt after raising the `max_history` value.

You can configure the value for all notebooks
by specifying `c.AiMagics.max_history` traitlet in `ipython_config.py`, for example:

```python
c.AiMagics.max_history = 4
```

### Clearing the chat history

You can run the `%ai reset` line magic command to clear the chat history. After you do this,
previous magic commands you've run will no longer be added as context in requests.

```
%ai reset
```

### Interpolating in prompts

Using curly brace syntax, you can include variables and other Python expressions in your
Expand Down
44 changes: 41 additions & 3 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
from langchain.chains import LLMChain
from langchain.schema import HumanMessage
from langchain_core.messages import AIMessage

from ._version import __version__
from .parsers import (
Expand All @@ -24,6 +25,7 @@
HelpArgs,
ListArgs,
RegisterArgs,
ResetArgs,
UpdateArgs,
VersionArgs,
cell_magic_parser,
Expand Down Expand Up @@ -144,9 +146,18 @@ class AiMagics(Magics):
config=True,
)

max_history = traitlets.Int(
default_value=2,
allow_none=False,
help="""Maximum number of exchanges (user/assistant) to include in the history
when invoking a chat model, defaults to 2.
""",
config=True,
)

def __init__(self, shell):
super().__init__(shell)
self.transcript_openai = []
self.transcript = []

# suppress warning when using old Anthropic provider
warnings.filterwarnings(
Expand Down Expand Up @@ -437,6 +448,12 @@ def handle_error(self, args: ErrorArgs):

return self.run_ai_cell(cell_args, prompt)

def _append_exchange(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
self.transcript.append(HumanMessage(prompt))
self.transcript.append(AIMessage(output))

dlqqq marked this conversation as resolved.
Show resolved Hide resolved
def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
# custom_model_registry maps keys to either a model name (a string) or an LLMChain.
Expand Down Expand Up @@ -500,6 +517,9 @@ def handle_list(self, args: ListArgs):
def handle_version(self, args: VersionArgs):
return __version__

def handle_reset(self, args: ResetArgs):
self.transcript = []

def run_ai_cell(self, args: CellArgs, prompt: str):
provider_id, local_model_id = self._decompose_model_id(args.model_id)

Expand Down Expand Up @@ -577,13 +597,29 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
ip = self.shell
prompt = prompt.format_map(FormatDict(ip.user_ns))

context = self.transcript[-2 * self.max_history :] if self.max_history else []
if provider.is_chat_provider:
result = provider.generate([[HumanMessage(content=prompt)]])
result = provider.generate([[*context, HumanMessage(content=prompt)]])
else:
# generate output from model via provider
result = provider.generate([prompt])
if context:
inputs = [
(
f"AI: {message.content}"
if message.type == "ai"
else f"{message.type.title()}: {message.content}"
)
for message in context + [HumanMessage(content=prompt)]
]
else:
inputs = [prompt]
result = provider.generate(inputs)

output = result.generations[0][0].text

# append exchange to transcript
self._append_exchange(prompt, output)

md = {"jupyter_ai": {"provider_id": provider_id, "model_id": local_model_id}}

return self.display_output(output, args.format, md)
Expand Down Expand Up @@ -628,6 +664,8 @@ def ai(self, line, cell=None):
return self.handle_update(args)
if args.type == "version":
return self.handle_version(args)
if args.type == "reset":
return self.handle_reset(args)
except ValueError as e:
print(e, file=sys.stderr)
return
Expand Down
13 changes: 13 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class UpdateArgs(BaseModel):
target: str


class ResetArgs(BaseModel):
type: Literal["reset"] = "reset"


class LineMagicGroup(click.Group):
"""Helper class to print the help string for cell magics as well when
`%ai --help` is called."""
Expand Down Expand Up @@ -277,3 +281,12 @@ def register_subparser(**kwargs):
def register_subparser(**kwargs):
"""Update an alias called NAME to refer to the model or chain named TARGET."""
return UpdateArgs(**kwargs)


@line_magic_parser.command(
name="reset",
short_help="Clear the conversation transcript.",
)
def register_subparser(**kwargs):
"""Clear the conversation transcript."""
return ResetArgs()
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import patch
import os
from unittest.mock import Mock, patch

import pytest
from IPython import InteractiveShell
from IPython.core.display import Markdown
from jupyter_ai_magics.magics import AiMagics
from langchain_core.messages import AIMessage, HumanMessage
from pytest import fixture
from traitlets.config.loader import Config

Expand Down Expand Up @@ -48,3 +52,71 @@ def test_default_model_error_line(ip):
assert mock_run.called
cell_args = mock_run.call_args.args[0]
assert cell_args.model_id == "my-favourite-llm"


PROMPT = HumanMessage(
content=("Write code for me please\n\nProduce output in markdown format only.")
)
RESPONSE = AIMessage(content="Leet code")
AI1 = AIMessage("ai1")
H1 = HumanMessage("h1")
AI2 = AIMessage("ai2")
H2 = HumanMessage("h2")
AI3 = AIMessage("ai3")


@pytest.mark.parametrize(
["transcript", "max_history", "expected_context"],
[
([], 3, [PROMPT]),
([AI1], 0, [PROMPT]),
([AI1], 1, [AI1, PROMPT]),
([H1, AI1], 0, [PROMPT]),
([H1, AI1], 1, [H1, AI1, PROMPT]),
([AI1, H1, AI2], 0, [PROMPT]),
([AI1, H1, AI2], 1, [H1, AI2, PROMPT]),
([AI1, H1, AI2], 2, [AI1, H1, AI2, PROMPT]),
([H1, AI1, H2, AI2], 0, [PROMPT]),
([H1, AI1, H2, AI2], 1, [H2, AI2, PROMPT]),
([H1, AI1, H2, AI2], 2, [H1, AI1, H2, AI2, PROMPT]),
([AI1, H1, AI2, H2, AI3], 0, [PROMPT]),
([AI1, H1, AI2, H2, AI3], 1, [H2, AI3, PROMPT]),
([AI1, H1, AI2, H2, AI3], 2, [H1, AI2, H2, AI3, PROMPT]),
([AI1, H1, AI2, H2, AI3], 3, [AI1, H1, AI2, H2, AI3, PROMPT]),
],
)
def test_max_history(ip, transcript, max_history, expected_context):
ip.extension_manager.load_extension("jupyter_ai_magics")
ai_magics = ip.magics_manager.registry["AiMagics"]
ai_magics.transcript = transcript.copy()
ai_magics.max_history = max_history
provider = ai_magics._get_provider("openrouter")
with patch.object(provider, "generate") as generate, patch.dict(
os.environ, OPENROUTER_API_KEY="123"
):
generate.return_value.generations = [[Mock(text="Leet code")]]
result = ip.run_cell_magic(
"ai",
"openrouter:anthropic/claude-3.5-sonnet",
cell="Write code for me please",
)
provider.generate.assert_called_once_with([expected_context])
assert isinstance(result, Markdown)
assert result.data == "Leet code"
assert result.filename is None
assert result.metadata == {
"jupyter_ai": {
"model_id": "anthropic/claude-3.5-sonnet",
"provider_id": "openrouter",
}
}
assert result.url is None
assert ai_magics.transcript == [*transcript, PROMPT, RESPONSE]


def test_reset(ip):
ip.extension_manager.load_extension("jupyter_ai_magics")
ai_magics = ip.magics_manager.registry["AiMagics"]
ai_magics.transcript = [AI1, H1, AI2, H2, AI3]
result = ip.run_line_magic("ai", "reset")
assert ai_magics.transcript == []
Loading