diff --git a/embedding/openai.go b/embedding/openai.go index de70991..22a67de 100644 --- a/embedding/openai.go +++ b/embedding/openai.go @@ -16,7 +16,6 @@ package embedding import ( "context" - "time" "github.com/casbin/casibase/proxy" "github.com/casbin/casibase/util" @@ -40,12 +39,9 @@ func getProxyClientFromToken(authToken string) *openai.Client { return c } -func (p *OpenAiEmbeddingProvider) QueryVector(text string, timeout int) ([]float32, error) { +func (p *OpenAiEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { client := getProxyClientFromToken(p.secretKey) - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) - defer cancel() - resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ Input: []string{text}, Model: openai.EmbeddingModel(util.ParseInt(p.subType)), diff --git a/embedding/provider.go b/embedding/provider.go index f643b71..a1a73e0 100644 --- a/embedding/provider.go +++ b/embedding/provider.go @@ -16,14 +16,10 @@ package embedding import ( "context" - "fmt" - "time" - - "github.com/sashabaranov/go-openai" ) type EmbeddingProvider interface { - QueryVector(text string, timeout int) ([]float32, error) + QueryVector(text string, ctx context.Context) ([]float32, error) } func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) { @@ -38,41 +34,3 @@ func GetEmbeddingProvider(typ string, subType string, clientSecret string) (Embe } return p, nil } - -func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { - client := getProxyClientFromToken(authToken) - - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) - defer cancel() - - resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ - Input: []string{text}, - Model: openai.AdaEmbeddingV2, - }) - if err != nil { - return nil, err - } - - return resp.Data[0].Embedding, nil -} - -func GetEmbeddingSafe(authToken string, text string) ([]float32, error) { - var embedding []float32 - var err error - for i := 0; i < 10; i++ { - embedding, err = getEmbedding(authToken, text, i) - if err != nil { - if i > 0 { - fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error()) - } - } else { - break - } - } - - if err != nil { - return nil, err - } else { - return embedding, nil - } -} diff --git a/object/message_ai.go b/object/message_ai.go index c06ff7b..e0c6229 100644 --- a/object/message_ai.go +++ b/object/message_ai.go @@ -24,5 +24,5 @@ func GetRefinedQuestion(knowledge string, question string) string { return fmt.Sprintf(`paragraph: %s You are a reading comprehension expert. Please answer the following questions based on the provided content. The content may be in a different language from the questions, so you need to understand the content according to the language of the questions and ensure that your answers are translated into the same language as the questions: -Q1: %s`, knowledge, question) +Question: %s`, knowledge, question) } diff --git a/object/vector_embedding.go b/object/vector_embedding.go index 32a2128..17dc540 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -54,8 +54,7 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, } func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) { - data, err := embeddingProviderObj.QueryVector(text, 5) - // data, err := model.GetEmbeddingSafe(authToken, text) + data, err := queryVectorSafe(embeddingProviderObj, text) if err != nil { return false, err } @@ -128,9 +127,35 @@ func getRelatedVectors(owner string) ([]*Vector, error) { return vectors, nil } +func queryVectorWithContext(embeddingProvider embedding.EmbeddingProvider, text string, timeout int) ([]float32, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) + defer cancel() + return embeddingProvider.QueryVector(text, ctx) +} + +func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string) ([]float32, error) { + var res []float32 + var err error + for i := 0; i < 10; i++ { + res, err = queryVectorWithContext(embeddingProvider, text, i) + if err != nil { + if i > 0 { + fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error()) + } + } else { + break + } + } + + if err != nil { + return nil, err + } else { + return res, nil + } +} + func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) { - qVector, err := embeddingProvider.QueryVector(text, 5) - // qVector, err := embedding.GetEmbeddingSafe(authToken, question) + qVector, err := queryVectorSafe(embeddingProvider, text) if err != nil { return "", err }