diff --git a/README.md b/README.md index b2856e407..33b18b248 100644 --- a/README.md +++ b/README.md @@ -42,8 +42,8 @@ DataJoint is a framework for scientific data pipelines based on the **Relational Citation - - DOI + + DOI Coverage @@ -80,7 +80,7 @@ conda install -c conda-forge datajoint - [How-To Guides](https://docs.datajoint.com/how-to/) — Task-oriented guides - [API Reference](https://docs.datajoint.com/api/) — Complete API documentation - [Migration Guide](https://docs.datajoint.com/how-to/migrate-to-v20/) — Upgrade from legacy versions -- **[DataJoint Elements](https://datajoint.com/docs/elements/)** — Example pipelines for neuroscience +- **[DataJoint Elements](https://docs.datajoint.com/elements/)** — Example pipelines for neuroscience - **[GitHub Discussions](https://github.com/datajoint/datajoint-python/discussions)** — Community support ## Contributing diff --git a/docs/design/thread-safe-mode.md b/docs/design/thread-safe-mode.md new file mode 100644 index 000000000..5d7472667 --- /dev/null +++ b/docs/design/thread-safe-mode.md @@ -0,0 +1,387 @@ +# Thread-Safe Mode Specification + +## Problem + +DataJoint uses global state (`dj.config`, `dj.conn()`) that is not thread-safe. Multi-tenant applications (web servers, async workers) need isolated connections per request/task. + +## Solution + +Introduce **Instance** objects that encapsulate config and connection. The `dj` module provides a global config that can be modified before connecting, and a lazily-loaded singleton connection. New isolated instances are created with `dj.Instance()`. + +## API + +### Legacy API (global config + singleton connection) + +```python +import datajoint as dj + +# Configure credentials (no connection yet) +dj.config.database.user = "user" +dj.config.database.password = "password" + +# First call to conn() or Schema() creates the singleton connection +dj.conn() # Creates connection using dj.config credentials +schema = dj.Schema("my_schema") + +@schema +class Mouse(dj.Manual): + definition = "..." +``` + +Alternatively, pass credentials directly to `conn()`: +```python +dj.conn(host="localhost", user="user", password="password") +``` + +Internally: +- `dj.config` → delegates to `_global_config` (with thread-safety check) +- `dj.conn()` → returns `_singleton_connection` (created lazily) +- `dj.Schema()` → uses `_singleton_connection` +- `dj.FreeTable()` → uses `_singleton_connection` + +### New API (isolated instance) + +```python +import datajoint as dj + +inst = dj.Instance( + host="localhost", + user="user", + password="password", +) +schema = inst.Schema("my_schema") + +@schema +class Mouse(dj.Manual): + definition = "..." +``` + +### Instance structure + +Each instance has: +- `inst.config` - Config (created fresh at instance creation) +- `inst.connection` - Connection (created at instance creation) +- `inst.Schema()` - Schema factory using instance's connection +- `inst.FreeTable()` - FreeTable factory using instance's connection + +```python +inst = dj.Instance(host="localhost", user="u", password="p") +inst.config # Config instance +inst.connection # Connection instance +inst.Schema("name") # Creates schema using inst.connection +inst.FreeTable("db.tbl") # Access table using inst.connection +``` + +### Table base classes vs instance methods + +**Base classes** (`dj.Manual`, `dj.Lookup`, etc.) - Used with `@schema` decorator: +```python +@schema +class Mouse(dj.Manual): # dj.Manual - schema links to connection + definition = "..." +``` + +**Instance methods** (`inst.Schema()`, `inst.FreeTable()`) - Need connection directly: +```python +schema = inst.Schema("my_schema") # Uses inst.connection +table = inst.FreeTable("db.table") # Uses inst.connection +``` + +### Thread-safe mode + +```bash +export DJ_THREAD_SAFE=true +``` + +`thread_safe` is checked dynamically on each access to global state. + +When `thread_safe=True`, accessing global state raises `ThreadSafetyError`: +- `dj.config` raises `ThreadSafetyError` +- `dj.conn()` raises `ThreadSafetyError` +- `dj.Schema()` raises `ThreadSafetyError` (without explicit connection) +- `dj.FreeTable()` raises `ThreadSafetyError` (without explicit connection) +- `dj.Instance()` works - isolated instances are always allowed + +```python +# thread_safe=True + +dj.config # ThreadSafetyError +dj.conn() # ThreadSafetyError +dj.Schema("name") # ThreadSafetyError + +inst = dj.Instance(host="h", user="u", password="p") # OK +inst.Schema("name") # OK +``` + +## Behavior Summary + +| Operation | `thread_safe=False` | `thread_safe=True` | +|-----------|--------------------|--------------------| +| `dj.config` | `_global_config` | `ThreadSafetyError` | +| `dj.conn()` | `_singleton_connection` | `ThreadSafetyError` | +| `dj.Schema()` | Uses singleton | `ThreadSafetyError` | +| `dj.FreeTable()` | Uses singleton | `ThreadSafetyError` | +| `dj.Instance()` | Works | Works | +| `inst.config` | Works | Works | +| `inst.connection` | Works | Works | +| `inst.Schema()` | Works | Works | + +## Lazy Loading + +The global config is created at module import time. The singleton connection is created lazily on first access: + +```python +dj.config.database.user = "user" # Modifies global config (no connection yet) +dj.config.database.password = "pw" +dj.conn() # Creates singleton connection using global config +dj.Schema("name") # Uses existing singleton connection +``` + +## Usage Example + +```python +import datajoint as dj + +# Create isolated instance +inst = dj.Instance( + host="localhost", + user="user", + password="password", +) + +# Create schema +schema = inst.Schema("my_schema") + +@schema +class Mouse(dj.Manual): + definition = """ + mouse_id: int + """ + +# Use tables +Mouse().insert1({"mouse_id": 1}) +Mouse().fetch() +``` + +## Architecture + +### Object graph + +There is exactly **one** global `Config` object created at import time in `settings.py`. Both the legacy API and the `Instance` API hang off `Connection` objects, each of which carries a `_config` reference. + +``` +settings.py + config = _create_config() ← THE single global Config + +instance.py + _global_config = settings.config ← same object (not a copy) + _singleton_connection = None ← lazily created Connection + +__init__.py + dj.config = _ConfigProxy() ← proxy → _global_config (with thread-safety check) + dj.conn() ← returns _singleton_connection + dj.Schema() ← uses _singleton_connection + dj.FreeTable() ← uses _singleton_connection + +Connection (singleton) + _config → _global_config ← same Config that dj.config writes to + +Connection (Instance) + _config → fresh Config ← isolated per-instance +``` + +### Config flow: singleton path + +``` +dj.config["safemode"] = False + ↓ _ConfigProxy.__setitem__ +_global_config["safemode"] = False (same object as settings.config) + ↓ +Connection._config["safemode"] (points to _global_config) + ↓ +schema.drop() reads self.connection._config["safemode"] → False ✓ +``` + +### Config flow: Instance path + +``` +inst = dj.Instance(host=..., user=..., password=...) + ↓ +inst.config = _create_config() (fresh Config, independent) +inst.connection._config = inst.config + ↓ +inst.config["safemode"] = False + ↓ +schema.drop() reads self.connection._config["safemode"] → False ✓ +``` + +### Key invariant + +**All runtime config reads go through `self.connection._config`**, never through the global `config` directly. This ensures both the singleton and Instance paths read the correct config. + +### Connection-scoped config reads + +Every module that previously imported `from .settings import config` now reads config from the connection: + +| Module | What was read | How it's read now | +|--------|--------------|-------------------| +| `schemas.py` | `config["safemode"]`, `config.database.create_tables` | `self.connection._config[...]` | +| `table.py` | `config["safemode"]` in `delete()`, `drop()` | `self.connection._config["safemode"]` | +| `expression.py` | `config["loglevel"]` in `__repr__()` | `self.connection._config["loglevel"]` | +| `preview.py` | `config["display.*"]` (8 reads) | `query_expression.connection._config[...]` | +| `autopopulate.py` | `config.jobs.allow_new_pk_fields`, `auto_refresh` | `self.connection._config.jobs.*` | +| `jobs.py` | `config.jobs.default_priority`, `stale_timeout`, `keep_completed` | `self.connection._config.jobs.*` | +| `declare.py` | `config.jobs.add_job_metadata` | `config` param (threaded from `table.py`) | +| `diagram.py` | `config.display.diagram_direction` | `self._connection._config.display.*` | +| `staged_insert.py` | `config.get_store_spec()` | `self._table.connection._config.get_store_spec()` | +| `hash_registry.py` | `config.get_store_spec()` in 5 functions | `config` kwarg (falls back to `settings.config`) | +| `builtin_codecs/hash.py` | `config` via hash_registry | `_config` from key dict → `config` kwarg to hash_registry | +| `builtin_codecs/attach.py` | `config.get("download_path")` | `_config` from key dict (falls back to `settings.config`) | +| `builtin_codecs/filepath.py` | `config.get_store_spec()` | `_config` from key dict (falls back to `settings.config`) | +| `builtin_codecs/schema.py` | `config.get_store_spec()` in helpers | `config` kwarg to `_build_path()`, `_get_backend()` | +| `builtin_codecs/npy.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers | +| `builtin_codecs/object.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers | +| `gc.py` | `config` via hash_registry | `schemas[0].connection._config` → `config` kwarg | + +### Functions that receive config as a parameter + +Some module-level functions cannot access `self.connection`. Config is threaded through: + +| Function | Caller | How config arrives | +|----------|--------|--------------------| +| `declare()` in `declare.py` | `Table.declare()` in `table.py` | `config=self.connection._config` kwarg | +| `_get_job_version()` in `jobs.py` | `AutoPopulate._make_tuples()`, `Job.reserve()` | `config=self.connection._config` positional arg | +| `get_store_backend()` in `hash_registry.py` | codecs, gc.py | `config` kwarg from key dict or schema connection | +| `get_store_subfolding()` in `hash_registry.py` | `put_hash()` | `config` kwarg chained from caller | +| `put_hash()` in `hash_registry.py` | `HashCodec.encode()` | `config` kwarg from `_config` in key dict | +| `get_hash()` in `hash_registry.py` | `HashCodec.decode()` | `config` kwarg from `_config` in key dict | +| `delete_path()` in `hash_registry.py` | `gc.collect()` | `config` kwarg from `schemas[0].connection._config` | +| `decode_attribute()` in `codecs.py` | `expression.py` fetch methods | `connection` kwarg → extracts `connection._config` | + +All functions accept `config=None` and fall back to the global `settings.config` for backward compatibility. + +## Implementation + +### 1. Create Instance class + +```python +class Instance: + def __init__(self, host, user, password, port=3306, **kwargs): + self.config = _create_config() # Fresh config with defaults + # Apply any config overrides from kwargs + self.connection = Connection(host, user, password, port, ...) + self.connection._config = self.config + + def Schema(self, name, **kwargs): + return Schema(name, connection=self.connection, **kwargs) + + def FreeTable(self, full_table_name): + return FreeTable(self.connection, full_table_name) +``` + +### 2. Global config and singleton connection + +```python +# settings.py - THE single global config +config = _create_config() # Created at import time + +# instance.py - reuses the same config object +_global_config = settings.config # Same reference, not a copy +_singleton_connection = None # Created lazily + +def _check_thread_safe(): + if _load_thread_safe(): + raise ThreadSafetyError( + "Global DataJoint state is disabled in thread-safe mode. " + "Use dj.Instance() to create an isolated instance." + ) + +def _get_singleton_connection(): + _check_thread_safe() + global _singleton_connection + if _singleton_connection is None: + _singleton_connection = Connection( + host=_global_config.database.host, + user=_global_config.database.user, + password=_global_config.database.password, + ... + ) + _singleton_connection._config = _global_config + return _singleton_connection +``` + +### 3. Legacy API with thread-safety checks + +```python +# dj.config -> global config with thread-safety check +class _ConfigProxy: + def __getattr__(self, name): + _check_thread_safe() + return getattr(_global_config, name) + def __setattr__(self, name, value): + _check_thread_safe() + setattr(_global_config, name, value) + +config = _ConfigProxy() + +# dj.conn() -> singleton connection (persistent across calls) +def conn(host=None, user=None, password=None, *, reset=False): + _check_thread_safe() + if reset or (_singleton_connection is None and credentials_provided): + _singleton_connection = Connection(...) + _singleton_connection._config = _global_config + return _get_singleton_connection() + +# dj.Schema() -> uses singleton connection +def Schema(name, connection=None, **kwargs): + if connection is None: + _check_thread_safe() + connection = _get_singleton_connection() + return _Schema(name, connection=connection, **kwargs) + +# dj.FreeTable() -> uses singleton connection +def FreeTable(conn_or_name, full_table_name=None): + if full_table_name is None: + _check_thread_safe() + return _FreeTable(_get_singleton_connection(), conn_or_name) + else: + return _FreeTable(conn_or_name, full_table_name) +``` + +## Global State Audit + +All module-level mutable state was reviewed for thread-safety implications. + +### Guarded (blocked in thread-safe mode) + +| State | Location | Mechanism | +|-------|----------|-----------| +| `config` singleton | `settings.py:979` | `_ConfigProxy` raises `ThreadSafetyError`; use `inst.config` instead | +| `conn()` singleton | `connection.py:108` | `_check_thread_safe()` guard; use `inst.connection` instead | + +These are the two globals that carry connection-scoped state (credentials, database settings) and are the primary source of cross-tenant interference. + +### Safe by design (no guard needed) + +| State | Location | Rationale | +|-------|----------|-----------| +| `_codec_registry` | `codecs.py:47` | Effectively immutable after import. Registration runs in `__init_subclass__` under Python's import lock. Runtime mutation (`_load_entry_points`) is idempotent under the GIL. Codecs are part of the type system, not connection-scoped. | +| `_entry_points_loaded` | `codecs.py:48` | Bool flag for idempotent lazy loading; worst case under concurrent access is redundant work, not corruption. | + +### Low risk (no guard needed) + +| State | Location | Rationale | +|-------|----------|-----------| +| Logging side effects | `logging.py:8,17,40-45,56` | Standard Python logging configuration. Monkey-patches `Logger` and replaces `sys.excepthook` at import time. Not DataJoint-specific mutable state. | +| `use_32bit_dims` | `blob.py:65` | Runtime flag affecting deserialization. Rarely changed; not connection-scoped. | +| `compression` dict | `blob.py:61` | Decompressor function registry. Populated at import time, effectively read-only thereafter. | +| `_lazy_modules` | `__init__.py:92` | Import caching via `globals()` mutation. Protected by Python's import lock. | +| `ADAPTERS` dict | `adapters/__init__.py:16` | Backend registry. Populated at import time, read-only in practice. | + +### Design principle + +Only state that is **connection-scoped** (credentials, database settings, connection objects) needs thread-safe guards. State that is **code-scoped** (type registries, import caches, logging configuration) is shared across all threads by design and does not vary between tenants. + +## Error Messages + +- Singleton access: `"Global DataJoint state is disabled in thread-safe mode. Use dj.Instance() to create an isolated instance."` diff --git a/src/datajoint/__init__.py b/src/datajoint/__init__.py index 7f809487d..7704ec1bc 100644 --- a/src/datajoint/__init__.py +++ b/src/datajoint/__init__.py @@ -23,6 +23,7 @@ "config", "conn", "Connection", + "Instance", "Schema", "VirtualModule", "virtual_schema", @@ -52,6 +53,7 @@ "errors", "migrate", "DataJointError", + "ThreadSafetyError", "logger", "cli", "ValidationResult", @@ -72,17 +74,191 @@ NpyRef, ) from .blob import MatCell, MatStruct -from .connection import Connection, conn -from .errors import DataJointError +from .connection import Connection +from .errors import DataJointError, ThreadSafetyError from .expression import AndList, Not, Top, U +from .instance import Instance, _ConfigProxy, _get_singleton_connection, _global_config, _check_thread_safe from .logging import logger from .objectref import ObjectRef -from .schemas import Schema, VirtualModule, list_schemas, virtual_schema -from .settings import config -from .table import FreeTable, Table, ValidationResult +from .schemas import _Schema, VirtualModule, list_schemas, virtual_schema +from .table import FreeTable as _FreeTable, Table, ValidationResult from .user_tables import Computed, Imported, Lookup, Manual, Part from .version import __version__ +# ============================================================================= +# Singleton-aware API +# ============================================================================= +# config is a proxy that delegates to the singleton instance's config +config = _ConfigProxy() + + +def conn( + host: str | None = None, + user: str | None = None, + password: str | None = None, + *, + reset: bool = False, + use_tls: bool | dict | None = None, +) -> Connection: + """ + Return a persistent connection object. + + When called without arguments, returns the singleton connection using + credentials from dj.config. When connection parameters are provided, + updates the singleton connection with the new credentials. + + Parameters + ---------- + host : str, optional + Database hostname. If provided, updates singleton. + user : str, optional + Database username. If provided, updates singleton. + password : str, optional + Database password. If provided, updates singleton. + reset : bool, optional + If True, reset existing connection. Default False. + use_tls : bool or dict, optional + TLS encryption option. + + Returns + ------- + Connection + Database connection. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled. + """ + import datajoint.instance as instance_module + from pydantic import SecretStr + + _check_thread_safe() + + # If reset requested, always recreate + # If credentials provided and no singleton exists, create one + # If credentials provided and singleton exists, return existing singleton + if reset or ( + instance_module._singleton_connection is None and (host is not None or user is not None or password is not None) + ): + # Use provided values or fall back to config + host = host if host is not None else _global_config.database.host + user = user if user is not None else _global_config.database.user + raw_password = password if password is not None else _global_config.database.password + password = raw_password.get_secret_value() if isinstance(raw_password, SecretStr) else raw_password + port = _global_config.database.port + use_tls = use_tls if use_tls is not None else _global_config.database.use_tls + + if user is None: + from .errors import DataJointError + + raise DataJointError("Database user not configured. Set dj.config['database.user'] or pass user= argument.") + if password is None: + from .errors import DataJointError + + raise DataJointError( + "Database password not configured. Set dj.config['database.password'] or pass password= argument." + ) + + instance_module._singleton_connection = Connection(host, user, password, port, use_tls, config_override=_global_config) + + return _get_singleton_connection() + + +class Schema(_Schema): + """ + Decorator that binds table classes to a database schema. + + When connection is not provided, uses the singleton connection. + In thread-safe mode (``DJ_THREAD_SAFE=true``), a connection must be + provided explicitly or use ``dj.Instance().Schema()`` instead. + + Parameters + ---------- + schema_name : str, optional + Database schema name. If omitted, call ``activate()`` later. + context : dict, optional + Namespace for foreign key lookup. None uses caller's context. + connection : Connection, optional + Database connection. Defaults to singleton connection. + create_schema : bool, optional + If False, raise error if schema doesn't exist. Default True. + create_tables : bool, optional + If False, raise error when accessing missing tables. + add_objects : dict, optional + Additional objects for declaration context. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled and no connection is provided. + + Examples + -------- + >>> schema = dj.Schema('my_schema') + >>> @schema + ... class Session(dj.Manual): + ... definition = ''' + ... session_id : int + ... ''' + """ + + def __init__( + self, + schema_name: str | None = None, + context: dict | None = None, + *, + connection: Connection | None = None, + create_schema: bool = True, + create_tables: bool | None = None, + add_objects: dict | None = None, + ) -> None: + if connection is None: + _check_thread_safe() + super().__init__( + schema_name, + context=context, + connection=connection, + create_schema=create_schema, + create_tables=create_tables, + add_objects=add_objects, + ) + + +def FreeTable(conn_or_name, full_table_name: str | None = None) -> _FreeTable: + """ + Create a FreeTable for accessing a table without a dedicated class. + + Can be called in two ways: + - ``FreeTable("schema.table")`` - uses singleton connection + - ``FreeTable(connection, "schema.table")`` - uses provided connection + + Parameters + ---------- + conn_or_name : Connection or str + Either a Connection object, or the full table name if using singleton. + full_table_name : str, optional + Full table name when first argument is a connection. + + Returns + ------- + FreeTable + A FreeTable instance for the specified table. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled and using singleton. + """ + if full_table_name is None: + # Called as FreeTable("db.table") - use singleton connection + _check_thread_safe() + return _FreeTable(_get_singleton_connection(), conn_or_name) + else: + # Called as FreeTable(conn, "db.table") - use provided connection + return _FreeTable(conn_or_name, full_table_name) + + # ============================================================================= # Lazy imports — heavy dependencies loaded on first access # ============================================================================= diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py index 35b32ed5f..011f306ab 100644 --- a/src/datajoint/adapters/base.py +++ b/src/datajoint/adapters/base.py @@ -169,6 +169,27 @@ def quote_identifier(self, name: str) -> str: """ ... + @abstractmethod + def split_full_table_name(self, full_table_name: str) -> tuple[str, str]: + """ + Split a fully-qualified table name into schema and table components. + + Inverse of quoting: strips backend-specific identifier quotes + and splits into (schema, table). + + Parameters + ---------- + full_table_name : str + Quoted full table name (e.g., ```\\`schema\\`.\\`table\\` ``` or + ``"schema"."table"``). + + Returns + ------- + tuple[str, str] + (schema_name, table_name) with quotes stripped. + """ + ... + @abstractmethod def quote_string(self, value: str) -> str: """ @@ -615,6 +636,23 @@ def list_schemas_sql(self) -> str: """ ... + @abstractmethod + def schema_exists_sql(self, schema_name: str) -> str: + """ + Generate query to check if a schema exists. + + Parameters + ---------- + schema_name : str + Name of schema to check. + + Returns + ------- + str + SQL query that returns a row if the schema exists. + """ + ... + @abstractmethod def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: """ @@ -710,6 +748,55 @@ def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: """ ... + @abstractmethod + def load_primary_keys_sql(self, schemas_list: str, like_pattern: str) -> str: + """ + Generate query to load primary key columns for all tables across schemas. + + Used by the dependency graph to build the schema graph. + + Parameters + ---------- + schemas_list : str + Comma-separated, quoted schema names for an IN clause. + like_pattern : str + SQL LIKE pattern to exclude (e.g., "'~%%'" for internal tables). + + Returns + ------- + str + SQL query returning rows with columns: + - tab: fully qualified table name (quoted) + - column_name: primary key column name + """ + ... + + @abstractmethod + def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str: + """ + Generate query to load foreign key relationships across schemas. + + Used by the dependency graph to build the schema graph. + + Parameters + ---------- + schemas_list : str + Comma-separated, quoted schema names for an IN clause. + like_pattern : str + SQL LIKE pattern to exclude (e.g., "'~%%'" for internal tables). + + Returns + ------- + str + SQL query returning rows (as dicts) with columns: + - constraint_name: FK constraint name + - referencing_table: fully qualified child table name (quoted) + - referenced_table: fully qualified parent table name (quoted) + - column_name: FK column in child table + - referenced_column_name: referenced column in parent table + """ + ... + @abstractmethod def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py index 88339335f..3c28a85e6 100644 --- a/src/datajoint/adapters/mysql.py +++ b/src/datajoint/adapters/mysql.py @@ -75,7 +75,6 @@ def connect( Password for authentication. **kwargs : Any Additional MySQL-specific parameters: - - init_command: SQL initialization command - ssl: TLS/SSL configuration dict (deprecated, use use_tls) - use_tls: bool or dict - DataJoint's SSL parameter (preferred) - charset: Character set (default from kwargs) @@ -85,7 +84,6 @@ def connect( pymysql.Connection MySQL connection object. """ - init_command = kwargs.get("init_command") # Handle both ssl (old) and use_tls (new) parameter names ssl_config = kwargs.get("use_tls", kwargs.get("ssl")) # Convert boolean True to dict for PyMySQL (PyMySQL expects dict or SSLContext) @@ -99,7 +97,6 @@ def connect( "port": port, "user": user, "passwd": password, - "init_command": init_command, "sql_mode": "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", "charset": charset, @@ -203,6 +200,11 @@ def quote_identifier(self, name: str) -> str: """ return f"`{name}`" + def split_full_table_name(self, full_table_name: str) -> tuple[str, str]: + """Split ```\\`schema\\`.\\`table\\` ``` into ``('schema', 'table')``.""" + schema, table = full_table_name.replace("`", "").split(".") + return schema, table + def quote_string(self, value: str) -> str: """ Quote string literal for MySQL with escaping. @@ -614,6 +616,10 @@ def list_schemas_sql(self) -> str: """Query to list all databases in MySQL.""" return "SELECT schema_name FROM information_schema.schemata" + def schema_exists_sql(self, schema_name: str) -> str: + """Query to check if a database exists in MySQL.""" + return f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = {self.quote_string(schema_name)}" + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: """Query to list tables in a database.""" sql = f"SHOW TABLES IN {self.quote_identifier(schema_name)}" @@ -655,6 +661,32 @@ def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: f"ORDER BY constraint_name, ordinal_position" ) + def load_primary_keys_sql(self, schemas_list: str, like_pattern: str) -> str: + """Query to load all primary key columns across schemas.""" + tab_expr = "concat('`', table_schema, '`.`', table_name, '`')" + return ( + f"SELECT {tab_expr} as tab, column_name " + f"FROM information_schema.key_column_usage " + f"WHERE table_name NOT LIKE {like_pattern} " + f"AND table_schema in ({schemas_list}) " + f"AND constraint_name='PRIMARY'" + ) + + def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str: + """Query to load all foreign key relationships across schemas.""" + tab_expr = "concat('`', table_schema, '`.`', table_name, '`')" + ref_tab_expr = "concat('`', referenced_table_schema, '`.`', referenced_table_name, '`')" + return ( + f"SELECT constraint_name, " + f"{tab_expr} as referencing_table, " + f"{ref_tab_expr} as referenced_table, " + f"column_name, referenced_column_name " + f"FROM information_schema.key_column_usage " + f"WHERE referenced_table_name NOT LIKE {like_pattern} " + f"AND (referenced_table_schema in ({schemas_list}) " + f"OR referenced_table_schema is not NULL AND table_schema in ({schemas_list}))" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """Query to get FK constraint details from information_schema.""" return ( diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py index 12fecae6a..2caebef75 100644 --- a/src/datajoint/adapters/postgres.py +++ b/src/datajoint/adapters/postgres.py @@ -249,6 +249,11 @@ def quote_identifier(self, name: str) -> str: """ return f'"{name}"' + def split_full_table_name(self, full_table_name: str) -> tuple[str, str]: + """Split ``"schema"."table"`` into ``('schema', 'table')``.""" + schema, table = full_table_name.replace('"', "").split(".") + return schema, table + def quote_string(self, value: str) -> str: """ Quote string literal for PostgreSQL with escaping. @@ -721,6 +726,10 @@ def list_schemas_sql(self) -> str: "WHERE schema_name NOT IN ('pg_catalog', 'information_schema')" ) + def schema_exists_sql(self, schema_name: str) -> str: + """Query to check if a schema exists in PostgreSQL.""" + return f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = {self.quote_string(schema_name)}" + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: """Query to list tables in a schema.""" sql = ( @@ -795,6 +804,44 @@ def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: f"ORDER BY kcu.constraint_name, kcu.ordinal_position" ) + def load_primary_keys_sql(self, schemas_list: str, like_pattern: str) -> str: + """Query to load all primary key columns across schemas.""" + tab_expr = "'\"' || kcu.table_schema || '\".\"' || kcu.table_name || '\"'" + return ( + f"SELECT {tab_expr} as tab, kcu.column_name " + f"FROM information_schema.key_column_usage kcu " + f"JOIN information_schema.table_constraints tc " + f"ON kcu.constraint_name = tc.constraint_name " + f"AND kcu.table_schema = tc.table_schema " + f"WHERE kcu.table_name NOT LIKE {like_pattern} " + f"AND kcu.table_schema in ({schemas_list}) " + f"AND tc.constraint_type = 'PRIMARY KEY'" + ) + + def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str: + """Query to load all foreign key relationships across schemas.""" + return ( + f"SELECT " + f"c.conname as constraint_name, " + f"'\"' || ns1.nspname || '\".\"' || cl1.relname || '\"' as referencing_table, " + f"'\"' || ns2.nspname || '\".\"' || cl2.relname || '\"' as referenced_table, " + f"a1.attname as column_name, " + f"a2.attname as referenced_column_name " + f"FROM pg_constraint c " + f"JOIN pg_class cl1 ON c.conrelid = cl1.oid " + f"JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid " + f"JOIN pg_class cl2 ON c.confrelid = cl2.oid " + f"JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid " + f"CROSS JOIN LATERAL unnest(c.conkey, c.confkey) WITH ORDINALITY AS cols(conkey, confkey, ord) " + f"JOIN pg_attribute a1 ON a1.attrelid = cl1.oid AND a1.attnum = cols.conkey " + f"JOIN pg_attribute a2 ON a2.attrelid = cl2.oid AND a2.attnum = cols.confkey " + f"WHERE c.contype = 'f' " + f"AND cl1.relname NOT LIKE {like_pattern} " + f"AND (ns2.nspname in ({schemas_list}) " + f"OR ns1.nspname in ({schemas_list})) " + f"ORDER BY c.conname, cols.ord" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ Query to get FK constraint details from information_schema. diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 7660e43ec..ae8be3b82 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -146,10 +146,8 @@ def _declare_check(self, primary_key: list[str], fk_attribute_map: dict[str, tup If native (non-FK) PK attributes are found, unless bypassed via ``dj.config.jobs.allow_new_pk_fields_in_computed_tables = True``. """ - from .settings import config - # Check if validation is bypassed - if config.jobs.allow_new_pk_fields_in_computed_tables: + if self.connection._config.jobs.allow_new_pk_fields_in_computed_tables: return # Check for native (non-FK) primary key attributes @@ -477,8 +475,6 @@ def _populate_distributed( """ from tqdm import tqdm - from .settings import config - # Define a signal handler for SIGTERM def handler(signum, frame): logger.info("Populate terminated by SIGTERM") @@ -489,7 +485,7 @@ def handler(signum, frame): try: # Refresh job queue if configured if refresh is None: - refresh = config.jobs.auto_refresh + refresh = self.connection._config.jobs.auto_refresh if refresh: # Use delay=-1 to ensure jobs are immediately schedulable # (avoids race condition with scheduled_time <= CURRENT_TIMESTAMP(3) check) @@ -659,7 +655,7 @@ def _populate1( key, start_time=datetime.datetime.fromtimestamp(start_time), duration=duration, - version=_get_job_version(), + version=_get_job_version(self.connection._config), ) if jobs is not None: diff --git a/src/datajoint/builtin_codecs/attach.py b/src/datajoint/builtin_codecs/attach.py index f9a454b1a..aa10f2424 100644 --- a/src/datajoint/builtin_codecs/attach.py +++ b/src/datajoint/builtin_codecs/attach.py @@ -98,14 +98,15 @@ def decode(self, stored: bytes, *, key: dict | None = None) -> str: """ from pathlib import Path - from ..settings import config - # Split on first null byte null_pos = stored.index(b"\x00") filename = stored[:null_pos].decode("utf-8") contents = stored[null_pos + 1 :] # Write to download path + config = (key or {}).get("_config") + if config is None: + from ..settings import config download_path = Path(config.get("download_path", ".")) download_path.mkdir(parents=True, exist_ok=True) local_path = download_path / filename diff --git a/src/datajoint/builtin_codecs/filepath.py b/src/datajoint/builtin_codecs/filepath.py index 9c05b2385..a0400499b 100644 --- a/src/datajoint/builtin_codecs/filepath.py +++ b/src/datajoint/builtin_codecs/filepath.py @@ -98,9 +98,12 @@ def encode(self, value: Any, *, key: dict | None = None, store_name: str | None """ from datetime import datetime, timezone - from .. import config from ..hash_registry import get_store_backend + config = (key or {}).get("_config") + if config is None: + from ..settings import config + path = str(value) # Get store spec to check prefix configuration @@ -137,7 +140,7 @@ def encode(self, value: Any, *, key: dict | None = None, store_name: str | None raise ValueError(f" must use prefix '{filepath_prefix}' (filepath_prefix). Got path: {path}") # Verify file exists - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) if not backend.exists(path): raise FileNotFoundError(f"File not found in store '{store_name or 'default'}': {path}") @@ -174,8 +177,9 @@ def decode(self, stored: dict, *, key: dict | None = None) -> Any: from ..objectref import ObjectRef from ..hash_registry import get_store_backend + config = (key or {}).get("_config") store_name = stored.get("store") - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) return ObjectRef.from_json(stored, backend=backend) def validate(self, value: Any) -> None: diff --git a/src/datajoint/builtin_codecs/hash.py b/src/datajoint/builtin_codecs/hash.py index 676c1916f..bb3a3852f 100644 --- a/src/datajoint/builtin_codecs/hash.py +++ b/src/datajoint/builtin_codecs/hash.py @@ -76,7 +76,8 @@ def encode(self, value: bytes, *, key: dict | None = None, store_name: str | Non from ..hash_registry import put_hash schema_name = (key or {}).get("_schema", "unknown") - return put_hash(value, schema_name=schema_name, store_name=store_name) + config = (key or {}).get("_config") + return put_hash(value, schema_name=schema_name, store_name=store_name, config=config) def decode(self, stored: dict, *, key: dict | None = None) -> bytes: """ @@ -96,7 +97,8 @@ def decode(self, stored: dict, *, key: dict | None = None) -> bytes: """ from ..hash_registry import get_hash - return get_hash(stored) + config = (key or {}).get("_config") + return get_hash(stored, config=config) def validate(self, value: Any) -> None: """Validate that value is bytes.""" diff --git a/src/datajoint/builtin_codecs/npy.py b/src/datajoint/builtin_codecs/npy.py index 51c5731ee..54853437b 100644 --- a/src/datajoint/builtin_codecs/npy.py +++ b/src/datajoint/builtin_codecs/npy.py @@ -336,9 +336,10 @@ def encode( # Extract context using inherited helper schema, table, field, primary_key = self._extract_context(key) + config = (key or {}).get("_config") # Build schema-addressed storage path - path, _ = self._build_path(schema, table, field, primary_key, ext=".npy", store_name=store_name) + path, _ = self._build_path(schema, table, field, primary_key, ext=".npy", store_name=store_name, config=config) # Serialize to .npy format buffer = io.BytesIO() @@ -346,7 +347,7 @@ def encode( npy_bytes = buffer.getvalue() # Upload to storage using inherited helper - backend = self._get_backend(store_name) + backend = self._get_backend(store_name, config=config) backend.put_buffer(npy_bytes, path) # Return metadata (includes numpy-specific shape/dtype) @@ -373,5 +374,6 @@ def decode(self, stored: dict, *, key: dict | None = None) -> NpyRef: NpyRef Lazy array reference with metadata access and numpy integration. """ - backend = self._get_backend(stored.get("store")) + config = (key or {}).get("_config") + backend = self._get_backend(stored.get("store"), config=config) return NpyRef(stored, backend) diff --git a/src/datajoint/builtin_codecs/object.py b/src/datajoint/builtin_codecs/object.py index 268651aea..1c0d8c673 100644 --- a/src/datajoint/builtin_codecs/object.py +++ b/src/datajoint/builtin_codecs/object.py @@ -104,6 +104,7 @@ def encode( # Extract context using inherited helper schema, table, field, primary_key = self._extract_context(key) + config = (key or {}).get("_config") # Check for pre-computed metadata (from staged insert) if isinstance(value, dict) and "path" in value: @@ -145,10 +146,10 @@ def encode( raise TypeError(f" expects bytes or path, got {type(value).__name__}") # Build storage path using inherited helper - path, token = self._build_path(schema, table, field, primary_key, ext=ext, store_name=store_name) + path, token = self._build_path(schema, table, field, primary_key, ext=ext, store_name=store_name, config=config) # Get storage backend using inherited helper - backend = self._get_backend(store_name) + backend = self._get_backend(store_name, config=config) # Upload content if is_dir: @@ -192,7 +193,8 @@ def decode(self, stored: dict, *, key: dict | None = None) -> Any: """ from ..objectref import ObjectRef - backend = self._get_backend(stored.get("store")) + config = (key or {}).get("_config") + backend = self._get_backend(stored.get("store"), config=config) return ObjectRef.from_json(stored, backend=backend) def validate(self, value: Any) -> None: diff --git a/src/datajoint/builtin_codecs/schema.py b/src/datajoint/builtin_codecs/schema.py index 18bd62d00..c8cc0759d 100644 --- a/src/datajoint/builtin_codecs/schema.py +++ b/src/datajoint/builtin_codecs/schema.py @@ -108,6 +108,7 @@ def _build_path( primary_key: dict, ext: str | None = None, store_name: str | None = None, + config=None, ) -> tuple[str, str]: """ Build schema-addressed storage path. @@ -131,6 +132,8 @@ def _build_path( File extension (e.g., ".npy", ".zarr"). store_name : str, optional Store name for retrieving partition configuration. + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- @@ -139,7 +142,9 @@ def _build_path( is a unique identifier. """ from ..storage import build_object_path - from .. import config + + if config is None: + from ..settings import config # Get store configuration for partition_pattern and token_length spec = config.get_store_spec(store_name) @@ -156,7 +161,7 @@ def _build_path( token_length=token_length, ) - def _get_backend(self, store_name: str | None = None): + def _get_backend(self, store_name: str | None = None, config=None): """ Get storage backend by name. @@ -164,6 +169,8 @@ def _get_backend(self, store_name: str | None = None): ---------- store_name : str, optional Store name. If None, returns default store. + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- @@ -172,4 +179,4 @@ def _get_backend(self, store_name: str | None = None): """ from ..hash_registry import get_store_backend - return get_store_backend(store_name) + return get_store_backend(store_name, config=config) diff --git a/src/datajoint/codecs.py b/src/datajoint/codecs.py index 5c192d46e..d7fbaf42d 100644 --- a/src/datajoint/codecs.py +++ b/src/datajoint/codecs.py @@ -43,7 +43,15 @@ class MyTable(dj.Manual): logger = logging.getLogger(__name__.split(".")[0]) -# Global codec registry - maps name to Codec instance +# Global codec registry - maps name to Codec instance. +# +# Thread safety: This registry is effectively immutable after import. +# Registration happens in __init_subclass__ during class definition, which is +# serialized by Python's import lock. The only runtime mutation is +# _load_entry_points(), which is idempotent and guarded by a bool flag; +# under CPython's GIL, concurrent calls may do redundant work but cannot +# corrupt the dict. Codecs are part of the type system (tied to code, not to +# any particular connection or tenant), so per-instance isolation is unnecessary. _codec_registry: dict[str, Codec] = {} _entry_points_loaded: bool = False @@ -507,7 +515,7 @@ def lookup_codec(codec_spec: str) -> tuple[Codec, str | None]: # ============================================================================= -def decode_attribute(attr, data, squeeze: bool = False): +def decode_attribute(attr, data, squeeze: bool = False, connection=None): """ Decode raw database value using attribute's codec or native type handling. @@ -520,6 +528,8 @@ def decode_attribute(attr, data, squeeze: bool = False): attr: Attribute from the table's heading. data: Raw value fetched from the database. squeeze: If True, remove singleton dimensions from numpy arrays. + connection: Connection instance for config access. If provided, + ``connection._config`` is passed to codecs via the key dict. Returns: Decoded Python value. @@ -552,9 +562,14 @@ def decode_attribute(attr, data, squeeze: bool = False): elif final_dtype.lower() == "binary(16)": data = uuid_module.UUID(bytes=data) + # Build decode key with config if connection is available + decode_key = None + if connection is not None: + decode_key = {"_config": connection._config} + # Apply decoders in reverse order: innermost first, then outermost for codec in reversed(type_chain): - data = codec.decode(data, key=None) + data = codec.decode(data, key=decode_key) # Squeeze arrays if requested if squeeze and isinstance(data, np.ndarray): diff --git a/src/datajoint/condition.py b/src/datajoint/condition.py index 0335d6adb..55f095246 100644 --- a/src/datajoint/condition.py +++ b/src/datajoint/condition.py @@ -244,6 +244,13 @@ def assert_join_compatibility( if isinstance(expr1, U) or isinstance(expr2, U): return + # Check that both expressions use the same connection + if expr1.connection is not expr2.connection: + raise DataJointError( + "Cannot operate on expressions from different connections. " + "Ensure both operands use the same dj.Instance or global connection." + ) + if semantic_check: # Check if lineage tracking is available for both expressions if not expr1.heading.lineage_available or not expr2.heading.lineage_available: diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index 21b48e638..e9eab0921 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -11,13 +11,16 @@ import re import warnings from contextlib import contextmanager -from typing import Callable +from typing import TYPE_CHECKING from . import errors from .adapters import get_adapter from .blob import pack, unpack from .dependencies import Dependencies from .settings import config + +if TYPE_CHECKING: + from .settings import Config from .version import __version__ logger = logging.getLogger(__name__.split(".")[0]) @@ -55,7 +58,6 @@ def conn( user: str | None = None, password: str | None = None, *, - init_fun: Callable | None = None, reset: bool = False, use_tls: bool | dict | None = None, ) -> Connection: @@ -73,8 +75,6 @@ def conn( Database username. Required if not set in config. password : str, optional Database password. Required if not set in config. - init_fun : callable, optional - Initialization function called after connection. reset : bool, optional If True, reset existing connection. Default False. use_tls : bool or dict, optional @@ -103,9 +103,8 @@ def conn( raise errors.DataJointError( "Database password not configured. Set datajoint.config['database.password'] or pass password= argument." ) - init_fun = init_fun if init_fun is not None else config["connection.init_function"] use_tls = use_tls if use_tls is not None else config["database.use_tls"] - conn.connection = Connection(host, user, password, None, init_fun, use_tls) + conn.connection = Connection(host, user, password, None, use_tls) return conn.connection @@ -150,8 +149,6 @@ class Connection: Database password. port : int, optional Port number. Overridden if specified in host. - init_fun : str, optional - SQL initialization command. use_tls : bool or dict, optional TLS encryption option. @@ -169,15 +166,20 @@ def __init__( user: str, password: str, port: int | None = None, - init_fun: str | None = None, use_tls: bool | dict | None = None, + *, + backend: str | None = None, + config_override: "Config | None" = None, ) -> None: + # Config reference — use override if provided, else global config + self._config = config_override if config_override is not None else config + if ":" in host: # the port in the hostname overrides the port argument host, port = host.split(":") port = int(port) elif port is None: - port = config["database.port"] + port = self._config["database.port"] self.conn_info = dict(host=host, port=port, user=user, passwd=password) if use_tls is not False: # use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config) @@ -190,13 +192,13 @@ def __init__( # use_tls=True: enable SSL with default settings self.conn_info["ssl"] = True self.conn_info["ssl_input"] = use_tls - self.init_fun = init_fun self._conn = None self._query_cache = None self._is_closed = True # Mark as closed until connect() succeeds - # Select adapter based on configured backend - backend = config["database.backend"] + # Select adapter: explicit backend > config backend + if backend is None: + backend = self._config["database.backend"] self.adapter = get_adapter(backend) self.connect() @@ -227,8 +229,7 @@ def connect(self) -> None: port=self.conn_info["port"], user=self.conn_info["user"], password=self.conn_info["passwd"], - init_command=self.init_fun, - charset=config["connection.charset"], + charset=self._config["connection.charset"], use_tls=self.conn_info.get("ssl"), ) except Exception as ssl_error: @@ -244,8 +245,7 @@ def connect(self) -> None: port=self.conn_info["port"], user=self.conn_info["user"], password=self.conn_info["passwd"], - init_command=self.init_fun, - charset=config["connection.charset"], + charset=self._config["connection.charset"], use_tls=False, # Explicitly disable SSL for fallback ) else: @@ -271,8 +271,8 @@ def set_query_cache(self, query_cache: str | None = None) -> None: def purge_query_cache(self) -> None: """Delete all cached query results.""" - if isinstance(config.get(cache_key), str) and pathlib.Path(config[cache_key]).is_dir(): - for path in pathlib.Path(config[cache_key]).iterdir(): + if isinstance(self._config.get(cache_key), str) and pathlib.Path(self._config[cache_key]).is_dir(): + for path in pathlib.Path(self._config[cache_key]).iterdir(): if not path.is_dir(): path.unlink() @@ -413,11 +413,11 @@ def query( if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query): raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.") if use_query_cache: - if not config[cache_key]: + if not self._config[cache_key]: raise errors.DataJointError(f"Provide filepath dj.config['{cache_key}'] when using query caching.") # Cache key is backend-specific (no identifier normalization needed) hash_ = hashlib.md5((str(self._query_cache)).encode() + pack(args) + query.encode()).hexdigest() - cache_path = pathlib.Path(config[cache_key]) / str(hash_) + cache_path = pathlib.Path(self._config[cache_key]) / str(hash_) try: buffer = cache_path.read_bytes() except FileNotFoundError: @@ -426,7 +426,7 @@ def query( return EmulatedCursor(unpack(buffer)) if reconnect is None: - reconnect = config["database.reconnect"] + reconnect = self._config["database.reconnect"] logger.debug("Executing SQL:" + query[:query_log_max_length]) cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict) try: diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index 375daa07e..6af24ae55 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -15,7 +15,6 @@ from .codecs import lookup_codec from .condition import translate_attribute from .errors import DataJointError -from .settings import config # Core DataJoint types - scientist-friendly names that are fully supported # These are recorded in field comments using :type: syntax for reconstruction @@ -295,12 +294,9 @@ def compile_foreign_key( # ref.support[0] may have cached quoting from a different backend # Extract database and table name and rebuild with current adapter parent_full_name = ref.support[0] - # Try to parse as database.table (with or without quotes) - parts = parent_full_name.replace('"', "").replace("`", "").split(".") - if len(parts) == 2: - ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}" - else: - ref_table_name = adapter.quote_identifier(parts[0]) + # Parse as database.table using the adapter's quoting convention + parts = adapter.split_full_table_name(parent_full_name) + ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}" foreign_key_sql.append( f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT" @@ -401,7 +397,7 @@ def prepare_declare( def declare( - full_table_name: str, definition: str, context: dict, adapter + full_table_name: str, definition: str, context: dict, adapter, *, config=None ) -> tuple[str, list[str], list[str], dict[str, tuple[str, str]], list[str], list[str]]: r""" Parse a definition and generate SQL CREATE TABLE statement. @@ -416,6 +412,8 @@ def declare( Namespace for resolving foreign key references. adapter : DatabaseAdapter Database adapter for backend-specific SQL generation. + config : Config, optional + Configuration object. If None, falls back to global config. Returns ------- @@ -464,6 +462,10 @@ def declare( ) = prepare_declare(definition, context, adapter) # Add hidden job metadata for Computed/Imported tables (not parts) + if config is None: + from .settings import config as _config + + config = _config if config.jobs.add_job_metadata: # Check if this is a Computed (__) or Imported (_) table, but not a Part (contains __ in middle) is_computed = table_name.startswith("__") and "__" not in table_name[2:] diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py index 83162a112..99556345e 100644 --- a/src/datajoint/dependencies.py +++ b/src/datajoint/dependencies.py @@ -164,92 +164,21 @@ def load(self, force: bool = True) -> None: # Build schema list for IN clause schemas_list = ", ".join(adapter.quote_string(s) for s in self._conn.schemas) - # Backend-specific queries for primary keys and foreign keys - # Note: Both PyMySQL and psycopg2 use %s placeholders, so escape % as %% + # Load primary keys and foreign keys via adapter methods + # Note: Both PyMySQL and psycopg use %s placeholders, so escape % as %% like_pattern = "'~%%'" - if adapter.backend == "mysql": - # MySQL: use concat() and MySQL-specific information_schema columns - tab_expr = "concat('`', table_schema, '`.`', table_name, '`')" - - # load primary key info (MySQL uses constraint_name='PRIMARY') - keys = self._conn.query( - f""" - SELECT {tab_expr} as tab, column_name - FROM information_schema.key_column_usage - WHERE table_name NOT LIKE {like_pattern} - AND table_schema in ({schemas_list}) - AND constraint_name='PRIMARY' - """ - ) - pks = defaultdict(set) - for key in keys: - pks[key[0]].add(key[1]) - - # load foreign keys (MySQL has referenced_* columns) - ref_tab_expr = "concat('`', referenced_table_schema, '`.`', referenced_table_name, '`')" - fk_keys = self._conn.query( - f""" - SELECT constraint_name, - {tab_expr} as referencing_table, - {ref_tab_expr} as referenced_table, - column_name, referenced_column_name - FROM information_schema.key_column_usage - WHERE referenced_table_name NOT LIKE {like_pattern} - AND (referenced_table_schema in ({schemas_list}) - OR referenced_table_schema is not NULL AND table_schema in ({schemas_list})) - """, - as_dict=True, - ) - else: - # PostgreSQL: use || concatenation and different query structure - tab_expr = "'\"' || kcu.table_schema || '\".\"' || kcu.table_name || '\"'" - - # load primary key info (PostgreSQL uses constraint_type='PRIMARY KEY') - keys = self._conn.query( - f""" - SELECT {tab_expr} as tab, kcu.column_name - FROM information_schema.key_column_usage kcu - JOIN information_schema.table_constraints tc - ON kcu.constraint_name = tc.constraint_name - AND kcu.table_schema = tc.table_schema - WHERE kcu.table_name NOT LIKE {like_pattern} - AND kcu.table_schema in ({schemas_list}) - AND tc.constraint_type = 'PRIMARY KEY' - """ - ) - pks = defaultdict(set) - for key in keys: - pks[key[0]].add(key[1]) - - # load foreign keys using pg_constraint system catalogs - # The information_schema approach creates a Cartesian product for composite FKs - # because constraint_column_usage doesn't have ordinal_position. - # Using pg_constraint with unnest(conkey, confkey) WITH ORDINALITY gives correct mapping. - fk_keys = self._conn.query( - f""" - SELECT - c.conname as constraint_name, - '"' || ns1.nspname || '"."' || cl1.relname || '"' as referencing_table, - '"' || ns2.nspname || '"."' || cl2.relname || '"' as referenced_table, - a1.attname as column_name, - a2.attname as referenced_column_name - FROM pg_constraint c - JOIN pg_class cl1 ON c.conrelid = cl1.oid - JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid - JOIN pg_class cl2 ON c.confrelid = cl2.oid - JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid - CROSS JOIN LATERAL unnest(c.conkey, c.confkey) WITH ORDINALITY AS cols(conkey, confkey, ord) - JOIN pg_attribute a1 ON a1.attrelid = cl1.oid AND a1.attnum = cols.conkey - JOIN pg_attribute a2 ON a2.attrelid = cl2.oid AND a2.attnum = cols.confkey - WHERE c.contype = 'f' - AND cl1.relname NOT LIKE {like_pattern} - AND (ns2.nspname in ({schemas_list}) - OR ns1.nspname in ({schemas_list})) - ORDER BY c.conname, cols.ord - """, - as_dict=True, - ) + # load primary key info + keys = self._conn.query(adapter.load_primary_keys_sql(schemas_list, like_pattern)) + pks = defaultdict(set) + for key in keys: + pks[key[0]].add(key[1]) + + # load foreign keys + fk_keys = self._conn.query( + adapter.load_foreign_keys_sql(schemas_list, like_pattern), + as_dict=True, + ) # add nodes to the graph for n, pk in pks.items(): diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 7034d122b..75e00c21c 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -16,7 +16,6 @@ from .dependencies import topo_sort from .errors import DataJointError -from .settings import config from .table import Table, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier @@ -105,6 +104,7 @@ def __init__(self, source, context=None) -> None: self.nodes_to_show = set(source.nodes_to_show) self._expanded_nodes = set(source._expanded_nodes) self.context = source.context + self._connection = source._connection super().__init__(source) return @@ -126,6 +126,7 @@ def __init__(self, source, context=None) -> None: raise DataJointError("Could not find database connection in %s" % repr(source[0])) # initialize graph from dependencies + self._connection = connection connection.dependencies.load() super().__init__(connection.dependencies) @@ -584,7 +585,7 @@ def make_dot(self): Tables are grouped by schema, with the Python module name shown as the group label when available. """ - direction = config.display.diagram_direction + direction = self._connection._config.display.diagram_direction graph = self._make_graph() # Apply collapse logic if needed @@ -857,7 +858,7 @@ def make_mermaid(self) -> str: Session --> Neuron """ graph = self._make_graph() - direction = config.display.diagram_direction + direction = self._connection._config.display.diagram_direction # Apply collapse logic if needed graph, collapsed_counts = self._apply_collapse(graph) diff --git a/src/datajoint/errors.py b/src/datajoint/errors.py index 7e10f021d..bba032b23 100644 --- a/src/datajoint/errors.py +++ b/src/datajoint/errors.py @@ -72,3 +72,7 @@ class MissingExternalFile(DataJointError): class BucketInaccessible(DataJointError): """S3 bucket is inaccessible.""" + + +class ThreadSafetyError(DataJointError): + """Global DataJoint state is disabled in thread-safe mode.""" diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 883853cd3..1b5f5ac9e 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -20,7 +20,6 @@ from .errors import DataJointError from .codecs import decode_attribute from .preview import preview, repr_html -from .settings import config logger = logging.getLogger(__name__.split(".")[0]) @@ -716,7 +715,7 @@ def fetch( import warnings warnings.warn( - "fetch() is deprecated in DataJoint 2.0. " "Use to_dicts(), to_pandas(), to_arrays(), or keys() instead.", + "fetch() is deprecated in DataJoint 2.0. Use to_dicts(), to_pandas(), to_arrays(), or keys() instead.", DeprecationWarning, stacklevel=2, ) @@ -818,7 +817,10 @@ def fetch1(self, *attrs, squeeze=False): row = cursor.fetchone() if not row or cursor.fetchone(): raise DataJointError("fetch1 requires exactly one tuple in the input set.") - return {name: decode_attribute(heading[name], row[name], squeeze=squeeze) for name in heading.names} + return { + name: decode_attribute(heading[name], row[name], squeeze=squeeze, connection=self.connection) + for name in heading.names + } else: # Handle "KEY" specially - it means primary key columns def is_key(attr): @@ -893,7 +895,10 @@ def to_dicts(self, order_by=None, limit=None, offset=None, squeeze=False): expr = self._apply_top(order_by, limit, offset) cursor = expr.cursor(as_dict=True) heading = expr.heading - return [{name: decode_attribute(heading[name], row[name], squeeze) for name in heading.names} for row in cursor] + return [ + {name: decode_attribute(heading[name], row[name], squeeze, connection=expr.connection) for name in heading.names} + for row in cursor + ] def to_pandas(self, order_by=None, limit=None, offset=None, squeeze=False): """ @@ -1064,7 +1069,7 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset return result_arrays[0] if len(attrs) == 1 else tuple(result_arrays) else: # Fetch all columns as structured array - get = partial(decode_attribute, squeeze=squeeze) + get = partial(decode_attribute, squeeze=squeeze, connection=expr.connection) cursor = expr.cursor(as_dict=False) rows = list(cursor.fetchall()) @@ -1218,7 +1223,10 @@ def __iter__(self): cursor = self.cursor(as_dict=True) heading = self.heading for row in cursor: - yield {name: decode_attribute(heading[name], row[name], squeeze=False) for name in heading.names} + yield { + name: decode_attribute(heading[name], row[name], squeeze=False, connection=self.connection) + for name in heading.names + } def cursor(self, as_dict=False): """ @@ -1247,7 +1255,7 @@ def __repr__(self): str String representation of the QueryExpression. """ - return super().__repr__() if config["loglevel"].lower() == "debug" else self.preview() + return super().__repr__() if self.connection._config["loglevel"].lower() == "debug" else self.preview() def preview(self, limit=None, width=None): """ @@ -1406,8 +1414,11 @@ def create(cls, arg1, arg2): arg2 = arg2() # instantiate if a class if not isinstance(arg2, QueryExpression): raise DataJointError("A QueryExpression can only be unioned with another QueryExpression") - if arg1.connection != arg2.connection: - raise DataJointError("Cannot operate on QueryExpressions originating from different connections.") + if arg1.connection is not arg2.connection: + raise DataJointError( + "Cannot operate on expressions from different connections. " + "Ensure both operands use the same dj.Instance or global connection." + ) if set(arg1.primary_key) != set(arg2.primary_key): raise DataJointError("The operands of a union must share the same primary key.") if set(arg1.heading.secondary_attributes) & set(arg2.heading.secondary_attributes): diff --git a/src/datajoint/gc.py b/src/datajoint/gc.py index 71a4e8d08..7f083416b 100644 --- a/src/datajoint/gc.py +++ b/src/datajoint/gc.py @@ -44,7 +44,7 @@ from .errors import DataJointError if TYPE_CHECKING: - from .schemas import Schema + from .schemas import _Schema as Schema logger = logging.getLogger(__name__.split(".")[0]) @@ -308,7 +308,7 @@ def scan_schema_references( return referenced -def list_stored_hashes(store_name: str | None = None) -> dict[str, int]: +def list_stored_hashes(store_name: str | None = None, config=None) -> dict[str, int]: """ List all hash-addressed items in storage. @@ -320,6 +320,8 @@ def list_stored_hashes(store_name: str | None = None) -> dict[str, int]: ---------- store_name : str, optional Store to scan (None = default store). + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- @@ -328,7 +330,7 @@ def list_stored_hashes(store_name: str | None = None) -> dict[str, int]: """ import re - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) stored: dict[str, int] = {} # Hash-addressed storage: _hash/{schema}/{subfolders...}/{hash} @@ -369,7 +371,7 @@ def list_stored_hashes(store_name: str | None = None) -> dict[str, int]: return stored -def list_schema_paths(store_name: str | None = None) -> dict[str, int]: +def list_schema_paths(store_name: str | None = None, config=None) -> dict[str, int]: """ List all schema-addressed items in storage. @@ -380,13 +382,15 @@ def list_schema_paths(store_name: str | None = None) -> dict[str, int]: ---------- store_name : str, optional Store to scan (None = default store). + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- dict[str, int] Dict mapping storage path to size in bytes. """ - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) stored: dict[str, int] = {} try: @@ -427,7 +431,7 @@ def list_schema_paths(store_name: str | None = None) -> dict[str, int]: return stored -def delete_schema_path(path: str, store_name: str | None = None) -> bool: +def delete_schema_path(path: str, store_name: str | None = None, config=None) -> bool: """ Delete a schema-addressed directory from storage. @@ -437,13 +441,15 @@ def delete_schema_path(path: str, store_name: str | None = None) -> bool: Storage path (relative to store root). store_name : str, optional Store name (None = default store). + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- bool True if deleted, False if not found. """ - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) try: full_path = backend._full_path(path) @@ -497,15 +503,18 @@ def scan( if not schemas: raise DataJointError("At least one schema must be provided") + # Extract config from the first schema's connection + _config = schemas[0].connection._config if schemas else None + # --- Hash-addressed storage --- hash_referenced = scan_hash_references(*schemas, store_name=store_name, verbose=verbose) - hash_stored = list_stored_hashes(store_name) + hash_stored = list_stored_hashes(store_name, config=_config) orphaned_hashes = set(hash_stored.keys()) - hash_referenced hash_orphaned_bytes = sum(hash_stored.get(h, 0) for h in orphaned_hashes) # --- Schema-addressed storage --- schema_paths_referenced = scan_schema_references(*schemas, store_name=store_name, verbose=verbose) - schema_paths_stored = list_schema_paths(store_name) + schema_paths_stored = list_schema_paths(store_name, config=_config) orphaned_paths = set(schema_paths_stored.keys()) - schema_paths_referenced schema_paths_orphaned_bytes = sum(schema_paths_stored.get(p, 0) for p in orphaned_paths) @@ -570,6 +579,9 @@ def collect( # First scan to find orphaned items stats = scan(*schemas, store_name=store_name, verbose=verbose) + # Extract config from the first schema's connection + _config = schemas[0].connection._config if schemas else None + hash_deleted = 0 schema_paths_deleted = 0 bytes_freed = 0 @@ -578,12 +590,12 @@ def collect( if not dry_run: # Delete orphaned hashes if stats["hash_orphaned"] > 0: - hash_stored = list_stored_hashes(store_name) + hash_stored = list_stored_hashes(store_name, config=_config) for path in stats["orphaned_hashes"]: try: size = hash_stored.get(path, 0) - if delete_path(path, store_name): + if delete_path(path, store_name, config=_config): hash_deleted += 1 bytes_freed += size if verbose: @@ -594,12 +606,12 @@ def collect( # Delete orphaned schema paths if stats["schema_paths_orphaned"] > 0: - schema_paths_stored = list_schema_paths(store_name) + schema_paths_stored = list_schema_paths(store_name, config=_config) for path in stats["orphaned_paths"]: try: size = schema_paths_stored.get(path, 0) - if delete_schema_path(path, store_name): + if delete_schema_path(path, store_name, config=_config): schema_paths_deleted += 1 bytes_freed += size if verbose: diff --git a/src/datajoint/hash_registry.py b/src/datajoint/hash_registry.py index a285e5df1..331c836cd 100644 --- a/src/datajoint/hash_registry.py +++ b/src/datajoint/hash_registry.py @@ -38,7 +38,6 @@ from typing import Any from .errors import DataJointError -from .settings import config from .storage import StorageBackend logger = logging.getLogger(__name__.split(".")[0]) @@ -131,7 +130,7 @@ def build_hash_path( return f"_hash/{schema_name}/{content_hash}" -def get_store_backend(store_name: str | None = None) -> StorageBackend: +def get_store_backend(store_name: str | None = None, config=None) -> StorageBackend: """ Get a StorageBackend for hash-addressed storage. @@ -139,18 +138,22 @@ def get_store_backend(store_name: str | None = None) -> StorageBackend: ---------- store_name : str, optional Name of the store to use. If None, uses stores.default. + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- StorageBackend StorageBackend instance. """ + if config is None: + from .settings import config # get_store_spec handles None by using stores.default spec = config.get_store_spec(store_name) return StorageBackend(spec) -def get_store_subfolding(store_name: str | None = None) -> tuple[int, ...] | None: +def get_store_subfolding(store_name: str | None = None, config=None) -> tuple[int, ...] | None: """ Get the subfolding configuration for a store. @@ -158,12 +161,16 @@ def get_store_subfolding(store_name: str | None = None) -> tuple[int, ...] | Non ---------- store_name : str, optional Name of the store. If None, uses stores.default. + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- tuple[int, ...] | None Subfolding pattern (e.g., (2, 2)) or None for flat storage. """ + if config is None: + from .settings import config spec = config.get_store_spec(store_name) subfolding = spec.get("subfolding") if subfolding is not None: @@ -175,6 +182,7 @@ def put_hash( data: bytes, schema_name: str, store_name: str | None = None, + config=None, ) -> dict[str, Any]: """ Store content using hash-addressed storage. @@ -193,6 +201,8 @@ def put_hash( Database/schema name for path isolation. store_name : str, optional Name of the store. If None, uses default store. + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- @@ -200,10 +210,10 @@ def put_hash( Metadata dict with keys: hash, path, schema, store, size. """ content_hash = compute_hash(data) - subfolding = get_store_subfolding(store_name) + subfolding = get_store_subfolding(store_name, config=config) path = build_hash_path(content_hash, schema_name, subfolding) - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) # Check if content already exists (deduplication within schema) if not backend.exists(path): @@ -221,7 +231,7 @@ def put_hash( } -def get_hash(metadata: dict[str, Any]) -> bytes: +def get_hash(metadata: dict[str, Any], config=None) -> bytes: """ Retrieve content using stored metadata. @@ -232,6 +242,8 @@ def get_hash(metadata: dict[str, Any]) -> bytes: ---------- metadata : dict Metadata dict with keys: path, hash, store (optional). + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- @@ -249,15 +261,13 @@ def get_hash(metadata: dict[str, Any]) -> bytes: expected_hash = metadata["hash"] store_name = metadata.get("store") - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) data = backend.get_buffer(path) # Verify hash for integrity actual_hash = compute_hash(data) if actual_hash != expected_hash: - raise DataJointError( - f"Hash mismatch: expected {expected_hash}, got {actual_hash}. " f"Data at {path} may be corrupted." - ) + raise DataJointError(f"Hash mismatch: expected {expected_hash}, got {actual_hash}. Data at {path} may be corrupted.") return data @@ -265,6 +275,7 @@ def get_hash(metadata: dict[str, Any]) -> bytes: def delete_path( path: str, store_name: str | None = None, + config=None, ) -> bool: """ Delete content at the specified path from storage. @@ -278,6 +289,8 @@ def delete_path( Storage path (as stored in metadata). store_name : str, optional Name of the store. If None, uses default store. + config : Config, optional + Config instance. If None, falls back to global settings.config. Returns ------- @@ -288,7 +301,7 @@ def delete_path( -------- This permanently deletes content. Ensure no references exist first. """ - backend = get_store_backend(store_name) + backend = get_store_backend(store_name, config=config) if backend.exists(path): backend.remove(path) diff --git a/src/datajoint/instance.py b/src/datajoint/instance.py new file mode 100644 index 000000000..455336a7c --- /dev/null +++ b/src/datajoint/instance.py @@ -0,0 +1,311 @@ +""" +DataJoint Instance for thread-safe operation. + +An Instance encapsulates a config and connection pair, providing isolated +database contexts for multi-tenant applications. +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Literal + +from .connection import Connection +from .errors import ThreadSafetyError +from .settings import Config, _create_config, config as _settings_config + +if TYPE_CHECKING: + from .schemas import _Schema as SchemaClass + from .table import FreeTable as FreeTableClass + + +def _load_thread_safe() -> bool: + """ + Check if thread-safe mode is enabled. + + Thread-safe mode is controlled by the ``DJ_THREAD_SAFE`` environment + variable, which must be set before the process starts. + + Returns + ------- + bool + True if thread-safe mode is enabled. + """ + env_val = os.environ.get("DJ_THREAD_SAFE", "").lower() + if env_val in ("true", "1", "yes"): + return True + return False + + +class Instance: + """ + Encapsulates a DataJoint configuration and connection. + + Each Instance has its own Config and Connection, providing isolation + for multi-tenant applications. Use ``dj.Instance()`` to create isolated + instances, or access the singleton via ``dj.config``, ``dj.conn()``, etc. + + Parameters + ---------- + host : str + Database hostname. + user : str + Database username. + password : str + Database password. + port : int, optional + Database port. Defaults to 3306 for MySQL, 5432 for PostgreSQL. + use_tls : bool or dict, optional + TLS configuration. + backend : str, optional + Database backend: ``"mysql"`` or ``"postgresql"``. Default from config. + **kwargs : Any + Additional config overrides applied to this instance's config. + + Attributes + ---------- + config : Config + Configuration for this instance. + connection : Connection + Database connection for this instance. + + Examples + -------- + >>> inst = dj.Instance(host="localhost", user="root", password="secret") + >>> inst.config.safemode = False + >>> schema = inst.Schema("my_schema") + """ + + def __init__( + self, + host: str, + user: str, + password: str, + port: int | None = None, + use_tls: bool | dict | None = None, + backend: Literal["mysql", "postgresql"] | None = None, + **kwargs: Any, + ) -> None: + # Create fresh config with defaults loaded from env/file + self.config = _create_config() + + # Apply backend override before other kwargs (port default depends on it) + if backend is not None: + self.config.database.backend = backend + # Re-derive port default since _create_config resolved it before backend was set + if port is None and "database__port" not in kwargs: + self.config.database.port = 5432 if backend == "postgresql" else 3306 + + # Apply any config overrides from kwargs + for key, value in kwargs.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + elif "__" in key: + # Handle nested keys like database__reconnect + parts = key.split("__") + obj = self.config + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + # Determine port + if port is None: + port = self.config.database.port + + # Create connection with this instance's config and backend + self.connection = Connection( + host, + user, + password, + port, + use_tls, + backend=self.config.database.backend, + config_override=self.config, + ) + + def Schema( + self, + schema_name: str, + *, + context: dict[str, Any] | None = None, + create_schema: bool = True, + create_tables: bool | None = None, + add_objects: dict[str, Any] | None = None, + ) -> "SchemaClass": + """ + Create a Schema bound to this instance's connection. + + Parameters + ---------- + schema_name : str + Database schema name. + context : dict, optional + Namespace for foreign key lookup. + create_schema : bool, optional + If False, raise error if schema doesn't exist. Default True. + create_tables : bool, optional + If False, raise error when accessing missing tables. + add_objects : dict, optional + Additional objects for declaration context. + + Returns + ------- + Schema + A Schema using this instance's connection. + """ + from .schemas import _Schema + + return _Schema( + schema_name, + context=context, + connection=self.connection, + create_schema=create_schema, + create_tables=create_tables, + add_objects=add_objects, + ) + + def FreeTable(self, full_table_name: str) -> "FreeTableClass": + """ + Create a FreeTable bound to this instance's connection. + + Parameters + ---------- + full_table_name : str + Full table name as ``'schema.table'`` or ```schema`.`table```. + + Returns + ------- + FreeTable + A FreeTable using this instance's connection. + """ + from .table import FreeTable + + return FreeTable(self.connection, full_table_name) + + def __repr__(self) -> str: + return f"Instance({self.connection!r})" + + +# ============================================================================= +# Singleton management +# ============================================================================= +# The global config is created at module load time and can be modified +# The singleton connection is created lazily when conn() or Schema() is called + +# Reuse the config created in settings.py — there must be exactly one global config +_global_config: Config = _settings_config +_singleton_connection: Connection | None = None + + +def _check_thread_safe() -> None: + """ + Check if thread-safe mode is enabled and raise if so. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled. + """ + if _load_thread_safe(): + raise ThreadSafetyError( + "Global DataJoint state is disabled in thread-safe mode. " "Use dj.Instance() to create an isolated instance." + ) + + +def _get_singleton_connection() -> Connection: + """ + Get or create the singleton Connection. + + Uses credentials from the global config. + + Raises + ------ + ThreadSafetyError + If thread_safe mode is enabled. + DataJointError + If credentials are not configured. + """ + global _singleton_connection + + _check_thread_safe() + + if _singleton_connection is None: + from .errors import DataJointError + + host = _global_config.database.host + user = _global_config.database.user + raw_password = _global_config.database.password + password = raw_password.get_secret_value() if raw_password is not None else None + port = _global_config.database.port + use_tls = _global_config.database.use_tls + + if user is None: + raise DataJointError( + "Database user not configured. Set dj.config['database.user'] or DJ_USER environment variable." + ) + if password is None: + raise DataJointError( + "Database password not configured. Set dj.config['database.password'] or DJ_PASS environment variable." + ) + + _singleton_connection = Connection(host, user, password, port, use_tls, config_override=_global_config) + + return _singleton_connection + + +class _ConfigProxy: + """ + Proxy that delegates to the global config, with thread-safety checks. + + In thread-safe mode, all access raises ThreadSafetyError. + """ + + def __getattr__(self, name: str) -> Any: + _check_thread_safe() + return getattr(_global_config, name) + + def __setattr__(self, name: str, value: Any) -> None: + _check_thread_safe() + setattr(_global_config, name, value) + + def __getitem__(self, key: str) -> Any: + _check_thread_safe() + return _global_config[key] + + def __setitem__(self, key: str, value: Any) -> None: + _check_thread_safe() + _global_config[key] = value + + def __delitem__(self, key: str) -> None: + _check_thread_safe() + del _global_config[key] + + def get(self, key: str, default: Any = None) -> Any: + _check_thread_safe() + return _global_config.get(key, default) + + def override(self, **kwargs: Any): + _check_thread_safe() + return _global_config.override(**kwargs) + + def load(self, filename: str) -> None: + _check_thread_safe() + return _global_config.load(filename) + + def get_store_spec(self, store: str | None = None, *, use_filepath_default: bool = False) -> dict[str, Any]: + _check_thread_safe() + return _global_config.get_store_spec(store, use_filepath_default=use_filepath_default) + + @staticmethod + def save_template( + path: str = "datajoint.json", + minimal: bool = True, + create_secrets_dir: bool = True, + ): + # save_template is a static method, no thread-safety check needed + return Config.save_template(path, minimal, create_secrets_dir) + + def __repr__(self) -> str: + if _load_thread_safe(): + return "ConfigProxy (thread-safe mode - use dj.Instance())" + return repr(_global_config) diff --git a/src/datajoint/jobs.py b/src/datajoint/jobs.py index 5a0eb2a86..6da8377dd 100644 --- a/src/datajoint/jobs.py +++ b/src/datajoint/jobs.py @@ -24,16 +24,22 @@ logger = logging.getLogger(__name__.split(".")[0]) -def _get_job_version() -> str: +def _get_job_version(config=None) -> str: """ Get version string based on config settings. + Parameters + ---------- + config : Config, optional + Configuration object. If None, falls back to global config. + Returns ------- str Version string, or empty string if version tracking disabled. """ - from .settings import config + if config is None: + from .settings import config method = config.jobs.version_method if method is None or method == "none": @@ -349,17 +355,15 @@ def refresh( 3. Remove stale jobs: jobs older than stale_timeout whose keys not in key_source 4. Remove orphaned jobs: reserved jobs older than orphan_timeout (if specified) """ - from .settings import config - # Ensure jobs table exists if not self.is_declared: self.declare() # Get defaults from config if priority is None: - priority = config.jobs.default_priority + priority = self.connection._config.jobs.default_priority if stale_timeout is None: - stale_timeout = config.jobs.stale_timeout + stale_timeout = self.connection._config.jobs.stale_timeout result = {"added": 0, "removed": 0, "orphaned": 0, "re_pended": 0} @@ -392,7 +396,7 @@ def refresh( pass # Job already exists # 2. Re-pend success jobs if keep_completed=True - if config.jobs.keep_completed: + if self.connection._config.jobs.keep_completed: # Success jobs whose keys are in key_source but not in target # Disable semantic_check for Job table operations (job table PK has different lineage than target) success_to_repend = self.completed.restrict(key_source, semantic_check=False).restrict( @@ -462,7 +466,7 @@ def reserve(self, key: dict) -> bool: os.getpid(), self.connection.connection_id, self.connection.get_user(), - _get_job_version(), + _get_job_version(self.connection._config), ] cursor = self.connection.query(query, args=args) return cursor.rowcount == 1 @@ -485,9 +489,7 @@ def complete(self, key: dict, duration: float | None = None) -> None: - If True: updates status to ``'success'`` with completion time and duration - If False: deletes the job entry """ - from .settings import config - - if config.jobs.keep_completed: + if self.connection._config.jobs.keep_completed: # Use server time for completed_time server_now = self.connection.query("SELECT CURRENT_TIMESTAMP").fetchone()[0] pk = self._get_pk(key) @@ -545,13 +547,11 @@ def ignore(self, key: dict) -> None: key : dict Primary key dict of the job. """ - from .settings import config - pk = self._get_pk(key) if pk in self: self.update1({**pk, "status": "ignore"}) else: - priority = config.jobs.default_priority + priority = self.connection._config.jobs.default_priority self.insert1({**pk, "status": "ignore", "priority": priority}) def progress(self) -> dict: diff --git a/src/datajoint/migrate.py b/src/datajoint/migrate.py index d48afae62..2ff0dfcb8 100644 --- a/src/datajoint/migrate.py +++ b/src/datajoint/migrate.py @@ -7,7 +7,7 @@ .. note:: This module is provided temporarily to assist with migration from pre-2.0. - It will be deprecated in DataJoint 2.1 and removed in 2.2. + It will be deprecated in DataJoint 2.1 and removed in 2.3. Complete your migrations while on DataJoint 2.0. Note on Terminology @@ -32,14 +32,14 @@ # Show deprecation warning starting in 2.1 if Version(__version__) >= Version("2.1"): warnings.warn( - "datajoint.migrate is deprecated and will be removed in DataJoint 2.2. " + "datajoint.migrate is deprecated and will be removed in DataJoint 2.3. " "Complete your schema migrations before upgrading.", DeprecationWarning, stacklevel=2, ) if TYPE_CHECKING: - from .schemas import Schema + from .schemas import _Schema as Schema logger = logging.getLogger(__name__.split(".")[0]) @@ -653,7 +653,7 @@ def add_job_metadata_columns(target, dry_run: bool = True) -> dict: - Future populate() calls will fill in metadata for new rows - This does NOT retroactively populate metadata for existing rows """ - from .schemas import Schema + from .schemas import _Schema from .table import Table result = { @@ -664,7 +664,7 @@ def add_job_metadata_columns(target, dry_run: bool = True) -> dict: } # Determine tables to process - if isinstance(target, Schema): + if isinstance(target, _Schema): schema = target # Get all user tables in the schema tables_query = """ diff --git a/src/datajoint/preview.py b/src/datajoint/preview.py index 92d09d874..0b80ad15f 100644 --- a/src/datajoint/preview.py +++ b/src/datajoint/preview.py @@ -2,8 +2,6 @@ import json -from .settings import config - def _format_object_display(json_data): """Format object metadata for display in query results.""" @@ -44,6 +42,7 @@ def _get_blob_placeholder(heading, field_name, html_escape=False): def preview(query_expression, limit, width): heading = query_expression.heading rel = query_expression.proj(*heading.non_blobs) + config = query_expression.connection._config # Object fields use codecs - not specially handled in simplified model object_fields = [] if limit is None: @@ -105,6 +104,7 @@ def get_display_value(tup, f, idx): def repr_html(query_expression): heading = query_expression.heading rel = query_expression.proj(*heading.non_blobs) + config = query_expression.connection._config # Object fields use codecs - not specially handled in simplified model object_fields = [] tuples = rel.to_arrays(limit=config["display.limit"] + 1) diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 1780bbaaf..8747cdbf2 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -7,23 +7,20 @@ from __future__ import annotations -import collections import inspect -import itertools import logging import re import types import warnings from typing import TYPE_CHECKING, Any -from .connection import conn from .errors import AccessError, DataJointError +from .instance import _get_singleton_connection if TYPE_CHECKING: from .connection import Connection from .heading import Heading from .jobs import Job -from .settings import config from .table import FreeTable, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _get_tier from .utils import to_camel_case, user_choice @@ -54,7 +51,7 @@ def ordered_dir(class_: type) -> list[str]: return attr_list -class Schema: +class _Schema: """ Decorator that binds table classes to a database schema. @@ -120,7 +117,7 @@ def __init__( self.database = None self.context = context self.create_schema = create_schema - self.create_tables = create_tables if create_tables is not None else config.database.create_tables + self.create_tables = create_tables # None means "use connection config default" self.add_objects = add_objects self.declare_list = [] if schema_name: @@ -174,7 +171,7 @@ def activate( if connection is not None: self.connection = connection if self.connection is None: - self.connection = conn() + self.connection = _get_singleton_connection() self.database = schema_name if create_schema is not None: self.create_schema = create_schema @@ -293,7 +290,10 @@ def _decorate_table(self, table_class: type, context: dict[str, Any], assert_dec # instantiate the class, declare the table if not already instance = table_class() is_declared = instance.is_declared - if not is_declared and not assert_declared and self.create_tables: + create_tables = ( + self.create_tables if self.create_tables is not None else self.connection._config.database.create_tables + ) + if not is_declared and not assert_declared and create_tables: instance.declare(context) self.connection.dependencies.clear() is_declared = is_declared or instance.is_declared @@ -343,7 +343,7 @@ def make_classes(self, into: dict[str, Any] | None = None) -> None: del frame tables = [ row[0] - for row in self.connection.query("SHOW TABLES in `%s`" % self.database) + for row in self.connection.query(self.connection.adapter.list_tables_sql(self.database)) if lookup_class_name("`{db}`.`{tab}`".format(db=self.database, tab=row[0]), into, 0) is None ] master_classes = (Lookup, Manual, Imported, Computed) @@ -389,7 +389,7 @@ def drop(self, prompt: bool | None = None) -> None: AccessError If insufficient permissions to drop the schema. """ - prompt = config["safemode"] if prompt is None else prompt + prompt = self.connection._config["safemode"] if prompt is None else prompt if not self.exists: logger.info("Schema named `{database}` does not exist. Doing nothing.".format(database=self.database)) @@ -421,13 +421,7 @@ def exists(self) -> bool: """ if self.database is None: raise DataJointError("Schema must be activated first.") - return bool( - self.connection.query( - "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{database}'".format( - database=self.database - ) - ).rowcount - ) + return bool(self.connection.query(self.connection.adapter.schema_exists_sql(self.database)).rowcount) @property def lineage_table_exists(self) -> bool: @@ -520,83 +514,6 @@ def jobs(self) -> list[Job]: return jobs_list - @property - def code(self): - self._assert_exists() - return self.save() - - def save(self, python_filename: str | None = None) -> str: - """ - Generate Python code that recreates this schema. - - Parameters - ---------- - python_filename : str, optional - If provided, write the code to this file. - - Returns - ------- - str - Python module source code defining this schema. - - Notes - ----- - This method is in preparation for a future release and is not - officially supported. - """ - self.connection.dependencies.load() - self._assert_exists() - module_count = itertools.count() - # add virtual modules for referenced modules with names vmod0, vmod1, ... - module_lookup = collections.defaultdict(lambda: "vmod" + str(next(module_count))) - db = self.database - - def make_class_definition(table): - tier = _get_tier(table).__name__ - class_name = table.split(".")[1].strip("`") - indent = "" - if tier == "Part": - class_name = class_name.split("__")[-1] - indent += " " - class_name = to_camel_case(class_name) - - def replace(s): - d, tabs = s.group(1), s.group(2) - return ("" if d == db else (module_lookup[d] + ".")) + ".".join( - to_camel_case(tab) for tab in tabs.lstrip("__").split("__") - ) - - return ("" if tier == "Part" else "\n@schema\n") + ( - '{indent}class {class_name}(dj.{tier}):\n{indent} definition = """\n{indent} {defi}"""' - ).format( - class_name=class_name, - indent=indent, - tier=tier, - defi=re.sub( - r"`([^`]+)`.`([^`]+)`", - replace, - FreeTable(self.connection, table).describe(), - ).replace("\n", "\n " + indent), - ) - - tables = self.connection.dependencies.topo_sort() - body = "\n\n".join(make_class_definition(table) for table in tables) - python_code = "\n\n".join( - ( - '"""This module was auto-generated by datajoint from an existing schema"""', - "import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db), - "\n".join( - "{module} = dj.VirtualModule('{module}', '{schema_name}')".format(module=v, schema_name=k) - for k, v in module_lookup.items() - ), - body, - ) - ) - if python_filename is None: - return python_code - with open(python_filename, "wt") as f: - f.write(python_code) - def list_tables(self) -> list[str]: """ Return all user tables in the schema. @@ -612,7 +529,10 @@ def list_tables(self) -> list[str]: self.connection.dependencies.load() return [ t - for d, t in (table_name.replace("`", "").split(".") for table_name in self.connection.dependencies.topo_sort()) + for d, t in ( + self.connection.adapter.split_full_table_name(table_name) + for table_name in self.connection.dependencies.topo_sort() + ) if d == self.database ] @@ -810,7 +730,7 @@ def __init__( Additional objects to add to the module namespace. """ super(VirtualModule, self).__init__(name=module_name) - _schema = Schema( + _schema = _Schema( schema_name, create_schema=create_schema, create_tables=create_tables, @@ -836,12 +756,8 @@ def list_schemas(connection: Connection | None = None) -> list[str]: list[str] Names of all accessible schemas. """ - return [ - r[0] - for r in (connection or conn()).query( - 'SELECT schema_name FROM information_schema.schemata WHERE schema_name <> "information_schema"' - ) - ] + conn = connection or _get_singleton_connection() + return [r[0] for r in conn.query(conn.adapter.list_schemas_sql())] def virtual_schema( diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index e373ca38f..7019d8345 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -224,7 +224,6 @@ class ConnectionSettings(BaseSettings): model_config = SettingsConfigDict(extra="forbid", validate_assignment=True) - init_function: str | None = None charset: str = "" # pymysql uses '' as default @@ -341,11 +340,8 @@ class Config(BaseSettings): # Top-level settings loglevel: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(default="INFO", validation_alias="DJ_LOG_LEVEL") safemode: bool = True - enable_python_native_blobs: bool = True - filepath_checksum_size_limit: int | None = None - # Cache paths - cache: Path | None = None + # Cache path for query results query_cache: Path | None = None # Download path for attachments and filepaths @@ -362,7 +358,7 @@ def set_logger_level(cls, v: str) -> str: logger.setLevel(v) return v - @field_validator("cache", "query_cache", mode="before") + @field_validator("query_cache", mode="before") @classmethod def convert_path(cls, v: Any) -> Path | None: """Convert string paths to Path objects.""" @@ -819,7 +815,6 @@ def save_template( "use_tls": None, }, "connection": { - "init_function": None, "charset": "", }, "display": { @@ -844,8 +839,6 @@ def save_template( }, "loglevel": "INFO", "safemode": True, - "enable_python_native_blobs": True, - "cache": None, "query_cache": None, "download_path": ".", } diff --git a/src/datajoint/staged_insert.py b/src/datajoint/staged_insert.py index 6ac3819e4..1f6ee7afb 100644 --- a/src/datajoint/staged_insert.py +++ b/src/datajoint/staged_insert.py @@ -14,7 +14,6 @@ import fsspec from .errors import DataJointError -from .settings import config from .storage import StorageBackend, build_object_path @@ -69,7 +68,7 @@ def _ensure_backend(self): """Ensure storage backend is initialized.""" if self._backend is None: try: - spec = config.get_store_spec() # Uses stores.default + spec = self._table.connection._config.get_store_spec() # Uses stores.default self._backend = StorageBackend(spec) except DataJointError: raise DataJointError( @@ -110,7 +109,7 @@ def _get_storage_path(self, field: str, ext: str = "") -> str: ) # Get storage spec (uses stores.default) - spec = config.get_store_spec() + spec = self._table.connection._config.get_store_spec() partition_pattern = spec.get("partition_pattern") token_length = spec.get("token_length", 8) diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 5fd8c3087..256fab6e9 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -23,7 +23,6 @@ ) from .expression import QueryExpression from .heading import Heading -from .settings import config from .staged_insert import staged_insert1 as _staged_insert1 from .utils import get_master, is_camel_case, user_choice @@ -141,8 +140,7 @@ def declare(self, context=None): class_name = self.class_name if "_" in class_name: warnings.warn( - f"Table class name `{class_name}` contains underscores. " - "CamelCase names without underscores are recommended.", + f"Table class name `{class_name}` contains underscores. CamelCase names without underscores are recommended.", UserWarning, stacklevel=2, ) @@ -153,7 +151,7 @@ def declare(self, context=None): "Class names must be in CamelCase, starting with a capital letter." ) sql, _external_stores, primary_key, fk_attribute_map, pre_ddl, post_ddl = declare( - self.full_table_name, self.definition, context, self.connection.adapter + self.full_table_name, self.definition, context, self.connection.adapter, config=self.connection._config ) # Call declaration hook for validation (subclasses like AutoPopulate can override) @@ -235,12 +233,7 @@ def _populate_lineage(self, primary_key, fk_attribute_map): # FK attributes: copy lineage from parent (whether in PK or not) for attr, (parent_table, parent_attr) in fk_attribute_map.items(): # Parse parent table name: `schema`.`table` or "schema"."table" -> (schema, table) - parent_clean = parent_table.replace("`", "").replace('"', "") - if "." in parent_clean: - parent_db, parent_tbl = parent_clean.split(".", 1) - else: - parent_db = self.database - parent_tbl = parent_clean + parent_db, parent_tbl = self.connection.adapter.split_full_table_name(parent_table) # Get parent's lineage for this attribute parent_lineage = get_lineage(self.connection, parent_db, parent_tbl, parent_attr) @@ -1119,7 +1112,7 @@ def strip_quotes(s): raise DataJointError("Exceeded maximum number of delete attempts.") return delete_count - prompt = config["safemode"] if prompt is None else prompt + prompt = self.connection._config["safemode"] if prompt is None else prompt # Start transaction if transaction: @@ -1227,7 +1220,7 @@ def drop(self, prompt: bool | None = None): raise DataJointError( "A table with an applied restriction cannot be dropped. Call drop() on the unrestricted Table." ) - prompt = config["safemode"] if prompt is None else prompt + prompt = self.connection._config["safemode"] if prompt is None else prompt self.connection.dependencies.load() do_drop = True @@ -1398,6 +1391,7 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): "_schema": self.database, "_table": self.table_name, "_field": name, + "_config": self.connection._config, } # Add primary key values from row if available if row is not None: diff --git a/tests/conftest.py b/tests/conftest.py index 4d6adf09c..8efaab745 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -536,13 +536,13 @@ def mock_stores(stores_config): @pytest.fixture def mock_cache(tmpdir_factory): - og_cache = dj.config.get("cache") - dj.config["cache"] = tmpdir_factory.mktemp("cache") + og_cache = dj.config.get("download_path") + dj.config["download_path"] = str(tmpdir_factory.mktemp("cache")) yield if og_cache is None: - del dj.config["cache"] + del dj.config["download_path"] else: - dj.config["cache"] = og_cache + dj.config["download_path"] = og_cache @pytest.fixture(scope="session") diff --git a/tests/integration/test_gc.py b/tests/integration/test_gc.py index 7eca79f37..47ca0a96d 100644 --- a/tests/integration/test_gc.py +++ b/tests/integration/test_gc.py @@ -251,7 +251,7 @@ def test_deletes_orphaned_hashes(self, mock_scan, mock_list_stored, mock_delete) assert stats["hash_deleted"] == 1 assert stats["bytes_freed"] == 100 assert stats["dry_run"] is False - mock_delete.assert_called_once_with("_hash/schema/orphan_path", "test_store") + mock_delete.assert_called_once_with("_hash/schema/orphan_path", "test_store", config=mock_schema.connection._config) @patch("datajoint.gc.delete_schema_path") @patch("datajoint.gc.list_schema_paths") @@ -278,7 +278,7 @@ def test_deletes_orphaned_schemas(self, mock_scan, mock_list_schemas, mock_delet assert stats["schema_paths_deleted"] == 1 assert stats["bytes_freed"] == 500 assert stats["dry_run"] is False - mock_delete.assert_called_once_with("schema/table/pk/field", "test_store") + mock_delete.assert_called_once_with("schema/table/pk/field", "test_store", config=mock_schema.connection._config) class TestFormatStats: diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index 20fa3233d..5a9203dca 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -108,10 +108,9 @@ def test_sigterm(clean_jobs, schema_any): def test_suppress_dj_errors(clean_jobs, schema_any): - """Test that DataJoint errors are suppressible without native py blobs.""" + """Test that DataJoint errors are suppressible.""" error_class = schema.ErrorClass() - with dj.config.override(enable_python_native_blobs=False): - error_class.populate(reserve_jobs=True, suppress_errors=True) + error_class.populate(reserve_jobs=True, suppress_errors=True) assert len(schema.DjExceptionName()) == len(error_class.jobs.errors) > 0 diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 9a4144bbb..1211ecd1e 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -260,5 +260,5 @@ class Recording(dj.Manual): id: smallint """ - schema2.drop() - schema1.drop() + schema2.drop(prompt=False) + schema1.drop(prompt=False) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index af5718503..475d96df9 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -504,23 +504,23 @@ def test_display_limit(self): class TestCachePaths: """Test cache path settings.""" - def test_cache_path_string(self): - """Test setting cache path as string.""" - original = dj.config.cache + def test_query_cache_path_string(self): + """Test setting query_cache path as string.""" + original = dj.config.query_cache try: - dj.config.cache = "/tmp/cache" - assert dj.config.cache == Path("/tmp/cache") + dj.config.query_cache = "/tmp/cache" + assert dj.config.query_cache == Path("/tmp/cache") finally: - dj.config.cache = original + dj.config.query_cache = original - def test_cache_path_none(self): - """Test cache path can be None.""" - original = dj.config.cache + def test_query_cache_path_none(self): + """Test query_cache path can be None.""" + original = dj.config.query_cache try: - dj.config.cache = None - assert dj.config.cache is None + dj.config.query_cache = None + assert dj.config.query_cache is None finally: - dj.config.cache = original + dj.config.query_cache = original class TestSaveTemplate: diff --git a/tests/unit/test_thread_safe.py b/tests/unit/test_thread_safe.py new file mode 100644 index 000000000..aba1b686b --- /dev/null +++ b/tests/unit/test_thread_safe.py @@ -0,0 +1,294 @@ +"""Tests for thread-safe mode functionality.""" + +import pytest + + +class TestThreadSafeMode: + """Test thread-safe mode behavior.""" + + def test_thread_safe_env_var_true(self, monkeypatch): + """DJ_THREAD_SAFE=true enables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + # Re-import to pick up the new env var + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is True + + def test_thread_safe_env_var_false(self, monkeypatch): + """DJ_THREAD_SAFE=false disables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "false") + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is False + + def test_thread_safe_env_var_1(self, monkeypatch): + """DJ_THREAD_SAFE=1 enables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "1") + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is True + + def test_thread_safe_env_var_yes(self, monkeypatch): + """DJ_THREAD_SAFE=yes enables thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "yes") + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is True + + def test_thread_safe_default_false(self, monkeypatch): + """Thread-safe mode defaults to False.""" + monkeypatch.delenv("DJ_THREAD_SAFE", raising=False) + + from datajoint.instance import _load_thread_safe + + assert _load_thread_safe() is False + + +class TestConfigProxyThreadSafe: + """Test ConfigProxy behavior in thread-safe mode.""" + + def test_config_access_raises_in_thread_safe_mode(self, monkeypatch): + """Accessing config raises ThreadSafetyError in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + _ = dj.config.database + + def test_config_access_works_in_normal_mode(self, monkeypatch): + """Accessing config works in normal mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "false") + + import datajoint as dj + + # Should not raise + host = dj.config.database.host + assert isinstance(host, str) + + def test_config_set_raises_in_thread_safe_mode(self, monkeypatch): + """Setting config raises ThreadSafetyError in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.config.safemode = False + + def test_save_template_works_in_thread_safe_mode(self, monkeypatch, tmp_path): + """save_template is a static method and works in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + + # Should not raise - save_template is static + config_file = tmp_path / "datajoint.json" + dj.config.save_template(str(config_file), create_secrets_dir=False) + assert config_file.exists() + + +class TestConnThreadSafe: + """Test conn() behavior in thread-safe mode.""" + + def test_conn_raises_in_thread_safe_mode(self, monkeypatch): + """conn() raises ThreadSafetyError in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.conn() + + +class TestSchemaThreadSafe: + """Test Schema behavior in thread-safe mode.""" + + def test_schema_raises_in_thread_safe_mode(self, monkeypatch): + """Schema() raises ThreadSafetyError in thread-safe mode without connection.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.Schema("test_schema") + + +class TestFreeTableThreadSafe: + """Test FreeTable behavior in thread-safe mode.""" + + def test_freetable_raises_in_thread_safe_mode(self, monkeypatch): + """FreeTable() raises ThreadSafetyError in thread-safe mode without connection.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + import datajoint as dj + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError): + dj.FreeTable("test.table") + + +class TestInstance: + """Test Instance class.""" + + def test_instance_import(self): + """Instance class is importable.""" + from datajoint import Instance + + assert Instance is not None + + def test_instance_always_allowed_in_thread_safe_mode(self, monkeypatch): + """Instance() is allowed even in thread-safe mode.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + + from datajoint import Instance + + # Instance class should be accessible + # (actual creation requires valid credentials) + assert callable(Instance) + + +class TestInstanceBackend: + """Test Instance backend parameter.""" + + def test_instance_backend_sets_config(self, monkeypatch): + """Instance(backend=...) sets config.database.backend.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "false") + from datajoint.instance import Instance + from unittest.mock import patch + + with patch("datajoint.instance.Connection"): + inst = Instance( + host="localhost", + user="root", + password="secret", + backend="postgresql", + ) + assert inst.config.database.backend == "postgresql" + + def test_instance_backend_default_from_config(self, monkeypatch): + """Instance without backend uses config default.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "false") + from datajoint.instance import Instance + from unittest.mock import patch + + with patch("datajoint.instance.Connection"): + inst = Instance( + host="localhost", + user="root", + password="secret", + ) + assert inst.config.database.backend == "mysql" + + def test_instance_backend_affects_port_default(self, monkeypatch): + """Instance(backend='postgresql') uses port 5432 by default.""" + monkeypatch.setenv("DJ_THREAD_SAFE", "false") + from datajoint.instance import Instance + from unittest.mock import patch + + with patch("datajoint.instance.Connection") as MockConn: + Instance( + host="localhost", + user="root", + password="secret", + backend="postgresql", + ) + # Connection should be called with port 5432 (PostgreSQL default) + args, kwargs = MockConn.call_args + assert args[3] == 5432 # port is the 4th positional arg + + +class TestCrossConnectionValidation: + """Test that cross-connection operations are rejected.""" + + def test_join_different_connections_raises(self): + """Join of expressions from different connections raises DataJointError.""" + from datajoint.expression import QueryExpression + from datajoint.errors import DataJointError + from unittest.mock import MagicMock + + expr1 = QueryExpression() + expr1._connection = MagicMock() + expr1._heading = MagicMock() + expr1._heading.names = [] + + expr2 = QueryExpression() + expr2._connection = MagicMock() # different connection object + expr2._heading = MagicMock() + expr2._heading.names = [] + + with pytest.raises(DataJointError, match="different connections"): + expr1 * expr2 + + def test_join_same_connection_allowed(self): + """Join of expressions from the same connection does not raise.""" + from datajoint.condition import assert_join_compatibility + from datajoint.expression import QueryExpression + from unittest.mock import MagicMock + + shared_conn = MagicMock() + + expr1 = QueryExpression() + expr1._connection = shared_conn + expr1._heading = MagicMock() + expr1._heading.names = [] + expr1._heading.lineage_available = False + + expr2 = QueryExpression() + expr2._connection = shared_conn + expr2._heading = MagicMock() + expr2._heading.names = [] + expr2._heading.lineage_available = False + + # Should not raise + assert_join_compatibility(expr1, expr2) + + def test_restriction_different_connections_raises(self): + """Restriction by expression from different connection raises DataJointError.""" + from datajoint.expression import QueryExpression + from datajoint.errors import DataJointError + from unittest.mock import MagicMock + + expr1 = QueryExpression() + expr1._connection = MagicMock() + expr1._heading = MagicMock() + expr1._heading.names = ["a"] + expr1._heading.__getitem__ = MagicMock() + expr1._heading.new_attributes = set() + expr1._support = ["`db`.`t1`"] + expr1._restriction = [] + expr1._restriction_attributes = set() + expr1._joins = [] + expr1._top = None + expr1._original_heading = expr1._heading + + expr2 = QueryExpression() + expr2._connection = MagicMock() # different connection + expr2._heading = MagicMock() + expr2._heading.names = ["a"] + + with pytest.raises(DataJointError, match="different connections"): + expr1 & expr2 + + +class TestThreadSafetyError: + """Test ThreadSafetyError exception.""" + + def test_error_is_datajoint_error(self): + """ThreadSafetyError is a subclass of DataJointError.""" + from datajoint.errors import DataJointError, ThreadSafetyError + + assert issubclass(ThreadSafetyError, DataJointError) + + def test_error_in_exports(self): + """ThreadSafetyError is exported from datajoint.""" + import datajoint as dj + + assert hasattr(dj, "ThreadSafetyError")