Skip to content

Commit 18e78e4

Browse files
authored
fix: used_extensions should include transitive requirements (#2891)
When gathering the extensions required to define a Hugr, we only included the ones that were directly referenced by the Hugr definition (inside types or operations). To correctly load a Hugr, however, we also need to have access to any extension referenced by the _extension definitions_ themselves (in the signature of their operation definitions). Otherwise we wouldn't be able to load the bundled extensions, as we saw happen with guppy programs after Quantinuum/guppylang#1449 got merged. This equally affected the rust and python computation of used extensions.
1 parent 25c625d commit 18e78e4

File tree

11 files changed

+173
-7
lines changed

11 files changed

+173
-7
lines changed

hugr-core/src/extension/op_def.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ impl CustomValidator {
136136
}
137137
}
138138

139+
/// Return a reference to the `PolyFuncTypeRV` used by this validator.
140+
pub(crate) fn poly_func(&self) -> &PolyFuncTypeRV {
141+
&self.poly_func
142+
}
143+
139144
/// Return a mutable reference to the `PolyFuncType`.
140145
pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
141146
&mut self.poly_func
@@ -212,6 +217,15 @@ impl SignatureFunc {
212217
}
213218
}
214219

220+
/// Return the underlying poly function type when available.
221+
pub(crate) fn poly_func_type(&self) -> Option<&PolyFuncTypeRV> {
222+
match self {
223+
SignatureFunc::PolyFuncType(ts) | SignatureFunc::MissingValidateFunc(ts) => Some(ts),
224+
SignatureFunc::CustomValidator(custom) => Some(custom.poly_func()),
225+
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
226+
}
227+
}
228+
215229
/// Compute the concrete signature ([`FuncValueType`]).
216230
///
217231
/// # Panics

hugr-core/src/extension/resolution.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ pub enum ExtensionResolutionError<N: HugrNode = Node> {
151151
/// The missing extension.
152152
missing_extensions: ExtensionSet,
153153
},
154+
/// Error while collecting extension dependencies.
155+
#[display("Error collecting extension dependencies: {_0}")]
156+
#[from]
157+
ExtensionDependencyError(ExtensionCollectionError<N>),
154158
}
155159

156160
impl<N: HugrNode> ExtensionResolutionError<N> {
@@ -225,6 +229,17 @@ pub enum ExtensionCollectionError<N: HugrNode = Node> {
225229
/// The missing extensions.
226230
missing_extensions: Vec<ExtensionId>,
227231
},
232+
/// An extension definition references an extension that is not in the given registry.
233+
#[display(
234+
"Extension {extension} depends on dropped extensions {}",
235+
missing_extensions.join(", ")
236+
)]
237+
DroppedTransitiveExtensions {
238+
/// The extension that is missing dependencies.
239+
extension: String,
240+
/// The missing extensions.
241+
missing_extensions: Vec<ExtensionId>,
242+
},
228243
}
229244

230245
impl<N: HugrNode> ExtensionCollectionError<N> {

hugr-core/src/extension/resolution/extension.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
use std::mem;
88
use std::sync::Arc;
99

10-
use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef};
10+
use crate::extension::{
11+
Extension, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, TypeDef,
12+
};
1113

14+
use super::types::collect_signature_exts;
1215
use super::types_mut::resolve_signature_exts;
13-
use super::{ExtensionResolutionError, WeakExtensionRegistry};
16+
use super::{ExtensionCollectionError, ExtensionResolutionError, WeakExtensionRegistry};
1417

1518
impl ExtensionRegistry {
1619
/// Given a list of extensions that has been deserialized, create a new
@@ -39,6 +42,55 @@ impl ExtensionRegistry {
3942
Ok(exts)
4043
})
4144
}
45+
46+
/// Expand the registry with transitive extension dependencies.
47+
///
48+
/// This includes all extensions required to define the types in the
49+
/// operation signatures.
50+
pub fn extend_with_dependencies(&mut self) -> Result<(), ExtensionCollectionError> {
51+
let mut queue: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
52+
let mut seen: std::collections::BTreeSet<ExtensionId> = self.exts.keys().cloned().collect();
53+
54+
while let Some(ext) = queue.pop() {
55+
let deps = collect_extension_deps(&ext)?;
56+
for dep in deps {
57+
let dep_id = dep.name().clone();
58+
if seen.insert(dep_id.clone()) {
59+
self.register_updated(dep.clone());
60+
queue.push(dep);
61+
}
62+
}
63+
}
64+
65+
Ok(())
66+
}
67+
}
68+
69+
/// Collect extensions referenced by an extension's operation signatures.
70+
fn collect_extension_deps(
71+
extension: &Extension,
72+
) -> Result<ExtensionRegistry, ExtensionCollectionError> {
73+
let mut used = WeakExtensionRegistry::default();
74+
let mut missing = ExtensionSet::new();
75+
76+
for (_, op_def) in extension.operations() {
77+
if let Some(signature) = op_def.signature_func().poly_func_type() {
78+
let mut local_missing = ExtensionSet::new();
79+
collect_signature_exts(signature.body(), &mut used, &mut local_missing);
80+
for ext in local_missing {
81+
missing.insert(ext);
82+
}
83+
}
84+
}
85+
86+
if missing.is_empty() {
87+
Ok(used.try_into().expect("All extensions are valid"))
88+
} else {
89+
Err(ExtensionCollectionError::DroppedTransitiveExtensions {
90+
extension: extension.name().to_string(),
91+
missing_extensions: missing.into_iter().collect(),
92+
})
93+
}
4294
}
4395

4496
impl Extension {

hugr-core/src/extension/resolution/test.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::extension::{
1919
use crate::ops::constant::CustomConst;
2020
use crate::ops::constant::test::CustomTestValue;
2121
use crate::ops::{CallIndirect, ExtensionOp, Input, OpType, Tag, Value};
22+
use crate::std_extensions::arithmetic::conversions::{self, ConvertOpDef};
2223
use crate::std_extensions::arithmetic::float_types::{self, ConstF64, float64_type};
2324
use crate::std_extensions::arithmetic::int_ops;
2425
use crate::std_extensions::arithmetic::int_types::{self, int_type};
@@ -32,7 +33,7 @@ use crate::{Extension, Hugr, HugrView, type_row};
3233
#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())]
3334
// A type with extra extensions in its instantiated type arguments.
3435
#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4),
35-
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned()]
36+
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned(), PRELUDE.to_owned()]
3637
))]
3738
fn collect_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: ExtensionRegistry) {
3839
let op = op.into();
@@ -44,7 +45,7 @@ fn collect_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: Ex
4445
#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())]
4546
// A type with extra extensions in its instantiated type arguments.
4647
#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4),
47-
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned()]
48+
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned(), PRELUDE.to_owned()]
4849
))]
4950
fn resolve_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: ExtensionRegistry) {
5051
let op = op.into();
@@ -365,6 +366,31 @@ fn resolve_call() {
365366
check_extension_resolution(hugr);
366367
}
367368

369+
/// Test that extension resolution is transitive across extension dependencies.
370+
///
371+
/// `arithmetic.conversions` depends on `arithmetic.int_types` and
372+
/// `arithmetic.float_types`, so using an op from the former should cause all
373+
/// three extensions to be resolved even if the operation itself doesn't
374+
/// directly use floats.
375+
#[rstest]
376+
fn resolve_transitive_extension_deps() {
377+
let mut build = DFGBuilder::new(Signature::new(vec![int_type(6)], vec![usize_t()])).unwrap();
378+
let [input] = build.input_wires_arr();
379+
380+
let out = build
381+
.add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [input])
382+
.unwrap();
383+
384+
let hugr = build
385+
.finish_hugr_with_outputs(out.outputs())
386+
.unwrap_or_else(|e| panic!("{e}"));
387+
388+
assert!(hugr.extensions().contains(&conversions::EXTENSION_ID));
389+
assert!(hugr.extensions().contains(&float_types::EXTENSION_ID));
390+
391+
check_extension_resolution(hugr);
392+
}
393+
368394
/// Test the [`ExtensionRegistry::new_cyclic`] and [`ExtensionRegistry::new_with_extension_resolution`] methods.
369395
#[test]
370396
fn register_new_cyclic() {

hugr-core/src/hugr.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ impl Hugr {
308308
);
309309
}
310310

311+
used_extensions.extend_with_dependencies()?;
312+
311313
self.extensions = used_extensions;
312314
Ok(())
313315
}

hugr-core/src/ops.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ impl OpType {
412412
if let Some(ext) = collect_op_extension(None, self)? {
413413
reg.register_updated(ext);
414414
}
415+
reg.extend_with_dependencies()?;
415416
Ok(reg)
416417
}
417418
}

hugr-py/src/hugr/ext.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,23 @@ def add_type_def(self, type_def: TypeDef) -> TypeDef:
336336
self.types[type_def.name] = type_def
337337
return self.types[type_def.name]
338338

339+
def _resolve_used_extensions(
340+
self, registry: ExtensionRegistry | None = None
341+
) -> ExtensionResolutionResult:
342+
"""Collect extension dependencies from this extension's op signatures."""
343+
if registry is not None and self.name not in registry:
344+
registry.register_updated(self)
345+
346+
result = ExtensionResolutionResult()
347+
for op_def in self.operations.values():
348+
poly_func = op_def.signature.poly_func
349+
if poly_func is None:
350+
continue
351+
_, sig_result = poly_func._resolve_used_extensions(registry)
352+
result.extend(sig_result)
353+
354+
return result
355+
339356
@dataclass
340357
class OperationNotFound(NotFound):
341358
"""Operation not found in extension."""
@@ -564,3 +581,22 @@ def extend(self, other: ExtensionResolutionResult) -> None:
564581
self.unresolved_extensions.update(other.unresolved_extensions)
565582
self.unresolved_ops.update(other.unresolved_ops)
566583
self.unresolved_types.update(other.unresolved_types)
584+
585+
def _extend_with_transitive_ops(
586+
self, registry: ExtensionRegistry | None = None
587+
) -> None:
588+
"""Extend the set of extensions with transitive dependencies required by
589+
the OpDefs in each extension definition.
590+
"""
591+
queue: list[Extension] = list(self.used_extensions.extensions.values())
592+
593+
while queue:
594+
ext = queue.pop()
595+
op_result = ext._resolve_used_extensions(registry)
596+
597+
self.unresolved_extensions.update(op_result.unresolved_extensions)
598+
599+
for new_ext in op_result.used_extensions.extensions.values():
600+
if new_ext.name not in self.used_extensions:
601+
self.used_extensions.register_updated(new_ext)
602+
queue.append(new_ext)

hugr-py/src/hugr/hugr/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,8 @@ def used_extensions(
13071307
self[node].op = resolved_op
13081308
result.extend(op_result)
13091309

1310+
result._extend_with_transitive_ops(resolve_from)
1311+
13101312
return result
13111313

13121314
def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr:

hugr-py/src/hugr/ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def _resolve_used_extensions(
8585
) -> tuple[Op, ExtensionResolutionResult]:
8686
"""Resolve the extensions required to define this operation.
8787
88+
Does not include transitive dependencies required by the returned
89+
extension definitions, to avoid infinite recursion.
90+
8891
Args:
8992
registry: A registry to resolve unresolved extensions from.
9093
@@ -103,6 +106,9 @@ def used_extensions(
103106
"""Get the extensions used by this operation, optionally resolving
104107
unresolved types and operations.
105108
109+
Includes any extension transitively required by the returned extension
110+
definitions.
111+
106112
Args:
107113
resolve_from: Optional extension registry to resolve against.
108114
If None, opaque types and Custom ops will not be resolved.
@@ -111,6 +117,7 @@ def used_extensions(
111117
The result containing used and unresolved extensions.
112118
"""
113119
_, result = self._resolve_used_extensions(resolve_from)
120+
result._extend_with_transitive_ops(resolve_from)
114121
return result
115122

116123

hugr-py/src/hugr/tys.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def _resolve_used_extensions(
146146
) -> tuple[Type, ExtensionResolutionResult]:
147147
"""Resolve the extensions required to define this type.
148148
149+
Does not include transitive dependencies required by the returned
150+
extension definitions, to avoid infinite recursion.
151+
149152
Args:
150153
registry: A registry to resolve unresolved extensions from.
151154
If None, opaque types will not be resolved.

0 commit comments

Comments
 (0)