diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts index 737b0bf2..bd660b47 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts @@ -1,8 +1,9 @@ import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' -import { convertStringToArrayString, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' +import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { BaseCache } from 'langchain/schema' import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from '@langchain/google-genai' import { HarmBlockThreshold, HarmCategory } from '@google/generative-ai' +import type { SafetySetting } from '@google/generative-ai' class GoogleGenerativeAI_ChatModels implements INode { label: string @@ -158,23 +159,10 @@ class GoogleGenerativeAI_ChatModels implements INode { const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string const cache = nodeData.inputs?.cache as BaseCache - // safetySettings - let harmCategories: string[] = convertStringToArrayString(harmCategory) - let harmBlockThresholds: string[] = convertStringToArrayString(harmBlockThreshold) - if (harmCategories.length != harmBlockThresholds.length) - throw new Error(`Harm Category & Harm Block Threshold are not the same length`) - const safetySettings = harmCategories.map((value, index) => { - return { - category: value, - threshold: harmBlockThresholds[index] - } - }) - const obj: Partial = { apiKey: apiKey, modelName: modelName, - maxOutputTokens: 2048, - safetySettings: safetySettings.length > 0 ? safetySettings : undefined + maxOutputTokens: 2048 } if (maxOutputTokens) obj.maxOutputTokens = parseInt(maxOutputTokens, 10) @@ -185,8 +173,63 @@ class GoogleGenerativeAI_ChatModels implements INode { if (cache) model.cache = cache if (temperature) model.temperature = parseFloat(temperature) + // Safety Settings + let harmCategories: string[] = convertMultiOptionsToStringArray(harmCategory) + let harmBlockThresholds: string[] = convertMultiOptionsToStringArray(harmBlockThreshold) + if (harmCategories.length != harmBlockThresholds.length) + throw new Error(`Harm Category & Harm Block Threshold are not the same length`) + const safetySettings: SafetySetting[] = harmCategories.map((value, index) => { + return { + category: categoryInput(value), + threshold: thresholdInput(harmBlockThresholds[index]) + } + }) + if (safetySettings.length > 0) model.safetySettings = safetySettings + return model } } +const categoryInput = (categoryInput: string): HarmCategory => { + let categoryOutput: HarmCategory + switch (categoryInput) { + case HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: + categoryOutput = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + break + case HarmCategory.HARM_CATEGORY_HATE_SPEECH: + categoryOutput = HarmCategory.HARM_CATEGORY_HATE_SPEECH + break + case HarmCategory.HARM_CATEGORY_HARASSMENT: + categoryOutput = HarmCategory.HARM_CATEGORY_HARASSMENT + break + case HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: + categoryOutput = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT + break + default: + categoryOutput = HarmCategory.HARM_CATEGORY_UNSPECIFIED + } + return categoryOutput +} + +const thresholdInput = (thresholdInput: string): HarmBlockThreshold => { + let thresholdOutput: HarmBlockThreshold + switch (thresholdInput) { + case HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: + thresholdOutput = HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + break + case HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: + thresholdOutput = HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE + break + case HarmBlockThreshold.BLOCK_NONE: + thresholdOutput = HarmBlockThreshold.BLOCK_NONE + break + case HarmBlockThreshold.BLOCK_ONLY_HIGH: + thresholdOutput = HarmBlockThreshold.BLOCK_ONLY_HIGH + break + default: + thresholdOutput = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED + } + return thresholdOutput +} + module.exports = { nodeClass: GoogleGenerativeAI_ChatModels } diff --git a/packages/components/package.json b/packages/components/package.json index a2565430..55e84074 100644 --- a/packages/components/package.json +++ b/packages/components/package.json @@ -25,6 +25,7 @@ "@gomomento/sdk": "^1.51.1", "@gomomento/sdk-core": "^1.51.1", "@google-ai/generativelanguage": "^0.2.1", + "@google/generative-ai": "^0.1.3", "@huggingface/inference": "^2.6.1", "@langchain/google-genai": "^0.0.6", "@langchain/mistralai": "^0.0.6", diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 88c553cf..2d983562 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -675,11 +675,11 @@ export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[ } /** - * Convert String to Array String + * Convert MultiOptions String to String Array * @param {string} inputString * @returns {string[]} */ -export const convertStringToArrayString = (inputString: string): string[] => { +export const convertMultiOptionsToStringArray = (inputString: string): string[] => { let ArrayString: string[] = [] if (inputString) { try {