Fix HuggingFaceModelProvider
This commit is contained in:
parent
c65baeba0b
commit
3a27f56531
|
@ -1,3 +1,17 @@
|
||||||
|
// Copyright 2023 The casbin Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
package ai
|
package ai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -11,45 +25,29 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type HuggingFaceModelProvider struct {
|
type HuggingFaceModelProvider struct {
|
||||||
|
SubType string
|
||||||
SecretKey string
|
SecretKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHuggingFaceModelProvider(secretKey string) (*HuggingFaceModelProvider, error) {
|
func NewHuggingFaceModelProvider(subType string, secretKey string) (*HuggingFaceModelProvider, error) {
|
||||||
p := &HuggingFaceModelProvider{
|
p := &HuggingFaceModelProvider{
|
||||||
|
SubType: subType,
|
||||||
SecretKey: secretKey,
|
SecretKey: secretKey,
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
|
func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
|
||||||
resp, err := getHuggingFaceResp(question, p.SecretKey)
|
client := huggingface.New(p.SubType, 1, false).WithToken(p.SecretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println(resp)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHuggingFaceResp(question string, secretKey string) (string, error) {
|
|
||||||
client := huggingface.New("gpt2", 1, false).WithToken(secretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
resp, err := client.Completion(ctx, question)
|
resp, err := client.Completion(ctx, question)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fixHuggingFaceResp(resp), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetHuggingFaceResp(question string, secretKey string) (string, error) {
|
|
||||||
return getHuggingFaceResp(question, secretKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func fixHuggingFaceResp(resp string) string {
|
|
||||||
resp = strings.Split(resp, "\n")[0]
|
resp = strings.Split(resp, "\n")[0]
|
||||||
return resp
|
fmt.Println(resp)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,49 +0,0 @@
|
||||||
package ai_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/casbin/casibase/ai"
|
|
||||||
"github.com/casbin/casibase/proxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetHuggingFaceResp(t *testing.T) {
|
|
||||||
proxy.InitHttpClient()
|
|
||||||
|
|
||||||
prompt := "Hello AI. Who are you?"
|
|
||||||
secretKey := "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"
|
|
||||||
|
|
||||||
resp, err := ai.GetHuggingFaceResp(prompt, secretKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("GetHuggingFaceResp err: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp == "" {
|
|
||||||
t.Error("GetHuggingFaceResp err: resp is nil")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("resp: %v", resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHuggingFaceModelProvider_QueryText(t *testing.T) {
|
|
||||||
proxy.InitHttpClient()
|
|
||||||
|
|
||||||
prompt := "Hello AI. Who are you?"
|
|
||||||
secretKey := "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"
|
|
||||||
|
|
||||||
p, err := ai.NewHuggingFaceModelProvider(secretKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewHuggingFaceModelProvider err: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = p.QueryText(prompt, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("QueryText err: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("QueryText success")
|
|
||||||
}
|
|
|
@ -30,6 +30,12 @@ func GetModelProvider(typ string, subType string, secretKey string) (ModelProvid
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
|
} else if typ == "Hugging Face" {
|
||||||
|
p, err := NewHuggingFaceModelProvider(subType, secretKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|
|
@ -118,6 +118,7 @@ class ProviderEditPage extends React.Component {
|
||||||
{
|
{
|
||||||
[
|
[
|
||||||
{id: "OpenAI API", name: "OpenAI API"},
|
{id: "OpenAI API", name: "OpenAI API"},
|
||||||
|
{id: "Hugging Face", name: "Hugging Face"},
|
||||||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
|
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
|
||||||
}
|
}
|
||||||
</Select>
|
</Select>
|
||||||
|
|
|
@ -124,7 +124,7 @@ class ProviderListPage extends React.Component {
|
||||||
title: i18next.t("provider:Category"),
|
title: i18next.t("provider:Category"),
|
||||||
dataIndex: "category",
|
dataIndex: "category",
|
||||||
key: "category",
|
key: "category",
|
||||||
width: "160px",
|
width: "140px",
|
||||||
sorter: (a, b) => a.category.localeCompare(b.category),
|
sorter: (a, b) => a.category.localeCompare(b.category),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -138,7 +138,7 @@ class ProviderListPage extends React.Component {
|
||||||
title: i18next.t("provider:Sub type"),
|
title: i18next.t("provider:Sub type"),
|
||||||
dataIndex: "subType",
|
dataIndex: "subType",
|
||||||
key: "subType",
|
key: "subType",
|
||||||
width: "160px",
|
width: "200px",
|
||||||
sorter: (a, b) => a.subType.localeCompare(b.subType),
|
sorter: (a, b) => a.subType.localeCompare(b.subType),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -682,6 +682,17 @@ export function getProviderSubTypeOptions(type) {
|
||||||
{id: "babbage", name: "babbage"},
|
{id: "babbage", name: "babbage"},
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
} else if (type === "Hugging Face") {
|
||||||
|
return (
|
||||||
|
[
|
||||||
|
{id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"},
|
||||||
|
{id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"},
|
||||||
|
{id: "bigscience/bloom", name: "bigscience/bloom"},
|
||||||
|
{id: "gpt2", name: "gpt2"},
|
||||||
|
{id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"},
|
||||||
|
{id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"},
|
||||||
|
]
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue