using System.Collections; using System.Linq.Expressions; using System.Reflection; using System.Text; using System.Text.RegularExpressions; using NPoco; using Umbraco.Cms.Core; using Umbraco.Cms.Infrastructure.Persistence; using Umbraco.Cms.Infrastructure.Persistence.Querying; using Umbraco.Cms.Infrastructure.Persistence.SqlSyntax; namespace Umbraco.Extensions { public static partial class NPocoSqlExtensions { #region Where /// /// Appends a WHERE clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// A predicate to transform and append to the Sql statement. /// An optional alias for the table. /// The Sql statement. public static Sql Where(this Sql sql, Expression> predicate, string? alias = null) { var (s, a) = sql.SqlContext.VisitDto(predicate, alias); return sql.Where(s, a); } /// /// Appends a WHERE clause to the Sql statement. /// /// The type of Dto 1. /// The type of Dto 2. /// The Sql statement. /// A predicate to transform and append to the Sql statement. /// An optional alias for Dto 1 table. /// An optional alias for Dto 2 table. /// The Sql statement. public static Sql Where(this Sql sql, Expression> predicate, string? alias1 = null, string? alias2 = null) { var (s, a) = sql.SqlContext.VisitDto(predicate, alias1, alias2); return sql.Where(s, a); } /// /// Appends a WHERE IN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// The values. /// The Sql statement. public static Sql WhereIn(this Sql sql, Expression> field, IEnumerable? values) { var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(field); sql.Where(fieldName + " IN (@values)", new { values }); return sql; } /// /// Appends a WHERE IN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// The values. /// The Sql statement. public static Sql WhereIn(this Sql sql, Expression> field, IEnumerable? values, string alias) { var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(field, alias); sql.Where(fieldName + " IN (@values)", new { values }); return sql; } /// /// Appends a WHERE IN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// A subquery returning the value. /// The Sql statement. public static Sql WhereIn(this Sql sql, Expression> field, Sql? values) { return WhereIn(sql, field, values, false, null); } public static Sql WhereIn(this Sql sql, Expression> field, Sql? values, string tableAlias) { return sql.WhereIn(field, values, false, tableAlias); } public static Sql WhereLike(this Sql sql, Expression> fieldSelector, Sql? valuesSql) { var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(fieldSelector); sql.Where(fieldName + " LIKE (" + valuesSql?.SQL + ")", valuesSql?.Arguments); return sql; } public static Sql Union(this Sql sql, Sql sql2) { return sql.Append( " UNION ").Append(sql2); } public static Sql.SqlJoinClause InnerJoinNested(this Sql sql, Sql nestedQuery, string alias) { return new Sql.SqlJoinClause(sql.Append("INNER JOIN (").Append(nestedQuery) .Append($") [{alias}]")); } public static Sql WhereLike(this Sql sql, Expression> fieldSelector, string likeValue) { var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(fieldSelector); sql.Where(fieldName + " LIKE ('" + likeValue + "')"); return sql; } /// /// Appends a WHERE NOT IN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// The values. /// The Sql statement. public static Sql WhereNotIn(this Sql sql, Expression> field, IEnumerable values) { var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(field); sql.Where(fieldName + " NOT IN (@values)", new { values }); return sql; } /// /// Appends a WHERE NOT IN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// A subquery returning the value. /// The Sql statement. public static Sql WhereNotIn(this Sql sql, Expression> field, Sql values) { return sql.WhereIn(field, values, true); } /// /// Appends multiple OR WHERE IN clauses to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expressions specifying the fields. /// The values. /// The Sql statement. public static Sql WhereAnyIn(this Sql sql, Expression>[] fields, IEnumerable values) { ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var fieldNames = fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); var sb = new StringBuilder(); sb.Append("("); for (var i = 0; i < fieldNames.Length; i++) { if (i > 0) { sb.Append(" OR "); } sb.Append(fieldNames[i]); sql.Append(" IN (@values)"); } sb.Append(")"); sql.Where(sb.ToString(), new { values }); return sql; } private static Sql WhereIn(this Sql sql, Expression> fieldSelector, Sql? valuesSql, bool not) { return WhereIn(sql, fieldSelector, valuesSql, not, null); } private static Sql WhereIn(this Sql sql, Expression> fieldSelector, Sql? valuesSql, bool not, string? tableAlias) { var fieldName = sql.SqlContext.SqlSyntax.GetFieldName(fieldSelector, tableAlias); sql.Where(fieldName + (not ? " NOT" : "") +" IN (" + valuesSql?.SQL + ")", valuesSql?.Arguments); return sql; } /// /// Appends multiple OR WHERE clauses to the Sql statement. /// /// The Sql statement. /// The WHERE predicates. /// The Sql statement. public static Sql WhereAny(this Sql sql, params Func, Sql>[] predicates) { var wsql = new Sql(sql.SqlContext); wsql.Append("("); for (var i = 0; i < predicates.Length; i++) { if (i > 0) { wsql.Append(") OR ("); } var temp = new Sql(sql.SqlContext); temp = predicates[i](temp); wsql.Append(temp.SQL.TrimStartExact("WHERE "), temp.Arguments); } wsql.Append(")"); return sql.Where(wsql.SQL, wsql.Arguments); } /// /// Appends a WHERE NOT NULL clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expression specifying the field. /// An optional alias for the table. /// The Sql statement. public static Sql WhereNotNull(this Sql sql, Expression> field, string? tableAlias = null) { return sql.WhereNull(field, tableAlias, true); } /// /// Appends a WHERE [NOT] NULL clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expression specifying the field. /// An optional alias for the table. /// A value indicating whether to NOT NULL. /// The Sql statement. public static Sql WhereNull(this Sql sql, Expression> field, string? tableAlias = null, bool not = false) { var column = sql.GetColumns(columnExpressions: new[] { field }, tableAlias: tableAlias, withAlias: false).First(); return sql.Where("(" + column + " IS " + (not ? "NOT " : "") + "NULL)"); } #endregion #region From /// /// Appends a FROM clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An optional table alias /// The Sql statement. public static Sql From(this Sql sql, string? alias = null) { Type type = typeof (TDto); var tableName = type.GetTableName(); var from = sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName); if (!string.IsNullOrWhiteSpace(alias)) { from += " " + sql.SqlContext.SqlSyntax.GetQuotedTableName(alias); } sql.From(from); return sql; } #endregion #region OrderBy, GroupBy /// /// Appends an ORDER BY clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// The Sql statement. public static Sql OrderBy(this Sql sql, Expression> field) { return sql.OrderBy("(" + sql.SqlContext.SqlSyntax.GetFieldName(field) + ")"); } public static Sql OrderBy(this Sql sql, Expression> field, string alias) { return sql.OrderBy("(" + sql.SqlContext.SqlSyntax.GetFieldName(field, alias) + ")"); } /// /// Appends an ORDER BY clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expression specifying the fields. /// The Sql statement. public static Sql OrderBy(this Sql sql, params Expression>[] fields) { ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.OrderBy(columns); } /// /// Appends an ORDER BY DESC clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// The Sql statement. public static Sql OrderByDescending(this Sql sql, Expression> field) { return sql.OrderByDescending(sql.SqlContext.SqlSyntax.GetFieldName(field)); } /// /// Appends an ORDER BY DESC clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expression specifying the fields. /// The Sql statement. public static Sql OrderByDescending(this Sql sql, params Expression>[] fields) { ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.OrderByDescending(columns); } /// /// Appends an ORDER BY DESC clause to the Sql statement. /// /// The Sql statement. /// Fields. /// The Sql statement. public static Sql OrderByDescending(this Sql sql, params string?[] fields) { return sql.Append("ORDER BY " + string.Join(", ", fields.Select(x => x + " DESC"))); } /// /// Appends a GROUP BY clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An expression specifying the field. /// The Sql statement. public static Sql GroupBy(this Sql sql, Expression> field) { return sql.GroupBy(sql.SqlContext.SqlSyntax.GetFieldName(field)); } /// /// Appends a GROUP BY clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expression specifying the fields. /// The Sql statement. public static Sql GroupBy(this Sql sql, params Expression>[] fields) { ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.GroupBy(columns); } /// /// Appends more ORDER BY or GROUP BY fields to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expressions specifying the fields. /// The Sql statement. public static Sql AndBy(this Sql sql, params Expression>[] fields) { ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.Append(", " + string.Join(", ", columns)); } public static Sql AndBy(this Sql sql, string tableAlias, params Expression>[] fields) { ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x, tableAlias)).ToArray(); return sql.Append(", " + string.Join(", ", columns)); } /// /// Appends more ORDER BY DESC fields to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// Expressions specifying the fields. /// The Sql statement. public static Sql AndByDescending(this Sql sql, params Expression>[] fields) { ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); return sql.Append(", " + string.Join(", ", columns.Select(x => x + " DESC"))); } #endregion #region Joins /// /// Appends a CROSS JOIN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An optional alias for the joined table. /// The Sql statement. public static Sql CrossJoin(this Sql sql, string? alias = null) { Type type = typeof(TDto); var tableName = type.GetTableName(); var join = sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName); if (alias != null) { join += " " + sql.SqlContext.SqlSyntax.GetQuotedTableName(alias); } return sql.Append("CROSS JOIN " + join); } /// /// Appends an INNER JOIN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An optional alias for the joined table. /// A SqlJoin statement. public static Sql.SqlJoinClause InnerJoin(this Sql sql, string? alias = null) { Type type = typeof(TDto); var tableName = type.GetTableName(); var join = sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName); if (alias != null) { join += " " + sql.SqlContext.SqlSyntax.GetQuotedTableName(alias); } return sql.InnerJoin(join); } /// /// Appends an INNER JOIN clause using a nested query. /// /// The SQL statement. /// The nested sql query. /// An optional alias for the joined table. /// A SqlJoin statement. public static Sql.SqlJoinClause InnerJoin(this Sql sql, Sql nestedSelect, string? alias = null) { var join = $"({nestedSelect.SQL})"; if (alias is not null) { join += " " + sql.SqlContext.SqlSyntax.GetQuotedTableName(alias); } return sql.InnerJoin(join); } /// /// Appends a LEFT JOIN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An optional alias for the joined table. /// A SqlJoin statement. public static Sql.SqlJoinClause LeftJoin(this Sql sql, string? alias = null) { Type type = typeof(TDto); var tableName = type.GetTableName(); var join = sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName); if (alias != null) { join += " " + sql.SqlContext.SqlSyntax.GetQuotedTableName(alias); } return sql.LeftJoin(join); } /// /// Appends a LEFT JOIN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// A nested join statement. /// An optional alias for the joined table. /// A SqlJoin statement. /// Nested statement produces LEFT JOIN xxx JOIN yyy ON ... ON ... public static Sql.SqlJoinClause LeftJoin( this Sql sql, Func, Sql> nestedJoin, string? alias = null) => sql.SqlContext.SqlSyntax.LeftJoinWithNestedJoin(sql, nestedJoin, alias); /// /// Appends an LEFT JOIN clause using a nested query. /// /// The SQL statement. /// The nested sql query. /// An optional alias for the joined table. /// A SqlJoin statement. public static Sql.SqlJoinClause LeftJoin(this Sql sql, Sql nestedSelect, string? alias = null) { var join = $"({nestedSelect.SQL})"; if (alias is not null) { join += " " + sql.SqlContext.SqlSyntax.GetQuotedTableName(alias); } return sql.LeftJoin(join); } /// /// Appends a RIGHT JOIN clause to the Sql statement. /// /// The type of the Dto. /// The Sql statement. /// An optional alias for the joined table. /// A SqlJoin statement. public static Sql.SqlJoinClause RightJoin(this Sql sql, string? alias = null) { Type type = typeof(TDto); var tableName = type.GetTableName(); var join = sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName); if (alias != null) { join += " " + sql.SqlContext.SqlSyntax.GetQuotedTableName(alias); } return sql.RightJoin(join); } /// /// Appends an ON clause to a SqlJoin statement. /// /// The type of the left Dto. /// The type of the right Dto. /// The Sql join statement. /// An expression specifying the left field. /// An expression specifying the right field. /// The Sql statement. public static Sql On(this Sql.SqlJoinClause sqlJoin, Expression> leftField, Expression> rightField) { // TODO: ugly - should define on SqlContext! var xLeft = new Sql(sqlJoin.SqlContext).Columns(leftField); var xRight = new Sql(sqlJoin.SqlContext).Columns(rightField); return sqlJoin.On(xLeft + " = " + xRight); //var sqlSyntax = clause.SqlContext.SqlSyntax; //var leftType = typeof (TLeft); //var rightType = typeof (TRight); //var leftTableName = leftType.GetTableName(); //var rightTableName = rightType.GetTableName(); //var leftColumn = ExpressionHelper.FindProperty(leftMember) as PropertyInfo; //var rightColumn = ExpressionHelper.FindProperty(rightMember) as PropertyInfo; //var leftColumnName = leftColumn.GetColumnName(); //var rightColumnName = rightColumn.GetColumnName(); //string onClause = $"{sqlSyntax.GetQuotedTableName(leftTableName)}.{sqlSyntax.GetQuotedColumnName(leftColumnName)} = {sqlSyntax.GetQuotedTableName(rightTableName)}.{sqlSyntax.GetQuotedColumnName(rightColumnName)}"; //return clause.On(onClause); } /// /// Appends an ON clause to a SqlJoin statement. /// /// The Sql join statement. /// A Sql fragment to use as the ON clause body. /// The Sql statement. public static Sql On(this Sql.SqlJoinClause sqlJoin, Func, Sql> on) { var sql = new Sql(sqlJoin.SqlContext); sql = on(sql); var text = sql.SQL.Trim().TrimStartExact("WHERE").Trim(); return sqlJoin.On(text, sql.Arguments); } /// /// Appends an ON clause to a SqlJoin statement. /// /// The type of Dto 1. /// The type of Dto 2. /// The SqlJoin statement. /// A predicate to transform and use as the ON clause body. /// An optional alias for Dto 1 table. /// An optional alias for Dto 2 table. /// The Sql statement. public static Sql On(this Sql.SqlJoinClause sqlJoin, Expression> predicate, string? aliasLeft = null, string? aliasRight = null) { var expresionist = new PocoToSqlExpressionVisitor(sqlJoin.SqlContext, aliasLeft, aliasRight); var onExpression = expresionist.Visit(predicate); return sqlJoin.On(onExpression, expresionist.GetSqlParameters()); } /// /// Appends an ON clause to a SqlJoin statement. /// /// The type of Dto 1. /// The type of Dto 2. /// The type of Dto 3. /// The SqlJoin statement. /// A predicate to transform and use as the ON clause body. /// An optional alias for Dto 1 table. /// An optional alias for Dto 2 table. /// An optional alias for Dto 3 table. /// The Sql statement. public static Sql On(this Sql.SqlJoinClause sqlJoin, Expression> predicate, string? aliasLeft = null, string? aliasRight = null, string? aliasOther = null) { var expresionist = new PocoToSqlExpressionVisitor(sqlJoin.SqlContext, aliasLeft, aliasRight, aliasOther); var onExpression = expresionist.Visit(predicate); return sqlJoin.On(onExpression, expresionist.GetSqlParameters()); } #endregion #region Select /// /// Alters a Sql statement to return a maximum amount of rows. /// /// The Sql statement. /// The maximum number of rows to return. /// The Sql statement. public static Sql SelectTop(this Sql sql, int count) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return sql.SqlContext.SqlSyntax.SelectTop(sql, count); } /// /// Creates a SELECT COUNT(*) Sql statement. /// /// The origin sql. /// An optional alias. /// The Sql statement. public static Sql SelectCount(this Sql sql, string? alias = null) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } var text = "COUNT(*)"; if (alias != null) { text += " AS " + sql.SqlContext.SqlSyntax.GetQuotedColumnName(alias); } return sql.Select(text); } /// /// Creates a SELECT COUNT Sql statement. /// /// The type of the DTO to count. /// The origin sql. /// Expressions indicating the columns to count. /// The Sql statement. /// /// If is empty, all columns are counted. /// public static Sql SelectCount(this Sql sql, params Expression>[] fields) => sql.SelectCount(null, fields); /// /// Creates a SELECT COUNT Sql statement. /// /// The type of the DTO to count. /// The origin sql. /// An alias. /// Expressions indicating the columns to count. /// The Sql statement. /// /// If is empty, all columns are counted. /// public static Sql SelectCount(this Sql sql, string? alias, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); var text = "COUNT (" + string.Join(", ", columns) + ")"; if (alias != null) { text += " AS " + sql.SqlContext.SqlSyntax.GetQuotedColumnName(alias); } return sql.Select(text); } /// /// Creates a SELECT * Sql statement. /// /// The origin sql. /// The Sql statement. public static Sql SelectAll(this Sql sql) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return sql.Select("*"); } /// /// Creates a SELECT Sql statement. /// /// The type of the DTO to select. /// The origin sql. /// Expressions indicating the columns to select. /// The Sql statement. /// /// If is empty, all columns are selected. /// public static Sql Select(this Sql sql, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return sql.Select(sql.GetColumns(columnExpressions: fields)); } /// /// Creates a SELECT DISTINCT Sql statement. /// /// The type of the DTO to select. /// The origin sql. /// Expressions indicating the columns to select. /// The Sql statement. /// /// If is empty, all columns are selected. /// public static Sql SelectDistinct(this Sql sql, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } var columns = sql.GetColumns(columnExpressions: fields); sql.Append("SELECT DISTINCT " + string.Join(", ", columns)); return sql; } public static Sql SelectDistinct(this Sql sql, params object[] columns) { sql.Append("SELECT DISTINCT " + string.Join(", ", columns)); return sql; } //this.Append("SELECT " + string.Join(", ", columns), new object[0]); /// /// Creates a SELECT Sql statement. /// /// The type of the DTO to select. /// The origin sql. /// A table alias. /// Expressions indicating the columns to select. /// The Sql statement. /// /// If is empty, all columns are selected. /// public static Sql Select(this Sql sql, string tableAlias, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return sql.Select(sql.GetColumns(tableAlias: tableAlias, columnExpressions: fields)); } /// /// Adds columns to a SELECT Sql statement. /// /// The origin sql. /// Columns to select. /// The Sql statement. public static Sql AndSelect(this Sql sql, params string[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return sql.Append(", " + string.Join(", ", fields)); } /// /// Adds columns to a SELECT Sql statement. /// /// The type of the DTO to select. /// The origin sql. /// Expressions indicating the columns to select. /// The Sql statement. /// /// If is empty, all columns are selected. /// public static Sql AndSelect(this Sql sql, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return sql.Append(", " + string.Join(", ", sql.GetColumns(columnExpressions: fields))); } /// /// Adds columns to a SELECT Sql statement. /// /// The type of the DTO to select. /// The origin sql. /// A table alias. /// Expressions indicating the columns to select. /// The Sql statement. /// /// If is empty, all columns are selected. /// public static Sql AndSelect(this Sql sql, string tableAlias, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return sql.Append(", " + string.Join(", ", sql.GetColumns(tableAlias: tableAlias, columnExpressions: fields))); } /// /// Adds a COUNT(*) to a SELECT Sql statement. /// /// The origin sql. /// An optional alias. /// The Sql statement. public static Sql AndSelectCount(this Sql sql, string? alias = null) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } var text = ", COUNT(*)"; if (alias != null) { text += " AS " + sql.SqlContext.SqlSyntax.GetQuotedColumnName(alias); } return sql.Append(text); } /// /// Adds a COUNT to a SELECT Sql statement. /// /// The type of the DTO to count. /// The origin sql. /// Expressions indicating the columns to count. /// The Sql statement. /// /// If is empty, all columns are counted. /// public static Sql AndSelectCount(this Sql sql, params Expression>[] fields) => sql.AndSelectCount(null, fields); /// /// Adds a COUNT to a SELECT Sql statement. /// /// The type of the DTO to count. /// The origin sql. /// An alias. /// Expressions indicating the columns to count. /// The Sql statement. /// /// If is empty, all columns are counted. /// public static Sql AndSelectCount(this Sql sql, string? alias = null, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } ISqlSyntaxProvider sqlSyntax = sql.SqlContext.SqlSyntax; var columns = fields.Length == 0 ? sql.GetColumns(withAlias: false) : fields.Select(x => sqlSyntax.GetFieldName(x)).ToArray(); var text = ", COUNT (" + string.Join(", ", columns) + ")"; if (alias != null) { text += " AS " + sql.SqlContext.SqlSyntax.GetQuotedColumnName(alias); } return sql.Append(text); } /// /// Creates a SELECT Sql statement with a referenced Dto. /// /// The type of the Dto to select. /// The origin Sql. /// An expression specifying the reference. /// The Sql statement. public static Sql Select(this Sql sql, Func, SqlRef> reference) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } sql.Select(sql.GetColumns()); reference?.Invoke(new SqlRef(sql, null)); return sql; } /// /// Creates a SELECT Sql statement with a referenced Dto. /// /// The type of the Dto to select. /// The origin Sql. /// An expression specifying the reference. /// An expression to apply to the Sql statement before adding the reference selection. /// The Sql statement. /// The expression applies to the Sql statement before the reference selection /// is added, so that it is possible to add (e.g. calculated) columns to the referencing Dto. public static Sql Select(this Sql sql, Func, SqlRef> reference, Func, Sql> sqlexpr) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } sql.Select(sql.GetColumns()); sql = sqlexpr(sql); reference(new SqlRef(sql, null)); return sql; } /// /// Creates a SELECT CASE WHEN EXISTS query, which returns 1 if the sub query returns any results, and 0 if not. /// /// The original SQL. /// The nested select to run the query against. /// The updated Sql statement. public static Sql SelectAnyIfExists(this Sql sql, Sql nestedSelect) { sql.Append("SELECT CASE WHEN EXISTS ("); sql.Append(nestedSelect); sql.Append(")"); sql.Append("THEN 1 ELSE 0 END"); return sql; } /// /// Represents a Dto reference expression. /// /// The type of the referencing Dto. public class SqlRef { /// /// Initializes a new Dto reference expression. /// /// The original Sql expression. /// The current Dtos prefix. public SqlRef(Sql sql, string? prefix) { Sql = sql; Prefix = prefix; } /// /// Gets the original Sql expression. /// public Sql Sql { get; } /// /// Gets the current Dtos prefix. /// public string? Prefix { get; } /// /// Appends fields for a referenced Dto. /// /// The type of the referenced Dto. /// An expression specifying the referencing field. /// An optional expression representing a nested reference selection. /// A SqlRef statement. public SqlRef Select(Expression> field, Func, SqlRef>? reference = null) => Select(field, null, reference); /// /// Appends fields for a referenced Dto. /// /// The type of the referenced Dto. /// An expression specifying the referencing field. /// The referenced Dto table alias. /// An optional expression representing a nested reference selection. /// A SqlRef statement. public SqlRef Select(Expression> field, string? tableAlias, Func, SqlRef>? reference = null) { PropertyInfo? property = field == null ? null : ExpressionHelper.FindProperty(field).Item1 as PropertyInfo; return Select(property, tableAlias, reference); } /// /// Selects referenced DTOs. /// /// The type of the referenced DTOs. /// An expression specifying the referencing field. /// An optional expression representing a nested reference selection. /// A referenced DTO expression. /// /// The referencing property has to be a List{}. /// public SqlRef Select(Expression>> field, Func, SqlRef>? reference = null) => Select(field, null, reference); /// /// Selects referenced DTOs. /// /// The type of the referenced DTOs. /// An expression specifying the referencing field. /// The DTO table alias. /// An optional expression representing a nested reference selection. /// A referenced DTO expression. /// /// The referencing property has to be a List{}. /// public SqlRef Select(Expression>> field, string? tableAlias, Func, SqlRef>? reference = null) { PropertyInfo? property = field == null ? null : ExpressionHelper.FindProperty(field).Item1 as PropertyInfo; return Select(property, tableAlias, reference); } private SqlRef Select(PropertyInfo? propertyInfo, string? tableAlias, Func, SqlRef>? nested = null) { var referenceName = propertyInfo?.Name ?? typeof (TDto).Name; if (Prefix != null) { referenceName = Prefix + PocoData.Separator + referenceName; } var columns = Sql.GetColumns(tableAlias, referenceName); Sql.Append(", " + string.Join(", ", columns)); nested?.Invoke(new SqlRef(Sql, referenceName)); return this; } } /// /// Gets fields for a Dto. /// /// The type of the Dto. /// The origin sql. /// Expressions specifying the fields. /// The comma-separated list of fields. /// /// If is empty, all fields are selected. /// public static string Columns(this Sql sql, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return string.Join(", ", sql.GetColumns(columnExpressions: fields, withAlias: false)); } /// /// Gets fields for a Dto. /// public static string ColumnsForInsert(this Sql sql, params Expression>[]? fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return string.Join(", ", sql.GetColumns(columnExpressions: fields, withAlias: false, forInsert: true)); } /// /// Gets fields for a Dto. /// /// The type of the Dto. /// The origin sql. /// The Dto table alias. /// Expressions specifying the fields. /// The comma-separated list of fields. /// /// If is empty, all fields are selected. /// public static string Columns(this Sql sql, string alias, params Expression>[] fields) { if (sql == null) { throw new ArgumentNullException(nameof(sql)); } return string.Join(", ", sql.GetColumns(columnExpressions: fields, withAlias: false, tableAlias: alias)); } #endregion #region Delete public static Sql Delete(this Sql sql) { sql.Append("DELETE"); return sql; } public static Sql Delete(this Sql sql) { Type type = typeof(TDto); var tableName = type.GetTableName(); // FROM optional SQL server, but not elsewhere. sql.Append($"DELETE FROM {sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName)}"); return sql; } #endregion #region Update public static Sql Update(this Sql sql) { sql.Append("UPDATE"); return sql; } public static Sql Update(this Sql sql) { Type type = typeof(TDto); var tableName = type.GetTableName(); sql.Append($"UPDATE {sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName)}"); return sql; } public static Sql Update(this Sql sql, Func, SqlUpd> updates) { Type type = typeof(TDto); var tableName = type.GetTableName(); sql.Append($"UPDATE {sql.SqlContext.SqlSyntax.GetQuotedTableName(tableName)} SET"); var u = new SqlUpd(sql.SqlContext); u = updates(u); var first = true; foreach (Tuple setExpression in u.SetExpressions) { switch (setExpression.Item2) { case null: sql.Append((first ? "" : ",") + " " + setExpression.Item1 + "=NULL"); break; case string s when s == string.Empty: sql.Append((first ? "" : ",") + " " + setExpression.Item1 + "=''"); break; default: sql.Append((first ? "" : ",") + " " + setExpression.Item1 + "=@0", setExpression.Item2); break; } first = false; } if (!first) { sql.Append(" "); } return sql; } public class SqlUpd { private readonly ISqlContext _sqlContext; private readonly List> _setExpressions = new List>(); public SqlUpd(ISqlContext sqlContext) { _sqlContext = sqlContext; } public SqlUpd Set(Expression> fieldSelector, object? value) { var fieldName = _sqlContext.SqlSyntax.GetFieldNameForUpdate(fieldSelector); _setExpressions.Add(new Tuple(fieldName, value)); return this; } public List> SetExpressions => _setExpressions; } #endregion #region Hints /// /// Appends the relevant ForUpdate hint. /// /// The Sql statement. /// The Sql statement. /// /// NOTE: This method will not work for all queries, only simple ones! /// public static Sql ForUpdate(this Sql sql) => sql.SqlContext.SqlSyntax.InsertForUpdateHint(sql); public static Sql AppendForUpdateHint(this Sql sql) => sql.SqlContext.SqlSyntax.AppendForUpdateHint(sql); #endregion #region Aliasing internal static string GetAliasedField(this Sql sql, string field) { // get alias, if aliased // // regex looks for pattern "([\w+].[\w+]) AS ([\w+])" ie "(field) AS (alias)" // and, if found & a group's field matches the field name, returns the alias // // so... if query contains "[umbracoNode].[nodeId] AS [umbracoNode__nodeId]" // then GetAliased for "[umbracoNode].[nodeId]" returns "[umbracoNode__nodeId]" MatchCollection matches = sql.SqlContext.SqlSyntax.AliasRegex.Matches(sql.SQL); Match? match = matches.Cast().FirstOrDefault(m => m.Groups[1].Value.InvariantEquals(field)); return match == null ? field : match.Groups[2].Value; } #endregion #region Utilities private static string[] GetColumns(this Sql sql, string? tableAlias = null, string? referenceName = null, Expression>[]? columnExpressions = null, bool withAlias = true, bool forInsert = false) { PocoData? pd = sql.SqlContext.PocoDataFactory.ForType(typeof (TDto)); var tableName = tableAlias ?? pd.TableInfo.TableName; var queryColumns = pd.QueryColumns.ToList(); Dictionary? aliases = null; if (columnExpressions != null && columnExpressions.Length > 0) { var names = columnExpressions.Select(x => { (MemberInfo member, var alias) = ExpressionHelper.FindProperty(x); var field = member as PropertyInfo; var fieldName = field?.GetColumnName(); if (alias != null && fieldName is not null) { if (aliases == null) { aliases = new Dictionary(); } aliases[fieldName] = alias; } return fieldName; }).ToArray(); //only get the columns that exist in the selected names queryColumns = queryColumns.Where(x => names.Contains(x.Key)).ToList(); //ensure the order of the columns in the expressions is the order in the result queryColumns.Sort((a, b) => names.IndexOf(a.Key).CompareTo(names.IndexOf(b.Key))); } string? GetAlias(PocoColumn column) { if (aliases != null && aliases.TryGetValue(column.ColumnName, out var alias)) { return alias; } return withAlias ? (string.IsNullOrEmpty(column.ColumnAlias) ? column.MemberInfoKey : column.ColumnAlias) : null; } return queryColumns .Select(x => sql.SqlContext.SqlSyntax.GetColumn(sql.SqlContext.DatabaseType, tableName, x.Value.ColumnName, GetAlias(x.Value)!, referenceName, forInsert: forInsert)) .ToArray(); } public static string GetTableName(this Type type) { // TODO: returning string.Empty for now // BUT the code bits that calls this method cannot deal with string.Empty so we // should either throw, or fix these code bits... TableNameAttribute? attr = type.FirstAttribute(); return string.IsNullOrWhiteSpace(attr?.Value) ? string.Empty : attr.Value; } private static string GetColumnName(this PropertyInfo column) { ColumnAttribute? attr = column.FirstAttribute(); return string.IsNullOrWhiteSpace(attr?.Name) ? column.Name : attr.Name; } public static string ToText(this Sql sql) { var text = new StringBuilder(); sql.ToText(text); return text.ToString(); } public static void ToText(this Sql sql, StringBuilder text) { ToText(sql.SQL, sql.Arguments, text); } public static void ToText(string? sql, object[]? arguments, StringBuilder text) { text.AppendLine(sql); if (arguments == null || arguments.Length == 0) { return; } text.Append(" --"); var i = 0; foreach (var arg in arguments) { text.Append(" @"); text.Append(i++); text.Append(":"); text.Append(arg); } text.AppendLine(); } #endregion } }