Add queryVectorSafe()

This commit is contained in:
Yang Luo 2023-09-10 00:50:32 +08:00
parent 405eceb738
commit 6d7921a669
4 changed files with 32 additions and 53 deletions

View File

@ -16,7 +16,6 @@ package embedding
import (
"context"
"time"
"github.com/casbin/casibase/proxy"
"github.com/casbin/casibase/util"
@ -40,12 +39,9 @@ func getProxyClientFromToken(authToken string) *openai.Client {
return c
}
func (p *OpenAiEmbeddingProvider) QueryVector(text string, timeout int) ([]float32, error) {
func (p *OpenAiEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) {
client := getProxyClientFromToken(p.secretKey)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
defer cancel()
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: []string{text},
Model: openai.EmbeddingModel(util.ParseInt(p.subType)),

View File

@ -16,14 +16,10 @@ package embedding
import (
"context"
"fmt"
"time"
"github.com/sashabaranov/go-openai"
)
type EmbeddingProvider interface {
QueryVector(text string, timeout int) ([]float32, error)
QueryVector(text string, ctx context.Context) ([]float32, error)
}
func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) {
@ -38,41 +34,3 @@ func GetEmbeddingProvider(typ string, subType string, clientSecret string) (Embe
}
return p, nil
}
func getEmbedding(authToken string, text string, timeout int) ([]float32, error) {
client := getProxyClientFromToken(authToken)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
defer cancel()
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: []string{text},
Model: openai.AdaEmbeddingV2,
})
if err != nil {
return nil, err
}
return resp.Data[0].Embedding, nil
}
func GetEmbeddingSafe(authToken string, text string) ([]float32, error) {
var embedding []float32
var err error
for i := 0; i < 10; i++ {
embedding, err = getEmbedding(authToken, text, i)
if err != nil {
if i > 0 {
fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error())
}
} else {
break
}
}
if err != nil {
return nil, err
} else {
return embedding, nil
}
}

View File

@ -24,5 +24,5 @@ func GetRefinedQuestion(knowledge string, question string) string {
return fmt.Sprintf(`paragraph: %s
You are a reading comprehension expert. Please answer the following questions based on the provided content. The content may be in a different language from the questions, so you need to understand the content according to the language of the questions and ensure that your answers are translated into the same language as the questions:
Q1: %s`, knowledge, question)
Question: %s`, knowledge, question)
}

View File

@ -54,8 +54,7 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object,
}
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) {
data, err := embeddingProviderObj.QueryVector(text, 5)
// data, err := model.GetEmbeddingSafe(authToken, text)
data, err := queryVectorSafe(embeddingProviderObj, text)
if err != nil {
return false, err
}
@ -128,9 +127,35 @@ func getRelatedVectors(owner string) ([]*Vector, error) {
return vectors, nil
}
func queryVectorWithContext(embeddingProvider embedding.EmbeddingProvider, text string, timeout int) ([]float32, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
defer cancel()
return embeddingProvider.QueryVector(text, ctx)
}
func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string) ([]float32, error) {
var res []float32
var err error
for i := 0; i < 10; i++ {
res, err = queryVectorWithContext(embeddingProvider, text, i)
if err != nil {
if i > 0 {
fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error())
}
} else {
break
}
}
if err != nil {
return nil, err
} else {
return res, nil
}
}
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) {
qVector, err := embeddingProvider.QueryVector(text, 5)
// qVector, err := embedding.GetEmbeddingSafe(authToken, question)
qVector, err := queryVectorSafe(embeddingProvider, text)
if err != nil {
return "", err
}