Skip to content

Commit 08e28ba

Browse files
committed
fix: Confusion with Overriding input After Snakemake Modularization
1. use name to map dependence 2. restrict cases when rule can be overwritten
1 parent 26fcd38 commit 08e28ba

5 files changed

Lines changed: 44 additions & 26 deletions

File tree

src/snakemake/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def __init__(
177177
self.allow_rule_overwrite = parent_modifier.allow_rule_overwrite
178178
self.path_modifier = parent_modifier.path_modifier
179179
self.replace_wrapper_tag = parent_modifier.replace_wrapper_tag
180-
self.namespace = parent_modifier.namespace
181180
self.wildcard_constraints = parent_modifier.wildcard_constraints
182181
self.rules = parent_modifier.rules
183182
self.rule_proxies = parent_modifier.rule_proxies
@@ -193,6 +192,7 @@ def __init__(
193192
self.globals["checkpoints"] = self.globals[
194193
"checkpoints"
195194
].spawn_new_namespace()
195+
self.globals["__name__"] = namespace
196196

197197
self.workflow = workflow
198198
self.base_snakefile = base_snakefile

src/snakemake/rules.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313
from itertools import chain
1414
from functools import partial
15+
from typing import Union
1516

1617
try:
1718
import re._constants as sre_constants
@@ -74,7 +75,7 @@
7475

7576

7677
class Rule(RuleInterface):
77-
def __init__(self, name, workflow, lineno=None, snakefile=None):
78+
def __init__(self, name: str, workflow, lineno=None, snakefile=None):
7879
"""
7980
Create a rule
8081
@@ -124,7 +125,7 @@ def __init__(self, name, workflow, lineno=None, snakefile=None):
124125
self.log_modifier = None
125126
self.benchmark_modifier = None
126127
self.ruleinfo = None
127-
self.module_globals = None
128+
self.module_globals: dict
128129

129130
@property
130131
def name(self):
@@ -462,7 +463,7 @@ def _set_inoutput_item(self, item, output=False, name=None, mark_ancient=False):
462463

463464
rule_dependency = None
464465
if isinstance(item, _IOFile) and item.rule and item in item.rule.output:
465-
rule_dependency = item.rule
466+
rule_dependency = item.rule.name
466467

467468
if output:
468469
path_modifier = self.output_modifier
@@ -895,20 +896,23 @@ def handle_incomplete_checkpoint(exception):
895896
)
896897

897898
if self.dependencies:
898-
dependencies = {
899+
rule_depends = {
899900
f: self.dependencies[f_]
900901
for f, f_ in mapping.items()
901902
if f_ in self.dependencies
902903
}
903904
if None in self.dependencies:
904-
dependencies[None] = self.dependencies[None]
905+
rule_depends[None] = self.dependencies[None]
906+
job_depends = {
907+
f: self.workflow.get_rule(d) for f, d in rule_depends.items()
908+
}
905909
else:
906-
dependencies = self.dependencies
910+
job_depends = {}
907911

908912
for f in input:
909913
f.check()
910914

911-
return input, mapping, dependencies, incomplete
915+
return input, mapping, job_depends, incomplete
912916

913917
@classmethod
914918
def _is_deriving_function(cls, func):

src/snakemake/workflow.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __post_init__(self):
174174
self.global_resources["_cores"] = self.resource_settings.cores
175175
self.global_resources["_nodes"] = self.resource_settings.nodes
176176

177-
self._rules = OrderedDict()
177+
self._rules: OrderedDict[str, Rule] = OrderedDict()
178178
self.default_target = None
179179
self._workdir_init = os.path.abspath(os.curdir)
180180
self._ruleorder = Ruleorder()
@@ -634,24 +634,34 @@ def add_rule(
634634
checkpoint=False,
635635
allow_overwrite=False,
636636
):
637+
"""Add a rule.
638+
Check if the rule can be overwritten.
639+
640+
> Specific rules may even be modified before using them,
641+
> via a final with: followed by a block that lists items to overwrite.
642+
> This modification can be performed after a general import,
643+
> and will overwrite any unmodified import of the same rule.
637644
"""
638-
Add a rule.
639-
"""
640-
is_overwrite = self.is_rule(name)
641-
if not allow_overwrite and is_overwrite:
642-
raise CreateRuleException(
643-
f"The name {name} is already used by another rule",
644-
lineno=lineno,
645-
snakefile=snakefile,
645+
if self.is_rule(name):
646+
is_overwrite = allow_overwrite and (
647+
self.get_rule(name).module_globals["__name__"]
648+
== self.modifier.namespace
646649
)
650+
if not is_overwrite:
651+
raise CreateRuleException(
652+
f"The name {name} is already used by another rule",
653+
lineno=lineno,
654+
snakefile=snakefile,
655+
)
656+
else:
657+
is_overwrite = False
658+
self.rule_count += 1
659+
if not self.default_target:
660+
self.default_target = name
647661
rule = Rule(name, self, lineno=lineno, snakefile=snakefile)
648662
self._rules[rule.name] = rule
649663
self.modifier.rules.add(rule)
650-
if not is_overwrite:
651-
self.rule_count += 1
652-
if not self.default_target:
653-
self.default_target = rule.name
654-
return name
664+
return is_overwrite
655665

656666
def is_rule(self, name):
657667
"""
@@ -1766,7 +1776,7 @@ def decorate(ruleinfo):
17661776
orig_name = name
17671777
name = self.modifier.modify_rulename(name)
17681778

1769-
name = self.add_rule(
1779+
is_overwrite = self.add_rule(
17701780
name,
17711781
lineno,
17721782
snakefile,
@@ -1776,6 +1786,8 @@ def decorate(ruleinfo):
17761786
rule = self.get_rule(name)
17771787
rule.is_checkpoint = checkpoint
17781788
rule.module_globals = self.modifier.globals
1789+
if is_overwrite:
1790+
rule.module_globals["__name__"] = None
17791791

17801792
def decorate(ruleinfo): # type: ignore[no-redef]
17811793
nonlocal name
@@ -1789,10 +1801,9 @@ def decorate(ruleinfo): # type: ignore[no-redef]
17891801
**ruleinfo.wildcard_constraints[1],
17901802
)
17911803
if ruleinfo.name:
1792-
rule.name = ruleinfo.name
17931804
del self._rules[name]
1794-
self._rules[ruleinfo.name] = rule
1795-
name = rule.name
1805+
name = rule.name = ruleinfo.name
1806+
self._rules[name] = rule
17961807
if ruleinfo.input:
17971808
rule.input_modifier = ruleinfo.input.modifier
17981809
rule.set_input(*ruleinfo.input.paths, **ruleinfo.input.kwpaths)

tests/test_modules_all/Snakefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ module test:
1515

1616
use rule * from test as test_*
1717

18+
use rule a from test as test_* with:
19+
input: "config/config.yaml"
1820

1921
rule all:
2022
input:

tests/test_modules_all/module-test/Snakefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ def some_func():
66

77

88
rule a:
9+
input: "not exist"
910
output:
1011
temp("results/a/{name}.out"),
1112
shell:

0 commit comments

Comments
 (0)