Compression Retriever: Addition of constant to RRF Retriever

This commit is contained in:
vinodkiran
2023-12-30 08:07:15 +05:30
parent d0ab21e733
commit 1bd3b5d0ee
2 changed files with 19 additions and 3 deletions
@@ -57,6 +57,18 @@ class RRFRetriever_Retrievers implements INode {
default: 0,
additionalParams: true,
optional: true
},
{
label: 'Constant',
name: 'c',
description:
'A constant added to the rank, controlling the balance between the importance of high-ranked items and the consideration given to lower-ranked items.\n' +
'Default is 60',
placeholder: '60',
type: 'number',
default: 60,
additionalParams: true,
optional: true
}
]
}
@@ -68,12 +80,14 @@ class RRFRetriever_Retrievers implements INode {
const q = queryCount ? parseFloat(queryCount) : 4
const topK = nodeData.inputs?.topK as string
let k = topK ? parseFloat(topK) : 4
const constantC = nodeData.inputs?.c as string
let c = topK ? parseFloat(constantC) : 60
if (k <= 0) {
k = (baseRetriever as VectorStoreRetriever).k
}
const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k)
const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k, c)
return new ContextualCompressionRetriever({
baseCompressor: ragFusion,
baseRetriever: baseRetriever
@@ -10,13 +10,15 @@ export class ReciprocalRankFusion extends BaseDocumentCompressor {
private readonly llm: BaseLanguageModel
private readonly queryCount: number
private readonly topK: number
private readonly c: number
private baseRetriever: VectorStoreRetriever
constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number) {
constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number, c: number) {
super()
this.queryCount = queryCount
this.llm = llm
this.baseRetriever = baseRetriever
this.topK = topK
this.c = c
}
async compressDocuments(
documents: Document<Record<string, any>>[],
@@ -57,7 +59,7 @@ export class ReciprocalRankFusion extends BaseDocumentCompressor {
docList.push(docs)
}
return this.reciprocalRankFunction(docList, 60)
return this.reciprocalRankFunction(docList, this.c)
}
reciprocalRankFunction(docList: Document<Record<string, any>>[][], k: number): Document<Record<string, any>>[] {