summaryrefslogtreecommitdiff
path: root/Services/OcrService.cs
blob: 4d217054d9985565677bab216da3584f7bdd6edf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
using HyperBooru.Util;
using Microsoft.EntityFrameworkCore;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using Tesseract;

namespace HyperBooru.Services;

public class OcrService : IHostedService {
    private readonly TimeSpan ProcessInterval = TimeSpan.FromMinutes(30);
    private readonly TimeSpan StartupDelay    = TimeSpan.FromSeconds(30);

    private readonly Regex SpaceRegex = new(@"[^0-9a-z]+", RegexOptions.Compiled);

    private Task? task;
    private CancellationTokenSource cts = new();

    private Timer timer;

    private IServiceScopeFactory         scopeFactory;
    private ILogger<OcrService>          logger;
    private IDbContextFactory<HBContext> dbFactory;

    public OcrService(
        IServiceScopeFactory scopeFactory,
        ILogger<OcrService> logger,
        IDbContextFactory<HBContext> dbFactory) {

        this.scopeFactory = scopeFactory;
        this.logger       = logger;
        this.dbFactory    = dbFactory;

        timer = new((object? state) => {
            if(task is not null && !task.IsCompleted)
                return;
            cts  = new();
            task = ProcessAllAsync(cts.Token);
        });
    }

    public Task StartAsync(CancellationToken ct) {
        logger.LogInformation("Service starting...");
        timer.Change(StartupDelay, ProcessInterval);
        return Task.CompletedTask;
    }

    public Task StopAsync(CancellationToken ct) {
        logger.LogInformation("Service stopping...");
        timer.Change(Timeout.Infinite, Timeout.Infinite);
        cts.Cancel();
        return Task.CompletedTask;
    }

    async Task ProcessAllAsync(CancellationToken ct) {
        return;

        using var scope = scopeFactory.CreateScope();
        var mediaService = scope.ServiceProvider
            .GetRequiredService<IMediaService>();

        using var db = dbFactory.CreateDbContext();
        Guid[] guids = db.Media
            .Include(m => m.CurrentUploadedFile)
            .Include(m => m.OcrData)
            .Where(m => m.OcrData == null)
            .Where(m => m.CurrentUploadedFile.MimeType.Contains("image/"))
            .Select(m => m.Guid)
            .ToArray();
        db.Dispose();

        logger.LogInformation($"Performing OCR pass on {guids.Count()} media items");

        var factory = new TaskFactory(new LimitedConcurrencyTaskScheduler());
        var tasks   = new List<Task>();

        var stopwatch = Stopwatch.StartNew();

        foreach(var guid in guids)
            tasks.Add(factory.StartNew(() => Process(guid, mediaService), ct));

        await Task.WhenAll(tasks);
        stopwatch.Stop();

        var time = stopwatch.Elapsed.ToStringHumanReadable();
        logger.LogInformation(
            $"Performed OCR pass on {guids.Count()} media items in {time}");
    }

    private void Process(Guid media, IMediaService mediaService) {
        logger.LogDebug($"Performing OCR on media item {media}");

        using var db = dbFactory.CreateDbContext();
        var m = db.Media
            .Include(m => m.OcrData)
            .First(m => m.Guid == media);

        OcrData o = m.OcrData ?? new();

        using var engine = new TesseractEngine("tessdata", "eng", EngineMode.Default);
        using var image  = Pix.LoadFromFile(mediaService.GetPath(m));
        engine.SetVariable("debug_file", NullFile);

        o.Timestamp      = DateTime.UtcNow;
        o.Text           = engine.Process(image).GetText().Trim();
        o.SearchableText = SpaceRegex.Replace(o.Text.ToLower(), " ").Trim();

        m.OcrData = o;
        db.SaveChanges();
    }

    private string NullFile {
        get {
            if(RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
                return "NUL";
            else
                return "/dev/null";
        }
    }
}