Add size to vector

This commit is contained in:
Yang Luo 2023-10-01 10:05:52 +08:00
parent 4426466d4f
commit abae92b963
6 changed files with 88 additions and 5 deletions

View File

@ -172,6 +172,15 @@ func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) {
}
}
func (store *Store) GetModelProvider() (*Provider, error) {
if store.ModelProvider == "" {
return GetDefaultModelProvider()
}
providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider)
return GetProvider(providerId)
}
func (store *Store) GetEmbeddingProvider() (*Provider, error) {
if store.EmbeddingProvider == "" {
return GetDefaultEmbeddingProvider()
@ -187,6 +196,11 @@ func RefreshStoreVectors(store *Store) (bool, error) {
return false, err
}
modelProvider, err := store.GetModelProvider()
if err != nil {
return false, err
}
embeddingProvider, err := store.GetEmbeddingProvider()
if err != nil {
return false, err
@ -197,6 +211,6 @@ func RefreshStoreVectors(store *Store) (bool, error) {
return false, err
}
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name)
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name, modelProvider.SubType)
return ok, err
}

View File

@ -32,6 +32,7 @@ type Vector struct {
File string `xorm:"varchar(100)" json:"file"`
Index int `json:"index"`
Text string `xorm:"mediumtext" json:"text"`
Size int `json:"size"`
Score float32 `json:"score"`
Data []float32 `xorm:"mediumtext" json:"data"`

View File

@ -22,6 +22,7 @@ import (
"time"
"github.com/casbin/casibase/embedding"
"github.com/casbin/casibase/model"
"github.com/casbin/casibase/storage"
"github.com/casbin/casibase/txt"
"github.com/casbin/casibase/util"
@ -45,7 +46,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object {
return res
}
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string) (bool, error) {
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string, modelSubType string) (bool, error) {
data, err := queryVectorSafe(embeddingProviderObj, text)
if err != nil {
return false, err
@ -56,6 +57,11 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
displayName = text[:25]
}
size, err := model.GetTokenSize(modelSubType, text)
if err != nil {
return false, err
}
vector := &Vector{
Owner: "admin",
Name: fmt.Sprintf("vector_%s", util.GetRandomName()),
@ -66,13 +72,14 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
File: fileName,
Index: index,
Text: text,
Size: size,
Data: data,
Dimension: len(data),
}
return AddVector(vector)
}
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string) (bool, error) {
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string, modelSubType string) (bool, error) {
var affected bool
files, err := storageProviderObj.ListObjects(prefix)
@ -106,7 +113,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
if timeLimiter.Allow() {
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
} else {
err = timeLimiter.Wait(context.Background())
if err != nil {
@ -114,7 +121,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
}
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
}
}
}

44
object/vector_test.go Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package object
import (
"testing"
"github.com/casbin/casibase/model"
)
func TestUpdateVectors(t *testing.T) {
InitConfig()
vectors, err := GetGlobalVectors()
if err != nil {
panic(err)
}
for _, vector := range vectors {
if vector.Text != "" && vector.Size == 0 {
vector.Size, err = model.GetTokenSize("text-davinci-003", vector.Text)
if err != nil {
panic(err)
}
_, err = UpdateVector(vector.GetId(), vector)
if err != nil {
panic(err)
}
}
}
}

View File

@ -132,6 +132,16 @@ class VectorEditPage extends React.Component {
}} />
</Col>
</Row>
<Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("vector:Size")}:
</Col>
<Col span={22} >
<InputNumber disabled={true} value={this.state.vector.size} onChange={value => {
this.updateVectorField("size", value);
}} />
</Col>
</Row>
<Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("vector:Dimension")}:

View File

@ -176,6 +176,13 @@ class VectorListPage extends React.Component {
);
},
},
{
title: i18next.t("vector:Size"),
dataIndex: "size",
key: "size",
width: "80px",
sorter: (a, b) => a.size - b.size,
},
{
title: i18next.t("vector:Data"),
dataIndex: "data",