Check vector's embedding provider
This commit is contained in:
parent
e6ff917c40
commit
b7acf39c17
|
@ -114,13 +114,13 @@ func (c *ApiController) GetMessageAnswer() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2)
|
_, modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.ResponseErrorStream(err.Error())
|
c.ResponseErrorStream(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2)
|
embeddingProvider, embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.ResponseErrorStream(err.Error())
|
c.ResponseErrorStream(err.Error())
|
||||||
return
|
return
|
||||||
|
@ -132,7 +132,7 @@ func (c *ApiController) GetMessageAnswer() {
|
||||||
|
|
||||||
question := questionMessage.Text
|
question := questionMessage.Text
|
||||||
|
|
||||||
knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question)
|
knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProvider, embeddingProviderObj, chat.Owner, question)
|
||||||
if err != nil && err.Error() != "no knowledge vectors found" {
|
if err != nil && err.Error() != "no knowledge vectors found" {
|
||||||
c.ResponseErrorStream(err.Error())
|
c.ResponseErrorStream(err.Error())
|
||||||
return
|
return
|
||||||
|
|
|
@ -32,14 +32,14 @@ func (c *ApiController) ResponseErrorStream(errorText string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) {
|
func getModelProviderFromContext(owner string, name string) (*object.Provider, model.ModelProvider, error) {
|
||||||
var providerName string
|
var providerName string
|
||||||
if name != "" {
|
if name != "" {
|
||||||
providerName = name
|
providerName = name
|
||||||
} else {
|
} else {
|
||||||
store, err := object.GetDefaultStore(owner)
|
store, err := object.GetDefaultStore(owner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if store != nil && store.ModelProvider != "" {
|
if store != nil && store.ModelProvider != "" {
|
||||||
|
@ -57,28 +57,28 @@ func getModelProviderFromContext(owner string, name string) (model.ModelProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider == nil && err == nil {
|
if provider == nil && err == nil {
|
||||||
return nil, fmt.Errorf("The model provider: %s is not found", providerName)
|
return nil, nil, fmt.Errorf("The model provider: %s is not found", providerName)
|
||||||
}
|
}
|
||||||
if provider.Category != "Model" || provider.ClientSecret == "" {
|
if provider.Category != "Model" || provider.ClientSecret == "" {
|
||||||
return nil, fmt.Errorf("The model provider: %s is invalid", providerName)
|
return nil, nil, fmt.Errorf("The model provider: %s is invalid", providerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
providerObj, err := provider.GetModelProvider()
|
providerObj, err := provider.GetModelProvider()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return providerObj, err
|
return provider, providerObj, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) {
|
func getEmbeddingProviderFromContext(owner string, name string) (*object.Provider, embedding.EmbeddingProvider, error) {
|
||||||
var providerName string
|
var providerName string
|
||||||
if name != "" {
|
if name != "" {
|
||||||
providerName = name
|
providerName = name
|
||||||
} else {
|
} else {
|
||||||
store, err := object.GetDefaultStore(owner)
|
store, err := object.GetDefaultStore(owner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if store != nil && store.EmbeddingProvider != "" {
|
if store != nil && store.EmbeddingProvider != "" {
|
||||||
|
@ -96,16 +96,16 @@ func getEmbeddingProviderFromContext(owner string, name string) (embedding.Embed
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider == nil && err == nil {
|
if provider == nil && err == nil {
|
||||||
return nil, fmt.Errorf("The embedding provider: %s is not found", providerName)
|
return nil, nil, fmt.Errorf("The embedding provider: %s is not found", providerName)
|
||||||
}
|
}
|
||||||
if provider.Category != "Embedding" || provider.ClientSecret == "" {
|
if provider.Category != "Embedding" || provider.ClientSecret == "" {
|
||||||
return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName)
|
return nil, nil, fmt.Errorf("The embedding provider: %s is invalid", providerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
providerObj, err := provider.GetEmbeddingProvider()
|
providerObj, err := provider.GetEmbeddingProvider()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return providerObj, err
|
return provider, providerObj, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,8 +161,8 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) {
|
func GetNearestKnowledge(embeddingProvider *Provider, embeddingProviderObj embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) {
|
||||||
qVector, err := queryVectorSafe(embeddingProvider, text)
|
qVector, err := queryVectorSafe(embeddingProviderObj, text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -183,6 +183,10 @@ func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner st
|
||||||
vectorScores := []VectorScore{}
|
vectorScores := []VectorScore{}
|
||||||
texts := []string{}
|
texts := []string{}
|
||||||
for _, vector := range vectors {
|
for _, vector := range vectors {
|
||||||
|
if embeddingProvider.Name != vector.Provider {
|
||||||
|
return "", nil, fmt.Errorf("The store's embedding provider: [%s] should equal to vector's embedding provider: [%s], vector = %v", embeddingProvider.Name, vector.Provider, vector)
|
||||||
|
}
|
||||||
|
|
||||||
vectorScores = append(vectorScores, VectorScore{
|
vectorScores = append(vectorScores, VectorScore{
|
||||||
Vector: vector.Name,
|
Vector: vector.Name,
|
||||||
Score: vector.Score,
|
Score: vector.Score,
|
||||||
|
|
Loading…
Reference in New Issue