mirror of https://github.com/langgenius/dify.git
feat: Add caching mechanism for plugin model schemas (#14898)
This commit is contained in:
parent
330dc2fd44
commit
4668c4996a
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
||||||
from contexts.wrapper import RecyclableContextVar
|
from contexts.wrapper import RecyclableContextVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
@ -20,11 +21,19 @@ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableCo
|
||||||
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
|
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
|
||||||
ContextVar("plugin_tool_providers")
|
ContextVar("plugin_tool_providers")
|
||||||
)
|
)
|
||||||
|
|
||||||
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||||
|
|
||||||
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||||
ContextVar("plugin_model_providers")
|
ContextVar("plugin_model_providers")
|
||||||
)
|
)
|
||||||
|
|
||||||
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||||
ContextVar("plugin_model_providers_lock")
|
ContextVar("plugin_model_providers_lock")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
|
||||||
|
|
||||||
|
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_model_schemas")
|
||||||
|
)
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import decimal
|
import decimal
|
||||||
|
import hashlib
|
||||||
|
from threading import Lock
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
import contexts
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
@ -139,15 +142,35 @@ class AIModel(BaseModel):
|
||||||
:return: model schema
|
:return: model schema
|
||||||
"""
|
"""
|
||||||
plugin_model_manager = PluginModelManager()
|
plugin_model_manager = PluginModelManager()
|
||||||
return plugin_model_manager.get_model_schema(
|
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||||
tenant_id=self.tenant_id,
|
# sort credentials
|
||||||
user_id="unknown",
|
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||||
plugin_id=self.plugin_id,
|
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||||
provider=self.provider_name,
|
|
||||||
model_type=self.model_type.value,
|
try:
|
||||||
model=model,
|
contexts.plugin_model_schemas.get()
|
||||||
credentials=credentials or {},
|
except LookupError:
|
||||||
)
|
contexts.plugin_model_schemas.set({})
|
||||||
|
contexts.plugin_model_schema_lock.set(Lock())
|
||||||
|
|
||||||
|
with contexts.plugin_model_schema_lock.get():
|
||||||
|
if cache_key in contexts.plugin_model_schemas.get():
|
||||||
|
return contexts.plugin_model_schemas.get()[cache_key]
|
||||||
|
|
||||||
|
schema = plugin_model_manager.get_model_schema(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id="unknown",
|
||||||
|
plugin_id=self.plugin_id,
|
||||||
|
provider=self.provider_name,
|
||||||
|
model_type=self.model_type.value,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
if schema:
|
||||||
|
contexts.plugin_model_schemas.get()[cache_key] = schema
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
@ -206,17 +207,35 @@ class ModelProviderFactory:
|
||||||
Get model schema
|
Get model schema
|
||||||
"""
|
"""
|
||||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||||
model_schema = self.plugin_model_manager.get_model_schema(
|
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
|
||||||
tenant_id=self.tenant_id,
|
# sort credentials
|
||||||
user_id="unknown",
|
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||||
plugin_id=plugin_id,
|
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||||
provider=provider_name,
|
|
||||||
model_type=model_type.value,
|
|
||||||
model=model,
|
|
||||||
credentials=credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_schema
|
try:
|
||||||
|
contexts.plugin_model_schemas.get()
|
||||||
|
except LookupError:
|
||||||
|
contexts.plugin_model_schemas.set({})
|
||||||
|
contexts.plugin_model_schema_lock.set(Lock())
|
||||||
|
|
||||||
|
with contexts.plugin_model_schema_lock.get():
|
||||||
|
if cache_key in contexts.plugin_model_schemas.get():
|
||||||
|
return contexts.plugin_model_schemas.get()[cache_key]
|
||||||
|
|
||||||
|
schema = self.plugin_model_manager.get_model_schema(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id="unknown",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
model_type=model_type.value,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
if schema:
|
||||||
|
contexts.plugin_model_schemas.get()[cache_key] = schema
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
def get_models(
|
def get_models(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue