Skip to content

Commit c3ffcf7

Browse files
committed
feat: nested modify
1 parent d108819 commit c3ffcf7

3 files changed

Lines changed: 46 additions & 31 deletions

File tree

src/snakemake/modules.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def __init__(
6767
"Module definition contains both prefix and replace_prefix. "
6868
"Only one at a time is allowed."
6969
)
70-
71-
self.replace_prefix = replace_prefix
72-
self.prefix = prefix
70+
self.path_modifier = PathModifier(
71+
replace_prefix, prefix, workflow, self.parent_modifier.path_modifier
72+
)
7373

7474
def use_rules(
7575
self,
@@ -95,8 +95,7 @@ def use_rules(
9595
ruleinfo_overwrite=ruleinfo,
9696
allow_rule_overwrite=True,
9797
namespace=self.name,
98-
replace_prefix=self.replace_prefix,
99-
prefix=self.prefix,
98+
path_modifier=self.path_modifier,
10099
replace_wrapper_tag=self.get_wrapper_tag(),
101100
rule_proxies=self.rule_proxies,
102101
)
@@ -153,40 +152,46 @@ def __init__(
153152
rule_exclude_list=None,
154153
ruleinfo_overwrite=None,
155154
allow_rule_overwrite=False,
156-
replace_prefix=None,
157-
prefix=None,
155+
path_modifier: PathModifier | None = None,
158156
replace_wrapper_tag=None,
159157
namespace=None,
160-
rule_proxies=None,
158+
rule_proxies: Rules | None = None,
161159
):
162160
if parent_modifier is None:
163-
# default settings for globals if not inheriting from parent
164-
self.globals = (
165-
globals if globals is not None else dict(workflow.vanilla_globals)
166-
)
161+
if globals is None: # use rule from module with maybe_ruleinfo
162+
globals = dict(workflow.vanilla_globals)
163+
else:
164+
# the first module modifier of workflow
165+
rule_proxies = Rules()
166+
path_modifier = PathModifier(None, None, workflow, None)
167167
self.wildcard_constraints: dict = dict()
168168
self.rules: set = set()
169-
self.rule_proxies = rule_proxies or Rules()
170-
self.globals["rules"] = self.rule_proxies
169+
self.globals = globals
170+
self.globals["rules"] = self.rule_proxies = rule_proxies
171171
self.globals["checkpoints"] = self.globals[
172172
"checkpoints"
173173
].spawn_new_namespace()
174+
if config is not None:
175+
self.globals["config"] = config
174176
self.globals["__name__"] = namespace
175177
self.modules: dict = dict()
176-
else:
177-
# init with values from parent modifier
178+
self.parent_modifier = parent_modifier
179+
self.path_modifier = path_modifier
180+
elif parent_modifier is not None:
181+
# use rule (from same include) as ... with: init with values from parent modifier
178182
self.globals = parent_modifier.globals
179183
self.wildcard_constraints = parent_modifier.wildcard_constraints
180184
self.rules = parent_modifier.rules
181185
self.rule_proxies = parent_modifier.rule_proxies
182186
self.modules = parent_modifier.modules
187+
self.parent_modifier = parent_modifier.parent_modifier
188+
self.path_modifier = parent_modifier.path_modifier
189+
else:
190+
raise WorkflowError("Invalid workflow modifier configuration.")
183191

184192
self.workflow = workflow
185193
self.base_snakefile = base_snakefile
186194

187-
if config is not None:
188-
self.globals["config"] = config
189-
190195
self.skip_configfile = config is not None
191196
self.resolved_rulename_modifier = resolved_rulename_modifier
192197
self.local_rulename_modifier = local_rulename_modifier
@@ -196,13 +201,12 @@ def __init__(
196201
self.rule_exclude_list = rule_exclude_list
197202
self.ruleinfo_overwrite = ruleinfo_overwrite
198203
self.allow_rule_overwrite = allow_rule_overwrite
199-
self.path_modifier = PathModifier(replace_prefix, prefix, workflow) # type: ignore[reportArgumentType]
200204
self.replace_wrapper_tag = replace_wrapper_tag
201205
self.namespace = namespace
202206
self.default_input_flags: DefaultFlags = DefaultFlags()
203207
self.default_output_flags: DefaultFlags = DefaultFlags()
204208

205-
def inherit_rule_proxies(self, child_modifier):
209+
def inherit_rule_proxies(self, child_modifier: "WorkflowModifier"):
206210
for name, rule in child_modifier.rule_proxies._rules.items():
207211
if child_modifier.local_rulename_modifier is not None:
208212
name = child_modifier.local_rulename_modifier(name)

src/snakemake/path_modifier.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,24 @@
77

88
from snakemake.common.prefix_lookup import PrefixLookup
99
from snakemake.exceptions import WorkflowError
10-
from snakemake.io import (
11-
is_callable,
12-
is_flagged,
13-
AnnotatedString,
14-
flag,
15-
get_flag_value,
16-
)
10+
from snakemake.io import is_callable, is_flagged, AnnotatedString, flag, get_flag_value
1711
from snakemake.logging import logger
1812

1913
PATH_MODIFIER_FLAG = "path_modified"
2014

2115

2216
class PathModifier:
23-
def __init__(self, replace_prefix: dict, prefix: str, workflow):
24-
self.skip_properties = set()
17+
18+
def __init__(
19+
self,
20+
replace_prefix: dict | None,
21+
prefix: str | None,
22+
workflow,
23+
inner_modifier: "PathModifier | None" = None,
24+
):
25+
self.skip_properties: set = set()
2526
self.workflow = workflow
27+
self.inner_modifier = inner_modifier
2628

2729
self.prefix = None
2830
assert not (prefix and replace_prefix)
@@ -67,6 +69,12 @@ def modify(self, path, property=None):
6769
return modified_path
6870

6971
def replace_prefix(self, path, property=None):
72+
path = self._replace_prefix(path, property)
73+
if self.inner_modifier is not None:
74+
return self.inner_modifier.replace_prefix(path, property)
75+
return path
76+
77+
def _replace_prefix(self, path, property):
7078
if (self._prefix_replacements is None and self.prefix is None) or (
7179
property in self.skip_properties
7280
or os.path.isabs(path)

src/snakemake/ruleinfo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def apply_modifier(
6666
prefix_replacables={"input", "output", "log", "benchmark"},
6767
):
6868
"""Update this ruleinfo with the given one (used for 'use rule' overrides)."""
69-
path_modifier = modifier.path_modifier
7069
skips = set()
7170

7271
if modifier.ruleinfo_overwrite:
@@ -97,6 +96,7 @@ def apply_modifier(
9796
if key in prefix_replacables:
9897
skips.add(key)
9998

99+
path_modifier = modifier.path_modifier
100100
if path_modifier.modifies_prefixes and skips:
101101
# use a specialized copy of the path modifier
102102
path_modifier = copy(path_modifier)
@@ -106,3 +106,6 @@ def apply_modifier(
106106

107107
# modify wrapper if requested
108108
self.wrapper = modifier.modify_wrapper_uri(self.wrapper)
109+
110+
if modifier.parent_modifier is not None:
111+
self.apply_modifier(modifier.parent_modifier, rulename=rulename)

0 commit comments

Comments
 (0)