[.Net] Add Goolge gemini (#2868)

* update

* add vertex gemini test

* remove DTO

* add test for vertexGeminiAgent

* update test name

* update IGeminiClient interface

* add test for streaming

* add message connector

* add gemini message extension

* add tests

* update

* add gemnini sample

* update examples

* add test for iamge

* fix test

* add more tests

* add streaming message test

* add comment

* remove unused json

* implement google gemini client

* update

* fix comment
This commit is contained in:
Xiaoyun Zhang 2024-06-10 10:31:45 -07:00 committed by GitHub
parent 7d057a93b2
commit a16b307dc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 2530 additions and 109 deletions

View File

@ -53,6 +53,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Tests", "
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Samples", "sample\AutoGen.Anthropic.Samples\AutoGen.Anthropic.Samples.csproj", "{834B4E85-64E5-4382-8465-548F332E5298}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini", "src\AutoGen.Gemini\AutoGen.Gemini.csproj", "{EFE0DC86-80FC-4D52-95B7-07654BA1A769}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini.Tests", "test\AutoGen.Gemini.Tests\AutoGen.Gemini.Tests.csproj", "{8EA16BAB-465A-4C07-ABC4-1070D40067E9}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Gemini.Sample", "sample\AutoGen.Gemini.Sample\AutoGen.Gemini.Sample.csproj", "{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}"
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.AotCompatibility.Tests", "test\AutoGen.AotCompatibility.Tests\AutoGen.AotCompatibility.Tests.csproj", "{6B82F26D-5040-4453-B21B-C8D1F913CE4C}"
EndProject
Global
@ -149,6 +154,18 @@ Global
{834B4E85-64E5-4382-8465-548F332E5298}.Debug|Any CPU.Build.0 = Debug|Any CPU
{834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.ActiveCfg = Release|Any CPU
{834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.Build.0 = Release|Any CPU
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.Build.0 = Release|Any CPU
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.Build.0 = Release|Any CPU
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.Build.0 = Release|Any CPU
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Release|Any CPU.ActiveCfg = Release|Any CPU
@ -180,6 +197,9 @@ Global
{6A95E113-B824-4524-8F13-CD0C3E1C8804} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{815E937E-86D6-4476-9EC6-B7FBCBBB5DB6} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{834B4E85-64E5-4382-8465-548F332E5298} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{EFE0DC86-80FC-4D52-95B7-07654BA1A769} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{8EA16BAB-465A-4C07-ABC4-1070D40067E9} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution

View File

@ -13,12 +13,37 @@
<CSNoWarn>CS1998;CS1591</CSNoWarn>
<NoWarn>$(NoWarn);$(CSNoWarn);NU5104</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<IsPackable>false</IsPackable>
<EnableNetAnalyzers>true</EnableNetAnalyzers>
<EnforceCodeStyleInBuild>true</EnforceCodeStyleInBuild>
<IsTestProject>false</IsTestProject>
</PropertyGroup>
<PropertyGroup>
<RepoRoot>$(MSBuildThisFileDirectory)</RepoRoot>
</PropertyGroup>
<ItemGroup Condition="'$(IsTestProject)' == 'true'">
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup Condition="'$(IsTestProject)' == 'true'">
<Content Include="$(RepoRoot)resource/**/*.*">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
<Link>testData/%(RecursiveDir)%(Filename)%(Extension)</Link>
</Content>
</ItemGroup>
<ItemGroup Condition="'$(IncludeResourceFolder)' == 'true'">
<Content Include="$(RepoRoot)resource/**/*.*">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
<Link>resource/%(RecursiveDir)%(Filename)%(Extension)</Link>
</Content>
</ItemGroup>
</Project>

View File

@ -12,6 +12,7 @@
<MicrosoftNETTestSdkVersion>17.7.0</MicrosoftNETTestSdkVersion>
<MicrosoftDotnetInteractive>1.0.0-beta.24229.4</MicrosoftDotnetInteractive>
<MicrosoftSourceLinkGitHubVersion>8.0.0</MicrosoftSourceLinkGitHubVersion>
<GoogleCloudAPIPlatformVersion>3.0.0</GoogleCloudAPIPlatformVersion>
<JsonSchemaVersion>4.3.0.2</JsonSchemaVersion>
</PropertyGroup>
</Project>

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8323d0b8eceb752e14c29543b2e28bb2fc648ed9719095c31b7708867a4dc918
size 491

View File

@ -93,7 +93,7 @@ The image is generated from prompt {prompt}
if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION"))
{
var imageUrl = content.Split("\n").Last();
var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From);
var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From, mimeType: "image/png");
Console.WriteLine($"download image from {imageUrl} to {imagePath}");
var httpClient = new HttpClient();

View File

@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IncludeResourceFolder>true</IncludeResourceFolder>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
<ProjectReference Include="..\..\src\AutoGen.Gemini\AutoGen.Gemini.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Chat_With_Google_Gemini.cs
using AutoGen.Core;
using AutoGen.Gemini.Middleware;
using FluentAssertions;
namespace AutoGen.Gemini.Sample;
public class Chat_With_Google_Gemini
{
public static async Task RunAsync()
{
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY");
if (apiKey is null)
{
Console.WriteLine("Please set GOOGLE_GEMINI_API_KEY environment variable.");
return;
}
#region Create_Gemini_Agent
var geminiAgent = new GeminiChatAgent(
name: "gemini",
model: "gemini-1.5-flash-001",
apiKey: apiKey,
systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code")
.RegisterMessageConnector()
.RegisterPrintMessage();
#endregion Create_Gemini_Agent
var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
#region verify_reply
reply.Should().BeOfType<TextMessage>();
#endregion verify_reply
}
}

View File

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Chat_With_Vertex_Gemini.cs
using AutoGen.Core;
using AutoGen.Gemini.Middleware;
using FluentAssertions;
namespace AutoGen.Gemini.Sample;
public class Chat_With_Vertex_Gemini
{
public static async Task RunAsync()
{
var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
if (projectID is null)
{
Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
return;
}
#region Create_Gemini_Agent
var geminiAgent = new GeminiChatAgent(
name: "gemini",
model: "gemini-1.5-flash-001",
location: "us-east1",
project: projectID,
systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code")
.RegisterMessageConnector()
.RegisterPrintMessage();
#endregion Create_Gemini_Agent
var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
#region verify_reply
reply.Should().BeOfType<TextMessage>();
#endregion verify_reply
}
}

View File

@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Function_Call_With_Gemini.cs
using AutoGen.Core;
using AutoGen.Gemini.Middleware;
using FluentAssertions;
using Google.Cloud.AIPlatform.V1;
namespace AutoGen.Gemini.Sample;
public partial class MovieFunction
{
/// <summary>
/// find movie titles currently playing in theaters based on any description, genre, title words, etc.
/// </summary>
/// <param name="location">The city and state, e.g. San Francisco, CA or a zip code e.g. 95616</param>
/// <param name="description">Any kind of description including category or genre, title words, attributes, etc.</param>
/// <returns></returns>
[Function]
public async Task<string> FindMovies(string location, string description)
{
// dummy implementation
var movies = new List<string> { "Barbie", "Spiderman", "Batman" };
var result = $"Movies playing in {location} based on {description} are: {string.Join(", ", movies)}";
return result;
}
/// <summary>
/// find theaters based on location and optionally movie title which is currently playing in theaters
/// </summary>
/// <param name="location">The city and state, e.g. San Francisco, CA or a zip code e.g. 95616</param>
/// <param name="movie">Any movie title</param>
[Function]
public async Task<string> FindTheaters(string location, string movie)
{
// dummy implementation
var theaters = new List<string> { "AMC", "Regal", "Cinemark" };
var result = $"Theaters playing {movie} in {location} are: {string.Join(", ", theaters)}";
return result;
}
/// <summary>
/// Find the start times for movies playing in a specific theater
/// </summary>
/// <param name="location">The city and state, e.g. San Francisco, CA or a zip code e.g. 95616</param>
/// <param name="movie">Any movie title</param>
/// <param name="theater">Name of the theater</param>
/// <param name="date">Date for requested showtime</param>
/// <returns></returns>
[Function]
public async Task<string> GetShowtimes(string location, string movie, string theater, string date)
{
// dummy implementation
var showtimes = new List<string> { "10:00 AM", "12:00 PM", "2:00 PM", "4:00 PM", "6:00 PM", "8:00 PM" };
var result = $"Showtimes for {movie} at {theater} in {location} are: {string.Join(", ", showtimes)}";
return result;
}
}
/// <summary>
/// Modified from https://ai.google.dev/gemini-api/docs/function-calling
/// </summary>
public partial class Function_Call_With_Gemini
{
public static async Task RunAsync()
{
var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
if (projectID is null)
{
Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
return;
}
var movieFunction = new MovieFunction();
var functionMiddleware = new FunctionCallMiddleware(
functions: [
movieFunction.FindMoviesFunctionContract,
movieFunction.FindTheatersFunctionContract,
movieFunction.GetShowtimesFunctionContract
],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ movieFunction.FindMoviesFunctionContract.Name!, movieFunction.FindMoviesWrapper },
{ movieFunction.FindTheatersFunctionContract.Name!, movieFunction.FindTheatersWrapper },
{ movieFunction.GetShowtimesFunctionContract.Name!, movieFunction.GetShowtimesWrapper },
});
#region Create_Gemini_Agent
var geminiAgent = new GeminiChatAgent(
name: "gemini",
model: "gemini-1.5-flash-001",
location: "us-central1",
project: projectID,
systemMessage: "You are a helpful AI assistant",
toolConfig: new ToolConfig()
{
FunctionCallingConfig = new FunctionCallingConfig()
{
Mode = FunctionCallingConfig.Types.Mode.Auto,
}
})
.RegisterMessageConnector()
.RegisterPrintMessage()
.RegisterStreamingMiddleware(functionMiddleware);
#endregion Create_Gemini_Agent
#region Single_turn
var question = new TextMessage(Role.User, "What movies are showing in North Seattle tonight?");
var functionCallReply = await geminiAgent.SendAsync(question);
#endregion Single_turn
#region Single_turn_verify_reply
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
#endregion Single_turn_verify_reply
#region Multi_turn
var finalReply = await geminiAgent.SendAsync(chatHistory: [question, functionCallReply]);
#endregion Multi_turn
#region Multi_turn_verify_reply
finalReply.Should().BeOfType<TextMessage>();
#endregion Multi_turn_verify_reply
}
}

View File

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Image_Chat_With_Vertex_Gemini.cs
using AutoGen.Core;
using AutoGen.Gemini.Middleware;
using FluentAssertions;
namespace AutoGen.Gemini.Sample;
public class Image_Chat_With_Vertex_Gemini
{
public static async Task RunAsync()
{
var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
if (projectID is null)
{
Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
return;
}
#region Create_Gemini_Agent
var geminiAgent = new GeminiChatAgent(
name: "gemini",
model: "gemini-1.5-flash-001",
location: "us-east4",
project: projectID,
systemMessage: "You explain image content to user")
.RegisterMessageConnector()
.RegisterPrintMessage();
#endregion Create_Gemini_Agent
#region Send_Image_Request
var imagePath = Path.Combine("resource", "images", "background.png");
var image = await File.ReadAllBytesAsync(imagePath);
var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(image, "image/png"));
var reply = await geminiAgent.SendAsync("what's in the image", [imageMessage]);
#endregion Send_Image_Request
#region Verify_Reply
reply.Should().BeOfType<TextMessage>();
#endregion Verify_Reply
}
}

View File

@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using AutoGen.Gemini.Sample;
Image_Chat_With_Vertex_Gemini.RunAsync().Wait();

View File

@ -5,6 +5,7 @@
<ImplicitUsings>enable</ImplicitUsings>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<NoWarn>$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110</NoWarn>
<IncludeResourceFolder>true</IncludeResourceFolder>
</PropertyGroup>
<ItemGroup>
@ -15,10 +16,4 @@
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
</ItemGroup>
<ItemGroup>
<None Update="images\*.png">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

View File

@ -28,7 +28,7 @@ public class Chat_With_LLaVA
#endregion Create_Ollama_Agent
#region Send_Message
var image = Path.Combine("images", "background.png");
var image = Path.Combine("resource", "images", "background.png");
var binaryData = BinaryData.FromBytes(File.ReadAllBytes(image), "image/png");
var imageMessage = new ImageMessage(Role.User, binaryData);
var textMessage = new TextMessage(Role.User, "what's in this image?");

View File

@ -1,11 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatCompletionRequest.cs
using System.Text.Json.Serialization;
using System.Collections.Generic;
namespace AutoGen.Anthropic.DTO;
using System.Collections.Generic;
public class ChatCompletionRequest
{
[JsonPropertyName("model")]

View File

@ -7,18 +7,34 @@ namespace AutoGen.Core;
public class ImageMessage : IMessage
{
public ImageMessage(Role role, string url, string? from = null)
public ImageMessage(Role role, string url, string? from = null, string? mimeType = null)
: this(role, new Uri(url), from, mimeType)
{
this.Role = role;
this.From = from;
this.Url = url;
}
public ImageMessage(Role role, Uri uri, string? from = null)
public ImageMessage(Role role, Uri uri, string? from = null, string? mimeType = null)
{
this.Role = role;
this.From = from;
this.Url = uri.ToString();
// try infer mimeType from uri extension if not provided
if (mimeType is null)
{
mimeType = uri switch
{
_ when uri.AbsoluteUri.EndsWith(".png", StringComparison.OrdinalIgnoreCase) => "image/png",
_ when uri.AbsoluteUri.EndsWith(".jpg", StringComparison.OrdinalIgnoreCase) => "image/jpeg",
_ when uri.AbsoluteUri.EndsWith(".jpeg", StringComparison.OrdinalIgnoreCase) => "image/jpeg",
_ when uri.AbsoluteUri.EndsWith(".gif", StringComparison.OrdinalIgnoreCase) => "image/gif",
_ when uri.AbsoluteUri.EndsWith(".bmp", StringComparison.OrdinalIgnoreCase) => "image/bmp",
_ when uri.AbsoluteUri.EndsWith(".webp", StringComparison.OrdinalIgnoreCase) => "image/webp",
_ when uri.AbsoluteUri.EndsWith(".svg", StringComparison.OrdinalIgnoreCase) => "image/svg+xml",
_ => throw new ArgumentException("MimeType is required for ImageMessage", nameof(mimeType))
};
}
this.MimeType = mimeType;
}
public ImageMessage(Role role, BinaryData data, string? from = null)
@ -28,7 +44,7 @@ public class ImageMessage : IMessage
throw new ArgumentException("Data cannot be empty", nameof(data));
}
if (string.IsNullOrWhiteSpace(data.MediaType))
if (data.MediaType is null)
{
throw new ArgumentException("MediaType is needed for DataUri Images", nameof(data));
}
@ -36,15 +52,18 @@ public class ImageMessage : IMessage
this.Role = role;
this.From = from;
this.Data = data;
this.MimeType = data.MediaType;
}
public Role Role { get; set; }
public Role Role { get; }
public string? Url { get; set; }
public string? Url { get; }
public string? From { get; set; }
public BinaryData? Data { get; set; }
public BinaryData? Data { get; }
public string MimeType { get; }
public string BuildDataUri()
{
@ -53,7 +72,7 @@ public class ImageMessage : IMessage
throw new NullReferenceException($"{nameof(Data)}");
}
return $"data:{this.Data.MediaType};base64,{Convert.ToBase64String(this.Data.ToArray())}";
return $"data:{this.MimeType};base64,{Convert.ToBase64String(this.Data.ToArray())}";
}
public override string ToString()

View File

@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Google.Cloud.AIPlatform.V1" Version="$(GoogleCloudAPIPlatformVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>
<ItemGroup>
<InternalsVisibleTo Include="AutoGen.Gemini.Tests" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,90 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionContractExtension.cs
using System.Collections.Generic;
using System.Linq;
using AutoGen.Core;
using Google.Cloud.AIPlatform.V1;
using Json.Schema;
using Json.Schema.Generation;
using OpenAPISchemaType = Google.Cloud.AIPlatform.V1.Type;
using Type = System.Type;
namespace AutoGen.Gemini.Extension;
public static class FunctionContractExtension
{
/// <summary>
/// Convert a <see cref="FunctionContract"/> to a <see cref="FunctionDeclaration"/> that can be used in gpt funciton call.
/// </summary>
public static FunctionDeclaration ToFunctionDeclaration(this FunctionContract function)
{
var required = function.Parameters!.Where(p => p.IsRequired)
.Select(p => p.Name)
.ToList();
var parameterProperties = new Dictionary<string, OpenApiSchema>();
foreach (var parameter in function.Parameters ?? Enumerable.Empty<FunctionParameterContract>())
{
var schema = ToOpenApiSchema(parameter.ParameterType);
schema.Description = parameter.Description;
schema.Title = parameter.Name;
schema.Nullable = !parameter.IsRequired;
parameterProperties.Add(parameter.Name!, schema);
}
return new FunctionDeclaration
{
Name = function.Name,
Description = function.Description,
Parameters = new OpenApiSchema
{
Required =
{
required,
},
Properties =
{
parameterProperties,
},
Type = OpenAPISchemaType.Object,
},
};
}
private static OpenApiSchema ToOpenApiSchema(Type? type)
{
if (type == null)
{
return new OpenApiSchema
{
Type = OpenAPISchemaType.Unspecified
};
}
var schema = new JsonSchemaBuilder().FromType(type).Build();
var openApiSchema = new OpenApiSchema
{
Type = schema.GetJsonType() switch
{
SchemaValueType.Array => OpenAPISchemaType.Array,
SchemaValueType.Boolean => OpenAPISchemaType.Boolean,
SchemaValueType.Integer => OpenAPISchemaType.Integer,
SchemaValueType.Number => OpenAPISchemaType.Number,
SchemaValueType.Object => OpenAPISchemaType.Object,
SchemaValueType.String => OpenAPISchemaType.String,
_ => OpenAPISchemaType.Unspecified
},
};
if (schema.GetJsonType() == SchemaValueType.Object && schema.GetProperties() is var properties && properties != null)
{
foreach (var property in properties)
{
openApiSchema.Properties.Add(property.Key, ToOpenApiSchema(property.Value.GetType()));
}
}
return openApiSchema;
}
}

View File

@ -0,0 +1,268 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiChatAgent.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
using AutoGen.Gemini.Extension;
using Google.Cloud.AIPlatform.V1;
using Google.Protobuf.Collections;
namespace AutoGen.Gemini;
public class GeminiChatAgent : IStreamingAgent
{
private readonly IGeminiClient client;
private readonly string? systemMessage;
private readonly string model;
private readonly ToolConfig? toolConfig;
private readonly RepeatedField<SafetySetting>? safetySettings;
private readonly string responseMimeType;
private readonly Tool[]? tools;
/// <summary>
/// Create <see cref="GeminiChatAgent"/> that connects to Gemini.
/// </summary>
/// <param name="client">the gemini client to use. e.g. <see cref="VertexGeminiClient"/> </param>
/// <param name="name">agent name</param>
/// <param name="model">the model id. It needs to be in the format of
/// 'projects/{project}/locations/{location}/publishers/{provider}/models/{model}' if the <paramref name="client"/> is <see cref="VertexGeminiClient"/></param>
/// <param name="systemMessage">system message</param>
/// <param name="toolConfig">tool config</param>
/// <param name="tools">tools</param>
/// <param name="safetySettings">safety settings</param>
/// <param name="responseMimeType">response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain'</param>
public GeminiChatAgent(
IGeminiClient client,
string name,
string model,
string? systemMessage = null,
ToolConfig? toolConfig = null,
Tool[]? tools = null,
RepeatedField<SafetySetting>? safetySettings = null,
string responseMimeType = "text/plain")
{
this.client = client;
this.Name = name;
this.systemMessage = systemMessage;
this.model = model;
this.toolConfig = toolConfig;
this.safetySettings = safetySettings;
this.responseMimeType = responseMimeType;
this.tools = tools;
}
/// <summary>
/// Create <see cref="GeminiChatAgent"/> that connects to Gemini using <see cref="GoogleGeminiClient"/>
/// </summary>
/// <param name="name">agent name</param>
/// <param name="model">the name of gemini model, e.g. gemini-1.5-flash-001</param>
/// <param name="apiKey">google gemini api key</param>
/// <param name="systemMessage">system message</param>
/// <param name="toolConfig">tool config</param>
/// <param name="tools">tools</param>
/// <param name="safetySettings"></param>
/// <param name="responseMimeType">response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain'</param>
/// /// <example>
/// <![CDATA[
/// [!code-csharp[Chat_With_Google_Gemini](~/../sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs?name=Create_Gemini_Agent)]
/// ]]>
/// </example>
public GeminiChatAgent(
string name,
string model,
string apiKey,
string systemMessage = "You are a helpful AI assistant",
ToolConfig? toolConfig = null,
Tool[]? tools = null,
RepeatedField<SafetySetting>? safetySettings = null,
string responseMimeType = "text/plain")
: this(
client: new GoogleGeminiClient(apiKey),
name: name,
model: model,
systemMessage: systemMessage,
toolConfig: toolConfig,
tools: tools,
safetySettings: safetySettings,
responseMimeType: responseMimeType)
{
}
/// <summary>
/// Create <see cref="GeminiChatAgent"/> that connects to Vertex AI.
/// </summary>
/// <param name="name">agent name</param>
/// <param name="systemMessage">system message</param>
/// <param name="model">the name of gemini model, e.g. gemini-1.5-flash-001</param>
/// <param name="project">project id</param>
/// <param name="location">model location</param>
/// <param name="provider">model provider, default is 'google'</param>
/// <param name="toolConfig">tool config</param>
/// <param name="tools">tools</param>
/// <param name="safetySettings"></param>
/// <param name="responseMimeType">response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain'</param>
/// <example>
/// <![CDATA[
/// [!code-csharp[Chat_With_Vertex_Gemini](~/../sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs?name=Create_Gemini_Agent)]
/// ]]>
/// </example>
public GeminiChatAgent(
string name,
string model,
string project,
string location,
string provider = "google",
string? systemMessage = null,
ToolConfig? toolConfig = null,
Tool[]? tools = null,
RepeatedField<SafetySetting>? safetySettings = null,
string responseMimeType = "text/plain")
: this(
client: new VertexGeminiClient(location),
name: name,
model: $"projects/{project}/locations/{location}/publishers/{provider}/models/{model}",
systemMessage: systemMessage,
toolConfig: toolConfig,
tools: tools,
safetySettings: safetySettings,
responseMimeType: responseMimeType)
{
}
public string Name { get; }
public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = await this.client.GenerateContentAsync(request, cancellationToken: cancellationToken).ConfigureAwait(false);
return MessageEnvelope.Create(response, this.Name);
}
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = this.client.GenerateContentStreamAsync(request);
await foreach (var item in response.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return MessageEnvelope.Create(item, this.Name);
}
}
private GenerateContentRequest BuildChatRequest(IEnumerable<IMessage> messages, GenerateReplyOptions? options)
{
var geminiMessages = messages.Select(m => m switch
{
IMessage<Content> contentMessage => contentMessage.Content,
_ => throw new NotSupportedException($"Message type {m.GetType()} is not supported.")
});
// there are several rules applies to the messages that can be sent to Gemini in a multi-turn chat
// - The first message must be from the user or function
// - The (user|model) roles must alternate e.g. (user, model, user, model, ...)
// - The last message must be from the user or function
// check if the first message is from the user
if (geminiMessages.FirstOrDefault()?.Role != "user" && geminiMessages.FirstOrDefault()?.Role != "function")
{
throw new ArgumentException("The first message must be from the user or function", nameof(messages));
}
// check if the last message is from the user
if (geminiMessages.LastOrDefault()?.Role != "user" && geminiMessages.LastOrDefault()?.Role != "function")
{
throw new ArgumentException("The last message must be from the user or function", nameof(messages));
}
// merge continuous messages with the same role into one message
var mergedMessages = geminiMessages.Aggregate(new List<Content>(), (acc, message) =>
{
if (acc.Count == 0 || acc.Last().Role != message.Role)
{
acc.Add(message);
}
else
{
acc.Last().Parts.AddRange(message.Parts);
}
return acc;
});
var systemMessage = this.systemMessage switch
{
null => null,
string message => new Content
{
Parts = { new[] { new Part { Text = message } } },
Role = "system_instruction"
}
};
List<Tool> tools = this.tools?.ToList() ?? new List<Tool>();
var request = new GenerateContentRequest()
{
Contents = { mergedMessages },
SystemInstruction = systemMessage,
Model = this.model,
GenerationConfig = new GenerationConfig
{
StopSequences = { options?.StopSequence ?? Enumerable.Empty<string>() },
ResponseMimeType = this.responseMimeType,
CandidateCount = 1,
},
};
if (this.toolConfig is not null)
{
request.ToolConfig = this.toolConfig;
}
if (this.safetySettings is not null)
{
request.SafetySettings.Add(this.safetySettings);
}
if (options?.MaxToken.HasValue is true)
{
request.GenerationConfig.MaxOutputTokens = options.MaxToken.Value;
}
if (options?.Temperature.HasValue is true)
{
request.GenerationConfig.Temperature = options.Temperature.Value;
}
if (options?.Functions is { Length: > 0 })
{
foreach (var function in options.Functions)
{
tools.Add(new Tool
{
FunctionDeclarations = { function.ToFunctionDeclaration() },
});
}
}
// merge tools into one tool
// because multipe tools are currently not supported by Gemini
// see https://github.com/googleapis/python-aiplatform/issues/3771
var aggregatedTool = new Tool
{
FunctionDeclarations = { tools.SelectMany(t => t.FunctionDeclarations) },
};
if (aggregatedTool is { FunctionDeclarations: { Count: > 0 } })
{
request.Tools.Add(aggregatedTool);
}
return request;
}
}

View File

@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GoogleGeminiClient.cs
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Google.Cloud.AIPlatform.V1;
using Google.Protobuf;
namespace AutoGen.Gemini;
public class GoogleGeminiClient : IGeminiClient
{
private readonly string apiKey;
private const string endpoint = "https://generativelanguage.googleapis.com/v1beta";
private readonly HttpClient httpClient = new();
private const string generateContentPath = "models/{0}:generateContent";
private const string generateContentStreamPath = "models/{0}:streamGenerateContent";
public GoogleGeminiClient(HttpClient httpClient, string apiKey)
{
this.apiKey = apiKey;
this.httpClient = httpClient;
}
public GoogleGeminiClient(string apiKey)
{
this.apiKey = apiKey;
}
public async Task<GenerateContentResponse> GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default)
{
var path = string.Format(generateContentPath, request.Model);
var url = $"{endpoint}/{path}?key={apiKey}";
var httpContent = new StringContent(JsonFormatter.Default.Format(request), System.Text.Encoding.UTF8, "application/json");
var response = await httpClient.PostAsync(url, httpContent, cancellationToken);
if (!response.IsSuccessStatusCode)
{
throw new Exception($"Failed to generate content. Status code: {response.StatusCode}");
}
var json = await response.Content.ReadAsStringAsync();
return GenerateContentResponse.Parser.ParseJson(json);
}
public async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(GenerateContentRequest request)
{
var path = string.Format(generateContentStreamPath, request.Model);
var url = $"{endpoint}/{path}?key={apiKey}&alt=sse";
var httpContent = new StringContent(JsonFormatter.Default.Format(request), System.Text.Encoding.UTF8, "application/json");
var requestMessage = new HttpRequestMessage(HttpMethod.Post, url)
{
Content = httpContent
};
var response = await httpClient.SendAsync(requestMessage, HttpCompletionOption.ResponseHeadersRead);
if (!response.IsSuccessStatusCode)
{
throw new Exception($"Failed to generate content. Status code: {response.StatusCode}");
}
var stream = await response.Content.ReadAsStreamAsync();
var jp = new JsonParser(JsonParser.Settings.Default.WithIgnoreUnknownFields(true));
using var streamReader = new System.IO.StreamReader(stream);
while (!streamReader.EndOfStream)
{
var json = await streamReader.ReadLineAsync();
if (string.IsNullOrWhiteSpace(json))
{
continue;
}
json = json.Substring("data:".Length).Trim();
yield return jp.Parse<GenerateContentResponse>(json);
}
}
}

View File

@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IVertexGeminiClient.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Google.Cloud.AIPlatform.V1;
namespace AutoGen.Gemini;
public interface IGeminiClient
{
Task<GenerateContentResponse> GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default);
IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(GenerateContentRequest request);
}

View File

@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiAgentExtension.cs
using AutoGen.Core;
namespace AutoGen.Gemini.Middleware;
public static class GeminiAgentExtension
{
/// <summary>
/// Register an <see cref="GeminiMessageConnector"/> to the <see cref="GeminiChatAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="GeminiMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<GeminiChatAgent> RegisterMessageConnector(
this GeminiChatAgent agent, GeminiMessageConnector? connector = null)
{
if (connector == null)
{
connector = new GeminiMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector);
}
/// <summary>
/// Register an <see cref="GeminiMessageConnector"/> to the <see cref="MiddlewareAgent{T}"/> where T is <see cref="GeminiChatAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="GeminiMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<GeminiChatAgent> RegisterMessageConnector(
this MiddlewareStreamingAgent<GeminiChatAgent> agent, GeminiMessageConnector? connector = null)
{
if (connector == null)
{
connector = new GeminiMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector);
}
}

View File

@ -0,0 +1,483 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiMessageConnector.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
using Google.Cloud.AIPlatform.V1;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using static Google.Cloud.AIPlatform.V1.Candidate.Types;
using IMessage = AutoGen.Core.IMessage;
namespace AutoGen.Gemini.Middleware;
public class GeminiMessageConnector : IStreamingMiddleware
{
/// <summary>
/// if true, the connector will throw an exception if it encounters an unsupport message type.
/// Otherwise, it will ignore processing the message and return the message as is.
/// </summary>
private readonly bool strictMode;
/// <summary>
/// Initializes a new instance of the <see cref="GeminiMessageConnector"/> class.
/// </summary>
/// <param name="strictMode">whether to throw an exception if it encounters an unsupport message type.
/// If true, the connector will throw an exception if it encounters an unsupport message type.
/// If false, it will ignore processing the message and return the message as is.</param>
public GeminiMessageConnector(bool strictMode = false)
{
this.strictMode = strictMode;
}
public string Name => nameof(GeminiMessageConnector);
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
var bucket = new List<GenerateContentResponse>();
await foreach (var reply in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken))
{
if (reply is Core.IMessage<GenerateContentResponse> m)
{
// if m.Content is empty and stop reason is Stop, ignore the message
if (m.Content.Candidates.Count == 1 && m.Content.Candidates[0].Content.Parts.Count == 1 && m.Content.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.Text)
{
var text = m.Content.Candidates[0].Content.Parts[0].Text;
var stopReason = m.Content.Candidates[0].FinishReason;
if (string.IsNullOrEmpty(text) && stopReason == FinishReason.Stop)
{
continue;
}
}
bucket.Add(m.Content);
yield return PostProcessStreamingMessage(m.Content, agent);
}
else if (strictMode)
{
throw new InvalidOperationException($"Unsupported message type: {reply.GetType()}");
}
else
{
yield return reply;
}
// aggregate the message updates from bucket into a single message
if (bucket is { Count: > 0 })
{
var isTextMessageUpdates = bucket.All(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.Text);
var isFunctionCallUpdates = bucket.Any(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.FunctionCall);
if (isTextMessageUpdates)
{
var text = string.Join(string.Empty, bucket.Select(m => m.Candidates[0].Content.Parts[0].Text));
var textMessage = new TextMessage(Role.Assistant, text, agent.Name);
yield return textMessage;
}
else if (isFunctionCallUpdates)
{
var functionCallParts = bucket.Where(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.FunctionCall)
.Select(m => m.Candidates[0].Content.Parts[0]).ToList();
var toolCalls = new List<ToolCall>();
foreach (var part in functionCallParts)
{
var fc = part.FunctionCall;
var toolCall = new ToolCall(fc.Name, fc.Args.ToString());
toolCalls.Add(toolCall);
}
var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name);
yield return toolCallMessage;
}
else
{
throw new InvalidOperationException("The response should contain either text or tool calls.");
}
}
}
}
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
var reply = await agent.GenerateReplyAsync(messages, context.Options, cancellationToken);
return reply switch
{
Core.IMessage<GenerateContentResponse> m => PostProcessMessage(m.Content, agent),
_ when strictMode => throw new InvalidOperationException($"Unsupported message type: {reply.GetType()}"),
_ => reply,
};
}
private IMessage PostProcessStreamingMessage(GenerateContentResponse m, IAgent agent)
{
this.ValidateGenerateContentResponse(m);
var candidate = m.Candidates[0];
var parts = candidate.Content.Parts;
if (parts.Count == 1 && parts[0].DataCase == Part.DataOneofCase.Text)
{
var content = parts[0].Text;
return new TextMessageUpdate(Role.Assistant, content, agent.Name);
}
else
{
var toolCalls = new List<ToolCall>();
foreach (var part in parts)
{
if (part.DataCase == Part.DataOneofCase.FunctionCall)
{
var fc = part.FunctionCall;
var toolCall = new ToolCall(fc.Name, fc.Args.ToString());
toolCalls.Add(toolCall);
}
}
if (toolCalls.Count > 0)
{
var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name);
return toolCallMessage;
}
else
{
throw new InvalidOperationException("The response should contain either text or tool calls.");
}
}
}
private IMessage PostProcessMessage(GenerateContentResponse m, IAgent agent)
{
this.ValidateGenerateContentResponse(m);
var candidate = m.Candidates[0];
var parts = candidate.Content.Parts;
if (parts.Count == 1 && parts[0].DataCase == Part.DataOneofCase.Text)
{
var content = parts[0].Text;
return new TextMessage(Role.Assistant, content, agent.Name);
}
else
{
var toolCalls = new List<ToolCall>();
foreach (var part in parts)
{
if (part.DataCase == Part.DataOneofCase.FunctionCall)
{
var fc = part.FunctionCall;
var toolCall = new ToolCall(fc.Name, fc.Args.ToString());
toolCalls.Add(toolCall);
}
}
if (toolCalls.Count > 0)
{
var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name);
return toolCallMessage;
}
else
{
throw new InvalidOperationException("The response should contain either text or tool calls.");
}
}
}
private IEnumerable<IMessage> ProcessMessage(IEnumerable<IMessage> messages, IAgent agent)
{
return messages.SelectMany(m =>
{
if (m is Core.IMessage<Content> messageEnvelope)
{
return [m];
}
else
{
return m switch
{
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage, agent),
ToolCallAggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
_ when strictMode => throw new InvalidOperationException($"Unsupported message type: {m.GetType()}"),
_ => [m],
};
}
});
}
private IEnumerable<IMessage> ProcessToolCallAggregateMessage(ToolCallAggregateMessage toolCallAggregateMessage, IAgent agent)
{
var parseAsUser = ShouldParseAsUser(toolCallAggregateMessage, agent);
if (parseAsUser)
{
var content = toolCallAggregateMessage.GetContent();
if (content is string str)
{
var textMessage = new TextMessage(Role.User, str, toolCallAggregateMessage.From);
return ProcessTextMessage(textMessage, agent);
}
return [];
}
else
{
var toolCallContents = ProcessToolCallMessage(toolCallAggregateMessage.Message1, agent);
var toolCallResultContents = ProcessToolCallResultMessage(toolCallAggregateMessage.Message2, agent);
return toolCallContents.Concat(toolCallResultContents);
}
}
private void ValidateGenerateContentResponse(GenerateContentResponse response)
{
if (response.Candidates.Count != 1)
{
throw new InvalidOperationException("The response should contain exactly one candidate.");
}
var candidate = response.Candidates[0];
if (candidate.Content is null)
{
var finishReason = candidate.FinishReason;
var finishMessage = candidate.FinishMessage;
throw new InvalidOperationException($"The response should contain content but the content is empty. FinishReason: {finishReason}, FinishMessage: {finishMessage}");
}
}
private IEnumerable<IMessage> ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage, IAgent agent)
{
var functionCallResultParts = new List<Part>();
foreach (var toolCallResult in toolCallResultMessage.ToolCalls)
{
if (toolCallResult.Result is null)
{
continue;
}
// if result is already a json object, use it as is
var json = toolCallResult.Result;
try
{
JsonNode.Parse(json);
}
catch (JsonException)
{
// if the result is not a json object, wrap it in a json object
var result = new { result = json };
json = JsonSerializer.Serialize(result);
}
var part = new Part
{
FunctionResponse = new FunctionResponse
{
Name = toolCallResult.FunctionName,
Response = Struct.Parser.ParseJson(json),
}
};
functionCallResultParts.Add(part);
}
var content = new Content
{
Parts = { functionCallResultParts },
Role = "function",
};
return [MessageEnvelope.Create(content, toolCallResultMessage.From)];
}
private IEnumerable<IMessage> ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
{
var shouldParseAsUser = ShouldParseAsUser(toolCallMessage, agent);
if (strictMode && shouldParseAsUser)
{
throw new InvalidOperationException("ToolCallMessage is not supported as user role in Gemini.");
}
var functionCallParts = new List<Part>();
foreach (var toolCall in toolCallMessage.ToolCalls)
{
var part = new Part
{
FunctionCall = new FunctionCall
{
Name = toolCall.FunctionName,
Args = Struct.Parser.ParseJson(toolCall.FunctionArguments),
}
};
functionCallParts.Add(part);
}
var content = new Content
{
Parts = { functionCallParts },
Role = "model"
};
return [MessageEnvelope.Create(content, toolCallMessage.From)];
}
private IEnumerable<IMessage> ProcessMultiModalMessage(MultiModalMessage multiModalMessage, IAgent agent)
{
var parts = new List<Part>();
foreach (var message in multiModalMessage.Content)
{
if (message is TextMessage textMessage)
{
parts.Add(new Part { Text = textMessage.Content });
}
else if (message is ImageMessage imageMessage)
{
parts.Add(CreateImagePart(imageMessage));
}
else
{
throw new InvalidOperationException($"Unsupported message type: {message.GetType()}");
}
}
var shouldParseAsUser = ShouldParseAsUser(multiModalMessage, agent);
if (strictMode && !shouldParseAsUser)
{
// image message is not supported as model role in Gemini
throw new InvalidOperationException("Image message is not supported as model role in Gemini.");
}
var content = new Content
{
Parts = { parts },
Role = shouldParseAsUser ? "user" : "model",
};
return [MessageEnvelope.Create(content, multiModalMessage.From)];
}
private IEnumerable<IMessage> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
if (textMessage.Role == Role.System)
{
// there are only user | model role in Gemini
// if the role is system and the strict mode is enabled, throw an exception
if (strictMode)
{
throw new InvalidOperationException("System role is not supported in Gemini.");
}
// if strict mode is not enabled, parse the message as a user message
var content = new Content
{
Parts = { new[] { new Part { Text = textMessage.Content } } },
Role = "user",
};
return [MessageEnvelope.Create(content, textMessage.From)];
}
var shouldParseAsUser = ShouldParseAsUser(textMessage, agent);
if (shouldParseAsUser)
{
var content = new Content
{
Parts = { new[] { new Part { Text = textMessage.Content } } },
Role = "user",
};
return [MessageEnvelope.Create(content, textMessage.From)];
}
else
{
var content = new Content
{
Parts = { new[] { new Part { Text = textMessage.Content } } },
Role = "model",
};
return [MessageEnvelope.Create(content, textMessage.From)];
}
}
private IEnumerable<IMessage> ProcessImageMessage(ImageMessage imageMessage, IAgent agent)
{
var imagePart = CreateImagePart(imageMessage);
var shouldParseAsUser = ShouldParseAsUser(imageMessage, agent);
if (strictMode && !shouldParseAsUser)
{
// image message is not supported as model role in Gemini
throw new InvalidOperationException("Image message is not supported as model role in Gemini.");
}
var content = new Content
{
Parts = { imagePart },
Role = shouldParseAsUser ? "user" : "model",
};
return [MessageEnvelope.Create(content, imageMessage.From)];
}
private Part CreateImagePart(ImageMessage message)
{
if (message.Url is string url)
{
return new Part
{
FileData = new FileData
{
FileUri = url,
MimeType = message.MimeType
}
};
}
else if (message.Data is BinaryData data)
{
return new Part
{
InlineData = new Blob
{
MimeType = message.MimeType,
Data = ByteString.CopyFrom(data.ToArray()),
}
};
}
else
{
throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided");
}
}
private bool ShouldParseAsUser(IMessage message, IAgent agent)
{
return message switch
{
TextMessage textMessage => (textMessage.Role == Role.User && textMessage.From is null)
|| (textMessage.From != agent.Name),
_ => message.From != agent.Name,
};
}
}

View File

@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IGeminiClient.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Google.Cloud.AIPlatform.V1;
namespace AutoGen.Gemini;
internal class VertexGeminiClient : IGeminiClient
{
private readonly PredictionServiceClient client;
public VertexGeminiClient(PredictionServiceClient client)
{
this.client = client;
}
public VertexGeminiClient(string location)
{
PredictionServiceClientBuilder builder = new()
{
Endpoint = $"{location}-aiplatform.googleapis.com",
};
this.client = builder.Build();
}
public Task<GenerateContentResponse> GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default)
{
return client.GenerateContentAsync(request, cancellationToken);
}
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(GenerateContentRequest request)
{
return client.StreamGenerateContent(request).GetResponseStream();
}
}

View File

@ -1,4 +1,4 @@
using System.Text;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using AutoGen.Anthropic.DTO;
@ -43,7 +43,7 @@ public class AnthropicClientTests
request.Model = AnthropicConstants.Claude3Haiku;
request.Stream = true;
request.MaxTokens = 500;
request.SystemMessage = "You are a helpful assistant that convert input to json object";
request.SystemMessage = "You are a helpful assistant that convert input to json object, use JSON format.";
request.Messages = new List<ChatMessage>()
{
new("user", "name: John, age: 41, email: g123456@gmail.com")

View File

@ -6,16 +6,9 @@
<IsPackable>false</IsPackable>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<RootNamespace>AutoGen.Anthropic.Tests</RootNamespace>
<IsTestProject>True</IsTestProject>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.Anthropic\AutoGen.Anthropic.csproj" />
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />

View File

@ -4,18 +4,10 @@
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<IsTestProject>True</IsTestProject>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />

View File

@ -0,0 +1,17 @@
{
"name": "GetWeatherAsync",
"description": "Get weather for a city.",
"parameters": {
"type": "OBJECT",
"properties": {
"city": {
"type": "STRING",
"description": "city",
"title": "city"
}
},
"required": [
"city"
]
}
}

View File

@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<IsTestProject>True</IsTestProject>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\sample\AutoGen.Gemini.Sample\AutoGen.Gemini.Sample.csproj" />
<ProjectReference Include="..\..\src\AutoGen.Gemini\AutoGen.Gemini.csproj" />
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionContractExtensionTests.cs
using ApprovalTests;
using ApprovalTests.Namers;
using ApprovalTests.Reporters;
using AutoGen.Gemini.Extension;
using Google.Protobuf;
using Xunit;
namespace AutoGen.Gemini.Tests;
public class FunctionContractExtensionTests
{
private readonly Functions functions = new Functions();
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
public void ItGenerateGetWeatherToolTest()
{
var contract = functions.GetWeatherAsyncFunctionContract;
var tool = contract.ToFunctionDeclaration();
var formatter = new JsonFormatter(JsonFormatter.Settings.Default.WithIndentation(" "));
var json = formatter.Format(tool);
Approvals.Verify(json);
}
}

View File

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Functions.cs
using AutoGen.Core;
namespace AutoGen.Gemini.Tests;
public partial class Functions
{
/// <summary>
/// Get weather for a city.
/// </summary>
/// <param name="city">city</param>
/// <returns>weather</returns>
[Function]
public async Task<string> GetWeatherAsync(string city)
{
return await Task.FromResult($"The weather in {city} is sunny.");
}
[Function]
public async Task<string> GetMovies(string location, string description)
{
var movies = new List<string> { "Barbie", "Spiderman", "Batman" };
return await Task.FromResult($"Movies playing in {location} based on {description} are: {string.Join(", ", movies)}");
}
}

View File

@ -0,0 +1,311 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiAgentTests.cs
using AutoGen.Tests;
using Google.Cloud.AIPlatform.V1;
using AutoGen.Core;
using FluentAssertions;
using AutoGen.Gemini.Extension;
using static Google.Cloud.AIPlatform.V1.Part;
using Xunit.Abstractions;
using AutoGen.Gemini.Middleware;
namespace AutoGen.Gemini.Tests;
public class GeminiAgentTests
{
private readonly Functions functions = new Functions();
private readonly ITestOutputHelper _output;
public GeminiAgentTests(ITestOutputHelper output)
{
_output = output;
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task VertexGeminiAgentGenerateReplyForTextContentAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
var model = "gemini-1.5-flash-001";
var textContent = new Content
{
Role = "user",
Parts =
{
new Part
{
Text = "Hello",
}
}
};
var agent = new GeminiChatAgent(
name: "assistant",
model: model,
project: project,
location: location,
systemMessage: "You are a helpful AI assistant");
var message = MessageEnvelope.Create(textContent, from: agent.Name);
var completion = await agent.SendAsync(message);
completion.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
completion.From.Should().Be(agent.Name);
var response = ((MessageEnvelope<GenerateContentResponse>)completion).Content;
response.Should().NotBeNull();
response.Candidates.Count.Should().BeGreaterThan(0);
response.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task VertexGeminiAgentGenerateStreamingReplyForTextContentAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
var model = "gemini-1.5-flash-001";
var textContent = new Content
{
Role = "user",
Parts =
{
new Part
{
Text = "Hello",
}
}
};
var agent = new GeminiChatAgent(
name: "assistant",
model: model,
project: project,
location: location,
systemMessage: "You are a helpful AI assistant");
var message = MessageEnvelope.Create(textContent, from: agent.Name);
var completion = agent.GenerateStreamingReplyAsync([message]);
var chunks = new List<IStreamingMessage>();
IStreamingMessage finalReply = null!;
await foreach (var item in completion)
{
item.Should().NotBeNull();
item.From.Should().Be(agent.Name);
var streamingMessage = (IMessage<GenerateContentResponse>)item;
streamingMessage.Content.Candidates.Should().NotBeNullOrEmpty();
chunks.Add(item);
finalReply = item;
}
chunks.Count.Should().BeGreaterThan(0);
finalReply.Should().NotBeNull();
finalReply.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
var response = ((MessageEnvelope<GenerateContentResponse>)finalReply).Content;
response.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task VertexGeminiAgentGenerateReplyWithToolsAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
var model = "gemini-1.5-flash-001";
var tools = new Tool[]
{
new Tool
{
FunctionDeclarations = {
functions.GetWeatherAsyncFunctionContract.ToFunctionDeclaration(),
},
},
new Tool
{
FunctionDeclarations =
{
functions.GetMoviesFunctionContract.ToFunctionDeclaration(),
},
},
};
var textContent = new Content
{
Role = "user",
Parts =
{
new Part
{
Text = "what's the weather in seattle",
}
}
};
var agent = new GeminiChatAgent(
name: "assistant",
model: model,
project: project,
location: location,
systemMessage: "You are a helpful AI assistant",
tools: tools,
toolConfig: new ToolConfig()
{
FunctionCallingConfig = new FunctionCallingConfig()
{
Mode = FunctionCallingConfig.Types.Mode.Auto,
}
});
var message = MessageEnvelope.Create(textContent, from: agent.Name);
var completion = await agent.SendAsync(message);
completion.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
completion.From.Should().Be(agent.Name);
var response = ((MessageEnvelope<GenerateContentResponse>)completion).Content;
response.Should().NotBeNull();
response.Candidates.Count.Should().BeGreaterThan(0);
response.Candidates[0].Content.Parts[0].DataCase.Should().Be(DataOneofCase.FunctionCall);
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task VertexGeminiAgentGenerateStreamingReplyWithToolsAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
var model = "gemini-1.5-flash-001";
var tools = new Tool[]
{
new Tool
{
FunctionDeclarations = { functions.GetWeatherAsyncFunctionContract.ToFunctionDeclaration() },
},
};
var textContent = new Content
{
Role = "user",
Parts =
{
new Part
{
Text = "what's the weather in seattle",
}
}
};
var agent = new GeminiChatAgent(
name: "assistant",
model: model,
project: project,
location: location,
systemMessage: "You are a helpful AI assistant",
tools: tools,
toolConfig: new ToolConfig()
{
FunctionCallingConfig = new FunctionCallingConfig()
{
Mode = FunctionCallingConfig.Types.Mode.Auto,
}
});
var message = MessageEnvelope.Create(textContent, from: agent.Name);
var chunks = new List<IStreamingMessage>();
IStreamingMessage finalReply = null!;
var completion = agent.GenerateStreamingReplyAsync([message]);
await foreach (var item in completion)
{
item.Should().NotBeNull();
item.From.Should().Be(agent.Name);
var streamingMessage = (IMessage<GenerateContentResponse>)item;
streamingMessage.Content.Candidates.Should().NotBeNullOrEmpty();
if (streamingMessage.Content.Candidates[0].FinishReason != Candidate.Types.FinishReason.Stop)
{
streamingMessage.Content.Candidates[0].Content.Parts[0].DataCase.Should().Be(DataOneofCase.FunctionCall);
}
chunks.Add(item);
finalReply = item;
}
chunks.Count.Should().BeGreaterThan(0);
finalReply.Should().NotBeNull();
finalReply.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
var response = ((MessageEnvelope<GenerateContentResponse>)finalReply).Content;
response.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task GeminiAgentUpperCaseTestAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
var model = "gemini-1.5-flash-001";
var agent = new GeminiChatAgent(
name: "assistant",
model: model,
project: project,
location: location)
.RegisterMessageConnector();
var singleAgentTest = new SingleAgentTest(_output);
await singleAgentTest.UpperCaseStreamingTestAsync(agent);
await singleAgentTest.UpperCaseTestAsync(agent);
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task GeminiAgentEchoFunctionCallTestAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
var model = "gemini-1.5-flash-001";
var singleAgentTest = new SingleAgentTest(_output);
var echoFunctionContract = singleAgentTest.EchoAsyncFunctionContract;
var agent = new GeminiChatAgent(
name: "assistant",
model: model,
project: project,
location: location,
tools:
[
new Tool
{
FunctionDeclarations = { echoFunctionContract.ToFunctionDeclaration() },
},
])
.RegisterMessageConnector();
await singleAgentTest.EchoFunctionCallTestAsync(agent);
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task GeminiAgentEchoFunctionCallExecutionTestAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
var model = "gemini-1.5-flash-001";
var singleAgentTest = new SingleAgentTest(_output);
var echoFunctionContract = singleAgentTest.EchoAsyncFunctionContract;
var functionMiddleware = new FunctionCallMiddleware(
functions: [echoFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>()
{
{ echoFunctionContract.Name!, singleAgentTest.EchoAsyncWrapper },
});
var agent = new GeminiChatAgent(
name: "assistant",
model: model,
project: project,
location: location)
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionMiddleware);
await singleAgentTest.EchoFunctionCallExecutionStreamingTestAsync(agent);
await singleAgentTest.EchoFunctionCallExecutionTestAsync(agent);
}
}

View File

@ -0,0 +1,380 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiMessageTests.cs
using AutoGen.Core;
using AutoGen.Gemini.Middleware;
using AutoGen.Tests;
using FluentAssertions;
using Google.Cloud.AIPlatform.V1;
using Xunit;
namespace AutoGen.Gemini.Tests;
public class GeminiMessageTests
{
[Fact]
public async Task ItProcessUserTextMessageAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Parts.Count.Should().Be(1);
message.Content.Role.Should().Be("user");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
// when from is null and role is user
await agent.SendAsync("Hello");
// when from is user and role is user
var userMessage = new TextMessage(Role.User, "Hello", from: "user");
await agent.SendAsync(userMessage);
// when from is user but role is assistant
userMessage = new TextMessage(Role.Assistant, "Hello", from: "user");
await agent.SendAsync(userMessage);
}
[Fact]
public async Task ItProcessAssistantTextMessageAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Parts.Count.Should().Be(1);
message.Content.Role.Should().Be("model");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
// when from is user and role is assistant
var message = new TextMessage(Role.User, "Hello", from: agent.Name);
await agent.SendAsync(message);
// when from is assistant and role is assistant
message = new TextMessage(Role.Assistant, "Hello", from: agent.Name);
await agent.SendAsync(message);
}
[Fact]
public async Task ItProcessSystemTextMessageAsUserMessageWhenStrictModeIsFalseAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Parts.Count.Should().Be(1);
message.Content.Role.Should().Be("user");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var message = new TextMessage(Role.System, "Hello", from: agent.Name);
await agent.SendAsync(message);
}
[Fact]
public async Task ItThrowExceptionOnSystemMessageWhenStrictModeIsTrueAsync()
{
var messageConnector = new GeminiMessageConnector(true);
var agent = new EchoAgent("assistant")
.RegisterMiddleware(messageConnector);
var message = new TextMessage(Role.System, "Hello", from: agent.Name);
var action = new Func<Task>(async () => await agent.SendAsync(message));
await action.Should().ThrowAsync<InvalidOperationException>();
}
[Fact]
public async Task ItProcessUserImageMessageAsInlineDataAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Parts.Count.Should().Be(1);
message.Content.Role.Should().Be("user");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.InlineData);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var imagePath = Path.Combine("testData", "images", "background.png");
var image = File.ReadAllBytes(imagePath);
var message = new ImageMessage(Role.User, BinaryData.FromBytes(image, "image/png"));
message.MimeType.Should().Be("image/png");
await agent.SendAsync(message);
}
[Fact]
public async Task ItProcessUserImageMessageAsFileDataAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Parts.Count.Should().Be(1);
message.Content.Role.Should().Be("user");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FileData);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var imagePath = Path.Combine("testData", "images", "image.png");
var url = new Uri(Path.GetFullPath(imagePath)).AbsoluteUri;
var message = new ImageMessage(Role.User, url);
message.MimeType.Should().Be("image/png");
await agent.SendAsync(message);
}
[Fact]
public async Task ItProcessMultiModalMessageAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Parts.Count.Should().Be(2);
message.Content.Role.Should().Be("user");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.Text);
message.Content.Parts.Last().DataCase.Should().Be(Part.DataOneofCase.FileData);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var imagePath = Path.Combine("testData", "images", "image.png");
var url = new Uri(Path.GetFullPath(imagePath)).AbsoluteUri;
var message = new ImageMessage(Role.User, url);
message.MimeType.Should().Be("image/png");
var textMessage = new TextMessage(Role.User, "What's in this image?");
var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, message]);
await agent.SendAsync(multiModalMessage);
}
[Fact]
public async Task ItProcessToolCallMessageAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Role.Should().Be("model");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionCall);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var toolCallMessage = new ToolCallMessage("test", "{}", "user");
await agent.SendAsync(toolCallMessage);
}
[Fact]
public async Task ItProcessStreamingTextMessageAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterStreamingMiddleware(messageConnector);
var messageChunks = Enumerable.Range(0, 10)
.Select(i => new GenerateContentResponse()
{
Candidates =
{
new Candidate()
{
Content = new Content()
{
Role = "user",
Parts = { new Part { Text = i.ToString() } },
}
}
}
})
.Select(m => MessageEnvelope.Create(m));
IStreamingMessage? finalReply = null;
await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks))
{
reply.Should().BeAssignableTo<IStreamingMessage>();
finalReply = reply;
}
finalReply.Should().BeOfType<TextMessage>();
var textMessage = (TextMessage)finalReply!;
textMessage.GetContent().Should().Be("0123456789");
}
[Fact]
public async Task ItProcessToolCallResultMessageAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Role.Should().Be("function");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionResponse);
message.Content.Parts.First().FunctionResponse.Response.ToString().Should().Be("{ \"result\": \"result\" }");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var message = new ToolCallResultMessage("result", "test", "{}", "user");
await agent.SendAsync(message);
// when the result is already a json object string
message = new ToolCallResultMessage("{ \"result\": \"result\" }", "test", "{}", "user");
await agent.SendAsync(message);
}
[Fact]
public async Task ItProcessToolCallAggregateMessageAsTextContentAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.First();
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)innerMessage;
message.Content.Role.Should().Be("user");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.Text);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var toolCallMessage = new ToolCallMessage("test", "{}", "user");
var toolCallResultMessage = new ToolCallResultMessage("result", "test", "{}", "user");
var message = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: "user");
await agent.SendAsync(message);
}
[Fact]
public async Task ItProcessToolCallAggregateMessageAsFunctionContentAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
msgs.Count().Should().Be(2);
var functionCallMessage = msgs.First();
functionCallMessage.Should().BeOfType<MessageEnvelope<Content>>();
var message = (IMessage<Content>)functionCallMessage;
message.Content.Role.Should().Be("model");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionCall);
var functionResultMessage = msgs.Last();
functionResultMessage.Should().BeOfType<MessageEnvelope<Content>>();
message = (IMessage<Content>)functionResultMessage;
message.Content.Role.Should().Be("function");
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionResponse);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(messageConnector);
var toolCallMessage = new ToolCallMessage("test", "{}", agent.Name);
var toolCallResultMessage = new ToolCallResultMessage("result", "test", "{}", agent.Name);
var message = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: agent.Name);
await agent.SendAsync(message);
}
[Fact]
public async Task ItThrowExceptionWhenProcessingUnknownMessageTypeInStrictModeAsync()
{
var messageConnector = new GeminiMessageConnector(true);
var agent = new EchoAgent("assistant")
.RegisterMiddleware(messageConnector);
var unknownMessage = new
{
text = "Hello",
};
var message = MessageEnvelope.Create(unknownMessage, from: agent.Name);
var action = new Func<Task>(async () => await agent.SendAsync(message));
await action.Should().ThrowAsync<InvalidOperationException>();
}
[Fact]
public async Task ItReturnUnknownMessageTypeInNonStrictModeAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
var message = msgs.First();
message.Should().BeAssignableTo<IMessage>();
return message;
})
.RegisterMiddleware(messageConnector);
var unknownMessage = new
{
text = "Hello",
};
var message = MessageEnvelope.Create(unknownMessage, from: agent.Name);
await agent.SendAsync(message);
}
[Fact]
public async Task ItShortcircuitContentTypeAsync()
{
var messageConnector = new GeminiMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
{
var message = msgs.First();
message.Should().BeOfType<MessageEnvelope<Content>>();
return message;
})
.RegisterMiddleware(messageConnector);
var message = new Content()
{
Parts = { new Part { Text = "Hello" } },
Role = "user",
};
await agent.SendAsync(MessageEnvelope.Create(message, from: agent.Name));
}
}

View File

@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GoogleGeminiClientTests.cs
using AutoGen.Tests;
using FluentAssertions;
using Google.Cloud.AIPlatform.V1;
using Google.Protobuf;
using static Google.Cloud.AIPlatform.V1.Candidate.Types;
namespace AutoGen.Gemini.Tests;
public class GoogleGeminiClientTests
{
[ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
public async Task ItGenerateContentAsync()
{
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
var client = new GoogleGeminiClient(apiKey);
var model = "gemini-1.5-flash-001";
var text = "Write a long, tedious story";
var request = new GenerateContentRequest
{
Model = model,
Contents =
{
new Content
{
Role = "user",
Parts =
{
new Part
{
Text = text,
}
}
}
}
};
var completion = await client.GenerateContentAsync(request);
completion.Should().NotBeNull();
completion.Candidates.Count.Should().BeGreaterThan(0);
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
}
[ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
public async Task ItGenerateContentWithImageAsync()
{
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
var client = new GoogleGeminiClient(apiKey);
var model = "gemini-1.5-flash-001";
var text = "what's in the image";
var imagePath = Path.Combine("testData", "images", "background.png");
var image = File.ReadAllBytes(imagePath);
var request = new GenerateContentRequest
{
Model = model,
Contents =
{
new Content
{
Role = "user",
Parts =
{
new Part
{
Text = text,
},
new Part
{
InlineData = new ()
{
MimeType = "image/png",
Data = ByteString.CopyFrom(image),
},
}
}
}
}
};
var completion = await client.GenerateContentAsync(request);
completion.Should().NotBeNull();
completion.Candidates.Count.Should().BeGreaterThan(0);
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
}
[ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
public async Task ItStreamingGenerateContentTestAsync()
{
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
var client = new GoogleGeminiClient(apiKey);
var model = "gemini-1.5-flash-001";
var text = "Tell me a long tedious joke";
var request = new GenerateContentRequest
{
Model = model,
Contents =
{
new Content
{
Role = "user",
Parts =
{
new Part
{
Text = text,
}
}
}
}
};
var response = client.GenerateContentStreamAsync(request);
var chunks = new List<GenerateContentResponse>();
GenerateContentResponse? final = null;
await foreach (var item in response)
{
item.Candidates.Count.Should().BeGreaterThan(0);
final = item;
chunks.Add(final);
}
chunks.Should().NotBeEmpty();
final.Should().NotBeNull();
final!.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
final!.Candidates[0].FinishReason.Should().Be(FinishReason.Stop);
}
}

View File

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SampleTests.cs
using AutoGen.Gemini.Sample;
using AutoGen.Tests;
namespace AutoGen.Gemini.Tests;
public class SampleTests
{
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task TestChatWithVertexGeminiAsync()
{
await Chat_With_Vertex_Gemini.RunAsync();
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task TestFunctionCallWithGeminiAsync()
{
await Function_Call_With_Gemini.RunAsync();
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task TestImageChatWithVertexGeminiAsync()
{
await Image_Chat_With_Vertex_Gemini.RunAsync();
}
}

View File

@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiVertexClientTests.cs
using AutoGen.Tests;
using FluentAssertions;
using Google.Cloud.AIPlatform.V1;
using Google.Protobuf;
using static Google.Cloud.AIPlatform.V1.Candidate.Types;
namespace AutoGen.Gemini.Tests;
public class VertexGeminiClientTests
{
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task ItGenerateContentAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
var client = new VertexGeminiClient(location);
var model = "gemini-1.5-flash-001";
var text = "Hello";
var request = new GenerateContentRequest
{
Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}",
Contents =
{
new Content
{
Role = "user",
Parts =
{
new Part
{
Text = text,
}
}
}
}
};
var completion = await client.GenerateContentAsync(request);
completion.Should().NotBeNull();
completion.Candidates.Count.Should().BeGreaterThan(0);
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task ItGenerateContentWithImageAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
var client = new VertexGeminiClient(location);
var model = "gemini-1.5-flash-001";
var text = "what's in the image";
var imagePath = Path.Combine("testData", "images", "image.png");
var image = File.ReadAllBytes(imagePath);
var request = new GenerateContentRequest
{
Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}",
Contents =
{
new Content
{
Role = "user",
Parts =
{
new Part
{
Text = text,
},
new Part
{
InlineData = new ()
{
MimeType = "image/png",
Data = ByteString.CopyFrom(image),
},
}
}
}
}
};
var completion = await client.GenerateContentAsync(request);
completion.Should().NotBeNull();
completion.Candidates.Count.Should().BeGreaterThan(0);
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
}
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
public async Task ItStreamingGenerateContentTestAsync()
{
var location = "us-central1";
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
var client = new VertexGeminiClient(location);
var model = "gemini-1.5-flash-001";
var text = "Hello, write a long tedious joke";
var request = new GenerateContentRequest
{
Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}",
Contents =
{
new Content
{
Role = "user",
Parts =
{
new Part
{
Text = text,
}
}
}
}
};
var response = client.GenerateContentStreamAsync(request);
var chunks = new List<GenerateContentResponse>();
GenerateContentResponse? final = null;
await foreach (var item in response)
{
item.Candidates.Count.Should().BeGreaterThan(0);
final = item;
chunks.Add(final);
}
chunks.Should().NotBeEmpty();
final.Should().NotBeNull();
final!.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
final!.Candidates[0].FinishReason.Should().Be(FinishReason.Stop);
}
}

View File

@ -4,18 +4,10 @@
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<IsTestProject>True</IsTestProject>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.Mistral\AutoGen.Mistral.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />

View File

@ -4,18 +4,10 @@
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<IsTestProject>True</IsTestProject>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.Ollama\AutoGen.Ollama.csproj" />
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />

View File

@ -3,18 +3,10 @@
<PropertyGroup>
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<IsPackable>false</IsPackable>
<IsTestProject>True</IsTestProject>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />

View File

@ -5,18 +5,10 @@
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<NoWarn>$(NoWarn);SKEXP0110</NoWarn>
<IsTestProject>True</IsTestProject>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />

View File

@ -4,18 +4,10 @@
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<IsTestProject>True</IsTestProject>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="true" />
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />

View File

@ -3,18 +3,10 @@
<PropertyGroup>
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<IsTestProject>True</IsTestProject>
<NoWarn>$(NoWarn);xUnit1013;SKEXP0110</NoWarn>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\sample\AutoGen.BasicSamples\AutoGen.BasicSample.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />

View File

@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ImageMessageTests.cs
using System;
using System.IO;
using System.Threading.Tasks;
using FluentAssertions;
using Xunit;
namespace AutoGen.Tests;
public class ImageMessageTests
{
[Fact]
public async Task ItCreateFromLocalImage()
{
var image = Path.Combine("testData", "images", "background.png");
var binary = File.ReadAllBytes(image);
var base64 = Convert.ToBase64String(binary);
var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(binary, "image/png"));
imageMessage.MimeType.Should().Be("image/png");
imageMessage.BuildDataUri().Should().Be($"data:image/png;base64,{base64}");
}
[Fact]
public async Task ItCreateFromUrl()
{
var image = Path.Combine("testData", "images", "background.png");
var fullPath = Path.GetFullPath(image);
var localUrl = new Uri(fullPath).AbsoluteUri;
var imageMessage = new ImageMessage(Role.User, localUrl);
imageMessage.Url.Should().Be(localUrl);
imageMessage.MimeType.Should().Be("image/png");
imageMessage.Data.Should().BeNull();
}
}

View File

@ -266,10 +266,10 @@ namespace AutoGen.Tests
public async Task EchoFunctionCallTestAsync(IAgent agent)
{
var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld });
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.From.Should().Be(agent.Name);
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
@ -277,10 +277,10 @@ namespace AutoGen.Tests
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
{
var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld });
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.GetContent().Should().Be("[ECHO] Hello world");
reply.From.Should().Be(agent.Name);
@ -289,13 +289,13 @@ namespace AutoGen.Tests
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
var answer = "[ECHO] Hello world";
IStreamingMessage? finalReply = default;
await foreach (var reply in replyStream)
@ -319,25 +319,23 @@ namespace AutoGen.Tests
public async Task UpperCaseTestAsync(IAgent agent)
{
var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case");
var uppCaseMessage = new TextMessage(Role.User, "abcdefg");
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
var reply = await agent.SendAsync(chatHistory: new[] { message, uppCaseMessage });
var reply = await agent.SendAsync(chatHistory: new[] { message });
reply.GetContent().Should().Contain("ABCDEFG");
reply.GetContent().Should().Contain("ABCDE");
reply.From.Should().Be(agent.Name);
}
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case");
var helloWorld = new TextMessage(Role.User, "a b c d e f g h i j k l m n");
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
var answer = "A B C D E F G H I J K L M N";
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
var answer = "HELLO WORLD";
TextMessage? finalReply = default;
await foreach (var reply in replyStream)
{