Approximation's sample method uses model contexts#7940
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## v6 #7940 +/- ##
=====================================
Coverage ? 91.71%
=====================================
Files ? 124
Lines ? 19967
Branches ? 0
=====================================
Hits ? 18313
Misses ? 1654
Partials ? 0
🚀 New features to boost your workflow:
|
|
This still needs some tests that the Deprecation warnings fire |
| alias_names = frozenset(["mf"]) | ||
|
|
||
| @node_property | ||
| @cached_property |
There was a problem hiding this comment.
these are all just dictionary lookups (or construction of symbolic variables). I don't think it's necessary to cache them
| # Ensure start is a 1D array and matches ddim | ||
| start = np.asarray(start).flatten() | ||
| if start.size != self.ddim: | ||
| raise ValueError( | ||
| f"Start array size mismatch: got {start.size}, expected {self.ddim}. " | ||
| f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}" | ||
| ) |
There was a problem hiding this comment.
Move this check to _prepare_start
| datalogp_norm = property(lambda self: self.approx.datalogp_norm) | ||
| logq_norm = property(lambda self: self.approx.logq_norm) | ||
| model = property(lambda self: self.approx.model) | ||
| model = property(lambda self: modelcontext(None)) |
There was a problem hiding this comment.
If you never refer to self.model now there's no need to store this
| # Clear old replacements/ordering before rebuilding | ||
| self.replacements = collections.OrderedDict() | ||
| self.ordering = collections.OrderedDict() |
There was a problem hiding this comment.
Is this necessary? Was there a bug?
| def collect(self, item): | ||
| return [getattr(g, item) for g in self.groups] | ||
|
|
||
| def _variational_orderings(self, model): |
There was a problem hiding this comment.
The next few functions are new but don't seem to be related to the refactor. Where did these come from?
Also they all take model as an argument, but you should use modelcontext here as well
There was a problem hiding this comment.
These are all used in sample to make the logic easier to follow there. We can probably tweak things to have less new functions but that's what they are for.
| samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) | ||
| spec = self._build_trace_spec(model, samples) | ||
|
|
||
| from collections import OrderedDict |
b4495b7 to
735d8d1
Compare
Documentation build overview
221 files changed ·
|
e287859 to
e41bb9f
Compare
|
@jessegrabowski I've tried to incorporate all requested changes. Anything else we need to get this over the line? |
ea840de to
a64292c
Compare
a64292c to
fe359ec
Compare
| with model: | ||
| x_test = [5, 6, 9, 12, 15] | ||
| pm.set_data(new_data={"x": x_test}, coords={"obs_id": list(range(len(x_test)))}) | ||
| y_test = pm.sample_posterior_predictive(trace, predictions=True, progressbar=False) | ||
|
|
||
| assert y_test.predictions["obs"].shape == (1, 500, 5) |
There was a problem hiding this comment.
wouldn't this test have had the same output before your PR? thought you were interested in changing approx.sample() alongside the model (so show how it changes after you call set_data?
ricardoV94
left a comment
There was a problem hiding this comment.
LGTM, can you confirm the test works as regression? Without the changes it failed
The test wasn't a regression. Now it is. Though this did require making more changes |
|
@ricardoV94 does this still look alright? |
| assert data["mu"].shape == () | ||
|
|
||
|
|
||
| def test_sample_posterior_predictive_after_set_data(): |
There was a problem hiding this comment.
name is not so great anymore?
78b809a to
b86f265
Compare
|
Thanks @zaxtax |
This change makes it so calling
sampleon fitted approximations uses the model context or an explicitly provided model. Previously, this had a hard assumption that we only want to sample from the same model as the one we fitted.This also deprecates
self.model.Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7940.org.readthedocs.build/en/7940/