add top K to vector stores

This commit is contained in:
Henry
2023-06-10 17:21:27 +01:00
parent 617b3bea96
commit d11cb5f4b4
36 changed files with 925 additions and 238 deletions
@@ -45,8 +45,9 @@ class BabyAGI_Agents implements INode {
const model = nodeData.inputs?.model as BaseChatModel
const vectorStore = nodeData.inputs?.vectorStore as VectorStore
const taskLoop = nodeData.inputs?.taskLoop as string
const k = (vectorStore as any)?.k ?? 4
const babyAgi = BabyAGI.fromLLM(model, vectorStore, parseInt(taskLoop, 10))
const babyAgi = BabyAGI.fromLLM(model, vectorStore, parseInt(taskLoop, 10), k)
return babyAgi
}
@@ -154,18 +154,22 @@ export class BabyAGI {
maxIterations = 3
topK = 4
constructor(
taskCreationChain: TaskCreationChain,
taskPrioritizationChain: TaskPrioritizationChain,
executionChain: ExecutionChain,
vectorStore: VectorStore,
maxIterations: number
maxIterations: number,
topK: number
) {
this.taskCreationChain = taskCreationChain
this.taskPrioritizationChain = taskPrioritizationChain
this.executionChain = executionChain
this.vectorStore = vectorStore
this.maxIterations = maxIterations
this.topK = topK
}
addTask(task: Task) {
@@ -219,7 +223,7 @@ export class BabyAGI {
this.printNextTask(task)
// Step 2: Execute the task
const result = await executeTask(this.vectorStore, this.executionChain, objective, task.task_name)
const result = await executeTask(this.vectorStore, this.executionChain, objective, task.task_name, this.topK)
const thisTaskId = task.task_id
finalResult = result
this.printTaskResult(result)
@@ -257,10 +261,10 @@ export class BabyAGI {
return finalResult
}
static fromLLM(llm: BaseChatModel, vectorstore: VectorStore, maxIterations = 3): BabyAGI {
static fromLLM(llm: BaseChatModel, vectorstore: VectorStore, maxIterations = 3, topK = 4): BabyAGI {
const taskCreationChain = TaskCreationChain.from_llm(llm)
const taskPrioritizationChain = TaskPrioritizationChain.from_llm(llm)
const executionChain = ExecutionChain.from_llm(llm)
return new BabyAGI(taskCreationChain, taskPrioritizationChain, executionChain, vectorstore, maxIterations)
return new BabyAGI(taskCreationChain, taskPrioritizationChain, executionChain, vectorstore, maxIterations, topK)
}
}