diff --git a/packages/components/nodes/retrievers/HydeRetriever/HydeRetriever.ts b/packages/components/nodes/retrievers/HydeRetriever/HydeRetriever.ts index 058f9886..a90f9870 100644 --- a/packages/components/nodes/retrievers/HydeRetriever/HydeRetriever.ts +++ b/packages/components/nodes/retrievers/HydeRetriever/HydeRetriever.ts @@ -1,8 +1,8 @@ import { VectorStore } from 'langchain/vectorstores/base' -import { INode, INodeData, INodeParams, HyDERetrieverInput } from '../../../src/Interface' -import { HydeRetriever } from 'langchain/retrievers/hyde' +import { INode, INodeData, INodeParams } from '../../../src/Interface' +import { HydeRetriever, HydeRetrieverOptions, PromptKey } from 'langchain/retrievers/hyde' import { BaseLanguageModel } from 'langchain/base_language' -import { Embeddings } from 'langchain/embeddings/base' +import { PromptTemplate } from 'langchain/prompts' class HydeRetriever_Retrievers implements INode { label: string @@ -33,11 +33,6 @@ class HydeRetriever_Retrievers implements INode { name: 'vectorStore', type: 'VectorStore' }, - { - label: 'Embeddings', - name: 'embeddings', - type: 'Embeddings' - }, { label: 'Prompt Key', name: 'promptKey', @@ -78,6 +73,16 @@ class HydeRetriever_Retrievers implements INode { ], default: 'websearch' }, + { + label: 'Custom Prompt', + name: 'customPrompt', + description: 'If custom prompt is used, this will override Prompt Key', + placeholder: 'Please write a passage to answer the question\nQuestion: {question}\nPassage:', + type: 'string', + rows: 4, + additionalParams: true, + optional: true + }, { label: 'Top K', name: 'topK', @@ -94,17 +99,19 @@ class HydeRetriever_Retrievers implements INode { async init(nodeData: INodeData): Promise { const llm = nodeData.inputs?.model as BaseLanguageModel const vectorStore = nodeData.inputs?.vectorStore as VectorStore - const embeddings = nodeData.inputs?.embeddings as Embeddings - const promptKey = nodeData.inputs?.promptKey as string - const topK = nodeData.inputs?.topK as number + const promptKey = nodeData.inputs?.promptKey as PromptKey + const customPrompt = nodeData.inputs?.customPrompt as string + const topK = nodeData.inputs?.topK as string + const k = topK ? parseInt(topK, 10) : 4 - const obj = { + const obj: HydeRetrieverOptions = { llm, vectorStore, - embeddings, - promptKey, - topK - } as HyDERetrieverInput + k + } + + if (customPrompt) obj.promptTemplate = PromptTemplate.fromTemplate(customPrompt) + else if (promptKey) obj.promptTemplate = promptKey const retriever = new HydeRetriever(obj) return retriever