Add more nodes for agents, loaders

This commit is contained in:
Henry
2023-04-10 13:56:44 +01:00
parent 05c86ff9c5
commit 58e06718d1
57 changed files with 1584 additions and 89 deletions
@@ -0,0 +1,74 @@
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
class ConversationalRetrievalQAChain_Chains implements INode {
label: string
name: string
type: string
icon: string
category: string
baseClasses: string[]
description: string
inputs: INodeParams[]
constructor() {
this.label = 'Conversational Retrieval QA Chain'
this.name = 'conversationalRetrievalQAChain'
this.type = 'ConversationalRetrievalQAChain'
this.icon = 'chain.svg'
this.category = 'Chains'
this.description = 'Document QA - built on RetrievalQAChain to provide a chat history component'
this.inputs = [
{
label: 'LLM',
name: 'llm',
type: 'BaseLanguageModel'
},
{
label: 'Vector Store Retriever',
name: 'vectorStoreRetriever',
type: 'BaseRetriever'
}
]
}
async getBaseClasses(): Promise<string[]> {
const { ConversationalRetrievalQAChain } = await import('langchain/chains')
return getBaseClasses(ConversationalRetrievalQAChain)
}
async init(nodeData: INodeData): Promise<any> {
const { ConversationalRetrievalQAChain } = await import('langchain/chains')
const llm = nodeData.inputs?.llm
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever
const chain = ConversationalRetrievalQAChain.fromLLM(llm, vectorStoreRetriever)
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const chain = nodeData.instance
let chatHistory = ''
if (options && options.chatHistory) {
const histories: IMessage[] = options.chatHistory
chatHistory = histories
.map((item) => {
return item.message
})
.join('')
}
const obj = {
question: input,
chat_history: chatHistory ? chatHistory : []
}
const res = await chain.call(obj)
return res?.text
}
}
module.exports = { nodeClass: ConversationalRetrievalQAChain_Chains }
@@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-dna" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"></path>
<path d="M14.828 14.828a4 4 0 1 0 -5.656 -5.656a4 4 0 0 0 5.656 5.656z"></path>
<path d="M9.172 20.485a4 4 0 1 0 -5.657 -5.657"></path>
<path d="M14.828 3.515a4 4 0 0 0 5.657 5.657"></path>
</svg>

After

Width:  |  Height:  |  Size: 489 B

@@ -28,6 +28,17 @@ class LLMChain_Chains implements INode {
label: 'Prompt',
name: 'prompt',
type: 'BasePromptTemplate'
},
{
label: 'Format Prompt Values',
name: 'promptValues',
type: 'string',
rows: 5,
placeholder: `{
"input_language": "English",
"output_language": "French"
}`,
optional: true
}
]
}
@@ -48,13 +59,39 @@ class LLMChain_Chains implements INode {
}
async run(nodeData: INodeData, input: string): Promise<string> {
const prompt = nodeData.instance.prompt.inputVariables // ["product"]
if (prompt.length > 1) throw new Error('Prompt can only contains 1 literal string {}. Multiples are found')
const inputVariables = nodeData.instance.prompt.inputVariables // ["product"]
const chain = nodeData.instance
const res = await chain.run(input)
return res
if (inputVariables.length === 1) {
const res = await chain.run(input)
return res
} else if (inputVariables.length > 1) {
const promptValuesStr = nodeData.inputs?.promptValues as string
if (!promptValuesStr) throw new Error('Please provide Prompt Values')
const promptValues = JSON.parse(promptValuesStr.replace(/\s/g, ''))
let seen = []
for (const variable of inputVariables) {
seen.push(variable)
if (promptValues[variable]) {
seen.pop()
}
}
if (seen.length === 1) {
const options = {
...promptValues,
[seen.pop()]: input
}
const res = await chain.call(options)
return res?.text
} else throw new Error('Please provide Prompt Values')
} else {
const res = await chain.run(input)
return res
}
}
}
@@ -0,0 +1,57 @@
import { INode, INodeData, INodeParams } from '../../../src/Interface'
class RetrievalQAChain_Chains implements INode {
label: string
name: string
type: string
icon: string
category: string
baseClasses: string[]
description: string
inputs: INodeParams[]
constructor() {
this.label = 'RetrievalQA Chain'
this.name = 'retrievalQAChain'
this.type = 'RetrievalQAChain'
this.icon = 'chain.svg'
this.category = 'Chains'
this.description = 'QA chain to answer a question based on the retrieved documents'
this.inputs = [
{
label: 'LLM',
name: 'llm',
type: 'BaseLanguageModel'
},
{
label: 'Vector Store Retriever',
name: 'vectorStoreRetriever',
type: 'BaseRetriever'
}
]
}
async getBaseClasses(): Promise<string[]> {
return ['BaseChain']
}
async init(nodeData: INodeData): Promise<any> {
const { RetrievalQAChain } = await import('langchain/chains')
const llm = nodeData.inputs?.llm
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever
const chain = RetrievalQAChain.fromLLM(llm, vectorStoreRetriever)
return chain
}
async run(nodeData: INodeData, input: string): Promise<string> {
const chain = nodeData.instance
const obj = {
query: input
}
const res = await chain.call(obj)
return res?.text
}
}
module.exports = { nodeClass: RetrievalQAChain_Chains }
@@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-dna" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"></path>
<path d="M14.828 14.828a4 4 0 1 0 -5.656 -5.656a4 4 0 0 0 5.656 5.656z"></path>
<path d="M9.172 20.485a4 4 0 1 0 -5.657 -5.657"></path>
<path d="M14.828 3.515a4 4 0 0 0 5.657 5.657"></path>
</svg>

After

Width:  |  Height:  |  Size: 489 B