diff --git a/packages/server/src/routes/predictions/index.ts b/packages/server/src/routes/predictions/index.ts index 40f37eef..077d10be 100644 --- a/packages/server/src/routes/predictions/index.ts +++ b/packages/server/src/routes/predictions/index.ts @@ -4,6 +4,8 @@ import { getMulterStorage } from '../../utils' const router = express.Router() +// NOTE: extractChatflowId function in XSS.ts extracts the chatflow ID from the prediction URL. +// It assumes the URL format is /prediction/{chatflowId}. Make sure to update the function if the URL format changes. // CREATE router.post( ['/', '/:id'], diff --git a/packages/server/src/services/chatflows/index.ts b/packages/server/src/services/chatflows/index.ts index 17456bc1..913eb533 100644 --- a/packages/server/src/services/chatflows/index.ts +++ b/packages/server/src/services/chatflows/index.ts @@ -378,6 +378,8 @@ const getSinglePublicChatbotConfig = async (chatflowId: string): Promise => } }) } + delete parsedConfig.allowedOrigins + delete parsedConfig.allowedOriginsError return { ...parsedConfig, uploads: uploadsConfig, flowData: dbResponse.flowData, isTTSEnabled } } catch (e) { throw new InternalFlowiseError(StatusCodes.INTERNAL_SERVER_ERROR, `Error parsing Chatbot Config for Chatflow ${chatflowId}`) diff --git a/packages/server/src/utils/XSS.ts b/packages/server/src/utils/XSS.ts index 96bbab57..f7c44686 100644 --- a/packages/server/src/utils/XSS.ts +++ b/packages/server/src/utils/XSS.ts @@ -1,5 +1,6 @@ import { Request, Response, NextFunction } from 'express' import sanitizeHtml from 'sanitize-html' +import { isPredictionRequest, extractChatflowId, validateChatflowDomain } from './domainValidation' export function sanitizeMiddleware(req: Request, res: Response, next: NextFunction): void { // decoding is necessary as the url is encoded by the browser @@ -20,22 +21,60 @@ export function sanitizeMiddleware(req: Request, res: Response, next: NextFuncti } export function getAllowedCorsOrigins(): string { - // Expects FQDN separated by commas, otherwise nothing or * for all. - return process.env.CORS_ORIGINS ?? '*' + // Expects FQDN separated by commas, otherwise nothing. + return process.env.CORS_ORIGINS ?? '' +} + +function parseAllowedOrigins(allowedOrigins: string): string[] { + if (!allowedOrigins) { + return [] + } + if (allowedOrigins === '*') { + return ['*'] + } + return allowedOrigins + .split(',') + .map((origin) => origin.trim().toLowerCase()) + .filter((origin) => origin.length > 0) } export function getCorsOptions(): any { - const corsOptions = { - origin: function (origin: string | undefined, callback: (err: Error | null, allow?: boolean) => void) { - const allowedOrigins = getAllowedCorsOrigins() - if (!origin || allowedOrigins == '*' || allowedOrigins.indexOf(origin) !== -1) { - callback(null, true) - } else { - callback(null, false) + return (req: any, callback: (err: Error | null, options?: any) => void) => { + const corsOptions = { + origin: async (origin: string | undefined, originCallback: (err: Error | null, allow?: boolean) => void) => { + const allowedOrigins = getAllowedCorsOrigins() + const isPredictionReq = isPredictionRequest(req.url) + const allowedList = parseAllowedOrigins(allowedOrigins) + const originLc = origin?.toLowerCase() + + // Always allow no-Origin requests (same-origin, server-to-server) + if (!originLc) return originCallback(null, true) + + // Global allow: '*' or exact match + const globallyAllowed = allowedOrigins === '*' || allowedList.includes(originLc) + + if (isPredictionReq) { + // Per-chatflow allowlist OR globally allowed + const chatflowId = extractChatflowId(req.url) + let chatflowAllowed = false + if (chatflowId) { + try { + chatflowAllowed = await validateChatflowDomain(chatflowId, originLc, req.user?.activeWorkspaceId) + } catch (error) { + // Log error and deny on failure + console.error('Domain validation error:', error) + chatflowAllowed = false + } + } + return originCallback(null, globallyAllowed || chatflowAllowed) + } + + // Non-prediction: rely on global policy only + return originCallback(null, globallyAllowed) } } + callback(null, corsOptions) } - return corsOptions } export function getAllowedIframeOrigins(): string { diff --git a/packages/server/src/utils/domainValidation.ts b/packages/server/src/utils/domainValidation.ts new file mode 100644 index 00000000..a2482d8b --- /dev/null +++ b/packages/server/src/utils/domainValidation.ts @@ -0,0 +1,109 @@ +import { isValidUUID } from 'flowise-components' +import chatflowsService from '../services/chatflows' +import logger from './logger' + +/** + * Validates if the origin is allowed for a specific chatflow + * @param chatflowId - The chatflow ID to validate against + * @param origin - The origin URL to validate + * @param workspaceId - Optional workspace ID for enterprise features + * @returns Promise - True if domain is allowed, false otherwise + */ +async function validateChatflowDomain(chatflowId: string, origin: string, workspaceId?: string): Promise { + try { + if (!chatflowId || !isValidUUID(chatflowId)) { + throw new Error('Invalid chatflowId format - must be a valid UUID') + } + + const chatflow = workspaceId + ? await chatflowsService.getChatflowById(chatflowId, workspaceId) + : await chatflowsService.getChatflowById(chatflowId) + + if (!chatflow?.chatbotConfig) { + return true + } + + const config = JSON.parse(chatflow.chatbotConfig) + + // If no allowed origins configured or first entry is empty, allow all + if (!config.allowedOrigins?.length || config.allowedOrigins[0] === '') { + return true + } + + const originHost = new URL(origin).host + const isAllowed = config.allowedOrigins.some((domain: string) => { + try { + const allowedOrigin = new URL(domain).host + return originHost === allowedOrigin + } catch (error) { + logger.warn(`Invalid domain format in allowedOrigins: ${domain}`) + return false + } + }) + + return isAllowed + } catch (error) { + logger.error(`Error validating domain for chatflow ${chatflowId}:`, error) + return false + } +} + +// NOTE: This function extracts the chatflow ID from a prediction URL. +// It assumes the URL format is /prediction/{chatflowId}. +/** + * Extracts chatflow ID from prediction URL + * @param url - The request URL + * @returns string | null - The chatflow ID or null if not found + */ +function extractChatflowId(url: string): string | null { + try { + const urlParts = url.split('/') + const predictionIndex = urlParts.indexOf('prediction') + + if (predictionIndex !== -1 && urlParts.length > predictionIndex + 1) { + const chatflowId = urlParts[predictionIndex + 1] + // Remove query parameters if present + return chatflowId.split('?')[0] + } + + return null + } catch (error) { + logger.error('Error extracting chatflow ID from URL:', error) + return null + } +} + +/** + * Validates if a request is a prediction request + * @param url - The request URL + * @returns boolean - True if it's a prediction request + */ +function isPredictionRequest(url: string): boolean { + return url.includes('/prediction/') +} + +/** + * Get the custom error message for unauthorized origin + * @param chatflowId - The chatflow ID + * @param workspaceId - Optional workspace ID + * @returns Promise - Custom error message or default + */ +async function getUnauthorizedOriginError(chatflowId: string, workspaceId?: string): Promise { + try { + const chatflow = workspaceId + ? await chatflowsService.getChatflowById(chatflowId, workspaceId) + : await chatflowsService.getChatflowById(chatflowId) + + if (chatflow?.chatbotConfig) { + const config = JSON.parse(chatflow.chatbotConfig) + return config.allowedOriginsError || 'This site is not allowed to access this chatbot' + } + + return 'This site is not allowed to access this chatbot' + } catch (error) { + logger.error(`Error getting unauthorized origin error for chatflow ${chatflowId}:`, error) + return 'This site is not allowed to access this chatbot' + } +} + +export { isPredictionRequest, extractChatflowId, validateChatflowDomain, getUnauthorizedOriginError }