From c96572e10ff8fa954d973f10748c603a00d49ef9 Mon Sep 17 00:00:00 2001 From: vinodkiran Date: Sat, 25 Nov 2023 16:39:02 +0530 Subject: [PATCH] GPT Vision - OpenAIVisionChain --- .../chains/VisionChain/OpenAIVisionChain.ts | 216 ++++++++++++++++++ .../nodes/chains/VisionChain/VLLMChain.ts | 146 ++++++++++++ packages/server/src/index.ts | 26 +++ 3 files changed, 388 insertions(+) create mode 100644 packages/components/nodes/chains/VisionChain/OpenAIVisionChain.ts create mode 100644 packages/components/nodes/chains/VisionChain/VLLMChain.ts diff --git a/packages/components/nodes/chains/VisionChain/OpenAIVisionChain.ts b/packages/components/nodes/chains/VisionChain/OpenAIVisionChain.ts new file mode 100644 index 00000000..f2260a76 --- /dev/null +++ b/packages/components/nodes/chains/VisionChain/OpenAIVisionChain.ts @@ -0,0 +1,216 @@ +import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils' +import { VLLMChain } from './VLLMChain' +import { BaseLanguageModel } from 'langchain/base_language' +import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' +import { formatResponse } from '../../outputparsers/OutputParserHelpers' +import { ChatOpenAI } from 'langchain/chat_models/openai' + +class OpenAIVisionChain_Chains implements INode { + label: string + name: string + version: number + type: string + icon: string + category: string + baseClasses: string[] + description: string + inputs: INodeParams[] + outputs: INodeOutputsValue[] + + constructor() { + this.label = 'Open AI Vision Chain' + this.name = 'openAIVisionChain' + this.version = 3.0 + this.type = 'OpenAIVisionChain' + this.icon = 'chain.svg' + this.category = 'Chains' + this.description = 'Chain to run queries against OpenAI (GPT-4) Vision .' + this.baseClasses = [this.type, ...getBaseClasses(VLLMChain)] + this.inputs = [ + { + label: 'Language Model (Works only with Open AI [gpt-4-vision-preview])', + name: 'model', + type: 'BaseLanguageModel' + }, + { + label: 'Prompt', + name: 'prompt', + type: 'BasePromptTemplate', + optional: true + }, + { + label: 'Image Resolution', + description: 'This parameter controls the resolution in which the model views the image.', + name: 'imageResolution', + type: 'options', + options: [ + { + label: 'Low', + name: 'low' + }, + { + label: 'High', + name: 'high' + } + ], + default: 'low', + optional: false + }, + { + label: 'Chain Name', + name: 'chainName', + type: 'string', + placeholder: 'Name Your Chain', + optional: true + } + ] + this.outputs = [ + { + label: 'Open AI Vision Chain', + name: 'openAIVisionChain', + baseClasses: [this.type, ...getBaseClasses(VLLMChain)] + }, + { + label: 'Output Prediction', + name: 'outputPrediction', + baseClasses: ['string', 'json'] + } + ] + } + + async init(nodeData: INodeData, input: string, options: ICommonObject): Promise { + const model = nodeData.inputs?.model as BaseLanguageModel + const prompt = nodeData.inputs?.prompt + const output = nodeData.outputs?.output as string + const imageResolution = nodeData.inputs?.imageResolution + const promptValues = prompt.promptValues as ICommonObject + if (!(model as any).openAIApiKey || (model as any).modelName !== 'gpt-4-vision-preview') { + throw new Error('Chain works with OpenAI Vision model only') + } + const openAIModel = model as ChatOpenAI + const fields = { + openAIApiKey: openAIModel.openAIApiKey, + imageResolution: imageResolution, + verbose: process.env.DEBUG === 'true', + imageUrls: options.url, + openAIModel: openAIModel + } + if (output === this.name) { + const chain = new VLLMChain({ + ...fields, + prompt: prompt + }) + return chain + } else if (output === 'outputPrediction') { + const chain = new VLLMChain({ + ...fields + }) + const inputVariables: string[] = prompt.inputVariables as string[] // ["product"] + const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData) + // eslint-disable-next-line no-console + console.log('\x1b[92m\x1b[1m\n*****OUTPUT PREDICTION*****\n\x1b[0m\x1b[0m') + // eslint-disable-next-line no-console + console.log(res) + /** + * Apply string transformation to convert special chars: + * FROM: hello i am ben\n\n\thow are you? + * TO: hello i am benFLOWISE_NEWLINEFLOWISE_NEWLINEFLOWISE_TABhow are you? + */ + return handleEscapeCharacters(res, false) + } + } + + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + const prompt = nodeData.inputs?.prompt + const inputVariables: string[] = prompt.inputVariables as string[] // ["product"] + const chain = nodeData.instance as VLLMChain + let promptValues: ICommonObject | undefined = nodeData.inputs?.prompt.promptValues as ICommonObject + const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData) + // eslint-disable-next-line no-console + console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m') + // eslint-disable-next-line no-console + console.log(res) + return res + } +} + +const runPrediction = async ( + inputVariables: string[], + chain: VLLMChain, + input: string, + promptValuesRaw: ICommonObject | undefined, + options: ICommonObject, + nodeData: INodeData +) => { + const loggerHandler = new ConsoleCallbackHandler(options.logger) + const callbacks = await additionalCallbacks(nodeData, options) + + const isStreaming = options.socketIO && options.socketIOClientId + const socketIO = isStreaming ? options.socketIO : undefined + const socketIOClientId = isStreaming ? options.socketIOClientId : '' + + /** + * Apply string transformation to reverse converted special chars: + * FROM: { "value": "hello i am benFLOWISE_NEWLINEFLOWISE_NEWLINEFLOWISE_TABhow are you?" } + * TO: { "value": "hello i am ben\n\n\thow are you?" } + */ + const promptValues = handleEscapeCharacters(promptValuesRaw, true) + if (options?.url) { + chain.imageUrls = options.url + } + if (promptValues && inputVariables.length > 0) { + let seen: string[] = [] + + for (const variable of inputVariables) { + seen.push(variable) + if (promptValues[variable]) { + chain.inputKey = variable + seen.pop() + } + } + + if (seen.length === 0) { + // All inputVariables have fixed values specified + const options = { ...promptValues } + if (isStreaming) { + const handler = new CustomChainHandler(socketIO, socketIOClientId) + const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) + return formatResponse(res?.text) + } else { + const res = await chain.call(options, [loggerHandler, ...callbacks]) + return formatResponse(res?.text) + } + } else if (seen.length === 1) { + // If one inputVariable is not specify, use input (user's question) as value + const lastValue = seen.pop() + if (!lastValue) throw new Error('Please provide Prompt Values') + chain.inputKey = lastValue as string + const options = { + ...promptValues, + [lastValue]: input + } + if (isStreaming) { + const handler = new CustomChainHandler(socketIO, socketIOClientId) + const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) + return formatResponse(res?.text) + } else { + const res = await chain.call(options, [loggerHandler, ...callbacks]) + return formatResponse(res?.text) + } + } else { + throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`) + } + } else { + if (isStreaming) { + const handler = new CustomChainHandler(socketIO, socketIOClientId) + const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) + return formatResponse(res) + } else { + const res = await chain.run(input, [loggerHandler, ...callbacks]) + return formatResponse(res) + } + } +} + +module.exports = { nodeClass: OpenAIVisionChain_Chains } diff --git a/packages/components/nodes/chains/VisionChain/VLLMChain.ts b/packages/components/nodes/chains/VisionChain/VLLMChain.ts new file mode 100644 index 00000000..17260be2 --- /dev/null +++ b/packages/components/nodes/chains/VisionChain/VLLMChain.ts @@ -0,0 +1,146 @@ +import { OpenAI as OpenAIClient, ClientOptions } from 'openai' +import { BaseChain, ChainInputs } from 'langchain/chains' +import { ChainValues } from 'langchain/schema' +import { BasePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' +import { ChatOpenAI } from 'langchain/chat_models/openai' + +/** + * Interface for the input parameters of the OpenAIVisionChain class. + */ +export interface OpenAIVisionChainInput extends ChainInputs { + openAIApiKey?: string + openAIOrganization?: string + throwError?: boolean + prompt?: BasePromptTemplate + configuration?: ClientOptions + imageUrls?: [] + imageResolution?: string + openAIModel: ChatOpenAI +} + +/** + * Class representing a chain for generating text from an image using the OpenAI + * Vision API. It extends the BaseChain class and implements the + * OpenAIVisionChainInput interface. + */ +export class VLLMChain extends BaseChain implements OpenAIVisionChainInput { + static lc_name() { + return 'VLLMChain' + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + openAIApiKey: 'OPENAI_API_KEY' + } + } + prompt: BasePromptTemplate | undefined + + inputKey = 'input' + outputKey = 'text' + imageUrls?: [] + imageResolution: string = 'low' + openAIApiKey?: string + openAIOrganization?: string + openAIModel: ChatOpenAI + clientConfig: ClientOptions + client: OpenAIClient + throwError: boolean + + constructor(fields: OpenAIVisionChainInput) { + super(fields) + this.throwError = fields?.throwError ?? false + this.imageResolution = fields?.imageResolution ?? 'low' + this.openAIApiKey = fields?.openAIApiKey + this.prompt = fields?.prompt + this.imageUrls = fields?.imageUrls ?? [] + if (!this.openAIApiKey) { + throw new Error('OpenAI API key not found') + } + + this.openAIOrganization = fields?.openAIOrganization + this.openAIModel = fields.openAIModel + + this.clientConfig = { + ...fields?.configuration, + apiKey: this.openAIApiKey, + organization: this.openAIOrganization + } + + this.client = new OpenAIClient(this.clientConfig) + } + + async _call(values: ChainValues): Promise { + const userInput = values[this.inputKey] + + const vRequest: any = { + model: 'gpt-4-vision-preview', + temperature: this.openAIModel.temperature, + top_p: this.openAIModel.topP, + messages: [] + } + if (this.openAIModel.maxTokens) vRequest.max_tokens = this.openAIModel.maxTokens + + const userRole: any = { role: 'user' } + userRole.content = [] + userRole.content.push({ + type: 'text', + text: userInput + }) + if (this.imageUrls && this.imageUrls.length > 0) { + this.imageUrls.forEach((imageUrl: any) => { + userRole.content.push({ + type: 'image_url', + image_url: { + url: imageUrl?.data, + detail: this.imageResolution + } + }) + }) + } + vRequest.messages.push(userRole) + if (this.prompt && this.prompt instanceof ChatPromptTemplate) { + let chatPrompt = this.prompt as ChatPromptTemplate + chatPrompt.promptMessages.forEach((message: any) => { + if (message instanceof SystemMessagePromptTemplate) { + vRequest.messages.push({ + role: 'system', + content: [ + { + type: 'text', + text: (message.prompt as any).template + } + ] + }) + } + }) + } + + let response + try { + // @ts-ignore + response = await this.client.chat.completions.create(vRequest) + } catch (error) { + if (error instanceof Error) { + throw error + } else { + throw new Error(error as string) + } + } + const output = response.choices[0] + return { + [this.outputKey]: output.message.content + } + } + + _chainType() { + return 'vision_chain' + } + + get inputKeys() { + return this.prompt?.inputVariables ?? [this.inputKey] + } + + get outputKeys(): string[] { + return [this.outputKey] + } +} diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 91de4f4c..9bc3eb3a 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -403,6 +403,19 @@ export class App { return res.json(obj) }) + // Check if chatflow valid for uploads + this.app.get('/api/v1/chatflows-uploads/:id', async (req: Request, res: Response) => { + const chatflow = await this.AppDataSource.getRepository(ChatFlow).findOneBy({ + id: req.params.id + }) + if (!chatflow) return res.status(404).send(`Chatflow ${req.params.id} not found`) + + const obj = { + allowUploads: this.shouldAllowUploads(chatflow) + } + return res.json(obj) + }) + // ---------------------------------------- // ChatMessage // ---------------------------------------- @@ -1241,6 +1254,19 @@ export class App { }) } + private uploadAllowedNodes = ['OpenAIVisionChain'] + private shouldAllowUploads(result: ChatFlow): boolean { + const flowObj = JSON.parse(result.flowData) + let allowUploads = false + flowObj.nodes.forEach((node: IReactFlowNode) => { + if (this.uploadAllowedNodes.indexOf(node.data.type) > -1) { + logger.debug(`[server]: Found Eligible Node ${node.data.type}, Allowing Uploads.`) + allowUploads = true + } + }) + return allowUploads + } + /** * Validate API Key * @param {Request} req