Add GetSupportedFileTypes()

This commit is contained in:
Yang Luo 2023-09-09 22:37:07 +08:00
parent c5828138cc
commit c47736c11c
2 changed files with 14 additions and 28 deletions

View File

@ -17,8 +17,6 @@ package object
import (
"context"
"fmt"
"io"
"net/http"
"path/filepath"
"time"
@ -30,22 +28,19 @@ import (
)
func filterTextFiles(files []*storage.Object) []*storage.Object {
extSet := map[string]bool{
".txt": true,
".md": true,
".docx": true,
".doc": false,
".pdf": true,
fileTypes := txt.GetSupportedFileTypes()
fileTypeMap := map[string]bool{}
for _, fileType := range fileTypes {
fileTypeMap[fileType] = true
}
var res []*storage.Object
res := []*storage.Object{}
for _, file := range files {
ext := filepath.Ext(file.Key)
if extSet[ext] {
if fileTypeMap[ext] {
res = append(res, file)
}
}
return res
}
@ -58,19 +53,6 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object,
return filterTextFiles(files), nil
}
func getObjectFile(object *storage.Object) (io.ReadCloser, error) {
resp, err := http.Get(object.Url)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("HTTP request failed with status code: %d", resp.StatusCode)
}
return resp.Body, nil
}
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) {
embedding, err := ai.GetEmbeddingSafe(authToken, text)
if err != nil {

View File

@ -41,6 +41,10 @@ func GetTextSections(text string) []string {
return res
}
func GetSupportedFileTypes() []string {
return []string{".txt", ".md", ".docx", ".pdf"}
}
func GetParsedTextFromUrl(url string, ext string) (string, error) {
path, err := getTempFilePathFromUrl(url)
if err != nil {
@ -54,12 +58,12 @@ func GetParsedTextFromUrl(url string, ext string) (string, error) {
}()
var res string
if ext == ".pdf" {
res, err = getTextFromPdf(path)
if ext == ".txt" || ext == ".md" {
res, err = getTextFromPlain(path)
} else if ext == ".docx" {
res, err = getTextFromDocx(path)
} else if ext == ".md" || ext == ".txt" {
res, err = getTextFromPlain(path)
} else if ext == ".pdf" {
res, err = getTextFromPdf(path)
} else {
return "", fmt.Errorf("unsupported file type: %s", ext)
}