From 6a31dd5fa71818f7a8e4b82dfa71fcd50f3cbe67 Mon Sep 17 00:00:00 2001 From: Gil LaHaye Date: Sat, 9 Dec 2023 09:29:16 -0800 Subject: [PATCH] Use chatId from URL rather than from payload for chats (#700) ### Motivation and Context The verify access to a chat, we use HandleRequest() with the chatId provided. Currently, we get this from the payload, which can differ from the chatId from the URL, which opens us to a security problem where a user could inject an arbitrary chatId in the payload, which doesn't match what's in the URL. ### Description - Use chatId from URL and only from URL - Add integrations test to validate this ### Contribution Checklist - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [Contribution Guidelines](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone :smile: --- integration-tests/ChatTests.cs | 50 ++++++++++++++++++++++++++ webapi/Controllers/ChatController.cs | 28 +++++++++++---- webapi/Extensions/ServiceExtensions.cs | 5 --- webapi/Program.cs | 1 - webapi/Utilities/AskConverter.cs | 41 --------------------- 5 files changed, 72 insertions(+), 53 deletions(-) create mode 100644 integration-tests/ChatTests.cs delete mode 100644 webapi/Utilities/AskConverter.cs diff --git a/integration-tests/ChatTests.cs b/integration-tests/ChatTests.cs new file mode 100644 index 000000000..6e05d5b7b --- /dev/null +++ b/integration-tests/ChatTests.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Json; +using System.Text.Json; +using CopilotChat.WebApi.Models.Request; +using CopilotChat.WebApi.Models.Response; +using Xunit; +using static CopilotChat.WebApi.Models.Storage.CopilotChatMessage; + +namespace ChatCopilotIntegrationTests; + +public class ChatTests : ChatCopilotIntegrationTest +{ + [Fact] + public async void ChatMessagePostSucceedsWithValidInput() + { + await this.SetUpAuth(); + + // Create chat session + var createChatParams = new CreateChatParameters() { Title = nameof(ChatMessagePostSucceedsWithValidInput) }; + HttpResponseMessage response = await this._httpClient.PostAsJsonAsync("chats", createChatParams); + response.EnsureSuccessStatusCode(); + + var contentStream = await response.Content.ReadAsStreamAsync(); + var createChatResponse = await JsonSerializer.DeserializeAsync(contentStream, new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + Assert.NotNull(createChatResponse); + + // Ask something to the bot + var ask = new Ask + { + Input = "Who is Satya Nadella?", + Variables = new KeyValuePair[] { new("MessageType", ChatMessageType.Message.ToString()) } + }; + response = await this._httpClient.PostAsJsonAsync($"chats/{createChatResponse.ChatSession.Id}/messages", ask); + response.EnsureSuccessStatusCode(); + + contentStream = await response.Content.ReadAsStreamAsync(); + var askResult = await JsonSerializer.DeserializeAsync(contentStream, new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + Assert.NotNull(askResult); + Assert.False(string.IsNullOrEmpty(askResult.Value)); + + + // Clean up + response = await this._httpClient.DeleteAsync($"chats/{createChatResponse.ChatSession.Id}"); + response.EnsureSuccessStatusCode(); + } +} + diff --git a/webapi/Controllers/ChatController.cs b/webapi/Controllers/ChatController.cs index 2519303b0..22de83e09 100644 --- a/webapi/Controllers/ChatController.cs +++ b/webapi/Controllers/ChatController.cs @@ -99,7 +99,6 @@ public async Task ChatAsync( [FromServices] IKernel kernel, [FromServices] IHubContext messageRelayHubContext, [FromServices] CopilotChatPlanner planner, - [FromServices] AskConverter askConverter, [FromServices] ChatSessionRepository chatSessionRepository, [FromServices] ChatParticipantRepository chatParticipantRepository, [FromServices] IAuthInfo authInfo, @@ -108,7 +107,7 @@ public async Task ChatAsync( { this._logger.LogDebug("Chat message received."); - return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); + return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); } /// @@ -135,7 +134,6 @@ public async Task ProcessPlanAsync( [FromServices] IKernel kernel, [FromServices] IHubContext messageRelayHubContext, [FromServices] CopilotChatPlanner planner, - [FromServices] AskConverter askConverter, [FromServices] ChatSessionRepository chatSessionRepository, [FromServices] ChatParticipantRepository chatParticipantRepository, [FromServices] IAuthInfo authInfo, @@ -144,7 +142,7 @@ public async Task ProcessPlanAsync( { this._logger.LogDebug("plan request received."); - return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); + return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); } /// @@ -166,7 +164,6 @@ private async Task HandleRequest( IKernel kernel, IHubContext messageRelayHubContext, CopilotChatPlanner planner, - AskConverter askConverter, ChatSessionRepository chatSessionRepository, ChatParticipantRepository chatParticipantRepository, IAuthInfo authInfo, @@ -174,7 +171,7 @@ private async Task HandleRequest( string chatId) { // Put ask's variables in the context we will use. - var contextVariables = askConverter.GetContextVariables(ask); + var contextVariables = GetContextVariables(ask, authInfo, chatId); // Verify that the chat exists and that the user has access to it. ChatSession? chat = null; @@ -415,6 +412,25 @@ await planner.Kernel.ImportOpenAIPluginFunctionsAsync( return; } + private static ContextVariables GetContextVariables(Ask ask, IAuthInfo authInfo, string chatId) + { + const string UserIdKey = "userId"; + const string UserNameKey = "userName"; + const string ChatIdKey = "chatId"; + + var contextVariables = new ContextVariables(ask.Input); + foreach (var variable in ask.Variables) + { + contextVariables.Set(variable.Key, variable.Value); + } + + contextVariables.Set(UserIdKey, authInfo.UserId); + contextVariables.Set(UserNameKey, authInfo.Name); + contextVariables.Set(ChatIdKey, chatId); + + return contextVariables; + } + /// /// Dispose of the object. /// diff --git a/webapi/Extensions/ServiceExtensions.cs b/webapi/Extensions/ServiceExtensions.cs index 120821e80..a5e790aa2 100644 --- a/webapi/Extensions/ServiceExtensions.cs +++ b/webapi/Extensions/ServiceExtensions.cs @@ -81,11 +81,6 @@ internal static void AddOptions(this IServiceCollection services, ICon .PostConfigure(TrimStringProperties); } - internal static IServiceCollection AddUtilities(this IServiceCollection services) - { - return services.AddScoped(); - } - internal static IServiceCollection AddPlugins(this IServiceCollection services, IConfiguration configuration) { var plugins = configuration.GetSection("Plugins").Get>() ?? new List(); diff --git a/webapi/Program.cs b/webapi/Program.cs index eaec533c5..051e0ddcb 100644 --- a/webapi/Program.cs +++ b/webapi/Program.cs @@ -45,7 +45,6 @@ public static async Task Main(string[] args) .AddOptions(builder.Configuration) .AddPersistentChatStore() .AddPlugins(builder.Configuration) - .AddUtilities() .AddChatCopilotAuthentication(builder.Configuration) .AddChatCopilotAuthorization(); diff --git a/webapi/Utilities/AskConverter.cs b/webapi/Utilities/AskConverter.cs deleted file mode 100644 index ad911b8f3..000000000 --- a/webapi/Utilities/AskConverter.cs +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using CopilotChat.WebApi.Auth; -using CopilotChat.WebApi.Models.Request; -using Microsoft.SemanticKernel.Orchestration; - -namespace CopilotChat.WebApi.Utilities; - -/// -/// Converts variables to , inserting some system variables along the way. -/// -public class AskConverter -{ - private readonly IAuthInfo _authInfo; - - public AskConverter(IAuthInfo authInfo) - { - this._authInfo = authInfo; - } - - /// - /// Converts variables to , inserting some system variables along the way. - /// - public ContextVariables GetContextVariables(Ask ask) - { - const string userIdKey = "userId"; - const string userNameKey = "userName"; - var contextVariables = new ContextVariables(ask.Input); - foreach (var variable in ask.Variables) - { - if (variable.Key != userIdKey && variable.Key != userNameKey) - { - contextVariables.Set(variable.Key, variable.Value); - } - } - - contextVariables.Set(userIdKey, this._authInfo.UserId); - contextVariables.Set(userNameKey, this._authInfo.Name); - return contextVariables; - } -}