Add size to vector
This commit is contained in:
parent
4426466d4f
commit
abae92b963
|
@ -172,6 +172,15 @@ func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (store *Store) GetModelProvider() (*Provider, error) {
|
||||
if store.ModelProvider == "" {
|
||||
return GetDefaultModelProvider()
|
||||
}
|
||||
|
||||
providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider)
|
||||
return GetProvider(providerId)
|
||||
}
|
||||
|
||||
func (store *Store) GetEmbeddingProvider() (*Provider, error) {
|
||||
if store.EmbeddingProvider == "" {
|
||||
return GetDefaultEmbeddingProvider()
|
||||
|
@ -187,6 +196,11 @@ func RefreshStoreVectors(store *Store) (bool, error) {
|
|||
return false, err
|
||||
}
|
||||
|
||||
modelProvider, err := store.GetModelProvider()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
embeddingProvider, err := store.GetEmbeddingProvider()
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -197,6 +211,6 @@ func RefreshStoreVectors(store *Store) (bool, error) {
|
|||
return false, err
|
||||
}
|
||||
|
||||
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name)
|
||||
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name, modelProvider.SubType)
|
||||
return ok, err
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ type Vector struct {
|
|||
File string `xorm:"varchar(100)" json:"file"`
|
||||
Index int `json:"index"`
|
||||
Text string `xorm:"mediumtext" json:"text"`
|
||||
Size int `json:"size"`
|
||||
Score float32 `json:"score"`
|
||||
|
||||
Data []float32 `xorm:"mediumtext" json:"data"`
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/casbin/casibase/embedding"
|
||||
"github.com/casbin/casibase/model"
|
||||
"github.com/casbin/casibase/storage"
|
||||
"github.com/casbin/casibase/txt"
|
||||
"github.com/casbin/casibase/util"
|
||||
|
@ -45,7 +46,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object {
|
|||
return res
|
||||
}
|
||||
|
||||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string) (bool, error) {
|
||||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string, modelSubType string) (bool, error) {
|
||||
data, err := queryVectorSafe(embeddingProviderObj, text)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -56,6 +57,11 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
|
|||
displayName = text[:25]
|
||||
}
|
||||
|
||||
size, err := model.GetTokenSize(modelSubType, text)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
vector := &Vector{
|
||||
Owner: "admin",
|
||||
Name: fmt.Sprintf("vector_%s", util.GetRandomName()),
|
||||
|
@ -66,13 +72,14 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
|
|||
File: fileName,
|
||||
Index: index,
|
||||
Text: text,
|
||||
Size: size,
|
||||
Data: data,
|
||||
Dimension: len(data),
|
||||
}
|
||||
return AddVector(vector)
|
||||
}
|
||||
|
||||
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string) (bool, error) {
|
||||
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string, modelSubType string) (bool, error) {
|
||||
var affected bool
|
||||
|
||||
files, err := storageProviderObj.ListObjects(prefix)
|
||||
|
@ -106,7 +113,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
|
|||
|
||||
if timeLimiter.Allow() {
|
||||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
|
||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
|
||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
|
||||
} else {
|
||||
err = timeLimiter.Wait(context.Background())
|
||||
if err != nil {
|
||||
|
@ -114,7 +121,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
|
|||
}
|
||||
|
||||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
|
||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
|
||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2023 The casbin Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package object
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/casbin/casibase/model"
|
||||
)
|
||||
|
||||
func TestUpdateVectors(t *testing.T) {
|
||||
InitConfig()
|
||||
|
||||
vectors, err := GetGlobalVectors()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, vector := range vectors {
|
||||
if vector.Text != "" && vector.Size == 0 {
|
||||
vector.Size, err = model.GetTokenSize("text-davinci-003", vector.Text)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = UpdateVector(vector.GetId(), vector)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -132,6 +132,16 @@ class VectorEditPage extends React.Component {
|
|||
}} />
|
||||
</Col>
|
||||
</Row>
|
||||
<Row style={{marginTop: "20px"}} >
|
||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||
{i18next.t("vector:Size")}:
|
||||
</Col>
|
||||
<Col span={22} >
|
||||
<InputNumber disabled={true} value={this.state.vector.size} onChange={value => {
|
||||
this.updateVectorField("size", value);
|
||||
}} />
|
||||
</Col>
|
||||
</Row>
|
||||
<Row style={{marginTop: "20px"}} >
|
||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||
{i18next.t("vector:Dimension")}:
|
||||
|
|
|
@ -176,6 +176,13 @@ class VectorListPage extends React.Component {
|
|||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: i18next.t("vector:Size"),
|
||||
dataIndex: "size",
|
||||
key: "size",
|
||||
width: "80px",
|
||||
sorter: (a, b) => a.size - b.size,
|
||||
},
|
||||
{
|
||||
title: i18next.t("vector:Data"),
|
||||
dataIndex: "data",
|
||||
|
|
Loading…
Reference in New Issue