From f3e9e282faac9142b92c0f4fc9d9100983f96ea9 Mon Sep 17 00:00:00 2001 From: Stephan Date: Tue, 6 Nov 2018 15:25:36 +0100 Subject: [PATCH] Fix expression visitor for WHERE IN --- .../Querying/ExpressionVisitorBase.cs | 389 ++++++------------ .../Persistence/Querying/ExpressionTests.cs | 20 +- 2 files changed, 141 insertions(+), 268 deletions(-) diff --git a/src/Umbraco.Core/Persistence/Querying/ExpressionVisitorBase.cs b/src/Umbraco.Core/Persistence/Querying/ExpressionVisitorBase.cs index 76116a8d03..d4d27fa4df 100644 --- a/src/Umbraco.Core/Persistence/Querying/ExpressionVisitorBase.cs +++ b/src/Umbraco.Core/Persistence/Querying/ExpressionVisitorBase.cs @@ -32,7 +32,7 @@ namespace Umbraco.Core.Persistence.Querying /// /// Gets or sets the SQL syntax provider for the current database. /// - protected ISqlSyntaxProvider SqlSyntax { get; private set; } + protected ISqlSyntaxProvider SqlSyntax { get; } /// /// Gets the list of SQL parameters. @@ -56,6 +56,8 @@ namespace Umbraco.Core.Persistence.Querying /// Also populates the SQL parameters. public virtual string Visit(Expression expression) { + if (expression == null) return string.Empty; + // if the expression is a CachedExpression, // visit the inner expression if not already visited var cachedExpression = expression as CachedExpression; @@ -65,8 +67,6 @@ namespace Umbraco.Core.Persistence.Querying expression = cachedExpression.InnerExpression; } - if (expression == null) return string.Empty; - string result; switch (expression.NodeType) @@ -135,35 +135,28 @@ namespace Umbraco.Core.Persistence.Querying // if the expression is a CachedExpression, // and is not already compiled, assign the result - if (cachedExpression != null) - { - if (cachedExpression.Visited == false) - cachedExpression.VisitResult = result; - result = cachedExpression.VisitResult; - } - - return result; + if (cachedExpression == null) + return result; + if (!cachedExpression.Visited) + cachedExpression.VisitResult = result; + return cachedExpression.VisitResult; } protected abstract string VisitMemberAccess(MemberExpression m); protected virtual string VisitLambda(LambdaExpression lambda) { - if (lambda.Body.NodeType == ExpressionType.MemberAccess) - { - var m = lambda.Body as MemberExpression; - - if (m != null && m.Expression != null) + if (lambda.Body.NodeType == ExpressionType.MemberAccess && + lambda.Body is MemberExpression memberExpression && memberExpression.Expression != null) { //This deals with members that are boolean (i.e. x => IsTrashed ) - var r = VisitMemberAccess(m); + var result = VisitMemberAccess(memberExpression); SqlParameters.Add(true); - return Visited ? string.Empty : string.Format("{0} = @{1}", r, SqlParameters.Count - 1); + return Visited ? string.Empty : $"{result} = @{SqlParameters.Count - 1}"; } - } return Visit(lambda.Body); } @@ -248,21 +241,10 @@ namespace Umbraco.Core.Persistence.Querying { case "MOD": case "COALESCE": - //don't execute if compiled - if (Visited == false) - { - return string.Format("{0}({1},{2})", operand, left, right); - } - //already compiled, return - return string.Empty; + return Visited ? string.Empty : $"{operand}({left},{right})"; + default: - //don't execute if compiled - if (Visited == false) - { - return string.Concat("(", left, " ", operand, " ", right, ")"); - } - //already compiled, return - return string.Empty; + return Visited ? string.Empty : $"({left} {operand} {right})"; } } @@ -284,10 +266,10 @@ namespace Umbraco.Core.Persistence.Querying return list; } - protected virtual string VisitNew(NewExpression nex) + protected virtual string VisitNew(NewExpression newExpression) { // TODO : check ! - var member = Expression.Convert(nex, typeof(object)); + var member = Expression.Convert(newExpression, typeof(object)); var lambda = Expression.Lambda>(member); try { @@ -295,20 +277,16 @@ namespace Umbraco.Core.Persistence.Querying var o = getter(); SqlParameters.Add(o); + return Visited ? string.Empty : $"@{SqlParameters.Count - 1}"; } catch (InvalidOperationException) { - if (Visited) return string.Empty; + if (Visited) + return string.Empty; - var exprs = VisitExpressionList(nex.Arguments); - var r = new StringBuilder(); - foreach (var e in exprs) - { - if (r.Length > 0) r.Append(","); - r.Append(e); - } - return r.ToString(); + var exprs = VisitExpressionList(newExpression.Arguments); + return string.Join(",", exprs); } } @@ -323,6 +301,7 @@ namespace Umbraco.Core.Persistence.Querying return "null"; SqlParameters.Add(c.Value); + return Visited ? string.Empty : $"@{SqlParameters.Count - 1}"; } @@ -375,27 +354,11 @@ namespace Umbraco.Core.Persistence.Querying protected virtual string VisitNewArray(NewArrayExpression na) { var exprs = VisitExpressionList(na.Expressions); - - //don't execute if compiled - if (Visited == false) - { - var r = new StringBuilder(); - foreach (var e in exprs) - { - r.Append(r.Length > 0 ? "," + e : e); - } - - return r.ToString(); - } - //already compiled, return - return string.Empty; + return Visited ? string.Empty : string.Join(",", exprs); } protected virtual List VisitNewArrayFromExpressionList(NewArrayExpression na) - { - var exprs = VisitExpressionList(na.Expressions); - return exprs; - } + => VisitExpressionList(na.Expressions); protected virtual string BindOperant(ExpressionType e) { @@ -436,50 +399,60 @@ namespace Umbraco.Core.Persistence.Querying protected virtual string VisitMethodCall(MethodCallExpression m) { - //Here's what happens with a MethodCallExpression: - // If a method is called that contains a single argument, - // then m.Object is the object on the left hand side of the method call, example: - // x.Path.StartsWith(content.Path) - // m.Object = x.Path - // and m.Arguments.Length == 1, therefor m.Arguments[0] == content.Path - // If a method is called that contains multiple arguments, then m.Object == null and the - // m.Arguments collection contains the left hand side of the method call, example: - // x.Path.SqlStartsWith(content.Path, TextColumnType.NVarchar) - // m.Object == null - // m.Arguments.Length == 3, therefor, m.Arguments[0] == x.Path, m.Arguments[1] == content.Path, m.Arguments[2] == TextColumnType.NVarchar - // So, we need to cater for these scenarios. + // m.Object is the expression that represent the instance for instance method class, or null for static method calls + // m.Arguments is the collection of expressions that represent arguments of the called method + // m.MethodInfo is the method info for the method to be called - var objectForMethod = m.Object ?? m.Arguments[0]; - var visitedObjectForMethod = Visit(objectForMethod); + // assume that static methods are extension methods (probably not ok) + // and then, the method object is its first argument - get "safe" object + var methodObject = m.Object ?? m.Arguments[0]; + var visitedMethodObject = Visit(methodObject); + // and then, "safe" arguments are what would come after the first arg var methodArgs = m.Object == null - ? m.Arguments.Skip(1).ToArray() - : m.Arguments.ToArray(); + ? new ReadOnlyCollection(m.Arguments.Skip(1).ToList()) + : m.Arguments; switch (m.Method.Name) { case "ToString": - SqlParameters.Add(objectForMethod.ToString()); - //don't execute if compiled - if (Visited == false) - return string.Format("@{0}", SqlParameters.Count - 1); - //already compiled, return - return string.Empty; + SqlParameters.Add(methodObject.ToString()); + return Visited ? string.Empty : $"@{SqlParameters.Count - 1}"; + case "ToUpper": - //don't execute if compiled - if (Visited == false) - return string.Format("upper({0})", visitedObjectForMethod); - //already compiled, return - return string.Empty; + return Visited ? string.Empty : $"upper({visitedMethodObject})"; + case "ToLower": - //don't execute if compiled - if (Visited == false) - return string.Format("lower({0})", visitedObjectForMethod); - //already compiled, return - return string.Empty; + return Visited ? string.Empty : $"lower({visitedMethodObject})"; + + case "Contains": + // for 'Contains', it can either be the string.Contains(string) method, or a collection Contains + // method, which would then need to become a SQL IN clause - but beware that string is + // an enumerable of char, and string.Contains(char) is an extension method - but NOT an SQL IN + + var isCollectionContains = + ( + m.Object == null && // static (extension?) method + m.Arguments.Count == 2 && // with two args + m.Arguments[0].Type != typeof(string) && // but not for string + TypeHelper.IsTypeAssignableFrom(m.Arguments[0].Type) && // first arg being an enumerable + m.Arguments[1].NodeType == ExpressionType.MemberAccess // second arg being a member access + ) || + ( + m.Object != null && // instance method + TypeHelper.IsTypeAssignableFrom(m.Object.Type) && // of an enumerable + m.Arguments.Count == 1 && // with 1 arg + m.Arguments[0].NodeType == ExpressionType.MemberAccess // arg being a member access + ); + + if (isCollectionContains) + goto case "SqlIn"; + else + goto case "Contains**String"; + case "SqlWildcard": case "StartsWith": case "EndsWith": - case "Contains": + case "Contains**String": // see "Contains" above case "Equals": case "SqlStartsWith": case "SqlEndsWith": @@ -490,18 +463,6 @@ namespace Umbraco.Core.Persistence.Querying case "InvariantContains": case "InvariantEquals": - //special case, if it is 'Contains' and the argumet that Contains is being called on is - //Enumerable and the methodArgs is the actual member access, then it's an SQL IN claus - if (m.Object == null - && m.Arguments[0].Type != typeof(string) - && m.Arguments.Count == 2 - && methodArgs.Length == 1 - && methodArgs[0].NodeType == ExpressionType.MemberAccess - && TypeHelper.IsTypeAssignableFrom(m.Arguments[0].Type)) - { - goto case "SqlIn"; - } - string compareValue; if (methodArgs[0].NodeType != ExpressionType.Constant) @@ -526,7 +487,7 @@ namespace Umbraco.Core.Persistence.Querying //then check if the col type argument has been passed to the current method (this will be the case for methods like // SqlContains and other Sql methods) - if (methodArgs.Length > 1) + if (methodArgs.Count > 1) { var colTypeArg = methodArgs.FirstOrDefault(x => x is ConstantExpression && x.Type == typeof(TextColumnType)); if (colTypeArg != null) @@ -535,7 +496,7 @@ namespace Umbraco.Core.Persistence.Querying } } - return HandleStringComparison(visitedObjectForMethod, compareValue, m.Method.Name, colType); + return HandleStringComparison(visitedMethodObject, compareValue, m.Method.Name, colType); case "Replace": string searchValue; @@ -581,14 +542,10 @@ namespace Umbraco.Core.Persistence.Querying } SqlParameters.Add(RemoveQuote(searchValue)); - SqlParameters.Add(RemoveQuote(replaceValue)); //don't execute if compiled - if (Visited == false) - return string.Format("replace({0}, @{1}, @{2})", visitedObjectForMethod, SqlParameters.Count - 2, SqlParameters.Count - 1); - //already compiled, return - return string.Empty; + return Visited ? string.Empty : $"replace({visitedMethodObject}, @{SqlParameters.Count - 2}, @{SqlParameters.Count - 1})"; //case "Substring": // var startIndex = Int32.Parse(args[0].ToString()) + 1; @@ -624,30 +581,33 @@ namespace Umbraco.Core.Persistence.Querying case "SqlIn": - if (m.Object == null && methodArgs.Length == 1 && methodArgs[0].NodeType == ExpressionType.MemberAccess) + if (methodArgs.Count != 1 || methodArgs[0].NodeType != ExpressionType.MemberAccess) + throw new NotSupportedException("SqlIn must contain the member being accessed."); + + var memberAccess = VisitMemberAccess((MemberExpression) methodArgs[0]); + + var inMember = Expression.Convert(methodObject, typeof(object)); + var inLambda = Expression.Lambda>(inMember); + var inGetter = inLambda.Compile(); + + var inArgs = (IEnumerable) inGetter(); + + var inBuilder = new StringBuilder(); + var inFirst = true; + + inBuilder.Append(memberAccess); + inBuilder.Append(" IN ("); + + foreach (var e in inArgs) { - var memberAccess = VisitMemberAccess((MemberExpression) methodArgs[0]); - - var member = Expression.Convert(m.Arguments[0], typeof(object)); - var lambda = Expression.Lambda>(member); - var getter = lambda.Compile(); - - var inArgs = (IEnumerable)getter(); - - var sIn = new StringBuilder(); - foreach (var e in inArgs) - { - SqlParameters.Add(e); - - sIn.AppendFormat("{0}{1}", - sIn.Length > 0 ? "," : "", - string.Format("@{0}", SqlParameters.Count - 1)); - } - - return string.Format("{0} IN ({1})", memberAccess, sIn); + SqlParameters.Add(e); + if (inFirst) inFirst = false; else inBuilder.Append(","); + inBuilder.Append("@"); + inBuilder.Append(SqlParameters.Count - 1); } - throw new NotSupportedException("SqlIn must contain the member being accessed"); + inBuilder.Append(")"); + return inBuilder.ToString(); //case "Desc": // return string.Format("{0} DESC", r); @@ -706,19 +666,13 @@ namespace Umbraco.Core.Persistence.Querying } public virtual string GetQuotedTableName(string tableName) - { - return Visited ? tableName : string.Format("\"{0}\"", tableName); - } + => GetQuotedName(tableName); public virtual string GetQuotedColumnName(string columnName) - { - return Visited ? columnName : string.Format("\"{0}\"", columnName); - } + => GetQuotedName(columnName); public virtual string GetQuotedName(string name) - { - return Visited ? name : string.Format("\"{0}\"", name); - } + => Visited ? name : "\"" + name + "\""; protected string HandleStringComparison(string col, string val, string verb, TextColumnType columnType) { @@ -726,115 +680,38 @@ namespace Umbraco.Core.Persistence.Querying { case "SqlWildcard": SqlParameters.Add(RemoveQuote(val)); - //don't execute if compiled - if (Visited == false) - return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); - //already compiled, return - return string.Empty; + return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); + case "Equals": - SqlParameters.Add(RemoveQuote(val)); - //don't execute if compiled - if (Visited == false) - return SqlSyntax.GetStringColumnEqualComparison(col, SqlParameters.Count - 1, columnType); - //already compiled, return - return string.Empty; - case "StartsWith": - SqlParameters.Add(string.Format("{0}{1}", - RemoveQuote(val), - SqlSyntax.GetWildcardPlaceholder())); - //don't execute if compiled - if (Visited == false) - return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); - //already compiled, return - return string.Empty; - case "EndsWith": - SqlParameters.Add(string.Format("{0}{1}", - SqlSyntax.GetWildcardPlaceholder(), - RemoveQuote(val))); - //don't execute if compiled - if (Visited == false) - return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); - //already compiled, return - return string.Empty; - case "Contains": - SqlParameters.Add(string.Format("{0}{1}{0}", - SqlSyntax.GetWildcardPlaceholder(), - RemoveQuote(val))); - //don't execute if compiled - if (Visited == false) - return SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); - //already compiled, return - return string.Empty; case "InvariantEquals": case "SqlEquals": - //recurse - return HandleStringComparison(col, val, "Equals", columnType); + SqlParameters.Add(RemoveQuote(val)); + return Visited ? string.Empty : SqlSyntax.GetStringColumnEqualComparison(col, SqlParameters.Count - 1, columnType); + + case "StartsWith": case "InvariantStartsWith": case "SqlStartsWith": - //recurse - return HandleStringComparison(col, val, "StartsWith", columnType); + SqlParameters.Add(RemoveQuote(val) + SqlSyntax.GetWildcardPlaceholder()); + return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); + + case "EndsWith": case "InvariantEndsWith": case "SqlEndsWith": - //recurse - return HandleStringComparison(col, val, "EndsWith", columnType); + SqlParameters.Add(SqlSyntax.GetWildcardPlaceholder() + RemoveQuote(val)); + return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); + + case "Contains": case "InvariantContains": case "SqlContains": - //recurse - return HandleStringComparison(col, val, "Contains", columnType); + var wildcardPlaceholder = SqlSyntax.GetWildcardPlaceholder(); + SqlParameters.Add(wildcardPlaceholder + RemoveQuote(val) + wildcardPlaceholder); + return Visited ? string.Empty : SqlSyntax.GetStringColumnWildcardComparison(col, SqlParameters.Count - 1, columnType); + default: - throw new ArgumentOutOfRangeException("verb"); + throw new ArgumentOutOfRangeException(nameof(verb)); } } - //public virtual string GetQuotedValue(object value, Type fieldType, Func escapeCallback = null, Func shouldQuoteCallback = null) - //{ - // if (value == null) return "NULL"; - - // if (escapeCallback == null) - // { - // escapeCallback = EscapeParam; - // } - // if (shouldQuoteCallback == null) - // { - // shouldQuoteCallback = ShouldQuoteValue; - // } - - // if (!fieldType.UnderlyingSystemType.IsValueType && fieldType != typeof(string)) - // { - // //if (TypeSerializer.CanCreateFromString(fieldType)) - // //{ - // // return "'" + escapeCallback(TypeSerializer.SerializeToString(value)) + "'"; - // //} - - // throw new NotSupportedException( - // string.Format("Property of type: {0} is not supported", fieldType.FullName)); - // } - - // if (fieldType == typeof(int)) - // return ((int)value).ToString(CultureInfo.InvariantCulture); - - // if (fieldType == typeof(float)) - // return ((float)value).ToString(CultureInfo.InvariantCulture); - - // if (fieldType == typeof(double)) - // return ((double)value).ToString(CultureInfo.InvariantCulture); - - // if (fieldType == typeof(decimal)) - // return ((decimal)value).ToString(CultureInfo.InvariantCulture); - - // if (fieldType == typeof(DateTime)) - // { - // return "'" + escapeCallback(((DateTime)value).ToIsoString()) + "'"; - // } - - // if (fieldType == typeof(bool)) - // return ((bool)value) ? Convert.ToString(1, CultureInfo.InvariantCulture) : Convert.ToString(0, CultureInfo.InvariantCulture); - - // return shouldQuoteCallback(fieldType) - // ? "'" + escapeCallback(value) + "'" - // : value.ToString(); - //} - public virtual string EscapeParam(object paramValue, ISqlSyntaxProvider sqlSyntax) { return paramValue == null @@ -842,34 +719,14 @@ namespace Umbraco.Core.Persistence.Querying : sqlSyntax.EscapeString(paramValue.ToString()); } - public virtual bool ShouldQuoteValue(Type fieldType) - { - return true; - } - protected virtual string RemoveQuote(string exp) { - if ((exp.StartsWith("\"") || exp.StartsWith("`") || exp.StartsWith("'")) - && - (exp.EndsWith("\"") || exp.EndsWith("`") || exp.EndsWith("'"))) - { - exp = exp.Remove(0, 1); - exp = exp.Remove(exp.Length - 1, 1); - } - return exp; + if (exp.IsNullOrWhiteSpace()) return exp; + + var c = exp[0]; + return (c == '"' || c == '`' || c == '\'') && exp[exp.Length - 1] == c + ? exp.Substring(1, exp.Length - 2) + : exp; } - - //protected virtual string RemoveQuoteFromAlias(string expression) - //{ - - // if ((expression.StartsWith("\"") || expression.StartsWith("`") || expression.StartsWith("'")) - // && - // (expression.EndsWith("\"") || expression.EndsWith("`") || expression.EndsWith("'"))) - // { - // expression = expression.Remove(0, 1); - // expression = expression.Remove(expression.Length - 1, 1); - // } - // return expression; - //} } } diff --git a/src/Umbraco.Tests/Persistence/Querying/ExpressionTests.cs b/src/Umbraco.Tests/Persistence/Querying/ExpressionTests.cs index f73831e8bc..70d70d4a31 100644 --- a/src/Umbraco.Tests/Persistence/Querying/ExpressionTests.cs +++ b/src/Umbraco.Tests/Persistence/Querying/ExpressionTests.cs @@ -52,9 +52,9 @@ namespace Umbraco.Tests.Persistence.Querying } [Test] - public void Can_Query_With_Content_Type_Aliases() + public void Can_Query_With_Content_Type_Aliases_IEnumerable() { - //Arrange + //Arrange - Contains is IEnumerable.Contains extension method var aliases = new[] { "Test1", "Test2" }; Expression> predicate = content => aliases.Contains(content.ContentType.Alias); var modelToSqlExpressionHelper = new ModelToSqlExpressionVisitor(SqlContext.SqlSyntax, Mappers); @@ -67,6 +67,22 @@ namespace Umbraco.Tests.Persistence.Querying Assert.AreEqual("Test2", modelToSqlExpressionHelper.GetSqlParameters()[2]); } + [Test] + public void Can_Query_With_Content_Type_Aliases_List() + { + //Arrange - Contains is List.Contains instance method + var aliases = new System.Collections.Generic.List { "Test1", "Test2" }; + Expression> predicate = content => aliases.Contains(content.ContentType.Alias); + var modelToSqlExpressionHelper = new ModelToSqlExpressionVisitor(SqlContext.SqlSyntax, Mappers); + var result = modelToSqlExpressionHelper.Visit(predicate); + + Debug.Print("Model to Sql ExpressionHelper: \n" + result); + + Assert.AreEqual("[cmsContentType].[alias] IN (@1,@2)", result); + Assert.AreEqual("Test1", modelToSqlExpressionHelper.GetSqlParameters()[1]); + Assert.AreEqual("Test2", modelToSqlExpressionHelper.GetSqlParameters()[2]); + } + [Test] public void CachedExpression_Can_Verify_Path_StartsWith_Predicate_In_Same_Result() {