add endpoint to HF

This commit is contained in:
Henry
2023-07-07 17:36:23 +01:00
parent 4dd43fb2c4
commit 0923a35683
8 changed files with 316 additions and 4 deletions
@@ -1,6 +1,6 @@
import { INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { HFInput, HuggingFaceInference } from 'langchain/llms/hf'
import { HFInput, HuggingFaceInference } from './core'
class ChatHuggingFace_ChatModels implements INode {
label: string
@@ -71,6 +71,15 @@ class ChatHuggingFace_ChatModels implements INode {
description: 'Frequency Penalty parameter may not apply to certain model. Please check available model parameters',
optional: true,
additionalParams: true
},
{
label: 'Endpoint',
name: 'endpoint',
type: 'string',
placeholder: 'https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2',
description: 'Using your own inference endpoint',
optional: true,
additionalParams: true
}
]
}
@@ -83,6 +92,7 @@ class ChatHuggingFace_ChatModels implements INode {
const topP = nodeData.inputs?.topP as string
const hfTopK = nodeData.inputs?.hfTopK as string
const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string
const endpoint = nodeData.inputs?.endpoint as string
const obj: Partial<HFInput> = {
model,
@@ -94,6 +104,7 @@ class ChatHuggingFace_ChatModels implements INode {
if (topP) obj.topP = parseInt(topP, 10)
if (hfTopK) obj.topK = parseInt(hfTopK, 10)
if (frequencyPenalty) obj.frequencyPenalty = parseInt(frequencyPenalty, 10)
if (endpoint) obj.endpoint = endpoint
const huggingFace = new HuggingFaceInference(obj)
return huggingFace