namespace HyperBooru.Util; public class LimitedConcurrencyTaskScheduler : TaskScheduler { public sealed override int MaximumConcurrencyLevel => maxConcurrency; private int maxConcurrency; [ThreadStatic] private static bool threadIsProcessingItems; private readonly LinkedList tasks = new(); private int delegatesQueuedOrRunning = 0; public LimitedConcurrencyTaskScheduler() { maxConcurrency = Environment.ProcessorCount; } public LimitedConcurrencyTaskScheduler(int maxConcurrency) { if(maxConcurrency < 1) throw new ArgumentOutOfRangeException("maxConcurrency must be greater than 0"); this.maxConcurrency = (int) maxConcurrency; } protected sealed override void QueueTask(Task task) { lock(tasks) { tasks.AddLast(task); if(delegatesQueuedOrRunning < maxConcurrency) { delegatesQueuedOrRunning++; NotifyThreadPoolOfPendingWork(); } } } private void NotifyThreadPoolOfPendingWork() { ThreadPool.UnsafeQueueUserWorkItem(_ => { threadIsProcessingItems = true; try { while(true) { Task item; lock(tasks) { if(tasks.Count == 0) { delegatesQueuedOrRunning--; break; } else { item = tasks.First.Value; tasks.RemoveFirst(); } } TryExecuteTask(item); } } finally { threadIsProcessingItems = false; } }, null); } protected sealed override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) { if(!threadIsProcessingItems) return false; if(taskWasPreviouslyQueued) return TryDequeue(task) ? TryExecuteTask(task) : false; else return TryExecuteTask(task); } protected sealed override bool TryDequeue(Task task) { lock(tasks) { return tasks.Remove(task); } } protected sealed override IEnumerable GetScheduledTasks() { bool lockTaken = false; try { Monitor.TryEnter(tasks, ref lockTaken); if(lockTaken) return tasks; else throw new NotSupportedException(); } finally { if(lockTaken) Monitor.Exit(tasks); } } }