Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test coverage for data source enum mapping support #2629

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 54 additions & 25 deletions test/EFCore.PG.FunctionalTests/Query/EnumQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ public EnumQueryTest(EnumFixture fixture, ITestOutputHelper testOutputHelper)
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

#region Roundtrip

[Fact]
public void Roundtrip()
{
Expand All @@ -22,10 +20,6 @@ public void Roundtrip()
Assert.Equal(MappedEnum.Happy, x.MappedEnum);
}

#endregion

#region Where

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task Where_with_constant(bool async)
Expand All @@ -39,7 +33,7 @@ await AssertQuery(

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."MappedEnum" = 'sad'::test.mapped_enum
""");
Expand All @@ -58,7 +52,7 @@ await AssertQuery(

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."SchemaQualifiedEnum" = 'Happy (PgName)'::test.schema_qualified_enum
""");
Expand All @@ -80,7 +74,7 @@ await AssertQuery(
"""
@__sad_0='Sad' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."MappedEnum" = @__sad_0
""");
Expand All @@ -102,7 +96,7 @@ await AssertQuery(
"""
@__sad_0='1'

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."UnmappedEnum" = @__sad_0
""");
Expand All @@ -124,7 +118,7 @@ await AssertQuery(
"""
@__sad_0='1'

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."UnmappedEnum" = @__sad_0
""");
Expand All @@ -146,7 +140,7 @@ await AssertQuery(
"""
@__sad_0='Sad' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."MappedEnum" = @__sad_0
""");
Expand All @@ -166,7 +160,7 @@ await AssertQuery(

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE strpos(s."MappedEnum"::text, 'sa') > 0
""");
Expand All @@ -189,7 +183,7 @@ await AssertQuery(
"""
@__values_0='0x01' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."ByteEnum" = ANY (@__values_0)
""");
Expand All @@ -211,13 +205,30 @@ await AssertQuery(
"""
@__values_0='0x01' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."UnmappedByteEnum" = ANY (@__values_0)
""");
}

#endregion
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task Global_enum_mapping(bool async)
{
using var ctx = CreateContext();

await AssertQuery(
async,
ss => ss.Set<SomeEnumEntity>().Where(e => e.GloballyMappedEnum == GloballyMappedEnum.Sad),
entryCount: 1);

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."GloballyMappedEnum" = 'sad'::test.globally_mapped_enum
""");
}

#region Support

Expand All @@ -234,21 +245,20 @@ public class EnumContext : PoolableDbContext
static EnumContext()
{
#pragma warning disable CS0618 // NpgsqlConnection.GlobalTypeMapper is obsolete
NpgsqlConnection.GlobalTypeMapper.MapEnum<MappedEnum>("test.mapped_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<InferredEnum>("test.inferred_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<ByteEnum>("test.byte_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<SchemaQualifiedEnum>("test.schema_qualified_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<GloballyMappedEnum>("test.globally_mapped_enum");
#pragma warning restore CS0618
}

public EnumContext(DbContextOptions options) : base(options) {}

protected override void OnModelCreating(ModelBuilder builder)
=> builder.HasPostgresEnum("mapped_enum", new[] { "happy", "sad" })
=> builder
.HasPostgresEnum("mapped_enum", new[] { "happy", "sad" })
.HasPostgresEnum<InferredEnum>()
.HasPostgresEnum<ByteEnum>()
.HasDefaultSchema("test")
.HasPostgresEnum<SchemaQualifiedEnum>();
.HasPostgresEnum<SchemaQualifiedEnum>()
.HasPostgresEnum<GloballyMappedEnum>();

public static void Seed(EnumContext context)
{
Expand All @@ -270,6 +280,7 @@ public class SomeEnumEntity
public ByteEnum ByteEnum { get; set; }
public UnmappedByteEnum UnmappedByteEnum { get; set; }
public int EnumValue { get; set; }
public GloballyMappedEnum GloballyMappedEnum { get; set; }
}

public enum MappedEnum
Expand All @@ -290,6 +301,12 @@ public enum InferredEnum
Sad
}

public enum GloballyMappedEnum
{
Happy,
Sad
}

public enum SchemaQualifiedEnum
{
[PgName("Happy (PgName)")]
Expand All @@ -313,7 +330,16 @@ public enum UnmappedByteEnum : byte
public class EnumFixture : SharedStoreFixtureBase<EnumContext>, IQueryFixtureBase
{
protected override string StoreName => "EnumQueryTest";
protected override ITestStoreFactory TestStoreFactory => NpgsqlTestStoreFactory.Instance;

protected override ITestStoreFactory TestStoreFactory
=> NpgsqlTestStoreFactory.WithDataSourceConfiguration(
b =>
b
.MapEnum<MappedEnum>("test.mapped_enum")
.MapEnum<InferredEnum>("test.inferred_enum")
.MapEnum<ByteEnum>("test.byte_enum")
.MapEnum<SchemaQualifiedEnum>("test.schema_qualified_enum"));

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

private EnumData _expectedData;
Expand Down Expand Up @@ -350,6 +376,7 @@ public IReadOnlyDictionary<Type, object> EntityAsserters
Assert.Equal(ee.ByteEnum, aa.ByteEnum);
Assert.Equal(ee.UnmappedByteEnum, aa.UnmappedByteEnum);
Assert.Equal(ee.EnumValue, aa.EnumValue);
Assert.Equal(ee.GloballyMappedEnum, aa.GloballyMappedEnum);
}
}
}
Expand Down Expand Up @@ -386,7 +413,8 @@ public static IReadOnlyList<SomeEnumEntity> CreateSomeEnumEntities()
SchemaQualifiedEnum = SchemaQualifiedEnum.Happy,
ByteEnum = ByteEnum.Happy,
UnmappedByteEnum = UnmappedByteEnum.Happy,
EnumValue = (int)MappedEnum.Happy
EnumValue = (int)MappedEnum.Happy,
GloballyMappedEnum = GloballyMappedEnum.Happy
},
new()
{
Expand All @@ -397,7 +425,8 @@ public static IReadOnlyList<SomeEnumEntity> CreateSomeEnumEntities()
SchemaQualifiedEnum = SchemaQualifiedEnum.Sad,
ByteEnum = ByteEnum.Sad,
UnmappedByteEnum = UnmappedByteEnum.Sad,
EnumValue = (int)MappedEnum.Sad
EnumValue = (int)MappedEnum.Sad,
GloballyMappedEnum = GloballyMappedEnum.Sad
}
};
}
Expand Down
2 changes: 1 addition & 1 deletion test/EFCore.PG.FunctionalTests/Query/TimestampQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ public class TimestampQueryFixture : SharedStoreFixtureBase<TimestampQueryContex
// don't depend on the database's time zone, and also that operations which shouldn't take TimeZone into account indeed
// don't.
protected override ITestStoreFactory TestStoreFactory
=> NpgsqlTestStoreFactory.WithConnectionStringOptions("-c TimeZone=Europe/Berlin");
=> NpgsqlTestStoreFactory.WithDataSourceConfiguration(b => b.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin");

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

Expand Down
48 changes: 27 additions & 21 deletions test/EFCore.PG.FunctionalTests/TestUtilities/NpgsqlTestStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

namespace Npgsql.EntityFrameworkCore.PostgreSQL.TestUtilities;

// ReSharper disable VirtualMemberCallInConstructor

public class NpgsqlTestStore : RelationalTestStore
{
private readonly NpgsqlDataSource _dataSource;

private readonly string _scriptPath;
private readonly string _additionalSql;

Expand All @@ -27,11 +31,14 @@ public static NpgsqlTestStore GetOrCreate(
string name,
string scriptPath = null,
string additionalSql = null,
string connectionStringOptions = null)
=> new(name, scriptPath, additionalSql, connectionStringOptions);
string connectionStringOptions = null,
Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null)
=> new(name, scriptPath, additionalSql, dataSourceBuilderAction);

public static NpgsqlTestStore Create(string name, string connectionStringOptions = null)
=> new(name, connectionStringOptions: connectionStringOptions, shared: false);
public static NpgsqlTestStore Create(
string name,
Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null)
=> new(name, dataSourceBuilderAction: dataSourceBuilderAction, shared: false);

public static NpgsqlTestStore CreateInitialized(string name)
=> new NpgsqlTestStore(name, shared: false)
Expand All @@ -41,7 +48,7 @@ private NpgsqlTestStore(
string name,
string scriptPath = null,
string additionalSql = null,
string connectionStringOptions = null,
Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null,
bool shared = true)
: base(name, shared)
{
Expand All @@ -55,10 +62,11 @@ private NpgsqlTestStore(

_additionalSql = additionalSql;

// ReSharper disable VirtualMemberCallInConstructor
ConnectionString = CreateConnectionString(Name, connectionStringOptions);
Connection = new NpgsqlConnection(ConnectionString);
// ReSharper restore VirtualMemberCallInConstructor
ConnectionString = CreateConnectionString(Name);
var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString);
dataSourceBuilderAction?.Invoke(dataSourceBuilder);
_dataSource = dataSourceBuilder.Build();
Connection = _dataSource.CreateConnection();
}

// ReSharper disable once MemberCanBePrivate.Global
Expand Down Expand Up @@ -100,7 +108,7 @@ protected override void Initialize(Func<DbContext> createContext, Action<DbConte
}

public override DbContextOptionsBuilder AddProviderOptions(DbContextOptionsBuilder builder)
=> builder.UseNpgsql(Connection, b => b.ApplyConfiguration()
=> builder.UseNpgsql(_dataSource, b => b.ApplyConfiguration()
.CommandTimeout(CommandTimeout)
// The tests are written with the assumption that NULLs are sorted first (SQL Server and .NET behavior), but PostgreSQL
// sorts NULLs last by default. This configures the provider to emit NULLS FIRST.
Expand Down Expand Up @@ -415,20 +423,18 @@ private static DbCommand CreateCommand(
return command;
}

public static string CreateConnectionString(string name, string options = null)
{
var builder = new NpgsqlConnectionStringBuilder(TestEnvironment.DefaultConnection) { Database = name };

if (options is not null)
{
builder.Options = options;
}

return builder.ConnectionString;
}
public static string CreateConnectionString(string name)
=> new NpgsqlConnectionStringBuilder(TestEnvironment.DefaultConnection) { Database = name }.ConnectionString;

private static string CreateAdminConnectionString() => CreateConnectionString("postgres");

public override void Clean(DbContext context)
=> context.Database.EnsureClean();

public override void Dispose()
{
base.Dispose();

_dataSource.Dispose();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@

public class NpgsqlTestStoreFactory : RelationalTestStoreFactory
{
private string _connectionStringOptions;
private readonly Action<NpgsqlDataSourceBuilder> _dataSourceBuilderAction;

public static NpgsqlTestStoreFactory Instance { get; } = new();

public static NpgsqlTestStoreFactory WithConnectionStringOptions(string connectionStringOptions)
=> new(connectionStringOptions);
public static NpgsqlTestStoreFactory WithDataSourceConfiguration(Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction)
=> new(dataSourceBuilderAction);

protected NpgsqlTestStoreFactory(string connectionStringOptions = null)
=> _connectionStringOptions = connectionStringOptions;
protected NpgsqlTestStoreFactory(Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null)
=> _dataSourceBuilderAction = dataSourceBuilderAction;

public override TestStore Create(string storeName)
=> NpgsqlTestStore.Create(storeName, _connectionStringOptions);
=> NpgsqlTestStore.Create(storeName, _dataSourceBuilderAction);

public override TestStore GetOrCreate(string storeName)
=> NpgsqlTestStore.GetOrCreate(storeName, connectionStringOptions: _connectionStringOptions);
=> NpgsqlTestStore.GetOrCreate(storeName, dataSourceBuilderAction: _dataSourceBuilderAction);

public override IServiceCollection AddProviderServices(IServiceCollection serviceCollection)
=> serviceCollection.AddEntityFrameworkNpgsql();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1910,7 +1910,7 @@ public NodaTimeQueryNpgsqlFixture()
// don't depend on the database's time zone, and also that operations which shouldn't take TimeZone into account indeed
// don't.
protected override ITestStoreFactory TestStoreFactory
=> NpgsqlTestStoreFactory.WithConnectionStringOptions("-c TimeZone=Europe/Berlin");
=> NpgsqlTestStoreFactory.WithDataSourceConfiguration(b => b.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin");

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

Expand Down