Get top 5 knowlesge
This commit is contained in:
parent
cdf5fcbfd8
commit
2c1a616ad6
|
@ -132,16 +132,16 @@ func (c *ApiController) GetMessageAnswer() {
|
|||
|
||||
question := questionMessage.Text
|
||||
|
||||
nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question)
|
||||
knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question)
|
||||
if err != nil && err.Error() != "no knowledge vectors found" {
|
||||
c.ResponseErrorStream(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
realQuestion := object.GetRefinedQuestion(nearestText, question)
|
||||
realQuestion := object.GetRefinedQuestion(knowledge, question)
|
||||
|
||||
fmt.Printf("Question: [%s]\n", question)
|
||||
fmt.Printf("Context: [%s]\n", nearestText)
|
||||
fmt.Printf("Knowledge: [%s]\n", knowledge)
|
||||
// fmt.Printf("Refined Question: [%s]\n", realQuestion)
|
||||
fmt.Printf("Answer: [")
|
||||
|
||||
|
@ -165,6 +165,7 @@ func (c *ApiController) GetMessageAnswer() {
|
|||
answer := writer.String()
|
||||
|
||||
message.Text = answer
|
||||
message.VectorScores = vectorScores
|
||||
_, err = object.UpdateMessage(message.GetId(), message)
|
||||
if err != nil {
|
||||
c.ResponseErrorStream(err.Error())
|
||||
|
@ -227,10 +228,11 @@ func (c *ApiController) AddMessage() {
|
|||
Name: fmt.Sprintf("message_%s", util.GetRandomName()),
|
||||
CreatedTime: util.GetCurrentTimeEx(message.CreatedTime),
|
||||
// Organization: message.Organization,
|
||||
Chat: message.Chat,
|
||||
ReplyTo: message.GetId(),
|
||||
Author: "AI",
|
||||
Text: "",
|
||||
Chat: message.Chat,
|
||||
ReplyTo: message.GetId(),
|
||||
Author: "AI",
|
||||
Text: "",
|
||||
VectorScores: []object.VectorScore{},
|
||||
}
|
||||
_, err = object.AddMessage(answerMessage)
|
||||
if err != nil {
|
||||
|
|
|
@ -21,16 +21,22 @@ import (
|
|||
"xorm.io/core"
|
||||
)
|
||||
|
||||
type VectorScore struct {
|
||||
Vector string `xorm:"varchar(100)" json:"vector"`
|
||||
Score float32 `json:"score"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
|
||||
Name string `xorm:"varchar(100) notnull pk" json:"name"`
|
||||
CreatedTime string `xorm:"varchar(100)" json:"createdTime"`
|
||||
|
||||
// Organization string `xorm:"varchar(100)" json:"organization"`
|
||||
Chat string `xorm:"varchar(100) index" json:"chat"`
|
||||
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"`
|
||||
Author string `xorm:"varchar(100)" json:"author"`
|
||||
Text string `xorm:"mediumtext" json:"text"`
|
||||
Chat string `xorm:"varchar(100) index" json:"chat"`
|
||||
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"`
|
||||
Author string `xorm:"varchar(100)" json:"author"`
|
||||
Text string `xorm:"mediumtext" json:"text"`
|
||||
VectorScores []VectorScore `xorm:"mediumtext" json:"vectorScores"`
|
||||
}
|
||||
|
||||
func GetGlobalMessages() ([]*Message, error) {
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
package object
|
||||
|
||||
type SearchProvider interface {
|
||||
Search(qVector []float32) (string, error)
|
||||
Search(qVector []float32) ([]Vector, error)
|
||||
}
|
||||
|
||||
func GetSearchProvider(typ string, owner string) (SearchProvider, error) {
|
||||
|
|
|
@ -22,17 +22,24 @@ func NewDefaultSearchProvider(owner string) (*DefaultSearchProvider, error) {
|
|||
return &DefaultSearchProvider{owner: owner}, nil
|
||||
}
|
||||
|
||||
func (p *DefaultSearchProvider) Search(qVector []float32) (string, error) {
|
||||
func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) {
|
||||
vectors, err := getRelatedVectors(p.owner)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var nVectors [][]float32
|
||||
var vectorData [][]float32
|
||||
for _, candidate := range vectors {
|
||||
nVectors = append(nVectors, candidate.Data)
|
||||
vectorData = append(vectorData, candidate.Data)
|
||||
}
|
||||
|
||||
i := getNearestVectorIndex(qVector, nVectors)
|
||||
return vectors[i].Text, nil
|
||||
res := []Vector{}
|
||||
similarities := getNearestVectors(qVector, vectorData, 5)
|
||||
for _, similarity := range similarities {
|
||||
vector := vectors[similarity.Index]
|
||||
vector.Score = similarity.Similarity
|
||||
res = append(res, *vector)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
|
|
@ -14,7 +14,10 @@
|
|||
|
||||
package object
|
||||
|
||||
import "math"
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
func dot(vec1, vec2 []float32) float32 {
|
||||
if len(vec1) != len(vec2) {
|
||||
|
@ -45,17 +48,27 @@ func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
|
|||
return dotProduct / (vec1Norm * vec2Norm)
|
||||
}
|
||||
|
||||
func getNearestVectorIndex(target []float32, vectors [][]float32) int {
|
||||
type SimilarityIndex struct {
|
||||
Similarity float32
|
||||
Index int
|
||||
}
|
||||
|
||||
func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex {
|
||||
targetNorm := norm(target)
|
||||
|
||||
var res int
|
||||
max := float32(-1.0)
|
||||
similarities := []SimilarityIndex{}
|
||||
for i, vector := range vectors {
|
||||
similarity := cosineSimilarity(target, vector, targetNorm)
|
||||
if similarity > max {
|
||||
max = similarity
|
||||
res = i
|
||||
}
|
||||
similarities = append(similarities, SimilarityIndex{similarity, i})
|
||||
}
|
||||
return res
|
||||
|
||||
sort.Slice(similarities, func(i, j int) bool {
|
||||
return similarities[i].Similarity > similarities[j].Similarity
|
||||
})
|
||||
|
||||
if len(vectors) < n {
|
||||
n = len(vectors)
|
||||
}
|
||||
|
||||
return similarities
|
||||
}
|
|
@ -29,13 +29,8 @@ func NewHnswSearchProvider() (*HnswSearchProvider, error) {
|
|||
return &HnswSearchProvider{}, nil
|
||||
}
|
||||
|
||||
func (p *HnswSearchProvider) Search(qVector []float32) (string, error) {
|
||||
search, err := Index.Search(qVector)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return search.Text, nil
|
||||
func (p *HnswSearchProvider) Search(qVector []float32) ([]Vector, error) {
|
||||
return Index.Search(qVector)
|
||||
}
|
||||
|
||||
var Index *HNSWIndex
|
||||
|
@ -75,11 +70,16 @@ func (h *HNSWIndex) Add(name string, vector []float32) error {
|
|||
return h.save()
|
||||
}
|
||||
|
||||
func (h *HNSWIndex) Search(vector []float32) (*Vector, error) {
|
||||
func (h *HNSWIndex) Search(vector []float32) ([]Vector, error) {
|
||||
result := h.Hnsw.Search(vector, 100, 4)
|
||||
item := result.Pop()
|
||||
|
||||
owner, name := util.GetOwnerAndNameFromId(h.IdToStr[item.ID])
|
||||
return getVector(owner, name)
|
||||
v, err := getVector(owner, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []Vector{*v}, nil
|
||||
}
|
||||
|
||||
func (h *HNSWIndex) save() error {
|
||||
|
|
|
@ -26,12 +26,13 @@ type Vector struct {
|
|||
Name string `xorm:"varchar(100) notnull pk" json:"name"`
|
||||
CreatedTime string `xorm:"varchar(100)" json:"createdTime"`
|
||||
|
||||
DisplayName string `xorm:"varchar(100)" json:"displayName"`
|
||||
Store string `xorm:"varchar(100)" json:"store"`
|
||||
Provider string `xorm:"varchar(100)" json:"provider"`
|
||||
File string `xorm:"varchar(100)" json:"file"`
|
||||
Index int `json:"index"`
|
||||
Text string `xorm:"mediumtext" json:"text"`
|
||||
DisplayName string `xorm:"varchar(100)" json:"displayName"`
|
||||
Store string `xorm:"varchar(100)" json:"store"`
|
||||
Provider string `xorm:"varchar(100)" json:"provider"`
|
||||
File string `xorm:"varchar(100)" json:"file"`
|
||||
Index int `json:"index"`
|
||||
Text string `xorm:"mediumtext" json:"text"`
|
||||
Score float32 `json:"score"`
|
||||
|
||||
Data []float32 `xorm:"mediumtext" json:"data"`
|
||||
Dimension int `json:"dimension"`
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/casbin/casibase/embedding"
|
||||
|
@ -149,19 +150,35 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string)
|
|||
}
|
||||
}
|
||||
|
||||
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) {
|
||||
func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) {
|
||||
qVector, err := queryVectorSafe(embeddingProvider, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
if qVector == nil {
|
||||
return "", fmt.Errorf("no qVector found")
|
||||
return "", nil, fmt.Errorf("no qVector found")
|
||||
}
|
||||
|
||||
searchProvider, err := GetSearchProvider("Default", owner)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return searchProvider.Search(qVector)
|
||||
vectors, err := searchProvider.Search(qVector)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
vectorScores := []VectorScore{}
|
||||
texts := []string{}
|
||||
for _, vector := range vectors {
|
||||
vectorScores = append(vectorScores, VectorScore{
|
||||
Vector: vector.Name,
|
||||
Score: vector.Score,
|
||||
})
|
||||
texts = append(texts, vector.Text)
|
||||
}
|
||||
|
||||
res := strings.Join(texts, "\n\n")
|
||||
return res, vectorScores, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue