Add queryVectorSafe()
This commit is contained in:
parent
405eceb738
commit
6d7921a669
|
@ -16,7 +16,6 @@ package embedding
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/casbin/casibase/proxy"
|
||||
"github.com/casbin/casibase/util"
|
||||
|
@ -40,12 +39,9 @@ func getProxyClientFromToken(authToken string) *openai.Client {
|
|||
return c
|
||||
}
|
||||
|
||||
func (p *OpenAiEmbeddingProvider) QueryVector(text string, timeout int) ([]float32, error) {
|
||||
func (p *OpenAiEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) {
|
||||
client := getProxyClientFromToken(p.secretKey)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
||||
Input: []string{text},
|
||||
Model: openai.EmbeddingModel(util.ParseInt(p.subType)),
|
||||
|
|
|
@ -16,14 +16,10 @@ package embedding
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type EmbeddingProvider interface {
|
||||
QueryVector(text string, timeout int) ([]float32, error)
|
||||
QueryVector(text string, ctx context.Context) ([]float32, error)
|
||||
}
|
||||
|
||||
func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) {
|
||||
|
@ -38,41 +34,3 @@ func GetEmbeddingProvider(typ string, subType string, clientSecret string) (Embe
|
|||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getEmbedding(authToken string, text string, timeout int) ([]float32, error) {
|
||||
client := getProxyClientFromToken(authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
||||
Input: []string{text},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
func GetEmbeddingSafe(authToken string, text string) ([]float32, error) {
|
||||
var embedding []float32
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
embedding, err = getEmbedding(authToken, text, i)
|
||||
if err != nil {
|
||||
if i > 0 {
|
||||
fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error())
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return embedding, nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,5 +24,5 @@ func GetRefinedQuestion(knowledge string, question string) string {
|
|||
return fmt.Sprintf(`paragraph: %s
|
||||
|
||||
You are a reading comprehension expert. Please answer the following questions based on the provided content. The content may be in a different language from the questions, so you need to understand the content according to the language of the questions and ensure that your answers are translated into the same language as the questions:
|
||||
Q1: %s`, knowledge, question)
|
||||
Question: %s`, knowledge, question)
|
||||
}
|
||||
|
|
|
@ -54,8 +54,7 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object,
|
|||
}
|
||||
|
||||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) {
|
||||
data, err := embeddingProviderObj.QueryVector(text, 5)
|
||||
// data, err := model.GetEmbeddingSafe(authToken, text)
|
||||
data, err := queryVectorSafe(embeddingProviderObj, text)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -128,9 +127,35 @@ func getRelatedVectors(owner string) ([]*Vector, error) {
|
|||
return vectors, nil
|
||||
}
|
||||
|
||||
func queryVectorWithContext(embeddingProvider embedding.EmbeddingProvider, text string, timeout int) ([]float32, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
|
||||
defer cancel()
|
||||
return embeddingProvider.QueryVector(text, ctx)
|
||||
}
|
||||
|
||||
func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string) ([]float32, error) {
|
||||
var res []float32
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
res, err = queryVectorWithContext(embeddingProvider, text, i)
|
||||
if err != nil {
|
||||
if i > 0 {
|
||||
fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error())
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) {
|
||||
qVector, err := embeddingProvider.QueryVector(text, 5)
|
||||
// qVector, err := embedding.GetEmbeddingSafe(authToken, question)
|
||||
qVector, err := queryVectorSafe(embeddingProvider, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue