Support cache system instructs for Google GenAI (#4148)

* Support cache system instructs for Google GenAI

* format code

* Update FlowiseGoogleAICacheManager.ts

---------

Co-authored-by: Henry Heng <henryheng@flowiseai.com>
This commit is contained in:
Hans
2025-04-14 23:26:03 +08:00
committed by GitHub
parent 654bd48849
commit d3510d1054
5 changed files with 160 additions and 5 deletions
@@ -5,6 +5,7 @@ import { ICommonObject, IMultiModalOption, INode, INodeData, INodeOptionsValue,
import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
import { getModels, MODEL_TYPE } from '../../../src/modelLoader'
import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from './FlowiseChatGoogleGenerativeAI'
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
class GoogleGenerativeAI_ChatModels implements INode {
label: string
@@ -42,6 +43,12 @@ class GoogleGenerativeAI_ChatModels implements INode {
type: 'BaseCache',
optional: true
},
{
label: 'Context Cache',
name: 'contextCache',
type: 'GoogleAICacheManager',
optional: true
},
{
label: 'Model Name',
name: 'modelName',
@@ -188,6 +195,7 @@ class GoogleGenerativeAI_ChatModels implements INode {
const harmCategory = nodeData.inputs?.harmCategory as string
const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string
const cache = nodeData.inputs?.cache as BaseCache
const contextCache = nodeData.inputs?.contextCache as FlowiseGoogleAICacheManager
const streaming = nodeData.inputs?.streaming as boolean
const allowImageUploads = nodeData.inputs?.allowImageUploads as boolean
@@ -225,6 +233,7 @@ class GoogleGenerativeAI_ChatModels implements INode {
const model = new ChatGoogleGenerativeAI(nodeData.id, obj)
model.setMultiModalOption(multiModalOption)
if (contextCache) model.setContextCache(contextCache)
return model
}
@@ -25,6 +25,7 @@ import { StructuredToolInterface } from '@langchain/core/tools'
import { isStructuredTool } from '@langchain/core/utils/function_calling'
import { zodToJsonSchema } from 'zod-to-json-schema'
import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base'
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
const DEFAULT_IMAGE_MAX_TOKEN = 8192
const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest'
@@ -86,6 +87,8 @@ class LangchainChatGoogleGenerativeAI
private client: GenerativeModel
private contextCache?: FlowiseGoogleAICacheManager
get _isMultimodalModel() {
return this.modelName.includes('vision') || this.modelName.startsWith('gemini-1.5')
}
@@ -147,7 +150,7 @@ class LangchainChatGoogleGenerativeAI
this.getClient()
}
getClient(tools?: Tool[]) {
async getClient(prompt?: Content[], tools?: Tool[]) {
this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel({
model: this.modelName,
tools,
@@ -161,6 +164,14 @@ class LangchainChatGoogleGenerativeAI
topK: this.topK
}
})
if (this.contextCache) {
const cachedContent = await this.contextCache.lookup({
contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [],
model: this.modelName,
tools
})
this.client.cachedContent = cachedContent as any
}
}
_combineLLMOutput() {
@@ -209,6 +220,10 @@ class LangchainChatGoogleGenerativeAI
}
}
setContextCache(contextCache: FlowiseGoogleAICacheManager): void {
this.contextCache = contextCache
}
async getNumTokens(prompt: BaseMessage[]) {
const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel)
const { totalTokens } = await this.client.countTokens({ contents })
@@ -226,9 +241,9 @@ class LangchainChatGoogleGenerativeAI
this.convertFunctionResponse(prompt)
if (tools.length > 0) {
this.getClient(tools as Tool[])
await this.getClient(prompt, tools as Tool[])
} else {
this.getClient()
await this.getClient(prompt)
}
const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
let output
@@ -296,9 +311,9 @@ class LangchainChatGoogleGenerativeAI
const tools = options.tools ?? []
if (tools.length > 0) {
this.getClient(tools as Tool[])
await this.getClient(prompt, tools as Tool[])
} else {
this.getClient()
await this.getClient(prompt)
}
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {