diff --git a/Dapper/SqlMapper.Async.cs b/Dapper/SqlMapper.Async.cs index dba737fc6..68cbe6b90 100644 --- a/Dapper/SqlMapper.Async.cs +++ b/Dapper/SqlMapper.Async.cs @@ -967,7 +967,7 @@ private static async Task> MultiMapAsync(this IDbC } } - private static IEnumerable ExecuteReaderSync(IDataReader reader, Func func, object parameters) + private static IEnumerable ExecuteReaderSync(DbDataReader reader, Func func, object parameters) { using (reader) { @@ -1004,7 +1004,7 @@ public static async Task QueryMultipleAsync(this IDbConnection cnn, CacheInfo info = GetCacheInfo(identity, param, command.AddToCache); DbCommand cmd = null; - IDataReader reader = null; + DbDataReader reader = null; bool wasClosed = cnn.State == ConnectionState.Closed; try { diff --git a/Dapper/SqlMapper.CacheInfo.cs b/Dapper/SqlMapper.CacheInfo.cs index 409ea3411..63540aa3a 100644 --- a/Dapper/SqlMapper.CacheInfo.cs +++ b/Dapper/SqlMapper.CacheInfo.cs @@ -1,5 +1,6 @@ using System; using System.Data; +using System.Data.Common; using System.Threading; namespace Dapper @@ -9,7 +10,7 @@ public static partial class SqlMapper private class CacheInfo { public DeserializerState Deserializer { get; set; } - public Func[] OtherDeserializers { get; set; } + public Func[] OtherDeserializers { get; set; } public Action ParamReader { get; set; } private int hitCount; public int GetHitCount() { return Interlocked.CompareExchange(ref hitCount, 0, 0); } diff --git a/Dapper/SqlMapper.DeserializerState.cs b/Dapper/SqlMapper.DeserializerState.cs index 26b176cc1..bf1c2fce1 100644 --- a/Dapper/SqlMapper.DeserializerState.cs +++ b/Dapper/SqlMapper.DeserializerState.cs @@ -1,5 +1,6 @@ using System; using System.Data; +using System.Data.Common; namespace Dapper { @@ -8,9 +9,9 @@ public static partial class SqlMapper private struct DeserializerState { public readonly int Hash; - public readonly Func Func; + public readonly Func Func; - public DeserializerState(int hash, Func func) + public DeserializerState(int hash, Func func) { Hash = hash; Func = func; diff --git a/Dapper/SqlMapper.GridReader.Async.cs b/Dapper/SqlMapper.GridReader.Async.cs index bec4967be..f1c5a7fb4 100644 --- a/Dapper/SqlMapper.GridReader.Async.cs +++ b/Dapper/SqlMapper.GridReader.Async.cs @@ -14,7 +14,7 @@ public static partial class SqlMapper public partial class GridReader { private readonly CancellationToken cancel; - internal GridReader(IDbCommand command, IDataReader reader, Identity identity, DynamicParameters dynamicParams, bool addToCache, CancellationToken cancel) + internal GridReader(IDbCommand command, DbDataReader reader, Identity identity, DynamicParameters dynamicParams, bool addToCache, CancellationToken cancel) : this(command, reader, identity, dynamicParams, addToCache) { this.cancel = cancel; @@ -225,7 +225,7 @@ private async Task ReadRowAsyncImplViaDbReader(DbDataReader reader, Type t return result; } - private async Task> ReadBufferedAsync(int index, Func deserializer) + private async Task> ReadBufferedAsync(int index, Func deserializer) { try { diff --git a/Dapper/SqlMapper.GridReader.cs b/Dapper/SqlMapper.GridReader.cs index 15311cf4d..7a59050f0 100644 --- a/Dapper/SqlMapper.GridReader.cs +++ b/Dapper/SqlMapper.GridReader.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Globalization; using System.Runtime.CompilerServices; +using System.Data.Common; namespace Dapper { @@ -14,11 +15,11 @@ public static partial class SqlMapper /// public partial class GridReader : IDisposable { - private IDataReader reader; + private DbDataReader reader; private readonly Identity identity; private readonly bool addToCache; - internal GridReader(IDbCommand command, IDataReader reader, Identity identity, IParameterCallbacks callbacks, bool addToCache) + internal GridReader(IDbCommand command, DbDataReader reader, Identity identity, IParameterCallbacks callbacks, bool addToCache) { Command = command; this.reader = reader; @@ -351,7 +352,7 @@ public IEnumerable Read(Type[] types, Func return buffered ? result.ToList() : result; } - private IEnumerable ReadDeferred(int index, Func deserializer, Type effectiveType) + private IEnumerable ReadDeferred(int index, Func deserializer, Type effectiveType) { try { diff --git a/Dapper/SqlMapper.IDataReader.cs b/Dapper/SqlMapper.IDataReader.cs index dca4ebc6d..0fa7e0719 100644 --- a/Dapper/SqlMapper.IDataReader.cs +++ b/Dapper/SqlMapper.IDataReader.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Data; +using System.Data.Common; namespace Dapper { @@ -13,14 +14,15 @@ public static partial class SqlMapper /// The data reader to parse results from. public static IEnumerable Parse(this IDataReader reader) { - if (reader.Read()) + var dbReader = GetDbDataReader(reader, false); + if (dbReader.Read()) { var effectiveType = typeof(T); - var deser = GetDeserializer(effectiveType, reader, 0, -1, false); + var deser = GetDeserializer(effectiveType, dbReader, 0, -1, false); var convertToType = Nullable.GetUnderlyingType(effectiveType) ?? effectiveType; do { - object val = deser(reader); + object val = deser(dbReader); if (val == null || val is T) { yield return (T)val; @@ -29,7 +31,7 @@ public static IEnumerable Parse(this IDataReader reader) { yield return (T)Convert.ChangeType(val, convertToType, System.Globalization.CultureInfo.InvariantCulture); } - } while (reader.Read()); + } while (dbReader.Read()); } } @@ -40,13 +42,14 @@ public static IEnumerable Parse(this IDataReader reader) /// The type to parse from the . public static IEnumerable Parse(this IDataReader reader, Type type) { - if (reader.Read()) + var dbReader = GetDbDataReader(reader, false); + if (dbReader.Read()) { - var deser = GetDeserializer(type, reader, 0, -1, false); + var deser = GetDeserializer(type, dbReader, 0, -1, false); do { - yield return deser(reader); - } while (reader.Read()); + yield return deser(dbReader); + } while (dbReader.Read()); } } @@ -56,13 +59,14 @@ public static IEnumerable Parse(this IDataReader reader, Type type) /// The data reader to parse results from. public static IEnumerable Parse(this IDataReader reader) { - if (reader.Read()) + var dbReader = GetDbDataReader(reader, false); + if (dbReader.Read()) { - var deser = GetDapperRowDeserializer(reader, 0, -1, false); + var deser = GetDapperRowDeserializer(dbReader, 0, -1, false); do { - yield return deser(reader); - } while (reader.Read()); + yield return deser(dbReader); + } while (dbReader.Read()); } } @@ -76,12 +80,47 @@ public static IEnumerable Parse(this IDataReader reader) /// The length of columns to read (default -1 = all fields following startIndex) /// Return null if we can't find the first column? (default false) /// A parser for this specific object from this row. +#if DEBUG // make sure we're not using this internally + [Obsolete(nameof(DbDataReader) + " API should be preferred")] +#endif public static Func GetRowParser(this IDataReader reader, Type type, int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false) + { + return WrapObjectReader(GetDeserializer(type, GetDbDataReader(reader, false), startIndex, length, returnNullIfFirstMissing)); + } + + /// + /// Gets the row parser for a specific row on a data reader. This allows for type switching every row based on, for example, a TypeId column. + /// You could return a collection of the base type but have each more specific. + /// + /// The data reader to get the parser for the current row from + /// The type to get the parser for + /// The start column index of the object (default 0) + /// The length of columns to read (default -1 = all fields following startIndex) + /// Return null if we can't find the first column? (default false) + /// A parser for this specific object from this row. + public static Func GetRowParser(this DbDataReader reader, Type type, + int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false) { return GetDeserializer(type, reader, startIndex, length, returnNullIfFirstMissing); } + /// +#if DEBUG // make sure we're not using this internally + [Obsolete(nameof(DbDataReader) + " API should be preferred")] +#endif + public static Func GetRowParser(this IDataReader reader, Type concreteType = null, + int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false) + { + concreteType ??= typeof(T); + var func = GetDeserializer(concreteType, GetDbDataReader(reader, false), startIndex, length, returnNullIfFirstMissing); + return Wrap(func); + + // this is just to be very clear about what we're capturing + static Func Wrap(Func func) + => reader => (T)func(GetDbDataReader(reader, false)); + } + /// /// Gets the row parser for a specific row on a data reader. This allows for type switching every row based on, for example, a TypeId column. /// You could return a collection of the base type but have each more specific. @@ -135,7 +174,7 @@ public static Func GetRowParser(this IDataReader reader, Ty /// public override int Type => 2; /// } /// - public static Func GetRowParser(this IDataReader reader, Type concreteType = null, + public static Func GetRowParser(this DbDataReader reader, Type concreteType = null, int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false) { concreteType ??= typeof(T); @@ -146,7 +185,7 @@ public static Func GetRowParser(this IDataReader reader, Type } else { - return (Func)(Delegate)func; + return (Func)(Delegate)func; } } } diff --git a/Dapper/SqlMapper.TypeDeserializerCache.cs b/Dapper/SqlMapper.TypeDeserializerCache.cs index 088f464c7..65fd3bf40 100644 --- a/Dapper/SqlMapper.TypeDeserializerCache.cs +++ b/Dapper/SqlMapper.TypeDeserializerCache.cs @@ -3,6 +3,7 @@ using System.Collections; using System.Collections.Generic; using System.Text; +using System.Data.Common; namespace Dapper { @@ -33,7 +34,7 @@ internal static void Purge() } } - internal static Func GetReader(Type type, IDataReader reader, int startBound, int length, bool returnNullIfFirstMissing) + internal static Func GetReader(Type type, DbDataReader reader, int startBound, int length, bool returnNullIfFirstMissing) { var found = (TypeDeserializerCache)byType[type]; if (found == null) @@ -50,18 +51,18 @@ internal static Func GetReader(Type type, IDataReader reade return found.GetReader(reader, startBound, length, returnNullIfFirstMissing); } - private readonly Dictionary> readers = new Dictionary>(); + private readonly Dictionary> readers = new Dictionary>(); private struct DeserializerKey : IEquatable { private readonly int startBound, length; private readonly bool returnNullIfFirstMissing; - private readonly IDataReader reader; + private readonly DbDataReader reader; private readonly string[] names; private readonly Type[] types; private readonly int hashCode; - public DeserializerKey(int hashCode, int startBound, int length, bool returnNullIfFirstMissing, IDataReader reader, bool copyDown) + public DeserializerKey(int hashCode, int startBound, int length, bool returnNullIfFirstMissing, DbDataReader reader, bool copyDown) { this.hashCode = hashCode; this.startBound = startBound; @@ -136,14 +137,14 @@ public bool Equals(DeserializerKey other) } } - private Func GetReader(IDataReader reader, int startBound, int length, bool returnNullIfFirstMissing) + private Func GetReader(DbDataReader reader, int startBound, int length, bool returnNullIfFirstMissing) { if (length < 0) length = reader.FieldCount - startBound; int hash = GetColumnHash(reader, startBound, length); if (returnNullIfFirstMissing) hash *= -27; // get a cheap key first: false means don't copy the values down var key = new DeserializerKey(hash, startBound, length, returnNullIfFirstMissing, reader, false); - Func deser; + Func deser; lock (readers) { if (readers.TryGetValue(key, out deser)) return deser; diff --git a/Dapper/SqlMapper.cs b/Dapper/SqlMapper.cs index ee49c2abd..f61f89c8b 100644 --- a/Dapper/SqlMapper.cs +++ b/Dapper/SqlMapper.cs @@ -8,6 +8,8 @@ using System.Collections.Generic; using System.ComponentModel; using System.Data; +using System.Data.Common; +using System.Data.SqlTypes; using System.Globalization; using System.Linq; using System.Reflection; @@ -31,7 +33,7 @@ private class PropertyInfoByNameComparer : IComparer { public int Compare(PropertyInfo x, PropertyInfo y) => string.CompareOrdinal(x.Name, y.Name); } - private static int GetColumnHash(IDataReader reader, int startBound = 0, int length = -1) + private static int GetColumnHash(DbDataReader reader, int startBound = 0, int length = -1) { unchecked { @@ -163,11 +165,39 @@ where pair.Value > 1 select Tuple.Create(pair.Key, pair.Value); } - private static Dictionary typeMap; + private static Dictionary typeMap; + + [Flags] + internal enum TypeMapEntryFlags + { + None = 0, + SetType = 1 << 0, + UseGetFieldValue = 1 << 1, + } + internal readonly struct TypeMapEntry : IEquatable + { + public readonly DbType DbType { get; } + public readonly TypeMapEntryFlags Flags; + public TypeMapEntry(DbType dbType, TypeMapEntryFlags flags) + { + DbType = dbType; + Flags = flags; + } + public override int GetHashCode() => (int)DbType ^ (int)Flags; + public override string ToString() => $"{DbType}, {Flags}"; + public override bool Equals(object obj) => obj is TypeMapEntry other && Equals(other); + public bool Equals(TypeMapEntry other) => other.DbType == DbType && other.Flags == Flags; + public static readonly TypeMapEntry + DoNotSet = new TypeMapEntry((DbType)(-2), TypeMapEntryFlags.None), + DecimalFieldValue = new TypeMapEntry(DbType.Decimal, TypeMapEntryFlags.SetType | TypeMapEntryFlags.UseGetFieldValue); + + public static implicit operator TypeMapEntry(DbType dbType) + => new TypeMapEntry(dbType, TypeMapEntryFlags.SetType); + } static SqlMapper() { - typeMap = new Dictionary(37) + typeMap = new Dictionary(41) { [typeof(byte)] = DbType.Byte, [typeof(sbyte)] = DbType.SByte, @@ -184,9 +214,9 @@ static SqlMapper() [typeof(string)] = DbType.String, [typeof(char)] = DbType.StringFixedLength, [typeof(Guid)] = DbType.Guid, - [typeof(DateTime)] = null, + [typeof(DateTime)] = TypeMapEntry.DoNotSet, [typeof(DateTimeOffset)] = DbType.DateTimeOffset, - [typeof(TimeSpan)] = null, + [typeof(TimeSpan)] = TypeMapEntry.DoNotSet, [typeof(byte[])] = DbType.Binary, [typeof(byte?)] = DbType.Byte, [typeof(sbyte?)] = DbType.SByte, @@ -202,10 +232,14 @@ static SqlMapper() [typeof(bool?)] = DbType.Boolean, [typeof(char?)] = DbType.StringFixedLength, [typeof(Guid?)] = DbType.Guid, - [typeof(DateTime?)] = null, + [typeof(DateTime?)] = TypeMapEntry.DoNotSet, [typeof(DateTimeOffset?)] = DbType.DateTimeOffset, - [typeof(TimeSpan?)] = null, - [typeof(object)] = DbType.Object + [typeof(TimeSpan?)] = TypeMapEntry.DoNotSet, + [typeof(object)] = DbType.Object, + [typeof(SqlDecimal)] = TypeMapEntry.DecimalFieldValue, + [typeof(SqlDecimal?)] = TypeMapEntry.DecimalFieldValue, + [typeof(SqlMoney)] = TypeMapEntry.DecimalFieldValue, + [typeof(SqlMoney?)] = TypeMapEntry.DecimalFieldValue, }; ResetTypeHandlers(false); } @@ -230,13 +264,42 @@ private static void ResetTypeHandlers(bool clone) /// The type to map from. /// The database type to map to. public static void AddTypeMap(Type type, DbType dbType) + => AddTypeMap(type, dbType, false); + + /// + /// Configure the specified type to be mapped to a given db-type. + /// + /// The type to map from. + /// The database type to map to. + /// Whether to prefer over . + public static void AddTypeMap(Type type, DbType dbType, bool useGetFieldValue) { // use clone, mutate, replace to avoid threading issues var snapshot = typeMap; + var flags = TypeMapEntryFlags.None; + if (dbType >= 0) + { + flags |= TypeMapEntryFlags.SetType; + } + if (useGetFieldValue) + { + flags |= TypeMapEntryFlags.UseGetFieldValue; + } + var value = new TypeMapEntry(dbType, flags); + if (snapshot.TryGetValue(type, out var oldValue) && oldValue.Equals(value)) return; // nothing to do - if (snapshot.TryGetValue(type, out var oldValue) && oldValue == dbType) return; // nothing to do + SetTypeMap(new Dictionary(snapshot) { [type] = value }); + } - typeMap = new Dictionary(snapshot) { [type] = dbType }; + private static void SetTypeMap(Dictionary value) + { + typeMap = value; + + // this cache is predicated on the contents of the type-map; reset it + lock (s_ReadViaGetFieldValueCache) + { + s_ReadViaGetFieldValueCache.Clear(); + } } /// @@ -250,10 +313,10 @@ public static void RemoveTypeMap(Type type) if (!snapshot.ContainsKey(type)) return; // nothing to do - var newCopy = new Dictionary(snapshot); + var newCopy = new Dictionary(snapshot); newCopy.Remove(type); - typeMap = newCopy; + SetTypeMap(newCopy); } /// @@ -371,9 +434,13 @@ public static void SetDbType(IDataParameter parameter, object value) { type = Enum.GetUnderlyingType(type); } - if (typeMap.TryGetValue(type, out var dbType)) + if (typeMap.TryGetValue(type, out var mapEntry)) { - return dbType; + if ((mapEntry.Flags & TypeMapEntryFlags.SetType) == 0) + { + return null; + } + return mapEntry.DbType; } if (type.FullName == LinqBinary) { @@ -1024,7 +1091,7 @@ private static GridReader QueryMultipleImpl(this IDbConnection cnn, ref CommandD CacheInfo info = GetCacheInfo(identity, param, command.AddToCache); IDbCommand cmd = null; - IDataReader reader = null; + DbDataReader reader = null; bool wasClosed = cnn.State == ConnectionState.Closed; try { @@ -1059,18 +1126,18 @@ private static GridReader QueryMultipleImpl(this IDbConnection cnn, ref CommandD } } - private static IDataReader ExecuteReaderWithFlagsFallback(IDbCommand cmd, bool wasClosed, CommandBehavior behavior) + private static DbDataReader ExecuteReaderWithFlagsFallback(IDbCommand cmd, bool wasClosed, CommandBehavior behavior) { try { - return cmd.ExecuteReader(GetBehavior(wasClosed, behavior)); + return GetDbDataReader(cmd.ExecuteReader(GetBehavior(wasClosed, behavior))); } catch (ArgumentException ex) { // thanks, Sqlite! if (Settings.DisableCommandBehaviorOptimizations(behavior, ex)) { // we can retry; this time it will have different flags - return cmd.ExecuteReader(GetBehavior(wasClosed, behavior)); + return GetDbDataReader(cmd.ExecuteReader(GetBehavior(wasClosed, behavior))); } throw; } @@ -1083,7 +1150,7 @@ private static IEnumerable QueryImpl(this IDbConnection cnn, CommandDefini var info = GetCacheInfo(identity, param, command.AddToCache); IDbCommand cmd = null; - IDataReader reader = null; + DbDataReader reader = null; bool wasClosed = cnn.State == ConnectionState.Closed; try @@ -1176,7 +1243,7 @@ private static T QueryRowImpl(IDbConnection cnn, Row row, ref CommandDefiniti var info = GetCacheInfo(identity, param, command.AddToCache); IDbCommand cmd = null; - IDataReader reader = null; + DbDataReader reader = null; bool wasClosed = cnn.State == ConnectionState.Closed; try @@ -1234,7 +1301,7 @@ private static T QueryRowImpl(IDbConnection cnn, Row row, ref CommandDefiniti /// Shared value deserialization path for QueryRowImpl and QueryRowAsync /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static T ReadRow(CacheInfo info, Identity identity, ref CommandDefinition command, Type effectiveType, IDataReader reader) + private static T ReadRow(CacheInfo info, Identity identity, ref CommandDefinition command, Type effectiveType, DbDataReader reader) { var tuple = info.Deserializer; int hash = GetColumnHash(reader); @@ -1250,7 +1317,7 @@ private static T ReadRow(CacheInfo info, Identity identity, ref CommandDefini } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static T GetValue(IDataReader reader, Type effectiveType, object val) + private static T GetValue(DbDataReader reader, Type effectiveType, object val) { if (val is T tVal) { @@ -1451,14 +1518,14 @@ private static IEnumerable MultiMap MultiMapImpl(this IDbConnection cnn, CommandDefinition command, Delegate map, string splitOn, IDataReader reader, Identity identity, bool finalize) + private static IEnumerable MultiMapImpl(this IDbConnection cnn, CommandDefinition command, Delegate map, string splitOn, DbDataReader reader, Identity identity, bool finalize) { object param = command.Parameters; identity ??= new Identity(command.CommandText, command.CommandType, cnn, typeof(TFirst), param?.GetType()); CacheInfo cinfo = GetCacheInfo(identity, param, command.AddToCache); IDbCommand ownedCommand = null; - IDataReader ownedReader = null; + DbDataReader ownedReader = null; bool wasClosed = cnn?.State == ConnectionState.Closed; try @@ -1471,7 +1538,7 @@ private static IEnumerable MultiMapImpl[] otherDeserializers; + Func[] otherDeserializers; int hash = GetColumnHash(reader); if ((deserializer = cinfo.Deserializer).Func == null || (otherDeserializers = cinfo.OtherDeserializers) == null || hash != deserializer.Hash) @@ -1482,7 +1549,7 @@ private static IEnumerable MultiMapImpl mapIt = GenerateMapper(deserializer.Func, otherDeserializers, map); + Func mapIt = GenerateMapper(deserializer.Func, otherDeserializers, map); if (mapIt != null) { @@ -1517,7 +1584,7 @@ private static CommandBehavior GetBehavior(bool close, CommandBehavior @default) return (close ? (@default | CommandBehavior.CloseConnection) : @default) & Settings.AllowedCommandBehaviors; } - private static IEnumerable MultiMapImpl(this IDbConnection cnn, CommandDefinition command, Type[] types, Func map, string splitOn, IDataReader reader, Identity identity, bool finalize) + private static IEnumerable MultiMapImpl(this IDbConnection cnn, CommandDefinition command, Type[] types, Func map, string splitOn, DbDataReader reader, Identity identity, bool finalize) { if (types.Length < 1) { @@ -1529,7 +1596,7 @@ private static IEnumerable MultiMapImpl(this IDbConnection cnn CacheInfo cinfo = GetCacheInfo(identity, param, command.AddToCache); IDbCommand ownedCommand = null; - IDataReader ownedReader = null; + DbDataReader ownedReader = null; bool wasClosed = cnn?.State == ConnectionState.Closed; try @@ -1542,7 +1609,7 @@ private static IEnumerable MultiMapImpl(this IDbConnection cnn reader = ownedReader; } DeserializerState deserializer; - Func[] otherDeserializers; + Func[] otherDeserializers; int hash = GetColumnHash(reader); if ((deserializer = cinfo.Deserializer).Func == null || (otherDeserializers = cinfo.OtherDeserializers) == null || hash != deserializer.Hash) @@ -1553,7 +1620,7 @@ private static IEnumerable MultiMapImpl(this IDbConnection cnn SetQueryCache(identity, cinfo); } - Func mapIt = GenerateMapper(types.Length, deserializer.Func, otherDeserializers, map); + Func mapIt = GenerateMapper(types.Length, deserializer.Func, otherDeserializers, map); if (mapIt != null) { @@ -1583,7 +1650,7 @@ private static IEnumerable MultiMapImpl(this IDbConnection cnn } } - private static Func GenerateMapper(Func deserializer, Func[] otherDeserializers, object map) + private static Func GenerateMapper(Func deserializer, Func[] otherDeserializers, object map) => otherDeserializers.Length switch { 1 => r => ((Func)map)((TFirst)deserializer(r), (TSecond)otherDeserializers[0](r)), @@ -1595,7 +1662,7 @@ private static Func GenerateMapper throw new NotSupportedException(), }; - private static Func GenerateMapper(int length, Func deserializer, Func[] otherDeserializers, Func map) + private static Func GenerateMapper(int length, Func deserializer, Func[] otherDeserializers, Func map) { return r => { @@ -1611,9 +1678,9 @@ private static Func GenerateMapper(int length, Fu }; } - private static Func[] GenerateDeserializers(Identity identity, string splitOn, IDataReader reader) + private static Func[] GenerateDeserializers(Identity identity, string splitOn, DbDataReader reader) { - var deserializers = new List>(); + var deserializers = new List>(); var splits = splitOn.Split(',').Select(s => s.Trim()).ToArray(); bool isMultiSplit = splits.Length > 1; @@ -1680,7 +1747,7 @@ private static Func[] GenerateDeserializers(Identity identi return deserializers.ToArray(); } - private static int GetNextSplitDynamic(int startIdx, string splitOn, IDataReader reader) + private static int GetNextSplitDynamic(int startIdx, string splitOn, DbDataReader reader) { if (startIdx == reader.FieldCount) { @@ -1703,7 +1770,7 @@ private static int GetNextSplitDynamic(int startIdx, string splitOn, IDataReader return reader.FieldCount; } - private static int GetNextSplit(int startIdx, string splitOn, IDataReader reader) + private static int GetNextSplit(int startIdx, string splitOn, DbDataReader reader) { if (splitOn == "*") { @@ -1817,15 +1884,41 @@ private static void PassByPosition(IDbCommand cmd) }); } - private static Func GetDeserializer(Type type, IDataReader reader, int startBound, int length, bool returnNullIfFirstMissing) + static DbDataReader GetDbDataReader(IDataReader reader, bool disposeOnFail = true) + { + return reader as DbDataReader ?? Throw(reader, disposeOnFail); + static DbDataReader Throw(IDataReader reader, bool disposeOnFail) + { + if (reader is null) + { + throw new ArgumentNullException(nameof(reader)); + } + if (disposeOnFail) + { + reader.Dispose(); // don't leak + } + // in reality, all providers have satisfied this since forever; we should have made Dapper target DbConnection, oops! + throw new NotSupportedException("The provided reader is required to be a DbDataReader, and is not"); + } + } + + private static Func GetDeserializer(Type type, DbDataReader reader, int startBound, int length, bool returnNullIfFirstMissing) { + + // dynamic is passed in as Object ... by c# design if (type == typeof(object) || type == typeof(DapperRow)) { return GetDapperRowDeserializer(reader, startBound, length, returnNullIfFirstMissing); } + Type underlyingType = null; - if (!(typeMap.ContainsKey(type) || type.IsEnum || type.IsArray || type.FullName == LinqBinary + bool useGetFieldValue = false; + if (typeMap.TryGetValue(type, out var mapEntry)) + { + useGetFieldValue = (mapEntry.Flags & TypeMapEntryFlags.UseGetFieldValue) != 0; + } + else if (!(type.IsEnum || type.IsArray || type.FullName == LinqBinary || (type.IsValueType && (underlyingType = Nullable.GetUnderlyingType(type)) != null && underlyingType.IsEnum))) { if (typeHandlers.TryGetValue(type, out ITypeHandler handler)) @@ -1834,10 +1927,10 @@ private static Func GetDeserializer(Type type, IDataReader } return GetTypeDeserializer(type, reader, startBound, length, returnNullIfFirstMissing); } - return GetStructDeserializer(type, underlyingType ?? type, startBound); + return GetStructDeserializer(type, underlyingType ?? type, startBound, useGetFieldValue); } - private static Func GetHandlerDeserializer(ITypeHandler handler, Type type, int startBound) + private static Func GetHandlerDeserializer(ITypeHandler handler, Type type, int startBound) { return reader => handler.Parse(type, reader.GetValue(startBound)); } @@ -1861,7 +1954,7 @@ private static Exception MultiMapException(IDataRecord reader, string splitOnCol } } - internal static Func GetDapperRowDeserializer(IDataRecord reader, int startBound, int length, bool returnNullIfFirstMissing) + internal static Func GetDapperRowDeserializer(DbDataReader reader, int startBound, int length, bool returnNullIfFirstMissing) { var fieldCount = reader.FieldCount; if (length == -1) @@ -2886,7 +2979,7 @@ private static T ExecuteScalarImpl(IDbConnection cnn, ref CommandDefinition c return Parse(result); } - private static IDataReader ExecuteReaderImpl(IDbConnection cnn, ref CommandDefinition command, CommandBehavior commandBehavior, out IDbCommand cmd) + private static DbDataReader ExecuteReaderImpl(IDbConnection cnn, ref CommandDefinition command, CommandBehavior commandBehavior, out IDbCommand cmd) { Action paramReader = GetParameterReader(cnn, ref command); cmd = null; @@ -2932,7 +3025,7 @@ private static Action GetParameterReader(IDbConnection cnn, return paramReader; } - private static Func GetStructDeserializer(Type type, Type effectiveType, int index) + private static Func GetStructDeserializer(Type type, Type effectiveType, int index, bool useGetFieldValue) { // no point using special per-type handling here; it boils down to the same, plus not all are supported anyway (see: SqlDataReader.GetChar - not supported!) #pragma warning disable 618 @@ -2970,12 +3063,41 @@ private static Func GetStructDeserializer(Type type, Type e return val is DBNull ? null : handler.Parse(type, val); }; } - return r => + return useGetFieldValue ? ReadViaGetFieldValueFactory(type, index) : ReadViaGetValue(index); + + static Func ReadViaGetValue(int index) + => reader => + { + var val = reader.GetValue(index); + return val is DBNull ? null : val; + }; + } + + static Func ReadViaGetFieldValueFactory(Type type, int index) + { + type = Nullable.GetUnderlyingType(type) ?? type; + var factory = (Func>)s_ReadViaGetFieldValueCache[type]; + if (factory is null) { - var val = r.GetValue(index); - return val is DBNull ? null : val; - }; + factory = (Func>)Delegate.CreateDelegate( + typeof(Func>), null, typeof(SqlMapper).GetMethod( + nameof(UnderlyingReadViaGetFieldValueFactory), BindingFlags.Static | BindingFlags.NonPublic) + .MakeGenericMethod(type)); + lock (s_ReadViaGetFieldValueCache) + { + s_ReadViaGetFieldValueCache[type] = factory; + } + } + return factory(index); } + // cache of ReadViaGetFieldValueFactory for per-value T + static readonly Hashtable s_ReadViaGetFieldValueCache = new Hashtable(); + + static Func UnderlyingReadViaGetFieldValueFactory(int index) + => reader => reader.IsDBNull(index) ? null : reader.GetFieldValue(index); + + static bool UseGetFieldValue(Type type) => typeMap.TryGetValue(type, out var mapEntry) + && (mapEntry.Flags & TypeMapEntryFlags.UseGetFieldValue) != 0; private static T Parse(object value) { @@ -3000,9 +3122,13 @@ private static T Parse(object value) private static readonly MethodInfo enumParse = typeof(Enum).GetMethod(nameof(Enum.Parse), new Type[] { typeof(Type), typeof(string), typeof(bool) }), - getItem = typeof(IDataRecord).GetProperties(BindingFlags.Instance | BindingFlags.Public) + getItem = typeof(DbDataReader).GetProperties(BindingFlags.Instance | BindingFlags.Public) .Where(p => p.GetIndexParameters().Length > 0 && p.GetIndexParameters()[0].ParameterType == typeof(int)) - .Select(p => p.GetGetMethod()).First(); + .Select(p => p.GetGetMethod()).First(), + getFieldValueT = typeof(DbDataReader).GetMethod(nameof(DbDataReader.GetFieldValue), + BindingFlags.Instance | BindingFlags.Public, null, new Type[] { typeof(int) }, null), + isDbNull = typeof(DbDataReader).GetMethod(nameof(DbDataReader.IsDBNull), + BindingFlags.Instance | BindingFlags.Public, null, new Type[] { typeof(int) }, null); /// /// Gets type-map for the given type @@ -3078,9 +3204,31 @@ public static void SetTypeMap(Type type, ITypeMap map) /// /// /// +#if DEBUG // make sure we're not using this internally + [Obsolete(nameof(DbDataReader) + " API should be preferred")] +#endif public static Func GetTypeDeserializer( Type type, IDataReader reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false ) + { + return WrapObjectReader(GetTypeDeserializer(type, GetDbDataReader(reader, false), startBound, length, returnNullIfFirstMissing)); + } + + private static Func WrapObjectReader(Func dbReader) + => reader => dbReader(GetDbDataReader(reader)); // we'll eat the extra layer here; this is not a core API + + + /// + /// Internal use only + /// + /// + /// + /// + /// + /// + public static Func GetTypeDeserializer( + Type type, DbDataReader reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false + ) { return TypeDeserializerCache.GetReader(type, reader, startBound, length, returnNullIfFirstMissing); } @@ -3104,8 +3252,8 @@ private static LocalBuilder GetTempLocal(ILGenerator il, ref Dictionary GetTypeDeserializerImpl( - Type type, IDataReader reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false + private static Func GetTypeDeserializerImpl( + Type type, DbDataReader reader, int startBound = 0, int length = -1, bool returnNullIfFirstMissing = false ) { if (length == -1) @@ -3119,7 +3267,7 @@ private static Func GetTypeDeserializerImpl( } var returnType = type.IsValueType ? typeof(object) : type; - var dm = new DynamicMethod("Deserialize" + Guid.NewGuid().ToString(), returnType, new[] { typeof(IDataReader) }, type, true); + var dm = new DynamicMethod("Deserialize" + Guid.NewGuid().ToString(), returnType, new[] { typeof(DbDataReader) }, type, true); var il = dm.GetILGenerator(); if (IsValueTuple(type)) @@ -3131,11 +3279,11 @@ private static Func GetTypeDeserializerImpl( GenerateDeserializerFromMap(type, reader, startBound, length, returnNullIfFirstMissing, il); } - var funcType = System.Linq.Expressions.Expression.GetFuncType(typeof(IDataReader), returnType); - return (Func)dm.CreateDelegate(funcType); + var funcType = System.Linq.Expressions.Expression.GetFuncType(typeof(DbDataReader), returnType); + return (Func)dm.CreateDelegate(funcType); } - private static void GenerateValueTupleDeserializer(Type valueTupleType, IDataReader reader, int startBound, int length, ILGenerator il) + private static void GenerateValueTupleDeserializer(Type valueTupleType, DbDataReader reader, int startBound, int length, ILGenerator il) { var nullableUnderlyingType = Nullable.GetUnderlyingType(valueTupleType); var currentValueTupleType = nullableUnderlyingType ?? valueTupleType; @@ -3201,12 +3349,15 @@ private static void GenerateValueTupleDeserializer(Type valueTupleType, IDataRea valueCopyLocal: null, reader.GetFieldType(startBound + i), targetType, - out var isDbNullLabel); + out var isDbNullLabel, out bool popWhenNull); var finishLabel = il.DefineLabel(); il.Emit(OpCodes.Br_S, finishLabel); il.MarkLabel(isDbNullLabel); - il.Emit(OpCodes.Pop); + if (popWhenNull) + { + il.Emit(OpCodes.Pop); + } LoadDefaultValue(il, targetType); @@ -3234,7 +3385,7 @@ private static void GenerateValueTupleDeserializer(Type valueTupleType, IDataRea il.Emit(OpCodes.Ret); } - private static void GenerateDeserializerFromMap(Type type, IDataReader reader, int startBound, int length, bool returnNullIfFirstMissing, ILGenerator il) + private static void GenerateDeserializerFromMap(Type type, DbDataReader reader, int startBound, int length, bool returnNullIfFirstMissing, ILGenerator il) { var currentIndexDiagnosticLocal = il.DeclareLocal(typeof(int)); var returnValueLocal = il.DeclareLocal(type); @@ -3348,7 +3499,7 @@ private static void GenerateDeserializerFromMap(Type type, IDataReader reader, i EmitInt32(il, index); il.Emit(OpCodes.Stloc, currentIndexDiagnosticLocal); - LoadReaderValueOrBranchToDBNullLabel(il, index, ref stringEnumLocal, valueCopyDiagnosticLocal, reader.GetFieldType(index), memberType, out var isDbNullLabel); + LoadReaderValueOrBranchToDBNullLabel(il, index, ref stringEnumLocal, valueCopyDiagnosticLocal, reader.GetFieldType(index), memberType, out var isDbNullLabel, out bool popWhenNull); if (specializedConstructor == null) { @@ -3365,15 +3516,14 @@ private static void GenerateDeserializerFromMap(Type type, IDataReader reader, i il.Emit(OpCodes.Br_S, finishLabel); // stack is now [target] - il.MarkLabel(isDbNullLabel); // incoming stack: [target][target][value] + il.MarkLabel(isDbNullLabel); // incoming stack: [target][target][(and possibly value)] + if (popWhenNull) il.Emit(OpCodes.Pop); // stack is now [target][target] if (specializedConstructor != null) { - il.Emit(OpCodes.Pop); LoadDefaultValue(il, item.MemberType); } else if (applyNullSetting && (!memberType.IsValueType || Nullable.GetUnderlyingType(memberType) != null)) { - il.Emit(OpCodes.Pop); // stack is now [target][target] // can load a null with this value if (memberType.IsValueType) { // must be Nullable for some T @@ -3397,7 +3547,6 @@ private static void GenerateDeserializerFromMap(Type type, IDataReader reader, i } else { - il.Emit(OpCodes.Pop); // stack is now [target][target] il.Emit(OpCodes.Pop); // stack is now [target] } @@ -3462,11 +3611,50 @@ private static void LoadDefaultValue(ILGenerator il, Type type) } } - private static void LoadReaderValueOrBranchToDBNullLabel(ILGenerator il, int index, ref LocalBuilder stringEnumLocal, LocalBuilder valueCopyLocal, Type colType, Type memberType, out Label isDbNullLabel) + private static void LoadReaderValueViaGetFieldValue(ILGenerator il, int index, Type memberType, LocalBuilder valueCopyLocal, Label isDbNullLabel, out bool popWhenNull) + { + popWhenNull = false; + var underlyingType = Nullable.GetUnderlyingType(memberType) ?? memberType; + + // for consistency, always do a null check (the GetValue approach always tests for DbNull and jumps) + il.Emit(OpCodes.Ldarg_0); // stack is now [...][reader] + EmitInt32(il, index); // stack is now [...][reader][index] + il.Emit(OpCodes.Callvirt, isDbNull); // stack is now [...][bool] + il.Emit(OpCodes.Brtrue_S, isDbNullLabel); + + // DB reports not null; read the value + il.Emit(OpCodes.Ldarg_0); // stack is now [...][reader] + EmitInt32(il, index); // stack is now [...][reader][index] + il.Emit(OpCodes.Callvirt, getFieldValueT.MakeGenericMethod(underlyingType)); // stack is now [...][T] + if (valueCopyLocal is not null) + { + il.Emit(OpCodes.Dup); // stack is now [...][T][T] + if (underlyingType.IsValueType) + { + il.Emit(OpCodes.Box, underlyingType); // stack is now [...][T][value-as-object] + } + il.Emit(OpCodes.Stloc, valueCopyLocal); // stack is now [...][T] + } + if (underlyingType != memberType) + { + // Nullable; wrap it + il.Emit(OpCodes.Newobj, memberType.GetConstructor(new[] { underlyingType })); // stack is now [...][T?] + } + } + + private static void LoadReaderValueOrBranchToDBNullLabel(ILGenerator il, int index, ref LocalBuilder stringEnumLocal, LocalBuilder valueCopyLocal, Type colType, Type memberType, out Label isDbNullLabel, out bool popWhenNull) { isDbNullLabel = il.DefineLabel(); + if (UseGetFieldValue(memberType)) + { + LoadReaderValueViaGetFieldValue(il, index, memberType, valueCopyLocal, isDbNullLabel, out popWhenNull); + return; + } + + popWhenNull = true; il.Emit(OpCodes.Ldarg_0); // stack is now [...][reader] EmitInt32(il, index); // stack is now [...][reader][index] + // default impl: use GetValue il.Emit(OpCodes.Callvirt, getItem); // stack is now [...][value-as-object] if (valueCopyLocal != null) diff --git a/Dapper/WrappedReader.cs b/Dapper/WrappedReader.cs index f092823f5..ded629b04 100644 --- a/Dapper/WrappedReader.cs +++ b/Dapper/WrappedReader.cs @@ -80,18 +80,7 @@ internal static class WrappedReader { // the purpose of wrapping here is to allow closing a reader to *also* close // the command, without having to explicitly hand the command back to the - // caller; what that actually looks like depends on what we get: if we are - // given a DbDataReader, we will surface a DbDataReader; if we are given - // a raw IDataReader, we will surface that; and if null: null - public static IDataReader Create(IDbCommand cmd, IDataReader reader) - { - if (cmd == null) return reader; // no need to wrap if no command - - if (reader is DbDataReader dbr) return new DbWrappedReader(cmd, dbr); - if (reader != null) return new BasicWrappedReader(cmd, reader); - cmd.Dispose(); - return null; // GIGO - } + // caller public static DbDataReader Create(IDbCommand cmd, DbDataReader reader) { if (cmd == null) return reader; // no need to wrap if no command @@ -210,97 +199,6 @@ public override long GetChars(int i, long fieldoffset, char[] buffer, int buffer public override Task NextResultAsync(CancellationToken cancellationToken) => _reader.NextResultAsync(cancellationToken); public override Task ReadAsync(CancellationToken cancellationToken) => _reader.ReadAsync(cancellationToken); public override int VisibleFieldCount => _reader.VisibleFieldCount; - protected override DbDataReader GetDbDataReader(int ordinal) => (((IDataReader)_reader).GetData(ordinal) as DbDataReader) ?? throw new NotSupportedException(); - } - - internal class BasicWrappedReader : IWrappedDataReader - { - private IDataReader _reader; - private IDbCommand _cmd; - - IDataReader IWrappedDataReader.Reader => _reader; - - IDbCommand IWrappedDataReader.Command => _cmd; - - public BasicWrappedReader(IDbCommand cmd, IDataReader reader) - { - _cmd = cmd; - _reader = reader; - } - - void IDataReader.Close() => _reader.Close(); - - int IDataReader.Depth => _reader.Depth; - - DataTable IDataReader.GetSchemaTable() => _reader.GetSchemaTable(); - - bool IDataReader.IsClosed => _reader.IsClosed; - - bool IDataReader.NextResult() => _reader.NextResult(); - - bool IDataReader.Read() => _reader.Read(); - - int IDataReader.RecordsAffected => _reader.RecordsAffected; - - void IDisposable.Dispose() - { - _reader.Close(); - _reader.Dispose(); - _reader = DisposedReader.Instance; - _cmd?.Dispose(); - _cmd = null; - } - - int IDataRecord.FieldCount => _reader.FieldCount; - - bool IDataRecord.GetBoolean(int i) => _reader.GetBoolean(i); - - byte IDataRecord.GetByte(int i) => _reader.GetByte(i); - - long IDataRecord.GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) => - _reader.GetBytes(i, fieldOffset, buffer, bufferoffset, length); - - char IDataRecord.GetChar(int i) => _reader.GetChar(i); - - long IDataRecord.GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) => - _reader.GetChars(i, fieldoffset, buffer, bufferoffset, length); - - IDataReader IDataRecord.GetData(int i) => _reader.GetData(i); - - string IDataRecord.GetDataTypeName(int i) => _reader.GetDataTypeName(i); - - DateTime IDataRecord.GetDateTime(int i) => _reader.GetDateTime(i); - - decimal IDataRecord.GetDecimal(int i) => _reader.GetDecimal(i); - - double IDataRecord.GetDouble(int i) => _reader.GetDouble(i); - - Type IDataRecord.GetFieldType(int i) => _reader.GetFieldType(i); - - float IDataRecord.GetFloat(int i) => _reader.GetFloat(i); - - Guid IDataRecord.GetGuid(int i) => _reader.GetGuid(i); - - short IDataRecord.GetInt16(int i) => _reader.GetInt16(i); - - int IDataRecord.GetInt32(int i) => _reader.GetInt32(i); - - long IDataRecord.GetInt64(int i) => _reader.GetInt64(i); - - string IDataRecord.GetName(int i) => _reader.GetName(i); - - int IDataRecord.GetOrdinal(string name) => _reader.GetOrdinal(name); - - string IDataRecord.GetString(int i) => _reader.GetString(i); - - object IDataRecord.GetValue(int i) => _reader.GetValue(i); - - int IDataRecord.GetValues(object[] values) => _reader.GetValues(values); - - bool IDataRecord.IsDBNull(int i) => _reader.IsDBNull(i); - - object IDataRecord.this[string name] => _reader[name]; - - object IDataRecord.this[int i] => _reader[i]; + protected override DbDataReader GetDbDataReader(int ordinal) => _reader.GetData(ordinal); } } diff --git a/Directory.Build.props b/Directory.Build.props index 0cb2bf374..9b3e9cf67 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -21,8 +21,9 @@ false true - + 9.0 + false diff --git a/docs/index.md b/docs/index.md index 84f237b51..ea0af98bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -22,6 +22,10 @@ Note: to get the latest pre-release build, add ` -Pre` to the end of the command ### unreleased +- add support for `SqlDecimal` and other types that need to be accessed via `DbDataReader.GetFieldValue` +- add an overload of `AddTypeMap` that supports `DbDataReader.GetFieldValue` for additional types +- acknowledge that in reality we only support `DbDataReader`; this has been true (via `DbConnection`) for `async` forever + (note: new PRs will not be merged until they add release note wording here) ### 2.0.123 diff --git a/tests/Dapper.Tests/Dapper.Tests.csproj b/tests/Dapper.Tests/Dapper.Tests.csproj index b9f8b9347..35af8a509 100644 --- a/tests/Dapper.Tests/Dapper.Tests.csproj +++ b/tests/Dapper.Tests/Dapper.Tests.csproj @@ -2,7 +2,7 @@ Dapper.Tests Dapper Core Test Suite - netcoreapp3.1;net462;net472;net5.0 + net462;net472;net6.0 $(DefineConstants);MSSQLCLIENT $(NoWarn);IDE0017;IDE0034;IDE0037;IDE0039;IDE0042;IDE0044;IDE0051;IDE0052;IDE0059;IDE0060;IDE0063;IDE1006;xUnit1004;CA1806;CA1816;CA1822;CA1825;CA2208 @@ -17,7 +17,6 @@ - diff --git a/tests/Dapper.Tests/DataReaderTests.cs b/tests/Dapper.Tests/DataReaderTests.cs index ab28619bc..ea6b9e1c2 100644 --- a/tests/Dapper.Tests/DataReaderTests.cs +++ b/tests/Dapper.Tests/DataReaderTests.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Data.Common; using System.Linq; using Xunit; @@ -14,14 +15,17 @@ public sealed class MicrosoftSqlClientDataReaderTests : DataReaderTests : TestBase where TProvider : DatabaseProvider { [Fact] - public void GetSameReaderForSameShape() + public void GetSameReaderForSameShape_IDataReader() { var origReader = connection.ExecuteReader("select 'abc' as Name, 123 as Id"); +#pragma warning disable CS0618 // Type or member is obsolete var origParser = origReader.GetRowParser(typeof(HazNameId)); var typedParser = origReader.GetRowParser(); +#pragma warning restore CS0618 // Type or member is obsolete - Assert.True(ReferenceEquals(origParser, typedParser)); + // because wrapped for IDataReader, not same instance each time + Assert.False(ReferenceEquals(origParser, typedParser)); var list = origReader.Parse().ToList(); Assert.Single(list); @@ -30,6 +34,40 @@ public void GetSameReaderForSameShape() origReader.Dispose(); var secondReader = connection.ExecuteReader("select 'abc' as Name, 123 as Id"); +#pragma warning disable CS0618 // Type or member is obsolete + var secondParser = secondReader.GetRowParser(typeof(HazNameId)); + var thirdParser = secondReader.GetRowParser(typeof(HazNameId), 1); +#pragma warning restore CS0618 // Type or member is obsolete + + list = secondReader.Parse().ToList(); + Assert.Single(list); + Assert.Equal("abc", list[0].Name); + Assert.Equal(123, list[0].Id); + secondReader.Dispose(); + + // now: should be different readers, and because wrapped for IDataReader, not same parser + Assert.False(ReferenceEquals(origReader, secondReader)); + Assert.False(ReferenceEquals(origParser, secondParser)); + Assert.False(ReferenceEquals(secondParser, thirdParser)); + } + + [Fact] + public void GetSameReaderForSameShape_DbDataReader() + { + var origReader = Assert.IsAssignableFrom(connection.ExecuteReader("select 'abc' as Name, 123 as Id")); + var origParser = origReader.GetRowParser(typeof(HazNameId)); + + var typedParser = origReader.GetRowParser(); + + Assert.True(ReferenceEquals(origParser, typedParser)); + + var list = origReader.Parse().ToList(); + Assert.Single(list); + Assert.Equal("abc", list[0].Name); + Assert.Equal(123, list[0].Id); + origReader.Dispose(); + + var secondReader = Assert.IsAssignableFrom(connection.ExecuteReader("select 'abc' as Name, 123 as Id")); var secondParser = secondReader.GetRowParser(typeof(HazNameId)); var thirdParser = secondReader.GetRowParser(typeof(HazNameId), 1); @@ -56,7 +94,7 @@ public void TestTreatIntAsABool() } [Fact] - public void DiscriminatedUnion() + public void DiscriminatedUnion_IDataReader() { List result = new List(); using (var reader = connection.ExecuteReader(@" @@ -66,8 +104,10 @@ union all { if (reader.Read()) { +#pragma warning disable CS0618 var toFoo = reader.GetRowParser(typeof(Discriminated_Foo)); var toBar = reader.GetRowParser(typeof(Discriminated_Bar)); +#pragma warning restore CS0618 var col = reader.GetOrdinal("Type"); do @@ -95,7 +135,46 @@ union all } [Fact] - public void DiscriminatedUnionWithMultiMapping() + public void DiscriminatedUnion_DbDataReader() + { + List result = new List(); + using (var reader = Assert.IsAssignableFrom(connection.ExecuteReader(@" +select 'abc' as Name, 1 as Type, 3.0 as Value +union all +select 'def' as Name, 2 as Type, 4.0 as Value"))) + { + if (reader.Read()) + { + var toFoo = reader.GetRowParser(typeof(Discriminated_Foo)); + var toBar = reader.GetRowParser(typeof(Discriminated_Bar)); + + var col = reader.GetOrdinal("Type"); + do + { + switch (reader.GetInt32(col)) + { + case 1: + result.Add(toFoo(reader)); + break; + case 2: + result.Add(toBar(reader)); + break; + } + } while (reader.Read()); + } + } + + Assert.Equal(2, result.Count); + Assert.Equal(1, result[0].Type); + Assert.Equal(2, result[1].Type); + var foo = (Discriminated_Foo)result[0]; + Assert.Equal("abc", foo.Name); + var bar = (Discriminated_Bar)result[1]; + Assert.Equal(bar.Value, (float)4.0); + } + + [Fact] + public void DiscriminatedUnionWithMultiMapping_IDataReader() { var result = new List(); using (var reader = connection.ExecuteReader(@" @@ -108,6 +187,60 @@ union all var col = reader.GetOrdinal("Type"); var splitOn = reader.GetOrdinal("Id"); +#pragma warning disable CS0618 + var toFoo = reader.GetRowParser(typeof(DiscriminatedWithMultiMapping_Foo), 0, splitOn); + var toBar = reader.GetRowParser(typeof(DiscriminatedWithMultiMapping_Bar), 0, splitOn); + var toHaz = reader.GetRowParser(typeof(HazNameId), splitOn, reader.FieldCount - splitOn); +#pragma warning restore CS0618 + + do + { + DiscriminatedWithMultiMapping_BaseType obj = null; + switch (reader.GetInt32(col)) + { + case 1: + obj = toFoo(reader); + break; + case 2: + obj = toBar(reader); + break; + } + + Assert.NotNull(obj); + obj.HazNameIdObject = toHaz(reader); + result.Add(obj); + + } while (reader.Read()); + } + } + + Assert.Equal(2, result.Count); + Assert.Equal(1, result[0].Type); + Assert.Equal(2, result[1].Type); + var foo = (DiscriminatedWithMultiMapping_Foo)result[0]; + Assert.Equal("abc", foo.Name); + Assert.Equal(1, foo.HazNameIdObject.Id); + Assert.Equal("zxc", foo.HazNameIdObject.Name); + var bar = (DiscriminatedWithMultiMapping_Bar)result[1]; + Assert.Equal(bar.Value, (float)4.0); + Assert.Equal(2, bar.HazNameIdObject.Id); + Assert.Equal("qwe", bar.HazNameIdObject.Name); + } + + [Fact] + public void DiscriminatedUnionWithMultiMapping_DbDataReader() + { + var result = new List(); + using (var reader = Assert.IsAssignableFrom(connection.ExecuteReader(@" +select 'abc' as Name, 1 as Type, 3.0 as Value, 1 as Id, 'zxc' as Name +union all +select 'def' as Name, 2 as Type, 4.0 as Value, 2 as Id, 'qwe' as Name"))) + { + if (reader.Read()) + { + var col = reader.GetOrdinal("Type"); + var splitOn = reader.GetOrdinal("Id"); + var toFoo = reader.GetRowParser(typeof(DiscriminatedWithMultiMapping_Foo), 0, splitOn); var toBar = reader.GetRowParser(typeof(DiscriminatedWithMultiMapping_Bar), 0, splitOn); var toHaz = reader.GetRowParser(typeof(HazNameId), splitOn, reader.FieldCount - splitOn); diff --git a/tests/Dapper.Tests/ParameterTests.cs b/tests/Dapper.Tests/ParameterTests.cs index 7179563ad..a6a17e24f 100644 --- a/tests/Dapper.Tests/ParameterTests.cs +++ b/tests/Dapper.Tests/ParameterTests.cs @@ -3,13 +3,14 @@ using System.Collections.Generic; using System.ComponentModel; using System.Data; +using System.Data.Common; using System.Data.SqlTypes; +using System.Diagnostics; using System.Dynamic; -using System.Linq; -using Xunit; using System.Globalization; +using System.Linq; using System.Text.RegularExpressions; -using System.Diagnostics; +using Xunit; #if ENTITY_FRAMEWORK using System.Data.Entity.Spatial; @@ -1593,5 +1594,116 @@ private static int GetExpectedListExpansionCount(int count, bool enabled) if (delta != 0) blocks++; return blocks * padFactor; } + + [Fact] + public void Issue1907_SqlDecimalPreciseValues() + { + bool close = false; + try + { + if (connection.State != ConnectionState.Open) + { + connection.Open(); + close = true; + } + connection.Execute(@" +create table #Issue1907 ( + Id int not null primary key identity(1,1), + Value numeric(30,15) not null);"); + + const string PreciseValue = "999999999999999.999999999999999"; + SqlDecimal sentValue = SqlDecimal.Parse(PreciseValue), recvValue; + connection.Execute("insert #Issue1907 (Value) values (@value)", new { value = sentValue }); + + // access via vendor-specific API; if this fails, nothing else can work + using (var wrappedReader = connection.ExecuteReader("select Id, Value from #Issue1907")) + { + var reader = Assert.IsAssignableFrom(wrappedReader).Reader; + Assert.True(reader.Read()); + if (reader is Microsoft.Data.SqlClient.SqlDataReader msReader) + { + recvValue = msReader.GetSqlDecimal(1); + } + else if (reader is System.Data.SqlClient.SqlDataReader sdReader) + { + recvValue = sdReader.GetSqlDecimal(1); + } + else + { + throw new InvalidOperationException($"unexpected reader type: {reader.GetType().FullName}"); + } + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + + Assert.False(reader.Read()); + Assert.False(reader.NextResult()); + } + + // access via generic API + using (var wrappedReader = connection.ExecuteReader("select Id, Value from #Issue1907")) + { + var reader = Assert.IsAssignableFrom(Assert.IsAssignableFrom(wrappedReader).Reader); + Assert.True(reader.Read()); + recvValue = reader.GetFieldValue(1); + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + + Assert.False(reader.Read()); + Assert.False(reader.NextResult()); + } + + // prove that we **cannot** fix ExecuteScalar, because ADO.NET itself doesn't work for that + Assert.Throws(() => + { + using var cmd = connection.CreateCommand(); + cmd.CommandText = "select Value from #Issue1907"; + cmd.CommandType = CommandType.Text; + _ = cmd.ExecuteScalar(); + }); + + // prove that simple read: works + recvValue = connection.QuerySingle("select Value from #Issue1907"); + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + + recvValue = connection.QuerySingle("select Value from #Issue1907").Value; + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + + // prove that object read: works + recvValue = connection.QuerySingle("select Id, Value from #Issue1907").Value; + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + + recvValue = connection.QuerySingle("select Id, Value from #Issue1907").Value.Value; + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + + // prove that value-tuple read: works + recvValue = connection.QuerySingle<(int Id, SqlDecimal Value)>("select Id, Value from #Issue1907").Value; + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + + recvValue = connection.QuerySingle<(int Id, SqlDecimal? Value)>("select Id, Value from #Issue1907").Value.Value; + Assert.Equal(sentValue, recvValue); + Assert.Equal(recvValue.ToString(), PreciseValue); + } + finally + { + if (close) connection.Close(); + } + + } + class HazSqlDecimal + { + public int Id { get; set; } + public SqlDecimal Value { get; set; } + } + + class HazNullableSqlDecimal + { + public int Id { get; set; } + public SqlDecimal? Value { get; set; } + } } }