From 56d20fc0a3854ef6fa60aa30470ff985041964c9 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 18 Dec 2024 18:12:54 -0300 Subject: [PATCH] Reuse prompt editor across buffer and terminal assist (#22188) Builds on https://github.com/zed-industries/zed/pull/22160 and extracts the rest of `PromptEditor` so it can be shared across terminal and inline assistants. This will help avoid the UI drifting as we have already observed. Note: This is mostly a mechanical refactor. I imagine some things could be factored in a better way by someone with more context, but I think this is a good start. Release Notes: - N/A --------- Co-authored-by: Richard Feldman --- crates/assistant2/src/assistant.rs | 2 + crates/assistant2/src/buffer_codegen.rs | 1475 +++++++++++ crates/assistant2/src/inline_assistant.rs | 2304 +---------------- crates/assistant2/src/inline_prompt_editor.rs | 1216 ++++++++- crates/assistant2/src/terminal_codegen.rs | 192 ++ .../src/terminal_inline_assistant.rs | 607 +---- 6 files changed, 2839 insertions(+), 2957 deletions(-) create mode 100644 crates/assistant2/src/buffer_codegen.rs create mode 100644 crates/assistant2/src/terminal_codegen.rs diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index 2feee77010fe9c..24c0793c9a5155 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -1,6 +1,7 @@ mod active_thread; mod assistant_panel; mod assistant_settings; +mod buffer_codegen; mod context; mod context_picker; mod context_store; @@ -10,6 +11,7 @@ mod inline_prompt_editor; mod message_editor; mod prompts; mod streaming_diff; +mod terminal_codegen; mod terminal_inline_assistant; mod thread; mod thread_history; diff --git a/crates/assistant2/src/buffer_codegen.rs b/crates/assistant2/src/buffer_codegen.rs new file mode 100644 index 00000000000000..54a8aa0383652d --- /dev/null +++ b/crates/assistant2/src/buffer_codegen.rs @@ -0,0 +1,1475 @@ +use crate::context::attach_context_to_message; +use crate::context_store::ContextStore; +use crate::inline_prompt_editor::CodegenStatus; +use crate::{ + prompts::PromptBuilder, + streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}, +}; +use anyhow::{Context as _, Result}; +use client::telemetry::Telemetry; +use collections::HashSet; +use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; +use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt}; +use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task}; +use language::{Buffer, IndentKind, Point, TransactionId}; +use language_model::{ + LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelTextStream, Role, +}; +use language_models::report_assistant_event; +use multi_buffer::MultiBufferRow; +use parking_lot::Mutex; +use rope::Rope; +use smol::future::FutureExt; +use std::{ + cmp, + future::Future, + iter, + ops::{Range, RangeInclusive}, + pin::Pin, + sync::Arc, + task::{self, Poll}, + time::Instant, +}; +use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase}; + +pub struct BufferCodegen { + alternatives: Vec>, + pub active_alternative: usize, + seen_alternatives: HashSet, + subscriptions: Vec, + buffer: Model, + range: Range, + initial_transaction_id: Option, + context_store: Model, + telemetry: Arc, + builder: Arc, + pub is_insertion: bool, +} + +impl BufferCodegen { + pub fn new( + buffer: Model, + range: Range, + initial_transaction_id: Option, + context_store: Model, + telemetry: Arc, + builder: Arc, + cx: &mut ModelContext, + ) -> Self { + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + false, + Some(context_store.clone()), + Some(telemetry.clone()), + builder.clone(), + cx, + ) + }); + let mut this = Self { + is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(), + alternatives: vec![codegen], + active_alternative: 0, + seen_alternatives: HashSet::default(), + subscriptions: Vec::new(), + buffer, + range, + initial_transaction_id, + context_store, + telemetry, + builder, + }; + this.activate(0, cx); + this + } + + fn subscribe_to_alternative(&mut self, cx: &mut ModelContext) { + let codegen = self.active_alternative().clone(); + self.subscriptions.clear(); + self.subscriptions + .push(cx.observe(&codegen, |_, _, cx| cx.notify())); + self.subscriptions + .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event))); + } + + pub fn active_alternative(&self) -> &Model { + &self.alternatives[self.active_alternative] + } + + pub fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus { + &self.active_alternative().read(cx).status + } + + pub fn alternative_count(&self, cx: &AppContext) -> usize { + LanguageModelRegistry::read_global(cx) + .inline_alternative_models() + .len() + + 1 + } + + pub fn cycle_prev(&mut self, cx: &mut ModelContext) { + let next_active_ix = if self.active_alternative == 0 { + self.alternatives.len() - 1 + } else { + self.active_alternative - 1 + }; + self.activate(next_active_ix, cx); + } + + pub fn cycle_next(&mut self, cx: &mut ModelContext) { + let next_active_ix = (self.active_alternative + 1) % self.alternatives.len(); + self.activate(next_active_ix, cx); + } + + fn activate(&mut self, index: usize, cx: &mut ModelContext) { + self.active_alternative() + .update(cx, |codegen, cx| codegen.set_active(false, cx)); + self.seen_alternatives.insert(index); + self.active_alternative = index; + self.active_alternative() + .update(cx, |codegen, cx| codegen.set_active(true, cx)); + self.subscribe_to_alternative(cx); + cx.notify(); + } + + pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext) -> Result<()> { + let alternative_models = LanguageModelRegistry::read_global(cx) + .inline_alternative_models() + .to_vec(); + + self.active_alternative() + .update(cx, |alternative, cx| alternative.undo(cx)); + self.activate(0, cx); + self.alternatives.truncate(1); + + for _ in 0..alternative_models.len() { + self.alternatives.push(cx.new_model(|cx| { + CodegenAlternative::new( + self.buffer.clone(), + self.range.clone(), + false, + Some(self.context_store.clone()), + Some(self.telemetry.clone()), + self.builder.clone(), + cx, + ) + })); + } + + let primary_model = LanguageModelRegistry::read_global(cx) + .active_model() + .context("no active model")?; + + for (model, alternative) in iter::once(primary_model) + .chain(alternative_models) + .zip(&self.alternatives) + { + alternative.update(cx, |alternative, cx| { + alternative.start(user_prompt.clone(), model.clone(), cx) + })?; + } + + Ok(()) + } + + pub fn stop(&mut self, cx: &mut ModelContext) { + for codegen in &self.alternatives { + codegen.update(cx, |codegen, cx| codegen.stop(cx)); + } + } + + pub fn undo(&mut self, cx: &mut ModelContext) { + self.active_alternative() + .update(cx, |codegen, cx| codegen.undo(cx)); + + self.buffer.update(cx, |buffer, cx| { + if let Some(transaction_id) = self.initial_transaction_id.take() { + buffer.undo_transaction(transaction_id, cx); + buffer.refresh_preview(cx); + } + }); + } + + pub fn buffer(&self, cx: &AppContext) -> Model { + self.active_alternative().read(cx).buffer.clone() + } + + pub fn old_buffer(&self, cx: &AppContext) -> Model { + self.active_alternative().read(cx).old_buffer.clone() + } + + pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot { + self.active_alternative().read(cx).snapshot.clone() + } + + pub fn edit_position(&self, cx: &AppContext) -> Option { + self.active_alternative().read(cx).edit_position + } + + pub fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff { + &self.active_alternative().read(cx).diff + } + + pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range] { + self.active_alternative().read(cx).last_equal_ranges() + } +} + +impl EventEmitter for BufferCodegen {} + +pub struct CodegenAlternative { + buffer: Model, + old_buffer: Model, + snapshot: MultiBufferSnapshot, + edit_position: Option, + range: Range, + last_equal_ranges: Vec>, + transformation_transaction_id: Option, + status: CodegenStatus, + generation: Task<()>, + diff: Diff, + context_store: Option>, + telemetry: Option>, + _subscription: gpui::Subscription, + builder: Arc, + active: bool, + edits: Vec<(Range, String)>, + line_operations: Vec, + request: Option, + elapsed_time: Option, + completion: Option, + pub message_id: Option, +} + +impl EventEmitter for CodegenAlternative {} + +impl CodegenAlternative { + pub fn new( + buffer: Model, + range: Range, + active: bool, + context_store: Option>, + telemetry: Option>, + builder: Arc, + cx: &mut ModelContext, + ) -> Self { + let snapshot = buffer.read(cx).snapshot(cx); + + let (old_buffer, _, _) = buffer + .read(cx) + .range_to_buffer_ranges(range.clone(), cx) + .pop() + .unwrap(); + let old_buffer = cx.new_model(|cx| { + let old_buffer = old_buffer.read(cx); + let text = old_buffer.as_rope().clone(); + let line_ending = old_buffer.line_ending(); + let language = old_buffer.language().cloned(); + let language_registry = old_buffer.language_registry(); + + let mut buffer = Buffer::local_normalized(text, line_ending, cx); + buffer.set_language(language, cx); + if let Some(language_registry) = language_registry { + buffer.set_language_registry(language_registry) + } + buffer + }); + + Self { + buffer: buffer.clone(), + old_buffer, + edit_position: None, + message_id: None, + snapshot, + last_equal_ranges: Default::default(), + transformation_transaction_id: None, + status: CodegenStatus::Idle, + generation: Task::ready(()), + diff: Diff::default(), + context_store, + telemetry, + _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), + builder, + active, + edits: Vec::new(), + line_operations: Vec::new(), + range, + request: None, + elapsed_time: None, + completion: None, + } + } + + pub fn set_active(&mut self, active: bool, cx: &mut ModelContext) { + if active != self.active { + self.active = active; + + if self.active { + let edits = self.edits.clone(); + self.apply_edits(edits, cx); + if matches!(self.status, CodegenStatus::Pending) { + let line_operations = self.line_operations.clone(); + self.reapply_line_based_diff(line_operations, cx); + } else { + self.reapply_batch_diff(cx).detach(); + } + } else if let Some(transaction_id) = self.transformation_transaction_id.take() { + self.buffer.update(cx, |buffer, cx| { + buffer.undo_transaction(transaction_id, cx); + buffer.forget_transaction(transaction_id, cx); + }); + } + } + } + + fn handle_buffer_event( + &mut self, + _buffer: Model, + event: &multi_buffer::Event, + cx: &mut ModelContext, + ) { + if let multi_buffer::Event::TransactionUndone { transaction_id } = event { + if self.transformation_transaction_id == Some(*transaction_id) { + self.transformation_transaction_id = None; + self.generation = Task::ready(()); + cx.emit(CodegenEvent::Undone); + } + } + } + + pub fn last_equal_ranges(&self) -> &[Range] { + &self.last_equal_ranges + } + + pub fn start( + &mut self, + user_prompt: String, + model: Arc, + cx: &mut ModelContext, + ) -> Result<()> { + if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() { + self.buffer.update(cx, |buffer, cx| { + buffer.undo_transaction(transformation_transaction_id, cx); + }); + } + + self.edit_position = Some(self.range.start.bias_right(&self.snapshot)); + + let api_key = model.api_key(cx); + let telemetry_id = model.telemetry_id(); + let provider_id = model.provider_id(); + let stream: LocalBoxFuture> = + if user_prompt.trim().to_lowercase() == "delete" { + async { Ok(LanguageModelTextStream::default()) }.boxed_local() + } else { + let request = self.build_request(user_prompt, cx)?; + self.request = Some(request.clone()); + + cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await }) + .boxed_local() + }; + self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); + Ok(()) + } + + fn build_request( + &self, + user_prompt: String, + cx: &mut AppContext, + ) -> Result { + let buffer = self.buffer.read(cx).snapshot(cx); + let language = buffer.language_at(self.range.start); + let language_name = if let Some(language) = language.as_ref() { + if Arc::ptr_eq(language, &language::PLAIN_TEXT) { + None + } else { + Some(language.name()) + } + } else { + None + }; + + let language_name = language_name.as_ref(); + let start = buffer.point_to_buffer_offset(self.range.start); + let end = buffer.point_to_buffer_offset(self.range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { + let (start_buffer, start_buffer_offset) = start; + let (end_buffer, end_buffer_offset) = end; + if start_buffer.remote_id() == end_buffer.remote_id() { + (start_buffer.clone(), start_buffer_offset..end_buffer_offset) + } else { + return Err(anyhow::anyhow!("invalid transformation range")); + } + } else { + return Err(anyhow::anyhow!("invalid transformation range")); + }; + + let prompt = self + .builder + .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range) + .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?; + + let mut request_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + }; + + if let Some(context_store) = &self.context_store { + let context = context_store.update(cx, |this, _cx| this.context().clone()); + attach_context_to_message(&mut request_message, context); + } + + request_message.content.push(prompt.into()); + + Ok(LanguageModelRequest { + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + messages: vec![request_message], + }) + } + + pub fn handle_stream( + &mut self, + model_telemetry_id: String, + model_provider_id: String, + model_api_key: Option, + stream: impl 'static + Future>, + cx: &mut ModelContext, + ) { + let start_time = Instant::now(); + let snapshot = self.snapshot.clone(); + let selected_text = snapshot + .text_for_range(self.range.start..self.range.end) + .collect::(); + + let selection_start = self.range.start.to_point(&snapshot); + + // Start with the indentation of the first line in the selection + let mut suggested_line_indent = snapshot + .suggested_indents(selection_start.row..=selection_start.row, cx) + .into_values() + .next() + .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row))); + + // If the first line in the selection does not have indentation, check the following lines + if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space { + for row in selection_start.row..=self.range.end.to_point(&snapshot).row { + let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row)); + // Prefer tabs if a line in the selection uses tabs as indentation + if line_indent.kind == IndentKind::Tab { + suggested_line_indent.kind = IndentKind::Tab; + break; + } + } + } + + let http_client = cx.http_client().clone(); + let telemetry = self.telemetry.clone(); + let language_name = { + let multibuffer = self.buffer.read(cx); + let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx); + ranges + .first() + .and_then(|(buffer, _, _)| buffer.read(cx).language()) + .map(|language| language.name()) + }; + + self.diff = Diff::default(); + self.status = CodegenStatus::Pending; + let mut edit_start = self.range.start.to_offset(&snapshot); + let completion = Arc::new(Mutex::new(String::new())); + let completion_clone = completion.clone(); + + self.generation = cx.spawn(|codegen, mut cx| { + async move { + let stream = stream.await; + let message_id = stream + .as_ref() + .ok() + .and_then(|stream| stream.message_id.clone()); + let generate = async { + let (mut diff_tx, mut diff_rx) = mpsc::channel(1); + let executor = cx.background_executor().clone(); + let message_id = message_id.clone(); + let line_based_stream_diff: Task> = + cx.background_executor().spawn(async move { + let mut response_latency = None; + let request_start = Instant::now(); + let diff = async { + let chunks = StripInvalidSpans::new(stream?.stream); + futures::pin_mut!(chunks); + let mut diff = StreamingDiff::new(selected_text.to_string()); + let mut line_diff = LineDiff::default(); + + let mut new_text = String::new(); + let mut base_indent = None; + let mut line_indent = None; + let mut first_line = true; + + while let Some(chunk) = chunks.next().await { + if response_latency.is_none() { + response_latency = Some(request_start.elapsed()); + } + let chunk = chunk?; + completion_clone.lock().push_str(&chunk); + + let mut lines = chunk.split('\n').peekable(); + while let Some(line) = lines.next() { + new_text.push_str(line); + if line_indent.is_none() { + if let Some(non_whitespace_ch_ix) = + new_text.find(|ch: char| !ch.is_whitespace()) + { + line_indent = Some(non_whitespace_ch_ix); + base_indent = base_indent.or(line_indent); + + let line_indent = line_indent.unwrap(); + let base_indent = base_indent.unwrap(); + let indent_delta = + line_indent as i32 - base_indent as i32; + let mut corrected_indent_len = cmp::max( + 0, + suggested_line_indent.len as i32 + indent_delta, + ) + as usize; + if first_line { + corrected_indent_len = corrected_indent_len + .saturating_sub( + selection_start.column as usize, + ); + } + + let indent_char = suggested_line_indent.char(); + let mut indent_buffer = [0; 4]; + let indent_str = + indent_char.encode_utf8(&mut indent_buffer); + new_text.replace_range( + ..line_indent, + &indent_str.repeat(corrected_indent_len), + ); + } + } + + if line_indent.is_some() { + let char_ops = diff.push_new(&new_text); + line_diff + .push_char_operations(&char_ops, &selected_text); + diff_tx + .send((char_ops, line_diff.line_operations())) + .await?; + new_text.clear(); + } + + if lines.peek().is_some() { + let char_ops = diff.push_new("\n"); + line_diff + .push_char_operations(&char_ops, &selected_text); + diff_tx + .send((char_ops, line_diff.line_operations())) + .await?; + if line_indent.is_none() { + // Don't write out the leading indentation in empty lines on the next line + // This is the case where the above if statement didn't clear the buffer + new_text.clear(); + } + line_indent = None; + first_line = false; + } + } + } + + let mut char_ops = diff.push_new(&new_text); + char_ops.extend(diff.finish()); + line_diff.push_char_operations(&char_ops, &selected_text); + line_diff.finish(&selected_text); + diff_tx + .send((char_ops, line_diff.line_operations())) + .await?; + + anyhow::Ok(()) + }; + + let result = diff.await; + + let error_message = + result.as_ref().err().map(|error| error.to_string()); + report_assistant_event( + AssistantEvent { + conversation_id: None, + message_id, + kind: AssistantKind::Inline, + phase: AssistantPhase::Response, + model: model_telemetry_id, + model_provider: model_provider_id.to_string(), + response_latency, + error_message, + language_name: language_name.map(|name| name.to_proto()), + }, + telemetry, + http_client, + model_api_key, + &executor, + ); + + result?; + Ok(()) + }); + + while let Some((char_ops, line_ops)) = diff_rx.next().await { + codegen.update(&mut cx, |codegen, cx| { + codegen.last_equal_ranges.clear(); + + let edits = char_ops + .into_iter() + .filter_map(|operation| match operation { + CharOperation::Insert { text } => { + let edit_start = snapshot.anchor_after(edit_start); + Some((edit_start..edit_start, text)) + } + CharOperation::Delete { bytes } => { + let edit_end = edit_start + bytes; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + Some((edit_range, String::new())) + } + CharOperation::Keep { bytes } => { + let edit_end = edit_start + bytes; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + codegen.last_equal_ranges.push(edit_range); + None + } + }) + .collect::>(); + + if codegen.active { + codegen.apply_edits(edits.iter().cloned(), cx); + codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx); + } + codegen.edits.extend(edits); + codegen.line_operations = line_ops; + codegen.edit_position = Some(snapshot.anchor_after(edit_start)); + + cx.notify(); + })?; + } + + // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer. + // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff. + // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`. + let batch_diff_task = + codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?; + let (line_based_stream_diff, ()) = + join!(line_based_stream_diff, batch_diff_task); + line_based_stream_diff?; + + anyhow::Ok(()) + }; + + let result = generate.await; + let elapsed_time = start_time.elapsed().as_secs_f64(); + + codegen + .update(&mut cx, |this, cx| { + this.message_id = message_id; + this.last_equal_ranges.clear(); + if let Err(error) = result { + this.status = CodegenStatus::Error(error); + } else { + this.status = CodegenStatus::Done; + } + this.elapsed_time = Some(elapsed_time); + this.completion = Some(completion.lock().clone()); + cx.emit(CodegenEvent::Finished); + cx.notify(); + }) + .ok(); + } + }); + cx.notify(); + } + + pub fn stop(&mut self, cx: &mut ModelContext) { + self.last_equal_ranges.clear(); + if self.diff.is_empty() { + self.status = CodegenStatus::Idle; + } else { + self.status = CodegenStatus::Done; + } + self.generation = Task::ready(()); + cx.emit(CodegenEvent::Finished); + cx.notify(); + } + + pub fn undo(&mut self, cx: &mut ModelContext) { + self.buffer.update(cx, |buffer, cx| { + if let Some(transaction_id) = self.transformation_transaction_id.take() { + buffer.undo_transaction(transaction_id, cx); + buffer.refresh_preview(cx); + } + }); + } + + fn apply_edits( + &mut self, + edits: impl IntoIterator, String)>, + cx: &mut ModelContext, + ) { + let transaction = self.buffer.update(cx, |buffer, cx| { + // Avoid grouping assistant edits with user edits. + buffer.finalize_last_transaction(cx); + buffer.start_transaction(cx); + buffer.edit(edits, None, cx); + buffer.end_transaction(cx) + }); + + if let Some(transaction) = transaction { + if let Some(first_transaction) = self.transformation_transaction_id { + // Group all assistant edits into the first transaction. + self.buffer.update(cx, |buffer, cx| { + buffer.merge_transactions(transaction, first_transaction, cx) + }); + } else { + self.transformation_transaction_id = Some(transaction); + self.buffer + .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx)); + } + } + } + + fn reapply_line_based_diff( + &mut self, + line_operations: impl IntoIterator, + cx: &mut ModelContext, + ) { + let old_snapshot = self.snapshot.clone(); + let old_range = self.range.to_point(&old_snapshot); + let new_snapshot = self.buffer.read(cx).snapshot(cx); + let new_range = self.range.to_point(&new_snapshot); + + let mut old_row = old_range.start.row; + let mut new_row = new_range.start.row; + + self.diff.deleted_row_ranges.clear(); + self.diff.inserted_row_ranges.clear(); + for operation in line_operations { + match operation { + LineOperation::Keep { lines } => { + old_row += lines; + new_row += lines; + } + LineOperation::Delete { lines } => { + let old_end_row = old_row + lines - 1; + let new_row = new_snapshot.anchor_before(Point::new(new_row, 0)); + + if let Some((_, last_deleted_row_range)) = + self.diff.deleted_row_ranges.last_mut() + { + if *last_deleted_row_range.end() + 1 == old_row { + *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row; + } else { + self.diff + .deleted_row_ranges + .push((new_row, old_row..=old_end_row)); + } + } else { + self.diff + .deleted_row_ranges + .push((new_row, old_row..=old_end_row)); + } + + old_row += lines; + } + LineOperation::Insert { lines } => { + let new_end_row = new_row + lines - 1; + let start = new_snapshot.anchor_before(Point::new(new_row, 0)); + let end = new_snapshot.anchor_before(Point::new( + new_end_row, + new_snapshot.line_len(MultiBufferRow(new_end_row)), + )); + self.diff.inserted_row_ranges.push(start..end); + new_row += lines; + } + } + + cx.notify(); + } + } + + fn reapply_batch_diff(&mut self, cx: &mut ModelContext) -> Task<()> { + let old_snapshot = self.snapshot.clone(); + let old_range = self.range.to_point(&old_snapshot); + let new_snapshot = self.buffer.read(cx).snapshot(cx); + let new_range = self.range.to_point(&new_snapshot); + + cx.spawn(|codegen, mut cx| async move { + let (deleted_row_ranges, inserted_row_ranges) = cx + .background_executor() + .spawn(async move { + let old_text = old_snapshot + .text_for_range( + Point::new(old_range.start.row, 0) + ..Point::new( + old_range.end.row, + old_snapshot.line_len(MultiBufferRow(old_range.end.row)), + ), + ) + .collect::(); + let new_text = new_snapshot + .text_for_range( + Point::new(new_range.start.row, 0) + ..Point::new( + new_range.end.row, + new_snapshot.line_len(MultiBufferRow(new_range.end.row)), + ), + ) + .collect::(); + + let mut old_row = old_range.start.row; + let mut new_row = new_range.start.row; + let batch_diff = + similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str()); + + let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive)> = Vec::new(); + let mut inserted_row_ranges = Vec::new(); + for change in batch_diff.iter_all_changes() { + let line_count = change.value().lines().count() as u32; + match change.tag() { + similar::ChangeTag::Equal => { + old_row += line_count; + new_row += line_count; + } + similar::ChangeTag::Delete => { + let old_end_row = old_row + line_count - 1; + let new_row = new_snapshot.anchor_before(Point::new(new_row, 0)); + + if let Some((_, last_deleted_row_range)) = + deleted_row_ranges.last_mut() + { + if *last_deleted_row_range.end() + 1 == old_row { + *last_deleted_row_range = + *last_deleted_row_range.start()..=old_end_row; + } else { + deleted_row_ranges.push((new_row, old_row..=old_end_row)); + } + } else { + deleted_row_ranges.push((new_row, old_row..=old_end_row)); + } + + old_row += line_count; + } + similar::ChangeTag::Insert => { + let new_end_row = new_row + line_count - 1; + let start = new_snapshot.anchor_before(Point::new(new_row, 0)); + let end = new_snapshot.anchor_before(Point::new( + new_end_row, + new_snapshot.line_len(MultiBufferRow(new_end_row)), + )); + inserted_row_ranges.push(start..end); + new_row += line_count; + } + } + } + + (deleted_row_ranges, inserted_row_ranges) + }) + .await; + + codegen + .update(&mut cx, |codegen, cx| { + codegen.diff.deleted_row_ranges = deleted_row_ranges; + codegen.diff.inserted_row_ranges = inserted_row_ranges; + cx.notify(); + }) + .ok(); + }) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum CodegenEvent { + Finished, + Undone, +} + +struct StripInvalidSpans { + stream: T, + stream_done: bool, + buffer: String, + first_line: bool, + line_end: bool, + starts_with_code_block: bool, +} + +impl StripInvalidSpans +where + T: Stream>, +{ + fn new(stream: T) -> Self { + Self { + stream, + stream_done: false, + buffer: String::new(), + first_line: true, + line_end: false, + starts_with_code_block: false, + } + } +} + +impl Stream for StripInvalidSpans +where + T: Stream>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + const CODE_BLOCK_DELIMITER: &str = "```"; + const CURSOR_SPAN: &str = "<|CURSOR|>"; + + let this = unsafe { self.get_unchecked_mut() }; + loop { + if !this.stream_done { + let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) }; + match stream.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + this.buffer.push_str(&chunk); + } + Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))), + Poll::Ready(None) => { + this.stream_done = true; + } + Poll::Pending => return Poll::Pending, + } + } + + let mut chunk = String::new(); + let mut consumed = 0; + if !this.buffer.is_empty() { + let mut lines = this.buffer.split('\n').enumerate().peekable(); + while let Some((line_ix, line)) = lines.next() { + if line_ix > 0 { + this.first_line = false; + } + + if this.first_line { + let trimmed_line = line.trim(); + if lines.peek().is_some() { + if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) { + consumed += line.len() + 1; + this.starts_with_code_block = true; + continue; + } + } else if trimmed_line.is_empty() + || prefixes(CODE_BLOCK_DELIMITER) + .any(|prefix| trimmed_line.starts_with(prefix)) + { + break; + } + } + + let line_without_cursor = line.replace(CURSOR_SPAN, ""); + if lines.peek().is_some() { + if this.line_end { + chunk.push('\n'); + } + + chunk.push_str(&line_without_cursor); + this.line_end = true; + consumed += line.len() + 1; + } else if this.stream_done { + if !this.starts_with_code_block + || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER) + { + if this.line_end { + chunk.push('\n'); + } + + chunk.push_str(&line); + } + + consumed += line.len(); + } else { + let trimmed_line = line.trim(); + if trimmed_line.is_empty() + || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix)) + || prefixes(CODE_BLOCK_DELIMITER) + .any(|prefix| trimmed_line.ends_with(prefix)) + { + break; + } else { + if this.line_end { + chunk.push('\n'); + this.line_end = false; + } + + chunk.push_str(&line_without_cursor); + consumed += line.len(); + } + } + } + } + + this.buffer = this.buffer.split_off(consumed); + if !chunk.is_empty() { + return Poll::Ready(Some(Ok(chunk))); + } else if this.stream_done { + return Poll::Ready(None); + } + } + } +} + +fn prefixes(text: &str) -> impl Iterator { + (0..text.len() - 1).map(|ix| &text[..ix + 1]) +} + +#[derive(Default)] +pub struct Diff { + pub deleted_row_ranges: Vec<(Anchor, RangeInclusive)>, + pub inserted_row_ranges: Vec>, +} + +impl Diff { + fn is_empty(&self) -> bool { + self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::{ + stream::{self}, + Stream, + }; + use gpui::{Context, TestAppContext}; + use indoc::indoc; + use language::{ + language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, + Point, + }; + use language_model::LanguageModelRegistry; + use rand::prelude::*; + use serde::Serialize; + use settings::SettingsStore; + use std::{future, sync::Arc}; + + #[derive(Serialize)] + pub struct DummyCompletionRequest { + pub name: String, + } + + #[gpui::test(iterations = 10)] + async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_model::LanguageModelRegistry::test); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + let x = 0; + for _ in 0..10 { + x += 1; + } + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + true, + None, + None, + prompt_builder, + cx, + ) + }); + + let chunks_tx = simulate_response_stream(codegen.clone(), cx); + + let mut new_text = concat!( + " let mut x = 0;\n", + " while x < 10 {\n", + " x += 1;\n", + " }", + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + chunks_tx.unbounded_send(chunk.to_string()).unwrap(); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_past_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + ) { + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + le + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + true, + None, + None, + prompt_builder, + cx, + ) + }); + + let chunks_tx = simulate_response_stream(codegen.clone(), cx); + + cx.background_executor.run_until_parked(); + + let mut new_text = concat!( + "t mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + chunks_tx.unbounded_send(chunk.to_string()).unwrap(); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_before_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + ) { + cx.update(LanguageModelRegistry::test); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = concat!( + "fn main() {\n", + " \n", + "}\n" // + ); + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + true, + None, + None, + prompt_builder, + cx, + ) + }); + + let chunks_tx = simulate_response_stream(codegen.clone(), cx); + + cx.background_executor.run_until_parked(); + + let mut new_text = concat!( + "let mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + chunks_tx.unbounded_send(chunk.to_string()).unwrap(); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) { + cx.update(LanguageModelRegistry::test); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + func main() { + \tx := 0 + \tfor i := 0; i < 10; i++ { + \t\tx++ + \t} + } + "}; + let buffer = cx.new_model(|cx| Buffer::local(text, cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + true, + None, + None, + prompt_builder, + cx, + ) + }); + + let chunks_tx = simulate_response_stream(codegen.clone(), cx); + let new_text = concat!( + "func main() {\n", + "\tx := 0\n", + "\tfor x < 10 {\n", + "\t\tx++\n", + "\t}", // + ); + chunks_tx.unbounded_send(new_text.to_string()).unwrap(); + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + func main() { + \tx := 0 + \tfor x < 10 { + \t\tx++ + \t} + } + "} + ); + } + + #[gpui::test] + async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) { + cx.update(LanguageModelRegistry::test); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + let x = 0; + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + false, + None, + None, + prompt_builder, + cx, + ) + }); + + let chunks_tx = simulate_response_stream(codegen.clone(), cx); + chunks_tx + .unbounded_send("let mut x = 0;\nx += 1;".to_string()) + .unwrap(); + drop(chunks_tx); + cx.run_until_parked(); + + // The codegen is inactive, so the buffer doesn't get modified. + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + text + ); + + // Activating the codegen applies the changes. + codegen.update(cx, |codegen, cx| codegen.set_active(true, cx)); + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + x += 1; + } + "} + ); + + // Deactivating the codegen undoes the changes. + codegen.update(cx, |codegen, cx| codegen.set_active(false, cx)); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + text + ); + } + + #[gpui::test] + async fn test_strip_invalid_spans_from_codeblock() { + assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; + assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await; + assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await; + assert_chunks( + "```html\n```js\nLorem ipsum dolor\n```\n```", + "```js\nLorem ipsum dolor\n```", + ) + .await; + assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await; + assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await; + assert_chunks("Lorem ipsum", "Lorem ipsum").await; + assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await; + + async fn assert_chunks(text: &str, expected_text: &str) { + for chunk_size in 1..=text.len() { + let actual_text = StripInvalidSpans::new(chunks(text, chunk_size)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await; + assert_eq!( + actual_text, expected_text, + "failed to strip invalid spans, chunk size: {}", + chunk_size + ); + } + } + + fn chunks(text: &str, size: usize) -> impl Stream> { + stream::iter( + text.chars() + .collect::>() + .chunks(size) + .map(|chunk| Ok(chunk.iter().collect::())) + .collect::>(), + ) + } + } + + fn simulate_response_stream( + codegen: Model, + cx: &mut TestAppContext, + ) -> mpsc::UnboundedSender { + let (chunks_tx, chunks_rx) = mpsc::unbounded(); + codegen.update(cx, |codegen, cx| { + codegen.handle_stream( + String::new(), + String::new(), + None, + future::ready(Ok(LanguageModelTextStream { + message_id: None, + stream: chunks_rx.map(Ok).boxed(), + })), + cx, + ); + }); + chunks_tx + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_indents_query( + r#" + (call_expression) @indent + (field_expression) @indent + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + } +} diff --git a/crates/assistant2/src/inline_assistant.rs b/crates/assistant2/src/inline_assistant.rs index 593ca61e43e89f..9535531895019e 100644 --- a/crates/assistant2/src/inline_assistant.rs +++ b/crates/assistant2/src/inline_assistant.rs @@ -1,73 +1,44 @@ -use crate::context::attach_context_to_message; -use crate::context_picker::ContextPicker; +use crate::buffer_codegen::{BufferCodegen, CodegenAlternative, CodegenEvent}; use crate::context_store::ContextStore; -use crate::context_strip::ContextStrip; -use crate::inline_prompt_editor::{ - render_cancel_button, CodegenStatus, PromptEditorEvent, PromptMode, -}; +use crate::inline_prompt_editor::{CodegenStatus, InlineAssistId, PromptEditor, PromptEditorEvent}; use crate::thread_store::ThreadStore; +use crate::AssistantPanel; use crate::{ - assistant_settings::AssistantSettings, - prompts::PromptBuilder, - streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}, + assistant_settings::AssistantSettings, prompts::PromptBuilder, terminal_inline_assistant::TerminalInlineAssistant, - CycleNextInlineAssist, CyclePreviousInlineAssist, }; -use crate::{AssistantPanel, ToggleContextPicker}; use anyhow::{Context as _, Result}; -use client::{telemetry::Telemetry, ErrorExt}; +use client::telemetry::Telemetry; use collections::{hash_map, HashMap, HashSet, VecDeque}; use editor::{ - actions::{MoveDown, MoveUp, SelectAll}, + actions::SelectAll, display_map::{ BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock, ToDisplayPoint, }, - Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode, - EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, - ToOffset as _, ToPoint, + Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorEvent, ExcerptId, ExcerptRange, + GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint, }; -use feature_flags::{FeatureFlagAppExt as _, ZedPro}; use fs::Fs; -use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt}; +use util::ResultExt; + use gpui::{ - anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter, - FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, - Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakModel, WeakView, - WindowContext, + point, AppContext, FocusableView, Global, HighlightStyle, Model, Subscription, Task, + UpdateGlobal, View, ViewContext, WeakModel, WeakView, WindowContext, }; -use language::{Buffer, IndentKind, Point, Selection, TransactionId}; -use language_model::{ - LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelTextStream, Role, -}; -use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; +use language::{Buffer, Point, Selection, TransactionId}; +use language_model::LanguageModelRegistry; use language_models::report_assistant_event; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use project::{CodeAction, ProjectTransaction}; -use rope::Rope; -use settings::{update_settings_file, Settings, SettingsStore}; -use smol::future::FutureExt; -use std::{ - cmp, - future::Future, - iter, mem, - ops::{Range, RangeInclusive}, - pin::Pin, - rc::Rc, - sync::Arc, - task::{self, Poll}, - time::Instant, -}; +use settings::{Settings, SettingsStore}; +use std::{cmp, mem, ops::Range, rc::Rc, sync::Arc}; use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase}; use terminal_view::{terminal_panel::TerminalPanel, TerminalView}; use text::{OffsetRangeExt, ToPoint as _}; -use theme::ThemeSettings; -use ui::{ - prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip, -}; -use util::{RangeExt, ResultExt}; +use ui::prelude::*; +use util::RangeExt; use workspace::{dock::Panel, ShowConfiguration}; use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace}; @@ -366,7 +337,7 @@ impl InlineAssistant { let assist_id = self.next_assist_id.post_inc(); let context_store = cx.new_model(|_cx| ContextStore::new()); let codegen = cx.new_model(|cx| { - Codegen::new( + BufferCodegen::new( editor.read(cx).buffer().clone(), range.clone(), None, @@ -379,7 +350,7 @@ impl InlineAssistant { let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default())); let prompt_editor = cx.new_view(|cx| { - PromptEditor::new( + PromptEditor::new_buffer( assist_id, gutter_dimensions.clone(), self.prompt_history.clone(), @@ -422,6 +393,8 @@ impl InlineAssistant { .or_insert_with(|| EditorInlineAssists::new(&editor, cx)); let mut assist_group = InlineAssistGroup::new(); for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists { + let codegen = prompt_editor.read(cx).codegen().clone(); + self.assists.insert( assist_id, InlineAssist::new( @@ -432,7 +405,7 @@ impl InlineAssistant { prompt_block_id, end_block_id, range, - prompt_editor.read(cx).codegen.clone(), + codegen, workspace.clone(), cx, ), @@ -475,7 +448,7 @@ impl InlineAssistant { let context_store = cx.new_model(|_cx| ContextStore::new()); let codegen = cx.new_model(|cx| { - Codegen::new( + BufferCodegen::new( editor.read(cx).buffer().clone(), range.clone(), initial_transaction_id, @@ -488,7 +461,7 @@ impl InlineAssistant { let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default())); let prompt_editor = cx.new_view(|cx| { - PromptEditor::new( + PromptEditor::new_buffer( assist_id, gutter_dimensions.clone(), self.prompt_history.clone(), @@ -521,7 +494,7 @@ impl InlineAssistant { prompt_block_id, end_block_id, range, - prompt_editor.read(cx).codegen.clone(), + codegen.clone(), workspace.clone(), cx, ), @@ -541,7 +514,7 @@ impl InlineAssistant { &self, editor: &View, range: &Range, - prompt_editor: &View, + prompt_editor: &View>, cx: &mut WindowContext, ) -> [CustomBlockId; 2] { let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| { @@ -643,11 +616,11 @@ impl InlineAssistant { fn handle_prompt_editor_event( &mut self, - prompt_editor: View, + prompt_editor: View>, event: &PromptEditorEvent, cx: &mut WindowContext, ) { - let assist_id = prompt_editor.read(cx).id; + let assist_id = prompt_editor.read(cx).id(); match event { PromptEditorEvent::StartRequested => { self.start_assist(assist_id, cx); @@ -665,7 +638,7 @@ impl InlineAssistant { self.dismiss_assist(assist_id, cx); } PromptEditorEvent::Resized { .. } => { - // This only matters for the terminal inline + // This only matters for the terminal inline assistant } } } @@ -1451,25 +1424,17 @@ impl InlineAssistGroup { } } -fn build_assist_editor_renderer(editor: &View) -> RenderBlock { +fn build_assist_editor_renderer(editor: &View>) -> RenderBlock { let editor = editor.clone(); + Arc::new(move |cx: &mut BlockContext| { - *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions; + let gutter_dimensions = editor.read(cx).gutter_dimensions(); + + *gutter_dimensions.lock() = *cx.gutter_dimensions; editor.clone().into_any_element() }) } -#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct InlineAssistId(usize); - -impl InlineAssistId { - fn post_inc(&mut self) -> InlineAssistId { - let id = *self; - self.0 += 1; - id - } -} - #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] struct InlineAssistGroupId(usize); @@ -1481,689 +1446,12 @@ impl InlineAssistGroupId { } } -struct PromptEditor { - id: InlineAssistId, - editor: View, - context_strip: View, - context_picker_menu_handle: PopoverMenuHandle, - language_model_selector: View, - edited_since_done: bool, - gutter_dimensions: Arc>, - prompt_history: VecDeque, - prompt_history_ix: Option, - pending_prompt: String, - codegen: Model, - _codegen_subscription: Subscription, - editor_subscriptions: Vec, - show_rate_limit_notice: bool, -} - -impl EventEmitter for PromptEditor {} - -impl Render for PromptEditor { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let gutter_dimensions = *self.gutter_dimensions.lock(); - let mut buttons = Vec::new(); - let codegen = self.codegen.read(cx); - if codegen.alternative_count(cx) > 1 { - buttons.push(self.render_cycle_controls(cx)); - } - let prompt_mode = if codegen.is_insertion { - PromptMode::Generate { - supports_execute: false, - } - } else { - PromptMode::Transform - }; - - buttons.extend(render_cancel_button( - codegen.status(cx).into(), - self.edited_since_done, - prompt_mode, - cx, - )); - - v_flex() - .border_y_1() - .border_color(cx.theme().status().info_border) - .size_full() - .py(cx.line_height() / 2.5) - .child( - h_flex() - .key_context("PromptEditor") - .bg(cx.theme().colors().editor_background) - .block_mouse_down() - .cursor(CursorStyle::Arrow) - .on_action(cx.listener(Self::toggle_context_picker)) - .on_action(cx.listener(Self::confirm)) - .on_action(cx.listener(Self::cancel)) - .on_action(cx.listener(Self::move_up)) - .on_action(cx.listener(Self::move_down)) - .capture_action(cx.listener(Self::cycle_prev)) - .capture_action(cx.listener(Self::cycle_next)) - .child( - h_flex() - .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) - .justify_center() - .gap_2() - .child(LanguageModelSelectorPopoverMenu::new( - self.language_model_selector.clone(), - IconButton::new("context", IconName::SettingsAlt) - .shape(IconButtonShape::Square) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .tooltip(move |cx| { - Tooltip::with_meta( - format!( - "Using {}", - LanguageModelRegistry::read_global(cx) - .active_model() - .map(|model| model.name().0) - .unwrap_or_else(|| "No model selected".into()), - ), - None, - "Change Model", - cx, - ) - }), - )) - .map(|el| { - let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) - else { - return el; - }; - - let error_message = SharedString::from(error.to_string()); - if error.error_code() == proto::ErrorCode::RateLimitExceeded - && cx.has_flag::() - { - el.child( - v_flex() - .child( - IconButton::new( - "rate-limit-error", - IconName::XCircle, - ) - .toggle_state(self.show_rate_limit_notice) - .shape(IconButtonShape::Square) - .icon_size(IconSize::Small) - .on_click( - cx.listener(Self::toggle_rate_limit_notice), - ), - ) - .children(self.show_rate_limit_notice.then(|| { - deferred( - anchored() - .position_mode( - gpui::AnchoredPositionMode::Local, - ) - .position(point(px(0.), px(24.))) - .anchor(gpui::Corner::TopLeft) - .child(self.render_rate_limit_notice(cx)), - ) - })), - ) - } else { - el.child( - div() - .id("error") - .tooltip(move |cx| { - Tooltip::text(error_message.clone(), cx) - }) - .child( - Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error), - ), - ) - } - }), - ) - .child(div().flex_1().child(self.render_editor(cx))) - .child(h_flex().gap_2().pr_6().children(buttons)), - ) - .child( - h_flex() - .child( - h_flex() - .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) - .justify_center() - .gap_2(), - ) - .child(self.context_strip.clone()), - ) - } -} - -impl FocusableView for PromptEditor { - fn focus_handle(&self, cx: &AppContext) -> FocusHandle { - self.editor.focus_handle(cx) - } -} - -impl PromptEditor { - const MAX_LINES: u8 = 8; - - #[allow(clippy::too_many_arguments)] - fn new( - id: InlineAssistId, - gutter_dimensions: Arc>, - prompt_history: VecDeque, - prompt_buffer: Model, - codegen: Model, - fs: Arc, - context_store: Model, - workspace: WeakView, - thread_store: Option>, - cx: &mut ViewContext, - ) -> Self { - let prompt_editor = cx.new_view(|cx| { - let mut editor = Editor::new( - EditorMode::AutoHeight { - max_lines: Self::MAX_LINES as usize, - }, - prompt_buffer, - None, - false, - cx, - ); - editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); - // Since the prompt editors for all inline assistants are linked, - // always show the cursor (even when it isn't focused) because - // typing in one will make what you typed appear in all of them. - editor.set_show_cursor_when_unfocused(true, cx); - editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx), cx), cx); - editor - }); - let context_picker_menu_handle = PopoverMenuHandle::default(); - - let mut this = Self { - id, - editor: prompt_editor.clone(), - context_strip: cx.new_view(|cx| { - ContextStrip::new( - context_store, - workspace.clone(), - thread_store.clone(), - prompt_editor.focus_handle(cx), - context_picker_menu_handle.clone(), - cx, - ) - }), - context_picker_menu_handle, - language_model_selector: cx.new_view(|cx| { - let fs = fs.clone(); - LanguageModelSelector::new( - move |model, cx| { - update_settings_file::( - fs.clone(), - cx, - move |settings, _| settings.set_model(model.clone()), - ); - }, - cx, - ) - }), - edited_since_done: false, - gutter_dimensions, - prompt_history, - prompt_history_ix: None, - pending_prompt: String::new(), - _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed), - editor_subscriptions: Vec::new(), - codegen, - show_rate_limit_notice: false, - }; - this.subscribe_to_editor(cx); - this - } - - fn subscribe_to_editor(&mut self, cx: &mut ViewContext) { - self.editor_subscriptions.clear(); - self.editor_subscriptions - .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events)); - } - - fn set_show_cursor_when_unfocused( - &mut self, - show_cursor_when_unfocused: bool, - cx: &mut ViewContext, - ) { - self.editor.update(cx, |editor, cx| { - editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx) - }); - } - - fn unlink(&mut self, cx: &mut ViewContext) { - let prompt = self.prompt(cx); - let focus = self.editor.focus_handle(cx).contains_focused(cx); - self.editor = cx.new_view(|cx| { - let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx); - editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); - editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx), cx), cx); - editor.set_placeholder_text("Add a prompt…", cx); - editor.set_text(prompt, cx); - if focus { - editor.focus(cx); - } - editor - }); - self.subscribe_to_editor(cx); - } - - fn placeholder_text(codegen: &Codegen, cx: &WindowContext) -> String { - let action = if codegen.is_insertion { - "Generate" - } else { - "Transform" - }; - let assistant_panel_keybinding = ui::text_for_action(&crate::ToggleFocus, cx) - .map(|keybinding| format!("{keybinding} to chat ― ")) - .unwrap_or_default(); - - format!("{action}… ({assistant_panel_keybinding}↓↑ for history)") - } - - fn prompt(&self, cx: &AppContext) -> String { - self.editor.read(cx).text(cx) - } - - fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext) { - self.show_rate_limit_notice = !self.show_rate_limit_notice; - if self.show_rate_limit_notice { - cx.focus_view(&self.editor); - } - cx.notify(); - } - - fn handle_prompt_editor_events( - &mut self, - _: View, - event: &EditorEvent, - cx: &mut ViewContext, - ) { - match event { - EditorEvent::Edited { .. } => { - if let Some(workspace) = cx.window_handle().downcast::() { - workspace - .update(cx, |workspace, cx| { - let is_via_ssh = workspace - .project() - .update(cx, |project, _| project.is_via_ssh()); - - workspace - .client() - .telemetry() - .log_edit_event("inline assist", is_via_ssh); - }) - .log_err(); - } - let prompt = self.editor.read(cx).text(cx); - if self - .prompt_history_ix - .map_or(true, |ix| self.prompt_history[ix] != prompt) - { - self.prompt_history_ix.take(); - self.pending_prompt = prompt; - } - - self.edited_since_done = true; - cx.notify(); - } - EditorEvent::Blurred => { - if self.show_rate_limit_notice { - self.show_rate_limit_notice = false; - cx.notify(); - } - } - _ => {} - } - } - - fn handle_codegen_changed(&mut self, _: Model, cx: &mut ViewContext) { - match self.codegen.read(cx).status(cx) { - CodegenStatus::Idle => { - self.editor - .update(cx, |editor, _| editor.set_read_only(false)); - } - CodegenStatus::Pending => { - self.editor - .update(cx, |editor, _| editor.set_read_only(true)); - } - CodegenStatus::Done => { - self.edited_since_done = false; - self.editor - .update(cx, |editor, _| editor.set_read_only(false)); - } - CodegenStatus::Error(error) => { - if cx.has_flag::() - && error.error_code() == proto::ErrorCode::RateLimitExceeded - && !dismissed_rate_limit_notice() - { - self.show_rate_limit_notice = true; - cx.notify(); - } - - self.edited_since_done = false; - self.editor - .update(cx, |editor, _| editor.set_read_only(false)); - } - } - } - - fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext) { - self.context_picker_menu_handle.toggle(cx); - } - - fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { - match self.codegen.read(cx).status(cx) { - CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => { - cx.emit(PromptEditorEvent::CancelRequested); - } - CodegenStatus::Pending => { - cx.emit(PromptEditorEvent::StopRequested); - } - } - } - - fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - match self.codegen.read(cx).status(cx) { - CodegenStatus::Idle => { - cx.emit(PromptEditorEvent::StartRequested); - } - CodegenStatus::Pending => { - cx.emit(PromptEditorEvent::DismissRequested); - } - CodegenStatus::Done => { - if self.edited_since_done { - cx.emit(PromptEditorEvent::StartRequested); - } else { - cx.emit(PromptEditorEvent::ConfirmRequested { execute: false }); - } - } - CodegenStatus::Error(_) => { - cx.emit(PromptEditorEvent::StartRequested); - } - } - } - - fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { - if let Some(ix) = self.prompt_history_ix { - if ix > 0 { - self.prompt_history_ix = Some(ix - 1); - let prompt = self.prompt_history[ix - 1].as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_beginning(&Default::default(), cx); - }); - } - } else if !self.prompt_history.is_empty() { - self.prompt_history_ix = Some(self.prompt_history.len() - 1); - let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_beginning(&Default::default(), cx); - }); - } - } - - fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { - if let Some(ix) = self.prompt_history_ix { - if ix < self.prompt_history.len() - 1 { - self.prompt_history_ix = Some(ix + 1); - let prompt = self.prompt_history[ix + 1].as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_end(&Default::default(), cx) - }); - } else { - self.prompt_history_ix = None; - let prompt = self.pending_prompt.as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_end(&Default::default(), cx) - }); - } - } - } - - fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext) { - self.codegen - .update(cx, |codegen, cx| codegen.cycle_prev(cx)); - } - - fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext) { - self.codegen - .update(cx, |codegen, cx| codegen.cycle_next(cx)); - } - - fn render_cycle_controls(&self, cx: &ViewContext) -> AnyElement { - let codegen = self.codegen.read(cx); - let disabled = matches!(codegen.status(cx), CodegenStatus::Idle); - - let model_registry = LanguageModelRegistry::read_global(cx); - let default_model = model_registry.active_model(); - let alternative_models = model_registry.inline_alternative_models(); - - let get_model_name = |index: usize| -> String { - let name = |model: &Arc| model.name().0.to_string(); - - match index { - 0 => default_model.as_ref().map_or_else(String::new, name), - index if index <= alternative_models.len() => alternative_models - .get(index - 1) - .map_or_else(String::new, name), - _ => String::new(), - } - }; - - let total_models = alternative_models.len() + 1; - - if total_models <= 1 { - return div().into_any_element(); - } - - let current_index = codegen.active_alternative; - let prev_index = (current_index + total_models - 1) % total_models; - let next_index = (current_index + 1) % total_models; - - let prev_model_name = get_model_name(prev_index); - let next_model_name = get_model_name(next_index); - - h_flex() - .child( - IconButton::new("previous", IconName::ChevronLeft) - .icon_color(Color::Muted) - .disabled(disabled || current_index == 0) - .shape(IconButtonShape::Square) - .tooltip({ - let focus_handle = self.editor.focus_handle(cx); - move |cx| { - cx.new_view(|cx| { - let mut tooltip = Tooltip::new("Previous Alternative").key_binding( - KeyBinding::for_action_in( - &CyclePreviousInlineAssist, - &focus_handle, - cx, - ), - ); - if !disabled && current_index != 0 { - tooltip = tooltip.meta(prev_model_name.clone()); - } - tooltip - }) - .into() - } - }) - .on_click(cx.listener(|this, _, cx| { - this.codegen - .update(cx, |codegen, cx| codegen.cycle_prev(cx)) - })), - ) - .child( - Label::new(format!( - "{}/{}", - codegen.active_alternative + 1, - codegen.alternative_count(cx) - )) - .size(LabelSize::Small) - .color(if disabled { - Color::Disabled - } else { - Color::Muted - }), - ) - .child( - IconButton::new("next", IconName::ChevronRight) - .icon_color(Color::Muted) - .disabled(disabled || current_index == total_models - 1) - .shape(IconButtonShape::Square) - .tooltip({ - let focus_handle = self.editor.focus_handle(cx); - move |cx| { - cx.new_view(|cx| { - let mut tooltip = Tooltip::new("Next Alternative").key_binding( - KeyBinding::for_action_in( - &CycleNextInlineAssist, - &focus_handle, - cx, - ), - ); - if !disabled && current_index != total_models - 1 { - tooltip = tooltip.meta(next_model_name.clone()); - } - tooltip - }) - .into() - } - }) - .on_click(cx.listener(|this, _, cx| { - this.codegen - .update(cx, |codegen, cx| codegen.cycle_next(cx)) - })), - ) - .into_any_element() - } - - fn render_rate_limit_notice(&self, cx: &mut ViewContext) -> impl IntoElement { - Popover::new().child( - v_flex() - .occlude() - .p_2() - .child( - Label::new("Out of Tokens") - .size(LabelSize::Small) - .weight(FontWeight::BOLD), - ) - .child(Label::new( - "Try Zed Pro for higher limits, a wider range of models, and more.", - )) - .child( - h_flex() - .justify_between() - .child(CheckboxWithLabel::new( - "dont-show-again", - Label::new("Don't show again"), - if dismissed_rate_limit_notice() { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |selection, cx| { - let is_dismissed = match selection { - ui::ToggleState::Unselected => false, - ui::ToggleState::Indeterminate => return, - ui::ToggleState::Selected => true, - }; - - set_rate_limit_notice_dismissed(is_dismissed, cx) - }, - )) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss", "Dismiss") - .style(ButtonStyle::Transparent) - .on_click(cx.listener(Self::toggle_rate_limit_notice)), - ) - .child(Button::new("more-info", "More Info").on_click( - |_event, cx| { - cx.dispatch_action(Box::new( - zed_actions::OpenAccountSettings, - )) - }, - )), - ), - ), - ) - } - - fn render_editor(&mut self, cx: &mut ViewContext) -> AnyElement { - let font_size = TextSize::Default.rems(cx); - let line_height = font_size.to_pixels(cx.rem_size()) * 1.3; - - v_flex() - .key_context("MessageEditor") - .size_full() - .gap_2() - .p_2() - .bg(cx.theme().colors().editor_background) - .child({ - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color: cx.theme().colors().editor_foreground, - font_family: settings.ui_font.family.clone(), - font_features: settings.ui_font.features.clone(), - font_size: font_size.into(), - font_weight: settings.ui_font.weight, - line_height: line_height.into(), - ..Default::default() - }; - - EditorElement::new( - &self.editor, - EditorStyle { - background: cx.theme().colors().editor_background, - local_player: cx.theme().players().local(), - text: text_style, - ..Default::default() - }, - ) - }) - .into_any_element() - } -} - -const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice"; - -fn dismissed_rate_limit_notice() -> bool { - db::kvp::KEY_VALUE_STORE - .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY) - .log_err() - .map_or(false, |s| s.is_some()) -} - -fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) { - db::write_and_log(cx, move || async move { - if is_dismissed { - db::kvp::KEY_VALUE_STORE - .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into()) - .await - } else { - db::kvp::KEY_VALUE_STORE - .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into()) - .await - } - }) -} - pub struct InlineAssist { group_id: InlineAssistGroupId, range: Range, editor: WeakView, decorations: Option, - codegen: Model, + codegen: Model, _subscriptions: Vec, workspace: WeakView, } @@ -2174,11 +1462,11 @@ impl InlineAssist { assist_id: InlineAssistId, group_id: InlineAssistGroupId, editor: &View, - prompt_editor: &View, + prompt_editor: &View>, prompt_block_id: CustomBlockId, end_block_id: CustomBlockId, range: Range, - codegen: Model, + codegen: Model, workspace: WeakView, cx: &mut WindowContext, ) -> Self { @@ -2273,1060 +1561,55 @@ impl InlineAssist { struct InlineAssistDecorations { prompt_block_id: CustomBlockId, - prompt_editor: View, + prompt_editor: View>, removed_line_block_ids: HashSet, end_block_id: CustomBlockId, } -#[derive(Copy, Clone, Debug)] -pub enum CodegenEvent { - Finished, - Undone, -} - -pub struct Codegen { - alternatives: Vec>, - active_alternative: usize, - seen_alternatives: HashSet, - subscriptions: Vec, - buffer: Model, - range: Range, - initial_transaction_id: Option, - context_store: Model, - telemetry: Arc, - builder: Arc, - is_insertion: bool, +struct AssistantCodeActionProvider { + editor: WeakView, + workspace: WeakView, + thread_store: Option>, } -impl Codegen { - pub fn new( - buffer: Model, - range: Range, - initial_transaction_id: Option, - context_store: Model, - telemetry: Arc, - builder: Arc, - cx: &mut ModelContext, - ) -> Self { - let codegen = cx.new_model(|cx| { - CodegenAlternative::new( - buffer.clone(), - range.clone(), - false, - Some(context_store.clone()), - Some(telemetry.clone()), - builder.clone(), - cx, - ) - }); - let mut this = Self { - is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(), - alternatives: vec![codegen], - active_alternative: 0, - seen_alternatives: HashSet::default(), - subscriptions: Vec::new(), - buffer, - range, - initial_transaction_id, - context_store, - telemetry, - builder, - }; - this.activate(0, cx); - this - } +impl CodeActionProvider for AssistantCodeActionProvider { + fn code_actions( + &self, + buffer: &Model, + range: Range, + cx: &mut WindowContext, + ) -> Task>> { + if !AssistantSettings::get_global(cx).enabled { + return Task::ready(Ok(Vec::new())); + } - fn subscribe_to_alternative(&mut self, cx: &mut ModelContext) { - let codegen = self.active_alternative().clone(); - self.subscriptions.clear(); - self.subscriptions - .push(cx.observe(&codegen, |_, _, cx| cx.notify())); - self.subscriptions - .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event))); - } + let snapshot = buffer.read(cx).snapshot(); + let mut range = range.to_point(&snapshot); - fn active_alternative(&self) -> &Model { - &self.alternatives[self.active_alternative] - } + // Expand the range to line boundaries. + range.start.column = 0; + range.end.column = snapshot.line_len(range.end.row); - fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus { - &self.active_alternative().read(cx).status - } + let mut has_diagnostics = false; + for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) { + range.start = cmp::min(range.start, diagnostic.range.start); + range.end = cmp::max(range.end, diagnostic.range.end); + has_diagnostics = true; + } + if has_diagnostics { + if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) { + if let Some(symbol) = symbols_containing_start.last() { + range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); + range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); + } + } - fn alternative_count(&self, cx: &AppContext) -> usize { - LanguageModelRegistry::read_global(cx) - .inline_alternative_models() - .len() - + 1 - } - - pub fn cycle_prev(&mut self, cx: &mut ModelContext) { - let next_active_ix = if self.active_alternative == 0 { - self.alternatives.len() - 1 - } else { - self.active_alternative - 1 - }; - self.activate(next_active_ix, cx); - } - - pub fn cycle_next(&mut self, cx: &mut ModelContext) { - let next_active_ix = (self.active_alternative + 1) % self.alternatives.len(); - self.activate(next_active_ix, cx); - } - - fn activate(&mut self, index: usize, cx: &mut ModelContext) { - self.active_alternative() - .update(cx, |codegen, cx| codegen.set_active(false, cx)); - self.seen_alternatives.insert(index); - self.active_alternative = index; - self.active_alternative() - .update(cx, |codegen, cx| codegen.set_active(true, cx)); - self.subscribe_to_alternative(cx); - cx.notify(); - } - - pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext) -> Result<()> { - let alternative_models = LanguageModelRegistry::read_global(cx) - .inline_alternative_models() - .to_vec(); - - self.active_alternative() - .update(cx, |alternative, cx| alternative.undo(cx)); - self.activate(0, cx); - self.alternatives.truncate(1); - - for _ in 0..alternative_models.len() { - self.alternatives.push(cx.new_model(|cx| { - CodegenAlternative::new( - self.buffer.clone(), - self.range.clone(), - false, - Some(self.context_store.clone()), - Some(self.telemetry.clone()), - self.builder.clone(), - cx, - ) - })); - } - - let primary_model = LanguageModelRegistry::read_global(cx) - .active_model() - .context("no active model")?; - - for (model, alternative) in iter::once(primary_model) - .chain(alternative_models) - .zip(&self.alternatives) - { - alternative.update(cx, |alternative, cx| { - alternative.start(user_prompt.clone(), model.clone(), cx) - })?; - } - - Ok(()) - } - - pub fn stop(&mut self, cx: &mut ModelContext) { - for codegen in &self.alternatives { - codegen.update(cx, |codegen, cx| codegen.stop(cx)); - } - } - - pub fn undo(&mut self, cx: &mut ModelContext) { - self.active_alternative() - .update(cx, |codegen, cx| codegen.undo(cx)); - - self.buffer.update(cx, |buffer, cx| { - if let Some(transaction_id) = self.initial_transaction_id.take() { - buffer.undo_transaction(transaction_id, cx); - buffer.refresh_preview(cx); - } - }); - } - - pub fn buffer(&self, cx: &AppContext) -> Model { - self.active_alternative().read(cx).buffer.clone() - } - - pub fn old_buffer(&self, cx: &AppContext) -> Model { - self.active_alternative().read(cx).old_buffer.clone() - } - - pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot { - self.active_alternative().read(cx).snapshot.clone() - } - - pub fn edit_position(&self, cx: &AppContext) -> Option { - self.active_alternative().read(cx).edit_position - } - - fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff { - &self.active_alternative().read(cx).diff - } - - pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range] { - self.active_alternative().read(cx).last_equal_ranges() - } -} - -impl EventEmitter for Codegen {} - -pub struct CodegenAlternative { - buffer: Model, - old_buffer: Model, - snapshot: MultiBufferSnapshot, - edit_position: Option, - range: Range, - last_equal_ranges: Vec>, - transformation_transaction_id: Option, - status: CodegenStatus, - generation: Task<()>, - diff: Diff, - context_store: Option>, - telemetry: Option>, - _subscription: gpui::Subscription, - builder: Arc, - active: bool, - edits: Vec<(Range, String)>, - line_operations: Vec, - request: Option, - elapsed_time: Option, - completion: Option, - message_id: Option, -} - -#[derive(Default)] -struct Diff { - deleted_row_ranges: Vec<(Anchor, RangeInclusive)>, - inserted_row_ranges: Vec>, -} - -impl Diff { - fn is_empty(&self) -> bool { - self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty() - } -} - -impl EventEmitter for CodegenAlternative {} - -impl CodegenAlternative { - pub fn new( - buffer: Model, - range: Range, - active: bool, - context_store: Option>, - telemetry: Option>, - builder: Arc, - cx: &mut ModelContext, - ) -> Self { - let snapshot = buffer.read(cx).snapshot(cx); - - let (old_buffer, _, _) = buffer - .read(cx) - .range_to_buffer_ranges(range.clone(), cx) - .pop() - .unwrap(); - let old_buffer = cx.new_model(|cx| { - let old_buffer = old_buffer.read(cx); - let text = old_buffer.as_rope().clone(); - let line_ending = old_buffer.line_ending(); - let language = old_buffer.language().cloned(); - let language_registry = old_buffer.language_registry(); - - let mut buffer = Buffer::local_normalized(text, line_ending, cx); - buffer.set_language(language, cx); - if let Some(language_registry) = language_registry { - buffer.set_language_registry(language_registry) - } - buffer - }); - - Self { - buffer: buffer.clone(), - old_buffer, - edit_position: None, - message_id: None, - snapshot, - last_equal_ranges: Default::default(), - transformation_transaction_id: None, - status: CodegenStatus::Idle, - generation: Task::ready(()), - diff: Diff::default(), - context_store, - telemetry, - _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), - builder, - active, - edits: Vec::new(), - line_operations: Vec::new(), - range, - request: None, - elapsed_time: None, - completion: None, - } - } - - fn set_active(&mut self, active: bool, cx: &mut ModelContext) { - if active != self.active { - self.active = active; - - if self.active { - let edits = self.edits.clone(); - self.apply_edits(edits, cx); - if matches!(self.status, CodegenStatus::Pending) { - let line_operations = self.line_operations.clone(); - self.reapply_line_based_diff(line_operations, cx); - } else { - self.reapply_batch_diff(cx).detach(); - } - } else if let Some(transaction_id) = self.transformation_transaction_id.take() { - self.buffer.update(cx, |buffer, cx| { - buffer.undo_transaction(transaction_id, cx); - buffer.forget_transaction(transaction_id, cx); - }); - } - } - } - - fn handle_buffer_event( - &mut self, - _buffer: Model, - event: &multi_buffer::Event, - cx: &mut ModelContext, - ) { - if let multi_buffer::Event::TransactionUndone { transaction_id } = event { - if self.transformation_transaction_id == Some(*transaction_id) { - self.transformation_transaction_id = None; - self.generation = Task::ready(()); - cx.emit(CodegenEvent::Undone); - } - } - } - - pub fn last_equal_ranges(&self) -> &[Range] { - &self.last_equal_ranges - } - - pub fn start( - &mut self, - user_prompt: String, - model: Arc, - cx: &mut ModelContext, - ) -> Result<()> { - if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() { - self.buffer.update(cx, |buffer, cx| { - buffer.undo_transaction(transformation_transaction_id, cx); - }); - } - - self.edit_position = Some(self.range.start.bias_right(&self.snapshot)); - - let api_key = model.api_key(cx); - let telemetry_id = model.telemetry_id(); - let provider_id = model.provider_id(); - let stream: LocalBoxFuture> = - if user_prompt.trim().to_lowercase() == "delete" { - async { Ok(LanguageModelTextStream::default()) }.boxed_local() - } else { - let request = self.build_request(user_prompt, cx)?; - self.request = Some(request.clone()); - - cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await }) - .boxed_local() - }; - self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); - Ok(()) - } - - fn build_request( - &self, - user_prompt: String, - cx: &mut AppContext, - ) -> Result { - let buffer = self.buffer.read(cx).snapshot(cx); - let language = buffer.language_at(self.range.start); - let language_name = if let Some(language) = language.as_ref() { - if Arc::ptr_eq(language, &language::PLAIN_TEXT) { - None - } else { - Some(language.name()) - } - } else { - None - }; - - let language_name = language_name.as_ref(); - let start = buffer.point_to_buffer_offset(self.range.start); - let end = buffer.point_to_buffer_offset(self.range.end); - let (buffer, range) = if let Some((start, end)) = start.zip(end) { - let (start_buffer, start_buffer_offset) = start; - let (end_buffer, end_buffer_offset) = end; - if start_buffer.remote_id() == end_buffer.remote_id() { - (start_buffer.clone(), start_buffer_offset..end_buffer_offset) - } else { - return Err(anyhow::anyhow!("invalid transformation range")); - } - } else { - return Err(anyhow::anyhow!("invalid transformation range")); - }; - - let prompt = self - .builder - .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range) - .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?; - - let mut request_message = LanguageModelRequestMessage { - role: Role::User, - content: Vec::new(), - cache: false, - }; - - if let Some(context_store) = &self.context_store { - let context = context_store.update(cx, |this, _cx| this.context().clone()); - attach_context_to_message(&mut request_message, context); - } - - request_message.content.push(prompt.into()); - - Ok(LanguageModelRequest { - tools: Vec::new(), - stop: Vec::new(), - temperature: None, - messages: vec![request_message], - }) - } - - pub fn handle_stream( - &mut self, - model_telemetry_id: String, - model_provider_id: String, - model_api_key: Option, - stream: impl 'static + Future>, - cx: &mut ModelContext, - ) { - let start_time = Instant::now(); - let snapshot = self.snapshot.clone(); - let selected_text = snapshot - .text_for_range(self.range.start..self.range.end) - .collect::(); - - let selection_start = self.range.start.to_point(&snapshot); - - // Start with the indentation of the first line in the selection - let mut suggested_line_indent = snapshot - .suggested_indents(selection_start.row..=selection_start.row, cx) - .into_values() - .next() - .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row))); - - // If the first line in the selection does not have indentation, check the following lines - if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space { - for row in selection_start.row..=self.range.end.to_point(&snapshot).row { - let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row)); - // Prefer tabs if a line in the selection uses tabs as indentation - if line_indent.kind == IndentKind::Tab { - suggested_line_indent.kind = IndentKind::Tab; - break; - } - } - } - - let http_client = cx.http_client().clone(); - let telemetry = self.telemetry.clone(); - let language_name = { - let multibuffer = self.buffer.read(cx); - let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx); - ranges - .first() - .and_then(|(buffer, _, _)| buffer.read(cx).language()) - .map(|language| language.name()) - }; - - self.diff = Diff::default(); - self.status = CodegenStatus::Pending; - let mut edit_start = self.range.start.to_offset(&snapshot); - let completion = Arc::new(Mutex::new(String::new())); - let completion_clone = completion.clone(); - - self.generation = cx.spawn(|codegen, mut cx| { - async move { - let stream = stream.await; - let message_id = stream - .as_ref() - .ok() - .and_then(|stream| stream.message_id.clone()); - let generate = async { - let (mut diff_tx, mut diff_rx) = mpsc::channel(1); - let executor = cx.background_executor().clone(); - let message_id = message_id.clone(); - let line_based_stream_diff: Task> = - cx.background_executor().spawn(async move { - let mut response_latency = None; - let request_start = Instant::now(); - let diff = async { - let chunks = StripInvalidSpans::new(stream?.stream); - futures::pin_mut!(chunks); - let mut diff = StreamingDiff::new(selected_text.to_string()); - let mut line_diff = LineDiff::default(); - - let mut new_text = String::new(); - let mut base_indent = None; - let mut line_indent = None; - let mut first_line = true; - - while let Some(chunk) = chunks.next().await { - if response_latency.is_none() { - response_latency = Some(request_start.elapsed()); - } - let chunk = chunk?; - completion_clone.lock().push_str(&chunk); - - let mut lines = chunk.split('\n').peekable(); - while let Some(line) = lines.next() { - new_text.push_str(line); - if line_indent.is_none() { - if let Some(non_whitespace_ch_ix) = - new_text.find(|ch: char| !ch.is_whitespace()) - { - line_indent = Some(non_whitespace_ch_ix); - base_indent = base_indent.or(line_indent); - - let line_indent = line_indent.unwrap(); - let base_indent = base_indent.unwrap(); - let indent_delta = - line_indent as i32 - base_indent as i32; - let mut corrected_indent_len = cmp::max( - 0, - suggested_line_indent.len as i32 + indent_delta, - ) - as usize; - if first_line { - corrected_indent_len = corrected_indent_len - .saturating_sub( - selection_start.column as usize, - ); - } - - let indent_char = suggested_line_indent.char(); - let mut indent_buffer = [0; 4]; - let indent_str = - indent_char.encode_utf8(&mut indent_buffer); - new_text.replace_range( - ..line_indent, - &indent_str.repeat(corrected_indent_len), - ); - } - } - - if line_indent.is_some() { - let char_ops = diff.push_new(&new_text); - line_diff - .push_char_operations(&char_ops, &selected_text); - diff_tx - .send((char_ops, line_diff.line_operations())) - .await?; - new_text.clear(); - } - - if lines.peek().is_some() { - let char_ops = diff.push_new("\n"); - line_diff - .push_char_operations(&char_ops, &selected_text); - diff_tx - .send((char_ops, line_diff.line_operations())) - .await?; - if line_indent.is_none() { - // Don't write out the leading indentation in empty lines on the next line - // This is the case where the above if statement didn't clear the buffer - new_text.clear(); - } - line_indent = None; - first_line = false; - } - } - } - - let mut char_ops = diff.push_new(&new_text); - char_ops.extend(diff.finish()); - line_diff.push_char_operations(&char_ops, &selected_text); - line_diff.finish(&selected_text); - diff_tx - .send((char_ops, line_diff.line_operations())) - .await?; - - anyhow::Ok(()) - }; - - let result = diff.await; - - let error_message = - result.as_ref().err().map(|error| error.to_string()); - report_assistant_event( - AssistantEvent { - conversation_id: None, - message_id, - kind: AssistantKind::Inline, - phase: AssistantPhase::Response, - model: model_telemetry_id, - model_provider: model_provider_id.to_string(), - response_latency, - error_message, - language_name: language_name.map(|name| name.to_proto()), - }, - telemetry, - http_client, - model_api_key, - &executor, - ); - - result?; - Ok(()) - }); - - while let Some((char_ops, line_ops)) = diff_rx.next().await { - codegen.update(&mut cx, |codegen, cx| { - codegen.last_equal_ranges.clear(); - - let edits = char_ops - .into_iter() - .filter_map(|operation| match operation { - CharOperation::Insert { text } => { - let edit_start = snapshot.anchor_after(edit_start); - Some((edit_start..edit_start, text)) - } - CharOperation::Delete { bytes } => { - let edit_end = edit_start + bytes; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - Some((edit_range, String::new())) - } - CharOperation::Keep { bytes } => { - let edit_end = edit_start + bytes; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - codegen.last_equal_ranges.push(edit_range); - None - } - }) - .collect::>(); - - if codegen.active { - codegen.apply_edits(edits.iter().cloned(), cx); - codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx); - } - codegen.edits.extend(edits); - codegen.line_operations = line_ops; - codegen.edit_position = Some(snapshot.anchor_after(edit_start)); - - cx.notify(); - })?; - } - - // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer. - // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff. - // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`. - let batch_diff_task = - codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?; - let (line_based_stream_diff, ()) = - join!(line_based_stream_diff, batch_diff_task); - line_based_stream_diff?; - - anyhow::Ok(()) - }; - - let result = generate.await; - let elapsed_time = start_time.elapsed().as_secs_f64(); - - codegen - .update(&mut cx, |this, cx| { - this.message_id = message_id; - this.last_equal_ranges.clear(); - if let Err(error) = result { - this.status = CodegenStatus::Error(error); - } else { - this.status = CodegenStatus::Done; - } - this.elapsed_time = Some(elapsed_time); - this.completion = Some(completion.lock().clone()); - cx.emit(CodegenEvent::Finished); - cx.notify(); - }) - .ok(); - } - }); - cx.notify(); - } - - pub fn stop(&mut self, cx: &mut ModelContext) { - self.last_equal_ranges.clear(); - if self.diff.is_empty() { - self.status = CodegenStatus::Idle; - } else { - self.status = CodegenStatus::Done; - } - self.generation = Task::ready(()); - cx.emit(CodegenEvent::Finished); - cx.notify(); - } - - pub fn undo(&mut self, cx: &mut ModelContext) { - self.buffer.update(cx, |buffer, cx| { - if let Some(transaction_id) = self.transformation_transaction_id.take() { - buffer.undo_transaction(transaction_id, cx); - buffer.refresh_preview(cx); - } - }); - } - - fn apply_edits( - &mut self, - edits: impl IntoIterator, String)>, - cx: &mut ModelContext, - ) { - let transaction = self.buffer.update(cx, |buffer, cx| { - // Avoid grouping assistant edits with user edits. - buffer.finalize_last_transaction(cx); - buffer.start_transaction(cx); - buffer.edit(edits, None, cx); - buffer.end_transaction(cx) - }); - - if let Some(transaction) = transaction { - if let Some(first_transaction) = self.transformation_transaction_id { - // Group all assistant edits into the first transaction. - self.buffer.update(cx, |buffer, cx| { - buffer.merge_transactions(transaction, first_transaction, cx) - }); - } else { - self.transformation_transaction_id = Some(transaction); - self.buffer - .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx)); - } - } - } - - fn reapply_line_based_diff( - &mut self, - line_operations: impl IntoIterator, - cx: &mut ModelContext, - ) { - let old_snapshot = self.snapshot.clone(); - let old_range = self.range.to_point(&old_snapshot); - let new_snapshot = self.buffer.read(cx).snapshot(cx); - let new_range = self.range.to_point(&new_snapshot); - - let mut old_row = old_range.start.row; - let mut new_row = new_range.start.row; - - self.diff.deleted_row_ranges.clear(); - self.diff.inserted_row_ranges.clear(); - for operation in line_operations { - match operation { - LineOperation::Keep { lines } => { - old_row += lines; - new_row += lines; - } - LineOperation::Delete { lines } => { - let old_end_row = old_row + lines - 1; - let new_row = new_snapshot.anchor_before(Point::new(new_row, 0)); - - if let Some((_, last_deleted_row_range)) = - self.diff.deleted_row_ranges.last_mut() - { - if *last_deleted_row_range.end() + 1 == old_row { - *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row; - } else { - self.diff - .deleted_row_ranges - .push((new_row, old_row..=old_end_row)); - } - } else { - self.diff - .deleted_row_ranges - .push((new_row, old_row..=old_end_row)); - } - - old_row += lines; - } - LineOperation::Insert { lines } => { - let new_end_row = new_row + lines - 1; - let start = new_snapshot.anchor_before(Point::new(new_row, 0)); - let end = new_snapshot.anchor_before(Point::new( - new_end_row, - new_snapshot.line_len(MultiBufferRow(new_end_row)), - )); - self.diff.inserted_row_ranges.push(start..end); - new_row += lines; - } - } - - cx.notify(); - } - } - - fn reapply_batch_diff(&mut self, cx: &mut ModelContext) -> Task<()> { - let old_snapshot = self.snapshot.clone(); - let old_range = self.range.to_point(&old_snapshot); - let new_snapshot = self.buffer.read(cx).snapshot(cx); - let new_range = self.range.to_point(&new_snapshot); - - cx.spawn(|codegen, mut cx| async move { - let (deleted_row_ranges, inserted_row_ranges) = cx - .background_executor() - .spawn(async move { - let old_text = old_snapshot - .text_for_range( - Point::new(old_range.start.row, 0) - ..Point::new( - old_range.end.row, - old_snapshot.line_len(MultiBufferRow(old_range.end.row)), - ), - ) - .collect::(); - let new_text = new_snapshot - .text_for_range( - Point::new(new_range.start.row, 0) - ..Point::new( - new_range.end.row, - new_snapshot.line_len(MultiBufferRow(new_range.end.row)), - ), - ) - .collect::(); - - let mut old_row = old_range.start.row; - let mut new_row = new_range.start.row; - let batch_diff = - similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str()); - - let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive)> = Vec::new(); - let mut inserted_row_ranges = Vec::new(); - for change in batch_diff.iter_all_changes() { - let line_count = change.value().lines().count() as u32; - match change.tag() { - similar::ChangeTag::Equal => { - old_row += line_count; - new_row += line_count; - } - similar::ChangeTag::Delete => { - let old_end_row = old_row + line_count - 1; - let new_row = new_snapshot.anchor_before(Point::new(new_row, 0)); - - if let Some((_, last_deleted_row_range)) = - deleted_row_ranges.last_mut() - { - if *last_deleted_row_range.end() + 1 == old_row { - *last_deleted_row_range = - *last_deleted_row_range.start()..=old_end_row; - } else { - deleted_row_ranges.push((new_row, old_row..=old_end_row)); - } - } else { - deleted_row_ranges.push((new_row, old_row..=old_end_row)); - } - - old_row += line_count; - } - similar::ChangeTag::Insert => { - let new_end_row = new_row + line_count - 1; - let start = new_snapshot.anchor_before(Point::new(new_row, 0)); - let end = new_snapshot.anchor_before(Point::new( - new_end_row, - new_snapshot.line_len(MultiBufferRow(new_end_row)), - )); - inserted_row_ranges.push(start..end); - new_row += line_count; - } - } - } - - (deleted_row_ranges, inserted_row_ranges) - }) - .await; - - codegen - .update(&mut cx, |codegen, cx| { - codegen.diff.deleted_row_ranges = deleted_row_ranges; - codegen.diff.inserted_row_ranges = inserted_row_ranges; - cx.notify(); - }) - .ok(); - }) - } -} - -struct StripInvalidSpans { - stream: T, - stream_done: bool, - buffer: String, - first_line: bool, - line_end: bool, - starts_with_code_block: bool, -} - -impl StripInvalidSpans -where - T: Stream>, -{ - fn new(stream: T) -> Self { - Self { - stream, - stream_done: false, - buffer: String::new(), - first_line: true, - line_end: false, - starts_with_code_block: false, - } - } -} - -impl Stream for StripInvalidSpans -where - T: Stream>, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { - const CODE_BLOCK_DELIMITER: &str = "```"; - const CURSOR_SPAN: &str = "<|CURSOR|>"; - - let this = unsafe { self.get_unchecked_mut() }; - loop { - if !this.stream_done { - let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) }; - match stream.as_mut().poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - this.buffer.push_str(&chunk); - } - Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))), - Poll::Ready(None) => { - this.stream_done = true; - } - Poll::Pending => return Poll::Pending, - } - } - - let mut chunk = String::new(); - let mut consumed = 0; - if !this.buffer.is_empty() { - let mut lines = this.buffer.split('\n').enumerate().peekable(); - while let Some((line_ix, line)) = lines.next() { - if line_ix > 0 { - this.first_line = false; - } - - if this.first_line { - let trimmed_line = line.trim(); - if lines.peek().is_some() { - if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) { - consumed += line.len() + 1; - this.starts_with_code_block = true; - continue; - } - } else if trimmed_line.is_empty() - || prefixes(CODE_BLOCK_DELIMITER) - .any(|prefix| trimmed_line.starts_with(prefix)) - { - break; - } - } - - let line_without_cursor = line.replace(CURSOR_SPAN, ""); - if lines.peek().is_some() { - if this.line_end { - chunk.push('\n'); - } - - chunk.push_str(&line_without_cursor); - this.line_end = true; - consumed += line.len() + 1; - } else if this.stream_done { - if !this.starts_with_code_block - || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER) - { - if this.line_end { - chunk.push('\n'); - } - - chunk.push_str(&line); - } - - consumed += line.len(); - } else { - let trimmed_line = line.trim(); - if trimmed_line.is_empty() - || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix)) - || prefixes(CODE_BLOCK_DELIMITER) - .any(|prefix| trimmed_line.ends_with(prefix)) - { - break; - } else { - if this.line_end { - chunk.push('\n'); - this.line_end = false; - } - - chunk.push_str(&line_without_cursor); - consumed += line.len(); - } - } - } - } - - this.buffer = this.buffer.split_off(consumed); - if !chunk.is_empty() { - return Poll::Ready(Some(Ok(chunk))); - } else if this.stream_done { - return Poll::Ready(None); - } - } - } -} - -struct AssistantCodeActionProvider { - editor: WeakView, - workspace: WeakView, - thread_store: Option>, -} - -impl CodeActionProvider for AssistantCodeActionProvider { - fn code_actions( - &self, - buffer: &Model, - range: Range, - cx: &mut WindowContext, - ) -> Task>> { - if !AssistantSettings::get_global(cx).enabled { - return Task::ready(Ok(Vec::new())); - } - - let snapshot = buffer.read(cx).snapshot(); - let mut range = range.to_point(&snapshot); - - // Expand the range to line boundaries. - range.start.column = 0; - range.end.column = snapshot.line_len(range.end.row); - - let mut has_diagnostics = false; - for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) { - range.start = cmp::min(range.start, diagnostic.range.start); - range.end = cmp::max(range.end, diagnostic.range.end); - has_diagnostics = true; - } - if has_diagnostics { - if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) { - if let Some(symbol) = symbols_containing_start.last() { - range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); - range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); - } - } - - if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) { - if let Some(symbol) = symbols_containing_end.last() { - range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); - range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); - } - } + if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) { + if let Some(symbol) = symbols_containing_end.last() { + range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); + range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); + } + } Task::ready(Ok(vec![CodeAction { server_id: language::LanguageServerId(0), @@ -3410,10 +1693,6 @@ impl CodeActionProvider for AssistantCodeActionProvider { } } -fn prefixes(text: &str) -> impl Iterator { - (0..text.len() - 1).map(|ix| &text[..ix + 1]) -} - fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { ranges.sort_unstable_by(|a, b| { a.start @@ -3435,432 +1714,3 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { } } } - -#[cfg(test)] -mod tests { - use super::*; - use futures::stream::{self}; - use gpui::{Context, TestAppContext}; - use indoc::indoc; - use language::{ - language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, - Point, - }; - use language_model::LanguageModelRegistry; - use rand::prelude::*; - use serde::Serialize; - use settings::SettingsStore; - use std::{future, sync::Arc}; - - #[derive(Serialize)] - pub struct DummyCompletionRequest { - pub name: String, - } - - #[gpui::test(iterations = 10)] - async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_model::LanguageModelRegistry::test); - cx.update(language_settings::init); - - let text = indoc! {" - fn main() { - let x = 0; - for _ in 0..10 { - x += 1; - } - } - "}; - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - CodegenAlternative::new( - buffer.clone(), - range.clone(), - true, - None, - None, - prompt_builder, - cx, - ) - }); - - let chunks_tx = simulate_response_stream(codegen.clone(), cx); - - let mut new_text = concat!( - " let mut x = 0;\n", - " while x < 10 {\n", - " x += 1;\n", - " }", - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - chunks_tx.unbounded_send(chunk.to_string()).unwrap(); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_past_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - ) { - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = indoc! {" - fn main() { - le - } - "}; - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - CodegenAlternative::new( - buffer.clone(), - range.clone(), - true, - None, - None, - prompt_builder, - cx, - ) - }); - - let chunks_tx = simulate_response_stream(codegen.clone(), cx); - - cx.background_executor.run_until_parked(); - - let mut new_text = concat!( - "t mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - chunks_tx.unbounded_send(chunk.to_string()).unwrap(); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_before_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - ) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = concat!( - "fn main() {\n", - " \n", - "}\n" // - ); - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - CodegenAlternative::new( - buffer.clone(), - range.clone(), - true, - None, - None, - prompt_builder, - cx, - ) - }); - - let chunks_tx = simulate_response_stream(codegen.clone(), cx); - - cx.background_executor.run_until_parked(); - - let mut new_text = concat!( - "let mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - chunks_tx.unbounded_send(chunk.to_string()).unwrap(); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = indoc! {" - func main() { - \tx := 0 - \tfor i := 0; i < 10; i++ { - \t\tx++ - \t} - } - "}; - let buffer = cx.new_model(|cx| Buffer::local(text, cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - CodegenAlternative::new( - buffer.clone(), - range.clone(), - true, - None, - None, - prompt_builder, - cx, - ) - }); - - let chunks_tx = simulate_response_stream(codegen.clone(), cx); - let new_text = concat!( - "func main() {\n", - "\tx := 0\n", - "\tfor x < 10 {\n", - "\t\tx++\n", - "\t}", // - ); - chunks_tx.unbounded_send(new_text.to_string()).unwrap(); - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - func main() { - \tx := 0 - \tfor x < 10 { - \t\tx++ - \t} - } - "} - ); - } - - #[gpui::test] - async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = indoc! {" - fn main() { - let x = 0; - } - "}; - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - CodegenAlternative::new( - buffer.clone(), - range.clone(), - false, - None, - None, - prompt_builder, - cx, - ) - }); - - let chunks_tx = simulate_response_stream(codegen.clone(), cx); - chunks_tx - .unbounded_send("let mut x = 0;\nx += 1;".to_string()) - .unwrap(); - drop(chunks_tx); - cx.run_until_parked(); - - // The codegen is inactive, so the buffer doesn't get modified. - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - text - ); - - // Activating the codegen applies the changes. - codegen.update(cx, |codegen, cx| codegen.set_active(true, cx)); - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - x += 1; - } - "} - ); - - // Deactivating the codegen undoes the changes. - codegen.update(cx, |codegen, cx| codegen.set_active(false, cx)); - cx.run_until_parked(); - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - text - ); - } - - #[gpui::test] - async fn test_strip_invalid_spans_from_codeblock() { - assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; - assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await; - assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await; - assert_chunks( - "```html\n```js\nLorem ipsum dolor\n```\n```", - "```js\nLorem ipsum dolor\n```", - ) - .await; - assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await; - assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await; - assert_chunks("Lorem ipsum", "Lorem ipsum").await; - assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await; - - async fn assert_chunks(text: &str, expected_text: &str) { - for chunk_size in 1..=text.len() { - let actual_text = StripInvalidSpans::new(chunks(text, chunk_size)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await; - assert_eq!( - actual_text, expected_text, - "failed to strip invalid spans, chunk size: {}", - chunk_size - ); - } - } - - fn chunks(text: &str, size: usize) -> impl Stream> { - stream::iter( - text.chars() - .collect::>() - .chunks(size) - .map(|chunk| Ok(chunk.iter().collect::())) - .collect::>(), - ) - } - } - - fn simulate_response_stream( - codegen: Model, - cx: &mut TestAppContext, - ) -> mpsc::UnboundedSender { - let (chunks_tx, chunks_rx) = mpsc::unbounded(); - codegen.update(cx, |codegen, cx| { - codegen.handle_stream( - String::new(), - String::new(), - None, - future::ready(Ok(LanguageModelTextStream { - message_id: None, - stream: chunks_rx.map(Ok).boxed(), - })), - cx, - ); - }); - chunks_tx - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_indents_query( - r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - } -} diff --git a/crates/assistant2/src/inline_prompt_editor.rs b/crates/assistant2/src/inline_prompt_editor.rs index 587ca3113dd31b..c146256341109a 100644 --- a/crates/assistant2/src/inline_prompt_editor.rs +++ b/crates/assistant2/src/inline_prompt_editor.rs @@ -1,5 +1,1068 @@ -use gpui::{AnyElement, EventEmitter}; -use ui::{prelude::*, IconButtonShape, Tooltip}; +use crate::buffer_codegen::BufferCodegen; +use crate::context_picker::ContextPicker; +use crate::context_store::ContextStore; +use crate::context_strip::ContextStrip; +use crate::terminal_codegen::TerminalCodegen; +use crate::thread_store::ThreadStore; +use crate::ToggleContextPicker; +use crate::{ + assistant_settings::AssistantSettings, CycleNextInlineAssist, CyclePreviousInlineAssist, +}; +use client::ErrorExt; +use collections::VecDeque; +use editor::{ + actions::{MoveDown, MoveUp}, + Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer, +}; +use feature_flags::{FeatureFlagAppExt as _, ZedPro}; +use fs::Fs; +use gpui::{ + anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter, + FocusHandle, FocusableView, FontWeight, Model, Subscription, TextStyle, View, ViewContext, + WeakModel, WeakView, WindowContext, +}; +use language_model::{LanguageModel, LanguageModelRegistry}; +use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; +use parking_lot::Mutex; +use settings::{update_settings_file, Settings}; +use std::cmp; +use std::sync::Arc; +use theme::ThemeSettings; +use ui::{ + prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip, +}; +use util::ResultExt; +use workspace::Workspace; + +pub struct PromptEditor { + pub editor: View, + mode: PromptEditorMode, + context_strip: View, + context_picker_menu_handle: PopoverMenuHandle, + language_model_selector: View, + edited_since_done: bool, + prompt_history: VecDeque, + prompt_history_ix: Option, + pending_prompt: String, + _codegen_subscription: Subscription, + editor_subscriptions: Vec, + show_rate_limit_notice: bool, + _phantom: std::marker::PhantomData, +} + +impl EventEmitter for PromptEditor {} + +impl Render for PromptEditor { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let mut buttons = Vec::new(); + + let spacing = match &self.mode { + PromptEditorMode::Buffer { + id: _, + codegen, + gutter_dimensions, + } => { + let codegen = codegen.read(cx); + + if codegen.alternative_count(cx) > 1 { + buttons.push(self.render_cycle_controls(&codegen, cx)); + } + + let gutter_dimensions = gutter_dimensions.lock(); + + gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0) + } + PromptEditorMode::Terminal { .. } => Pixels::ZERO, + }; + + buttons.extend(self.render_buttons(cx)); + + v_flex() + .border_y_1() + .border_color(cx.theme().status().info_border) + .size_full() + .py(cx.line_height() / 2.5) + .child( + h_flex() + .key_context("PromptEditor") + .bg(cx.theme().colors().editor_background) + .block_mouse_down() + .cursor(CursorStyle::Arrow) + .on_action(cx.listener(Self::toggle_context_picker)) + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(Self::move_up)) + .on_action(cx.listener(Self::move_down)) + .capture_action(cx.listener(Self::cycle_prev)) + .capture_action(cx.listener(Self::cycle_next)) + .child( + h_flex() + .w(spacing) + .justify_center() + .gap_2() + .child(LanguageModelSelectorPopoverMenu::new( + self.language_model_selector.clone(), + IconButton::new("context", IconName::SettingsAlt) + .shape(IconButtonShape::Square) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(move |cx| { + Tooltip::with_meta( + format!( + "Using {}", + LanguageModelRegistry::read_global(cx) + .active_model() + .map(|model| model.name().0) + .unwrap_or_else(|| "No model selected".into()), + ), + None, + "Change Model", + cx, + ) + }), + )) + .map(|el| { + let CodegenStatus::Error(error) = self.codegen_status(cx) else { + return el; + }; + + let error_message = SharedString::from(error.to_string()); + if error.error_code() == proto::ErrorCode::RateLimitExceeded + && cx.has_flag::() + { + el.child( + v_flex() + .child( + IconButton::new( + "rate-limit-error", + IconName::XCircle, + ) + .toggle_state(self.show_rate_limit_notice) + .shape(IconButtonShape::Square) + .icon_size(IconSize::Small) + .on_click( + cx.listener(Self::toggle_rate_limit_notice), + ), + ) + .children(self.show_rate_limit_notice.then(|| { + deferred( + anchored() + .position_mode( + gpui::AnchoredPositionMode::Local, + ) + .position(point(px(0.), px(24.))) + .anchor(gpui::Corner::TopLeft) + .child(self.render_rate_limit_notice(cx)), + ) + })), + ) + } else { + el.child( + div() + .id("error") + .tooltip(move |cx| { + Tooltip::text(error_message.clone(), cx) + }) + .child( + Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error), + ), + ) + } + }), + ) + .child(div().flex_1().child(self.render_editor(cx))) + .child(h_flex().gap_2().pr_6().children(buttons)), + ) + .child( + h_flex() + .child(h_flex().w(spacing).justify_center().gap_2()) + .child(self.context_strip.clone()), + ) + } +} + +impl FocusableView for PromptEditor { + fn focus_handle(&self, cx: &AppContext) -> FocusHandle { + self.editor.focus_handle(cx) + } +} + +impl PromptEditor { + const MAX_LINES: u8 = 8; + + fn codegen_status<'a>(&'a self, cx: &'a AppContext) -> &'a CodegenStatus { + match &self.mode { + PromptEditorMode::Buffer { codegen, .. } => codegen.read(cx).status(cx), + PromptEditorMode::Terminal { codegen, .. } => &codegen.read(cx).status, + } + } + + fn subscribe_to_editor(&mut self, cx: &mut ViewContext) { + self.editor_subscriptions.clear(); + self.editor_subscriptions + .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events)); + } + + pub fn set_show_cursor_when_unfocused( + &mut self, + show_cursor_when_unfocused: bool, + cx: &mut ViewContext, + ) { + self.editor.update(cx, |editor, cx| { + editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx) + }); + } + + pub fn unlink(&mut self, cx: &mut ViewContext) { + let prompt = self.prompt(cx); + let focus = self.editor.focus_handle(cx).contains_focused(cx); + self.editor = cx.new_view(|cx| { + let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx); + editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); + editor.set_placeholder_text(Self::placeholder_text(&self.mode, cx), cx); + editor.set_placeholder_text("Add a prompt…", cx); + editor.set_text(prompt, cx); + if focus { + editor.focus(cx); + } + editor + }); + self.subscribe_to_editor(cx); + } + + pub fn placeholder_text(mode: &PromptEditorMode, cx: &WindowContext) -> String { + let action = match mode { + PromptEditorMode::Buffer { codegen, .. } => { + if codegen.read(cx).is_insertion { + "Generate" + } else { + "Transform" + } + } + PromptEditorMode::Terminal { .. } => "Generate", + }; + + let assistant_panel_keybinding = ui::text_for_action(&crate::ToggleFocus, cx) + .map(|keybinding| format!("{keybinding} to chat ― ")) + .unwrap_or_default(); + + format!("{action}… ({assistant_panel_keybinding}↓↑ for history)") + } + + pub fn prompt(&self, cx: &AppContext) -> String { + self.editor.read(cx).text(cx) + } + + fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext) { + self.show_rate_limit_notice = !self.show_rate_limit_notice; + if self.show_rate_limit_notice { + cx.focus_view(&self.editor); + } + cx.notify(); + } + + fn handle_prompt_editor_events( + &mut self, + _: View, + event: &EditorEvent, + cx: &mut ViewContext, + ) { + match event { + EditorEvent::Edited { .. } => { + if let Some(workspace) = cx.window_handle().downcast::() { + workspace + .update(cx, |workspace, cx| { + let is_via_ssh = workspace + .project() + .update(cx, |project, _| project.is_via_ssh()); + + workspace + .client() + .telemetry() + .log_edit_event("inline assist", is_via_ssh); + }) + .log_err(); + } + let prompt = self.editor.read(cx).text(cx); + if self + .prompt_history_ix + .map_or(true, |ix| self.prompt_history[ix] != prompt) + { + self.prompt_history_ix.take(); + self.pending_prompt = prompt; + } + + self.edited_since_done = true; + cx.notify(); + } + EditorEvent::Blurred => { + if self.show_rate_limit_notice { + self.show_rate_limit_notice = false; + cx.notify(); + } + } + _ => {} + } + } + + fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext) { + self.context_picker_menu_handle.toggle(cx); + } + + fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { + match self.codegen_status(cx) { + CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => { + cx.emit(PromptEditorEvent::CancelRequested); + } + CodegenStatus::Pending => { + cx.emit(PromptEditorEvent::StopRequested); + } + } + } + + fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + match self.codegen_status(cx) { + CodegenStatus::Idle => { + cx.emit(PromptEditorEvent::StartRequested); + } + CodegenStatus::Pending => { + cx.emit(PromptEditorEvent::DismissRequested); + } + CodegenStatus::Done => { + if self.edited_since_done { + cx.emit(PromptEditorEvent::StartRequested); + } else { + cx.emit(PromptEditorEvent::ConfirmRequested { execute: false }); + } + } + CodegenStatus::Error(_) => { + cx.emit(PromptEditorEvent::StartRequested); + } + } + } + + fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { + if let Some(ix) = self.prompt_history_ix { + if ix > 0 { + self.prompt_history_ix = Some(ix - 1); + let prompt = self.prompt_history[ix - 1].as_str(); + self.editor.update(cx, |editor, cx| { + editor.set_text(prompt, cx); + editor.move_to_beginning(&Default::default(), cx); + }); + } + } else if !self.prompt_history.is_empty() { + self.prompt_history_ix = Some(self.prompt_history.len() - 1); + let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str(); + self.editor.update(cx, |editor, cx| { + editor.set_text(prompt, cx); + editor.move_to_beginning(&Default::default(), cx); + }); + } + } + + fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { + if let Some(ix) = self.prompt_history_ix { + if ix < self.prompt_history.len() - 1 { + self.prompt_history_ix = Some(ix + 1); + let prompt = self.prompt_history[ix + 1].as_str(); + self.editor.update(cx, |editor, cx| { + editor.set_text(prompt, cx); + editor.move_to_end(&Default::default(), cx) + }); + } else { + self.prompt_history_ix = None; + let prompt = self.pending_prompt.as_str(); + self.editor.update(cx, |editor, cx| { + editor.set_text(prompt, cx); + editor.move_to_end(&Default::default(), cx) + }); + } + } + } + + fn render_buttons(&self, cx: &mut ViewContext) -> Vec { + let mode = match &self.mode { + PromptEditorMode::Buffer { codegen, .. } => { + let codegen = codegen.read(cx); + if codegen.is_insertion { + GenerationMode::Generate + } else { + GenerationMode::Transform + } + } + PromptEditorMode::Terminal { .. } => GenerationMode::Generate, + }; + + let codegen_status = self.codegen_status(cx); + + match codegen_status { + CodegenStatus::Idle => { + vec![ + IconButton::new("cancel", IconName::Close) + .icon_color(Color::Muted) + .shape(IconButtonShape::Square) + .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) + .on_click( + cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)), + ) + .into_any_element(), + Button::new("start", mode.start_label()) + .icon(IconName::Return) + .icon_color(Color::Muted) + .on_click( + cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)), + ) + .into_any_element(), + ] + } + CodegenStatus::Pending => vec![ + IconButton::new("cancel", IconName::Close) + .icon_color(Color::Muted) + .shape(IconButtonShape::Square) + .tooltip(|cx| Tooltip::text("Cancel Assist", cx)) + .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested))) + .into_any_element(), + IconButton::new("stop", IconName::Stop) + .icon_color(Color::Error) + .shape(IconButtonShape::Square) + .tooltip(move |cx| { + Tooltip::with_meta( + mode.tooltip_interrupt(), + Some(&menu::Cancel), + "Changes won't be discarded", + cx, + ) + }) + .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested))) + .into_any_element(), + ], + CodegenStatus::Done | CodegenStatus::Error(_) => { + let cancel = IconButton::new("cancel", IconName::Close) + .icon_color(Color::Muted) + .shape(IconButtonShape::Square) + .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) + .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested))) + .into_any_element(); + + let has_error = matches!(codegen_status, CodegenStatus::Error(_)); + if has_error || self.edited_since_done { + vec![ + cancel, + IconButton::new("restart", IconName::RotateCw) + .icon_color(Color::Info) + .shape(IconButtonShape::Square) + .tooltip(move |cx| { + Tooltip::with_meta( + mode.tooltip_restart(), + Some(&menu::Confirm), + "Changes will be discarded", + cx, + ) + }) + .on_click(cx.listener(|_, _, cx| { + cx.emit(PromptEditorEvent::StartRequested); + })) + .into_any_element(), + ] + } else { + let accept = IconButton::new("accept", IconName::Check) + .icon_color(Color::Info) + .shape(IconButtonShape::Square) + .tooltip(move |cx| { + Tooltip::for_action(mode.tooltip_accept(), &menu::Confirm, cx) + }) + .on_click(cx.listener(|_, _, cx| { + cx.emit(PromptEditorEvent::ConfirmRequested { execute: false }); + })) + .into_any_element(); + + match &self.mode { + PromptEditorMode::Terminal { .. } => vec![ + accept, + cancel, + IconButton::new("confirm", IconName::Play) + .icon_color(Color::Info) + .shape(IconButtonShape::Square) + .tooltip(|cx| { + Tooltip::for_action( + "Execute Generated Command", + &menu::SecondaryConfirm, + cx, + ) + }) + .on_click(cx.listener(|_, _, cx| { + cx.emit(PromptEditorEvent::ConfirmRequested { execute: true }); + })) + .into_any_element(), + ], + PromptEditorMode::Buffer { .. } => vec![accept, cancel], + } + } + } + } + } + + fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext) { + match &self.mode { + PromptEditorMode::Buffer { codegen, .. } => { + codegen.update(cx, |codegen, cx| codegen.cycle_prev(cx)); + } + PromptEditorMode::Terminal { .. } => { + // no cycle buttons in terminal mode + } + } + } + + fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext) { + match &self.mode { + PromptEditorMode::Buffer { codegen, .. } => { + codegen.update(cx, |codegen, cx| codegen.cycle_next(cx)); + } + PromptEditorMode::Terminal { .. } => { + // no cycle buttons in terminal mode + } + } + } + + fn render_cycle_controls(&self, codegen: &BufferCodegen, cx: &ViewContext) -> AnyElement { + let disabled = matches!(codegen.status(cx), CodegenStatus::Idle); + + let model_registry = LanguageModelRegistry::read_global(cx); + let default_model = model_registry.active_model(); + let alternative_models = model_registry.inline_alternative_models(); + + let get_model_name = |index: usize| -> String { + let name = |model: &Arc| model.name().0.to_string(); + + match index { + 0 => default_model.as_ref().map_or_else(String::new, name), + index if index <= alternative_models.len() => alternative_models + .get(index - 1) + .map_or_else(String::new, name), + _ => String::new(), + } + }; + + let total_models = alternative_models.len() + 1; + + if total_models <= 1 { + return div().into_any_element(); + } + + let current_index = codegen.active_alternative; + let prev_index = (current_index + total_models - 1) % total_models; + let next_index = (current_index + 1) % total_models; + + let prev_model_name = get_model_name(prev_index); + let next_model_name = get_model_name(next_index); + + h_flex() + .child( + IconButton::new("previous", IconName::ChevronLeft) + .icon_color(Color::Muted) + .disabled(disabled || current_index == 0) + .shape(IconButtonShape::Square) + .tooltip({ + let focus_handle = self.editor.focus_handle(cx); + move |cx| { + cx.new_view(|cx| { + let mut tooltip = Tooltip::new("Previous Alternative").key_binding( + KeyBinding::for_action_in( + &CyclePreviousInlineAssist, + &focus_handle, + cx, + ), + ); + if !disabled && current_index != 0 { + tooltip = tooltip.meta(prev_model_name.clone()); + } + tooltip + }) + .into() + } + }) + .on_click(cx.listener(|this, _, cx| { + this.cycle_prev(&CyclePreviousInlineAssist, cx); + })), + ) + .child( + Label::new(format!( + "{}/{}", + codegen.active_alternative + 1, + codegen.alternative_count(cx) + )) + .size(LabelSize::Small) + .color(if disabled { + Color::Disabled + } else { + Color::Muted + }), + ) + .child( + IconButton::new("next", IconName::ChevronRight) + .icon_color(Color::Muted) + .disabled(disabled || current_index == total_models - 1) + .shape(IconButtonShape::Square) + .tooltip({ + let focus_handle = self.editor.focus_handle(cx); + move |cx| { + cx.new_view(|cx| { + let mut tooltip = Tooltip::new("Next Alternative").key_binding( + KeyBinding::for_action_in( + &CycleNextInlineAssist, + &focus_handle, + cx, + ), + ); + if !disabled && current_index != total_models - 1 { + tooltip = tooltip.meta(next_model_name.clone()); + } + tooltip + }) + .into() + } + }) + .on_click( + cx.listener(|this, _, cx| this.cycle_next(&CycleNextInlineAssist, cx)), + ), + ) + .into_any_element() + } + + fn render_rate_limit_notice(&self, cx: &mut ViewContext) -> impl IntoElement { + Popover::new().child( + v_flex() + .occlude() + .p_2() + .child( + Label::new("Out of Tokens") + .size(LabelSize::Small) + .weight(FontWeight::BOLD), + ) + .child(Label::new( + "Try Zed Pro for higher limits, a wider range of models, and more.", + )) + .child( + h_flex() + .justify_between() + .child(CheckboxWithLabel::new( + "dont-show-again", + Label::new("Don't show again"), + if dismissed_rate_limit_notice() { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |selection, cx| { + let is_dismissed = match selection { + ui::ToggleState::Unselected => false, + ui::ToggleState::Indeterminate => return, + ui::ToggleState::Selected => true, + }; + + set_rate_limit_notice_dismissed(is_dismissed, cx) + }, + )) + .child( + h_flex() + .gap_2() + .child( + Button::new("dismiss", "Dismiss") + .style(ButtonStyle::Transparent) + .on_click(cx.listener(Self::toggle_rate_limit_notice)), + ) + .child(Button::new("more-info", "More Info").on_click( + |_event, cx| { + cx.dispatch_action(Box::new( + zed_actions::OpenAccountSettings, + )) + }, + )), + ), + ), + ) + } + + fn render_editor(&mut self, cx: &mut ViewContext) -> AnyElement { + let font_size = TextSize::Default.rems(cx); + let line_height = font_size.to_pixels(cx.rem_size()) * 1.3; + + v_flex() + .key_context("MessageEditor") + .size_full() + .gap_2() + .p_2() + .bg(cx.theme().colors().editor_background) + .child({ + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().editor_foreground, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_size: font_size.into(), + font_weight: settings.ui_font.weight, + line_height: line_height.into(), + ..Default::default() + }; + + EditorElement::new( + &self.editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + }) + .into_any_element() + } +} + +pub enum PromptEditorMode { + Buffer { + id: InlineAssistId, + codegen: Model, + gutter_dimensions: Arc>, + }, + Terminal { + id: TerminalInlineAssistId, + codegen: Model, + height_in_lines: u8, + }, +} + +pub enum PromptEditorEvent { + StartRequested, + StopRequested, + ConfirmRequested { execute: bool }, + CancelRequested, + DismissRequested, + Resized { height_in_lines: u8 }, +} + +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +pub struct InlineAssistId(pub usize); + +impl InlineAssistId { + pub fn post_inc(&mut self) -> InlineAssistId { + let id = *self; + self.0 += 1; + id + } +} + +impl PromptEditor { + #[allow(clippy::too_many_arguments)] + pub fn new_buffer( + id: InlineAssistId, + gutter_dimensions: Arc>, + prompt_history: VecDeque, + prompt_buffer: Model, + codegen: Model, + fs: Arc, + context_store: Model, + workspace: WeakView, + thread_store: Option>, + cx: &mut ViewContext>, + ) -> PromptEditor { + let codegen_subscription = cx.observe(&codegen, Self::handle_codegen_changed); + let mode = PromptEditorMode::Buffer { + id, + codegen, + gutter_dimensions, + }; + + let prompt_editor = cx.new_view(|cx| { + let mut editor = Editor::new( + EditorMode::AutoHeight { + max_lines: Self::MAX_LINES as usize, + }, + prompt_buffer, + None, + false, + cx, + ); + editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); + // Since the prompt editors for all inline assistants are linked, + // always show the cursor (even when it isn't focused) because + // typing in one will make what you typed appear in all of them. + editor.set_show_cursor_when_unfocused(true, cx); + editor.set_placeholder_text(Self::placeholder_text(&mode, cx), cx); + editor + }); + let context_picker_menu_handle = PopoverMenuHandle::default(); + + let mut this: PromptEditor = PromptEditor { + editor: prompt_editor.clone(), + context_strip: cx.new_view(|cx| { + ContextStrip::new( + context_store, + workspace.clone(), + thread_store.clone(), + prompt_editor.focus_handle(cx), + context_picker_menu_handle.clone(), + cx, + ) + }), + context_picker_menu_handle, + language_model_selector: cx.new_view(|cx| { + let fs = fs.clone(); + LanguageModelSelector::new( + move |model, cx| { + update_settings_file::( + fs.clone(), + cx, + move |settings, _| settings.set_model(model.clone()), + ); + }, + cx, + ) + }), + edited_since_done: false, + prompt_history, + prompt_history_ix: None, + pending_prompt: String::new(), + _codegen_subscription: codegen_subscription, + editor_subscriptions: Vec::new(), + show_rate_limit_notice: false, + mode, + _phantom: Default::default(), + }; + + this.subscribe_to_editor(cx); + this + } + + fn handle_codegen_changed( + &mut self, + _: Model, + cx: &mut ViewContext>, + ) { + match self.codegen_status(cx) { + CodegenStatus::Idle => { + self.editor + .update(cx, |editor, _| editor.set_read_only(false)); + } + CodegenStatus::Pending => { + self.editor + .update(cx, |editor, _| editor.set_read_only(true)); + } + CodegenStatus::Done => { + self.edited_since_done = false; + self.editor + .update(cx, |editor, _| editor.set_read_only(false)); + } + CodegenStatus::Error(error) => { + if cx.has_flag::() + && error.error_code() == proto::ErrorCode::RateLimitExceeded + && !dismissed_rate_limit_notice() + { + self.show_rate_limit_notice = true; + cx.notify(); + } + + self.edited_since_done = false; + self.editor + .update(cx, |editor, _| editor.set_read_only(false)); + } + } + } + + pub fn id(&self) -> InlineAssistId { + match &self.mode { + PromptEditorMode::Buffer { id, .. } => *id, + PromptEditorMode::Terminal { .. } => unreachable!(), + } + } + + pub fn codegen(&self) -> &Model { + match &self.mode { + PromptEditorMode::Buffer { codegen, .. } => codegen, + PromptEditorMode::Terminal { .. } => unreachable!(), + } + } + + pub fn gutter_dimensions(&self) -> &Arc> { + match &self.mode { + PromptEditorMode::Buffer { + gutter_dimensions, .. + } => gutter_dimensions, + PromptEditorMode::Terminal { .. } => unreachable!(), + } + } +} + +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +pub struct TerminalInlineAssistId(pub usize); + +impl TerminalInlineAssistId { + pub fn post_inc(&mut self) -> TerminalInlineAssistId { + let id = *self; + self.0 += 1; + id + } +} + +impl PromptEditor { + #[allow(clippy::too_many_arguments)] + pub fn new_terminal( + id: TerminalInlineAssistId, + prompt_history: VecDeque, + prompt_buffer: Model, + codegen: Model, + fs: Arc, + context_store: Model, + workspace: WeakView, + thread_store: Option>, + cx: &mut ViewContext, + ) -> Self { + let codegen_subscription = cx.observe(&codegen, Self::handle_codegen_changed); + let mode = PromptEditorMode::Terminal { + id, + codegen, + height_in_lines: 1, + }; + + let prompt_editor = cx.new_view(|cx| { + let mut editor = Editor::new( + EditorMode::AutoHeight { + max_lines: Self::MAX_LINES as usize, + }, + prompt_buffer, + None, + false, + cx, + ); + editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); + editor.set_placeholder_text(Self::placeholder_text(&mode, cx), cx); + editor + }); + let context_picker_menu_handle = PopoverMenuHandle::default(); + + let mut this = Self { + editor: prompt_editor.clone(), + context_strip: cx.new_view(|cx| { + ContextStrip::new( + context_store, + workspace.clone(), + thread_store.clone(), + prompt_editor.focus_handle(cx), + context_picker_menu_handle.clone(), + cx, + ) + }), + context_picker_menu_handle, + language_model_selector: cx.new_view(|cx| { + let fs = fs.clone(); + LanguageModelSelector::new( + move |model, cx| { + update_settings_file::( + fs.clone(), + cx, + move |settings, _| settings.set_model(model.clone()), + ); + }, + cx, + ) + }), + edited_since_done: false, + prompt_history, + prompt_history_ix: None, + pending_prompt: String::new(), + _codegen_subscription: codegen_subscription, + editor_subscriptions: Vec::new(), + mode, + show_rate_limit_notice: false, + _phantom: Default::default(), + }; + this.count_lines(cx); + this.subscribe_to_editor(cx); + this + } + + fn count_lines(&mut self, cx: &mut ViewContext) { + let height_in_lines = cmp::max( + 2, // Make the editor at least two lines tall, to account for padding and buttons. + cmp::min( + self.editor + .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1), + Self::MAX_LINES as u32, + ), + ) as u8; + + match &mut self.mode { + PromptEditorMode::Terminal { + height_in_lines: current_height, + .. + } => { + if height_in_lines != *current_height { + *current_height = height_in_lines; + cx.emit(PromptEditorEvent::Resized { height_in_lines }); + } + } + PromptEditorMode::Buffer { .. } => unreachable!(), + } + } + + fn handle_codegen_changed(&mut self, _: Model, cx: &mut ViewContext) { + match &self.codegen().read(cx).status { + CodegenStatus::Idle => { + self.editor + .update(cx, |editor, _| editor.set_read_only(false)); + } + CodegenStatus::Pending => { + self.editor + .update(cx, |editor, _| editor.set_read_only(true)); + } + CodegenStatus::Done | CodegenStatus::Error(_) => { + self.edited_since_done = false; + self.editor + .update(cx, |editor, _| editor.set_read_only(false)); + } + } + } + + pub fn codegen(&self) -> &Model { + match &self.mode { + PromptEditorMode::Buffer { .. } => unreachable!(), + PromptEditorMode::Terminal { codegen, .. } => codegen, + } + } + + pub fn id(&self) -> TerminalInlineAssistId { + match &self.mode { + PromptEditorMode::Buffer { .. } => unreachable!(), + PromptEditorMode::Terminal { id, .. } => *id, + } + } +} + +const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice"; + +fn dismissed_rate_limit_notice() -> bool { + db::kvp::KEY_VALUE_STORE + .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY) + .log_err() + .map_or(false, |s| s.is_some()) +} + +fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) { + db::write_and_log(cx, move || async move { + if is_dismissed { + db::kvp::KEY_VALUE_STORE + .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into()) + .await + } else { + db::kvp::KEY_VALUE_STORE + .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into()) + .await + } + }) +} pub enum CodegenStatus { Idle, @@ -29,163 +1092,36 @@ impl Into for &CodegenStatus { } #[derive(Copy, Clone)] -pub enum PromptMode { - Generate { supports_execute: bool }, +pub enum GenerationMode { + Generate, Transform, } -impl PromptMode { +impl GenerationMode { fn start_label(self) -> &'static str { match self { - PromptMode::Generate { .. } => "Generate", - PromptMode::Transform => "Transform", + GenerationMode::Generate { .. } => "Generate", + GenerationMode::Transform => "Transform", } } fn tooltip_interrupt(self) -> &'static str { match self { - PromptMode::Generate { .. } => "Interrupt Generation", - PromptMode::Transform => "Interrupt Transform", + GenerationMode::Generate { .. } => "Interrupt Generation", + GenerationMode::Transform => "Interrupt Transform", } } fn tooltip_restart(self) -> &'static str { match self { - PromptMode::Generate { .. } => "Restart Generation", - PromptMode::Transform => "Restart Transform", + GenerationMode::Generate { .. } => "Restart Generation", + GenerationMode::Transform => "Restart Transform", } } fn tooltip_accept(self) -> &'static str { match self { - PromptMode::Generate { .. } => "Accept Generation", - PromptMode::Transform => "Accept Transform", - } - } -} - -pub fn render_cancel_button>( - cancel_button_state: CancelButtonState, - edited_since_done: bool, - mode: PromptMode, - cx: &mut ViewContext, -) -> Vec { - match cancel_button_state { - CancelButtonState::Idle => { - vec![ - IconButton::new("cancel", IconName::Close) - .icon_color(Color::Muted) - .shape(IconButtonShape::Square) - .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) - .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested))) - .into_any_element(), - Button::new("start", mode.start_label()) - .icon(IconName::Return) - .icon_color(Color::Muted) - .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested))) - .into_any_element(), - ] - } - CancelButtonState::Pending => vec![ - IconButton::new("cancel", IconName::Close) - .icon_color(Color::Muted) - .shape(IconButtonShape::Square) - .tooltip(|cx| Tooltip::text("Cancel Assist", cx)) - .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested))) - .into_any_element(), - IconButton::new("stop", IconName::Stop) - .icon_color(Color::Error) - .shape(IconButtonShape::Square) - .tooltip(move |cx| { - Tooltip::with_meta( - mode.tooltip_interrupt(), - Some(&menu::Cancel), - "Changes won't be discarded", - cx, - ) - }) - .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested))) - .into_any_element(), - ], - CancelButtonState::Done | CancelButtonState::Error => { - let cancel = IconButton::new("cancel", IconName::Close) - .icon_color(Color::Muted) - .shape(IconButtonShape::Square) - .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) - .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested))) - .into_any_element(); - - let has_error = matches!(cancel_button_state, CancelButtonState::Error); - if has_error || edited_since_done { - vec![ - cancel, - IconButton::new("restart", IconName::RotateCw) - .icon_color(Color::Info) - .shape(IconButtonShape::Square) - .tooltip(move |cx| { - Tooltip::with_meta( - mode.tooltip_restart(), - Some(&menu::Confirm), - "Changes will be discarded", - cx, - ) - }) - .on_click(cx.listener(|_, _, cx| { - cx.emit(PromptEditorEvent::StartRequested); - })) - .into_any_element(), - ] - } else { - let mut buttons = vec![ - cancel, - IconButton::new("accept", IconName::Check) - .icon_color(Color::Info) - .shape(IconButtonShape::Square) - .tooltip(move |cx| { - Tooltip::for_action(mode.tooltip_accept(), &menu::Confirm, cx) - }) - .on_click(cx.listener(|_, _, cx| { - cx.emit(PromptEditorEvent::ConfirmRequested { execute: false }); - })) - .into_any_element(), - ]; - - match mode { - PromptMode::Generate { supports_execute } => { - if supports_execute { - buttons.push( - IconButton::new("confirm", IconName::Play) - .icon_color(Color::Info) - .shape(IconButtonShape::Square) - .tooltip(|cx| { - Tooltip::for_action( - "Execute Generated Command", - &menu::SecondaryConfirm, - cx, - ) - }) - .on_click(cx.listener(|_, _, cx| { - cx.emit(PromptEditorEvent::ConfirmRequested { - execute: true, - }); - })) - .into_any_element(), - ) - } - } - PromptMode::Transform => {} - } - - buttons - } + GenerationMode::Generate { .. } => "Accept Generation", + GenerationMode::Transform => "Accept Transform", } } } - -pub enum PromptEditorEvent { - StartRequested, - StopRequested, - ConfirmRequested { execute: bool }, - CancelRequested, - DismissRequested, - Resized { height_in_lines: u8 }, -} diff --git a/crates/assistant2/src/terminal_codegen.rs b/crates/assistant2/src/terminal_codegen.rs new file mode 100644 index 00000000000000..97cb18e4400bbb --- /dev/null +++ b/crates/assistant2/src/terminal_codegen.rs @@ -0,0 +1,192 @@ +use crate::inline_prompt_editor::CodegenStatus; +use client::telemetry::Telemetry; +use futures::{channel::mpsc, SinkExt, StreamExt}; +use gpui::{AppContext, EventEmitter, Model, ModelContext, Task}; +use language_model::{LanguageModelRegistry, LanguageModelRequest}; +use language_models::report_assistant_event; +use std::{sync::Arc, time::Instant}; +use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase}; +use terminal::Terminal; + +pub struct TerminalCodegen { + pub status: CodegenStatus, + pub telemetry: Option>, + terminal: Model, + generation: Task<()>, + pub message_id: Option, + transaction: Option, +} + +impl EventEmitter for TerminalCodegen {} + +impl TerminalCodegen { + pub fn new(terminal: Model, telemetry: Option>) -> Self { + Self { + terminal, + telemetry, + status: CodegenStatus::Idle, + generation: Task::ready(()), + message_id: None, + transaction: None, + } + } + + pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext) { + let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + return; + }; + + let model_api_key = model.api_key(cx); + let http_client = cx.http_client(); + let telemetry = self.telemetry.clone(); + self.status = CodegenStatus::Pending; + self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); + self.generation = cx.spawn(|this, mut cx| async move { + let model_telemetry_id = model.telemetry_id(); + let model_provider_id = model.provider_id(); + let response = model.stream_completion_text(prompt, &cx).await; + let generate = async { + let message_id = response + .as_ref() + .ok() + .and_then(|response| response.message_id.clone()); + + let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); + + let task = cx.background_executor().spawn({ + let message_id = message_id.clone(); + let executor = cx.background_executor().clone(); + async move { + let mut response_latency = None; + let request_start = Instant::now(); + let task = async { + let mut chunks = response?.stream; + while let Some(chunk) = chunks.next().await { + if response_latency.is_none() { + response_latency = Some(request_start.elapsed()); + } + let chunk = chunk?; + hunks_tx.send(chunk).await?; + } + + anyhow::Ok(()) + }; + + let result = task.await; + + let error_message = result.as_ref().err().map(|error| error.to_string()); + report_assistant_event( + AssistantEvent { + conversation_id: None, + kind: AssistantKind::InlineTerminal, + message_id, + phase: AssistantPhase::Response, + model: model_telemetry_id, + model_provider: model_provider_id.to_string(), + response_latency, + error_message, + language_name: None, + }, + telemetry, + http_client, + model_api_key, + &executor, + ); + + result?; + anyhow::Ok(()) + } + }); + + this.update(&mut cx, |this, _| { + this.message_id = message_id; + })?; + + while let Some(hunk) = hunks_rx.next().await { + this.update(&mut cx, |this, cx| { + if let Some(transaction) = &mut this.transaction { + transaction.push(hunk, cx); + cx.notify(); + } + })?; + } + + task.await?; + anyhow::Ok(()) + }; + + let result = generate.await; + + this.update(&mut cx, |this, cx| { + if let Err(error) = result { + this.status = CodegenStatus::Error(error); + } else { + this.status = CodegenStatus::Done; + } + cx.emit(CodegenEvent::Finished); + cx.notify(); + }) + .ok(); + }); + cx.notify(); + } + + pub fn stop(&mut self, cx: &mut ModelContext) { + self.status = CodegenStatus::Done; + self.generation = Task::ready(()); + cx.emit(CodegenEvent::Finished); + cx.notify(); + } + + pub fn complete(&mut self, cx: &mut ModelContext) { + if let Some(transaction) = self.transaction.take() { + transaction.complete(cx); + } + } + + pub fn undo(&mut self, cx: &mut ModelContext) { + if let Some(transaction) = self.transaction.take() { + transaction.undo(cx); + } + } +} + +#[derive(Copy, Clone, Debug)] +pub enum CodegenEvent { + Finished, +} + +pub const CLEAR_INPUT: &str = "\x15"; +const CARRIAGE_RETURN: &str = "\x0d"; + +struct TerminalTransaction { + terminal: Model, +} + +impl TerminalTransaction { + pub fn start(terminal: Model) -> Self { + Self { terminal } + } + + pub fn push(&mut self, hunk: String, cx: &mut AppContext) { + // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal + let input = Self::sanitize_input(hunk); + self.terminal + .update(cx, |terminal, _| terminal.input(input)); + } + + pub fn undo(&self, cx: &mut AppContext) { + self.terminal + .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string())); + } + + pub fn complete(&self, cx: &mut AppContext) { + self.terminal.update(cx, |terminal, _| { + terminal.input(CARRIAGE_RETURN.to_string()) + }); + } + + fn sanitize_input(input: String) -> String { + input.replace(['\r', '\n'], "") + } +} diff --git a/crates/assistant2/src/terminal_inline_assistant.rs b/crates/assistant2/src/terminal_inline_assistant.rs index c72c7d607120b4..ad8153293a7be0 100644 --- a/crates/assistant2/src/terminal_inline_assistant.rs +++ b/crates/assistant2/src/terminal_inline_assistant.rs @@ -1,38 +1,29 @@ use crate::context::attach_context_to_message; -use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; -use crate::context_strip::ContextStrip; -use crate::inline_prompt_editor::{CodegenStatus, PromptEditorEvent, PromptMode}; +use crate::inline_prompt_editor::{ + CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId, +}; use crate::prompts::PromptBuilder; +use crate::terminal_codegen::{CodegenEvent, TerminalCodegen, CLEAR_INPUT}; use crate::thread_store::ThreadStore; -use crate::ToggleContextPicker; -use crate::{assistant_settings::AssistantSettings, inline_prompt_editor::render_cancel_button}; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; use collections::{HashMap, VecDeque}; -use editor::{ - actions::{MoveDown, MoveUp, SelectAll}, - Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer, -}; +use editor::{actions::SelectAll, MultiBuffer}; use fs::Fs; -use futures::{channel::mpsc, SinkExt, StreamExt}; use gpui::{ - AppContext, Context, EventEmitter, FocusHandle, FocusableView, Global, Model, ModelContext, - Subscription, Task, TextStyle, UpdateGlobal, View, WeakModel, WeakView, + AppContext, Context, FocusableView, Global, Model, Subscription, UpdateGlobal, View, WeakModel, + WeakView, }; use language::Buffer; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; -use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use language_models::report_assistant_event; -use settings::{update_settings_file, Settings}; -use std::{cmp, sync::Arc, time::Instant}; +use std::sync::Arc; use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase}; -use terminal::Terminal; use terminal_view::TerminalView; -use theme::ThemeSettings; -use ui::{prelude::*, text_for_action, IconButtonShape, PopoverMenuHandle, Tooltip}; +use ui::prelude::*; use util::ResultExt; use workspace::{notifications::NotificationId, Toast, Workspace}; @@ -48,17 +39,6 @@ pub fn init( const DEFAULT_CONTEXT_LINES: usize = 50; const PROMPT_HISTORY_MAX_LEN: usize = 20; -#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -struct TerminalInlineAssistId(usize); - -impl TerminalInlineAssistId { - fn post_inc(&mut self) -> TerminalInlineAssistId { - let id = *self; - self.0 += 1; - id - } -} - pub struct TerminalInlineAssistant { next_assist_id: TerminalInlineAssistId, assists: HashMap, @@ -99,10 +79,10 @@ impl TerminalInlineAssistant { MultiBuffer::singleton(cx.new_model(|cx| Buffer::local(String::new(), cx)), cx) }); let context_store = cx.new_model(|_cx| ContextStore::new()); - let codegen = cx.new_model(|_| Codegen::new(terminal, self.telemetry.clone())); + let codegen = cx.new_model(|_| TerminalCodegen::new(terminal, self.telemetry.clone())); let prompt_editor = cx.new_view(|cx| { - PromptEditor::new( + PromptEditor::new_terminal( assist_id, self.prompt_history.clone(), prompt_buffer.clone(), @@ -151,11 +131,11 @@ impl TerminalInlineAssistant { fn handle_prompt_editor_event( &mut self, - prompt_editor: View, + prompt_editor: View>, event: &PromptEditorEvent, cx: &mut WindowContext, ) { - let assist_id = prompt_editor.read(cx).id; + let assist_id = prompt_editor.read(cx).id(); match event { PromptEditorEvent::StartRequested => { self.start_assist(assist_id, cx); @@ -381,8 +361,8 @@ impl TerminalInlineAssistant { struct TerminalInlineAssist { terminal: WeakView, - prompt_editor: Option>, - codegen: Model, + prompt_editor: Option>>, + codegen: Model, workspace: WeakView, context_store: Model, _subscriptions: Vec, @@ -392,12 +372,12 @@ impl TerminalInlineAssist { pub fn new( assist_id: TerminalInlineAssistId, terminal: &View, - prompt_editor: View, + prompt_editor: View>, workspace: WeakView, context_store: Model, cx: &mut WindowContext, ) -> Self { - let codegen = prompt_editor.read(cx).codegen.clone(); + let codegen = prompt_editor.read(cx).codegen().clone(); Self { terminal: terminal.downgrade(), prompt_editor: Some(prompt_editor.clone()), @@ -448,556 +428,3 @@ impl TerminalInlineAssist { } } } - -struct PromptEditor { - id: TerminalInlineAssistId, - height_in_lines: u8, - editor: View, - context_strip: View, - context_picker_menu_handle: PopoverMenuHandle, - language_model_selector: View, - edited_since_done: bool, - prompt_history: VecDeque, - prompt_history_ix: Option, - pending_prompt: String, - codegen: Model, - _codegen_subscription: Subscription, - editor_subscriptions: Vec, -} - -impl EventEmitter for PromptEditor {} - -impl Render for PromptEditor { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let mut buttons = Vec::new(); - - buttons.extend(render_cancel_button( - (&self.codegen.read(cx).status).into(), - self.edited_since_done, - PromptMode::Generate { - supports_execute: true, - }, - cx, - )); - - v_flex() - .border_y_1() - .border_color(cx.theme().status().info_border) - .py_2() - .size_full() - .child( - h_flex() - .key_context("PromptEditor") - .bg(cx.theme().colors().editor_background) - .on_action(cx.listener(Self::toggle_context_picker)) - .on_action(cx.listener(Self::confirm)) - .on_action(cx.listener(Self::secondary_confirm)) - .on_action(cx.listener(Self::cancel)) - .on_action(cx.listener(Self::move_up)) - .on_action(cx.listener(Self::move_down)) - .child( - h_flex() - .w_12() - .justify_center() - .gap_2() - .child(LanguageModelSelectorPopoverMenu::new( - self.language_model_selector.clone(), - IconButton::new("context", IconName::SettingsAlt) - .shape(IconButtonShape::Square) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .tooltip(move |cx| { - Tooltip::with_meta( - format!( - "Using {}", - LanguageModelRegistry::read_global(cx) - .active_model() - .map(|model| model.name().0) - .unwrap_or_else(|| "No model selected".into()), - ), - None, - "Change Model", - cx, - ) - }), - )) - .children( - if let CodegenStatus::Error(error) = &self.codegen.read(cx).status { - let error_message = SharedString::from(error.to_string()); - Some( - div() - .id("error") - .tooltip(move |cx| { - Tooltip::text(error_message.clone(), cx) - }) - .child( - Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error), - ), - ) - } else { - None - }, - ), - ) - .child(div().flex_1().child(self.render_prompt_editor(cx))) - .child(h_flex().gap_1().pr_4().children(buttons)), - ) - .child(h_flex().child(self.context_strip.clone())) - } -} - -impl FocusableView for PromptEditor { - fn focus_handle(&self, cx: &AppContext) -> FocusHandle { - self.editor.focus_handle(cx) - } -} - -impl PromptEditor { - const MAX_LINES: u8 = 8; - - #[allow(clippy::too_many_arguments)] - fn new( - id: TerminalInlineAssistId, - prompt_history: VecDeque, - prompt_buffer: Model, - codegen: Model, - fs: Arc, - context_store: Model, - workspace: WeakView, - thread_store: Option>, - cx: &mut ViewContext, - ) -> Self { - let prompt_editor = cx.new_view(|cx| { - let mut editor = Editor::new( - EditorMode::AutoHeight { - max_lines: Self::MAX_LINES as usize, - }, - prompt_buffer, - None, - false, - cx, - ); - editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); - editor.set_placeholder_text(Self::placeholder_text(cx), cx); - editor - }); - let context_picker_menu_handle = PopoverMenuHandle::default(); - - let mut this = Self { - id, - height_in_lines: 1, - editor: prompt_editor.clone(), - context_strip: cx.new_view(|cx| { - ContextStrip::new( - context_store, - workspace.clone(), - thread_store.clone(), - prompt_editor.focus_handle(cx), - context_picker_menu_handle.clone(), - cx, - ) - }), - context_picker_menu_handle, - language_model_selector: cx.new_view(|cx| { - let fs = fs.clone(); - LanguageModelSelector::new( - move |model, cx| { - update_settings_file::( - fs.clone(), - cx, - move |settings, _| settings.set_model(model.clone()), - ); - }, - cx, - ) - }), - edited_since_done: false, - prompt_history, - prompt_history_ix: None, - pending_prompt: String::new(), - _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed), - editor_subscriptions: Vec::new(), - codegen, - }; - this.count_lines(cx); - this.subscribe_to_editor(cx); - this - } - - fn placeholder_text(cx: &WindowContext) -> String { - let context_keybinding = text_for_action(&crate::ToggleFocus, cx) - .map(|keybinding| format!(" • {keybinding} for context")) - .unwrap_or_default(); - - format!("Generate…{context_keybinding} ↓↑ for history") - } - - fn subscribe_to_editor(&mut self, cx: &mut ViewContext) { - self.editor_subscriptions.clear(); - self.editor_subscriptions - .push(cx.observe(&self.editor, Self::handle_prompt_editor_changed)); - self.editor_subscriptions - .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events)); - } - - fn prompt(&self, cx: &AppContext) -> String { - self.editor.read(cx).text(cx) - } - - fn count_lines(&mut self, cx: &mut ViewContext) { - let height_in_lines = cmp::max( - 2, // Make the editor at least two lines tall, to account for padding and buttons. - cmp::min( - self.editor - .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1), - Self::MAX_LINES as u32, - ), - ) as u8; - - if height_in_lines != self.height_in_lines { - self.height_in_lines = height_in_lines; - cx.emit(PromptEditorEvent::Resized { height_in_lines }); - } - } - - fn handle_prompt_editor_changed(&mut self, _: View, cx: &mut ViewContext) { - self.count_lines(cx); - } - - fn handle_prompt_editor_events( - &mut self, - _: View, - event: &EditorEvent, - cx: &mut ViewContext, - ) { - match event { - EditorEvent::Edited { .. } => { - let prompt = self.editor.read(cx).text(cx); - if self - .prompt_history_ix - .map_or(true, |ix| self.prompt_history[ix] != prompt) - { - self.prompt_history_ix.take(); - self.pending_prompt = prompt; - } - - self.edited_since_done = true; - cx.notify(); - } - _ => {} - } - } - - fn handle_codegen_changed(&mut self, _: Model, cx: &mut ViewContext) { - match &self.codegen.read(cx).status { - CodegenStatus::Idle => { - self.editor - .update(cx, |editor, _| editor.set_read_only(false)); - } - CodegenStatus::Pending => { - self.editor - .update(cx, |editor, _| editor.set_read_only(true)); - } - CodegenStatus::Done | CodegenStatus::Error(_) => { - self.edited_since_done = false; - self.editor - .update(cx, |editor, _| editor.set_read_only(false)); - } - } - } - - fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext) { - self.context_picker_menu_handle.toggle(cx); - } - - fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { - match &self.codegen.read(cx).status { - CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => { - cx.emit(PromptEditorEvent::CancelRequested); - } - CodegenStatus::Pending => { - cx.emit(PromptEditorEvent::StopRequested); - } - } - } - - fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - match &self.codegen.read(cx).status { - CodegenStatus::Idle => { - if !self.editor.read(cx).text(cx).trim().is_empty() { - cx.emit(PromptEditorEvent::StartRequested); - } - } - CodegenStatus::Pending => { - cx.emit(PromptEditorEvent::DismissRequested); - } - CodegenStatus::Done => { - if self.edited_since_done { - cx.emit(PromptEditorEvent::StartRequested); - } else { - cx.emit(PromptEditorEvent::ConfirmRequested { execute: false }); - } - } - CodegenStatus::Error(_) => { - cx.emit(PromptEditorEvent::StartRequested); - } - } - } - - fn secondary_confirm(&mut self, _: &menu::SecondaryConfirm, cx: &mut ViewContext) { - if matches!(self.codegen.read(cx).status, CodegenStatus::Done) { - cx.emit(PromptEditorEvent::ConfirmRequested { execute: true }); - } - } - - fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { - if let Some(ix) = self.prompt_history_ix { - if ix > 0 { - self.prompt_history_ix = Some(ix - 1); - let prompt = self.prompt_history[ix - 1].as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_beginning(&Default::default(), cx); - }); - } - } else if !self.prompt_history.is_empty() { - self.prompt_history_ix = Some(self.prompt_history.len() - 1); - let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_beginning(&Default::default(), cx); - }); - } - } - - fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { - if let Some(ix) = self.prompt_history_ix { - if ix < self.prompt_history.len() - 1 { - self.prompt_history_ix = Some(ix + 1); - let prompt = self.prompt_history[ix + 1].as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_end(&Default::default(), cx) - }); - } else { - self.prompt_history_ix = None; - let prompt = self.pending_prompt.as_str(); - self.editor.update(cx, |editor, cx| { - editor.set_text(prompt, cx); - editor.move_to_end(&Default::default(), cx) - }); - } - } - } - - fn render_prompt_editor(&self, cx: &mut ViewContext) -> impl IntoElement { - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color: if self.editor.read(cx).read_only(cx) { - cx.theme().colors().text_disabled - } else { - cx.theme().colors().text - }, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_size: settings.buffer_font_size.into(), - font_weight: settings.buffer_font.weight, - line_height: relative(settings.buffer_line_height.value()), - ..Default::default() - }; - EditorElement::new( - &self.editor, - EditorStyle { - background: cx.theme().colors().editor_background, - local_player: cx.theme().players().local(), - text: text_style, - ..Default::default() - }, - ) - } -} - -#[derive(Debug)] -pub enum CodegenEvent { - Finished, -} - -impl EventEmitter for Codegen {} - -const CLEAR_INPUT: &str = "\x15"; -const CARRIAGE_RETURN: &str = "\x0d"; - -struct TerminalTransaction { - terminal: Model, -} - -impl TerminalTransaction { - pub fn start(terminal: Model) -> Self { - Self { terminal } - } - - pub fn push(&mut self, hunk: String, cx: &mut AppContext) { - // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal - let input = Self::sanitize_input(hunk); - self.terminal - .update(cx, |terminal, _| terminal.input(input)); - } - - pub fn undo(&self, cx: &mut AppContext) { - self.terminal - .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string())); - } - - pub fn complete(&self, cx: &mut AppContext) { - self.terminal.update(cx, |terminal, _| { - terminal.input(CARRIAGE_RETURN.to_string()) - }); - } - - fn sanitize_input(input: String) -> String { - input.replace(['\r', '\n'], "") - } -} - -pub struct Codegen { - status: CodegenStatus, - telemetry: Option>, - terminal: Model, - generation: Task<()>, - message_id: Option, - transaction: Option, -} - -impl Codegen { - pub fn new(terminal: Model, telemetry: Option>) -> Self { - Self { - terminal, - telemetry, - status: CodegenStatus::Idle, - generation: Task::ready(()), - message_id: None, - transaction: None, - } - } - - pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext) { - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { - return; - }; - - let model_api_key = model.api_key(cx); - let http_client = cx.http_client(); - let telemetry = self.telemetry.clone(); - self.status = CodegenStatus::Pending; - self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); - self.generation = cx.spawn(|this, mut cx| async move { - let model_telemetry_id = model.telemetry_id(); - let model_provider_id = model.provider_id(); - let response = model.stream_completion_text(prompt, &cx).await; - let generate = async { - let message_id = response - .as_ref() - .ok() - .and_then(|response| response.message_id.clone()); - - let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); - - let task = cx.background_executor().spawn({ - let message_id = message_id.clone(); - let executor = cx.background_executor().clone(); - async move { - let mut response_latency = None; - let request_start = Instant::now(); - let task = async { - let mut chunks = response?.stream; - while let Some(chunk) = chunks.next().await { - if response_latency.is_none() { - response_latency = Some(request_start.elapsed()); - } - let chunk = chunk?; - hunks_tx.send(chunk).await?; - } - - anyhow::Ok(()) - }; - - let result = task.await; - - let error_message = result.as_ref().err().map(|error| error.to_string()); - report_assistant_event( - AssistantEvent { - conversation_id: None, - kind: AssistantKind::InlineTerminal, - message_id, - phase: AssistantPhase::Response, - model: model_telemetry_id, - model_provider: model_provider_id.to_string(), - response_latency, - error_message, - language_name: None, - }, - telemetry, - http_client, - model_api_key, - &executor, - ); - - result?; - anyhow::Ok(()) - } - }); - - this.update(&mut cx, |this, _| { - this.message_id = message_id; - })?; - - while let Some(hunk) = hunks_rx.next().await { - this.update(&mut cx, |this, cx| { - if let Some(transaction) = &mut this.transaction { - transaction.push(hunk, cx); - cx.notify(); - } - })?; - } - - task.await?; - anyhow::Ok(()) - }; - - let result = generate.await; - - this.update(&mut cx, |this, cx| { - if let Err(error) = result { - this.status = CodegenStatus::Error(error); - } else { - this.status = CodegenStatus::Done; - } - cx.emit(CodegenEvent::Finished); - cx.notify(); - }) - .ok(); - }); - cx.notify(); - } - - pub fn stop(&mut self, cx: &mut ModelContext) { - self.status = CodegenStatus::Done; - self.generation = Task::ready(()); - cx.emit(CodegenEvent::Finished); - cx.notify(); - } - - pub fn complete(&mut self, cx: &mut ModelContext) { - if let Some(transaction) = self.transaction.take() { - transaction.complete(cx); - } - } - - pub fn undo(&mut self, cx: &mut ModelContext) { - if let Some(transaction) = self.transaction.take() { - transaction.undo(cx); - } - } -}