making the chain multi-modal. now we accept audio and image uploads and can run inference

This commit is contained in:
vinodkiran
2023-12-09 22:07:16 +05:30
parent 32575828cd
commit 1b308a8b54
4 changed files with 96 additions and 38 deletions
@@ -7,6 +7,7 @@ class OpenAIAudioWhisper implements INode {
description: string description: string
type: string type: string
icon: string icon: string
badge: string
category: string category: string
baseClasses: string[] baseClasses: string[]
inputs: INodeParams[] inputs: INodeParams[]
@@ -18,6 +19,7 @@ class OpenAIAudioWhisper implements INode {
this.type = 'OpenAIWhisper' this.type = 'OpenAIWhisper'
this.description = 'Speech to text using OpenAI Whisper API' this.description = 'Speech to text using OpenAI Whisper API'
this.icon = 'audio.svg' this.icon = 'audio.svg'
this.badge = 'BETA'
this.category = 'MultiModal' this.category = 'MultiModal'
this.baseClasses = [this.type] this.baseClasses = [this.type]
this.inputs = [ this.inputs = [
@@ -27,14 +29,15 @@ class OpenAIAudioWhisper implements INode {
type: 'options', type: 'options',
options: [ options: [
{ {
label: 'transcription', label: 'Transcription',
name: 'transcription' name: 'transcription'
}, },
{ {
label: 'translation', label: 'Translation',
name: 'translation' name: 'translation'
} }
] ],
default: 'transcription'
}, },
{ {
label: 'Accepted Upload Types', label: 'Accepted Upload Types',
@@ -54,7 +57,9 @@ class OpenAIAudioWhisper implements INode {
} }
async init(nodeData: INodeData): Promise<any> { async init(nodeData: INodeData): Promise<any> {
return {} const purpose = nodeData.inputs?.purpose as string
return { purpose }
} }
} }
@@ -132,7 +132,7 @@ class OpenAIVisionChain_Chains implements INode {
this.outputs = [ this.outputs = [
{ {
label: 'Open AI MultiModal Chain', label: 'Open AI MultiModal Chain',
name: 'OpenAIMultiModalChain', name: 'openAIMultiModalChain',
baseClasses: [this.type, ...getBaseClasses(VLLMChain)] baseClasses: [this.type, ...getBaseClasses(VLLMChain)]
}, },
{ {
@@ -154,6 +154,8 @@ class OpenAIVisionChain_Chains implements INode {
const modelName = nodeData.inputs?.modelName as string const modelName = nodeData.inputs?.modelName as string
const maxTokens = nodeData.inputs?.maxTokens as string const maxTokens = nodeData.inputs?.maxTokens as string
const topP = nodeData.inputs?.topP as string const topP = nodeData.inputs?.topP as string
const whisperConfig = nodeData.inputs?.audioInput
const fields: OpenAIVisionChainInput = { const fields: OpenAIVisionChainInput = {
openAIApiKey: openAIApiKey, openAIApiKey: openAIApiKey,
imageResolution: imageResolution, imageResolution: imageResolution,
@@ -164,6 +166,8 @@ class OpenAIVisionChain_Chains implements INode {
if (temperature) fields.temperature = parseFloat(temperature) if (temperature) fields.temperature = parseFloat(temperature)
if (maxTokens) fields.maxTokens = parseInt(maxTokens, 10) if (maxTokens) fields.maxTokens = parseInt(maxTokens, 10)
if (topP) fields.topP = parseFloat(topP) if (topP) fields.topP = parseFloat(topP)
if (whisperConfig) fields.whisperConfig = whisperConfig
if (output === this.name) { if (output === this.name) {
const chain = new VLLMChain({ const chain = new VLLMChain({
...fields, ...fields,
@@ -21,6 +21,7 @@ export interface OpenAIVisionChainInput extends ChainInputs {
modelName?: string modelName?: string
maxTokens?: number maxTokens?: number
topP?: number topP?: number
whisperConfig?: any
} }
/** /**
@@ -48,6 +49,8 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput {
maxTokens?: number maxTokens?: number
topP?: number topP?: number
whisperConfig?: any
constructor(fields: OpenAIVisionChainInput) { constructor(fields: OpenAIVisionChainInput) {
super(fields) super(fields)
this.throwError = fields?.throwError ?? false this.throwError = fields?.throwError ?? false
@@ -59,6 +62,7 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput {
this.maxTokens = fields?.maxTokens this.maxTokens = fields?.maxTokens
this.topP = fields?.topP this.topP = fields?.topP
this.imageUrls = fields?.imageUrls ?? [] this.imageUrls = fields?.imageUrls ?? []
this.whisperConfig = fields?.whisperConfig ?? {}
if (!this.openAIApiKey) { if (!this.openAIApiKey) {
throw new Error('OpenAI API key not found') throw new Error('OpenAI API key not found')
} }
@@ -92,15 +96,44 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput {
type: 'text', type: 'text',
text: userInput text: userInput
}) })
if (this.whisperConfig && this.imageUrls && this.imageUrls.length > 0) {
const audioUploads = this.getAudioUploads(this.imageUrls)
for (const url of audioUploads) {
const filePath = path.join(getUserHome(), '.flowise', 'gptvision', url.data, url.name)
// as the image is stored in the server, read the file and convert it to base64
const audio_file = fs.createReadStream(filePath)
if (this.whisperConfig.purpose === 'transcription') {
const transcription = await this.client.audio.transcriptions.create({
file: audio_file,
model: 'whisper-1'
})
userRole.content.push({
type: 'text',
text: transcription.text
})
} else if (this.whisperConfig.purpose === 'translation') {
const translation = await this.client.audio.translations.create({
file: audio_file,
model: 'whisper-1'
})
userRole.content.push({
type: 'text',
text: translation.text
})
}
}
}
if (this.imageUrls && this.imageUrls.length > 0) { if (this.imageUrls && this.imageUrls.length > 0) {
this.imageUrls.forEach((imageUrl: any) => { const imageUploads = this.getImageUploads(this.imageUrls)
let bf = imageUrl?.data for (const url of imageUploads) {
if (imageUrl.type == 'stored-file') { let bf = url.data
const filePath = path.join(getUserHome(), '.flowise', 'gptvision', imageUrl.data, imageUrl.name) if (url.type == 'stored-file') {
const filePath = path.join(getUserHome(), '.flowise', 'gptvision', url.data, url.name)
// as the image is stored in the server, read the file and convert it to base64 // as the image is stored in the server, read the file and convert it to base64
const contents = fs.readFileSync(filePath) const contents = fs.readFileSync(filePath)
bf = 'data:' + imageUrl.mime + ';base64,' + contents.toString('base64') bf = 'data:' + url.mime + ';base64,' + contents.toString('base64')
} }
userRole.content.push({ userRole.content.push({
type: 'image_url', type: 'image_url',
@@ -109,7 +142,7 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput {
detail: this.imageResolution detail: this.imageResolution
} }
}) })
}) }
} }
vRequest.messages.push(userRole) vRequest.messages.push(userRole)
if (this.prompt && this.prompt instanceof ChatPromptTemplate) { if (this.prompt && this.prompt instanceof ChatPromptTemplate) {
@@ -146,6 +179,14 @@ export class VLLMChain extends BaseChain implements OpenAIVisionChainInput {
} }
} }
getAudioUploads = (urls: any[]) => {
return urls.filter((url: any) => url.mime.startsWith('audio/'))
}
getImageUploads = (urls: any[]) => {
return urls.filter((url: any) => url.mime.startsWith('image/'))
}
_chainType() { _chainType() {
return 'vision_chain' return 'vision_chain'
} }
@@ -14,7 +14,6 @@ import {
Box, Box,
Button, Button,
Card, Card,
CardActions,
CardMedia, CardMedia,
Chip, Chip,
CircularProgress, CircularProgress,
@@ -48,7 +47,6 @@ import { baseURL, maxScroll } from 'store/constant'
import robotPNG from 'assets/images/robot.png' import robotPNG from 'assets/images/robot.png'
import userPNG from 'assets/images/account.png' import userPNG from 'assets/images/account.png'
import { isValidURL, removeDuplicateURL, setLocalStorageChatflow } from 'utils/genericHelper' import { isValidURL, removeDuplicateURL, setLocalStorageChatflow } from 'utils/genericHelper'
import DeleteIcon from '@mui/icons-material/Delete'
export const ChatMessage = ({ open, chatflowid, isDialog }) => { export const ChatMessage = ({ open, chatflowid, isDialog }) => {
const theme = useTheme() const theme = useTheme()
@@ -628,15 +626,25 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
{message.fileUploads && {message.fileUploads &&
message.fileUploads.map((item, index) => { message.fileUploads.map((item, index) => {
return ( return (
<Card key={index} sx={{ maxWidth: 128, margin: 5 }}> <>
<CardMedia {item.mime.startsWith('image/') ? (
component='img' <Card key={index} sx={{ maxWidth: 128, margin: 5 }}>
image={item.data} <CardMedia
sx={{ height: 64 }} component='img'
alt={'preview'} image={item.data}
style={messageImageStyle} sx={{ height: 64 }}
/> alt={'preview'}
</Card> style={messageImageStyle}
/>
</Card>
) : (
// eslint-disable-next-line jsx-a11y/media-has-caption
<audio controls='controls'>
Your browser does not support the &lt;audio&gt; tag.
<source src={item.data} type={item.mime} />
</audio>
)}
</>
) )
})} })}
{message.sourceDocuments && ( {message.sourceDocuments && (
@@ -738,23 +746,23 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
<Grid container spacing={2} sx={{ p: 1, mt: '5px', ml: '1px' }}> <Grid container spacing={2} sx={{ p: 1, mt: '5px', ml: '1px' }}>
{previews.map((item, index) => ( {previews.map((item, index) => (
<Grid item xs={12} sm={6} md={3} key={index}> <Grid item xs={12} sm={6} md={3} key={index}>
<Card variant='outlined' sx={{ maxWidth: 128 }}> {item.mime.startsWith('image/') ? (
<CardMedia <Card key={index} sx={{ maxWidth: 128, margin: 5 }}>
component='img' <CardMedia
image={item.preview} component='img'
sx={{ height: 64 }} image={item.data}
alt={`preview ${index}`} sx={{ height: 64 }}
style={previewStyle} alt={'preview'}
/> style={previewStyle}
<CardActions className='center' sx={{ p: 0, m: 0 }}>
<Button
startIcon={<DeleteIcon />}
onClick={() => handleDeletePreview(item)}
size='small'
variant='text'
/> />
</CardActions> </Card>
</Card> ) : (
// eslint-disable-next-line jsx-a11y/media-has-caption
<audio controls='controls'>
Your browser does not support the &lt;audio&gt; tag.
<source src={item.data} type={item.mime} />
</audio>
)}
</Grid> </Grid>
))} ))}
</Grid> </Grid>