Add session_id_param to ACADynamicSessionsCodeExecutor (#6171)

The initializer for ACADynamicSessionsCodeExecutor creates a new GUID to
use as the session ID for dynamic sessions.

In some scenarios it is desirable to be able to re-create the agent
group chat from saved state. In this case, the
ACADynamicSessionsCodeExecutor needs to be associated with a previous
instance (so that any execution state is still valid)

This PR adds a new argument to the initializer to allow a session ID to
be passed in (defaulting to the current behaviour of creating a GUID if
absent).

Closes #6119

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Stuart Leeks 2025-04-02 22:39:44 +01:00 committed by GitHub
parent 9de16d5f70
commit 9143e58ef1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -73,6 +73,7 @@ class ACADynamicSessionsCodeExecutor(CodeExecutor):
directory is a temporal directory.
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
suppress_result_output bool: By default the executor will attach any result info in the execution response to the result outpu. Set this to True to prevent this.
session_id (str): The session id for the code execution (passed to Dynamic Sessions). If None, a new session id will be generated. Default is None. Note this value will be reset when calling `restart`
.. note::
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
@ -102,6 +103,7 @@ $functions"""
] = [],
functions_module: str = "functions",
suppress_result_output: bool = False,
session_id: Optional[str] = None,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
@ -141,7 +143,7 @@ $functions"""
self._pool_management_endpoint = pool_management_endpoint
self._access_token: str | None = None
self._session_id: str = str(uuid4())
self._session_id: str = session_id or str(uuid4())
self._available_packages: set[str] | None = None
self._credential: TokenProvider = credential
# cwd needs to be set to /mnt/data to properly read uploaded files and download written files

View File

@ -22,6 +22,23 @@ ENVIRON_KEY_AZURE_POOL_ENDPOINT = "AZURE_POOL_ENDPOINT"
POOL_ENDPOINT = os.getenv(ENVIRON_KEY_AZURE_POOL_ENDPOINT)
def test_session_id_preserved_if_passed() -> None:
executor = ACADynamicSessionsCodeExecutor(
pool_management_endpoint="fake-endpoint", credential=DefaultAzureCredential()
)
session_id = "test_session_id"
executor._session_id = session_id # type: ignore[reportPrivateUsage]
assert executor._session_id == session_id # type: ignore[reportPrivateUsage]
def test_session_id_generated_if_not_passed() -> None:
executor = ACADynamicSessionsCodeExecutor(
pool_management_endpoint="fake-endpoint", credential=DefaultAzureCredential()
)
assert executor._session_id is not None # type: ignore[reportPrivateUsage]
assert len(executor._session_id) > 0 # type: ignore[reportPrivateUsage]
@pytest.mark.skipif(
not POOL_ENDPOINT,
reason="do not run if pool endpoint is not defined",