diff --git a/ai/huggingface.go b/ai/huggingface.go index 1d9877e..5bce326 100644 --- a/ai/huggingface.go +++ b/ai/huggingface.go @@ -1,3 +1,17 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package ai import ( @@ -11,45 +25,29 @@ import ( ) type HuggingFaceModelProvider struct { + SubType string SecretKey string } -func NewHuggingFaceModelProvider(secretKey string) (*HuggingFaceModelProvider, error) { +func NewHuggingFaceModelProvider(subType string, secretKey string) (*HuggingFaceModelProvider, error) { p := &HuggingFaceModelProvider{ + SubType: subType, SecretKey: secretKey, } return p, nil } func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { - resp, err := getHuggingFaceResp(question, p.SecretKey) - if err != nil { - return err - } - - fmt.Println(resp) - - return nil -} - -func getHuggingFaceResp(question string, secretKey string) (string, error) { - client := huggingface.New("gpt2", 1, false).WithToken(secretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration) + client := huggingface.New(p.SubType, 1, false).WithToken(p.SecretKey).WithHTTPClient(proxy.ProxyHttpClient).WithMode(huggingface.HuggingFaceModeTextGeneration) ctx := context.Background() resp, err := client.Completion(ctx, question) if err != nil { - return "", err + return err } - return fixHuggingFaceResp(resp), nil -} - -func GetHuggingFaceResp(question string, secretKey string) (string, error) { - return getHuggingFaceResp(question, secretKey) -} - -func fixHuggingFaceResp(resp string) string { resp = strings.Split(resp, "\n")[0] - return resp + fmt.Println(resp) + return nil } diff --git a/ai/huggingface_test.go b/ai/huggingface_test.go deleted file mode 100644 index bbd083a..0000000 --- a/ai/huggingface_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package ai_test - -import ( - "testing" - - "github.com/casbin/casibase/ai" - "github.com/casbin/casibase/proxy" -) - -func TestGetHuggingFaceResp(t *testing.T) { - proxy.InitHttpClient() - - prompt := "Hello AI. Who are you?" - secretKey := "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" - - resp, err := ai.GetHuggingFaceResp(prompt, secretKey) - if err != nil { - t.Errorf("GetHuggingFaceResp err: %v", err) - return - } - - if resp == "" { - t.Error("GetHuggingFaceResp err: resp is nil") - return - } - - t.Logf("resp: %v", resp) -} - -func TestHuggingFaceModelProvider_QueryText(t *testing.T) { - proxy.InitHttpClient() - - prompt := "Hello AI. Who are you?" - secretKey := "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" - - p, err := ai.NewHuggingFaceModelProvider(secretKey) - if err != nil { - t.Errorf("NewHuggingFaceModelProvider err: %v", err) - return - } - - err = p.QueryText(prompt, nil, nil) - if err != nil { - t.Errorf("QueryText err: %v", err) - return - } - - t.Logf("QueryText success") -} diff --git a/ai/model.go b/ai/model.go index 579c2bf..62f5963 100644 --- a/ai/model.go +++ b/ai/model.go @@ -30,6 +30,12 @@ func GetModelProvider(typ string, subType string, secretKey string) (ModelProvid return nil, err } return p, nil + } else if typ == "Hugging Face" { + p, err := NewHuggingFaceModelProvider(subType, secretKey) + if err != nil { + return nil, err + } + return p, nil } return nil, nil diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index 7a318af..89eecb4 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -118,6 +118,7 @@ class ProviderEditPage extends React.Component { { [ {id: "OpenAI API", name: "OpenAI API"}, + {id: "Hugging Face", name: "Hugging Face"}, ].map((item, index) => ) } diff --git a/web/src/ProviderListPage.js b/web/src/ProviderListPage.js index 0625df1..dce1477 100644 --- a/web/src/ProviderListPage.js +++ b/web/src/ProviderListPage.js @@ -124,7 +124,7 @@ class ProviderListPage extends React.Component { title: i18next.t("provider:Category"), dataIndex: "category", key: "category", - width: "160px", + width: "140px", sorter: (a, b) => a.category.localeCompare(b.category), }, { @@ -138,7 +138,7 @@ class ProviderListPage extends React.Component { title: i18next.t("provider:Sub type"), dataIndex: "subType", key: "subType", - width: "160px", + width: "200px", sorter: (a, b) => a.subType.localeCompare(b.subType), }, { diff --git a/web/src/Setting.js b/web/src/Setting.js index f78ae92..def6854 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -682,6 +682,17 @@ export function getProviderSubTypeOptions(type) { {id: "babbage", name: "babbage"}, ] ); + } else if (type === "Hugging Face") { + return ( + [ + {id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"}, + {id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"}, + {id: "bigscience/bloom", name: "bigscience/bloom"}, + {id: "gpt2", name: "gpt2"}, + {id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"}, + {id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"}, + ] + ); } else { return []; }