fix: add preset max tokens for OpenAI model provider (#637)
This commit is contained in:
parent
13d1659151
commit
1572e75bb9
|
@ -25,6 +25,31 @@ import (
|
|||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// https://pkg.go.dev/github.com/sashabaranov/go-openai@v1.12.0#pkg-constants
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
var __maxTokens = map[string]int{
|
||||
openai.GPT4: 8192,
|
||||
openai.GPT40613: 8192,
|
||||
openai.GPT432K: 32768,
|
||||
openai.GPT432K0613: 32768,
|
||||
openai.GPT40314: 8192,
|
||||
openai.GPT432K0314: 32768,
|
||||
openai.GPT3Dot5Turbo: 4097,
|
||||
openai.GPT3Dot5Turbo16K: 16385,
|
||||
openai.GPT3Dot5Turbo0613: 4097,
|
||||
openai.GPT3Dot5Turbo16K0613: 16385,
|
||||
openai.GPT3Dot5Turbo0301: 4097,
|
||||
openai.GPT3TextDavinci003: 4097,
|
||||
openai.GPT3TextDavinci002: 4097,
|
||||
openai.GPT3TextCurie001: 2049,
|
||||
openai.GPT3TextBabbage001: 2049,
|
||||
openai.GPT3TextAda001: 2049,
|
||||
openai.GPT3Davinci: 2049,
|
||||
openai.GPT3Curie: 2049,
|
||||
openai.GPT3Ada: 2049,
|
||||
openai.GPT3Babbage: 2049,
|
||||
}
|
||||
|
||||
type OpenAiModelProvider struct {
|
||||
subType string
|
||||
secretKey string
|
||||
|
@ -42,6 +67,15 @@ func getProxyClientFromToken(authToken string) *openai.Client {
|
|||
return c
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the max tokens for a given openai model.
|
||||
func (p *OpenAiModelProvider) GetMaxTokens() int {
|
||||
res, ok := __maxTokens[p.subType]
|
||||
if !ok {
|
||||
return 4097
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
|
||||
client := getProxyClientFromToken(p.secretKey)
|
||||
|
||||
|
@ -63,8 +97,7 @@ func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, build
|
|||
return err
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/models/gpt-3-5
|
||||
maxTokens := 4097 - promptTokens
|
||||
maxTokens := p.GetMaxTokens() - promptTokens
|
||||
|
||||
respStream, err := client.CreateCompletionStream(
|
||||
ctx,
|
||||
|
|
Loading…
Reference in New Issue