diff --git a/src/Umbraco.Core/Persistence/NPocoSqlExtensions.cs b/src/Umbraco.Core/Persistence/NPocoSqlExtensions.cs index 9c1f0d9a07..0d3faa0d68 100644 --- a/src/Umbraco.Core/Persistence/NPocoSqlExtensions.cs +++ b/src/Umbraco.Core/Persistence/NPocoSqlExtensions.cs @@ -7,7 +7,6 @@ using System.Reflection; using System.Text; using NPoco; using Umbraco.Core.Persistence.Querying; -using Umbraco.Core.Persistence.SqlSyntax; namespace Umbraco.Core.Persistence { @@ -74,10 +73,27 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql Where(this Sql sql, Expression> predicate, string alias = null) { - var expresionist = new PocoToSqlExpressionVisitor(sql.SqlContext, alias); - var whereExpression = expresionist.Visit(predicate); - sql.Where(whereExpression, expresionist.GetSqlParameters()); - return sql; + var (s, a) = sql.SqlContext.Visit(predicate, alias); + return sql.Where(s, a); + } + + /// + /// Appends an AND clause to a WHERE Sql statement. + /// + /// The type of the Dto. + /// The Sql statement. + /// A predicate to transform and append to the Sql statement. + /// An optional alias for the table. + /// The Sql statement. + /// + /// Chaining .Where(...).Where(...) in NPoco works because it merges the two WHERE statements, + /// however if the first statement is not an explicit WHERE statement, chaining fails and two WHERE + /// statements appear in the resulting Sql. This allows for adding an AND clause without problems. + /// + public static Sql AndWhere(this Sql sql, Expression> predicate, string alias = null) + { + var (s, a) = sql.SqlContext.Visit(predicate, alias); + return sql.Append("AND (" + s + ")", a); } /// @@ -92,10 +108,8 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql Where(this Sql sql, Expression> predicate, string alias1 = null, string alias2 = null) { - var expresionist = new PocoToSqlExpressionVisitor(sql.SqlContext, alias1, alias2); - var whereExpression = expresionist.Visit(predicate); - sql.Where(whereExpression, expresionist.GetSqlParameters()); - return sql; + var (s, a) = sql.SqlContext.Visit(predicate, alias1, alias2); + return sql.Where(s, a); } /// @@ -108,7 +122,7 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql WhereIn(this Sql sql, Expression> field, IEnumerable values) { - var fieldName = GetFieldName(field, sql.SqlContext.SqlSyntax); + var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(field); sql.Where(fieldName + " IN (@values)", new { values }); return sql; } @@ -136,7 +150,7 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql WhereNotIn(this Sql sql, Expression> field, IEnumerable values) { - var fieldName = GetFieldName(field, sql.SqlContext.SqlSyntax); + var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(field); sql.Where(fieldName + " NOT IN (@values)", new { values }); return sql; } @@ -164,7 +178,8 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql WhereAnyIn(this Sql sql, Expression>[] fields, IEnumerable values) { - var fieldNames = fields.Select(x => GetFieldName(x, sql.SqlContext.SqlSyntax)).ToArray(); + var sqlSyntax = sql.SqlContext.SqlSyntax; + var fieldNames = fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); var sb = new StringBuilder(); sb.Append("("); for (var i = 0; i < fieldNames.Length; i++) @@ -180,7 +195,7 @@ namespace Umbraco.Core.Persistence private static Sql WhereIn(this Sql sql, Expression> fieldSelector, Sql valuesSql, bool not) { - var fieldName = GetFieldName(fieldSelector, sql.SqlContext.SqlSyntax); + var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(fieldSelector); sql.Where(fieldName + (not ? " NOT" : "") +" IN (" + valuesSql.SQL + ")", valuesSql.Arguments); return sql; } @@ -274,7 +289,7 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql OrderBy(this Sql sql, Expression> field) { - return sql.OrderBy("(" + GetFieldName(field, sql.SqlContext.SqlSyntax) + ")"); + return sql.OrderBy("(" + sql.SqlContext.SqlSyntax.GetFieldName(field) + ")"); } /// @@ -286,9 +301,10 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql OrderBy(this Sql sql, params Expression>[] fields) { + var sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) - : fields.Select(x => GetFieldName(x, sql.SqlContext.SqlSyntax)).ToArray(); + : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.OrderBy(columns); } @@ -301,7 +317,7 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql OrderByDescending(this Sql sql, Expression> field) { - return sql.OrderBy("(" + GetFieldName(field, sql.SqlContext.SqlSyntax) + ") DESC"); + return sql.OrderBy("(" + sql.SqlContext.SqlSyntax.GetFieldName(field) + ") DESC"); } /// @@ -313,9 +329,10 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql OrderByDescending(this Sql sql, params Expression>[] fields) { + var sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) - : fields.Select(x => GetFieldName(x, sql.SqlContext.SqlSyntax)).ToArray(); + : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.OrderBy(columns.Select(x => x + " DESC")); } @@ -339,7 +356,7 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql GroupBy(this Sql sql, Expression> field) { - return sql.GroupBy(GetFieldName(field, sql.SqlContext.SqlSyntax)); + return sql.GroupBy(sql.SqlContext.SqlSyntax.GetFieldName(field)); } /// @@ -351,9 +368,10 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql GroupBy(this Sql sql, params Expression>[] fields) { + var sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) - : fields.Select(x => GetFieldName(x, sql.SqlContext.SqlSyntax)).ToArray(); + : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.GroupBy(columns); } @@ -366,9 +384,10 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql AndBy(this Sql sql, params Expression>[] fields) { + var sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) - : fields.Select(x => GetFieldName(x, sql.SqlContext.SqlSyntax)).ToArray(); + : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.Append(", " + string.Join(", ", columns)); } @@ -381,9 +400,10 @@ namespace Umbraco.Core.Persistence /// The Sql statement. public static Sql AndByDescending(this Sql sql, params Expression>[] fields) { + var sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) - : fields.Select(x => GetFieldName(x, sql.SqlContext.SqlSyntax)).ToArray(); + : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.Append(", " + string.Join(", ", columns.Select(x => x + " DESC"))); } @@ -572,9 +592,10 @@ namespace Umbraco.Core.Persistence public static Sql SelectCount(this Sql sql, params Expression>[] fields) { if (sql == null) throw new ArgumentNullException(nameof(sql)); + var sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) - : fields.Select(x => GetFieldName(x, sql.SqlContext.SqlSyntax)).ToArray(); + : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.Select("COUNT (" + string.Join(", ", columns) + ")"); } @@ -906,7 +927,7 @@ namespace Umbraco.Core.Persistence public SqlUpd Set(Expression> fieldSelector, object value) { - var fieldName = GetFieldName(fieldSelector, _sqlContext.SqlSyntax); + var fieldName = _sqlContext.SqlSyntax.GetFieldName(fieldSelector); _setExpressions.Add(new Tuple(fieldName, value)); return this; } @@ -1062,17 +1083,6 @@ namespace Umbraco.Core.Persistence return string.IsNullOrWhiteSpace(attr?.Name) ? column.Name : attr.Name; } - private static string GetFieldName(Expression> fieldSelector, ISqlSyntaxProvider sqlSyntax) - { - var field = ExpressionHelper.FindProperty(fieldSelector).Item1 as PropertyInfo; - var fieldName = field.GetColumnName(); - - var type = typeof (TDto); - var tableName = type.GetTableName(); - - return sqlSyntax.GetQuotedTableName(tableName) + "." + sqlSyntax.GetQuotedColumnName(fieldName); - } - internal static void WriteToConsole(this Sql sql) { Console.WriteLine(sql.SQL); diff --git a/src/Umbraco.Core/Persistence/SqlContextExtensions.cs b/src/Umbraco.Core/Persistence/SqlContextExtensions.cs new file mode 100644 index 0000000000..e28816b6a4 --- /dev/null +++ b/src/Umbraco.Core/Persistence/SqlContextExtensions.cs @@ -0,0 +1,78 @@ +using System; +using System.Linq.Expressions; +using Umbraco.Core.Persistence.Querying; + +namespace Umbraco.Core.Persistence +{ + /// + /// Provides extension methods to . + /// + public static class SqlContextExtensions + { + /// + /// Visit an expression. + /// + /// The type of the DTO. + /// An . + /// An expression to visit. + /// An optional table alias. + /// A SQL statement, and arguments, corresponding to the expression. + public static (string Sql, object[] Args) Visit(this ISqlContext sqlContext, Expression> expression, string alias = null) + { + var expresionist = new PocoToSqlExpressionVisitor(sqlContext, alias); + var visited = expresionist.Visit(expression); + return (visited, expresionist.GetSqlParameters()); + } + + /// + /// Visit an expression. + /// + /// The type of the DTO. + /// The type returned by the expression. + /// An . + /// An expression to visit. + /// An optional table alias. + /// A SQL statement, and arguments, corresponding to the expression. + public static (string Sql, object[] Args) Visit(this ISqlContext sqlContext, Expression> expression, string alias = null) + { + var expresionist = new PocoToSqlExpressionVisitor(sqlContext, alias); + var visited = expresionist.Visit(expression); + return (visited, expresionist.GetSqlParameters()); + } + + /// + /// Visit an expression. + /// + /// The type of the first DTO. + /// The type of the second DTO. + /// An . + /// An expression to visit. + /// An optional table alias for the first DTO. + /// An optional table alias for the second DTO. + /// A SQL statement, and arguments, corresponding to the expression. + public static (string Sql, object[] Args) Visit(this ISqlContext sqlContext, Expression> expression, string alias1 = null, string alias2 = null) + { + var expresionist = new PocoToSqlExpressionVisitor(sqlContext, alias1, alias2); + var visited = expresionist.Visit(expression); + return (visited, expresionist.GetSqlParameters()); + } + + /// + /// Visit an expression. + /// + /// The type of the first DTO. + /// The type of the second DTO. + /// The type returned by the expression. + /// An . + /// An expression to visit. + /// An optional table alias for the first DTO. + /// An optional table alias for the second DTO. + /// A SQL statement, and arguments, corresponding to the expression. + public static (string Sql, object[] Args) Visit(this ISqlContext sqlContext, Expression> expression, string alias1 = null, string alias2 = null) + { + var expresionist = new PocoToSqlExpressionVisitor(sqlContext, alias1, alias2); + var visited = expresionist.Visit(expression); + return (visited, expresionist.GetSqlParameters()); + } + } +} \ No newline at end of file diff --git a/src/Umbraco.Core/Persistence/SqlSyntaxExtensions.cs b/src/Umbraco.Core/Persistence/SqlSyntaxExtensions.cs new file mode 100644 index 0000000000..43ef03327b --- /dev/null +++ b/src/Umbraco.Core/Persistence/SqlSyntaxExtensions.cs @@ -0,0 +1,48 @@ +using System; +using System.Linq.Expressions; +using System.Reflection; +using NPoco; +using Umbraco.Core.Persistence.SqlSyntax; + +namespace Umbraco.Core.Persistence +{ + /// + /// Provides extension methods to . + /// + public static class SqlSyntaxExtensions + { + private static string GetTableName(this Type type) + { + // todo: returning string.Empty for now + // BUT the code bits that calls this method cannot deal with string.Empty so we + // should either throw, or fix these code bits... + var attr = type.FirstAttribute(); + return string.IsNullOrWhiteSpace(attr?.Value) ? string.Empty : attr.Value; + } + + private static string GetColumnName(this PropertyInfo column) + { + var attr = column.FirstAttribute(); + return string.IsNullOrWhiteSpace(attr?.Name) ? column.Name : attr.Name; + } + + /// + /// Gets a quoted table and field name. + /// + /// The type of the DTO. + /// An . + /// An expression specifying the field. + /// An optional table alias. + /// + public static string GetFieldName(this ISqlSyntaxProvider sqlSyntax, Expression> fieldSelector, string tableAlias = null) + { + var field = ExpressionHelper.FindProperty(fieldSelector).Item1 as PropertyInfo; + var fieldName = field.GetColumnName(); + + var type = typeof(TDto); + var tableName = tableAlias ?? type.GetTableName(); + + return sqlSyntax.GetQuotedTableName(tableName) + "." + sqlSyntax.GetQuotedColumnName(fieldName); + } + } +} diff --git a/src/Umbraco.Core/Umbraco.Core.csproj b/src/Umbraco.Core/Umbraco.Core.csproj index b747b0960f..e36f2eaf81 100644 --- a/src/Umbraco.Core/Umbraco.Core.csproj +++ b/src/Umbraco.Core/Umbraco.Core.csproj @@ -406,6 +406,8 @@ + +