mirror of https://github.com/microsoft/autogen.git
[.Net] Add Goolge gemini (#2868)
* update * add vertex gemini test * remove DTO * add test for vertexGeminiAgent * update test name * update IGeminiClient interface * add test for streaming * add message connector * add gemini message extension * add tests * update * add gemnini sample * update examples * add test for iamge * fix test * add more tests * add streaming message test * add comment * remove unused json * implement google gemini client * update * fix comment
This commit is contained in:
parent
7d057a93b2
commit
a16b307dc0
|
@ -53,6 +53,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Tests", "
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Samples", "sample\AutoGen.Anthropic.Samples\AutoGen.Anthropic.Samples.csproj", "{834B4E85-64E5-4382-8465-548F332E5298}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini", "src\AutoGen.Gemini\AutoGen.Gemini.csproj", "{EFE0DC86-80FC-4D52-95B7-07654BA1A769}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini.Tests", "test\AutoGen.Gemini.Tests\AutoGen.Gemini.Tests.csproj", "{8EA16BAB-465A-4C07-ABC4-1070D40067E9}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Gemini.Sample", "sample\AutoGen.Gemini.Sample\AutoGen.Gemini.Sample.csproj", "{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}"
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.AotCompatibility.Tests", "test\AutoGen.AotCompatibility.Tests\AutoGen.AotCompatibility.Tests.csproj", "{6B82F26D-5040-4453-B21B-C8D1F913CE4C}"
|
||||
EndProject
|
||||
Global
|
||||
|
@ -149,6 +154,18 @@ Global
|
|||
{834B4E85-64E5-4382-8465-548F332E5298}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
|
@ -180,6 +197,9 @@ Global
|
|||
{6A95E113-B824-4524-8F13-CD0C3E1C8804} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
|
||||
{815E937E-86D6-4476-9EC6-B7FBCBBB5DB6} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
|
||||
{834B4E85-64E5-4382-8465-548F332E5298} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
|
||||
{EFE0DC86-80FC-4D52-95B7-07654BA1A769} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
|
||||
{8EA16BAB-465A-4C07-ABC4-1070D40067E9} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
|
||||
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
|
||||
{6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
|
|
|
@ -13,12 +13,37 @@
|
|||
<CSNoWarn>CS1998;CS1591</CSNoWarn>
|
||||
<NoWarn>$(NoWarn);$(CSNoWarn);NU5104</NoWarn>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<GenerateDocumentationFile>true</GenerateDocumentationFile>
|
||||
<IsPackable>false</IsPackable>
|
||||
<EnableNetAnalyzers>true</EnableNetAnalyzers>
|
||||
<EnforceCodeStyleInBuild>true</EnforceCodeStyleInBuild>
|
||||
<IsTestProject>false</IsTestProject>
|
||||
</PropertyGroup>
|
||||
|
||||
<PropertyGroup>
|
||||
<RepoRoot>$(MSBuildThisFileDirectory)</RepoRoot>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup Condition="'$(IsTestProject)' == 'true'">
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup Condition="'$(IsTestProject)' == 'true'">
|
||||
<Content Include="$(RepoRoot)resource/**/*.*">
|
||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
<Link>testData/%(RecursiveDir)%(Filename)%(Extension)</Link>
|
||||
</Content>
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup Condition="'$(IncludeResourceFolder)' == 'true'">
|
||||
<Content Include="$(RepoRoot)resource/**/*.*">
|
||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
<Link>resource/%(RecursiveDir)%(Filename)%(Extension)</Link>
|
||||
</Content>
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
<MicrosoftNETTestSdkVersion>17.7.0</MicrosoftNETTestSdkVersion>
|
||||
<MicrosoftDotnetInteractive>1.0.0-beta.24229.4</MicrosoftDotnetInteractive>
|
||||
<MicrosoftSourceLinkGitHubVersion>8.0.0</MicrosoftSourceLinkGitHubVersion>
|
||||
<GoogleCloudAPIPlatformVersion>3.0.0</GoogleCloudAPIPlatformVersion>
|
||||
<JsonSchemaVersion>4.3.0.2</JsonSchemaVersion>
|
||||
</PropertyGroup>
|
||||
</Project>
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8323d0b8eceb752e14c29543b2e28bb2fc648ed9719095c31b7708867a4dc918
|
||||
size 491
|
|
@ -93,7 +93,7 @@ The image is generated from prompt {prompt}
|
|||
if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION"))
|
||||
{
|
||||
var imageUrl = content.Split("\n").Last();
|
||||
var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From);
|
||||
var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From, mimeType: "image/png");
|
||||
|
||||
Console.WriteLine($"download image from {imageUrl} to {imagePath}");
|
||||
var httpClient = new HttpClient();
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>net8.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<IncludeResourceFolder>true</IncludeResourceFolder>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
|
||||
<ProjectReference Include="..\..\src\AutoGen.Gemini\AutoGen.Gemini.csproj" />
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Chat_With_Google_Gemini.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
using AutoGen.Gemini.Middleware;
|
||||
using FluentAssertions;
|
||||
|
||||
namespace AutoGen.Gemini.Sample;
|
||||
|
||||
public class Chat_With_Google_Gemini
|
||||
{
|
||||
public static async Task RunAsync()
|
||||
{
|
||||
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY");
|
||||
|
||||
if (apiKey is null)
|
||||
{
|
||||
Console.WriteLine("Please set GOOGLE_GEMINI_API_KEY environment variable.");
|
||||
return;
|
||||
}
|
||||
|
||||
#region Create_Gemini_Agent
|
||||
var geminiAgent = new GeminiChatAgent(
|
||||
name: "gemini",
|
||||
model: "gemini-1.5-flash-001",
|
||||
apiKey: apiKey,
|
||||
systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterPrintMessage();
|
||||
#endregion Create_Gemini_Agent
|
||||
|
||||
var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
|
||||
|
||||
#region verify_reply
|
||||
reply.Should().BeOfType<TextMessage>();
|
||||
#endregion verify_reply
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Chat_With_Vertex_Gemini.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
using AutoGen.Gemini.Middleware;
|
||||
using FluentAssertions;
|
||||
|
||||
namespace AutoGen.Gemini.Sample;
|
||||
|
||||
public class Chat_With_Vertex_Gemini
|
||||
{
|
||||
public static async Task RunAsync()
|
||||
{
|
||||
var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
|
||||
|
||||
if (projectID is null)
|
||||
{
|
||||
Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
|
||||
return;
|
||||
}
|
||||
|
||||
#region Create_Gemini_Agent
|
||||
var geminiAgent = new GeminiChatAgent(
|
||||
name: "gemini",
|
||||
model: "gemini-1.5-flash-001",
|
||||
location: "us-east1",
|
||||
project: projectID,
|
||||
systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterPrintMessage();
|
||||
#endregion Create_Gemini_Agent
|
||||
|
||||
var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
|
||||
|
||||
#region verify_reply
|
||||
reply.Should().BeOfType<TextMessage>();
|
||||
#endregion verify_reply
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Function_Call_With_Gemini.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
using AutoGen.Gemini.Middleware;
|
||||
using FluentAssertions;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
|
||||
namespace AutoGen.Gemini.Sample;
|
||||
|
||||
public partial class MovieFunction
|
||||
{
|
||||
/// <summary>
|
||||
/// find movie titles currently playing in theaters based on any description, genre, title words, etc.
|
||||
/// </summary>
|
||||
/// <param name="location">The city and state, e.g. San Francisco, CA or a zip code e.g. 95616</param>
|
||||
/// <param name="description">Any kind of description including category or genre, title words, attributes, etc.</param>
|
||||
/// <returns></returns>
|
||||
[Function]
|
||||
public async Task<string> FindMovies(string location, string description)
|
||||
{
|
||||
// dummy implementation
|
||||
var movies = new List<string> { "Barbie", "Spiderman", "Batman" };
|
||||
var result = $"Movies playing in {location} based on {description} are: {string.Join(", ", movies)}";
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// find theaters based on location and optionally movie title which is currently playing in theaters
|
||||
/// </summary>
|
||||
/// <param name="location">The city and state, e.g. San Francisco, CA or a zip code e.g. 95616</param>
|
||||
/// <param name="movie">Any movie title</param>
|
||||
[Function]
|
||||
public async Task<string> FindTheaters(string location, string movie)
|
||||
{
|
||||
// dummy implementation
|
||||
var theaters = new List<string> { "AMC", "Regal", "Cinemark" };
|
||||
var result = $"Theaters playing {movie} in {location} are: {string.Join(", ", theaters)}";
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Find the start times for movies playing in a specific theater
|
||||
/// </summary>
|
||||
/// <param name="location">The city and state, e.g. San Francisco, CA or a zip code e.g. 95616</param>
|
||||
/// <param name="movie">Any movie title</param>
|
||||
/// <param name="theater">Name of the theater</param>
|
||||
/// <param name="date">Date for requested showtime</param>
|
||||
/// <returns></returns>
|
||||
[Function]
|
||||
public async Task<string> GetShowtimes(string location, string movie, string theater, string date)
|
||||
{
|
||||
// dummy implementation
|
||||
var showtimes = new List<string> { "10:00 AM", "12:00 PM", "2:00 PM", "4:00 PM", "6:00 PM", "8:00 PM" };
|
||||
var result = $"Showtimes for {movie} at {theater} in {location} are: {string.Join(", ", showtimes)}";
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Modified from https://ai.google.dev/gemini-api/docs/function-calling
|
||||
/// </summary>
|
||||
public partial class Function_Call_With_Gemini
|
||||
{
|
||||
public static async Task RunAsync()
|
||||
{
|
||||
var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
|
||||
|
||||
if (projectID is null)
|
||||
{
|
||||
Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
|
||||
return;
|
||||
}
|
||||
|
||||
var movieFunction = new MovieFunction();
|
||||
var functionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [
|
||||
movieFunction.FindMoviesFunctionContract,
|
||||
movieFunction.FindTheatersFunctionContract,
|
||||
movieFunction.GetShowtimesFunctionContract
|
||||
],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ movieFunction.FindMoviesFunctionContract.Name!, movieFunction.FindMoviesWrapper },
|
||||
{ movieFunction.FindTheatersFunctionContract.Name!, movieFunction.FindTheatersWrapper },
|
||||
{ movieFunction.GetShowtimesFunctionContract.Name!, movieFunction.GetShowtimesWrapper },
|
||||
});
|
||||
|
||||
#region Create_Gemini_Agent
|
||||
var geminiAgent = new GeminiChatAgent(
|
||||
name: "gemini",
|
||||
model: "gemini-1.5-flash-001",
|
||||
location: "us-central1",
|
||||
project: projectID,
|
||||
systemMessage: "You are a helpful AI assistant",
|
||||
toolConfig: new ToolConfig()
|
||||
{
|
||||
FunctionCallingConfig = new FunctionCallingConfig()
|
||||
{
|
||||
Mode = FunctionCallingConfig.Types.Mode.Auto,
|
||||
}
|
||||
})
|
||||
.RegisterMessageConnector()
|
||||
.RegisterPrintMessage()
|
||||
.RegisterStreamingMiddleware(functionMiddleware);
|
||||
#endregion Create_Gemini_Agent
|
||||
|
||||
#region Single_turn
|
||||
var question = new TextMessage(Role.User, "What movies are showing in North Seattle tonight?");
|
||||
var functionCallReply = await geminiAgent.SendAsync(question);
|
||||
#endregion Single_turn
|
||||
|
||||
#region Single_turn_verify_reply
|
||||
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||
#endregion Single_turn_verify_reply
|
||||
|
||||
#region Multi_turn
|
||||
var finalReply = await geminiAgent.SendAsync(chatHistory: [question, functionCallReply]);
|
||||
#endregion Multi_turn
|
||||
|
||||
#region Multi_turn_verify_reply
|
||||
finalReply.Should().BeOfType<TextMessage>();
|
||||
#endregion Multi_turn_verify_reply
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Image_Chat_With_Vertex_Gemini.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
using AutoGen.Gemini.Middleware;
|
||||
using FluentAssertions;
|
||||
|
||||
namespace AutoGen.Gemini.Sample;
|
||||
|
||||
public class Image_Chat_With_Vertex_Gemini
|
||||
{
|
||||
public static async Task RunAsync()
|
||||
{
|
||||
var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
|
||||
|
||||
if (projectID is null)
|
||||
{
|
||||
Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
|
||||
return;
|
||||
}
|
||||
|
||||
#region Create_Gemini_Agent
|
||||
var geminiAgent = new GeminiChatAgent(
|
||||
name: "gemini",
|
||||
model: "gemini-1.5-flash-001",
|
||||
location: "us-east4",
|
||||
project: projectID,
|
||||
systemMessage: "You explain image content to user")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterPrintMessage();
|
||||
#endregion Create_Gemini_Agent
|
||||
|
||||
#region Send_Image_Request
|
||||
var imagePath = Path.Combine("resource", "images", "background.png");
|
||||
var image = await File.ReadAllBytesAsync(imagePath);
|
||||
var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(image, "image/png"));
|
||||
var reply = await geminiAgent.SendAsync("what's in the image", [imageMessage]);
|
||||
#endregion Send_Image_Request
|
||||
|
||||
#region Verify_Reply
|
||||
reply.Should().BeOfType<TextMessage>();
|
||||
#endregion Verify_Reply
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using AutoGen.Gemini.Sample;
|
||||
|
||||
Image_Chat_With_Vertex_Gemini.RunAsync().Wait();
|
|
@ -5,6 +5,7 @@
|
|||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
<NoWarn>$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110</NoWarn>
|
||||
<IncludeResourceFolder>true</IncludeResourceFolder>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
@ -15,10 +16,4 @@
|
|||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<None Update="images\*.png">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
@ -28,7 +28,7 @@ public class Chat_With_LLaVA
|
|||
#endregion Create_Ollama_Agent
|
||||
|
||||
#region Send_Message
|
||||
var image = Path.Combine("images", "background.png");
|
||||
var image = Path.Combine("resource", "images", "background.png");
|
||||
var binaryData = BinaryData.FromBytes(File.ReadAllBytes(image), "image/png");
|
||||
var imageMessage = new ImageMessage(Role.User, binaryData);
|
||||
var textMessage = new TextMessage(Role.User, "what's in this image?");
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ChatCompletionRequest.cs
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace AutoGen.Anthropic.DTO;
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
||||
public class ChatCompletionRequest
|
||||
{
|
||||
[JsonPropertyName("model")]
|
||||
|
|
|
@ -7,18 +7,34 @@ namespace AutoGen.Core;
|
|||
|
||||
public class ImageMessage : IMessage
|
||||
{
|
||||
public ImageMessage(Role role, string url, string? from = null)
|
||||
public ImageMessage(Role role, string url, string? from = null, string? mimeType = null)
|
||||
: this(role, new Uri(url), from, mimeType)
|
||||
{
|
||||
this.Role = role;
|
||||
this.From = from;
|
||||
this.Url = url;
|
||||
}
|
||||
|
||||
public ImageMessage(Role role, Uri uri, string? from = null)
|
||||
public ImageMessage(Role role, Uri uri, string? from = null, string? mimeType = null)
|
||||
{
|
||||
this.Role = role;
|
||||
this.From = from;
|
||||
this.Url = uri.ToString();
|
||||
|
||||
// try infer mimeType from uri extension if not provided
|
||||
if (mimeType is null)
|
||||
{
|
||||
mimeType = uri switch
|
||||
{
|
||||
_ when uri.AbsoluteUri.EndsWith(".png", StringComparison.OrdinalIgnoreCase) => "image/png",
|
||||
_ when uri.AbsoluteUri.EndsWith(".jpg", StringComparison.OrdinalIgnoreCase) => "image/jpeg",
|
||||
_ when uri.AbsoluteUri.EndsWith(".jpeg", StringComparison.OrdinalIgnoreCase) => "image/jpeg",
|
||||
_ when uri.AbsoluteUri.EndsWith(".gif", StringComparison.OrdinalIgnoreCase) => "image/gif",
|
||||
_ when uri.AbsoluteUri.EndsWith(".bmp", StringComparison.OrdinalIgnoreCase) => "image/bmp",
|
||||
_ when uri.AbsoluteUri.EndsWith(".webp", StringComparison.OrdinalIgnoreCase) => "image/webp",
|
||||
_ when uri.AbsoluteUri.EndsWith(".svg", StringComparison.OrdinalIgnoreCase) => "image/svg+xml",
|
||||
_ => throw new ArgumentException("MimeType is required for ImageMessage", nameof(mimeType))
|
||||
};
|
||||
}
|
||||
|
||||
this.MimeType = mimeType;
|
||||
}
|
||||
|
||||
public ImageMessage(Role role, BinaryData data, string? from = null)
|
||||
|
@ -28,7 +44,7 @@ public class ImageMessage : IMessage
|
|||
throw new ArgumentException("Data cannot be empty", nameof(data));
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(data.MediaType))
|
||||
if (data.MediaType is null)
|
||||
{
|
||||
throw new ArgumentException("MediaType is needed for DataUri Images", nameof(data));
|
||||
}
|
||||
|
@ -36,15 +52,18 @@ public class ImageMessage : IMessage
|
|||
this.Role = role;
|
||||
this.From = from;
|
||||
this.Data = data;
|
||||
this.MimeType = data.MediaType;
|
||||
}
|
||||
|
||||
public Role Role { get; set; }
|
||||
public Role Role { get; }
|
||||
|
||||
public string? Url { get; set; }
|
||||
public string? Url { get; }
|
||||
|
||||
public string? From { get; set; }
|
||||
|
||||
public BinaryData? Data { get; set; }
|
||||
public BinaryData? Data { get; }
|
||||
|
||||
public string MimeType { get; }
|
||||
|
||||
public string BuildDataUri()
|
||||
{
|
||||
|
@ -53,7 +72,7 @@ public class ImageMessage : IMessage
|
|||
throw new NullReferenceException($"{nameof(Data)}");
|
||||
}
|
||||
|
||||
return $"data:{this.Data.MediaType};base64,{Convert.ToBase64String(this.Data.ToArray())}";
|
||||
return $"data:{this.MimeType};base64,{Convert.ToBase64String(this.Data.ToArray())}";
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>netstandard2.0</TargetFramework>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Google.Cloud.AIPlatform.V1" Version="$(GoogleCloudAPIPlatformVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<InternalsVisibleTo Include="AutoGen.Gemini.Tests" />
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// FunctionContractExtension.cs
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using AutoGen.Core;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using Json.Schema;
|
||||
using Json.Schema.Generation;
|
||||
using OpenAPISchemaType = Google.Cloud.AIPlatform.V1.Type;
|
||||
using Type = System.Type;
|
||||
|
||||
namespace AutoGen.Gemini.Extension;
|
||||
|
||||
public static class FunctionContractExtension
|
||||
{
|
||||
/// <summary>
|
||||
/// Convert a <see cref="FunctionContract"/> to a <see cref="FunctionDeclaration"/> that can be used in gpt funciton call.
|
||||
/// </summary>
|
||||
public static FunctionDeclaration ToFunctionDeclaration(this FunctionContract function)
|
||||
{
|
||||
var required = function.Parameters!.Where(p => p.IsRequired)
|
||||
.Select(p => p.Name)
|
||||
.ToList();
|
||||
var parameterProperties = new Dictionary<string, OpenApiSchema>();
|
||||
|
||||
foreach (var parameter in function.Parameters ?? Enumerable.Empty<FunctionParameterContract>())
|
||||
{
|
||||
var schema = ToOpenApiSchema(parameter.ParameterType);
|
||||
schema.Description = parameter.Description;
|
||||
schema.Title = parameter.Name;
|
||||
schema.Nullable = !parameter.IsRequired;
|
||||
parameterProperties.Add(parameter.Name!, schema);
|
||||
}
|
||||
|
||||
return new FunctionDeclaration
|
||||
{
|
||||
Name = function.Name,
|
||||
Description = function.Description,
|
||||
Parameters = new OpenApiSchema
|
||||
{
|
||||
Required =
|
||||
{
|
||||
required,
|
||||
},
|
||||
Properties =
|
||||
{
|
||||
parameterProperties,
|
||||
},
|
||||
Type = OpenAPISchemaType.Object,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private static OpenApiSchema ToOpenApiSchema(Type? type)
|
||||
{
|
||||
if (type == null)
|
||||
{
|
||||
return new OpenApiSchema
|
||||
{
|
||||
Type = OpenAPISchemaType.Unspecified
|
||||
};
|
||||
}
|
||||
|
||||
var schema = new JsonSchemaBuilder().FromType(type).Build();
|
||||
var openApiSchema = new OpenApiSchema
|
||||
{
|
||||
Type = schema.GetJsonType() switch
|
||||
{
|
||||
SchemaValueType.Array => OpenAPISchemaType.Array,
|
||||
SchemaValueType.Boolean => OpenAPISchemaType.Boolean,
|
||||
SchemaValueType.Integer => OpenAPISchemaType.Integer,
|
||||
SchemaValueType.Number => OpenAPISchemaType.Number,
|
||||
SchemaValueType.Object => OpenAPISchemaType.Object,
|
||||
SchemaValueType.String => OpenAPISchemaType.String,
|
||||
_ => OpenAPISchemaType.Unspecified
|
||||
},
|
||||
};
|
||||
|
||||
if (schema.GetJsonType() == SchemaValueType.Object && schema.GetProperties() is var properties && properties != null)
|
||||
{
|
||||
foreach (var property in properties)
|
||||
{
|
||||
openApiSchema.Properties.Add(property.Key, ToOpenApiSchema(property.Value.GetType()));
|
||||
}
|
||||
}
|
||||
|
||||
return openApiSchema;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,268 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GeminiChatAgent.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using AutoGen.Core;
|
||||
using AutoGen.Gemini.Extension;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using Google.Protobuf.Collections;
|
||||
namespace AutoGen.Gemini;
|
||||
|
||||
public class GeminiChatAgent : IStreamingAgent
|
||||
{
|
||||
private readonly IGeminiClient client;
|
||||
private readonly string? systemMessage;
|
||||
private readonly string model;
|
||||
private readonly ToolConfig? toolConfig;
|
||||
private readonly RepeatedField<SafetySetting>? safetySettings;
|
||||
private readonly string responseMimeType;
|
||||
private readonly Tool[]? tools;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="GeminiChatAgent"/> that connects to Gemini.
|
||||
/// </summary>
|
||||
/// <param name="client">the gemini client to use. e.g. <see cref="VertexGeminiClient"/> </param>
|
||||
/// <param name="name">agent name</param>
|
||||
/// <param name="model">the model id. It needs to be in the format of
|
||||
/// 'projects/{project}/locations/{location}/publishers/{provider}/models/{model}' if the <paramref name="client"/> is <see cref="VertexGeminiClient"/></param>
|
||||
/// <param name="systemMessage">system message</param>
|
||||
/// <param name="toolConfig">tool config</param>
|
||||
/// <param name="tools">tools</param>
|
||||
/// <param name="safetySettings">safety settings</param>
|
||||
/// <param name="responseMimeType">response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain'</param>
|
||||
public GeminiChatAgent(
|
||||
IGeminiClient client,
|
||||
string name,
|
||||
string model,
|
||||
string? systemMessage = null,
|
||||
ToolConfig? toolConfig = null,
|
||||
Tool[]? tools = null,
|
||||
RepeatedField<SafetySetting>? safetySettings = null,
|
||||
string responseMimeType = "text/plain")
|
||||
{
|
||||
this.client = client;
|
||||
this.Name = name;
|
||||
this.systemMessage = systemMessage;
|
||||
this.model = model;
|
||||
this.toolConfig = toolConfig;
|
||||
this.safetySettings = safetySettings;
|
||||
this.responseMimeType = responseMimeType;
|
||||
this.tools = tools;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="GeminiChatAgent"/> that connects to Gemini using <see cref="GoogleGeminiClient"/>
|
||||
/// </summary>
|
||||
/// <param name="name">agent name</param>
|
||||
/// <param name="model">the name of gemini model, e.g. gemini-1.5-flash-001</param>
|
||||
/// <param name="apiKey">google gemini api key</param>
|
||||
/// <param name="systemMessage">system message</param>
|
||||
/// <param name="toolConfig">tool config</param>
|
||||
/// <param name="tools">tools</param>
|
||||
/// <param name="safetySettings"></param>
|
||||
/// <param name="responseMimeType">response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain'</param>
|
||||
/// /// <example>
|
||||
/// <]
|
||||
/// ]]>
|
||||
/// </example>
|
||||
public GeminiChatAgent(
|
||||
string name,
|
||||
string model,
|
||||
string apiKey,
|
||||
string systemMessage = "You are a helpful AI assistant",
|
||||
ToolConfig? toolConfig = null,
|
||||
Tool[]? tools = null,
|
||||
RepeatedField<SafetySetting>? safetySettings = null,
|
||||
string responseMimeType = "text/plain")
|
||||
: this(
|
||||
client: new GoogleGeminiClient(apiKey),
|
||||
name: name,
|
||||
model: model,
|
||||
systemMessage: systemMessage,
|
||||
toolConfig: toolConfig,
|
||||
tools: tools,
|
||||
safetySettings: safetySettings,
|
||||
responseMimeType: responseMimeType)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="GeminiChatAgent"/> that connects to Vertex AI.
|
||||
/// </summary>
|
||||
/// <param name="name">agent name</param>
|
||||
/// <param name="systemMessage">system message</param>
|
||||
/// <param name="model">the name of gemini model, e.g. gemini-1.5-flash-001</param>
|
||||
/// <param name="project">project id</param>
|
||||
/// <param name="location">model location</param>
|
||||
/// <param name="provider">model provider, default is 'google'</param>
|
||||
/// <param name="toolConfig">tool config</param>
|
||||
/// <param name="tools">tools</param>
|
||||
/// <param name="safetySettings"></param>
|
||||
/// <param name="responseMimeType">response mime type, available values are ['application/json', 'text/plain'], default is 'text/plain'</param>
|
||||
/// <example>
|
||||
/// <]
|
||||
/// ]]>
|
||||
/// </example>
|
||||
public GeminiChatAgent(
|
||||
string name,
|
||||
string model,
|
||||
string project,
|
||||
string location,
|
||||
string provider = "google",
|
||||
string? systemMessage = null,
|
||||
ToolConfig? toolConfig = null,
|
||||
Tool[]? tools = null,
|
||||
RepeatedField<SafetySetting>? safetySettings = null,
|
||||
string responseMimeType = "text/plain")
|
||||
: this(
|
||||
client: new VertexGeminiClient(location),
|
||||
name: name,
|
||||
model: $"projects/{project}/locations/{location}/publishers/{provider}/models/{model}",
|
||||
systemMessage: systemMessage,
|
||||
toolConfig: toolConfig,
|
||||
tools: tools,
|
||||
safetySettings: safetySettings,
|
||||
responseMimeType: responseMimeType)
|
||||
{
|
||||
}
|
||||
|
||||
public string Name { get; }
|
||||
|
||||
public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var request = BuildChatRequest(messages, options);
|
||||
var response = await this.client.GenerateContentAsync(request, cancellationToken: cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return MessageEnvelope.Create(response, this.Name);
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var request = BuildChatRequest(messages, options);
|
||||
var response = this.client.GenerateContentStreamAsync(request);
|
||||
|
||||
await foreach (var item in response.WithCancellation(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
yield return MessageEnvelope.Create(item, this.Name);
|
||||
}
|
||||
}
|
||||
|
||||
private GenerateContentRequest BuildChatRequest(IEnumerable<IMessage> messages, GenerateReplyOptions? options)
|
||||
{
|
||||
var geminiMessages = messages.Select(m => m switch
|
||||
{
|
||||
IMessage<Content> contentMessage => contentMessage.Content,
|
||||
_ => throw new NotSupportedException($"Message type {m.GetType()} is not supported.")
|
||||
});
|
||||
|
||||
// there are several rules applies to the messages that can be sent to Gemini in a multi-turn chat
|
||||
// - The first message must be from the user or function
|
||||
// - The (user|model) roles must alternate e.g. (user, model, user, model, ...)
|
||||
// - The last message must be from the user or function
|
||||
|
||||
// check if the first message is from the user
|
||||
if (geminiMessages.FirstOrDefault()?.Role != "user" && geminiMessages.FirstOrDefault()?.Role != "function")
|
||||
{
|
||||
throw new ArgumentException("The first message must be from the user or function", nameof(messages));
|
||||
}
|
||||
|
||||
// check if the last message is from the user
|
||||
if (geminiMessages.LastOrDefault()?.Role != "user" && geminiMessages.LastOrDefault()?.Role != "function")
|
||||
{
|
||||
throw new ArgumentException("The last message must be from the user or function", nameof(messages));
|
||||
}
|
||||
|
||||
// merge continuous messages with the same role into one message
|
||||
var mergedMessages = geminiMessages.Aggregate(new List<Content>(), (acc, message) =>
|
||||
{
|
||||
if (acc.Count == 0 || acc.Last().Role != message.Role)
|
||||
{
|
||||
acc.Add(message);
|
||||
}
|
||||
else
|
||||
{
|
||||
acc.Last().Parts.AddRange(message.Parts);
|
||||
}
|
||||
|
||||
return acc;
|
||||
});
|
||||
|
||||
var systemMessage = this.systemMessage switch
|
||||
{
|
||||
null => null,
|
||||
string message => new Content
|
||||
{
|
||||
Parts = { new[] { new Part { Text = message } } },
|
||||
Role = "system_instruction"
|
||||
}
|
||||
};
|
||||
|
||||
List<Tool> tools = this.tools?.ToList() ?? new List<Tool>();
|
||||
|
||||
var request = new GenerateContentRequest()
|
||||
{
|
||||
Contents = { mergedMessages },
|
||||
SystemInstruction = systemMessage,
|
||||
Model = this.model,
|
||||
GenerationConfig = new GenerationConfig
|
||||
{
|
||||
StopSequences = { options?.StopSequence ?? Enumerable.Empty<string>() },
|
||||
ResponseMimeType = this.responseMimeType,
|
||||
CandidateCount = 1,
|
||||
},
|
||||
};
|
||||
|
||||
if (this.toolConfig is not null)
|
||||
{
|
||||
request.ToolConfig = this.toolConfig;
|
||||
}
|
||||
|
||||
if (this.safetySettings is not null)
|
||||
{
|
||||
request.SafetySettings.Add(this.safetySettings);
|
||||
}
|
||||
|
||||
if (options?.MaxToken.HasValue is true)
|
||||
{
|
||||
request.GenerationConfig.MaxOutputTokens = options.MaxToken.Value;
|
||||
}
|
||||
|
||||
if (options?.Temperature.HasValue is true)
|
||||
{
|
||||
request.GenerationConfig.Temperature = options.Temperature.Value;
|
||||
}
|
||||
|
||||
if (options?.Functions is { Length: > 0 })
|
||||
{
|
||||
foreach (var function in options.Functions)
|
||||
{
|
||||
tools.Add(new Tool
|
||||
{
|
||||
FunctionDeclarations = { function.ToFunctionDeclaration() },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// merge tools into one tool
|
||||
// because multipe tools are currently not supported by Gemini
|
||||
// see https://github.com/googleapis/python-aiplatform/issues/3771
|
||||
var aggregatedTool = new Tool
|
||||
{
|
||||
FunctionDeclarations = { tools.SelectMany(t => t.FunctionDeclarations) },
|
||||
};
|
||||
|
||||
if (aggregatedTool is { FunctionDeclarations: { Count: > 0 } })
|
||||
{
|
||||
request.Tools.Add(aggregatedTool);
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GoogleGeminiClient.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Net.Http;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using Google.Protobuf;
|
||||
|
||||
namespace AutoGen.Gemini;
|
||||
|
||||
public class GoogleGeminiClient : IGeminiClient
|
||||
{
|
||||
private readonly string apiKey;
|
||||
private const string endpoint = "https://generativelanguage.googleapis.com/v1beta";
|
||||
private readonly HttpClient httpClient = new();
|
||||
private const string generateContentPath = "models/{0}:generateContent";
|
||||
private const string generateContentStreamPath = "models/{0}:streamGenerateContent";
|
||||
|
||||
public GoogleGeminiClient(HttpClient httpClient, string apiKey)
|
||||
{
|
||||
this.apiKey = apiKey;
|
||||
this.httpClient = httpClient;
|
||||
}
|
||||
|
||||
public GoogleGeminiClient(string apiKey)
|
||||
{
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
|
||||
public async Task<GenerateContentResponse> GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var path = string.Format(generateContentPath, request.Model);
|
||||
var url = $"{endpoint}/{path}?key={apiKey}";
|
||||
|
||||
var httpContent = new StringContent(JsonFormatter.Default.Format(request), System.Text.Encoding.UTF8, "application/json");
|
||||
var response = await httpClient.PostAsync(url, httpContent, cancellationToken);
|
||||
|
||||
if (!response.IsSuccessStatusCode)
|
||||
{
|
||||
throw new Exception($"Failed to generate content. Status code: {response.StatusCode}");
|
||||
}
|
||||
|
||||
var json = await response.Content.ReadAsStringAsync();
|
||||
return GenerateContentResponse.Parser.ParseJson(json);
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(GenerateContentRequest request)
|
||||
{
|
||||
var path = string.Format(generateContentStreamPath, request.Model);
|
||||
var url = $"{endpoint}/{path}?key={apiKey}&alt=sse";
|
||||
|
||||
var httpContent = new StringContent(JsonFormatter.Default.Format(request), System.Text.Encoding.UTF8, "application/json");
|
||||
var requestMessage = new HttpRequestMessage(HttpMethod.Post, url)
|
||||
{
|
||||
Content = httpContent
|
||||
};
|
||||
|
||||
var response = await httpClient.SendAsync(requestMessage, HttpCompletionOption.ResponseHeadersRead);
|
||||
|
||||
if (!response.IsSuccessStatusCode)
|
||||
{
|
||||
throw new Exception($"Failed to generate content. Status code: {response.StatusCode}");
|
||||
}
|
||||
|
||||
var stream = await response.Content.ReadAsStreamAsync();
|
||||
var jp = new JsonParser(JsonParser.Settings.Default.WithIgnoreUnknownFields(true));
|
||||
using var streamReader = new System.IO.StreamReader(stream);
|
||||
while (!streamReader.EndOfStream)
|
||||
{
|
||||
var json = await streamReader.ReadLineAsync();
|
||||
if (string.IsNullOrWhiteSpace(json))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
json = json.Substring("data:".Length).Trim();
|
||||
yield return jp.Parse<GenerateContentResponse>(json);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IVertexGeminiClient.cs
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
|
||||
namespace AutoGen.Gemini;
|
||||
|
||||
public interface IGeminiClient
|
||||
{
|
||||
Task<GenerateContentResponse> GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default);
|
||||
IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(GenerateContentRequest request);
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GeminiAgentExtension.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
|
||||
namespace AutoGen.Gemini.Middleware;
|
||||
|
||||
public static class GeminiAgentExtension
|
||||
{
|
||||
|
||||
/// <summary>
|
||||
/// Register an <see cref="GeminiMessageConnector"/> to the <see cref="GeminiChatAgent"/>
|
||||
/// </summary>
|
||||
/// <param name="connector">the connector to use. If null, a new instance of <see cref="GeminiMessageConnector"/> will be created.</param>
|
||||
public static MiddlewareStreamingAgent<GeminiChatAgent> RegisterMessageConnector(
|
||||
this GeminiChatAgent agent, GeminiMessageConnector? connector = null)
|
||||
{
|
||||
if (connector == null)
|
||||
{
|
||||
connector = new GeminiMessageConnector();
|
||||
}
|
||||
|
||||
return agent.RegisterStreamingMiddleware(connector);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Register an <see cref="GeminiMessageConnector"/> to the <see cref="MiddlewareAgent{T}"/> where T is <see cref="GeminiChatAgent"/>
|
||||
/// </summary>
|
||||
/// <param name="connector">the connector to use. If null, a new instance of <see cref="GeminiMessageConnector"/> will be created.</param>
|
||||
public static MiddlewareStreamingAgent<GeminiChatAgent> RegisterMessageConnector(
|
||||
this MiddlewareStreamingAgent<GeminiChatAgent> agent, GeminiMessageConnector? connector = null)
|
||||
{
|
||||
if (connector == null)
|
||||
{
|
||||
connector = new GeminiMessageConnector();
|
||||
}
|
||||
|
||||
return agent.RegisterStreamingMiddleware(connector);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,483 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GeminiMessageConnector.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Nodes;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using AutoGen.Core;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using Google.Protobuf;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using static Google.Cloud.AIPlatform.V1.Candidate.Types;
|
||||
using IMessage = AutoGen.Core.IMessage;
|
||||
|
||||
namespace AutoGen.Gemini.Middleware;
|
||||
|
||||
public class GeminiMessageConnector : IStreamingMiddleware
|
||||
{
|
||||
/// <summary>
|
||||
/// if true, the connector will throw an exception if it encounters an unsupport message type.
|
||||
/// Otherwise, it will ignore processing the message and return the message as is.
|
||||
/// </summary>
|
||||
private readonly bool strictMode;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="GeminiMessageConnector"/> class.
|
||||
/// </summary>
|
||||
/// <param name="strictMode">whether to throw an exception if it encounters an unsupport message type.
|
||||
/// If true, the connector will throw an exception if it encounters an unsupport message type.
|
||||
/// If false, it will ignore processing the message and return the message as is.</param>
|
||||
public GeminiMessageConnector(bool strictMode = false)
|
||||
{
|
||||
this.strictMode = strictMode;
|
||||
}
|
||||
|
||||
public string Name => nameof(GeminiMessageConnector);
|
||||
|
||||
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var messages = ProcessMessage(context.Messages, agent);
|
||||
|
||||
var bucket = new List<GenerateContentResponse>();
|
||||
|
||||
await foreach (var reply in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken))
|
||||
{
|
||||
if (reply is Core.IMessage<GenerateContentResponse> m)
|
||||
{
|
||||
// if m.Content is empty and stop reason is Stop, ignore the message
|
||||
if (m.Content.Candidates.Count == 1 && m.Content.Candidates[0].Content.Parts.Count == 1 && m.Content.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.Text)
|
||||
{
|
||||
var text = m.Content.Candidates[0].Content.Parts[0].Text;
|
||||
var stopReason = m.Content.Candidates[0].FinishReason;
|
||||
if (string.IsNullOrEmpty(text) && stopReason == FinishReason.Stop)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
bucket.Add(m.Content);
|
||||
|
||||
yield return PostProcessStreamingMessage(m.Content, agent);
|
||||
}
|
||||
else if (strictMode)
|
||||
{
|
||||
throw new InvalidOperationException($"Unsupported message type: {reply.GetType()}");
|
||||
}
|
||||
else
|
||||
{
|
||||
yield return reply;
|
||||
}
|
||||
|
||||
// aggregate the message updates from bucket into a single message
|
||||
if (bucket is { Count: > 0 })
|
||||
{
|
||||
var isTextMessageUpdates = bucket.All(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.Text);
|
||||
var isFunctionCallUpdates = bucket.Any(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.FunctionCall);
|
||||
if (isTextMessageUpdates)
|
||||
{
|
||||
var text = string.Join(string.Empty, bucket.Select(m => m.Candidates[0].Content.Parts[0].Text));
|
||||
var textMessage = new TextMessage(Role.Assistant, text, agent.Name);
|
||||
|
||||
yield return textMessage;
|
||||
}
|
||||
else if (isFunctionCallUpdates)
|
||||
{
|
||||
var functionCallParts = bucket.Where(m => m.Candidates.Count == 1 && m.Candidates[0].Content.Parts.Count == 1 && m.Candidates[0].Content.Parts[0].DataCase == Part.DataOneofCase.FunctionCall)
|
||||
.Select(m => m.Candidates[0].Content.Parts[0]).ToList();
|
||||
|
||||
var toolCalls = new List<ToolCall>();
|
||||
foreach (var part in functionCallParts)
|
||||
{
|
||||
var fc = part.FunctionCall;
|
||||
var toolCall = new ToolCall(fc.Name, fc.Args.ToString());
|
||||
|
||||
toolCalls.Add(toolCall);
|
||||
}
|
||||
|
||||
var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name);
|
||||
|
||||
yield return toolCallMessage;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The response should contain either text or tool calls.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var messages = ProcessMessage(context.Messages, agent);
|
||||
var reply = await agent.GenerateReplyAsync(messages, context.Options, cancellationToken);
|
||||
|
||||
return reply switch
|
||||
{
|
||||
Core.IMessage<GenerateContentResponse> m => PostProcessMessage(m.Content, agent),
|
||||
_ when strictMode => throw new InvalidOperationException($"Unsupported message type: {reply.GetType()}"),
|
||||
_ => reply,
|
||||
};
|
||||
}
|
||||
|
||||
private IMessage PostProcessStreamingMessage(GenerateContentResponse m, IAgent agent)
|
||||
{
|
||||
this.ValidateGenerateContentResponse(m);
|
||||
|
||||
var candidate = m.Candidates[0];
|
||||
var parts = candidate.Content.Parts;
|
||||
|
||||
if (parts.Count == 1 && parts[0].DataCase == Part.DataOneofCase.Text)
|
||||
{
|
||||
var content = parts[0].Text;
|
||||
return new TextMessageUpdate(Role.Assistant, content, agent.Name);
|
||||
}
|
||||
else
|
||||
{
|
||||
var toolCalls = new List<ToolCall>();
|
||||
foreach (var part in parts)
|
||||
{
|
||||
if (part.DataCase == Part.DataOneofCase.FunctionCall)
|
||||
{
|
||||
var fc = part.FunctionCall;
|
||||
var toolCall = new ToolCall(fc.Name, fc.Args.ToString());
|
||||
|
||||
toolCalls.Add(toolCall);
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.Count > 0)
|
||||
{
|
||||
var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name);
|
||||
return toolCallMessage;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The response should contain either text or tool calls.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private IMessage PostProcessMessage(GenerateContentResponse m, IAgent agent)
|
||||
{
|
||||
this.ValidateGenerateContentResponse(m);
|
||||
var candidate = m.Candidates[0];
|
||||
var parts = candidate.Content.Parts;
|
||||
|
||||
if (parts.Count == 1 && parts[0].DataCase == Part.DataOneofCase.Text)
|
||||
{
|
||||
var content = parts[0].Text;
|
||||
return new TextMessage(Role.Assistant, content, agent.Name);
|
||||
}
|
||||
else
|
||||
{
|
||||
var toolCalls = new List<ToolCall>();
|
||||
foreach (var part in parts)
|
||||
{
|
||||
if (part.DataCase == Part.DataOneofCase.FunctionCall)
|
||||
{
|
||||
var fc = part.FunctionCall;
|
||||
var toolCall = new ToolCall(fc.Name, fc.Args.ToString());
|
||||
|
||||
toolCalls.Add(toolCall);
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.Count > 0)
|
||||
{
|
||||
var toolCallMessage = new ToolCallMessage(toolCalls, agent.Name);
|
||||
return toolCallMessage;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The response should contain either text or tool calls.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private IEnumerable<IMessage> ProcessMessage(IEnumerable<IMessage> messages, IAgent agent)
|
||||
{
|
||||
return messages.SelectMany(m =>
|
||||
{
|
||||
if (m is Core.IMessage<Content> messageEnvelope)
|
||||
{
|
||||
return [m];
|
||||
}
|
||||
else
|
||||
{
|
||||
return m switch
|
||||
{
|
||||
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
|
||||
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
|
||||
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
|
||||
ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
|
||||
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage, agent),
|
||||
ToolCallAggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
|
||||
_ when strictMode => throw new InvalidOperationException($"Unsupported message type: {m.GetType()}"),
|
||||
_ => [m],
|
||||
};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private IEnumerable<IMessage> ProcessToolCallAggregateMessage(ToolCallAggregateMessage toolCallAggregateMessage, IAgent agent)
|
||||
{
|
||||
var parseAsUser = ShouldParseAsUser(toolCallAggregateMessage, agent);
|
||||
if (parseAsUser)
|
||||
{
|
||||
var content = toolCallAggregateMessage.GetContent();
|
||||
|
||||
if (content is string str)
|
||||
{
|
||||
var textMessage = new TextMessage(Role.User, str, toolCallAggregateMessage.From);
|
||||
|
||||
return ProcessTextMessage(textMessage, agent);
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
else
|
||||
{
|
||||
var toolCallContents = ProcessToolCallMessage(toolCallAggregateMessage.Message1, agent);
|
||||
var toolCallResultContents = ProcessToolCallResultMessage(toolCallAggregateMessage.Message2, agent);
|
||||
|
||||
return toolCallContents.Concat(toolCallResultContents);
|
||||
}
|
||||
}
|
||||
|
||||
private void ValidateGenerateContentResponse(GenerateContentResponse response)
|
||||
{
|
||||
if (response.Candidates.Count != 1)
|
||||
{
|
||||
throw new InvalidOperationException("The response should contain exactly one candidate.");
|
||||
}
|
||||
|
||||
var candidate = response.Candidates[0];
|
||||
if (candidate.Content is null)
|
||||
{
|
||||
var finishReason = candidate.FinishReason;
|
||||
var finishMessage = candidate.FinishMessage;
|
||||
|
||||
throw new InvalidOperationException($"The response should contain content but the content is empty. FinishReason: {finishReason}, FinishMessage: {finishMessage}");
|
||||
}
|
||||
}
|
||||
|
||||
private IEnumerable<IMessage> ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage, IAgent agent)
|
||||
{
|
||||
var functionCallResultParts = new List<Part>();
|
||||
foreach (var toolCallResult in toolCallResultMessage.ToolCalls)
|
||||
{
|
||||
if (toolCallResult.Result is null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// if result is already a json object, use it as is
|
||||
var json = toolCallResult.Result;
|
||||
try
|
||||
{
|
||||
JsonNode.Parse(json);
|
||||
}
|
||||
catch (JsonException)
|
||||
{
|
||||
// if the result is not a json object, wrap it in a json object
|
||||
var result = new { result = json };
|
||||
json = JsonSerializer.Serialize(result);
|
||||
}
|
||||
var part = new Part
|
||||
{
|
||||
FunctionResponse = new FunctionResponse
|
||||
{
|
||||
Name = toolCallResult.FunctionName,
|
||||
Response = Struct.Parser.ParseJson(json),
|
||||
}
|
||||
};
|
||||
|
||||
functionCallResultParts.Add(part);
|
||||
}
|
||||
|
||||
var content = new Content
|
||||
{
|
||||
Parts = { functionCallResultParts },
|
||||
Role = "function",
|
||||
};
|
||||
|
||||
return [MessageEnvelope.Create(content, toolCallResultMessage.From)];
|
||||
}
|
||||
|
||||
private IEnumerable<IMessage> ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
|
||||
{
|
||||
var shouldParseAsUser = ShouldParseAsUser(toolCallMessage, agent);
|
||||
if (strictMode && shouldParseAsUser)
|
||||
{
|
||||
throw new InvalidOperationException("ToolCallMessage is not supported as user role in Gemini.");
|
||||
}
|
||||
|
||||
var functionCallParts = new List<Part>();
|
||||
foreach (var toolCall in toolCallMessage.ToolCalls)
|
||||
{
|
||||
var part = new Part
|
||||
{
|
||||
FunctionCall = new FunctionCall
|
||||
{
|
||||
Name = toolCall.FunctionName,
|
||||
Args = Struct.Parser.ParseJson(toolCall.FunctionArguments),
|
||||
}
|
||||
};
|
||||
|
||||
functionCallParts.Add(part);
|
||||
}
|
||||
var content = new Content
|
||||
{
|
||||
Parts = { functionCallParts },
|
||||
Role = "model"
|
||||
};
|
||||
|
||||
return [MessageEnvelope.Create(content, toolCallMessage.From)];
|
||||
}
|
||||
|
||||
private IEnumerable<IMessage> ProcessMultiModalMessage(MultiModalMessage multiModalMessage, IAgent agent)
|
||||
{
|
||||
var parts = new List<Part>();
|
||||
foreach (var message in multiModalMessage.Content)
|
||||
{
|
||||
if (message is TextMessage textMessage)
|
||||
{
|
||||
parts.Add(new Part { Text = textMessage.Content });
|
||||
}
|
||||
else if (message is ImageMessage imageMessage)
|
||||
{
|
||||
parts.Add(CreateImagePart(imageMessage));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Unsupported message type: {message.GetType()}");
|
||||
}
|
||||
}
|
||||
|
||||
var shouldParseAsUser = ShouldParseAsUser(multiModalMessage, agent);
|
||||
|
||||
if (strictMode && !shouldParseAsUser)
|
||||
{
|
||||
// image message is not supported as model role in Gemini
|
||||
throw new InvalidOperationException("Image message is not supported as model role in Gemini.");
|
||||
}
|
||||
|
||||
var content = new Content
|
||||
{
|
||||
Parts = { parts },
|
||||
Role = shouldParseAsUser ? "user" : "model",
|
||||
};
|
||||
|
||||
return [MessageEnvelope.Create(content, multiModalMessage.From)];
|
||||
}
|
||||
|
||||
private IEnumerable<IMessage> ProcessTextMessage(TextMessage textMessage, IAgent agent)
|
||||
{
|
||||
if (textMessage.Role == Role.System)
|
||||
{
|
||||
// there are only user | model role in Gemini
|
||||
// if the role is system and the strict mode is enabled, throw an exception
|
||||
if (strictMode)
|
||||
{
|
||||
throw new InvalidOperationException("System role is not supported in Gemini.");
|
||||
}
|
||||
|
||||
// if strict mode is not enabled, parse the message as a user message
|
||||
var content = new Content
|
||||
{
|
||||
Parts = { new[] { new Part { Text = textMessage.Content } } },
|
||||
Role = "user",
|
||||
};
|
||||
|
||||
return [MessageEnvelope.Create(content, textMessage.From)];
|
||||
}
|
||||
|
||||
var shouldParseAsUser = ShouldParseAsUser(textMessage, agent);
|
||||
|
||||
if (shouldParseAsUser)
|
||||
{
|
||||
var content = new Content
|
||||
{
|
||||
Parts = { new[] { new Part { Text = textMessage.Content } } },
|
||||
Role = "user",
|
||||
};
|
||||
|
||||
return [MessageEnvelope.Create(content, textMessage.From)];
|
||||
}
|
||||
else
|
||||
{
|
||||
var content = new Content
|
||||
{
|
||||
Parts = { new[] { new Part { Text = textMessage.Content } } },
|
||||
Role = "model",
|
||||
};
|
||||
|
||||
return [MessageEnvelope.Create(content, textMessage.From)];
|
||||
}
|
||||
}
|
||||
|
||||
private IEnumerable<IMessage> ProcessImageMessage(ImageMessage imageMessage, IAgent agent)
|
||||
{
|
||||
var imagePart = CreateImagePart(imageMessage);
|
||||
var shouldParseAsUser = ShouldParseAsUser(imageMessage, agent);
|
||||
|
||||
if (strictMode && !shouldParseAsUser)
|
||||
{
|
||||
// image message is not supported as model role in Gemini
|
||||
throw new InvalidOperationException("Image message is not supported as model role in Gemini.");
|
||||
}
|
||||
|
||||
var content = new Content
|
||||
{
|
||||
Parts = { imagePart },
|
||||
Role = shouldParseAsUser ? "user" : "model",
|
||||
};
|
||||
|
||||
return [MessageEnvelope.Create(content, imageMessage.From)];
|
||||
}
|
||||
|
||||
private Part CreateImagePart(ImageMessage message)
|
||||
{
|
||||
if (message.Url is string url)
|
||||
{
|
||||
return new Part
|
||||
{
|
||||
FileData = new FileData
|
||||
{
|
||||
FileUri = url,
|
||||
MimeType = message.MimeType
|
||||
}
|
||||
};
|
||||
}
|
||||
else if (message.Data is BinaryData data)
|
||||
{
|
||||
return new Part
|
||||
{
|
||||
InlineData = new Blob
|
||||
{
|
||||
MimeType = message.MimeType,
|
||||
Data = ByteString.CopyFrom(data.ToArray()),
|
||||
}
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided");
|
||||
}
|
||||
}
|
||||
|
||||
private bool ShouldParseAsUser(IMessage message, IAgent agent)
|
||||
{
|
||||
return message switch
|
||||
{
|
||||
TextMessage textMessage => (textMessage.Role == Role.User && textMessage.From is null)
|
||||
|| (textMessage.From != agent.Name),
|
||||
_ => message.From != agent.Name,
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IGeminiClient.cs
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
|
||||
namespace AutoGen.Gemini;
|
||||
|
||||
internal class VertexGeminiClient : IGeminiClient
|
||||
{
|
||||
private readonly PredictionServiceClient client;
|
||||
public VertexGeminiClient(PredictionServiceClient client)
|
||||
{
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
public VertexGeminiClient(string location)
|
||||
{
|
||||
PredictionServiceClientBuilder builder = new()
|
||||
{
|
||||
Endpoint = $"{location}-aiplatform.googleapis.com",
|
||||
};
|
||||
|
||||
this.client = builder.Build();
|
||||
}
|
||||
|
||||
public Task<GenerateContentResponse> GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return client.GenerateContentAsync(request, cancellationToken);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(GenerateContentRequest request)
|
||||
{
|
||||
return client.StreamGenerateContent(request).GetResponseStream();
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
using System.Text;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using AutoGen.Anthropic.DTO;
|
||||
|
@ -43,7 +43,7 @@ public class AnthropicClientTests
|
|||
request.Model = AnthropicConstants.Claude3Haiku;
|
||||
request.Stream = true;
|
||||
request.MaxTokens = 500;
|
||||
request.SystemMessage = "You are a helpful assistant that convert input to json object";
|
||||
request.SystemMessage = "You are a helpful assistant that convert input to json object, use JSON format.";
|
||||
request.Messages = new List<ChatMessage>()
|
||||
{
|
||||
new("user", "name: John, age: 41, email: g123456@gmail.com")
|
||||
|
|
|
@ -6,16 +6,9 @@
|
|||
<IsPackable>false</IsPackable>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
<RootNamespace>AutoGen.Anthropic.Tests</RootNamespace>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen.Anthropic\AutoGen.Anthropic.csproj" />
|
||||
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
|
||||
|
|
|
@ -4,18 +4,10 @@
|
|||
<TargetFramework>$(TestTargetFramework)</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"name": "GetWeatherAsync",
|
||||
"description": "Get weather for a city.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "STRING",
|
||||
"description": "city",
|
||||
"title": "city"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"city"
|
||||
]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>$(TestTargetFramework)</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\sample\AutoGen.Gemini.Sample\AutoGen.Gemini.Sample.csproj" />
|
||||
<ProjectReference Include="..\..\src\AutoGen.Gemini\AutoGen.Gemini.csproj" />
|
||||
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// FunctionContractExtensionTests.cs
|
||||
|
||||
using ApprovalTests;
|
||||
using ApprovalTests.Namers;
|
||||
using ApprovalTests.Reporters;
|
||||
using AutoGen.Gemini.Extension;
|
||||
using Google.Protobuf;
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.Gemini.Tests;
|
||||
|
||||
public class FunctionContractExtensionTests
|
||||
{
|
||||
private readonly Functions functions = new Functions();
|
||||
[Fact]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
public void ItGenerateGetWeatherToolTest()
|
||||
{
|
||||
var contract = functions.GetWeatherAsyncFunctionContract;
|
||||
var tool = contract.ToFunctionDeclaration();
|
||||
var formatter = new JsonFormatter(JsonFormatter.Settings.Default.WithIndentation(" "));
|
||||
var json = formatter.Format(tool);
|
||||
Approvals.Verify(json);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Functions.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
|
||||
namespace AutoGen.Gemini.Tests;
|
||||
|
||||
public partial class Functions
|
||||
{
|
||||
/// <summary>
|
||||
/// Get weather for a city.
|
||||
/// </summary>
|
||||
/// <param name="city">city</param>
|
||||
/// <returns>weather</returns>
|
||||
[Function]
|
||||
public async Task<string> GetWeatherAsync(string city)
|
||||
{
|
||||
return await Task.FromResult($"The weather in {city} is sunny.");
|
||||
}
|
||||
|
||||
[Function]
|
||||
public async Task<string> GetMovies(string location, string description)
|
||||
{
|
||||
var movies = new List<string> { "Barbie", "Spiderman", "Batman" };
|
||||
|
||||
return await Task.FromResult($"Movies playing in {location} based on {description} are: {string.Join(", ", movies)}");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,311 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GeminiAgentTests.cs
|
||||
|
||||
using AutoGen.Tests;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using AutoGen.Core;
|
||||
using FluentAssertions;
|
||||
using AutoGen.Gemini.Extension;
|
||||
using static Google.Cloud.AIPlatform.V1.Part;
|
||||
using Xunit.Abstractions;
|
||||
using AutoGen.Gemini.Middleware;
|
||||
namespace AutoGen.Gemini.Tests;
|
||||
|
||||
public class GeminiAgentTests
|
||||
{
|
||||
private readonly Functions functions = new Functions();
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
public GeminiAgentTests(ITestOutputHelper output)
|
||||
{
|
||||
_output = output;
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task VertexGeminiAgentGenerateReplyForTextContentAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var textContent = new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = "Hello",
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var agent = new GeminiChatAgent(
|
||||
name: "assistant",
|
||||
model: model,
|
||||
project: project,
|
||||
location: location,
|
||||
systemMessage: "You are a helpful AI assistant");
|
||||
var message = MessageEnvelope.Create(textContent, from: agent.Name);
|
||||
|
||||
var completion = await agent.SendAsync(message);
|
||||
|
||||
completion.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
|
||||
completion.From.Should().Be(agent.Name);
|
||||
|
||||
var response = ((MessageEnvelope<GenerateContentResponse>)completion).Content;
|
||||
response.Should().NotBeNull();
|
||||
response.Candidates.Count.Should().BeGreaterThan(0);
|
||||
response.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task VertexGeminiAgentGenerateStreamingReplyForTextContentAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var textContent = new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = "Hello",
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var agent = new GeminiChatAgent(
|
||||
name: "assistant",
|
||||
model: model,
|
||||
project: project,
|
||||
location: location,
|
||||
systemMessage: "You are a helpful AI assistant");
|
||||
var message = MessageEnvelope.Create(textContent, from: agent.Name);
|
||||
|
||||
var completion = agent.GenerateStreamingReplyAsync([message]);
|
||||
var chunks = new List<IStreamingMessage>();
|
||||
IStreamingMessage finalReply = null!;
|
||||
|
||||
await foreach (var item in completion)
|
||||
{
|
||||
item.Should().NotBeNull();
|
||||
item.From.Should().Be(agent.Name);
|
||||
var streamingMessage = (IMessage<GenerateContentResponse>)item;
|
||||
streamingMessage.Content.Candidates.Should().NotBeNullOrEmpty();
|
||||
chunks.Add(item);
|
||||
finalReply = item;
|
||||
}
|
||||
|
||||
chunks.Count.Should().BeGreaterThan(0);
|
||||
finalReply.Should().NotBeNull();
|
||||
finalReply.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
|
||||
var response = ((MessageEnvelope<GenerateContentResponse>)finalReply).Content;
|
||||
response.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task VertexGeminiAgentGenerateReplyWithToolsAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
|
||||
var model = "gemini-1.5-flash-001";
|
||||
var tools = new Tool[]
|
||||
{
|
||||
new Tool
|
||||
{
|
||||
FunctionDeclarations = {
|
||||
functions.GetWeatherAsyncFunctionContract.ToFunctionDeclaration(),
|
||||
},
|
||||
},
|
||||
new Tool
|
||||
{
|
||||
FunctionDeclarations =
|
||||
{
|
||||
functions.GetMoviesFunctionContract.ToFunctionDeclaration(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
var textContent = new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = "what's the weather in seattle",
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var agent = new GeminiChatAgent(
|
||||
name: "assistant",
|
||||
model: model,
|
||||
project: project,
|
||||
location: location,
|
||||
systemMessage: "You are a helpful AI assistant",
|
||||
tools: tools,
|
||||
toolConfig: new ToolConfig()
|
||||
{
|
||||
FunctionCallingConfig = new FunctionCallingConfig()
|
||||
{
|
||||
Mode = FunctionCallingConfig.Types.Mode.Auto,
|
||||
}
|
||||
});
|
||||
|
||||
var message = MessageEnvelope.Create(textContent, from: agent.Name);
|
||||
|
||||
var completion = await agent.SendAsync(message);
|
||||
|
||||
completion.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
|
||||
completion.From.Should().Be(agent.Name);
|
||||
|
||||
var response = ((MessageEnvelope<GenerateContentResponse>)completion).Content;
|
||||
response.Should().NotBeNull();
|
||||
response.Candidates.Count.Should().BeGreaterThan(0);
|
||||
response.Candidates[0].Content.Parts[0].DataCase.Should().Be(DataOneofCase.FunctionCall);
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task VertexGeminiAgentGenerateStreamingReplyWithToolsAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
|
||||
var model = "gemini-1.5-flash-001";
|
||||
var tools = new Tool[]
|
||||
{
|
||||
new Tool
|
||||
{
|
||||
FunctionDeclarations = { functions.GetWeatherAsyncFunctionContract.ToFunctionDeclaration() },
|
||||
},
|
||||
};
|
||||
|
||||
var textContent = new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = "what's the weather in seattle",
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var agent = new GeminiChatAgent(
|
||||
name: "assistant",
|
||||
model: model,
|
||||
project: project,
|
||||
location: location,
|
||||
systemMessage: "You are a helpful AI assistant",
|
||||
tools: tools,
|
||||
toolConfig: new ToolConfig()
|
||||
{
|
||||
FunctionCallingConfig = new FunctionCallingConfig()
|
||||
{
|
||||
Mode = FunctionCallingConfig.Types.Mode.Auto,
|
||||
}
|
||||
});
|
||||
|
||||
var message = MessageEnvelope.Create(textContent, from: agent.Name);
|
||||
|
||||
var chunks = new List<IStreamingMessage>();
|
||||
IStreamingMessage finalReply = null!;
|
||||
|
||||
var completion = agent.GenerateStreamingReplyAsync([message]);
|
||||
|
||||
await foreach (var item in completion)
|
||||
{
|
||||
item.Should().NotBeNull();
|
||||
item.From.Should().Be(agent.Name);
|
||||
var streamingMessage = (IMessage<GenerateContentResponse>)item;
|
||||
streamingMessage.Content.Candidates.Should().NotBeNullOrEmpty();
|
||||
if (streamingMessage.Content.Candidates[0].FinishReason != Candidate.Types.FinishReason.Stop)
|
||||
{
|
||||
streamingMessage.Content.Candidates[0].Content.Parts[0].DataCase.Should().Be(DataOneofCase.FunctionCall);
|
||||
}
|
||||
chunks.Add(item);
|
||||
finalReply = item;
|
||||
}
|
||||
|
||||
chunks.Count.Should().BeGreaterThan(0);
|
||||
finalReply.Should().NotBeNull();
|
||||
finalReply.Should().BeOfType<MessageEnvelope<GenerateContentResponse>>();
|
||||
var response = ((MessageEnvelope<GenerateContentResponse>)finalReply).Content;
|
||||
response.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task GeminiAgentUpperCaseTestAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
|
||||
var model = "gemini-1.5-flash-001";
|
||||
var agent = new GeminiChatAgent(
|
||||
name: "assistant",
|
||||
model: model,
|
||||
project: project,
|
||||
location: location)
|
||||
.RegisterMessageConnector();
|
||||
|
||||
var singleAgentTest = new SingleAgentTest(_output);
|
||||
await singleAgentTest.UpperCaseStreamingTestAsync(agent);
|
||||
await singleAgentTest.UpperCaseTestAsync(agent);
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task GeminiAgentEchoFunctionCallTestAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
|
||||
var model = "gemini-1.5-flash-001";
|
||||
var singleAgentTest = new SingleAgentTest(_output);
|
||||
var echoFunctionContract = singleAgentTest.EchoAsyncFunctionContract;
|
||||
var agent = new GeminiChatAgent(
|
||||
name: "assistant",
|
||||
model: model,
|
||||
project: project,
|
||||
location: location,
|
||||
tools:
|
||||
[
|
||||
new Tool
|
||||
{
|
||||
FunctionDeclarations = { echoFunctionContract.ToFunctionDeclaration() },
|
||||
},
|
||||
])
|
||||
.RegisterMessageConnector();
|
||||
|
||||
await singleAgentTest.EchoFunctionCallTestAsync(agent);
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task GeminiAgentEchoFunctionCallExecutionTestAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID") ?? throw new InvalidOperationException("GCP_VERTEX_PROJECT_ID is not set.");
|
||||
var model = "gemini-1.5-flash-001";
|
||||
var singleAgentTest = new SingleAgentTest(_output);
|
||||
var echoFunctionContract = singleAgentTest.EchoAsyncFunctionContract;
|
||||
var functionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [echoFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>()
|
||||
{
|
||||
{ echoFunctionContract.Name!, singleAgentTest.EchoAsyncWrapper },
|
||||
});
|
||||
|
||||
var agent = new GeminiChatAgent(
|
||||
name: "assistant",
|
||||
model: model,
|
||||
project: project,
|
||||
location: location)
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionMiddleware);
|
||||
|
||||
await singleAgentTest.EchoFunctionCallExecutionStreamingTestAsync(agent);
|
||||
await singleAgentTest.EchoFunctionCallExecutionTestAsync(agent);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,380 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GeminiMessageTests.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
using AutoGen.Gemini.Middleware;
|
||||
using AutoGen.Tests;
|
||||
using FluentAssertions;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.Gemini.Tests;
|
||||
|
||||
public class GeminiMessageTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ItProcessUserTextMessageAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Parts.Count.Should().Be(1);
|
||||
message.Content.Role.Should().Be("user");
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
// when from is null and role is user
|
||||
await agent.SendAsync("Hello");
|
||||
|
||||
// when from is user and role is user
|
||||
var userMessage = new TextMessage(Role.User, "Hello", from: "user");
|
||||
await agent.SendAsync(userMessage);
|
||||
|
||||
// when from is user but role is assistant
|
||||
userMessage = new TextMessage(Role.Assistant, "Hello", from: "user");
|
||||
await agent.SendAsync(userMessage);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessAssistantTextMessageAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Parts.Count.Should().Be(1);
|
||||
message.Content.Role.Should().Be("model");
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
// when from is user and role is assistant
|
||||
var message = new TextMessage(Role.User, "Hello", from: agent.Name);
|
||||
await agent.SendAsync(message);
|
||||
|
||||
// when from is assistant and role is assistant
|
||||
message = new TextMessage(Role.Assistant, "Hello", from: agent.Name);
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessSystemTextMessageAsUserMessageWhenStrictModeIsFalseAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Parts.Count.Should().Be(1);
|
||||
message.Content.Role.Should().Be("user");
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var message = new TextMessage(Role.System, "Hello", from: agent.Name);
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItThrowExceptionOnSystemMessageWhenStrictModeIsTrueAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector(true);
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var message = new TextMessage(Role.System, "Hello", from: agent.Name);
|
||||
var action = new Func<Task>(async () => await agent.SendAsync(message));
|
||||
await action.Should().ThrowAsync<InvalidOperationException>();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessUserImageMessageAsInlineDataAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Parts.Count.Should().Be(1);
|
||||
message.Content.Role.Should().Be("user");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.InlineData);
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var imagePath = Path.Combine("testData", "images", "background.png");
|
||||
var image = File.ReadAllBytes(imagePath);
|
||||
var message = new ImageMessage(Role.User, BinaryData.FromBytes(image, "image/png"));
|
||||
message.MimeType.Should().Be("image/png");
|
||||
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessUserImageMessageAsFileDataAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Parts.Count.Should().Be(1);
|
||||
message.Content.Role.Should().Be("user");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FileData);
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var imagePath = Path.Combine("testData", "images", "image.png");
|
||||
var url = new Uri(Path.GetFullPath(imagePath)).AbsoluteUri;
|
||||
var message = new ImageMessage(Role.User, url);
|
||||
message.MimeType.Should().Be("image/png");
|
||||
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessMultiModalMessageAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Parts.Count.Should().Be(2);
|
||||
message.Content.Role.Should().Be("user");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.Text);
|
||||
message.Content.Parts.Last().DataCase.Should().Be(Part.DataOneofCase.FileData);
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var imagePath = Path.Combine("testData", "images", "image.png");
|
||||
var url = new Uri(Path.GetFullPath(imagePath)).AbsoluteUri;
|
||||
var message = new ImageMessage(Role.User, url);
|
||||
message.MimeType.Should().Be("image/png");
|
||||
var textMessage = new TextMessage(Role.User, "What's in this image?");
|
||||
var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, message]);
|
||||
|
||||
await agent.SendAsync(multiModalMessage);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessToolCallMessageAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Role.Should().Be("model");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionCall);
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var toolCallMessage = new ToolCallMessage("test", "{}", "user");
|
||||
await agent.SendAsync(toolCallMessage);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessStreamingTextMessageAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterStreamingMiddleware(messageConnector);
|
||||
|
||||
var messageChunks = Enumerable.Range(0, 10)
|
||||
.Select(i => new GenerateContentResponse()
|
||||
{
|
||||
Candidates =
|
||||
{
|
||||
new Candidate()
|
||||
{
|
||||
Content = new Content()
|
||||
{
|
||||
Role = "user",
|
||||
Parts = { new Part { Text = i.ToString() } },
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.Select(m => MessageEnvelope.Create(m));
|
||||
|
||||
IStreamingMessage? finalReply = null;
|
||||
await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks))
|
||||
{
|
||||
reply.Should().BeAssignableTo<IStreamingMessage>();
|
||||
finalReply = reply;
|
||||
}
|
||||
|
||||
finalReply.Should().BeOfType<TextMessage>();
|
||||
var textMessage = (TextMessage)finalReply!;
|
||||
textMessage.GetContent().Should().Be("0123456789");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessToolCallResultMessageAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Role.Should().Be("function");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionResponse);
|
||||
message.Content.Parts.First().FunctionResponse.Response.ToString().Should().Be("{ \"result\": \"result\" }");
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var message = new ToolCallResultMessage("result", "test", "{}", "user");
|
||||
await agent.SendAsync(message);
|
||||
|
||||
// when the result is already a json object string
|
||||
message = new ToolCallResultMessage("{ \"result\": \"result\" }", "test", "{}", "user");
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessToolCallAggregateMessageAsTextContentAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(1);
|
||||
var innerMessage = msgs.First();
|
||||
innerMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)innerMessage;
|
||||
message.Content.Role.Should().Be("user");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.Text);
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
var toolCallMessage = new ToolCallMessage("test", "{}", "user");
|
||||
var toolCallResultMessage = new ToolCallResultMessage("result", "test", "{}", "user");
|
||||
var message = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: "user");
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessToolCallAggregateMessageAsFunctionContentAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
msgs.Count().Should().Be(2);
|
||||
var functionCallMessage = msgs.First();
|
||||
functionCallMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
var message = (IMessage<Content>)functionCallMessage;
|
||||
message.Content.Role.Should().Be("model");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionCall);
|
||||
|
||||
var functionResultMessage = msgs.Last();
|
||||
functionResultMessage.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
message = (IMessage<Content>)functionResultMessage;
|
||||
message.Content.Role.Should().Be("function");
|
||||
message.Content.Parts.First().DataCase.Should().Be(Part.DataOneofCase.FunctionResponse);
|
||||
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
var toolCallMessage = new ToolCallMessage("test", "{}", agent.Name);
|
||||
var toolCallResultMessage = new ToolCallResultMessage("result", "test", "{}", agent.Name);
|
||||
var message = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: agent.Name);
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItThrowExceptionWhenProcessingUnknownMessageTypeInStrictModeAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector(true);
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var unknownMessage = new
|
||||
{
|
||||
text = "Hello",
|
||||
};
|
||||
|
||||
var message = MessageEnvelope.Create(unknownMessage, from: agent.Name);
|
||||
var action = new Func<Task>(async () => await agent.SendAsync(message));
|
||||
|
||||
await action.Should().ThrowAsync<InvalidOperationException>();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItReturnUnknownMessageTypeInNonStrictModeAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
var message = msgs.First();
|
||||
message.Should().BeAssignableTo<IMessage>();
|
||||
return message;
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var unknownMessage = new
|
||||
{
|
||||
text = "Hello",
|
||||
};
|
||||
|
||||
var message = MessageEnvelope.Create(unknownMessage, from: agent.Name);
|
||||
await agent.SendAsync(message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItShortcircuitContentTypeAsync()
|
||||
{
|
||||
var messageConnector = new GeminiMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
|
||||
{
|
||||
var message = msgs.First();
|
||||
message.Should().BeOfType<MessageEnvelope<Content>>();
|
||||
|
||||
return message;
|
||||
})
|
||||
.RegisterMiddleware(messageConnector);
|
||||
|
||||
var message = new Content()
|
||||
{
|
||||
Parts = { new Part { Text = "Hello" } },
|
||||
Role = "user",
|
||||
};
|
||||
|
||||
await agent.SendAsync(MessageEnvelope.Create(message, from: agent.Name));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GoogleGeminiClientTests.cs
|
||||
|
||||
using AutoGen.Tests;
|
||||
using FluentAssertions;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using Google.Protobuf;
|
||||
using static Google.Cloud.AIPlatform.V1.Candidate.Types;
|
||||
|
||||
namespace AutoGen.Gemini.Tests;
|
||||
|
||||
public class GoogleGeminiClientTests
|
||||
{
|
||||
[ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
|
||||
public async Task ItGenerateContentAsync()
|
||||
{
|
||||
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
|
||||
var client = new GoogleGeminiClient(apiKey);
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var text = "Write a long, tedious story";
|
||||
var request = new GenerateContentRequest
|
||||
{
|
||||
Model = model,
|
||||
Contents =
|
||||
{
|
||||
new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = text,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
var completion = await client.GenerateContentAsync(request);
|
||||
|
||||
completion.Should().NotBeNull();
|
||||
completion.Candidates.Count.Should().BeGreaterThan(0);
|
||||
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
|
||||
public async Task ItGenerateContentWithImageAsync()
|
||||
{
|
||||
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
|
||||
var client = new GoogleGeminiClient(apiKey);
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var text = "what's in the image";
|
||||
var imagePath = Path.Combine("testData", "images", "background.png");
|
||||
var image = File.ReadAllBytes(imagePath);
|
||||
var request = new GenerateContentRequest
|
||||
{
|
||||
Model = model,
|
||||
Contents =
|
||||
{
|
||||
new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = text,
|
||||
},
|
||||
new Part
|
||||
{
|
||||
InlineData = new ()
|
||||
{
|
||||
MimeType = "image/png",
|
||||
Data = ByteString.CopyFrom(image),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var completion = await client.GenerateContentAsync(request);
|
||||
completion.Should().NotBeNull();
|
||||
completion.Candidates.Count.Should().BeGreaterThan(0);
|
||||
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
|
||||
public async Task ItStreamingGenerateContentTestAsync()
|
||||
{
|
||||
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
|
||||
var client = new GoogleGeminiClient(apiKey);
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var text = "Tell me a long tedious joke";
|
||||
var request = new GenerateContentRequest
|
||||
{
|
||||
Model = model,
|
||||
Contents =
|
||||
{
|
||||
new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = text,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var response = client.GenerateContentStreamAsync(request);
|
||||
var chunks = new List<GenerateContentResponse>();
|
||||
GenerateContentResponse? final = null;
|
||||
await foreach (var item in response)
|
||||
{
|
||||
item.Candidates.Count.Should().BeGreaterThan(0);
|
||||
final = item;
|
||||
chunks.Add(final);
|
||||
}
|
||||
|
||||
chunks.Should().NotBeEmpty();
|
||||
final.Should().NotBeNull();
|
||||
final!.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
|
||||
final!.Candidates[0].FinishReason.Should().Be(FinishReason.Stop);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SampleTests.cs
|
||||
|
||||
using AutoGen.Gemini.Sample;
|
||||
using AutoGen.Tests;
|
||||
|
||||
namespace AutoGen.Gemini.Tests;
|
||||
|
||||
public class SampleTests
|
||||
{
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task TestChatWithVertexGeminiAsync()
|
||||
{
|
||||
await Chat_With_Vertex_Gemini.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task TestFunctionCallWithGeminiAsync()
|
||||
{
|
||||
await Function_Call_With_Gemini.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task TestImageChatWithVertexGeminiAsync()
|
||||
{
|
||||
await Image_Chat_With_Vertex_Gemini.RunAsync();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GeminiVertexClientTests.cs
|
||||
|
||||
using AutoGen.Tests;
|
||||
using FluentAssertions;
|
||||
using Google.Cloud.AIPlatform.V1;
|
||||
using Google.Protobuf;
|
||||
using static Google.Cloud.AIPlatform.V1.Candidate.Types;
|
||||
namespace AutoGen.Gemini.Tests;
|
||||
|
||||
public class VertexGeminiClientTests
|
||||
{
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task ItGenerateContentAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
|
||||
var client = new VertexGeminiClient(location);
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var text = "Hello";
|
||||
var request = new GenerateContentRequest
|
||||
{
|
||||
Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}",
|
||||
Contents =
|
||||
{
|
||||
new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = text,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
var completion = await client.GenerateContentAsync(request);
|
||||
|
||||
completion.Should().NotBeNull();
|
||||
completion.Candidates.Count.Should().BeGreaterThan(0);
|
||||
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task ItGenerateContentWithImageAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
|
||||
var client = new VertexGeminiClient(location);
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var text = "what's in the image";
|
||||
var imagePath = Path.Combine("testData", "images", "image.png");
|
||||
var image = File.ReadAllBytes(imagePath);
|
||||
var request = new GenerateContentRequest
|
||||
{
|
||||
Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}",
|
||||
Contents =
|
||||
{
|
||||
new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = text,
|
||||
},
|
||||
new Part
|
||||
{
|
||||
InlineData = new ()
|
||||
{
|
||||
MimeType = "image/png",
|
||||
Data = ByteString.CopyFrom(image),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var completion = await client.GenerateContentAsync(request);
|
||||
completion.Should().NotBeNull();
|
||||
completion.Candidates.Count.Should().BeGreaterThan(0);
|
||||
completion.Candidates[0].Content.Parts[0].Text.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[ApiKeyFact("GCP_VERTEX_PROJECT_ID")]
|
||||
public async Task ItStreamingGenerateContentTestAsync()
|
||||
{
|
||||
var location = "us-central1";
|
||||
var project = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
|
||||
var client = new VertexGeminiClient(location);
|
||||
var model = "gemini-1.5-flash-001";
|
||||
|
||||
var text = "Hello, write a long tedious joke";
|
||||
var request = new GenerateContentRequest
|
||||
{
|
||||
Model = $"projects/{project}/locations/{location}/publishers/google/models/{model}",
|
||||
Contents =
|
||||
{
|
||||
new Content
|
||||
{
|
||||
Role = "user",
|
||||
Parts =
|
||||
{
|
||||
new Part
|
||||
{
|
||||
Text = text,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var response = client.GenerateContentStreamAsync(request);
|
||||
var chunks = new List<GenerateContentResponse>();
|
||||
GenerateContentResponse? final = null;
|
||||
await foreach (var item in response)
|
||||
{
|
||||
item.Candidates.Count.Should().BeGreaterThan(0);
|
||||
final = item;
|
||||
chunks.Add(final);
|
||||
}
|
||||
|
||||
chunks.Should().NotBeEmpty();
|
||||
final.Should().NotBeNull();
|
||||
final!.UsageMetadata.CandidatesTokenCount.Should().BeGreaterThan(0);
|
||||
final!.Candidates[0].FinishReason.Should().Be(FinishReason.Stop);
|
||||
}
|
||||
}
|
|
@ -4,18 +4,10 @@
|
|||
<TargetFramework>$(TestTargetFramework)</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen.Mistral\AutoGen.Mistral.csproj" />
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
|
|
|
@ -4,18 +4,10 @@
|
|||
<TargetFramework>$(TestTargetFramework)</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen.Ollama\AutoGen.Ollama.csproj" />
|
||||
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
|
||||
|
|
|
@ -3,18 +3,10 @@
|
|||
<PropertyGroup>
|
||||
<TargetFramework>$(TestTargetFramework)</TargetFramework>
|
||||
<IsPackable>false</IsPackable>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
|
||||
|
|
|
@ -5,18 +5,10 @@
|
|||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<NoWarn>$(NoWarn);SKEXP0110</NoWarn>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj" />
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
|
|
|
@ -4,18 +4,10 @@
|
|||
<TargetFramework>$(TestTargetFramework)</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="true" />
|
||||
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
|
||||
|
|
|
@ -3,18 +3,10 @@
|
|||
<PropertyGroup>
|
||||
<TargetFramework>$(TestTargetFramework)</TargetFramework>
|
||||
<GenerateDocumentationFile>True</GenerateDocumentationFile>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
<NoWarn>$(NoWarn);xUnit1013;SKEXP0110</NoWarn>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ApprovalTests" Version="$(ApprovalTestVersion)" />
|
||||
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\sample\AutoGen.BasicSamples\AutoGen.BasicSample.csproj" />
|
||||
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ImageMessageTests.cs
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
using FluentAssertions;
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
public class ImageMessageTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ItCreateFromLocalImage()
|
||||
{
|
||||
var image = Path.Combine("testData", "images", "background.png");
|
||||
var binary = File.ReadAllBytes(image);
|
||||
var base64 = Convert.ToBase64String(binary);
|
||||
var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(binary, "image/png"));
|
||||
|
||||
imageMessage.MimeType.Should().Be("image/png");
|
||||
imageMessage.BuildDataUri().Should().Be($"data:image/png;base64,{base64}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItCreateFromUrl()
|
||||
{
|
||||
var image = Path.Combine("testData", "images", "background.png");
|
||||
var fullPath = Path.GetFullPath(image);
|
||||
var localUrl = new Uri(fullPath).AbsoluteUri;
|
||||
var imageMessage = new ImageMessage(Role.User, localUrl);
|
||||
|
||||
imageMessage.Url.Should().Be(localUrl);
|
||||
imageMessage.MimeType.Should().Be("image/png");
|
||||
imageMessage.Data.Should().BeNull();
|
||||
}
|
||||
}
|
|
@ -266,10 +266,10 @@ namespace AutoGen.Tests
|
|||
|
||||
public async Task EchoFunctionCallTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld });
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
|
@ -277,10 +277,10 @@ namespace AutoGen.Tests
|
|||
|
||||
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld });
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.GetContent().Should().Be("[ECHO] Hello world");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
|
@ -289,13 +289,13 @@ namespace AutoGen.Tests
|
|||
|
||||
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IStreamingMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
|
@ -319,25 +319,23 @@ namespace AutoGen.Tests
|
|||
|
||||
public async Task UpperCaseTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case");
|
||||
var uppCaseMessage = new TextMessage(Role.User, "abcdefg");
|
||||
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message, uppCaseMessage });
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message });
|
||||
|
||||
reply.GetContent().Should().Contain("ABCDEFG");
|
||||
reply.GetContent().Should().Contain("ABCDE");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
}
|
||||
|
||||
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case");
|
||||
var helloWorld = new TextMessage(Role.User, "a b c d e f g h i j k l m n");
|
||||
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
|
||||
var answer = "A B C D E F G H I J K L M N";
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
|
||||
var answer = "HELLO WORLD";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue