Refactor to addVectorsForStore()

This commit is contained in:
Yang Luo 2023-09-07 23:05:20 +08:00
parent 9020043669
commit ec9bcb6d78
4 changed files with 30 additions and 38 deletions

View File

@ -25,7 +25,7 @@ import (
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
func readCloserToString(f io.ReadCloser, fileName string) (string, error) { func ReadFileToString(f io.ReadCloser, fileName string) (string, error) {
fileType := docconv.MimeTypeByExtension(fileName) fileType := docconv.MimeTypeByExtension(fileName)
res, err := docconv.Convert(f, fileType, true) res, err := docconv.Convert(f, fileType, true)
if err != nil { if err != nil {
@ -38,7 +38,7 @@ func readCloserToString(f io.ReadCloser, fileName string) (string, error) {
return res.Body, nil return res.Body, nil
} }
func splitText(text string) []string { func SplitText(text string) []string {
const maxLength = 210 * 3 const maxLength = 210 * 3
var res []string var res []string
var temp string var temp string
@ -59,15 +59,6 @@ func splitText(text string) []string {
return res return res
} }
func GetSplitTxt(f io.ReadCloser, fileName string) []string {
text, err := readCloserToString(f, fileName)
if err != nil {
return nil
}
return splitText(text)
}
func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { func getEmbedding(authToken string, text string, timeout int) ([]float32, error) {
client := getProxyClientFromToken(authToken) client := getProxyClientFromToken(authToken)

View File

@ -132,11 +132,11 @@ func (c *ApiController) RefreshStoreVectors() {
return return
} }
success, err := object.RefreshStoreVectors(&store) ok, err := object.RefreshStoreVectors(&store)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
} }
c.ResponseOk(success) c.ResponseOk(ok)
} }

View File

@ -157,12 +157,6 @@ func RefreshStoreVectors(store *Store) (bool, error) {
} }
authToken := provider.ClientSecret authToken := provider.ClientSecret
success, err := setTextObjectVector(authToken, store.StorageProvider, "", store.Name) ok, err := addVectorsForStore(authToken, store.StorageProvider, "", store.Name)
if err != nil { return ok, err
return false, err
}
if !success {
return false, nil
}
return true, nil
} }

View File

@ -48,7 +48,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object {
return res return res
} }
func getTextFiles(provider string, prefix string) ([]*storage.Object, error) { func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, error) {
files, err := storage.ListObjects(provider, prefix) files, err := storage.ListObjects(provider, prefix)
if err != nil { if err != nil {
return nil, err return nil, err
@ -57,7 +57,7 @@ func getTextFiles(provider string, prefix string) ([]*storage.Object, error) {
return filterTextFiles(files), nil return filterTextFiles(files), nil
} }
func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) { func getObjectFile(object *storage.Object) (io.ReadCloser, error) {
resp, err := http.Get(object.Url) resp, err := http.Get(object.Url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -94,44 +94,51 @@ func addEmbeddedVector(authToken string, text string, storeName string, fileName
return AddVector(vector) return AddVector(vector)
} }
func setTextObjectVector(authToken string, provider string, key string, storeName string) (bool, error) { func addVectorsForStore(authToken string, provider string, key string, storeName string) (bool, error) {
lb := rate.NewLimiter(rate.Every(time.Minute), 3) timeLimiter := rate.NewLimiter(rate.Every(time.Minute), 3)
textObjects, err := getTextFiles(provider, key) objs, err := getFilteredFileObjects(provider, key)
if err != nil { if err != nil {
return false, err return false, err
} }
if len(textObjects) == 0 { if len(objs) == 0 {
return false, nil return false, nil
} }
for _, textObject := range textObjects { for _, obj := range objs {
readCloser, err := getObjectReadCloser(textObject) f, err := getObjectFile(obj)
if err != nil { if err != nil {
return false, err return false, err
} }
defer readCloser.Close() defer f.Close()
splitTxts := ai.GetSplitTxt(readCloser, textObject.Key) filename := obj.Key
for _, splitTxt := range splitTxts { text, err := ai.ReadFileToString(f, filename)
if lb.Allow() { if err != nil {
success, err := addEmbeddedVector(authToken, splitTxt, storeName, textObject.Key) return false, err
}
textSections := ai.SplitText(text)
for _, textSection := range textSections {
if timeLimiter.Allow() {
ok, err := addEmbeddedVector(authToken, textSection, storeName, obj.Key)
if err != nil { if err != nil {
return false, err return false, err
} }
if !success { if !ok {
return false, nil return false, nil
} }
} else { } else {
err := lb.Wait(context.Background()) err := timeLimiter.Wait(context.Background())
if err != nil { if err != nil {
return false, err return false, err
} }
success, err := addEmbeddedVector(authToken, splitTxt, storeName, textObject.Key)
ok, err := addEmbeddedVector(authToken, textSection, storeName, obj.Key)
if err != nil { if err != nil {
return false, err return false, err
} }
if !success { if !ok {
return false, nil return false, nil
} }
} }