From c47736c11c708fd9882a441194cf4fcc8ba86c16 Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Sat, 9 Sep 2023 22:37:07 +0800 Subject: [PATCH] Add GetSupportedFileTypes() --- object/vector_embedding.go | 30 ++++++------------------------ txt/text.go | 12 ++++++++---- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/object/vector_embedding.go b/object/vector_embedding.go index 7d928c0..6d8f428 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -17,8 +17,6 @@ package object import ( "context" "fmt" - "io" - "net/http" "path/filepath" "time" @@ -30,22 +28,19 @@ import ( ) func filterTextFiles(files []*storage.Object) []*storage.Object { - extSet := map[string]bool{ - ".txt": true, - ".md": true, - ".docx": true, - ".doc": false, - ".pdf": true, + fileTypes := txt.GetSupportedFileTypes() + fileTypeMap := map[string]bool{} + for _, fileType := range fileTypes { + fileTypeMap[fileType] = true } - var res []*storage.Object + res := []*storage.Object{} for _, file := range files { ext := filepath.Ext(file.Key) - if extSet[ext] { + if fileTypeMap[ext] { res = append(res, file) } } - return res } @@ -58,19 +53,6 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, return filterTextFiles(files), nil } -func getObjectFile(object *storage.Object) (io.ReadCloser, error) { - resp, err := http.Get(object.Url) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, fmt.Errorf("HTTP request failed with status code: %d", resp.StatusCode) - } - return resp.Body, nil -} - func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { embedding, err := ai.GetEmbeddingSafe(authToken, text) if err != nil { diff --git a/txt/text.go b/txt/text.go index 51c5365..cabf594 100644 --- a/txt/text.go +++ b/txt/text.go @@ -41,6 +41,10 @@ func GetTextSections(text string) []string { return res } +func GetSupportedFileTypes() []string { + return []string{".txt", ".md", ".docx", ".pdf"} +} + func GetParsedTextFromUrl(url string, ext string) (string, error) { path, err := getTempFilePathFromUrl(url) if err != nil { @@ -54,12 +58,12 @@ func GetParsedTextFromUrl(url string, ext string) (string, error) { }() var res string - if ext == ".pdf" { - res, err = getTextFromPdf(path) + if ext == ".txt" || ext == ".md" { + res, err = getTextFromPlain(path) } else if ext == ".docx" { res, err = getTextFromDocx(path) - } else if ext == ".md" || ext == ".txt" { - res, err = getTextFromPlain(path) + } else if ext == ".pdf" { + res, err = getTextFromPdf(path) } else { return "", fmt.Errorf("unsupported file type: %s", ext) }