Add size to vector
This commit is contained in:
parent
4426466d4f
commit
abae92b963
|
@ -172,6 +172,15 @@ func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (store *Store) GetModelProvider() (*Provider, error) {
|
||||||
|
if store.ModelProvider == "" {
|
||||||
|
return GetDefaultModelProvider()
|
||||||
|
}
|
||||||
|
|
||||||
|
providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider)
|
||||||
|
return GetProvider(providerId)
|
||||||
|
}
|
||||||
|
|
||||||
func (store *Store) GetEmbeddingProvider() (*Provider, error) {
|
func (store *Store) GetEmbeddingProvider() (*Provider, error) {
|
||||||
if store.EmbeddingProvider == "" {
|
if store.EmbeddingProvider == "" {
|
||||||
return GetDefaultEmbeddingProvider()
|
return GetDefaultEmbeddingProvider()
|
||||||
|
@ -187,6 +196,11 @@ func RefreshStoreVectors(store *Store) (bool, error) {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelProvider, err := store.GetModelProvider()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
embeddingProvider, err := store.GetEmbeddingProvider()
|
embeddingProvider, err := store.GetEmbeddingProvider()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -197,6 +211,6 @@ func RefreshStoreVectors(store *Store) (bool, error) {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name)
|
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name, modelProvider.SubType)
|
||||||
return ok, err
|
return ok, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@ type Vector struct {
|
||||||
File string `xorm:"varchar(100)" json:"file"`
|
File string `xorm:"varchar(100)" json:"file"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
Text string `xorm:"mediumtext" json:"text"`
|
Text string `xorm:"mediumtext" json:"text"`
|
||||||
|
Size int `json:"size"`
|
||||||
Score float32 `json:"score"`
|
Score float32 `json:"score"`
|
||||||
|
|
||||||
Data []float32 `xorm:"mediumtext" json:"data"`
|
Data []float32 `xorm:"mediumtext" json:"data"`
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/casbin/casibase/embedding"
|
"github.com/casbin/casibase/embedding"
|
||||||
|
"github.com/casbin/casibase/model"
|
||||||
"github.com/casbin/casibase/storage"
|
"github.com/casbin/casibase/storage"
|
||||||
"github.com/casbin/casibase/txt"
|
"github.com/casbin/casibase/txt"
|
||||||
"github.com/casbin/casibase/util"
|
"github.com/casbin/casibase/util"
|
||||||
|
@ -45,7 +46,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string) (bool, error) {
|
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string, modelSubType string) (bool, error) {
|
||||||
data, err := queryVectorSafe(embeddingProviderObj, text)
|
data, err := queryVectorSafe(embeddingProviderObj, text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -56,6 +57,11 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
|
||||||
displayName = text[:25]
|
displayName = text[:25]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size, err := model.GetTokenSize(modelSubType, text)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
vector := &Vector{
|
vector := &Vector{
|
||||||
Owner: "admin",
|
Owner: "admin",
|
||||||
Name: fmt.Sprintf("vector_%s", util.GetRandomName()),
|
Name: fmt.Sprintf("vector_%s", util.GetRandomName()),
|
||||||
|
@ -66,13 +72,14 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
|
||||||
File: fileName,
|
File: fileName,
|
||||||
Index: index,
|
Index: index,
|
||||||
Text: text,
|
Text: text,
|
||||||
|
Size: size,
|
||||||
Data: data,
|
Data: data,
|
||||||
Dimension: len(data),
|
Dimension: len(data),
|
||||||
}
|
}
|
||||||
return AddVector(vector)
|
return AddVector(vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string) (bool, error) {
|
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string, modelSubType string) (bool, error) {
|
||||||
var affected bool
|
var affected bool
|
||||||
|
|
||||||
files, err := storageProviderObj.ListObjects(prefix)
|
files, err := storageProviderObj.ListObjects(prefix)
|
||||||
|
@ -106,7 +113,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
|
||||||
|
|
||||||
if timeLimiter.Allow() {
|
if timeLimiter.Allow() {
|
||||||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
|
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
|
||||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
|
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
|
||||||
} else {
|
} else {
|
||||||
err = timeLimiter.Wait(context.Background())
|
err = timeLimiter.Wait(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -114,7 +121,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
|
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
|
||||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
|
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
// Copyright 2023 The casbin Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package object
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/casbin/casibase/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateVectors(t *testing.T) {
|
||||||
|
InitConfig()
|
||||||
|
|
||||||
|
vectors, err := GetGlobalVectors()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, vector := range vectors {
|
||||||
|
if vector.Text != "" && vector.Size == 0 {
|
||||||
|
vector.Size, err = model.GetTokenSize("text-davinci-003", vector.Text)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = UpdateVector(vector.GetId(), vector)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -132,6 +132,16 @@ class VectorEditPage extends React.Component {
|
||||||
}} />
|
}} />
|
||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
|
<Row style={{marginTop: "20px"}} >
|
||||||
|
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||||
|
{i18next.t("vector:Size")}:
|
||||||
|
</Col>
|
||||||
|
<Col span={22} >
|
||||||
|
<InputNumber disabled={true} value={this.state.vector.size} onChange={value => {
|
||||||
|
this.updateVectorField("size", value);
|
||||||
|
}} />
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
<Row style={{marginTop: "20px"}} >
|
<Row style={{marginTop: "20px"}} >
|
||||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||||
{i18next.t("vector:Dimension")}:
|
{i18next.t("vector:Dimension")}:
|
||||||
|
|
|
@ -176,6 +176,13 @@ class VectorListPage extends React.Component {
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
title: i18next.t("vector:Size"),
|
||||||
|
dataIndex: "size",
|
||||||
|
key: "size",
|
||||||
|
width: "80px",
|
||||||
|
sorter: (a, b) => a.size - b.size,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
title: i18next.t("vector:Data"),
|
title: i18next.t("vector:Data"),
|
||||||
dataIndex: "data",
|
dataIndex: "data",
|
||||||
|
|
Loading…
Reference in New Issue