feat: Add caching mechanism for plugin model schemas (#14898)

This commit is contained in:
Yeuoly 2025-03-04 18:02:06 +08:00 committed by GitHub
parent 330dc2fd44
commit 4668c4996a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 19 deletions

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING: if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -20,11 +21,19 @@ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableCo
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers") ContextVar("plugin_tool_providers")
) )
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers") ContextVar("plugin_model_providers")
) )
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock") ContextVar("plugin_model_providers_lock")
) )
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)

View File

@ -1,8 +1,11 @@
import decimal import decimal
import hashlib
from threading import Lock
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
import contexts
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@ -139,15 +142,35 @@ class AIModel(BaseModel):
:return: model schema :return: model schema
""" """
plugin_model_manager = PluginModelManager() plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_model_schema( cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
tenant_id=self.tenant_id, # sort credentials
user_id="unknown", sorted_credentials = sorted(credentials.items()) if credentials else []
plugin_id=self.plugin_id, cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
provider=self.provider_name,
model_type=self.model_type.value, try:
model=model, contexts.plugin_model_schemas.get()
credentials=credentials or {}, except LookupError:
) contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
""" """

View File

@ -1,3 +1,4 @@
import hashlib
import logging import logging
import os import os
from collections.abc import Sequence from collections.abc import Sequence
@ -206,17 +207,35 @@ class ModelProviderFactory:
Get model schema Get model schema
""" """
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
model_schema = self.plugin_model_manager.get_model_schema( cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
tenant_id=self.tenant_id, # sort credentials
user_id="unknown", sorted_credentials = sorted(credentials.items()) if credentials else []
plugin_id=plugin_id, cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
return model_schema try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
return schema
def get_models( def get_models(
self, self,