feat: support Ernie model parameters (#647)
This commit is contained in:
parent
de2c9cb679
commit
65e5fdecca
|
@ -26,13 +26,23 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ErnieModelProvider struct {
|
type ErnieModelProvider struct {
|
||||||
subType string
|
subType string
|
||||||
apiKey string
|
apiKey string
|
||||||
secretKey string
|
secretKey string
|
||||||
|
temperature float32
|
||||||
|
topP float32
|
||||||
|
presencePenalty float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewErnieModelProvider(subType string, apiKey string, secretKey string) (*ErnieModelProvider, error) {
|
func NewErnieModelProvider(subType string, apiKey string, secretKey string, temperature float32, topP float32, presencePenalty float32) (*ErnieModelProvider, error) {
|
||||||
return &ErnieModelProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil
|
return &ErnieModelProvider{
|
||||||
|
subType: subType,
|
||||||
|
apiKey: apiKey,
|
||||||
|
secretKey: secretKey,
|
||||||
|
temperature: temperature,
|
||||||
|
topP: topP,
|
||||||
|
presencePenalty: presencePenalty,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
|
func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
|
||||||
|
@ -59,8 +69,18 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
temperature := p.temperature
|
||||||
|
topP := p.topP
|
||||||
|
presencePenalty := p.presencePenalty
|
||||||
|
|
||||||
if p.subType == "ERNIE-Bot" {
|
if p.subType == "ERNIE-Bot" {
|
||||||
stream, err := client.CreateErnieBotChatCompletionStream(ctx, ernie.ErnieBotRequest{Messages: messages})
|
stream, err := client.CreateErnieBotChatCompletionStream(ctx,
|
||||||
|
ernie.ErnieBotRequest{
|
||||||
|
Messages: messages,
|
||||||
|
Temperature: temperature,
|
||||||
|
TopP: topP,
|
||||||
|
PresencePenalty: presencePenalty,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -82,7 +102,13 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if p.subType == "ERNIE-Bot-turbo" {
|
} else if p.subType == "ERNIE-Bot-turbo" {
|
||||||
stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx, ernie.ErnieBotTurboRequest{Messages: messages})
|
stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx,
|
||||||
|
ernie.ErnieBotTurboRequest{
|
||||||
|
Messages: messages,
|
||||||
|
Temperature: temperature,
|
||||||
|
TopP: topP,
|
||||||
|
PresencePenalty: presencePenalty,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret
|
||||||
} else if typ == "OpenRouter" {
|
} else if typ == "OpenRouter" {
|
||||||
p, err = NewOpenRouterModelProvider(subType, clientSecret)
|
p, err = NewOpenRouterModelProvider(subType, clientSecret)
|
||||||
} else if typ == "Ernie" {
|
} else if typ == "Ernie" {
|
||||||
p, err = NewErnieModelProvider(subType, clientId, clientSecret)
|
p, err = NewErnieModelProvider(subType, clientId, clientSecret, temperature, topP, presencePenalty)
|
||||||
} else if typ == "iFlytek" {
|
} else if typ == "iFlytek" {
|
||||||
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK)
|
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK)
|
||||||
} else if typ == "ChatGLM" {
|
} else if typ == "ChatGLM" {
|
||||||
|
|
|
@ -315,6 +315,58 @@ class ProviderEditPage extends React.Component {
|
||||||
</>
|
</>
|
||||||
) : null
|
) : null
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
(this.state.provider.category === "Model" && this.state.provider.type === "Ernie") ? (
|
||||||
|
<>
|
||||||
|
<Row style={{marginTop: "20px"}}>
|
||||||
|
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||||
|
{i18next.t("provider:Temperature")}:
|
||||||
|
</Col>
|
||||||
|
<this.InputSlider
|
||||||
|
min={0.01}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
value={this.state.provider.temperature}
|
||||||
|
onChange={(value) => {
|
||||||
|
this.updateProviderField("temperature", value);
|
||||||
|
}}
|
||||||
|
isMobile={Setting.isMobile()}
|
||||||
|
/>
|
||||||
|
</Row>
|
||||||
|
<Row style={{marginTop: "20px"}}>
|
||||||
|
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||||
|
{i18next.t("provider:Top P")}:
|
||||||
|
</Col>
|
||||||
|
<this.InputSlider
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
value={this.state.provider.topP}
|
||||||
|
onChange={(value) => {
|
||||||
|
this.updateProviderField("topP", value);
|
||||||
|
}}
|
||||||
|
isMobile={Setting.isMobile()}
|
||||||
|
/>
|
||||||
|
</Row>
|
||||||
|
<Row style={{marginTop: "20px"}}>
|
||||||
|
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||||
|
{i18next.t("provider:Presence penalty")}:
|
||||||
|
</Col>
|
||||||
|
<this.InputSlider
|
||||||
|
label={i18next.t("provider:Presence penalty")}
|
||||||
|
min={1}
|
||||||
|
max={2}
|
||||||
|
step={0.01}
|
||||||
|
value={this.state.provider.presencePenalty}
|
||||||
|
onChange={(value) => {
|
||||||
|
this.updateProviderField("presencePenalty", value);
|
||||||
|
}}
|
||||||
|
isMobile={Setting.isMobile()}
|
||||||
|
/>
|
||||||
|
</Row>
|
||||||
|
</>
|
||||||
|
) : null
|
||||||
|
}
|
||||||
<Row style={{marginTop: "20px"}} >
|
<Row style={{marginTop: "20px"}} >
|
||||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||||
{i18next.t("general:Provider URL")}:
|
{i18next.t("general:Provider URL")}:
|
||||||
|
|
Loading…
Reference in New Issue