feat: add api support of HuggingFace (#620)
This commit is contained in:
parent
ec9bcb6d78
commit
6edd86276b
|
@ -0,0 +1,55 @@
|
|||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/casbin/casibase/proxy"
|
||||
"github.com/henomis/lingoose/llm/huggingface"
|
||||
)
|
||||
|
||||
type HuggingFaceModelProvider struct {
|
||||
SecretKey string
|
||||
}
|
||||
|
||||
func NewHuggingFaceModelProvider(secretKey string) (*HuggingFaceModelProvider, error) {
|
||||
p := &HuggingFaceModelProvider{
|
||||
SecretKey: secretKey,
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
|
||||
resp, err := getHuggingFaceResp(question, p.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHuggingFaceResp(question string, secretKey string) (string, error) {
|
||||
client := huggingface.New("gpt2", 1, false).WithToken(secretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := client.Completion(ctx, question)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fixHuggingFaceResp(resp), nil
|
||||
}
|
||||
|
||||
func GetHuggingFaceResp(question string, secretKey string) (string, error) {
|
||||
return getHuggingFaceResp(question, secretKey)
|
||||
}
|
||||
|
||||
func fixHuggingFaceResp(resp string) string {
|
||||
resp = strings.Split(resp, "\n")[0]
|
||||
return resp
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package ai_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/casbin/casibase/ai"
|
||||
"github.com/casbin/casibase/proxy"
|
||||
)
|
||||
|
||||
func TestGetHuggingFaceResp(t *testing.T) {
|
||||
proxy.InitHttpClient()
|
||||
|
||||
prompt := "Hello AI. Who are you?"
|
||||
secretKey := "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"
|
||||
|
||||
resp, err := ai.GetHuggingFaceResp(prompt, secretKey)
|
||||
if err != nil {
|
||||
t.Errorf("GetHuggingFaceResp err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if resp == "" {
|
||||
t.Error("GetHuggingFaceResp err: resp is nil")
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("resp: %v", resp)
|
||||
}
|
||||
|
||||
func TestHuggingFaceModelProvider_QueryText(t *testing.T) {
|
||||
proxy.InitHttpClient()
|
||||
|
||||
prompt := "Hello AI. Who are you?"
|
||||
secretKey := "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"
|
||||
|
||||
p, err := ai.NewHuggingFaceModelProvider(secretKey)
|
||||
if err != nil {
|
||||
t.Errorf("NewHuggingFaceModelProvider err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = p.QueryText(prompt, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("QueryText err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("QueryText success")
|
||||
}
|
1
go.mod
1
go.mod
|
@ -48,6 +48,7 @@ require (
|
|||
github.com/gomodule/redigo v2.0.0+incompatible // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/hashicorp/golang-lru v0.5.4 // indirect
|
||||
github.com/henomis/lingoose v0.0.11-alpha1 // indirect
|
||||
github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
|
|
3
go.sum
3
go.sum
|
@ -349,6 +349,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
|
|||
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
|
||||
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
|
||||
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
|
||||
github.com/henomis/lingoose v0.0.11-alpha1 h1:6iXcdewIdTDJCNg7AxZF6onobLEh0BPFyHYTKSV8bAw=
|
||||
github.com/henomis/lingoose v0.0.11-alpha1/go.mod h1:hOfRJswe3sA17uZSUJHJNrBiqPxEt2FM9wUFqFFOHSE=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
|
@ -1256,6 +1258,7 @@ modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
|||
modernc.org/z v1.0.1/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA=
|
||||
modernc.org/z v1.5.1 h1:RTNHdsrOpeoSeOF4FbzTo8gBYByaJ5xT7NgZ9ZqRiJM=
|
||||
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
||||
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
||||
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
|
||||
|
|
Loading…
Reference in New Issue