diff --git a/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs b/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs index 7e68ea2e8..2fc2a68fa 100644 --- a/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs +++ b/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs @@ -69,7 +69,7 @@ last_update TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL var indexName = "create_index_test"; var vectorSize = 1536; - var target = new PostgresMemory(config, new FakeEmbeddingGenerator()); + using var target = new PostgresMemory(config, new FakeEmbeddingGenerator()); var tasks = new List(); for (int i = 0; i < concurrency; i++) @@ -96,7 +96,7 @@ public async Task UpsertConcurrencyTest() var vectorSize = 4; var indexName = "upsert_test" + Guid.NewGuid().ToString("D"); - var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator()); + using var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator()); await target.CreateIndexAsync(indexName, vectorSize); diff --git a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs index f20f8c0ee..34efbcfab 100644 --- a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs +++ b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs @@ -19,7 +19,7 @@ namespace Microsoft.KernelMemory.Postgres; /// /// An implementation of a client for Postgres. This class is used to managing postgres database operations. /// -internal sealed class PostgresDbClient +internal sealed class PostgresDbClient : IDisposable { // See: https://www.postgresql.org/docs/current/errcodes-appendix.html private const string PgErrUndefinedTable = "42P01"; // undefined_table @@ -28,7 +28,7 @@ internal sealed class PostgresDbClient private const string PgErrDatabaseDoesNotExist = "3D000"; // invalid_catalog_name private readonly ILogger _log; - private readonly NpgsqlDataSourceBuilder _dataSourceBuilder; + private readonly NpgsqlDataSource _dataSource; private readonly string _schema; private readonly string _tableNamePrefix; @@ -52,8 +52,9 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n config.Validate(); this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); - this._dataSourceBuilder = new(config.ConnectionString); - this._dataSourceBuilder.UseVector(); + NpgsqlDataSourceBuilder dataSourceBuilder = new(config.ConnectionString); + dataSourceBuilder.UseVector(); + this._dataSource = dataSourceBuilder.Build(); this._dbNamePresent = config.ConnectionString.Contains("Database=", StringComparison.OrdinalIgnoreCase); this._schema = config.Schema; @@ -96,51 +97,48 @@ public async Task DoesTableExistAsync( tableName = this.WithTableNamePrefix(tableName); this._log.LogTrace("Checking if table {0} exists", tableName); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $@" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = @schema - AND table_name = @table - AND table_type = 'BASE TABLE' - LIMIT 1 - "; - - cmd.Parameters.AddWithValue("@schema", this._schema); - cmd.Parameters.AddWithValue("@table", tableName); + cmd.CommandText = $@" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = @schema + AND table_name = @table + AND table_type = 'BASE TABLE' + LIMIT 1 + "; + + cmd.Parameters.AddWithValue("@schema", this._schema); + cmd.Parameters.AddWithValue("@table", tableName); #pragma warning restore CA2100 - this._log.LogTrace("Schema: {0}, Table: {1}, SQL: {2}", this._schema, tableName, cmd.CommandText); + this._log.LogTrace("Schema: {0}, Table: {1}, SQL: {2}", this._schema, tableName, cmd.CommandText); - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) + { + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - var name = dataReader.GetString(dataReader.GetOrdinal("table_name")); - - return string.Equals(name, tableName, StringComparison.OrdinalIgnoreCase); - } + var name = dataReader.GetString(dataReader.GetOrdinal("table_name")); - this._log.LogTrace("Table {0} does not exist", tableName); - return false; + return string.Equals(name, tableName, StringComparison.OrdinalIgnoreCase); } + + this._log.LogTrace("Table {0} does not exist", tableName); + return false; } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -162,70 +160,67 @@ public async Task CreateTableAsync( Npgsql.PostgresException? createErr = null; - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { - var lockId = GenLockId(tableName); + var lockId = GenLockId(tableName); #pragma warning disable CA2100 // SQL reviewed - if (!string.IsNullOrEmpty(this._createTableSql)) - { - cmd.CommandText = this._createTableSql - .Replace(PostgresConfig.SqlPlaceholdersTableName, tableName, StringComparison.Ordinal) - .Replace(PostgresConfig.SqlPlaceholdersVectorSize, $"{vectorSize}", StringComparison.Ordinal) - .Replace(PostgresConfig.SqlPlaceholdersLockId, $"{lockId}", StringComparison.Ordinal); + if (!string.IsNullOrEmpty(this._createTableSql)) + { + cmd.CommandText = this._createTableSql + .Replace(PostgresConfig.SqlPlaceholdersTableName, tableName, StringComparison.Ordinal) + .Replace(PostgresConfig.SqlPlaceholdersVectorSize, $"{vectorSize}", StringComparison.Ordinal) + .Replace(PostgresConfig.SqlPlaceholdersLockId, $"{lockId}", StringComparison.Ordinal); - this._log.LogTrace("Creating table with custom SQL: {0}", cmd.CommandText); - } - else - { - cmd.CommandText = $@" - BEGIN; - SELECT pg_advisory_xact_lock({lockId}); - CREATE TABLE IF NOT EXISTS {tableName} ( - {this._colId} TEXT NOT NULL PRIMARY KEY, - {this._colEmbedding} vector({vectorSize}), - {this._colTags} TEXT[] DEFAULT '{{}}'::TEXT[] NOT NULL, - {this._colContent} TEXT DEFAULT '' NOT NULL, - {this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags}); - COMMIT; - "; + this._log.LogTrace("Creating table with custom SQL: {0}", cmd.CommandText); + } + else + { + cmd.CommandText = $@" + BEGIN; + SELECT pg_advisory_xact_lock({lockId}); + CREATE TABLE IF NOT EXISTS {tableName} ( + {this._colId} TEXT NOT NULL PRIMARY KEY, + {this._colEmbedding} vector({vectorSize}), + {this._colTags} TEXT[] DEFAULT '{{}}'::TEXT[] NOT NULL, + {this._colContent} TEXT DEFAULT '' NOT NULL, + {this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags}); + COMMIT; + "; #pragma warning restore CA2100 - this._log.LogTrace("Creating table with default SQL: {0}", cmd.CommandText); - } - - int result = await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - this._log.LogTrace("Table '{0}' creation result: {1}", tableName, result); + this._log.LogTrace("Creating table with default SQL: {0}", cmd.CommandText); } + + int result = await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + this._log.LogTrace("Table '{0}' creation result: {1}", tableName, result); } - catch (Npgsql.PostgresException e) when (IsVectorTypeDoesNotExistException(e)) - { - this._log.LogError(e, "Vector type not installed, check 'SELECT * FROM pg_extension'"); - throw; - } - catch (Npgsql.PostgresException e) when (e.SqlState == PgErrUniqueViolation) - { - createErr = e; - } - catch (Exception e) - { - this._log.LogError(e, "Table '{0}' creation error: {1}. Err: {2}. InnerEx: {3}", tableName, e, e.Message, e.InnerException); - throw; - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + catch (Npgsql.PostgresException e) when (IsVectorTypeDoesNotExistException(e)) + { + this._log.LogError(e, "Vector type not installed, check 'SELECT * FROM pg_extension'"); + throw; + } + catch (Npgsql.PostgresException e) when (e.SqlState == PgErrUniqueViolation) + { + createErr = e; + } + catch (Exception e) + { + this._log.LogError(e, "Table '{0}' creation error: {1}. Err: {2}. InnerEx: {3}", tableName, e, e.Message, e.InnerException); + throw; + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } @@ -267,40 +262,37 @@ public async Task CreateTableAsync( public async IAsyncEnumerable GetTablesAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { - cmd.CommandText = @"SELECT table_name FROM information_schema.tables + cmd.CommandText = @"SELECT table_name FROM information_schema.tables WHERE table_schema = @schema AND table_type = 'BASE TABLE';"; - cmd.Parameters.AddWithValue("@schema", this._schema); + cmd.Parameters.AddWithValue("@schema", this._schema); - this._log.LogTrace("Fetching list of tables. SQL: {0}. Schema: {1}", cmd.CommandText, this._schema); + this._log.LogTrace("Fetching list of tables. SQL: {0}. Schema: {1}", cmd.CommandText, this._schema); - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) + { + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + var tableNameWithPrefix = dataReader.GetString(dataReader.GetOrdinal("table_name")); + if (tableNameWithPrefix.StartsWith(this._tableNamePrefix, StringComparison.OrdinalIgnoreCase)) { - var tableNameWithPrefix = dataReader.GetString(dataReader.GetOrdinal("table_name")); - if (tableNameWithPrefix.StartsWith(this._tableNamePrefix, StringComparison.OrdinalIgnoreCase)) - { - yield return tableNameWithPrefix.Remove(0, this._tableNamePrefix.Length); - } + yield return tableNameWithPrefix.Remove(0, this._tableNamePrefix.Length); } } } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -316,33 +308,30 @@ public async Task DeleteTableAsync( { tableName = this.WithSchemaAndTableNamePrefix(tableName); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $"DROP TABLE IF EXISTS {tableName}"; + cmd.CommandText = $"DROP TABLE IF EXISTS {tableName}"; #pragma warning restore CA2100 - this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); + this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); + } } } @@ -363,55 +352,52 @@ public async Task UpsertAsync( const string EmptyContent = ""; string[] emptyTags = []; - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $@" - INSERT INTO {tableName} - ({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload}) - VALUES - (@id, @embedding, @tags, @content, @payload) - ON CONFLICT ({this._colId}) - DO UPDATE SET - {this._colEmbedding} = @embedding, - {this._colTags} = @tags, - {this._colContent} = @content, - {this._colPayload} = @payload - "; - - cmd.Parameters.AddWithValue("@id", record.Id); - cmd.Parameters.AddWithValue("@embedding", record.Embedding); - cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags); - cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, CleanContent(record.Content) ?? EmptyContent); - cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload); + cmd.CommandText = $@" + INSERT INTO {tableName} + ({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload}) + VALUES + (@id, @embedding, @tags, @content, @payload) + ON CONFLICT ({this._colId}) + DO UPDATE SET + {this._colEmbedding} = @embedding, + {this._colTags} = @tags, + {this._colContent} = @content, + {this._colPayload} = @payload + "; + + cmd.Parameters.AddWithValue("@id", record.Id); + cmd.Parameters.AddWithValue("@embedding", record.Embedding); + cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags); + cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, CleanContent(record.Content) ?? EmptyContent); + cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload); #pragma warning restore CA2100 - this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName); + this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - throw new IndexNotFoundException(e.Message, e); - } - catch (Exception e) - { - throw new PostgresException(e.Message, e); - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + throw new IndexNotFoundException(e.Message, e); + } + catch (Exception e) + { + throw new PostgresException(e.Message, e); + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); + } } } @@ -460,71 +446,68 @@ DO UPDATE SET this._log.LogTrace("Searching by similarity. Table: {0}. Threshold: {1}. Limit: {2}. Offset: {3}. Using SQL filter: {4}", tableName, minSimilarity, limit, offset, string.IsNullOrWhiteSpace(filterSql) ? "false" : "true"); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - string colDistance = "__distance"; - - // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate - // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. - cmd.CommandText = @$" - SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} - FROM {tableName} - WHERE {filterSql} - ORDER BY {colDistance} ASC - LIMIT @limit - OFFSET @offset - "; - - cmd.Parameters.AddWithValue("@embedding", target); - cmd.Parameters.AddWithValue("@maxDistance", maxDistance); - cmd.Parameters.AddWithValue("@limit", limit); - cmd.Parameters.AddWithValue("@offset", offset); - - foreach (KeyValuePair kv in sqlUserValues) - { - cmd.Parameters.AddWithValue(kv.Key, kv.Value); - } + string colDistance = "__distance"; + + // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate + // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. + cmd.CommandText = @$" + SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} + FROM {tableName} + WHERE {filterSql} + ORDER BY {colDistance} ASC + LIMIT @limit + OFFSET @offset + "; + + cmd.Parameters.AddWithValue("@embedding", target); + cmd.Parameters.AddWithValue("@maxDistance", maxDistance); + cmd.Parameters.AddWithValue("@limit", limit); + cmd.Parameters.AddWithValue("@offset", offset); + + foreach (KeyValuePair kv in sqlUserValues) + { + cmd.Parameters.AddWithValue(kv.Key, kv.Value); + } #pragma warning restore CA2100 - // TODO: rewrite code to stream results (need to combine yield and try-catch) - var result = new List<(PostgresMemoryRecord record, double similarity)>(); - try + // TODO: rewrite code to stream results (need to combine yield and try-catch) + var result = new List<(PostgresMemoryRecord record, double similarity)>(); + try + { + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) { - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - double distance = dataReader.GetDouble(dataReader.GetOrdinal(colDistance)); - double similarity = 1 - distance; - result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity)); - } + double distance = dataReader.GetDouble(dataReader.GetOrdinal(colDistance)); + double similarity = 1 - distance; + result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity)); } } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } + } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); + } - // TODO: rewrite code to stream results (need to combine yield and try-catch) - foreach (var x in result) - { - yield return x; - } + // TODO: rewrite code to stream results (need to combine yield and try-catch) + foreach (var x in result) + { + yield return x; } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -572,66 +555,63 @@ public async IAsyncEnumerable GetListAsync( this._log.LogTrace("Fetching list of records. Table: {0}. Order by: {1}. Limit: {2}. Offset: {3}. Using SQL filter: {4}", tableName, orderBySql, limit, offset, string.IsNullOrWhiteSpace(filterSql) ? "false" : "true"); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = @$" - SELECT {columns} FROM {tableName} - WHERE {filterSql} - ORDER BY {orderBySql} - LIMIT @limit - OFFSET @offset - "; - - cmd.Parameters.AddWithValue("@limit", limit); - cmd.Parameters.AddWithValue("@offset", offset); - - if (sqlUserValues != null) + cmd.CommandText = @$" + SELECT {columns} FROM {tableName} + WHERE {filterSql} + ORDER BY {orderBySql} + LIMIT @limit + OFFSET @offset + "; + + cmd.Parameters.AddWithValue("@limit", limit); + cmd.Parameters.AddWithValue("@offset", offset); + + if (sqlUserValues != null) + { + foreach (KeyValuePair kv in sqlUserValues) { - foreach (KeyValuePair kv in sqlUserValues) - { - cmd.Parameters.AddWithValue(kv.Key, kv.Value); - } + cmd.Parameters.AddWithValue(kv.Key, kv.Value); } + } #pragma warning restore CA2100 - // TODO: rewrite code to stream results (need to combine yield and try-catch) - var result = new List(); - try + // TODO: rewrite code to stream results (need to combine yield and try-catch) + var result = new List(); + try + { + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) { - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - result.Add(this.ReadEntry(dataReader, withEmbeddings)); - } + result.Add(this.ReadEntry(dataReader, withEmbeddings)); } } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } + } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); + } - // TODO: rewrite code to stream results (need to combine yield and try-catch) - foreach (var x in result) - { - yield return x; - } + // TODO: rewrite code to stream results (need to combine yield and try-catch) + foreach (var x in result) + { + yield return x; } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -650,36 +630,51 @@ public async Task DeleteAsync( tableName = this.WithSchemaAndTableNamePrefix(tableName); this._log.LogTrace("Deleting record '{0}' from table '{1}'", id, tableName); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $"DELETE FROM {tableName} WHERE {this._colId}=@id"; - cmd.Parameters.AddWithValue("@id", id); + cmd.CommandText = $"DELETE FROM {tableName} WHERE {this._colId}=@id"; + cmd.Parameters.AddWithValue("@id", id); #pragma warning restore CA2100 - try - { - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } + try + { + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); } - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); } } + finally + { + await connection.CloseAsync().ConfigureAwait(false); + } + } + } + + /// + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Disposes the managed resources + /// + private void Dispose(bool disposing) + { + if (disposing) + { + (this._dataSource as IDisposable)?.Dispose(); } } @@ -688,12 +683,11 @@ public async Task DeleteAsync( /// /// /// - private async Task<(NpgsqlDataSource DataSource, NpgsqlConnection Connection)> ConnectAsync(CancellationToken cancellationToken = default) + private async Task ConnectAsync(CancellationToken cancellationToken = default) { try { - var dataSource = this._dataSourceBuilder.Build(); - return (dataSource, await dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)); + return await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); } catch (Npgsql.PostgresException e) when (IsDbNotFoundException(e)) { @@ -786,4 +780,4 @@ private static long GenLockId(string resourceId) return BitConverter.ToUInt32(SHA256.HashData(Encoding.UTF8.GetBytes(resourceId)), 0) % short.MaxValue; } -} +} \ No newline at end of file diff --git a/extensions/Postgres/Postgres/PostgresMemory.cs b/extensions/Postgres/Postgres/PostgresMemory.cs index 8e6a6bac3..fd20311b0 100644 --- a/extensions/Postgres/Postgres/PostgresMemory.cs +++ b/extensions/Postgres/Postgres/PostgresMemory.cs @@ -21,7 +21,7 @@ namespace Microsoft.KernelMemory.Postgres; /// Postgres connector for Kernel Memory. /// [Experimental("KMEXP03")] -public sealed class PostgresMemory : IMemoryDb +public sealed class PostgresMemory : IMemoryDb, IDisposable { private readonly ILogger _log; private readonly ITextEmbeddingGenerator _embeddingGenerator; @@ -209,6 +209,24 @@ public Task DeleteAsync( return this._db.DeleteAsync(tableName: index, id: record.Id, cancellationToken); } + /// + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Disposes the managed resources. + /// + private void Dispose(bool disposing) + { + if (disposing) + { + (this._db as IDisposable)?.Dispose(); + } + } + #region private ================================================================================ // Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs @@ -288,4 +306,4 @@ private static string NormalizeTableNamePrefix(string? name) } #endregion -} +} \ No newline at end of file