Fix HuggingFaceModelProvider

This commit is contained in:
Yang Luo 2023-09-08 00:51:24 +08:00
parent c65baeba0b
commit 3a27f56531
6 changed files with 41 additions and 74 deletions

View File

@ -1,3 +1,17 @@
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai package ai
import ( import (
@ -11,45 +25,29 @@ import (
) )
type HuggingFaceModelProvider struct { type HuggingFaceModelProvider struct {
SubType string
SecretKey string SecretKey string
} }
func NewHuggingFaceModelProvider(secretKey string) (*HuggingFaceModelProvider, error) { func NewHuggingFaceModelProvider(subType string, secretKey string) (*HuggingFaceModelProvider, error) {
p := &HuggingFaceModelProvider{ p := &HuggingFaceModelProvider{
SubType: subType,
SecretKey: secretKey, SecretKey: secretKey,
} }
return p, nil return p, nil
} }
func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
resp, err := getHuggingFaceResp(question, p.SecretKey) client := huggingface.New(p.SubType, 1, false).WithToken(p.SecretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
if err != nil {
return err
}
fmt.Println(resp)
return nil
}
func getHuggingFaceResp(question string, secretKey string) (string, error) {
client := huggingface.New("gpt2", 1, false).WithToken(secretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
ctx := context.Background() ctx := context.Background()
resp, err := client.Completion(ctx, question) resp, err := client.Completion(ctx, question)
if err != nil { if err != nil {
return "", err return err
} }
return fixHuggingFaceResp(resp), nil
}
func GetHuggingFaceResp(question string, secretKey string) (string, error) {
return getHuggingFaceResp(question, secretKey)
}
func fixHuggingFaceResp(resp string) string {
resp = strings.Split(resp, "\n")[0] resp = strings.Split(resp, "\n")[0]
return resp fmt.Println(resp)
return nil
} }

View File

@ -1,49 +0,0 @@
package ai_test
import (
"testing"
"github.com/casbin/casibase/ai"
"github.com/casbin/casibase/proxy"
)
func TestGetHuggingFaceResp(t *testing.T) {
proxy.InitHttpClient()
prompt := "Hello AI. Who are you?"
secretKey := "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"
resp, err := ai.GetHuggingFaceResp(prompt, secretKey)
if err != nil {
t.Errorf("GetHuggingFaceResp err: %v", err)
return
}
if resp == "" {
t.Error("GetHuggingFaceResp err: resp is nil")
return
}
t.Logf("resp: %v", resp)
}
func TestHuggingFaceModelProvider_QueryText(t *testing.T) {
proxy.InitHttpClient()
prompt := "Hello AI. Who are you?"
secretKey := "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"
p, err := ai.NewHuggingFaceModelProvider(secretKey)
if err != nil {
t.Errorf("NewHuggingFaceModelProvider err: %v", err)
return
}
err = p.QueryText(prompt, nil, nil)
if err != nil {
t.Errorf("QueryText err: %v", err)
return
}
t.Logf("QueryText success")
}

View File

@ -30,6 +30,12 @@ func GetModelProvider(typ string, subType string, secretKey string) (ModelProvid
return nil, err return nil, err
} }
return p, nil return p, nil
} else if typ == "Hugging Face" {
p, err := NewHuggingFaceModelProvider(subType, secretKey)
if err != nil {
return nil, err
}
return p, nil
} }
return nil, nil return nil, nil

View File

@ -118,6 +118,7 @@ class ProviderEditPage extends React.Component {
{ {
[ [
{id: "OpenAI API", name: "OpenAI API"}, {id: "OpenAI API", name: "OpenAI API"},
{id: "Hugging Face", name: "Hugging Face"},
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) ].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
} }
</Select> </Select>

View File

@ -124,7 +124,7 @@ class ProviderListPage extends React.Component {
title: i18next.t("provider:Category"), title: i18next.t("provider:Category"),
dataIndex: "category", dataIndex: "category",
key: "category", key: "category",
width: "160px", width: "140px",
sorter: (a, b) => a.category.localeCompare(b.category), sorter: (a, b) => a.category.localeCompare(b.category),
}, },
{ {
@ -138,7 +138,7 @@ class ProviderListPage extends React.Component {
title: i18next.t("provider:Sub type"), title: i18next.t("provider:Sub type"),
dataIndex: "subType", dataIndex: "subType",
key: "subType", key: "subType",
width: "160px", width: "200px",
sorter: (a, b) => a.subType.localeCompare(b.subType), sorter: (a, b) => a.subType.localeCompare(b.subType),
}, },
{ {

View File

@ -682,6 +682,17 @@ export function getProviderSubTypeOptions(type) {
{id: "babbage", name: "babbage"}, {id: "babbage", name: "babbage"},
] ]
); );
} else if (type === "Hugging Face") {
return (
[
{id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"},
{id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"},
{id: "bigscience/bloom", name: "bigscience/bloom"},
{id: "gpt2", name: "gpt2"},
{id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"},
{id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"},
]
);
} else { } else {
return []; return [];
} }