fix: update knowledge retrieval function (#609)
* fix: fix redirect not working error * feat: update knowledge retrieval function * fix: optimized the logic of retrieve function
This commit is contained in:
parent
2ad4e9b5ba
commit
83a4697c9d
7
ai/ai.go
7
ai/ai.go
|
@ -139,3 +139,10 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetQuestionWithKnowledge(knowledge string, question string) string {
|
||||
return fmt.Sprintf(`paragraph: %s
|
||||
|
||||
You are a reading comprehension expert. Please answer the following questions based on the provided content. The content may be in a different language from the questions, so you need to understand the content according to the language of the questions and ensure that your answers are translated into the same language as the questions:
|
||||
Q1: %s`, knowledge, question)
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
)
|
||||
|
||||
func splitTxt(f io.ReadCloser) []string {
|
||||
const maxLength = 512 * 3
|
||||
const maxLength = 210 * 3
|
||||
scanner := bufio.NewScanner(f)
|
||||
var res []string
|
||||
var temp string
|
||||
|
@ -51,14 +51,14 @@ func GetSplitTxt(f io.ReadCloser) []string {
|
|||
return splitTxt(f)
|
||||
}
|
||||
|
||||
func getEmbedding(authToken string, input []string, timeout int) ([]float32, error) {
|
||||
func getEmbedding(authToken string, text string, timeout int) ([]float32, error) {
|
||||
client := getProxyClientFromToken(authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
||||
Input: input,
|
||||
Input: []string{text},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -68,11 +68,11 @@ func getEmbedding(authToken string, input []string, timeout int) ([]float32, err
|
|||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) {
|
||||
func GetEmbeddingSafe(authToken string, text string) ([]float32, error) {
|
||||
var embedding []float32
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
embedding, err = getEmbedding(authToken, input, i)
|
||||
embedding, err = getEmbedding(authToken, text, i)
|
||||
if err != nil {
|
||||
if i > 0 {
|
||||
fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error())
|
||||
|
@ -88,3 +88,18 @@ func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) {
|
|||
return embedding, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetNearestVectorIndex(target []float32, vectors [][]float32) int {
|
||||
targetNorm := norm(target)
|
||||
|
||||
var res int
|
||||
max := float32(-1.0)
|
||||
for i, vector := range vectors {
|
||||
similarity := cosineSimilarity(target, vector, targetNorm)
|
||||
if similarity > max {
|
||||
max = similarity
|
||||
res = i
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
|
35
ai/util.go
35
ai/util.go
|
@ -14,7 +14,11 @@
|
|||
|
||||
package ai
|
||||
|
||||
import "github.com/pkoukk/tiktoken-go"
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
func GetTokenSize(model string, prompt string) (int, error) {
|
||||
tkm, err := tiktoken.EncodingForModel(model)
|
||||
|
@ -26,3 +30,32 @@ func GetTokenSize(model string, prompt string) (int, error) {
|
|||
res := len(token)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
|
||||
dotProduct := dot(vec1, vec2)
|
||||
vec2Norm := norm(vec2)
|
||||
if vec2Norm == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return dotProduct / (vec1Norm * vec2Norm)
|
||||
}
|
||||
|
||||
func dot(vec1, vec2 []float32) float32 {
|
||||
if len(vec1) != len(vec2) {
|
||||
panic("Vector lengths do not match")
|
||||
}
|
||||
|
||||
dotProduct := float32(0.0)
|
||||
for i := range vec1 {
|
||||
dotProduct += vec1[i] * vec2[i]
|
||||
}
|
||||
return dotProduct
|
||||
}
|
||||
|
||||
func norm(vec []float32) float32 {
|
||||
normSquared := float32(0.0)
|
||||
for _, val := range vec {
|
||||
normSquared += val * val
|
||||
}
|
||||
return float32(math.Sqrt(float64(normSquared)))
|
||||
}
|
||||
|
|
|
@ -149,10 +149,19 @@ func (c *ApiController) GetMessageAnswer() {
|
|||
question := questionMessage.Text
|
||||
var stringBuilder strings.Builder
|
||||
|
||||
fmt.Printf("Question: [%s]\n", questionMessage.Text)
|
||||
nearestText, err := object.GetNearestVectorText(authToken, chat.Owner, question)
|
||||
if err != nil {
|
||||
c.ResponseErrorStream(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
realQuestion := ai.GetQuestionWithKnowledge(nearestText, question)
|
||||
|
||||
fmt.Printf("Question: [%s]\n", question)
|
||||
fmt.Printf("Context: [%s]\n", nearestText)
|
||||
fmt.Printf("Answer: [")
|
||||
|
||||
err = ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
|
||||
err = ai.QueryAnswerStream(authToken, realQuestion, c.Ctx.ResponseWriter, &stringBuilder)
|
||||
if err != nil {
|
||||
c.ResponseErrorStream(err.Error())
|
||||
return
|
||||
|
|
|
@ -62,7 +62,7 @@ func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) {
|
|||
}
|
||||
|
||||
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) {
|
||||
embedding, err := ai.GetEmbeddingSafe(authToken, []string{text})
|
||||
embedding, err := ai.GetEmbeddingSafe(authToken, text)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -131,3 +131,38 @@ func setTxtObjectVector(authToken string, provider string, key string, storeName
|
|||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func getRelatedVectors(owner string) ([]*Vector, error) {
|
||||
vectors, err := GetVectors(owner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return nil, fmt.Errorf("no knowledge vectors found")
|
||||
}
|
||||
|
||||
return vectors, nil
|
||||
}
|
||||
|
||||
func GetNearestVectorText(authToken string, owner string, question string) (string, error) {
|
||||
qVector, err := ai.GetEmbeddingSafe(authToken, question)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if qVector == nil {
|
||||
return "", fmt.Errorf("no qVector found")
|
||||
}
|
||||
|
||||
vectors, err := getRelatedVectors(owner)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var nVectors [][]float32
|
||||
for _, candidate := range vectors {
|
||||
nVectors = append(nVectors, candidate.Data)
|
||||
}
|
||||
|
||||
i := ai.GetNearestVectorIndex(qVector, nVectors)
|
||||
return vectors[i].Text, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue