diff --git a/Directory.Packages.props b/Directory.Packages.props
index ca6246ee6..41efaafd0 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -42,7 +42,6 @@
-
@@ -78,11 +77,11 @@
all
runtime; build; native; contentfiles; analyzers; buildtransitive
-
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
-
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
diff --git a/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs b/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs
index 7e68ea2e8..2fc2a68fa 100644
--- a/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs
+++ b/extensions/Postgres/Postgres.FunctionalTests/ConcurrencyTests.cs
@@ -69,7 +69,7 @@ last_update TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL
var indexName = "create_index_test";
var vectorSize = 1536;
- var target = new PostgresMemory(config, new FakeEmbeddingGenerator());
+ using var target = new PostgresMemory(config, new FakeEmbeddingGenerator());
var tasks = new List();
for (int i = 0; i < concurrency; i++)
@@ -96,7 +96,7 @@ public async Task UpsertConcurrencyTest()
var vectorSize = 4;
var indexName = "upsert_test" + Guid.NewGuid().ToString("D");
- var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator());
+ using var target = new PostgresMemory(this.PostgresConfig, new FakeEmbeddingGenerator());
await target.CreateIndexAsync(indexName, vectorSize);
diff --git a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs
index f20f8c0ee..7d05b4b92 100644
--- a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs
+++ b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs
@@ -19,28 +19,11 @@ namespace Microsoft.KernelMemory.Postgres;
///
/// An implementation of a client for Postgres. This class is used to managing postgres database operations.
///
-internal sealed class PostgresDbClient
+internal sealed class PostgresDbClient : IDisposable, IAsyncDisposable
{
- // See: https://www.postgresql.org/docs/current/errcodes-appendix.html
- private const string PgErrUndefinedTable = "42P01"; // undefined_table
- private const string PgErrUniqueViolation = "23505"; // unique_violation
- private const string PgErrTypeDoesNotExist = "42704"; // undefined_object
- private const string PgErrDatabaseDoesNotExist = "3D000"; // invalid_catalog_name
-
+ // Dependencies
+ private readonly NpgsqlDataSource _dataSource;
private readonly ILogger _log;
- private readonly NpgsqlDataSourceBuilder _dataSourceBuilder;
-
- private readonly string _schema;
- private readonly string _tableNamePrefix;
- private readonly string _createTableSql;
- private readonly string _colId;
- private readonly string _colEmbedding;
- private readonly string _colTags;
- private readonly string _colContent;
- private readonly string _colPayload;
- private readonly string _columnsListNoEmbeddings;
- private readonly string _columnsListWithEmbeddings;
- private readonly bool _dbNamePresent;
///
/// Initializes a new instance of the class.
@@ -52,8 +35,9 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n
config.Validate();
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger();
- this._dataSourceBuilder = new(config.ConnectionString);
- this._dataSourceBuilder.UseVector();
+ NpgsqlDataSourceBuilder dataSourceBuilder = new(config.ConnectionString);
+ dataSourceBuilder.UseVector();
+ this._dataSource = dataSourceBuilder.Build();
this._dbNamePresent = config.ConnectionString.Contains("Database=", StringComparison.OrdinalIgnoreCase);
this._schema = config.Schema;
@@ -96,51 +80,48 @@ public async Task DoesTableExistAsync(
tableName = this.WithTableNamePrefix(tableName);
this._log.LogTrace("Checking if table {0} exists", tableName);
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
#pragma warning disable CA2100 // SQL reviewed
- cmd.CommandText = $@"
- SELECT table_name
- FROM information_schema.tables
- WHERE table_schema = @schema
- AND table_name = @table
- AND table_type = 'BASE TABLE'
- LIMIT 1
- ";
-
- cmd.Parameters.AddWithValue("@schema", this._schema);
- cmd.Parameters.AddWithValue("@table", tableName);
+ cmd.CommandText = $@"
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = @schema
+ AND table_name = @table
+ AND table_type = 'BASE TABLE'
+ LIMIT 1
+ ";
+
+ cmd.Parameters.AddWithValue("@schema", this._schema);
+ cmd.Parameters.AddWithValue("@table", tableName);
#pragma warning restore CA2100
- this._log.LogTrace("Schema: {0}, Table: {1}, SQL: {2}", this._schema, tableName, cmd.CommandText);
+ this._log.LogTrace("Schema: {0}, Table: {1}, SQL: {2}", this._schema, tableName, cmd.CommandText);
- NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
- await using (dataReader.ConfigureAwait(false))
+ NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ await using (dataReader.ConfigureAwait(false))
+ {
+ if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
- if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
- {
- var name = dataReader.GetString(dataReader.GetOrdinal("table_name"));
-
- return string.Equals(name, tableName, StringComparison.OrdinalIgnoreCase);
- }
+ var name = dataReader.GetString(dataReader.GetOrdinal("table_name"));
- this._log.LogTrace("Table {0} does not exist", tableName);
- return false;
+ return string.Equals(name, tableName, StringComparison.OrdinalIgnoreCase);
}
+
+ this._log.LogTrace("Table {0} does not exist", tableName);
+ return false;
}
}
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
- }
+ }
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
}
}
}
@@ -162,71 +143,68 @@ public async Task CreateTableAsync(
Npgsql.PostgresException? createErr = null;
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
- var lockId = GenLockId(tableName);
+ var lockId = GenLockId(tableName);
#pragma warning disable CA2100 // SQL reviewed
- if (!string.IsNullOrEmpty(this._createTableSql))
- {
- cmd.CommandText = this._createTableSql
- .Replace(PostgresConfig.SqlPlaceholdersTableName, tableName, StringComparison.Ordinal)
- .Replace(PostgresConfig.SqlPlaceholdersVectorSize, $"{vectorSize}", StringComparison.Ordinal)
- .Replace(PostgresConfig.SqlPlaceholdersLockId, $"{lockId}", StringComparison.Ordinal);
+ if (!string.IsNullOrEmpty(this._createTableSql))
+ {
+ cmd.CommandText = this._createTableSql
+ .Replace(PostgresConfig.SqlPlaceholdersTableName, tableName, StringComparison.Ordinal)
+ .Replace(PostgresConfig.SqlPlaceholdersVectorSize, $"{vectorSize}", StringComparison.Ordinal)
+ .Replace(PostgresConfig.SqlPlaceholdersLockId, $"{lockId}", StringComparison.Ordinal);
- this._log.LogTrace("Creating table with custom SQL: {0}", cmd.CommandText);
- }
- else
- {
- cmd.CommandText = $@"
- BEGIN;
- SELECT pg_advisory_xact_lock({lockId});
- CREATE TABLE IF NOT EXISTS {tableName} (
- {this._colId} TEXT NOT NULL PRIMARY KEY,
- {this._colEmbedding} vector({vectorSize}),
- {this._colTags} TEXT[] DEFAULT '{{}}'::TEXT[] NOT NULL,
- {this._colContent} TEXT DEFAULT '' NOT NULL,
- {this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL
- );
- CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags});
- COMMIT;
- ";
+ this._log.LogTrace("Creating table with custom SQL: {0}", cmd.CommandText);
+ }
+ else
+ {
+ cmd.CommandText = $@"
+ BEGIN;
+ SELECT pg_advisory_xact_lock({lockId});
+ CREATE TABLE IF NOT EXISTS {tableName} (
+ {this._colId} TEXT NOT NULL PRIMARY KEY,
+ {this._colEmbedding} vector({vectorSize}),
+ {this._colTags} TEXT[] DEFAULT '{{}}'::TEXT[] NOT NULL,
+ {this._colContent} TEXT DEFAULT '' NOT NULL,
+ {this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL
+ );
+ CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags});
+ COMMIT;
+ ";
#pragma warning restore CA2100
- this._log.LogTrace("Creating table with default SQL: {0}", cmd.CommandText);
- }
-
- int result = await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
- this._log.LogTrace("Table '{0}' creation result: {1}", tableName, result);
+ this._log.LogTrace("Creating table with default SQL: {0}", cmd.CommandText);
}
- }
- catch (Npgsql.PostgresException e) when (IsVectorTypeDoesNotExistException(e))
- {
- this._log.LogError(e, "Vector type not installed, check 'SELECT * FROM pg_extension'");
- throw;
- }
- catch (Npgsql.PostgresException e) when (e.SqlState == PgErrUniqueViolation)
- {
- createErr = e;
- }
- catch (Exception e)
- {
- this._log.LogError(e, "Table '{0}' creation error: {1}. Err: {2}. InnerEx: {3}", tableName, e, e.Message, e.InnerException);
- throw;
- }
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
+
+ int result = await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
+ this._log.LogTrace("Table '{0}' creation result: {1}", tableName, result);
}
}
+ catch (Npgsql.PostgresException e) when (IsVectorTypeDoesNotExistException(e))
+ {
+ this._log.LogError(e, "Vector type not installed, check 'SELECT * FROM pg_extension'");
+ throw;
+ }
+ catch (Npgsql.PostgresException e) when (e.SqlState == PgErrUniqueViolation)
+ {
+ createErr = e;
+ }
+ catch (Exception e)
+ {
+ this._log.LogError(e, "Table '{0}' creation error: {1}. Err: {2}. InnerEx: {3}", tableName, e, e.Message, e.InnerException);
+ throw;
+ }
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
+ }
}
if (createErr != null)
@@ -267,40 +245,37 @@ public async Task CreateTableAsync(
public async IAsyncEnumerable GetTablesAsync(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
- cmd.CommandText = @"SELECT table_name FROM information_schema.tables
+ cmd.CommandText = @"SELECT table_name FROM information_schema.tables
WHERE table_schema = @schema AND table_type = 'BASE TABLE';";
- cmd.Parameters.AddWithValue("@schema", this._schema);
+ cmd.Parameters.AddWithValue("@schema", this._schema);
- this._log.LogTrace("Fetching list of tables. SQL: {0}. Schema: {1}", cmd.CommandText, this._schema);
+ this._log.LogTrace("Fetching list of tables. SQL: {0}. Schema: {1}", cmd.CommandText, this._schema);
- NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
- await using (dataReader.ConfigureAwait(false))
+ NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ await using (dataReader.ConfigureAwait(false))
+ {
+ while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
- while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
+ var tableNameWithPrefix = dataReader.GetString(dataReader.GetOrdinal("table_name"));
+ if (tableNameWithPrefix.StartsWith(this._tableNamePrefix, StringComparison.OrdinalIgnoreCase))
{
- var tableNameWithPrefix = dataReader.GetString(dataReader.GetOrdinal("table_name"));
- if (tableNameWithPrefix.StartsWith(this._tableNamePrefix, StringComparison.OrdinalIgnoreCase))
- {
- yield return tableNameWithPrefix.Remove(0, this._tableNamePrefix.Length);
- }
+ yield return tableNameWithPrefix.Remove(0, this._tableNamePrefix.Length);
}
}
}
}
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
- }
+ }
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
}
}
}
@@ -316,33 +291,30 @@ public async Task DeleteTableAsync(
{
tableName = this.WithSchemaAndTableNamePrefix(tableName);
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
#pragma warning disable CA2100 // SQL reviewed
- cmd.CommandText = $"DROP TABLE IF EXISTS {tableName}";
+ cmd.CommandText = $"DROP TABLE IF EXISTS {tableName}";
#pragma warning restore CA2100
- this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText);
- await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
- }
- }
- catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
- {
- this._log.LogTrace("Table not found: {0}", tableName);
- }
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
+ this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText);
+ await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}
+ catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
+ {
+ this._log.LogTrace("Table not found: {0}", tableName);
+ }
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
+ }
}
}
@@ -363,55 +335,52 @@ public async Task UpsertAsync(
const string EmptyContent = "";
string[] emptyTags = [];
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
#pragma warning disable CA2100 // SQL reviewed
- cmd.CommandText = $@"
- INSERT INTO {tableName}
- ({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload})
- VALUES
- (@id, @embedding, @tags, @content, @payload)
- ON CONFLICT ({this._colId})
- DO UPDATE SET
- {this._colEmbedding} = @embedding,
- {this._colTags} = @tags,
- {this._colContent} = @content,
- {this._colPayload} = @payload
- ";
-
- cmd.Parameters.AddWithValue("@id", record.Id);
- cmd.Parameters.AddWithValue("@embedding", record.Embedding);
- cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags);
- cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, CleanContent(record.Content) ?? EmptyContent);
- cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload);
+ cmd.CommandText = $@"
+ INSERT INTO {tableName}
+ ({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload})
+ VALUES
+ (@id, @embedding, @tags, @content, @payload)
+ ON CONFLICT ({this._colId})
+ DO UPDATE SET
+ {this._colEmbedding} = @embedding,
+ {this._colTags} = @tags,
+ {this._colContent} = @content,
+ {this._colPayload} = @payload
+ ";
+
+ cmd.Parameters.AddWithValue("@id", record.Id);
+ cmd.Parameters.AddWithValue("@embedding", record.Embedding);
+ cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags);
+ cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, CleanContent(record.Content) ?? EmptyContent);
+ cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload);
#pragma warning restore CA2100
- this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName);
+ this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName);
- await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
- }
- }
- catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
- {
- throw new IndexNotFoundException(e.Message, e);
- }
- catch (Exception e)
- {
- throw new PostgresException(e.Message, e);
- }
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
+ await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}
+ catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
+ {
+ throw new IndexNotFoundException(e.Message, e);
+ }
+ catch (Exception e)
+ {
+ throw new PostgresException(e.Message, e);
+ }
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
+ }
}
}
@@ -460,71 +429,68 @@ DO UPDATE SET
this._log.LogTrace("Searching by similarity. Table: {0}. Threshold: {1}. Limit: {2}. Offset: {3}. Using SQL filter: {4}",
tableName, minSimilarity, limit, offset, string.IsNullOrWhiteSpace(filterSql) ? "false" : "true");
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
#pragma warning disable CA2100 // SQL reviewed
- string colDistance = "__distance";
-
- // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
- // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
- cmd.CommandText = @$"
- SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance}
- FROM {tableName}
- WHERE {filterSql}
- ORDER BY {colDistance} ASC
- LIMIT @limit
- OFFSET @offset
- ";
-
- cmd.Parameters.AddWithValue("@embedding", target);
- cmd.Parameters.AddWithValue("@maxDistance", maxDistance);
- cmd.Parameters.AddWithValue("@limit", limit);
- cmd.Parameters.AddWithValue("@offset", offset);
-
- foreach (KeyValuePair kv in sqlUserValues)
- {
- cmd.Parameters.AddWithValue(kv.Key, kv.Value);
- }
+ string colDistance = "__distance";
+
+ // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
+ // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
+ cmd.CommandText = @$"
+ SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance}
+ FROM {tableName}
+ WHERE {filterSql}
+ ORDER BY {colDistance} ASC
+ LIMIT @limit
+ OFFSET @offset
+ ";
+
+ cmd.Parameters.AddWithValue("@embedding", target);
+ cmd.Parameters.AddWithValue("@maxDistance", maxDistance);
+ cmd.Parameters.AddWithValue("@limit", limit);
+ cmd.Parameters.AddWithValue("@offset", offset);
+
+ foreach (KeyValuePair kv in sqlUserValues)
+ {
+ cmd.Parameters.AddWithValue(kv.Key, kv.Value);
+ }
#pragma warning restore CA2100
- // TODO: rewrite code to stream results (need to combine yield and try-catch)
- var result = new List<(PostgresMemoryRecord record, double similarity)>();
- try
+ // TODO: rewrite code to stream results (need to combine yield and try-catch)
+ var result = new List<(PostgresMemoryRecord record, double similarity)>();
+ try
+ {
+ NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ await using (dataReader.ConfigureAwait(false))
{
- NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
- await using (dataReader.ConfigureAwait(false))
+ while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
- while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
- {
- double distance = dataReader.GetDouble(dataReader.GetOrdinal(colDistance));
- double similarity = 1 - distance;
- result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity));
- }
+ double distance = dataReader.GetDouble(dataReader.GetOrdinal(colDistance));
+ double similarity = 1 - distance;
+ result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity));
}
}
- catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
- {
- this._log.LogTrace("Table not found: {0}", tableName);
- }
+ }
+ catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
+ {
+ this._log.LogTrace("Table not found: {0}", tableName);
+ }
- // TODO: rewrite code to stream results (need to combine yield and try-catch)
- foreach (var x in result)
- {
- yield return x;
- }
+ // TODO: rewrite code to stream results (need to combine yield and try-catch)
+ foreach (var x in result)
+ {
+ yield return x;
}
}
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
- }
+ }
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
}
}
}
@@ -572,66 +538,63 @@ public async IAsyncEnumerable GetListAsync(
this._log.LogTrace("Fetching list of records. Table: {0}. Order by: {1}. Limit: {2}. Offset: {3}. Using SQL filter: {4}",
tableName, orderBySql, limit, offset, string.IsNullOrWhiteSpace(filterSql) ? "false" : "true");
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
#pragma warning disable CA2100 // SQL reviewed
- cmd.CommandText = @$"
- SELECT {columns} FROM {tableName}
- WHERE {filterSql}
- ORDER BY {orderBySql}
- LIMIT @limit
- OFFSET @offset
- ";
-
- cmd.Parameters.AddWithValue("@limit", limit);
- cmd.Parameters.AddWithValue("@offset", offset);
-
- if (sqlUserValues != null)
+ cmd.CommandText = @$"
+ SELECT {columns} FROM {tableName}
+ WHERE {filterSql}
+ ORDER BY {orderBySql}
+ LIMIT @limit
+ OFFSET @offset
+ ";
+
+ cmd.Parameters.AddWithValue("@limit", limit);
+ cmd.Parameters.AddWithValue("@offset", offset);
+
+ if (sqlUserValues != null)
+ {
+ foreach (KeyValuePair kv in sqlUserValues)
{
- foreach (KeyValuePair kv in sqlUserValues)
- {
- cmd.Parameters.AddWithValue(kv.Key, kv.Value);
- }
+ cmd.Parameters.AddWithValue(kv.Key, kv.Value);
}
+ }
#pragma warning restore CA2100
- // TODO: rewrite code to stream results (need to combine yield and try-catch)
- var result = new List();
- try
+ // TODO: rewrite code to stream results (need to combine yield and try-catch)
+ var result = new List();
+ try
+ {
+ NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ await using (dataReader.ConfigureAwait(false))
{
- NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
- await using (dataReader.ConfigureAwait(false))
+ while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
- while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
- {
- result.Add(this.ReadEntry(dataReader, withEmbeddings));
- }
+ result.Add(this.ReadEntry(dataReader, withEmbeddings));
}
}
- catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
- {
- this._log.LogTrace("Table not found: {0}", tableName);
- }
+ }
+ catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
+ {
+ this._log.LogTrace("Table not found: {0}", tableName);
+ }
- // TODO: rewrite code to stream results (need to combine yield and try-catch)
- foreach (var x in result)
- {
- yield return x;
- }
+ // TODO: rewrite code to stream results (need to combine yield and try-catch)
+ foreach (var x in result)
+ {
+ yield return x;
}
}
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
- }
+ }
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
}
}
}
@@ -650,50 +613,84 @@ public async Task DeleteAsync(
tableName = this.WithSchemaAndTableNamePrefix(tableName);
this._log.LogTrace("Deleting record '{0}' from table '{1}'", id, tableName);
- var (dataSource, connection) = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
- await using (dataSource.ConfigureAwait(false))
+ NpgsqlConnection connection = await this.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await using (connection)
{
- await using (connection)
+ try
{
- try
+ NpgsqlCommand cmd = connection.CreateCommand();
+ await using (cmd.ConfigureAwait(false))
{
- NpgsqlCommand cmd = connection.CreateCommand();
- await using (cmd.ConfigureAwait(false))
- {
#pragma warning disable CA2100 // SQL reviewed
- cmd.CommandText = $"DELETE FROM {tableName} WHERE {this._colId}=@id";
- cmd.Parameters.AddWithValue("@id", id);
+ cmd.CommandText = $"DELETE FROM {tableName} WHERE {this._colId}=@id";
+ cmd.Parameters.AddWithValue("@id", id);
#pragma warning restore CA2100
- try
- {
- await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
- }
- catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
- {
- this._log.LogTrace("Table not found: {0}", tableName);
- }
+ try
+ {
+ await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
+ }
+ catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
+ {
+ this._log.LogTrace("Table not found: {0}", tableName);
}
- }
- finally
- {
- await connection.CloseAsync().ConfigureAwait(false);
}
}
+ finally
+ {
+ await connection.CloseAsync().ConfigureAwait(false);
+ }
+ }
+ }
+
+ ///
+ public void Dispose()
+ {
+ this._dataSource?.Dispose();
+ }
+
+ ///
+ public async ValueTask DisposeAsync()
+ {
+ try
+ {
+ await this._dataSource.DisposeAsync().ConfigureAwait(false);
+ }
+ catch (NullReferenceException)
+ {
+ // ignore
}
}
+ #region private ================================================================================
+
+ // See: https://www.postgresql.org/docs/current/errcodes-appendix.html
+ private const string PgErrUndefinedTable = "42P01"; // undefined_table
+ private const string PgErrUniqueViolation = "23505"; // unique_violation
+ private const string PgErrTypeDoesNotExist = "42704"; // undefined_object
+ private const string PgErrDatabaseDoesNotExist = "3D000"; // invalid_catalog_name
+
+ private readonly string _schema;
+ private readonly string _tableNamePrefix;
+ private readonly string _createTableSql;
+ private readonly string _colId;
+ private readonly string _colEmbedding;
+ private readonly string _colTags;
+ private readonly string _colContent;
+ private readonly string _colPayload;
+ private readonly string _columnsListNoEmbeddings;
+ private readonly string _columnsListWithEmbeddings;
+ private readonly bool _dbNamePresent;
+
///
/// Try to connect to PG, handling exceptions in case the DB doesn't exist
///
///
- ///
- private async Task<(NpgsqlDataSource DataSource, NpgsqlConnection Connection)> ConnectAsync(CancellationToken cancellationToken = default)
+ private async Task ConnectAsync(CancellationToken cancellationToken = default)
{
try
{
- var dataSource = this._dataSourceBuilder.Build();
- return (dataSource, await dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false));
+ return await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
}
catch (Npgsql.PostgresException e) when (IsDbNotFoundException(e))
{
@@ -786,4 +783,6 @@ private static long GenLockId(string resourceId)
return BitConverter.ToUInt32(SHA256.HashData(Encoding.UTF8.GetBytes(resourceId)), 0)
% short.MaxValue;
}
+
+ #endregion
}
diff --git a/extensions/Postgres/Postgres/Postgres.csproj b/extensions/Postgres/Postgres/Postgres.csproj
index 2c80dec15..a76e6fcf4 100644
--- a/extensions/Postgres/Postgres/Postgres.csproj
+++ b/extensions/Postgres/Postgres/Postgres.csproj
@@ -14,7 +14,6 @@
-
diff --git a/extensions/Postgres/Postgres/PostgresMemory.cs b/extensions/Postgres/Postgres/PostgresMemory.cs
index 8e6a6bac3..bd7854c9c 100644
--- a/extensions/Postgres/Postgres/PostgresMemory.cs
+++ b/extensions/Postgres/Postgres/PostgresMemory.cs
@@ -21,11 +21,12 @@ namespace Microsoft.KernelMemory.Postgres;
/// Postgres connector for Kernel Memory.
///
[Experimental("KMEXP03")]
-public sealed class PostgresMemory : IMemoryDb
+public sealed class PostgresMemory : IMemoryDb, IDisposable, IAsyncDisposable
{
- private readonly ILogger _log;
- private readonly ITextEmbeddingGenerator _embeddingGenerator;
+ // Dependencies
private readonly PostgresDbClient _db;
+ private readonly ITextEmbeddingGenerator _embeddingGenerator;
+ private readonly ILogger _log;
///
/// Create a new instance of Postgres KM connector
@@ -209,6 +210,25 @@ public Task DeleteAsync(
return this._db.DeleteAsync(tableName: index, id: record.Id, cancellationToken);
}
+ ///
+ public void Dispose()
+ {
+ this._db?.Dispose();
+ }
+
+ ///
+ public async ValueTask DisposeAsync()
+ {
+ try
+ {
+ await this._db.DisposeAsync().ConfigureAwait(false);
+ }
+ catch (NullReferenceException)
+ {
+ // ignore
+ }
+ }
+
#region private ================================================================================
// Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs