Skip to content

Commit

Permalink
refactor autocomplete to remove hardcoded '/' and '@' prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelchia committed Sep 19, 2024
1 parent d93da8f commit 139e37a
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 56 deletions.
37 changes: 26 additions & 11 deletions packages/jupyter-ai/jupyter_ai/context_providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class ContextCommand(BaseModel):

@property
def id(self) -> str:
return self.cmd.partition(":")[0][1:]
return self.cmd.partition(":")[0]

@property
def arg(self) -> Optional[str]:
Expand All @@ -122,11 +122,17 @@ def __hash__(self) -> int:


class BaseCommandContextProvider(BaseContextProvider):
id_prefix: ClassVar[str] = "@"
only_start: ClassVar[bool] = False
requires_arg: ClassVar[bool] = False
remove_from_prompt: ClassVar[bool] = (
False # whether the command should be removed from prompt
)

@property
def command_id(self) -> str:
return self.id_prefix + self.id

@property
def pattern(self) -> str:
# arg pattern allows for arguments between quotes or spaces with escape character ('\ ')
Expand All @@ -153,20 +159,14 @@ def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]:
"""
if self.requires_arg:
# default implementation that should be modified if 'requires_arg' is True
return [
ListOptionsEntry.from_arg(
type="@",
id=self.id,
description=self.description,
arg=arg_prefix,
is_complete=True,
)
]
return [self._make_arg_option(arg_prefix)]
return []

def _find_commands(self, text: str) -> List[ContextCommand]:
# finds commands of the context provider in the text
matches = re.finditer(self.pattern, text)
matches = list(re.finditer(self.pattern, text))
if self.only_start:
matches = [match for match in matches if match.start() == 0]
results = []
for match in matches:
if not _is_within_backticks(match, text):
Expand All @@ -178,6 +178,21 @@ def _replace_command(self, command: ContextCommand) -> str:
return ""
return command.cmd

def _make_arg_option(
self,
arg: str,
*,
is_complete: bool = True,
description: Optional[str] = None,
) -> ListOptionsEntry:
return ListOptionsEntry.from_arg(
id=self.command_id,
description=description or self.description,
only_start=self.only_start,
arg=arg,
is_complete=is_complete,
)


def _is_within_backticks(match, text):
# potentially buggy if there is a stray backtick in text
Expand Down
16 changes: 7 additions & 9 deletions packages/jupyter-ai/jupyter_ai/context_providers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,24 @@ def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]:
path_prefix = arg_prefix if is_abs else os.path.join(self.base_dir, arg_prefix)
path_prefix = path_prefix
return [
self._make_option(path, is_abs, is_dir)
self._make_arg_option(
arg=self._make_path(path, is_abs, is_dir),
description="Directory" if is_dir else "File",
is_complete=not is_dir,
)
for path in glob.glob(path_prefix + "*")
if (
(is_dir := os.path.isdir(path))
or os.path.splitext(path)[1] in SUPPORTED_EXTS
)
]

def _make_option(self, path: str, is_abs: bool, is_dir: bool) -> ListOptionsEntry:
def _make_path(self, path: str, is_abs: bool, is_dir: bool) -> str:
if not is_abs:
path = os.path.relpath(path, self.base_dir)
if is_dir:
path += "/"
return ListOptionsEntry.from_arg(
type="@",
id=self.id,
description="Directory" if is_dir else "File",
arg=path,
is_complete=not is_dir,
)
return path

async def make_context_prompt(self, message: HumanChatMessage) -> str:
commands = set(self._find_commands(message.prompt))
Expand Down
18 changes: 14 additions & 4 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,15 @@ def get(self):
def post(self):
try:
data = self.get_json_body()
context_provider = self.context_providers.get(data["id"])
context_provider = next(
(
cp
for cp in self.context_providers.values()
if isinstance(cp, BaseCommandContextProvider)
and cp.command_id == data["id"]
),
None,
)
cmd = data["cmd"]
response = ListOptionsResponse()

Expand Down Expand Up @@ -658,7 +666,9 @@ def _get_slash_command_options(self) -> List[ListOptionsEntry]:

options.append(
ListOptionsEntry.from_command(
type="/", id=routing_type.slash_id, description=chat_handler.help
id="/" + routing_type.slash_id,
description=chat_handler.help,
only_start=True,
)
)
options.sort(key=lambda opt: opt.id)
Expand All @@ -667,9 +677,9 @@ def _get_slash_command_options(self) -> List[ListOptionsEntry]:
def _get_context_provider_options(self) -> List[ListOptionsEntry]:
options = [
ListOptionsEntry.from_command(
type="@",
id=context_provider.id,
id=context_provider.command_id,
description=context_provider.description,
only_start=context_provider.only_start,
requires_arg=context_provider.requires_arg,
)
for context_provider in self.context_providers.values()
Expand Down
21 changes: 14 additions & 7 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,34 +266,41 @@ class ListSlashCommandsResponse(BaseModel):


class ListOptionsEntry(BaseModel):
type: Literal["/", "@"]
id: str
# includes the command prefix. e.g. "/clear", "@file".
label: str
description: str
only_start: bool
# only allows autocomplete to be triggered if command is at start of input

@classmethod
def from_command(
cls,
type: Literal["/", "@"],
id: str,
description: str,
only_start: bool = False,
requires_arg: bool = False,
):
label = type + id + (":" if requires_arg else " ")
return cls(type=type, id=id, description=description, label=label)
label = id + (":" if requires_arg else " ")
return cls(id=id, description=description, label=label, only_start=only_start)

@classmethod
def from_arg(
cls,
type: Literal["/", "@"],
id: str,
description: str,
arg: str,
only_start: bool = False,
is_complete: bool = True,
):
arg = arg.replace("\\ ", " ").replace(" ", "\\ ") # escape spaces
label = type + id + ":" + arg + (" " if is_complete else "")
return cls(type=type, id=id, description=description, label=label)
label = id + ":" + arg + (" " if is_complete else "")
return cls(
id=id,
description=description,
label=label,
only_start=only_start,
)


class ListOptionsResponse(BaseModel):
Expand Down
54 changes: 30 additions & 24 deletions packages/jupyter-ai/src/components/chat-input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ type ChatInputProps = {
* unclear whether custom icons should be defined within a Lumino plugin (in the
* frontend) or served from a static server route (in the backend).
*/
const DEFAULT_SLASH_COMMAND_ICONS: Record<string, JSX.Element> = {
ask: <FindInPage />,
clear: <HideSource />,
export: <Download />,
fix: <AutoFixNormal />,
generate: <MenuBook />,
help: <Help />,
learn: <School />,
const DEFAULT_COMMAND_ICONS: Record<string, JSX.Element> = {
'/ask': <FindInPage />,
'/clear': <HideSource />,
'/export': <Download />,
'/fix': <AutoFixNormal />,
'/generate': <MenuBook />,
'/help': <Help />,
'/learn': <School />,
'@file': <FindInPage />,
unknown: <MoreHoriz />
};

Expand All @@ -64,9 +65,9 @@ function renderAutocompleteOption(
option: AiService.AutocompleteOption
): JSX.Element {
const icon =
option.id in DEFAULT_SLASH_COMMAND_ICONS
? DEFAULT_SLASH_COMMAND_ICONS[option.id]
: DEFAULT_SLASH_COMMAND_ICONS.unknown;
option.id in DEFAULT_COMMAND_ICONS
? DEFAULT_COMMAND_ICONS[option.id]
: DEFAULT_COMMAND_ICONS.unknown;

return (
<li {...optionProps}>
Expand Down Expand Up @@ -120,12 +121,12 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
useEffect(() => {
async function getAutocompleteArgOptions() {
let options: AiService.AutocompleteOption[] = [];
const lastWord = input.split(/(?<!\\)\s+/).pop() || '';
if (lastWord.startsWith('@') && lastWord.includes(':')) {
const lastWord = getLastWord(input);
if (lastWord.includes(':')) {
const id = lastWord.split(':', 1)[0];
// get option that matches the command
const option = autocompleteCommandOptions.find(
option => option.id === id.slice(1) && option.type === '@'
option => option.id === id
);
if (option) {
const response = await AiService.listAutocompleteArgOptions({
Expand All @@ -149,10 +150,10 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
}
}, [autocompleteCommandOptions, autocompleteArgOptions]);

// whether any option is highlighted in the slash command autocomplete
// whether any option is highlighted in the autocomplete
const [highlighted, setHighlighted] = useState<boolean>(false);

// controls whether the slash command autocomplete is open
// controls whether the autocomplete is open
const [open, setOpen] = useState<boolean>(false);

// store reference to the input element to enable focusing it easily
Expand All @@ -178,7 +179,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
* chat input. Close the autocomplete when the user clears the chat input.
*/
useEffect(() => {
if (input === '/' || input.endsWith('@')) {
if (filterAutocompleteOptions(autocompleteOptions, input).length > 0) {
setOpen(true);
return;
}
Expand Down Expand Up @@ -284,14 +285,15 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
options: AiService.AutocompleteOption[],
inputValue: string
): AiService.AutocompleteOption[] {
const lastWord = inputValue.split(/(?<!\\)\s+/).pop() || '';
if (
(lastWord.startsWith('/') && lastWord === inputValue) ||
lastWord.startsWith('@')
) {
return options.filter(option => option.label.startsWith(lastWord));
const lastWord = getLastWord(inputValue);
if (lastWord === '') {
return [];
}
return [];
const isStart = lastWord === inputValue;
return options.filter(
option =>
option.label.startsWith(lastWord) && (!option.only_start || isStart)
);
}

return (
Expand Down Expand Up @@ -387,3 +389,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
</Box>
);
}

function getLastWord(input: string): string {
return input.split(/(?<!\\)\s+/).pop() || '';
}
2 changes: 1 addition & 1 deletion packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ export namespace AiService {
}

export type AutocompleteOption = {
type: '/' | '@';
id: string;
description: string;
label: string;
only_start: boolean;
};

export type ListAutocompleteOptionsResponse = {
Expand Down

0 comments on commit 139e37a

Please sign in to comment.