using System.Data; using System.Data.Common; using System.Text; using Microsoft.Extensions.Logging; using NPoco; using Umbraco.Cms.Infrastructure.Migrations.Install; using Umbraco.Extensions; namespace Umbraco.Cms.Infrastructure.Persistence; /// /// Extends NPoco Database for Umbraco. /// /// /// /// Is used everywhere in place of the original NPoco Database object, and provides additional features /// such as profiling, retry policies, logging, etc. /// /// Is never created directly but obtained from the . /// public class UmbracoDatabase : Database, IUmbracoDatabase { private readonly ILogger _logger; private readonly IBulkSqlInsertProvider? _bulkSqlInsertProvider; private readonly DatabaseSchemaCreatorFactory? _databaseSchemaCreatorFactory; private readonly IEnumerable? _mapperCollection; private readonly Guid _instanceGuid = Guid.NewGuid(); private List? _commands; #region Ctor /// /// Initializes a new instance of the class. /// /// /// Used by UmbracoDatabaseFactory to create databases. /// Also used by DatabaseBuilder for creating databases and installing/upgrading. /// public UmbracoDatabase( string connectionString, ISqlContext sqlContext, DbProviderFactory provider, ILogger logger, IBulkSqlInsertProvider? bulkSqlInsertProvider, DatabaseSchemaCreatorFactory databaseSchemaCreatorFactory, IEnumerable? mapperCollection = null) : base(connectionString, sqlContext.DatabaseType, provider, sqlContext.SqlSyntax.DefaultIsolationLevel) { SqlContext = sqlContext; _logger = logger; _bulkSqlInsertProvider = bulkSqlInsertProvider; _databaseSchemaCreatorFactory = databaseSchemaCreatorFactory; _mapperCollection = mapperCollection; Init(); } /// /// Initializes a new instance of the class. /// /// Internal for unit tests only. internal UmbracoDatabase( DbConnection connection, ISqlContext sqlContext, ILogger logger, IBulkSqlInsertProvider bulkSqlInsertProvider) : base(connection, sqlContext.DatabaseType, sqlContext.SqlSyntax.DefaultIsolationLevel) { SqlContext = sqlContext; _logger = logger; _bulkSqlInsertProvider = bulkSqlInsertProvider; Init(); } private void Init() { EnableSqlTrace = EnableSqlTraceDefault; if (_mapperCollection != null) { Mappers.AddRange(_mapperCollection); } } #endregion /// public ISqlContext SqlContext { get; } #region Testing, Debugging and Troubleshooting private bool _enableCount; #if DEBUG_DATABASES private int _spid = -1; private const bool EnableSqlTraceDefault = true; #else private string? _instanceId; private const bool EnableSqlTraceDefault = false; #endif /// public string InstanceId => #if DEBUG_DATABASES _instanceGuid.ToString("N").Substring(0, 8) + ':' + _spid; #else _instanceId ??= _instanceGuid.ToString("N").Substring(0, 8); #endif /// public bool InTransaction { get; private set; } protected override void OnBeginTransaction() { base.OnBeginTransaction(); InTransaction = true; } protected override void OnAbortTransaction() { InTransaction = false; base.OnAbortTransaction(); } protected override void OnCompleteTransaction() { InTransaction = false; base.OnCompleteTransaction(); } /// /// Gets or sets a value indicating whether to log all executed Sql statements. /// internal bool EnableSqlTrace { get; set; } /// /// Gets or sets a value indicating whether to count all executed Sql statements. /// public bool EnableSqlCount { get => _enableCount; set { _enableCount = value; if (_enableCount == false) { SqlCount = 0; } } } /// /// Gets the count of all executed Sql statements. /// public int SqlCount { get; private set; } internal bool LogCommands { get => _commands != null; set => _commands = value ? new List() : null; } internal IEnumerable? Commands => _commands; public int BulkInsertRecords(IEnumerable records) => _bulkSqlInsertProvider?.BulkInsertRecords(this, records) ?? 0; /// /// Returns the for the database /// public DatabaseSchemaResult ValidateSchema() { DatabaseSchemaCreator? dbSchema = _databaseSchemaCreatorFactory?.Create(this); DatabaseSchemaResult? databaseSchemaValidationResult = dbSchema?.ValidateSchema(); return databaseSchemaValidationResult ?? new DatabaseSchemaResult(); } /// /// Returns true if Umbraco database tables are detected to be installed /// public bool IsUmbracoInstalled() => ValidateSchema().DetermineHasInstalledVersion(); #endregion #region OnSomething protected override DbConnection OnConnectionOpened(DbConnection connection) { if (connection == null) { throw new ArgumentNullException(nameof(connection)); } // TODO: this should probably move to a SQL Server ProviderSpecificInterceptor. #if DEBUG_DATABASES // determines the database connection SPID for debugging if (DatabaseType.IsSqlServer()) { using (var command = connection.CreateCommand()) { command.CommandText = "SELECT @@SPID"; _spid = Convert.ToInt32(command.ExecuteScalar()); } } else { // includes SqlCE _spid = 0; } #endif return connection; } #if DEBUG_DATABASES protected override void OnConnectionClosing(DbConnection conn) { _spid = -1; base.OnConnectionClosing(conn); } #endif protected override void OnException(Exception ex) { _logger.LogError(ex, "Exception ({InstanceId}).", InstanceId); _logger.LogDebug("At:\r\n{StackTrace}", Environment.StackTrace); if (EnableSqlTrace == false) { _logger.LogDebug("Sql:\r\n{Sql}", CommandToString(LastSQL, LastArgs)); } base.OnException(ex); } private DbCommand? _cmd; protected override void OnExecutingCommand(DbCommand cmd) { // if no timeout is specified, and the connection has a longer timeout, use it if (OneTimeCommandTimeout == 0 && CommandTimeout == 0 && cmd.Connection?.ConnectionTimeout > 30) { cmd.CommandTimeout = cmd.Connection.ConnectionTimeout; } if (EnableSqlTrace) { _logger.LogDebug("SQL Trace:\r\n{Sql}", CommandToString(cmd).Replace("{", "{{").Replace("}", "}}")); // TODO: these escapes should be builtin } #if DEBUG_DATABASES // detects whether the command is already in use (eg still has an open reader...) DatabaseDebugHelper.SetCommand(cmd, InstanceId + " [T" + System.Threading.Thread.CurrentThread.ManagedThreadId + "]"); var refsobj = DatabaseDebugHelper.GetReferencedObjects(cmd.Connection); if (refsobj != null) _logger.LogDebug("Oops!" + Environment.NewLine + refsobj); #endif _cmd = cmd; base.OnExecutingCommand(cmd); } private string CommandToString(DbCommand cmd) => CommandToString(cmd.CommandText, cmd.Parameters.Cast().Select(x => x.Value).WhereNotNull().ToArray()); private string CommandToString(string? sql, object[]? args) { var text = new StringBuilder(); #if DEBUG_DATABASES text.Append(InstanceId); text.Append(": "); #endif NPocoSqlExtensions.ToText(sql, args, text); return text.ToString(); } protected override void OnExecutedCommand(DbCommand cmd) { if (_enableCount) { SqlCount++; } _commands?.Add(new CommandInfo(cmd)); base.OnExecutedCommand(cmd); } #endregion // used for tracking commands public class CommandInfo { public CommandInfo(IDbCommand cmd) { Text = cmd.CommandText; var parameters = new List(); foreach (IDbDataParameter parameter in cmd.Parameters) { parameters.Add(new ParameterInfo(parameter)); } Parameters = parameters.ToArray(); } public string Text { get; } public ParameterInfo[] Parameters { get; } } // used for tracking commands public class ParameterInfo { public ParameterInfo(IDbDataParameter parameter) { Name = parameter.ParameterName; Value = parameter.Value; DbType = parameter.DbType; Size = parameter.Size; } public string Name { get; } public object? Value { get; } public DbType DbType { get; } public int Size { get; } } /// public new T ExecuteScalar(string sql, params object[] args) => ExecuteScalar(new Sql(sql, args)); /// public new T ExecuteScalar(Sql sql) => ExecuteScalar(sql.SQL, CommandType.Text, sql.Arguments); /// /// /// Be nice if handled upstream GH issue /// public new T ExecuteScalar(string sql, CommandType commandType, params object[] args) { if (SqlContext.SqlSyntax.ScalarMappers == null) { return base.ExecuteScalar(sql, commandType, args); } if (!SqlContext.SqlSyntax.ScalarMappers.TryGetValue(typeof(T), out IScalarMapper? mapper)) { return base.ExecuteScalar(sql, commandType, args); } var result = base.ExecuteScalar(sql, commandType, args); return (T)mapper.Map(result); } }