Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 10 additions & 27 deletions src/middlewares/auto-moderation-stack/ai-moderation.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { EventEmitter } from "node:events"
import type { Filter } from "grammy"
import OpenAI from "openai"
import { env } from "@/env"
import { logger } from "@/logger"
import { getText } from "@/utils/messages"
import type { Context } from "@/utils/types"
import { Awaiter } from "@/utils/wait"
import { DELETION_THRESHOLDS } from "./constants"
import type { Category, FlaggedCategory, ModerationCandidate, ModerationResult } from "./types"

Expand All @@ -22,9 +22,7 @@ import type { Category, FlaggedCategory, ModerationCandidate, ModerationResult }
*
* More info on the API here: https://platform.openai.com/docs/guides/moderation
*/
export class AIModeration<C extends Context> extends EventEmitter<{
results: [ModerationResult[]]
}> {
export class AIModeration<C extends Context> {
/**
* Takes each category, and for the flagged ones takes the score (highest among related results) and
* confronts it with predefined thresholds
Expand Down Expand Up @@ -64,10 +62,11 @@ export class AIModeration<C extends Context> extends EventEmitter<{
private client: OpenAI | null
private checkQueue: ModerationCandidate[] = []
private timeout: NodeJS.Timeout | null = null
private responseAwaiter: Awaiter<ModerationResult[]> = new Awaiter()

constructor() {
super()
this.client = env.OPENAI_API_KEY ? new OpenAI({ apiKey: env.OPENAI_API_KEY }) : null
this.responseAwaiter.resolve([]) // initialize it as resolved with empty results

if (!this.client) logger.warn("[AI Mod] Missing env OPENAI_API_KEY, automatic moderation will not work.")
else logger.debug("[AI Mod] OpenAI client initialized for moderation.")
Expand All @@ -82,41 +81,27 @@ export class AIModeration<C extends Context> extends EventEmitter<{
if (!this.client) return
if (this.checkQueue.length === 0) return

this.responseAwaiter = new Awaiter() // reset the awaiter for the next batch
const candidates = this.checkQueue.splice(0, this.checkQueue.length)

void this.client.moderations
.create({ input: candidates, model: "omni-moderation-latest" })
.then((response) => {
this.emit("results", response.results)
this.responseAwaiter.resolve(response.results)
})
.catch((error: unknown) => {
logger.error({ error }, "[AI Mod] Error during moderation check")
this.responseAwaiter.resolve([]) // fail open: return empty results on error
})
}

/**
* Wait for the moderation results to be emitted.
*
* This is done to allow batching of moderation checks.
* @returns A promise that resolves with the moderation results, mapped as they were queued.
*/
private waitForResults(): Promise<ModerationResult[]> {
return new Promise((resolve, reject) => {
this.once("results", (results) => {
resolve(results)
})
setTimeout(() => {
reject(new Error("Moderation Check timed out"))
}, 1000 * 30)
})
}

/**
* Add a candidate to the moderation check queue, returns the result if found.
* @param candidate the candidate to add to the queue, either text or image
* @returns A promise that resolves with the moderation result, or null if not found or timed out.
*/
private addToCheckQueue(candidate: ModerationCandidate): Promise<ModerationResult | null> {
private async addToCheckQueue(candidate: ModerationCandidate): Promise<ModerationResult | null> {
await this.responseAwaiter // wait for the previous batch to be processed
const index = this.checkQueue.push(candidate) - 1
if (this.timeout === null) {
// throttle a check every 10 seconds
Expand All @@ -125,9 +110,7 @@ export class AIModeration<C extends Context> extends EventEmitter<{
this.timeout = null
}, 10 * 1000)
}
return this.waitForResults()
.then((results) => results[index] ?? null)
.catch(() => null) // check timed out
return this.responseAwaiter.then((results) => results[index] ?? null)
}

/**
Expand Down
114 changes: 57 additions & 57 deletions src/middlewares/auto-moderation-stack/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import { modules } from "@/modules"
import { Moderation } from "@/modules/moderation"
import { measureForkDuration, type TelemetryContextFlavor, TrackedMiddleware } from "@/modules/telemetry"
import { redis } from "@/redis"
import { defer } from "@/utils/deferred-middleware"
// import { defer } from "@/utils/deferred-middleware"
import { duration } from "@/utils/duration"
import { fmt, fmtUser } from "@/utils/format"
import { createFakeMessage, ephemeral, getText } from "@/utils/messages"
import { throttle } from "@/utils/throttle"
import type { Context } from "@/utils/types"
import { MessageUserStorage } from "../message-user-storage"
import { AIModeration } from "./ai-moderation"
// import { AIModeration } from "./ai-moderation"
import { MULTI_CHAT_SPAM, NON_LATIN } from "./constants"
import { checkForAllowedLinks } from "./functions"

Expand Down Expand Up @@ -42,14 +42,14 @@ const debouncedError = throttle((error: unknown, msg: string) => {
* - [x] Links handler
* - [x] Harmful content handler
* - [x] Multichat spam handler for similar messages
* - [ ] Avoid deletion for messages explicitly allowed by Direttivo or from privileged users
* - [x] Avoid deletion for messages explicitly allowed by Direttivo or from privileged users
* - [x] handle non-latin characters
*/
export class AutoModerationStack<C extends TelemetryContextFlavor<Context>> extends TrackedMiddleware<
ModerationFlavor<C>
> {
// AI moderation instance
private aiModeration: AIModeration<C> = new AIModeration<C>()
// private aiModeration: AIModeration<C> = new AIModeration<C>()

constructor() {
super("auto_moderation_stack")
Expand All @@ -75,7 +75,7 @@ export class AutoModerationStack<C extends TelemetryContextFlavor<Context>> exte
.use(measureForkDuration("auto_moderation_link_duration"))
.use((ctx) => this.linkHandler(ctx))
// AI takes too long to measure, completely defer it to avoid blocking the main stack, and just log any errors
filtered.fork().use(defer((ctx) => this.harmfulContentHandler(ctx)))
// filtered.fork().use(defer((ctx) => this.harmfulContentHandler(ctx)))
filtered
.on([":text", ":caption"])
.fork()
Expand Down Expand Up @@ -168,61 +168,61 @@ export class AutoModerationStack<C extends TelemetryContextFlavor<Context>> exte
)
}

/**
* Checks messages for harmful content using AI moderation.
* If harmful content is detected, mutes the user and deletes the message.
*/
private async harmfulContentHandler(ctx: ModerationContext<C>) {
const message = ctx.msg
const flaggedCategories = await this.aiModeration.checkForHarmfulContent(ctx)
// /**
// * Checks messages for harmful content using AI moderation.
// * If harmful content is detected, mutes the user and deletes the message.
// */
// private async harmfulContentHandler(ctx: ModerationContext<C>) {
// const message = ctx.msg
// const flaggedCategories = await this.aiModeration.checkForHarmfulContent(ctx)

if (flaggedCategories.length > 0) {
const reasons = flaggedCategories.map((cat) => ` - ${cat.category} (${(cat.score * 100).toFixed(1)}%)`).join("\n")
// if (flaggedCategories.length > 0) {
// const reasons = flaggedCategories.map((cat) => ` - ${cat.category} (${(cat.score * 100).toFixed(1)}%)`).join("\n")

if (flaggedCategories.some((cat) => cat.aboveThreshold)) {
if (ctx.whitelisted) {
// log the action but do not mute
if (ctx.whitelisted.role === "user")
await modules.get("tgLogger").grants({
action: "USAGE",
from: ctx.from,
chat: ctx.chat,
message,
})
} else {
// above threshold, mute user and delete the message
const res = await Moderation.mute(
ctx.from,
ctx.chat,
ctx.me,
duration.zod.parse("1d"),
[message],
`Automatic moderation detected harmful content\n${reasons}`
)
// if (flaggedCategories.some((cat) => cat.aboveThreshold)) {
// if (ctx.whitelisted) {
// // log the action but do not mute
// if (ctx.whitelisted.role === "user")
// await modules.get("tgLogger").grants({
// action: "USAGE",
// from: ctx.from,
// chat: ctx.chat,
// message,
// })
// } else {
// // above threshold, mute user and delete the message
// const res = await Moderation.mute(
// ctx.from,
// ctx.chat,
// ctx.me,
// duration.zod.parse("1d"),
// [message],
// `Automatic moderation detected harmful content\n${reasons}`
// )

void ephemeral(
ctx.reply(
res.isOk()
? fmt(({ i, b }) => [
b`⚠️ Message from ${fmtUser(ctx.from)} was deleted automatically due to harmful content.`,
i`If you think this is a mistake, please contact the group administrators.`,
])
: res.error.fmtError
)
)
}
} else {
// no flagged category is above the threshold, still log it for manual review
await modules.get("tgLogger").moderationAction({
action: "SILENT",
from: ctx.me,
chat: ctx.chat,
target: ctx.from,
reason: `Message flagged for moderation: \n${reasons}`,
})
}
}
}
// void ephemeral(
// ctx.reply(
// res.isOk()
// ? fmt(({ i, b }) => [
// b`⚠️ Message from ${fmtUser(ctx.from)} was deleted automatically due to harmful content.`,
// i`If you think this is a mistake, please contact the group administrators.`,
// ])
// : res.error.fmtError
// )
// )
// }
// } else {
// // no flagged category is above the threshold, still log it for manual review
// await modules.get("tgLogger").moderationAction({
// action: "SILENT",
// from: ctx.me,
// chat: ctx.chat,
// target: ctx.from,
// reason: `Message flagged for moderation: \n${reasons}`,
// })
// }
// }
// }

/**
* Handles messages containing a high percentage of non-latin characters to avoid most spam bots.
Expand Down
Loading
Loading