diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..15100c8 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,80 @@ +name: release + +on: + push: + tags: ["v*"] + workflow_dispatch: + inputs: + publish: + description: "Publish to PyPI" + required: true + default: "false" + type: choice + options: ["false", "true"] + +jobs: + quality: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e ".[dev]" + + - name: Run tests + run: python -m pytest -q + + - name: Run security checks + run: python -m bandit -q -r src/predicate_secure/ + + - name: Run pre-commit checks + run: | + python -m pip install pre-commit + pre-commit run --all-files + + publish: + runs-on: ubuntu-latest + needs: [quality] + if: (github.event_name == 'workflow_dispatch' && inputs.publish == 'true') || startsWith(github.ref, 'refs/tags/v') + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build tooling + run: python -m pip install --upgrade pip build twine + + - name: Validate version matches tag + if: startsWith(github.ref, 'refs/tags/v') + run: | + TAG_VERSION="${GITHUB_REF_NAME#v}" + PKG_VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])") + if [ "$TAG_VERSION" != "$PKG_VERSION" ]; then + echo "Tag version ($TAG_VERSION) does not match package version ($PKG_VERSION)" + exit 1 + fi + echo "Version validated: $PKG_VERSION" + + - name: Build package + run: python -m build + + - name: Validate distribution metadata + run: twine check dist/* + + - name: Publish to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN_PREDICATE_SECURE }} + run: twine upload dist/* diff --git a/README.md b/README.md index 019e01d..22f18eb 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # predicate-secure +[![License](https://img.shields.io/badge/License-MIT%2FApache--2.0-blue.svg)](LICENSE) +[![PyPI - predicate-secure](https://img.shields.io/pypi/v/predicate-secure.svg)](https://pypi.org/project/predicate-secure/) + Drop-in security wrapper for AI agents. Adds authorization, verification, and audit to any agent framework in 3 lines of code. ## Features @@ -67,8 +70,8 @@ secure_agent.run() `predicate-secure` is a thin orchestration layer that combines: -- **predicate** (sdk-python) - Snapshot engine, DOM pruning, verification predicates -- **predicate-authority** (AgentIdentity) - Policy engine, mandate signing, audit logging +- **[predicate-runtime](https://github.com/PredicateSystems/sdk-python)** - Snapshot engine, DOM pruning, verification predicates +- **[predicate-authority](https://github.com/PredicateSystems/predicate-authority)** - Policy engine, mandate signing, audit logging ``` SecureAgent diff --git a/src/predicate_secure/__init__.py b/src/predicate_secure/__init__.py index 086118d..be7bbfa 100644 --- a/src/predicate_secure/__init__.py +++ b/src/predicate_secure/__init__.py @@ -36,6 +36,15 @@ ) from .config import SecureAgentConfig, WrappedAgent from .detection import DetectionResult, Framework, FrameworkDetector, UnsupportedFrameworkError +from .tracing import ( + DebugTracer, + PolicyDecision, + SnapshotDiff, + TraceEvent, + TraceFormat, + VerificationResult, + create_debug_tracer, +) __version__ = "0.1.0" @@ -57,6 +66,14 @@ "create_playwright_adapter", "create_langchain_adapter", "create_pydantic_ai_adapter", + # Tracing + "DebugTracer", + "TraceEvent", + "TraceFormat", + "PolicyDecision", + "SnapshotDiff", + "VerificationResult", + "create_debug_tracer", # Modes "MODE_STRICT", "MODE_PERMISSIVE", @@ -133,6 +150,10 @@ def __init__( sidecar_url: str | None = None, signing_key: str | None = None, mandate_ttl_seconds: int = 300, + trace_format: str = "console", + trace_file: str | Path | None = None, + trace_colors: bool = True, + trace_verbose: bool = True, ): """ Initialize SecureAgent wrapper. @@ -147,6 +168,10 @@ def __init__( sidecar_url: Sidecar URL (None for embedded mode) signing_key: Secret key for mandate signing mandate_ttl_seconds: TTL for issued mandates + trace_format: Format for debug trace output ("console" or "json") + trace_file: Path to trace output file (None for stderr) + trace_colors: Whether to use ANSI colors in console output + trace_verbose: Whether to output verbose trace information """ # Build config from kwargs self._config = SecureAgentConfig.from_kwargs( @@ -158,6 +183,10 @@ def __init__( sidecar_url=sidecar_url, signing_key=signing_key, mandate_ttl_seconds=mandate_ttl_seconds, + trace_format=trace_format, + trace_file=trace_file, + trace_colors=trace_colors, + trace_verbose=trace_verbose, ) # Detect framework and wrap agent @@ -166,6 +195,16 @@ def __init__( # Lazy-initialized authority context self._authority_context: Any = None + # Debug tracer (initialized when mode="debug") + self._tracer: DebugTracer | None = None + if self._config.is_debug_mode: + self._tracer = create_debug_tracer( + format=self._config.trace_format, + file_path=self._config.effective_trace_file, + use_colors=self._config.trace_colors, + verbose=self._config.trace_verbose, + ) + # Legacy attribute access (for backward compat with tests) self._agent = agent self._policy = policy @@ -188,6 +227,11 @@ def framework(self) -> Framework: """Get the detected framework.""" return Framework(self._wrapped.framework) + @property + def tracer(self) -> DebugTracer | None: + """Get the debug tracer (available when mode='debug').""" + return self._tracer + def _wrap_agent(self, agent: Any) -> WrappedAgent: """ Detect framework and wrap agent. @@ -285,10 +329,38 @@ def _create_pre_action_authorizer(self) -> Any: def authorizer(request: Any) -> Any: """Pre-action authorization callback.""" + # Trace authorization request + if self._tracer: + action = getattr(request, "action", str(request)) + resource = getattr(request, "resource", "") + self._tracer.trace_authorization_request( + action=action, + resource=resource, + principal=self._config.effective_principal_id, + ) + decision = context.client.authorize(request) - if self._config.mode == "debug": - print(f"[predicate-secure] authorize({request.action}): {decision}") + # Trace policy decision + if self._tracer: + action = getattr(request, "action", str(request)) + resource = getattr(request, "resource", "") + reason = None + if hasattr(decision, "reason") and decision.reason: + reason = ( + decision.reason.value + if hasattr(decision.reason, "value") + else str(decision.reason) + ) + self._tracer.trace_policy_decision( + PolicyDecision( + action=action, + resource=resource, + allowed=decision.allowed, + reason=reason, + principal=self._config.effective_principal_id, + ) + ) if not decision.allowed and self._config.fail_closed: raise AuthorizationDenied( @@ -319,20 +391,41 @@ def run(self, task: str | None = None) -> Any: detection = FrameworkDetector.detect(self._wrapped.original) raise UnsupportedFrameworkError(detection) - # Framework-specific execution - if self._wrapped.framework == Framework.BROWSER_USE.value: - return self._run_browser_use(task) + # Trace session start + if self._tracer: + self._tracer.trace_session_start( + framework=self._wrapped.framework, + mode=self._config.mode, + policy=self._config.effective_policy_path, + principal_id=self._config.effective_principal_id, + ) - if self._wrapped.framework == Framework.PLAYWRIGHT.value: - return self._run_playwright(task) + try: + # Framework-specific execution + if self._wrapped.framework == Framework.BROWSER_USE.value: + result = self._run_browser_use(task) + elif self._wrapped.framework == Framework.PLAYWRIGHT.value: + result = self._run_playwright(task) + elif self._wrapped.framework == Framework.LANGCHAIN.value: + result = self._run_langchain(task) + elif self._wrapped.framework == Framework.PYDANTIC_AI.value: + result = self._run_pydantic_ai(task) + else: + raise NotImplementedError( + f"run() not implemented for framework: {self._wrapped.framework}" + ) - if self._wrapped.framework == Framework.LANGCHAIN.value: - return self._run_langchain(task) + # Trace session end (success) + if self._tracer: + self._tracer.trace_session_end(success=True) - if self._wrapped.framework == Framework.PYDANTIC_AI.value: - return self._run_pydantic_ai(task) + return result - raise NotImplementedError(f"run() not implemented for framework: {self._wrapped.framework}") + except Exception as e: + # Trace session end (failure) + if self._tracer: + self._tracer.trace_session_end(success=False, error=str(e)) + raise def _run_browser_use(self, task: str | None) -> Any: """Run browser-use agent with authorization.""" @@ -388,6 +481,123 @@ def _run_pydantic_ai(self, task: str | None) -> Any: """Run PydanticAI agent with authorization.""" raise NotImplementedError("PydanticAI integration not yet implemented.") + def trace_step( + self, + action: str, + resource: str = "", + metadata: dict | None = None, + ) -> int | None: + """ + Trace a step start (for manual step tracking). + + Args: + action: Action being performed + resource: Resource being acted upon + metadata: Additional metadata + + Returns: + Step number (None if not in debug mode) + + Example: + step = secure.trace_step("click", "button#submit") + # ... perform action ... + secure.trace_step_end(step, success=True) + """ + if self._tracer: + return self._tracer.trace_step_start( + action=action, + resource=resource, + metadata=metadata, + ) + return None + + def trace_step_end( + self, + step_number: int | None, + success: bool = True, + result: Any = None, + error: str | None = None, + ) -> None: + """ + Trace a step end (for manual step tracking). + + Args: + step_number: Step number from trace_step() + success: Whether the step succeeded + result: Step result (optional) + error: Error message if failed + """ + if self._tracer and step_number is not None: + self._tracer.trace_step_end( + step_number=step_number, + success=success, + result=result, + error=error, + ) + + def trace_snapshot_diff( + self, + before: dict | None = None, + after: dict | None = None, + diff: dict | None = None, + label: str = "State Change", + ) -> None: + """ + Trace a snapshot diff (before/after state change). + + Args: + before: Before snapshot (for computing diff) + after: After snapshot (for computing diff) + diff: Pre-computed diff (if before/after not provided) + label: Label for the diff + """ + if not self._tracer: + return + + if diff: + self._tracer.trace_snapshot_diff(SnapshotDiff(**diff), label=label) + elif before is not None and after is not None: + # Compute simple diff + computed_diff = SnapshotDiff( + added=[k for k in after if k not in before], + removed=[k for k in before if k not in after], + changed=[ + {"element": k, "before": before[k], "after": after[k]} + for k in before + if k in after and before[k] != after[k] + ], + ) + self._tracer.trace_snapshot_diff(computed_diff, label=label) + + def trace_verification( + self, + predicate: str, + passed: bool, + message: str | None = None, + expected: Any = None, + actual: Any = None, + ) -> None: + """ + Trace a verification predicate result. + + Args: + predicate: Predicate name or expression + passed: Whether verification passed + message: Optional message + expected: Expected value (for failed verifications) + actual: Actual value (for failed verifications) + """ + if self._tracer: + self._tracer.trace_verification_result( + VerificationResult( + predicate=predicate, + passed=passed, + message=message, + expected=expected, + actual=actual, + ) + ) + @classmethod def attach(cls, agent: Any, **kwargs: Any) -> SecureAgent: """ diff --git a/src/predicate_secure/adapters.py b/src/predicate_secure/adapters.py index e25b100..727a574 100644 --- a/src/predicate_secure/adapters.py +++ b/src/predicate_secure/adapters.py @@ -134,7 +134,8 @@ async def create_browser_use_runtime( try: from predicate.agent_runtime import AgentRuntime from predicate.backends.browser_use_adapter import BrowserUseAdapter - from predicate.tracing import JsonlTraceSink, Tracer as PredicateTracer + from predicate.tracing import JsonlTraceSink + from predicate.tracing import Tracer as PredicateTracer except ImportError as e: raise AdapterError( f"browser-use adapter requires predicate. Error: {e}", @@ -201,7 +202,8 @@ def create_playwright_adapter( """ try: from predicate.agent_runtime import AgentRuntime - from predicate.tracing import JsonlTraceSink, Tracer as PredicateTracer + from predicate.tracing import JsonlTraceSink + from predicate.tracing import Tracer as PredicateTracer except ImportError as e: raise AdapterError( f"Playwright adapter requires predicate. Error: {e}", @@ -368,20 +370,14 @@ def create_adapter( AdapterError: If framework is not supported """ if framework == Framework.BROWSER_USE: - return create_browser_use_adapter( - agent, tracer, snapshot_options, predicate_api_key - ) + return create_browser_use_adapter(agent, tracer, snapshot_options, predicate_api_key) if framework == Framework.PLAYWRIGHT: - return create_playwright_adapter( - agent, tracer, snapshot_options, predicate_api_key - ) + return create_playwright_adapter(agent, tracer, snapshot_options, predicate_api_key) if framework == Framework.LANGCHAIN: browser = kwargs.get("browser") - return create_langchain_adapter( - agent, browser, tracer, snapshot_options, predicate_api_key - ) + return create_langchain_adapter(agent, browser, tracer, snapshot_options, predicate_api_key) if framework == Framework.PYDANTIC_AI: return create_pydantic_ai_adapter(agent, tracer) diff --git a/src/predicate_secure/config.py b/src/predicate_secure/config.py index e6a9ee4..96092dc 100644 --- a/src/predicate_secure/config.py +++ b/src/predicate_secure/config.py @@ -10,6 +10,9 @@ # Mode type alias Mode = Literal["strict", "permissive", "debug", "audit"] +# Trace format type alias +TraceFormatType = Literal["console", "json"] + @dataclass(frozen=True) class SecureAgentConfig: @@ -26,6 +29,10 @@ class SecureAgentConfig: signing_key: Secret key for mandate signing (auto-detect from env if not provided) mandate_ttl_seconds: TTL for issued mandates fail_closed: Whether to fail closed on authorization errors (based on mode) + trace_format: Format for debug trace output ("console" or "json") + trace_file: Path to trace output file (None for stderr) + trace_colors: Whether to use ANSI colors in console output + trace_verbose: Whether to output verbose trace information """ policy: str | Path | None = None @@ -36,6 +43,11 @@ class SecureAgentConfig: sidecar_url: str | None = None signing_key: str | None = None mandate_ttl_seconds: int = 300 + # Debug trace configuration + trace_format: TraceFormatType = "console" + trace_file: str | Path | None = None + trace_colors: bool = True + trace_verbose: bool = True @property def fail_closed(self) -> bool: @@ -68,6 +80,20 @@ def effective_policy_path(self) -> str | None: return str(self.policy) return self.policy + @property + def is_debug_mode(self) -> bool: + """Whether debug mode is enabled.""" + return self.mode == "debug" + + @property + def effective_trace_file(self) -> str | None: + """Get trace file path as string.""" + if self.trace_file is None: + return None + if isinstance(self.trace_file, Path): + return str(self.trace_file) + return self.trace_file + @classmethod def from_kwargs( cls, @@ -79,12 +105,22 @@ def from_kwargs( sidecar_url: str | None = None, signing_key: str | None = None, mandate_ttl_seconds: int = 300, + trace_format: str = "console", + trace_file: str | Path | None = None, + trace_colors: bool = True, + trace_verbose: bool = True, ) -> SecureAgentConfig: """Create config from keyword arguments with validation.""" valid_modes = ("strict", "permissive", "debug", "audit") if mode not in valid_modes: raise ValueError(f"Invalid mode '{mode}'. Must be one of: {valid_modes}") + valid_formats = ("console", "json") + if trace_format not in valid_formats: + raise ValueError( + f"Invalid trace_format '{trace_format}'. Must be one of: {valid_formats}" + ) + return cls( policy=policy, mode=mode, # type: ignore[arg-type] @@ -94,6 +130,10 @@ def from_kwargs( sidecar_url=sidecar_url, signing_key=signing_key, mandate_ttl_seconds=mandate_ttl_seconds, + trace_format=trace_format, # type: ignore[arg-type] + trace_file=trace_file, + trace_colors=trace_colors, + trace_verbose=trace_verbose, ) diff --git a/src/predicate_secure/tracing.py b/src/predicate_secure/tracing.py new file mode 100644 index 0000000..910351a --- /dev/null +++ b/src/predicate_secure/tracing.py @@ -0,0 +1,534 @@ +"""Debug tracing for predicate-secure. + +This module provides human-readable and machine-parseable trace output +for debugging agent executions with authorization and verification. +""" + +from __future__ import annotations + +import json +import sys +import time +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import IO, Any, Literal + + +class TraceFormat(str, Enum): + """Output format for trace events.""" + + CONSOLE = "console" + JSON = "json" + + +@dataclass +class TraceEvent: + """A single trace event in the execution flow.""" + + event_type: str + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + data: dict = field(default_factory=dict) + step_number: int | None = None + duration_ms: float | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: dict[str, Any] = { + "event_type": self.event_type, + "timestamp": self.timestamp, + "data": self.data, + } + if self.step_number is not None: + result["step_number"] = self.step_number + if self.duration_ms is not None: + result["duration_ms"] = self.duration_ms + return result + + def to_json(self) -> str: + """Convert to JSON string.""" + return json.dumps(self.to_dict()) + + +@dataclass +class SnapshotDiff: + """Represents a diff between two snapshots.""" + + added: list[str] = field(default_factory=list) + removed: list[str] = field(default_factory=list) + changed: list[dict] = field(default_factory=list) # {"element": str, "before": str, "after": str} + + def is_empty(self) -> bool: + """Check if diff is empty (no changes).""" + return not (self.added or self.removed or self.changed) + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return asdict(self) + + +@dataclass +class PolicyDecision: + """Represents a policy decision with explanation.""" + + action: str + resource: str + allowed: bool + reason: str | None = None + policy_rule: str | None = None + principal: str | None = None + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return asdict(self) + + +@dataclass +class VerificationResult: + """Represents a verification predicate result.""" + + predicate: str + passed: bool + message: str | None = None + expected: Any = None + actual: Any = None + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return asdict(self) + + +class DebugTracer: + """ + Tracer for debug mode output. + + Outputs human-readable trace information to console or file, + with optional JSON format for machine parsing. + + Example: + tracer = DebugTracer(format="console") + tracer.trace_step_start(1, "click", "button#submit") + tracer.trace_policy_decision(decision) + tracer.trace_step_end(1, duration_ms=150) + """ + + # ANSI color codes + COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + } + + def __init__( + self, + format: Literal["console", "json"] = "console", + output: IO[str] | None = None, + file_path: str | Path | None = None, + use_colors: bool = True, + verbose: bool = True, + ): + """ + Initialize the debug tracer. + + Args: + format: Output format ("console" or "json") + output: Output stream (defaults to sys.stderr for console, file for json) + file_path: Path to trace file (for json format or file output) + use_colors: Whether to use ANSI colors (console format only) + verbose: Whether to output verbose information + """ + self.format = TraceFormat(format) + self.use_colors = use_colors and self.format == TraceFormat.CONSOLE + self.verbose = verbose + self._step_count = 0 + self._start_time: float | None = None + self._step_start_times: dict[int, float] = {} + self._events: list[TraceEvent] = [] + + # Set up output stream + self._file_handle: IO[str] | None = None + self.output: IO[str] + if file_path: + self._file_handle = open(file_path, "a") + self.output = self._file_handle + elif output: + self.output = output + else: + self.output = sys.stderr + + def close(self) -> None: + """Close the file handle if opened.""" + if self._file_handle: + self._file_handle.close() + self._file_handle = None + + def __enter__(self) -> DebugTracer: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def _color(self, text: str, color: str) -> str: + """Apply color to text if colors are enabled.""" + if not self.use_colors: + return text + return f"{self.COLORS.get(color, '')}{text}{self.COLORS['reset']}" + + def _emit(self, event: TraceEvent) -> None: + """Emit a trace event.""" + self._events.append(event) + + if self.format == TraceFormat.JSON: + self.output.write(event.to_json() + "\n") + self.output.flush() + # Console format is handled by specific trace methods + + def trace_session_start( + self, + framework: str, + mode: str, + policy: str | None = None, + principal_id: str | None = None, + ) -> None: + """Trace session start.""" + self._start_time = time.time() + self._step_count = 0 + + event = TraceEvent( + event_type="session_start", + data={ + "framework": framework, + "mode": mode, + "policy": policy, + "principal_id": principal_id, + }, + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE: + self.output.write("\n") + self.output.write(self._color("=" * 60, "bold") + "\n") + self.output.write( + self._color("[predicate-secure]", "cyan") + + " Session Start\n" + ) + self.output.write(f" Framework: {self._color(framework, 'blue')}\n") + self.output.write(f" Mode: {self._color(mode, 'yellow')}\n") + if policy: + self.output.write(f" Policy: {policy}\n") + if principal_id: + self.output.write(f" Principal: {principal_id}\n") + self.output.write(self._color("=" * 60, "bold") + "\n\n") + self.output.flush() + + def trace_session_end(self, success: bool = True, error: str | None = None) -> None: + """Trace session end.""" + duration_ms = None + if self._start_time: + duration_ms = (time.time() - self._start_time) * 1000 + + event = TraceEvent( + event_type="session_end", + data={ + "success": success, + "error": error, + "total_steps": self._step_count, + }, + duration_ms=duration_ms, + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE: + self.output.write("\n") + self.output.write(self._color("=" * 60, "bold") + "\n") + status = ( + self._color("SUCCESS", "green") + if success + else self._color("FAILED", "red") + ) + self.output.write( + self._color("[predicate-secure]", "cyan") + + f" Session End: {status}\n" + ) + self.output.write(f" Total Steps: {self._step_count}\n") + if duration_ms: + self.output.write(f" Duration: {duration_ms:.1f}ms\n") + if error: + self.output.write(f" Error: {self._color(error, 'red')}\n") + self.output.write(self._color("=" * 60, "bold") + "\n") + self.output.flush() + + def trace_step_start( + self, + step_number: int | None = None, + action: str = "", + resource: str = "", + metadata: dict | None = None, + ) -> int: + """ + Trace step start. + + Args: + step_number: Step number (auto-incremented if None) + action: Action being performed + resource: Resource being acted upon + metadata: Additional metadata + + Returns: + The step number + """ + if step_number is None: + self._step_count += 1 + step_number = self._step_count + else: + self._step_count = max(self._step_count, step_number) + + self._step_start_times[step_number] = time.time() + + event = TraceEvent( + event_type="step_start", + step_number=step_number, + data={ + "action": action, + "resource": resource, + **(metadata or {}), + }, + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE: + self.output.write( + self._color(f"[Step {step_number}]", "bold") + + f" {self._color(action, 'magenta')}" + ) + if resource: + self.output.write(f" → {self._color(resource, 'blue')}") + self.output.write("\n") + self.output.flush() + + return step_number + + def trace_step_end( + self, + step_number: int, + success: bool = True, + result: Any = None, + error: str | None = None, + ) -> None: + """Trace step end.""" + duration_ms = None + if step_number in self._step_start_times: + duration_ms = (time.time() - self._step_start_times[step_number]) * 1000 + del self._step_start_times[step_number] + + event = TraceEvent( + event_type="step_end", + step_number=step_number, + duration_ms=duration_ms, + data={ + "success": success, + "result": str(result) if result else None, + "error": error, + }, + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE and self.verbose: + status = ( + self._color("OK", "green") + if success + else self._color("FAILED", "red") + ) + duration_str = f" ({duration_ms:.1f}ms)" if duration_ms else "" + self.output.write(f" └─ {status}{duration_str}\n") + if error: + self.output.write(f" Error: {self._color(error, 'red')}\n") + self.output.write("\n") + self.output.flush() + + def trace_policy_decision( + self, + decision: PolicyDecision | dict, + ) -> None: + """Trace a policy decision.""" + if isinstance(decision, dict): + decision = PolicyDecision(**decision) + + event = TraceEvent( + event_type="policy_decision", + data=decision.to_dict(), + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE: + status = ( + self._color("ALLOWED", "green") + if decision.allowed + else self._color("DENIED", "red") + ) + self.output.write(f" ├─ Policy: {status}\n") + self.output.write(f" │ Action: {decision.action}\n") + self.output.write(f" │ Resource: {decision.resource}\n") + if decision.reason: + self.output.write(f" │ Reason: {decision.reason}\n") + if decision.policy_rule: + self.output.write( + f" │ Rule: {self._color(decision.policy_rule, 'dim')}\n" + ) + self.output.flush() + + def trace_snapshot_diff( + self, + diff: SnapshotDiff | dict, + label: str = "State Change", + ) -> None: + """Trace snapshot diff (before/after state change).""" + if isinstance(diff, dict): + diff = SnapshotDiff(**diff) + + event = TraceEvent( + event_type="snapshot_diff", + data={ + "label": label, + **diff.to_dict(), + }, + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE: + if diff.is_empty(): + self.output.write( + f" ├─ {label}: {self._color('(no changes)', 'dim')}\n" + ) + else: + self.output.write(f" ├─ {label}:\n") + for added in diff.added: + self.output.write( + f" │ {self._color('+', 'green')} {added}\n" + ) + for removed in diff.removed: + self.output.write( + f" │ {self._color('-', 'red')} {removed}\n" + ) + for changed in diff.changed: + self.output.write( + f" │ {self._color('~', 'yellow')} {changed.get('element', 'unknown')}\n" + ) + if self.verbose: + before = changed.get("before", "") + after = changed.get("after", "") + if before: + self.output.write( + f" │ Before: {self._color(str(before)[:50], 'dim')}\n" + ) + if after: + self.output.write( + f" │ After: {self._color(str(after)[:50], 'dim')}\n" + ) + self.output.flush() + + def trace_verification_result( + self, + result: VerificationResult | dict, + ) -> None: + """Trace verification predicate result.""" + if isinstance(result, dict): + result = VerificationResult(**result) + + event = TraceEvent( + event_type="verification_result", + data=result.to_dict(), + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE: + status = ( + self._color("PASS", "green") + if result.passed + else self._color("FAIL", "red") + ) + self.output.write(f" ├─ Verification: {status}\n") + self.output.write(f" │ Predicate: {result.predicate}\n") + if result.message: + self.output.write(f" │ Message: {result.message}\n") + if not result.passed and self.verbose: + if result.expected is not None: + self.output.write(f" │ Expected: {result.expected}\n") + if result.actual is not None: + self.output.write(f" │ Actual: {result.actual}\n") + self.output.flush() + + def trace_authorization_request( + self, + action: str, + resource: str, + principal: str | None = None, + context: dict | None = None, + ) -> None: + """Trace an authorization request.""" + event = TraceEvent( + event_type="authorization_request", + data={ + "action": action, + "resource": resource, + "principal": principal, + "context": context, + }, + ) + self._emit(event) + + if self.format == TraceFormat.CONSOLE: + self.output.write(f" ├─ Authorize: {action} on {resource}\n") + if principal: + self.output.write(f" │ Principal: {principal}\n") + self.output.flush() + + def trace_custom(self, event_type: str, data: dict) -> None: + """Trace a custom event.""" + event = TraceEvent(event_type=event_type, data=data) + self._emit(event) + + if self.format == TraceFormat.CONSOLE and self.verbose: + self.output.write(f" ├─ {event_type}: {json.dumps(data)}\n") + self.output.flush() + + def get_events(self) -> list[TraceEvent]: + """Get all recorded trace events.""" + return self._events.copy() + + def clear_events(self) -> None: + """Clear recorded events.""" + self._events.clear() + + +def create_debug_tracer( + format: Literal["console", "json"] = "console", + file_path: str | Path | None = None, + use_colors: bool = True, + verbose: bool = True, +) -> DebugTracer: + """ + Factory function to create a debug tracer. + + Args: + format: Output format ("console" or "json") + file_path: Path to trace file (optional) + use_colors: Whether to use ANSI colors + verbose: Whether to output verbose information + + Returns: + Configured DebugTracer instance + """ + return DebugTracer( + format=format, + file_path=file_path, + use_colors=use_colors, + verbose=verbose, + ) diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 8abec6b..29fe8f6 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -2,13 +2,7 @@ import pytest -from predicate_secure import ( - AdapterError, - AdapterResult, - Framework, - SecureAgent, - create_adapter, -) +from predicate_secure import AdapterError, AdapterResult, Framework, SecureAgent, create_adapter class TestAdapterResult: diff --git a/tests/test_adapters_integration.py b/tests/test_adapters_integration.py index 9eb21a0..0e75aa0 100644 --- a/tests/test_adapters_integration.py +++ b/tests/test_adapters_integration.py @@ -18,7 +18,6 @@ create_pydantic_ai_adapter, ) - # Check if predicate is available try: from predicate.integrations.browser_use.plugin import PredicateBrowserUsePlugin diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 0000000..74d8733 --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,605 @@ +"""Tests for debug tracing functionality.""" + +import io +import json +import tempfile +from pathlib import Path + +from predicate_secure import ( + DebugTracer, + PolicyDecision, + SecureAgent, + SnapshotDiff, + TraceEvent, + TraceFormat, + VerificationResult, + create_debug_tracer, +) + + +class TestTraceEvent: + """Tests for TraceEvent dataclass.""" + + def test_trace_event_to_dict(self): + """TraceEvent converts to dict correctly.""" + event = TraceEvent( + event_type="test_event", + timestamp="2024-01-01T00:00:00Z", + data={"key": "value"}, + step_number=1, + duration_ms=100.5, + ) + d = event.to_dict() + + assert d["event_type"] == "test_event" + assert d["timestamp"] == "2024-01-01T00:00:00Z" + assert d["data"] == {"key": "value"} + assert d["step_number"] == 1 + assert d["duration_ms"] == 100.5 + + def test_trace_event_to_dict_optional_fields(self): + """TraceEvent omits optional fields when None.""" + event = TraceEvent(event_type="minimal") + d = event.to_dict() + + assert "step_number" not in d + assert "duration_ms" not in d + + def test_trace_event_to_json(self): + """TraceEvent converts to valid JSON.""" + event = TraceEvent(event_type="test", data={"foo": "bar"}) + json_str = event.to_json() + parsed = json.loads(json_str) + + assert parsed["event_type"] == "test" + assert parsed["data"]["foo"] == "bar" + + +class TestSnapshotDiff: + """Tests for SnapshotDiff dataclass.""" + + def test_snapshot_diff_is_empty(self): + """SnapshotDiff.is_empty() works correctly.""" + empty_diff = SnapshotDiff() + assert empty_diff.is_empty() is True + + diff_with_added = SnapshotDiff(added=["element1"]) + assert diff_with_added.is_empty() is False + + diff_with_removed = SnapshotDiff(removed=["element1"]) + assert diff_with_removed.is_empty() is False + + diff_with_changed = SnapshotDiff(changed=[{"element": "x", "before": "a", "after": "b"}]) + assert diff_with_changed.is_empty() is False + + def test_snapshot_diff_to_dict(self): + """SnapshotDiff converts to dict correctly.""" + diff = SnapshotDiff( + added=["new_element"], + removed=["old_element"], + changed=[{"element": "modified", "before": "x", "after": "y"}], + ) + d = diff.to_dict() + + assert d["added"] == ["new_element"] + assert d["removed"] == ["old_element"] + assert len(d["changed"]) == 1 + + +class TestPolicyDecision: + """Tests for PolicyDecision dataclass.""" + + def test_policy_decision_to_dict(self): + """PolicyDecision converts to dict correctly.""" + decision = PolicyDecision( + action="click", + resource="button#submit", + allowed=True, + reason="policy_allowed", + policy_rule="allow_submit_buttons", + principal="agent:test", + ) + d = decision.to_dict() + + assert d["action"] == "click" + assert d["resource"] == "button#submit" + assert d["allowed"] is True + assert d["reason"] == "policy_allowed" + assert d["policy_rule"] == "allow_submit_buttons" + assert d["principal"] == "agent:test" + + +class TestVerificationResult: + """Tests for VerificationResult dataclass.""" + + def test_verification_result_to_dict(self): + """VerificationResult converts to dict correctly.""" + result = VerificationResult( + predicate="cart_updated", + passed=False, + message="Cart count mismatch", + expected=1, + actual=0, + ) + d = result.to_dict() + + assert d["predicate"] == "cart_updated" + assert d["passed"] is False + assert d["message"] == "Cart count mismatch" + assert d["expected"] == 1 + assert d["actual"] == 0 + + +class TestDebugTracerConsole: + """Tests for DebugTracer console output.""" + + def test_tracer_session_start(self): + """Tracer outputs session start correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_session_start( + framework="browser_use", + mode="debug", + policy="test.yaml", + principal_id="agent:test", + ) + + out = output.getvalue() + assert "Session Start" in out + assert "browser_use" in out + assert "debug" in out + assert "test.yaml" in out + assert "agent:test" in out + + def test_tracer_session_end(self): + """Tracer outputs session end correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_session_start(framework="test", mode="debug") + tracer.trace_session_end(success=True) + + out = output.getvalue() + assert "Session End" in out + assert "SUCCESS" in out + + def test_tracer_session_end_failure(self): + """Tracer outputs failure correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_session_end(success=False, error="Test error") + + out = output.getvalue() + assert "FAILED" in out + assert "Test error" in out + + def test_tracer_step_start(self): + """Tracer outputs step start correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + step = tracer.trace_step_start(action="click", resource="button#submit") + + out = output.getvalue() + assert "[Step 1]" in out + assert "click" in out + assert "button#submit" in out + assert step == 1 + + def test_tracer_step_auto_increment(self): + """Tracer auto-increments step numbers.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + step1 = tracer.trace_step_start(action="action1") + step2 = tracer.trace_step_start(action="action2") + step3 = tracer.trace_step_start(action="action3") + + assert step1 == 1 + assert step2 == 2 + assert step3 == 3 + + def test_tracer_step_end(self): + """Tracer outputs step end correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + step = tracer.trace_step_start(action="click") + tracer.trace_step_end(step, success=True) + + out = output.getvalue() + assert "OK" in out + + def test_tracer_step_end_failure(self): + """Tracer outputs step failure correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + step = tracer.trace_step_start(action="click") + tracer.trace_step_end(step, success=False, error="Element not found") + + out = output.getvalue() + assert "FAILED" in out + assert "Element not found" in out + + def test_tracer_policy_decision_allowed(self): + """Tracer outputs allowed policy decision correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_policy_decision( + PolicyDecision(action="click", resource="button", allowed=True) + ) + + out = output.getvalue() + assert "ALLOWED" in out + assert "click" in out + assert "button" in out + + def test_tracer_policy_decision_denied(self): + """Tracer outputs denied policy decision correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_policy_decision( + PolicyDecision( + action="delete", + resource="database", + allowed=False, + reason="policy_denied", + ) + ) + + out = output.getvalue() + assert "DENIED" in out + assert "delete" in out + assert "database" in out + assert "policy_denied" in out + + def test_tracer_snapshot_diff(self): + """Tracer outputs snapshot diff correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_snapshot_diff( + SnapshotDiff( + added=["new_element"], + removed=["old_element"], + changed=[{"element": "modified", "before": "x", "after": "y"}], + ) + ) + + out = output.getvalue() + assert "+" in out + assert "new_element" in out + assert "-" in out + assert "old_element" in out + assert "~" in out + assert "modified" in out + + def test_tracer_snapshot_diff_empty(self): + """Tracer handles empty diff correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_snapshot_diff(SnapshotDiff()) + + out = output.getvalue() + assert "(no changes)" in out + + def test_tracer_verification_pass(self): + """Tracer outputs passed verification correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_verification_result( + VerificationResult(predicate="item_in_cart", passed=True) + ) + + out = output.getvalue() + assert "PASS" in out + assert "item_in_cart" in out + + def test_tracer_verification_fail(self): + """Tracer outputs failed verification correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output, use_colors=False) + + tracer.trace_verification_result( + VerificationResult( + predicate="cart_count", + passed=False, + message="Count mismatch", + expected=1, + actual=0, + ) + ) + + out = output.getvalue() + assert "FAIL" in out + assert "cart_count" in out + assert "Count mismatch" in out + assert "Expected: 1" in out + assert "Actual: 0" in out + + +class TestDebugTracerJson: + """Tests for DebugTracer JSON output.""" + + def test_tracer_json_session_start(self): + """Tracer outputs JSON session start correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="json", output=output) + + tracer.trace_session_start(framework="test", mode="debug") + + lines = output.getvalue().strip().split("\n") + event = json.loads(lines[0]) + + assert event["event_type"] == "session_start" + assert event["data"]["framework"] == "test" + assert event["data"]["mode"] == "debug" + + def test_tracer_json_step(self): + """Tracer outputs JSON step events correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="json", output=output) + + tracer.trace_step_start(action="click", resource="button") + tracer.trace_step_end(1, success=True) + + lines = output.getvalue().strip().split("\n") + start_event = json.loads(lines[0]) + end_event = json.loads(lines[1]) + + assert start_event["event_type"] == "step_start" + assert start_event["step_number"] == 1 + assert end_event["event_type"] == "step_end" + assert end_event["step_number"] == 1 + + def test_tracer_json_policy_decision(self): + """Tracer outputs JSON policy decision correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="json", output=output) + + tracer.trace_policy_decision( + PolicyDecision(action="click", resource="button", allowed=True) + ) + + lines = output.getvalue().strip().split("\n") + event = json.loads(lines[0]) + + assert event["event_type"] == "policy_decision" + assert event["data"]["action"] == "click" + assert event["data"]["allowed"] is True + + +class TestDebugTracerFile: + """Tests for DebugTracer file output.""" + + def test_tracer_file_output(self): + """Tracer writes to file correctly.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + file_path = f.name + + try: + tracer = DebugTracer(format="json", file_path=file_path) + tracer.trace_session_start(framework="test", mode="debug") + tracer.trace_step_start(action="click") + tracer.close() + + with open(file_path) as f: + lines = f.readlines() + + assert len(lines) == 2 + + event1 = json.loads(lines[0]) + assert event1["event_type"] == "session_start" + + event2 = json.loads(lines[1]) + assert event2["event_type"] == "step_start" + + finally: + Path(file_path).unlink(missing_ok=True) + + def test_tracer_context_manager(self): + """Tracer works as context manager.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + file_path = f.name + + try: + with DebugTracer(format="json", file_path=file_path) as tracer: + tracer.trace_session_start(framework="test", mode="debug") + + with open(file_path) as f: + lines = f.readlines() + + assert len(lines) == 1 + + finally: + Path(file_path).unlink(missing_ok=True) + + +class TestDebugTracerEvents: + """Tests for DebugTracer event collection.""" + + def test_tracer_get_events(self): + """Tracer collects events correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output) + + tracer.trace_session_start(framework="test", mode="debug") + tracer.trace_step_start(action="click") + tracer.trace_session_end(success=True) + + events = tracer.get_events() + assert len(events) == 3 + assert events[0].event_type == "session_start" + assert events[1].event_type == "step_start" + assert events[2].event_type == "session_end" + + def test_tracer_clear_events(self): + """Tracer clears events correctly.""" + output = io.StringIO() + tracer = DebugTracer(format="console", output=output) + + tracer.trace_session_start(framework="test", mode="debug") + assert len(tracer.get_events()) == 1 + + tracer.clear_events() + assert len(tracer.get_events()) == 0 + + +class TestCreateDebugTracer: + """Tests for create_debug_tracer factory function.""" + + def test_create_debug_tracer_console(self): + """create_debug_tracer creates console tracer.""" + tracer = create_debug_tracer(format="console") + assert tracer.format == TraceFormat.CONSOLE + + def test_create_debug_tracer_json(self): + """create_debug_tracer creates JSON tracer.""" + tracer = create_debug_tracer(format="json") + assert tracer.format == TraceFormat.JSON + + +class TestSecureAgentDebugMode: + """Tests for SecureAgent debug mode integration.""" + + def test_secure_agent_debug_mode_creates_tracer(self): + """SecureAgent creates tracer in debug mode.""" + + class MockAgent: + __module__ = "pydantic_ai.agent" + model = "test" + + secure = SecureAgent(agent=MockAgent(), mode="debug") + assert secure.tracer is not None + assert isinstance(secure.tracer, DebugTracer) + + def test_secure_agent_non_debug_mode_no_tracer(self): + """SecureAgent doesn't create tracer in non-debug modes.""" + + class MockAgent: + __module__ = "pydantic_ai.agent" + model = "test" + + secure_strict = SecureAgent(agent=MockAgent(), mode="strict") + assert secure_strict.tracer is None + + secure_permissive = SecureAgent(agent=MockAgent(), mode="permissive") + assert secure_permissive.tracer is None + + secure_audit = SecureAgent(agent=MockAgent(), mode="audit") + assert secure_audit.tracer is None + + def test_secure_agent_debug_mode_json_format(self): + """SecureAgent respects trace_format parameter.""" + + class MockAgent: + __module__ = "pydantic_ai.agent" + + secure = SecureAgent(agent=MockAgent(), mode="debug", trace_format="json") + assert secure.tracer is not None + assert secure.tracer.format == TraceFormat.JSON + + def test_secure_agent_trace_step(self): + """SecureAgent.trace_step() works correctly.""" + + class MockAgent: + __module__ = "pydantic_ai.agent" + + secure = SecureAgent(agent=MockAgent(), mode="debug") + step = secure.trace_step("click", "button#submit") + + assert step == 1 + + def test_secure_agent_trace_step_non_debug(self): + """SecureAgent.trace_step() returns None in non-debug mode.""" + + class MockAgent: + __module__ = "pydantic_ai.agent" + + secure = SecureAgent(agent=MockAgent(), mode="strict") + step = secure.trace_step("click", "button#submit") + + assert step is None + + def test_secure_agent_trace_snapshot_diff(self): + """SecureAgent.trace_snapshot_diff() works correctly.""" + + class MockAgent: + __module__ = "pydantic_ai.agent" + + secure = SecureAgent(agent=MockAgent(), mode="debug") + + # Should not raise + secure.trace_snapshot_diff( + before={"element1": "value1"}, + after={"element1": "value2", "element2": "new"}, + ) + + events = secure.tracer.get_events() + diff_events = [e for e in events if e.event_type == "snapshot_diff"] + assert len(diff_events) == 1 + + def test_secure_agent_trace_verification(self): + """SecureAgent.trace_verification() works correctly.""" + + class MockAgent: + __module__ = "pydantic_ai.agent" + + secure = SecureAgent(agent=MockAgent(), mode="debug") + + secure.trace_verification( + predicate="cart_updated", + passed=True, + message="Cart has 1 item", + ) + + events = secure.tracer.get_events() + verification_events = [e for e in events if e.event_type == "verification_result"] + assert len(verification_events) == 1 + assert verification_events[0].data["passed"] is True + + +class TestSecureAgentDebugModeFile: + """Tests for SecureAgent debug mode with file output.""" + + def test_secure_agent_debug_mode_file_output(self): + """SecureAgent writes trace to file in debug mode.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + file_path = f.name + + try: + + class MockAgent: + __module__ = "pydantic_ai.agent" + model = "test" + + secure = SecureAgent( + agent=MockAgent(), + mode="debug", + trace_format="json", + trace_file=file_path, + ) + + # Write some trace events + secure.trace_step("click", "button") + + # Close the tracer to flush + secure.tracer.close() + + # Read the file + with open(file_path) as f: + lines = f.readlines() + + assert len(lines) >= 1 + event = json.loads(lines[0]) + assert event["event_type"] == "step_start" + + finally: + Path(file_path).unlink(missing_ok=True)