From 1bd3b5d0ee9395ed5bce48eadc5c4e58f5385099 Mon Sep 17 00:00:00 2001 From: vinodkiran Date: Sat, 30 Dec 2023 08:07:15 +0530 Subject: [PATCH] Compression Retriever: Addition of constant to RRF Retriever --- .../retrievers/RRFRetriever/RRFRetriever.ts | 16 +++++++++++++++- .../RRFRetriever/ReciprocalRankFusion.ts | 6 ++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts b/packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts index 8d6d9d6f..3229b3a8 100644 --- a/packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts +++ b/packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts @@ -57,6 +57,18 @@ class RRFRetriever_Retrievers implements INode { default: 0, additionalParams: true, optional: true + }, + { + label: 'Constant', + name: 'c', + description: + 'A constant added to the rank, controlling the balance between the importance of high-ranked items and the consideration given to lower-ranked items.\n' + + 'Default is 60', + placeholder: '60', + type: 'number', + default: 60, + additionalParams: true, + optional: true } ] } @@ -68,12 +80,14 @@ class RRFRetriever_Retrievers implements INode { const q = queryCount ? parseFloat(queryCount) : 4 const topK = nodeData.inputs?.topK as string let k = topK ? parseFloat(topK) : 4 + const constantC = nodeData.inputs?.c as string + let c = topK ? parseFloat(constantC) : 60 if (k <= 0) { k = (baseRetriever as VectorStoreRetriever).k } - const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k) + const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k, c) return new ContextualCompressionRetriever({ baseCompressor: ragFusion, baseRetriever: baseRetriever diff --git a/packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts b/packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts index 134d7c8a..b14608fe 100644 --- a/packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts +++ b/packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts @@ -10,13 +10,15 @@ export class ReciprocalRankFusion extends BaseDocumentCompressor { private readonly llm: BaseLanguageModel private readonly queryCount: number private readonly topK: number + private readonly c: number private baseRetriever: VectorStoreRetriever - constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number) { + constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number, c: number) { super() this.queryCount = queryCount this.llm = llm this.baseRetriever = baseRetriever this.topK = topK + this.c = c } async compressDocuments( documents: Document>[], @@ -57,7 +59,7 @@ export class ReciprocalRankFusion extends BaseDocumentCompressor { docList.push(docs) } - return this.reciprocalRankFunction(docList, 60) + return this.reciprocalRankFunction(docList, this.c) } reciprocalRankFunction(docList: Document>[][], k: number): Document>[] {