Fix expression visitor for WHERE IN

This commit is contained in:
Stephan
2018-11-06 15:25:36 +01:00
parent cfa3f58c33
commit f3e9e282fa
2 changed files with 141 additions and 268 deletions

View File

@@ -32,7 +32,7 @@ namespace Umbraco.Core.Persistence.Querying
/// <summary>
/// Gets or sets the SQL syntax provider for the current database.
/// </summary>
protected ISqlSyntaxProvider SqlSyntax { get; private set; }
protected ISqlSyntaxProvider SqlSyntax { get; }
/// <summary>
/// Gets the list of SQL parameters.
@@ -56,6 +56,8 @@ namespace Umbraco.Core.Persistence.Querying
/// <remarks>Also populates the SQL parameters.</remarks>
public virtual string Visit(Expression expression)
{
if (expression == null) return string.Empty;
// if the expression is a CachedExpression,
// visit the inner expression if not already visited
var cachedExpression = expression as CachedExpression;
@@ -65,8 +67,6 @@ namespace Umbraco.Core.Persistence.Querying
expression = cachedExpression.InnerExpression;
}
if (expression == null) return string.Empty;
string result;
switch (expression.NodeType)
@@ -135,35 +135,28 @@ namespace Umbraco.Core.Persistence.Querying
// if the expression is a CachedExpression,
// and is not already compiled, assign the result
if (cachedExpression != null)
{
if (cachedExpression.Visited == false)
cachedExpression.VisitResult = result;
result = cachedExpression.VisitResult;
}
return result;
if (cachedExpression == null)
return result;
if (!cachedExpression.Visited)
cachedExpression.VisitResult = result;
return cachedExpression.VisitResult;
}
protected abstract string VisitMemberAccess(MemberExpression m);
protected virtual string VisitLambda(LambdaExpression lambda)
{
if (lambda.Body.NodeType == ExpressionType.MemberAccess)
{
var m = lambda.Body as MemberExpression;
if (m != null && m.Expression != null)
if (lambda.Body.NodeType == ExpressionType.MemberAccess &&
lambda.Body is MemberExpression memberExpression && memberExpression.Expression != null)
{
//This deals with members that are boolean (i.e. x => IsTrashed )
var r = VisitMemberAccess(m);
var result = VisitMemberAccess(memberExpression);
SqlParameters.Add(true);
return Visited ? string.Empty : string.Format("{0} = @{1}", r, SqlParameters.Count - 1);
return Visited ? string.Empty : $"{result} = @{SqlParameters.Count - 1}";
}
}
return Visit(lambda.Body);
}
@@ -248,21 +241,10 @@ namespace Umbraco.Core.Persistence.Querying
{
case "MOD":
case "COALESCE":
//don't execute if compiled
if (Visited == false)
{
return string.Format("{0}({1},{2})", operand, left, right);
}
//already compiled, return
return string.Empty;
return Visited ? string.Empty : $"{operand}({left},{right})";
default:
//don't execute if compiled
if (Visited == false)
{
return string.Concat("(", left, " ", operand, " ", right, ")");
}
//already compiled, return
return string.Empty;
return Visited ? string.Empty : $"({left} {operand} {right})";
}
}
@@ -284,10 +266,10 @@ namespace Umbraco.Core.Persistence.Querying
return list;
}
protected virtual string VisitNew(NewExpression nex)
protected virtual string VisitNew(NewExpression newExpression)
{
// TODO : check !
var member = Expression.Convert(nex, typeof(object));
var member = Expression.Convert(newExpression, typeof(object));
var lambda = Expression.Lambda<Func<object>>(member);
try
{
@@ -295,20 +277,16 @@ namespace Umbraco.Core.Persistence.Querying
var o = getter();
SqlParameters.Add(o);
return Visited ? string.Empty : $"@{SqlParameters.Count - 1}";
}
catch (InvalidOperationException)
{
if (Visited) return string.Empty;
if (Visited)
return string.Empty;
var exprs = VisitExpressionList(nex.Arguments);
var r = new StringBuilder();
foreach (var e in exprs)
{
if (r.Length > 0) r.Append(",");
r.Append(e);
}
return r.ToString();
var exprs = VisitExpressionList(newExpression.Arguments);
return string.Join(",", exprs);
}
}
@@ -323,6 +301,7 @@ namespace Umbraco.Core.Persistence.Querying
return "null";
SqlParameters.Add(c.Value);
return Visited ? string.Empty : $"@{SqlParameters.Count - 1}";
}
@@ -375,27 +354,11 @@ namespace Umbraco.Core.Persistence.Querying
protected virtual string VisitNewArray(NewArrayExpression na)
{
var exprs = VisitExpressionList(na.Expressions);
//don't execute if compiled
if (Visited == false)
{
var r = new StringBuilder();
foreach (var e in exprs)
{
r.Append(r.Length > 0 ? "," + e : e);
}
return r.ToString();
}
//already compiled, return
return string.Empty;
return Visited ? string.Empty : string.Join(",", exprs);
}
protected virtual List<object> VisitNewArrayFromExpressionList(NewArrayExpression na)
{
var exprs = VisitExpressionList(na.Expressions);
return exprs;
}
=> VisitExpressionList(na.Expressions);
protected virtual string BindOperant(ExpressionType e)
{
@@ -436,50 +399,60 @@ namespace Umbraco.Core.Persistence.Querying
protected virtual string VisitMethodCall(MethodCallExpression m)
{
//Here's what happens with a MethodCallExpression:
// If a method is called that contains a single argument,
// then m.Object is the object on the left hand side of the method call, example:
// x.Path.StartsWith(content.Path)
// m.Object = x.Path
// and m.Arguments.Length == 1, therefor m.Arguments[0] == content.Path
// If a method is called that contains multiple arguments, then m.Object == null and the
// m.Arguments collection contains the left hand side of the method call, example:
// x.Path.SqlStartsWith(content.Path, TextColumnType.NVarchar)
// m.Object == null
// m.Arguments.Length == 3, therefor, m.Arguments[0] == x.Path, m.Arguments[1] == content.Path, m.Arguments[2] == TextColumnType.NVarchar
// So, we need to cater for these scenarios.
// m.Object is the expression that represent the instance for instance method class, or null for static method calls
// m.Arguments is the collection of expressions that represent arguments of the called method
// m.MethodInfo is the method info for the method to be called
var objectForMethod = m.Object ?? m.Arguments[0];
var visitedObjectForMethod = Visit(objectForMethod);
// assume that static methods are extension methods (probably not ok)
// and then, the method object is its first argument - get "safe" object
var methodObject = m.Object ?? m.Arguments[0];
var visitedMethodObject = Visit(methodObject);
// and then, "safe" arguments are what would come after the first arg
var methodArgs = m.Object == null
? m.Arguments.Skip(1).ToArray()
: m.Arguments.ToArray();
? new ReadOnlyCollection<Expression>(m.Arguments.Skip(1).ToList())
: m.Arguments;
switch (m.Method.Name)
{
case "ToString":
SqlParameters.Add(objectForMethod.ToString());
//don't execute if compiled
if (Visited == false)
return string.Format("@{0}", SqlParameters.Count - 1);
//already compiled, return
return string.Empty;
SqlParameters.Add(methodObject.ToString());
return Visited ? string.Empty : $"@{SqlParameters.Count - 1}";
case "ToUpper":
//don't execute if compiled
if (Visited == false)
return string.Format("upper({0})", visitedObjectForMethod);
//already compiled, return
return string.Empty;
return Visited ? string.Empty : $"upper({visitedMethodObject})";
case "ToLower":
//don't execute if compiled
if (Visited == false)
return string.Format("lower({0})", visitedObjectForMethod);
//already compiled, return
return string.Empty;
return Visited ? string.Empty : $"lower({visitedMethodObject})";
case "Contains":
// for 'Contains', it can either be the string.Contains(string) method, or a collection Contains
// method, which would then need to become a SQL IN clause - but beware that string is
// an enumerable of char, and string.Contains(char) is an extension method - but NOT an SQL IN
var isCollectionContains =
(
m.Object == null && // static (extension?) method
m.Arguments.Count == 2 && // with two args
m.Arguments[0].Type != typeof(string) && // but not for string
TypeHelper.IsTypeAssignableFrom<IEnumerable>(m.Arguments[0].Type) && // first arg being an enumerable
m.Arguments[1].NodeType == ExpressionType.MemberAccess // second arg being a member access
) ||
(
m.Object != null && // instance method
TypeHelper.IsTypeAssignableFrom<IEnumerable>(m.Object.Type) && // of an enumerable
m.Arguments.Count == 1 && // with 1 arg
m.Arguments[0].NodeType == ExpressionType.MemberAccess // arg being a member access
);
if (isCollectionContains)
goto case "SqlIn";
else
goto case "Contains**String";
case "SqlWildcard":
case "StartsWith":
case "EndsWith":
case "Contains":
case "Contains**String": // see "Contains" above
case "Equals":
case "SqlStartsWith":
case "SqlEndsWith":
@@ -490,18 +463,6 @@ namespace Umbraco.Core.Persistence.Querying
case "InvariantContains":
case "InvariantEquals":
//special case, if it is 'Contains' and the argumet that Contains is being called on is
//Enumerable and the methodArgs is the actual member access, then it's an SQL IN claus
if (m.Object == null
&& m.Arguments[0].Type != typeof(string)
&& m.Arguments.Count == 2
&& methodArgs.Length == 1
&& methodArgs[0].NodeType == ExpressionType.MemberAccess
&& TypeHelper.IsTypeAssignableFrom<IEnumerable>(m.Arguments[0].Type))
{
goto case "SqlIn";
}
string compareValue;
if (methodArgs[0].NodeType != ExpressionType.Constant)
@@ -526,7 +487,7 @@ namespace Umbraco.Core.Persistence.Querying
//then check if the col type argument has been passed to the current method (this will be the case for methods like
// SqlContains and other Sql methods)
if (methodArgs.Length > 1)
if (methodArgs.Count > 1)
{
var colTypeArg = methodArgs.FirstOrDefault(x => x is ConstantExpression && x.Type == typeof(TextColumnType));
if (colTypeArg != null)
@@ -535,7 +496,7 @@ namespace Umbraco.Core.Persistence.Querying
}
}
return HandleStringComparison(visitedObjectForMethod, compareValue, m.Method.Name, colType);
return HandleStringComparison(visitedMethodObject, compareValue, m.Method.Name, colType);
case "Replace":
string searchValue;
@@ -581,14 +542,10 @@ namespace Umbraco.Core.Persistence.Querying
}
SqlParameters.Add(RemoveQuote(searchValue));
SqlParameters.Add(RemoveQuote(replaceValue));
//don't execute if compiled
if (Visited == false)
return string.Format("replace({0}, @{1}, @{2})", visitedObjectForMethod, SqlParameters.Count - 2, SqlParameters.Count - 1);
//already compiled, return
return string.Empty;
return Visited ? string.Empty : $"replace({visitedMethodObject}, @{SqlParameters.Count - 2}, @{SqlParameters.Count - 1})";
//case "Substring":
// var startIndex = Int32.Parse(args[0].ToString()) + 1;
@@ -624,30 +581,33 @@ namespace Umbraco.Core.Persistence.Querying
case "SqlIn":
if (m.Object == null && methodArgs.Length == 1 && methodArgs[0].NodeType == ExpressionType.MemberAccess)
if (methodArgs.Count != 1 || methodArgs[0].NodeType != ExpressionType.MemberAccess)
throw new NotSupportedException("SqlIn must contain the member being accessed.");
var memberAccess = VisitMemberAccess((MemberExpression) methodArgs[0]);
var inMember = Expression.Convert(methodObject, typeof(object));
var inLambda = Expression.Lambda<Func<object>>(inMember);
var inGetter = inLambda.Compile();
var inArgs = (IEnumerable) inGetter();
var inBuilder = new StringBuilder();
var inFirst = true;
inBuilder.Append(memberAccess);
inBuilder.Append(" IN (");
foreach (var e in inArgs)
{
var memberAccess = VisitMemberAccess((MemberExpression) methodArgs[0]);
var member = Expression.Convert(m.Arguments[0], typeof(object));
var lambda = Expression.Lambda<Func<object>>(member);
var getter = lambda.Compile();
var inArgs = (IEnumerable)getter();
var sIn = new StringBuilder();
foreach (var e in inArgs)
{
SqlParameters.Add(e);
sIn.AppendFormat("{0}{1}",
sIn.Length > 0 ? "," : "",
string.Format("@{0}", SqlParameters.Count - 1));
}
return string.Format("{0} IN ({1})", memberAccess, sIn);
SqlParameters.Add(e);
if (inFirst) inFirst = false; else inBuilder.Append(",");
inBuilder.Append("@");
inBuilder.Append(SqlParameters.Count - 1);
}
throw new NotSupportedException("SqlIn must contain the member being accessed");
inBuilder.Append(")");
return inBuilder.ToString();
//case "Desc":
// return string.Format("{0} DESC", r);
@@ -706,19 +666,13 @@ namespace Umbraco.Core.Persistence.Querying
}
public virtual string GetQuotedTableName(string tableName)
{
return Visited ? tableName : string.Format("\"{0}\"", tableName);
}
=> GetQuotedName(tableName);
public virtual string GetQuotedColumnName(string columnName)
{
return Visited ? columnName : string.Format("\"{0}\"", columnName);
}
=> GetQuotedName(columnName);
public virtual string GetQuotedName(string name)
{
return Visited ? name : string.Format("\"{0}\"", name);
}
=> Visited ? name : "\"" + name + "\"";
protected string HandleStringComparison(string col, string val, string verb, TextColumnType columnType)
{
@@ -726,115 +680,38 @@ namespace Umbraco.Core.Persistence.Querying
{
case "SqlWildcard":
SqlParameters.Add(RemoveQuote(val));
//don't execute if compiled
if (Visited == false)
return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
//already compiled, return
return string.Empty;
return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
case "Equals":
SqlParameters.Add(RemoveQuote(val));
//don't execute if compiled
if (Visited == false)
return SqlSyntax.GetStringColumnEqualComparison(col, SqlParameters.Count - 1, columnType);
//already compiled, return
return string.Empty;
case "StartsWith":
SqlParameters.Add(string.Format("{0}{1}",
RemoveQuote(val),
SqlSyntax.GetWildcardPlaceholder()));
//don't execute if compiled
if (Visited == false)
return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
//already compiled, return
return string.Empty;
case "EndsWith":
SqlParameters.Add(string.Format("{0}{1}",
SqlSyntax.GetWildcardPlaceholder(),
RemoveQuote(val)));
//don't execute if compiled
if (Visited == false)
return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
//already compiled, return
return string.Empty;
case "Contains":
SqlParameters.Add(string.Format("{0}{1}{0}",
SqlSyntax.GetWildcardPlaceholder(),
RemoveQuote(val)));
//don't execute if compiled
if (Visited == false)
return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
//already compiled, return
return string.Empty;
case "InvariantEquals":
case "SqlEquals":
//recurse
return HandleStringComparison(col, val, "Equals", columnType);
SqlParameters.Add(RemoveQuote(val));
return Visited ? string.Empty : SqlSyntax.GetStringColumnEqualComparison(col, SqlParameters.Count - 1, columnType);
case "StartsWith":
case "InvariantStartsWith":
case "SqlStartsWith":
//recurse
return HandleStringComparison(col, val, "StartsWith", columnType);
SqlParameters.Add(RemoveQuote(val) + SqlSyntax.GetWildcardPlaceholder());
return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
case "EndsWith":
case "InvariantEndsWith":
case "SqlEndsWith":
//recurse
return HandleStringComparison(col, val, "EndsWith", columnType);
SqlParameters.Add(SqlSyntax.GetWildcardPlaceholder() + RemoveQuote(val));
return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
case "Contains":
case "InvariantContains":
case "SqlContains":
//recurse
return HandleStringComparison(col, val, "Contains", columnType);
var wildcardPlaceholder = SqlSyntax.GetWildcardPlaceholder();
SqlParameters.Add(wildcardPlaceholder + RemoveQuote(val) + wildcardPlaceholder);
return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType);
default:
throw new ArgumentOutOfRangeException("verb");
throw new ArgumentOutOfRangeException(nameof(verb));
}
}
//public virtual string GetQuotedValue(object value, Type fieldType, Func<object, string> escapeCallback = null, Func<Type, bool> shouldQuoteCallback = null)
//{
// if (value == null) return "NULL";
// if (escapeCallback == null)
// {
// escapeCallback = EscapeParam;
// }
// if (shouldQuoteCallback == null)
// {
// shouldQuoteCallback = ShouldQuoteValue;
// }
// if (!fieldType.UnderlyingSystemType.IsValueType && fieldType != typeof(string))
// {
// //if (TypeSerializer.CanCreateFromString(fieldType))
// //{
// // return "'" + escapeCallback(TypeSerializer.SerializeToString(value)) + "'";
// //}
// throw new NotSupportedException(
// string.Format("Property of type: {0} is not supported", fieldType.FullName));
// }
// if (fieldType == typeof(int))
// return ((int)value).ToString(CultureInfo.InvariantCulture);
// if (fieldType == typeof(float))
// return ((float)value).ToString(CultureInfo.InvariantCulture);
// if (fieldType == typeof(double))
// return ((double)value).ToString(CultureInfo.InvariantCulture);
// if (fieldType == typeof(decimal))
// return ((decimal)value).ToString(CultureInfo.InvariantCulture);
// if (fieldType == typeof(DateTime))
// {
// return "'" + escapeCallback(((DateTime)value).ToIsoString()) + "'";
// }
// if (fieldType == typeof(bool))
// return ((bool)value) ? Convert.ToString(1, CultureInfo.InvariantCulture) : Convert.ToString(0, CultureInfo.InvariantCulture);
// return shouldQuoteCallback(fieldType)
// ? "'" + escapeCallback(value) + "'"
// : value.ToString();
//}
public virtual string EscapeParam(object paramValue, ISqlSyntaxProvider sqlSyntax)
{
return paramValue == null
@@ -842,34 +719,14 @@ namespace Umbraco.Core.Persistence.Querying
: sqlSyntax.EscapeString(paramValue.ToString());
}
public virtual bool ShouldQuoteValue(Type fieldType)
{
return true;
}
protected virtual string RemoveQuote(string exp)
{
if ((exp.StartsWith("\"") || exp.StartsWith("`") || exp.StartsWith("'"))
&&
(exp.EndsWith("\"") || exp.EndsWith("`") || exp.EndsWith("'")))
{
exp = exp.Remove(0, 1);
exp = exp.Remove(exp.Length - 1, 1);
}
return exp;
if (exp.IsNullOrWhiteSpace()) return exp;
var c = exp[0];
return (c == '"' || c == '`' || c == '\'') && exp[exp.Length - 1] == c
? exp.Substring(1, exp.Length - 2)
: exp;
}
//protected virtual string RemoveQuoteFromAlias(string expression)
//{
// if ((expression.StartsWith("\"") || expression.StartsWith("`") || expression.StartsWith("'"))
// &&
// (expression.EndsWith("\"") || expression.EndsWith("`") || expression.EndsWith("'")))
// {
// expression = expression.Remove(0, 1);
// expression = expression.Remove(expression.Length - 1, 1);
// }
// return expression;
//}
}
}

View File

@@ -52,9 +52,9 @@ namespace Umbraco.Tests.Persistence.Querying
}
[Test]
public void Can_Query_With_Content_Type_Aliases()
public void Can_Query_With_Content_Type_Aliases_IEnumerable()
{
//Arrange
//Arrange - Contains is IEnumerable.Contains extension method
var aliases = new[] { "Test1", "Test2" };
Expression<Func<IMedia, bool>> predicate = content => aliases.Contains(content.ContentType.Alias);
var modelToSqlExpressionHelper = new ModelToSqlExpressionVisitor<IContent>(SqlContext.SqlSyntax, Mappers);
@@ -67,6 +67,22 @@ namespace Umbraco.Tests.Persistence.Querying
Assert.AreEqual("Test2", modelToSqlExpressionHelper.GetSqlParameters()[2]);
}
[Test]
public void Can_Query_With_Content_Type_Aliases_List()
{
//Arrange - Contains is List.Contains instance method
var aliases = new System.Collections.Generic.List<string> { "Test1", "Test2" };
Expression<Func<IMedia, bool>> predicate = content => aliases.Contains(content.ContentType.Alias);
var modelToSqlExpressionHelper = new ModelToSqlExpressionVisitor<IContent>(SqlContext.SqlSyntax, Mappers);
var result = modelToSqlExpressionHelper.Visit(predicate);
Debug.Print("Model to Sql ExpressionHelper: \n" + result);
Assert.AreEqual("[cmsContentType].[alias] IN (@1,@2)", result);
Assert.AreEqual("Test1", modelToSqlExpressionHelper.GetSqlParameters()[1]);
Assert.AreEqual("Test2", modelToSqlExpressionHelper.GetSqlParameters()[2]);
}
[Test]
public void CachedExpression_Can_Verify_Path_StartsWith_Predicate_In_Same_Result()
{