Skip to content

Commit a87e2a7

Browse files
committed
Rework the encoding to introduce node(ID, Package) nested facts
So far the encoding has a single ID per package, i.e. all the facts will be node(0, Package). This will prepare the stage for extending this logic and having multiple nodes from the same package in a DAG.
1 parent f2455dd commit a87e2a7

File tree

3 files changed

+631
-537
lines changed

3 files changed

+631
-537
lines changed

lib/spack/spack/solver/asp.py

Lines changed: 96 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -516,15 +516,17 @@ def _compute_specs_from_answer_set(self):
516516
best = min(self.answers)
517517
opt, _, answer = best
518518
for input_spec in self.abstract_specs:
519-
key = input_spec.name
519+
node = SpecBuilder.root_node(pkg=input_spec.name)
520520
if input_spec.virtual:
521-
providers = [spec.name for spec in answer.values() if spec.package.provides(key)]
522-
key = providers[0]
523-
candidate = answer.get(key)
521+
providers = [
522+
spec.name for spec in answer.values() if spec.package.provides(input_spec.name)
523+
]
524+
node = SpecBuilder.root_node(pkg=providers[0])
525+
candidate = answer.get(node)
524526

525527
if candidate and candidate.satisfies(input_spec):
526-
self._concrete_specs.append(answer[key])
527-
self._concrete_specs_by_input[input_spec] = answer[key]
528+
self._concrete_specs.append(answer[node])
529+
self._concrete_specs_by_input[input_spec] = answer[node]
528530
else:
529531
self._unsolved_specs.append(input_spec)
530532

@@ -2426,6 +2428,18 @@ class SpecBuilder(object):
24262428
)
24272429
)
24282430

2431+
node_regex = re.compile(r"node\(\d,\"(.*)\"\)")
2432+
2433+
@staticmethod
2434+
def root_node(*, pkg: str) -> str:
2435+
"""Given a package name, returns the string representation of the root node in
2436+
the ASP encoding.
2437+
2438+
Args:
2439+
pkg: name of a package
2440+
"""
2441+
return f'node(0,"{pkg}")'
2442+
24292443
def __init__(self, specs, hash_lookup=None):
24302444
self._specs = {}
24312445
self._result = None
@@ -2438,100 +2452,121 @@ def __init__(self, specs, hash_lookup=None):
24382452
# from this dictionary during reconstruction
24392453
self._hash_lookup = hash_lookup or {}
24402454

2441-
def hash(self, pkg, h):
2442-
if pkg not in self._specs:
2443-
self._specs[pkg] = self._hash_lookup[h]
2444-
self._hash_specs.append(pkg)
2455+
@staticmethod
2456+
def extract_pkg(node: str) -> str:
2457+
"""Extracts the package name from a node fact, and returns it.
2458+
2459+
Args:
2460+
node: node from which the package name is to be extracted
2461+
"""
2462+
m = SpecBuilder.node_regex.match(node)
2463+
if m is None:
2464+
raise spack.error.SpackError(f"cannot extract package information from '{node}'")
2465+
2466+
return m.group(1)
2467+
2468+
def hash(self, node, h):
2469+
if node not in self._specs:
2470+
self._specs[node] = self._hash_lookup[h]
2471+
self._hash_specs.append(node)
24452472

2446-
def node(self, pkg):
2447-
if pkg not in self._specs:
2448-
self._specs[pkg] = spack.spec.Spec(pkg)
2473+
def node(self, node):
2474+
pkg = self.extract_pkg(node)
2475+
if node not in self._specs:
2476+
self._specs[node] = spack.spec.Spec(pkg)
24492477

2450-
def _arch(self, pkg):
2451-
arch = self._specs[pkg].architecture
2478+
def _arch(self, node):
2479+
arch = self._specs[node].architecture
24522480
if not arch:
24532481
arch = spack.spec.ArchSpec()
2454-
self._specs[pkg].architecture = arch
2482+
self._specs[node].architecture = arch
24552483
return arch
24562484

2457-
def node_platform(self, pkg, platform):
2458-
self._arch(pkg).platform = platform
2485+
def node_platform(self, node, platform):
2486+
self._arch(node).platform = platform
24592487

2460-
def node_os(self, pkg, os):
2461-
self._arch(pkg).os = os
2488+
def node_os(self, node, os):
2489+
self._arch(node).os = os
24622490

2463-
def node_target(self, pkg, target):
2464-
self._arch(pkg).target = target
2491+
def node_target(self, node, target):
2492+
self._arch(node).target = target
24652493

2466-
def variant_value(self, pkg, name, value):
2494+
def variant_value(self, node, name, value):
24672495
# FIXME: is there a way not to special case 'dev_path' everywhere?
24682496
if name == "dev_path":
2469-
self._specs[pkg].variants.setdefault(
2497+
self._specs[node].variants.setdefault(
24702498
name, spack.variant.SingleValuedVariant(name, value)
24712499
)
24722500
return
24732501

24742502
if name == "patches":
2475-
self._specs[pkg].variants.setdefault(
2503+
self._specs[node].variants.setdefault(
24762504
name, spack.variant.MultiValuedVariant(name, value)
24772505
)
24782506
return
24792507

2480-
self._specs[pkg].update_variant_validate(name, value)
2508+
self._specs[node].update_variant_validate(name, value)
24812509

2482-
def version(self, pkg, version):
2483-
self._specs[pkg].versions = vn.VersionList([vn.Version(version)])
2510+
def version(self, node, version):
2511+
self._specs[node].versions = vn.VersionList([vn.Version(version)])
24842512

2485-
def node_compiler_version(self, pkg, compiler, version):
2486-
self._specs[pkg].compiler = spack.spec.CompilerSpec(compiler)
2487-
self._specs[pkg].compiler.versions = vn.VersionList([vn.Version(version)])
2513+
def node_compiler_version(self, node, compiler, version):
2514+
self._specs[node].compiler = spack.spec.CompilerSpec(compiler)
2515+
self._specs[node].compiler.versions = vn.VersionList([vn.Version(version)])
24882516

2489-
def node_flag_compiler_default(self, pkg):
2490-
self._flag_compiler_defaults.add(pkg)
2517+
def node_flag_compiler_default(self, node):
2518+
self._flag_compiler_defaults.add(node)
24912519

2492-
def node_flag(self, pkg, flag_type, flag):
2493-
self._specs[pkg].compiler_flags.add_flag(flag_type, flag, False)
2520+
def node_flag(self, node, flag_type, flag):
2521+
self._specs[node].compiler_flags.add_flag(flag_type, flag, False)
24942522

2495-
def node_flag_source(self, pkg, flag_type, source):
2496-
self._flag_sources[(pkg, flag_type)].add(source)
2523+
def node_flag_source(self, node, flag_type, source):
2524+
self._flag_sources[(node, flag_type)].add(source)
24972525

2498-
def no_flags(self, pkg, flag_type):
2499-
self._specs[pkg].compiler_flags[flag_type] = []
2526+
def no_flags(self, node, flag_type):
2527+
self._specs[node].compiler_flags[flag_type] = []
25002528

2501-
def external_spec_selected(self, pkg, idx):
2529+
def external_spec_selected(self, node, idx):
25022530
"""This means that the external spec and index idx
25032531
has been selected for this package.
25042532
"""
2533+
25052534
packages_yaml = spack.config.get("packages")
25062535
packages_yaml = _normalize_packages_yaml(packages_yaml)
2536+
pkg = self.extract_pkg(node)
25072537
spec_info = packages_yaml[pkg]["externals"][int(idx)]
2508-
self._specs[pkg].external_path = spec_info.get("prefix", None)
2509-
self._specs[pkg].external_modules = spack.spec.Spec._format_module_list(
2538+
self._specs[node].external_path = spec_info.get("prefix", None)
2539+
self._specs[node].external_modules = spack.spec.Spec._format_module_list(
25102540
spec_info.get("modules", None)
25112541
)
2512-
self._specs[pkg].extra_attributes = spec_info.get("extra_attributes", {})
2542+
self._specs[node].extra_attributes = spec_info.get("extra_attributes", {})
25132543

25142544
# If this is an extension, update the dependencies to include the extendee
2515-
package = self._specs[pkg].package_class(self._specs[pkg])
2545+
package = self._specs[node].package_class(self._specs[node])
25162546
extendee_spec = package.extendee_spec
2547+
25172548
if extendee_spec:
2518-
package.update_external_dependencies(self._specs.get(extendee_spec.name, None))
2549+
extendee_node = SpecBuilder.root_node(pkg=extendee_spec.name)
2550+
package.update_external_dependencies(self._specs.get(extendee_node, None))
25192551

2520-
def depends_on(self, pkg, dep, type):
2521-
dependencies = self._specs[pkg].edges_to_dependencies(name=dep)
2552+
def depends_on(self, parent_node, dependency_node, type):
2553+
dependencies = self._specs[parent_node].edges_to_dependencies(name=dependency_node)
25222554

25232555
# TODO: assertion to be removed when cross-compilation is handled correctly
25242556
msg = "Current solver does not handle multiple dependency edges of the same name"
25252557
assert len(dependencies) < 2, msg
25262558

25272559
if not dependencies:
2528-
self._specs[pkg].add_dependency_edge(self._specs[dep], deptypes=(type,), virtuals=())
2560+
self._specs[parent_node].add_dependency_edge(
2561+
self._specs[dependency_node], deptypes=(type,), virtuals=()
2562+
)
25292563
else:
25302564
# TODO: This assumes that each solve unifies dependencies
25312565
dependencies[0].update_deptypes(deptypes=(type,))
25322566

2533-
def virtual_on_edge(self, pkg, provider, virtual):
2534-
dependencies = self._specs[pkg].edges_to_dependencies(name=provider)
2567+
def virtual_on_edge(self, parent_node, provider_node, virtual):
2568+
provider = self.extract_pkg(provider_node)
2569+
dependencies = self._specs[parent_node].edges_to_dependencies(name=provider)
25352570
assert len(dependencies) == 1
25362571
dependencies[0].update_virtuals((virtual,))
25372572

@@ -2562,17 +2597,22 @@ def reorder_flags(self):
25622597

25632598
# order is determined by the DAG. A spec's flags come after any of its ancestors
25642599
# on the compile line
2565-
source_key = (spec.name, flag_type)
2600+
node = SpecBuilder.root_node(pkg=spec.name)
2601+
source_key = (node, flag_type)
25662602
if source_key in self._flag_sources:
2567-
order = [s.name for s in spec.traverse(order="post", direction="parents")]
2603+
order = [
2604+
SpecBuilder.root_node(pkg=s.name)
2605+
for s in spec.traverse(order="post", direction="parents")
2606+
]
25682607
sorted_sources = sorted(
25692608
self._flag_sources[source_key], key=lambda s: order.index(s)
25702609
)
25712610

25722611
# add flags from each source, lowest to highest precedence
2573-
for name in sorted_sources:
2612+
for node in sorted_sources:
25742613
all_src_flags = list()
2575-
per_pkg_sources = [self._specs[name]]
2614+
per_pkg_sources = [self._specs[node]]
2615+
name = self.extract_pkg(node)
25762616
if name in cmd_specs:
25772617
per_pkg_sources.append(cmd_specs[name])
25782618
for source in per_pkg_sources:
@@ -2645,14 +2685,14 @@ def build_specs(self, function_tuples):
26452685
# solving but don't construct anything. Do not ignore error
26462686
# predicates on virtual packages.
26472687
if name != "error":
2648-
pkg = args[0]
2688+
pkg = self.extract_pkg(args[0])
26492689
if spack.repo.path.is_virtual(pkg):
26502690
continue
26512691

26522692
# if we've already gotten a concrete spec for this pkg,
26532693
# do not bother calling actions on it except for node_flag_source,
26542694
# since node_flag_source is tracking information not in the spec itself
2655-
spec = self._specs.get(pkg)
2695+
spec = self._specs.get(args[0])
26562696
if spec and spec.concrete:
26572697
if name != "node_flag_source":
26582698
continue

0 commit comments

Comments
 (0)