summaryrefslogtreecommitdiff
path: root/PrincipalProviders/LocalPrincipalProvider.cs
diff options
context:
space:
mode:
Diffstat (limited to 'PrincipalProviders/LocalPrincipalProvider.cs')
-rw-r--r--PrincipalProviders/LocalPrincipalProvider.cs69
1 files changed, 46 insertions, 23 deletions
diff --git a/PrincipalProviders/LocalPrincipalProvider.cs b/PrincipalProviders/LocalPrincipalProvider.cs
index 7bee800..8035ce8 100644
--- a/PrincipalProviders/LocalPrincipalProvider.cs
+++ b/PrincipalProviders/LocalPrincipalProvider.cs
@@ -1,4 +1,5 @@
using HyperBooru.Services;
+using Microsoft.AspNetCore.Cryptography.KeyDerivation;
using Microsoft.EntityFrameworkCore;
namespace HyperBooru.PrincipalProviders;
@@ -9,41 +10,63 @@ public class LocalPrincipalProvider : PrincipalProvider {
public LocalPrincipalProvider(IDbContextFactory<HBContext> dbFactory) =>
this.dbFactory = dbFactory;
- public override Principal? GetPrincipal(string name) {
+ public override IPrincipal? GetPrincipal(string name) {
using var db = dbFactory.CreateDbContext();
+ return db.Principals.FirstOrDefault(p => p.Name == name);
+ }
- LocalPrincipal? principal = db.Principals.FirstOrDefault(p => p.Name == name);
- if(principal is null)
- return null;
+ public override IUser? GetUser(string name) {
+ using var db = dbFactory.CreateDbContext();
+ return db.Users.FirstOrDefault(p => p.Name == name);
+ }
- return principal;
+ public override IGroup? GetGroup(string name) {
+ using var db = dbFactory.CreateDbContext();
+ return db.Groups.FirstOrDefault(p => p.Name == name);
}
- public override User? GetUser(string name) {
+ public override IGroup[] GetGroups(IPrincipal principal, bool recurse) {
using var db = dbFactory.CreateDbContext();
- LocalUser? user = db.Users.FirstOrDefault(p => p.Name == name);
- if(user is null)
- return null;
+ List<LocalGroup> groups = db.Principals
+ .First(p => p.Sid == principal.Sid)
+ .MemberOf;
- return user;
- }
+ if(!recurse)
+ return groups.ToArray();
- public override Group? GetGroup(string name) {
- using var db = dbFactory.CreateDbContext();
+ var allGroups = db.Groups
+ .Include(g => g.MemberOf)
+ .ToArray();
- LocalGroup? group = db.Groups.FirstOrDefault(p => p.Name == name);
- if(group is null)
- return null;
+ groups = allGroups
+ .IntersectBy(groups.Select(g => g.Sid), g => g.Sid)
+ .ToList();
- return group;
- }
+ while(true) {
+ var toAdd = groups
+ .SelectMany(g => g.MemberOf)
+ .ExceptBy(groups.Select(g => g.Sid), g => g.Sid)
+ .ToArray();
- public override Group[] GetGroups(Principal principal, bool recurse) {
- throw new NotImplementedException();
- }
+ if(toAdd.Count() == 0)
+ break;
+
+ groups.AddRange(toAdd);
+ }
- public override bool ValidatePassword(User principal, string password) {
- throw new NotImplementedException();
+ return groups.ToArray();
}
+
+ public override bool ValidatePassword(IUser user, string password) =>
+ ((LocalUser) user).PasswordHash == HashPassword(password);
+
+ public static string HashPassword(string password) =>
+ Convert.ToBase64String(
+ KeyDerivation.Pbkdf2(
+ password,
+ Array.Empty<byte>(),
+ KeyDerivationPrf.HMACSHA512,
+ 100_000,
+ 512 / 8));
}