Skip to content

Commit

Permalink
Ensure sibling files in toil-wdl-runner (#4610)
Browse files Browse the repository at this point in the history
* Ensure sibling files stay sibling files when downloaded

* Fix incorrect argument order

* Fix directory collisions with sibling files
  • Loading branch information
stxue1 authored Oct 16, 2023
1 parent b73b9ef commit 6c0fe1e
Showing 1 changed file with 77 additions and 26 deletions.
103 changes: 77 additions & 26 deletions src/toil/wdl/wdltoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
from contextlib import ExitStack, contextmanager
from graphlib import TopologicalSorter
from typing import cast, Any, Callable, Union, Dict, List, Optional, Set, Sequence, Tuple, Type, TypeVar, Iterator, \
Generator
Iterable, Generator
from urllib.parse import urlsplit, urljoin, quote, unquote

import WDL
from WDL import Error
from configargparse import ArgParser
from WDL._util import byte_size_units
from WDL.runtime.task_container import TaskContainer
Expand All @@ -50,6 +50,7 @@
from toil.fileStores import FileID
from toil.fileStores.abstractFileStore import AbstractFileStore
from toil.jobStores.abstractJobStore import AbstractJobStore, UnimplementedURLException
from toil.lib.memoize import memoize
from toil.lib.conversions import convert_units, human2bytes
from toil.lib.misc import get_user_name
from toil.lib.threading import global_mutex
Expand Down Expand Up @@ -305,17 +306,17 @@ def recursive_dependencies(root: WDL.Tree.WorkflowNode) -> Set[str]:

TOIL_URI_SCHEME = 'toilfile:'

def pack_toil_uri(file_id: FileID, file_basename: str) -> str:
def pack_toil_uri(file_id: FileID, dir_id: uuid.UUID, file_basename: str) -> str:
"""
Encode a Toil file ID and its source path in a URI that starts with the scheme in TOIL_URI_SCHEME.
"""

# We urlencode everything, including any slashes. We need to use a slash to
# set off the actual filename, so the WDL standard library basename
# function works correctly.
return f"{TOIL_URI_SCHEME}{quote(file_id.pack(), safe='')}/{quote(file_basename, safe='')}"
return f"{TOIL_URI_SCHEME}{quote(file_id.pack(), safe='')}/{quote(str(dir_id))}/{quote(file_basename, safe='')}"

def unpack_toil_uri(toil_uri: str) -> Tuple[FileID, str]:
def unpack_toil_uri(toil_uri: str) -> Tuple[FileID, str, str]:
"""
Unpack a URI made by make_toil_uri to retrieve the FileID and the basename
(no path prefix) that the file is supposed to have.
Expand All @@ -329,12 +330,13 @@ def unpack_toil_uri(toil_uri: str) -> Tuple[FileID, str]:
raise ValueError(f"URI doesn't start with {TOIL_URI_SCHEME} and should: {toil_uri}")
# Split encoded file ID from filename
parts = parts[1].split('/')
if len(parts) != 2:
if len(parts) != 3:
raise ValueError(f"Wrong number of path segments in URI: {toil_uri}")
file_id = FileID.unpack(unquote(parts[0]))
file_basename = unquote(parts[1])
parent_id = unquote(parts[1])
file_basename = unquote(parts[2])

return file_id, file_basename
return file_id, parent_id, file_basename

def evaluate_output_decls(output_decls: List[WDL.Tree.Decl], all_bindings: WDL.Env.Bindings[WDL.Value.Base], standard_library: WDL.StdLib.Base) -> WDL.Env.Bindings[WDL.Value.Base]:
"""
Expand Down Expand Up @@ -380,7 +382,7 @@ def _call_eager(self, expr: "WDL.Expr.Apply", arguments: List[WDL.Value.Base]) -
if uri.startswith(TOIL_URI_SCHEME):
# This is a Toil File ID we encoded; we have the size
# available.
file_id, _ = unpack_toil_uri(uri)
file_id, _, _ = unpack_toil_uri(uri)
# Use the encoded size
total_size += file_id.size
else:
Expand Down Expand Up @@ -429,7 +431,6 @@ class ToilWDLStdLibBase(WDL.StdLib.Base):
"""
Standard library implementation for WDL as run on Toil.
"""

def __init__(self, file_store: AbstractFileStore, execution_dir: Optional[str] = None):
"""
Set up the standard library.
Expand All @@ -448,6 +449,9 @@ def __init__(self, file_store: AbstractFileStore, execution_dir: Optional[str] =
# Keep the file store around so we can access files.
self._file_store = file_store

# UUID to differentiate which node files are virtualized from
self._parent_dir_to_ids: Dict[str, uuid.UUID] = dict()

self._execution_dir = execution_dir

def _is_url(self, filename: str, schemes: List[str] = ['http:', 'https:', 's3:', 'gs:', TOIL_URI_SCHEME]) -> bool:
Expand All @@ -459,6 +463,7 @@ def _is_url(self, filename: str, schemes: List[str] = ['http:', 'https:', 's3:',
return True
return False

@memoize
def _devirtualize_filename(self, filename: str) -> str:
"""
'devirtualize' filename passed to a read_* function: return a filename that can be open()ed
Expand All @@ -470,11 +475,16 @@ def _devirtualize_filename(self, filename: str) -> str:
if filename.startswith(TOIL_URI_SCHEME):
# This is a reference to the Toil filestore.
# Deserialize the FileID
file_id, file_basename = unpack_toil_uri(filename)
file_id, parent_id, file_basename = unpack_toil_uri(filename)

# Decide where it should be put
file_dir = self._file_store.getLocalTempDir()
dest_path = os.path.join(file_dir, file_basename)
# This is a URI with the "parent" UUID attached to the filename
# Use UUID as folder name rather than a new temp folder to reduce internal clutter
dir_path = os.path.join(self._file_store.localTempDir, parent_id)
if not os.path.exists(parent_id):
os.mkdir(dir_path)
# Put the UUID in the destination path in order for tasks to see where to put files depending on their parents
dest_path = os.path.join(dir_path, file_basename)

# And get a local path to the file
result = self._file_store.readGlobalFile(file_id, dest_path)
Expand Down Expand Up @@ -506,8 +516,6 @@ def _virtualize_filename(self, filename: str) -> str:
from a local path in write_dir, 'virtualize' into the filename as it should present in a
File value
"""


if self._is_url(filename):
# Already virtual
logger.debug('Already virtualized %s as WDL file %s', filename, filename)
Expand All @@ -521,7 +529,9 @@ def _virtualize_filename(self, filename: str) -> str:
file_id = self._file_store.writeGlobalFile(os.path.join(self._execution_dir, filename))
else:
file_id = self._file_store.writeGlobalFile(filename)
result = pack_toil_uri(file_id, os.path.basename(filename))
dir = os.path.dirname(os.path.abspath(filename)) # is filename always an abspath?
parent_id = self._parent_dir_to_ids.setdefault(dir, uuid.uuid4())
result = pack_toil_uri(file_id, parent_id, os.path.basename(filename))
logger.debug('Virtualized %s as WDL file %s', filename, result)
return result

Expand All @@ -543,6 +553,7 @@ def __init__(self, file_store: AbstractFileStore, container: TaskContainer):
super().__init__(file_store)
self.container = container

@memoize
def _devirtualize_filename(self, filename: str) -> str:
"""
Go from a virtualized WDL-side filename to a local disk filename.
Expand Down Expand Up @@ -681,6 +692,7 @@ def _glob(self, pattern: WDL.Value.String) -> WDL.Value.Array:
# Just turn them all into WDL File objects with local disk out-of-container names.
return WDL.Value.Array(WDL.Type.File(), [WDL.Value.File(x) for x in results])

@memoize
def _devirtualize_filename(self, filename: str) -> str:
"""
Go from a virtualized WDL-side filename to a local disk filename.
Expand Down Expand Up @@ -797,8 +809,8 @@ def devirtualize_files(environment: WDLBindings, stdlib: WDL.StdLib.Base) -> WDL
"""
Make sure all the File values embedded in the given bindings point to files
that are actually available to command line commands.
The same virtual file always maps to the same devirtualized filename even with duplicates
"""

return map_over_files_in_bindings(environment, stdlib._devirtualize_filename)

def virtualize_files(environment: WDLBindings, stdlib: WDL.StdLib.Base) -> WDLBindings:
Expand All @@ -809,6 +821,39 @@ def virtualize_files(environment: WDLBindings, stdlib: WDL.StdLib.Base) -> WDLBi

return map_over_files_in_bindings(environment, stdlib._virtualize_filename)

def add_paths(task_container: TaskContainer, host_paths: Iterable[str]) -> None:
"""
Based off of WDL.runtime.task_container.add_paths from miniwdl
Maps the host path to the container paths
"""
# partition the files by host directory
host_paths_by_dir: Dict[str, Set[str]] = {}
for host_path in host_paths:
host_path_strip = host_path.rstrip("/")
if host_path not in task_container.input_path_map and host_path_strip not in task_container.input_path_map:
if not os.path.exists(host_path_strip):
raise Error.InputError("input path not found: " + host_path)
host_paths_by_dir.setdefault(os.path.dirname(host_path_strip), set()).add(host_path)
# for each such partition of files
# - if there are no basename collisions under input subdirectory 0, then mount them there.
# - otherwise, mount them in a fresh subdirectory
subd = 0
id_to_subd: Dict[str, str] = {}
for paths in host_paths_by_dir.values():
based = os.path.join(task_container.container_dir, "work/_miniwdl_inputs")
for host_path in paths:
parent_id = os.path.basename(os.path.dirname(host_path))
if id_to_subd.get(parent_id, None) is None:
id_to_subd[parent_id] = str(subd)
subd += 1
host_path_subd = id_to_subd[parent_id]
container_path = os.path.join(based, host_path_subd, os.path.basename(host_path.rstrip("/")))
if host_path.endswith("/"):
container_path += "/"
assert container_path not in task_container.input_path_map_rev, f"{container_path}, {task_container.input_path_map_rev}"
task_container.input_path_map[host_path] = container_path
task_container.input_path_map_rev[container_path] = host_path

def import_files(environment: WDLBindings, toil: Toil, path: Optional[List[str]] = None) -> WDLBindings:
"""
Make sure all File values embedded in the given bindings are imported,
Expand All @@ -817,7 +862,8 @@ def import_files(environment: WDLBindings, toil: Toil, path: Optional[List[str]]
:param path: If set, try resolving input location relative to the URLs or
directories in this list.
"""

path_to_id: Dict[str, uuid.UUID] = {}
@memoize
def import_file_from_uri(uri: str) -> str:
"""
Import a file from a URI and return a virtualized filename for it.
Expand Down Expand Up @@ -855,7 +901,10 @@ def import_file_from_uri(uri: str) -> str:
raise RuntimeError(f"File {candidate_uri} has no basename and so cannot be a WDL File")

# Was actually found
return pack_toil_uri(imported, file_basename)
# Pack a UUID of the parent directory
dir_id = path_to_id.setdefault(os.path.dirname(candidate_uri), uuid.uuid4())

return pack_toil_uri(imported, dir_id, file_basename)

# If we get here we tried all the candidates
raise RuntimeError(f"Could not find {uri} at any of: {tried}")
Expand Down Expand Up @@ -1170,6 +1219,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# For a task we are only passed the inside-the-task namespace.
bindings = combine_bindings(unwrap_all(self._prev_node_results))
# Set up the WDL standard library
# UUID to use for virtualizing files
standard_library = ToilWDLStdLibBase(file_store)

if self._task.inputs:
Expand Down Expand Up @@ -1427,7 +1477,7 @@ def patched_run_invocation(*args: Any, **kwargs: Any) -> List[str]:
# Tell the container to take up all these files. It will assign
# them all new paths in task_container.input_path_map which we can
# read. We also get a task_container.host_path() to go the other way.
task_container.add_paths(get_file_paths_in_bindings(bindings))
add_paths(task_container, get_file_paths_in_bindings(bindings))
logger.debug("Using container path map: %s", task_container.input_path_map)

# Replace everything with in-container paths for the command.
Expand Down Expand Up @@ -1529,7 +1579,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# Combine the bindings we get from previous jobs
incoming_bindings = combine_bindings(unwrap_all(self._prev_node_results))
# Set up the WDL standard library
standard_library = ToilWDLStdLibBase(file_store, self._execution_dir)
standard_library = ToilWDLStdLibBase(file_store, execution_dir=self._execution_dir)
with monkeypatch_coerce(standard_library):
if isinstance(self._node, WDL.Tree.Decl):
# This is a variable assignment
Expand Down Expand Up @@ -1618,7 +1668,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# Combine the bindings we get from previous jobs
current_bindings = combine_bindings(unwrap_all(self._prev_node_results))
# Set up the WDL standard library
standard_library = ToilWDLStdLibBase(file_store, self._execution_dir)
standard_library = ToilWDLStdLibBase(file_store, execution_dir=self._execution_dir)

with monkeypatch_coerce(standard_library):
for node in self._nodes:
Expand Down Expand Up @@ -2273,7 +2323,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# For a task we only see the insode-the-task namespace.
bindings = combine_bindings(unwrap_all(self._prev_node_results))
# Set up the WDL standard library
standard_library = ToilWDLStdLibBase(file_store, self._execution_dir)
standard_library = ToilWDLStdLibBase(file_store, execution_dir=self._execution_dir)

if self._workflow.inputs:
with monkeypatch_coerce(standard_library):
Expand Down Expand Up @@ -2335,7 +2385,7 @@ def run(self, file_store: AbstractFileStore) -> WDLBindings:
else:
# Output section is declared and is nonempty, so evaluate normally
# Evaluate all the outputs in the normal, non-task-outputs library context
standard_library = ToilWDLStdLibBase(file_store, self._execution_dir)
standard_library = ToilWDLStdLibBase(file_store, execution_dir=self._execution_dir)
# Combine the bindings from the previous job
output_bindings = evaluate_output_decls(self._workflow.outputs, unwrap(self._bindings), standard_library)
return self.postprocess(output_bindings)
Expand Down Expand Up @@ -2514,10 +2564,11 @@ def devirtualize_output(filename: str) -> str:
if filename.startswith(TOIL_URI_SCHEME):
# This is a reference to the Toil filestore.
# Deserialize the FileID and required basename
file_id, file_basename = unpack_toil_uri(filename)
file_id, parent_id, file_basename = unpack_toil_uri(filename)
# Figure out where it should go.
# TODO: Deal with name collisions
# If a UUID is included, it will be omitted
dest_name = os.path.join(output_directory, file_basename)
# TODO: Deal with name collisions
# Export the file
toil.export_file(file_id, dest_name)
# And return where we put it
Expand Down

0 comments on commit 6c0fe1e

Please sign in to comment.