summaryrefslogtreecommitdiff
path: root/Services
diff options
context:
space:
mode:
Diffstat (limited to 'Services')
-rw-r--r--Services/OcrService.cs117
-rw-r--r--Services/SearchService.cs60
2 files changed, 169 insertions, 8 deletions
diff --git a/Services/OcrService.cs b/Services/OcrService.cs
new file mode 100644
index 0000000..2f65e43
--- /dev/null
+++ b/Services/OcrService.cs
@@ -0,0 +1,117 @@
+using HyperBooru.Util;
+using Microsoft.EntityFrameworkCore;
+using System.Diagnostics;
+using System.Runtime.InteropServices;
+using System.Text.RegularExpressions;
+using Tesseract;
+
+namespace HyperBooru.Services;
+
+public class OcrService : IHostedService {
+ private readonly TimeSpan ProcessInterval = TimeSpan.FromMinutes(30);
+ private readonly TimeSpan StartupDelay = TimeSpan.FromSeconds(30);
+
+ private readonly Regex SpaceRegex = new(@"[^0-9a-z]+", RegexOptions.Compiled);
+
+ private Task? task;
+ private CancellationTokenSource cts = new();
+
+ private Timer timer;
+
+ private IServiceScopeFactory scopeFactory;
+ private ILogger<OcrService> logger;
+ private IDbContextFactory<HBContext> dbFactory;
+
+ public OcrService(
+ IServiceScopeFactory scopeFactory,
+ ILogger<OcrService> logger,
+ IDbContextFactory<HBContext> dbFactory) {
+
+ this.scopeFactory = scopeFactory;
+ this.logger = logger;
+ this.dbFactory = dbFactory;
+
+ timer = new((object? state) => {
+ if(task is not null && !task.IsCompleted)
+ return;
+ cts = new();
+ task = ProcessAllAsync(cts.Token);
+ });
+ }
+
+ public Task StartAsync(CancellationToken ct) {
+ logger.LogInformation("Service starting...");
+ timer.Change(StartupDelay, ProcessInterval);
+ return Task.CompletedTask;
+ }
+
+ public Task StopAsync(CancellationToken ct) {
+ logger.LogInformation("Service stopping...");
+ timer.Change(Timeout.Infinite, Timeout.Infinite);
+ cts.Cancel();
+ return Task.CompletedTask;
+ }
+
+ async Task ProcessAllAsync(CancellationToken ct) {
+ using var scope = scopeFactory.CreateScope();
+ var mediaService = scope.ServiceProvider
+ .GetRequiredService<IMediaService>();
+
+ using var db = dbFactory.CreateDbContext();
+ Guid[] guids = db.Media
+ .Include(m => m.OcrData)
+ .Where(m => m.OcrData == null)
+ .Where(m => m.MimeType.Contains("image/"))
+ .Select(m => m.Guid)
+ .ToArray();
+ db.Dispose();
+
+ logger.LogInformation($"Performing OCR pass on {guids.Count()} media items");
+
+ var factory = new TaskFactory(new LimitedConcurrencyTaskScheduler());
+ var tasks = new List<Task>();
+
+ var stopwatch = Stopwatch.StartNew();
+
+ foreach(var guid in guids)
+ tasks.Add(factory.StartNew(() => Process(guid, mediaService), ct));
+
+ await Task.WhenAll(tasks);
+ stopwatch.Stop();
+
+ var time = stopwatch.Elapsed.ToStringHumanReadable();
+ logger.LogInformation(
+ $"Performed OCR pass on {guids.Count()} media items in {time}");
+ }
+
+ private void Process(Guid media, IMediaService mediaService) {
+ logger.LogDebug($"Performing OCR on media item {media}");
+
+ using var db = dbFactory.CreateDbContext();
+ var m = db.Media
+ .Include(m => m.OcrData)
+ .First(m => m.Guid == media);
+
+ OcrData o = m.OcrData ?? new();
+
+ using var engine = new TesseractEngine("tessdata", "eng", EngineMode.Default);
+ using var image = Pix.LoadFromFile(mediaService.GetPath(m));
+ engine.SetVariable("debug_file", NullFile);
+
+ o.Timestamp = DateTime.UtcNow;
+ o.Text = engine.Process(image).GetText().Trim();
+ o.SearchableText = SpaceRegex.Replace(o.Text.ToLower(), " ").Trim();
+
+ m.OcrData = o;
+ db.SaveChanges();
+ }
+
+ private string NullFile {
+ get {
+ if(RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ return "NUL";
+ else
+ return "/dev/null";
+ }
+ }
+}
diff --git a/Services/SearchService.cs b/Services/SearchService.cs
index e8e497d..bb2963d 100644
--- a/Services/SearchService.cs
+++ b/Services/SearchService.cs
@@ -24,33 +24,77 @@ public class SearchService : ISearchService {
query = query.ToLower();
+ int[] descriptionResults = SearchDescription(query);
+ int[] ocrResults = SearchOcr(query);
+
var matchedTag = db.TagDefinitions
.FirstOrDefault(td => td.Name.ToLower() == query);
int[] tags;
-
if(matchedTag is not null) {
tags = tagService
.TagsThatImply(matchedTag)
.Select(td => td.ObjectId)
.ToArray();
} else {
- // TODO: expand scope to all tags that imply
+ // TODO: Expand scope to all tags that imply
tags = db.TagDefinitions
.Where(td => td.Name.ToLower().Contains(query))
.Select(td => td.ObjectId)
.ToArray();
}
+ int[] tagResults = SearchTags(tags);
+
+ int[] mediaIds = descriptionResults
+ .Union(ocrResults)
+ .Union(tagResults)
+ .OrderDescending()
+ .ToArray();
+
return db.Media
.Include(m => m.Tags)
- .AsEnumerable()
- .Where(m => m.Tags.IntersectBy(tags, t => t.TagDefinitionId).Any())
- .Concat(db.Media
+ .Where(m => mediaIds.Contains(m.ObjectId))
+ .ToArray();
+ }
+
+ // TODO: Make asynchronous
+ private int[] SearchTags(int[] tags) {
+ return Task.Run(() => {
+ using var db = dbFactory.CreateDbContext();
+ return db.Media
+ .Include(m => m.Tags)
+ .AsEnumerable()
+ .Where(m => m.Tags.IntersectBy(tags, t => t.TagDefinitionId).Any())
+ .Select(m => m.ObjectId)
+ .ToArray();
+ }).GetAwaiter().GetResult();
+ }
+
+ // TODO: Make asynchronous
+ private int[] SearchDescription(string query) {
+ return Task.Run(() => {
+ using var db = dbFactory.CreateDbContext();
+ query = query.ToLower();
+ return db.Media
.Where(m =>
(m.ShortDescription != null && m.ShortDescription.ToLower().Contains(query)) ||
- (m.LongDescription != null && m.LongDescription.ToLower().Contains(query))))
- .DistinctBy(m => m.ObjectId)
- .ToArray();
+ (m.LongDescription != null && m.LongDescription.ToLower().Contains(query)))
+ .Select(m => m.ObjectId)
+ .ToArray();
+ }).GetAwaiter().GetResult();
+ }
+
+ // TODO: Make asynchronous
+ private int[] SearchOcr(string query) {
+ return Task.Run(() => {
+ using var db = dbFactory.CreateDbContext();
+ query = query.ToLower();
+ return db.OcrData
+ .Include(o => o.Media)
+ .Where(o => o.SearchableText.Contains(query))
+ .Select(o => o.Media.ObjectId)
+ .ToArray();
+ }).GetAwaiter().GetResult();
}
}