diff --git a/manifest.json b/manifest.json index 1ff1cfaf..cbb406cb 100644 --- a/manifest.json +++ b/manifest.json @@ -1,7 +1,7 @@ { "id": "copilot", "name": "Copilot", - "version": "2.4.1", + "version": "2.4.2", "minAppVersion": "0.15.0", "description": "A ChatGPT Copilot in Obsidian.", "author": "Logan Yang", diff --git a/package-lock.json b/package-lock.json index ea41418f..78c6b201 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "obsidian-copilot", - "version": "2.4.1", + "version": "2.4.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "obsidian-copilot", - "version": "2.4.1", + "version": "2.4.2", "license": "AGPL-3.0", "dependencies": { "@huggingface/inference": "^1.8.0", diff --git a/package.json b/package.json index 9b95a5f5..87748e90 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "obsidian-copilot", - "version": "2.4.1", + "version": "2.4.2", "description": "ChatGPT integration for Obsidian", "main": "main.js", "scripts": { diff --git a/src/aiState.ts b/src/aiState.ts index bc9eab9d..a355baad 100644 --- a/src/aiState.ts +++ b/src/aiState.ts @@ -18,6 +18,7 @@ import { } from '@/constants'; import { ChatMessage } from '@/sharedState'; import { getModelName, isSupportedChain } from '@/utils'; +import VectorDBManager, { MemoryVector } from '@/vectorDBManager'; import { BaseChain, ConversationChain, @@ -27,7 +28,6 @@ import { import { ChatAnthropic } from 'langchain/chat_models/anthropic'; import { BaseChatModel } from 'langchain/chat_models/base'; import { ChatOpenAI } from 'langchain/chat_models/openai'; -import { VectorStore } from 'langchain/dist/vectorstores/base'; import { Embeddings } from "langchain/embeddings/base"; import { CohereEmbeddings } from "langchain/embeddings/cohere"; import { HuggingFaceInferenceEmbeddings } from "langchain/embeddings/hf"; @@ -105,7 +105,7 @@ class AIState { private static conversationalRetrievalChain: ConversationalRetrievalQAChain; private chatPrompt: ChatPromptTemplate; - private vectorStore: VectorStore; + private vectorStore: MemoryVectorStore; memory: BufferWindowMemory; langChainParams: LangChainParams; @@ -434,9 +434,12 @@ class AIState { } this.setNoteContent(options.noteContent); - const docHash = ChainFactory.getDocumentHash(options.noteContent); - const vectorStore = ChainFactory.vectorStoreMap.get(docHash); - if (vectorStore) { + const docHash = VectorDBManager.getDocumentHash(options.noteContent); + const parsedMemoryVectors: MemoryVector[] | undefined = await VectorDBManager.getMemoryVectors(docHash); + if (parsedMemoryVectors) { + const vectorStore = await VectorDBManager.rebuildMemoryVectorStore( + parsedMemoryVectors, this.getEmbeddingsAPI() + ); AIState.retrievalChain = RetrievalQAChain.fromLLM( AIState.chatModel, vectorStore.asRetriever(), @@ -486,7 +489,8 @@ class AIState { this.vectorStore = await MemoryVectorStore.fromDocuments( docs, embeddingsAPI, ); - ChainFactory.setVectorStore(this.vectorStore, docHash); + // Serialize and save vector store to PouchDB + VectorDBManager.setMemoryVectors(this.vectorStore.memoryVectors, docHash); console.log('Vector store created successfully.'); new Notice('Vector store created successfully.'); } catch (error) { diff --git a/src/chainFactory.ts b/src/chainFactory.ts index d06b735a..d28064ef 100644 --- a/src/chainFactory.ts +++ b/src/chainFactory.ts @@ -1,4 +1,3 @@ -import { MD5 } from 'crypto-js'; import { BaseLanguageModel } from "langchain/base_language"; import { BaseChain, @@ -6,10 +5,8 @@ import { ConversationalRetrievalQAChain, LLMChainInput } from "langchain/chains"; -import { VectorStore } from 'langchain/dist/vectorstores/base'; import { BaseRetriever } from "langchain/schema"; - export interface RetrievalChainParams { llm: BaseLanguageModel; retriever: BaseRetriever; @@ -42,7 +39,6 @@ export enum ChainType { class ChainFactory { public static instances: Map = new Map(); - public static vectorStoreMap: Map = new Map(); public static createNewLLMChain(args: LLMChainInput): BaseChain { const instance = new ConversationChain(args as LLMChainInput); @@ -59,18 +55,6 @@ class ChainFactory { return instance; } - public static getDocumentHash(sourceDocument: string): string { - return MD5(sourceDocument).toString(); - } - - public static setVectorStore(vectorStore: VectorStore, docHash: string): void { - ChainFactory.vectorStoreMap.set(docHash, vectorStore); - } - - public static getVectorStore(docHash: string): VectorStore | undefined { - return ChainFactory.vectorStoreMap.get(docHash); - } - public static createConversationalRetrievalChain( args: ConversationalRetrievalChainParams ): ConversationalRetrievalQAChain { diff --git a/src/components/Chat.tsx b/src/components/Chat.tsx index 2e1591db..896a92e5 100644 --- a/src/components/Chat.tsx +++ b/src/components/Chat.tsx @@ -1,5 +1,5 @@ import AIState, { useAIState } from '@/aiState'; -import ChainFactory, { ChainType } from '@/chainFactory'; +import { ChainType } from '@/chainFactory'; import ChatIcons from '@/components/ChatComponents/ChatIcons'; import ChatInput from '@/components/ChatComponents/ChatInput'; import ChatMessages from '@/components/ChatComponents/ChatMessages'; @@ -31,6 +31,7 @@ import { summarizePrompt, tocPrompt, } from '@/utils'; +import VectorDBManager from '@/vectorDBManager'; import { EventEmitter } from 'events'; import { Notice, TFile } from 'obsidian'; import React, { @@ -141,7 +142,7 @@ const Chat: React.FC = ({ return; } - const docHash = ChainFactory.getDocumentHash(noteContent); + const docHash = VectorDBManager.getDocumentHash(noteContent); await aiState.buildIndex(noteContent, docHash); const activeNoteOnMessage: ChatMessage = { sender: AI_SENDER, diff --git a/src/constants.ts b/src/constants.ts index 629bff58..f96a1977 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -111,6 +111,7 @@ export const DEFAULT_SETTINGS: CopilotSettings = { userSystemPrompt: '', openAIProxyBaseUrl: '', localAIModel: '', + ttlDays: 30, stream: true, embeddingProvider: OPENAI, debug: false, diff --git a/src/main.ts b/src/main.ts index d4ff75a3..e0a08661 100644 --- a/src/main.ts +++ b/src/main.ts @@ -11,11 +11,8 @@ import { import { CopilotSettingTab } from '@/settings'; import SharedState from '@/sharedState'; import { sanitizeSettings } from "@/utils"; -import cors from '@koa/cors'; +import VectorDBManager, { VectorStoreDocument } from '@/vectorDBManager'; import { Server } from 'http'; -import Koa from 'koa'; -import proxy from 'koa-proxies'; -import net from 'net'; import { Editor, Notice, Plugin, WorkspaceLeaf } from 'obsidian'; import PouchDB from 'pouchdb'; @@ -39,6 +36,7 @@ export interface CopilotSettings { userSystemPrompt: string; openAIProxyBaseUrl: string; localAIModel: string; + ttlDays: number; stream: boolean; embeddingProvider: string; debug: boolean; @@ -59,6 +57,7 @@ export default class CopilotPlugin extends Plugin { activateViewPromise: Promise | null = null; chatIsVisible = false; dbPrompts: PouchDB.Database; + dbVectorStores: PouchDB.Database; server: Server| null = null; isChatVisible = () => this.chatIsVisible; @@ -72,6 +71,13 @@ export default class CopilotPlugin extends Plugin { this.aiState = new AIState(langChainParams); this.dbPrompts = new PouchDB('copilot_custom_prompts'); + this.dbVectorStores = new PouchDB('copilot_vector_stores'); + + VectorDBManager.initializeDB(this.dbVectorStores); + // Remove documents older than TTL days on load + VectorDBManager.removeOldDocuments( + this.settings.ttlDays * 24 * 60 * 60 * 1000 + ); this.registerView( CHAT_VIEWTYPE, @@ -366,6 +372,24 @@ export default class CopilotPlugin extends Plugin { return true; }, }); + + this.addCommand({ + id: 'clear-local-vector-store', + name: 'Clear local vector store', + callback: async () => { + try { + // Clear the vectorstore db + await this.dbVectorStores.destroy(); + // Reinitialize the database + this.dbVectorStores = new PouchDB('copilot_vector_stores'); // + new Notice('Local vector store cleared successfully.'); + console.log('Local vector store cleared successfully.'); + } catch (err) { + console.error("Error clearing the local vector store:", err); + new Notice('An error occurred while clearing the local vector store.'); + } + } + }); } processSelection(editor: Editor, eventType: string, eventSubtype?: string) { diff --git a/src/settings.ts b/src/settings.ts index e820aa44..f0303a27 100644 --- a/src/settings.ts +++ b/src/settings.ts @@ -336,6 +336,23 @@ export class CopilotSettingTab extends PluginSettingTab { }); }); + new Setting(containerEl) + .setName("TTL (Days)") + .setDesc("Specify the Time To Live (TTL) for the saved embeddings in days. Default is 30 days. Embeddings older than the TTL will be deleted automatically to save storage space.") + .addText((text) => { + text + .setPlaceholder("30") + .setValue(this.plugin.settings.ttlDays ? this.plugin.settings.ttlDays.toString() : '') + .onChange(async (value: string) => { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + this.plugin.settings.ttlDays = intValue; + await this.plugin.saveSettings(); + } + }); + }); + + new Setting(containerEl) .setName("Your CohereAI trial API key") .setDesc( diff --git a/src/vectorDBManager.ts b/src/vectorDBManager.ts new file mode 100644 index 00000000..3591af13 --- /dev/null +++ b/src/vectorDBManager.ts @@ -0,0 +1,130 @@ +import { MD5 } from 'crypto-js'; +import { Document } from "langchain/document"; +import { Embeddings } from 'langchain/embeddings/base'; +import { MemoryVectorStore } from "langchain/vectorstores/memory"; + +export interface VectorStoreDocument { + _id: string; + _rev?: string; + memory_vectors: string; + created_at: number; +} + +export interface MemoryVector { + content: string; + embedding: number[]; + metadata: Record; +} + +class VectorDBManager { + public static db: PouchDB.Database | null = null; + + public static initializeDB(db: PouchDB.Database): void { + this.db = db; + } + + public static getDocumentHash(sourceDocument: string): string { + return MD5(sourceDocument).toString(); + } + + public static async rebuildMemoryVectorStore( + memoryVectors: MemoryVector[], embeddingsAPI: Embeddings + ) { + if (!Array.isArray(memoryVectors)) { + throw new TypeError("Expected memoryVectors to be an array"); + } + // Extract the embeddings and documents from the deserialized memoryVectors + const embeddingsArray: number[][] = memoryVectors.map( + memoryVector => memoryVector.embedding + ); + const documentsArray = memoryVectors.map( + memoryVector => new Document({ + pageContent: memoryVector.content, + metadata: memoryVector.metadata + }) + ); + + // Create a new MemoryVectorStore instance + const memoryVectorStore = new MemoryVectorStore(embeddingsAPI); + await memoryVectorStore.addVectors(embeddingsArray, documentsArray); + return memoryVectorStore; + } + + public static async setMemoryVectors(memoryVectors: MemoryVector[], docHash: string): Promise { + if (!this.db) throw new Error("DB not initialized"); + if (!Array.isArray(memoryVectors)) { + throw new TypeError("Expected memoryVectors to be an array"); + } + const serializedMemoryVectors = JSON.stringify(memoryVectors); + try { + // Attempt to fetch the existing document, if it exists. + const existingDoc = await this.db.get(docHash).catch(err => null); + + // Prepare the document to be saved. + const docToSave = { + _id: docHash, + memory_vectors: serializedMemoryVectors, + created_at: Date.now(), + _rev: existingDoc?._rev // Add the current revision if the document exists. + }; + + // Save the document. + await this.db.put(docToSave); + } catch (err) { + console.error("Error storing vectors in VectorDB:", err); + } + } + + public static async getMemoryVectors(docHash: string): Promise { + if (!this.db) throw new Error("DB not initialized"); + try { + const doc: VectorStoreDocument = await this.db.get(docHash); + if (doc && doc.memory_vectors) { + return JSON.parse(doc.memory_vectors); + } + } catch (err) { + console.log("No vectors found in VectorDB for dochash:", docHash); + } + } + + public static async removeOldDocuments(ttl: number): Promise { + if (!this.db) throw new Error("DB not initialized"); + + try { + const thresholdTime = Date.now() - ttl; + + // Fetch all documents from the database + const allDocsResponse = await this.db.allDocs<{ created_at: number }>({ include_docs: true }); + + // Filter out the documents older than 2 weeks + const oldDocs = allDocsResponse.rows.filter(row => { + // Assert the doc type + const doc = row.doc as VectorStoreDocument; + return doc && doc.created_at < thresholdTime; + }); + + if (oldDocs.length === 0) { + return; + } + // Prepare the documents for deletion + const docsToDelete = oldDocs.map(row => ({ + _id: row.id, + _rev: (row.doc as VectorStoreDocument)._rev, + _deleted: true + })); + + // Delete the old documents + await this.db.bulkDocs(docsToDelete); + console.log("Deleted old documents from VectorDB"); + } + catch (err) { + console.error("Error removing old documents from VectorDB:", err); + } + } + + // TODO: Implement advanced stale document removal. + // NOTE: Cannot just rely on note title + ts because a "document" here is a chunk from + // the original note. Need a better strategy. +} + +export default VectorDBManager; diff --git a/versions.json b/versions.json index 64776420..613c4e2e 100644 --- a/versions.json +++ b/versions.json @@ -22,5 +22,6 @@ "2.3.5": "0.15.0", "2.3.6": "0.15.0", "2.4.0": "0.15.0", - "2.4.1": "0.15.0" + "2.4.1": "0.15.0", + "2.4.2": "0.15.0" } \ No newline at end of file