Support cache system instructs for Google GenAI (#4148)

* Support cache system instructs for Google GenAI

* format code

* Update FlowiseGoogleAICacheManager.ts

---------

Co-authored-by: Henry Heng <henryheng@flowiseai.com>
This commit is contained in:
Hans
2025-04-14 23:26:03 +08:00
committed by GitHub
parent 654bd48849
commit d3510d1054
5 changed files with 160 additions and 5 deletions
@@ -25,6 +25,7 @@ import { StructuredToolInterface } from '@langchain/core/tools'
import { isStructuredTool } from '@langchain/core/utils/function_calling'
import { zodToJsonSchema } from 'zod-to-json-schema'
import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base'
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
const DEFAULT_IMAGE_MAX_TOKEN = 8192
const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest'
@@ -86,6 +87,8 @@ class LangchainChatGoogleGenerativeAI
private client: GenerativeModel
private contextCache?: FlowiseGoogleAICacheManager
get _isMultimodalModel() {
return this.modelName.includes('vision') || this.modelName.startsWith('gemini-1.5')
}
@@ -147,7 +150,7 @@ class LangchainChatGoogleGenerativeAI
this.getClient()
}
getClient(tools?: Tool[]) {
async getClient(prompt?: Content[], tools?: Tool[]) {
this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel({
model: this.modelName,
tools,
@@ -161,6 +164,14 @@ class LangchainChatGoogleGenerativeAI
topK: this.topK
}
})
if (this.contextCache) {
const cachedContent = await this.contextCache.lookup({
contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [],
model: this.modelName,
tools
})
this.client.cachedContent = cachedContent as any
}
}
_combineLLMOutput() {
@@ -209,6 +220,10 @@ class LangchainChatGoogleGenerativeAI
}
}
setContextCache(contextCache: FlowiseGoogleAICacheManager): void {
this.contextCache = contextCache
}
async getNumTokens(prompt: BaseMessage[]) {
const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel)
const { totalTokens } = await this.client.countTokens({ contents })
@@ -226,9 +241,9 @@ class LangchainChatGoogleGenerativeAI
this.convertFunctionResponse(prompt)
if (tools.length > 0) {
this.getClient(tools as Tool[])
await this.getClient(prompt, tools as Tool[])
} else {
this.getClient()
await this.getClient(prompt)
}
const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
let output
@@ -296,9 +311,9 @@ class LangchainChatGoogleGenerativeAI
const tools = options.tools ?? []
if (tools.length > 0) {
this.getClient(tools as Tool[])
await this.getClient(prompt, tools as Tool[])
} else {
this.getClient()
await this.getClient(prompt)
}
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {