diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs new file mode 100644 index 0000000000..e64710e46e --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -0,0 +1,594 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Data.Common; +using System.Text.Json; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Parsers; +using Azure.DataApiBuilder.Core.Resolvers; +using Azure.DataApiBuilder.Core.Resolvers.Factories; +using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Mcp.BuiltInTools +{ + /// + /// Tool to aggregate records from a table/view entity configured in DAB. + /// Supports count, avg, sum, min, max with optional distinct, filter, groupby, having, orderby. + /// + public class AggregateRecordsTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + private static readonly HashSet ValidFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" }; + + public Tool GetToolMetadata() + { + return new Tool + { + Name = "aggregate_records", + Description = "STEP 1: describe_entities -> find entities with READ permission and their fields. STEP 2: call this tool to compute aggregations (count, avg, sum, min, max) with optional filter, groupby, having, and orderby.", + InputSchema = JsonSerializer.Deserialize( + @"{ + ""type"": ""object"", + ""properties"": { + ""entity"": { + ""type"": ""string"", + ""description"": ""Entity name with READ permission."" + }, + ""function"": { + ""type"": ""string"", + ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""], + ""description"": ""Aggregation function to apply."" + }, + ""field"": { + ""type"": ""string"", + ""description"": ""Field to aggregate. Use '*' for count."" + }, + ""distinct"": { + ""type"": ""boolean"", + ""description"": ""Apply DISTINCT before aggregating."", + ""default"": false + }, + ""filter"": { + ""type"": ""string"", + ""description"": ""OData filter applied before aggregating (WHERE). Example: 'unitPrice lt 10'"", + ""default"": """" + }, + ""groupby"": { + ""type"": ""array"", + ""items"": { ""type"": ""string"" }, + ""description"": ""Fields to group by, e.g., ['category', 'region']. Grouped field values are included in the response."", + ""default"": [] + }, + ""orderby"": { + ""type"": ""string"", + ""enum"": [""asc"", ""desc""], + ""description"": ""Sort aggregated results by the computed value. Only applies with groupby."", + ""default"": ""desc"" + }, + ""having"": { + ""type"": ""object"", + ""description"": ""Filter applied after aggregating on the result (HAVING). Operators are AND-ed together."", + ""properties"": { + ""eq"": { ""type"": ""number"", ""description"": ""Aggregated value equals."" }, + ""neq"": { ""type"": ""number"", ""description"": ""Aggregated value not equals."" }, + ""gt"": { ""type"": ""number"", ""description"": ""Aggregated value greater than."" }, + ""gte"": { ""type"": ""number"", ""description"": ""Aggregated value greater than or equal."" }, + ""lt"": { ""type"": ""number"", ""description"": ""Aggregated value less than."" }, + ""lte"": { ""type"": ""number"", ""description"": ""Aggregated value less than or equal."" }, + ""in"": { + ""type"": ""array"", + ""items"": { ""type"": ""number"" }, + ""description"": ""Aggregated value is in the given list."" + } + } + } + }, + ""required"": [""entity"", ""function"", ""field""] + }" + ) + }; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; + + RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); + RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig(); + + if (runtimeConfig.McpDmlTools?.AggregateRecords is not true) + { + return McpErrorHelpers.ToolDisabled(toolName, logger); + } + + try + { + cancellationToken.ThrowIfCancellationRequested(); + + if (arguments == null) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); + } + + JsonElement root = arguments.RootElement; + + // Parse required arguments + if (!McpArgumentParser.TryParseEntity(root, out string entityName, out string parseError)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); + } + + if (runtimeConfig.Entities?.TryGetValue(entityName, out Entity? entity) == true && + entity.Mcp?.DmlToolEnabled == false) + { + return McpErrorHelpers.ToolDisabled(toolName, logger, $"DML tools are disabled for entity '{entityName}'."); + } + + if (!root.TryGetProperty("function", out JsonElement funcEl) || string.IsNullOrWhiteSpace(funcEl.GetString())) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'function'.", logger); + } + + string function = funcEl.GetString()!.ToLowerInvariant(); + if (!ValidFunctions.Contains(function)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid function '{function}'. Must be one of: count, avg, sum, min, max.", logger); + } + + if (!root.TryGetProperty("field", out JsonElement fieldEl) || string.IsNullOrWhiteSpace(fieldEl.GetString())) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'field'.", logger); + } + + string field = fieldEl.GetString()!; + bool distinct = root.TryGetProperty("distinct", out JsonElement distinctEl) && distinctEl.GetBoolean(); + string? filter = root.TryGetProperty("filter", out JsonElement filterEl) ? filterEl.GetString() : null; + string orderby = root.TryGetProperty("orderby", out JsonElement orderbyEl) ? (orderbyEl.GetString() ?? "desc") : "desc"; + + List groupby = new(); + if (root.TryGetProperty("groupby", out JsonElement groupbyEl) && groupbyEl.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement g in groupbyEl.EnumerateArray()) + { + string? gVal = g.GetString(); + if (!string.IsNullOrWhiteSpace(gVal)) + { + groupby.Add(gVal); + } + } + } + + Dictionary? havingOps = null; + List? havingIn = null; + if (root.TryGetProperty("having", out JsonElement havingEl) && havingEl.ValueKind == JsonValueKind.Object) + { + havingOps = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (JsonProperty prop in havingEl.EnumerateObject()) + { + if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase) && prop.Value.ValueKind == JsonValueKind.Array) + { + havingIn = new List(); + foreach (JsonElement item in prop.Value.EnumerateArray()) + { + havingIn.Add(item.GetDouble()); + } + } + else if (prop.Value.ValueKind == JsonValueKind.Number) + { + havingOps[prop.Name] = prop.Value.GetDouble(); + } + } + } + + // Resolve metadata + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + runtimeConfig, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); + } + + // Authorization + IAuthorizationResolver authResolver = serviceProvider.GetRequiredService(); + IAuthorizationService authorizationService = serviceProvider.GetRequiredService(); + IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService(); + HttpContext? httpContext = httpContextAccessor.HttpContext; + + if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleCtxError)) + { + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", roleCtxError, logger); + } + + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext!, + authResolver, + entityName, + EntityActionOperation.Read, + out string? effectiveRole, + out string readAuthError)) + { + string finalError = readAuthError.StartsWith("You do not have permission", StringComparison.OrdinalIgnoreCase) + ? $"You do not have permission to read records for entity '{entityName}'." + : readAuthError; + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger); + } + + // Build select list: groupby fields + aggregation field + List selectFields = new(groupby); + bool isCountStar = function == "count" && field == "*"; + if (!isCountStar && !selectFields.Contains(field, StringComparer.OrdinalIgnoreCase)) + { + selectFields.Add(field); + } + + // Build and validate Find context + RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider); + FindRequestContext context = new(entityName, dbObject, true); + httpContext!.Request.Method = "GET"; + + requestValidator.ValidateEntity(entityName); + + if (selectFields.Count > 0) + { + context.UpdateReturnFields(selectFields); + } + + if (!string.IsNullOrWhiteSpace(filter)) + { + string filterQueryString = $"?{RequestParser.FILTER_URL}={filter}"; + context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause(filterQueryString, $"{context.EntityName}.{context.DatabaseObject.FullName}"); + } + + requestValidator.ValidateRequestContext(context); + + AuthorizationResult authorizationResult = await authorizationService.AuthorizeAsync( + user: httpContext.User, + resource: context, + requirements: new[] { new ColumnsPermissionsRequirement() }); + if (!authorizationResult.Succeeded) + { + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + } + + // Execute query to get records + IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); + IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType()); + JsonDocument? queryResult = await queryEngine.ExecuteAsync(context); + + IActionResult actionResult = queryResult is null + ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true) + : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true); + + string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult); + using JsonDocument resultDoc = JsonDocument.Parse(rawPayloadJson); + JsonElement resultRoot = resultDoc.RootElement; + + // Extract the records array from the response + JsonElement records; + if (resultRoot.TryGetProperty("value", out JsonElement valueArray)) + { + records = valueArray; + } + else if (resultRoot.ValueKind == JsonValueKind.Array) + { + records = resultRoot; + } + else + { + records = resultRoot; + } + + // Compute alias for the response + string alias = ComputeAlias(function, field); + + // Perform in-memory aggregation + List> aggregatedResults = PerformAggregation( + records, function, field, distinct, groupby, havingOps, havingIn, orderby, alias); + + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = aggregatedResults, + ["message"] = $"Successfully aggregated records for entity '{entityName}'" + }, + logger, + $"AggregateRecordsTool success for entity {entityName}."); + } + catch (OperationCanceledException) + { + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The aggregate operation was canceled.", logger); + } + catch (DbException argEx) + { + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", argEx.Message, logger); + } + catch (ArgumentException argEx) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); + } + catch (DataApiBuilderException argEx) + { + return McpResponseBuilder.BuildErrorResult(toolName, argEx.StatusCode.ToString(), argEx.Message, logger); + } + catch (Exception ex) + { + logger?.LogError(ex, "Unexpected error in AggregateRecordsTool."); + return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in AggregateRecordsTool.", logger); + } + } + + /// + /// Computes the response alias for the aggregation result. + /// For count with "*", the alias is "count". Otherwise it's "{function}_{field}". + /// + internal static string ComputeAlias(string function, string field) + { + if (function == "count" && field == "*") + { + return "count"; + } + + return $"{function}_{field}"; + } + + /// + /// Performs in-memory aggregation over a JSON array of records. + /// + internal static List> PerformAggregation( + JsonElement records, + string function, + string field, + bool distinct, + List groupby, + Dictionary? havingOps, + List? havingIn, + string orderby, + string alias) + { + if (records.ValueKind != JsonValueKind.Array) + { + return new List> { new() { [alias] = null } }; + } + + bool isCountStar = function == "count" && field == "*"; + + if (groupby.Count == 0) + { + // No groupby - single result + List items = new(); + foreach (JsonElement record in records.EnumerateArray()) + { + items.Add(record); + } + + double? aggregatedValue = ComputeAggregateValue(items, function, field, distinct, isCountStar); + + // Apply having + if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn)) + { + return new List>(); + } + + return new List> + { + new() { [alias] = aggregatedValue } + }; + } + else + { + // Group by + Dictionary> groups = new(); + Dictionary> groupKeys = new(); + + foreach (JsonElement record in records.EnumerateArray()) + { + string key = BuildGroupKey(record, groupby); + if (!groups.ContainsKey(key)) + { + groups[key] = new List(); + groupKeys[key] = ExtractGroupFields(record, groupby); + } + + groups[key].Add(record); + } + + List> results = new(); + foreach (KeyValuePair> group in groups) + { + double? aggregatedValue = ComputeAggregateValue(group.Value, function, field, distinct, isCountStar); + + if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn)) + { + continue; + } + + Dictionary row = new(groupKeys[group.Key]) + { + [alias] = aggregatedValue + }; + results.Add(row); + } + + // Apply orderby + if (orderby.Equals("asc", StringComparison.OrdinalIgnoreCase)) + { + results.Sort((a, b) => CompareNullableDoubles(a[alias] as double?, b[alias] as double?)); + } + else + { + results.Sort((a, b) => CompareNullableDoubles(b[alias] as double?, a[alias] as double?)); + } + + return results; + } + } + + private static double? ComputeAggregateValue(List records, string function, string field, bool distinct, bool isCountStar) + { + if (isCountStar) + { + return distinct ? 0 : records.Count; + } + + List values = new(); + foreach (JsonElement record in records) + { + if (record.TryGetProperty(field, out JsonElement val) && val.ValueKind == JsonValueKind.Number) + { + values.Add(val.GetDouble()); + } + } + + if (distinct) + { + values = values.Distinct().ToList(); + } + + if (function == "count") + { + return values.Count; + } + + if (values.Count == 0) + { + return null; + } + + return function switch + { + "avg" => Math.Round(values.Average(), 2), + "sum" => values.Sum(), + "min" => values.Min(), + "max" => values.Max(), + _ => null + }; + } + + private static bool PassesHavingFilter(double? value, Dictionary? havingOps, List? havingIn) + { + if (havingOps == null && havingIn == null) + { + return true; + } + + if (value == null) + { + return false; + } + + double v = value.Value; + + if (havingOps != null) + { + foreach (KeyValuePair op in havingOps) + { + bool passes = op.Key.ToLowerInvariant() switch + { + "eq" => v == op.Value, + "neq" => v != op.Value, + "gt" => v > op.Value, + "gte" => v >= op.Value, + "lt" => v < op.Value, + "lte" => v <= op.Value, + _ => true + }; + + if (!passes) + { + return false; + } + } + } + + if (havingIn != null && !havingIn.Contains(v)) + { + return false; + } + + return true; + } + + private static string BuildGroupKey(JsonElement record, List groupby) + { + List parts = new(); + foreach (string g in groupby) + { + if (record.TryGetProperty(g, out JsonElement val)) + { + parts.Add(val.ToString()); + } + else + { + parts.Add("__null__"); + } + } + + return string.Join("|", parts); + } + + private static Dictionary ExtractGroupFields(JsonElement record, List groupby) + { + Dictionary result = new(); + foreach (string g in groupby) + { + if (record.TryGetProperty(g, out JsonElement val)) + { + result[g] = McpResponseBuilder.GetJsonValue(val); + } + else + { + result[g] = null; + } + } + + return result; + } + + private static int CompareNullableDoubles(double? a, double? b) + { + if (a == null && b == null) + { + return 0; + } + + if (a == null) + { + return -1; + } + + if (b == null) + { + return 1; + } + + return a.Value.CompareTo(b.Value); + } + } +} diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 262cbc9145..ecd5ecd185 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -224,6 +224,9 @@ public ConfigureOptions( [Option("runtime.mcp.dml-tools.execute-entity.enabled", Required = false, HelpText = "Enable DAB's MCP execute entity tool. Default: true (boolean).")] public bool? RuntimeMcpDmlToolsExecuteEntityEnabled { get; } + [Option("runtime.mcp.dml-tools.aggregate-records.enabled", Required = false, HelpText = "Enable DAB's MCP aggregate records tool. Default: true (boolean).")] + public bool? RuntimeMcpDmlToolsAggregateRecordsEnabled { get; } + [Option("runtime.cache.enabled", Required = false, HelpText = "Enable DAB's cache globally. (You must also enable each entity's cache separately.). Default: false (boolean).")] public bool? RuntimeCacheEnabled { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 6c51f002b7..2eaf50a822 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -1181,6 +1181,7 @@ private static bool TryUpdateConfiguredMcpValues( bool? updateRecord = currentDmlTools?.UpdateRecord; bool? deleteRecord = currentDmlTools?.DeleteRecord; bool? executeEntity = currentDmlTools?.ExecuteEntity; + bool? aggregateRecords = currentDmlTools?.AggregateRecords; updatedValue = options?.RuntimeMcpDmlToolsDescribeEntitiesEnabled; if (updatedValue != null) @@ -1230,6 +1231,14 @@ private static bool TryUpdateConfiguredMcpValues( _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.execute-entity as '{updatedValue}'", updatedValue); } + updatedValue = options?.RuntimeMcpDmlToolsAggregateRecordsEnabled; + if (updatedValue != null) + { + aggregateRecords = (bool)updatedValue; + hasToolUpdates = true; + _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.aggregate-records as '{updatedValue}'", updatedValue); + } + if (hasToolUpdates) { updatedMcpOptions = updatedMcpOptions! with @@ -1242,7 +1251,8 @@ private static bool TryUpdateConfiguredMcpValues( ReadRecords = readRecord, UpdateRecord = updateRecord, DeleteRecord = deleteRecord, - ExecuteEntity = executeEntity + ExecuteEntity = executeEntity, + AggregateRecords = aggregateRecords } }; } diff --git a/src/Config/Converters/DmlToolsConfigConverter.cs b/src/Config/Converters/DmlToolsConfigConverter.cs index 82ac3f6069..7e049c7926 100644 --- a/src/Config/Converters/DmlToolsConfigConverter.cs +++ b/src/Config/Converters/DmlToolsConfigConverter.cs @@ -44,6 +44,7 @@ internal class DmlToolsConfigConverter : JsonConverter bool? updateRecord = null; bool? deleteRecord = null; bool? executeEntity = null; + bool? aggregateRecords = null; while (reader.Read()) { @@ -82,6 +83,9 @@ internal class DmlToolsConfigConverter : JsonConverter case "execute-entity": executeEntity = value; break; + case "aggregate-records": + aggregateRecords = value; + break; default: // Skip unknown properties break; @@ -91,7 +95,8 @@ internal class DmlToolsConfigConverter : JsonConverter { // Error on non-boolean values for known properties if (property?.ToLowerInvariant() is "describe-entities" or "create-record" - or "read-records" or "update-record" or "delete-record" or "execute-entity") + or "read-records" or "update-record" or "delete-record" or "execute-entity" + or "aggregate-records") { throw new JsonException($"Property '{property}' must be a boolean value."); } @@ -110,7 +115,8 @@ internal class DmlToolsConfigConverter : JsonConverter readRecords: readRecords, updateRecord: updateRecord, deleteRecord: deleteRecord, - executeEntity: executeEntity); + executeEntity: executeEntity, + aggregateRecords: aggregateRecords); } // For any other unexpected token type, return default (all enabled) @@ -135,7 +141,8 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer value.UserProvidedReadRecords || value.UserProvidedUpdateRecord || value.UserProvidedDeleteRecord || - value.UserProvidedExecuteEntity; + value.UserProvidedExecuteEntity || + value.UserProvidedAggregateRecords; // Only write the boolean value if it's provided by user // This prevents writing "dml-tools": true when it's the default @@ -181,6 +188,11 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer writer.WriteBoolean("execute-entity", value.ExecuteEntity.Value); } + if (value.UserProvidedAggregateRecords && value.AggregateRecords.HasValue) + { + writer.WriteBoolean("aggregate-records", value.AggregateRecords.Value); + } + writer.WriteEndObject(); } } diff --git a/src/Config/ObjectModel/DmlToolsConfig.cs b/src/Config/ObjectModel/DmlToolsConfig.cs index 2a09e9d53c..c1f8b278cd 100644 --- a/src/Config/ObjectModel/DmlToolsConfig.cs +++ b/src/Config/ObjectModel/DmlToolsConfig.cs @@ -51,6 +51,11 @@ public record DmlToolsConfig /// public bool? ExecuteEntity { get; init; } + /// + /// Whether aggregate-records tool is enabled + /// + public bool? AggregateRecords { get; init; } + [JsonConstructor] public DmlToolsConfig( bool? allToolsEnabled = null, @@ -59,7 +64,8 @@ public DmlToolsConfig( bool? readRecords = null, bool? updateRecord = null, bool? deleteRecord = null, - bool? executeEntity = null) + bool? executeEntity = null, + bool? aggregateRecords = null) { if (allToolsEnabled is not null) { @@ -75,6 +81,7 @@ public DmlToolsConfig( UpdateRecord = updateRecord ?? toolDefault; DeleteRecord = deleteRecord ?? toolDefault; ExecuteEntity = executeEntity ?? toolDefault; + AggregateRecords = aggregateRecords ?? toolDefault; } else { @@ -87,6 +94,7 @@ public DmlToolsConfig( UpdateRecord = updateRecord ?? DEFAULT_ENABLED; DeleteRecord = deleteRecord ?? DEFAULT_ENABLED; ExecuteEntity = executeEntity ?? DEFAULT_ENABLED; + AggregateRecords = aggregateRecords ?? DEFAULT_ENABLED; } // Track user-provided status - only true if the parameter was not null @@ -96,6 +104,7 @@ public DmlToolsConfig( UserProvidedUpdateRecord = updateRecord is not null; UserProvidedDeleteRecord = deleteRecord is not null; UserProvidedExecuteEntity = executeEntity is not null; + UserProvidedAggregateRecords = aggregateRecords is not null; } /// @@ -112,7 +121,8 @@ public static DmlToolsConfig FromBoolean(bool enabled) readRecords: null, updateRecord: null, deleteRecord: null, - executeEntity: null + executeEntity: null, + aggregateRecords: null ); } @@ -127,7 +137,8 @@ public static DmlToolsConfig FromBoolean(bool enabled) readRecords: null, updateRecord: null, deleteRecord: null, - executeEntity: null + executeEntity: null, + aggregateRecords: null ); /// @@ -185,4 +196,12 @@ public static DmlToolsConfig FromBoolean(bool enabled) [JsonIgnore(Condition = JsonIgnoreCondition.Always)] [MemberNotNullWhen(true, nameof(ExecuteEntity))] public bool UserProvidedExecuteEntity { get; init; } = false; + + /// + /// Flag which informs CLI and JSON serializer whether to write aggregate-records + /// property/value to the runtime config file. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(AggregateRecords))] + public bool UserProvidedAggregateRecords { get; init; } = false; } diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs new file mode 100644 index 0000000000..a1fb2b691c --- /dev/null +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -0,0 +1,596 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Mcp.BuiltInTools; +using Azure.DataApiBuilder.Mcp.Model; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.Mcp +{ + /// + /// Tests for the AggregateRecordsTool MCP tool. + /// Covers: + /// - Tool metadata and schema validation + /// - Runtime-level enabled/disabled configuration + /// - Entity-level DML tool configuration + /// - Input validation (missing/invalid arguments) + /// - In-memory aggregation logic (count, avg, sum, min, max) + /// - distinct, groupby, having, orderby + /// - Alias convention + /// + [TestClass] + public class AggregateRecordsToolTests + { + #region Tool Metadata Tests + + [TestMethod] + public void GetToolMetadata_ReturnsCorrectName() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + Assert.AreEqual("aggregate_records", metadata.Name); + } + + [TestMethod] + public void GetToolMetadata_ReturnsCorrectToolType() + { + AggregateRecordsTool tool = new(); + Assert.AreEqual(McpEnums.ToolType.BuiltIn, tool.ToolType); + } + + [TestMethod] + public void GetToolMetadata_HasInputSchema() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + Assert.AreEqual(JsonValueKind.Object, metadata.InputSchema.ValueKind); + Assert.IsTrue(metadata.InputSchema.TryGetProperty("properties", out _)); + Assert.IsTrue(metadata.InputSchema.TryGetProperty("required", out JsonElement required)); + + List requiredFields = new(); + foreach (JsonElement r in required.EnumerateArray()) + { + requiredFields.Add(r.GetString()!); + } + + CollectionAssert.Contains(requiredFields, "entity"); + CollectionAssert.Contains(requiredFields, "function"); + CollectionAssert.Contains(requiredFields, "field"); + } + + #endregion + + #region Configuration Tests + + [TestMethod] + public async Task AggregateRecords_DisabledAtRuntimeLevel_ReturnsToolDisabledError() + { + RuntimeConfig config = CreateConfig(aggregateRecordsEnabled: false); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + AssertToolDisabledError(content); + } + + [TestMethod] + public async Task AggregateRecords_DisabledAtEntityLevel_ReturnsToolDisabledError() + { + RuntimeConfig config = CreateConfigWithEntityDmlDisabled(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + AssertToolDisabledError(content); + } + + #endregion + + #region Input Validation Tests + + [TestMethod] + public async Task AggregateRecords_NullArguments_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + CallToolResult result = await tool.ExecuteAsync(null, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + Assert.AreEqual("InvalidArguments", error.GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingEntity_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingFunction_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingField_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_InvalidFunction_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"median\", \"field\": \"price\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("median")); + } + + #endregion + + #region Alias Convention Tests + + [TestMethod] + public void ComputeAlias_CountStar_ReturnsCount() + { + Assert.AreEqual("count", AggregateRecordsTool.ComputeAlias("count", "*")); + } + + [TestMethod] + public void ComputeAlias_CountField_ReturnsFunctionField() + { + Assert.AreEqual("count_supplierId", AggregateRecordsTool.ComputeAlias("count", "supplierId")); + } + + [TestMethod] + public void ComputeAlias_AvgField_ReturnsFunctionField() + { + Assert.AreEqual("avg_unitPrice", AggregateRecordsTool.ComputeAlias("avg", "unitPrice")); + } + + [TestMethod] + public void ComputeAlias_SumField_ReturnsFunctionField() + { + Assert.AreEqual("sum_unitPrice", AggregateRecordsTool.ComputeAlias("sum", "unitPrice")); + } + + [TestMethod] + public void ComputeAlias_MinField_ReturnsFunctionField() + { + Assert.AreEqual("min_price", AggregateRecordsTool.ComputeAlias("min", "price")); + } + + [TestMethod] + public void ComputeAlias_MaxField_ReturnsFunctionField() + { + Assert.AreEqual("max_price", AggregateRecordsTool.ComputeAlias("max", "price")); + } + + #endregion + + #region In-Memory Aggregation Tests + + [TestMethod] + public void PerformAggregation_CountStar_ReturnsCount() + { + JsonElement records = ParseArray("[{\"id\":1},{\"id\":2},{\"id\":3}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(3.0, result[0]["count"]); + } + + [TestMethod] + public void PerformAggregation_Avg_ReturnsAverage() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(20.0, result[0]["avg_price"]); + } + + [TestMethod] + public void PerformAggregation_Sum_ReturnsSum() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), null, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(60.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_Min_ReturnsMin() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "min", "price", false, new(), null, null, "desc", "min_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(5.0, result[0]["min_price"]); + } + + [TestMethod] + public void PerformAggregation_Max_ReturnsMax() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "max", "price", false, new(), null, null, "desc", "max_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(20.0, result[0]["max_price"]); + } + + [TestMethod] + public void PerformAggregation_CountDistinct_ReturnsDistinctCount() + { + JsonElement records = ParseArray("[{\"supplierId\":1},{\"supplierId\":2},{\"supplierId\":1},{\"supplierId\":3}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "supplierId", true, new(), null, null, "desc", "count_supplierId"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(3.0, result[0]["count_supplierId"]); + } + + [TestMethod] + public void PerformAggregation_AvgDistinct_ReturnsDistinctAvg() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":10},{\"price\":20},{\"price\":30}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", true, new(), null, null, "desc", "avg_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(20.0, result[0]["avg_price"]); + } + + [TestMethod] + public void PerformAggregation_GroupBy_ReturnsGroupedResults() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":50}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "desc", "sum_price"); + + Assert.AreEqual(2, result.Count); + // Desc order: B(50) first, then A(30) + Assert.AreEqual("B", result[0]["category"]?.ToString()); + Assert.AreEqual(50.0, result[0]["sum_price"]); + Assert.AreEqual("A", result[1]["category"]?.ToString()); + Assert.AreEqual(30.0, result[1]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_GroupBy_Asc_ReturnsSortedAsc() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":30},{\"category\":\"A\",\"price\":20}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "asc", "sum_price"); + + Assert.AreEqual(2, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + Assert.AreEqual(30.0, result[0]["sum_price"]); + Assert.AreEqual("B", result[1]["category"]?.ToString()); + Assert.AreEqual(30.0, result[1]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_CountStar_GroupBy_ReturnsGroupCounts() + { + JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, null, "desc", "count"); + + Assert.AreEqual(2, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + Assert.AreEqual(2.0, result[0]["count"]); + Assert.AreEqual("B", result[1]["category"]?.ToString()); + Assert.AreEqual(1.0, result[1]["count"]); + } + + [TestMethod] + public void PerformAggregation_HavingGt_FiltersResults() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":5}]"); + var having = new Dictionary { ["gt"] = 10 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + Assert.AreEqual(30.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_HavingGteLte_FiltersRange() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":100},{\"category\":\"B\",\"price\":20},{\"category\":\"C\",\"price\":1}]"); + var having = new Dictionary { ["gte"] = 10, ["lte"] = 50 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("B", result[0]["category"]?.ToString()); + } + + [TestMethod] + public void PerformAggregation_HavingIn_FiltersExactValues() + { + JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"},{\"category\":\"C\"},{\"category\":\"C\"},{\"category\":\"C\"}]"); + var havingIn = new List { 2, 3 }; + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, havingIn, "desc", "count"); + + Assert.AreEqual(2, result.Count); + // C(3) desc, A(2) + Assert.AreEqual("C", result[0]["category"]?.ToString()); + Assert.AreEqual(3.0, result[0]["count"]); + Assert.AreEqual("A", result[1]["category"]?.ToString()); + Assert.AreEqual(2.0, result[1]["count"]); + } + + [TestMethod] + public void PerformAggregation_HavingEq_FiltersSingleValue() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]"); + var having = new Dictionary { ["eq"] = 10 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + } + + [TestMethod] + public void PerformAggregation_HavingNeq_FiltersOutValue() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]"); + var having = new Dictionary { ["neq"] = 10 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("B", result[0]["category"]?.ToString()); + } + + [TestMethod] + public void PerformAggregation_EmptyRecords_ReturnsNull() + { + JsonElement records = ParseArray("[]"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price"); + + Assert.AreEqual(1, result.Count); + Assert.IsNull(result[0]["avg_price"]); + } + + [TestMethod] + public void PerformAggregation_EmptyRecordsCountStar_ReturnsZero() + { + JsonElement records = ParseArray("[]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(0.0, result[0]["count"]); + } + + [TestMethod] + public void PerformAggregation_MultipleGroupByFields_ReturnsCorrectGroups() + { + JsonElement records = ParseArray("[{\"cat\":\"A\",\"region\":\"East\",\"price\":10},{\"cat\":\"A\",\"region\":\"East\",\"price\":20},{\"cat\":\"A\",\"region\":\"West\",\"price\":5}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "cat", "region" }, null, null, "desc", "sum_price"); + + Assert.AreEqual(2, result.Count); + // (A,East)=30 desc, (A,West)=5 + Assert.AreEqual("A", result[0]["cat"]?.ToString()); + Assert.AreEqual("East", result[0]["region"]?.ToString()); + Assert.AreEqual(30.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_HavingNoResults_ReturnsEmpty() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10}]"); + var having = new Dictionary { ["gt"] = 100 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(0, result.Count); + } + + [TestMethod] + public void PerformAggregation_HavingOnSingleResult_Passes() + { + JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]"); + var having = new Dictionary { ["gte"] = 100 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(110.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_HavingOnSingleResult_Fails() + { + JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]"); + var having = new Dictionary { ["gt"] = 200 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price"); + + Assert.AreEqual(0, result.Count); + } + + #endregion + + #region Helper Methods + + private static JsonElement ParseArray(string json) + { + return JsonDocument.Parse(json).RootElement; + } + + private static JsonElement ParseContent(CallToolResult result) + { + TextContentBlock firstContent = (TextContentBlock)result.Content[0]; + return JsonDocument.Parse(firstContent.Text).RootElement; + } + + private static void AssertToolDisabledError(JsonElement content) + { + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + Assert.IsTrue(error.TryGetProperty("type", out JsonElement errorType)); + Assert.AreEqual("ToolDisabled", errorType.GetString()); + } + + private static RuntimeConfig CreateConfig(bool aggregateRecordsEnabled = true) + { + Dictionary entities = new() + { + ["Book"] = new Entity( + Source: new("books", EntitySourceType.Table, null, null), + GraphQL: new("Book", "Books"), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { + new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null) + }) }, + Mappings: null, + Relationships: null, + Mcp: null + ) + }; + + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: aggregateRecordsEnabled + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + + private static RuntimeConfig CreateConfigWithEntityDmlDisabled() + { + Dictionary entities = new() + { + ["Book"] = new Entity( + Source: new("books", EntitySourceType.Table, null, null), + GraphQL: new("Book", "Books"), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { + new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null) + }) }, + Mappings: null, + Relationships: null, + Mcp: new EntityMcpOptions(customToolEnabled: false, dmlToolsEnabled: false) + ) + }; + + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + + private static IServiceProvider CreateServiceProvider(RuntimeConfig config) + { + ServiceCollection services = new(); + + RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + + Mock mockAuthResolver = new(); + mockAuthResolver.Setup(x => x.IsValidRoleContext(It.IsAny())).Returns(true); + services.AddSingleton(mockAuthResolver.Object); + + Mock mockHttpContext = new(); + Mock mockRequest = new(); + mockRequest.Setup(x => x.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]).Returns("anonymous"); + mockHttpContext.Setup(x => x.Request).Returns(mockRequest.Object); + + Mock mockHttpContextAccessor = new(); + mockHttpContextAccessor.Setup(x => x.HttpContext).Returns(mockHttpContext.Object); + services.AddSingleton(mockHttpContextAccessor.Object); + + services.AddLogging(); + + return services.BuildServiceProvider(); + } + + #endregion + } +} diff --git a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs index d2f6554cd3..b4ae074207 100644 --- a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs +++ b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs @@ -48,6 +48,7 @@ public class EntityLevelDmlToolConfigurationTests [DataRow("UpdateRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}, \"fields\": {\"title\": \"Updated\"}}", false, DisplayName = "UpdateRecord respects entity-level DmlToolEnabled=false")] [DataRow("DeleteRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}}", false, DisplayName = "DeleteRecord respects entity-level DmlToolEnabled=false")] [DataRow("ExecuteEntity", "{\"entity\": \"GetBook\"}", true, DisplayName = "ExecuteEntity respects entity-level DmlToolEnabled=false")] + [DataRow("AggregateRecords", "{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}", false, DisplayName = "AggregateRecords respects entity-level DmlToolEnabled=false")] public async Task DmlTool_RespectsEntityLevelDmlToolDisabled(string toolType, string jsonArguments, bool isStoredProcedure) { // Arrange @@ -238,6 +239,7 @@ private static IMcpTool CreateTool(string toolType) "UpdateRecord" => new UpdateRecordTool(), "DeleteRecord" => new DeleteRecordTool(), "ExecuteEntity" => new ExecuteEntityTool(), + "AggregateRecords" => new AggregateRecordsTool(), _ => throw new ArgumentException($"Unknown tool type: {toolType}", nameof(toolType)) }; }