From d186a41ed1330d5abe6a262fd93bda770f9e594c Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Mon, 25 Nov 2024 00:32:56 +0000 Subject: [PATCH] ensure that cancellation token is passed in InvokeWithActivityAsync (#4329) * ensure that cancellation token is passed in InvokeWithActivityAsync * add comments and baggange is not nullable * store ncrunch settings * shange signature to have nullable activity at the end of Update * correct spelling case * primary contructor * add docs and make async interface accept cancellation tokens * address code ql error --- dotnet/AutoGen.v3.ncrunchsolution | 8 ++++ .../Abstractions/IAgentRuntime.cs | 6 +-- .../Abstractions/IAgentState.cs | 20 +++++++++- .../src/Microsoft.AutoGen/Agents/AgentBase.cs | 20 +++++----- .../Agents/AgentBaseExtensions.cs | 37 ++++++++++++++----- .../Microsoft.AutoGen/Agents/AgentRuntime.cs | 6 +-- .../Agents/Agents/AIAgent/InferenceAgent.cs | 16 ++++---- .../Services/Orleans/AgentStateGrain.cs | 6 ++- 8 files changed, 81 insertions(+), 38 deletions(-) create mode 100644 dotnet/AutoGen.v3.ncrunchsolution diff --git a/dotnet/AutoGen.v3.ncrunchsolution b/dotnet/AutoGen.v3.ncrunchsolution new file mode 100644 index 000000000..13107d394 --- /dev/null +++ b/dotnet/AutoGen.v3.ncrunchsolution @@ -0,0 +1,8 @@ + + + True + True + True + True + + \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentRuntime.cs index aa5b5a13a..6b3d4f98c 100644 --- a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentRuntime.cs @@ -15,8 +15,8 @@ public interface IAgentRuntime ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default); ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken = default); ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default); - void Update(Activity? activity, RpcRequest request); - void Update(Activity? activity, CloudEvent cloudEvent); - (string?, string?) GetTraceIDandState(IDictionary metadata); + void Update(RpcRequest request, Activity? activity); + void Update(CloudEvent cloudEvent, Activity? activity); + (string?, string?) GetTraceIdAndState(IDictionary metadata); IDictionary ExtractMetadata(IDictionary metadata); } diff --git a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentState.cs b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentState.cs index 0a6784b54..1b816b4ef 100644 --- a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentState.cs +++ b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentState.cs @@ -3,8 +3,24 @@ namespace Microsoft.AutoGen.Abstractions; +/// +/// Interface for managing the state of an agent. +/// public interface IAgentState { - ValueTask ReadStateAsync(); - ValueTask WriteStateAsync(AgentState state, string eTag); + /// + /// Reads the current state of the agent asynchronously. + /// + /// A token to cancel the operation. + /// A task that represents the asynchronous read operation. The task result contains the current state of the agent. + ValueTask ReadStateAsync(CancellationToken cancellationToken = default); + + /// + /// Writes the specified state of the agent asynchronously. + /// + /// The state to write. + /// The ETag for concurrency control. + /// A token to cancel the operation. + /// A task that represents the asynchronous write operation. The task result contains the ETag of the written state. + ValueTask WriteStateAsync(AgentState state, string eTag, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs index 13b2e8519..345e6d34c 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs @@ -93,7 +93,7 @@ public abstract class AgentBase : IAgentBase, IHandle { var activity = this.ExtractActivity(msg.CloudEvent.Type, msg.CloudEvent.Metadata); await this.InvokeWithActivityAsync( - static ((AgentBase Agent, CloudEvent Item) state) => state.Agent.CallHandler(state.Item), + static ((AgentBase Agent, CloudEvent Item) state, CancellationToken _) => state.Agent.CallHandler(state.Item), (this, msg.CloudEvent), activity, msg.CloudEvent.Type, cancellationToken).ConfigureAwait(false); @@ -103,7 +103,7 @@ public abstract class AgentBase : IAgentBase, IHandle { var activity = this.ExtractActivity(msg.Request.Method, msg.Request.Metadata); await this.InvokeWithActivityAsync( - static ((AgentBase Agent, RpcRequest Request) state) => state.Agent.OnRequestCoreAsync(state.Request), + static ((AgentBase Agent, RpcRequest Request) state, CancellationToken ct) => state.Agent.OnRequestCoreAsync(state.Request, ct), (this, msg.Request), activity, msg.Request.Method, cancellationToken).ConfigureAwait(false); @@ -142,8 +142,8 @@ public abstract class AgentBase : IAgentBase, IHandle } public async Task ReadAsync(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new() { - var agentstate = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false); - return agentstate.FromAgentState(); + var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false); + return agentState.FromAgentState(); } private void OnResponseCore(RpcResponse response) { @@ -195,9 +195,9 @@ public abstract class AgentBase : IAgentBase, IHandle activity?.SetTag("peer.service", target.ToString()); var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _context.Update(activity, request); + _context.Update(request, activity); await this.InvokeWithActivityAsync( - static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource) state) => + static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource) state, CancellationToken ct) => { var (self, request, completion) = state; @@ -206,7 +206,7 @@ public abstract class AgentBase : IAgentBase, IHandle self._pendingRequests[request.RequestId] = completion; } - await state.Agent._context.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false); + await state.Agent._context.SendRequestAsync(state.Agent, state.Request, ct).ConfigureAwait(false); await completion.Task.ConfigureAwait(false); }, @@ -231,11 +231,11 @@ public abstract class AgentBase : IAgentBase, IHandle activity?.SetTag("peer.service", $"{item.Type}/{item.Source}"); // TODO: fix activity - _context.Update(activity, item); + _context.Update(item, activity); await this.InvokeWithActivityAsync( - static async ((AgentBase Agent, CloudEvent Event) state) => + static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) => { - await state.Agent._context.PublishEventAsync(state.Event).ConfigureAwait(false); + await state.Agent._context.PublishEventAsync(state.Event, ct).ConfigureAwait(false); }, (this, item), activity, diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentBaseExtensions.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentBaseExtensions.cs index ce1318a0d..5d738e5fc 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentBaseExtensions.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentBaseExtensions.cs @@ -5,15 +5,25 @@ using System.Diagnostics; namespace Microsoft.AutoGen.Agents; +/// +/// Provides extension methods for the class. +/// public static class AgentBaseExtensions { + /// + /// Extracts an from the given agent and metadata. + /// + /// The agent from which to extract the activity. + /// The name of the activity. + /// The metadata containing trace information. + /// The extracted or null if extraction fails. public static Activity? ExtractActivity(this AgentBase agent, string activityName, IDictionary metadata) { Activity? activity; - (var traceParent, var traceState) = agent.Context.GetTraceIDandState(metadata); + var (traceParent, traceState) = agent.Context.GetTraceIdAndState(metadata); if (!string.IsNullOrEmpty(traceParent)) { - if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out ActivityContext parentContext)) + if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out var parentContext)) { // traceParent is a W3CId activity = AgentBase.s_source.CreateActivity(activityName, ActivityKind.Server, parentContext); @@ -33,12 +43,9 @@ public static class AgentBaseExtensions var baggage = agent.Context.ExtractMetadata(metadata); - if (baggage is not null) + foreach (var baggageItem in baggage) { - foreach (var baggageItem in baggage) - { - activity.AddBaggage(baggageItem.Key, baggageItem.Value); - } + activity.AddBaggage(baggageItem.Key, baggageItem.Value); } } } @@ -49,7 +56,19 @@ public static class AgentBaseExtensions return activity; } - public static async Task InvokeWithActivityAsync(this AgentBase agent, Func func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default) + + /// + /// Invokes a function asynchronously within the context of an . + /// + /// The type of the state parameter. + /// The agent invoking the function. + /// The function to invoke. + /// The state parameter to pass to the function. + /// The activity within which to invoke the function. + /// The name of the method being invoked. + /// A token to monitor for cancellation requests. + /// A task representing the asynchronous operation. + public static async Task InvokeWithActivityAsync(this AgentBase agent, Func func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default) { if (activity is not null && activity.StartTimeUtc == default) { @@ -63,7 +82,7 @@ public static class AgentBaseExtensions try { - await func(state).ConfigureAwait(false); + await func(state, cancellationToken).ConfigureAwait(false); if (activity is not null && activity.IsAllDataRequested) { activity.SetStatus(ActivityStatusCode.Ok); diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentRuntime.cs index fad372ce2..c36d456af 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentRuntime.cs @@ -15,7 +15,7 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger public ILogger Logger { get; } = logger; public IAgentBase? AgentInstance { get; set; } private DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator; - public (string?, string?) GetTraceIDandState(IDictionary metadata) + public (string?, string?) GetTraceIdAndState(IDictionary metadata) { DistributedContextPropagator.ExtractTraceIdAndState(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => @@ -28,11 +28,11 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger out var traceState); return (traceParent, traceState); } - public void Update(Activity? activity, RpcRequest request) + public void Update(RpcRequest request, Activity? activity = null) { DistributedContextPropagator.Inject(activity, request.Metadata, static (carrier, key, value) => ((IDictionary)carrier!)[key] = value); } - public void Update(Activity? activity, CloudEvent cloudEvent) + public void Update(CloudEvent cloudEvent, Activity? activity = null) { DistributedContextPropagator.Inject(activity, cloudEvent.Metadata, static (carrier, key, value) => ((IDictionary)carrier!)[key] = value); } diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Agents/AIAgent/InferenceAgent.cs b/dotnet/src/Microsoft.AutoGen/Agents/Agents/AIAgent/InferenceAgent.cs index a0383a3c2..bf68467e3 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Agents/AIAgent/InferenceAgent.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Agents/AIAgent/InferenceAgent.cs @@ -5,16 +5,14 @@ using Google.Protobuf; using Microsoft.AutoGen.Abstractions; using Microsoft.Extensions.AI; namespace Microsoft.AutoGen.Agents; -public abstract class InferenceAgent : AgentBase where T : IMessage, new() +public abstract class InferenceAgent( + IAgentRuntime context, + EventTypes typeRegistry, + IChatClient client) + : AgentBase(context, typeRegistry) + where T : IMessage, new() { - protected IChatClient ChatClient { get; } - public InferenceAgent( - IAgentRuntime context, - EventTypes typeRegistry, IChatClient client - ) : base(context, typeRegistry) - { - ChatClient = client; - } + protected IChatClient ChatClient { get; } = client; private Task CompleteAsync( IList chatMessages, diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs index 50d8c3ad4..9905f6aeb 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs @@ -7,7 +7,8 @@ namespace Microsoft.AutoGen.Agents; internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState state) : Grain, IAgentState { - public async ValueTask WriteStateAsync(AgentState newState, string eTag) + /// + public async ValueTask WriteStateAsync(AgentState newState, string eTag, CancellationToken cancellationToken = default) { // etags for optimistic concurrency control // if the Etag is null, its a new state @@ -27,7 +28,8 @@ internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore return state.Etag; } - public ValueTask ReadStateAsync() + /// + public ValueTask ReadStateAsync(CancellationToken cancellationToken = default) { return ValueTask.FromResult(state.State); }