diff --git a/object/vector_embedding.go b/object/vector_embedding.go index e9ad5f1..da675fa 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -28,14 +28,10 @@ import ( "golang.org/x/time/rate" ) -func isTxt(filename string) bool { - return strings.HasSuffix(filename, ".txt") -} - -func filterTxtFiles(files []*storage.Object) []*storage.Object { +func filterTextFiles(files []*storage.Object) []*storage.Object { var res []*storage.Object for _, file := range files { - if isTxt(file.Key) { + if strings.HasSuffix(file.Key, ".txt") || strings.HasSuffix(file.Key, ".md") { res = append(res, file) } } @@ -43,13 +39,13 @@ func filterTxtFiles(files []*storage.Object) []*storage.Object { return res } -func getTxtFiles(provider string, prefix string) ([]*storage.Object, error) { +func getTextFiles(provider string, prefix string) ([]*storage.Object, error) { files, err := storage.ListObjects(provider, prefix) if err != nil { return nil, err } - return filterTxtFiles(files), nil + return filterTextFiles(files), nil } func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) { @@ -92,7 +88,7 @@ func addEmbeddedVector(authToken string, text string, storeName string, fileName func setTxtObjectVector(authToken string, provider string, key string, storeName string) (bool, error) { lb := rate.NewLimiter(rate.Every(time.Minute), 3) - txtObjects, err := getTxtFiles(provider, key) + txtObjects, err := getTextFiles(provider, key) if err != nil { return false, err }