Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions src/executorlib/task_scheduler/file/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,26 @@ def execute_tasks_h5(
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)
if not task_dict["cancel_futures"] and wait:
_cancel_processes(
terminate_function=terminate_function,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)
else:
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)
for value in memory_dict.values():
if not value.done():
Expand Down Expand Up @@ -179,6 +187,10 @@ def execute_tasks_h5(
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)


Expand Down Expand Up @@ -255,17 +267,37 @@ def _convert_args_and_kwargs(
return task_args, task_kwargs, future_wait_key_lst


def _refresh_memory_dict(memory_dict: dict, cache_dir_dict: dict) -> dict:
def _refresh_memory_dict(
memory_dict: dict,
cache_dir_dict: dict,
process_dict: dict,
terminate_function: Optional[Callable] = None,
pysqa_config_directory: Optional[str] = None,
backend: Optional[str] = None,
) -> dict:
"""
Refresh memory dictionary

Args:
memory_dict (dict): dictionary with task keys and future objects
cache_dir_dict (dict): dictionary with task keys and cache directories
process_dict (dict): dictionary with task keys and process reference.
terminate_function (callable): The function to terminate the tasks.
pysqa_config_directory (str): path to the pysqa config directory (only for pysqa based backend).
backend (str): name of the backend used to spawn tasks.

Returns:
dict: Updated memory dictionary
"""
cancelled_lst = [
key for key, value in memory_dict.items() if value.done() and value.cancelled()
]
_cancel_processes(
process_dict={k: v for k, v in process_dict.items() if k in cancelled_lst},
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)
return {
key: _check_task_output(
task_key=key,
Expand Down