Skip to content

Commit

Permalink
Allow control over interfaces of generated DataLoader. (#7621)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Oct 17, 2024
1 parent 720cb95 commit 6ddcc7c
Show file tree
Hide file tree
Showing 30 changed files with 313 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,9 @@ public sealed class DataLoaderDefaultsAttribute : Attribute
/// Specifies if module registration code for DataLoaders shall be generated.
/// </summary>
public bool GenerateRegistrationCode { get; set; } = true;

/// <summary>
/// Specifies if interfaces for DataLoaders shall be generated.
/// </summary>
public bool GenerateInterfaces { get; set; } = true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ public void WriteBeginDataLoaderClass(
bool isPublic,
DataLoaderKind kind,
ITypeSymbol key,
ITypeSymbol value)
ITypeSymbol value,
bool withInterface)
{
_writer.WriteIndentedLine(
"{0} sealed class {1}",
"{0} sealed partial class {1}",
isPublic
? "public"
: "internal",
Expand All @@ -99,7 +100,10 @@ kind is DataLoaderKind.Group
: ": global::GreenDonut.DataLoaderBase<{0}, {1}>",
key.ToFullyQualified(),
value.ToFullyQualified());
_writer.WriteIndentedLine(", {0}", interfaceName);
if (withInterface)
{
_writer.WriteIndentedLine(", {0}", interfaceName);
}
_writer.DecreaseIndent();
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();
Expand Down Expand Up @@ -509,22 +513,34 @@ kind is DataLoaderKind.Cache

public void WriteDataLoaderGroupClass(
string groupClassName,
IReadOnlyList<GroupedDataLoaderInfo> dataLoaders)
IReadOnlyList<GroupedDataLoaderInfo> dataLoaders,
bool withInterface)
{
_writer.WriteIndentedLine("public interface I{0}", groupClassName);
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();

foreach (var dataLoader in dataLoaders)
if (withInterface)
{
_writer.WriteIndentedLine("{0} {1} {{ get; }}", dataLoader.InterfaceName, dataLoader.Name);
_writer.WriteIndentedLine("public interface I{0}", groupClassName);
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();

foreach (var dataLoader in dataLoaders)
{
_writer.WriteIndentedLine("{0} {1} {{ get; }}", dataLoader.InterfaceName, dataLoader.Name);
}

_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
_writer.WriteLine();
}

_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
_writer.WriteLine();
if (withInterface)
{
_writer.WriteIndentedLine("public sealed partial class {0} : I{0}", groupClassName);
}
else
{
_writer.WriteIndentedLine("public sealed partial class {0}", groupClassName);
}

_writer.WriteIndentedLine("public sealed class {0} : I{0}", groupClassName);
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,26 @@ public void WriteAddDataLoader(string dataLoaderType)
dataLoaderType);
}

public void WriteAddDataLoader(string dataLoaderType, string dataLoaderInterfaceType)
public void WriteAddDataLoader(
string dataLoaderType,
string dataLoaderInterfaceType,
bool withInterface)
{
_writer.WriteIndentedLine(
"global::{0}.AddDataLoader<global::{1}, global::{2}>(services);",
WellKnownTypes.DataLoaderServiceCollectionExtension,
dataLoaderInterfaceType,
dataLoaderType);
if (withInterface)
{
_writer.WriteIndentedLine(
"global::{0}.AddDataLoader<global::{1}, global::{2}>(services);",
WellKnownTypes.DataLoaderServiceCollectionExtension,
dataLoaderInterfaceType,
dataLoaderType);
}
else
{
_writer.WriteIndentedLine(
"global::{0}.AddDataLoader<global::{1}>(services);",
WellKnownTypes.DataLoaderServiceCollectionExtension,
dataLoaderType);
}
}

public void WriteAddDataLoaderGroup(string groupType, string groupInterfaceType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ public void WriteRegisterObjectTypeExtensionHelpers()
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine(
"var hooks = (global::System.Collections.Generic.List<" +
"Action<IObjectTypeDescriptor<T>>>)" +
"descriptor.Extend().Context.ContextData[hooksKey]!;");
"var hooks = (global::System.Collections.Generic.List<"
+ "Action<IObjectTypeDescriptor<T>>>)"
+ "descriptor.Extend().Context.ContextData[hooksKey]!;");
_writer.WriteIndentedLine("foreach (var configure in hooks)");
_writer.WriteIndentedLine("{");

Expand Down Expand Up @@ -169,8 +169,8 @@ public void WriteRegisterObjectTypeExtensionHelpers()
_writer.WriteIndentedLine("}");
_writer.WriteLine();
_writer.WriteIndentedLine(
"((System.Collections.Generic.List<Action<IObjectTypeDescriptor<T>>>)value!)" +
".Add(initialize);");
"((System.Collections.Generic.List<Action<IObjectTypeDescriptor<T>>>)value!)"
+ ".Add(initialize);");
}

_writer.WriteIndentedLine("});");
Expand Down Expand Up @@ -217,9 +217,9 @@ public void WriteRegisterInterfaceTypeExtensionHelpers()
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine(
"var hooks = (global::System.Collections.Generic.List<" +
"Action<IInterfaceTypeDescriptor<T>>>)" +
"descriptor.Extend().Context.ContextData[hooksKey]!;");
"var hooks = (global::System.Collections.Generic.List<"
+ "Action<IInterfaceTypeDescriptor<T>>>)"
+ "descriptor.Extend().Context.ContextData[hooksKey]!;");
_writer.WriteIndentedLine("foreach (var configure in hooks)");
_writer.WriteIndentedLine("{");

Expand Down Expand Up @@ -253,8 +253,8 @@ public void WriteRegisterInterfaceTypeExtensionHelpers()
_writer.WriteIndentedLine("}");
_writer.WriteLine();
_writer.WriteIndentedLine(
"((System.Collections.Generic.List<Action<IInterfaceTypeDescriptor<T>>>)value!)" +
".Add(initialize);");
"((System.Collections.Generic.List<Action<IInterfaceTypeDescriptor<T>>>)value!)"
+ ".Add(initialize);");
}

_writer.WriteIndentedLine("});");
Expand All @@ -266,8 +266,25 @@ public void WriteRegisterInterfaceTypeExtensionHelpers()
public void WriteRegisterDataLoader(string typeName)
=> _writer.WriteIndentedLine("builder.AddDataLoader<global::{0}>();", typeName);

public void WriteRegisterDataLoader(string typeName, string interfaceTypeName)
=> _writer.WriteIndentedLine("builder.AddDataLoader<global::{0}, global::{1}>();", interfaceTypeName, typeName);
public void WriteRegisterDataLoader(
string typeName,
string interfaceTypeName,
bool withInterface)
{
if (withInterface)
{
_writer.WriteIndentedLine(
"builder.AddDataLoader<global::{0}, global::{1}>();",
interfaceTypeName,
typeName);
}
else
{
_writer.WriteIndentedLine(
"builder.AddDataLoader<global::{0}>();",
typeName);
}
}

public void WriteRegisterDataLoaderGroup(string typeName, string interfaceTypeName)
=> _writer.WriteIndentedLine(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private static void WriteDataLoader(
buffer ??= new();
buffer.Clear();
buffer.AddRange(dataLoaderGroups);
generator.WriteDataLoaderGroupClass(dataLoaderGroup.Key, buffer);
generator.WriteDataLoaderGroupClass(dataLoaderGroup.Key, buffer, defaults.GenerateInterfaces);
}

generator.WriteEndNamespace();
Expand All @@ -133,15 +133,19 @@ private static void GenerateDataLoader(
var isPublic = dataLoader.IsPublic ?? defaults.IsPublic ?? true;
var isInterfacePublic = dataLoader.IsInterfacePublic ?? defaults.IsInterfacePublic ?? true;

generator.WriteDataLoaderInterface(dataLoader.InterfaceName, isInterfacePublic, kind, keyType, valueType);
if (defaults.GenerateInterfaces)
{
generator.WriteDataLoaderInterface(dataLoader.InterfaceName, isInterfacePublic, kind, keyType, valueType);
}

generator.WriteBeginDataLoaderClass(
dataLoader.Name,
dataLoader.InterfaceName,
isPublic,
kind,
keyType,
valueType);
valueType,
defaults.GenerateInterfaces);
generator.WriteDataLoaderConstructor(
dataLoader.Name,
kind,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Immutable;
using HotChocolate.Types.Analyzers.FileBuilders;
using HotChocolate.Types.Analyzers.Helpers;
using HotChocolate.Types.Analyzers.Models;
using Microsoft.CodeAnalysis;

Expand All @@ -13,6 +14,7 @@ public void Generate(
ImmutableArray<SyntaxInfo> syntaxInfos)
{
var module = GetDataLoaderModuleInfo(syntaxInfos);
var dataLoaderDefaults = syntaxInfos.GetDataLoaderDefaults();

if (module is null || !syntaxInfos.Any(t => t is DataLoaderInfo or RegisterDataLoaderInfo))
{
Expand Down Expand Up @@ -43,7 +45,7 @@ public void Generate(
case DataLoaderInfo dataLoader:
var typeName = $"{dataLoader.Namespace}.{dataLoader.Name}";
var interfaceTypeName = $"{dataLoader.Namespace}.{dataLoader.InterfaceName}";
generator.WriteAddDataLoader(typeName, interfaceTypeName);
generator.WriteAddDataLoader(typeName, interfaceTypeName, dataLoaderDefaults.GenerateInterfaces);

if(dataLoader.Groups.Count > 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ private static void WriteConfiguration(
List<SyntaxInfo> syntaxInfos,
ModuleInfo module)
{
var dataLoaderDefaults = syntaxInfos.GetDataLoaderDefaults();
HashSet<(string InterfaceName, string ClassName)>? groups = null;
using var generator = new ModuleFileBuilder(module.ModuleName, "Microsoft.Extensions.DependencyInjection");

Expand Down Expand Up @@ -109,15 +110,20 @@ private static void WriteConfiguration(
var typeName = $"{dataLoader.Namespace}.{dataLoader.Name}";
var interfaceTypeName = $"{dataLoader.Namespace}.{dataLoader.InterfaceName}";

generator.WriteRegisterDataLoader(typeName, interfaceTypeName);
generator.WriteRegisterDataLoader(
typeName,
interfaceTypeName,
dataLoaderDefaults.GenerateInterfaces);
hasConfigurations = true;

if(dataLoader.Groups.Count > 0)
{
groups ??= [];
foreach (var groupName in dataLoader.Groups)
{
groups.Add(($"{dataLoader.Namespace}.I{groupName}", $"{dataLoader.Namespace}.{groupName}"));
groups.Add((
$"{dataLoader.Namespace}.I{groupName}",
$"{dataLoader.Namespace}.{groupName}"));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,27 @@ public static bool RegisterService(
return true;
}

public static bool GenerateInterfaces(
this SeparatedSyntaxList<AttributeArgumentSyntax> arguments,
GeneratorSyntaxContext context)
{
var argumentSyntax = arguments.FirstOrDefault(
t => t.NameEquals?.Name.ToFullString().Trim() == "GenerateInterfaces");

if (argumentSyntax is not null)
{
var valueExpression = argumentSyntax.Expression;
var value = context.SemanticModel.GetConstantValue(valueExpression).Value;

if (value is not null)
{
return (bool)value;
}
}

return true;
}

public static bool TryGetName(
this AttributeData attribute,
[NotNullWhen(true)] out string? name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,21 @@ public static DataLoaderDefaultsInfo GetDataLoaderDefaults(
}
}

return new DataLoaderDefaultsInfo(null, null, true, true);
return new DataLoaderDefaultsInfo(null, null, true, true, true);
}

public static DataLoaderDefaultsInfo GetDataLoaderDefaults(
this List<SyntaxInfo> syntaxInfos)
{
foreach (var syntaxInfo in syntaxInfos)
{
if (syntaxInfo is DataLoaderDefaultsInfo defaults)
{
return defaults;
}
}

return new DataLoaderDefaultsInfo(null, null, true, true, true);
}

public static string CreateModuleName(string? assemblyName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public bool TryHandle(
attribList.Arguments.IsScoped(context),
attribList.Arguments.IsPublic(context),
attribList.Arguments.IsInterfacePublic(context),
attribList.Arguments.RegisterService(context));
attribList.Arguments.RegisterService(context),
attribList.Arguments.GenerateInterfaces(context));
return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ public sealed class DataLoaderDefaultsInfo(
bool? scoped,
bool? isPublic,
bool? isInterfacePublic,
bool registerServices)
bool registerServices,
bool generateInterfaces)
: SyntaxInfo
{
public bool? Scoped { get; } = scoped;
Expand All @@ -15,6 +16,8 @@ public sealed class DataLoaderDefaultsInfo(

public bool RegisterServices { get; } = registerServices;

public bool GenerateInterfaces { get; } = generateInterfaces;

public override bool Equals(object? obj)
=> obj is DataLoaderDefaultsInfo other && Equals(other);

Expand All @@ -25,8 +28,9 @@ private bool Equals(DataLoaderDefaultsInfo other)
=> Scoped.Equals(other.Scoped)
&& IsPublic.Equals(other.IsPublic)
&& IsInterfacePublic.Equals(other.IsInterfacePublic)
&& RegisterServices.Equals(other.RegisterServices);
&& RegisterServices.Equals(other.RegisterServices)
&& GenerateInterfaces.Equals(other.GenerateInterfaces);

public override int GetHashCode()
=> HashCode.Combine(Scoped, IsPublic, IsInterfacePublic, RegisterServices);
=> HashCode.Combine(Scoped, IsPublic, IsInterfacePublic, RegisterServices, GenerateInterfaces);
}
Loading

0 comments on commit 6ddcc7c

Please sign in to comment.