diff --git a/go.mod b/go.mod index 146def4..ee1cc60 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/google/uuid v1.3.0 github.com/henomis/lingoose v0.0.11-alpha1 github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 + github.com/leverly/ChatGLM v1.2.0 github.com/lib/pq v1.10.2 github.com/madebywelch/anthropic-go v1.0.1 github.com/muesli/clusters v0.0.0-20200529215643-2700303c1762 @@ -36,6 +37,7 @@ require ( github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect github.com/dlclark/regexp2 v1.8.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/goccy/go-json v0.9.11 // indirect diff --git a/go.sum b/go.sum index 53f4d17..5f6ee56 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= @@ -378,6 +379,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6/go.mod h1:n931TsDuKuq+uX4v1fulaMbA/7ZLLhjc85h7chZGBCQ= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= +github.com/leverly/ChatGLM v1.2.0 h1:CLXD0+sqUabADMwEDoZP22qoLInY9BjQl/jZ2hZG0M0= +github.com/leverly/ChatGLM v1.2.0/go.mod h1:DoxwOIyOup0Ct+dhm2FxKrakCB5AscJ0N0jkW+XuU8Q= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= diff --git a/model/chatglm.go b/model/chatglm.go new file mode 100644 index 0000000..06ff632 --- /dev/null +++ b/model/chatglm.go @@ -0,0 +1,65 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/leverly/ChatGLM/client" +) + +type ChatGLMModelProvider struct { + subType string + clientSecret string +} + +func NewChatGLMModelProvider(subType string, clientSecret string) (*ChatGLMModelProvider, error) { + return &ChatGLMModelProvider{subType: subType, clientSecret: clientSecret}, nil +} + +func (p *ChatGLMModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { + proxy := client.NewChatGLMClient(p.clientSecret, 30*time.Second) + prompt := []client.Message{{Role: "user", Content: question}} + taskId, err := proxy.AsyncInvoke(p.subType, 0.2, prompt) + if err != nil { + return err + } + flusher, ok := writer.(http.Flusher) + if !ok { + return fmt.Errorf("writer does not implement http.Flusher") + } + flushData := func(data string) error { + if _, err := fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil { + return err + } + flusher.Flush() + builder.WriteString(data) + return nil + } + response, err := proxy.AsyncInvokeTask(p.subType, taskId) + if err != nil { + return err + } + content := (*response.Choices)[0].Content + err = flushData(content) + if err != nil { + return err + } + return nil +} diff --git a/model/provider.go b/model/provider.go index 07c33b4..96fbd50 100644 --- a/model/provider.go +++ b/model/provider.go @@ -38,6 +38,8 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret p, err = NewErnieModelProvider(subType, clientId, clientSecret) } else if typ == "iFlytek" { p, err = NewiFlytekModelProvider(subType, clientSecret) + } else if typ == "ChatGLM" { + p, err = NewChatGLMModelProvider(subType, clientSecret) } if err != nil { diff --git a/web/src/Setting.js b/web/src/Setting.js index b32d6d9..397a141 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -645,6 +645,7 @@ export function getProviderTypeOptions(category) { {id: "OpenRouter", name: "OpenRouter"}, {id: "Ernie", name: "Ernie"}, {id: "iFlytek", name: "iFlytek"}, + {id: "ChatGLM", name: "ChatGLM"}, ] ); } else if (category === "Embedding") { @@ -813,6 +814,12 @@ export function getProviderSubTypeOptions(category, type) { {id: "spark-v2.0", name: "spark-v2.0"}, ] ); + } else if (type === "ChatGLM") { + return ( + [ + {id: "chatglm2-6b", name: "chatglm2-6b"}, + ] + ); } else { return []; }