diff --git a/LeaderboardBackend/Models/Entities/ApplicationContext.cs b/LeaderboardBackend/Models/Entities/ApplicationContext.cs index 7ef6ebde..46db11ed 100644 --- a/LeaderboardBackend/Models/Entities/ApplicationContext.cs +++ b/LeaderboardBackend/Models/Entities/ApplicationContext.cs @@ -1,4 +1,3 @@ -using System.Collections.Concurrent; using System.Reflection; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; @@ -10,15 +9,12 @@ public class ApplicationContext : DbContext { public const string CASE_INSENSITIVE_COLLATION = "case_insensitive"; - // the HashCode is calculated with all of the config's property because PostgresConfig is a record - private static readonly ConcurrentDictionary _dataSourceCache = new(); + private readonly AppContextDataSourceProvider _dataSourceProvider; - private readonly NpgsqlDataSource _dataSource; - - public ApplicationContext(DbContextOptions options, IOptions config) + public ApplicationContext(DbContextOptions options, AppContextDataSourceProvider dataSourceProvider) : base(options) { - _dataSource = CreateDataSource(config.Value.Pg); + _dataSourceProvider = dataSourceProvider; } public DbSet AccountRecoveries { get; set; } = null!; @@ -28,6 +24,28 @@ public ApplicationContext(DbContextOptions options, IOptions public DbSet Runs { get; set; } = null!; public DbSet Users { get; set; } = null!; + public static NpgsqlDataSource CreateDataSource(PostgresConfig config) + { + NpgsqlConnectionStringBuilder connectionBuilder = new() + { + Host = config.Host, + Username = config.User, + Password = config.Password, + Database = config.Db, + IncludeErrorDetail = true, + }; + + if (config.Port is not null) + { + connectionBuilder.Port = config.Port.Value; + } + + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionBuilder.ConnectionString); + dataSourceBuilder.UseNodaTime().MapEnum(); + + return dataSourceBuilder.Build(); + } + public void MigrateDatabase() { Database.Migrate(); @@ -75,32 +93,28 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) protected override void OnConfiguring(DbContextOptionsBuilder opt) { - opt.UseNpgsql(_dataSource, o => o.UseNodaTime()); + opt.UseNpgsql(_dataSourceProvider.Value, o => o.UseNodaTime()); opt.UseSnakeCaseNamingConvention(); } +} + +public class AppContextDataSourceProvider +{ + private static int _cacheKey; + private static NpgsqlDataSource? _cachedDataSource; - private static NpgsqlDataSource CreateDataSource(PostgresConfig c) + public NpgsqlDataSource Value => _cachedDataSource!; + + public AppContextDataSourceProvider(IOptions appContextConfig) { - return _dataSourceCache.GetOrAdd(c, config => + PostgresConfig config = appContextConfig.Value.Pg; + + int key = config.GetHashCode(); // a record's HashCode is calculated with all its properties values + if (_cacheKey != key) { - NpgsqlConnectionStringBuilder connectionBuilder = new() - { - Host = config.Host, - Username = config.User, - Password = config.Password, - Database = config.Db, - IncludeErrorDetail = true, - }; - - if (config.Port is not null) - { - connectionBuilder.Port = config.Port.Value; - } - - NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionBuilder.ConnectionString); - dataSourceBuilder.UseNodaTime().MapEnum(); - - return dataSourceBuilder.Build(); - }); + // if we ever want to parallelize tests, this code will need to be made thread-safe + _cacheKey = key; + _cachedDataSource = ApplicationContext.CreateDataSource(config); + } } } diff --git a/LeaderboardBackend/Program.cs b/LeaderboardBackend/Program.cs index 5637f444..bbc87c3e 100644 --- a/LeaderboardBackend/Program.cs +++ b/LeaderboardBackend/Program.cs @@ -59,6 +59,7 @@ .ValidateDataAnnotationsRecursively() .ValidateOnStart(); +builder.Services.AddSingleton(); builder.Services.AddDbContext(); // Add services to the container.