diff --git a/packages/server/package.json b/packages/server/package.json index 9b9b7dbd..4d4293b0 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -46,6 +46,7 @@ "license": "SEE LICENSE IN LICENSE.md", "dependencies": { "@oclif/core": "^1.13.10", + "async-mutex": "^0.4.0", "axios": "^0.27.2", "cors": "^2.8.5", "crypto-js": "^4.1.1", diff --git a/packages/server/src/Interface.ts b/packages/server/src/Interface.ts index 92e3054d..a83e556e 100644 --- a/packages/server/src/Interface.ts +++ b/packages/server/src/Interface.ts @@ -15,6 +15,9 @@ export interface IChatFlow { isPublic?: boolean apikeyid?: string chatbotConfig?: string + rateLimit?: number + rateLimitDuration?: number + rateLimitMsg?: string } export interface IChatMessage { diff --git a/packages/server/src/entity/ChatFlow.ts b/packages/server/src/entity/ChatFlow.ts index 4c37e083..e8ed861b 100644 --- a/packages/server/src/entity/ChatFlow.ts +++ b/packages/server/src/entity/ChatFlow.ts @@ -25,6 +25,15 @@ export class ChatFlow implements IChatFlow { @Column({ nullable: true }) chatbotConfig?: string + @Column({ nullable: true }) + rateLimit?: number + + @Column({ nullable: true }) + rateLimitDuration?: number + + @Column({ nullable: true }) + rateLimitMsg?: string + @CreateDateColumn() createdDate: Date diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 15709ad9..f6df0c30 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -54,7 +54,7 @@ import { Credential } from './entity/Credential' import { Tool } from './entity/Tool' import { ChatflowPool } from './ChatflowPool' import { ICommonObject, INodeOptionsValue } from 'flowise-components' -import { createRateLimiter, getRateLimiter } from './utils/rateLimit' +import { createRateLimiter, getRateLimiter, initializeRateLimiter } from './utils/rateLimit' export class App { app: express.Application @@ -84,6 +84,10 @@ export class App { // Initialize encryption key await getEncryptionKey() + + // Initialize Rate Limit + const AllChatFlow: IChatFlow[] = await getAllChatFlow() + await initializeRateLimiter(AllChatFlow) }) .catch((err) => { logger.error('❌ [server]: Error during Data Source initialization:', err) @@ -246,7 +250,7 @@ export class App { // Get all chatflows this.app.get('/api/v1/chatflows', async (req: Request, res: Response) => { - const chatflows: IChatFlow[] = await this.AppDataSource.getRepository(ChatFlow).find() + const chatflows: IChatFlow[] = await getAllChatFlow() return res.json(chatflows) }) @@ -655,21 +659,6 @@ export class App { // Prediction // ---------------------------------------- - this.app.get( - '/api/v1/rate-limit/:id', - upload.array('files'), - (req: Request, res: Response, next: NextFunction) => getRateLimiter(req, res, next), - // specificRouteLimiter, - async (req: Request, res: Response) => { - res.send("you're fine") - } - ) - - this.app.post('/api/v1/rate-limit/', async (req: Request, res: Response) => { - createRateLimiter(req) - res.send('Created/Updated rate limit') - }) - // Send input message and get prediction result (External) this.app.post('/api/v1/prediction/:id', upload.array('files'), async (req: Request, res: Response) => { await this.processPrediction(req, res, socketIO) @@ -768,6 +757,39 @@ export class App { } }) + // ---------------------------------------- + // Rate Limit + // ---------------------------------------- + + this.app.get( + '/api/v1/rate-limit/:id', + upload.array('files'), + (req: Request, res: Response, next: NextFunction) => getRateLimiter(req, res, next), + // specificRouteLimiter, + async (req: Request, res: Response) => { + res.send("you're fine") + } + ) + + this.app.post('/api/v1/rate-limit/', async (req: Request, res: Response) => { + const id = req.body.id + const duration = req.body.duration + const limit = req.body.limit + const message = req.body.message + + const result = await getDataSource() + .getRepository(ChatFlow) + .createQueryBuilder() + .update(ChatFlow) + .set({ rateLimit: limit, rateLimitDuration: duration, rateLimitMsg: message }) + .where('id = :id', { id: id }) + .execute() + + await createRateLimiter(id, Number(duration), Number(limit), message) + + res.send({ result }) + }) + // ---------------------------------------- // Serve UI static // ---------------------------------------- @@ -1012,6 +1034,10 @@ export async function getChatId(chatflowid: string) { let serverApp: App | undefined +export async function getAllChatFlow(): Promise { + return await getDataSource().getRepository(ChatFlow).find() +} + export async function start(): Promise { serverApp = new App() diff --git a/packages/server/src/utils/rateLimit.ts b/packages/server/src/utils/rateLimit.ts index 0bd5be98..882964c4 100644 --- a/packages/server/src/utils/rateLimit.ts +++ b/packages/server/src/utils/rateLimit.ts @@ -1,56 +1,39 @@ import { NextFunction, Request, Response } from 'express' import { rateLimit, RateLimitRequestHandler } from 'express-rate-limit' +import { IChatFlow } from '../Interface' +import { Mutex } from 'async-mutex' -interface RateLimit { - id: string - rateLimitObj: RateLimitRequestHandler -} +let rateLimiters: Record = {} +const rateLimiterMutex = new Mutex() -export const specificRouteLimiter: RateLimitRequestHandler = rateLimit({ - windowMs: 1 * 60 * 1000, // 15 minutes - max: 1, // Limit each IP to 100 requests per windowMs - message: 'Too many requests, please try again later.' -}) - -let rateLimiters: RateLimit[] = [] - -export function createRateLimiter(req: Request) { - const id = req.body.id - const duration = req.body.duration - const limit = req.body.limit - const message = req.body.message - - const rateLimitObj: RateLimitRequestHandler = rateLimit({ - windowMs: Number(duration), - max: limit, - handler: (req, res) => { - res.status(429).json({ error: message }) - } - }) - - const existingIndex: number = rateLimiters.findIndex((rateLimit) => rateLimit.id === id) - - if (existingIndex === -1) { - rateLimiters.push({ - id, - rateLimitObj +export async function createRateLimiter(id: string, duration: number, limit: number, message: string) { + const release = await rateLimiterMutex.acquire() + try { + rateLimiters[id] = rateLimit({ + windowMs: duration, + max: limit, + handler: (req, res) => { + res.status(429).json({ error: message }) + } }) - } else { - rateLimiters[existingIndex] = { - id, - rateLimitObj - } + } finally { + release() } } export function getRateLimiter(req: Request, res: Response, next: NextFunction) { const id = req.params.id - const ratelimiter = rateLimiters.find((rateLimit) => rateLimit.id === id) + if (!rateLimiters[id]) return next() - if (!ratelimiter) return next() - - const idRateLimiter = ratelimiter.rateLimitObj + const idRateLimiter = rateLimiters[id] return idRateLimiter(req, res, next) } + +export async function initializeRateLimiter(ChatFlowPool: IChatFlow[]) { + await ChatFlowPool.map(async (ChatFlow) => { + if (ChatFlow.rateLimitDuration && ChatFlow.rateLimit && ChatFlow.rateLimitMsg) + await createRateLimiter(ChatFlow.id, ChatFlow.rateLimitDuration, ChatFlow.rateLimit, ChatFlow.rateLimitMsg) + }) +}