Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ def _get_grpc_metadata(
extensions: list[str] | None = None,
) -> list[tuple[str, str]] | None:
"""Creates gRPC metadata for extensions."""
if extensions is not None:
return [(HTTP_EXTENSION_HEADER, ','.join(extensions))]
if self.extensions is not None:
return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))]
ext_to_use = extensions if extensions is not None else self.extensions
if ext_to_use is not None:
return [(HTTP_EXTENSION_HEADER.lower(), ','.join(ext_to_use))]
return None

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def _get_metadata_value(
) -> list[str]:
md = context.invocation_metadata
raw_values: list[str | bytes] = []
lower_key = key.lower()
if isinstance(md, Metadata):
raw_values = md.get_all(key)
raw_values = md.get_all(lower_key)
elif isinstance(md, Sequence):
lower_key = key.lower()
raw_values = [e for (k, e) in md if k.lower() == lower_key]
return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values]

Expand Down Expand Up @@ -417,7 +417,7 @@ def _set_extension_metadata(
if server_context.activated_extensions:
context.set_trailing_metadata(
[
(HTTP_EXTENSION_HEADER, e)
(HTTP_EXTENSION_HEADER.lower(), e)
for e in sorted(server_context.activated_extensions)
]
)
47 changes: 34 additions & 13 deletions tests/client/transports/test_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def test_send_message_task_response(
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert kwargs['metadata'] == [
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v3',
)
]
Expand All @@ -228,7 +228,7 @@ async def test_send_message_message_response(
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert kwargs['metadata'] == [
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
]
Expand Down Expand Up @@ -283,7 +283,7 @@ async def test_send_message_streaming( # noqa: PLR0913
_, kwargs = mock_grpc_stub.SendStreamingMessage.call_args
assert kwargs['metadata'] == [
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
]
Expand Down Expand Up @@ -313,7 +313,7 @@ async def test_get_task(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand All @@ -338,7 +338,7 @@ async def test_get_task_with_history(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand All @@ -363,7 +363,9 @@ async def test_cancel_task(

mock_grpc_stub.CancelTask.assert_awaited_once_with(
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'),
metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')],
metadata=[
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3')
],
)
assert response.status.state == TaskState.canceled

Expand Down Expand Up @@ -395,7 +397,7 @@ async def test_set_task_callback_with_valid_task(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand Down Expand Up @@ -458,7 +460,7 @@ async def test_get_task_callback_with_valid_task(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand Down Expand Up @@ -506,27 +508,27 @@ async def test_get_task_callback_with_invalid_task(
(
['ext1'],
None,
[(HTTP_EXTENSION_HEADER, 'ext1')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext1')],
), # Case 2: Initial, No input
(
None,
['ext2'],
[(HTTP_EXTENSION_HEADER, 'ext2')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
), # Case 3: No initial, Input
(
['ext1'],
['ext2'],
[(HTTP_EXTENSION_HEADER, 'ext2')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
), # Case 4: Initial, Input (override)
(
['ext1'],
['ext2', 'ext3'],
[(HTTP_EXTENSION_HEADER, 'ext2,ext3')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')],
), # Case 5: Initial, Multiple inputs (override)
(
['ext1', 'ext2'],
['ext3'],
[(HTTP_EXTENSION_HEADER, 'ext3')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext3')],
), # Case 6: Multiple initial, Single input (override)
],
)
Expand All @@ -540,3 +542,22 @@ def test_get_grpc_metadata(
grpc_transport.extensions = initial_extensions
metadata = grpc_transport._get_grpc_metadata(input_extensions)
assert metadata == expected_metadata


@pytest.mark.parametrize(
'test_extensions',
[
(['ext1']), # Test with explicit extensions
(None), # Test with transport's default extensions
],
)
def test_get_grpc_metadata_uses_lowercase_header_key(
grpc_transport: GrpcTransport,
test_extensions: list[str] | None,
) -> None:
"""Test gRPC metadata header key is always lowercase."""
# Regression: gRPC rejects non-lowercase metadata keys
metadata = grpc_transport._get_grpc_metadata(test_extensions)
if metadata:
key, _ = metadata[0]
assert key == key.lower()
20 changes: 10 additions & 10 deletions tests/server/request_handlers/test_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ async def test_send_message_with_extensions(
mock_grpc_context: AsyncMock,
) -> None:
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'bar'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
)

def side_effect(request, context: ServerCallContext):
Expand Down Expand Up @@ -379,8 +379,8 @@ def side_effect(request, context: ServerCallContext):
mock_grpc_context.set_trailing_metadata.call_args.args[0]
)
assert set(called_metadata) == {
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'baz'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
}

async def test_send_message_with_comma_separated_extensions(
Expand All @@ -390,8 +390,8 @@ async def test_send_message_with_comma_separated_extensions(
mock_grpc_context: AsyncMock,
) -> None:
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
(HTTP_EXTENSION_HEADER, 'foo ,, bar,'),
(HTTP_EXTENSION_HEADER, 'baz , bar'),
(HTTP_EXTENSION_HEADER.lower(), 'foo ,, bar,'),
(HTTP_EXTENSION_HEADER.lower(), 'baz , bar'),
)
mock_request_handler.on_message_send.return_value = types.Message(
message_id='1',
Expand All @@ -415,8 +415,8 @@ async def test_send_streaming_message_with_extensions(
mock_grpc_context: AsyncMock,
) -> None:
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'bar'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
)

async def side_effect(request, context: ServerCallContext):
Expand Down Expand Up @@ -450,6 +450,6 @@ async def side_effect(request, context: ServerCallContext):
mock_grpc_context.set_trailing_metadata.call_args.args[0]
)
assert set(called_metadata) == {
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'baz'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
}
Loading