diff --git a/src/executorlib/task_scheduler/file/shared.py b/src/executorlib/task_scheduler/file/shared.py index c65409ca..20c0a55a 100644 --- a/src/executorlib/task_scheduler/file/shared.py +++ b/src/executorlib/task_scheduler/file/shared.py @@ -90,39 +90,22 @@ def execute_tasks_h5( if task_dict is not None and "shutdown" in task_dict and task_dict["shutdown"]: if task_dict["wait"] and wait: while len(memory_dict) > 0: - memory_dict = { - key: _check_task_output( - task_key=key, - future_obj=value, - cache_directory=cache_dir_dict[key], - ) - for key, value in memory_dict.items() - if not value.done() - } + memory_dict = _refresh_memory_dict( + memory_dict=memory_dict, + cache_dir_dict=cache_dir_dict, + ) if not task_dict["cancel_futures"] and wait: - if ( - terminate_function is not None - and terminate_function == terminate_subprocess - ): - for task in process_dict.values(): - terminate_function(task=task) - elif terminate_function is not None: - for queue_id in process_dict.values(): - terminate_function( - queue_id=queue_id, - config_directory=pysqa_config_directory, - backend=backend, - ) + _cancel_processes( + terminate_function=terminate_function, + process_dict=process_dict, + pysqa_config_directory=pysqa_config_directory, + backend=backend, + ) else: - memory_dict = { - key: _check_task_output( - task_key=key, - future_obj=value, - cache_directory=cache_dir_dict[key], - ) - for key, value in memory_dict.items() - if not value.done() - } + memory_dict = _refresh_memory_dict( + memory_dict=memory_dict, + cache_dir_dict=cache_dir_dict, + ) for value in memory_dict.values(): if not value.done(): value.cancel() @@ -193,15 +176,10 @@ def execute_tasks_h5( cache_dir_dict[task_key] = cache_directory future_queue.task_done() else: - memory_dict = { - key: _check_task_output( - task_key=key, - future_obj=value, - cache_directory=cache_dir_dict[key], - ) - for key, value in memory_dict.items() - if not value.done() - } + memory_dict = _refresh_memory_dict( + memory_dict=memory_dict, + cache_dir_dict=cache_dir_dict, + ) def _check_task_output( @@ -275,3 +253,56 @@ def _convert_args_and_kwargs( else: task_kwargs[key] = arg return task_args, task_kwargs, future_wait_key_lst + + +def _refresh_memory_dict(memory_dict: dict, cache_dir_dict: dict) -> dict: + """ + Refresh memory dictionary + + Args: + memory_dict (dict): dictionary with task keys and future objects + cache_dir_dict (dict): dictionary with task keys and cache directories + + Returns: + dict: Updated memory dictionary + """ + return { + key: _check_task_output( + task_key=key, + future_obj=value, + cache_directory=cache_dir_dict[key], + ) + for key, value in memory_dict.items() + if not value.done() + } + + +def _cancel_processes( + process_dict: dict, + terminate_function: Optional[Callable] = None, + pysqa_config_directory: Optional[str] = None, + backend: Optional[str] = None, +): + """ + Cancel processes + + Args: + process_dict (dict): dictionary with task keys and process reference. + terminate_function (callable): The function to terminate the tasks. + pysqa_config_directory (str): path to the pysqa config directory (only for pysqa based backend). + backend (str): name of the backend used to spawn tasks. + """ + if terminate_function is not None and terminate_function == terminate_subprocess: + for task in process_dict.values(): + terminate_function(task=task) + elif ( + terminate_function is not None + and backend is not None + and pysqa_config_directory is not None + ): + for queue_id in process_dict.values(): + terminate_function( + queue_id=queue_id, + config_directory=pysqa_config_directory, + backend=backend, + )