feat: support Azure model provider (#668)

* feat: support Azure model provider

* fix: convert apiVersion to string
This commit is contained in:
Kelvin Chiu 2023-10-08 15:57:52 +08:00 committed by GitHub
parent 0cdc87732e
commit e247f85dda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 188 additions and 54 deletions

55
model/azure_openai.go Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"github.com/casbin/casibase/proxy"
"github.com/sashabaranov/go-openai"
)
func NewAzureModelProvider(typ string, subType string, deploymentName string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string, apiVersion string) (*LocalModelProvider, error) {
p := &LocalModelProvider{
typ: typ,
subType: subType,
deploymentName: deploymentName,
secretKey: secretKey,
temperature: temperature,
topP: topP,
frequencyPenalty: frequencyPenalty,
presencePenalty: presencePenalty,
providerUrl: providerUrl,
apiVersion: apiVersion,
}
return p, nil
}
func getAzureClientFromToken(subtype string, deploymentName string, authToken string, url string, apiVersion string) *openai.Client {
config := openai.DefaultAzureConfig(authToken, url)
config.HTTPClient = proxy.ProxyHttpClient
if apiVersion != "" {
config.APIVersion = apiVersion
}
if deploymentName != "" {
config.AzureModelMapperFunc = func(model string) string {
azureModelMapping := map[string]string{
subtype: deploymentName,
}
return azureModelMapping[model]
}
}
c := openai.NewClientWithConfig(config)
return c
}

View File

@ -25,17 +25,21 @@ import (
) )
type LocalModelProvider struct { type LocalModelProvider struct {
typ string
subType string subType string
deploymentName string
secretKey string secretKey string
temperature float32 temperature float32
topP float32 topP float32
frequencyPenalty float32 frequencyPenalty float32
presencePenalty float32 presencePenalty float32
providerUrl string providerUrl string
apiVersion string
} }
func NewLocalModelProvider(subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { func NewLocalModelProvider(typ string, subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) {
p := &LocalModelProvider{ p := &LocalModelProvider{
typ: typ,
subType: subType, subType: subType,
secretKey: secretKey, secretKey: secretKey,
temperature: temperature, temperature: temperature,
@ -56,7 +60,12 @@ func getLocalClientFromUrl(authToken string, url string) *openai.Client {
} }
func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
client := getLocalClientFromUrl(p.secretKey, p.providerUrl) var client *openai.Client
if p.typ == "Local" {
client = getLocalClientFromUrl(p.secretKey, p.providerUrl)
} else if p.typ == "Azure" {
client = getAzureClientFromToken(p.subType, p.deploymentName, p.secretKey, p.providerUrl, p.apiVersion)
}
ctx := context.Background() ctx := context.Background()
flusher, ok := writer.(http.Flusher) flusher, ok := writer.(http.Flusher)
@ -70,11 +79,16 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builde
frequencyPenalty := p.frequencyPenalty frequencyPenalty := p.frequencyPenalty
presencePenalty := p.presencePenalty presencePenalty := p.presencePenalty
respStream, err := client.CreateCompletionStream( respStream, err := client.CreateChatCompletionStream(
ctx, ctx,
openai.CompletionRequest{ openai.ChatCompletionRequest{
Model: model, Model: model,
Prompt: question, Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: question,
},
},
Stream: true, Stream: true,
Temperature: temperature, Temperature: temperature,
TopP: topP, TopP: topP,
@ -97,7 +111,7 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builde
return streamErr return streamErr
} }
data := completion.Choices[0].Text data := completion.Choices[0].Delta.Content
if isLeadingReturn && len(data) != 0 { if isLeadingReturn && len(data) != 0 {
if strings.Count(data, "\n") == len(data) { if strings.Count(data, "\n") == len(data) {
continue continue

View File

@ -23,7 +23,7 @@ type ModelProvider interface {
QueryText(question string, writer io.Writer, builder *strings.Builder) error QueryText(question string, writer io.Writer, builder *strings.Builder) error
} }
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string) (ModelProvider, error) { func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string, apiVersion string) (ModelProvider, error) {
var p ModelProvider var p ModelProvider
var err error var err error
if typ == "OpenAI" { if typ == "OpenAI" {
@ -43,7 +43,9 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret
} else if typ == "MiniMax" { } else if typ == "MiniMax" {
p, err = NewMiniMaxModelProvider(subType, clientId, clientSecret, temperature) p, err = NewMiniMaxModelProvider(subType, clientId, clientSecret, temperature)
} else if typ == "Local" { } else if typ == "Local" {
p, err = NewLocalModelProvider(subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) p, err = NewLocalModelProvider(typ, subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl)
} else if typ == "Azure" {
p, err = NewAzureModelProvider(typ, subType, clientId, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl, apiVersion)
} }
if err != nil { if err != nil {

View File

@ -36,6 +36,7 @@ type Provider struct {
ClientId string `xorm:"varchar(100)" json:"clientId"` ClientId string `xorm:"varchar(100)" json:"clientId"`
ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"` ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"`
ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"` ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"`
ApiVersion string `xorm:"varchar(100)" json:"apiVersion"`
Temperature float32 `xorm:"float" json:"temperature"` Temperature float32 `xorm:"float" json:"temperature"`
TopP float32 `xorm:"float" json:"topP"` TopP float32 `xorm:"float" json:"topP"`
@ -211,7 +212,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) {
} }
func (p *Provider) GetModelProvider() (model.ModelProvider, error) { func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl) pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl, p.ApiVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -105,6 +105,14 @@ class ProviderEditPage extends React.Component {
); );
} }
handleTagChange = (key, value) => {
if (Array.isArray(value) && value.length > 0) {
this.updateProviderField(key, value[value.length - 1]);
} else {
this.updateProviderField(key, value);
}
};
renderProvider() { renderProvider() {
return ( return (
<Card size="small" title={ <Card size="small" title={
@ -186,6 +194,8 @@ class ProviderEditPage extends React.Component {
this.updateProviderField("subType", "chatglm2-6b"); this.updateProviderField("subType", "chatglm2-6b");
} else if (value === "Local") { } else if (value === "Local") {
this.updateProviderField("subType", "custom-model"); this.updateProviderField("subType", "custom-model");
} else if (value === "Azure") {
this.updateProviderField("subType", "gpt-4");
} }
} else if (this.state.provider.category === "Embedding") { } else if (this.state.provider.category === "Embedding") {
if (value === "OpenAI") { if (value === "OpenAI") {
@ -480,6 +490,35 @@ class ProviderEditPage extends React.Component {
</> </>
) : null ) : null
} }
{
((this.state.provider.category === "Model") && this.state.provider.type === "Azure") ? (
<>
<Row style={{marginTop: "20px"}}>
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("provider:Deployment Name")}:
</Col>
<Col span={22} >
<Input value={this.state.provider.clientId} onChange={e => {
this.updateProviderField("clientId", e.target.value);
}} />
</Col>
</Row>
<Row style={{marginTop: "20px"}}>
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("provider:API Version")}:
</Col>
<Col span={22} >
<Select virtual={false} mode="tags" style={{width: "100%"}}
value={this.state.provider.apiVersion}
onSelect={(value) => {this.handleTagChange("apiVersion", value);}}
onChange={(value) => {this.handleTagChange("apiVersion", value);}}
options={Setting.getProviderAzureApiVersionOptions().map((item) => Setting.getOption(item.name, item.id))}
/>
</Col>
</Row>
</>
) : null
}
<Row style={{marginTop: "20px"}} > <Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("general:Provider URL")}: {i18next.t("general:Provider URL")}:

View File

@ -64,6 +64,7 @@ class ProviderListPage extends React.Component {
frequencyPenalty: 0, frequencyPenalty: 0,
presencePenalty: 0, presencePenalty: 0,
providerUrl: "https://platform.openai.com/account/api-keys", providerUrl: "https://platform.openai.com/account/api-keys",
apiVersion: "",
}; };
} }

View File

@ -648,6 +648,7 @@ export function getProviderTypeOptions(category) {
{id: "ChatGLM", name: "ChatGLM"}, {id: "ChatGLM", name: "ChatGLM"},
{id: "MiniMax", name: "MiniMax"}, {id: "MiniMax", name: "MiniMax"},
{id: "Local", name: "Local"}, {id: "Local", name: "Local"},
{id: "Azure", name: "Azure"},
] ]
); );
} else if (category === "Embedding") { } else if (category === "Embedding") {
@ -665,57 +666,61 @@ export function getProviderTypeOptions(category) {
} }
} }
const openaiModels = [
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"},
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"},
{id: "gpt-4-32k", name: "gpt-4-32k"},
{id: "gpt-4-0613", name: "gpt-4-0613"},
{id: "gpt-4-0314", name: "gpt-4-0314"},
{id: "gpt-4", name: "gpt-4"},
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"},
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"},
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"},
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"},
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"},
{id: "text-davinci-003", name: "text-davinci-003"},
{id: "text-davinci-002", name: "text-davinci-002"},
{id: "text-curie-001", name: "text-curie-001"},
{id: "text-babbage-001", name: "text-babbage-001"},
{id: "text-ada-001", name: "text-ada-001"},
{id: "text-davinci-001", name: "text-davinci-001"},
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"},
{id: "davinci", name: "davinci"},
{id: "curie-instruct-beta", name: "curie-instruct-beta"},
{id: "curie", name: "curie"},
{id: "ada", name: "ada"},
{id: "babbage", name: "babbage"},
];
const openaiEmbeddings = [
{id: "1", name: "AdaSimilarity"},
{id: "2", name: "BabbageSimilarity"},
{id: "3", name: "CurieSimilarity"},
{id: "4", name: "DavinciSimilarity"},
{id: "5", name: "AdaSearchDocument"},
{id: "6", name: "AdaSearchQuery"},
{id: "7", name: "BabbageSearchDocument"},
{id: "8", name: "BabbageSearchQuery"},
{id: "9", name: "CurieSearchDocument"},
{id: "10", name: "CurieSearchQuery"},
{id: "11", name: "DavinciSearchDocument"},
{id: "12", name: "DavinciSearchQuery"},
{id: "13", name: "AdaCodeSearchCode"},
{id: "14", name: "AdaCodeSearchText"},
{id: "15", name: "BabbageCodeSearchCode"},
{id: "16", name: "BabbageCodeSearchText"},
{id: "17", name: "AdaEmbeddingV2"},
];
export function getProviderSubTypeOptions(category, type) { export function getProviderSubTypeOptions(category, type) {
if (type === "OpenAI") { if (type === "OpenAI") {
if (category === "Model") { if (category === "Model") {
return ( return (
[ openaiModels
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"},
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"},
{id: "gpt-4-32k", name: "gpt-4-32k"},
{id: "gpt-4-0613", name: "gpt-4-0613"},
{id: "gpt-4-0314", name: "gpt-4-0314"},
{id: "gpt-4", name: "gpt-4"},
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"},
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"},
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"},
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"},
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"},
{id: "text-davinci-003", name: "text-davinci-003"},
{id: "text-davinci-002", name: "text-davinci-002"},
{id: "text-curie-001", name: "text-curie-001"},
{id: "text-babbage-001", name: "text-babbage-001"},
{id: "text-ada-001", name: "text-ada-001"},
{id: "text-davinci-001", name: "text-davinci-001"},
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"},
{id: "davinci", name: "davinci"},
{id: "curie-instruct-beta", name: "curie-instruct-beta"},
{id: "curie", name: "curie"},
{id: "ada", name: "ada"},
{id: "babbage", name: "babbage"},
]
); );
} else if (category === "Embedding") { } else if (category === "Embedding") {
return ( return (
[ openaiEmbeddings
{id: "1", name: "AdaSimilarity"},
{id: "2", name: "BabbageSimilarity"},
{id: "3", name: "CurieSimilarity"},
{id: "4", name: "DavinciSimilarity"},
{id: "5", name: "AdaSearchDocument"},
{id: "6", name: "AdaSearchQuery"},
{id: "7", name: "BabbageSearchDocument"},
{id: "8", name: "BabbageSearchQuery"},
{id: "9", name: "CurieSearchDocument"},
{id: "10", name: "CurieSearchQuery"},
{id: "11", name: "DavinciSearchDocument"},
{id: "12", name: "DavinciSearchQuery"},
{id: "13", name: "AdaCodeSearchCode"},
{id: "14", name: "AdaCodeSearchText"},
{id: "15", name: "BabbageCodeSearchCode"},
{id: "16", name: "BabbageCodeSearchText"},
{id: "17", name: "AdaEmbeddingV2"},
]
); );
} else { } else {
return []; return [];
@ -845,7 +850,24 @@ export function getProviderSubTypeOptions(category, type) {
} else { } else {
return []; return [];
} }
} else if (type === "Azure") {
if (category === "Model") {
return (
openaiModels
);
}
} else { } else {
return []; return [];
} }
} }
export function getProviderAzureApiVersionOptions() {
return ([
{id: "", name: ""},
{id: "2023-03-15-preview", name: "2023-03-15-preview"},
{id: "2023-05-15", name: "2023-05-15"},
{id: "2023-06-01-preview", name: "2023-06-01-preview"},
{id: "2023-07-01-preview", name: "2023-07-01-preview"},
{id: "2023-08-01-preview", name: "2023-08-01-preview"},
]);
}