Skip to content

Commit

Permalink
Implement cross-session vector store, command to clear it, and TTL se…
Browse files Browse the repository at this point in the history
…tting (#126)

* Implement cross-session vectorstore persistence

* Add command to clear local vector store

* Refactor ChainFactory and create VectorDBManager

* Add setting for TTL

* 2.4.2
  • Loading branch information
logancyang authored Aug 10, 2023
1 parent 106c5db commit 287f5d3
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 33 deletions.
2 changes: 1 addition & 1 deletion manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"id": "copilot",
"name": "Copilot",
"version": "2.4.1",
"version": "2.4.2",
"minAppVersion": "0.15.0",
"description": "A ChatGPT Copilot in Obsidian.",
"author": "Logan Yang",
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "obsidian-copilot",
"version": "2.4.1",
"version": "2.4.2",
"description": "ChatGPT integration for Obsidian",
"main": "main.js",
"scripts": {
Expand Down
16 changes: 10 additions & 6 deletions src/aiState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
} from '@/constants';
import { ChatMessage } from '@/sharedState';
import { getModelName, isSupportedChain } from '@/utils';
import VectorDBManager, { MemoryVector } from '@/vectorDBManager';
import {
BaseChain,
ConversationChain,
Expand All @@ -27,7 +28,6 @@ import {
import { ChatAnthropic } from 'langchain/chat_models/anthropic';
import { BaseChatModel } from 'langchain/chat_models/base';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { VectorStore } from 'langchain/dist/vectorstores/base';
import { Embeddings } from "langchain/embeddings/base";
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import { HuggingFaceInferenceEmbeddings } from "langchain/embeddings/hf";
Expand Down Expand Up @@ -105,7 +105,7 @@ class AIState {
private static conversationalRetrievalChain: ConversationalRetrievalQAChain;

private chatPrompt: ChatPromptTemplate;
private vectorStore: VectorStore;
private vectorStore: MemoryVectorStore;

memory: BufferWindowMemory;
langChainParams: LangChainParams;
Expand Down Expand Up @@ -434,9 +434,12 @@ class AIState {
}

this.setNoteContent(options.noteContent);
const docHash = ChainFactory.getDocumentHash(options.noteContent);
const vectorStore = ChainFactory.vectorStoreMap.get(docHash);
if (vectorStore) {
const docHash = VectorDBManager.getDocumentHash(options.noteContent);
const parsedMemoryVectors: MemoryVector[] | undefined = await VectorDBManager.getMemoryVectors(docHash);
if (parsedMemoryVectors) {
const vectorStore = await VectorDBManager.rebuildMemoryVectorStore(
parsedMemoryVectors, this.getEmbeddingsAPI()
);
AIState.retrievalChain = RetrievalQAChain.fromLLM(
AIState.chatModel,
vectorStore.asRetriever(),
Expand Down Expand Up @@ -486,7 +489,8 @@ class AIState {
this.vectorStore = await MemoryVectorStore.fromDocuments(
docs, embeddingsAPI,
);
ChainFactory.setVectorStore(this.vectorStore, docHash);
// Serialize and save vector store to PouchDB
VectorDBManager.setMemoryVectors(this.vectorStore.memoryVectors, docHash);
console.log('Vector store created successfully.');
new Notice('Vector store created successfully.');
} catch (error) {
Expand Down
16 changes: 0 additions & 16 deletions src/chainFactory.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import { MD5 } from 'crypto-js';
import { BaseLanguageModel } from "langchain/base_language";
import {
BaseChain,
ConversationChain,
ConversationalRetrievalQAChain,
LLMChainInput
} from "langchain/chains";
import { VectorStore } from 'langchain/dist/vectorstores/base';
import { BaseRetriever } from "langchain/schema";


export interface RetrievalChainParams {
llm: BaseLanguageModel;
retriever: BaseRetriever;
Expand Down Expand Up @@ -42,7 +39,6 @@ export enum ChainType {

class ChainFactory {
public static instances: Map<string, BaseChain> = new Map();
public static vectorStoreMap: Map<string, VectorStore> = new Map();

public static createNewLLMChain(args: LLMChainInput): BaseChain {
const instance = new ConversationChain(args as LLMChainInput);
Expand All @@ -59,18 +55,6 @@ class ChainFactory {
return instance;
}

public static getDocumentHash(sourceDocument: string): string {
return MD5(sourceDocument).toString();
}

public static setVectorStore(vectorStore: VectorStore, docHash: string): void {
ChainFactory.vectorStoreMap.set(docHash, vectorStore);
}

public static getVectorStore(docHash: string): VectorStore | undefined {
return ChainFactory.vectorStoreMap.get(docHash);
}

public static createConversationalRetrievalChain(
args: ConversationalRetrievalChainParams
): ConversationalRetrievalQAChain {
Expand Down
5 changes: 3 additions & 2 deletions src/components/Chat.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import AIState, { useAIState } from '@/aiState';
import ChainFactory, { ChainType } from '@/chainFactory';
import { ChainType } from '@/chainFactory';
import ChatIcons from '@/components/ChatComponents/ChatIcons';
import ChatInput from '@/components/ChatComponents/ChatInput';
import ChatMessages from '@/components/ChatComponents/ChatMessages';
Expand Down Expand Up @@ -31,6 +31,7 @@ import {
summarizePrompt,
tocPrompt,
} from '@/utils';
import VectorDBManager from '@/vectorDBManager';
import { EventEmitter } from 'events';
import { Notice, TFile } from 'obsidian';
import React, {
Expand Down Expand Up @@ -141,7 +142,7 @@ const Chat: React.FC<ChatProps> = ({
return;
}

const docHash = ChainFactory.getDocumentHash(noteContent);
const docHash = VectorDBManager.getDocumentHash(noteContent);
await aiState.buildIndex(noteContent, docHash);
const activeNoteOnMessage: ChatMessage = {
sender: AI_SENDER,
Expand Down
1 change: 1 addition & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ export const DEFAULT_SETTINGS: CopilotSettings = {
userSystemPrompt: '',
openAIProxyBaseUrl: '',
localAIModel: '',
ttlDays: 30,
stream: true,
embeddingProvider: OPENAI,
debug: false,
Expand Down
32 changes: 28 additions & 4 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@ import {
import { CopilotSettingTab } from '@/settings';
import SharedState from '@/sharedState';
import { sanitizeSettings } from "@/utils";
import cors from '@koa/cors';
import VectorDBManager, { VectorStoreDocument } from '@/vectorDBManager';
import { Server } from 'http';
import Koa from 'koa';
import proxy from 'koa-proxies';
import net from 'net';
import { Editor, Notice, Plugin, WorkspaceLeaf } from 'obsidian';
import PouchDB from 'pouchdb';

Expand All @@ -39,6 +36,7 @@ export interface CopilotSettings {
userSystemPrompt: string;
openAIProxyBaseUrl: string;
localAIModel: string;
ttlDays: number;
stream: boolean;
embeddingProvider: string;
debug: boolean;
Expand All @@ -59,6 +57,7 @@ export default class CopilotPlugin extends Plugin {
activateViewPromise: Promise<void> | null = null;
chatIsVisible = false;
dbPrompts: PouchDB.Database;
dbVectorStores: PouchDB.Database;
server: Server| null = null;

isChatVisible = () => this.chatIsVisible;
Expand All @@ -72,6 +71,13 @@ export default class CopilotPlugin extends Plugin {
this.aiState = new AIState(langChainParams);

this.dbPrompts = new PouchDB<CustomPrompt>('copilot_custom_prompts');
this.dbVectorStores = new PouchDB<VectorStoreDocument>('copilot_vector_stores');

VectorDBManager.initializeDB(this.dbVectorStores);
// Remove documents older than TTL days on load
VectorDBManager.removeOldDocuments(
this.settings.ttlDays * 24 * 60 * 60 * 1000
);

this.registerView(
CHAT_VIEWTYPE,
Expand Down Expand Up @@ -366,6 +372,24 @@ export default class CopilotPlugin extends Plugin {
return true;
},
});

this.addCommand({
id: 'clear-local-vector-store',
name: 'Clear local vector store',
callback: async () => {
try {
// Clear the vectorstore db
await this.dbVectorStores.destroy();
// Reinitialize the database
this.dbVectorStores = new PouchDB<VectorStoreDocument>('copilot_vector_stores'); //
new Notice('Local vector store cleared successfully.');
console.log('Local vector store cleared successfully.');
} catch (err) {
console.error("Error clearing the local vector store:", err);
new Notice('An error occurred while clearing the local vector store.');
}
}
});
}

processSelection(editor: Editor, eventType: string, eventSubtype?: string) {
Expand Down
17 changes: 17 additions & 0 deletions src/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,23 @@ export class CopilotSettingTab extends PluginSettingTab {
});
});

new Setting(containerEl)
.setName("TTL (Days)")
.setDesc("Specify the Time To Live (TTL) for the saved embeddings in days. Default is 30 days. Embeddings older than the TTL will be deleted automatically to save storage space.")
.addText((text) => {
text
.setPlaceholder("30")
.setValue(this.plugin.settings.ttlDays ? this.plugin.settings.ttlDays.toString() : '')
.onChange(async (value: string) => {
const intValue = parseInt(value);
if (!isNaN(intValue)) {
this.plugin.settings.ttlDays = intValue;
await this.plugin.saveSettings();
}
});
});


new Setting(containerEl)
.setName("Your CohereAI trial API key")
.setDesc(
Expand Down
130 changes: 130 additions & 0 deletions src/vectorDBManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import { MD5 } from 'crypto-js';
import { Document } from "langchain/document";
import { Embeddings } from 'langchain/embeddings/base';
import { MemoryVectorStore } from "langchain/vectorstores/memory";

export interface VectorStoreDocument {
_id: string;
_rev?: string;
memory_vectors: string;
created_at: number;
}

export interface MemoryVector {
content: string;
embedding: number[];
metadata: Record<string, any>;
}

class VectorDBManager {
public static db: PouchDB.Database | null = null;

public static initializeDB(db: PouchDB.Database): void {
this.db = db;
}

public static getDocumentHash(sourceDocument: string): string {
return MD5(sourceDocument).toString();
}

public static async rebuildMemoryVectorStore(
memoryVectors: MemoryVector[], embeddingsAPI: Embeddings
) {
if (!Array.isArray(memoryVectors)) {
throw new TypeError("Expected memoryVectors to be an array");
}
// Extract the embeddings and documents from the deserialized memoryVectors
const embeddingsArray: number[][] = memoryVectors.map(
memoryVector => memoryVector.embedding
);
const documentsArray = memoryVectors.map(
memoryVector => new Document({
pageContent: memoryVector.content,
metadata: memoryVector.metadata
})
);

// Create a new MemoryVectorStore instance
const memoryVectorStore = new MemoryVectorStore(embeddingsAPI);
await memoryVectorStore.addVectors(embeddingsArray, documentsArray);
return memoryVectorStore;
}

public static async setMemoryVectors(memoryVectors: MemoryVector[], docHash: string): Promise<void> {
if (!this.db) throw new Error("DB not initialized");
if (!Array.isArray(memoryVectors)) {
throw new TypeError("Expected memoryVectors to be an array");
}
const serializedMemoryVectors = JSON.stringify(memoryVectors);
try {
// Attempt to fetch the existing document, if it exists.
const existingDoc = await this.db.get(docHash).catch(err => null);

// Prepare the document to be saved.
const docToSave = {
_id: docHash,
memory_vectors: serializedMemoryVectors,
created_at: Date.now(),
_rev: existingDoc?._rev // Add the current revision if the document exists.
};

// Save the document.
await this.db.put(docToSave);
} catch (err) {
console.error("Error storing vectors in VectorDB:", err);
}
}

public static async getMemoryVectors(docHash: string): Promise<MemoryVector[] | undefined> {
if (!this.db) throw new Error("DB not initialized");
try {
const doc: VectorStoreDocument = await this.db.get(docHash);
if (doc && doc.memory_vectors) {
return JSON.parse(doc.memory_vectors);
}
} catch (err) {
console.log("No vectors found in VectorDB for dochash:", docHash);
}
}

public static async removeOldDocuments(ttl: number): Promise<void> {
if (!this.db) throw new Error("DB not initialized");

try {
const thresholdTime = Date.now() - ttl;

// Fetch all documents from the database
const allDocsResponse = await this.db.allDocs<{ created_at: number }>({ include_docs: true });

// Filter out the documents older than 2 weeks
const oldDocs = allDocsResponse.rows.filter(row => {
// Assert the doc type
const doc = row.doc as VectorStoreDocument;
return doc && doc.created_at < thresholdTime;
});

if (oldDocs.length === 0) {
return;
}
// Prepare the documents for deletion
const docsToDelete = oldDocs.map(row => ({
_id: row.id,
_rev: (row.doc as VectorStoreDocument)._rev,
_deleted: true
}));

// Delete the old documents
await this.db.bulkDocs(docsToDelete);
console.log("Deleted old documents from VectorDB");
}
catch (err) {
console.error("Error removing old documents from VectorDB:", err);
}
}

// TODO: Implement advanced stale document removal.
// NOTE: Cannot just rely on note title + ts because a "document" here is a chunk from
// the original note. Need a better strategy.
}

export default VectorDBManager;
3 changes: 2 additions & 1 deletion versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
"2.3.5": "0.15.0",
"2.3.6": "0.15.0",
"2.4.0": "0.15.0",
"2.4.1": "0.15.0"
"2.4.1": "0.15.0",
"2.4.2": "0.15.0"
}

0 comments on commit 287f5d3

Please sign in to comment.