Skip to content

Approximation's sample method uses model contexts#7940

Merged
ricardoV94 merged 5 commits into
pymc-devs:v6from
zaxtax:removing_model_field_in_fit
Apr 17, 2026
Merged

Approximation's sample method uses model contexts#7940
ricardoV94 merged 5 commits into
pymc-devs:v6from
zaxtax:removing_model_field_in_fit

Conversation

@zaxtax
Copy link
Copy Markdown
Contributor

@zaxtax zaxtax commented Oct 29, 2025

This change makes it so calling sample on fitted approximations uses the model context or an explicitly provided model. Previously, this had a hard assumption that we only want to sample from the same model as the one we fitted.

This also deprecates self.model.

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7940.org.readthedocs.build/en/7940/

@codecov
Copy link
Copy Markdown

codecov Bot commented Oct 31, 2025

Codecov Report

❌ Patch coverage is 87.50000% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (v6@4e28043). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pymc/variational/approximations.py 0.00% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@          Coverage Diff          @@
##             v6    #7940   +/-   ##
=====================================
  Coverage      ?   91.71%           
=====================================
  Files         ?      124           
  Lines         ?    19967           
  Branches      ?        0           
=====================================
  Hits          ?    18313           
  Misses        ?     1654           
  Partials      ?        0           
Files with missing lines Coverage Δ
pymc/sampling/mcmc.py 87.99% <100.00%> (ø)
pymc/variational/opvi.py 87.42% <100.00%> (ø)
pymc/variational/approximations.py 90.36% <0.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@zaxtax zaxtax requested a review from ricardoV94 October 31, 2025 01:48
@zaxtax
Copy link
Copy Markdown
Contributor Author

zaxtax commented Oct 31, 2025

This still needs some tests that the Deprecation warnings fire

@zaxtax zaxtax requested a review from jessegrabowski November 3, 2025 15:39
Comment thread pymc/variational/approximations.py Outdated
alias_names = frozenset(["mf"])

@node_property
@cached_property
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are all just dictionary lookups (or construction of symbolic variables). I don't think it's necessary to cache them

Comment thread pymc/variational/approximations.py Outdated
Comment on lines +89 to +95
# Ensure start is a 1D array and matches ddim
start = np.asarray(start).flatten()
if start.size != self.ddim:
raise ValueError(
f"Start array size mismatch: got {start.size}, expected {self.ddim}. "
f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}"
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this check to _prepare_start

Comment thread pymc/variational/approximations.py Outdated
Comment thread pymc/variational/opvi.py Outdated
datalogp_norm = property(lambda self: self.approx.datalogp_norm)
logq_norm = property(lambda self: self.approx.logq_norm)
model = property(lambda self: self.approx.model)
model = property(lambda self: modelcontext(None))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you never refer to self.model now there's no need to store this

Comment thread pymc/variational/opvi.py Outdated
Comment on lines +891 to +893
# Clear old replacements/ordering before rebuilding
self.replacements = collections.OrderedDict()
self.ordering = collections.OrderedDict()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? Was there a bug?

Comment thread pymc/variational/opvi.py Outdated
def collect(self, item):
return [getattr(g, item) for g in self.groups]

def _variational_orderings(self, model):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next few functions are new but don't seem to be related to the refactor. Where did these come from?

Also they all take model as an argument, but you should use modelcontext here as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all used in sample to make the logic easier to follow there. We can probably tweak things to have less new functions but that's what they are for.

Comment thread pymc/variational/opvi.py Outdated
Comment thread pymc/variational/opvi.py Outdated
samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed)
spec = self._build_trace_spec(model, samples)

from collections import OrderedDict
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to top

@zaxtax zaxtax force-pushed the removing_model_field_in_fit branch from b4495b7 to 735d8d1 Compare March 20, 2026 00:26
@zaxtax zaxtax changed the base branch from main to v6 March 21, 2026 21:44
@zaxtax zaxtax force-pushed the removing_model_field_in_fit branch from e287859 to e41bb9f Compare March 21, 2026 21:45
@zaxtax
Copy link
Copy Markdown
Contributor Author

zaxtax commented Mar 21, 2026

@jessegrabowski I've tried to incorporate all requested changes. Anything else we need to get this over the line?

@zaxtax zaxtax force-pushed the removing_model_field_in_fit branch 4 times, most recently from ea840de to a64292c Compare April 16, 2026 02:31
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dang you weren't kidding when you said you slimmed the PR down. Love it.

I approved but pending tests passing ofc

@zaxtax zaxtax force-pushed the removing_model_field_in_fit branch from a64292c to fe359ec Compare April 16, 2026 03:24
Comment thread tests/variational/test_inference.py Outdated
Comment on lines +460 to +465
with model:
x_test = [5, 6, 9, 12, 15]
pm.set_data(new_data={"x": x_test}, coords={"obs_id": list(range(len(x_test)))})
y_test = pm.sample_posterior_predictive(trace, predictions=True, progressbar=False)

assert y_test.predictions["obs"].shape == (1, 500, 5)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't this test have had the same output before your PR? thought you were interested in changing approx.sample() alongside the model (so show how it changes after you call set_data?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, can you confirm the test works as regression? Without the changes it failed

@zaxtax
Copy link
Copy Markdown
Contributor Author

zaxtax commented Apr 17, 2026

LGTM, can you confirm the test works as regression? Without the changes it failed

The test wasn't a regression. Now it is. Though this did require making more changes

@zaxtax
Copy link
Copy Markdown
Contributor Author

zaxtax commented Apr 17, 2026

@ricardoV94 does this still look alright?

Comment thread tests/variational/test_inference.py Outdated
assert data["mu"].shape == ()


def test_sample_posterior_predictive_after_set_data():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name is not so great anymore?

@zaxtax zaxtax force-pushed the removing_model_field_in_fit branch from 78b809a to b86f265 Compare April 17, 2026 16:19
@ricardoV94 ricardoV94 added maintenance VI Variational Inference labels Apr 17, 2026
@ricardoV94 ricardoV94 merged commit 6182801 into pymc-devs:v6 Apr 17, 2026
42 checks passed
@ricardoV94
Copy link
Copy Markdown
Member

Thanks @zaxtax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

maintenance VI Variational Inference

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants