Add GetSupportedFileTypes()

This commit is contained in:
Yang Luo 2023-09-09 22:37:07 +08:00
parent c5828138cc
commit c47736c11c
2 changed files with 14 additions and 28 deletions

View File

@ -17,8 +17,6 @@ package object
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"path/filepath" "path/filepath"
"time" "time"
@ -30,22 +28,19 @@ import (
) )
func filterTextFiles(files []*storage.Object) []*storage.Object { func filterTextFiles(files []*storage.Object) []*storage.Object {
extSet := map[string]bool{ fileTypes := txt.GetSupportedFileTypes()
".txt": true, fileTypeMap := map[string]bool{}
".md": true, for _, fileType := range fileTypes {
".docx": true, fileTypeMap[fileType] = true
".doc": false,
".pdf": true,
} }
var res []*storage.Object res := []*storage.Object{}
for _, file := range files { for _, file := range files {
ext := filepath.Ext(file.Key) ext := filepath.Ext(file.Key)
if extSet[ext] { if fileTypeMap[ext] {
res = append(res, file) res = append(res, file)
} }
} }
return res return res
} }
@ -58,19 +53,6 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object,
return filterTextFiles(files), nil return filterTextFiles(files), nil
} }
func getObjectFile(object *storage.Object) (io.ReadCloser, error) {
resp, err := http.Get(object.Url)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("HTTP request failed with status code: %d", resp.StatusCode)
}
return resp.Body, nil
}
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) {
embedding, err := ai.GetEmbeddingSafe(authToken, text) embedding, err := ai.GetEmbeddingSafe(authToken, text)
if err != nil { if err != nil {

View File

@ -41,6 +41,10 @@ func GetTextSections(text string) []string {
return res return res
} }
func GetSupportedFileTypes() []string {
return []string{".txt", ".md", ".docx", ".pdf"}
}
func GetParsedTextFromUrl(url string, ext string) (string, error) { func GetParsedTextFromUrl(url string, ext string) (string, error) {
path, err := getTempFilePathFromUrl(url) path, err := getTempFilePathFromUrl(url)
if err != nil { if err != nil {
@ -54,12 +58,12 @@ func GetParsedTextFromUrl(url string, ext string) (string, error) {
}() }()
var res string var res string
if ext == ".pdf" { if ext == ".txt" || ext == ".md" {
res, err = getTextFromPdf(path) res, err = getTextFromPlain(path)
} else if ext == ".docx" { } else if ext == ".docx" {
res, err = getTextFromDocx(path) res, err = getTextFromDocx(path)
} else if ext == ".md" || ext == ".txt" { } else if ext == ".pdf" {
res, err = getTextFromPlain(path) res, err = getTextFromPdf(path)
} else { } else {
return "", fmt.Errorf("unsupported file type: %s", ext) return "", fmt.Errorf("unsupported file type: %s", ext)
} }