Fix HuggingFaceModelProvider
This commit is contained in:
parent
c65baeba0b
commit
3a27f56531
|
@ -1,3 +1,17 @@
|
|||
// Copyright 2023 The casbin Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ai
|
||||
|
||||
import (
|
||||
|
@ -11,45 +25,29 @@ import (
|
|||
)
|
||||
|
||||
type HuggingFaceModelProvider struct {
|
||||
SubType string
|
||||
SecretKey string
|
||||
}
|
||||
|
||||
func NewHuggingFaceModelProvider(secretKey string) (*HuggingFaceModelProvider, error) {
|
||||
func NewHuggingFaceModelProvider(subType string, secretKey string) (*HuggingFaceModelProvider, error) {
|
||||
p := &HuggingFaceModelProvider{
|
||||
SubType: subType,
|
||||
SecretKey: secretKey,
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
|
||||
resp, err := getHuggingFaceResp(question, p.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHuggingFaceResp(question string, secretKey string) (string, error) {
|
||||
client := huggingface.New("gpt2", 1, false).WithToken(secretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
|
||||
client := huggingface.New(p.SubType, 1, false).WithToken(p.SecretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := client.Completion(ctx, question)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
return fixHuggingFaceResp(resp), nil
|
||||
}
|
||||
|
||||
func GetHuggingFaceResp(question string, secretKey string) (string, error) {
|
||||
return getHuggingFaceResp(question, secretKey)
|
||||
}
|
||||
|
||||
func fixHuggingFaceResp(resp string) string {
|
||||
resp = strings.Split(resp, "\n")[0]
|
||||
return resp
|
||||
fmt.Println(resp)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
package ai_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/casbin/casibase/ai"
|
||||
"github.com/casbin/casibase/proxy"
|
||||
)
|
||||
|
||||
func TestGetHuggingFaceResp(t *testing.T) {
|
||||
proxy.InitHttpClient()
|
||||
|
||||
prompt := "Hello AI. Who are you?"
|
||||
secretKey := "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"
|
||||
|
||||
resp, err := ai.GetHuggingFaceResp(prompt, secretKey)
|
||||
if err != nil {
|
||||
t.Errorf("GetHuggingFaceResp err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if resp == "" {
|
||||
t.Error("GetHuggingFaceResp err: resp is nil")
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("resp: %v", resp)
|
||||
}
|
||||
|
||||
func TestHuggingFaceModelProvider_QueryText(t *testing.T) {
|
||||
proxy.InitHttpClient()
|
||||
|
||||
prompt := "Hello AI. Who are you?"
|
||||
secretKey := "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"
|
||||
|
||||
p, err := ai.NewHuggingFaceModelProvider(secretKey)
|
||||
if err != nil {
|
||||
t.Errorf("NewHuggingFaceModelProvider err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = p.QueryText(prompt, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("QueryText err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("QueryText success")
|
||||
}
|
|
@ -30,6 +30,12 @@ func GetModelProvider(typ string, subType string, secretKey string) (ModelProvid
|
|||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
} else if typ == "Hugging Face" {
|
||||
p, err := NewHuggingFaceModelProvider(subType, secretKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
|
|
|
@ -118,6 +118,7 @@ class ProviderEditPage extends React.Component {
|
|||
{
|
||||
[
|
||||
{id: "OpenAI API", name: "OpenAI API"},
|
||||
{id: "Hugging Face", name: "Hugging Face"},
|
||||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
|
||||
}
|
||||
</Select>
|
||||
|
|
|
@ -124,7 +124,7 @@ class ProviderListPage extends React.Component {
|
|||
title: i18next.t("provider:Category"),
|
||||
dataIndex: "category",
|
||||
key: "category",
|
||||
width: "160px",
|
||||
width: "140px",
|
||||
sorter: (a, b) => a.category.localeCompare(b.category),
|
||||
},
|
||||
{
|
||||
|
@ -138,7 +138,7 @@ class ProviderListPage extends React.Component {
|
|||
title: i18next.t("provider:Sub type"),
|
||||
dataIndex: "subType",
|
||||
key: "subType",
|
||||
width: "160px",
|
||||
width: "200px",
|
||||
sorter: (a, b) => a.subType.localeCompare(b.subType),
|
||||
},
|
||||
{
|
||||
|
|
|
@ -682,6 +682,17 @@ export function getProviderSubTypeOptions(type) {
|
|||
{id: "babbage", name: "babbage"},
|
||||
]
|
||||
);
|
||||
} else if (type === "Hugging Face") {
|
||||
return (
|
||||
[
|
||||
{id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"},
|
||||
{id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"},
|
||||
{id: "bigscience/bloom", name: "bigscience/bloom"},
|
||||
{id: "gpt2", name: "gpt2"},
|
||||
{id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"},
|
||||
{id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"},
|
||||
]
|
||||
);
|
||||
} else {
|
||||
return [];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue