diff --git a/README.md b/README.md
index b2856e407..33b18b248 100644
--- a/README.md
+++ b/README.md
@@ -42,8 +42,8 @@ DataJoint is a framework for scientific data pipelines based on the **Relational
Citation
-
-
+
+
Coverage
@@ -80,7 +80,7 @@ conda install -c conda-forge datajoint
- [How-To Guides](https://docs.datajoint.com/how-to/) — Task-oriented guides
- [API Reference](https://docs.datajoint.com/api/) — Complete API documentation
- [Migration Guide](https://docs.datajoint.com/how-to/migrate-to-v20/) — Upgrade from legacy versions
-- **[DataJoint Elements](https://datajoint.com/docs/elements/)** — Example pipelines for neuroscience
+- **[DataJoint Elements](https://docs.datajoint.com/elements/)** — Example pipelines for neuroscience
- **[GitHub Discussions](https://github.com/datajoint/datajoint-python/discussions)** — Community support
## Contributing
diff --git a/docs/design/thread-safe-mode.md b/docs/design/thread-safe-mode.md
new file mode 100644
index 000000000..5d7472667
--- /dev/null
+++ b/docs/design/thread-safe-mode.md
@@ -0,0 +1,387 @@
+# Thread-Safe Mode Specification
+
+## Problem
+
+DataJoint uses global state (`dj.config`, `dj.conn()`) that is not thread-safe. Multi-tenant applications (web servers, async workers) need isolated connections per request/task.
+
+## Solution
+
+Introduce **Instance** objects that encapsulate config and connection. The `dj` module provides a global config that can be modified before connecting, and a lazily-loaded singleton connection. New isolated instances are created with `dj.Instance()`.
+
+## API
+
+### Legacy API (global config + singleton connection)
+
+```python
+import datajoint as dj
+
+# Configure credentials (no connection yet)
+dj.config.database.user = "user"
+dj.config.database.password = "password"
+
+# First call to conn() or Schema() creates the singleton connection
+dj.conn() # Creates connection using dj.config credentials
+schema = dj.Schema("my_schema")
+
+@schema
+class Mouse(dj.Manual):
+ definition = "..."
+```
+
+Alternatively, pass credentials directly to `conn()`:
+```python
+dj.conn(host="localhost", user="user", password="password")
+```
+
+Internally:
+- `dj.config` → delegates to `_global_config` (with thread-safety check)
+- `dj.conn()` → returns `_singleton_connection` (created lazily)
+- `dj.Schema()` → uses `_singleton_connection`
+- `dj.FreeTable()` → uses `_singleton_connection`
+
+### New API (isolated instance)
+
+```python
+import datajoint as dj
+
+inst = dj.Instance(
+ host="localhost",
+ user="user",
+ password="password",
+)
+schema = inst.Schema("my_schema")
+
+@schema
+class Mouse(dj.Manual):
+ definition = "..."
+```
+
+### Instance structure
+
+Each instance has:
+- `inst.config` - Config (created fresh at instance creation)
+- `inst.connection` - Connection (created at instance creation)
+- `inst.Schema()` - Schema factory using instance's connection
+- `inst.FreeTable()` - FreeTable factory using instance's connection
+
+```python
+inst = dj.Instance(host="localhost", user="u", password="p")
+inst.config # Config instance
+inst.connection # Connection instance
+inst.Schema("name") # Creates schema using inst.connection
+inst.FreeTable("db.tbl") # Access table using inst.connection
+```
+
+### Table base classes vs instance methods
+
+**Base classes** (`dj.Manual`, `dj.Lookup`, etc.) - Used with `@schema` decorator:
+```python
+@schema
+class Mouse(dj.Manual): # dj.Manual - schema links to connection
+ definition = "..."
+```
+
+**Instance methods** (`inst.Schema()`, `inst.FreeTable()`) - Need connection directly:
+```python
+schema = inst.Schema("my_schema") # Uses inst.connection
+table = inst.FreeTable("db.table") # Uses inst.connection
+```
+
+### Thread-safe mode
+
+```bash
+export DJ_THREAD_SAFE=true
+```
+
+`thread_safe` is checked dynamically on each access to global state.
+
+When `thread_safe=True`, accessing global state raises `ThreadSafetyError`:
+- `dj.config` raises `ThreadSafetyError`
+- `dj.conn()` raises `ThreadSafetyError`
+- `dj.Schema()` raises `ThreadSafetyError` (without explicit connection)
+- `dj.FreeTable()` raises `ThreadSafetyError` (without explicit connection)
+- `dj.Instance()` works - isolated instances are always allowed
+
+```python
+# thread_safe=True
+
+dj.config # ThreadSafetyError
+dj.conn() # ThreadSafetyError
+dj.Schema("name") # ThreadSafetyError
+
+inst = dj.Instance(host="h", user="u", password="p") # OK
+inst.Schema("name") # OK
+```
+
+## Behavior Summary
+
+| Operation | `thread_safe=False` | `thread_safe=True` |
+|-----------|--------------------|--------------------|
+| `dj.config` | `_global_config` | `ThreadSafetyError` |
+| `dj.conn()` | `_singleton_connection` | `ThreadSafetyError` |
+| `dj.Schema()` | Uses singleton | `ThreadSafetyError` |
+| `dj.FreeTable()` | Uses singleton | `ThreadSafetyError` |
+| `dj.Instance()` | Works | Works |
+| `inst.config` | Works | Works |
+| `inst.connection` | Works | Works |
+| `inst.Schema()` | Works | Works |
+
+## Lazy Loading
+
+The global config is created at module import time. The singleton connection is created lazily on first access:
+
+```python
+dj.config.database.user = "user" # Modifies global config (no connection yet)
+dj.config.database.password = "pw"
+dj.conn() # Creates singleton connection using global config
+dj.Schema("name") # Uses existing singleton connection
+```
+
+## Usage Example
+
+```python
+import datajoint as dj
+
+# Create isolated instance
+inst = dj.Instance(
+ host="localhost",
+ user="user",
+ password="password",
+)
+
+# Create schema
+schema = inst.Schema("my_schema")
+
+@schema
+class Mouse(dj.Manual):
+ definition = """
+ mouse_id: int
+ """
+
+# Use tables
+Mouse().insert1({"mouse_id": 1})
+Mouse().fetch()
+```
+
+## Architecture
+
+### Object graph
+
+There is exactly **one** global `Config` object created at import time in `settings.py`. Both the legacy API and the `Instance` API hang off `Connection` objects, each of which carries a `_config` reference.
+
+```
+settings.py
+ config = _create_config() ← THE single global Config
+
+instance.py
+ _global_config = settings.config ← same object (not a copy)
+ _singleton_connection = None ← lazily created Connection
+
+__init__.py
+ dj.config = _ConfigProxy() ← proxy → _global_config (with thread-safety check)
+ dj.conn() ← returns _singleton_connection
+ dj.Schema() ← uses _singleton_connection
+ dj.FreeTable() ← uses _singleton_connection
+
+Connection (singleton)
+ _config → _global_config ← same Config that dj.config writes to
+
+Connection (Instance)
+ _config → fresh Config ← isolated per-instance
+```
+
+### Config flow: singleton path
+
+```
+dj.config["safemode"] = False
+ ↓ _ConfigProxy.__setitem__
+_global_config["safemode"] = False (same object as settings.config)
+ ↓
+Connection._config["safemode"] (points to _global_config)
+ ↓
+schema.drop() reads self.connection._config["safemode"] → False ✓
+```
+
+### Config flow: Instance path
+
+```
+inst = dj.Instance(host=..., user=..., password=...)
+ ↓
+inst.config = _create_config() (fresh Config, independent)
+inst.connection._config = inst.config
+ ↓
+inst.config["safemode"] = False
+ ↓
+schema.drop() reads self.connection._config["safemode"] → False ✓
+```
+
+### Key invariant
+
+**All runtime config reads go through `self.connection._config`**, never through the global `config` directly. This ensures both the singleton and Instance paths read the correct config.
+
+### Connection-scoped config reads
+
+Every module that previously imported `from .settings import config` now reads config from the connection:
+
+| Module | What was read | How it's read now |
+|--------|--------------|-------------------|
+| `schemas.py` | `config["safemode"]`, `config.database.create_tables` | `self.connection._config[...]` |
+| `table.py` | `config["safemode"]` in `delete()`, `drop()` | `self.connection._config["safemode"]` |
+| `expression.py` | `config["loglevel"]` in `__repr__()` | `self.connection._config["loglevel"]` |
+| `preview.py` | `config["display.*"]` (8 reads) | `query_expression.connection._config[...]` |
+| `autopopulate.py` | `config.jobs.allow_new_pk_fields`, `auto_refresh` | `self.connection._config.jobs.*` |
+| `jobs.py` | `config.jobs.default_priority`, `stale_timeout`, `keep_completed` | `self.connection._config.jobs.*` |
+| `declare.py` | `config.jobs.add_job_metadata` | `config` param (threaded from `table.py`) |
+| `diagram.py` | `config.display.diagram_direction` | `self._connection._config.display.*` |
+| `staged_insert.py` | `config.get_store_spec()` | `self._table.connection._config.get_store_spec()` |
+| `hash_registry.py` | `config.get_store_spec()` in 5 functions | `config` kwarg (falls back to `settings.config`) |
+| `builtin_codecs/hash.py` | `config` via hash_registry | `_config` from key dict → `config` kwarg to hash_registry |
+| `builtin_codecs/attach.py` | `config.get("download_path")` | `_config` from key dict (falls back to `settings.config`) |
+| `builtin_codecs/filepath.py` | `config.get_store_spec()` | `_config` from key dict (falls back to `settings.config`) |
+| `builtin_codecs/schema.py` | `config.get_store_spec()` in helpers | `config` kwarg to `_build_path()`, `_get_backend()` |
+| `builtin_codecs/npy.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers |
+| `builtin_codecs/object.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers |
+| `gc.py` | `config` via hash_registry | `schemas[0].connection._config` → `config` kwarg |
+
+### Functions that receive config as a parameter
+
+Some module-level functions cannot access `self.connection`. Config is threaded through:
+
+| Function | Caller | How config arrives |
+|----------|--------|--------------------|
+| `declare()` in `declare.py` | `Table.declare()` in `table.py` | `config=self.connection._config` kwarg |
+| `_get_job_version()` in `jobs.py` | `AutoPopulate._make_tuples()`, `Job.reserve()` | `config=self.connection._config` positional arg |
+| `get_store_backend()` in `hash_registry.py` | codecs, gc.py | `config` kwarg from key dict or schema connection |
+| `get_store_subfolding()` in `hash_registry.py` | `put_hash()` | `config` kwarg chained from caller |
+| `put_hash()` in `hash_registry.py` | `HashCodec.encode()` | `config` kwarg from `_config` in key dict |
+| `get_hash()` in `hash_registry.py` | `HashCodec.decode()` | `config` kwarg from `_config` in key dict |
+| `delete_path()` in `hash_registry.py` | `gc.collect()` | `config` kwarg from `schemas[0].connection._config` |
+| `decode_attribute()` in `codecs.py` | `expression.py` fetch methods | `connection` kwarg → extracts `connection._config` |
+
+All functions accept `config=None` and fall back to the global `settings.config` for backward compatibility.
+
+## Implementation
+
+### 1. Create Instance class
+
+```python
+class Instance:
+ def __init__(self, host, user, password, port=3306, **kwargs):
+ self.config = _create_config() # Fresh config with defaults
+ # Apply any config overrides from kwargs
+ self.connection = Connection(host, user, password, port, ...)
+ self.connection._config = self.config
+
+ def Schema(self, name, **kwargs):
+ return Schema(name, connection=self.connection, **kwargs)
+
+ def FreeTable(self, full_table_name):
+ return FreeTable(self.connection, full_table_name)
+```
+
+### 2. Global config and singleton connection
+
+```python
+# settings.py - THE single global config
+config = _create_config() # Created at import time
+
+# instance.py - reuses the same config object
+_global_config = settings.config # Same reference, not a copy
+_singleton_connection = None # Created lazily
+
+def _check_thread_safe():
+ if _load_thread_safe():
+ raise ThreadSafetyError(
+ "Global DataJoint state is disabled in thread-safe mode. "
+ "Use dj.Instance() to create an isolated instance."
+ )
+
+def _get_singleton_connection():
+ _check_thread_safe()
+ global _singleton_connection
+ if _singleton_connection is None:
+ _singleton_connection = Connection(
+ host=_global_config.database.host,
+ user=_global_config.database.user,
+ password=_global_config.database.password,
+ ...
+ )
+ _singleton_connection._config = _global_config
+ return _singleton_connection
+```
+
+### 3. Legacy API with thread-safety checks
+
+```python
+# dj.config -> global config with thread-safety check
+class _ConfigProxy:
+ def __getattr__(self, name):
+ _check_thread_safe()
+ return getattr(_global_config, name)
+ def __setattr__(self, name, value):
+ _check_thread_safe()
+ setattr(_global_config, name, value)
+
+config = _ConfigProxy()
+
+# dj.conn() -> singleton connection (persistent across calls)
+def conn(host=None, user=None, password=None, *, reset=False):
+ _check_thread_safe()
+ if reset or (_singleton_connection is None and credentials_provided):
+ _singleton_connection = Connection(...)
+ _singleton_connection._config = _global_config
+ return _get_singleton_connection()
+
+# dj.Schema() -> uses singleton connection
+def Schema(name, connection=None, **kwargs):
+ if connection is None:
+ _check_thread_safe()
+ connection = _get_singleton_connection()
+ return _Schema(name, connection=connection, **kwargs)
+
+# dj.FreeTable() -> uses singleton connection
+def FreeTable(conn_or_name, full_table_name=None):
+ if full_table_name is None:
+ _check_thread_safe()
+ return _FreeTable(_get_singleton_connection(), conn_or_name)
+ else:
+ return _FreeTable(conn_or_name, full_table_name)
+```
+
+## Global State Audit
+
+All module-level mutable state was reviewed for thread-safety implications.
+
+### Guarded (blocked in thread-safe mode)
+
+| State | Location | Mechanism |
+|-------|----------|-----------|
+| `config` singleton | `settings.py:979` | `_ConfigProxy` raises `ThreadSafetyError`; use `inst.config` instead |
+| `conn()` singleton | `connection.py:108` | `_check_thread_safe()` guard; use `inst.connection` instead |
+
+These are the two globals that carry connection-scoped state (credentials, database settings) and are the primary source of cross-tenant interference.
+
+### Safe by design (no guard needed)
+
+| State | Location | Rationale |
+|-------|----------|-----------|
+| `_codec_registry` | `codecs.py:47` | Effectively immutable after import. Registration runs in `__init_subclass__` under Python's import lock. Runtime mutation (`_load_entry_points`) is idempotent under the GIL. Codecs are part of the type system, not connection-scoped. |
+| `_entry_points_loaded` | `codecs.py:48` | Bool flag for idempotent lazy loading; worst case under concurrent access is redundant work, not corruption. |
+
+### Low risk (no guard needed)
+
+| State | Location | Rationale |
+|-------|----------|-----------|
+| Logging side effects | `logging.py:8,17,40-45,56` | Standard Python logging configuration. Monkey-patches `Logger` and replaces `sys.excepthook` at import time. Not DataJoint-specific mutable state. |
+| `use_32bit_dims` | `blob.py:65` | Runtime flag affecting deserialization. Rarely changed; not connection-scoped. |
+| `compression` dict | `blob.py:61` | Decompressor function registry. Populated at import time, effectively read-only thereafter. |
+| `_lazy_modules` | `__init__.py:92` | Import caching via `globals()` mutation. Protected by Python's import lock. |
+| `ADAPTERS` dict | `adapters/__init__.py:16` | Backend registry. Populated at import time, read-only in practice. |
+
+### Design principle
+
+Only state that is **connection-scoped** (credentials, database settings, connection objects) needs thread-safe guards. State that is **code-scoped** (type registries, import caches, logging configuration) is shared across all threads by design and does not vary between tenants.
+
+## Error Messages
+
+- Singleton access: `"Global DataJoint state is disabled in thread-safe mode. Use dj.Instance() to create an isolated instance."`
diff --git a/src/datajoint/__init__.py b/src/datajoint/__init__.py
index 7f809487d..7704ec1bc 100644
--- a/src/datajoint/__init__.py
+++ b/src/datajoint/__init__.py
@@ -23,6 +23,7 @@
"config",
"conn",
"Connection",
+ "Instance",
"Schema",
"VirtualModule",
"virtual_schema",
@@ -52,6 +53,7 @@
"errors",
"migrate",
"DataJointError",
+ "ThreadSafetyError",
"logger",
"cli",
"ValidationResult",
@@ -72,17 +74,191 @@
NpyRef,
)
from .blob import MatCell, MatStruct
-from .connection import Connection, conn
-from .errors import DataJointError
+from .connection import Connection
+from .errors import DataJointError, ThreadSafetyError
from .expression import AndList, Not, Top, U
+from .instance import Instance, _ConfigProxy, _get_singleton_connection, _global_config, _check_thread_safe
from .logging import logger
from .objectref import ObjectRef
-from .schemas import Schema, VirtualModule, list_schemas, virtual_schema
-from .settings import config
-from .table import FreeTable, Table, ValidationResult
+from .schemas import _Schema, VirtualModule, list_schemas, virtual_schema
+from .table import FreeTable as _FreeTable, Table, ValidationResult
from .user_tables import Computed, Imported, Lookup, Manual, Part
from .version import __version__
+# =============================================================================
+# Singleton-aware API
+# =============================================================================
+# config is a proxy that delegates to the singleton instance's config
+config = _ConfigProxy()
+
+
+def conn(
+ host: str | None = None,
+ user: str | None = None,
+ password: str | None = None,
+ *,
+ reset: bool = False,
+ use_tls: bool | dict | None = None,
+) -> Connection:
+ """
+ Return a persistent connection object.
+
+ When called without arguments, returns the singleton connection using
+ credentials from dj.config. When connection parameters are provided,
+ updates the singleton connection with the new credentials.
+
+ Parameters
+ ----------
+ host : str, optional
+ Database hostname. If provided, updates singleton.
+ user : str, optional
+ Database username. If provided, updates singleton.
+ password : str, optional
+ Database password. If provided, updates singleton.
+ reset : bool, optional
+ If True, reset existing connection. Default False.
+ use_tls : bool or dict, optional
+ TLS encryption option.
+
+ Returns
+ -------
+ Connection
+ Database connection.
+
+ Raises
+ ------
+ ThreadSafetyError
+ If thread_safe mode is enabled.
+ """
+ import datajoint.instance as instance_module
+ from pydantic import SecretStr
+
+ _check_thread_safe()
+
+ # If reset requested, always recreate
+ # If credentials provided and no singleton exists, create one
+ # If credentials provided and singleton exists, return existing singleton
+ if reset or (
+ instance_module._singleton_connection is None and (host is not None or user is not None or password is not None)
+ ):
+ # Use provided values or fall back to config
+ host = host if host is not None else _global_config.database.host
+ user = user if user is not None else _global_config.database.user
+ raw_password = password if password is not None else _global_config.database.password
+ password = raw_password.get_secret_value() if isinstance(raw_password, SecretStr) else raw_password
+ port = _global_config.database.port
+ use_tls = use_tls if use_tls is not None else _global_config.database.use_tls
+
+ if user is None:
+ from .errors import DataJointError
+
+ raise DataJointError("Database user not configured. Set dj.config['database.user'] or pass user= argument.")
+ if password is None:
+ from .errors import DataJointError
+
+ raise DataJointError(
+ "Database password not configured. Set dj.config['database.password'] or pass password= argument."
+ )
+
+ instance_module._singleton_connection = Connection(host, user, password, port, use_tls, config_override=_global_config)
+
+ return _get_singleton_connection()
+
+
+class Schema(_Schema):
+ """
+ Decorator that binds table classes to a database schema.
+
+ When connection is not provided, uses the singleton connection.
+ In thread-safe mode (``DJ_THREAD_SAFE=true``), a connection must be
+ provided explicitly or use ``dj.Instance().Schema()`` instead.
+
+ Parameters
+ ----------
+ schema_name : str, optional
+ Database schema name. If omitted, call ``activate()`` later.
+ context : dict, optional
+ Namespace for foreign key lookup. None uses caller's context.
+ connection : Connection, optional
+ Database connection. Defaults to singleton connection.
+ create_schema : bool, optional
+ If False, raise error if schema doesn't exist. Default True.
+ create_tables : bool, optional
+ If False, raise error when accessing missing tables.
+ add_objects : dict, optional
+ Additional objects for declaration context.
+
+ Raises
+ ------
+ ThreadSafetyError
+ If thread_safe mode is enabled and no connection is provided.
+
+ Examples
+ --------
+ >>> schema = dj.Schema('my_schema')
+ >>> @schema
+ ... class Session(dj.Manual):
+ ... definition = '''
+ ... session_id : int
+ ... '''
+ """
+
+ def __init__(
+ self,
+ schema_name: str | None = None,
+ context: dict | None = None,
+ *,
+ connection: Connection | None = None,
+ create_schema: bool = True,
+ create_tables: bool | None = None,
+ add_objects: dict | None = None,
+ ) -> None:
+ if connection is None:
+ _check_thread_safe()
+ super().__init__(
+ schema_name,
+ context=context,
+ connection=connection,
+ create_schema=create_schema,
+ create_tables=create_tables,
+ add_objects=add_objects,
+ )
+
+
+def FreeTable(conn_or_name, full_table_name: str | None = None) -> _FreeTable:
+ """
+ Create a FreeTable for accessing a table without a dedicated class.
+
+ Can be called in two ways:
+ - ``FreeTable("schema.table")`` - uses singleton connection
+ - ``FreeTable(connection, "schema.table")`` - uses provided connection
+
+ Parameters
+ ----------
+ conn_or_name : Connection or str
+ Either a Connection object, or the full table name if using singleton.
+ full_table_name : str, optional
+ Full table name when first argument is a connection.
+
+ Returns
+ -------
+ FreeTable
+ A FreeTable instance for the specified table.
+
+ Raises
+ ------
+ ThreadSafetyError
+ If thread_safe mode is enabled and using singleton.
+ """
+ if full_table_name is None:
+ # Called as FreeTable("db.table") - use singleton connection
+ _check_thread_safe()
+ return _FreeTable(_get_singleton_connection(), conn_or_name)
+ else:
+ # Called as FreeTable(conn, "db.table") - use provided connection
+ return _FreeTable(conn_or_name, full_table_name)
+
+
# =============================================================================
# Lazy imports — heavy dependencies loaded on first access
# =============================================================================
diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py
index 35b32ed5f..011f306ab 100644
--- a/src/datajoint/adapters/base.py
+++ b/src/datajoint/adapters/base.py
@@ -169,6 +169,27 @@ def quote_identifier(self, name: str) -> str:
"""
...
+ @abstractmethod
+ def split_full_table_name(self, full_table_name: str) -> tuple[str, str]:
+ """
+ Split a fully-qualified table name into schema and table components.
+
+ Inverse of quoting: strips backend-specific identifier quotes
+ and splits into (schema, table).
+
+ Parameters
+ ----------
+ full_table_name : str
+ Quoted full table name (e.g., ```\\`schema\\`.\\`table\\` ``` or
+ ``"schema"."table"``).
+
+ Returns
+ -------
+ tuple[str, str]
+ (schema_name, table_name) with quotes stripped.
+ """
+ ...
+
@abstractmethod
def quote_string(self, value: str) -> str:
"""
@@ -615,6 +636,23 @@ def list_schemas_sql(self) -> str:
"""
...
+ @abstractmethod
+ def schema_exists_sql(self, schema_name: str) -> str:
+ """
+ Generate query to check if a schema exists.
+
+ Parameters
+ ----------
+ schema_name : str
+ Name of schema to check.
+
+ Returns
+ -------
+ str
+ SQL query that returns a row if the schema exists.
+ """
+ ...
+
@abstractmethod
def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str:
"""
@@ -710,6 +748,55 @@ def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str:
"""
...
+ @abstractmethod
+ def load_primary_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
+ """
+ Generate query to load primary key columns for all tables across schemas.
+
+ Used by the dependency graph to build the schema graph.
+
+ Parameters
+ ----------
+ schemas_list : str
+ Comma-separated, quoted schema names for an IN clause.
+ like_pattern : str
+ SQL LIKE pattern to exclude (e.g., "'~%%'" for internal tables).
+
+ Returns
+ -------
+ str
+ SQL query returning rows with columns:
+ - tab: fully qualified table name (quoted)
+ - column_name: primary key column name
+ """
+ ...
+
+ @abstractmethod
+ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
+ """
+ Generate query to load foreign key relationships across schemas.
+
+ Used by the dependency graph to build the schema graph.
+
+ Parameters
+ ----------
+ schemas_list : str
+ Comma-separated, quoted schema names for an IN clause.
+ like_pattern : str
+ SQL LIKE pattern to exclude (e.g., "'~%%'" for internal tables).
+
+ Returns
+ -------
+ str
+ SQL query returning rows (as dicts) with columns:
+ - constraint_name: FK constraint name
+ - referencing_table: fully qualified child table name (quoted)
+ - referenced_table: fully qualified parent table name (quoted)
+ - column_name: FK column in child table
+ - referenced_column_name: referenced column in parent table
+ """
+ ...
+
@abstractmethod
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
"""
diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py
index 88339335f..3c28a85e6 100644
--- a/src/datajoint/adapters/mysql.py
+++ b/src/datajoint/adapters/mysql.py
@@ -75,7 +75,6 @@ def connect(
Password for authentication.
**kwargs : Any
Additional MySQL-specific parameters:
- - init_command: SQL initialization command
- ssl: TLS/SSL configuration dict (deprecated, use use_tls)
- use_tls: bool or dict - DataJoint's SSL parameter (preferred)
- charset: Character set (default from kwargs)
@@ -85,7 +84,6 @@ def connect(
pymysql.Connection
MySQL connection object.
"""
- init_command = kwargs.get("init_command")
# Handle both ssl (old) and use_tls (new) parameter names
ssl_config = kwargs.get("use_tls", kwargs.get("ssl"))
# Convert boolean True to dict for PyMySQL (PyMySQL expects dict or SSLContext)
@@ -99,7 +97,6 @@ def connect(
"port": port,
"user": user,
"passwd": password,
- "init_command": init_command,
"sql_mode": "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
"charset": charset,
@@ -203,6 +200,11 @@ def quote_identifier(self, name: str) -> str:
"""
return f"`{name}`"
+ def split_full_table_name(self, full_table_name: str) -> tuple[str, str]:
+ """Split ```\\`schema\\`.\\`table\\` ``` into ``('schema', 'table')``."""
+ schema, table = full_table_name.replace("`", "").split(".")
+ return schema, table
+
def quote_string(self, value: str) -> str:
"""
Quote string literal for MySQL with escaping.
@@ -614,6 +616,10 @@ def list_schemas_sql(self) -> str:
"""Query to list all databases in MySQL."""
return "SELECT schema_name FROM information_schema.schemata"
+ def schema_exists_sql(self, schema_name: str) -> str:
+ """Query to check if a database exists in MySQL."""
+ return f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = {self.quote_string(schema_name)}"
+
def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str:
"""Query to list tables in a database."""
sql = f"SHOW TABLES IN {self.quote_identifier(schema_name)}"
@@ -655,6 +661,32 @@ def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str:
f"ORDER BY constraint_name, ordinal_position"
)
+ def load_primary_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
+ """Query to load all primary key columns across schemas."""
+ tab_expr = "concat('`', table_schema, '`.`', table_name, '`')"
+ return (
+ f"SELECT {tab_expr} as tab, column_name "
+ f"FROM information_schema.key_column_usage "
+ f"WHERE table_name NOT LIKE {like_pattern} "
+ f"AND table_schema in ({schemas_list}) "
+ f"AND constraint_name='PRIMARY'"
+ )
+
+ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
+ """Query to load all foreign key relationships across schemas."""
+ tab_expr = "concat('`', table_schema, '`.`', table_name, '`')"
+ ref_tab_expr = "concat('`', referenced_table_schema, '`.`', referenced_table_name, '`')"
+ return (
+ f"SELECT constraint_name, "
+ f"{tab_expr} as referencing_table, "
+ f"{ref_tab_expr} as referenced_table, "
+ f"column_name, referenced_column_name "
+ f"FROM information_schema.key_column_usage "
+ f"WHERE referenced_table_name NOT LIKE {like_pattern} "
+ f"AND (referenced_table_schema in ({schemas_list}) "
+ f"OR referenced_table_schema is not NULL AND table_schema in ({schemas_list}))"
+ )
+
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
"""Query to get FK constraint details from information_schema."""
return (
diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py
index 12fecae6a..2caebef75 100644
--- a/src/datajoint/adapters/postgres.py
+++ b/src/datajoint/adapters/postgres.py
@@ -249,6 +249,11 @@ def quote_identifier(self, name: str) -> str:
"""
return f'"{name}"'
+ def split_full_table_name(self, full_table_name: str) -> tuple[str, str]:
+ """Split ``"schema"."table"`` into ``('schema', 'table')``."""
+ schema, table = full_table_name.replace('"', "").split(".")
+ return schema, table
+
def quote_string(self, value: str) -> str:
"""
Quote string literal for PostgreSQL with escaping.
@@ -721,6 +726,10 @@ def list_schemas_sql(self) -> str:
"WHERE schema_name NOT IN ('pg_catalog', 'information_schema')"
)
+ def schema_exists_sql(self, schema_name: str) -> str:
+ """Query to check if a schema exists in PostgreSQL."""
+ return f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = {self.quote_string(schema_name)}"
+
def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str:
"""Query to list tables in a schema."""
sql = (
@@ -795,6 +804,44 @@ def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str:
f"ORDER BY kcu.constraint_name, kcu.ordinal_position"
)
+ def load_primary_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
+ """Query to load all primary key columns across schemas."""
+ tab_expr = "'\"' || kcu.table_schema || '\".\"' || kcu.table_name || '\"'"
+ return (
+ f"SELECT {tab_expr} as tab, kcu.column_name "
+ f"FROM information_schema.key_column_usage kcu "
+ f"JOIN information_schema.table_constraints tc "
+ f"ON kcu.constraint_name = tc.constraint_name "
+ f"AND kcu.table_schema = tc.table_schema "
+ f"WHERE kcu.table_name NOT LIKE {like_pattern} "
+ f"AND kcu.table_schema in ({schemas_list}) "
+ f"AND tc.constraint_type = 'PRIMARY KEY'"
+ )
+
+ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
+ """Query to load all foreign key relationships across schemas."""
+ return (
+ f"SELECT "
+ f"c.conname as constraint_name, "
+ f"'\"' || ns1.nspname || '\".\"' || cl1.relname || '\"' as referencing_table, "
+ f"'\"' || ns2.nspname || '\".\"' || cl2.relname || '\"' as referenced_table, "
+ f"a1.attname as column_name, "
+ f"a2.attname as referenced_column_name "
+ f"FROM pg_constraint c "
+ f"JOIN pg_class cl1 ON c.conrelid = cl1.oid "
+ f"JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid "
+ f"JOIN pg_class cl2 ON c.confrelid = cl2.oid "
+ f"JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid "
+ f"CROSS JOIN LATERAL unnest(c.conkey, c.confkey) WITH ORDINALITY AS cols(conkey, confkey, ord) "
+ f"JOIN pg_attribute a1 ON a1.attrelid = cl1.oid AND a1.attnum = cols.conkey "
+ f"JOIN pg_attribute a2 ON a2.attrelid = cl2.oid AND a2.attnum = cols.confkey "
+ f"WHERE c.contype = 'f' "
+ f"AND cl1.relname NOT LIKE {like_pattern} "
+ f"AND (ns2.nspname in ({schemas_list}) "
+ f"OR ns1.nspname in ({schemas_list})) "
+ f"ORDER BY c.conname, cols.ord"
+ )
+
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
"""
Query to get FK constraint details from information_schema.
diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py
index 7660e43ec..ae8be3b82 100644
--- a/src/datajoint/autopopulate.py
+++ b/src/datajoint/autopopulate.py
@@ -146,10 +146,8 @@ def _declare_check(self, primary_key: list[str], fk_attribute_map: dict[str, tup
If native (non-FK) PK attributes are found, unless bypassed via
``dj.config.jobs.allow_new_pk_fields_in_computed_tables = True``.
"""
- from .settings import config
-
# Check if validation is bypassed
- if config.jobs.allow_new_pk_fields_in_computed_tables:
+ if self.connection._config.jobs.allow_new_pk_fields_in_computed_tables:
return
# Check for native (non-FK) primary key attributes
@@ -477,8 +475,6 @@ def _populate_distributed(
"""
from tqdm import tqdm
- from .settings import config
-
# Define a signal handler for SIGTERM
def handler(signum, frame):
logger.info("Populate terminated by SIGTERM")
@@ -489,7 +485,7 @@ def handler(signum, frame):
try:
# Refresh job queue if configured
if refresh is None:
- refresh = config.jobs.auto_refresh
+ refresh = self.connection._config.jobs.auto_refresh
if refresh:
# Use delay=-1 to ensure jobs are immediately schedulable
# (avoids race condition with scheduled_time <= CURRENT_TIMESTAMP(3) check)
@@ -659,7 +655,7 @@ def _populate1(
key,
start_time=datetime.datetime.fromtimestamp(start_time),
duration=duration,
- version=_get_job_version(),
+ version=_get_job_version(self.connection._config),
)
if jobs is not None:
diff --git a/src/datajoint/builtin_codecs/attach.py b/src/datajoint/builtin_codecs/attach.py
index f9a454b1a..aa10f2424 100644
--- a/src/datajoint/builtin_codecs/attach.py
+++ b/src/datajoint/builtin_codecs/attach.py
@@ -98,14 +98,15 @@ def decode(self, stored: bytes, *, key: dict | None = None) -> str:
"""
from pathlib import Path
- from ..settings import config
-
# Split on first null byte
null_pos = stored.index(b"\x00")
filename = stored[:null_pos].decode("utf-8")
contents = stored[null_pos + 1 :]
# Write to download path
+ config = (key or {}).get("_config")
+ if config is None:
+ from ..settings import config
download_path = Path(config.get("download_path", "."))
download_path.mkdir(parents=True, exist_ok=True)
local_path = download_path / filename
diff --git a/src/datajoint/builtin_codecs/filepath.py b/src/datajoint/builtin_codecs/filepath.py
index 9c05b2385..a0400499b 100644
--- a/src/datajoint/builtin_codecs/filepath.py
+++ b/src/datajoint/builtin_codecs/filepath.py
@@ -98,9 +98,12 @@ def encode(self, value: Any, *, key: dict | None = None, store_name: str | None
"""
from datetime import datetime, timezone
- from .. import config
from ..hash_registry import get_store_backend
+ config = (key or {}).get("_config")
+ if config is None:
+ from ..settings import config
+
path = str(value)
# Get store spec to check prefix configuration
@@ -137,7 +140,7 @@ def encode(self, value: Any, *, key: dict | None = None, store_name: str | None
raise ValueError(f" must use prefix '{filepath_prefix}' (filepath_prefix). Got path: {path}")
# Verify file exists
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
if not backend.exists(path):
raise FileNotFoundError(f"File not found in store '{store_name or 'default'}': {path}")
@@ -174,8 +177,9 @@ def decode(self, stored: dict, *, key: dict | None = None) -> Any:
from ..objectref import ObjectRef
from ..hash_registry import get_store_backend
+ config = (key or {}).get("_config")
store_name = stored.get("store")
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
return ObjectRef.from_json(stored, backend=backend)
def validate(self, value: Any) -> None:
diff --git a/src/datajoint/builtin_codecs/hash.py b/src/datajoint/builtin_codecs/hash.py
index 676c1916f..bb3a3852f 100644
--- a/src/datajoint/builtin_codecs/hash.py
+++ b/src/datajoint/builtin_codecs/hash.py
@@ -76,7 +76,8 @@ def encode(self, value: bytes, *, key: dict | None = None, store_name: str | Non
from ..hash_registry import put_hash
schema_name = (key or {}).get("_schema", "unknown")
- return put_hash(value, schema_name=schema_name, store_name=store_name)
+ config = (key or {}).get("_config")
+ return put_hash(value, schema_name=schema_name, store_name=store_name, config=config)
def decode(self, stored: dict, *, key: dict | None = None) -> bytes:
"""
@@ -96,7 +97,8 @@ def decode(self, stored: dict, *, key: dict | None = None) -> bytes:
"""
from ..hash_registry import get_hash
- return get_hash(stored)
+ config = (key or {}).get("_config")
+ return get_hash(stored, config=config)
def validate(self, value: Any) -> None:
"""Validate that value is bytes."""
diff --git a/src/datajoint/builtin_codecs/npy.py b/src/datajoint/builtin_codecs/npy.py
index 51c5731ee..54853437b 100644
--- a/src/datajoint/builtin_codecs/npy.py
+++ b/src/datajoint/builtin_codecs/npy.py
@@ -336,9 +336,10 @@ def encode(
# Extract context using inherited helper
schema, table, field, primary_key = self._extract_context(key)
+ config = (key or {}).get("_config")
# Build schema-addressed storage path
- path, _ = self._build_path(schema, table, field, primary_key, ext=".npy", store_name=store_name)
+ path, _ = self._build_path(schema, table, field, primary_key, ext=".npy", store_name=store_name, config=config)
# Serialize to .npy format
buffer = io.BytesIO()
@@ -346,7 +347,7 @@ def encode(
npy_bytes = buffer.getvalue()
# Upload to storage using inherited helper
- backend = self._get_backend(store_name)
+ backend = self._get_backend(store_name, config=config)
backend.put_buffer(npy_bytes, path)
# Return metadata (includes numpy-specific shape/dtype)
@@ -373,5 +374,6 @@ def decode(self, stored: dict, *, key: dict | None = None) -> NpyRef:
NpyRef
Lazy array reference with metadata access and numpy integration.
"""
- backend = self._get_backend(stored.get("store"))
+ config = (key or {}).get("_config")
+ backend = self._get_backend(stored.get("store"), config=config)
return NpyRef(stored, backend)
diff --git a/src/datajoint/builtin_codecs/object.py b/src/datajoint/builtin_codecs/object.py
index 268651aea..1c0d8c673 100644
--- a/src/datajoint/builtin_codecs/object.py
+++ b/src/datajoint/builtin_codecs/object.py
@@ -104,6 +104,7 @@ def encode(
# Extract context using inherited helper
schema, table, field, primary_key = self._extract_context(key)
+ config = (key or {}).get("_config")
# Check for pre-computed metadata (from staged insert)
if isinstance(value, dict) and "path" in value:
@@ -145,10 +146,10 @@ def encode(
raise TypeError(f" expects bytes or path, got {type(value).__name__}")
# Build storage path using inherited helper
- path, token = self._build_path(schema, table, field, primary_key, ext=ext, store_name=store_name)
+ path, token = self._build_path(schema, table, field, primary_key, ext=ext, store_name=store_name, config=config)
# Get storage backend using inherited helper
- backend = self._get_backend(store_name)
+ backend = self._get_backend(store_name, config=config)
# Upload content
if is_dir:
@@ -192,7 +193,8 @@ def decode(self, stored: dict, *, key: dict | None = None) -> Any:
"""
from ..objectref import ObjectRef
- backend = self._get_backend(stored.get("store"))
+ config = (key or {}).get("_config")
+ backend = self._get_backend(stored.get("store"), config=config)
return ObjectRef.from_json(stored, backend=backend)
def validate(self, value: Any) -> None:
diff --git a/src/datajoint/builtin_codecs/schema.py b/src/datajoint/builtin_codecs/schema.py
index 18bd62d00..c8cc0759d 100644
--- a/src/datajoint/builtin_codecs/schema.py
+++ b/src/datajoint/builtin_codecs/schema.py
@@ -108,6 +108,7 @@ def _build_path(
primary_key: dict,
ext: str | None = None,
store_name: str | None = None,
+ config=None,
) -> tuple[str, str]:
"""
Build schema-addressed storage path.
@@ -131,6 +132,8 @@ def _build_path(
File extension (e.g., ".npy", ".zarr").
store_name : str, optional
Store name for retrieving partition configuration.
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
@@ -139,7 +142,9 @@ def _build_path(
is a unique identifier.
"""
from ..storage import build_object_path
- from .. import config
+
+ if config is None:
+ from ..settings import config
# Get store configuration for partition_pattern and token_length
spec = config.get_store_spec(store_name)
@@ -156,7 +161,7 @@ def _build_path(
token_length=token_length,
)
- def _get_backend(self, store_name: str | None = None):
+ def _get_backend(self, store_name: str | None = None, config=None):
"""
Get storage backend by name.
@@ -164,6 +169,8 @@ def _get_backend(self, store_name: str | None = None):
----------
store_name : str, optional
Store name. If None, returns default store.
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
@@ -172,4 +179,4 @@ def _get_backend(self, store_name: str | None = None):
"""
from ..hash_registry import get_store_backend
- return get_store_backend(store_name)
+ return get_store_backend(store_name, config=config)
diff --git a/src/datajoint/codecs.py b/src/datajoint/codecs.py
index 5c192d46e..d7fbaf42d 100644
--- a/src/datajoint/codecs.py
+++ b/src/datajoint/codecs.py
@@ -43,7 +43,15 @@ class MyTable(dj.Manual):
logger = logging.getLogger(__name__.split(".")[0])
-# Global codec registry - maps name to Codec instance
+# Global codec registry - maps name to Codec instance.
+#
+# Thread safety: This registry is effectively immutable after import.
+# Registration happens in __init_subclass__ during class definition, which is
+# serialized by Python's import lock. The only runtime mutation is
+# _load_entry_points(), which is idempotent and guarded by a bool flag;
+# under CPython's GIL, concurrent calls may do redundant work but cannot
+# corrupt the dict. Codecs are part of the type system (tied to code, not to
+# any particular connection or tenant), so per-instance isolation is unnecessary.
_codec_registry: dict[str, Codec] = {}
_entry_points_loaded: bool = False
@@ -507,7 +515,7 @@ def lookup_codec(codec_spec: str) -> tuple[Codec, str | None]:
# =============================================================================
-def decode_attribute(attr, data, squeeze: bool = False):
+def decode_attribute(attr, data, squeeze: bool = False, connection=None):
"""
Decode raw database value using attribute's codec or native type handling.
@@ -520,6 +528,8 @@ def decode_attribute(attr, data, squeeze: bool = False):
attr: Attribute from the table's heading.
data: Raw value fetched from the database.
squeeze: If True, remove singleton dimensions from numpy arrays.
+ connection: Connection instance for config access. If provided,
+ ``connection._config`` is passed to codecs via the key dict.
Returns:
Decoded Python value.
@@ -552,9 +562,14 @@ def decode_attribute(attr, data, squeeze: bool = False):
elif final_dtype.lower() == "binary(16)":
data = uuid_module.UUID(bytes=data)
+ # Build decode key with config if connection is available
+ decode_key = None
+ if connection is not None:
+ decode_key = {"_config": connection._config}
+
# Apply decoders in reverse order: innermost first, then outermost
for codec in reversed(type_chain):
- data = codec.decode(data, key=None)
+ data = codec.decode(data, key=decode_key)
# Squeeze arrays if requested
if squeeze and isinstance(data, np.ndarray):
diff --git a/src/datajoint/condition.py b/src/datajoint/condition.py
index 0335d6adb..55f095246 100644
--- a/src/datajoint/condition.py
+++ b/src/datajoint/condition.py
@@ -244,6 +244,13 @@ def assert_join_compatibility(
if isinstance(expr1, U) or isinstance(expr2, U):
return
+ # Check that both expressions use the same connection
+ if expr1.connection is not expr2.connection:
+ raise DataJointError(
+ "Cannot operate on expressions from different connections. "
+ "Ensure both operands use the same dj.Instance or global connection."
+ )
+
if semantic_check:
# Check if lineage tracking is available for both expressions
if not expr1.heading.lineage_available or not expr2.heading.lineage_available:
diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py
index 21b48e638..e9eab0921 100644
--- a/src/datajoint/connection.py
+++ b/src/datajoint/connection.py
@@ -11,13 +11,16 @@
import re
import warnings
from contextlib import contextmanager
-from typing import Callable
+from typing import TYPE_CHECKING
from . import errors
from .adapters import get_adapter
from .blob import pack, unpack
from .dependencies import Dependencies
from .settings import config
+
+if TYPE_CHECKING:
+ from .settings import Config
from .version import __version__
logger = logging.getLogger(__name__.split(".")[0])
@@ -55,7 +58,6 @@ def conn(
user: str | None = None,
password: str | None = None,
*,
- init_fun: Callable | None = None,
reset: bool = False,
use_tls: bool | dict | None = None,
) -> Connection:
@@ -73,8 +75,6 @@ def conn(
Database username. Required if not set in config.
password : str, optional
Database password. Required if not set in config.
- init_fun : callable, optional
- Initialization function called after connection.
reset : bool, optional
If True, reset existing connection. Default False.
use_tls : bool or dict, optional
@@ -103,9 +103,8 @@ def conn(
raise errors.DataJointError(
"Database password not configured. Set datajoint.config['database.password'] or pass password= argument."
)
- init_fun = init_fun if init_fun is not None else config["connection.init_function"]
use_tls = use_tls if use_tls is not None else config["database.use_tls"]
- conn.connection = Connection(host, user, password, None, init_fun, use_tls)
+ conn.connection = Connection(host, user, password, None, use_tls)
return conn.connection
@@ -150,8 +149,6 @@ class Connection:
Database password.
port : int, optional
Port number. Overridden if specified in host.
- init_fun : str, optional
- SQL initialization command.
use_tls : bool or dict, optional
TLS encryption option.
@@ -169,15 +166,20 @@ def __init__(
user: str,
password: str,
port: int | None = None,
- init_fun: str | None = None,
use_tls: bool | dict | None = None,
+ *,
+ backend: str | None = None,
+ config_override: "Config | None" = None,
) -> None:
+ # Config reference — use override if provided, else global config
+ self._config = config_override if config_override is not None else config
+
if ":" in host:
# the port in the hostname overrides the port argument
host, port = host.split(":")
port = int(port)
elif port is None:
- port = config["database.port"]
+ port = self._config["database.port"]
self.conn_info = dict(host=host, port=port, user=user, passwd=password)
if use_tls is not False:
# use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config)
@@ -190,13 +192,13 @@ def __init__(
# use_tls=True: enable SSL with default settings
self.conn_info["ssl"] = True
self.conn_info["ssl_input"] = use_tls
- self.init_fun = init_fun
self._conn = None
self._query_cache = None
self._is_closed = True # Mark as closed until connect() succeeds
- # Select adapter based on configured backend
- backend = config["database.backend"]
+ # Select adapter: explicit backend > config backend
+ if backend is None:
+ backend = self._config["database.backend"]
self.adapter = get_adapter(backend)
self.connect()
@@ -227,8 +229,7 @@ def connect(self) -> None:
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
- init_command=self.init_fun,
- charset=config["connection.charset"],
+ charset=self._config["connection.charset"],
use_tls=self.conn_info.get("ssl"),
)
except Exception as ssl_error:
@@ -244,8 +245,7 @@ def connect(self) -> None:
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
- init_command=self.init_fun,
- charset=config["connection.charset"],
+ charset=self._config["connection.charset"],
use_tls=False, # Explicitly disable SSL for fallback
)
else:
@@ -271,8 +271,8 @@ def set_query_cache(self, query_cache: str | None = None) -> None:
def purge_query_cache(self) -> None:
"""Delete all cached query results."""
- if isinstance(config.get(cache_key), str) and pathlib.Path(config[cache_key]).is_dir():
- for path in pathlib.Path(config[cache_key]).iterdir():
+ if isinstance(self._config.get(cache_key), str) and pathlib.Path(self._config[cache_key]).is_dir():
+ for path in pathlib.Path(self._config[cache_key]).iterdir():
if not path.is_dir():
path.unlink()
@@ -413,11 +413,11 @@ def query(
if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.")
if use_query_cache:
- if not config[cache_key]:
+ if not self._config[cache_key]:
raise errors.DataJointError(f"Provide filepath dj.config['{cache_key}'] when using query caching.")
# Cache key is backend-specific (no identifier normalization needed)
hash_ = hashlib.md5((str(self._query_cache)).encode() + pack(args) + query.encode()).hexdigest()
- cache_path = pathlib.Path(config[cache_key]) / str(hash_)
+ cache_path = pathlib.Path(self._config[cache_key]) / str(hash_)
try:
buffer = cache_path.read_bytes()
except FileNotFoundError:
@@ -426,7 +426,7 @@ def query(
return EmulatedCursor(unpack(buffer))
if reconnect is None:
- reconnect = config["database.reconnect"]
+ reconnect = self._config["database.reconnect"]
logger.debug("Executing SQL:" + query[:query_log_max_length])
cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict)
try:
diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py
index 375daa07e..6af24ae55 100644
--- a/src/datajoint/declare.py
+++ b/src/datajoint/declare.py
@@ -15,7 +15,6 @@
from .codecs import lookup_codec
from .condition import translate_attribute
from .errors import DataJointError
-from .settings import config
# Core DataJoint types - scientist-friendly names that are fully supported
# These are recorded in field comments using :type: syntax for reconstruction
@@ -295,12 +294,9 @@ def compile_foreign_key(
# ref.support[0] may have cached quoting from a different backend
# Extract database and table name and rebuild with current adapter
parent_full_name = ref.support[0]
- # Try to parse as database.table (with or without quotes)
- parts = parent_full_name.replace('"', "").replace("`", "").split(".")
- if len(parts) == 2:
- ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}"
- else:
- ref_table_name = adapter.quote_identifier(parts[0])
+ # Parse as database.table using the adapter's quoting convention
+ parts = adapter.split_full_table_name(parent_full_name)
+ ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}"
foreign_key_sql.append(
f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT"
@@ -401,7 +397,7 @@ def prepare_declare(
def declare(
- full_table_name: str, definition: str, context: dict, adapter
+ full_table_name: str, definition: str, context: dict, adapter, *, config=None
) -> tuple[str, list[str], list[str], dict[str, tuple[str, str]], list[str], list[str]]:
r"""
Parse a definition and generate SQL CREATE TABLE statement.
@@ -416,6 +412,8 @@ def declare(
Namespace for resolving foreign key references.
adapter : DatabaseAdapter
Database adapter for backend-specific SQL generation.
+ config : Config, optional
+ Configuration object. If None, falls back to global config.
Returns
-------
@@ -464,6 +462,10 @@ def declare(
) = prepare_declare(definition, context, adapter)
# Add hidden job metadata for Computed/Imported tables (not parts)
+ if config is None:
+ from .settings import config as _config
+
+ config = _config
if config.jobs.add_job_metadata:
# Check if this is a Computed (__) or Imported (_) table, but not a Part (contains __ in middle)
is_computed = table_name.startswith("__") and "__" not in table_name[2:]
diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py
index 83162a112..99556345e 100644
--- a/src/datajoint/dependencies.py
+++ b/src/datajoint/dependencies.py
@@ -164,92 +164,21 @@ def load(self, force: bool = True) -> None:
# Build schema list for IN clause
schemas_list = ", ".join(adapter.quote_string(s) for s in self._conn.schemas)
- # Backend-specific queries for primary keys and foreign keys
- # Note: Both PyMySQL and psycopg2 use %s placeholders, so escape % as %%
+ # Load primary keys and foreign keys via adapter methods
+ # Note: Both PyMySQL and psycopg use %s placeholders, so escape % as %%
like_pattern = "'~%%'"
- if adapter.backend == "mysql":
- # MySQL: use concat() and MySQL-specific information_schema columns
- tab_expr = "concat('`', table_schema, '`.`', table_name, '`')"
-
- # load primary key info (MySQL uses constraint_name='PRIMARY')
- keys = self._conn.query(
- f"""
- SELECT {tab_expr} as tab, column_name
- FROM information_schema.key_column_usage
- WHERE table_name NOT LIKE {like_pattern}
- AND table_schema in ({schemas_list})
- AND constraint_name='PRIMARY'
- """
- )
- pks = defaultdict(set)
- for key in keys:
- pks[key[0]].add(key[1])
-
- # load foreign keys (MySQL has referenced_* columns)
- ref_tab_expr = "concat('`', referenced_table_schema, '`.`', referenced_table_name, '`')"
- fk_keys = self._conn.query(
- f"""
- SELECT constraint_name,
- {tab_expr} as referencing_table,
- {ref_tab_expr} as referenced_table,
- column_name, referenced_column_name
- FROM information_schema.key_column_usage
- WHERE referenced_table_name NOT LIKE {like_pattern}
- AND (referenced_table_schema in ({schemas_list})
- OR referenced_table_schema is not NULL AND table_schema in ({schemas_list}))
- """,
- as_dict=True,
- )
- else:
- # PostgreSQL: use || concatenation and different query structure
- tab_expr = "'\"' || kcu.table_schema || '\".\"' || kcu.table_name || '\"'"
-
- # load primary key info (PostgreSQL uses constraint_type='PRIMARY KEY')
- keys = self._conn.query(
- f"""
- SELECT {tab_expr} as tab, kcu.column_name
- FROM information_schema.key_column_usage kcu
- JOIN information_schema.table_constraints tc
- ON kcu.constraint_name = tc.constraint_name
- AND kcu.table_schema = tc.table_schema
- WHERE kcu.table_name NOT LIKE {like_pattern}
- AND kcu.table_schema in ({schemas_list})
- AND tc.constraint_type = 'PRIMARY KEY'
- """
- )
- pks = defaultdict(set)
- for key in keys:
- pks[key[0]].add(key[1])
-
- # load foreign keys using pg_constraint system catalogs
- # The information_schema approach creates a Cartesian product for composite FKs
- # because constraint_column_usage doesn't have ordinal_position.
- # Using pg_constraint with unnest(conkey, confkey) WITH ORDINALITY gives correct mapping.
- fk_keys = self._conn.query(
- f"""
- SELECT
- c.conname as constraint_name,
- '"' || ns1.nspname || '"."' || cl1.relname || '"' as referencing_table,
- '"' || ns2.nspname || '"."' || cl2.relname || '"' as referenced_table,
- a1.attname as column_name,
- a2.attname as referenced_column_name
- FROM pg_constraint c
- JOIN pg_class cl1 ON c.conrelid = cl1.oid
- JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid
- JOIN pg_class cl2 ON c.confrelid = cl2.oid
- JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid
- CROSS JOIN LATERAL unnest(c.conkey, c.confkey) WITH ORDINALITY AS cols(conkey, confkey, ord)
- JOIN pg_attribute a1 ON a1.attrelid = cl1.oid AND a1.attnum = cols.conkey
- JOIN pg_attribute a2 ON a2.attrelid = cl2.oid AND a2.attnum = cols.confkey
- WHERE c.contype = 'f'
- AND cl1.relname NOT LIKE {like_pattern}
- AND (ns2.nspname in ({schemas_list})
- OR ns1.nspname in ({schemas_list}))
- ORDER BY c.conname, cols.ord
- """,
- as_dict=True,
- )
+ # load primary key info
+ keys = self._conn.query(adapter.load_primary_keys_sql(schemas_list, like_pattern))
+ pks = defaultdict(set)
+ for key in keys:
+ pks[key[0]].add(key[1])
+
+ # load foreign keys
+ fk_keys = self._conn.query(
+ adapter.load_foreign_keys_sql(schemas_list, like_pattern),
+ as_dict=True,
+ )
# add nodes to the graph
for n, pk in pks.items():
diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py
index 7034d122b..75e00c21c 100644
--- a/src/datajoint/diagram.py
+++ b/src/datajoint/diagram.py
@@ -16,7 +16,6 @@
from .dependencies import topo_sort
from .errors import DataJointError
-from .settings import config
from .table import Table, lookup_class_name
from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier
@@ -105,6 +104,7 @@ def __init__(self, source, context=None) -> None:
self.nodes_to_show = set(source.nodes_to_show)
self._expanded_nodes = set(source._expanded_nodes)
self.context = source.context
+ self._connection = source._connection
super().__init__(source)
return
@@ -126,6 +126,7 @@ def __init__(self, source, context=None) -> None:
raise DataJointError("Could not find database connection in %s" % repr(source[0]))
# initialize graph from dependencies
+ self._connection = connection
connection.dependencies.load()
super().__init__(connection.dependencies)
@@ -584,7 +585,7 @@ def make_dot(self):
Tables are grouped by schema, with the Python module name shown as the
group label when available.
"""
- direction = config.display.diagram_direction
+ direction = self._connection._config.display.diagram_direction
graph = self._make_graph()
# Apply collapse logic if needed
@@ -857,7 +858,7 @@ def make_mermaid(self) -> str:
Session --> Neuron
"""
graph = self._make_graph()
- direction = config.display.diagram_direction
+ direction = self._connection._config.display.diagram_direction
# Apply collapse logic if needed
graph, collapsed_counts = self._apply_collapse(graph)
diff --git a/src/datajoint/errors.py b/src/datajoint/errors.py
index 7e10f021d..bba032b23 100644
--- a/src/datajoint/errors.py
+++ b/src/datajoint/errors.py
@@ -72,3 +72,7 @@ class MissingExternalFile(DataJointError):
class BucketInaccessible(DataJointError):
"""S3 bucket is inaccessible."""
+
+
+class ThreadSafetyError(DataJointError):
+ """Global DataJoint state is disabled in thread-safe mode."""
diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py
index 883853cd3..1b5f5ac9e 100644
--- a/src/datajoint/expression.py
+++ b/src/datajoint/expression.py
@@ -20,7 +20,6 @@
from .errors import DataJointError
from .codecs import decode_attribute
from .preview import preview, repr_html
-from .settings import config
logger = logging.getLogger(__name__.split(".")[0])
@@ -716,7 +715,7 @@ def fetch(
import warnings
warnings.warn(
- "fetch() is deprecated in DataJoint 2.0. " "Use to_dicts(), to_pandas(), to_arrays(), or keys() instead.",
+ "fetch() is deprecated in DataJoint 2.0. Use to_dicts(), to_pandas(), to_arrays(), or keys() instead.",
DeprecationWarning,
stacklevel=2,
)
@@ -818,7 +817,10 @@ def fetch1(self, *attrs, squeeze=False):
row = cursor.fetchone()
if not row or cursor.fetchone():
raise DataJointError("fetch1 requires exactly one tuple in the input set.")
- return {name: decode_attribute(heading[name], row[name], squeeze=squeeze) for name in heading.names}
+ return {
+ name: decode_attribute(heading[name], row[name], squeeze=squeeze, connection=self.connection)
+ for name in heading.names
+ }
else:
# Handle "KEY" specially - it means primary key columns
def is_key(attr):
@@ -893,7 +895,10 @@ def to_dicts(self, order_by=None, limit=None, offset=None, squeeze=False):
expr = self._apply_top(order_by, limit, offset)
cursor = expr.cursor(as_dict=True)
heading = expr.heading
- return [{name: decode_attribute(heading[name], row[name], squeeze) for name in heading.names} for row in cursor]
+ return [
+ {name: decode_attribute(heading[name], row[name], squeeze, connection=expr.connection) for name in heading.names}
+ for row in cursor
+ ]
def to_pandas(self, order_by=None, limit=None, offset=None, squeeze=False):
"""
@@ -1064,7 +1069,7 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset
return result_arrays[0] if len(attrs) == 1 else tuple(result_arrays)
else:
# Fetch all columns as structured array
- get = partial(decode_attribute, squeeze=squeeze)
+ get = partial(decode_attribute, squeeze=squeeze, connection=expr.connection)
cursor = expr.cursor(as_dict=False)
rows = list(cursor.fetchall())
@@ -1218,7 +1223,10 @@ def __iter__(self):
cursor = self.cursor(as_dict=True)
heading = self.heading
for row in cursor:
- yield {name: decode_attribute(heading[name], row[name], squeeze=False) for name in heading.names}
+ yield {
+ name: decode_attribute(heading[name], row[name], squeeze=False, connection=self.connection)
+ for name in heading.names
+ }
def cursor(self, as_dict=False):
"""
@@ -1247,7 +1255,7 @@ def __repr__(self):
str
String representation of the QueryExpression.
"""
- return super().__repr__() if config["loglevel"].lower() == "debug" else self.preview()
+ return super().__repr__() if self.connection._config["loglevel"].lower() == "debug" else self.preview()
def preview(self, limit=None, width=None):
"""
@@ -1406,8 +1414,11 @@ def create(cls, arg1, arg2):
arg2 = arg2() # instantiate if a class
if not isinstance(arg2, QueryExpression):
raise DataJointError("A QueryExpression can only be unioned with another QueryExpression")
- if arg1.connection != arg2.connection:
- raise DataJointError("Cannot operate on QueryExpressions originating from different connections.")
+ if arg1.connection is not arg2.connection:
+ raise DataJointError(
+ "Cannot operate on expressions from different connections. "
+ "Ensure both operands use the same dj.Instance or global connection."
+ )
if set(arg1.primary_key) != set(arg2.primary_key):
raise DataJointError("The operands of a union must share the same primary key.")
if set(arg1.heading.secondary_attributes) & set(arg2.heading.secondary_attributes):
diff --git a/src/datajoint/gc.py b/src/datajoint/gc.py
index 71a4e8d08..7f083416b 100644
--- a/src/datajoint/gc.py
+++ b/src/datajoint/gc.py
@@ -44,7 +44,7 @@
from .errors import DataJointError
if TYPE_CHECKING:
- from .schemas import Schema
+ from .schemas import _Schema as Schema
logger = logging.getLogger(__name__.split(".")[0])
@@ -308,7 +308,7 @@ def scan_schema_references(
return referenced
-def list_stored_hashes(store_name: str | None = None) -> dict[str, int]:
+def list_stored_hashes(store_name: str | None = None, config=None) -> dict[str, int]:
"""
List all hash-addressed items in storage.
@@ -320,6 +320,8 @@ def list_stored_hashes(store_name: str | None = None) -> dict[str, int]:
----------
store_name : str, optional
Store to scan (None = default store).
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
@@ -328,7 +330,7 @@ def list_stored_hashes(store_name: str | None = None) -> dict[str, int]:
"""
import re
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
stored: dict[str, int] = {}
# Hash-addressed storage: _hash/{schema}/{subfolders...}/{hash}
@@ -369,7 +371,7 @@ def list_stored_hashes(store_name: str | None = None) -> dict[str, int]:
return stored
-def list_schema_paths(store_name: str | None = None) -> dict[str, int]:
+def list_schema_paths(store_name: str | None = None, config=None) -> dict[str, int]:
"""
List all schema-addressed items in storage.
@@ -380,13 +382,15 @@ def list_schema_paths(store_name: str | None = None) -> dict[str, int]:
----------
store_name : str, optional
Store to scan (None = default store).
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
dict[str, int]
Dict mapping storage path to size in bytes.
"""
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
stored: dict[str, int] = {}
try:
@@ -427,7 +431,7 @@ def list_schema_paths(store_name: str | None = None) -> dict[str, int]:
return stored
-def delete_schema_path(path: str, store_name: str | None = None) -> bool:
+def delete_schema_path(path: str, store_name: str | None = None, config=None) -> bool:
"""
Delete a schema-addressed directory from storage.
@@ -437,13 +441,15 @@ def delete_schema_path(path: str, store_name: str | None = None) -> bool:
Storage path (relative to store root).
store_name : str, optional
Store name (None = default store).
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
bool
True if deleted, False if not found.
"""
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
try:
full_path = backend._full_path(path)
@@ -497,15 +503,18 @@ def scan(
if not schemas:
raise DataJointError("At least one schema must be provided")
+ # Extract config from the first schema's connection
+ _config = schemas[0].connection._config if schemas else None
+
# --- Hash-addressed storage ---
hash_referenced = scan_hash_references(*schemas, store_name=store_name, verbose=verbose)
- hash_stored = list_stored_hashes(store_name)
+ hash_stored = list_stored_hashes(store_name, config=_config)
orphaned_hashes = set(hash_stored.keys()) - hash_referenced
hash_orphaned_bytes = sum(hash_stored.get(h, 0) for h in orphaned_hashes)
# --- Schema-addressed storage ---
schema_paths_referenced = scan_schema_references(*schemas, store_name=store_name, verbose=verbose)
- schema_paths_stored = list_schema_paths(store_name)
+ schema_paths_stored = list_schema_paths(store_name, config=_config)
orphaned_paths = set(schema_paths_stored.keys()) - schema_paths_referenced
schema_paths_orphaned_bytes = sum(schema_paths_stored.get(p, 0) for p in orphaned_paths)
@@ -570,6 +579,9 @@ def collect(
# First scan to find orphaned items
stats = scan(*schemas, store_name=store_name, verbose=verbose)
+ # Extract config from the first schema's connection
+ _config = schemas[0].connection._config if schemas else None
+
hash_deleted = 0
schema_paths_deleted = 0
bytes_freed = 0
@@ -578,12 +590,12 @@ def collect(
if not dry_run:
# Delete orphaned hashes
if stats["hash_orphaned"] > 0:
- hash_stored = list_stored_hashes(store_name)
+ hash_stored = list_stored_hashes(store_name, config=_config)
for path in stats["orphaned_hashes"]:
try:
size = hash_stored.get(path, 0)
- if delete_path(path, store_name):
+ if delete_path(path, store_name, config=_config):
hash_deleted += 1
bytes_freed += size
if verbose:
@@ -594,12 +606,12 @@ def collect(
# Delete orphaned schema paths
if stats["schema_paths_orphaned"] > 0:
- schema_paths_stored = list_schema_paths(store_name)
+ schema_paths_stored = list_schema_paths(store_name, config=_config)
for path in stats["orphaned_paths"]:
try:
size = schema_paths_stored.get(path, 0)
- if delete_schema_path(path, store_name):
+ if delete_schema_path(path, store_name, config=_config):
schema_paths_deleted += 1
bytes_freed += size
if verbose:
diff --git a/src/datajoint/hash_registry.py b/src/datajoint/hash_registry.py
index a285e5df1..331c836cd 100644
--- a/src/datajoint/hash_registry.py
+++ b/src/datajoint/hash_registry.py
@@ -38,7 +38,6 @@
from typing import Any
from .errors import DataJointError
-from .settings import config
from .storage import StorageBackend
logger = logging.getLogger(__name__.split(".")[0])
@@ -131,7 +130,7 @@ def build_hash_path(
return f"_hash/{schema_name}/{content_hash}"
-def get_store_backend(store_name: str | None = None) -> StorageBackend:
+def get_store_backend(store_name: str | None = None, config=None) -> StorageBackend:
"""
Get a StorageBackend for hash-addressed storage.
@@ -139,18 +138,22 @@ def get_store_backend(store_name: str | None = None) -> StorageBackend:
----------
store_name : str, optional
Name of the store to use. If None, uses stores.default.
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
StorageBackend
StorageBackend instance.
"""
+ if config is None:
+ from .settings import config
# get_store_spec handles None by using stores.default
spec = config.get_store_spec(store_name)
return StorageBackend(spec)
-def get_store_subfolding(store_name: str | None = None) -> tuple[int, ...] | None:
+def get_store_subfolding(store_name: str | None = None, config=None) -> tuple[int, ...] | None:
"""
Get the subfolding configuration for a store.
@@ -158,12 +161,16 @@ def get_store_subfolding(store_name: str | None = None) -> tuple[int, ...] | Non
----------
store_name : str, optional
Name of the store. If None, uses stores.default.
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
tuple[int, ...] | None
Subfolding pattern (e.g., (2, 2)) or None for flat storage.
"""
+ if config is None:
+ from .settings import config
spec = config.get_store_spec(store_name)
subfolding = spec.get("subfolding")
if subfolding is not None:
@@ -175,6 +182,7 @@ def put_hash(
data: bytes,
schema_name: str,
store_name: str | None = None,
+ config=None,
) -> dict[str, Any]:
"""
Store content using hash-addressed storage.
@@ -193,6 +201,8 @@ def put_hash(
Database/schema name for path isolation.
store_name : str, optional
Name of the store. If None, uses default store.
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
@@ -200,10 +210,10 @@ def put_hash(
Metadata dict with keys: hash, path, schema, store, size.
"""
content_hash = compute_hash(data)
- subfolding = get_store_subfolding(store_name)
+ subfolding = get_store_subfolding(store_name, config=config)
path = build_hash_path(content_hash, schema_name, subfolding)
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
# Check if content already exists (deduplication within schema)
if not backend.exists(path):
@@ -221,7 +231,7 @@ def put_hash(
}
-def get_hash(metadata: dict[str, Any]) -> bytes:
+def get_hash(metadata: dict[str, Any], config=None) -> bytes:
"""
Retrieve content using stored metadata.
@@ -232,6 +242,8 @@ def get_hash(metadata: dict[str, Any]) -> bytes:
----------
metadata : dict
Metadata dict with keys: path, hash, store (optional).
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
@@ -249,15 +261,13 @@ def get_hash(metadata: dict[str, Any]) -> bytes:
expected_hash = metadata["hash"]
store_name = metadata.get("store")
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
data = backend.get_buffer(path)
# Verify hash for integrity
actual_hash = compute_hash(data)
if actual_hash != expected_hash:
- raise DataJointError(
- f"Hash mismatch: expected {expected_hash}, got {actual_hash}. " f"Data at {path} may be corrupted."
- )
+ raise DataJointError(f"Hash mismatch: expected {expected_hash}, got {actual_hash}. Data at {path} may be corrupted.")
return data
@@ -265,6 +275,7 @@ def get_hash(metadata: dict[str, Any]) -> bytes:
def delete_path(
path: str,
store_name: str | None = None,
+ config=None,
) -> bool:
"""
Delete content at the specified path from storage.
@@ -278,6 +289,8 @@ def delete_path(
Storage path (as stored in metadata).
store_name : str, optional
Name of the store. If None, uses default store.
+ config : Config, optional
+ Config instance. If None, falls back to global settings.config.
Returns
-------
@@ -288,7 +301,7 @@ def delete_path(
--------
This permanently deletes content. Ensure no references exist first.
"""
- backend = get_store_backend(store_name)
+ backend = get_store_backend(store_name, config=config)
if backend.exists(path):
backend.remove(path)
diff --git a/src/datajoint/instance.py b/src/datajoint/instance.py
new file mode 100644
index 000000000..455336a7c
--- /dev/null
+++ b/src/datajoint/instance.py
@@ -0,0 +1,311 @@
+"""
+DataJoint Instance for thread-safe operation.
+
+An Instance encapsulates a config and connection pair, providing isolated
+database contexts for multi-tenant applications.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Any, Literal
+
+from .connection import Connection
+from .errors import ThreadSafetyError
+from .settings import Config, _create_config, config as _settings_config
+
+if TYPE_CHECKING:
+ from .schemas import _Schema as SchemaClass
+ from .table import FreeTable as FreeTableClass
+
+
+def _load_thread_safe() -> bool:
+ """
+ Check if thread-safe mode is enabled.
+
+ Thread-safe mode is controlled by the ``DJ_THREAD_SAFE`` environment
+ variable, which must be set before the process starts.
+
+ Returns
+ -------
+ bool
+ True if thread-safe mode is enabled.
+ """
+ env_val = os.environ.get("DJ_THREAD_SAFE", "").lower()
+ if env_val in ("true", "1", "yes"):
+ return True
+ return False
+
+
+class Instance:
+ """
+ Encapsulates a DataJoint configuration and connection.
+
+ Each Instance has its own Config and Connection, providing isolation
+ for multi-tenant applications. Use ``dj.Instance()`` to create isolated
+ instances, or access the singleton via ``dj.config``, ``dj.conn()``, etc.
+
+ Parameters
+ ----------
+ host : str
+ Database hostname.
+ user : str
+ Database username.
+ password : str
+ Database password.
+ port : int, optional
+ Database port. Defaults to 3306 for MySQL, 5432 for PostgreSQL.
+ use_tls : bool or dict, optional
+ TLS configuration.
+ backend : str, optional
+ Database backend: ``"mysql"`` or ``"postgresql"``. Default from config.
+ **kwargs : Any
+ Additional config overrides applied to this instance's config.
+
+ Attributes
+ ----------
+ config : Config
+ Configuration for this instance.
+ connection : Connection
+ Database connection for this instance.
+
+ Examples
+ --------
+ >>> inst = dj.Instance(host="localhost", user="root", password="secret")
+ >>> inst.config.safemode = False
+ >>> schema = inst.Schema("my_schema")
+ """
+
+ def __init__(
+ self,
+ host: str,
+ user: str,
+ password: str,
+ port: int | None = None,
+ use_tls: bool | dict | None = None,
+ backend: Literal["mysql", "postgresql"] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ # Create fresh config with defaults loaded from env/file
+ self.config = _create_config()
+
+ # Apply backend override before other kwargs (port default depends on it)
+ if backend is not None:
+ self.config.database.backend = backend
+ # Re-derive port default since _create_config resolved it before backend was set
+ if port is None and "database__port" not in kwargs:
+ self.config.database.port = 5432 if backend == "postgresql" else 3306
+
+ # Apply any config overrides from kwargs
+ for key, value in kwargs.items():
+ if hasattr(self.config, key):
+ setattr(self.config, key, value)
+ elif "__" in key:
+ # Handle nested keys like database__reconnect
+ parts = key.split("__")
+ obj = self.config
+ for part in parts[:-1]:
+ obj = getattr(obj, part)
+ setattr(obj, parts[-1], value)
+
+ # Determine port
+ if port is None:
+ port = self.config.database.port
+
+ # Create connection with this instance's config and backend
+ self.connection = Connection(
+ host,
+ user,
+ password,
+ port,
+ use_tls,
+ backend=self.config.database.backend,
+ config_override=self.config,
+ )
+
+ def Schema(
+ self,
+ schema_name: str,
+ *,
+ context: dict[str, Any] | None = None,
+ create_schema: bool = True,
+ create_tables: bool | None = None,
+ add_objects: dict[str, Any] | None = None,
+ ) -> "SchemaClass":
+ """
+ Create a Schema bound to this instance's connection.
+
+ Parameters
+ ----------
+ schema_name : str
+ Database schema name.
+ context : dict, optional
+ Namespace for foreign key lookup.
+ create_schema : bool, optional
+ If False, raise error if schema doesn't exist. Default True.
+ create_tables : bool, optional
+ If False, raise error when accessing missing tables.
+ add_objects : dict, optional
+ Additional objects for declaration context.
+
+ Returns
+ -------
+ Schema
+ A Schema using this instance's connection.
+ """
+ from .schemas import _Schema
+
+ return _Schema(
+ schema_name,
+ context=context,
+ connection=self.connection,
+ create_schema=create_schema,
+ create_tables=create_tables,
+ add_objects=add_objects,
+ )
+
+ def FreeTable(self, full_table_name: str) -> "FreeTableClass":
+ """
+ Create a FreeTable bound to this instance's connection.
+
+ Parameters
+ ----------
+ full_table_name : str
+ Full table name as ``'schema.table'`` or ```schema`.`table```.
+
+ Returns
+ -------
+ FreeTable
+ A FreeTable using this instance's connection.
+ """
+ from .table import FreeTable
+
+ return FreeTable(self.connection, full_table_name)
+
+ def __repr__(self) -> str:
+ return f"Instance({self.connection!r})"
+
+
+# =============================================================================
+# Singleton management
+# =============================================================================
+# The global config is created at module load time and can be modified
+# The singleton connection is created lazily when conn() or Schema() is called
+
+# Reuse the config created in settings.py — there must be exactly one global config
+_global_config: Config = _settings_config
+_singleton_connection: Connection | None = None
+
+
+def _check_thread_safe() -> None:
+ """
+ Check if thread-safe mode is enabled and raise if so.
+
+ Raises
+ ------
+ ThreadSafetyError
+ If thread_safe mode is enabled.
+ """
+ if _load_thread_safe():
+ raise ThreadSafetyError(
+ "Global DataJoint state is disabled in thread-safe mode. " "Use dj.Instance() to create an isolated instance."
+ )
+
+
+def _get_singleton_connection() -> Connection:
+ """
+ Get or create the singleton Connection.
+
+ Uses credentials from the global config.
+
+ Raises
+ ------
+ ThreadSafetyError
+ If thread_safe mode is enabled.
+ DataJointError
+ If credentials are not configured.
+ """
+ global _singleton_connection
+
+ _check_thread_safe()
+
+ if _singleton_connection is None:
+ from .errors import DataJointError
+
+ host = _global_config.database.host
+ user = _global_config.database.user
+ raw_password = _global_config.database.password
+ password = raw_password.get_secret_value() if raw_password is not None else None
+ port = _global_config.database.port
+ use_tls = _global_config.database.use_tls
+
+ if user is None:
+ raise DataJointError(
+ "Database user not configured. Set dj.config['database.user'] or DJ_USER environment variable."
+ )
+ if password is None:
+ raise DataJointError(
+ "Database password not configured. Set dj.config['database.password'] or DJ_PASS environment variable."
+ )
+
+ _singleton_connection = Connection(host, user, password, port, use_tls, config_override=_global_config)
+
+ return _singleton_connection
+
+
+class _ConfigProxy:
+ """
+ Proxy that delegates to the global config, with thread-safety checks.
+
+ In thread-safe mode, all access raises ThreadSafetyError.
+ """
+
+ def __getattr__(self, name: str) -> Any:
+ _check_thread_safe()
+ return getattr(_global_config, name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ _check_thread_safe()
+ setattr(_global_config, name, value)
+
+ def __getitem__(self, key: str) -> Any:
+ _check_thread_safe()
+ return _global_config[key]
+
+ def __setitem__(self, key: str, value: Any) -> None:
+ _check_thread_safe()
+ _global_config[key] = value
+
+ def __delitem__(self, key: str) -> None:
+ _check_thread_safe()
+ del _global_config[key]
+
+ def get(self, key: str, default: Any = None) -> Any:
+ _check_thread_safe()
+ return _global_config.get(key, default)
+
+ def override(self, **kwargs: Any):
+ _check_thread_safe()
+ return _global_config.override(**kwargs)
+
+ def load(self, filename: str) -> None:
+ _check_thread_safe()
+ return _global_config.load(filename)
+
+ def get_store_spec(self, store: str | None = None, *, use_filepath_default: bool = False) -> dict[str, Any]:
+ _check_thread_safe()
+ return _global_config.get_store_spec(store, use_filepath_default=use_filepath_default)
+
+ @staticmethod
+ def save_template(
+ path: str = "datajoint.json",
+ minimal: bool = True,
+ create_secrets_dir: bool = True,
+ ):
+ # save_template is a static method, no thread-safety check needed
+ return Config.save_template(path, minimal, create_secrets_dir)
+
+ def __repr__(self) -> str:
+ if _load_thread_safe():
+ return "ConfigProxy (thread-safe mode - use dj.Instance())"
+ return repr(_global_config)
diff --git a/src/datajoint/jobs.py b/src/datajoint/jobs.py
index 5a0eb2a86..6da8377dd 100644
--- a/src/datajoint/jobs.py
+++ b/src/datajoint/jobs.py
@@ -24,16 +24,22 @@
logger = logging.getLogger(__name__.split(".")[0])
-def _get_job_version() -> str:
+def _get_job_version(config=None) -> str:
"""
Get version string based on config settings.
+ Parameters
+ ----------
+ config : Config, optional
+ Configuration object. If None, falls back to global config.
+
Returns
-------
str
Version string, or empty string if version tracking disabled.
"""
- from .settings import config
+ if config is None:
+ from .settings import config
method = config.jobs.version_method
if method is None or method == "none":
@@ -349,17 +355,15 @@ def refresh(
3. Remove stale jobs: jobs older than stale_timeout whose keys not in key_source
4. Remove orphaned jobs: reserved jobs older than orphan_timeout (if specified)
"""
- from .settings import config
-
# Ensure jobs table exists
if not self.is_declared:
self.declare()
# Get defaults from config
if priority is None:
- priority = config.jobs.default_priority
+ priority = self.connection._config.jobs.default_priority
if stale_timeout is None:
- stale_timeout = config.jobs.stale_timeout
+ stale_timeout = self.connection._config.jobs.stale_timeout
result = {"added": 0, "removed": 0, "orphaned": 0, "re_pended": 0}
@@ -392,7 +396,7 @@ def refresh(
pass # Job already exists
# 2. Re-pend success jobs if keep_completed=True
- if config.jobs.keep_completed:
+ if self.connection._config.jobs.keep_completed:
# Success jobs whose keys are in key_source but not in target
# Disable semantic_check for Job table operations (job table PK has different lineage than target)
success_to_repend = self.completed.restrict(key_source, semantic_check=False).restrict(
@@ -462,7 +466,7 @@ def reserve(self, key: dict) -> bool:
os.getpid(),
self.connection.connection_id,
self.connection.get_user(),
- _get_job_version(),
+ _get_job_version(self.connection._config),
]
cursor = self.connection.query(query, args=args)
return cursor.rowcount == 1
@@ -485,9 +489,7 @@ def complete(self, key: dict, duration: float | None = None) -> None:
- If True: updates status to ``'success'`` with completion time and duration
- If False: deletes the job entry
"""
- from .settings import config
-
- if config.jobs.keep_completed:
+ if self.connection._config.jobs.keep_completed:
# Use server time for completed_time
server_now = self.connection.query("SELECT CURRENT_TIMESTAMP").fetchone()[0]
pk = self._get_pk(key)
@@ -545,13 +547,11 @@ def ignore(self, key: dict) -> None:
key : dict
Primary key dict of the job.
"""
- from .settings import config
-
pk = self._get_pk(key)
if pk in self:
self.update1({**pk, "status": "ignore"})
else:
- priority = config.jobs.default_priority
+ priority = self.connection._config.jobs.default_priority
self.insert1({**pk, "status": "ignore", "priority": priority})
def progress(self) -> dict:
diff --git a/src/datajoint/migrate.py b/src/datajoint/migrate.py
index d48afae62..2ff0dfcb8 100644
--- a/src/datajoint/migrate.py
+++ b/src/datajoint/migrate.py
@@ -7,7 +7,7 @@
.. note::
This module is provided temporarily to assist with migration from pre-2.0.
- It will be deprecated in DataJoint 2.1 and removed in 2.2.
+ It will be deprecated in DataJoint 2.1 and removed in 2.3.
Complete your migrations while on DataJoint 2.0.
Note on Terminology
@@ -32,14 +32,14 @@
# Show deprecation warning starting in 2.1
if Version(__version__) >= Version("2.1"):
warnings.warn(
- "datajoint.migrate is deprecated and will be removed in DataJoint 2.2. "
+ "datajoint.migrate is deprecated and will be removed in DataJoint 2.3. "
"Complete your schema migrations before upgrading.",
DeprecationWarning,
stacklevel=2,
)
if TYPE_CHECKING:
- from .schemas import Schema
+ from .schemas import _Schema as Schema
logger = logging.getLogger(__name__.split(".")[0])
@@ -653,7 +653,7 @@ def add_job_metadata_columns(target, dry_run: bool = True) -> dict:
- Future populate() calls will fill in metadata for new rows
- This does NOT retroactively populate metadata for existing rows
"""
- from .schemas import Schema
+ from .schemas import _Schema
from .table import Table
result = {
@@ -664,7 +664,7 @@ def add_job_metadata_columns(target, dry_run: bool = True) -> dict:
}
# Determine tables to process
- if isinstance(target, Schema):
+ if isinstance(target, _Schema):
schema = target
# Get all user tables in the schema
tables_query = """
diff --git a/src/datajoint/preview.py b/src/datajoint/preview.py
index 92d09d874..0b80ad15f 100644
--- a/src/datajoint/preview.py
+++ b/src/datajoint/preview.py
@@ -2,8 +2,6 @@
import json
-from .settings import config
-
def _format_object_display(json_data):
"""Format object metadata for display in query results."""
@@ -44,6 +42,7 @@ def _get_blob_placeholder(heading, field_name, html_escape=False):
def preview(query_expression, limit, width):
heading = query_expression.heading
rel = query_expression.proj(*heading.non_blobs)
+ config = query_expression.connection._config
# Object fields use codecs - not specially handled in simplified model
object_fields = []
if limit is None:
@@ -105,6 +104,7 @@ def get_display_value(tup, f, idx):
def repr_html(query_expression):
heading = query_expression.heading
rel = query_expression.proj(*heading.non_blobs)
+ config = query_expression.connection._config
# Object fields use codecs - not specially handled in simplified model
object_fields = []
tuples = rel.to_arrays(limit=config["display.limit"] + 1)
diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py
index 1780bbaaf..8747cdbf2 100644
--- a/src/datajoint/schemas.py
+++ b/src/datajoint/schemas.py
@@ -7,23 +7,20 @@
from __future__ import annotations
-import collections
import inspect
-import itertools
import logging
import re
import types
import warnings
from typing import TYPE_CHECKING, Any
-from .connection import conn
from .errors import AccessError, DataJointError
+from .instance import _get_singleton_connection
if TYPE_CHECKING:
from .connection import Connection
from .heading import Heading
from .jobs import Job
-from .settings import config
from .table import FreeTable, lookup_class_name
from .user_tables import Computed, Imported, Lookup, Manual, Part, _get_tier
from .utils import to_camel_case, user_choice
@@ -54,7 +51,7 @@ def ordered_dir(class_: type) -> list[str]:
return attr_list
-class Schema:
+class _Schema:
"""
Decorator that binds table classes to a database schema.
@@ -120,7 +117,7 @@ def __init__(
self.database = None
self.context = context
self.create_schema = create_schema
- self.create_tables = create_tables if create_tables is not None else config.database.create_tables
+ self.create_tables = create_tables # None means "use connection config default"
self.add_objects = add_objects
self.declare_list = []
if schema_name:
@@ -174,7 +171,7 @@ def activate(
if connection is not None:
self.connection = connection
if self.connection is None:
- self.connection = conn()
+ self.connection = _get_singleton_connection()
self.database = schema_name
if create_schema is not None:
self.create_schema = create_schema
@@ -293,7 +290,10 @@ def _decorate_table(self, table_class: type, context: dict[str, Any], assert_dec
# instantiate the class, declare the table if not already
instance = table_class()
is_declared = instance.is_declared
- if not is_declared and not assert_declared and self.create_tables:
+ create_tables = (
+ self.create_tables if self.create_tables is not None else self.connection._config.database.create_tables
+ )
+ if not is_declared and not assert_declared and create_tables:
instance.declare(context)
self.connection.dependencies.clear()
is_declared = is_declared or instance.is_declared
@@ -343,7 +343,7 @@ def make_classes(self, into: dict[str, Any] | None = None) -> None:
del frame
tables = [
row[0]
- for row in self.connection.query("SHOW TABLES in `%s`" % self.database)
+ for row in self.connection.query(self.connection.adapter.list_tables_sql(self.database))
if lookup_class_name("`{db}`.`{tab}`".format(db=self.database, tab=row[0]), into, 0) is None
]
master_classes = (Lookup, Manual, Imported, Computed)
@@ -389,7 +389,7 @@ def drop(self, prompt: bool | None = None) -> None:
AccessError
If insufficient permissions to drop the schema.
"""
- prompt = config["safemode"] if prompt is None else prompt
+ prompt = self.connection._config["safemode"] if prompt is None else prompt
if not self.exists:
logger.info("Schema named `{database}` does not exist. Doing nothing.".format(database=self.database))
@@ -421,13 +421,7 @@ def exists(self) -> bool:
"""
if self.database is None:
raise DataJointError("Schema must be activated first.")
- return bool(
- self.connection.query(
- "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{database}'".format(
- database=self.database
- )
- ).rowcount
- )
+ return bool(self.connection.query(self.connection.adapter.schema_exists_sql(self.database)).rowcount)
@property
def lineage_table_exists(self) -> bool:
@@ -520,83 +514,6 @@ def jobs(self) -> list[Job]:
return jobs_list
- @property
- def code(self):
- self._assert_exists()
- return self.save()
-
- def save(self, python_filename: str | None = None) -> str:
- """
- Generate Python code that recreates this schema.
-
- Parameters
- ----------
- python_filename : str, optional
- If provided, write the code to this file.
-
- Returns
- -------
- str
- Python module source code defining this schema.
-
- Notes
- -----
- This method is in preparation for a future release and is not
- officially supported.
- """
- self.connection.dependencies.load()
- self._assert_exists()
- module_count = itertools.count()
- # add virtual modules for referenced modules with names vmod0, vmod1, ...
- module_lookup = collections.defaultdict(lambda: "vmod" + str(next(module_count)))
- db = self.database
-
- def make_class_definition(table):
- tier = _get_tier(table).__name__
- class_name = table.split(".")[1].strip("`")
- indent = ""
- if tier == "Part":
- class_name = class_name.split("__")[-1]
- indent += " "
- class_name = to_camel_case(class_name)
-
- def replace(s):
- d, tabs = s.group(1), s.group(2)
- return ("" if d == db else (module_lookup[d] + ".")) + ".".join(
- to_camel_case(tab) for tab in tabs.lstrip("__").split("__")
- )
-
- return ("" if tier == "Part" else "\n@schema\n") + (
- '{indent}class {class_name}(dj.{tier}):\n{indent} definition = """\n{indent} {defi}"""'
- ).format(
- class_name=class_name,
- indent=indent,
- tier=tier,
- defi=re.sub(
- r"`([^`]+)`.`([^`]+)`",
- replace,
- FreeTable(self.connection, table).describe(),
- ).replace("\n", "\n " + indent),
- )
-
- tables = self.connection.dependencies.topo_sort()
- body = "\n\n".join(make_class_definition(table) for table in tables)
- python_code = "\n\n".join(
- (
- '"""This module was auto-generated by datajoint from an existing schema"""',
- "import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db),
- "\n".join(
- "{module} = dj.VirtualModule('{module}', '{schema_name}')".format(module=v, schema_name=k)
- for k, v in module_lookup.items()
- ),
- body,
- )
- )
- if python_filename is None:
- return python_code
- with open(python_filename, "wt") as f:
- f.write(python_code)
-
def list_tables(self) -> list[str]:
"""
Return all user tables in the schema.
@@ -612,7 +529,10 @@ def list_tables(self) -> list[str]:
self.connection.dependencies.load()
return [
t
- for d, t in (table_name.replace("`", "").split(".") for table_name in self.connection.dependencies.topo_sort())
+ for d, t in (
+ self.connection.adapter.split_full_table_name(table_name)
+ for table_name in self.connection.dependencies.topo_sort()
+ )
if d == self.database
]
@@ -810,7 +730,7 @@ def __init__(
Additional objects to add to the module namespace.
"""
super(VirtualModule, self).__init__(name=module_name)
- _schema = Schema(
+ _schema = _Schema(
schema_name,
create_schema=create_schema,
create_tables=create_tables,
@@ -836,12 +756,8 @@ def list_schemas(connection: Connection | None = None) -> list[str]:
list[str]
Names of all accessible schemas.
"""
- return [
- r[0]
- for r in (connection or conn()).query(
- 'SELECT schema_name FROM information_schema.schemata WHERE schema_name <> "information_schema"'
- )
- ]
+ conn = connection or _get_singleton_connection()
+ return [r[0] for r in conn.query(conn.adapter.list_schemas_sql())]
def virtual_schema(
diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py
index e373ca38f..7019d8345 100644
--- a/src/datajoint/settings.py
+++ b/src/datajoint/settings.py
@@ -224,7 +224,6 @@ class ConnectionSettings(BaseSettings):
model_config = SettingsConfigDict(extra="forbid", validate_assignment=True)
- init_function: str | None = None
charset: str = "" # pymysql uses '' as default
@@ -341,11 +340,8 @@ class Config(BaseSettings):
# Top-level settings
loglevel: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(default="INFO", validation_alias="DJ_LOG_LEVEL")
safemode: bool = True
- enable_python_native_blobs: bool = True
- filepath_checksum_size_limit: int | None = None
- # Cache paths
- cache: Path | None = None
+ # Cache path for query results
query_cache: Path | None = None
# Download path for attachments and filepaths
@@ -362,7 +358,7 @@ def set_logger_level(cls, v: str) -> str:
logger.setLevel(v)
return v
- @field_validator("cache", "query_cache", mode="before")
+ @field_validator("query_cache", mode="before")
@classmethod
def convert_path(cls, v: Any) -> Path | None:
"""Convert string paths to Path objects."""
@@ -819,7 +815,6 @@ def save_template(
"use_tls": None,
},
"connection": {
- "init_function": None,
"charset": "",
},
"display": {
@@ -844,8 +839,6 @@ def save_template(
},
"loglevel": "INFO",
"safemode": True,
- "enable_python_native_blobs": True,
- "cache": None,
"query_cache": None,
"download_path": ".",
}
diff --git a/src/datajoint/staged_insert.py b/src/datajoint/staged_insert.py
index 6ac3819e4..1f6ee7afb 100644
--- a/src/datajoint/staged_insert.py
+++ b/src/datajoint/staged_insert.py
@@ -14,7 +14,6 @@
import fsspec
from .errors import DataJointError
-from .settings import config
from .storage import StorageBackend, build_object_path
@@ -69,7 +68,7 @@ def _ensure_backend(self):
"""Ensure storage backend is initialized."""
if self._backend is None:
try:
- spec = config.get_store_spec() # Uses stores.default
+ spec = self._table.connection._config.get_store_spec() # Uses stores.default
self._backend = StorageBackend(spec)
except DataJointError:
raise DataJointError(
@@ -110,7 +109,7 @@ def _get_storage_path(self, field: str, ext: str = "") -> str:
)
# Get storage spec (uses stores.default)
- spec = config.get_store_spec()
+ spec = self._table.connection._config.get_store_spec()
partition_pattern = spec.get("partition_pattern")
token_length = spec.get("token_length", 8)
diff --git a/src/datajoint/table.py b/src/datajoint/table.py
index 5fd8c3087..256fab6e9 100644
--- a/src/datajoint/table.py
+++ b/src/datajoint/table.py
@@ -23,7 +23,6 @@
)
from .expression import QueryExpression
from .heading import Heading
-from .settings import config
from .staged_insert import staged_insert1 as _staged_insert1
from .utils import get_master, is_camel_case, user_choice
@@ -141,8 +140,7 @@ def declare(self, context=None):
class_name = self.class_name
if "_" in class_name:
warnings.warn(
- f"Table class name `{class_name}` contains underscores. "
- "CamelCase names without underscores are recommended.",
+ f"Table class name `{class_name}` contains underscores. CamelCase names without underscores are recommended.",
UserWarning,
stacklevel=2,
)
@@ -153,7 +151,7 @@ def declare(self, context=None):
"Class names must be in CamelCase, starting with a capital letter."
)
sql, _external_stores, primary_key, fk_attribute_map, pre_ddl, post_ddl = declare(
- self.full_table_name, self.definition, context, self.connection.adapter
+ self.full_table_name, self.definition, context, self.connection.adapter, config=self.connection._config
)
# Call declaration hook for validation (subclasses like AutoPopulate can override)
@@ -235,12 +233,7 @@ def _populate_lineage(self, primary_key, fk_attribute_map):
# FK attributes: copy lineage from parent (whether in PK or not)
for attr, (parent_table, parent_attr) in fk_attribute_map.items():
# Parse parent table name: `schema`.`table` or "schema"."table" -> (schema, table)
- parent_clean = parent_table.replace("`", "").replace('"', "")
- if "." in parent_clean:
- parent_db, parent_tbl = parent_clean.split(".", 1)
- else:
- parent_db = self.database
- parent_tbl = parent_clean
+ parent_db, parent_tbl = self.connection.adapter.split_full_table_name(parent_table)
# Get parent's lineage for this attribute
parent_lineage = get_lineage(self.connection, parent_db, parent_tbl, parent_attr)
@@ -1119,7 +1112,7 @@ def strip_quotes(s):
raise DataJointError("Exceeded maximum number of delete attempts.")
return delete_count
- prompt = config["safemode"] if prompt is None else prompt
+ prompt = self.connection._config["safemode"] if prompt is None else prompt
# Start transaction
if transaction:
@@ -1227,7 +1220,7 @@ def drop(self, prompt: bool | None = None):
raise DataJointError(
"A table with an applied restriction cannot be dropped. Call drop() on the unrestricted Table."
)
- prompt = config["safemode"] if prompt is None else prompt
+ prompt = self.connection._config["safemode"] if prompt is None else prompt
self.connection.dependencies.load()
do_drop = True
@@ -1398,6 +1391,7 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None):
"_schema": self.database,
"_table": self.table_name,
"_field": name,
+ "_config": self.connection._config,
}
# Add primary key values from row if available
if row is not None:
diff --git a/tests/conftest.py b/tests/conftest.py
index 4d6adf09c..8efaab745 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -536,13 +536,13 @@ def mock_stores(stores_config):
@pytest.fixture
def mock_cache(tmpdir_factory):
- og_cache = dj.config.get("cache")
- dj.config["cache"] = tmpdir_factory.mktemp("cache")
+ og_cache = dj.config.get("download_path")
+ dj.config["download_path"] = str(tmpdir_factory.mktemp("cache"))
yield
if og_cache is None:
- del dj.config["cache"]
+ del dj.config["download_path"]
else:
- dj.config["cache"] = og_cache
+ dj.config["download_path"] = og_cache
@pytest.fixture(scope="session")
diff --git a/tests/integration/test_gc.py b/tests/integration/test_gc.py
index 7eca79f37..47ca0a96d 100644
--- a/tests/integration/test_gc.py
+++ b/tests/integration/test_gc.py
@@ -251,7 +251,7 @@ def test_deletes_orphaned_hashes(self, mock_scan, mock_list_stored, mock_delete)
assert stats["hash_deleted"] == 1
assert stats["bytes_freed"] == 100
assert stats["dry_run"] is False
- mock_delete.assert_called_once_with("_hash/schema/orphan_path", "test_store")
+ mock_delete.assert_called_once_with("_hash/schema/orphan_path", "test_store", config=mock_schema.connection._config)
@patch("datajoint.gc.delete_schema_path")
@patch("datajoint.gc.list_schema_paths")
@@ -278,7 +278,7 @@ def test_deletes_orphaned_schemas(self, mock_scan, mock_list_schemas, mock_delet
assert stats["schema_paths_deleted"] == 1
assert stats["bytes_freed"] == 500
assert stats["dry_run"] is False
- mock_delete.assert_called_once_with("schema/table/pk/field", "test_store")
+ mock_delete.assert_called_once_with("schema/table/pk/field", "test_store", config=mock_schema.connection._config)
class TestFormatStats:
diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py
index 20fa3233d..5a9203dca 100644
--- a/tests/integration/test_jobs.py
+++ b/tests/integration/test_jobs.py
@@ -108,10 +108,9 @@ def test_sigterm(clean_jobs, schema_any):
def test_suppress_dj_errors(clean_jobs, schema_any):
- """Test that DataJoint errors are suppressible without native py blobs."""
+ """Test that DataJoint errors are suppressible."""
error_class = schema.ErrorClass()
- with dj.config.override(enable_python_native_blobs=False):
- error_class.populate(reserve_jobs=True, suppress_errors=True)
+ error_class.populate(reserve_jobs=True, suppress_errors=True)
assert len(schema.DjExceptionName()) == len(error_class.jobs.errors) > 0
diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py
index 9a4144bbb..1211ecd1e 100644
--- a/tests/integration/test_schema.py
+++ b/tests/integration/test_schema.py
@@ -260,5 +260,5 @@ class Recording(dj.Manual):
id: smallint
"""
- schema2.drop()
- schema1.drop()
+ schema2.drop(prompt=False)
+ schema1.drop(prompt=False)
diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py
index af5718503..475d96df9 100644
--- a/tests/unit/test_settings.py
+++ b/tests/unit/test_settings.py
@@ -504,23 +504,23 @@ def test_display_limit(self):
class TestCachePaths:
"""Test cache path settings."""
- def test_cache_path_string(self):
- """Test setting cache path as string."""
- original = dj.config.cache
+ def test_query_cache_path_string(self):
+ """Test setting query_cache path as string."""
+ original = dj.config.query_cache
try:
- dj.config.cache = "/tmp/cache"
- assert dj.config.cache == Path("/tmp/cache")
+ dj.config.query_cache = "/tmp/cache"
+ assert dj.config.query_cache == Path("/tmp/cache")
finally:
- dj.config.cache = original
+ dj.config.query_cache = original
- def test_cache_path_none(self):
- """Test cache path can be None."""
- original = dj.config.cache
+ def test_query_cache_path_none(self):
+ """Test query_cache path can be None."""
+ original = dj.config.query_cache
try:
- dj.config.cache = None
- assert dj.config.cache is None
+ dj.config.query_cache = None
+ assert dj.config.query_cache is None
finally:
- dj.config.cache = original
+ dj.config.query_cache = original
class TestSaveTemplate:
diff --git a/tests/unit/test_thread_safe.py b/tests/unit/test_thread_safe.py
new file mode 100644
index 000000000..aba1b686b
--- /dev/null
+++ b/tests/unit/test_thread_safe.py
@@ -0,0 +1,294 @@
+"""Tests for thread-safe mode functionality."""
+
+import pytest
+
+
+class TestThreadSafeMode:
+ """Test thread-safe mode behavior."""
+
+ def test_thread_safe_env_var_true(self, monkeypatch):
+ """DJ_THREAD_SAFE=true enables thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ # Re-import to pick up the new env var
+ from datajoint.instance import _load_thread_safe
+
+ assert _load_thread_safe() is True
+
+ def test_thread_safe_env_var_false(self, monkeypatch):
+ """DJ_THREAD_SAFE=false disables thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "false")
+
+ from datajoint.instance import _load_thread_safe
+
+ assert _load_thread_safe() is False
+
+ def test_thread_safe_env_var_1(self, monkeypatch):
+ """DJ_THREAD_SAFE=1 enables thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "1")
+
+ from datajoint.instance import _load_thread_safe
+
+ assert _load_thread_safe() is True
+
+ def test_thread_safe_env_var_yes(self, monkeypatch):
+ """DJ_THREAD_SAFE=yes enables thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "yes")
+
+ from datajoint.instance import _load_thread_safe
+
+ assert _load_thread_safe() is True
+
+ def test_thread_safe_default_false(self, monkeypatch):
+ """Thread-safe mode defaults to False."""
+ monkeypatch.delenv("DJ_THREAD_SAFE", raising=False)
+
+ from datajoint.instance import _load_thread_safe
+
+ assert _load_thread_safe() is False
+
+
+class TestConfigProxyThreadSafe:
+ """Test ConfigProxy behavior in thread-safe mode."""
+
+ def test_config_access_raises_in_thread_safe_mode(self, monkeypatch):
+ """Accessing config raises ThreadSafetyError in thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ import datajoint as dj
+ from datajoint.errors import ThreadSafetyError
+
+ with pytest.raises(ThreadSafetyError):
+ _ = dj.config.database
+
+ def test_config_access_works_in_normal_mode(self, monkeypatch):
+ """Accessing config works in normal mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "false")
+
+ import datajoint as dj
+
+ # Should not raise
+ host = dj.config.database.host
+ assert isinstance(host, str)
+
+ def test_config_set_raises_in_thread_safe_mode(self, monkeypatch):
+ """Setting config raises ThreadSafetyError in thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ import datajoint as dj
+ from datajoint.errors import ThreadSafetyError
+
+ with pytest.raises(ThreadSafetyError):
+ dj.config.safemode = False
+
+ def test_save_template_works_in_thread_safe_mode(self, monkeypatch, tmp_path):
+ """save_template is a static method and works in thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ import datajoint as dj
+
+ # Should not raise - save_template is static
+ config_file = tmp_path / "datajoint.json"
+ dj.config.save_template(str(config_file), create_secrets_dir=False)
+ assert config_file.exists()
+
+
+class TestConnThreadSafe:
+ """Test conn() behavior in thread-safe mode."""
+
+ def test_conn_raises_in_thread_safe_mode(self, monkeypatch):
+ """conn() raises ThreadSafetyError in thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ import datajoint as dj
+ from datajoint.errors import ThreadSafetyError
+
+ with pytest.raises(ThreadSafetyError):
+ dj.conn()
+
+
+class TestSchemaThreadSafe:
+ """Test Schema behavior in thread-safe mode."""
+
+ def test_schema_raises_in_thread_safe_mode(self, monkeypatch):
+ """Schema() raises ThreadSafetyError in thread-safe mode without connection."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ import datajoint as dj
+ from datajoint.errors import ThreadSafetyError
+
+ with pytest.raises(ThreadSafetyError):
+ dj.Schema("test_schema")
+
+
+class TestFreeTableThreadSafe:
+ """Test FreeTable behavior in thread-safe mode."""
+
+ def test_freetable_raises_in_thread_safe_mode(self, monkeypatch):
+ """FreeTable() raises ThreadSafetyError in thread-safe mode without connection."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ import datajoint as dj
+ from datajoint.errors import ThreadSafetyError
+
+ with pytest.raises(ThreadSafetyError):
+ dj.FreeTable("test.table")
+
+
+class TestInstance:
+ """Test Instance class."""
+
+ def test_instance_import(self):
+ """Instance class is importable."""
+ from datajoint import Instance
+
+ assert Instance is not None
+
+ def test_instance_always_allowed_in_thread_safe_mode(self, monkeypatch):
+ """Instance() is allowed even in thread-safe mode."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "true")
+
+ from datajoint import Instance
+
+ # Instance class should be accessible
+ # (actual creation requires valid credentials)
+ assert callable(Instance)
+
+
+class TestInstanceBackend:
+ """Test Instance backend parameter."""
+
+ def test_instance_backend_sets_config(self, monkeypatch):
+ """Instance(backend=...) sets config.database.backend."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "false")
+ from datajoint.instance import Instance
+ from unittest.mock import patch
+
+ with patch("datajoint.instance.Connection"):
+ inst = Instance(
+ host="localhost",
+ user="root",
+ password="secret",
+ backend="postgresql",
+ )
+ assert inst.config.database.backend == "postgresql"
+
+ def test_instance_backend_default_from_config(self, monkeypatch):
+ """Instance without backend uses config default."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "false")
+ from datajoint.instance import Instance
+ from unittest.mock import patch
+
+ with patch("datajoint.instance.Connection"):
+ inst = Instance(
+ host="localhost",
+ user="root",
+ password="secret",
+ )
+ assert inst.config.database.backend == "mysql"
+
+ def test_instance_backend_affects_port_default(self, monkeypatch):
+ """Instance(backend='postgresql') uses port 5432 by default."""
+ monkeypatch.setenv("DJ_THREAD_SAFE", "false")
+ from datajoint.instance import Instance
+ from unittest.mock import patch
+
+ with patch("datajoint.instance.Connection") as MockConn:
+ Instance(
+ host="localhost",
+ user="root",
+ password="secret",
+ backend="postgresql",
+ )
+ # Connection should be called with port 5432 (PostgreSQL default)
+ args, kwargs = MockConn.call_args
+ assert args[3] == 5432 # port is the 4th positional arg
+
+
+class TestCrossConnectionValidation:
+ """Test that cross-connection operations are rejected."""
+
+ def test_join_different_connections_raises(self):
+ """Join of expressions from different connections raises DataJointError."""
+ from datajoint.expression import QueryExpression
+ from datajoint.errors import DataJointError
+ from unittest.mock import MagicMock
+
+ expr1 = QueryExpression()
+ expr1._connection = MagicMock()
+ expr1._heading = MagicMock()
+ expr1._heading.names = []
+
+ expr2 = QueryExpression()
+ expr2._connection = MagicMock() # different connection object
+ expr2._heading = MagicMock()
+ expr2._heading.names = []
+
+ with pytest.raises(DataJointError, match="different connections"):
+ expr1 * expr2
+
+ def test_join_same_connection_allowed(self):
+ """Join of expressions from the same connection does not raise."""
+ from datajoint.condition import assert_join_compatibility
+ from datajoint.expression import QueryExpression
+ from unittest.mock import MagicMock
+
+ shared_conn = MagicMock()
+
+ expr1 = QueryExpression()
+ expr1._connection = shared_conn
+ expr1._heading = MagicMock()
+ expr1._heading.names = []
+ expr1._heading.lineage_available = False
+
+ expr2 = QueryExpression()
+ expr2._connection = shared_conn
+ expr2._heading = MagicMock()
+ expr2._heading.names = []
+ expr2._heading.lineage_available = False
+
+ # Should not raise
+ assert_join_compatibility(expr1, expr2)
+
+ def test_restriction_different_connections_raises(self):
+ """Restriction by expression from different connection raises DataJointError."""
+ from datajoint.expression import QueryExpression
+ from datajoint.errors import DataJointError
+ from unittest.mock import MagicMock
+
+ expr1 = QueryExpression()
+ expr1._connection = MagicMock()
+ expr1._heading = MagicMock()
+ expr1._heading.names = ["a"]
+ expr1._heading.__getitem__ = MagicMock()
+ expr1._heading.new_attributes = set()
+ expr1._support = ["`db`.`t1`"]
+ expr1._restriction = []
+ expr1._restriction_attributes = set()
+ expr1._joins = []
+ expr1._top = None
+ expr1._original_heading = expr1._heading
+
+ expr2 = QueryExpression()
+ expr2._connection = MagicMock() # different connection
+ expr2._heading = MagicMock()
+ expr2._heading.names = ["a"]
+
+ with pytest.raises(DataJointError, match="different connections"):
+ expr1 & expr2
+
+
+class TestThreadSafetyError:
+ """Test ThreadSafetyError exception."""
+
+ def test_error_is_datajoint_error(self):
+ """ThreadSafetyError is a subclass of DataJointError."""
+ from datajoint.errors import DataJointError, ThreadSafetyError
+
+ assert issubclass(ThreadSafetyError, DataJointError)
+
+ def test_error_in_exports(self):
+ """ThreadSafetyError is exported from datajoint."""
+ import datajoint as dj
+
+ assert hasattr(dj, "ThreadSafetyError")