Skip to content

Commit

Permalink
Support raw SQL projections in Select and OrderBy statements.
Browse files Browse the repository at this point in the history
  • Loading branch information
rwasef1830 committed Oct 13, 2021
1 parent e03ecc8 commit 71ecf9d
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 14 deletions.
45 changes: 45 additions & 0 deletions src/Marten.Testing/CoreFunctionality/query_by_sql.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Linq;
using System.Threading.Tasks;
using Marten.Linq.MatchesSql;
using Marten.Linq.SqlProjection;
using Marten.Testing.Documents;
using Marten.Testing.Harness;
using Shouldly;
Expand Down Expand Up @@ -324,6 +325,50 @@ public async Task query_with_select_in_query_async()
}
}

[Fact]
public async Task query_with_select_sql_projection_async()
{
using (var session = theStore.OpenSession())
{
var u = new User {FirstName = "Jeremy", LastName = "Miller", Age = 1337};
session.Store(u);
session.SaveChanges();

#region sample_using-sql-projection-queryasync

var users = await session.Query<User>()
.Select(x => new { Age = x.SqlProjection<int>("data->>'Age'::integer") })
.ToListAsync();
var user = users.Single();

#endregion

user.Age.ShouldBe(1337);
}
}

[Fact]
public async Task query_with_order_by_sql_projection_async()
{
using (var session = theStore.OpenSession())
{
var u = new User {FirstName = "Jeremy", LastName = "Miller"};
session.Store(u);
session.SaveChanges();

#region sample_using-sql-projection-queryasync

var users = await session.Query<User>()
.OrderBy(x => x.SqlProjection<string>("data->>'FirstName'"))
.ToListAsync();
var user = users.Single();

#endregion

user.FirstName.ShouldBe("Jeremy");
}
}

[Fact]
public async Task get_sum_of_integers_asynchronously()
{
Expand Down
17 changes: 17 additions & 0 deletions src/Marten.Testing/Linq/SqlProjection/SqlProjectionTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System;
using Marten.Linq.SqlProjection;
using Shouldly;
using Xunit;

namespace Marten.Testing.Linq.SqlProjection
{
public class SqlProjectionTests
{
[Fact]
public void Throws_NotSupportedException_when_called_directly()
{
Should.Throw<NotSupportedException>(
() => new object().SqlProjection<string>("COALESCE(d.data ->> 'UserName', ?)", "baz"));
}
}
}
100 changes: 89 additions & 11 deletions src/Marten/Linq/Parsing/SelectTransformBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
using System.Reflection;
using Baseline;
using Marten.Linq.Fields;
using Marten.Linq.SqlProjection;
using Remotion.Linq.Parsing;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.Parsing
{
internal class SelectTransformBuilder : RelinqExpressionVisitor
{
private TargetObject _target;
private SelectedField _currentField;
private BindingTarget _currentTarget;

public SelectTransformBuilder(Expression clause, IFieldMapping fields, ISerializer serializer)
{
Expand All @@ -35,7 +37,7 @@ protected override Expression VisitNew(NewExpression expression)

for (var i = 0; i < parameters.Length; i++)
{
_currentField = _target.StartBinding(parameters[i].Name);
_currentTarget = _target.StartBinding(parameters[i].Name);
Visit(expression.Arguments[i]);
}

Expand All @@ -44,21 +46,76 @@ protected override Expression VisitNew(NewExpression expression)

protected override Expression VisitMember(MemberExpression node)
{
_currentField.Add(node.Member);
_currentTarget.AddMember(node.Member);
return base.VisitMember(node);
}

protected override MemberBinding VisitMemberBinding(MemberBinding node)
{
_currentField = _target.StartBinding(node.Member.Name);
_currentTarget = _target.StartBinding(node.Member.Name);

return base.VisitMemberBinding(node);
}

protected override Expression VisitMethodCall(MethodCallExpression node)
{
var fragment = SqlProjectionSqlFragment.TryParse(node);
if (fragment == null)
{
throw new NotSupportedException(
$"Method {node.Method.DeclaringType?.FullName}.{node.Method.Name} is not supported.");
}

_currentTarget.AddSqlProjection(fragment);

return base.VisitMethodCall(node);
}

public class BindingTarget : TargetObject.ISetterBinding
{
private readonly string _name;
private TargetObject.SetterBinding _field;
private TargetObject.SqlProjectionBinding _sqlProjection;

public BindingTarget(string name)
{
_name = name;
}

public void AddMember(MemberInfo memberInfo)
{
if (_sqlProjection != null)
{
throw new InvalidOperationException(
"Cannot bind to a member after having bound to a sql projection");
}

_field ??= new TargetObject.SetterBinding(_name);
_field.Field.Add(memberInfo);
}

public void AddSqlProjection(ISqlFragment sqlProjectionClause)
{
if (_field != null)
{
throw new InvalidOperationException(
"Cannot bind to a sql projection after having bound to a member.");
}

_sqlProjection = new TargetObject.SqlProjectionBinding(_name, sqlProjectionClause);
}

public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer)
{
return _field?.ToJsonBuildObjectPair(mapping, serializer)
?? _sqlProjection?.ToJsonBuildObjectPair(mapping, serializer)
?? string.Empty;
}
}

public class TargetObject
{
private readonly IList<SetterBinding> _setters = new List<SetterBinding>();
private readonly IList<ISetterBinding> _setters = new List<ISetterBinding>();

public TargetObject(Type type)
{
Expand All @@ -67,12 +124,11 @@ public TargetObject(Type type)

public Type Type { get; }

public SelectedField StartBinding(string bindingName)
public BindingTarget StartBinding(string bindingName)
{
var setter = new SetterBinding(bindingName);
_setters.Add(setter);

return setter.Field;
var bindingTarget = new BindingTarget(bindingName);
_setters.Add(bindingTarget);
return bindingTarget;
}

public string ToSelectField(IFieldMapping fields, ISerializer serializer)
Expand All @@ -81,7 +137,12 @@ public string ToSelectField(IFieldMapping fields, ISerializer serializer)
return $"jsonb_build_object({jsonBuildObjectArgs})";
}

private class SetterBinding
public interface ISetterBinding
{
string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer);
}

public class SetterBinding: ISetterBinding
{
public SetterBinding(string name)
{
Expand All @@ -101,6 +162,23 @@ public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serialize
return $"'{Name}', {locator}";
}
}

public class SqlProjectionBinding: ISetterBinding
{
public SqlProjectionBinding(string name, ISqlFragment projectionFragment)
{
Name = name;
ProjectionFragment = projectionFragment;
}

private string Name { get; }
private ISqlFragment ProjectionFragment { get; }

public string ToJsonBuildObjectPair(IFieldMapping mapping, ISerializer serializer)
{
return $"'{Name}', ({ProjectionFragment.ToSql()})";
}
}
}

public class SelectedField: IEnumerable<MemberInfo>
Expand Down
24 changes: 21 additions & 3 deletions src/Marten/Linq/SqlGeneration/Statement.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Baseline;
using Marten.Internal;
using Marten.Linq.Fields;
using Marten.Linq.Parsing;
using Marten.Linq.SqlProjection;
using Weasel.Postgresql;
using Npgsql;
using Remotion.Linq.Clauses;
Expand Down Expand Up @@ -95,9 +98,24 @@ protected void writeWhereClause(CommandBuilder sql)

protected void writeOrderByFragment(CommandBuilder sql, Ordering clause)
{
var field = Fields.FieldFor(clause.Expression);
var locator = field.ToOrderExpression(clause.Expression);
sql.Append(locator);
var handled = false;

if (clause.Expression is MethodCallExpression methodCallExpression)
{
var sqlProjectionFragment = SqlProjectionSqlFragment.TryParse(methodCallExpression);
if (sqlProjectionFragment != null)
{
sqlProjectionFragment.Apply(sql);
handled = true;
}
}

if (!handled)
{
var field = Fields.FieldFor(clause.Expression);
var locator = field.ToOrderExpression(clause.Expression);
sql.Append(locator);
}

if (clause.OrderingDirection == OrderingDirection.Desc) sql.Append(" desc");
}
Expand Down
18 changes: 18 additions & 0 deletions src/Marten/Linq/SqlProjection/SqlProjectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System;
using System.Reflection;

namespace Marten.Linq.SqlProjection
{
public static class SqlProjectionExtensions
{
public static readonly MethodInfo MethodInfo = typeof(SqlProjectionExtensions)
.GetMethod(nameof(SqlProjection),
BindingFlags.Public | BindingFlags.Static);

public static T SqlProjection<T>(this object doc, string sql, params object[] parameters)
{
throw new NotSupportedException(
$"{nameof(SqlProjection)} extension method can only be used in Marten Linq queries.");
}
}
}
38 changes: 38 additions & 0 deletions src/Marten/Linq/SqlProjection/SqlProjectionSqlFragment.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using System;
using System.Linq.Expressions;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.SqlProjection
{
public static class SqlProjectionSqlFragment
{
public static ISqlFragment TryParse(MethodCallExpression node, Func<Expression, Expression> visit = null)
{
if (node == null)
{
return null;
}

visit ??= x => x;

if (!node.Method.IsGenericMethod ||
node.Method.GetGenericMethodDefinition() != SqlProjectionExtensions.MethodInfo)
{
return null;
}

if (visit(node.Arguments[1]) is not ConstantExpression { Value: string sql })
{
throw new NotSupportedException("SqlProjection first parameter needs to resolve to a string");
}

if (visit(node.Arguments[2]) is not ConstantExpression { Value: object[] sqlArguments })
{
throw new NotSupportedException("SqlProjection second parameter needs to resolve to an object[]");
}

var whereFragment = new WhereFragment(sql, sqlArguments);
return whereFragment;
}
}
}

0 comments on commit 71ecf9d

Please sign in to comment.