diff options
Diffstat (limited to 'PrincipalProviders')
| -rw-r--r-- | PrincipalProviders/LocalPrincipalProvider.cs | 69 |
1 files changed, 46 insertions, 23 deletions
diff --git a/PrincipalProviders/LocalPrincipalProvider.cs b/PrincipalProviders/LocalPrincipalProvider.cs index 7bee800..8035ce8 100644 --- a/PrincipalProviders/LocalPrincipalProvider.cs +++ b/PrincipalProviders/LocalPrincipalProvider.cs @@ -1,4 +1,5 @@ using HyperBooru.Services; +using Microsoft.AspNetCore.Cryptography.KeyDerivation; using Microsoft.EntityFrameworkCore; namespace HyperBooru.PrincipalProviders; @@ -9,41 +10,63 @@ public class LocalPrincipalProvider : PrincipalProvider { public LocalPrincipalProvider(IDbContextFactory<HBContext> dbFactory) => this.dbFactory = dbFactory; - public override Principal? GetPrincipal(string name) { + public override IPrincipal? GetPrincipal(string name) { using var db = dbFactory.CreateDbContext(); + return db.Principals.FirstOrDefault(p => p.Name == name); + } - LocalPrincipal? principal = db.Principals.FirstOrDefault(p => p.Name == name); - if(principal is null) - return null; + public override IUser? GetUser(string name) { + using var db = dbFactory.CreateDbContext(); + return db.Users.FirstOrDefault(p => p.Name == name); + } - return principal; + public override IGroup? GetGroup(string name) { + using var db = dbFactory.CreateDbContext(); + return db.Groups.FirstOrDefault(p => p.Name == name); } - public override User? GetUser(string name) { + public override IGroup[] GetGroups(IPrincipal principal, bool recurse) { using var db = dbFactory.CreateDbContext(); - LocalUser? user = db.Users.FirstOrDefault(p => p.Name == name); - if(user is null) - return null; + List<LocalGroup> groups = db.Principals + .First(p => p.Sid == principal.Sid) + .MemberOf; - return user; - } + if(!recurse) + return groups.ToArray(); - public override Group? GetGroup(string name) { - using var db = dbFactory.CreateDbContext(); + var allGroups = db.Groups + .Include(g => g.MemberOf) + .ToArray(); - LocalGroup? group = db.Groups.FirstOrDefault(p => p.Name == name); - if(group is null) - return null; + groups = allGroups + .IntersectBy(groups.Select(g => g.Sid), g => g.Sid) + .ToList(); - return group; - } + while(true) { + var toAdd = groups + .SelectMany(g => g.MemberOf) + .ExceptBy(groups.Select(g => g.Sid), g => g.Sid) + .ToArray(); - public override Group[] GetGroups(Principal principal, bool recurse) { - throw new NotImplementedException(); - } + if(toAdd.Count() == 0) + break; + + groups.AddRange(toAdd); + } - public override bool ValidatePassword(User principal, string password) { - throw new NotImplementedException(); + return groups.ToArray(); } + + public override bool ValidatePassword(IUser user, string password) => + ((LocalUser) user).PasswordHash == HashPassword(password); + + public static string HashPassword(string password) => + Convert.ToBase64String( + KeyDerivation.Pbkdf2( + password, + Array.Empty<byte>(), + KeyDerivationPrf.HMACSHA512, + 100_000, + 512 / 8)); } |
