feat: add api support of HuggingFace (#620)

This commit is contained in:
Kelvin Chiu 2023-09-07 23:17:59 +08:00 committed by GitHub
parent ec9bcb6d78
commit 6edd86276b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 0 deletions

55
ai/huggingface.go Normal file
View File

@ -0,0 +1,55 @@
package ai
import (
"context"
"fmt"
"io"
"strings"
"github.com/casbin/casibase/proxy"
"github.com/henomis/lingoose/llm/huggingface"
)
type HuggingFaceModelProvider struct {
SecretKey string
}
func NewHuggingFaceModelProvider(secretKey string) (*HuggingFaceModelProvider, error) {
p := &HuggingFaceModelProvider{
SecretKey: secretKey,
}
return p, nil
}
func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
resp, err := getHuggingFaceResp(question, p.SecretKey)
if err != nil {
return err
}
fmt.Println(resp)
return nil
}
func getHuggingFaceResp(question string, secretKey string) (string, error) {
client := huggingface.New("gpt2", 1, false).WithToken(secretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
ctx := context.Background()
resp, err := client.Completion(ctx, question)
if err != nil {
return "", err
}
return fixHuggingFaceResp(resp), nil
}
func GetHuggingFaceResp(question string, secretKey string) (string, error) {
return getHuggingFaceResp(question, secretKey)
}
func fixHuggingFaceResp(resp string) string {
resp = strings.Split(resp, "\n")[0]
return resp
}

49
ai/huggingface_test.go Normal file
View File

@ -0,0 +1,49 @@
package ai_test
import (
"testing"
"github.com/casbin/casibase/ai"
"github.com/casbin/casibase/proxy"
)
func TestGetHuggingFaceResp(t *testing.T) {
proxy.InitHttpClient()
prompt := "Hello AI. Who are you?"
secretKey := "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"
resp, err := ai.GetHuggingFaceResp(prompt, secretKey)
if err != nil {
t.Errorf("GetHuggingFaceResp err: %v", err)
return
}
if resp == "" {
t.Error("GetHuggingFaceResp err: resp is nil")
return
}
t.Logf("resp: %v", resp)
}
func TestHuggingFaceModelProvider_QueryText(t *testing.T) {
proxy.InitHttpClient()
prompt := "Hello AI. Who are you?"
secretKey := "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"
p, err := ai.NewHuggingFaceModelProvider(secretKey)
if err != nil {
t.Errorf("NewHuggingFaceModelProvider err: %v", err)
return
}
err = p.QueryText(prompt, nil, nil)
if err != nil {
t.Errorf("QueryText err: %v", err)
return
}
t.Logf("QueryText success")
}

1
go.mod
View File

@ -48,6 +48,7 @@ require (
github.com/gomodule/redigo v2.0.0+incompatible // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/henomis/lingoose v0.0.11-alpha1 // indirect
github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect

3
go.sum
View File

@ -349,6 +349,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
github.com/henomis/lingoose v0.0.11-alpha1 h1:6iXcdewIdTDJCNg7AxZF6onobLEh0BPFyHYTKSV8bAw=
github.com/henomis/lingoose v0.0.11-alpha1/go.mod h1:hOfRJswe3sA17uZSUJHJNrBiqPxEt2FM9wUFqFFOHSE=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
@ -1256,6 +1258,7 @@ modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
modernc.org/z v1.0.1/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA=
modernc.org/z v1.5.1 h1:RTNHdsrOpeoSeOF4FbzTo8gBYByaJ5xT7NgZ9ZqRiJM=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=