using Microsoft.Extensions.Logging; using NPoco; using Umbraco.Cms.Core; using Umbraco.Cms.Core.Cache; using Umbraco.Cms.Core.Models.Entities; using Umbraco.Cms.Core.Persistence; using Umbraco.Cms.Core.Persistence.Querying; using Umbraco.Cms.Core.Scoping; using Umbraco.Cms.Infrastructure.Persistence.Querying; using Umbraco.Cms.Infrastructure.Scoping; using Umbraco.Extensions; namespace Umbraco.Cms.Infrastructure.Persistence.Repositories.Implement; /// /// Provides a base class to all based repositories. /// /// The type of the entity's unique identifier. /// The type of the entity managed by this repository. public abstract class EntityRepositoryBase : RepositoryBase, IReadWriteQueryRepository where TEntity : class, IEntity { private static RepositoryCachePolicyOptions? _defaultOptions; private IRepositoryCachePolicy? _cachePolicy; private IQuery? _hasIdQuery; /// /// Initializes a new instance of the class. /// protected EntityRepositoryBase(IScopeAccessor scopeAccessor, AppCaches appCaches, ILogger> logger) : base(scopeAccessor, appCaches) => Logger = logger ?? throw new ArgumentNullException(nameof(logger)); /// /// Gets the logger /// protected ILogger> Logger { get; } /// /// Gets the isolated cache for the /// protected IAppPolicyCache GlobalIsolatedCache => AppCaches.IsolatedCaches.GetOrCreate(); /// /// Gets the isolated cache. /// /// Depends on the ambient scope cache mode. protected IAppPolicyCache IsolatedCache { get { switch (AmbientScope.RepositoryCacheMode) { case RepositoryCacheMode.Default: return AppCaches.IsolatedCaches.GetOrCreate(); case RepositoryCacheMode.Scoped: return AmbientScope.IsolatedCaches.GetOrCreate(); case RepositoryCacheMode.None: return NoAppCache.Instance; default: throw new Exception("oops: cache mode."); } } } /// /// Gets the default /// protected virtual RepositoryCachePolicyOptions DefaultOptions => _defaultOptions ??= new RepositoryCachePolicyOptions(() => { // get count of all entities of current type (TEntity) to ensure cached result is correct // create query once if it is needed (no need for locking here) - query is static! IQuery query = _hasIdQuery ??= AmbientScope.SqlContext.Query().Where(x => x.Id != 0); return PerformCount(query); }); /// /// Gets the repository cache policy /// protected IRepositoryCachePolicy CachePolicy { get { if (AppCaches == AppCaches.NoCache) { return NoCacheRepositoryCachePolicy.Instance; } // create the cache policy using IsolatedCache which is either global // or scoped depending on the repository cache mode for the current scope switch (AmbientScope.RepositoryCacheMode) { case RepositoryCacheMode.Default: case RepositoryCacheMode.Scoped: // return the same cache policy in both cases - the cache policy is // supposed to pick either the global or scope cache depending on the // scope cache mode return _cachePolicy ??= CreateCachePolicy(); case RepositoryCacheMode.None: return NoCacheRepositoryCachePolicy.Instance; default: throw new Exception("oops: cache mode."); } } } /// /// Adds or Updates an entity of type TEntity /// /// This method is backed by an cache public virtual void Save(TEntity entity) { if (entity.HasIdentity == false) { CachePolicy.Create(entity, PersistNewItem); } else { CachePolicy.Update(entity, PersistUpdatedItem); } } /// /// Deletes the passed in entity /// public virtual void Delete(TEntity entity) => CachePolicy.Delete(entity, PersistDeletedItem); /// /// Gets an entity by the passed in Id utilizing the repository's cache policy /// public TEntity? Get(TId? id) => CachePolicy.Get(id, PerformGet, PerformGetAll); /// /// Gets all entities of type TEntity or a list according to the passed in Ids /// public IEnumerable GetMany(params TId[]? ids) { // ensure they are de-duplicated, easy win if people don't do this as this can cause many excess queries ids = ids?.Distinct() // don't query by anything that is a default of T (like a zero) // TODO: I think we should enabled this in case accidental calls are made to get all with invalid ids // .Where(x => Equals(x, default(TId)) == false) .ToArray(); // can't query more than 2000 ids at a time... but if someone is really querying 2000+ entities, // the additional overhead of fetching them in groups is minimal compared to the lookup time of each group if (ids?.Length <= Constants.Sql.MaxParameterCount) { return CachePolicy.GetAll(ids, PerformGetAll); } var entities = new List(); foreach (IEnumerable group in ids.InGroupsOf(Constants.Sql.MaxParameterCount)) { TEntity[] groups = CachePolicy.GetAll(group.ToArray(), PerformGetAll); entities.AddRange(groups); } return entities; } /// /// Gets a list of entities by the passed in query /// public IEnumerable Get(IQuery query) => // ensure we don't include any null refs in the returned collection! PerformGetByQuery(query) .WhereNotNull(); /// /// Returns a boolean indicating whether an entity with the passed Id exists /// public bool Exists(TId id) => CachePolicy.Exists(id, PerformExists, PerformGetAll); /// /// Returns an integer with the count of entities found with the passed in query /// public int Count(IQuery query) => PerformCount(query); /// /// Get the entity id for the . /// protected virtual TId GetEntityId(TEntity entity) => (TId)(object)entity.Id; /// /// Create the repository cache policy /// protected virtual IRepositoryCachePolicy CreateCachePolicy() => new DefaultRepositoryCachePolicy(GlobalIsolatedCache, ScopeAccessor, DefaultOptions); protected abstract TEntity? PerformGet(TId? id); protected abstract IEnumerable PerformGetAll(params TId[]? ids); protected abstract IEnumerable PerformGetByQuery(IQuery query); protected abstract void PersistNewItem(TEntity item); protected abstract void PersistUpdatedItem(TEntity item); // TODO: obsolete, use QueryType instead everywhere like GetBaseQuery(QueryType queryType); protected abstract Sql GetBaseQuery(bool isCount); protected abstract string GetBaseWhereClause(); protected abstract IEnumerable GetDeleteClauses(); protected virtual bool PerformExists(TId id) { Sql sql = GetBaseQuery(true); sql.Where(GetBaseWhereClause(), new { id }); var count = Database.ExecuteScalar(sql); return count == 1; } protected virtual int PerformCount(IQuery query) { Sql sqlClause = GetBaseQuery(true); var translator = new SqlTranslator(sqlClause, query); Sql sql = translator.Translate(); return Database.ExecuteScalar(sql); } protected virtual void PersistDeletedItem(TEntity entity) { IEnumerable deletes = GetDeleteClauses(); foreach (var delete in deletes) { Database.Execute(delete, new { id = GetEntityId(entity) }); } entity.DeleteDate = DateTime.Now; } }