feat: add more embedding providers (#632)
This commit is contained in:
parent
6d84245786
commit
d62d609456
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2023 The casbin Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
cohereembedder "github.com/henomis/lingoose/embedder/cohere"
|
||||
)
|
||||
|
||||
type CohereEmbeddingProvider struct {
|
||||
subType string
|
||||
secretKey string
|
||||
}
|
||||
|
||||
func NewCohereEmbeddingProvider(subType string, secretKey string) (*CohereEmbeddingProvider, error) {
|
||||
return &CohereEmbeddingProvider{subType: subType, secretKey: secretKey}, nil
|
||||
}
|
||||
|
||||
func (c *CohereEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) {
|
||||
client := cohereembedder.New().WithModel(cohereembedder.EmbedderModel(c.subType)).WithAPIKey(c.secretKey)
|
||||
embed, err := client.Embed(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return float64ToFloat32(embed[0]), nil
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright 2023 The casbin Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ernie "github.com/anhao/go-ernie"
|
||||
)
|
||||
|
||||
type ErnieEmbeddingProvider struct {
|
||||
subType string
|
||||
apiKey string
|
||||
secretKey string
|
||||
}
|
||||
|
||||
func NewErnieEmbeddingProvider(subType string, apiKey string, secretKey string) (*ErnieEmbeddingProvider, error) {
|
||||
return &ErnieEmbeddingProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil
|
||||
}
|
||||
|
||||
func (e *ErnieEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) {
|
||||
client := ernie.NewDefaultClient(e.apiKey, e.secretKey)
|
||||
request := ernie.EmbeddingRequest{Input: []string{text}}
|
||||
embeddings, err := client.CreateEmbeddings(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return float64ToFloat32(embeddings.Data[0].Embedding), nil
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright 2023 The casbin Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/casbin/casibase/proxy"
|
||||
"github.com/henomis/lingoose/embedder/huggingface"
|
||||
)
|
||||
|
||||
type HuggingFaceEmbeddingProvider struct {
|
||||
subType string
|
||||
secretKey string
|
||||
}
|
||||
|
||||
func NewHuggingFaceEmbeddingProvider(subType string, secretKey string) (*HuggingFaceEmbeddingProvider, error) {
|
||||
return &HuggingFaceEmbeddingProvider{subType: subType, secretKey: secretKey}, nil
|
||||
}
|
||||
|
||||
func (h *HuggingFaceEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) {
|
||||
client := huggingfaceembedder.New().WithToken(h.secretKey).WithModel(h.subType).WithHTTPClient(proxy.ProxyHttpClient)
|
||||
embed, err := client.Embed(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return float64ToFloat32(embed[0]), nil
|
||||
}
|
||||
|
||||
func float64ToFloat32(slice []float64) []float32 {
|
||||
newSlice := make([]float32, len(slice))
|
||||
for i, v := range slice {
|
||||
newSlice[i] = float32(v)
|
||||
}
|
||||
return newSlice
|
||||
}
|
|
@ -22,11 +22,17 @@ type EmbeddingProvider interface {
|
|||
QueryVector(text string, ctx context.Context) ([]float32, error)
|
||||
}
|
||||
|
||||
func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) {
|
||||
func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string) (EmbeddingProvider, error) {
|
||||
var p EmbeddingProvider
|
||||
var err error
|
||||
if typ == "OpenAI" {
|
||||
p, err = NewOpenAiEmbeddingProvider(subType, clientSecret)
|
||||
} else if typ == "Hugging Face" {
|
||||
p, err = NewHuggingFaceEmbeddingProvider(subType, clientSecret)
|
||||
} else if typ == "Cohere" {
|
||||
p, err = NewCohereEmbeddingProvider(subType, clientSecret)
|
||||
} else if typ == "Ernie" {
|
||||
p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
2
go.mod
2
go.mod
|
@ -44,6 +44,8 @@ require (
|
|||
github.com/gomodule/redigo v2.0.0+incompatible // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/hashicorp/golang-lru v0.5.4 // indirect
|
||||
github.com/henomis/cohere-go v1.0.1 // indirect
|
||||
github.com/henomis/restclientgo v1.0.5 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -278,8 +278,12 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
|
|||
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
|
||||
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
|
||||
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
|
||||
github.com/henomis/cohere-go v1.0.1 h1:a47gIN29tqAl4yBTAT+BzQMjsWG94Fz07u9AE4Md+a8=
|
||||
github.com/henomis/cohere-go v1.0.1/go.mod h1:F6D33jlWle6pbGdf9Fm2bteaOOQOO1cQtFnlFfj+ZXY=
|
||||
github.com/henomis/lingoose v0.0.11-alpha1 h1:6iXcdewIdTDJCNg7AxZF6onobLEh0BPFyHYTKSV8bAw=
|
||||
github.com/henomis/lingoose v0.0.11-alpha1/go.mod h1:hOfRJswe3sA17uZSUJHJNrBiqPxEt2FM9wUFqFFOHSE=
|
||||
github.com/henomis/restclientgo v1.0.5 h1:xMuznJLagE8nGrmFPyBkzsDztJm2A7uMLNGMBY5iWSg=
|
||||
github.com/henomis/restclientgo v1.0.5/go.mod h1:xIeTCu2ZstvRn0fCukNpzXLN3m/kRTU0i0RwAbv7Zug=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
|
|
|
@ -190,7 +190,7 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
|
|||
}
|
||||
|
||||
func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) {
|
||||
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret)
|
||||
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -638,6 +638,9 @@ export function getProviderTypeOptions(category) {
|
|||
return (
|
||||
[
|
||||
{id: "OpenAI", name: "OpenAI"},
|
||||
{id: "Hugging Face", name: "Hugging Face"},
|
||||
{id: "Cohere", name: "Cohere"},
|
||||
{id: "Ernie", name: "Ernie"},
|
||||
]
|
||||
);
|
||||
} else {
|
||||
|
@ -701,16 +704,26 @@ export function getProviderSubTypeOptions(category, type) {
|
|||
return [];
|
||||
}
|
||||
} else if (type === "Hugging Face") {
|
||||
return (
|
||||
[
|
||||
{id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"},
|
||||
{id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"},
|
||||
{id: "bigscience/bloom", name: "bigscience/bloom"},
|
||||
{id: "gpt2", name: "gpt2"},
|
||||
{id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"},
|
||||
{id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"},
|
||||
]
|
||||
);
|
||||
if (category === "Model") {
|
||||
return (
|
||||
[
|
||||
{id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"},
|
||||
{id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"},
|
||||
{id: "bigscience/bloom", name: "bigscience/bloom"},
|
||||
{id: "gpt2", name: "gpt2"},
|
||||
{id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"},
|
||||
{id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"},
|
||||
]
|
||||
);
|
||||
} else if (category === "Embedding") {
|
||||
return (
|
||||
[
|
||||
{id: "sentence-transformers/all-MiniLM-L6-v2", name: "sentence-transformers/all-MiniLM-L6-v2"},
|
||||
]
|
||||
);
|
||||
} else {
|
||||
return [];
|
||||
}
|
||||
} else if (type === "OpenRouter") {
|
||||
return (
|
||||
[
|
||||
|
@ -737,12 +750,30 @@ export function getProviderSubTypeOptions(category, type) {
|
|||
]
|
||||
);
|
||||
} else if (type === "Ernie") {
|
||||
if (category === "Model") {
|
||||
return (
|
||||
[
|
||||
{id: "ERNIE-Bot", name: "ERNIE-Bot"},
|
||||
{id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"},
|
||||
{id: "BLOOMZ-7B", name: "BLOOMZ-7B"},
|
||||
{id: "Llama-2", name: "Llama-2"},
|
||||
]
|
||||
);
|
||||
} else if (category === "Embedding") {
|
||||
return (
|
||||
[
|
||||
{id: "default", name: "default"},
|
||||
]
|
||||
);
|
||||
} else {
|
||||
return [];
|
||||
}
|
||||
} else if (type === "Cohere") {
|
||||
return (
|
||||
[
|
||||
{id: "ERNIE-Bot", name: "ERNIE-Bot"},
|
||||
{id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"},
|
||||
{id: "BLOOMZ-7B", name: "BLOOMZ-7B"},
|
||||
{id: "Llama-2", name: "Llama-2"},
|
||||
{id: "embed-english-v2.0", name: "embed-english-v2.0"},
|
||||
{id: "embed-english-light-v2.0", name: "embed-english-light-v2.0"},
|
||||
{id: "embed-multilingual-v2.0", name: "embed-multilingual-v2.0"},
|
||||
]
|
||||
);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue