171 lines
3.7 KiB
Go
171 lines
3.7 KiB
Go
package planner
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
|
"github.com/joho/godotenv"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
const (
|
|
defaultTimeout = 60 * time.Second
|
|
)
|
|
|
|
type OpenAIInitConfig struct {
|
|
ReportURL string `json:"REPORT_SERVER_URL"`
|
|
Headers map[string]string `json:"defaultHeaders"`
|
|
}
|
|
|
|
const (
|
|
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
|
|
EnvOpenAIAPIKey = "OPENAI_API_KEY"
|
|
EnvModelName = "MIDSCENE_MODEL_NAME"
|
|
EnvOpenAIInitConfigJSON = "MIDSCENE_OPENAI_INIT_CONFIG_JSON"
|
|
EnvUseVLMUITars = "MIDSCENE_USE_VLM_UI_TARS"
|
|
)
|
|
|
|
// loadEnv loads environment variables from a file
|
|
func loadEnv(envPath string) error {
|
|
err := godotenv.Load(envPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Info().Str("path", envPath).Msg("load env success")
|
|
return nil
|
|
}
|
|
|
|
func GetEnvConfig(key string) string {
|
|
return os.Getenv(key)
|
|
}
|
|
|
|
func GetEnvConfigInJSON(key string) (map[string]interface{}, error) {
|
|
value := GetEnvConfig(key)
|
|
if value == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
var result map[string]interface{}
|
|
if err := json.Unmarshal([]byte(value), &result); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func GetEnvConfigInBool(key string) bool {
|
|
value := GetEnvConfig(key)
|
|
if value == "" {
|
|
return false
|
|
}
|
|
|
|
boolValue, _ := strconv.ParseBool(value)
|
|
return boolValue
|
|
}
|
|
|
|
// GetEnvConfigOrDefault get env config or default value
|
|
func GetEnvConfigOrDefault(key, defaultValue string) string {
|
|
value := GetEnvConfig(key)
|
|
if value == "" {
|
|
return defaultValue
|
|
}
|
|
return value
|
|
}
|
|
|
|
func GetEnvConfigInInt(key string, defaultValue int) int {
|
|
value := GetEnvConfig(key)
|
|
if value == "" {
|
|
return defaultValue
|
|
}
|
|
|
|
intValue, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
return defaultValue
|
|
}
|
|
return intValue
|
|
}
|
|
|
|
// CustomTransport is a custom RoundTripper that adds headers to every request
|
|
type CustomTransport struct {
|
|
Transport http.RoundTripper
|
|
Headers map[string]string
|
|
}
|
|
|
|
// RoundTrip executes a single HTTP transaction and adds custom headers
|
|
func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
for key, value := range c.Headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
return c.Transport.RoundTrip(req)
|
|
}
|
|
|
|
// GetModelConfig get OpenAI config
|
|
func GetModelConfig() (*openai.ChatModelConfig, error) {
|
|
envConfig := &OpenAIInitConfig{
|
|
Headers: make(map[string]string),
|
|
}
|
|
|
|
// read from JSON config first
|
|
jsonStr := GetEnvConfig(EnvOpenAIInitConfigJSON)
|
|
if jsonStr != "" {
|
|
if err := json.Unmarshal([]byte(jsonStr), envConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
config := &openai.ChatModelConfig{
|
|
HTTPClient: &http.Client{
|
|
Timeout: defaultTimeout,
|
|
Transport: &CustomTransport{
|
|
Transport: http.DefaultTransport,
|
|
Headers: envConfig.Headers,
|
|
},
|
|
},
|
|
}
|
|
|
|
if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {
|
|
config.BaseURL = baseURL
|
|
} else {
|
|
return nil, fmt.Errorf("miss env %s", EnvOpenAIBaseURL)
|
|
}
|
|
|
|
if apiKey := GetEnvConfig(EnvOpenAIAPIKey); apiKey != "" {
|
|
config.APIKey = apiKey
|
|
} else {
|
|
return nil, fmt.Errorf("miss env %s", EnvOpenAIAPIKey)
|
|
}
|
|
|
|
if modelName := GetEnvConfig(EnvModelName); modelName != "" {
|
|
config.Model = modelName
|
|
} else {
|
|
return nil, fmt.Errorf("miss env %s", EnvModelName)
|
|
}
|
|
|
|
// log config info
|
|
log.Info().Str("model", config.Model).
|
|
Str("baseURL", config.BaseURL).
|
|
Str("apiKey", maskAPIKey(config.APIKey)).
|
|
Str("timeout", defaultTimeout.String()).
|
|
Msg("get model config")
|
|
|
|
return config, nil
|
|
}
|
|
|
|
// maskAPIKey masks the API key
|
|
func maskAPIKey(key string) string {
|
|
if len(key) <= 8 {
|
|
return "******"
|
|
}
|
|
|
|
return key[:4] + "******" + key[len(key)-4:]
|
|
}
|
|
|
|
func IsUseVLMUITars() bool {
|
|
return GetEnvConfigInBool(EnvUseVLMUITars)
|
|
}
|