From b286a0b0f1fcdb511d2dbb8886039cfb0182c89b Mon Sep 17 00:00:00 2001 From: Jake Mannens Date: Fri, 1 Sep 2023 13:03:57 +1000 Subject: Merged OCR functionality --- Services/OcrService.cs | 117 ++++++++++++++++++++++++++++++++++++++++++++++ Services/SearchService.cs | 60 ++++++++++++++++++++---- 2 files changed, 169 insertions(+), 8 deletions(-) create mode 100644 Services/OcrService.cs (limited to 'Services') diff --git a/Services/OcrService.cs b/Services/OcrService.cs new file mode 100644 index 0000000..2f65e43 --- /dev/null +++ b/Services/OcrService.cs @@ -0,0 +1,117 @@ +using HyperBooru.Util; +using Microsoft.EntityFrameworkCore; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text.RegularExpressions; +using Tesseract; + +namespace HyperBooru.Services; + +public class OcrService : IHostedService { + private readonly TimeSpan ProcessInterval = TimeSpan.FromMinutes(30); + private readonly TimeSpan StartupDelay = TimeSpan.FromSeconds(30); + + private readonly Regex SpaceRegex = new(@"[^0-9a-z]+", RegexOptions.Compiled); + + private Task? task; + private CancellationTokenSource cts = new(); + + private Timer timer; + + private IServiceScopeFactory scopeFactory; + private ILogger logger; + private IDbContextFactory dbFactory; + + public OcrService( + IServiceScopeFactory scopeFactory, + ILogger logger, + IDbContextFactory dbFactory) { + + this.scopeFactory = scopeFactory; + this.logger = logger; + this.dbFactory = dbFactory; + + timer = new((object? state) => { + if(task is not null && !task.IsCompleted) + return; + cts = new(); + task = ProcessAllAsync(cts.Token); + }); + } + + public Task StartAsync(CancellationToken ct) { + logger.LogInformation("Service starting..."); + timer.Change(StartupDelay, ProcessInterval); + return Task.CompletedTask; + } + + public Task StopAsync(CancellationToken ct) { + logger.LogInformation("Service stopping..."); + timer.Change(Timeout.Infinite, Timeout.Infinite); + cts.Cancel(); + return Task.CompletedTask; + } + + async Task ProcessAllAsync(CancellationToken ct) { + using var scope = scopeFactory.CreateScope(); + var mediaService = scope.ServiceProvider + .GetRequiredService(); + + using var db = dbFactory.CreateDbContext(); + Guid[] guids = db.Media + .Include(m => m.OcrData) + .Where(m => m.OcrData == null) + .Where(m => m.MimeType.Contains("image/")) + .Select(m => m.Guid) + .ToArray(); + db.Dispose(); + + logger.LogInformation($"Performing OCR pass on {guids.Count()} media items"); + + var factory = new TaskFactory(new LimitedConcurrencyTaskScheduler()); + var tasks = new List(); + + var stopwatch = Stopwatch.StartNew(); + + foreach(var guid in guids) + tasks.Add(factory.StartNew(() => Process(guid, mediaService), ct)); + + await Task.WhenAll(tasks); + stopwatch.Stop(); + + var time = stopwatch.Elapsed.ToStringHumanReadable(); + logger.LogInformation( + $"Performed OCR pass on {guids.Count()} media items in {time}"); + } + + private void Process(Guid media, IMediaService mediaService) { + logger.LogDebug($"Performing OCR on media item {media}"); + + using var db = dbFactory.CreateDbContext(); + var m = db.Media + .Include(m => m.OcrData) + .First(m => m.Guid == media); + + OcrData o = m.OcrData ?? new(); + + using var engine = new TesseractEngine("tessdata", "eng", EngineMode.Default); + using var image = Pix.LoadFromFile(mediaService.GetPath(m)); + engine.SetVariable("debug_file", NullFile); + + o.Timestamp = DateTime.UtcNow; + o.Text = engine.Process(image).GetText().Trim(); + o.SearchableText = SpaceRegex.Replace(o.Text.ToLower(), " ").Trim(); + + m.OcrData = o; + db.SaveChanges(); + } + + private string NullFile { + get { + if(RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return "NUL"; + else + return "/dev/null"; + } + } +} diff --git a/Services/SearchService.cs b/Services/SearchService.cs index e8e497d..bb2963d 100644 --- a/Services/SearchService.cs +++ b/Services/SearchService.cs @@ -24,33 +24,77 @@ public class SearchService : ISearchService { query = query.ToLower(); + int[] descriptionResults = SearchDescription(query); + int[] ocrResults = SearchOcr(query); + var matchedTag = db.TagDefinitions .FirstOrDefault(td => td.Name.ToLower() == query); int[] tags; - if(matchedTag is not null) { tags = tagService .TagsThatImply(matchedTag) .Select(td => td.ObjectId) .ToArray(); } else { - // TODO: expand scope to all tags that imply + // TODO: Expand scope to all tags that imply tags = db.TagDefinitions .Where(td => td.Name.ToLower().Contains(query)) .Select(td => td.ObjectId) .ToArray(); } + int[] tagResults = SearchTags(tags); + + int[] mediaIds = descriptionResults + .Union(ocrResults) + .Union(tagResults) + .OrderDescending() + .ToArray(); + return db.Media .Include(m => m.Tags) - .AsEnumerable() - .Where(m => m.Tags.IntersectBy(tags, t => t.TagDefinitionId).Any()) - .Concat(db.Media + .Where(m => mediaIds.Contains(m.ObjectId)) + .ToArray(); + } + + // TODO: Make asynchronous + private int[] SearchTags(int[] tags) { + return Task.Run(() => { + using var db = dbFactory.CreateDbContext(); + return db.Media + .Include(m => m.Tags) + .AsEnumerable() + .Where(m => m.Tags.IntersectBy(tags, t => t.TagDefinitionId).Any()) + .Select(m => m.ObjectId) + .ToArray(); + }).GetAwaiter().GetResult(); + } + + // TODO: Make asynchronous + private int[] SearchDescription(string query) { + return Task.Run(() => { + using var db = dbFactory.CreateDbContext(); + query = query.ToLower(); + return db.Media .Where(m => (m.ShortDescription != null && m.ShortDescription.ToLower().Contains(query)) || - (m.LongDescription != null && m.LongDescription.ToLower().Contains(query)))) - .DistinctBy(m => m.ObjectId) - .ToArray(); + (m.LongDescription != null && m.LongDescription.ToLower().Contains(query))) + .Select(m => m.ObjectId) + .ToArray(); + }).GetAwaiter().GetResult(); + } + + // TODO: Make asynchronous + private int[] SearchOcr(string query) { + return Task.Run(() => { + using var db = dbFactory.CreateDbContext(); + query = query.ToLower(); + return db.OcrData + .Include(o => o.Media) + .Where(o => o.SearchableText.Contains(query)) + .Select(o => o.Media.ObjectId) + .ToArray(); + }).GetAwaiter().GetResult(); } } -- cgit v1.3