From e0bfe203025b7a98ea4d626a98a940f9e45322d3 Mon Sep 17 00:00:00 2001
From: Priyank <15610225+priyankinfinnov@users.noreply.github.com>
Date: Mon, 19 Jan 2026 21:30:23 +0530
Subject: [PATCH 1/2] portkey integration code and tests
---
.../tracing/portkey/portkey_tracing.ipynb | 151 ++++
src/openlayer/lib/__init__.py | 38 +
.../lib/integrations/portkey_tracer.py | 756 ++++++++++++++++++
tests/test_integration_conditional_imports.py | 1 +
tests/test_portkey_integration.py | 576 +++++++++++++
5 files changed, 1522 insertions(+)
create mode 100644 examples/tracing/portkey/portkey_tracing.ipynb
create mode 100644 src/openlayer/lib/integrations/portkey_tracer.py
create mode 100644 tests/test_portkey_integration.py
diff --git a/examples/tracing/portkey/portkey_tracing.ipynb b/examples/tracing/portkey/portkey_tracing.ipynb
new file mode 100644
index 00000000..425b9b7a
--- /dev/null
+++ b/examples/tracing/portkey/portkey_tracing.ipynb
@@ -0,0 +1,151 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/openlayer-ai/openlayer-python/blob/main/examples/tracing/portkey/portkey_tracing.ipynb)\n",
+ "\n",
+ "\n",
+ "# Portkey monitoring quickstart\n",
+ "\n",
+ "This notebook illustrates how to get started monitoring Portkey completions with Openlayer.\n",
+ "\n",
+ "Portkey provides a unified interface to call 100+ LLM APIs using the same input/output format. This integration allows you to trace and monitor completions across all supported providers through a single interface.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install openlayer portkey-ai"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Set the environment variables\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "from portkey_ai import Portkey\n",
+ "\n",
+ "# Set your Portkey API keys\n",
+ "os.environ['PORTKEY_API_KEY'] = \"YOUR_PORTKEY_API_HERE\"\n",
+ "\n",
+ "# Openlayer env variables\n",
+ "os.environ[\"OPENLAYER_API_KEY\"] = \"YOUR_OPENLAYER_API_KEY_HERE\"\n",
+ "os.environ[\"OPENLAYER_INFERENCE_PIPELINE_ID\"] = \"YOUR_OPENLAYER_INFERENCE_PIPELINE_ID_HERE\"\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Enable Portkey tracing\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from openlayer.lib import trace_portkey\n",
+ "\n",
+ "# Enable openlayer tracing for all Portkey completions\n",
+ "trace_portkey()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Use Portkey normally - tracing happens automatically!\n",
+ "\n",
+ "### Basic completion with OpenAI\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Basic portkey client initialization\n",
+ "portkey = Portkey(\n",
+ " api_key = os.environ['PORTKEY_API_KEY'],\n",
+ " config = \"YOUR_PORTKEY_CONFIG_ID_HERE\", # optional your portkey config id\n",
+ ")\n",
+ "\n",
+ "# Basic portkey LLM call\n",
+ "response = portkey.chat.completions.create(\n",
+ " #model = \"@YOUR_PORTKEY_SLUG/YOUR_MODEL_NAME\", # optional if giving config\n",
+ " messages = [\n",
+ " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
+ " {\"role\": \"user\", \"content\": \"Write a poem on Argentina, least 500 words.\"}\n",
+ " ]\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. View your traces\n",
+ "\n",
+ "Once you've run the examples above, you can:\n",
+ "\n",
+ "1. **Visit your OpenLayer dashboard** to see all the traced completions\n",
+ "2. **Analyze performance** across different models and providers\n",
+ "3. **Monitor costs** and token usage\n",
+ "4. **Debug issues** with detailed request/response logs\n",
+ "5. **Compare models** side-by-side\n",
+ "\n",
+ "The traces will include:\n",
+ "- **Request details**: Model, parameters, messages\n",
+ "- **Response data**: Generated content, token counts, latency\n",
+ "- **Provider information**: Which underlying service was used\n",
+ "- **Custom metadata**: Any additional context you provide\n",
+ "\n",
+ "For more information, check out:\n",
+ "- [OpenLayer Documentation](https://docs.openlayer.com/)\n",
+ "- [Portkey Documentation](https://portkey.ai/docs)\n",
+ "- [Portkey AI Gateway](https://portkey.ai/docs/product/ai-gateway)\n",
+ "- [Portkey Supported Providers](https://portkey.ai/docs/api-reference/inference-api/supported-providers)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/openlayer/lib/__init__.py b/src/openlayer/lib/__init__.py
index a22dbd98..7da581d2 100644
--- a/src/openlayer/lib/__init__.py
+++ b/src/openlayer/lib/__init__.py
@@ -14,6 +14,7 @@
"trace_oci_genai",
"trace_oci", # Alias for backward compatibility
"trace_litellm",
+ "trace_portkey",
"trace_google_adk",
"unpatch_google_adk",
"trace_gemini",
@@ -193,6 +194,43 @@ def trace_litellm():
return litellm_tracer.trace_litellm()
+# ---------------------------------- Portkey ---------------------------------- #
+def trace_portkey():
+ """Enable tracing for Portkey completions.
+
+ This function patches Portkey's chat.completions.create to automatically trace
+ all OpenAI-compatible completions routed via the Portkey AI Gateway.
+
+ Example:
+ >>> from portkey_ai import Portkey
+ >>> from openlayer.lib import trace_portkey
+ >>> # Enable openlayer tracing for all Portkey completions
+ >>> trace_portkey()
+ >>> # Basic portkey client initialization
+ >>> portkey = Portkey(
+ >>> api_key = os.environ['PORTKEY_API_KEY'],
+ >>> config = "YOUR_PORTKEY_CONFIG_ID", # optional your portkey config id
+ >>> )
+ >>> # use portkey normally - tracing happens automatically
+ >>> response = portkey.chat.completions.create(
+ >>> #model = "@YOUR_PORTKEY_SLUG/YOUR_MODEL_NAME", # optional if giving config
+ >>> messages = [
+ >>> {"role": "system", "content": "You are a helpful assistant."},
+ >>> {"role": "user", "content": "Write a poem on Argentina, least 100 words."}
+ >>> ]
+ >>> )
+ """
+ # pylint: disable=import-outside-toplevel
+ try:
+ from portkey_ai import Portkey # noqa: F401
+ except ImportError:
+ raise ImportError("portkey-ai is required for Portkey tracing. Install with: pip install portkey-ai")
+
+ from .integrations import portkey_tracer
+
+ return portkey_tracer.trace_portkey()
+
+
# ------------------------------ Google ADK ---------------------------------- #
def trace_google_adk(disable_adk_otel: bool = False):
"""Enable tracing for Google Agent Development Kit (ADK).
diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py
new file mode 100644
index 00000000..f9c6ca9e
--- /dev/null
+++ b/src/openlayer/lib/integrations/portkey_tracer.py
@@ -0,0 +1,756 @@
+"""Module with methods used to trace Portkey AI Gateway chat completions."""
+
+import json
+import logging
+import time
+from functools import wraps
+from typing import Any, Dict, Iterator, Optional, Union, TYPE_CHECKING
+
+try:
+ from portkey_ai import Portkey
+ HAVE_PORTKEY = True
+except ImportError:
+ HAVE_PORTKEY = False
+
+if TYPE_CHECKING:
+ from portkey_ai import Portkey
+
+from ..tracing import tracer
+
+logger = logging.getLogger(__name__)
+
+
+def trace_portkey() -> None:
+ """Patch Portkey's chat.completions.create to trace completions.
+
+ The following information is collected for each completion:
+ - start_time: The time when the completion was requested.
+ - end_time: The time when the completion was received.
+ - latency: The time it took to generate the completion.
+ - tokens: The total number of tokens used to generate the completion.
+ - prompt_tokens: The number of tokens in the prompt.
+ - completion_tokens: The number of tokens in the completion.
+ - model: The model used to generate the completion.
+ - model_parameters: The parameters used to configure the model.
+ - raw_output: The raw output of the model.
+ - inputs: The inputs used to generate the completion.
+ - Portkey-specific metadata (base URL, x-portkey-* headers if available)
+
+ Returns
+ -------
+ None
+ This function patches portkey.chat.completions.create in place.
+
+ Example
+ -------
+ >>> from portkey_ai import Portkey
+ >>> from openlayer.lib import trace_portkey
+ >>>
+ >>> # Enable tracing
+ >>> trace_portkey()
+ >>>
+ >>> # Use Portkey normally - tracing happens automatically
+ >>> portkey = Portkey(api_key = "YOUR_PORTKEY_API_KEY")
+ >>> response = portkey.chat.completions.create(
+ ... model = "@YOUR_PROVIDER_SLUG/MODEL_NAME",
+ ... messages = [
+ ... {"role": "system", "content": "You are a helpful assistant."},
+ ... {"role": "user", "content": "What is Portkey"}
+ ... ],
+ ... inference_id="custom-id-123" # Optional Openlayer parameter
+ ... max_tokens = 512
+ ... )
+ """
+ if not HAVE_PORTKEY:
+ raise ImportError(
+ "Portkey library is not installed. Please install it with: pip install portkey-ai"
+ )
+
+ # Patch instances on initialization rather than class-level attributes.
+ # Some SDKs initialize 'chat' lazily on the instance.
+ original_init = Portkey.__init__
+
+ @wraps(original_init)
+ def traced_init(self, *args, **kwargs):
+ original_init(self, *args, **kwargs)
+ try:
+ # Avoid double-patching
+ if getattr(self, "_openlayer_portkey_patched", False):
+ return
+ # Access chat to ensure it's constructed, then wrap create
+ chat = getattr(self, "chat", None)
+ if chat is None or not hasattr(chat, "completions") or not hasattr(chat.completions, "create"):
+ # If the structure isn't present, skip gracefully and log diagnostics
+ logger.debug(
+ "Openlayer Portkey tracer: Portkey client missing expected attributes (chat/completions/create). "
+ "Tracing not applied for this instance."
+ )
+ return
+ original_create = chat.completions.create
+
+ @wraps(original_create)
+ def traced_create(*c_args, **c_kwargs):
+ inference_id = c_kwargs.pop("inference_id", None)
+ stream = c_kwargs.get("stream", False)
+ if stream:
+ return handle_streaming_create(
+ self,
+ *c_args,
+ create_func=original_create,
+ inference_id=inference_id,
+ **c_kwargs,
+ )
+ return handle_non_streaming_create(
+ self,
+ *c_args,
+ create_func=original_create,
+ inference_id=inference_id,
+ **c_kwargs,
+ )
+
+ self.chat.completions.create = traced_create
+ setattr(self, "_openlayer_portkey_patched", True)
+ logger.debug("Openlayer Portkey tracer: successfully patched Portkey client instance for tracing.")
+ except Exception as e:
+ logger.debug("Failed to patch Portkey client instance for tracing: %s", e)
+
+ Portkey.__init__ = traced_init
+ logger.info("Openlayer Portkey tracer: tracing enabled (instance-level patch).")
+
+
+def handle_streaming_create(
+ client: "Portkey",
+ *args,
+ create_func: callable,
+ inference_id: Optional[str] = None,
+ **kwargs,
+) -> Iterator[Any]:
+ """
+ Handles streaming chat.completions.create routed via Portkey.
+
+ Parameters
+ ----------
+ client : Portkey
+ The Portkey client instance making the request.
+ *args :
+ Positional arguments passed to the create function.
+ create_func : callable
+ The create function to call (typically chat.completions.create).
+ inference_id : Optional[str], default None
+ Optional inference ID for tracking this request.
+ **kwargs :
+ Additional keyword arguments forwarded to create_func.
+
+ Returns
+ -------
+ Iterator[Any]
+ A generator that yields the chunks of the completion.
+ """
+ # Portkey is OpenAI-compatible; request and chunks follow OpenAI spec
+ # create_func is a bound method; do not pass client again
+ chunks = create_func(*args, **kwargs)
+ return stream_chunks(
+ chunks=chunks,
+ kwargs=kwargs,
+ client=client,
+ inference_id=inference_id,
+ )
+
+
+def stream_chunks(
+ chunks: Iterator[Any],
+ kwargs: Dict[str, Any],
+ client: "Portkey",
+ inference_id: Optional[str] = None,
+):
+ """Streams the chunks of the completion and traces the completion."""
+ collected_output_data = []
+ collected_function_call = {"name": "", "arguments": ""}
+ raw_outputs = []
+ start_time = time.time()
+ end_time = None
+ first_token_time = None
+ num_of_completion_tokens = None
+ latency = None
+ model_name = kwargs.get("model", "unknown")
+ provider = "unknown"
+ latest_usage_data = {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}
+ latest_chunk_metadata: Dict[str, Any] = {}
+
+ try:
+ i = 0
+ for i, chunk in enumerate(chunks):
+ raw_outputs.append(chunk.model_dump() if hasattr(chunk, "model_dump") else str(chunk))
+
+ if i == 0:
+ first_token_time = time.time()
+ # Try to detect provider at first chunk
+ provider = detect_provider(chunk, client, model_name)
+ if i > 0:
+ num_of_completion_tokens = i + 1
+
+ # Extract usage from chunk if available
+ chunk_usage = extract_usage(chunk)
+ if any(v is not None for v in chunk_usage.values()):
+ latest_usage_data = chunk_usage
+
+ # Update metadata from latest chunk (headers/etc.)
+ chunk_metadata = extract_portkey_unit_metadata(chunk, model_name)
+ if chunk_metadata:
+ latest_chunk_metadata.update(chunk_metadata)
+
+ # Extract delta from chunk (OpenAI-compatible)
+ delta = get_delta_from_chunk(chunk)
+
+ if delta and getattr(delta, "content", None):
+ collected_output_data.append(delta.content)
+ elif delta and getattr(delta, "function_call", None):
+ if delta.function_call.name:
+ collected_function_call["name"] += delta.function_call.name
+ if delta.function_call.arguments:
+ collected_function_call["arguments"] += delta.function_call.arguments
+ elif delta and getattr(delta, "tool_calls", None):
+ tool_call = delta.tool_calls[0]
+ if getattr(tool_call.function, "name", None):
+ collected_function_call["name"] += tool_call.function.name
+ if getattr(tool_call.function, "arguments", None):
+ collected_function_call["arguments"] += tool_call.function.arguments
+
+ yield chunk
+ end_time = time.time()
+ latency = (end_time - start_time) * 1000
+ # pylint: disable=broad-except
+ except Exception as e:
+ logger.error("Failed to yield Portkey chunk. %s", e)
+ finally:
+ # Try to add step to the trace
+ try:
+ collected_output_data = [m for m in collected_output_data if m is not None]
+ if collected_output_data:
+ output_data = "".join(collected_output_data)
+ else:
+ if collected_function_call["arguments"]:
+ try:
+ collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
+ except json.JSONDecodeError:
+ pass
+ output_data = collected_function_call
+
+ # Calculate usage and cost at end of stream (prioritize actual usage if present)
+ completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost(
+ chunks=raw_outputs,
+ messages=kwargs.get("messages", []),
+ output_content=output_data,
+ model_name=model_name,
+ latest_usage_data=latest_usage_data,
+ latest_chunk_metadata=latest_chunk_metadata,
+ )
+
+ usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {}
+ final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0)
+ final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens)
+ final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", (final_prompt_tokens or 0) + (final_completion_tokens or 0))
+ final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get("cost", None)
+
+ trace_args = create_trace_args(
+ end_time=end_time,
+ inputs={"prompt": kwargs.get("messages", [])},
+ output=output_data,
+ latency=latency,
+ tokens=final_total_tokens,
+ prompt_tokens=final_prompt_tokens,
+ completion_tokens=final_completion_tokens,
+ model=model_name,
+ model_parameters=get_model_parameters(kwargs),
+ raw_output=raw_outputs,
+ id=inference_id,
+ cost=final_cost,
+ metadata={
+ "timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None),
+ "provider": provider,
+ "portkey_model": model_name,
+ **extract_portkey_metadata(client),
+ **latest_chunk_metadata,
+ },
+ )
+ add_to_trace(**trace_args)
+ # pylint: disable=broad-except
+ except Exception as e:
+ logger.error("Failed to trace the Portkey streaming completion. %s", e)
+
+
+def handle_non_streaming_create(
+ client: "Portkey",
+ *args,
+ create_func: callable,
+ inference_id: Optional[str] = None,
+ **kwargs,
+) -> Any:
+ """
+ Handles non-streaming chat.completions.create routed via Portkey.
+
+ Parameters
+ ----------
+ client : Portkey
+ The Portkey client instance used for routing the request.
+ *args :
+ Positional arguments for the create function.
+ create_func : callable
+ The function used to create the chat completion. This is a bound method, so do not pass client again.
+ inference_id : Optional[str], optional
+ A unique identifier for the inference or trace, by default None.
+ **kwargs :
+ Additional keyword arguments passed to the create function (e.g., "messages", "model", etc.).
+
+ Returns
+ -------
+ Any
+ The completion response as returned by the create function.
+ """
+ start_time = time.time()
+ # create_func is a bound method; do not pass client again
+ response = create_func(*args, **kwargs)
+ end_time = time.time()
+
+ # Try to add step to the trace
+ try:
+ output_data = parse_non_streaming_output_data(response)
+
+ # Usage (if provided by upstream provider via Portkey)
+ usage_data = extract_usage(response)
+ model_name = getattr(response, "model", kwargs.get("model", "unknown"))
+ provider = detect_provider(response, client, model_name)
+ extra_metadata = extract_portkey_unit_metadata(response, model_name)
+ cost = extra_metadata.get("cost", None)
+
+ trace_args = create_trace_args(
+ end_time=end_time,
+ inputs={"prompt": kwargs.get("messages", [])},
+ output=output_data,
+ latency=(end_time - start_time) * 1000,
+ tokens=usage_data.get("total_tokens"),
+ prompt_tokens=usage_data.get("prompt_tokens"),
+ completion_tokens=usage_data.get("completion_tokens"),
+ model=model_name,
+ model_parameters=get_model_parameters(kwargs),
+ raw_output=response.model_dump() if hasattr(response, "model_dump") else str(response),
+ id=inference_id,
+ cost=cost,
+ metadata={
+ "system_fingerprint": getattr(response, "system_fingerprint", None),
+ "provider": provider,
+ "portkey_model": model_name,
+ **extract_portkey_metadata(client),
+ **extra_metadata,
+ },
+ )
+ add_to_trace(**trace_args)
+ # pylint: disable=broad-except
+ except Exception as e:
+ logger.error("Failed to trace the Portkey non-streaming completion. %s", e)
+
+ return response
+
+
+def get_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ """Gets the model parameters from the kwargs (OpenAI-compatible)."""
+ return {
+ "temperature": kwargs.get("temperature", 1),
+ "top_p": kwargs.get("top_p", 1),
+ "max_tokens": kwargs.get("max_tokens", None),
+ "n": kwargs.get("n", 1),
+ "stream": kwargs.get("stream", False),
+ "stop": kwargs.get("stop", None),
+ "presence_penalty": kwargs.get("presence_penalty", 0),
+ "frequency_penalty": kwargs.get("frequency_penalty", 0),
+ "logit_bias": kwargs.get("logit_bias", None),
+ "logprobs": kwargs.get("logprobs", False),
+ "top_logprobs": kwargs.get("top_logprobs", None),
+ "parallel_tool_calls": kwargs.get("parallel_tool_calls", True),
+ "seed": kwargs.get("seed", None),
+ "response_format": kwargs.get("response_format", None),
+ "timeout": kwargs.get("timeout", None),
+ "api_base": kwargs.get("api_base", None),
+ "api_version": kwargs.get("api_version", None),
+ }
+
+
+def create_trace_args(
+ end_time: float,
+ inputs: Dict[str, Any],
+ output: Union[str, Dict[str, Any], None],
+ latency: float,
+ tokens: Optional[int],
+ prompt_tokens: Optional[int],
+ completion_tokens: Optional[int],
+ model: str,
+ model_parameters: Optional[Dict[str, Any]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ raw_output: Optional[Union[str, Dict[str, Any]]] = None,
+ id: Optional[str] = None,
+ cost: Optional[float] = None,
+) -> Dict[str, Any]:
+ """Returns a dictionary with the trace arguments."""
+ trace_args = {
+ "end_time": end_time,
+ "inputs": inputs,
+ "output": output,
+ "latency": latency,
+ "tokens": tokens,
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": completion_tokens,
+ "model": model,
+ "model_parameters": model_parameters,
+ "raw_output": raw_output,
+ "metadata": metadata if metadata else {},
+ }
+ if id:
+ trace_args["id"] = id
+ if cost is not None:
+ trace_args["cost"] = cost
+ return trace_args
+
+
+def add_to_trace(**kwargs) -> None:
+ """Add a chat completion step to the trace."""
+ provider = kwargs.get("metadata", {}).get("provider", "Portkey")
+ tracer.add_chat_completion_step_to_trace(**kwargs, name="Portkey Chat Completion", provider=provider)
+
+
+def parse_non_streaming_output_data(response: Any) -> Union[str, Dict[str, Any], None]:
+ """Parses the output data from a non-streaming completion (OpenAI-compatible)."""
+ try:
+ if hasattr(response, "choices") and response.choices:
+ choice = response.choices[0]
+ message = getattr(choice, "message", None)
+ if message is None:
+ return None
+ content = getattr(message, "content", None)
+ function_call = getattr(message, "function_call", None)
+ tool_calls = getattr(message, "tool_calls", None)
+ if content:
+ return content.strip()
+ if function_call:
+ return {
+ "name": function_call.name,
+ "arguments": json.loads(function_call.arguments) if isinstance(function_call.arguments, str) else function_call.arguments,
+ }
+ if tool_calls:
+ return {
+ "name": tool_calls[0].function.name,
+ "arguments": json.loads(tool_calls[0].function.arguments) if isinstance(tool_calls[0].function.arguments, str) else tool_calls[0].function.arguments,
+ }
+ except Exception as e:
+ logger.debug("Error parsing Portkey output data: %s", e)
+ return None
+
+
+def extract_portkey_metadata(client: "Portkey") -> Dict[str, Any]:
+ """Extract Portkey-specific metadata from a Portkey client.
+
+ Attempts to read base URL and redact x-portkey-* headers if present.
+ Works defensively across SDK versions.
+ """
+ metadata: Dict[str, Any] = {"isPortkey": True}
+ # Base URL or host
+ for attr in ("base_url", "baseURL", "host", "custom_host"):
+ try:
+ val = getattr(client, attr, None)
+ if val:
+ metadata["portkeyBaseUrl"] = str(val)
+ break
+ except Exception:
+ continue
+
+ # Headers
+ possible_header_attrs = ("default_headers", "headers", "_default_headers", "_headers", "custom_headers", "allHeaders")
+ redacted: Dict[str, Any] = {}
+ for attr in possible_header_attrs:
+ try:
+ headers = getattr(client, attr, None)
+ if _is_dict_like(headers):
+ for k, v in headers.items():
+ if isinstance(k, str) and k.lower().startswith("x-portkey-"):
+ if k.lower() in {"x-portkey-api-key", "x-portkey-virtual-key"}:
+ redacted[k] = "***"
+ else:
+ redacted[k] = v
+ except Exception:
+ continue
+ if redacted:
+ metadata["portkeyHeaders"] = redacted
+ else:
+ logger.debug(
+ "Openlayer Portkey tracer: No x-portkey-* headers detected on client; provider/config metadata may be limited."
+ )
+ return metadata
+
+
+def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]:
+ """Extract metadata from a response or chunk unit (headers, ids)."""
+ metadata: Dict[str, Any] = {}
+ try:
+ # Extract system fingerprint if available (OpenAI-compatible)
+ if hasattr(unit, "system_fingerprint"):
+ metadata["system_fingerprint"] = unit.system_fingerprint
+ if hasattr(unit, "service_tier"):
+ metadata["service_tier"] = unit.service_tier
+
+ # Response headers may be present on the object
+ headers_obj = None
+ if hasattr(unit, "_response_headers"):
+ headers_obj = getattr(unit, "_response_headers")
+ elif hasattr(unit, "response_headers"):
+ headers_obj = getattr(unit, "response_headers")
+ elif hasattr(unit, "_headers"):
+ headers_obj = getattr(unit, "_headers")
+
+ if _is_dict_like(headers_obj):
+ headers = {str(k): v for k, v in headers_obj.items()}
+ metadata["response_headers"] = headers
+ # Known Portkey header hints (names are lower-cased defensively)
+ lower = {k.lower(): v for k, v in headers.items()}
+ if "x-portkey-trace-id" in lower:
+ metadata["portkey_trace_id"] = lower["x-portkey-trace-id"]
+ if "x-portkey-cache-status" in lower:
+ metadata["portkey_cache_status"] = lower["x-portkey-cache-status"]
+ if "x-portkey-retry-attempt-count" in lower:
+ metadata["portkey_retry_attempt_count"] = lower["x-portkey-retry-attempt-count"]
+ if "x-portkey-last-used-option-index" in lower:
+ metadata["portkey_last_used_option_index"] = lower["x-portkey-last-used-option-index"]
+ except Exception:
+ pass
+ # Attach model for convenience
+ if model_name:
+ metadata["portkey_model"] = model_name
+ return metadata
+
+
+def extract_usage(obj: Any) -> Dict[str, Optional[int]]:
+ """Extract usage from a response or chunk object.
+
+ This function attempts to extract token usage information from various
+ locations where it might be stored, including:
+ - Direct `usage` attribute
+ - `model_dump()` dictionary (for streaming chunks)
+
+ Parameters
+ ----------
+ obj : Any
+ The response or chunk object to extract usage from.
+
+ Returns
+ -------
+ Dict[str, Optional[int]]
+ Dictionary with keys: total_tokens, prompt_tokens, completion_tokens.
+ Values are None if usage information is not found.
+ """
+ try:
+ # Check for direct usage attribute (works for both response and chunk)
+ if hasattr(obj, "usage") and obj.usage is not None:
+ usage = obj.usage
+ return {
+ "total_tokens": getattr(usage, "total_tokens", None),
+ "prompt_tokens": getattr(usage, "prompt_tokens", None),
+ "completion_tokens": getattr(usage, "completion_tokens", None),
+ }
+
+ # Check if object model dump has usage (primarily for streaming chunks)
+ if hasattr(obj, "model_dump"):
+ obj_dict = obj.model_dump()
+ if _supports_membership_check(obj_dict) and "usage" in obj_dict and obj_dict["usage"]:
+ usage = obj_dict["usage"]
+ return {
+ "total_tokens": usage.get("total_tokens", None),
+ "prompt_tokens": usage.get("prompt_tokens", None),
+ "completion_tokens": usage.get("completion_tokens", None),
+ }
+ except Exception:
+ pass
+ return {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}
+
+
+def calculate_streaming_usage_and_cost(
+ chunks: Any,
+ messages: Any,
+ output_content: Any,
+ model_name: str,
+ latest_usage_data: Dict[str, Optional[int]],
+ latest_chunk_metadata: Dict[str, Any],
+):
+ """Calculate usage and cost at the end of streaming."""
+ try:
+ # Priority 1: Actual usage provided in chunks
+ if latest_usage_data and latest_usage_data.get("total_tokens") and latest_usage_data.get("total_tokens") > 0:
+ return (
+ latest_usage_data.get("completion_tokens"),
+ latest_usage_data.get("prompt_tokens"),
+ latest_usage_data.get("total_tokens"),
+ latest_chunk_metadata.get("cost"),
+ )
+
+ # Priority 2: Look for usage embedded in final chunk dicts (if raw dicts)
+ if isinstance(chunks, list):
+ for chunk_data in reversed(chunks):
+ if _supports_membership_check(chunk_data) and "usage" in chunk_data and chunk_data["usage"]:
+ usage = chunk_data["usage"]
+ if usage.get("total_tokens", 0) > 0:
+ return (
+ usage.get("completion_tokens"),
+ usage.get("prompt_tokens"),
+ usage.get("total_tokens"),
+ latest_chunk_metadata.get("cost"),
+ )
+
+ # Priority 3: Estimate tokens
+ completion_tokens = None
+ prompt_tokens = None
+ total_tokens = None
+ cost = None
+
+ # Estimate completion tokens
+ if isinstance(output_content, str):
+ completion_tokens = max(1, len(output_content) // 4)
+ elif _is_dict_like(output_content):
+ json_str = json.dumps(output_content) if output_content else "{}"
+ completion_tokens = max(1, len(json_str) // 4)
+ else:
+ # Fallback: count chunks present
+ try:
+ completion_tokens = len([c for c in chunks if c])
+ except Exception:
+ completion_tokens = None
+
+ # Estimate prompt tokens from messages
+ if messages:
+ total_chars = 0
+ try:
+ for message in messages:
+ if _supports_membership_check(message) and "content" in message:
+ total_chars += len(str(message["content"]))
+ except Exception:
+ total_chars = 0
+ prompt_tokens = max(1, total_chars // 4) if total_chars > 0 else 0
+ else:
+ prompt_tokens = 0
+
+ total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
+
+ # Cost from metadata if present; otherwise simple heuristic for some models
+ cost = latest_chunk_metadata.get("cost")
+ if cost is None and total_tokens and model_name:
+ ml = model_name.lower()
+ if "gpt-3.5-turbo" in ml:
+ cost = (prompt_tokens * 0.0005 / 1000.0) + (completion_tokens * 0.0015 / 1000.0)
+
+ return completion_tokens, prompt_tokens, total_tokens, cost
+ except Exception as e:
+ logger.error("Error calculating streaming usage: %s", e)
+ return None, None, None, None
+
+
+def _extract_provider_from_object(obj: Any) -> Optional[str]:
+ """Extract provider from a response or chunk object.
+
+ Checks response_metadata for provider information.
+ Returns None if no provider is found.
+ """
+ try:
+ # Check response_metadata
+ if hasattr(obj, "response_metadata") and _is_dict_like(obj.response_metadata):
+ if "provider" in obj.response_metadata:
+ return obj.response_metadata["provider"]
+ except Exception:
+ pass
+ return None
+
+
+def detect_provider(obj: Any, client: "Portkey", model_name: str) -> str:
+ """Detect provider from a response or chunk object.
+
+ Parameters
+ ----------
+ obj : Any
+ The response or chunk object to extract provider information from.
+ client : Portkey
+ The Portkey client instance.
+ model_name : str
+ The model name to use as a fallback for provider detection.
+
+ Returns
+ -------
+ str
+ The detected provider name.
+ """
+ # First: check Portkey headers on the client (authoritative)
+ provider = _provider_from_portkey_headers(client)
+ if provider:
+ return provider
+ # Next: check object metadata if any
+ provider = _extract_provider_from_object(obj)
+ if provider:
+ return provider
+ # Fallback to model name heuristics
+ return detect_provider_from_model_name(model_name)
+
+
+def detect_provider_from_model_name(model_name: str) -> str:
+ """Detect provider from model name patterns."""
+ model_lower = (model_name or "").lower()
+ if model_lower.startswith(("gpt-", "o1-", "text-davinci", "text-curie", "text-babbage", "text-ada")):
+ return "OpenAI"
+ if model_lower.startswith(("claude-", "claude")):
+ return "Anthropic"
+ if "gemini" in model_lower or "palm" in model_lower:
+ return "Google"
+ if "llama" in model_lower or "meta-" in model_lower:
+ return "Meta"
+ if model_lower.startswith("mistral") or "mixtral" in model_lower:
+ return "Mistral"
+ if model_lower.startswith("command"):
+ return "Cohere"
+ return "Portkey"
+
+
+def get_delta_from_chunk(chunk: Any) -> Any:
+ """Extract delta from chunk, handling different response formats."""
+ try:
+ if hasattr(chunk, "choices") and chunk.choices:
+ choice = chunk.choices[0]
+ if hasattr(choice, "delta"):
+ return choice.delta
+ except Exception:
+ pass
+ return None
+
+
+def _provider_from_portkey_headers(client: "Portkey") -> Optional[str]:
+ """Get provider from Portkey headers on the client."""
+ header_sources = ("default_headers", "headers", "_default_headers", "_headers")
+ for attr in header_sources:
+ try:
+ headers = getattr(client, attr, None)
+ if _is_dict_like(headers):
+ for k, v in headers.items():
+ if isinstance(k, str) and k.lower() == "x-portkey-provider" and v:
+ return str(v)
+ except Exception:
+ continue
+ return None
+
+
+def _is_dict_like(obj: Any) -> bool:
+ """Check if an object is dict-like (has .items() method).
+
+ This is more robust than isinstance(obj, dict) as it handles
+ custom dict-like objects (e.g., CaseInsensitiveDict, custom headers).
+ """
+ return hasattr(obj, "items") and callable(getattr(obj, "items", None))
+
+
+def _supports_membership_check(obj: Any) -> bool:
+ """Check if an object supports membership testing (e.g., 'key in obj').
+
+ This checks for __contains__ method or if it's dict-like.
+ """
+ return hasattr(obj, "__contains__") or _is_dict_like(obj)
diff --git a/tests/test_integration_conditional_imports.py b/tests/test_integration_conditional_imports.py
index f673b480..adb5bf9d 100644
--- a/tests/test_integration_conditional_imports.py
+++ b/tests/test_integration_conditional_imports.py
@@ -34,6 +34,7 @@
"oci_tracer": ["oci"],
"langchain_callback": ["langchain", "langchain_core", "langchain_community"],
"litellm_tracer": ["litellm"],
+ "portkey_tracer": ["portkey_ai"],
}
# Expected patterns for integration modules
diff --git a/tests/test_portkey_integration.py b/tests/test_portkey_integration.py
new file mode 100644
index 00000000..97e2d50e
--- /dev/null
+++ b/tests/test_portkey_integration.py
@@ -0,0 +1,576 @@
+"""Test Portkey tracer integration."""
+
+import json
+from types import SimpleNamespace
+from typing import Any, Dict
+from unittest.mock import Mock, patch
+
+import pytest # type: ignore
+
+
+class TestPortkeyIntegration:
+ """Test Portkey integration functionality."""
+
+ def test_import_without_portkey(self) -> None:
+ """Module should import even when Portkey is unavailable."""
+ from openlayer.lib.integrations import portkey_tracer # noqa: F401
+
+ assert hasattr(portkey_tracer, "HAVE_PORTKEY")
+
+ def test_trace_portkey_raises_import_error_without_dependency(self) -> None:
+ """trace_portkey should raise ImportError when Portkey is missing."""
+ with patch("openlayer.lib.integrations.portkey_tracer.HAVE_PORTKEY", False):
+ from openlayer.lib.integrations.portkey_tracer import trace_portkey
+
+ with pytest.raises(ImportError) as exc_info: # type: ignore
+ trace_portkey()
+
+ message = str(exc_info.value) # type: ignore[attr-defined]
+ assert "Portkey library is not installed" in message
+ assert "pip install portkey-ai" in message
+
+ def test_trace_portkey_patches_portkey_client(self) -> None:
+ """trace_portkey should wrap Portkey chat completions for tracing."""
+
+ class DummyPortkey: # pylint: disable=too-few-public-methods
+ """Lightweight Portkey stand-in used for patching behavior."""
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401
+ completions = SimpleNamespace(create=Mock(name="original_create"))
+ self.chat = SimpleNamespace(completions=completions)
+ self._init_args = (args, kwargs)
+ self.original_create = completions.create
+
+ with patch("openlayer.lib.integrations.portkey_tracer.HAVE_PORTKEY", True), patch(
+ "openlayer.lib.integrations.portkey_tracer.Portkey", DummyPortkey, create=True
+ ), patch(
+ "openlayer.lib.integrations.portkey_tracer.handle_non_streaming_create",
+ autospec=True,
+ ) as mock_non_streaming, patch(
+ "openlayer.lib.integrations.portkey_tracer.handle_streaming_create",
+ autospec=True,
+ ) as mock_streaming:
+ mock_non_streaming.return_value = "non-stream-result"
+ mock_streaming.return_value = "stream-result"
+
+ from openlayer.lib.integrations.portkey_tracer import trace_portkey
+
+ trace_portkey()
+
+ client = DummyPortkey()
+ # Non-streaming
+ result_non_stream = client.chat.completions.create(messages=[{"role": "user", "content": "hi"}])
+ assert result_non_stream == "non-stream-result"
+ assert mock_non_streaming.call_count == 1
+ non_stream_kwargs = mock_non_streaming.call_args.kwargs
+ assert non_stream_kwargs["create_func"] is client.original_create
+ assert non_stream_kwargs["inference_id"] is None
+
+ # Streaming path
+ result_stream = client.chat.completions.create(
+ messages=[{"role": "user", "content": "hi"}], stream=True, inference_id="inference-123"
+ )
+ assert result_stream == "stream-result"
+ assert mock_streaming.call_count == 1
+ stream_kwargs = mock_streaming.call_args.kwargs
+ assert stream_kwargs["create_func"] is client.original_create
+ assert stream_kwargs["inference_id"] == "inference-123"
+
+ def test_detect_provider_from_model_name(self) -> None:
+ """Provider detection should match model naming heuristics."""
+ from openlayer.lib.integrations.portkey_tracer import detect_provider_from_model_name
+
+ test_cases = [
+ ("gpt-4", "OpenAI"),
+ ("Gpt-3.5-turbo", "OpenAI"),
+ ("claude-3-opus", "Anthropic"),
+ ("gemini-pro", "Google"),
+ ("meta-llama-3-70b", "Meta"),
+ ("mixtral-8x7b", "Mistral"),
+ ("command-r", "Cohere"),
+ ("unknown-model", "Portkey"),
+ ]
+
+ for model_name, expected in test_cases:
+ assert detect_provider_from_model_name(model_name) == expected
+
+ def test_get_model_parameters(self) -> None:
+ """Ensure OpenAI-compatible kwargs are extracted."""
+ from openlayer.lib.integrations.portkey_tracer import get_model_parameters
+
+ kwargs = {
+ "temperature": 0.5,
+ "top_p": 0.7,
+ "max_tokens": 256,
+ "n": 3,
+ "stream": True,
+ "stop": ["END"],
+ "presence_penalty": 0.1,
+ "frequency_penalty": -0.1,
+ "logit_bias": {"1": -1},
+ "logprobs": True,
+ "top_logprobs": 5,
+ "parallel_tool_calls": False,
+ "seed": 123,
+ "response_format": {"type": "json_object"},
+ "timeout": 42,
+ "api_base": "https://api.example.com",
+ "api_version": "2024-05-01",
+ }
+
+ params = get_model_parameters(kwargs)
+
+ expected = kwargs.copy()
+ assert params == expected
+
+ def test_extract_portkey_metadata(self) -> None:
+ """Portkey metadata should redact sensitive headers and include base URL."""
+ from openlayer.lib.integrations.portkey_tracer import extract_portkey_metadata
+
+ client = SimpleNamespace(
+ base_url="https://gateway.portkey.ai",
+ headers={
+ "X-Portkey-Api-Key": "secret",
+ "X-Portkey-Provider": "openai",
+ "Authorization": "Bearer ignored",
+ },
+ )
+
+ metadata = extract_portkey_metadata(client)
+
+ assert metadata["isPortkey"] is True
+ assert metadata["portkeyBaseUrl"] == "https://gateway.portkey.ai"
+ assert metadata["portkeyHeaders"]["X-Portkey-Api-Key"] == "***"
+ assert metadata["portkeyHeaders"]["X-Portkey-Provider"] == "openai"
+ assert "Authorization" not in metadata["portkeyHeaders"]
+
+ def test_extract_portkey_unit_metadata(self) -> None:
+ """Unit metadata should capture headers and retry/option index hints."""
+ from openlayer.lib.integrations.portkey_tracer import extract_portkey_unit_metadata
+
+ unit = SimpleNamespace(
+ system_fingerprint="fingerprint-123",
+ _response_headers={
+ "x-portkey-trace-id": "trace-1",
+ "x-portkey-cache-status": "HIT",
+ "x-portkey-retry-attempt-count": "2",
+ "x-portkey-last-used-option-index": "config.targets[1]",
+ "content-type": "application/json",
+ },
+ )
+
+ metadata = extract_portkey_unit_metadata(unit, "claude-3-opus")
+
+ assert metadata["system_fingerprint"] == "fingerprint-123"
+ assert metadata["portkey_trace_id"] == "trace-1"
+ assert metadata["portkey_cache_status"] == "HIT"
+ assert metadata["portkey_retry_attempt_count"] == "2"
+ assert metadata["portkey_last_used_option_index"] == "config.targets[1]"
+ assert metadata["portkey_model"] == "claude-3-opus"
+ assert metadata["response_headers"]["content-type"] == "application/json"
+
+ def test_extract_portkey_unit_metadata_with_dict_like_headers(self) -> None:
+ """Unit metadata should work with dict-like objects (not just dicts)."""
+ from openlayer.lib.integrations.portkey_tracer import extract_portkey_unit_metadata
+
+ # Create a dict-like object (has .items() but not isinstance(dict))
+ class DictLikeHeaders:
+ def __init__(self):
+ self._data = {
+ "x-portkey-trace-id": "trace-2",
+ "x-portkey-cache-status": "MISS",
+ "x-portkey-retry-attempt-count": "3",
+ "x-portkey-last-used-option-index": "config.targets[1]",
+ }
+
+ def items(self):
+ return self._data.items()
+
+ unit = SimpleNamespace(
+ _response_headers=DictLikeHeaders(),
+ )
+
+ metadata = extract_portkey_unit_metadata(unit, "gpt-4")
+
+ assert metadata["portkey_trace_id"] == "trace-2"
+ assert metadata["portkey_cache_status"] == "MISS"
+ assert metadata["portkey_retry_attempt_count"] == "3"
+ assert metadata["portkey_last_used_option_index"] == "config.targets[1]"
+
+ def test_extract_usage_from_response(self) -> None:
+ """Usage extraction should read OpenAI-style usage objects."""
+ from openlayer.lib.integrations.portkey_tracer import extract_usage
+
+ usage = SimpleNamespace(total_tokens=50, prompt_tokens=20, completion_tokens=30)
+ response = SimpleNamespace(usage=usage)
+
+ assert extract_usage(response) == {
+ "total_tokens": 50,
+ "prompt_tokens": 20,
+ "completion_tokens": 30,
+ }
+
+ response_no_usage = SimpleNamespace()
+ assert extract_usage(response_no_usage) == {
+ "total_tokens": None,
+ "prompt_tokens": None,
+ "completion_tokens": None,
+ }
+
+ def test_extract_usage_from_chunk(self) -> None:
+ """Usage data should be derived from multiple potential chunk attributes."""
+ from openlayer.lib.integrations.portkey_tracer import extract_usage
+
+ chunk_direct = SimpleNamespace(
+ usage=SimpleNamespace(total_tokens=120, prompt_tokens=40, completion_tokens=80)
+ )
+ assert extract_usage(chunk_direct) == {
+ "total_tokens": 120,
+ "prompt_tokens": 40,
+ "completion_tokens": 80,
+ }
+
+ class ChunkWithModelDump: # pylint: disable=too-few-public-methods
+ def model_dump(self) -> Dict[str, Any]:
+ return {"usage": {"total_tokens": 12, "prompt_tokens": 5, "completion_tokens": 7}}
+
+ assert extract_usage(ChunkWithModelDump()) == {
+ "total_tokens": 12,
+ "prompt_tokens": 5,
+ "completion_tokens": 7,
+ }
+
+ def test_calculate_streaming_usage_and_cost_with_actual_usage(self) -> None:
+ """Actual usage data should be returned when available."""
+ from openlayer.lib.integrations.portkey_tracer import calculate_streaming_usage_and_cost
+
+ latest_usage = {"total_tokens": 100, "prompt_tokens": 40, "completion_tokens": 60}
+ latest_metadata = {"cost": 0.99}
+
+ result = calculate_streaming_usage_and_cost(
+ chunks=[],
+ messages=[],
+ output_content="",
+ model_name="gpt-4",
+ latest_usage_data=latest_usage,
+ latest_chunk_metadata=latest_metadata,
+ )
+
+ assert result == (60, 40, 100, 0.99)
+
+ def test_calculate_streaming_usage_and_cost_fallback_estimation(self) -> None:
+ """Fallback estimation should approximate tokens and cost when usage is missing."""
+ from openlayer.lib.integrations.portkey_tracer import calculate_streaming_usage_and_cost
+
+ output_content = "Generated answer text."
+ messages = [
+ {"role": "system", "content": "You are helpful."},
+ {"role": "user", "content": "Tell me something interesting."},
+ ]
+
+ completion_tokens, prompt_tokens, total_tokens, cost = calculate_streaming_usage_and_cost(
+ chunks=[{"usage": None}],
+ messages=messages,
+ output_content=output_content,
+ model_name="gpt-3.5-turbo",
+ latest_usage_data={"total_tokens": None, "prompt_tokens": None, "completion_tokens": None},
+ latest_chunk_metadata={},
+ )
+
+ assert completion_tokens >= 1
+ assert prompt_tokens >= 1
+ assert total_tokens == (completion_tokens or 0) + (prompt_tokens or 0)
+ assert cost is not None
+ assert cost >= 0
+
+ def test_detect_provider_from_response_prefers_headers(self) -> None:
+ """Provider detection should prioritize Portkey headers."""
+ from openlayer.lib.integrations.portkey_tracer import detect_provider
+
+ client = SimpleNamespace()
+ response = SimpleNamespace()
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider"
+ ):
+ assert detect_provider(response, client, "gpt-4") == "header-provider"
+
+ def test_detect_provider_from_chunk_prefers_headers(self) -> None:
+ """Provider detection from chunk should prioritize header-derived values."""
+ from openlayer.lib.integrations.portkey_tracer import detect_provider
+
+ client = SimpleNamespace()
+ chunk = SimpleNamespace()
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider"
+ ):
+ assert detect_provider(chunk, client, "gpt-4") == "header-provider"
+
+ def test_detect_provider_from_response_fallback(self) -> None:
+ """Provider detection should fall back to response metadata or model name."""
+ from openlayer.lib.integrations.portkey_tracer import detect_provider
+
+ client = SimpleNamespace()
+ response = SimpleNamespace(
+ response_metadata={"provider": "anthropic"},
+ )
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None
+ ):
+ assert detect_provider(response, client, "mistral-7b") == "anthropic"
+
+ def test_detect_provider_from_chunk_fallback(self) -> None:
+ """Chunk provider detection should fall back gracefully."""
+ from openlayer.lib.integrations.portkey_tracer import detect_provider
+
+ chunk = SimpleNamespace(
+ response_metadata={"provider": "cohere"},
+ )
+ client = SimpleNamespace()
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None
+ ):
+ assert detect_provider(chunk, client, "command-r") == "cohere"
+
+ def test_provider_from_portkey_headers(self) -> None:
+ """Header helper should identify provider values on the client."""
+ from openlayer.lib.integrations.portkey_tracer import _provider_from_portkey_headers
+
+ client = SimpleNamespace(
+ default_headers={"X-Portkey-Provider": "openai"},
+ headers={"X-Portkey-Provider": "anthropic"},
+ )
+
+ assert _provider_from_portkey_headers(client) == "openai"
+
+ def test_parse_non_streaming_output_data(self) -> None:
+ """Output parsing should support content, function calls, and tool calls."""
+ from openlayer.lib.integrations.portkey_tracer import parse_non_streaming_output_data
+
+ # Content message
+ response_content = SimpleNamespace(
+ choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", function_call=None, tool_calls=None))]
+ )
+ assert parse_non_streaming_output_data(response_content) == "Hello!"
+
+ # Function call
+ response_function = SimpleNamespace(
+ choices=[
+ SimpleNamespace(
+ message=SimpleNamespace(
+ content=None,
+ function_call=SimpleNamespace(name="do_something", arguments=json.dumps({"value": 1})),
+ tool_calls=None,
+ )
+ )
+ ]
+ )
+ assert parse_non_streaming_output_data(response_function) == {"name": "do_something", "arguments": {"value": 1}}
+
+ # Tool call
+ response_tool = SimpleNamespace(
+ choices=[
+ SimpleNamespace(
+ message=SimpleNamespace(
+ content=None,
+ function_call=None,
+ tool_calls=[
+ SimpleNamespace(
+ function=SimpleNamespace(name="call_tool", arguments=json.dumps({"value": 2}))
+ )
+ ],
+ )
+ )
+ ]
+ )
+ assert parse_non_streaming_output_data(response_tool) == {"name": "call_tool", "arguments": {"value": 2}}
+
+ def test_create_trace_args(self) -> None:
+ """Trace argument helper should include optional id and cost."""
+ from openlayer.lib.integrations.portkey_tracer import create_trace_args
+
+ args = create_trace_args(
+ end_time=1.0,
+ inputs={"prompt": []},
+ output="response",
+ latency=123.0,
+ tokens=10,
+ prompt_tokens=4,
+ completion_tokens=6,
+ model="gpt-4",
+ id="trace-id",
+ cost=0.42,
+ )
+
+ assert args["id"] == "trace-id"
+ assert args["cost"] == 0.42
+ assert args["metadata"] == {}
+
+ def test_add_to_trace_uses_provider_metadata(self) -> None:
+ """add_to_trace should pass provider metadata through to tracer."""
+ from openlayer.lib.integrations.portkey_tracer import add_to_trace
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer.tracer.add_chat_completion_step_to_trace"
+ ) as mock_add:
+ add_to_trace(
+ end_time=1.0,
+ inputs={},
+ output=None,
+ latency=10.0,
+ tokens=None,
+ prompt_tokens=None,
+ completion_tokens=None,
+ model="model",
+ metadata={},
+ )
+
+ _, kwargs = mock_add.call_args
+ assert kwargs["provider"] == "Portkey"
+ assert kwargs["name"] == "Portkey Chat Completion"
+
+ add_to_trace(
+ end_time=2.0,
+ inputs={},
+ output=None,
+ latency=5.0,
+ tokens=None,
+ prompt_tokens=None,
+ completion_tokens=None,
+ model="model",
+ metadata={"provider": "OpenAI"},
+ )
+
+ assert mock_add.call_count == 2
+ assert mock_add.call_args.kwargs["provider"] == "OpenAI"
+
+ def test_handle_streaming_create_delegates_to_stream_chunks(self) -> None:
+ """handle_streaming_create should call the original create and stream_chunks."""
+ from openlayer.lib.integrations.portkey_tracer import handle_streaming_create
+
+ client = SimpleNamespace()
+ create_func = Mock(return_value=iter(["chunk"]))
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer.stream_chunks", return_value=iter(["chunk"])
+ ) as mock_stream_chunks:
+ result_iterator = handle_streaming_create(
+ client,
+ "arg-1",
+ create_func=create_func,
+ inference_id="stream-id",
+ foo="bar",
+ )
+
+ assert list(result_iterator) == ["chunk"]
+ create_func.assert_called_once_with("arg-1", foo="bar")
+ mock_stream_chunks.assert_called_once()
+ stream_kwargs = mock_stream_chunks.call_args.kwargs
+ assert stream_kwargs["client"] is client
+ assert stream_kwargs["inference_id"] == "stream-id"
+ assert stream_kwargs["kwargs"] == {"foo": "bar"}
+ assert stream_kwargs["chunks"] is create_func.return_value
+
+ def test_stream_chunks_traces_completion(self) -> None:
+ """stream_chunks should yield all chunks and record a traced step."""
+ from openlayer.lib.integrations.portkey_tracer import stream_chunks
+
+ chunk_a = object()
+ chunk_b = object()
+ chunks = [chunk_a, chunk_b]
+ kwargs = {"messages": [{"role": "user", "content": "hello"}]}
+ client = SimpleNamespace()
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer.add_to_trace", autospec=True
+ ) as mock_add_to_trace, patch(
+ "openlayer.lib.integrations.portkey_tracer.extract_usage", autospec=True
+ ) as mock_usage, patch(
+ "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", autospec=True
+ ) as mock_unit_metadata, patch(
+ "openlayer.lib.integrations.portkey_tracer.detect_provider", autospec=True
+ ) as mock_detect_provider, patch(
+ "openlayer.lib.integrations.portkey_tracer.get_delta_from_chunk", autospec=True
+ ) as mock_delta, patch(
+ "openlayer.lib.integrations.portkey_tracer.calculate_streaming_usage_and_cost", autospec=True
+ ) as mock_calc, patch(
+ "openlayer.lib.integrations.portkey_tracer.extract_portkey_metadata", autospec=True
+ ) as mock_client_metadata, patch(
+ "openlayer.lib.integrations.portkey_tracer.time.time", side_effect=[100.0, 100.05, 100.2]
+ ):
+ mock_usage.side_effect = [
+ {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None},
+ {"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6},
+ ]
+ mock_unit_metadata.side_effect = [{}, {"cost": 0.1}]
+ mock_detect_provider.side_effect = ["OpenAI", "OpenAI"]
+ mock_delta.side_effect = [
+ SimpleNamespace(content="Hello ", function_call=None, tool_calls=None),
+ SimpleNamespace(content="World", function_call=None, tool_calls=None),
+ ]
+ mock_calc.return_value = (6, 4, 10, 0.1)
+ mock_client_metadata.return_value = {"portkeyBaseUrl": "https://gateway"}
+
+ yielded = list(
+ stream_chunks(
+ chunks=iter(chunks),
+ kwargs=kwargs,
+ client=client,
+ inference_id="trace-123",
+ )
+ )
+
+ assert yielded == chunks
+ mock_add_to_trace.assert_called_once()
+ trace_kwargs = mock_add_to_trace.call_args.kwargs
+ assert trace_kwargs["metadata"]["provider"] == "OpenAI"
+ assert trace_kwargs["metadata"]["portkeyBaseUrl"] == "https://gateway"
+ assert trace_kwargs["id"] == "trace-123"
+ assert trace_kwargs["tokens"] == 10
+ assert trace_kwargs["latency"] == pytest.approx(200.0)
+
+ def test_handle_non_streaming_create_traces_completion(self) -> None:
+ """handle_non_streaming_create should record a traced step for completions."""
+ from openlayer.lib.integrations.portkey_tracer import handle_non_streaming_create
+
+ response = SimpleNamespace(model="gpt-4", system_fingerprint="fp-1")
+ client = SimpleNamespace()
+ create_func = Mock(return_value=response)
+
+ with patch(
+ "openlayer.lib.integrations.portkey_tracer.parse_non_streaming_output_data", return_value="output"
+ ), patch(
+ "openlayer.lib.integrations.portkey_tracer.extract_usage",
+ return_value={"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6},
+ ), patch(
+ "openlayer.lib.integrations.portkey_tracer.detect_provider", return_value="OpenAI"
+ ), patch(
+ "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata",
+ return_value={"cost": 0.25},
+ ), patch(
+ "openlayer.lib.integrations.portkey_tracer.extract_portkey_metadata",
+ return_value={"portkeyHeaders": {"X-Portkey-Provider": "openai"}},
+ ), patch(
+ "openlayer.lib.integrations.portkey_tracer.add_to_trace"
+ ) as mock_add_to_trace, patch(
+ "openlayer.lib.integrations.portkey_tracer.time.time", side_effect=[10.0, 10.2]
+ ):
+ result = handle_non_streaming_create(
+ client,
+ create_func=create_func,
+ inference_id="trace-xyz",
+ messages=[{"role": "user", "content": "Hello"}],
+ )
+
+ assert result is response
+ mock_add_to_trace.assert_called_once()
+ trace_kwargs = mock_add_to_trace.call_args.kwargs
+ assert trace_kwargs["id"] == "trace-xyz"
+ assert trace_kwargs["metadata"]["provider"] == "OpenAI"
+ assert trace_kwargs["metadata"]["cost"] == 0.25
+ assert trace_kwargs["metadata"]["portkeyHeaders"]["X-Portkey-Provider"] == "openai"
+
+
From ba9c163685bfd2263e9213a738d49dad8e085b50 Mon Sep 17 00:00:00 2001
From: Siddhant Shah
Date: Fri, 20 Feb 2026 11:49:10 +0530
Subject: [PATCH 2/2] fix(test_portkey_integration.py): lint fixes
---
tests/test_portkey_integration.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/tests/test_portkey_integration.py b/tests/test_portkey_integration.py
index 97e2d50e..b0ff1d32 100644
--- a/tests/test_portkey_integration.py
+++ b/tests/test_portkey_integration.py
@@ -1,8 +1,12 @@
"""Test Portkey tracer integration."""
+# openlayer.lib.integrations is in pyright's ignore list, so imports from portkey_tracer
+# get unknown/partially unknown types; disable these diagnostics for this test file only.
+# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false
+
import json
from types import SimpleNamespace
-from typing import Any, Dict
+from typing import Any, Dict, Optional
from unittest.mock import Mock, patch
import pytest # type: ignore
@@ -244,7 +248,11 @@ def test_calculate_streaming_usage_and_cost_with_actual_usage(self) -> None:
"""Actual usage data should be returned when available."""
from openlayer.lib.integrations.portkey_tracer import calculate_streaming_usage_and_cost
- latest_usage = {"total_tokens": 100, "prompt_tokens": 40, "completion_tokens": 60}
+ latest_usage: Dict[str, Optional[int]] = {
+ "total_tokens": 100,
+ "prompt_tokens": 40,
+ "completion_tokens": 60,
+ }
latest_metadata = {"cost": 0.99}
result = calculate_streaming_usage_and_cost(