diff --git a/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerank.ts b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerank.ts new file mode 100644 index 00000000..612581ed --- /dev/null +++ b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerank.ts @@ -0,0 +1,51 @@ +import { Callbacks } from 'langchain/callbacks' +import { Document } from 'langchain/document' +import { BaseDocumentCompressor } from 'langchain/retrievers/document_compressors' +import axios from 'axios' +export class CohereRerank extends BaseDocumentCompressor { + private cohereAPIKey: any + private COHERE_API_URL = 'https://api.cohere.ai/v1/rerank' + private model: string + + constructor(cohereAPIKey: string, model: string) { + super() + this.cohereAPIKey = cohereAPIKey + this.model = model + } + async compressDocuments( + documents: Document>[], + query: string, + _?: Callbacks | undefined + ): Promise>[]> { + // avoid empty api call + if (documents.length === 0) { + return [] + } + const config = { + headers: { + Authorization: `Bearer ${this.cohereAPIKey}`, + 'Content-Type': 'application/json', + Accept: 'application/json' + } + } + const data = { + model: this.model, + max_chunks_per_doc: 10, + query: query, + return_documents: false, + documents: documents.map((doc) => doc.pageContent) + } + try { + let returnedDocs = await axios.post(this.COHERE_API_URL, data, config) + const finalResults: Document>[] = [] + returnedDocs.data.results.forEach((result: any) => { + const doc = documents[result.index] + doc.metadata.relevance_score = result.relevance_score + finalResults.push(doc) + }) + return finalResults + } catch (error) { + return documents + } + } +} diff --git a/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerankRetriever.ts b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerankRetriever.ts new file mode 100644 index 00000000..2e7090bc --- /dev/null +++ b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerankRetriever.ts @@ -0,0 +1,77 @@ +import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { BaseRetriever } from 'langchain/schema/retriever' +import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression' +import { getCredentialData, getCredentialParam } from '../../../src' +import { CohereRerank } from './CohereRerank' + +class CohereRerankRetriever_Retrievers implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + outputs: INodeOutputsValue[] + credential: INodeParams + badge: string + + constructor() { + this.label = 'Cohere Rerank Retriever' + this.name = 'cohereRerankRetriever' + this.version = 1.0 + this.type = 'Cohere Rerank Retriever' + this.icon = 'compressionRetriever.svg' + this.category = 'Retrievers' + this.badge = 'NEW' + this.description = 'Cohere Rerank indexes the documents from most to least semantically relevant to the query.' + this.baseClasses = [this.type, 'BaseRetriever'] + this.credential = { + label: 'Connect Credential', + name: 'credential', + type: 'credential', + credentialNames: ['cohereApi'] + } + this.inputs = [ + { + label: 'Base Retriever', + name: 'baseRetriever', + type: 'VectorStoreRetriever' + }, + { + label: 'Model Name', + name: 'model', + type: 'options', + options: [ + { + label: 'rerank-english-v2.0', + name: 'rerank-english-v2.0' + }, + { + label: 'rerank-multilingual-v2.0', + name: 'rerank-multilingual-v2.0' + } + ], + default: 'rerank-english-v2.0', + optional: true + } + ] + } + + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { + const baseRetriever = nodeData.inputs?.baseRetriever as BaseRetriever + const model = nodeData.inputs?.model as string + const credentialData = await getCredentialData(nodeData.credential ?? '', options) + const cohereApiKey = getCredentialParam('cohereApiKey', credentialData, nodeData) + + const cohereCompressor = new CohereRerank(cohereApiKey, model) + return new ContextualCompressionRetriever({ + baseCompressor: cohereCompressor, + baseRetriever: baseRetriever + }) + } +} + +module.exports = { nodeClass: CohereRerankRetriever_Retrievers } diff --git a/packages/components/nodes/retrievers/CohereRerankRetriever/compressionRetriever.svg b/packages/components/nodes/retrievers/CohereRerankRetriever/compressionRetriever.svg new file mode 100644 index 00000000..23c52d25 --- /dev/null +++ b/packages/components/nodes/retrievers/CohereRerankRetriever/compressionRetriever.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file