Skip to content

Commit

Permalink
Using semantic memory type enum instead of string in API contract. (m…
Browse files Browse the repository at this point in the history
…icrosoft#410)

### Motivation and Context

Currently the memory retrieval API requires the name of the memory as
input, however, the memoryName is something that's configured in the
backend and is not known by external clients. This PR refactors things
to use a memoryType based on a stable enum as the public interface to
the API which allows the memoryName to remain a configuration detail of
the backend.  
-->

See issue microsoft#388 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [Contribution
Guidelines](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Ben Thomas <[email protected]>
  • Loading branch information
alliscode and Ben Thomas authored Sep 28, 2023
1 parent 91dbfea commit 969394b
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 31 deletions.
38 changes: 14 additions & 24 deletions webapi/Controllers/ChatMemoryController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Threading.Tasks;
using CopilotChat.WebApi.Auth;
using CopilotChat.WebApi.Extensions;
using CopilotChat.WebApi.Models.Request;
using CopilotChat.WebApi.Options;
using CopilotChat.WebApi.Storage;
using Microsoft.AspNetCore.Authorization;
Expand Down Expand Up @@ -50,21 +51,27 @@ public ChatMemoryController(
/// </summary>
/// <param name="semanticTextMemory">The semantic text memory instance.</param>
/// <param name="chatId">The chat id.</param>
/// <param name="memoryName">Name of the memory type.</param>
/// <param name="memoryType">Type of memory. Must map to a member of <see cref="SemanticMemoryType"/>.</param>
[HttpGet]
[Route("chatMemory/{chatId:guid}/{memoryName}")]
[Route("chatMemory/{chatId:guid}/{memoryType}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[Authorize(Policy = AuthPolicyName.RequireChatParticipant)]
public async Task<IActionResult> GetSemanticMemoriesAsync(
[FromServices] ISemanticMemoryClient memoryClient,
[FromRoute] string chatId,
[FromRoute] string memoryName)
[FromRoute] string memoryType)
{
// Sanitize the log input by removing new line characters.
// https://github.com/microsoft/chat-copilot/security/code-scanning/1
var sanitizedChatId = GetSanitizedParameter(chatId);
var sanitizedMemoryName = GetSanitizedParameter(memoryName);

// Map the requested memoryType to the memory store container name
if (!this._promptOptions.TryGetMemoryContainerName(memoryType, out string memoryContainerName))
{
this._logger.LogWarning("Memory type: {0} is invalid.", memoryType);
return this.BadRequest($"Memory type: {memoryType} is invalid.");
}

// Make sure the chat session exists.
if (!await this._chatSessionRepository.TryFindByIdAsync(chatId))
Expand All @@ -73,13 +80,6 @@ public async Task<IActionResult> GetSemanticMemoriesAsync(
return this.BadRequest($"Chat session: {sanitizedChatId} does not exist.");
}

// Make sure the memory name is valid.
if (!this.ValidateMemoryName(sanitizedMemoryName))
{
this._logger.LogWarning("Memory name: {0} is invalid.", sanitizedMemoryName);
return this.BadRequest($"Memory name: {sanitizedMemoryName} is invalid.");
}

// Gather the requested semantic memory.
// Will use a dummy query since we don't care about relevance.
// minRelevanceScore is set to 0.0 to return all memories.
Expand All @@ -89,7 +89,7 @@ public async Task<IActionResult> GetSemanticMemoriesAsync(
// Search if there is already a memory item that has a high similarity score with the new item.
var filter = new MemoryFilter();
filter.ByTag("chatid", chatId);
filter.ByTag("memory", sanitizedMemoryName);
filter.ByTag("memory", memoryContainerName);
filter.MinRelevance = 0;

var searchResult =
Expand All @@ -99,7 +99,7 @@ await memoryClient.SearchMemoryAsync(
relevanceThreshold: 0,
resultCount: 1,
chatId,
sanitizedMemoryName)
memoryContainerName)
.ConfigureAwait(false);

foreach (var memory in searchResult.Results.SelectMany(c => c.Partitions))
Expand All @@ -110,7 +110,7 @@ await memoryClient.SearchMemoryAsync(
catch (Exception connectorException) when (!connectorException.IsCriticalException())
{
// A store exception might be thrown if the collection does not exist, depending on the memory store connector.
this._logger.LogError(connectorException, "Cannot search collection {0}", sanitizedMemoryName);
this._logger.LogError(connectorException, "Cannot search collection {0}", memoryContainerName);
}

return this.Ok(memories);
Expand All @@ -123,15 +123,5 @@ private static string GetSanitizedParameter(string parameterValue)
return parameterValue.Replace(Environment.NewLine, string.Empty, StringComparison.Ordinal);
}

/// <summary>
/// Validates the memory name.
/// </summary>
/// <param name="memoryName">Name of the memory requested.</param>
/// <returns>True if the memory name is valid.</returns>
private bool ValidateMemoryName(string memoryName)
{
return this._promptOptions.MemoryMap.ContainsKey(memoryName);
}

# endregion
}
12 changes: 12 additions & 0 deletions webapi/Models/Request/SemanticMemoryType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

namespace CopilotChat.WebApi.Models.Request;

/// <summary>
/// Types of semantic memories supported by chat-copilot.
/// </summary>
public enum SemanticMemoryType
{
LongTermMemory,
WorkingMemory
}
27 changes: 27 additions & 0 deletions webapi/Options/PromptsOptions.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using CopilotChat.WebApi.Models.Request;

namespace CopilotChat.WebApi.Options;

Expand Down Expand Up @@ -161,4 +163,29 @@ public class PromptsOptions
/// </summary>
/// <returns>A shallow copy of the options.</returns>
internal PromptsOptions Copy() => (PromptsOptions)this.MemberwiseClone();

/// <summary>
/// Tries to retrieve the memoryContainerName associated with the specified memory type.
/// </summary>
internal bool TryGetMemoryContainerName(string memoryType, out string memoryContainerName)
{
memoryContainerName = "";
if (!Enum.TryParse<SemanticMemoryType>(memoryType, true, out SemanticMemoryType semanticMemoryType))
{
return false;
}

switch (semanticMemoryType)
{
case SemanticMemoryType.LongTermMemory:
memoryContainerName = this.LongTermMemoryName;
return true;

case SemanticMemoryType.WorkingMemory:
memoryContainerName = this.WorkingMemoryName;
return true;

default: return false;
}
}
}
16 changes: 11 additions & 5 deletions webapi/Skills/ChatSkills/SemanticChatMemoryExtractor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading;
using System.Threading.Tasks;
using CopilotChat.WebApi.Extensions;
using CopilotChat.WebApi.Models.Request;
using CopilotChat.WebApi.Options;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
Expand Down Expand Up @@ -37,11 +38,16 @@ public static async Task ExtractSemanticChatMemoryAsync(
ILogger logger,
CancellationToken cancellationToken)
{
foreach (var memoryName in options.MemoryMap.Keys)
foreach (string memoryType in Enum.GetNames(typeof(SemanticMemoryType)))
{
try
{
var semanticMemory = await ExtractCognitiveMemoryAsync(memoryName, logger);
if (!options.TryGetMemoryContainerName(memoryType, out var memoryName))
{
logger.LogInformation("Unable to extract semantic memory for invalid memory type {0}. Continuing...", memoryType);
continue;
}
var semanticMemory = await ExtractCognitiveMemoryAsync(memoryType, memoryName, logger);
foreach (var item in semanticMemory.Items)
{
await CreateMemoryAsync(memoryName, item.ToFormattedString());
Expand All @@ -51,15 +57,15 @@ public static async Task ExtractSemanticChatMemoryAsync(
{
// Skip semantic memory extraction for this item if it fails.
// We cannot rely on the model to response with perfect Json each time.
logger.LogInformation("Unable to extract semantic memory for {0}: {1}. Continuing...", memoryName, ex.Message);
logger.LogInformation("Unable to extract semantic memory for {0}: {1}. Continuing...", memoryType, ex.Message);
continue;
}
}

/// <summary>
/// Extracts the semantic chat memory from the chat session.
/// </summary>
async Task<SemanticChatMemory> ExtractCognitiveMemoryAsync(string memoryName, ILogger logger)
async Task<SemanticChatMemory> ExtractCognitiveMemoryAsync(string memoryType, string memoryName, ILogger logger)
{
if (!options.MemoryMap.TryGetValue(memoryName, out var memoryPrompt))
{
Expand Down Expand Up @@ -87,7 +93,7 @@ async Task<SemanticChatMemory> ExtractCognitiveMemoryAsync(string memoryName, IL

// Get token usage from ChatCompletion result and add to context
// Since there are multiple memory types, total token usage is calculated by cumulating the token usage of each memory type.
TokenUtilities.GetFunctionTokenUsage(result, context, logger, $"SystemCognitive_{memoryName}");
TokenUtilities.GetFunctionTokenUsage(result, context, logger, $"SystemCognitive_{memoryType}");

SemanticChatMemory memory = SemanticChatMemory.FromJson(result.ToString());
return memory;
Expand Down
4 changes: 2 additions & 2 deletions webapi/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@
"MemoryFormat": "{\"items\": [{\"label\": string, \"details\": string }]}",
"MemoryAntiHallucination": "IMPORTANT: DO NOT INCLUDE ANY OF THE ABOVE INFORMATION IN THE GENERATED RESPONSE AND ALSO DO NOT MAKE UP OR INFER ANY ADDITIONAL INFORMATION THAT IS NOT INCLUDED BELOW. ALSO DO NOT RESPOND IF THE LAST MESSAGE WAS NOT ADDRESSED TO YOU.",
"MemoryContinuation": "Generate a well-formed JSON of extracted context data. DO NOT include a preamble in the response. DO NOT give a list of possible responses. Only provide a single response of the json block.\nResponse:",
"WorkingMemoryName": "WorkingMemory",
"WorkingMemoryName": "WorkingMemory", // The name used for the container that stores Working Memory in the Semantic Memory database. This should not be changed once memories are established.
"WorkingMemoryExtraction": "Extract information for a short period of time, such as a few seconds or minutes. It should be useful for performing complex cognitive tasks that require attention, concentration, or mental calculation.",
"LongTermMemoryName": "LongTermMemory",
"LongTermMemoryName": "LongTermMemory", // The name used for the container that stores Long Term Memory in the Semantic Memory database. This should not be changed once memories are established.
"LongTermMemoryExtraction": "Extract information that is encoded and consolidated from other memory types, such as working memory or sensory memory. It should be useful for maintaining and recalling one's personal identity, history, and knowledge over time.",
"DocumentMemoryName": "DocumentMemory",
"MemoryIndexName": "chatmemory"
Expand Down

0 comments on commit 969394b

Please sign in to comment.