fix: update knowledge retrieval function (#609)

* fix: fix redirect not working error

* feat: update knowledge retrieval function

* fix: optimized the logic of retrieve function
This commit is contained in:
Kelvin Chiu 2023-08-15 15:32:22 +08:00 committed by Yang Luo
parent 2ad4e9b5ba
commit 83a4697c9d
5 changed files with 108 additions and 9 deletions

View File

@ -139,3 +139,10 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil
return nil
}
func GetQuestionWithKnowledge(knowledge string, question string) string {
return fmt.Sprintf(`paragraph: %s
You are a reading comprehension expert. Please answer the following questions based on the provided content. The content may be in a different language from the questions, so you need to understand the content according to the language of the questions and ensure that your answers are translated into the same language as the questions:
Q1: %s`, knowledge, question)
}

View File

@ -25,7 +25,7 @@ import (
)
func splitTxt(f io.ReadCloser) []string {
const maxLength = 512 * 3
const maxLength = 210 * 3
scanner := bufio.NewScanner(f)
var res []string
var temp string
@ -51,14 +51,14 @@ func GetSplitTxt(f io.ReadCloser) []string {
return splitTxt(f)
}
func getEmbedding(authToken string, input []string, timeout int) ([]float32, error) {
func getEmbedding(authToken string, text string, timeout int) ([]float32, error) {
client := getProxyClientFromToken(authToken)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
defer cancel()
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: input,
Input: []string{text},
Model: openai.AdaEmbeddingV2,
})
if err != nil {
@ -68,11 +68,11 @@ func getEmbedding(authToken string, input []string, timeout int) ([]float32, err
return resp.Data[0].Embedding, nil
}
func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) {
func GetEmbeddingSafe(authToken string, text string) ([]float32, error) {
var embedding []float32
var err error
for i := 0; i < 10; i++ {
embedding, err = getEmbedding(authToken, input, i)
embedding, err = getEmbedding(authToken, text, i)
if err != nil {
if i > 0 {
fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error())
@ -88,3 +88,18 @@ func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) {
return embedding, nil
}
}
func GetNearestVectorIndex(target []float32, vectors [][]float32) int {
targetNorm := norm(target)
var res int
max := float32(-1.0)
for i, vector := range vectors {
similarity := cosineSimilarity(target, vector, targetNorm)
if similarity > max {
max = similarity
res = i
}
}
return res
}

View File

@ -14,7 +14,11 @@
package ai
import "github.com/pkoukk/tiktoken-go"
import (
"math"
"github.com/pkoukk/tiktoken-go"
)
func GetTokenSize(model string, prompt string) (int, error) {
tkm, err := tiktoken.EncodingForModel(model)
@ -26,3 +30,32 @@ func GetTokenSize(model string, prompt string) (int, error) {
res := len(token)
return res, nil
}
func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
dotProduct := dot(vec1, vec2)
vec2Norm := norm(vec2)
if vec2Norm == 0 {
return 0.0
}
return dotProduct / (vec1Norm * vec2Norm)
}
func dot(vec1, vec2 []float32) float32 {
if len(vec1) != len(vec2) {
panic("Vector lengths do not match")
}
dotProduct := float32(0.0)
for i := range vec1 {
dotProduct += vec1[i] * vec2[i]
}
return dotProduct
}
func norm(vec []float32) float32 {
normSquared := float32(0.0)
for _, val := range vec {
normSquared += val * val
}
return float32(math.Sqrt(float64(normSquared)))
}

View File

@ -149,10 +149,19 @@ func (c *ApiController) GetMessageAnswer() {
question := questionMessage.Text
var stringBuilder strings.Builder
fmt.Printf("Question: [%s]\n", questionMessage.Text)
nearestText, err := object.GetNearestVectorText(authToken, chat.Owner, question)
if err != nil {
c.ResponseErrorStream(err.Error())
return
}
realQuestion := ai.GetQuestionWithKnowledge(nearestText, question)
fmt.Printf("Question: [%s]\n", question)
fmt.Printf("Context: [%s]\n", nearestText)
fmt.Printf("Answer: [")
err = ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
err = ai.QueryAnswerStream(authToken, realQuestion, c.Ctx.ResponseWriter, &stringBuilder)
if err != nil {
c.ResponseErrorStream(err.Error())
return

View File

@ -62,7 +62,7 @@ func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) {
}
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) {
embedding, err := ai.GetEmbeddingSafe(authToken, []string{text})
embedding, err := ai.GetEmbeddingSafe(authToken, text)
if err != nil {
return false, err
}
@ -131,3 +131,38 @@ func setTxtObjectVector(authToken string, provider string, key string, storeName
return true, nil
}
func getRelatedVectors(owner string) ([]*Vector, error) {
vectors, err := GetVectors(owner)
if err != nil {
return nil, err
}
if len(vectors) == 0 {
return nil, fmt.Errorf("no knowledge vectors found")
}
return vectors, nil
}
func GetNearestVectorText(authToken string, owner string, question string) (string, error) {
qVector, err := ai.GetEmbeddingSafe(authToken, question)
if err != nil {
return "", err
}
if qVector == nil {
return "", fmt.Errorf("no qVector found")
}
vectors, err := getRelatedVectors(owner)
if err != nil {
return "", err
}
var nVectors [][]float32
for _, candidate := range vectors {
nVectors = append(nVectors, candidate.Data)
}
i := ai.GetNearestVectorIndex(qVector, nVectors)
return vectors[i].Text, nil
}