Feature/Mistral FunctionAgent (#1912)

* add mistral ai agent, add used tools streaming

* fix AWS Bedrock imports

* update pnpm lock
This commit is contained in:
Henry Heng
2024-03-18 13:17:00 +08:00
committed by GitHub
parent 58122e985c
commit cd4c659009
13 changed files with 30546 additions and 29817 deletions
@@ -1,8 +1,36 @@
import { ChatCompletionResponse, ToolCalls as MistralAIToolCalls } from '@mistralai/mistralai'
import { BaseCache } from '@langchain/core/caches'
import { ChatMistralAI, ChatMistralAIInput } from '@langchain/mistralai'
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
import { NewTokenIndices } from '@langchain/core/callbacks/base'
import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
import {
MessageType,
type BaseMessage,
MessageContent,
AIMessage,
HumanMessage,
HumanMessageChunk,
AIMessageChunk,
ToolMessageChunk,
ChatMessageChunk
} from '@langchain/core/messages'
import { ChatMistralAI as LangchainChatMistralAI, ChatMistralAIInput } from '@langchain/mistralai'
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
interface TokenUsage {
completionTokens?: number
promptTokens?: number
totalTokens?: number
}
type MistralAIInputMessage = {
role: string
name?: string
content: string | string[]
tool_calls?: MistralAIToolCalls[] | any[]
}
class ChatMistral_ChatModels implements INode {
label: string
name: string
@@ -135,4 +163,243 @@ class ChatMistral_ChatModels implements INode {
}
}
class ChatMistralAI extends LangchainChatMistralAI {
async _generate(
messages: BaseMessage[],
options?: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const tokenUsage: TokenUsage = {}
const params = this.invocationParams(options)
const mistralMessages = this.convertMessagesToMistralMessages(messages)
const input = {
...params,
messages: mistralMessages
}
// Handle streaming
if (this.streaming) {
const stream = this._streamResponseChunks(messages, options, runManager)
const finalChunks: Record<number, ChatGenerationChunk> = {}
for await (const chunk of stream) {
const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0
if (finalChunks[index] === undefined) {
finalChunks[index] = chunk
} else {
finalChunks[index] = finalChunks[index].concat(chunk)
}
}
const generations = Object.entries(finalChunks)
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
.map(([_, value]) => value)
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
}
// Not streaming, so we can just call the API once.
const response = await this.completionWithRetry(input, false)
const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens } = response?.usage ?? {}
if (completionTokens) {
tokenUsage.completionTokens = (tokenUsage.completionTokens ?? 0) + completionTokens
}
if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens
}
if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens
}
const generations: ChatGeneration[] = []
for (const part of response?.choices ?? []) {
if ('delta' in part) {
throw new Error('Delta not supported in non-streaming mode.')
}
if (!('message' in part)) {
throw new Error('No message found in the choice.')
}
const text = part.message?.content ?? ''
const generation: ChatGeneration = {
text,
message: this.mistralAIResponseToChatMessage(part)
}
if (part.finish_reason) {
generation.generationInfo = { finish_reason: part.finish_reason }
}
generations.push(generation)
}
return {
generations,
llmOutput: { tokenUsage }
}
}
async *_streamResponseChunks(
messages: BaseMessage[],
options?: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const mistralMessages = this.convertMessagesToMistralMessages(messages)
const params = this.invocationParams(options)
const input = {
...params,
messages: mistralMessages
}
const streamIterable = await this.completionWithRetry(input, true)
for await (const data of streamIterable) {
const choice = data?.choices[0]
if (!choice || !('delta' in choice)) {
continue
}
const { delta } = choice
if (!delta) {
continue
}
const newTokenIndices = {
prompt: 0,
completion: choice.index ?? 0
}
const message = this._convertDeltaToMessageChunk(delta)
if (message === null) {
// Do not yield a chunk if the message is empty
continue
}
const generationChunk = new ChatGenerationChunk({
message,
text: delta.content ?? '',
generationInfo: newTokenIndices
})
yield generationChunk
void runManager?.handleLLMNewToken(generationChunk.text ?? '', newTokenIndices, undefined, undefined, undefined, {
chunk: generationChunk
})
}
if (options?.signal?.aborted) {
throw new Error('AbortError')
}
}
_convertDeltaToMessageChunk(delta: {
role?: string | undefined
content?: string | undefined
tool_calls?: MistralAIToolCalls[] | undefined
}) {
if (!delta.content && !delta.tool_calls) {
return null
}
// Our merge additional kwargs util function will throw unless there
// is an index key in each tool object (as seen in OpenAI's) so we
// need to insert it here.
const toolCallsWithIndex = delta.tool_calls?.length
? delta.tool_calls?.map((toolCall, index) => ({
...toolCall,
index
}))
: undefined
let role = 'assistant'
if (delta.role) {
role = delta.role
} else if (toolCallsWithIndex) {
role = 'tool'
}
const content = delta.content ?? ''
let additional_kwargs
if (toolCallsWithIndex) {
additional_kwargs = {
tool_calls: toolCallsWithIndex
}
} else {
additional_kwargs = {}
}
if (role === 'user') {
return new HumanMessageChunk({ content })
} else if (role === 'assistant') {
return new AIMessageChunk({ content, additional_kwargs })
} else if (role === 'tool') {
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: toolCallsWithIndex?.[0].id ?? ''
})
} else {
return new ChatMessageChunk({ content, role })
}
}
convertMessagesToMistralMessages(messages: Array<BaseMessage>): Array<MistralAIInputMessage> {
const getRole = (role: MessageType) => {
switch (role) {
case 'human':
return 'user'
case 'ai':
return 'assistant'
case 'tool':
return 'tool'
case 'function':
return 'function'
case 'system':
return 'system'
default:
throw new Error(`Unknown message type: ${role}`)
}
}
const getContent = (content: MessageContent): string => {
if (typeof content === 'string') {
return content
}
throw new Error(`ChatMistralAI does not support non text message content. Received: ${JSON.stringify(content, null, 2)}`)
}
const mistralMessages = []
for (const msg of messages) {
const msgObj: MistralAIInputMessage = {
role: getRole(msg._getType()),
content: getContent(msg.content)
}
if (getRole(msg._getType()) === 'tool') {
msgObj.role = 'assistant'
msgObj.tool_calls = msg.additional_kwargs?.tool_calls ?? []
} else if (getRole(msg._getType()) === 'function') {
msgObj.role = 'tool'
msgObj.name = msg.name
}
mistralMessages.push(msgObj)
}
return mistralMessages
}
mistralAIResponseToChatMessage(choice: ChatCompletionResponse['choices'][0]): BaseMessage {
const { message } = choice
// MistralAI SDK does not include tool_calls in the non
// streaming return type, so we need to extract it like this
// to satisfy typescript.
let toolCalls: MistralAIToolCalls[] = []
if ('tool_calls' in message) {
toolCalls = message.tool_calls as MistralAIToolCalls[]
}
switch (message.role) {
case 'assistant':
return new AIMessage({
content: message.content ?? '',
additional_kwargs: {
tool_calls: toolCalls
}
})
default:
return new HumanMessage(message.content ?? '')
}
}
}
module.exports = { nodeClass: ChatMistral_ChatModels }