using System; using System.Collections; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Text; using Umbraco.Core.Persistence.Querying; using Umbraco.Core.Persistence.SqlSyntax; namespace Umbraco.Core.Persistence { /// /// Extension methods adding strong types to PetaPoco's Sql Builder /// public static class PetaPocoSqlExtensions { /// /// Defines the column to select in the generated SQL query /// /// /// Sql object /// Sql syntax /// Columns to select /// public static Sql Select(this Sql sql, ISqlSyntaxProvider sqlSyntax, params Expression>[] fields) { return sql.Select(GetFieldNames(sqlSyntax, fields)); } /// /// Adds another set of field to select. This method must be used with "Select" when fecthing fields from different tables. /// /// /// Sql object /// Sql syntax /// Additional columns to select /// public static Sql AndSelect(this Sql sql, ISqlSyntaxProvider sqlSyntax, params Expression>[] fields) { return sql.AndSelect(GetFieldNames(sqlSyntax, fields)); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql From(this Sql sql) { return From(sql, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql From(this Sql sql, ISqlSyntaxProvider sqlSyntax) { var type = typeof(T); var tableName = type.GetTableName(); return sql.From(sqlSyntax.GetQuotedTableName(tableName)); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql Where(this Sql sql, Expression> predicate) { var expresionist = new PocoToSqlExpressionVisitor(); var whereExpression = expresionist.Visit(predicate); return sql.Where(whereExpression, expresionist.GetSqlParameters()); } public static Sql Where(this Sql sql, Expression> predicate, ISqlSyntaxProvider sqlSyntax) { var expresionist = new PocoToSqlExpressionVisitor(sqlSyntax); var whereExpression = expresionist.Visit(predicate); return sql.Where(whereExpression, expresionist.GetSqlParameters()); } private static string GetFieldName(Expression> fieldSelector, ISqlSyntaxProvider sqlSyntax) { var field = ExpressionHelper.FindProperty(fieldSelector) as PropertyInfo; var fieldName = field.GetColumnName(); var type = typeof(T); var tableName = type.GetTableName(); return sqlSyntax.GetQuotedTableName(tableName) + "." + sqlSyntax.GetQuotedColumnName(fieldName); } private static string[] GetFieldNames(ISqlSyntaxProvider sqlSyntax, params Expression>[] fields) { if (fields.Length == 0) { return new[] { string.Format("{0}.*", sqlSyntax.GetQuotedTableName(typeof(T).GetTableName())) }; } return fields.Select(field => GetFieldName(field, sqlSyntax)).ToArray(); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql WhereIn(this Sql sql, Expression> fieldSelector, IEnumerable values) { return sql.WhereIn(fieldSelector, values, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql WhereIn(this Sql sql, Expression> fieldSelector, IEnumerable values, ISqlSyntaxProvider sqlSyntax) { var fieldName = GetFieldName(fieldSelector, sqlSyntax); return sql.Where(fieldName + " IN (@values)", new { values }); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql WhereAnyIn(this Sql sql, Expression>[] fieldSelectors, IEnumerable values) { return sql.WhereAnyIn(fieldSelectors, values, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql WhereAnyIn(this Sql sql, Expression>[] fieldSelectors, IEnumerable values, ISqlSyntaxProvider sqlSyntax) { var fieldNames = fieldSelectors.Select(x => GetFieldName(x, sqlSyntax)).ToArray(); var sb = new StringBuilder(); sb.Append("("); for (var i = 0; i < fieldNames.Length; i++) { if (i > 0) sb.Append(" OR "); sb.Append(fieldNames[i] + " IN (@values)"); } sb.Append(")"); return sql.Where(sb.ToString(), new { values }); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql OrderBy(this Sql sql, Expression> columnMember) { return OrderBy(sql, columnMember, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql OrderBy(this Sql sql, Expression> columnMember, ISqlSyntaxProvider sqlSyntax) { var syntax = "(" + GetFieldName(columnMember, sqlSyntax) + ")"; return sql.OrderBy(syntax); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql OrderByDescending(this Sql sql, Expression> columnMember) { return OrderByDescending(sql, columnMember, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql OrderByDescending(this Sql sql, Expression> columnMember, ISqlSyntaxProvider sqlSyntax) { var column = ExpressionHelper.FindProperty(columnMember) as PropertyInfo; var columnName = column.GetColumnName(); var type = typeof(TColumn); var tableName = type.GetTableName(); //need to ensure the order by is in brackets, see: https://github.com/toptensoftware/PetaPoco/issues/177 var syntax = string.Format("({0}.{1}) DESC", sqlSyntax.GetQuotedTableName(tableName), sqlSyntax.GetQuotedColumnName(columnName)); return sql.OrderBy(syntax); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql GroupBy(this Sql sql, Expression> columnMember) { return GroupBy(sql, columnMember, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql GroupBy(this Sql sql, Expression> columnMember, ISqlSyntaxProvider sqlProvider) { var column = ExpressionHelper.FindProperty(columnMember) as PropertyInfo; var columnName = column.GetColumnName(); return sql.GroupBy(sqlProvider.GetQuotedColumnName(columnName)); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql.SqlJoinClause InnerJoin(this Sql sql) { return InnerJoin(sql, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql.SqlJoinClause InnerJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax) { var type = typeof(T); var tableName = type.GetTableName(); return sql.InnerJoin(sqlSyntax.GetQuotedTableName(tableName)); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql.SqlJoinClause LeftJoin(this Sql sql) { return LeftJoin(sql, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql.SqlJoinClause LeftJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax) { var type = typeof(T); var tableName = type.GetTableName(); return sql.LeftJoin(sqlSyntax.GetQuotedTableName(tableName)); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql.SqlJoinClause LeftOuterJoin(this Sql sql) { return LeftOuterJoin(sql, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql.SqlJoinClause LeftOuterJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax) { var type = typeof(T); var tableName = type.GetTableName(); return sql.LeftOuterJoin(sqlSyntax.GetQuotedTableName(tableName)); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql.SqlJoinClause RightJoin(this Sql sql) { return RightJoin(sql, SqlSyntaxContext.SqlSyntaxProvider); } public static Sql.SqlJoinClause RightJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax) { var type = typeof(T); var tableName = type.GetTableName(); return sql.RightJoin(sqlSyntax.GetQuotedTableName(tableName)); } [Obsolete("Use the overload specifying ISqlSyntaxProvider instead")] public static Sql On(this Sql.SqlJoinClause sql, Expression> leftMember, Expression> rightMember, params object[] args) { return On(sql, SqlSyntaxContext.SqlSyntaxProvider, leftMember, rightMember, args); } public static Sql On(this Sql.SqlJoinClause sql, ISqlSyntaxProvider sqlSyntax, Expression> leftMember, Expression> rightMember, params object[] args) { var leftType = typeof(TLeft); var rightType = typeof(TRight); var leftTableName = leftType.GetTableName(); var rightTableName = rightType.GetTableName(); var leftColumn = ExpressionHelper.FindProperty(leftMember) as PropertyInfo; var rightColumn = ExpressionHelper.FindProperty(rightMember) as PropertyInfo; var leftColumnName = leftColumn.GetColumnName(); var rightColumnName = rightColumn.GetColumnName(); string onClause = string.Format("{0}.{1} = {2}.{3}", sqlSyntax.GetQuotedTableName(leftTableName), sqlSyntax.GetQuotedColumnName(leftColumnName), sqlSyntax.GetQuotedTableName(rightTableName), sqlSyntax.GetQuotedColumnName(rightColumnName)); return sql.On(onClause); } public static Sql OrderByDescending(this Sql sql, params object[] columns) { return sql.Append(new Sql("ORDER BY " + String.Join(", ", (from x in columns select x + " DESC").ToArray()))); } private static string GetTableName(this Type type) { // todo: returning string.Empty for now // BUT the code bits that calls this method cannot deal with string.Empty so we // should either throw, or fix these code bits... var attr = type.FirstAttribute(); return attr == null || string.IsNullOrWhiteSpace(attr.Value) ? string.Empty : attr.Value; } private static string GetColumnName(this PropertyInfo column) { var attr = column.FirstAttribute(); return attr == null || string.IsNullOrWhiteSpace(attr.Name) ? column.Name : attr.Name; } } }