diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 6a8b16f9..fa3d017a 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -64,10 +64,9 @@ def _get_grpc_metadata( extensions: list[str] | None = None, ) -> list[tuple[str, str]] | None: """Creates gRPC metadata for extensions.""" - if extensions is not None: - return [(HTTP_EXTENSION_HEADER, ','.join(extensions))] - if self.extensions is not None: - return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))] + ext_to_use = extensions if extensions is not None else self.extensions + if ext_to_use is not None: + return [(HTTP_EXTENSION_HEADER.lower(), ','.join(ext_to_use))] return None @classmethod diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 105b9947..d4a6bbd4 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -55,10 +55,10 @@ def _get_metadata_value( ) -> list[str]: md = context.invocation_metadata raw_values: list[str | bytes] = [] + lower_key = key.lower() if isinstance(md, Metadata): - raw_values = md.get_all(key) + raw_values = md.get_all(lower_key) elif isinstance(md, Sequence): - lower_key = key.lower() raw_values = [e for (k, e) in md if k.lower() == lower_key] return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values] @@ -417,7 +417,7 @@ def _set_extension_metadata( if server_context.activated_extensions: context.set_trailing_metadata( [ - (HTTP_EXTENSION_HEADER, e) + (HTTP_EXTENSION_HEADER.lower(), e) for e in sorted(server_context.activated_extensions) ] ) diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index 111e44ba..3be7b3f0 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -202,7 +202,7 @@ async def test_send_message_task_response( _, kwargs = mock_grpc_stub.SendMessage.call_args assert kwargs['metadata'] == [ ( - HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3', ) ] @@ -228,7 +228,7 @@ async def test_send_message_message_response( _, kwargs = mock_grpc_stub.SendMessage.call_args assert kwargs['metadata'] == [ ( - HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ] @@ -283,7 +283,7 @@ async def test_send_message_streaming( # noqa: PLR0913 _, kwargs = mock_grpc_stub.SendStreamingMessage.call_args assert kwargs['metadata'] == [ ( - HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ] @@ -313,7 +313,7 @@ async def test_get_task( ), metadata=[ ( - HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], @@ -338,7 +338,7 @@ async def test_get_task_with_history( ), metadata=[ ( - HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], @@ -363,7 +363,9 @@ async def test_cancel_task( mock_grpc_stub.CancelTask.assert_awaited_once_with( a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'), - metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')], + metadata=[ + (HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3') + ], ) assert response.status.state == TaskState.canceled @@ -395,7 +397,7 @@ async def test_set_task_callback_with_valid_task( ), metadata=[ ( - HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], @@ -458,7 +460,7 @@ async def test_get_task_callback_with_valid_task( ), metadata=[ ( - HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], @@ -506,27 +508,27 @@ async def test_get_task_callback_with_invalid_task( ( ['ext1'], None, - [(HTTP_EXTENSION_HEADER, 'ext1')], + [(HTTP_EXTENSION_HEADER.lower(), 'ext1')], ), # Case 2: Initial, No input ( None, ['ext2'], - [(HTTP_EXTENSION_HEADER, 'ext2')], + [(HTTP_EXTENSION_HEADER.lower(), 'ext2')], ), # Case 3: No initial, Input ( ['ext1'], ['ext2'], - [(HTTP_EXTENSION_HEADER, 'ext2')], + [(HTTP_EXTENSION_HEADER.lower(), 'ext2')], ), # Case 4: Initial, Input (override) ( ['ext1'], ['ext2', 'ext3'], - [(HTTP_EXTENSION_HEADER, 'ext2,ext3')], + [(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')], ), # Case 5: Initial, Multiple inputs (override) ( ['ext1', 'ext2'], ['ext3'], - [(HTTP_EXTENSION_HEADER, 'ext3')], + [(HTTP_EXTENSION_HEADER.lower(), 'ext3')], ), # Case 6: Multiple initial, Single input (override) ], ) @@ -540,3 +542,22 @@ def test_get_grpc_metadata( grpc_transport.extensions = initial_extensions metadata = grpc_transport._get_grpc_metadata(input_extensions) assert metadata == expected_metadata + + +@pytest.mark.parametrize( + 'test_extensions', + [ + (['ext1']), # Test with explicit extensions + (None), # Test with transport's default extensions + ], +) +def test_get_grpc_metadata_uses_lowercase_header_key( + grpc_transport: GrpcTransport, + test_extensions: list[str] | None, +) -> None: + """Test gRPC metadata header key is always lowercase.""" + # Regression: gRPC rejects non-lowercase metadata keys + metadata = grpc_transport._get_grpc_metadata(test_extensions) + if metadata: + key, _ = metadata[0] + assert key == key.lower() diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 647d9e86..e9d8cb26 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -350,8 +350,8 @@ async def test_send_message_with_extensions( mock_grpc_context: AsyncMock, ) -> None: mock_grpc_context.invocation_metadata = grpc.aio.Metadata( - (HTTP_EXTENSION_HEADER, 'foo'), - (HTTP_EXTENSION_HEADER, 'bar'), + (HTTP_EXTENSION_HEADER.lower(), 'foo'), + (HTTP_EXTENSION_HEADER.lower(), 'bar'), ) def side_effect(request, context: ServerCallContext): @@ -379,8 +379,8 @@ def side_effect(request, context: ServerCallContext): mock_grpc_context.set_trailing_metadata.call_args.args[0] ) assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER, 'foo'), - (HTTP_EXTENSION_HEADER, 'baz'), + (HTTP_EXTENSION_HEADER.lower(), 'foo'), + (HTTP_EXTENSION_HEADER.lower(), 'baz'), } async def test_send_message_with_comma_separated_extensions( @@ -390,8 +390,8 @@ async def test_send_message_with_comma_separated_extensions( mock_grpc_context: AsyncMock, ) -> None: mock_grpc_context.invocation_metadata = grpc.aio.Metadata( - (HTTP_EXTENSION_HEADER, 'foo ,, bar,'), - (HTTP_EXTENSION_HEADER, 'baz , bar'), + (HTTP_EXTENSION_HEADER.lower(), 'foo ,, bar,'), + (HTTP_EXTENSION_HEADER.lower(), 'baz , bar'), ) mock_request_handler.on_message_send.return_value = types.Message( message_id='1', @@ -415,8 +415,8 @@ async def test_send_streaming_message_with_extensions( mock_grpc_context: AsyncMock, ) -> None: mock_grpc_context.invocation_metadata = grpc.aio.Metadata( - (HTTP_EXTENSION_HEADER, 'foo'), - (HTTP_EXTENSION_HEADER, 'bar'), + (HTTP_EXTENSION_HEADER.lower(), 'foo'), + (HTTP_EXTENSION_HEADER.lower(), 'bar'), ) async def side_effect(request, context: ServerCallContext): @@ -450,6 +450,6 @@ async def side_effect(request, context: ServerCallContext): mock_grpc_context.set_trailing_metadata.call_args.args[0] ) assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER, 'foo'), - (HTTP_EXTENSION_HEADER, 'baz'), + (HTTP_EXTENSION_HEADER.lower(), 'foo'), + (HTTP_EXTENSION_HEADER.lower(), 'baz'), }