diff --git a/src/executorlib/task_scheduler/file/shared.py b/src/executorlib/task_scheduler/file/shared.py index 20c0a55a..3f6aeeb6 100644 --- a/src/executorlib/task_scheduler/file/shared.py +++ b/src/executorlib/task_scheduler/file/shared.py @@ -93,11 +93,15 @@ def execute_tasks_h5( memory_dict = _refresh_memory_dict( memory_dict=memory_dict, cache_dir_dict=cache_dir_dict, + process_dict=process_dict, + terminate_function=terminate_function, + pysqa_config_directory=pysqa_config_directory, + backend=backend, ) if not task_dict["cancel_futures"] and wait: _cancel_processes( - terminate_function=terminate_function, process_dict=process_dict, + terminate_function=terminate_function, pysqa_config_directory=pysqa_config_directory, backend=backend, ) @@ -105,6 +109,10 @@ def execute_tasks_h5( memory_dict = _refresh_memory_dict( memory_dict=memory_dict, cache_dir_dict=cache_dir_dict, + process_dict=process_dict, + terminate_function=terminate_function, + pysqa_config_directory=pysqa_config_directory, + backend=backend, ) for value in memory_dict.values(): if not value.done(): @@ -179,6 +187,10 @@ def execute_tasks_h5( memory_dict = _refresh_memory_dict( memory_dict=memory_dict, cache_dir_dict=cache_dir_dict, + process_dict=process_dict, + terminate_function=terminate_function, + pysqa_config_directory=pysqa_config_directory, + backend=backend, ) @@ -255,17 +267,37 @@ def _convert_args_and_kwargs( return task_args, task_kwargs, future_wait_key_lst -def _refresh_memory_dict(memory_dict: dict, cache_dir_dict: dict) -> dict: +def _refresh_memory_dict( + memory_dict: dict, + cache_dir_dict: dict, + process_dict: dict, + terminate_function: Optional[Callable] = None, + pysqa_config_directory: Optional[str] = None, + backend: Optional[str] = None, +) -> dict: """ Refresh memory dictionary Args: memory_dict (dict): dictionary with task keys and future objects cache_dir_dict (dict): dictionary with task keys and cache directories + process_dict (dict): dictionary with task keys and process reference. + terminate_function (callable): The function to terminate the tasks. + pysqa_config_directory (str): path to the pysqa config directory (only for pysqa based backend). + backend (str): name of the backend used to spawn tasks. Returns: dict: Updated memory dictionary """ + cancelled_lst = [ + key for key, value in memory_dict.items() if value.done() and value.cancelled() + ] + _cancel_processes( + process_dict={k: v for k, v in process_dict.items() if k in cancelled_lst}, + terminate_function=terminate_function, + pysqa_config_directory=pysqa_config_directory, + backend=backend, + ) return { key: _check_task_output( task_key=key,