using System;
using System.Collections;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using Umbraco.Core.Persistence.Querying;
using Umbraco.Core.Persistence.SqlSyntax;
namespace Umbraco.Core.Persistence
{
///
/// Extension methods adding strong types to PetaPoco's Sql Builder
///
public static class PetaPocoSqlExtensions
{
///
/// Defines the column to select in the generated SQL query
///
///
/// Sql object
/// Sql syntax
/// Columns to select
///
public static Sql Select(this Sql sql, ISqlSyntaxProvider sqlSyntax, params Expression>[] fields)
{
return sql.Select(GetFieldNames(sqlSyntax, fields));
}
///
/// Adds another set of field to select. This method must be used with "Select" when fecthing fields from different tables.
///
///
/// Sql object
/// Sql syntax
/// Additional columns to select
///
public static Sql AndSelect(this Sql sql, ISqlSyntaxProvider sqlSyntax, params Expression>[] fields)
{
return sql.AndSelect(GetFieldNames(sqlSyntax, fields));
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql From(this Sql sql)
{
return From(sql, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql From(this Sql sql, ISqlSyntaxProvider sqlSyntax)
{
var type = typeof(T);
var tableName = type.GetTableName();
return sql.From(sqlSyntax.GetQuotedTableName(tableName));
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql Where(this Sql sql, Expression> predicate)
{
var expresionist = new PocoToSqlExpressionVisitor();
var whereExpression = expresionist.Visit(predicate);
return sql.Where(whereExpression, expresionist.GetSqlParameters());
}
public static Sql Where(this Sql sql, Expression> predicate, ISqlSyntaxProvider sqlSyntax)
{
var expresionist = new PocoToSqlExpressionVisitor(sqlSyntax);
var whereExpression = expresionist.Visit(predicate);
return sql.Where(whereExpression, expresionist.GetSqlParameters());
}
private static string GetFieldName(Expression> fieldSelector, ISqlSyntaxProvider sqlSyntax)
{
var field = ExpressionHelper.FindProperty(fieldSelector) as PropertyInfo;
var fieldName = field.GetColumnName();
var type = typeof(T);
var tableName = type.GetTableName();
return sqlSyntax.GetQuotedTableName(tableName) + "." + sqlSyntax.GetQuotedColumnName(fieldName);
}
private static string[] GetFieldNames(ISqlSyntaxProvider sqlSyntax, params Expression>[] fields)
{
if (fields.Length == 0)
{
return new[] { string.Format("{0}.*", sqlSyntax.GetQuotedTableName(typeof(T).GetTableName())) };
}
return fields.Select(field => GetFieldName(field, sqlSyntax)).ToArray();
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql WhereIn(this Sql sql, Expression> fieldSelector, IEnumerable values)
{
return sql.WhereIn(fieldSelector, values, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql WhereIn(this Sql sql, Expression> fieldSelector, IEnumerable values, ISqlSyntaxProvider sqlSyntax)
{
var fieldName = GetFieldName(fieldSelector, sqlSyntax);
return sql.Where(fieldName + " IN (@values)", new { values });
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql WhereAnyIn(this Sql sql, Expression>[] fieldSelectors, IEnumerable values)
{
return sql.WhereAnyIn(fieldSelectors, values, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql WhereAnyIn(this Sql sql, Expression>[] fieldSelectors, IEnumerable values, ISqlSyntaxProvider sqlSyntax)
{
var fieldNames = fieldSelectors.Select(x => GetFieldName(x, sqlSyntax)).ToArray();
var sb = new StringBuilder();
sb.Append("(");
for (var i = 0; i < fieldNames.Length; i++)
{
if (i > 0) sb.Append(" OR ");
sb.Append(fieldNames[i] + " IN (@values)");
}
sb.Append(")");
return sql.Where(sb.ToString(), new { values });
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql OrderBy(this Sql sql, Expression> columnMember)
{
return OrderBy(sql, columnMember, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql OrderBy(this Sql sql, Expression> columnMember, ISqlSyntaxProvider sqlSyntax)
{
var syntax = "(" + GetFieldName(columnMember, sqlSyntax) + ")";
return sql.OrderBy(syntax);
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql OrderByDescending(this Sql sql, Expression> columnMember)
{
return OrderByDescending(sql, columnMember, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql OrderByDescending(this Sql sql, Expression> columnMember, ISqlSyntaxProvider sqlSyntax)
{
var column = ExpressionHelper.FindProperty(columnMember) as PropertyInfo;
var columnName = column.GetColumnName();
var type = typeof(TColumn);
var tableName = type.GetTableName();
//need to ensure the order by is in brackets, see: https://github.com/toptensoftware/PetaPoco/issues/177
var syntax = string.Format("({0}.{1}) DESC",
sqlSyntax.GetQuotedTableName(tableName),
sqlSyntax.GetQuotedColumnName(columnName));
return sql.OrderBy(syntax);
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql GroupBy(this Sql sql, Expression> columnMember)
{
return GroupBy(sql, columnMember, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql GroupBy(this Sql sql, Expression> columnMember, ISqlSyntaxProvider sqlProvider)
{
var column = ExpressionHelper.FindProperty(columnMember) as PropertyInfo;
var columnName = column.GetColumnName();
return sql.GroupBy(sqlProvider.GetQuotedColumnName(columnName));
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql.SqlJoinClause InnerJoin(this Sql sql)
{
return InnerJoin(sql, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql.SqlJoinClause InnerJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax)
{
var type = typeof(T);
var tableName = type.GetTableName();
return sql.InnerJoin(sqlSyntax.GetQuotedTableName(tableName));
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql.SqlJoinClause LeftJoin(this Sql sql)
{
return LeftJoin(sql, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql.SqlJoinClause LeftJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax)
{
var type = typeof(T);
var tableName = type.GetTableName();
return sql.LeftJoin(sqlSyntax.GetQuotedTableName(tableName));
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql.SqlJoinClause LeftOuterJoin(this Sql sql)
{
return LeftOuterJoin(sql, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql.SqlJoinClause LeftOuterJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax)
{
var type = typeof(T);
var tableName = type.GetTableName();
return sql.LeftOuterJoin(sqlSyntax.GetQuotedTableName(tableName));
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql.SqlJoinClause RightJoin(this Sql sql)
{
return RightJoin(sql, SqlSyntaxContext.SqlSyntaxProvider);
}
public static Sql.SqlJoinClause RightJoin(this Sql sql, ISqlSyntaxProvider sqlSyntax)
{
var type = typeof(T);
var tableName = type.GetTableName();
return sql.RightJoin(sqlSyntax.GetQuotedTableName(tableName));
}
[Obsolete("Use the overload specifying ISqlSyntaxProvider instead")]
public static Sql On(this Sql.SqlJoinClause sql, Expression> leftMember,
Expression> rightMember, params object[] args)
{
return On(sql, SqlSyntaxContext.SqlSyntaxProvider, leftMember, rightMember, args);
}
public static Sql On(this Sql.SqlJoinClause sql, ISqlSyntaxProvider sqlSyntax, Expression> leftMember,
Expression> rightMember, params object[] args)
{
var leftType = typeof(TLeft);
var rightType = typeof(TRight);
var leftTableName = leftType.GetTableName();
var rightTableName = rightType.GetTableName();
var leftColumn = ExpressionHelper.FindProperty(leftMember) as PropertyInfo;
var rightColumn = ExpressionHelper.FindProperty(rightMember) as PropertyInfo;
var leftColumnName = leftColumn.GetColumnName();
var rightColumnName = rightColumn.GetColumnName();
string onClause = string.Format("{0}.{1} = {2}.{3}",
sqlSyntax.GetQuotedTableName(leftTableName),
sqlSyntax.GetQuotedColumnName(leftColumnName),
sqlSyntax.GetQuotedTableName(rightTableName),
sqlSyntax.GetQuotedColumnName(rightColumnName));
return sql.On(onClause);
}
public static Sql OrderByDescending(this Sql sql, params object[] columns)
{
return sql.Append(new Sql("ORDER BY " + String.Join(", ", (from x in columns select x + " DESC").ToArray())));
}
private static string GetTableName(this Type type)
{
// todo: returning string.Empty for now
// BUT the code bits that calls this method cannot deal with string.Empty so we
// should either throw, or fix these code bits...
var attr = type.FirstAttribute();
return attr == null || string.IsNullOrWhiteSpace(attr.Value) ? string.Empty : attr.Value;
}
private static string GetColumnName(this PropertyInfo column)
{
var attr = column.FirstAttribute();
return attr == null || string.IsNullOrWhiteSpace(attr.Name) ? column.Name : attr.Name;
}
}
}