using Microsoft.Extensions.Logging; using NPoco; using Umbraco.Cms.Core; using Umbraco.Cms.Core.Cache; using Umbraco.Cms.Core.Models.Entities; using Umbraco.Cms.Core.Persistence.Querying; using Umbraco.Cms.Core.Persistence.Repositories; using Umbraco.Cms.Core.Security; using Umbraco.Cms.Infrastructure.Persistence.Dtos; using Umbraco.Cms.Infrastructure.Persistence.Factories; using Umbraco.Cms.Infrastructure.Persistence.Querying; using Umbraco.Cms.Infrastructure.Scoping; using Umbraco.Extensions; namespace Umbraco.Cms.Infrastructure.Persistence.Repositories.Implement; internal class ExternalLoginRepository : EntityRepositoryBase, IExternalLoginWithKeyRepository { public ExternalLoginRepository(IScopeAccessor scopeAccessor, AppCaches cache, ILogger logger) : base(scopeAccessor, cache, logger) { } /// /// Query for user tokens /// /// /// public IEnumerable Get(IQuery? query) { Sql sqlClause = GetBaseTokenQuery(false); var translator = new SqlTranslator(sqlClause, query); Sql sql = translator.Translate(); List dtos = Database.Fetch(sql); foreach (ExternalLoginTokenDto dto in dtos) { yield return ExternalLoginFactory.BuildEntity(dto); } } /// /// Count for user tokens /// /// /// public int Count(IQuery? query) { Sql sql = Sql().SelectCount().From(); return Database.ExecuteScalar(sql); } /// public void DeleteUserLogins(Guid userOrMemberKey) => Database.Delete("WHERE userOrMemberKey=@userOrMemberKey", new { userOrMemberKey }); /// public void Save(Guid userOrMemberKey, IEnumerable logins) { Sql sql = Sql() .Select() .From() .Where(x => x.UserOrMemberKey == userOrMemberKey) .ForUpdate(); // deduplicate the logins logins = logins.DistinctBy(x => x.ProviderKey + x.LoginProvider).ToList(); var toUpdate = new Dictionary(); var toDelete = new List(); var toInsert = new List(logins); List? existingLogins = Database.Fetch(sql); foreach (ExternalLoginDto? existing in existingLogins) { IExternalLogin? found = logins.FirstOrDefault(x => x.LoginProvider.Equals(existing.LoginProvider, StringComparison.InvariantCultureIgnoreCase) && x.ProviderKey.Equals(existing.ProviderKey, StringComparison.InvariantCultureIgnoreCase)); if (found != null) { toUpdate.Add(existing.Id, found); // if it's an update then it's not an insert toInsert.RemoveAll(x => x.ProviderKey == found.ProviderKey && x.LoginProvider == found.LoginProvider); } else { toDelete.Add(existing.Id); } } // do the deletes, updates and inserts if (toDelete.Count > 0) { // Before we can remove the external login, we must remove the external login tokens associated with that external login, // otherwise we'll get foreign key constraint errors Database.DeleteMany().Where(x => toDelete.Contains(x.ExternalLoginId)).Execute(); Database.DeleteMany().Where(x => toDelete.Contains(x.Id)).Execute(); } foreach (KeyValuePair u in toUpdate) { Database.Update(ExternalLoginFactory.BuildDto(userOrMemberKey, u.Value, u.Key)); } Database.InsertBulk(toInsert.Select(i => ExternalLoginFactory.BuildDto(userOrMemberKey, i))); } /// public void Save(Guid userOrMemberKey, IEnumerable tokens) { // get the existing logins (provider + id) var existingUserLogins = Database .Fetch(GetBaseQuery(false) .Where(x => x.UserOrMemberKey == userOrMemberKey)) .ToDictionary(x => x.LoginProvider, x => x.Id); // deduplicate the tokens tokens = tokens.DistinctBy(x => x.LoginProvider + x.Name).ToList(); var providers = tokens.Select(x => x.LoginProvider).Distinct().ToList(); Sql sql = GetBaseTokenQuery(true) .WhereIn(x => x.LoginProvider, providers) .Where(x => x.UserOrMemberKey == userOrMemberKey); var toUpdate = new Dictionary(); var toDelete = new List(); var toInsert = new List(tokens); List? existingTokens = Database.Fetch(sql); foreach (ExternalLoginTokenDto existing in existingTokens) { IExternalLoginToken? found = tokens.FirstOrDefault(x => x.LoginProvider.InvariantEquals(existing.ExternalLoginDto.LoginProvider) && x.Name.InvariantEquals(existing.Name)); if (found != null) { toUpdate.Add(existing.Id, (found, existing.ExternalLoginId)); // if it's an update then it's not an insert toInsert.RemoveAll(x => x.LoginProvider.InvariantEquals(found.LoginProvider) && x.Name.InvariantEquals(found.Name)); } else { toDelete.Add(existing.Id); } } // do the deletes, updates and inserts if (toDelete.Count > 0) { Database.DeleteMany().Where(x => toDelete.Contains(x.Id)).Execute(); } foreach (KeyValuePair u in toUpdate) { Database.Update(ExternalLoginFactory.BuildDto(u.Value.externalLoginId, u.Value.externalLoginToken, u.Key)); } var insertDtos = new List(); foreach (IExternalLoginToken t in toInsert) { if (!existingUserLogins.TryGetValue(t.LoginProvider, out var externalLoginId)) { throw new InvalidOperationException( $"A token was attempted to be saved for login provider {t.LoginProvider} which is not assigned to this user"); } insertDtos.Add(ExternalLoginFactory.BuildDto(externalLoginId, t)); } Database.InsertBulk(insertDtos); } protected override IIdentityUserLogin? PerformGet(int id) { Sql sql = GetBaseQuery(false); sql.Where(GetBaseWhereClause(), new { id }); ExternalLoginDto? dto = Database.Fetch(SqlSyntax.SelectTop(sql, 1)).FirstOrDefault(); if (dto == null) { return null; } IIdentityUserLogin entity = ExternalLoginFactory.BuildEntity(dto); // reset dirty initial properties (U4-1946) entity.ResetDirtyProperties(false); return entity; } protected override IEnumerable PerformGetAll(params int[]? ids) { if (ids?.Any() ?? false) { return PerformGetAllOnIds(ids); } Sql sql = GetBaseQuery(false).OrderByDescending(x => x.CreateDate); return ConvertFromDtos(Database.Fetch(sql)) .ToArray(); // we don't want to re-iterate again! } protected override IEnumerable PerformGetByQuery(IQuery query) { Sql sqlClause = GetBaseQuery(false); var translator = new SqlTranslator(sqlClause, query); Sql sql = translator.Translate(); List? dtos = Database.Fetch(sql); foreach (ExternalLoginDto? dto in dtos) { yield return ExternalLoginFactory.BuildEntity(dto); } } private IEnumerable PerformGetAllOnIds(params int[] ids) { if (ids.Any() == false) { yield break; } foreach (var id in ids) { IIdentityUserLogin? identityUserLogin = Get(id); if (identityUserLogin is not null) { yield return identityUserLogin; } } } private IEnumerable ConvertFromDtos(IEnumerable dtos) { foreach (IIdentityUserLogin entity in dtos.Select(ExternalLoginFactory.BuildEntity)) { // reset dirty initial properties (U4-1946) ((BeingDirtyBase)entity).ResetDirtyProperties(false); yield return entity; } } protected override Sql GetBaseQuery(bool isCount) { Sql sql = Sql(); if (isCount) { sql.SelectCount(); } else { sql.SelectAll(); } sql.From(); return sql; } protected override string GetBaseWhereClause() => $"{Constants.DatabaseSchema.Tables.ExternalLogin}.id = @id"; protected override IEnumerable GetDeleteClauses() { var list = new List { "DELETE FROM umbracoExternalLogin WHERE id = @id" }; return list; } protected override void PersistNewItem(IIdentityUserLogin entity) { entity.AddingEntity(); ExternalLoginDto dto = ExternalLoginFactory.BuildDto(entity); var id = Convert.ToInt32(Database.Insert(dto)); entity.Id = id; entity.ResetDirtyProperties(); } protected override void PersistUpdatedItem(IIdentityUserLogin entity) { entity.UpdatingEntity(); ExternalLoginDto dto = ExternalLoginFactory.BuildDto(entity); Database.Update(dto); entity.ResetDirtyProperties(); } private Sql GetBaseTokenQuery(bool forUpdate) => forUpdate ? Sql() .Select(r => r.Select(x => x.ExternalLoginDto)) .From() .AppendForUpdateHint() // ensure these table values are locked for updates, the ForUpdate ext method does not work here .InnerJoin() .On(x => x.ExternalLoginId, x => x.Id) : Sql() .Select() .AndSelect(x => x.LoginProvider, x => x.UserOrMemberKey) .From() .InnerJoin() .On(x => x.ExternalLoginId, x => x.Id); }