feat: Add support for TEI API key authentication (#11006)

Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
kenwoodjw 2024-11-23 23:55:35 +08:00 committed by GitHub
parent 16c41585e1
commit 096c0ad564
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 26 deletions

View File

@ -34,3 +34,11 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
- variable: api_key
label:
en_US: API Key
type: secret-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

View File

@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
server_url = server_url.removesuffix("/")
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
results = TeiHelper.invoke_rerank(server_url, query, docs)
results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
rerank_documents = []
for result in results:
@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
"""
try:
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
if extra_args.model_type != "reranker":
raise CredentialsValidateFailedError("Current model is not a rerank model")

View File

@ -26,13 +26,15 @@ cache_lock = Lock()
class TeiHelper:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
def get_tei_extra_parameter(
server_url: str, model_name: str, headers: Optional[dict] = None
) -> TeiModelExtraParameter:
TeiHelper._clean_cache()
with cache_lock:
if model_name not in cache:
cache[model_name] = {
"expires": time() + 300,
"value": TeiHelper._get_tei_extra_parameter(server_url),
"value": TeiHelper._get_tei_extra_parameter(server_url, headers),
}
return cache[model_name]["value"]
@ -47,7 +49,7 @@ class TeiHelper:
pass
@staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
"""
get tei model extra parameter like model_type, max_input_length, max_batch_requests
"""
@ -61,7 +63,7 @@ class TeiHelper:
session.mount("https://", HTTPAdapter(max_retries=3))
try:
response = session.get(url, timeout=10)
response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
if response.status_code != 200:
@ -86,7 +88,7 @@ class TeiHelper:
)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
"""
Invoke tokenize endpoint
@ -114,15 +116,15 @@ class TeiHelper:
:param server_url: server url
:param texts: texts to tokenize
"""
resp = httpx.post(
f"{server_url}/tokenize",
json={"inputs": texts},
)
url = f"{server_url}/tokenize"
json_data = {"inputs": texts}
resp = httpx.post(url, json=json_data, headers=headers)
resp.raise_for_status()
return resp.json()
@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
"""
Invoke embeddings endpoint
@ -147,15 +149,14 @@ class TeiHelper:
:param texts: texts to embed
"""
# Use OpenAI compatible API here, which has usage tracking
resp = httpx.post(
f"{server_url}/v1/embeddings",
json={"input": texts},
)
url = f"{server_url}/v1/embeddings"
json_data = {"input": texts}
resp = httpx.post(url, json=json_data, headers=headers)
resp.raise_for_status()
return resp.json()
@staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
"""
Invoke rerank endpoint
@ -173,10 +174,7 @@ class TeiHelper:
:param candidates: candidates to rerank
"""
params = {"query": query, "texts": docs, "return_text": True}
response = httpx.post(
server_url + "/rerank",
json=params,
)
url = f"{server_url}/rerank"
response = httpx.post(url, json=params, headers=headers)
response.raise_for_status()
return response.json()

View File

@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url.removesuffix("/")
headers = {"Content-Type": "application/json"}
api_key = credentials["api_key"]
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0
# get tokenized results from TEI
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size
@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0
for i in _iter:
iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
embeddings = results["data"]
embeddings = [embedding["embedding"] for embedding in embeddings]
batched_embeddings.extend(embeddings)
@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url.removesuffix("/")
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
}
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
num_tokens = sum(len(tokens) for tokens in batch_tokens)
return num_tokens
@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
"""
try:
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
print(extra_args)
if extra_args.model_type != "embedding":
raise CredentialsValidateFailedError("Current model is not a embedding model")

View File

@ -20,6 +20,7 @@ env =
OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
TEI_RERANK_SERVER_URL = http://a.abc.com:11451
TEI_API_KEY = ttttttttttttttt
UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
XINFERENCE_CHAT_MODEL_UID = chat

View File

@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
model="reranker",
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
texts=["hello", "world"],
user="abc-123",

View File

@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
model="embedding",
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
query="Who is Kasumi?",
docs=[