diff --git a/pyproject.toml b/pyproject.toml index aefaa35d1..326e5d740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "mermaid-builder==0.0.3", "graphtty==0.1.8", "applicationinsights>=0.11.10", + "sqlparse>=0.4.4", ] classifiers = [ "Intended Audience :: Developers", diff --git a/src/uipath/agent/models/agent.py b/src/uipath/agent/models/agent.py index 0a61347a3..8ebec8ffd 100644 --- a/src/uipath/agent/models/agent.py +++ b/src/uipath/agent/models/agent.py @@ -110,6 +110,7 @@ class AgentContextRetrievalMode(str, Enum): STRUCTURED = "Structured" DEEP_RAG = "DeepRAG" BATCH_TRANSFORM = "BatchTransform" + DATA_FABRIC = "DataFabric" UNKNOWN = "Unknown" # fallback branch discriminator @@ -342,6 +343,7 @@ class AgentContextSettings(BaseCfg): AgentContextRetrievalMode.STRUCTURED, AgentContextRetrievalMode.DEEP_RAG, AgentContextRetrievalMode.BATCH_TRANSFORM, + AgentContextRetrievalMode.DATA_FABRIC, AgentContextRetrievalMode.UNKNOWN, ] = Field(alias="retrievalMode") threshold: float = Field(default=0) @@ -361,6 +363,10 @@ class AgentContextSettings(BaseCfg): output_columns: Optional[List[AgentContextOutputColumn]] = Field( None, alias="outputColumns" ) + # Data Fabric specific settings + entity_identifiers: Optional[List[str]] = Field( + None, alias="entityIdentifiers" + ) class AgentContextResourceConfig(BaseAgentResourceConfig): @@ -1198,6 +1204,7 @@ def _normalize_resources(v: Dict[str, Any]) -> None: "structured": "Structured", "deeprag": "DeepRAG", "batchtransform": "BatchTransform", + "datafabric": "DataFabric", "unknown": "Unknown", } diff --git a/src/uipath/platform/entities/_entities_service.py b/src/uipath/platform/entities/_entities_service.py index 2b31b830c..8b26f3d2b 100644 --- a/src/uipath/platform/entities/_entities_service.py +++ b/src/uipath/platform/entities/_entities_service.py @@ -1,6 +1,16 @@ -from typing import Any, List, Optional, Type +from typing import Any, Dict, List, Optional, Type +import sqlparse from httpx import Response +from sqlparse.sql import ( + IdentifierList, + Parenthesis, + Statement, + Token, + TokenList, + Where, +) +from sqlparse.tokens import DML, Comment, Keyword, Punctuation, Wildcard from ..._utils import Endpoint, RequestSpec from ...tracing import traced @@ -11,15 +21,39 @@ EntityRecordsBatchResponse, ) +_FORBIDDEN_SQL_KEYWORDS = { + "INSERT", + "UPDATE", + "DELETE", + "MERGE", + "DROP", + "ALTER", + "CREATE", + "TRUNCATE", + "REPLACE", +} +_DISALLOWED_SQL_OPERATORS = { + "WITH", + "UNION", + "INTERSECT", + "EXCEPT", + "OVER", + "ROLLUP", + "CUBE", + "GROUPING SETS", + "PARTITION BY", +} + class EntitiesService(BaseService): """Service for managing UiPath Data Service entities. - Entities are database tables in UiPath Data Service that can store - structured data for automation processes. + Entities represent business objects that provide structured data storage and access via the Data Service. + This service allows you to retrieve entity metadata, list entities, and query records using SQL. - See Also: - https://docs.uipath.com/data-service/automation-cloud/latest/user-guide/introduction + !!! warning "Preview Feature" + This function is currently experimental. + Behavior and parameters, request and response formats are subject to change in future versions. """ def __init__( @@ -389,6 +423,102 @@ class CustomerRecord: EntityRecord.from_data(data=record, model=schema) for record in records_data ] + @traced(name="query_multiple_entities", run_type="uipath") + def query_multiple_entities( + self, + sql_query: str, + schema: Optional[Type[Any]] = None, + ) -> List[Dict[str, Any]]: + """Query entity records using a SQL query. + + This method allows executing SQL queries directly against entity data + via the Data Fabric query endpoint. + + Args: + sql_query (str): The full SQL query to execute. Should be a valid + SELECT statement targeting the entity. + schema (Optional[Type[Any]]): Optional schema class for validation. + + Returns: + List[Dict[str, Any]]: A list of record dictionaries matching the query. + + Examples: + Basic query:: + + records = entities_service.query_multiple_entities( + "SELECT * FROM Customers WHERE Status = 'Active' LIMIT 100" + ) + + Query with specific fields:: + + records = entities_service.query_multiple_entities( + "SELECT OrderId, CustomerName, Amount FROM Orders WHERE Amount > 1000" + ) + """ + self._validate_sql_query(sql_query) + spec = self._query_multiple_entities_spec(sql_query) + headers = { + "X-UiPath-Internal-TenantName": self._url.tenant_name, + "X-UiPath-Internal-AccountName": self._url.org_name, + } + # Use absolute URL to bypass scoping since org/tenant are embedded in the path + full_url = f"{self._url.base_url}{spec.endpoint}" + response = self.request(spec.method, full_url, json=spec.json, headers=headers) + + if response.status_code == 200: + records_data = response.json().get("results", []) + return records_data + else: + response.raise_for_status() + + + @traced(name="query_multiple_entities_async", run_type="uipath") + async def query_multiple_entities_async( + self, + sql_query: str, + schema: Optional[Type[Any]] = None, + ) -> List[Dict[str, Any]]: + """Asynchronously query entity records using a SQL query. + + This method allows executing SQL queries directly against entity data + via the Data Fabric query endpoint. + + Args: + sql_query (str): The full SQL query to execute. Should be a valid + SELECT statement targeting the entity. + schema (Optional[Type[Any]]): Optional schema class for validation. + + Returns: + List[Dict[str, Any]]: A list of record dictionaries matching the query. + + Examples: + Basic query:: + + records = await entities_service.query_multiple_entities_async( + "SELECT * FROM Customers WHERE Status = 'Active' LIMIT 100" + ) + + Query with specific fields:: + + records = await entities_service.query_multiple_entities_async( + "SELECT OrderId, CustomerName, Amount FROM Orders WHERE Amount > 1000" + ) + """ + self._validate_sql_query(sql_query) + spec = self._query_multiple_entities_spec(sql_query) + headers = { + "X-UiPath-Internal-TenantName": self._url.tenant_name, + "X-UiPath-Internal-AccountName": self._url.org_name, + } + full_url = f"{self._url.base_url}{spec.endpoint}" + response = await self.request_async(spec.method, full_url, json=spec.json, headers=headers) + + if response.status_code == 200: + records_data = response.json().get("results", []) + return records_data + else: + response.raise_for_status() + @traced(name="entity_record_insert_batch", run_type="uipath") def insert_records( self, @@ -872,6 +1002,17 @@ def _list_records_spec( params=({"start": start, "limit": limit}), ) + def _query_multiple_entities_spec( + self, + sql_query: str, + ) -> RequestSpec: + endpoint = f"/dataservice_/{self._url.org_name}/{self._url.tenant_name}/datafabric_/api/v1/query/execute" + return RequestSpec( + method="POST", + endpoint=Endpoint(endpoint), + json={"query": sql_query}, + ) + def _insert_batch_spec(self, entity_key: str, records: List[Any]) -> RequestSpec: return RequestSpec( method="POST", @@ -900,3 +1041,113 @@ def _delete_batch_spec(self, entity_key: str, record_ids: List[str]) -> RequestS ), json=record_ids, ) + + def _validate_sql_query(self, sql_query: str) -> None: + query = sql_query.strip() + if not query: + raise ValueError("SQL query cannot be empty.") + + statements = [stmt for stmt in sqlparse.parse(query) if stmt.tokens] + if len(statements) != 1: + raise ValueError("Only a single SELECT statement is allowed.") + + statement = statements[0] + if statement.get_type() != "SELECT": + raise ValueError("Only SELECT statements are allowed.") + + normalized_keywords = { + token.normalized + for token in statement.flatten() + if token.ttype in Keyword or token.ttype is DML + } + + for keyword in _FORBIDDEN_SQL_KEYWORDS: + if keyword in normalized_keywords: + raise ValueError(f"SQL keyword '{keyword}' is not allowed.") + + for operator in _DISALLOWED_SQL_OPERATORS: + if operator in normalized_keywords: + raise ValueError( + f"SQL construct '{operator}' is not allowed in entity queries." + ) + + if self._contains_subquery(statement): + raise ValueError("Subqueries are not allowed.") + + has_where = any(isinstance(token, Where) for token in statement.tokens) + has_limit = any( + token.ttype in Keyword and token.normalized == "LIMIT" + for token in statement.flatten() + ) + if not has_where and not has_limit: + raise ValueError("Queries without WHERE must include a LIMIT clause.") + + projection_tokens = self._projection_tokens(statement) + has_wildcard_projection = any( + token.ttype is Wildcard + for projection_token in projection_tokens + for token in projection_token.flatten() + ) + if has_wildcard_projection and not has_where: + raise ValueError("SELECT * without filtering is not allowed.") + if not has_where and self._projection_column_count(projection_tokens) > 4: + raise ValueError( + "Selecting more than 4 columns without filtering is not allowed." + ) + + def _contains_subquery(self, token_list: TokenList) -> bool: + for token in token_list.tokens: + if isinstance(token, Parenthesis): + if any( + nested.ttype is DML and nested.normalized == "SELECT" + for nested in token.flatten() + ): + return True + if isinstance(token, TokenList) and self._contains_subquery(token): + return True + return False + + def _projection_tokens(self, statement: Statement) -> List[Token]: + projection: List[Token] = [] + found_select = False + + for token in statement.tokens: + if token.is_whitespace or token.ttype in Comment: + continue + if not found_select: + if token.ttype is DML and token.normalized == "SELECT": + found_select = True + continue + if token.ttype in Keyword and token.normalized == "FROM": + break + projection.append(token) + return projection + + def _projection_column_count(self, projection_tokens: List[Token]) -> int: + identifier_list = next( + ( + token + for token in projection_tokens + if isinstance(token, IdentifierList) + ), + None, + ) + if identifier_list is not None: + return sum(1 for _ in identifier_list.get_identifiers()) + + count = 0 + has_current_expression = False + + for token in projection_tokens: + if token.is_whitespace or token.ttype in Comment: + continue + if token.ttype is Punctuation and token.value == ",": + if has_current_expression: + count += 1 + has_current_expression = False + continue + has_current_expression = True + + if has_current_expression: + count += 1 + return count diff --git a/tests/agent/models/test_agent.py b/tests/agent/models/test_agent.py index f836cf4f2..fcd99d2c5 100644 --- a/tests/agent/models/test_agent.py +++ b/tests/agent/models/test_agent.py @@ -2761,6 +2761,56 @@ def test_is_conversational_false_by_default(self): assert config.is_conversational is False +class TestDataFabricContextConfig: + """Tests for Data Fabric context resource configuration.""" + + def test_datafabric_retrieval_mode_exists(self): + """Test that DATA_FABRIC retrieval mode is defined.""" + assert AgentContextRetrievalMode.DATA_FABRIC == "DataFabric" + + def test_datafabric_context_config_parses(self): + """Test that Data Fabric context config parses correctly.""" + config = { + "$resourceType": "context", + "name": "Customer Data", + "description": "Query customer and order data", + "isEnabled": True, + "folderPath": "Shared", + "indexName": "", + "settings": { + "retrievalMode": "DataFabric", + "resultCount": 100, + "entityIdentifiers": ["customers-key", "orders-key"], + }, + } + + parsed = AgentContextResourceConfig.model_validate(config) + + assert parsed.name == "Customer Data" + assert parsed.settings.retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC + assert parsed.settings.entity_identifiers == ["customers-key", "orders-key"] + + def test_datafabric_context_config_without_entity_identifiers(self): + """Test that entity_identifiers is optional.""" + config = { + "$resourceType": "context", + "name": "Test", + "description": "Test", + "isEnabled": True, + "folderPath": "Shared", + "indexName": "", + "settings": { + "retrievalMode": "DataFabric", + "resultCount": 10, + }, + } + + parsed = AgentContextResourceConfig.model_validate(config) + + assert parsed.settings.retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC + assert parsed.settings.entity_identifiers is None + + class TestAgentBuilderConfigResources: """Tests for AgentDefinition resource configuration parsing.""" diff --git a/tests/sdk/services/test_entities_service.py b/tests/sdk/services/test_entities_service.py index 4c6c85882..feb80a161 100644 --- a/tests/sdk/services/test_entities_service.py +++ b/tests/sdk/services/test_entities_service.py @@ -1,6 +1,8 @@ import uuid from dataclasses import make_dataclass from typing import Optional +import re +from unittest.mock import AsyncMock, MagicMock import pytest from pytest_httpx import HTTPXMock @@ -260,3 +262,111 @@ def test_retrieve_records_with_optional_fields( start=0, limit=1, ) + + @pytest.mark.parametrize( + "sql_query", + [ + "SELECT id FROM Customers WHERE id = 1", + "SELECT id, name FROM Customers LIMIT 10", + "SELECT * FROM Customers WHERE status = 'Active'", + "SELECT id, name, email, phone FROM Customers LIMIT 5", + "SELECT DISTINCT id FROM Customers WHERE id > 100", + ], + ) + def test_validate_sql_query_allows_supported_select_queries( + self, + sql_query: str, service: EntitiesService + ) -> None: + service._validate_sql_query(sql_query) + + + @pytest.mark.parametrize( + "sql_query,error_message", + [ + ("", "SQL query cannot be empty."), + (" ", "SQL query cannot be empty."), + ("SELECT id FROM Customers; SELECT id FROM Orders", "Only a single SELECT statement is allowed."), + ("INSERT INTO Customers VALUES (1)", "Only SELECT statements are allowed."), + ( + "WITH cte AS (SELECT id FROM Customers) SELECT id FROM cte", + "SQL construct 'WITH' is not allowed in entity queries.", + ), + ("SELECT id FROM Customers UNION SELECT id FROM Orders", "SQL construct 'UNION' is not allowed in entity queries."), + ("SELECT id, SUM(amount) OVER (PARTITION BY id) FROM Orders LIMIT 10", "SQL construct 'OVER' is not allowed in entity queries."), + ("SELECT id FROM (SELECT id FROM Customers) c", "Subqueries are not allowed."), + ("SELECT id FROM Customers", "Queries without WHERE must include a LIMIT clause."), + ("SELECT * FROM Customers LIMIT 10", "SELECT * without filtering is not allowed."), + ( + "SELECT id, name, email, phone, address FROM Customers LIMIT 10", + "Selecting more than 4 columns without filtering is not allowed.", + ), + ], + ) + def test_validate_sql_query_rejects_disallowed_queries( + self, + sql_query: str, error_message: str, service: EntitiesService + ) -> None: + with pytest.raises(ValueError, match=re.escape(error_message)): + service._validate_sql_query(sql_query) + + + def test_query_multiple_entities_rejects_invalid_sql_before_network_call( + self, + service: EntitiesService, + ) -> None: + service.request = MagicMock() # type: ignore[method-assign] + + with pytest.raises( + ValueError, match=re.escape("Only SELECT statements are allowed.") + ): + service.query_multiple_entities("UPDATE Customers SET name = 'X'") + + service.request.assert_not_called() # type: ignore[attr-defined] + + + def test_query_multiple_entities_calls_request_for_valid_sql( + self, + service: EntitiesService, + ) -> None: + response = MagicMock() + response.json.return_value = {"results": [{"id": 1}, {"id": 2}]} + + service.request = MagicMock(return_value=response) # type: ignore[method-assign] + + result = service.query_multiple_entities("SELECT id FROM Customers WHERE id > 0") + + assert result == [{"id": 1}, {"id": 2}] + service.request.assert_called_once() # type: ignore[attr-defined] + + + @pytest.mark.anyio + async def test_query_multiple_entities_async_rejects_invalid_sql_before_network_call( + self, + service: EntitiesService, + ) -> None: + service.request_async = AsyncMock() # type: ignore[method-assign] + + with pytest.raises(ValueError, match=re.escape("Subqueries are not allowed.")): + await service.query_multiple_entities_async( + "SELECT id FROM Customers WHERE id IN (SELECT id FROM Orders)" + ) + + service.request_async.assert_not_called() # type: ignore[attr-defined] + + + @pytest.mark.anyio + async def test_query_multiple_entities_async_calls_request_for_valid_sql( + self, + service: EntitiesService, + ) -> None: + response = MagicMock() + response.json.return_value = {"results": [{"id": "c1"}]} + + service.request_async = AsyncMock(return_value=response) # type: ignore[method-assign] + + result = await service.query_multiple_entities_async( + "SELECT id FROM Customers WHERE id = 'c1'" + ) + + assert result == [{"id": "c1"}] + service.request_async.assert_called_once() # type: ignore[attr-defined] diff --git a/uv.lock b/uv.lock index 19a37fc8e..acbeddcac 100644 --- a/uv.lock +++ b/uv.lock @@ -2368,6 +2368,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, ] +[[package]] +name = "sqlparse" +version = "0.5.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/76/437d71068094df0726366574cf3432a4ed754217b436eb7429415cf2d480/sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e", size = 120815, upload-time = "2025-12-19T07:17:45.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/4b/359f28a903c13438ef59ebeee215fb25da53066db67b305c125f1c6d2a25/sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba", size = 46138, upload-time = "2025-12-19T07:17:46.573Z" }, +] + [[package]] name = "stevedore" version = "5.6.0" @@ -2548,6 +2557,7 @@ dependencies = [ { name = "python-dotenv" }, { name = "python-socketio" }, { name = "rich" }, + { name = "sqlparse" }, { name = "tenacity" }, { name = "truststore" }, { name = "uipath-core" }, @@ -2580,6 +2590,7 @@ dev = [ { name = "termynal" }, { name = "tomli-w" }, { name = "types-toml" }, + { name = "uipath" }, { name = "virtualenv" }, ] @@ -2599,6 +2610,7 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "python-socketio", specifier = ">=5.15.0,<6.0.0" }, { name = "rich", specifier = ">=14.2.0" }, + { name = "sqlparse", specifier = ">=0.4.4" }, { name = "tenacity", specifier = ">=9.0.0" }, { name = "truststore", specifier = ">=0.10.1" }, { name = "uipath-core", specifier = ">=0.4.1,<0.5.0" }, @@ -2631,6 +2643,7 @@ dev = [ { name = "termynal", specifier = ">=0.13.1" }, { name = "tomli-w", specifier = ">=1.2.0" }, { name = "types-toml", specifier = ">=0.10.8" }, + { name = "uipath", editable = "." }, { name = "virtualenv", specifier = ">=20.36.1" }, ]