diff --git a/src/python_workflow_definition/models.py b/src/python_workflow_definition/models.py index 4980cfa..8b25be8 100644 --- a/src/python_workflow_definition/models.py +++ b/src/python_workflow_definition/models.py @@ -1,5 +1,15 @@ from pathlib import Path -from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar +from typing import ( + List, + Union, + Optional, + Literal, + Any, + Annotated, + Type, + TypeVar, +) +from typing_extensions import TypeAliasType from pydantic import BaseModel, Field, field_validator, field_serializer from pydantic import ValidationError import json @@ -19,6 +29,13 @@ ) +JsonPrimitive = Union[str, int, float, bool, None] +AllowableDefaults = TypeAliasType( + "AllowableDefaults", + "Union[JsonPrimitive, dict[str, AllowableDefaults], list[AllowableDefaults]]", +) + + class PythonWorkflowDefinitionBaseNode(BaseModel): """Base model for all node types, containing common fields.""" @@ -33,7 +50,7 @@ class PythonWorkflowDefinitionInputNode(PythonWorkflowDefinitionBaseNode): type: Literal["input"] name: str - value: Optional[Any] = None + value: Optional[AllowableDefaults] = None class PythonWorkflowDefinitionOutputNode(PythonWorkflowDefinitionBaseNode): diff --git a/tests/test_models.py b/tests/test_models.py index 83f6066..d353f46 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,6 +4,7 @@ from unittest import mock from pydantic import ValidationError from python_workflow_definition.models import ( + JsonPrimitive, PythonWorkflowDefinitionInputNode, PythonWorkflowDefinitionOutputNode, PythonWorkflowDefinitionFunctionNode, @@ -12,6 +13,11 @@ INTERNAL_DEFAULT_HANDLE, ) + +class _NoTrivialSerialization: + pass + + class TestModels(unittest.TestCase): def setUp(self): self.valid_workflow_dict = { @@ -40,6 +46,50 @@ def test_input_node(self): ) self.assertEqual(node_with_value.value, 42) + def test_input_node_valid_values(self): + good_values = ( + 1, + 1.1, + "string", + True, + None, + [1, 2], + [["recursive", "tuple"], [True, False]], + ) + for value in good_values: + with self.subTest(value=value): + model = PythonWorkflowDefinitionInputNode.model_validate( + { + "id": 0, + "type": "input", + "name": "x", + "value": value, + } + ) + self.assertEqual( + value, + PythonWorkflowDefinitionInputNode.model_validate( + model.model_dump(mode="json") + ).value + ) + + def test_input_node_invalid_value_raises(self): + bad_values = ( + {1: 2}, + _NoTrivialSerialization(), + ) + for value in bad_values: + with self.subTest(value=value): + with self.assertRaises(ValidationError): + PythonWorkflowDefinitionInputNode.model_validate( + { + "id": 0, + "type": "input", + "name": "x", + "value": value, + } + ) + def test_output_node(self): node = PythonWorkflowDefinitionOutputNode(id=1, type="output", name="test_output") self.assertEqual(node.id, 1)