diff --git a/ai/embedding.go b/ai/embedding.go index ff508f6..b92fa0d 100644 --- a/ai/embedding.go +++ b/ai/embedding.go @@ -25,7 +25,7 @@ import ( "github.com/sashabaranov/go-openai" ) -func readCloserToString(f io.ReadCloser, fileName string) (string, error) { +func ReadFileToString(f io.ReadCloser, fileName string) (string, error) { fileType := docconv.MimeTypeByExtension(fileName) res, err := docconv.Convert(f, fileType, true) if err != nil { @@ -38,7 +38,7 @@ func readCloserToString(f io.ReadCloser, fileName string) (string, error) { return res.Body, nil } -func splitText(text string) []string { +func SplitText(text string) []string { const maxLength = 210 * 3 var res []string var temp string @@ -59,15 +59,6 @@ func splitText(text string) []string { return res } -func GetSplitTxt(f io.ReadCloser, fileName string) []string { - text, err := readCloserToString(f, fileName) - if err != nil { - return nil - } - - return splitText(text) -} - func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { client := getProxyClientFromToken(authToken) diff --git a/controllers/store.go b/controllers/store.go index 28928b4..118d173 100644 --- a/controllers/store.go +++ b/controllers/store.go @@ -132,11 +132,11 @@ func (c *ApiController) RefreshStoreVectors() { return } - success, err := object.RefreshStoreVectors(&store) + ok, err := object.RefreshStoreVectors(&store) if err != nil { c.ResponseError(err.Error()) return } - c.ResponseOk(success) + c.ResponseOk(ok) } diff --git a/object/store.go b/object/store.go index 3f2b2c8..9236183 100644 --- a/object/store.go +++ b/object/store.go @@ -157,12 +157,6 @@ func RefreshStoreVectors(store *Store) (bool, error) { } authToken := provider.ClientSecret - success, err := setTextObjectVector(authToken, store.StorageProvider, "", store.Name) - if err != nil { - return false, err - } - if !success { - return false, nil - } - return true, nil + ok, err := addVectorsForStore(authToken, store.StorageProvider, "", store.Name) + return ok, err } diff --git a/object/vector_embedding.go b/object/vector_embedding.go index 5f4a493..7c380d4 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -48,7 +48,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object { return res } -func getTextFiles(provider string, prefix string) ([]*storage.Object, error) { +func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, error) { files, err := storage.ListObjects(provider, prefix) if err != nil { return nil, err @@ -57,7 +57,7 @@ func getTextFiles(provider string, prefix string) ([]*storage.Object, error) { return filterTextFiles(files), nil } -func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) { +func getObjectFile(object *storage.Object) (io.ReadCloser, error) { resp, err := http.Get(object.Url) if err != nil { return nil, err @@ -94,44 +94,51 @@ func addEmbeddedVector(authToken string, text string, storeName string, fileName return AddVector(vector) } -func setTextObjectVector(authToken string, provider string, key string, storeName string) (bool, error) { - lb := rate.NewLimiter(rate.Every(time.Minute), 3) +func addVectorsForStore(authToken string, provider string, key string, storeName string) (bool, error) { + timeLimiter := rate.NewLimiter(rate.Every(time.Minute), 3) - textObjects, err := getTextFiles(provider, key) + objs, err := getFilteredFileObjects(provider, key) if err != nil { return false, err } - if len(textObjects) == 0 { + if len(objs) == 0 { return false, nil } - for _, textObject := range textObjects { - readCloser, err := getObjectReadCloser(textObject) + for _, obj := range objs { + f, err := getObjectFile(obj) if err != nil { return false, err } - defer readCloser.Close() + defer f.Close() - splitTxts := ai.GetSplitTxt(readCloser, textObject.Key) - for _, splitTxt := range splitTxts { - if lb.Allow() { - success, err := addEmbeddedVector(authToken, splitTxt, storeName, textObject.Key) + filename := obj.Key + text, err := ai.ReadFileToString(f, filename) + if err != nil { + return false, err + } + + textSections := ai.SplitText(text) + for _, textSection := range textSections { + if timeLimiter.Allow() { + ok, err := addEmbeddedVector(authToken, textSection, storeName, obj.Key) if err != nil { return false, err } - if !success { + if !ok { return false, nil } } else { - err := lb.Wait(context.Background()) + err := timeLimiter.Wait(context.Background()) if err != nil { return false, err } - success, err := addEmbeddedVector(authToken, splitTxt, storeName, textObject.Key) + + ok, err := addEmbeddedVector(authToken, textSection, storeName, obj.Key) if err != nil { return false, err } - if !success { + if !ok { return false, nil } }