Add getModelProviderFromContext()
This commit is contained in:
parent
8c79294a7b
commit
bf22093cde
|
@ -78,6 +78,37 @@ func (c *ApiController) ResponseErrorStream(errorText string) {
|
|||
}
|
||||
}
|
||||
|
||||
func getModelProviderFromContext(owner string, name string) (*object.Provider, error) {
|
||||
var providerName string
|
||||
if name != "" {
|
||||
providerName = name
|
||||
} else {
|
||||
store, err := object.GetDefaultStore(owner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if store.ModelProvider != "" {
|
||||
providerName = store.ModelProvider
|
||||
}
|
||||
}
|
||||
|
||||
var provider *object.Provider
|
||||
var err error
|
||||
if providerName != "" {
|
||||
providerId := util.GetIdFromOwnerAndName(owner, providerName)
|
||||
provider, err = object.GetProvider(providerId)
|
||||
} else {
|
||||
provider, err = object.GetDefaultModelProvider()
|
||||
}
|
||||
|
||||
if provider == nil && err == nil {
|
||||
return nil, fmt.Errorf("The provider: %s is not found", providerName)
|
||||
} else {
|
||||
return provider, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ApiController) GetMessageAnswer() {
|
||||
id := c.Input().Get("id")
|
||||
|
||||
|
@ -124,20 +155,13 @@ func (c *ApiController) GetMessageAnswer() {
|
|||
return
|
||||
}
|
||||
|
||||
providerId := util.GetIdFromOwnerAndName(chat.Owner, chat.User2)
|
||||
provider, err := object.GetProvider(providerId)
|
||||
provider, err := getModelProviderFromContext(chat.Owner, chat.User2)
|
||||
if err != nil {
|
||||
c.ResponseError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if provider == nil {
|
||||
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is not found", providerId))
|
||||
return
|
||||
}
|
||||
|
||||
if provider.Category != "Model" || provider.ClientSecret == "" {
|
||||
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", providerId))
|
||||
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", provider.GetId()))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ func UpdateChat(id string, chat *Chat) (bool, error) {
|
|||
|
||||
func AddChat(chat *Chat) (bool, error) {
|
||||
if chat.Type == "AI" && chat.User2 == "" {
|
||||
provider, err := getDefaultModelProvider()
|
||||
provider, err := GetDefaultModelProvider()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ func GetProvider(id string) (*Provider, error) {
|
|||
return getProvider(owner, name)
|
||||
}
|
||||
|
||||
func getDefaultModelProvider() (*Provider, error) {
|
||||
func GetDefaultModelProvider() (*Provider, error) {
|
||||
provider := Provider{Owner: "admin", Category: "Model"}
|
||||
existed, err := adapter.engine.Get(&provider)
|
||||
if err != nil {
|
||||
|
|
|
@ -151,7 +151,7 @@ func (store *Store) GetId() string {
|
|||
}
|
||||
|
||||
func RefreshStoreVectors(store *Store) (bool, error) {
|
||||
provider, err := getDefaultModelProvider()
|
||||
provider, err := GetDefaultModelProvider()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ class ChatListPage extends React.Component {
|
|||
updatedTime: moment().format(),
|
||||
displayName: `New Chat - ${randomName}`,
|
||||
category: "Chat Category - 1",
|
||||
type: "Single",
|
||||
type: "AI",
|
||||
user1: `${this.props.account.owner}/${this.props.account.name}`,
|
||||
user2: "",
|
||||
users: [`${this.props.account.owner}/${this.props.account.name}`],
|
||||
|
|
Loading…
Reference in New Issue