Get top 5 knowlesge

This commit is contained in:
Yang Luo 2023-09-30 16:54:57 +08:00
parent cdf5fcbfd8
commit 2c1a616ad6
8 changed files with 93 additions and 47 deletions

View File

@ -132,16 +132,16 @@ func (c *ApiController) GetMessageAnswer() {
question := questionMessage.Text
nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question)
knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question)
if err != nil && err.Error() != "no knowledge vectors found" {
c.ResponseErrorStream(err.Error())
return
}
realQuestion := object.GetRefinedQuestion(nearestText, question)
realQuestion := object.GetRefinedQuestion(knowledge, question)
fmt.Printf("Question: [%s]\n", question)
fmt.Printf("Context: [%s]\n", nearestText)
fmt.Printf("Knowledge: [%s]\n", knowledge)
// fmt.Printf("Refined Question: [%s]\n", realQuestion)
fmt.Printf("Answer: [")
@ -165,6 +165,7 @@ func (c *ApiController) GetMessageAnswer() {
answer := writer.String()
message.Text = answer
message.VectorScores = vectorScores
_, err = object.UpdateMessage(message.GetId(), message)
if err != nil {
c.ResponseErrorStream(err.Error())
@ -227,10 +228,11 @@ func (c *ApiController) AddMessage() {
Name: fmt.Sprintf("message_%s", util.GetRandomName()),
CreatedTime: util.GetCurrentTimeEx(message.CreatedTime),
// Organization: message.Organization,
Chat: message.Chat,
ReplyTo: message.GetId(),
Author: "AI",
Text: "",
Chat: message.Chat,
ReplyTo: message.GetId(),
Author: "AI",
Text: "",
VectorScores: []object.VectorScore{},
}
_, err = object.AddMessage(answerMessage)
if err != nil {

View File

@ -21,16 +21,22 @@ import (
"xorm.io/core"
)
type VectorScore struct {
Vector string `xorm:"varchar(100)" json:"vector"`
Score float32 `json:"score"`
}
type Message struct {
Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
Name string `xorm:"varchar(100) notnull pk" json:"name"`
CreatedTime string `xorm:"varchar(100)" json:"createdTime"`
// Organization string `xorm:"varchar(100)" json:"organization"`
Chat string `xorm:"varchar(100) index" json:"chat"`
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"`
Author string `xorm:"varchar(100)" json:"author"`
Text string `xorm:"mediumtext" json:"text"`
Chat string `xorm:"varchar(100) index" json:"chat"`
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"`
Author string `xorm:"varchar(100)" json:"author"`
Text string `xorm:"mediumtext" json:"text"`
VectorScores []VectorScore `xorm:"mediumtext" json:"vectorScores"`
}
func GetGlobalMessages() ([]*Message, error) {

View File

@ -15,7 +15,7 @@
package object
type SearchProvider interface {
Search(qVector []float32) (string, error)
Search(qVector []float32) ([]Vector, error)
}
func GetSearchProvider(typ string, owner string) (SearchProvider, error) {

View File

@ -22,17 +22,24 @@ func NewDefaultSearchProvider(owner string) (*DefaultSearchProvider, error) {
return &DefaultSearchProvider{owner: owner}, nil
}
func (p *DefaultSearchProvider) Search(qVector []float32) (string, error) {
func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) {
vectors, err := getRelatedVectors(p.owner)
if err != nil {
return "", err
return nil, err
}
var nVectors [][]float32
var vectorData [][]float32
for _, candidate := range vectors {
nVectors = append(nVectors, candidate.Data)
vectorData = append(vectorData, candidate.Data)
}
i := getNearestVectorIndex(qVector, nVectors)
return vectors[i].Text, nil
res := []Vector{}
similarities := getNearestVectors(qVector, vectorData, 5)
for _, similarity := range similarities {
vector := vectors[similarity.Index]
vector.Score = similarity.Similarity
res = append(res, *vector)
}
return res, nil
}

View File

@ -14,7 +14,10 @@
package object
import "math"
import (
"math"
"sort"
)
func dot(vec1, vec2 []float32) float32 {
if len(vec1) != len(vec2) {
@ -45,17 +48,27 @@ func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
return dotProduct / (vec1Norm * vec2Norm)
}
func getNearestVectorIndex(target []float32, vectors [][]float32) int {
type SimilarityIndex struct {
Similarity float32
Index int
}
func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex {
targetNorm := norm(target)
var res int
max := float32(-1.0)
similarities := []SimilarityIndex{}
for i, vector := range vectors {
similarity := cosineSimilarity(target, vector, targetNorm)
if similarity > max {
max = similarity
res = i
}
similarities = append(similarities, SimilarityIndex{similarity, i})
}
return res
sort.Slice(similarities, func(i, j int) bool {
return similarities[i].Similarity > similarities[j].Similarity
})
if len(vectors) < n {
n = len(vectors)
}
return similarities
}

View File

@ -29,13 +29,8 @@ func NewHnswSearchProvider() (*HnswSearchProvider, error) {
return &HnswSearchProvider{}, nil
}
func (p *HnswSearchProvider) Search(qVector []float32) (string, error) {
search, err := Index.Search(qVector)
if err != nil {
return "", err
}
return search.Text, nil
func (p *HnswSearchProvider) Search(qVector []float32) ([]Vector, error) {
return Index.Search(qVector)
}
var Index *HNSWIndex
@ -75,11 +70,16 @@ func (h *HNSWIndex) Add(name string, vector []float32) error {
return h.save()
}
func (h *HNSWIndex) Search(vector []float32) (*Vector, error) {
func (h *HNSWIndex) Search(vector []float32) ([]Vector, error) {
result := h.Hnsw.Search(vector, 100, 4)
item := result.Pop()
owner, name := util.GetOwnerAndNameFromId(h.IdToStr[item.ID])
return getVector(owner, name)
v, err := getVector(owner, name)
if err != nil {
return nil, err
}
return []Vector{*v}, nil
}
func (h *HNSWIndex) save() error {

View File

@ -26,12 +26,13 @@ type Vector struct {
Name string `xorm:"varchar(100) notnull pk" json:"name"`
CreatedTime string `xorm:"varchar(100)" json:"createdTime"`
DisplayName string `xorm:"varchar(100)" json:"displayName"`
Store string `xorm:"varchar(100)" json:"store"`
Provider string `xorm:"varchar(100)" json:"provider"`
File string `xorm:"varchar(100)" json:"file"`
Index int `json:"index"`
Text string `xorm:"mediumtext" json:"text"`
DisplayName string `xorm:"varchar(100)" json:"displayName"`
Store string `xorm:"varchar(100)" json:"store"`
Provider string `xorm:"varchar(100)" json:"provider"`
File string `xorm:"varchar(100)" json:"file"`
Index int `json:"index"`
Text string `xorm:"mediumtext" json:"text"`
Score float32 `json:"score"`
Data []float32 `xorm:"mediumtext" json:"data"`
Dimension int `json:"dimension"`

View File

@ -18,6 +18,7 @@ import (
"context"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/casbin/casibase/embedding"
@ -149,19 +150,35 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string)
}
}
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) {
func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) {
qVector, err := queryVectorSafe(embeddingProvider, text)
if err != nil {
return "", err
return "", nil, err
}
if qVector == nil {
return "", fmt.Errorf("no qVector found")
return "", nil, fmt.Errorf("no qVector found")
}
searchProvider, err := GetSearchProvider("Default", owner)
if err != nil {
return "", err
return "", nil, err
}
return searchProvider.Search(qVector)
vectors, err := searchProvider.Search(qVector)
if err != nil {
return "", nil, err
}
vectorScores := []VectorScore{}
texts := []string{}
for _, vector := range vectors {
vectorScores = append(vectorScores, VectorScore{
Vector: vector.Name,
Score: vector.Score,
})
texts = append(texts, vector.Text)
}
res := strings.Join(texts, "\n\n")
return res, vectorScores, nil
}