mirror of https://github.com/langgenius/dify.git
feat: Add support for TEI API key authentication (#11006)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
parent
16c41585e1
commit
096c0ad564
|
@ -34,3 +34,11 @@ model_credential_schema:
|
|||
placeholder:
|
||||
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
|
||||
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
|
|
|
@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
|||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
results = TeiHelper.invoke_rerank(server_url, query, docs)
|
||||
results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
|
||||
|
||||
rerank_documents = []
|
||||
for result in results:
|
||||
|
@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
|||
"""
|
||||
try:
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||
if extra_args.model_type != "reranker":
|
||||
raise CredentialsValidateFailedError("Current model is not a rerank model")
|
||||
|
||||
|
|
|
@ -26,13 +26,15 @@ cache_lock = Lock()
|
|||
|
||||
class TeiHelper:
|
||||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
def get_tei_extra_parameter(
|
||||
server_url: str, model_name: str, headers: Optional[dict] = None
|
||||
) -> TeiModelExtraParameter:
|
||||
TeiHelper._clean_cache()
|
||||
with cache_lock:
|
||||
if model_name not in cache:
|
||||
cache[model_name] = {
|
||||
"expires": time() + 300,
|
||||
"value": TeiHelper._get_tei_extra_parameter(server_url),
|
||||
"value": TeiHelper._get_tei_extra_parameter(server_url, headers),
|
||||
}
|
||||
return cache[model_name]["value"]
|
||||
|
||||
|
@ -47,7 +49,7 @@ class TeiHelper:
|
|||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
|
||||
def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
|
||||
"""
|
||||
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
||||
"""
|
||||
|
@ -61,7 +63,7 @@ class TeiHelper:
|
|||
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||
|
||||
try:
|
||||
response = session.get(url, timeout=10)
|
||||
response = session.get(url, headers=headers, timeout=10)
|
||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
|
||||
if response.status_code != 200:
|
||||
|
@ -86,7 +88,7 @@ class TeiHelper:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||
def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
|
||||
"""
|
||||
Invoke tokenize endpoint
|
||||
|
||||
|
@ -114,15 +116,15 @@ class TeiHelper:
|
|||
:param server_url: server url
|
||||
:param texts: texts to tokenize
|
||||
"""
|
||||
resp = httpx.post(
|
||||
f"{server_url}/tokenize",
|
||||
json={"inputs": texts},
|
||||
)
|
||||
url = f"{server_url}/tokenize"
|
||||
json_data = {"inputs": texts}
|
||||
resp = httpx.post(url, json=json_data, headers=headers)
|
||||
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@staticmethod
|
||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||
def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Invoke embeddings endpoint
|
||||
|
||||
|
@ -147,15 +149,14 @@ class TeiHelper:
|
|||
:param texts: texts to embed
|
||||
"""
|
||||
# Use OpenAI compatible API here, which has usage tracking
|
||||
resp = httpx.post(
|
||||
f"{server_url}/v1/embeddings",
|
||||
json={"input": texts},
|
||||
)
|
||||
url = f"{server_url}/v1/embeddings"
|
||||
json_data = {"input": texts}
|
||||
resp = httpx.post(url, json=json_data, headers=headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@staticmethod
|
||||
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
|
||||
def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
|
||||
"""
|
||||
Invoke rerank endpoint
|
||||
|
||||
|
@ -173,10 +174,7 @@ class TeiHelper:
|
|||
:param candidates: candidates to rerank
|
||||
"""
|
||||
params = {"query": query, "texts": docs, "return_text": True}
|
||||
|
||||
response = httpx.post(
|
||||
server_url + "/rerank",
|
||||
json=params,
|
||||
)
|
||||
url = f"{server_url}/rerank"
|
||||
response = httpx.post(url, json=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
|
|
@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials["api_key"]
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
used_tokens = 0
|
||||
|
||||
# get tokenized results from TEI
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||
|
||||
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
||||
# Check if the number of tokens is larger than the context size
|
||||
|
@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
used_tokens = 0
|
||||
for i in _iter:
|
||||
iter_texts = inputs[i : i + max_chunks]
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
|
||||
embeddings = results["data"]
|
||||
embeddings = [embedding["embedding"] for embedding in embeddings]
|
||||
batched_embeddings.extend(embeddings)
|
||||
|
@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('api_key')}",
|
||||
}
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||
num_tokens = sum(len(tokens) for tokens in batch_tokens)
|
||||
return num_tokens
|
||||
|
||||
|
@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
"""
|
||||
try:
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||
print(extra_args)
|
||||
if extra_args.model_type != "embedding":
|
||||
raise CredentialsValidateFailedError("Current model is not a embedding model")
|
||||
|
|
|
@ -20,6 +20,7 @@ env =
|
|||
OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
|
||||
TEI_RERANK_SERVER_URL = http://a.abc.com:11451
|
||||
TEI_API_KEY = ttttttttttttttt
|
||||
UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
|
||||
VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
|
||||
XINFERENCE_CHAT_MODEL_UID = chat
|
||||
|
|
|
@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||
model="reranker",
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock):
|
|||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
|
|
|
@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||
model="embedding",
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock):
|
|||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
|
|
Loading…
Reference in New Issue