summaryrefslogtreecommitdiff
path: root/Services/OcrService.cs
blob: d43db2eafcf0e961fc67716158b1518eeba2d802 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
using HyperBooru.Util;
using Microsoft.EntityFrameworkCore;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using Tesseract;

namespace HyperBooru.Services;

public class OcrService : IHostedService {
    private readonly string[] InvalidMimeTypes = [ "image/heic", "image/webp" ];

    private readonly TimeSpan ProcessInterval = TimeSpan.FromMinutes(30);
    private readonly TimeSpan StartupDelay    = TimeSpan.FromSeconds(30);

    private readonly Regex SpaceRegex = new(@"[^0-9a-z]+", RegexOptions.Compiled);

    private Task? task;
    private CancellationTokenSource cts = new();

    private Timer timer;

    private IConfigService               configService;
    private IServiceScopeFactory         scopeFactory;
    private ILogger<OcrService>          logger;
    private IDbContextFactory<HBContext> dbFactory;

    public OcrService(
        IConfigService configService,
        IServiceScopeFactory scopeFactory,
        ILogger<OcrService> logger,
        IDbContextFactory<HBContext> dbFactory) {

        this.configService = configService;
        this.scopeFactory  = scopeFactory;
        this.logger        = logger;
        this.dbFactory     = dbFactory;

        timer = new((object? state) => {
            if(task is not null && !task.IsCompleted)
                return;
            cts  = new();
            task = ProcessAllAsync(cts.Token);
        });
    }

    public Task StartAsync(CancellationToken ct) {
        if(configService.EnableOcr) {
            logger.LogInformation("Service starting...");
            timer.Change(StartupDelay, ProcessInterval);
        }

        return Task.CompletedTask;
    }

    public Task StopAsync(CancellationToken ct) {
        logger.LogInformation("Service stopping...");
        timer.Change(Timeout.Infinite, Timeout.Infinite);
        cts.Cancel();
        return Task.CompletedTask;
    }

    async Task ProcessAllAsync(CancellationToken ct) {
        using var scope = scopeFactory.CreateScope();
        var mediaService = scope.ServiceProvider
            .GetRequiredService<IMediaService>();

        using var db = dbFactory.CreateDbContext();
        Guid[] guids = db.Media
            .AsNoTracking()
            .Include(m => m.CurrentUploadedFile)
            .Include(m => m.OcrData)
            .Where(m => m.OcrData == null)
            .Where(m => m.CurrentUploadedFile!.MimeType.Contains("image/"))
            .Where(m => !InvalidMimeTypes.Contains(m.CurrentUploadedFile!.MimeType))
            .Select(m => m.Guid)
            .ToArray();
        db.Dispose();

        logger.LogInformation($"Performing OCR pass on {guids.Count()} media items");

        var factory = new TaskFactory(new LimitedConcurrencyTaskScheduler());
        var tasks   = new List<Task>();

        var stopwatch = Stopwatch.StartNew();

        foreach(var guid in guids)
            tasks.Add(factory.StartNew(() => Process(guid, mediaService), ct));

        await Task.WhenAll(tasks);
        stopwatch.Stop();

        var time = stopwatch.Elapsed.ToStringHumanReadable();
        logger.LogInformation(
            $"Performed OCR pass on {guids.Count()} media items in {time}");
    }

    private void Process(Guid media, IMediaService mediaService) {
        logger.LogDebug($"Performing OCR on media item {media}");

        using var db = dbFactory.CreateDbContext();
        var m = db.Media
            .Include(m => m.OcrData)
            .First(m => m.Guid == media);

        OcrData o = m.OcrData ?? new();

        using var engine = new TesseractEngine("tessdata", "eng", EngineMode.Default);
        using var image  = Pix.LoadFromFile(mediaService.GetPath(m));
        engine.SetVariable("debug_file", NullFile);

        o.Timestamp      = DateTime.UtcNow;
        o.Text           = engine.Process(image).GetText().Trim();
        o.SearchableText = SpaceRegex.Replace(o.Text.ToLower(), " ").Trim();

        m.OcrData = o;
        db.SaveChanges();
    }

    private string NullFile {
        get {
            if(RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
                return "NUL";
            else
                return "/dev/null";
        }
    }
}