Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/snakefiles/rules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,30 @@ Further, an output file marked as ``temp`` is deleted after all rules that use i

.. _snakefiles-directory_output:

Auto-grouping via temp files upon remote execution
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For performance reasons, it is sometimes useful to write intermediate files on a faster storage, e.g., attached locally on the cluster compute node rather than shared over the network (and thus neither visible to the main snakemake process that submits jobs to the cluster, nor to other nodes of the cluster).
Snakemake (since version 9.0) allows files marked as ``temp`` to use the option ``group_jobs`` to indicate that rules creating and consuming them should be automatically :ref:`grouped <job_grouping>` together so Snakemake will schedule them to run on the same physical node:

.. code-block:: python

rule NAME1:
input:
"path/to/inputfile"
output:
temp("path/to/intermediatefile", group_jobs=True)
shell:
"somecommand {input} {output}"

rule NAME2:
input:
"path/to/intermediatefile"
output:
"path/to/outputfile"
shell:
"someothercommand {input} {output}"

Directories as outputs
----------------------

Expand Down
77 changes: 49 additions & 28 deletions src/snakemake/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ async def handle_storage(self, job, store_in_storage=True, store_only_log=False)
and not f.should_not_be_retrieved_from_storage
and not is_flagged(f, "pipe")
and not is_flagged(f, "service")
and not is_flagged(f, "nodelocal")
):
await f.store_in_storage()
storage_mtime = (await f.mtime()).storage()
Expand Down Expand Up @@ -1772,11 +1773,13 @@ def handle_pipes_and_services(self):
user_groups.add(job.group)
all_depending = set()
has_pipe_or_service = False
has_nodelocal = False
for f in job.output:
is_pipe = is_flagged(f, "pipe")
is_service = is_flagged(f, "service")
if is_pipe or is_service:
if job.is_run:
is_nodelocal = is_flagged(f, "nodelocal")
if is_pipe or is_service or is_nodelocal:
if not is_nodelocal and job.is_run:
raise WorkflowError(
"Rule defines pipe output but "
"uses a 'run' directive. This is "
Expand All @@ -1786,7 +1789,11 @@ def handle_pipes_and_services(self):
rule=job.rule,
)

has_pipe_or_service = True
if is_pipe or is_service:
has_pipe_or_service = True
if is_nodelocal:
has_nodelocal = True

depending = [
j for j, files in self.depending[job].items() if f in files
]
Expand All @@ -1799,7 +1806,7 @@ def handle_pipes_and_services(self):
"job".format(f),
rule=job.rule,
)
elif len(depending) == 0:
elif not is_nodelocal and len(depending) == 0:
raise WorkflowError(
"Output file {} is marked as pipe or service "
"but it has no consumer. This is "
Expand All @@ -1815,7 +1822,7 @@ def handle_pipes_and_services(self):
)

for dep in depending:
if dep.is_run:
if not is_nodelocal and dep.is_run:
raise WorkflowError(
"Rule consumes pipe or service input but "
"uses a 'run' directive. This is "
Expand All @@ -1826,42 +1833,54 @@ def handle_pipes_and_services(self):
)

all_depending.add(dep)
if dep.pipe_group is not None:
if (is_pipe or is_service) and dep.pipe_group is not None:
candidate_groups.add(dep.pipe_group)
if dep.group is not None:
user_groups.add(dep.group)

if not has_pipe_or_service:
if not has_pipe_or_service and not has_nodelocal:
continue

# All pipe groups should be contained within one user-defined group
if len(user_groups) > 1:
raise WorkflowError(
"An output file is marked as "
"pipe or service, but consuming jobs "
"pipe, service or nodelocal, but consuming jobs "
"are part of conflicting "
"groups.",
f"groups. {user_groups}",
rule=job.rule,
)

if len(candidate_groups) > 1:
# Merge multiple pipe groups together
group = candidate_groups.pop()
for g in candidate_groups:
g.merge(group)
elif candidate_groups:
# extend the candidate group to all involved jobs
group = candidate_groups.pop()
else:
# generate a random unique group name
group = CandidateGroup() # str(uuid.uuid4())

# Assign the pipe group to all involved jobs.
job.pipe_group = group
visited.add(job)
for j in all_depending:
j.pipe_group = group
visited.add(j)
if has_pipe_or_service:
if len(candidate_groups) > 1:
# Merge multiple pipe groups together
group = candidate_groups.pop()
for g in candidate_groups:
g.merge(group)
elif candidate_groups:
# extend the candidate group to all involved jobs
group = candidate_groups.pop()
else:
# generate a random unique group name
group = CandidateGroup() # str(uuid.uuid4())

# Assign the pipe group to all involved jobs.
job.pipe_group = group
visited.add(job)
for j in all_depending:
j.pipe_group = group
visited.add(j)

if has_nodelocal and not self.workflow.local_exec:
# put the dependencies in the same user group (not pipe group as pipes are ran in parallel, whereas we want serial for nodelocal files)
# NOTE do NOT create usergroup in the node's local inst of snakemake lest the resources get merge, see: resources.py: class GroupResources
ugroup = user_groups.pop() if user_groups else str(uuid.uuid4())

job.group = ugroup
visited.add(job)
for j in all_depending:
j.group = ugroup
visited.add(j)

# convert candidate groups to plain string IDs
for job in visited:
Expand Down Expand Up @@ -2683,7 +2702,9 @@ def fmt_output(f):

status = "ok"
if not await f.exists():
if is_flagged(f, "temp"):
if is_flagged(f, "nodelocal"):
status = "file local to node"
elif is_flagged(f, "temp"):
status = "removed temp file"
elif is_flagged(f, "pipe"):
status = "pipe file"
Expand Down
12 changes: 11 additions & 1 deletion src/snakemake/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,14 +1129,22 @@ def directory(value):
return flag(value, "directory")


def temp(value):
def temp(value, group_jobs=False):
"""
A flag for an input or output file that shall be removed after usage.
When set to true, the extra flag "group_jobs" causes the file to also be flagged as "nodelocal":
A flag for an intermediate file that only lives on the compute node executing the group jobs and not accessible from the main snakemake job.
e.g. for what some HPC call "local scratch". This will cause snakemake to automatically group rules on the same compute note.
"""

if is_flagged(value, "protected"):
raise SyntaxError("Protected and temporary flags are mutually exclusive.")
if is_flagged(value, "storage_object"):
raise SyntaxError("Storage and temporary flags are mutually exclusive.")

if group_jobs:
value = flag(value, "nodelocal")
return flag(value, "temp")


Expand Down Expand Up @@ -1216,6 +1224,8 @@ def protected(value):
raise SyntaxError("Protected and temporary flags are mutually exclusive.")
if is_flagged(value, "storage_object"):
raise SyntaxError("Storage and protected flags are mutually exclusive.")
if is_flagged(value, "nodelocal"):
raise SyntaxError("Protected and nodelocal flags are mutually exclusive.")
return flag(value, "protected")


Expand Down
5 changes: 4 additions & 1 deletion src/snakemake/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def format_file(f, is_input: bool):
return f"{f} (pipe)"
elif is_flagged(f, "service"):
return f"{f} (service)"
elif is_flagged(f, "nodelocal"):
return f"{f} (nodelocal)"
elif is_flagged(f, "update"):
return f"{f} (update)"
elif is_flagged(f, "before_update"):
Expand Down Expand Up @@ -1617,7 +1619,8 @@ def needed(job_, f):
f
for j in self.jobs
for f in j.output
if is_flagged(f, "temp") and not needed(j, f)
if is_flagged(f, "nodelocal")
or (is_flagged(f, "temp") and not needed(j, f))
]

# Iterate over jobs in toposorted order (see self.__iter__) to
Expand Down
1 change: 1 addition & 0 deletions src/snakemake/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def _set_inoutput_item(self, item, output=False, name=None, mark_ancient=False):
if not output and item_flag in [
"protected",
"temp",
"nodelocal",
"temporary",
"directory",
"touch",
Expand Down
36 changes: 36 additions & 0 deletions tests/test_nodelocal/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
rule all:
input:
"results/result.txt",


rule create_nodelocal_temp:
output:
temp("scratch/temp.txt", group_jobs=True),
shell:
"""
sleep 4
echo "test" > "{output}"
"""


rule create_nodelocal_persist:
output:
flag("scratch/persist.txt", "nodelocal"),
shell:
"""
sleep 4
echo "test" > "{output}"
"""


rule consume_nodelocal:
input:
tmp="scratch/temp.txt",
persist="scratch/persist.txt",
output:
"results/result.txt",
shell:
"""
ls "{input.tmp}" > "{output}"
ls "{input.persist}" >> "{output}"
"""
1 change: 1 addition & 0 deletions tests/test_nodelocal/expected-results/local/local
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
local
1 change: 1 addition & 0 deletions tests/test_nodelocal/expected-results/local/persist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test
2 changes: 2 additions & 0 deletions tests/test_nodelocal/expected-results/results/result.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
scratch/temp.txt
scratch/persist.txt
12 changes: 12 additions & 0 deletions tests/test_nodelocal/qsub
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/sh

# Simulates node-local storage (a.k.a. "local scratch"):
# - files and paths only accessible within a single compute node
# - NOT accessible on other nodes nor on the login node

mkdir -p scratch local

# Unshare the mount namespace:
# - mounts within qsub_stage2 NOT shared with other processes
# - outer process will keep seeing "scratch" unaffected
unshare --mount --map-root-user ./qsub_stage2 "$@"
20 changes: 20 additions & 0 deletions tests/test_nodelocal/qsub_stage2
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

# Simulates node-local storage (a.k.a. "local scratch"):
# - files and paths only accessible within a single compute node
# - NOT accessible on other nodes nor on the login node
#
# Make the mount namespace private and bind 'local' to appear as 'scratch' within the namespace
mount --make-rprivate --bind local scratch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for mount command.

The mount command might fail due to insufficient permissions or if directories don't exist, but the script continues regardless of the outcome.

-mount --make-rprivate --bind local scratch
+if ! mount --make-rprivate --bind local scratch; then
+  echo "Error: Failed to bind mount local to scratch. Ensure both directories exist and you have sufficient permissions." >&2
+  exit 1
+fi
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mount --make-rprivate --bind local scratch
if ! mount --make-rprivate --bind local scratch; then
echo "Error: Failed to bind mount local to scratch. Ensure both directories exist and you have sufficient permissions." >&2
exit 1
fi

# Create a marker file to verify the binding is limited
# - current script ^will see it into "scratch/" too
# - outer calling processes will only see it in "local/", will never see content appearing in "scratch/"
echo "local" > scratch/local

# normal qsub like other cluster-simulated tests

echo `date` >> qsub.log
tail -n1 $1 >> qsub.log
# simulate printing of job id by a random number
echo $RANDOM
sh $1
15 changes: 15 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,3 +2151,18 @@ def test_issue3361_fail():
targets=["all"],
shouldfail=True,
)


@skip_on_windows
def test_nodelocal():
work_path = Path("test_nodelocal")
run(
dpath(work_path),
cluster="./qsub",
cores=1,
resources={"mem_mb": 120},
default_resources=DefaultResources(["mem_mb=120"]),
)
assert not (work_path / "local/temp.txt").exists() or not any(
(work_path / "scratch/").iterdir()
)