Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 71 additions & 40 deletions src/executorlib/task_scheduler/file/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,39 +90,22 @@ def execute_tasks_h5(
if task_dict is not None and "shutdown" in task_dict and task_dict["shutdown"]:
if task_dict["wait"] and wait:
while len(memory_dict) > 0:
memory_dict = {
key: _check_task_output(
task_key=key,
future_obj=value,
cache_directory=cache_dir_dict[key],
)
for key, value in memory_dict.items()
if not value.done()
}
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
)
if not task_dict["cancel_futures"] and wait:
if (
terminate_function is not None
and terminate_function == terminate_subprocess
):
for task in process_dict.values():
terminate_function(task=task)
elif terminate_function is not None:
for queue_id in process_dict.values():
terminate_function(
queue_id=queue_id,
config_directory=pysqa_config_directory,
backend=backend,
)
_cancel_processes(
terminate_function=terminate_function,
process_dict=process_dict,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
)
else:
memory_dict = {
key: _check_task_output(
task_key=key,
future_obj=value,
cache_directory=cache_dir_dict[key],
)
for key, value in memory_dict.items()
if not value.done()
}
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
)
for value in memory_dict.values():
if not value.done():
value.cancel()
Expand Down Expand Up @@ -193,15 +176,10 @@ def execute_tasks_h5(
cache_dir_dict[task_key] = cache_directory
future_queue.task_done()
else:
memory_dict = {
key: _check_task_output(
task_key=key,
future_obj=value,
cache_directory=cache_dir_dict[key],
)
for key, value in memory_dict.items()
if not value.done()
}
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
)


def _check_task_output(
Expand Down Expand Up @@ -275,3 +253,56 @@ def _convert_args_and_kwargs(
else:
task_kwargs[key] = arg
return task_args, task_kwargs, future_wait_key_lst


def _refresh_memory_dict(memory_dict: dict, cache_dir_dict: dict) -> dict:
"""
Refresh memory dictionary

Args:
memory_dict (dict): dictionary with task keys and future objects
cache_dir_dict (dict): dictionary with task keys and cache directories

Returns:
dict: Updated memory dictionary
"""
return {
key: _check_task_output(
task_key=key,
future_obj=value,
cache_directory=cache_dir_dict[key],
)
for key, value in memory_dict.items()
if not value.done()
}


def _cancel_processes(
process_dict: dict,
terminate_function: Optional[Callable] = None,
pysqa_config_directory: Optional[str] = None,
backend: Optional[str] = None,
):
"""
Cancel processes

Args:
process_dict (dict): dictionary with task keys and process reference.
terminate_function (callable): The function to terminate the tasks.
pysqa_config_directory (str): path to the pysqa config directory (only for pysqa based backend).
backend (str): name of the backend used to spawn tasks.
"""
if terminate_function is not None and terminate_function == terminate_subprocess:
for task in process_dict.values():
terminate_function(task=task)
elif (
terminate_function is not None
and backend is not None
and pysqa_config_directory is not None
):
for queue_id in process_dict.values():
terminate_function(
queue_id=queue_id,
config_directory=pysqa_config_directory,
backend=backend,
)
Loading