From 24ed99d0189c2ee5975bdffc15c804b59a620bd5 Mon Sep 17 00:00:00 2001 From: Tommaso Morganti Date: Sat, 11 Apr 2026 23:50:43 +0200 Subject: [PATCH 1/2] fix: ai correctly awaits message batches --- .../auto-moderation-stack/ai-moderation.ts | 37 +- tests/awaiter.test.ts | 383 ++++++++++++++++++ 2 files changed, 393 insertions(+), 27 deletions(-) create mode 100644 tests/awaiter.test.ts diff --git a/src/middlewares/auto-moderation-stack/ai-moderation.ts b/src/middlewares/auto-moderation-stack/ai-moderation.ts index 6c05d6e..8c6a1c8 100644 --- a/src/middlewares/auto-moderation-stack/ai-moderation.ts +++ b/src/middlewares/auto-moderation-stack/ai-moderation.ts @@ -1,10 +1,10 @@ -import { EventEmitter } from "node:events" import type { Filter } from "grammy" import OpenAI from "openai" import { env } from "@/env" import { logger } from "@/logger" import { getText } from "@/utils/messages" import type { Context } from "@/utils/types" +import { Awaiter } from "@/utils/wait" import { DELETION_THRESHOLDS } from "./constants" import type { Category, FlaggedCategory, ModerationCandidate, ModerationResult } from "./types" @@ -22,9 +22,7 @@ import type { Category, FlaggedCategory, ModerationCandidate, ModerationResult } * * More info on the API here: https://platform.openai.com/docs/guides/moderation */ -export class AIModeration extends EventEmitter<{ - results: [ModerationResult[]] -}> { +export class AIModeration { /** * Takes each category, and for the flagged ones takes the score (highest among related results) and * confronts it with predefined thresholds @@ -64,10 +62,11 @@ export class AIModeration extends EventEmitter<{ private client: OpenAI | null private checkQueue: ModerationCandidate[] = [] private timeout: NodeJS.Timeout | null = null + private responseAwaiter: Awaiter = new Awaiter() constructor() { - super() this.client = env.OPENAI_API_KEY ? new OpenAI({ apiKey: env.OPENAI_API_KEY }) : null + this.responseAwaiter.resolve([]) // initialize it as resolved with empty results if (!this.client) logger.warn("[AI Mod] Missing env OPENAI_API_KEY, automatic moderation will not work.") else logger.debug("[AI Mod] OpenAI client initialized for moderation.") @@ -82,41 +81,27 @@ export class AIModeration extends EventEmitter<{ if (!this.client) return if (this.checkQueue.length === 0) return + this.responseAwaiter = new Awaiter() // reset the awaiter for the next batch const candidates = this.checkQueue.splice(0, this.checkQueue.length) void this.client.moderations .create({ input: candidates, model: "omni-moderation-latest" }) .then((response) => { - this.emit("results", response.results) + this.responseAwaiter.resolve(response.results) }) .catch((error: unknown) => { logger.error({ error }, "[AI Mod] Error during moderation check") + this.responseAwaiter.resolve([]) // fail open: return empty results on error }) } - /** - * Wait for the moderation results to be emitted. - * - * This is done to allow batching of moderation checks. - * @returns A promise that resolves with the moderation results, mapped as they were queued. - */ - private waitForResults(): Promise { - return new Promise((resolve, reject) => { - this.once("results", (results) => { - resolve(results) - }) - setTimeout(() => { - reject(new Error("Moderation Check timed out")) - }, 1000 * 30) - }) - } - /** * Add a candidate to the moderation check queue, returns the result if found. * @param candidate the candidate to add to the queue, either text or image * @returns A promise that resolves with the moderation result, or null if not found or timed out. */ - private addToCheckQueue(candidate: ModerationCandidate): Promise { + private async addToCheckQueue(candidate: ModerationCandidate): Promise { + await this.responseAwaiter // wait for the previous batch to be processed const index = this.checkQueue.push(candidate) - 1 if (this.timeout === null) { // throttle a check every 10 seconds @@ -125,9 +110,7 @@ export class AIModeration extends EventEmitter<{ this.timeout = null }, 10 * 1000) } - return this.waitForResults() - .then((results) => results[index] ?? null) - .catch(() => null) // check timed out + return this.responseAwaiter.then((results) => results[index] ?? null) } /** diff --git a/tests/awaiter.test.ts b/tests/awaiter.test.ts new file mode 100644 index 0000000..53573fc --- /dev/null +++ b/tests/awaiter.test.ts @@ -0,0 +1,383 @@ +import { describe, expect, it, vi } from "vitest" +import { Awaiter } from "@/utils/wait" + +describe("Awaiter: PromiseLike implementation with manual resolution", () => { + describe("basic resolution", () => { + it("should resolve with a value", async () => { + const awaiter = new Awaiter() + + setTimeout(() => { + awaiter.resolve("test value") + }, 10) + + const result = await awaiter + expect(result).toBe("test value") + }) + + it("should resolve with different types", async () => { + const stringAwaiter = new Awaiter() + stringAwaiter.resolve("hello") + expect(await stringAwaiter).toBe("hello") + + const numberAwaiter = new Awaiter() + numberAwaiter.resolve(42) + expect(await numberAwaiter).toBe(42) + + const booleanAwaiter = new Awaiter() + booleanAwaiter.resolve(true) + expect(await booleanAwaiter).toBe(true) + + const objectAwaiter = new Awaiter<{ id: number; name: string }>() + objectAwaiter.resolve({ id: 1, name: "test" }) + expect(await objectAwaiter).toEqual({ id: 1, name: "test" }) + }) + + it("should resolve with null and undefined", async () => { + const nullAwaiter = new Awaiter() + nullAwaiter.resolve(null) + expect(await nullAwaiter).toBe(null) + + const undefinedAwaiter = new Awaiter() + undefinedAwaiter.resolve(undefined) + expect(await undefinedAwaiter).toBe(undefined) + }) + }) + + describe("instant resolution after already resolved", () => { + it("should instantly resolve when awaited after resolution", async () => { + const awaiter = new Awaiter() + awaiter.resolve("immediate") + + const start = Date.now() + const result = await awaiter + const elapsed = Date.now() - start + + expect(result).toBe("immediate") + expect(elapsed).toBeLessThan(2) // Should be nearly instant + }) + + it("should always return the same value on multiple awaits", async () => { + const awaiter = new Awaiter() + awaiter.resolve("same value") + + const result1 = await awaiter + const result2 = await awaiter + const result3 = await awaiter + + expect(result1).toBe("same value") + expect(result2).toBe("same value") + expect(result3).toBe("same value") + }) + + it("should handle rapid sequential awaits", async () => { + const awaiter = new Awaiter() + awaiter.resolve(123) + + const results = await Promise.all([awaiter, awaiter, awaiter, awaiter, awaiter]) + + expect(results).toEqual([123, 123, 123, 123, 123]) + }) + + it("should work with concurrent awaits before resolution", async () => { + const awaiter = new Awaiter() + + const promise1 = awaiter + const promise2 = awaiter + const promise3 = awaiter + + // Resolve after promises are created but before awaited + setTimeout(() => { + awaiter.resolve("concurrent value") + }, 10) + + const [result1, result2, result3] = await Promise.all([promise1, promise2, promise3]) + + expect(result1).toBe("concurrent value") + expect(result2).toBe("concurrent value") + expect(result3).toBe("concurrent value") + }) + }) + + describe("then() chaining", () => { + it("should support then() method", async () => { + const awaiter = new Awaiter() + awaiter.resolve(5) + + const result = await awaiter.then((value) => value * 2) + expect(result).toBe(10) + }) + + it("should chain multiple then() calls", async () => { + const awaiter = new Awaiter() + awaiter.resolve(2) + + const result = await awaiter + .then((value) => value * 2) + .then((value) => value + 3) + .then((value) => value * 10) + + expect(result).toBe(70) // ((2 * 2) + 3) * 10 + }) + + it("should handle then() with value transformation", async () => { + const awaiter = new Awaiter<{ count: number }>() + awaiter.resolve({ count: 5 }) + + const result = await awaiter.then((obj) => obj.count.toString()) + expect(result).toBe("5") + }) + + it("should support then() that returns a promise", async () => { + const awaiter = new Awaiter() + awaiter.resolve("hello") + + const result = await awaiter.then((value) => Promise.resolve(value.toUpperCase())) + expect(result).toBe("HELLO") + }) + + it("should resolve then() instantly when awaiter is already resolved", async () => { + const awaiter = new Awaiter() + awaiter.resolve(42) + + const start = Date.now() + const result = await awaiter.then((value) => value * 2) + const elapsed = Date.now() - start + + expect(result).toBe(84) + expect(elapsed).toBeLessThan(2) + }) + + it("should handle null in then() callback", async () => { + const awaiter = new Awaiter() + awaiter.resolve("value") + + const result = await awaiter.then(null) + expect(result).toBe("value") + }) + + it("should handle undefined in then() callback", async () => { + const awaiter = new Awaiter() + awaiter.resolve("value") + + const result = await awaiter.then(undefined) + expect(result).toBe("value") + }) + }) + + describe("onrejected handler", () => { + it("should support onrejected handler in then()", async () => { + const awaiter = new Awaiter() + // Note: The current implementation doesn't have a reject method, + // but we can test that onrejected handler is accepted + awaiter.resolve("success") + + const result = await awaiter.then( + (value) => value, + () => "error handler" + ) + expect(result).toBe("success") + }) + }) + + describe("promiselike integration", () => { + it("should be usable with Promise.resolve()", async () => { + const awaiter = new Awaiter() + awaiter.resolve("resolved") + + const result = await Promise.resolve(awaiter) + expect(result).toBe("resolved") + }) + + it("should be usable with Promise.all()", async () => { + const awaiter1 = new Awaiter() + const awaiter2 = new Awaiter() + + awaiter1.resolve(1) + awaiter2.resolve(2) + + const results = await Promise.all([awaiter1, awaiter2]) + expect(results).toEqual([1, 2]) + }) + + it("should be usable with Promise.race()", async () => { + const awaiter1 = new Awaiter() + const awaiter2 = new Awaiter() + + awaiter1.resolve("first") + // awaiter2 not resolved, so it loses the race + + const result = await Promise.race([awaiter1, awaiter2]) + expect(result).toBe("first") + }) + + it("should work in async/await context", async () => { + const awaiter = new Awaiter() + + const asyncFunction = async () => { + const value = await awaiter + return value.toUpperCase() + } + + awaiter.resolve("hello") + const result = await asyncFunction() + expect(result).toBe("HELLO") + }) + }) + + describe("multiple resolves and timing", () => { + it("should only use the first resolve call", async () => { + const awaiter = new Awaiter() + + awaiter.resolve("first") + awaiter.resolve("second") // Should be ignored + awaiter.resolve("third") // Should be ignored + + const result = await awaiter + expect(result).toBe("first") + }) + + it("should handle delayed resolution followed by immediate awaits", async () => { + const awaiter = new Awaiter() + + setTimeout(() => { + awaiter.resolve("delayed") + }, 20) + + const promises = [ + awaiter, + awaiter, + (async () => { + await new Promise((r) => setTimeout(r, 10)) + return awaiter + })(), + ] + + const results = await Promise.all(promises) + expect(results[0]).toBe("delayed") + expect(results[1]).toBe("delayed") + expect(results[2]).toBe("delayed") + }) + }) + + describe("real-world use cases", () => { + it("should work as a deferred value holder for async operations", async () => { + const awaiter = new Awaiter<{ data: string }>() + + // Simulate an async operation that resolves the awaiter + const asyncOp = async () => { + await new Promise((r) => setTimeout(r, 20)) + awaiter.resolve({ data: "fetched" }) + } + + void asyncOp() + + // Multiple consumers can await the same value + const [result1, result2] = await Promise.all([awaiter, awaiter]) + expect(result1).toEqual({ data: "fetched" }) + expect(result2).toEqual({ data: "fetched" }) + }) + + it("should work with setTimeout async pattern", async () => { + const awaiter = new Awaiter() + let counter = 0 + + const interval = setInterval(() => { + counter++ + if (counter >= 3) { + clearInterval(interval) + awaiter.resolve(counter) + } + }, 10) + + const result = await awaiter + expect(result).toBe(3) + }) + + it("should integrate with event-like patterns", async () => { + const awaiter = new Awaiter() + const events: string[] = [] + + const listener = (event: string) => { + events.push(event) + if (event === "ready") { + awaiter.resolve("listening complete") + } + } + + // Simulate event emission + setTimeout(() => listener("start"), 5) + setTimeout(() => listener("process"), 10) + setTimeout(() => listener("ready"), 15) + + const result = await awaiter + expect(result).toBe("listening complete") + expect(events).toEqual(["start", "process", "ready"]) + }) + + it("should handle multiple concurrent waiters on same awaiter", async () => { + const awaiter = new Awaiter() + const callbacks = vi.fn() + + const waiter1 = awaiter.then((val) => { + callbacks("waiter1") + return val + }) + + const waiter2 = awaiter.then((val) => { + callbacks("waiter2") + return val + }) + + const waiter3 = awaiter.then((val) => { + callbacks("waiter3") + return val + }) + + awaiter.resolve("value") + + await Promise.all([waiter1, waiter2, waiter3]) + expect(callbacks).toHaveBeenCalledTimes(3) + expect(callbacks).toHaveBeenCalledWith("waiter1") + expect(callbacks).toHaveBeenCalledWith("waiter2") + expect(callbacks).toHaveBeenCalledWith("waiter3") + }) + }) + + describe("edge cases", () => { + it("should handle resolution with 0", async () => { + const awaiter = new Awaiter() + awaiter.resolve(0) + expect(await awaiter).toBe(0) + }) + + it("should handle resolution with empty string", async () => { + const awaiter = new Awaiter() + awaiter.resolve("") + expect(await awaiter).toBe("") + }) + + it("should handle resolution with empty array", async () => { + const awaiter = new Awaiter() + awaiter.resolve([]) + expect(await awaiter).toEqual([]) + }) + + it("should handle resolution with empty object", async () => { + const awaiter = new Awaiter>() + awaiter.resolve({}) + expect(await awaiter).toEqual({}) + }) + + it("should preserve promise identity across multiple awaits", async () => { + const awaiter = new Awaiter() + awaiter.resolve("value") + + const promise1 = awaiter.then((v) => v) + const promise2 = awaiter.then((v) => v) + + const results = await Promise.all([promise1, promise2]) + expect(results[0]).toBe("value") + expect(results[1]).toBe("value") + }) + }) +}) From 0b216aac7c2c007bc115689369bc8b030bcc3133 Mon Sep 17 00:00:00 2001 From: Tommaso Morganti Date: Sat, 11 Apr 2026 23:50:52 +0200 Subject: [PATCH 2/2] chore: remove AI --- .../auto-moderation-stack/index.ts | 114 +++++++++--------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/src/middlewares/auto-moderation-stack/index.ts b/src/middlewares/auto-moderation-stack/index.ts index acd8582..7b12711 100644 --- a/src/middlewares/auto-moderation-stack/index.ts +++ b/src/middlewares/auto-moderation-stack/index.ts @@ -7,14 +7,14 @@ import { modules } from "@/modules" import { Moderation } from "@/modules/moderation" import { measureForkDuration, type TelemetryContextFlavor, TrackedMiddleware } from "@/modules/telemetry" import { redis } from "@/redis" -import { defer } from "@/utils/deferred-middleware" +// import { defer } from "@/utils/deferred-middleware" import { duration } from "@/utils/duration" import { fmt, fmtUser } from "@/utils/format" import { createFakeMessage, ephemeral, getText } from "@/utils/messages" import { throttle } from "@/utils/throttle" import type { Context } from "@/utils/types" import { MessageUserStorage } from "../message-user-storage" -import { AIModeration } from "./ai-moderation" +// import { AIModeration } from "./ai-moderation" import { MULTI_CHAT_SPAM, NON_LATIN } from "./constants" import { checkForAllowedLinks } from "./functions" @@ -42,14 +42,14 @@ const debouncedError = throttle((error: unknown, msg: string) => { * - [x] Links handler * - [x] Harmful content handler * - [x] Multichat spam handler for similar messages - * - [ ] Avoid deletion for messages explicitly allowed by Direttivo or from privileged users + * - [x] Avoid deletion for messages explicitly allowed by Direttivo or from privileged users * - [x] handle non-latin characters */ export class AutoModerationStack> extends TrackedMiddleware< ModerationFlavor > { // AI moderation instance - private aiModeration: AIModeration = new AIModeration() + // private aiModeration: AIModeration = new AIModeration() constructor() { super("auto_moderation_stack") @@ -75,7 +75,7 @@ export class AutoModerationStack> exte .use(measureForkDuration("auto_moderation_link_duration")) .use((ctx) => this.linkHandler(ctx)) // AI takes too long to measure, completely defer it to avoid blocking the main stack, and just log any errors - filtered.fork().use(defer((ctx) => this.harmfulContentHandler(ctx))) + // filtered.fork().use(defer((ctx) => this.harmfulContentHandler(ctx))) filtered .on([":text", ":caption"]) .fork() @@ -168,61 +168,61 @@ export class AutoModerationStack> exte ) } - /** - * Checks messages for harmful content using AI moderation. - * If harmful content is detected, mutes the user and deletes the message. - */ - private async harmfulContentHandler(ctx: ModerationContext) { - const message = ctx.msg - const flaggedCategories = await this.aiModeration.checkForHarmfulContent(ctx) + // /** + // * Checks messages for harmful content using AI moderation. + // * If harmful content is detected, mutes the user and deletes the message. + // */ + // private async harmfulContentHandler(ctx: ModerationContext) { + // const message = ctx.msg + // const flaggedCategories = await this.aiModeration.checkForHarmfulContent(ctx) - if (flaggedCategories.length > 0) { - const reasons = flaggedCategories.map((cat) => ` - ${cat.category} (${(cat.score * 100).toFixed(1)}%)`).join("\n") + // if (flaggedCategories.length > 0) { + // const reasons = flaggedCategories.map((cat) => ` - ${cat.category} (${(cat.score * 100).toFixed(1)}%)`).join("\n") - if (flaggedCategories.some((cat) => cat.aboveThreshold)) { - if (ctx.whitelisted) { - // log the action but do not mute - if (ctx.whitelisted.role === "user") - await modules.get("tgLogger").grants({ - action: "USAGE", - from: ctx.from, - chat: ctx.chat, - message, - }) - } else { - // above threshold, mute user and delete the message - const res = await Moderation.mute( - ctx.from, - ctx.chat, - ctx.me, - duration.zod.parse("1d"), - [message], - `Automatic moderation detected harmful content\n${reasons}` - ) + // if (flaggedCategories.some((cat) => cat.aboveThreshold)) { + // if (ctx.whitelisted) { + // // log the action but do not mute + // if (ctx.whitelisted.role === "user") + // await modules.get("tgLogger").grants({ + // action: "USAGE", + // from: ctx.from, + // chat: ctx.chat, + // message, + // }) + // } else { + // // above threshold, mute user and delete the message + // const res = await Moderation.mute( + // ctx.from, + // ctx.chat, + // ctx.me, + // duration.zod.parse("1d"), + // [message], + // `Automatic moderation detected harmful content\n${reasons}` + // ) - void ephemeral( - ctx.reply( - res.isOk() - ? fmt(({ i, b }) => [ - b`⚠️ Message from ${fmtUser(ctx.from)} was deleted automatically due to harmful content.`, - i`If you think this is a mistake, please contact the group administrators.`, - ]) - : res.error.fmtError - ) - ) - } - } else { - // no flagged category is above the threshold, still log it for manual review - await modules.get("tgLogger").moderationAction({ - action: "SILENT", - from: ctx.me, - chat: ctx.chat, - target: ctx.from, - reason: `Message flagged for moderation: \n${reasons}`, - }) - } - } - } + // void ephemeral( + // ctx.reply( + // res.isOk() + // ? fmt(({ i, b }) => [ + // b`⚠️ Message from ${fmtUser(ctx.from)} was deleted automatically due to harmful content.`, + // i`If you think this is a mistake, please contact the group administrators.`, + // ]) + // : res.error.fmtError + // ) + // ) + // } + // } else { + // // no flagged category is above the threshold, still log it for manual review + // await modules.get("tgLogger").moderationAction({ + // action: "SILENT", + // from: ctx.me, + // chat: ctx.chat, + // target: ctx.from, + // reason: `Message flagged for moderation: \n${reasons}`, + // }) + // } + // } + // } /** * Handles messages containing a high percentage of non-latin characters to avoid most spam bots.