Skip to content

Commit

Permalink
add support for fetching values via GetFieldValue<T> (#1910)
Browse files Browse the repository at this point in the history
add support for fetching values via GetFieldValue<T>; this requires
DbDataReader, but in reality the reader is *always* DbDataReader;
acknowledge this, and stop pretending to use IDataReader internally
(no breaks to public API)
  • Loading branch information
mgravell authored Jun 9, 2023
1 parent a31dfd3 commit 01f03ef
Show file tree
Hide file tree
Showing 14 changed files with 586 additions and 208 deletions.
4 changes: 2 additions & 2 deletions Dapper/SqlMapper.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ private static async Task<IEnumerable<TReturn>> MultiMapAsync<TReturn>(this IDbC
}
}

private static IEnumerable<T> ExecuteReaderSync<T>(IDataReader reader, Func<IDataReader, object> func, object parameters)
private static IEnumerable<T> ExecuteReaderSync<T>(DbDataReader reader, Func<DbDataReader, object> func, object parameters)
{
using (reader)
{
Expand Down Expand Up @@ -1004,7 +1004,7 @@ public static async Task<GridReader> QueryMultipleAsync(this IDbConnection cnn,
CacheInfo info = GetCacheInfo(identity, param, command.AddToCache);

DbCommand cmd = null;
IDataReader reader = null;
DbDataReader reader = null;
bool wasClosed = cnn.State == ConnectionState.Closed;
try
{
Expand Down
3 changes: 2 additions & 1 deletion Dapper/SqlMapper.CacheInfo.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Data;
using System.Data.Common;
using System.Threading;

namespace Dapper
Expand All @@ -9,7 +10,7 @@ public static partial class SqlMapper
private class CacheInfo
{
public DeserializerState Deserializer { get; set; }
public Func<IDataReader, object>[] OtherDeserializers { get; set; }
public Func<DbDataReader, object>[] OtherDeserializers { get; set; }
public Action<IDbCommand, object> ParamReader { get; set; }
private int hitCount;
public int GetHitCount() { return Interlocked.CompareExchange(ref hitCount, 0, 0); }
Expand Down
5 changes: 3 additions & 2 deletions Dapper/SqlMapper.DeserializerState.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Data;
using System.Data.Common;

namespace Dapper
{
Expand All @@ -8,9 +9,9 @@ public static partial class SqlMapper
private struct DeserializerState
{
public readonly int Hash;
public readonly Func<IDataReader, object> Func;
public readonly Func<DbDataReader, object> Func;

public DeserializerState(int hash, Func<IDataReader, object> func)
public DeserializerState(int hash, Func<DbDataReader, object> func)
{
Hash = hash;
Func = func;
Expand Down
4 changes: 2 additions & 2 deletions Dapper/SqlMapper.GridReader.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public static partial class SqlMapper
public partial class GridReader
{
private readonly CancellationToken cancel;
internal GridReader(IDbCommand command, IDataReader reader, Identity identity, DynamicParameters dynamicParams, bool addToCache, CancellationToken cancel)
internal GridReader(IDbCommand command, DbDataReader reader, Identity identity, DynamicParameters dynamicParams, bool addToCache, CancellationToken cancel)
: this(command, reader, identity, dynamicParams, addToCache)
{
this.cancel = cancel;
Expand Down Expand Up @@ -225,7 +225,7 @@ private async Task<T> ReadRowAsyncImplViaDbReader<T>(DbDataReader reader, Type t
return result;
}

private async Task<IEnumerable<T>> ReadBufferedAsync<T>(int index, Func<IDataReader, object> deserializer)
private async Task<IEnumerable<T>> ReadBufferedAsync<T>(int index, Func<DbDataReader, object> deserializer)
{
try
{
Expand Down
7 changes: 4 additions & 3 deletions Dapper/SqlMapper.GridReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Data.Common;

namespace Dapper
{
Expand All @@ -14,11 +15,11 @@ public static partial class SqlMapper
/// </summary>
public partial class GridReader : IDisposable
{
private IDataReader reader;
private DbDataReader reader;
private readonly Identity identity;
private readonly bool addToCache;

internal GridReader(IDbCommand command, IDataReader reader, Identity identity, IParameterCallbacks callbacks, bool addToCache)
internal GridReader(IDbCommand command, DbDataReader reader, Identity identity, IParameterCallbacks callbacks, bool addToCache)
{
Command = command;
this.reader = reader;
Expand Down Expand Up @@ -351,7 +352,7 @@ public IEnumerable<TReturn> Read<TReturn>(Type[] types, Func<object[], TReturn>
return buffered ? result.ToList() : result;
}

private IEnumerable<T> ReadDeferred<T>(int index, Func<IDataReader, object> deserializer, Type effectiveType)
private IEnumerable<T> ReadDeferred<T>(int index, Func<DbDataReader, object> deserializer, Type effectiveType)
{
try
{
Expand Down
67 changes: 53 additions & 14 deletions Dapper/SqlMapper.IDataReader.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;

namespace Dapper
{
Expand All @@ -13,14 +14,15 @@ public static partial class SqlMapper
/// <param name="reader">The data reader to parse results from.</param>
public static IEnumerable<T> Parse<T>(this IDataReader reader)
{
if (reader.Read())
var dbReader = GetDbDataReader(reader, false);
if (dbReader.Read())
{
var effectiveType = typeof(T);
var deser = GetDeserializer(effectiveType, reader, 0, -1, false);
var deser = GetDeserializer(effectiveType, dbReader, 0, -1, false);
var convertToType = Nullable.GetUnderlyingType(effectiveType) ?? effectiveType;
do
{
object val = deser(reader);
object val = deser(dbReader);
if (val == null || val is T)
{
yield return (T)val;
Expand All @@ -29,7 +31,7 @@ public static IEnumerable<T> Parse<T>(this IDataReader reader)
{
yield return (T)Convert.ChangeType(val, convertToType, System.Globalization.CultureInfo.InvariantCulture);
}
} while (reader.Read());
} while (dbReader.Read());
}
}

Expand All @@ -40,13 +42,14 @@ public static IEnumerable<T> Parse<T>(this IDataReader reader)
/// <param name="type">The type to parse from the <paramref name="reader"/>.</param>
public static IEnumerable<object> Parse(this IDataReader reader, Type type)
{
if (reader.Read())
var dbReader = GetDbDataReader(reader, false);
if (dbReader.Read())
{
var deser = GetDeserializer(type, reader, 0, -1, false);
var deser = GetDeserializer(type, dbReader, 0, -1, false);
do
{
yield return deser(reader);
} while (reader.Read());
yield return deser(dbReader);
} while (dbReader.Read());
}
}

Expand All @@ -56,13 +59,14 @@ public static IEnumerable<object> Parse(this IDataReader reader, Type type)
/// <param name="reader">The data reader to parse results from.</param>
public static IEnumerable<dynamic> Parse(this IDataReader reader)
{
if (reader.Read())
var dbReader = GetDbDataReader(reader, false);
if (dbReader.Read())
{
var deser = GetDapperRowDeserializer(reader, 0, -1, false);
var deser = GetDapperRowDeserializer(dbReader, 0, -1, false);
do
{
yield return deser(reader);
} while (reader.Read());
yield return deser(dbReader);
} while (dbReader.Read());
}
}

Expand All @@ -76,12 +80,47 @@ public static IEnumerable<dynamic> Parse(this IDataReader reader)
/// <param name="length">The length of columns to read (default -1 = all fields following startIndex)</param>
/// <param name="returnNullIfFirstMissing">Return null if we can't find the first column? (default false)</param>
/// <returns>A parser for this specific object from this row.</returns>
#if DEBUG // make sure we're not using this internally
[Obsolete(nameof(DbDataReader) + " API should be preferred")]
#endif
public static Func<IDataReader, object> GetRowParser(this IDataReader reader, Type type,
int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false)
{
return WrapObjectReader(GetDeserializer(type, GetDbDataReader(reader, false), startIndex, length, returnNullIfFirstMissing));
}

/// <summary>
/// Gets the row parser for a specific row on a data reader. This allows for type switching every row based on, for example, a TypeId column.
/// You could return a collection of the base type but have each more specific.
/// </summary>
/// <param name="reader">The data reader to get the parser for the current row from</param>
/// <param name="type">The type to get the parser for</param>
/// <param name="startIndex">The start column index of the object (default 0)</param>
/// <param name="length">The length of columns to read (default -1 = all fields following startIndex)</param>
/// <param name="returnNullIfFirstMissing">Return null if we can't find the first column? (default false)</param>
/// <returns>A parser for this specific object from this row.</returns>
public static Func<DbDataReader, object> GetRowParser(this DbDataReader reader, Type type,
int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false)
{
return GetDeserializer(type, reader, startIndex, length, returnNullIfFirstMissing);
}

/// <inheritdoc cref="GetRowParser{T}(DbDataReader, Type, int, int, bool)"/>
#if DEBUG // make sure we're not using this internally
[Obsolete(nameof(DbDataReader) + " API should be preferred")]
#endif
public static Func<IDataReader, T> GetRowParser<T>(this IDataReader reader, Type concreteType = null,
int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false)
{
concreteType ??= typeof(T);
var func = GetDeserializer(concreteType, GetDbDataReader(reader, false), startIndex, length, returnNullIfFirstMissing);
return Wrap(func);

// this is just to be very clear about what we're capturing
static Func<IDataReader, T> Wrap(Func<DbDataReader, object> func)
=> reader => (T)func(GetDbDataReader(reader, false));
}

/// <summary>
/// Gets the row parser for a specific row on a data reader. This allows for type switching every row based on, for example, a TypeId column.
/// You could return a collection of the base type but have each more specific.
Expand Down Expand Up @@ -135,7 +174,7 @@ public static Func<IDataReader, object> GetRowParser(this IDataReader reader, Ty
/// public override int Type =&gt; 2;
/// }
/// </example>
public static Func<IDataReader, T> GetRowParser<T>(this IDataReader reader, Type concreteType = null,
public static Func<DbDataReader, T> GetRowParser<T>(this DbDataReader reader, Type concreteType = null,
int startIndex = 0, int length = -1, bool returnNullIfFirstMissing = false)
{
concreteType ??= typeof(T);
Expand All @@ -146,7 +185,7 @@ public static Func<IDataReader, T> GetRowParser<T>(this IDataReader reader, Type
}
else
{
return (Func<IDataReader, T>)(Delegate)func;
return (Func<DbDataReader, T>)(Delegate)func;
}
}
}
Expand Down
13 changes: 7 additions & 6 deletions Dapper/SqlMapper.TypeDeserializerCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Text;
using System.Data.Common;

namespace Dapper
{
Expand Down Expand Up @@ -33,7 +34,7 @@ internal static void Purge()
}
}

internal static Func<IDataReader, object> GetReader(Type type, IDataReader reader, int startBound, int length, bool returnNullIfFirstMissing)
internal static Func<DbDataReader, object> GetReader(Type type, DbDataReader reader, int startBound, int length, bool returnNullIfFirstMissing)
{
var found = (TypeDeserializerCache)byType[type];
if (found == null)
Expand All @@ -50,18 +51,18 @@ internal static Func<IDataReader, object> GetReader(Type type, IDataReader reade
return found.GetReader(reader, startBound, length, returnNullIfFirstMissing);
}

private readonly Dictionary<DeserializerKey, Func<IDataReader, object>> readers = new Dictionary<DeserializerKey, Func<IDataReader, object>>();
private readonly Dictionary<DeserializerKey, Func<DbDataReader, object>> readers = new Dictionary<DeserializerKey, Func<DbDataReader, object>>();

private struct DeserializerKey : IEquatable<DeserializerKey>
{
private readonly int startBound, length;
private readonly bool returnNullIfFirstMissing;
private readonly IDataReader reader;
private readonly DbDataReader reader;
private readonly string[] names;
private readonly Type[] types;
private readonly int hashCode;

public DeserializerKey(int hashCode, int startBound, int length, bool returnNullIfFirstMissing, IDataReader reader, bool copyDown)
public DeserializerKey(int hashCode, int startBound, int length, bool returnNullIfFirstMissing, DbDataReader reader, bool copyDown)
{
this.hashCode = hashCode;
this.startBound = startBound;
Expand Down Expand Up @@ -136,14 +137,14 @@ public bool Equals(DeserializerKey other)
}
}

private Func<IDataReader, object> GetReader(IDataReader reader, int startBound, int length, bool returnNullIfFirstMissing)
private Func<DbDataReader, object> GetReader(DbDataReader reader, int startBound, int length, bool returnNullIfFirstMissing)
{
if (length < 0) length = reader.FieldCount - startBound;
int hash = GetColumnHash(reader, startBound, length);
if (returnNullIfFirstMissing) hash *= -27;
// get a cheap key first: false means don't copy the values down
var key = new DeserializerKey(hash, startBound, length, returnNullIfFirstMissing, reader, false);
Func<IDataReader, object> deser;
Func<DbDataReader, object> deser;
lock (readers)
{
if (readers.TryGetValue(key, out deser)) return deser;
Expand Down
Loading

0 comments on commit 01f03ef

Please sign in to comment.