diff --git a/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts b/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts index 5c5f6352..4946fa8b 100644 --- a/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts +++ b/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts @@ -83,6 +83,24 @@ class AWSBedrockEmbedding_Embeddings implements INode { } ], optional: true + }, + { + label: 'Batch Size', + name: 'batchSize', + description: 'Documents batch size to send to AWS API for Titan model embeddings. Used to avoid throttling.', + type: 'number', + optional: true, + default: 50, + additionalParams: true + }, + { + label: 'Max AWS API retries', + name: 'maxRetries', + description: 'This will limit the nubmer of AWS API for Titan model embeddings call retries. Used to avoid throttling.', + type: 'number', + optional: true, + default: 5, + additionalParams: true } ] } @@ -144,7 +162,9 @@ class AWSBedrockEmbedding_Embeddings implements INode { if (iModel.startsWith('cohere')) { return await embedTextCohere(documents, client, iModel, inputType) } else { - return Promise.all(documents.map((document) => embedTextTitan(document, client, iModel))) + const batchSize = nodeData.inputs?.batchSize as number + const maxRetries = nodeData.inputs?.maxRetries as number + return processInBatches(documents, batchSize, maxRetries, (document) => embedTextTitan(document, client, iModel)) } } return model @@ -195,4 +215,38 @@ const embedTextCohere = async (texts: string[], client: BedrockRuntimeClient, mo } } +const processInBatches = async ( + documents: string[], + batchSize: number, + maxRetries: number, + processFunc: (document: string) => Promise +): Promise => { + let sleepTime = 0 + let retryCounter = 0 + let result: number[][] = [] + for (let i = 0; i < documents.length; i += batchSize) { + let chunk = documents.slice(i, i + batchSize) + try { + let chunkResult = await Promise.all(chunk.map(processFunc)) + result.push(...chunkResult) + retryCounter = 0 + } catch (e) { + if (retryCounter < maxRetries && e.name.includes('ThrottlingException')) { + retryCounter = retryCounter + 1 + i = i - batchSize + sleepTime = sleepTime + 100 + } else { + // Split to distinguish between throttling retry error and other errors in trance + if (e.name.includes('ThrottlingException')) { + throw new Error('AWS Bedrock retry limit reached: ' + e) + } else { + throw new Error(e) + } + } + } + await new Promise((resolve) => setTimeout(resolve, sleepTime)) + } + return result +} + module.exports = { nodeClass: AWSBedrockEmbedding_Embeddings }