diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml index f3a912d84d..e81da51048 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml @@ -34,3 +34,11 @@ model_credential_schema: placeholder: zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 + - variable: api_key + label: + en_US: API Key + type: secret-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index 0bb9a9c8b5..06f76c2d85 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel): server_url = server_url.removesuffix("/") + headers = {"Content-Type": "application/json"} + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + try: - results = TeiHelper.invoke_rerank(server_url, query, docs) + results = TeiHelper.invoke_rerank(server_url, query, docs, headers) rerank_documents = [] for result in results: @@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel): """ try: server_url = credentials["server_url"] - extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + headers = {"Content-Type": "application/json"} + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) if extra_args.model_type != "reranker": raise CredentialsValidateFailedError("Current model is not a rerank model") diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 81ab249214..3ffcf4175e 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -26,13 +26,15 @@ cache_lock = Lock() class TeiHelper: @staticmethod - def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: + def get_tei_extra_parameter( + server_url: str, model_name: str, headers: Optional[dict] = None + ) -> TeiModelExtraParameter: TeiHelper._clean_cache() with cache_lock: if model_name not in cache: cache[model_name] = { "expires": time() + 300, - "value": TeiHelper._get_tei_extra_parameter(server_url), + "value": TeiHelper._get_tei_extra_parameter(server_url, headers), } return cache[model_name]["value"] @@ -47,7 +49,7 @@ class TeiHelper: pass @staticmethod - def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: + def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter: """ get tei model extra parameter like model_type, max_input_length, max_batch_requests """ @@ -61,7 +63,7 @@ class TeiHelper: session.mount("https://", HTTPAdapter(max_retries=3)) try: - response = session.get(url, timeout=10) + response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: @@ -86,7 +88,7 @@ class TeiHelper: ) @staticmethod - def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: + def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]: """ Invoke tokenize endpoint @@ -114,15 +116,15 @@ class TeiHelper: :param server_url: server url :param texts: texts to tokenize """ - resp = httpx.post( - f"{server_url}/tokenize", - json={"inputs": texts}, - ) + url = f"{server_url}/tokenize" + json_data = {"inputs": texts} + resp = httpx.post(url, json=json_data, headers=headers) + resp.raise_for_status() return resp.json() @staticmethod - def invoke_embeddings(server_url: str, texts: list[str]) -> dict: + def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict: """ Invoke embeddings endpoint @@ -147,15 +149,14 @@ class TeiHelper: :param texts: texts to embed """ # Use OpenAI compatible API here, which has usage tracking - resp = httpx.post( - f"{server_url}/v1/embeddings", - json={"input": texts}, - ) + url = f"{server_url}/v1/embeddings" + json_data = {"input": texts} + resp = httpx.post(url, json=json_data, headers=headers) resp.raise_for_status() return resp.json() @staticmethod - def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: + def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]: """ Invoke rerank endpoint @@ -173,10 +174,7 @@ class TeiHelper: :param candidates: candidates to rerank """ params = {"query": query, "texts": docs, "return_text": True} - - response = httpx.post( - server_url + "/rerank", - json=params, - ) + url = f"{server_url}/rerank" + response = httpx.post(url, json=params, headers=headers) response.raise_for_status() return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index a0917630a9..284429b741 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): server_url = server_url.removesuffix("/") + headers = {"Content-Type": "application/json"} + api_key = credentials["api_key"] + if api_key: + headers["Authorization"] = f"Bearer {api_key}" # get model properties context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 # get tokenized results from TEI - batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) + batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers) for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): # Check if the number of tokens is larger than the context size @@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 for i in _iter: iter_texts = inputs[i : i + max_chunks] - results = TeiHelper.invoke_embeddings(server_url, iter_texts) + results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers) embeddings = results["data"] embeddings = [embedding["embedding"] for embedding in embeddings] batched_embeddings.extend(embeddings) @@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): server_url = server_url.removesuffix("/") - batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) + headers = { + "Authorization": f"Bearer {credentials.get('api_key')}", + } + + batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers) num_tokens = sum(len(tokens) for tokens in batch_tokens) return num_tokens @@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): """ try: server_url = credentials["server_url"] - extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) print(extra_args) if extra_args.model_type != "embedding": raise CredentialsValidateFailedError("Current model is not a embedding model") diff --git a/api/pytest.ini b/api/pytest.ini index a23a4b3f3d..993da4c9a7 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -20,6 +20,7 @@ env = OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 TEI_RERANK_SERVER_URL = http://a.abc.com:11451 + TEI_API_KEY = ttttttttttttttt UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa XINFERENCE_CHAT_MODEL_UID = chat diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py index b1fa9d5ca5..33160062e5 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock): model="reranker", credentials={ "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), }, texts=["hello", "world"], user="abc-123", diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py index cd1c20dd02..9777367063 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock): model="embedding", credentials={ "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), }, query="Who is Kasumi?", docs=[