diff --git a/Directory.Packages.props b/Directory.Packages.props index ca6246ee6..41efaafd0 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -42,7 +42,6 @@ - @@ -78,11 +77,11 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs b/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs index 7e68ea2e8..2fc2a68fa 100644 --- a/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs +++ b/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs @@ -69,7 +69,7 @@ last_update TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL var indexName = "create_index_test"; var vectorSize = 1536; - var target = new PostgresMemory(config, new FakeEmbeddingGenerator()); + using var target = new PostgresMemory(config, new FakeEmbeddingGenerator()); var tasks = new List(); for (int i = 0; i < concurrency; i++) @@ -96,7 +96,7 @@ public async Task UpsertConcurrencyTest() var vectorSize = 4; var indexName = "upsert_test" + Guid.NewGuid().ToString("D"); - var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator()); + using var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator()); await target.CreateIndexAsync(indexName, vectorSize); diff --git a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs index f20f8c0ee..7d05b4b92 100644 --- a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs +++ b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs @@ -19,28 +19,11 @@ namespace Microsoft.KernelMemory.Postgres; /// /// An implementation of a client for Postgres. This class is used to managing postgres database operations. /// -internal sealed class PostgresDbClient +internal sealed class PostgresDbClient : IDisposable, IAsyncDisposable { - // See: https://www.postgresql.org/docs/current/errcodes-appendix.html - private const string PgErrUndefinedTable = "42P01"; // undefined_table - private const string PgErrUniqueViolation = "23505"; // unique_violation - private const string PgErrTypeDoesNotExist = "42704"; // undefined_object - private const string PgErrDatabaseDoesNotExist = "3D000"; // invalid_catalog_name - + // Dependencies + private readonly NpgsqlDataSource _dataSource; private readonly ILogger _log; - private readonly NpgsqlDataSourceBuilder _dataSourceBuilder; - - private readonly string _schema; - private readonly string _tableNamePrefix; - private readonly string _createTableSql; - private readonly string _colId; - private readonly string _colEmbedding; - private readonly string _colTags; - private readonly string _colContent; - private readonly string _colPayload; - private readonly string _columnsListNoEmbeddings; - private readonly string _columnsListWithEmbeddings; - private readonly bool _dbNamePresent; /// /// Initializes a new instance of the class. @@ -52,8 +35,9 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n config.Validate(); this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); - this._dataSourceBuilder = new(config.ConnectionString); - this._dataSourceBuilder.UseVector(); + NpgsqlDataSourceBuilder dataSourceBuilder = new(config.ConnectionString); + dataSourceBuilder.UseVector(); + this._dataSource = dataSourceBuilder.Build(); this._dbNamePresent = config.ConnectionString.Contains("Database=", StringComparison.OrdinalIgnoreCase); this._schema = config.Schema; @@ -96,51 +80,48 @@ public async Task DoesTableExistAsync( tableName = this.WithTableNamePrefix(tableName); this._log.LogTrace("Checking if table {0} exists", tableName); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $@" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = @schema - AND table_name = @table - AND table_type = 'BASE TABLE' - LIMIT 1 - "; - - cmd.Parameters.AddWithValue("@schema", this._schema); - cmd.Parameters.AddWithValue("@table", tableName); + cmd.CommandText = $@" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = @schema + AND table_name = @table + AND table_type = 'BASE TABLE' + LIMIT 1 + "; + + cmd.Parameters.AddWithValue("@schema", this._schema); + cmd.Parameters.AddWithValue("@table", tableName); #pragma warning restore CA2100 - this._log.LogTrace("Schema: {0}, Table: {1}, SQL: {2}", this._schema, tableName, cmd.CommandText); + this._log.LogTrace("Schema: {0}, Table: {1}, SQL: {2}", this._schema, tableName, cmd.CommandText); - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) + { + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - var name = dataReader.GetString(dataReader.GetOrdinal("table_name")); - - return string.Equals(name, tableName, StringComparison.OrdinalIgnoreCase); - } + var name = dataReader.GetString(dataReader.GetOrdinal("table_name")); - this._log.LogTrace("Table {0} does not exist", tableName); - return false; + return string.Equals(name, tableName, StringComparison.OrdinalIgnoreCase); } + + this._log.LogTrace("Table {0} does not exist", tableName); + return false; } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -162,71 +143,68 @@ public async Task CreateTableAsync( Npgsql.PostgresException? createErr = null; - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { - var lockId = GenLockId(tableName); + var lockId = GenLockId(tableName); #pragma warning disable CA2100 // SQL reviewed - if (!string.IsNullOrEmpty(this._createTableSql)) - { - cmd.CommandText = this._createTableSql - .Replace(PostgresConfig.SqlPlaceholdersTableName, tableName, StringComparison.Ordinal) - .Replace(PostgresConfig.SqlPlaceholdersVectorSize, $"{vectorSize}", StringComparison.Ordinal) - .Replace(PostgresConfig.SqlPlaceholdersLockId, $"{lockId}", StringComparison.Ordinal); + if (!string.IsNullOrEmpty(this._createTableSql)) + { + cmd.CommandText = this._createTableSql + .Replace(PostgresConfig.SqlPlaceholdersTableName, tableName, StringComparison.Ordinal) + .Replace(PostgresConfig.SqlPlaceholdersVectorSize, $"{vectorSize}", StringComparison.Ordinal) + .Replace(PostgresConfig.SqlPlaceholdersLockId, $"{lockId}", StringComparison.Ordinal); - this._log.LogTrace("Creating table with custom SQL: {0}", cmd.CommandText); - } - else - { - cmd.CommandText = $@" - BEGIN; - SELECT pg_advisory_xact_lock({lockId}); - CREATE TABLE IF NOT EXISTS {tableName} ( - {this._colId} TEXT NOT NULL PRIMARY KEY, - {this._colEmbedding} vector({vectorSize}), - {this._colTags} TEXT[] DEFAULT '{{}}'::TEXT[] NOT NULL, - {this._colContent} TEXT DEFAULT '' NOT NULL, - {this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags}); - COMMIT; - "; + this._log.LogTrace("Creating table with custom SQL: {0}", cmd.CommandText); + } + else + { + cmd.CommandText = $@" + BEGIN; + SELECT pg_advisory_xact_lock({lockId}); + CREATE TABLE IF NOT EXISTS {tableName} ( + {this._colId} TEXT NOT NULL PRIMARY KEY, + {this._colEmbedding} vector({vectorSize}), + {this._colTags} TEXT[] DEFAULT '{{}}'::TEXT[] NOT NULL, + {this._colContent} TEXT DEFAULT '' NOT NULL, + {this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags}); + COMMIT; + "; #pragma warning restore CA2100 - this._log.LogTrace("Creating table with default SQL: {0}", cmd.CommandText); - } - - int result = await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - this._log.LogTrace("Table '{0}' creation result: {1}", tableName, result); + this._log.LogTrace("Creating table with default SQL: {0}", cmd.CommandText); } - } - catch (Npgsql.PostgresException e) when (IsVectorTypeDoesNotExistException(e)) - { - this._log.LogError(e, "Vector type not installed, check 'SELECT * FROM pg_extension'"); - throw; - } - catch (Npgsql.PostgresException e) when (e.SqlState == PgErrUniqueViolation) - { - createErr = e; - } - catch (Exception e) - { - this._log.LogError(e, "Table '{0}' creation error: {1}. Err: {2}. InnerEx: {3}", tableName, e, e.Message, e.InnerException); - throw; - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); + + int result = await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + this._log.LogTrace("Table '{0}' creation result: {1}", tableName, result); } } + catch (Npgsql.PostgresException e) when (IsVectorTypeDoesNotExistException(e)) + { + this._log.LogError(e, "Vector type not installed, check 'SELECT * FROM pg_extension'"); + throw; + } + catch (Npgsql.PostgresException e) when (e.SqlState == PgErrUniqueViolation) + { + createErr = e; + } + catch (Exception e) + { + this._log.LogError(e, "Table '{0}' creation error: {1}. Err: {2}. InnerEx: {3}", tableName, e, e.Message, e.InnerException); + throw; + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); + } } if (createErr != null) @@ -267,40 +245,37 @@ public async Task CreateTableAsync( public async IAsyncEnumerable GetTablesAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { - cmd.CommandText = @"SELECT table_name FROM information_schema.tables + cmd.CommandText = @"SELECT table_name FROM information_schema.tables WHERE table_schema = @schema AND table_type = 'BASE TABLE';"; - cmd.Parameters.AddWithValue("@schema", this._schema); + cmd.Parameters.AddWithValue("@schema", this._schema); - this._log.LogTrace("Fetching list of tables. SQL: {0}. Schema: {1}", cmd.CommandText, this._schema); + this._log.LogTrace("Fetching list of tables. SQL: {0}. Schema: {1}", cmd.CommandText, this._schema); - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) + { + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + var tableNameWithPrefix = dataReader.GetString(dataReader.GetOrdinal("table_name")); + if (tableNameWithPrefix.StartsWith(this._tableNamePrefix, StringComparison.OrdinalIgnoreCase)) { - var tableNameWithPrefix = dataReader.GetString(dataReader.GetOrdinal("table_name")); - if (tableNameWithPrefix.StartsWith(this._tableNamePrefix, StringComparison.OrdinalIgnoreCase)) - { - yield return tableNameWithPrefix.Remove(0, this._tableNamePrefix.Length); - } + yield return tableNameWithPrefix.Remove(0, this._tableNamePrefix.Length); } } } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -316,33 +291,30 @@ public async Task DeleteTableAsync( { tableName = this.WithSchemaAndTableNamePrefix(tableName); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $"DROP TABLE IF EXISTS {tableName}"; + cmd.CommandText = $"DROP TABLE IF EXISTS {tableName}"; #pragma warning restore CA2100 - this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); + this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); + } } } @@ -363,55 +335,52 @@ public async Task UpsertAsync( const string EmptyContent = ""; string[] emptyTags = []; - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $@" - INSERT INTO {tableName} - ({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload}) - VALUES - (@id, @embedding, @tags, @content, @payload) - ON CONFLICT ({this._colId}) - DO UPDATE SET - {this._colEmbedding} = @embedding, - {this._colTags} = @tags, - {this._colContent} = @content, - {this._colPayload} = @payload - "; - - cmd.Parameters.AddWithValue("@id", record.Id); - cmd.Parameters.AddWithValue("@embedding", record.Embedding); - cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags); - cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, CleanContent(record.Content) ?? EmptyContent); - cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload); + cmd.CommandText = $@" + INSERT INTO {tableName} + ({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload}) + VALUES + (@id, @embedding, @tags, @content, @payload) + ON CONFLICT ({this._colId}) + DO UPDATE SET + {this._colEmbedding} = @embedding, + {this._colTags} = @tags, + {this._colContent} = @content, + {this._colPayload} = @payload + "; + + cmd.Parameters.AddWithValue("@id", record.Id); + cmd.Parameters.AddWithValue("@embedding", record.Embedding); + cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags); + cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, CleanContent(record.Content) ?? EmptyContent); + cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload); #pragma warning restore CA2100 - this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName); + this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - throw new IndexNotFoundException(e.Message, e); - } - catch (Exception e) - { - throw new PostgresException(e.Message, e); - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + throw new IndexNotFoundException(e.Message, e); + } + catch (Exception e) + { + throw new PostgresException(e.Message, e); + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); + } } } @@ -460,71 +429,68 @@ DO UPDATE SET this._log.LogTrace("Searching by similarity. Table: {0}. Threshold: {1}. Limit: {2}. Offset: {3}. Using SQL filter: {4}", tableName, minSimilarity, limit, offset, string.IsNullOrWhiteSpace(filterSql) ? "false" : "true"); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - string colDistance = "__distance"; - - // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate - // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. - cmd.CommandText = @$" - SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} - FROM {tableName} - WHERE {filterSql} - ORDER BY {colDistance} ASC - LIMIT @limit - OFFSET @offset - "; - - cmd.Parameters.AddWithValue("@embedding", target); - cmd.Parameters.AddWithValue("@maxDistance", maxDistance); - cmd.Parameters.AddWithValue("@limit", limit); - cmd.Parameters.AddWithValue("@offset", offset); - - foreach (KeyValuePair kv in sqlUserValues) - { - cmd.Parameters.AddWithValue(kv.Key, kv.Value); - } + string colDistance = "__distance"; + + // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate + // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. + cmd.CommandText = @$" + SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} + FROM {tableName} + WHERE {filterSql} + ORDER BY {colDistance} ASC + LIMIT @limit + OFFSET @offset + "; + + cmd.Parameters.AddWithValue("@embedding", target); + cmd.Parameters.AddWithValue("@maxDistance", maxDistance); + cmd.Parameters.AddWithValue("@limit", limit); + cmd.Parameters.AddWithValue("@offset", offset); + + foreach (KeyValuePair kv in sqlUserValues) + { + cmd.Parameters.AddWithValue(kv.Key, kv.Value); + } #pragma warning restore CA2100 - // TODO: rewrite code to stream results (need to combine yield and try-catch) - var result = new List<(PostgresMemoryRecord record, double similarity)>(); - try + // TODO: rewrite code to stream results (need to combine yield and try-catch) + var result = new List<(PostgresMemoryRecord record, double similarity)>(); + try + { + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) { - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - double distance = dataReader.GetDouble(dataReader.GetOrdinal(colDistance)); - double similarity = 1 - distance; - result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity)); - } + double distance = dataReader.GetDouble(dataReader.GetOrdinal(colDistance)); + double similarity = 1 - distance; + result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity)); } } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } + } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); + } - // TODO: rewrite code to stream results (need to combine yield and try-catch) - foreach (var x in result) - { - yield return x; - } + // TODO: rewrite code to stream results (need to combine yield and try-catch) + foreach (var x in result) + { + yield return x; } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -572,66 +538,63 @@ public async IAsyncEnumerable GetListAsync( this._log.LogTrace("Fetching list of records. Table: {0}. Order by: {1}. Limit: {2}. Offset: {3}. Using SQL filter: {4}", tableName, orderBySql, limit, offset, string.IsNullOrWhiteSpace(filterSql) ? "false" : "true"); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = @$" - SELECT {columns} FROM {tableName} - WHERE {filterSql} - ORDER BY {orderBySql} - LIMIT @limit - OFFSET @offset - "; - - cmd.Parameters.AddWithValue("@limit", limit); - cmd.Parameters.AddWithValue("@offset", offset); - - if (sqlUserValues != null) + cmd.CommandText = @$" + SELECT {columns} FROM {tableName} + WHERE {filterSql} + ORDER BY {orderBySql} + LIMIT @limit + OFFSET @offset + "; + + cmd.Parameters.AddWithValue("@limit", limit); + cmd.Parameters.AddWithValue("@offset", offset); + + if (sqlUserValues != null) + { + foreach (KeyValuePair kv in sqlUserValues) { - foreach (KeyValuePair kv in sqlUserValues) - { - cmd.Parameters.AddWithValue(kv.Key, kv.Value); - } + cmd.Parameters.AddWithValue(kv.Key, kv.Value); } + } #pragma warning restore CA2100 - // TODO: rewrite code to stream results (need to combine yield and try-catch) - var result = new List(); - try + // TODO: rewrite code to stream results (need to combine yield and try-catch) + var result = new List(); + try + { + NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (dataReader.ConfigureAwait(false)) { - NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - await using (dataReader.ConfigureAwait(false)) + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - result.Add(this.ReadEntry(dataReader, withEmbeddings)); - } + result.Add(this.ReadEntry(dataReader, withEmbeddings)); } } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } + } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); + } - // TODO: rewrite code to stream results (need to combine yield and try-catch) - foreach (var x in result) - { - yield return x; - } + // TODO: rewrite code to stream results (need to combine yield and try-catch) + foreach (var x in result) + { + yield return x; } } - finally - { - await connection.CloseAsync().ConfigureAwait(false); - } + } + finally + { + await connection.CloseAsync().ConfigureAwait(false); } } } @@ -650,50 +613,84 @@ public async Task DeleteAsync( tableName = this.WithSchemaAndTableNamePrefix(tableName); this._log.LogTrace("Deleting record '{0}' from table '{1}'", id, tableName); - var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); - await using (dataSource.ConfigureAwait(false)) + NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false); + await using (connection) { - await using (connection) + try { - try + NpgsqlCommand cmd = connection.CreateCommand(); + await using (cmd.ConfigureAwait(false)) { - NpgsqlCommand cmd = connection.CreateCommand(); - await using (cmd.ConfigureAwait(false)) - { #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $"DELETE FROM {tableName} WHERE {this._colId}=@id"; - cmd.Parameters.AddWithValue("@id", id); + cmd.CommandText = $"DELETE FROM {tableName} WHERE {this._colId}=@id"; + cmd.Parameters.AddWithValue("@id", id); #pragma warning restore CA2100 - try - { - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) - { - this._log.LogTrace("Table not found: {0}", tableName); - } + try + { + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + this._log.LogTrace("Table not found: {0}", tableName); } - } - finally - { - await connection.CloseAsync().ConfigureAwait(false); } } + finally + { + await connection.CloseAsync().ConfigureAwait(false); + } + } + } + + /// + public void Dispose() + { + this._dataSource?.Dispose(); + } + + /// + public async ValueTask DisposeAsync() + { + try + { + await this._dataSource.DisposeAsync().ConfigureAwait(false); + } + catch (NullReferenceException) + { + // ignore } } + #region private ================================================================================ + + // See: https://www.postgresql.org/docs/current/errcodes-appendix.html + private const string PgErrUndefinedTable = "42P01"; // undefined_table + private const string PgErrUniqueViolation = "23505"; // unique_violation + private const string PgErrTypeDoesNotExist = "42704"; // undefined_object + private const string PgErrDatabaseDoesNotExist = "3D000"; // invalid_catalog_name + + private readonly string _schema; + private readonly string _tableNamePrefix; + private readonly string _createTableSql; + private readonly string _colId; + private readonly string _colEmbedding; + private readonly string _colTags; + private readonly string _colContent; + private readonly string _colPayload; + private readonly string _columnsListNoEmbeddings; + private readonly string _columnsListWithEmbeddings; + private readonly bool _dbNamePresent; + /// /// Try to connect to PG, handling exceptions in case the DB doesn't exist /// /// - /// - private async Task<(NpgsqlDataSource DataSource, NpgsqlConnection Connection)> ConnectAsync(CancellationToken cancellationToken = default) + private async Task ConnectAsync(CancellationToken cancellationToken = default) { try { - var dataSource = this._dataSourceBuilder.Build(); - return (dataSource, await dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false)); + return await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); } catch (Npgsql.PostgresException e) when (IsDbNotFoundException(e)) { @@ -786,4 +783,6 @@ private static long GenLockId(string resourceId) return BitConverter.ToUInt32(SHA256.HashData(Encoding.UTF8.GetBytes(resourceId)), 0) % short.MaxValue; } + + #endregion } diff --git a/extensions/Postgres/Postgres/Postgres.csproj b/extensions/Postgres/Postgres/Postgres.csproj index 2c80dec15..a76e6fcf4 100644 --- a/extensions/Postgres/Postgres/Postgres.csproj +++ b/extensions/Postgres/Postgres/Postgres.csproj @@ -14,7 +14,6 @@ - diff --git a/extensions/Postgres/Postgres/PostgresMemory.cs b/extensions/Postgres/Postgres/PostgresMemory.cs index 8e6a6bac3..bd7854c9c 100644 --- a/extensions/Postgres/Postgres/PostgresMemory.cs +++ b/extensions/Postgres/Postgres/PostgresMemory.cs @@ -21,11 +21,12 @@ namespace Microsoft.KernelMemory.Postgres; /// Postgres connector for Kernel Memory. /// [Experimental("KMEXP03")] -public sealed class PostgresMemory : IMemoryDb +public sealed class PostgresMemory : IMemoryDb, IDisposable, IAsyncDisposable { - private readonly ILogger _log; - private readonly ITextEmbeddingGenerator _embeddingGenerator; + // Dependencies private readonly PostgresDbClient _db; + private readonly ITextEmbeddingGenerator _embeddingGenerator; + private readonly ILogger _log; /// /// Create a new instance of Postgres KM connector @@ -209,6 +210,25 @@ public Task DeleteAsync( return this._db.DeleteAsync(tableName: index, id: record.Id, cancellationToken); } + /// + public void Dispose() + { + this._db?.Dispose(); + } + + /// + public async ValueTask DisposeAsync() + { + try + { + await this._db.DisposeAsync().ConfigureAwait(false); + } + catch (NullReferenceException) + { + // ignore + } + } + #region private ================================================================================ // Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs