From 1b308a8b54d02c37efb20dd2071a39be0e704216 Mon Sep 17 00:00:00 2001 From: vinodkiran Date: Sat, 9 Dec 2023 22:07:16 +0530 Subject: [PATCH] making the chain multi-modal. now we accept audio and image uploads and can run inference --- .../nodes/multimodal/OpenAI/AudioWhisper.ts | 13 ++-- .../multimodal/OpenAI/OpenAIVisionChain.ts | 6 +- .../nodes/multimodal/OpenAI/VLLMChain.ts | 53 ++++++++++++++-- .../ui/src/views/chatmessage/ChatMessage.js | 62 +++++++++++-------- 4 files changed, 96 insertions(+), 38 deletions(-) diff --git a/packages/components/nodes/multimodal/OpenAI/AudioWhisper.ts b/packages/components/nodes/multimodal/OpenAI/AudioWhisper.ts index b308a7c5..aa2c71e1 100644 --- a/packages/components/nodes/multimodal/OpenAI/AudioWhisper.ts +++ b/packages/components/nodes/multimodal/OpenAI/AudioWhisper.ts @@ -7,6 +7,7 @@ class OpenAIAudioWhisper implements INode { description: string type: string icon: string + badge: string category: string baseClasses: string[] inputs: INodeParams[] @@ -18,6 +19,7 @@ class OpenAIAudioWhisper implements INode { this.type = 'OpenAIWhisper' this.description = 'Speech to text using OpenAI Whisper API' this.icon = 'audio.svg' + this.badge = 'BETA' this.category = 'MultiModal' this.baseClasses = [this.type] this.inputs = [ @@ -27,14 +29,15 @@ class OpenAIAudioWhisper implements INode { type: 'options', options: [ { - label: 'transcription', + label: 'Transcription', name: 'transcription' }, { - label: 'translation', + label: 'Translation', name: 'translation' } - ] + ], + default: 'transcription' }, { label: 'Accepted Upload Types', @@ -54,7 +57,9 @@ class OpenAIAudioWhisper implements INode { } async init(nodeData: INodeData): Promise { - return {} + const purpose = nodeData.inputs?.purpose as string + + return { purpose } } } diff --git a/packages/components/nodes/multimodal/OpenAI/OpenAIVisionChain.ts b/packages/components/nodes/multimodal/OpenAI/OpenAIVisionChain.ts index 4151b4b0..dcaa96e2 100644 --- a/packages/components/nodes/multimodal/OpenAI/OpenAIVisionChain.ts +++ b/packages/components/nodes/multimodal/OpenAI/OpenAIVisionChain.ts @@ -132,7 +132,7 @@ class OpenAIVisionChain_Chains implements INode { this.outputs = [ { label: 'Open AI MultiModal Chain', - name: 'OpenAIMultiModalChain', + name: 'openAIMultiModalChain', baseClasses: [this.type, ...getBaseClasses(VLLMChain)] }, { @@ -154,6 +154,8 @@ class OpenAIVisionChain_Chains implements INode { const modelName = nodeData.inputs?.modelName as string const maxTokens = nodeData.inputs?.maxTokens as string const topP = nodeData.inputs?.topP as string + const whisperConfig = nodeData.inputs?.audioInput + const fields: OpenAIVisionChainInput = { openAIApiKey: openAIApiKey, imageResolution: imageResolution, @@ -164,6 +166,8 @@ class OpenAIVisionChain_Chains implements INode { if (temperature) fields.temperature = parseFloat(temperature) if (maxTokens) fields.maxTokens = parseInt(maxTokens, 10) if (topP) fields.topP = parseFloat(topP) + if (whisperConfig) fields.whisperConfig = whisperConfig + if (output === this.name) { const chain = new VLLMChain({ ...fields, diff --git a/packages/components/nodes/multimodal/OpenAI/VLLMChain.ts b/packages/components/nodes/multimodal/OpenAI/VLLMChain.ts index 2849cf63..dd44ebb5 100644 --- a/packages/components/nodes/multimodal/OpenAI/VLLMChain.ts +++ b/packages/components/nodes/multimodal/OpenAI/VLLMChain.ts @@ -21,6 +21,7 @@ export interface OpenAIVisionChainInput extends ChainInputs { modelName?: string maxTokens?: number topP?: number + whisperConfig?: any } /** @@ -48,6 +49,8 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput { maxTokens?: number topP?: number + whisperConfig?: any + constructor(fields: OpenAIVisionChainInput) { super(fields) this.throwError = fields?.throwError ?? false @@ -59,6 +62,7 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput { this.maxTokens = fields?.maxTokens this.topP = fields?.topP this.imageUrls = fields?.imageUrls ?? [] + this.whisperConfig = fields?.whisperConfig ?? {} if (!this.openAIApiKey) { throw new Error('OpenAI API key not found') } @@ -92,15 +96,44 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput { type: 'text', text: userInput }) + if (this.whisperConfig && this.imageUrls && this.imageUrls.length > 0) { + const audioUploads = this.getAudioUploads(this.imageUrls) + for (const url of audioUploads) { + const filePath = path.join(getUserHome(), '.flowise', 'gptvision', url.data, url.name) + + // as the image is stored in the server, read the file and convert it to base64 + const audio_file = fs.createReadStream(filePath) + if (this.whisperConfig.purpose === 'transcription') { + const transcription = await this.client.audio.transcriptions.create({ + file: audio_file, + model: 'whisper-1' + }) + userRole.content.push({ + type: 'text', + text: transcription.text + }) + } else if (this.whisperConfig.purpose === 'translation') { + const translation = await this.client.audio.translations.create({ + file: audio_file, + model: 'whisper-1' + }) + userRole.content.push({ + type: 'text', + text: translation.text + }) + } + } + } if (this.imageUrls && this.imageUrls.length > 0) { - this.imageUrls.forEach((imageUrl: any) => { - let bf = imageUrl?.data - if (imageUrl.type == 'stored-file') { - const filePath = path.join(getUserHome(), '.flowise', 'gptvision', imageUrl.data, imageUrl.name) + const imageUploads = this.getImageUploads(this.imageUrls) + for (const url of imageUploads) { + let bf = url.data + if (url.type == 'stored-file') { + const filePath = path.join(getUserHome(), '.flowise', 'gptvision', url.data, url.name) // as the image is stored in the server, read the file and convert it to base64 const contents = fs.readFileSync(filePath) - bf = 'data:' + imageUrl.mime + ';base64,' + contents.toString('base64') + bf = 'data:' + url.mime + ';base64,' + contents.toString('base64') } userRole.content.push({ type: 'image_url', @@ -109,7 +142,7 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput { detail: this.imageResolution } }) - }) + } } vRequest.messages.push(userRole) if (this.prompt && this.prompt instanceof ChatPromptTemplate) { @@ -146,6 +179,14 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput { } } + getAudioUploads = (urls: any[]) => { + return urls.filter((url: any) => url.mime.startsWith('audio/')) + } + + getImageUploads = (urls: any[]) => { + return urls.filter((url: any) => url.mime.startsWith('image/')) + } + _chainType() { return 'vision_chain' } diff --git a/packages/ui/src/views/chatmessage/ChatMessage.js b/packages/ui/src/views/chatmessage/ChatMessage.js index 79a9b6e0..37b45bd5 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.js +++ b/packages/ui/src/views/chatmessage/ChatMessage.js @@ -14,7 +14,6 @@ import { Box, Button, Card, - CardActions, CardMedia, Chip, CircularProgress, @@ -48,7 +47,6 @@ import { baseURL, maxScroll } from 'store/constant' import robotPNG from 'assets/images/robot.png' import userPNG from 'assets/images/account.png' import { isValidURL, removeDuplicateURL, setLocalStorageChatflow } from 'utils/genericHelper' -import DeleteIcon from '@mui/icons-material/Delete' export const ChatMessage = ({ open, chatflowid, isDialog }) => { const theme = useTheme() @@ -628,15 +626,25 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { {message.fileUploads && message.fileUploads.map((item, index) => { return ( - - - + <> + {item.mime.startsWith('image/') ? ( + + + + ) : ( + // eslint-disable-next-line jsx-a11y/media-has-caption + + )} + ) })} {message.sourceDocuments && ( @@ -738,23 +746,23 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { {previews.map((item, index) => ( - - - -