diff options
Diffstat (limited to 'Services')
| -rw-r--r-- | Services/MediaService.cs | 8 | ||||
| -rw-r--r-- | Services/PrincipalProvider.cs | 9 | ||||
| -rw-r--r-- | Services/SecurityService.cs | 37 |
3 files changed, 32 insertions, 22 deletions
diff --git a/Services/MediaService.cs b/Services/MediaService.cs index c779d8e..abc026f 100644 --- a/Services/MediaService.cs +++ b/Services/MediaService.cs @@ -63,13 +63,13 @@ public class MediaService : IMediaService { .ThenInclude(t => t.TagDefinition) .First(m => m.Guid == media.Guid); var ingestTag = db.TagDefinitions - .First(td => td.Guid == HBContext.IngestTag); + .First(td => td.Guid == HBObjectGuid.IngestTag); if(ingest) { - if(!media.Tags.Select(t => t.TagDefinition.Guid).Contains(HBContext.IngestTag)) + if(!media.Tags.Select(t => t.TagDefinition.Guid).Contains(HBObjectGuid.IngestTag)) media.Tags.Add(new(ingestTag)); } else { - media.Tags.RemoveAll(t => t.TagDefinition.Guid == HBContext.IngestTag); + media.Tags.RemoveAll(t => t.TagDefinition.Guid == HBObjectGuid.IngestTag); } db.SaveChanges(); @@ -151,7 +151,7 @@ public class MediaService : IMediaService { if(media is null) { var ingestTagDef = db.TagDefinitions - .First(td => td.Guid == HBContext.IngestTag); + .First(td => td.Guid == HBObjectGuid.IngestTag); media = new() { CurrentUploadedFile = fileRecord, diff --git a/Services/PrincipalProvider.cs b/Services/PrincipalProvider.cs new file mode 100644 index 0000000..e75c6c7 --- /dev/null +++ b/Services/PrincipalProvider.cs @@ -0,0 +1,9 @@ +namespace HyperBooru.Services; + +public abstract class PrincipalProvider { + public abstract bool ValidatePassword(HBPrincipal principal, string password); + + public abstract HBPrincipal GetPrincipal(string username); + + public abstract Group[] GetAllGroups(HBPrincipal principal); +} diff --git a/Services/SecurityService.cs b/Services/SecurityService.cs index f0ebd70..f1444c1 100644 --- a/Services/SecurityService.cs +++ b/Services/SecurityService.cs @@ -7,7 +7,7 @@ namespace HyperBooru.Services; public class SecurityService { private IDbContextFactory<HBContext> dbFactory; - private MemoryCache<int, HBPrincipal> principalCache; + private MemoryCache<SidStruct, HBPrincipal> principalCache; private MemoryCache<int, Acl> aclCache; public SecurityService(IDbContextFactory<HBContext> dbFactory) { @@ -17,11 +17,11 @@ public class SecurityService { principalCache = new() { MaxItems = 10_000, MaxAge = TimeSpan.FromMinutes(10), - DataSource = (int id) => { + DataSource = (SidStruct sid) => { using var db = dbFactory.CreateDbContext(); return db.Principals .Include(p => p.MemberOf) - .FirstOrDefault(p => p.ObjectId == id); + .FirstOrDefault(p => p.Sid.SidStruct.Equals(sid)); } }; @@ -32,7 +32,7 @@ public class SecurityService { using var db = dbFactory.CreateDbContext(); return db.Acls .Include(a => a.Rules) - .FirstOrDefault(a => a.ObjectId == id); + .FirstOrDefault(a => a.AclId == id); } }; } @@ -66,26 +66,27 @@ public class SecurityService { if(acl is null) return ulong.MaxValue; - bool hasAllowRules = acl.Rules - .Any(r => r.Action == AclRuleAction.Allow); - - ulong permissions = hasAllowRules ? 0 : ulong.MaxValue; + ulong permissions = 0; var principals = GetGroupMemberShip(principal) .Cast<HBPrincipal>() .Concat(new[] { principal }) + .Select(p => p.Sid) .ToArray(); - acl.Rules.IntersectBy(principals, r => r.Principal); + var allowRules = acl.Rules.Where(r => r.Action == AclRuleAction.Allow); + var denyRules = acl.Rules.Where(r => r.Action == AclRuleAction.Deny); - foreach(var rule in acl.Rules) { + foreach(var rule in allowRules) { if(!principals.Contains(rule.Principal)) continue; + permissions |= rule.Permissions; + } - if(rule.Action == AclRuleAction.Allow) - permissions |= rule.Permissions; - else - permissions &= ~rule.Permissions; + foreach(var rule in denyRules) { + if(!principals.Contains(rule.Principal)) + continue; + permissions &= ~rule.Permissions; } return permissions; @@ -101,15 +102,15 @@ public class SecurityService { while(true) { var toAdd = groups .SelectMany(g => g.MemberOf) - .Select(g => g.ObjectId) - .Where(id => !groups.Select(g => g.ObjectId).Contains(id)) + .Select(g => g.Sid.SidStruct) + .Where(sid => !groups.Select(g => g.Sid.SidStruct).Contains(sid)) .ToArray(); if(toAdd.Count() == 0) break; - foreach(var id in toAdd) - groups.Add((Group) principalCache[id]); + foreach(var sid in toAdd) + groups.Add((Group) principalCache[sid]); } return groups; |
