Skip to content

Commit

Permalink
fix up some async paths
Browse files Browse the repository at this point in the history
  • Loading branch information
quinchs committed Jan 9, 2024
1 parent f3a091c commit 1bdbc4f
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 58 deletions.
10 changes: 6 additions & 4 deletions src/EdgeDB.Net.Driver/Binary/Builders/ObjectBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ internal sealed class ObjectBuilder
private static readonly ConcurrentDictionary<Type, (int Version, ICodec Codec)> _codecVisitorStateTable = new();
private static readonly object _visitorLock = new();

public static async Task<PreheatedCodec> PreheatCodec<T>(EdgeDBBinaryClient client, ICodec codec)
public static async Task<PreheatedCodec> PreheatCodecAsync<T>(EdgeDBBinaryClient client, ICodec codec,
CancellationToken token)
{
// if the codec has been visited before and we have the most up-to-date version, return it.
if (
Expand All @@ -30,7 +31,7 @@ public static async Task<PreheatedCodec> PreheatCodec<T>(EdgeDBBinaryClient clie
var visitor = new TypeVisitor(client);
visitor.SetTargetType(typeof(T));
var reference = new Ref<ICodec>(codec);
await visitor.VisitAsync(reference);
await visitor.VisitAsync(reference, token);

if (typeof(T) != typeof(object))
_codecVisitorStateTable[typeof(T)] = (version, reference.Value);
Expand All @@ -56,8 +57,9 @@ public static async Task<PreheatedCodec> PreheatCodec<T>(EdgeDBBinaryClient clie
return (T?)ConvertTo(typeof(T), value);
}

public static async Task<T?> BuildResultAsync<T>(EdgeDBBinaryClient client, ICodec codec, ReadOnlyMemory<byte> data)
=> BuildResult<T>(client, await PreheatCodec<T>(client, codec), data);
public static async Task<T?> BuildResultAsync<T>(
EdgeDBBinaryClient client, ICodec codec, ReadOnlyMemory<byte> data, CancellationToken token)
=> BuildResult<T>(client, await PreheatCodecAsync<T>(client, codec, token), data);

public static object? ConvertTo(Type type, object? value)
{
Expand Down
2 changes: 1 addition & 1 deletion src/EdgeDB.Net.Driver/Binary/Codecs/CompilableCodec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ public Type GetInnerType()
: InnerCodec.ConverterType;

public override string ToString()
=> $"compilable({_rootCodecType.Name})";
=> $"compilable({InnerCodec})";
}
10 changes: 6 additions & 4 deletions src/EdgeDB.Net.Driver/Binary/Codecs/Visitors/ArgumentVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ public ArgumentVisitor(EdgeDBBinaryClient client, IDictionary<string, object?> a
_arguments = args;
}

protected override async Task VisitCodecAsync(Ref<ICodec> codec)
protected override async Task VisitCodecAsync(Ref<ICodec> codec, CancellationToken token)
{
token.ThrowIfCancellationRequested();

switch (codec.Value)
{
case ObjectCodec objectCodec:
Expand All @@ -28,7 +30,7 @@ protected override async Task VisitCodecAsync(Ref<ICodec> codec)
type = v?.GetType() ?? typeof(object);

var reference = new Ref<ICodec>(x);
await VisitAsync(reference, type);
await VisitAsync(reference, type, token);

return reference.Value;
})
Expand All @@ -41,10 +43,10 @@ protected override async Task VisitCodecAsync(Ref<ICodec> codec)
}
}

private async ValueTask VisitAsync(Ref<ICodec> codec, Type type)
private async ValueTask VisitAsync(Ref<ICodec> codec, Type type, CancellationToken token)
{
var typeVisitor = new TypeVisitor(_client);
typeVisitor.SetTargetType(type);
await typeVisitor.VisitAsync(codec);
await typeVisitor.VisitAsync(codec, token);
}
}
4 changes: 2 additions & 2 deletions src/EdgeDB.Net.Driver/Binary/Codecs/Visitors/CodecVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace EdgeDB.Binary.Codecs;

internal abstract class CodecVisitor
{
public Task VisitAsync(Ref<ICodec> codec) => VisitCodecAsync(codec);
public Task VisitAsync(Ref<ICodec> codec, CancellationToken token) => VisitCodecAsync(codec, token);

protected abstract Task VisitCodecAsync(Ref<ICodec> codec);
protected abstract Task VisitCodecAsync(Ref<ICodec> codec, CancellationToken token);
}
19 changes: 9 additions & 10 deletions src/EdgeDB.Net.Driver/Binary/Codecs/Visitors/TypeVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ protected override
#if DEBUG
async
#endif
Task VisitCodecAsync(Ref<ICodec> codec)
Task VisitCodecAsync(Ref<ICodec> codec, CancellationToken token)
{
if (_context is null)
throw new EdgeDBException("Context was not initialized for type walking");

#if DEBUG
var sw = Stopwatch.StartNew();
await VisitCodecAsync(codec, _context);
await VisitCodecAsync(codec, _context, token);
sw.Stop();
_logger.CodecVisitorTimingTrace(
this,
Expand All @@ -50,14 +50,13 @@ Task VisitCodecAsync(Ref<ICodec> codec)
#else
if(_logger.IsEnabled(LogLevel.Trace))
_logger.CodecTree(CodecFormatter.FormatCodecAsTree(codec.Value).ToString());
return VisitCodecAsync(codec, _context);
return VisitCodecAsync(codec, _context, token);
#endif
}

private async Task VisitCodecAsync(Ref<ICodec> codec, TypeVisitorContext context)
private async Task VisitCodecAsync(Ref<ICodec> codec, TypeVisitorContext context, CancellationToken token)
{
// TODO: if dynamic or object was passed in, return the default type
// from the complex OR based on config.
token.ThrowIfCancellationRequested();

if (context.Type == typeof(void))
return;
Expand All @@ -74,7 +73,7 @@ private async Task VisitCodecAsync(Ref<ICodec> codec, TypeVisitorContext context
? context.Type
: context.Type.GetWrappingType();

await VisitCodecAsync(reference, context with { Type = type });
await VisitCodecAsync(reference, context with { Type = type }, token);
compiled.InnerCodec = reference.Value;
}
break;
Expand Down Expand Up @@ -121,7 +120,7 @@ await VisitCodecAsync(reference, subContext with
: x.ConverterType,
InnerRealType = !hasPropInfo && x is CompilableWrappingCodec,
Name = obj.PropertyNames[i]
});
}, token);

obj.InnerCodecs[i] = reference.Value;
}));
Expand All @@ -147,7 +146,7 @@ await VisitCodecAsync(reference, context with
Type = tupleTypes is not null
? tupleTypes[i]
: typeof(object)
});
}, token);
tuple.InnerCodecs[i] = reference.Value;
}));

Expand All @@ -164,7 +163,7 @@ await VisitCodecAsync(reference, context with
: context.Type.GetWrappingType();

var reference = new Ref<ICodec>(compilable.InnerCodec);
await VisitCodecAsync(reference, context with { Type = innerType });
await VisitCodecAsync(reference, context with { Type = innerType }, token);
codec.Value = compilable.Compile(_client.ProtocolProvider, innerType, reference.Value);
}
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public virtual async Task<ExecuteResult> ExecuteQueryAsync(QueryParameters query
var inCodec = new Ref<ICodec>(parseResult.InCodecInfo.Codec);

var argumentVisitor = new ArgumentVisitor(_client, queryParameters.Arguments);
await argumentVisitor.VisitAsync(inCodec);
await argumentVisitor.VisitAsync(inCodec, token);

if (inCodec.Value is not IArgumentCodec argumentCodec)
throw new MissingCodecException(
Expand Down
83 changes: 47 additions & 36 deletions src/EdgeDB.Net.Driver/Clients/EdgeDBBinaryClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,20 @@ public async Task SyncAsync(CancellationToken token = default)
}
}

internal readonly record struct ExecuteResult(
ProtocolExecuteResult ProtocolResult,
ObjectBuilder.PreheatedCodec? PreheatedCodec
);

/// <exception cref="EdgeDBException">A general error occored.</exception>
/// <exception cref="EdgeDBErrorException">The client received an <see cref="IProtocolError" />.</exception>
/// <exception cref="UnexpectedMessageException">The client received an unexpected message.</exception>
/// <exception cref="MissingCodecException">A codec could not be found for the given input arguments or the result.</exception>
internal virtual async Task<ProtocolExecuteResult> ExecuteInternalAsync(string query,
internal virtual async Task<ExecuteResult> ExecuteInternalAsync(string query,
IDictionary<string, object?>? args = null, Cardinality? cardinality = null,
Capabilities? capabilities = Capabilities.Modifications, IOFormat format = IOFormat.Binary,
bool isRetry = false, bool implicitTypeName = false,
Func<ParseResult, Task>? preheat = null,
Func<ParseResult, Task<ObjectBuilder.PreheatedCodec>>? preheat = null,
CancellationToken token = default)
{
// if the current client is not connected, reconnect it
Expand Down Expand Up @@ -176,12 +181,13 @@ internal virtual async Task<ProtocolExecuteResult> ExecuteInternalAsync(string q
_protocolProvider.ExecuteQueryAsync(arguments, parseResult, linkedTokenSource.Token);

#if DEBUG
async Task PreheatWithTrace(ParseResult p)
async Task<ObjectBuilder.PreheatedCodec> PreheatWithTrace(ParseResult p)
{
var stopwatch = Stopwatch.StartNew();
await preheat(p);
var preheated = await preheat(p);
stopwatch.Stop();
Logger.LogDebug("Preheating of codecs took {@PreheatTime}ms", Math.Round(stopwatch.Elapsed.TotalMilliseconds, 4));
return preheated;
}

var preheatTask = PreheatWithTrace(parseResult);
Expand All @@ -194,10 +200,10 @@ await Task.WhenAll(
executeTask
);

return executeTask.Result;
return new(executeTask.Result, preheatTask.Result);
}

return await _protocolProvider.ExecuteQueryAsync(arguments, parseResult, linkedTokenSource.Token);
return new(await _protocolProvider.ExecuteQueryAsync(arguments, parseResult, linkedTokenSource.Token), null);
}
catch (OperationCanceledException ce)
{
Expand Down Expand Up @@ -244,6 +250,15 @@ await Task.WhenAll(
}
}

private ValueTask<ObjectBuilder.PreheatedCodec> GetPreheatedCodecAsync<T>(ExecuteResult result, CancellationToken token)
{
if (result.PreheatedCodec.HasValue)
return new ValueTask<ObjectBuilder.PreheatedCodec>(result.PreheatedCodec.Value);

return new ValueTask<ObjectBuilder.PreheatedCodec>(
ObjectBuilder.PreheatCodecAsync<T>(this, result.ProtocolResult.OutCodecInfo.Codec, token));
}

/// <inheritdoc />
/// <exception cref="EdgeDBException">A general error occored.</exception>
/// <exception cref="EdgeDBErrorException">The client received an <see cref="IProtocolError" />.</exception>
Expand Down Expand Up @@ -275,30 +290,26 @@ public override async Task ExecuteAsync(string query, IDictionary<string, object
Cardinality.Many,
capabilities,
implicitTypeName: implicitTypeName,
preheat: parseResult =>
Task.Run(() => ObjectBuilder.PreheatCodec<TResult>(this, parseResult.OutCodecInfo.Codec), token),
preheat: parseResult => ObjectBuilder.PreheatCodecAsync<TResult>(this, parseResult.OutCodecInfo.Codec, token),
token: token);

var array = new TResult?[result.Data.Length];
var array = new TResult?[result.ProtocolResult.Data.Length];

var codec = result.OutCodecInfo.Codec;
var codec = await GetPreheatedCodecAsync<TResult>(result, token);

if (result.Data.Length <= 7)
if (result.ProtocolResult.Data.Length <= 7)
{
for (var i = 0; i != result.Data.Length; i++)
for (var i = 0; i != result.ProtocolResult.Data.Length; i++)
{
var obj = await ObjectBuilder.BuildResultAsync<TResult>(this, codec, result.Data[i]);
array[i] = obj;
array[i] = ObjectBuilder.BuildResult<TResult>(this, codec, result.ProtocolResult.Data[i]);
}
}
else
{
await Task.WhenAll(
result.Data.Select(async (x, i) =>
{
array[i] = await ObjectBuilder.BuildResultAsync<TResult>(this, codec, x);
})
);
Parallel.ForEach(result.ProtocolResult.Data, (x, state, i) =>
{
array[i] = ObjectBuilder.BuildResult<TResult>(this, codec, x);
});
}

return array.ToImmutableArray();
Expand Down Expand Up @@ -326,16 +337,17 @@ await Task.WhenAll(
Cardinality.AtMostOne,
capabilities,
implicitTypeName: implicitTypeName,
preheat: parseResult =>
Task.Run(() => ObjectBuilder.PreheatCodec<TResult>(this, parseResult.OutCodecInfo.Codec), token),
preheat: parseResult => ObjectBuilder.PreheatCodecAsync<TResult>(this, parseResult.OutCodecInfo.Codec, token),
token: token);

if (result.Data.Length > 1)
if (result.ProtocolResult.Data.Length > 1)
throw new ResultCardinalityMismatchException(Cardinality.AtMostOne, Cardinality.Many);

return result.Data.Length == 0
var codec = await GetPreheatedCodecAsync<TResult>(result, token);

return result.ProtocolResult.Data.Length == 0
? default
: await ObjectBuilder.BuildResultAsync<TResult>(this, result.OutCodecInfo.Codec, result.Data[0]);
: ObjectBuilder.BuildResult<TResult>(this, codec, result.ProtocolResult.Data[0]);
}

/// <inheritdoc />
Expand All @@ -360,19 +372,18 @@ public override async Task<TResult> QueryRequiredSingleAsync<TResult>(string que
Cardinality.AtMostOne,
capabilities,
implicitTypeName: implicitTypeName,
preheat: parseResult =>
Task.Run(() => ObjectBuilder.PreheatCodec<TResult>(this, parseResult.OutCodecInfo.Codec), token),
preheat: parseResult => ObjectBuilder.PreheatCodecAsync<TResult>(this, parseResult.OutCodecInfo.Codec, token),
token: token);

if (result.Data.Length is > 1 or 0)
if (result.ProtocolResult.Data.Length is > 1 or 0)
throw new ResultCardinalityMismatchException(Cardinality.One,
result.Data.Length > 1 ? Cardinality.Many : Cardinality.AtMostOne);
result.ProtocolResult.Data.Length > 1 ? Cardinality.Many : Cardinality.AtMostOne);


return result.Data.Length != 1
return result.ProtocolResult.Data.Length != 1
? throw new MissingRequiredException()
: await ObjectBuilder.BuildResultAsync<TResult>(this, result.OutCodecInfo.Codec, result.Data[0])
?? throw new ResultCardinalityMismatchException(Cardinality.One, Cardinality.NoResult);
: ObjectBuilder.BuildResult<TResult>(this, await GetPreheatedCodecAsync<TResult>(result, token),
result.ProtocolResult.Data[0])!;
}

/// <inheritdoc />
Expand All @@ -387,8 +398,8 @@ public override async Task<Json> QueryJsonAsync(string query, IDictionary<string
var result =
await ExecuteInternalAsync(query, args, Cardinality.Many, capabilities, IOFormat.Json, token: token);

return result.Data.Length == 1
? (string)result.OutCodecInfo.Codec.Deserialize(this, in result.Data[0])!
return result.ProtocolResult.Data.Length == 1
? (string)result.ProtocolResult.OutCodecInfo.Codec.Deserialize(this, in result.ProtocolResult.Data[0])!
: "[]";
}

Expand All @@ -404,8 +415,8 @@ public override async Task<IReadOnlyCollection<Json>> QueryJsonElementsAsync(str
var result = await ExecuteInternalAsync(query, args, Cardinality.Many, capabilities, IOFormat.JsonElements,
token: token);

return result.Data.Any()
? result.Data.Select(x => new Json((string?)result.OutCodecInfo.Codec.Deserialize(this, in x)))
return result.ProtocolResult.Data.Any()
? result.ProtocolResult.Data.Select(x => new Json((string?)result.ProtocolResult.OutCodecInfo.Codec.Deserialize(this, in x)))
.ToImmutableArray()
: ImmutableArray<Json>.Empty;
}
Expand Down

0 comments on commit 1bdbc4f

Please sign in to comment.