Add getModelProviderFromContext()

This commit is contained in:
Yang Luo 2023-08-15 16:39:16 +08:00
parent 8c79294a7b
commit bf22093cde
5 changed files with 37 additions and 13 deletions

View File

@ -78,6 +78,37 @@ func (c *ApiController) ResponseErrorStream(errorText string) {
}
}
func getModelProviderFromContext(owner string, name string) (*object.Provider, error) {
var providerName string
if name != "" {
providerName = name
} else {
store, err := object.GetDefaultStore(owner)
if err != nil {
return nil, err
}
if store.ModelProvider != "" {
providerName = store.ModelProvider
}
}
var provider *object.Provider
var err error
if providerName != "" {
providerId := util.GetIdFromOwnerAndName(owner, providerName)
provider, err = object.GetProvider(providerId)
} else {
provider, err = object.GetDefaultModelProvider()
}
if provider == nil && err == nil {
return nil, fmt.Errorf("The provider: %s is not found", providerName)
} else {
return provider, err
}
}
func (c *ApiController) GetMessageAnswer() {
id := c.Input().Get("id")
@ -124,20 +155,13 @@ func (c *ApiController) GetMessageAnswer() {
return
}
providerId := util.GetIdFromOwnerAndName(chat.Owner, chat.User2)
provider, err := object.GetProvider(providerId)
provider, err := getModelProviderFromContext(chat.Owner, chat.User2)
if err != nil {
c.ResponseError(err.Error())
return
}
if provider == nil {
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is not found", providerId))
return
}
if provider.Category != "Model" || provider.ClientSecret == "" {
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", providerId))
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", provider.GetId()))
return
}

View File

@ -97,7 +97,7 @@ func UpdateChat(id string, chat *Chat) (bool, error) {
func AddChat(chat *Chat) (bool, error) {
if chat.Type == "AI" && chat.User2 == "" {
provider, err := getDefaultModelProvider()
provider, err := GetDefaultModelProvider()
if err != nil {
return false, err
}

View File

@ -100,7 +100,7 @@ func GetProvider(id string) (*Provider, error) {
return getProvider(owner, name)
}
func getDefaultModelProvider() (*Provider, error) {
func GetDefaultModelProvider() (*Provider, error) {
provider := Provider{Owner: "admin", Category: "Model"}
existed, err := adapter.engine.Get(&provider)
if err != nil {

View File

@ -151,7 +151,7 @@ func (store *Store) GetId() string {
}
func RefreshStoreVectors(store *Store) (bool, error) {
provider, err := getDefaultModelProvider()
provider, err := GetDefaultModelProvider()
if err != nil {
return false, err
}

View File

@ -55,7 +55,7 @@ class ChatListPage extends React.Component {
updatedTime: moment().format(),
displayName: `New Chat - ${randomName}`,
category: "Chat Category - 1",
type: "Single",
type: "AI",
user1: `${this.props.account.owner}/${this.props.account.name}`,
user2: "",
users: [`${this.props.account.owner}/${this.props.account.name}`],