add multi prompt chain

This commit is contained in:
Henry
2023-05-25 00:36:09 +01:00
parent f2d8796920
commit 1d42456ba1
7 changed files with 609 additions and 1 deletions
@@ -56,7 +56,8 @@ class ConversationChain_Chains implements INode {
const obj: any = {
llm: model,
memory
memory,
verbose: process.env.DEBUG === 'true' ? true : false
}
const chatPrompt = ChatPromptTemplate.fromPromptMessages([
@@ -0,0 +1,68 @@
import { BaseLanguageModel } from 'langchain/base_language'
import { INode, INodeData, INodeParams, PromptRetriever } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { MultiPromptChain } from 'langchain/chains'
class MultiPromptChain_Chains implements INode {
label: string
name: string
type: string
icon: string
category: string
baseClasses: string[]
description: string
inputs: INodeParams[]
constructor() {
this.label = 'Multi Prompt Chain'
this.name = 'multiPromptChain'
this.type = 'MultiPromptChain'
this.icon = 'chain.svg'
this.category = 'Chains'
this.description = 'Chain automatically picks an appropriate prompt from multiple prompt templates'
this.baseClasses = [this.type, ...getBaseClasses(MultiPromptChain)]
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel'
},
{
label: 'Prompt Retriever',
name: 'promptRetriever',
type: 'PromptRetriever',
list: true
}
]
}
async init(nodeData: INodeData): Promise<any> {
const model = nodeData.inputs?.model as BaseLanguageModel
const promptRetriever = nodeData.inputs?.promptRetriever as PromptRetriever[]
const promptNames = []
const promptDescriptions = []
const promptTemplates = []
for (const prompt of promptRetriever) {
promptNames.push(prompt.name)
promptDescriptions.push(prompt.description)
promptTemplates.push(prompt.systemMessage)
}
const chain = MultiPromptChain.fromPrompts(model, promptNames, promptDescriptions, promptTemplates, undefined, {
verbose: process.env.DEBUG === 'true' ? true : false
} as any)
return chain
}
async run(nodeData: INodeData, input: string): Promise<string> {
const chain = nodeData.instance as MultiPromptChain
const res = await chain.call({ input })
return res?.text
}
}
module.exports = { nodeClass: MultiPromptChain_Chains }
@@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-dna" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"></path>
<path d="M14.828 14.828a4 4 0 1 0 -5.656 -5.656a4 4 0 0 0 5.656 5.656z"></path>
<path d="M9.172 20.485a4 4 0 1 0 -5.657 -5.657"></path>
<path d="M14.828 3.515a4 4 0 0 0 5.657 5.657"></path>
</svg>

After

Width:  |  Height:  |  Size: 489 B

@@ -0,0 +1,62 @@
import { INode, INodeData, INodeParams, PromptRetriever, PromptRetrieverInput } from '../../../src/Interface'
class PromptRetriever_Retrievers implements INode {
label: string
name: string
description: string
type: string
icon: string
category: string
baseClasses: string[]
inputs: INodeParams[]
constructor() {
this.label = 'Prompt Retriever'
this.name = 'promptRetriever'
this.type = 'PromptRetriever'
this.icon = 'promptretriever.svg'
this.category = 'Retrievers'
this.description = 'Store prompt template with name & description to be later queried by MultiPromptChain'
this.baseClasses = [this.type]
this.inputs = [
{
label: 'Prompt Name',
name: 'name',
type: 'string',
placeholder: 'physics-qa'
},
{
label: 'Prompt Description',
name: 'description',
type: 'string',
rows: 3,
description: 'Description of what the prompt does and when it should be used',
placeholder: 'Good for answering questions about physics'
},
{
label: 'Prompt System Message',
name: 'systemMessage',
type: 'string',
rows: 4,
placeholder: `You are a very smart physics professor. You are great at answering questions about physics in a concise and easy to understand manner. When you don't know the answer to a question you admit that you don't know.`
}
]
}
async init(nodeData: INodeData): Promise<any> {
const name = nodeData.inputs?.name as string
const description = nodeData.inputs?.description as string
const systemMessage = nodeData.inputs?.systemMessage as string
const obj = {
name,
description,
systemMessage
} as PromptRetrieverInput
const retriever = new PromptRetriever(obj)
return retriever
}
}
module.exports = { nodeClass: PromptRetriever_Retrievers }
@@ -0,0 +1,8 @@
<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-message-down" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"></path>
<path d="M8 9h8"></path>
<path d="M8 13h6"></path>
<path d="M11.998 18.601l-3.998 2.399v-3h-2a3 3 0 0 1 -3 -3v-8a3 3 0 0 1 3 -3h12a3 3 0 0 1 3 3v5.5"></path>
<path d="M19 16v6"></path>
<path d="M22 19l-3 3l-3 -3"></path>
</svg>

After

Width:  |  Height:  |  Size: 535 B

+21
View File
@@ -103,3 +103,24 @@ export class PromptTemplate extends LangchainPromptTemplate {
super(input)
}
}
export interface PromptRetrieverInput {
name: string
description: string
systemMessage: string
}
const fixedTemplate = `Here is a question:
{input}
`
export class PromptRetriever {
name: string
description: string
systemMessage: string
constructor(fields: PromptRetrieverInput) {
this.name = fields.name
this.description = fields.description
this.systemMessage = `${fields.systemMessage}\n${fixedTemplate}`
}
}