{"title":"Brandon T. Willard","link":[{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/","rel":"alternate"}},{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/feeds\/all.atom.xml","rel":"self"}}],"id":"https:\/\/brandonwillard.github.io\/","updated":"2022-01-13T00:00:00-06:00","entry":[{"title":"Dynamic Linear Models in Theano","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/dynamic-linear-models-in-theano.html","rel":"alternate"}},"published":"2020-03-18T00:00:00-05:00","updated":"2020-05-08T00:00:00-05:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2020-03-18:\/dynamic-linear-models-in-theano.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>Dynamic Linear Models in Theano<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">Dynamic Linear Models in Theano<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2020\u201303\u201318<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<ul>\n<li><a href=\"#org8ae0857\">Introduction<\/a><\/li>\n<li><a href=\"#org8873a63\">Analytic Posteriors<\/a><\/li>\n<li><a href=\"#orgefb3ef8\">Posterior Estimation<\/a>\n<ul>\n<li><a href=\"#org1d9986d\">SVD-based Filtering<\/a><\/li>\n<li><a href=\"#org350fa8d\">SVD-based Smoothing<\/a><\/li>\n<li><a href=\"#org9beb158\">Example<\/a><\/li>\n<\/ul><\/li>\n<li><a href=\"#orgaa30fc5\">Forward-filtering Backward-sampling<\/a>\n<ul>\n<li><a href=\"#org68ce824\">Simulation Example<\/a><\/li>\n<\/ul><\/li>\n<li><a href=\"#orge6290b1\">Non-Gaussian Extension<\/a>\n<ul>\n<li><a href=\"#orgeef4d98\">Simulation Example<\/a><\/li>\n<\/ul><\/li>\n<li><a href=\"#orgd705951\">Augmented FFBS Sampler<\/a>\n<ul>\n<li><a href=\"#org1057f5a\">Simulation Example<\/a><\/li>\n<li><a href=\"#orgf74c6fe\">COVID\u201319 Example<\/a><\/li>\n<\/ul><\/li>\n<li><a href=\"#org24da3be\">Discussion<\/a><\/li>\n<\/ul>\n<div class=\"abstract\">\n<p>In this document we detail how dynamic linear models (DLMs) can be implemented in Theano (or similar tensor libraries), as well as a complementary Rao-Blackwellized sampler tailored to the structure of DLMs. Furthermore, we provide an example of how DLMs can be extended\u2013via Gaussian scale-mixtures\u2013to model non-gaussian observations. The example is a Po\u0301lya-Gamma extension to forward-filtering backward-sampling for negative-binomial observations.<\/p>\n<\/div>\n<p><a id=\"org8ae0857\"><\/a><\/p>\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p>For a proper introduction of dynamic linear models and their estimation, see <a id=\"4bbd465b4e78e5c5151b0cbba54d984e\"><a href=\"#harrison_bayesian_1999\">Harrison &amp; West (1999)<\/a><\/a>,<a id=\"1c3f471fd137724bd53b01eb4d6534fe\"><a href=\"#PetrisDynamicLinearModels2009\">Petris, Petrone &amp; Campagnoli (2009)<\/a><\/a>, and <a id=\"8247a823e73eba801aad2942a49b03be\"><a href=\"#GamermanMarkovchainMonte2006\">Gamerman &amp; Lopes (2006)<\/a><\/a>,<a id=\"c6a5d83a82b5963f42264d85625b5153\"><a href=\"#pole_applied_1994\">Pole, West &amp; Harrison (1994)<\/a><\/a>. The focus of this document is on some practical details behind DLM estimation in Python\u2013and, particularly, libraries like Theano <a id=\"25fdcab375e353d7bb32eec7b13064c6\"><a href=\"#bergstra_theano:_2010\">(Bergstra, Breuleux, Bastien, Lamblin, Pascanu, Desjardins, Turian, Warde-Farley &amp; Bengio 2010)<\/a><\/a>. Throughout, we use some custom random variable objects from <code>symbolic-pymc<\/code> <a id=\"e8fa26b92264fca946be25e6b617fc56\"><a href=\"#Willardsymbolicpymc2019\">(Willard 2019)<\/a><\/a>. This is largely because <code>symbolic-pymc<\/code> offers a necessary non-standard distribution and its random variables work well within the <code>scan<\/code> operations we employ.<\/p>\n<p>In the future, we plan to build optimizations that automate much of the work done here (and more). This exposition sets the foundation for such work by first motivating the use and generality of Bayesian frameworks like DLMs, then by demonstrating the analytic steps that produce customized, efficient samplers. These are the steps that would undergo automation in the future.<\/p>\n<p>A DLM with prior <span class=\"math inline\">\\(\\theta_0 \\sim \\operatorname{N}\\left( m_0, C_0 \\right)\\)<\/span> is defined as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{align}\n  y_t &amp;= F_t^{\\top} \\theta_{t} + \\epsilon_t, \\quad \\epsilon_t \\sim \\operatorname{N}\\left( 0, V_t \\right)\n  \\label{eq:basic-dlm-obs}\n  \\\\\n  \\theta_t &amp;= G_t \\theta_{t-1} + \\nu_t, \\quad \\nu_t \\sim \\operatorname{N}\\left( 0, W_t \\right)\n  \\label{eq:basic-dlm-state}\n\\end{align}\\]<\/span><\/p>\n<p>for <span class=\"math inline\">\\(t \\in \\{1, \\dots, T\\}\\)<\/span>, <span class=\"math inline\">\\(y_t \\in \\mathbb{R}\\)<\/span>, and <span class=\"math inline\">\\(\\theta_t \\in \\mathbb{R}^{M}\\)<\/span>.<\/p>\n<p>The most \u201cnotationally\u201d faithful representation of the timeseries model in <span class=\"math inline\">\\(\\eqref{eq:basic-dlm-state}\\)<\/span> using Theano is provided in Listing <a href=\"#orge6ada13\">2<\/a>. It represents the notion of a recursion\u2013to the best of Theano\u2019s ability\u2013by way of the <code>scan<\/code> operator.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> numpy <span class=\"im\">as<\/span> np<\/span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano<\/span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano.tensor <span class=\"im\">as<\/span> tt<\/span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> matplotlib.pyplot <span class=\"im\">as<\/span> plt<\/span>\n<span id=\"cb1-7\"><a href=\"#cb1-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-8\"><a href=\"#cb1-8\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> cycler <span class=\"im\">import<\/span> cycler<\/span>\n<span id=\"cb1-9\"><a href=\"#cb1-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-10\"><a href=\"#cb1-10\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> matplotlib.collections <span class=\"im\">import<\/span> LineCollection<\/span>\n<span id=\"cb1-11\"><a href=\"#cb1-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-12\"><a href=\"#cb1-12\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.printing <span class=\"im\">import<\/span> debugprint <span class=\"im\">as<\/span> tt_dprint<\/span>\n<span id=\"cb1-13\"><a href=\"#cb1-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-14\"><a href=\"#cb1-14\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.theano.random_variables <span class=\"im\">import<\/span> NormalRV, MvNormalRV, GammaRV<\/span>\n<span id=\"cb1-15\"><a href=\"#cb1-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-16\"><a href=\"#cb1-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-17\"><a href=\"#cb1-17\" aria-hidden=\"true\"><\/a>plt.style.use(<span class=\"st\">&#39;ggplot&#39;<\/span>)<\/span>\n<span id=\"cb1-18\"><a href=\"#cb1-18\" aria-hidden=\"true\"><\/a>plt_orig_cycler <span class=\"op\">=<\/span> plt.rcParams[<span class=\"st\">&#39;axes.prop_cycle&#39;<\/span>]<\/span>\n<span id=\"cb1-19\"><a href=\"#cb1-19\" aria-hidden=\"true\"><\/a>plt.rc(<span class=\"st\">&#39;text&#39;<\/span>, usetex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-20\"><a href=\"#cb1-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-21\"><a href=\"#cb1-21\" aria-hidden=\"true\"><\/a><span class=\"co\"># theano.config.cxx = &quot;&quot;<\/span><\/span>\n<span id=\"cb1-22\"><a href=\"#cb1-22\" aria-hidden=\"true\"><\/a><span class=\"co\"># theano.config.mode = &quot;FAST_COMPILE&quot;<\/span><\/span>\n<span id=\"cb1-23\"><a href=\"#cb1-23\" aria-hidden=\"true\"><\/a>tt.config.compute_test_value <span class=\"op\">=<\/span> <span class=\"st\">&#39;ignore&#39;<\/span><\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"orge6ada13\">\n<div class=\"sourceCode\" id=\"cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb2-1\"><a href=\"#cb2-1\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-2\"><a href=\"#cb2-2\" aria-hidden=\"true\"><\/a>N_obs_tt <span class=\"op\">=<\/span> tt.iscalar(<span class=\"st\">&quot;N_obs&quot;<\/span>)<\/span>\n<span id=\"cb2-3\"><a href=\"#cb2-3\" aria-hidden=\"true\"><\/a>N_theta_tt <span class=\"op\">=<\/span> tt.iscalar(<span class=\"st\">&quot;N_theta&quot;<\/span>)<\/span>\n<span id=\"cb2-4\"><a href=\"#cb2-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-5\"><a href=\"#cb2-5\" aria-hidden=\"true\"><\/a>G_tt <span class=\"op\">=<\/span> tt.specify_shape(tt.matrix(), [N_theta_tt, N_theta_tt])<\/span>\n<span id=\"cb2-6\"><a href=\"#cb2-6\" aria-hidden=\"true\"><\/a>G_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&#39;G_t&#39;<\/span><\/span>\n<span id=\"cb2-7\"><a href=\"#cb2-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-8\"><a href=\"#cb2-8\" aria-hidden=\"true\"><\/a>F_tt <span class=\"op\">=<\/span> tt.specify_shape(tt.col(), [N_theta_tt, <span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb2-9\"><a href=\"#cb2-9\" aria-hidden=\"true\"><\/a>F_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&#39;F_t&#39;<\/span><\/span>\n<span id=\"cb2-10\"><a href=\"#cb2-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-11\"><a href=\"#cb2-11\" aria-hidden=\"true\"><\/a>rng_state <span class=\"op\">=<\/span> np.random.RandomState(np.random.MT19937(np.random.SeedSequence(<span class=\"dv\">1234<\/span>)))<\/span>\n<span id=\"cb2-12\"><a href=\"#cb2-12\" aria-hidden=\"true\"><\/a>rng_init_state <span class=\"op\">=<\/span> rng_state.get_state()<\/span>\n<span id=\"cb2-13\"><a href=\"#cb2-13\" aria-hidden=\"true\"><\/a>rng_tt <span class=\"op\">=<\/span> theano.shared(rng_state, name<span class=\"op\">=<\/span><span class=\"st\">&#39;rng&#39;<\/span>, borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb2-14\"><a href=\"#cb2-14\" aria-hidden=\"true\"><\/a>rng_tt.tag.is_rng <span class=\"op\">=<\/span> <span class=\"va\">True<\/span><\/span>\n<span id=\"cb2-15\"><a href=\"#cb2-15\" aria-hidden=\"true\"><\/a>rng_tt.default_update <span class=\"op\">=<\/span> rng_tt<\/span>\n<span id=\"cb2-16\"><a href=\"#cb2-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-17\"><a href=\"#cb2-17\" aria-hidden=\"true\"><\/a>m_0_tt <span class=\"op\">=<\/span> tt.zeros([N_theta_tt])<\/span>\n<span id=\"cb2-18\"><a href=\"#cb2-18\" aria-hidden=\"true\"><\/a>m_0_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;m_0&quot;<\/span><\/span>\n<span id=\"cb2-19\"><a href=\"#cb2-19\" aria-hidden=\"true\"><\/a>C_0_tt <span class=\"op\">=<\/span> <span class=\"fl\">1000.0<\/span> <span class=\"op\">*<\/span> tt.eye(N_theta_tt)<\/span>\n<span id=\"cb2-20\"><a href=\"#cb2-20\" aria-hidden=\"true\"><\/a>C_0_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;C_0&quot;<\/span><\/span>\n<span id=\"cb2-21\"><a href=\"#cb2-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-22\"><a href=\"#cb2-22\" aria-hidden=\"true\"><\/a>theta_0_rv <span class=\"op\">=<\/span> MvNormalRV(m_0_tt, C_0_tt, rng<span class=\"op\">=<\/span>rng_tt, name<span class=\"op\">=<\/span><span class=\"st\">&#39;theta_0&#39;<\/span>)<\/span>\n<span id=\"cb2-23\"><a href=\"#cb2-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-24\"><a href=\"#cb2-24\" aria-hidden=\"true\"><\/a>phi_W_true <span class=\"op\">=<\/span> np.r_[<span class=\"fl\">5.0<\/span>, <span class=\"fl\">10.0<\/span>]<\/span>\n<span id=\"cb2-25\"><a href=\"#cb2-25\" aria-hidden=\"true\"><\/a>phi_W_tt <span class=\"op\">=<\/span> theano.shared(phi_W_true, name<span class=\"op\">=<\/span><span class=\"st\">&#39;phi_W&#39;<\/span>)<\/span>\n<span id=\"cb2-26\"><a href=\"#cb2-26\" aria-hidden=\"true\"><\/a>W_tt <span class=\"op\">=<\/span> tt.eye(N_theta_tt) <span class=\"op\">*<\/span> tt.inv(phi_W_tt)<\/span>\n<span id=\"cb2-27\"><a href=\"#cb2-27\" aria-hidden=\"true\"><\/a>W_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;W_t&quot;<\/span><\/span>\n<span id=\"cb2-28\"><a href=\"#cb2-28\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-29\"><a href=\"#cb2-29\" aria-hidden=\"true\"><\/a>phi_V_true <span class=\"op\">=<\/span> <span class=\"fl\">0.1<\/span><\/span>\n<span id=\"cb2-30\"><a href=\"#cb2-30\" aria-hidden=\"true\"><\/a>phi_V_tt <span class=\"op\">=<\/span> theano.shared(phi_V_true, name<span class=\"op\">=<\/span><span class=\"st\">&#39;phi_V&#39;<\/span>)<\/span>\n<span id=\"cb2-31\"><a href=\"#cb2-31\" aria-hidden=\"true\"><\/a>V_tt <span class=\"op\">=<\/span> tt.inv(phi_V_tt)<\/span>\n<span id=\"cb2-32\"><a href=\"#cb2-32\" aria-hidden=\"true\"><\/a>V_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;V_t&quot;<\/span><\/span>\n<span id=\"cb2-33\"><a href=\"#cb2-33\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-34\"><a href=\"#cb2-34\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> state_step(theta_tm1, G_t, W_t, N_theta, rng):<\/span>\n<span id=\"cb2-35\"><a href=\"#cb2-35\" aria-hidden=\"true\"><\/a>    nu_rv <span class=\"op\">=<\/span> MvNormalRV(tt.zeros([N_theta]), W_t, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span><span class=\"st\">&#39;nu&#39;<\/span>)<\/span>\n<span id=\"cb2-36\"><a href=\"#cb2-36\" aria-hidden=\"true\"><\/a>    theta_t <span class=\"op\">=<\/span> G_t.dot(theta_tm1) <span class=\"op\">+<\/span> nu_rv<\/span>\n<span id=\"cb2-37\"><a href=\"#cb2-37\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> theta_t<\/span>\n<span id=\"cb2-38\"><a href=\"#cb2-38\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-39\"><a href=\"#cb2-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-40\"><a href=\"#cb2-40\" aria-hidden=\"true\"><\/a>theta_t_rv, theta_t_updates <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>state_step,<\/span>\n<span id=\"cb2-41\"><a href=\"#cb2-41\" aria-hidden=\"true\"><\/a>                                          outputs_info<span class=\"op\">=<\/span>{<span class=\"st\">&quot;initial&quot;<\/span>: theta_0_rv, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb2-42\"><a href=\"#cb2-42\" aria-hidden=\"true\"><\/a>                                          non_sequences<span class=\"op\">=<\/span>[G_tt, W_tt, N_theta_tt, rng_tt],<\/span>\n<span id=\"cb2-43\"><a href=\"#cb2-43\" aria-hidden=\"true\"><\/a>                                          n_steps<span class=\"op\">=<\/span>N_obs_tt,<\/span>\n<span id=\"cb2-44\"><a href=\"#cb2-44\" aria-hidden=\"true\"><\/a>                                          strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb2-45\"><a href=\"#cb2-45\" aria-hidden=\"true\"><\/a>                                          name<span class=\"op\">=<\/span><span class=\"st\">&#39;theta_t&#39;<\/span>)<\/span>\n<span id=\"cb2-46\"><a href=\"#cb2-46\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-47\"><a href=\"#cb2-47\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> obs_step(theta_t, F_t, V_t, rng):<\/span>\n<span id=\"cb2-48\"><a href=\"#cb2-48\" aria-hidden=\"true\"><\/a>    eps_rv <span class=\"op\">=<\/span> NormalRV(<span class=\"fl\">0.0<\/span>, tt.sqrt(V_t), rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span><span class=\"st\">&#39;eps&#39;<\/span>)<\/span>\n<span id=\"cb2-49\"><a href=\"#cb2-49\" aria-hidden=\"true\"><\/a>    y_t <span class=\"op\">=<\/span> F_t.T.dot(theta_t) <span class=\"op\">+<\/span> eps_rv<\/span>\n<span id=\"cb2-50\"><a href=\"#cb2-50\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> y_t<\/span>\n<span id=\"cb2-51\"><a href=\"#cb2-51\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-52\"><a href=\"#cb2-52\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-53\"><a href=\"#cb2-53\" aria-hidden=\"true\"><\/a>Y_t_rv, Y_t_updates <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>obs_step,<\/span>\n<span id=\"cb2-54\"><a href=\"#cb2-54\" aria-hidden=\"true\"><\/a>                                  sequences<span class=\"op\">=<\/span>[theta_t_rv],<\/span>\n<span id=\"cb2-55\"><a href=\"#cb2-55\" aria-hidden=\"true\"><\/a>                                  non_sequences<span class=\"op\">=<\/span>[F_tt, V_tt, rng_tt],<\/span>\n<span id=\"cb2-56\"><a href=\"#cb2-56\" aria-hidden=\"true\"><\/a>                                  strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb2-57\"><a href=\"#cb2-57\" aria-hidden=\"true\"><\/a>                                  name<span class=\"op\">=<\/span><span class=\"st\">&#39;Y_t&#39;<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 2\n<\/figcaption>\n<\/figure>\n<p>Throughout we\u2019ll use data sampled from <span class=\"math inline\">\\(\\eqref{eq:basic-dlm-state}\\)<\/span> for demonstration purposes. Our simulation has the following model parameter values:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    T = 200,\\quad M = 2\n    \\\\\n    W_t = \\operatorname{diag}\\left( \\phi_W \\right)\n    ,\\quad\n    V_t = \\operatorname{diag}\\left( \\phi_V \\right)\n    \\\\\n    \\phi_W = \\left(1.1, 10\\right),\\quad \\phi_V = 0.7\n    \\\\\n    G_t = \\begin{pmatrix}\n    1 &amp; 0.1 \\\\\n    0 &amp; 1 \\\\\n    \\end{pmatrix},\\quad\n    F_t = \\begin{pmatrix}\n    1 \\\\\n    0\n    \\end{pmatrix}\n    \\\\\n    \\theta_0 = \\begin{pmatrix}\n    0 \\\\\n    0\n    \\end{pmatrix}\n  \\end{gathered}\n  \\label{eq:sim-settings}\n\\end{equation}\\]<\/span><\/p>\n<p>A sample from <span class=\"math inline\">\\(\\eqref{eq:sim-settings}\\)<\/span> is generated in Listing <a href=\"#org721f0b8\">3<\/a>.<\/p>\n<figure id=\"org721f0b8\">\n<div class=\"sourceCode\" id=\"cb3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb3-1\"><a href=\"#cb3-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> function <span class=\"im\">as<\/span> tt_function<\/span>\n<span id=\"cb3-2\"><a href=\"#cb3-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-3\"><a href=\"#cb3-3\" aria-hidden=\"true\"><\/a>dlm_sim_values <span class=\"op\">=<\/span> {<\/span>\n<span id=\"cb3-4\"><a href=\"#cb3-4\" aria-hidden=\"true\"><\/a>    N_obs_tt: <span class=\"dv\">200<\/span>,<\/span>\n<span id=\"cb3-5\"><a href=\"#cb3-5\" aria-hidden=\"true\"><\/a>    N_theta_tt: <span class=\"dv\">2<\/span>,<\/span>\n<span id=\"cb3-6\"><a href=\"#cb3-6\" aria-hidden=\"true\"><\/a>    G_tt: np.r_[<span class=\"st\">&#39;0,2&#39;<\/span>,<\/span>\n<span id=\"cb3-7\"><a href=\"#cb3-7\" aria-hidden=\"true\"><\/a>                [<span class=\"fl\">1.0<\/span>, <span class=\"fl\">0.1<\/span>],<\/span>\n<span id=\"cb3-8\"><a href=\"#cb3-8\" aria-hidden=\"true\"><\/a>                [<span class=\"fl\">0.0<\/span>, <span class=\"fl\">1.0<\/span>]].astype(tt.config.floatX),<\/span>\n<span id=\"cb3-9\"><a href=\"#cb3-9\" aria-hidden=\"true\"><\/a>    F_tt: np.r_[[[<span class=\"fl\">1.0<\/span>],<\/span>\n<span id=\"cb3-10\"><a href=\"#cb3-10\" aria-hidden=\"true\"><\/a>                 [<span class=\"fl\">0.0<\/span>]]].astype(tt.config.floatX)<\/span>\n<span id=\"cb3-11\"><a href=\"#cb3-11\" aria-hidden=\"true\"><\/a>}<\/span>\n<span id=\"cb3-12\"><a href=\"#cb3-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-13\"><a href=\"#cb3-13\" aria-hidden=\"true\"><\/a>rng_tt.get_value(borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>).set_state(rng_init_state)<\/span>\n<span id=\"cb3-14\"><a href=\"#cb3-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-15\"><a href=\"#cb3-15\" aria-hidden=\"true\"><\/a>simulate_dlm <span class=\"op\">=<\/span> tt_function([N_obs_tt, N_theta_tt, G_tt, F_tt],<\/span>\n<span id=\"cb3-16\"><a href=\"#cb3-16\" aria-hidden=\"true\"><\/a>                           [Y_t_rv, theta_t_rv],<\/span>\n<span id=\"cb3-17\"><a href=\"#cb3-17\" aria-hidden=\"true\"><\/a>                           givens<span class=\"op\">=<\/span>{theta_0_rv: np.r_[<span class=\"fl\">0.0<\/span>, <span class=\"fl\">0.0<\/span>]},<\/span>\n<span id=\"cb3-18\"><a href=\"#cb3-18\" aria-hidden=\"true\"><\/a>                           updates<span class=\"op\">=<\/span>Y_t_updates)<\/span>\n<span id=\"cb3-19\"><a href=\"#cb3-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-20\"><a href=\"#cb3-20\" aria-hidden=\"true\"><\/a>y_sim, theta_t_sim <span class=\"op\">=<\/span> simulate_dlm(dlm_sim_values[N_obs_tt], dlm_sim_values[N_theta_tt], dlm_sim_values[G_tt], dlm_sim_values[F_tt])<\/span>\n<span id=\"cb3-21\"><a href=\"#cb3-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-22\"><a href=\"#cb3-22\" aria-hidden=\"true\"><\/a>rng_sim_state <span class=\"op\">=<\/span> rng_tt.get_value(borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>).get_state()<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 3\n<\/figcaption>\n<\/figure>\n<p>In Figure <a href=\"#org04355d9\">4<\/a> we plot a sample from the model in Listing <a href=\"#orge6ada13\">2<\/a> for a fixed RNG seed.<\/p>\n<figure id=\"org04355d9\">\n<div class=\"sourceCode\" id=\"cb4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb4-1\"><a href=\"#cb4-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb4-2\"><a href=\"#cb4-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-3\"><a href=\"#cb4-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb4-4\"><a href=\"#cb4-4\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> ax.plot(y_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$y_t$&#39;<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.7<\/span>)<\/span>\n<span id=\"cb4-5\"><a href=\"#cb4-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-6\"><a href=\"#cb4-6\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb4-7\"><a href=\"#cb4-7\" aria-hidden=\"true\"><\/a>plt.legend()<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 4\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/basic-dlm-sim-plot.png\" title=\"fig:\" alt=\"\" \/>\n<figcaption>\n<\/figcaption>\n<\/figure>\n<p><a id=\"org8873a63\"><\/a><\/p>\n<\/section>\n<section id=\"analytic-posteriors\" class=\"level1\">\n<h1>Analytic Posteriors<\/h1>\n<p>The standard DLM is essentially a <a href=\"https:\/\/en.wikipedia.org\/wiki\/Kalman_filter\">Kalman Filter<\/a> and enjoys many well documented closed-form results. In the following, we will simply state the relevant prior predictive and posterior results.<\/p>\n<p>Given all the prior and observed data up to time <span class=\"math inline\">\\(t\\)<\/span>, <span class=\"math inline\">\\(D_t\\)<\/span>, these distribution are given by the following:<\/p>\n<p><span class=\"math display\">\\[\\begin{align}\n  \\theta_{t} \\mid D_{t-1} &amp;\\sim \\operatorname{N}\\left( a_{t}, R_{t} \\right)\n  \\\\\n  y_{t} \\mid D_{t-1} &amp;\\sim \\operatorname{N}\\left( f_{t}, Q_{t} \\right)\n\\end{align}\\]<\/span><\/p>\n<p>The prior predictive moments are as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    a_t = G_t m_{t-1}, \\quad R_t = G_t C_{t-1} G_t^\\top + W_t\n    \\\\\n    f_t = F_t^\\top a_{t}, \\quad Q_t = F_t^\\top C_{t-1} F_t + V_t\n  \\end{gathered}\n  \\label{eq:dlm-prior-predictive}\n\\end{equation}\\]<\/span><\/p>\n<p>We\u2019ll also want to compute the posterior moments for <span class=\"math inline\">\\(\\theta_t \\mid D_t\\)<\/span>, which are as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    m_t = a_{t} + R_t F_t Q_t^{-1} \\left(y_t - f_t\\right),\n    \\quad C_t = R_t  - R_t F_t Q_t^{-1} F_t^\\top R_t\n  \\end{gathered}\n  \\label{eq:dlm-post-moments}\n\\end{equation}\\]<\/span><\/p>\n<p>These \u201cfiltered\u201d moments\/distributions are one \u201ckind\u201d of posterior result for a DLM, and they only take into account the data <strong>up to<\/strong> time <span class=\"math inline\">\\(t\\)<\/span>. The other kind are the \u201csmoothed\u201d distributions, which provide posterior distributions for each time <span class=\"math inline\">\\(t\\)<\/span> given all observations preceding <strong>and<\/strong> following <span class=\"math inline\">\\(t\\)<\/span>.<\/p>\n<p>Notationally, we\u2019ve used <span class=\"math inline\">\\(D_t\\)<\/span> to signify all conditional observations and parameters up to time <span class=\"math inline\">\\(t\\)<\/span>, so the smoothed distributions are given by <span class=\"math inline\">\\(\\theta_t \\mid D_T\\)<\/span>, in contrast to <span class=\"math inline\">\\(\\theta_t \\mid D_t\\)<\/span><\/p>\n<p>The smoothed <span class=\"math inline\">\\(\\theta_t\\)<\/span> distributions are still Gaussian, i.e.\u00a0<span class=\"math inline\">\\(\\left(\\theta_t \\mid D_T\\right) \\sim \\operatorname{N}\\left(s_t, S_t\\right)\\)<\/span>, and their moments are as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    s_t &amp;= m_t + C_t G_{t+1}^\\top R_{t+1}^{-1} \\left( s_{t+1} - a_{t+1} \\right)\n    \\\\\n    S_t &amp;= C_t - C_t G_{t+1}^\\top R_{t+1}^{-1} \\left( R_{t+1} - S_{t+1} \\right) R_{t+1}^{-1} G_{t+1} C_t\n  \\end{aligned}\n  \\label{eq:dlm-smooth-moments}\n\\end{equation}\\]<\/span><\/p>\n<div class=\"remark\" data-markdown=\"\">\n<p>In most cases, models will not be as simple as the standard DLM. Even so, these basic closed-form solutions can still be relevant. For instance, efficient MCMC algorithms can be constructed using these closed-form results for <strong>conditionally linear<\/strong> models. In those cases, we can compute the posterior moments\u2013in closed-form\u2013conditional on samples generated by other means.<\/p>\n<\/div>\n<p>The standard approach is called forward-filtering backward-sampling (FFBS) and uses smoothed posteriors <span class=\"math inline\">\\(\\theta_t \\mid \\theta_{t+1}, D_T\\)<\/span> conditioned on all other parameters.<\/p>\n<p>We\u2019ll build up to forward-backward sampling in what follows, but, first, we need to establish how the requisite quantities can be computed symbolically.<\/p>\n<p><a id=\"orgefb3ef8\"><\/a><\/p>\n<\/section>\n<section id=\"posterior-estimation\" class=\"level1\">\n<h1>Posterior Estimation<\/h1>\n<p>In Listings <a href=\"#org224c0e5\">6<\/a> and <a href=\"#org911f8da\">7<\/a>, we demonstrate how the posterior moments in <span class=\"math inline\">\\(\\eqref{eq:dlm-post-moments}\\)<\/span> and <span class=\"math inline\">\\(\\eqref{eq:dlm-smooth-moments}\\)<\/span> can be computed in Theano.<\/p>\n<p>Unfortunately, if we attempt to implement the exact closed-form updates in <span class=\"math inline\">\\(\\eqref{eq:dlm-post-moments}\\)<\/span> or <span class=\"math inline\">\\(\\eqref{eq:dlm-smooth-moments}\\)<\/span>, our results will be fraught with numerical errors. This is a very basic issue with naively implemented Kalman filters. The solution to these issues usually involves some analytic reformulations that compensate for the covariance matrix subtractions. The standard approaches generally use some form of matrix decomposition that directly accounts for the positive semi-definite nature of the covariance matrices.<\/p>\n<p>The approach taken here is based on the singular value decomposition (SVD) and effectively computes only one symmetric \u201chalf\u201d of the updated covariances. The SVD also allows for easy inversions. See <a id=\"0ae04c048b20d07f32d7f0f75bb51483\"><a href=\"#ZhangFixedintervalsmoothingalgorithm1996\">Zhang &amp; Li (1996)<\/a><\/a> for more details, or <a id=\"3a4d89388a434d7b1b91dc8690f3a03b\"><a href=\"#PetrisDynamiclinearmodels2009\">Petris, Petrone &amp; Campagnoli (2009)<\/a><\/a> for a concise overview of the procedure in the context of DLMs.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb5-1\"><a href=\"#cb5-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> warnings<\/span>\n<span id=\"cb5-2\"><a href=\"#cb5-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-3\"><a href=\"#cb5-3\" aria-hidden=\"true\"><\/a>warnings.filterwarnings(<span class=\"st\">&quot;ignore&quot;<\/span>, category<span class=\"op\">=<\/span><span class=\"pp\">FutureWarning<\/span>, message<span class=\"op\">=<\/span><span class=\"st\">&quot;Using a non-tuple sequence&quot;<\/span>)<\/span>\n<span id=\"cb5-4\"><a href=\"#cb5-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-5\"><a href=\"#cb5-5\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.tensor.nlinalg <span class=\"im\">import<\/span> matrix_dot<\/span>\n<span id=\"cb5-6\"><a href=\"#cb5-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-7\"><a href=\"#cb5-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-8\"><a href=\"#cb5-8\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> tt_finite_inv(x, eps_truncate<span class=\"op\">=<\/span><span class=\"va\">False<\/span>):<\/span>\n<span id=\"cb5-9\"><a href=\"#cb5-9\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Compute the element-wise reciprocal with special handling for small inputs.<\/span><\/span>\n<span id=\"cb5-10\"><a href=\"#cb5-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-11\"><a href=\"#cb5-11\" aria-hidden=\"true\"><\/a><span class=\"co\">    Parameters<\/span><\/span>\n<span id=\"cb5-12\"><a href=\"#cb5-12\" aria-hidden=\"true\"><\/a><span class=\"co\">    ==========<\/span><\/span>\n<span id=\"cb5-13\"><a href=\"#cb5-13\" aria-hidden=\"true\"><\/a><span class=\"co\">    x: Tensor-like<\/span><\/span>\n<span id=\"cb5-14\"><a href=\"#cb5-14\" aria-hidden=\"true\"><\/a><span class=\"co\">        The value for which the reciprocal, i.e. `1\/x`, is computed.<\/span><\/span>\n<span id=\"cb5-15\"><a href=\"#cb5-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-16\"><a href=\"#cb5-16\" aria-hidden=\"true\"><\/a><span class=\"co\">    eps_truncate: bool (optional)<\/span><\/span>\n<span id=\"cb5-17\"><a href=\"#cb5-17\" aria-hidden=\"true\"><\/a><span class=\"co\">        Determines whether or not a floating-point epsilon truncation is used to<\/span><\/span>\n<span id=\"cb5-18\"><a href=\"#cb5-18\" aria-hidden=\"true\"><\/a><span class=\"co\">        upper-bound the returned values.<\/span><\/span>\n<span id=\"cb5-19\"><a href=\"#cb5-19\" aria-hidden=\"true\"><\/a><span class=\"co\">        If not (the default), infinite values are simply set to zero.<\/span><\/span>\n<span id=\"cb5-20\"><a href=\"#cb5-20\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb5-21\"><a href=\"#cb5-21\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> eps_truncate:<\/span>\n<span id=\"cb5-22\"><a href=\"#cb5-22\" aria-hidden=\"true\"><\/a>        eps <span class=\"op\">=<\/span> np.finfo(<span class=\"bu\">getattr<\/span>(x, <span class=\"st\">&#39;dtype&#39;<\/span>, <span class=\"va\">None<\/span>) <span class=\"kw\">or<\/span> theano.config.floatX).eps<\/span>\n<span id=\"cb5-23\"><a href=\"#cb5-23\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> tt.minimum(tt.inv(x), np.reciprocal(np.sqrt(eps)))<\/span>\n<span id=\"cb5-24\"><a href=\"#cb5-24\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"cb5-25\"><a href=\"#cb5-25\" aria-hidden=\"true\"><\/a>        y <span class=\"op\">=<\/span> tt.inv(x)<\/span>\n<span id=\"cb5-26\"><a href=\"#cb5-26\" aria-hidden=\"true\"><\/a>        res_subtensor <span class=\"op\">=<\/span> y[tt.isinf(y)]<\/span>\n<span id=\"cb5-27\"><a href=\"#cb5-27\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> tt.set_subtensor(res_subtensor, <span class=\"fl\">0.0<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p><a id=\"org1d9986d\"><\/a><\/p>\n<section id=\"svd-based-filtering\" class=\"level2\">\n<h2>SVD-based Filtering<\/h2>\n<p>The SVD forms of the filtering equations in <span class=\"math inline\">\\(\\eqref{eq:dlm-post-moments}\\)<\/span> are produced through creative use of the SVDs of its component matrices. Using a slightly modified version of the formulation established in <a id=\"3a4d89388a434d7b1b91dc8690f3a03b\"><a href=\"#PetrisDynamiclinearmodels2009\">Petris, Petrone &amp; Campagnoli (2009)<\/a><\/a>, the SVD for a matrix <span class=\"math inline\">\\(M\\)<\/span> is given by <span class=\"math inline\">\\(M = U_{M} D_{M} V_{M}^\\top\\)<\/span>. A symmetric matrix then takes the form <span class=\"math inline\">\\(M = U_{M} D_{M} U_{M}^\\top\\)<\/span> and its \u201csquare-root\u201d is given by <span class=\"math inline\">\\(M = N_M^\\top N_M\\)<\/span> with <span class=\"math inline\">\\(N_M = S_{M} U_{M}^\\top\\)<\/span> and <span class=\"math inline\">\\(S_{M} = D_{M}^{1\/2}\\)<\/span>. Likewise, matrix (generalized) inverses take the form <span class=\"math inline\">\\(M^{-1} = U_{M} S_{M}^{-1} U_{M}^\\top\\)<\/span>.<\/p>\n<p>The idea here is that we can combine these SVD identities to derive square-root relationship between the SVD of <span class=\"math inline\">\\(C_t^{-1}\\)<\/span> and the SVDs of <span class=\"math inline\">\\(C_{t-1}\\)<\/span>, <span class=\"math inline\">\\(W_t\\)<\/span>, <span class=\"math inline\">\\(V_t\\)<\/span>, and <span class=\"math inline\">\\(R_t\\)<\/span>, then we can easily invert <span class=\"math inline\">\\(C_t^{-1}\\)<\/span> to arrive at the desired numerically stable SVD of <span class=\"math inline\">\\(C_t\\)<\/span>.<\/p>\n<p>First, note that <span class=\"math inline\">\\(N_{R_t}^\\top N_{R_t} = G_t C_{t-1} G_t^\\top + W_t = R_t\\)<\/span> for<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    N_{R_t} &amp;=\n      \\begin{pmatrix}\n        S_{C_{t-1}} U_{C_{t-1}}^\\top G_t^\\top\n        \\\\\n        N_{W_t}\n      \\end{pmatrix}\n  \\end{aligned}\n  .\n  \\label{eq:N_R_t}\n\\end{equation}\\]<\/span><\/p>\n<p>From this, we know that the SVD of <span class=\"math inline\">\\(R_t\\)<\/span> can be easily derived from the SVD of its square root, <span class=\"math inline\">\\(N_{R_t}\\)<\/span>, i.e.\u00a0<span class=\"math inline\">\\(U_{R_t} = V_{N_{R_t}}\\)<\/span> and <span class=\"math inline\">\\(S_{R_t} = D_{N_{R_t}}\\)<\/span>. In other words, we can obtain a matrix\u2019s SVD by computing the SVD of its \u201chalf\u201d, which is itself entirely comprised of previous SVD components. The inherent symmetry of our covariance matrices is nicely preserved because we\u2019re only ever using and computing one \u201chalf\u201d of these matrices.<\/p>\n<p>With the updated SVD of <span class=\"math inline\">\\(R_t\\)<\/span>, we can use the identity <span class=\"math inline\">\\(C_t^{-1} = F_t V_t^{-1} F_t^\\top + R_t^{-1}\\)<\/span>\u2013obtained via the classic <a href=\"https:\/\/en.wikipedia.org\/wiki\/Woodbury_matrix_identity\">Sherman-Morrison-Woodbury matrix inverse identity<\/a>\u2013to employ the same technique as before and produce the SVD of <span class=\"math inline\">\\(C_t^{-1}\\)<\/span> by way of the SVD of yet another block square-root matrix,<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    N_{C_t^{-1}} &amp;=\n      \\begin{pmatrix}\n        N_{V_t^{-1}} F_t^\\top U_{R_t}\n        \\\\\n        S_{R_t}^{-1}\n      \\end{pmatrix}\n  \\end{aligned}\n  .\n  \\label{eq:N_C_t_inv}\n\\end{equation}\\]<\/span><\/p>\n<p>Again, we compute the SVD of <span class=\"math inline\">\\(N_{C_t^{-1}}\\)<\/span> at this step and obtain <span class=\"math inline\">\\(V_{N_{C_t^{-1}}}\\)<\/span> and <span class=\"math inline\">\\(D_{N_{C_t^{-1}}}\\)<\/span>.<\/p>\n<p>This time, the block square-root matrix relationship isn\u2019t so direct, and we have to multiply by <span class=\"math inline\">\\(U_{R_t}\\)<\/span>: <span class=\"math inline\">\\(U_{R_t} N_{C_t^{-1}}^\\top N_{C_t^{-1}} U_{R_t}^\\top = C_t^{-1}\\)<\/span>. However, since the additional <span class=\"math inline\">\\(U_{R_t}\\)<\/span> terms are orthogonal, we are able to derive the SVD of <span class=\"math inline\">\\(C_t\\)<\/span> as <span class=\"math inline\">\\(U_{C_t} = U_{R_t} V_{N_{C_t^{-1}}}\\)<\/span> and <span class=\"math inline\">\\(S_{C_t} = D_{N_{C_t^{-1}}}^{-1}\\)<\/span>.<\/p>\n<p>These quantities are computed in Listing <a href=\"#org224c0e5\">6<\/a>.<\/p>\n<figure id=\"org224c0e5\">\n<div class=\"sourceCode\" id=\"cb6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb6-1\"><a href=\"#cb6-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.tensor.nlinalg <span class=\"im\">import<\/span> svd<\/span>\n<span id=\"cb6-2\"><a href=\"#cb6-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-3\"><a href=\"#cb6-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-4\"><a href=\"#cb6-4\" aria-hidden=\"true\"><\/a>y_tt <span class=\"op\">=<\/span> tt.specify_shape(tt.col(), [N_obs_tt, <span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb6-5\"><a href=\"#cb6-5\" aria-hidden=\"true\"><\/a>y_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&#39;y_t&#39;<\/span><\/span>\n<span id=\"cb6-6\"><a href=\"#cb6-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-7\"><a href=\"#cb6-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-8\"><a href=\"#cb6-8\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> filtering_step(y_t, m_tm1, U_C_tm1, S_C_tm1, F_t, G_t, N_W_t, N_V_t_inv):<\/span>\n<span id=\"cb6-9\"><a href=\"#cb6-9\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Compute the sequential posterior state and prior predictive parameters.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb6-10\"><a href=\"#cb6-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-11\"><a href=\"#cb6-11\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># R_t = N_R.T.dot(N_R)<\/span><\/span>\n<span id=\"cb6-12\"><a href=\"#cb6-12\" aria-hidden=\"true\"><\/a>    N_R <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>,<\/span>\n<span id=\"cb6-13\"><a href=\"#cb6-13\" aria-hidden=\"true\"><\/a>                  matrix_dot(S_C_tm1, U_C_tm1.T, G_t.T),<\/span>\n<span id=\"cb6-14\"><a href=\"#cb6-14\" aria-hidden=\"true\"><\/a>                  N_W_t)<\/span>\n<span id=\"cb6-15\"><a href=\"#cb6-15\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: All this could be much more efficient if we only computed *one* set of singular<\/span><\/span>\n<span id=\"cb6-16\"><a href=\"#cb6-16\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># vectors for these non-square matrices.<\/span><\/span>\n<span id=\"cb6-17\"><a href=\"#cb6-17\" aria-hidden=\"true\"><\/a>    _, d_N_R_t, V_N_R_t_T <span class=\"op\">=<\/span> svd(N_R)<\/span>\n<span id=\"cb6-18\"><a href=\"#cb6-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-19\"><a href=\"#cb6-19\" aria-hidden=\"true\"><\/a>    U_R_t <span class=\"op\">=<\/span> V_N_R_t_T.T<\/span>\n<span id=\"cb6-20\"><a href=\"#cb6-20\" aria-hidden=\"true\"><\/a>    S_R_t <span class=\"op\">=<\/span> tt.diag(d_N_R_t)<\/span>\n<span id=\"cb6-21\"><a href=\"#cb6-21\" aria-hidden=\"true\"><\/a>    S_R_t_inv <span class=\"op\">=<\/span> tt.diag(tt_finite_inv(d_N_R_t))<\/span>\n<span id=\"cb6-22\"><a href=\"#cb6-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-23\"><a href=\"#cb6-23\" aria-hidden=\"true\"><\/a>    N_C_t_inv <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>,<\/span>\n<span id=\"cb6-24\"><a href=\"#cb6-24\" aria-hidden=\"true\"><\/a>                        matrix_dot(N_V_t_inv, F_t.T, U_R_t),<\/span>\n<span id=\"cb6-25\"><a href=\"#cb6-25\" aria-hidden=\"true\"><\/a>                        S_R_t_inv)<\/span>\n<span id=\"cb6-26\"><a href=\"#cb6-26\" aria-hidden=\"true\"><\/a>    _, d_N_C_t_inv, V_N_C_t_inv_T <span class=\"op\">=<\/span> svd(N_C_t_inv)<\/span>\n<span id=\"cb6-27\"><a href=\"#cb6-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-28\"><a href=\"#cb6-28\" aria-hidden=\"true\"><\/a>    U_C_t <span class=\"op\">=<\/span> U_R_t.dot(V_N_C_t_inv_T.T)<\/span>\n<span id=\"cb6-29\"><a href=\"#cb6-29\" aria-hidden=\"true\"><\/a>    d_C_t <span class=\"op\">=<\/span> tt_finite_inv(tt.square(d_N_C_t_inv))<\/span>\n<span id=\"cb6-30\"><a href=\"#cb6-30\" aria-hidden=\"true\"><\/a>    D_C_t <span class=\"op\">=<\/span> tt.diag(d_C_t)<\/span>\n<span id=\"cb6-31\"><a href=\"#cb6-31\" aria-hidden=\"true\"><\/a>    S_C_t <span class=\"op\">=<\/span> tt.diag(tt.sqrt(d_C_t))<\/span>\n<span id=\"cb6-32\"><a href=\"#cb6-32\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-33\"><a href=\"#cb6-33\" aria-hidden=\"true\"><\/a>    C_t <span class=\"op\">=<\/span> matrix_dot(U_C_t, D_C_t, U_C_t.T)<\/span>\n<span id=\"cb6-34\"><a href=\"#cb6-34\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-35\"><a href=\"#cb6-35\" aria-hidden=\"true\"><\/a>    a_t <span class=\"op\">=<\/span> G_t.dot(m_tm1)<\/span>\n<span id=\"cb6-36\"><a href=\"#cb6-36\" aria-hidden=\"true\"><\/a>    f_t <span class=\"op\">=<\/span> F_t.T.dot(a_t)<\/span>\n<span id=\"cb6-37\"><a href=\"#cb6-37\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># A_t = R_t @ F_t @ inv(Q_t) = C_t @ F_t @ inv(V_t)<\/span><\/span>\n<span id=\"cb6-38\"><a href=\"#cb6-38\" aria-hidden=\"true\"><\/a>    m_t <span class=\"op\">=<\/span> a_t <span class=\"op\">+<\/span> matrix_dot(C_t, F_t, N_V_t_inv.T, N_V_t_inv, y_t <span class=\"op\">-<\/span> f_t)<\/span>\n<span id=\"cb6-39\"><a href=\"#cb6-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-40\"><a href=\"#cb6-40\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> [m_t, U_C_t, S_C_t, a_t, U_R_t, S_R_t]<\/span>\n<span id=\"cb6-41\"><a href=\"#cb6-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-42\"><a href=\"#cb6-42\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-43\"><a href=\"#cb6-43\" aria-hidden=\"true\"><\/a>_, d_C_0_tt, Vt_C_0_tt <span class=\"op\">=<\/span> svd(C_0_tt)<\/span>\n<span id=\"cb6-44\"><a href=\"#cb6-44\" aria-hidden=\"true\"><\/a>U_C_0_tt <span class=\"op\">=<\/span> Vt_C_0_tt.T<\/span>\n<span id=\"cb6-45\"><a href=\"#cb6-45\" aria-hidden=\"true\"><\/a>S_C_0_tt <span class=\"op\">=<\/span> tt.diag(tt.sqrt(d_C_0_tt))<\/span>\n<span id=\"cb6-46\"><a href=\"#cb6-46\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-47\"><a href=\"#cb6-47\" aria-hidden=\"true\"><\/a>_, d_W_tt, Vt_W_tt <span class=\"op\">=<\/span> svd(W_tt)<\/span>\n<span id=\"cb6-48\"><a href=\"#cb6-48\" aria-hidden=\"true\"><\/a>U_W_tt <span class=\"op\">=<\/span> Vt_W_tt.T<\/span>\n<span id=\"cb6-49\"><a href=\"#cb6-49\" aria-hidden=\"true\"><\/a>s_W_tt <span class=\"op\">=<\/span> tt.sqrt(d_W_tt)<\/span>\n<span id=\"cb6-50\"><a href=\"#cb6-50\" aria-hidden=\"true\"><\/a>N_W_tt <span class=\"op\">=<\/span> tt.diag(s_W_tt).dot(U_W_tt.T)<\/span>\n<span id=\"cb6-51\"><a href=\"#cb6-51\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-52\"><a href=\"#cb6-52\" aria-hidden=\"true\"><\/a>_, D_V_tt, Vt_V_tt <span class=\"op\">=<\/span> svd(tt.as_tensor_variable(V_tt, ndim<span class=\"op\">=<\/span><span class=\"dv\">2<\/span>) <span class=\"cf\">if<\/span> V_tt.ndim <span class=\"op\">&lt;<\/span> <span class=\"dv\">2<\/span> <span class=\"cf\">else<\/span> V_tt)<\/span>\n<span id=\"cb6-53\"><a href=\"#cb6-53\" aria-hidden=\"true\"><\/a>U_V_tt <span class=\"op\">=<\/span> Vt_V_tt.T<\/span>\n<span id=\"cb6-54\"><a href=\"#cb6-54\" aria-hidden=\"true\"><\/a>S_V_inv_tt <span class=\"op\">=<\/span> tt.diag(tt.sqrt(tt_finite_inv(D_V_tt, eps_truncate<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)))<\/span>\n<span id=\"cb6-55\"><a href=\"#cb6-55\" aria-hidden=\"true\"><\/a>N_V_inv_tt <span class=\"op\">=<\/span> S_V_inv_tt.dot(U_V_tt.T)<\/span>\n<span id=\"cb6-56\"><a href=\"#cb6-56\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-57\"><a href=\"#cb6-57\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-58\"><a href=\"#cb6-58\" aria-hidden=\"true\"><\/a>filter_res, filter_updates <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>filtering_step,<\/span>\n<span id=\"cb6-59\"><a href=\"#cb6-59\" aria-hidden=\"true\"><\/a>                                         sequences<span class=\"op\">=<\/span>y_tt,<\/span>\n<span id=\"cb6-60\"><a href=\"#cb6-60\" aria-hidden=\"true\"><\/a>                                         outputs_info<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb6-61\"><a href=\"#cb6-61\" aria-hidden=\"true\"><\/a>                                             {<span class=\"st\">&quot;initial&quot;<\/span>: m_0_tt, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb6-62\"><a href=\"#cb6-62\" aria-hidden=\"true\"><\/a>                                             {<span class=\"st\">&quot;initial&quot;<\/span>: U_C_0_tt, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb6-63\"><a href=\"#cb6-63\" aria-hidden=\"true\"><\/a>                                             {<span class=\"st\">&quot;initial&quot;<\/span>: S_C_0_tt, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb6-64\"><a href=\"#cb6-64\" aria-hidden=\"true\"><\/a>                                             {}, {}, {}, <span class=\"co\"># a_t, U_R_t, S_R_t<\/span><\/span>\n<span id=\"cb6-65\"><a href=\"#cb6-65\" aria-hidden=\"true\"><\/a>                                         ],<\/span>\n<span id=\"cb6-66\"><a href=\"#cb6-66\" aria-hidden=\"true\"><\/a>                                         non_sequences<span class=\"op\">=<\/span>[F_tt, G_tt, N_W_tt, N_V_inv_tt],<\/span>\n<span id=\"cb6-67\"><a href=\"#cb6-67\" aria-hidden=\"true\"><\/a>                                         strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb6-68\"><a href=\"#cb6-68\" aria-hidden=\"true\"><\/a>                                         name<span class=\"op\">=<\/span><span class=\"st\">&#39;theta_filtered&#39;<\/span>)<\/span>\n<span id=\"cb6-69\"><a href=\"#cb6-69\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-70\"><a href=\"#cb6-70\" aria-hidden=\"true\"><\/a>(m_t, U_C_t, S_C_t, a_t, U_R_t, S_R_t) <span class=\"op\">=<\/span> filter_res<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 6\n<\/figcaption>\n<\/figure>\n<p><a id=\"org350fa8d\"><\/a><\/p>\n<\/section>\n<section id=\"svd-based-smoothing\" class=\"level2\">\n<h2>SVD-based Smoothing<\/h2>\n<p>We can use the techniques above to produce SVD versions of the smoothing equations in <span class=\"math inline\">\\(\\eqref{eq:dlm-smooth-moments}\\)<\/span>. In this case, some extra steps are required in order to SVD-decompose <span class=\"math inline\">\\(S_t\\)<\/span> in the same manner as <span class=\"math inline\">\\(R_t\\)<\/span> and <span class=\"math inline\">\\(C_t^{-1}\\)<\/span> were.<\/p>\n<p>First, notice that our target, <span class=\"math inline\">\\(S_t\\)<\/span>, is a difference of matrices, unlike the matrix sums that comprised <span class=\"math inline\">\\(R_t\\)<\/span> and <span class=\"math inline\">\\(C_t^{-1}\\)<\/span> above. Furthermore, <span class=\"math inline\">\\(S_t\\)<\/span> is given as a difference of a (transformed) difference. To address the latter, we start by expanding <span class=\"math inline\">\\(S_t\\)<\/span> and setting <span class=\"math inline\">\\(B_t = C_t G_{t+1}^\\top R_{t+1}^{-1}\\)<\/span> to obtain<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    S_t &amp;= C_t - B_t R_{t+1} B_t^\\top + B_t S_{t+1} B_t^\\top\n      \\\\\n      &amp;= H_t + B_t S_{t+1} B_t^\\top\n  \\end{aligned}\n  \\label{eq:S_t_decomp}\n\\end{equation}\\]<\/span><\/p>\n<p>Having turned <span class=\"math inline\">\\(S_t\\)<\/span> into a sum of two terms, we can now consider another blocked SVD-based square-root reformulation, which starts with the reformulation of <span class=\"math inline\">\\(H_t\\)<\/span>.<\/p>\n<p>We can use the definition of <span class=\"math inline\">\\(R_t = G_{t+1} C_t G_{t+1}^\\top + W_{t+1}\\)<\/span> to get<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    H_t &amp;= C_t - B_t R_{t+1} B_t^\\top\n    \\\\\n    &amp;= C_t - C_t G_{t+1}^\\top R_{t+1}^{-1} G_{t+1} C_t\n    \\\\\n    &amp;= C_t - C_t G_{t+1}^\\top \\left(G_{t+1} C_t G_{t+1}^\\top + W_{t+1}\\right)^{-1} G_{t+1} C_t\n    .\n  \\end{aligned}\n\\end{equation}\\]<\/span><\/p>\n<p>This form of <span class=\"math inline\">\\(H_t\\)<\/span> fits the Woodbury identity and results in <span class=\"math inline\">\\(H_t^{-1} = G_{t+1}^\\top W_{t+1}^{-1} G_{t+1} + C_t^{-1}\\)<\/span>, which is amenable to our square-root formulation.<\/p>\n<p>Specifically, <span class=\"math inline\">\\(H_t^{-1} = U_{C_t} N_{H_t}^{-\\top} N_{H_t}^{-1} U_{C_t}^\\top\\)<\/span>, where<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    N_{H_t}^{-1} &amp;=\n      \\begin{pmatrix}\n        N_{W_{t+1}}^{-1} G_{t+1} U_{C_t}\n        \\\\\n        S_{C_t}^{-1}\n      \\end{pmatrix}\n  \\end{aligned}\n  .\n  \\label{eq:N_H_t_inv}\n\\end{equation}\\]<\/span><\/p>\n<p>By inverting the SVD of <span class=\"math inline\">\\(N_{H_t}^{-1}\\)<\/span> we obtain the SVD of <span class=\"math inline\">\\(H_t\\)<\/span> as <span class=\"math inline\">\\(U_{H_t} = U_{C_t} V_{N_{H_t}^{-1}}\\)<\/span> and <span class=\"math inline\">\\(D_{H_t} = {D_{N_{H_t}^{-1}}}^{-2} = S_{H_t}^2\\)<\/span>.<\/p>\n<p>Finally, using <span class=\"math inline\">\\(\\eqref{eq:S_t_decomp}\\)<\/span> and <span class=\"math inline\">\\(\\eqref{eq:N_H_t_inv}\\)<\/span> we can derive the last blocked square-root decomposition <span class=\"math inline\">\\(S_t = N_{S_t}^\\top N_{S_t}\\)<\/span>:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    N_{S_t} &amp;=\n      \\begin{pmatrix}\n        S_{H_t} U_{H_t}^\\top\n        \\\\\n        S_{S_{t+1}} U_{S_{t+1}}^\\top B_t^\\top\n      \\end{pmatrix}\n  \\end{aligned}\n  .\n  \\label{eq:N_S_t}\n\\end{equation}\\]<\/span><\/p>\n<p>Again, we take the SVD of <span class=\"math inline\">\\(N_{S_t}\\)<\/span> and derive the SVD of <span class=\"math inline\">\\(S_t\\)<\/span> as <span class=\"math inline\">\\(U_{S_t} = V_{N_{S_t}}\\)<\/span> and <span class=\"math inline\">\\(D_{S_t} = D_{N_{S_t}}^2 = S_{S_t}^2\\)<\/span>.<\/p>\n<figure id=\"org911f8da\">\n<div class=\"sourceCode\" id=\"cb7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb7-1\"><a href=\"#cb7-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> smoother_step(m_t, U_C_t, S_C_t, a_tp1, U_R_tp1, S_R_tp1, s_tp1, U_S_tp1, S_S_tp1, G_tp1, N_W_tp1_inv):<\/span>\n<span id=\"cb7-2\"><a href=\"#cb7-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Smooth a series starting from the &quot;forward&quot;\/sequentially computed posterior moments.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb7-3\"><a href=\"#cb7-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-4\"><a href=\"#cb7-4\" aria-hidden=\"true\"><\/a>    N_C_t <span class=\"op\">=<\/span> S_C_t.dot(U_C_t.T)<\/span>\n<span id=\"cb7-5\"><a href=\"#cb7-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-6\"><a href=\"#cb7-6\" aria-hidden=\"true\"><\/a>    S_R_tp1_inv <span class=\"op\">=<\/span> tt_finite_inv(S_R_tp1)<\/span>\n<span id=\"cb7-7\"><a href=\"#cb7-7\" aria-hidden=\"true\"><\/a>    N_R_tp1_inv <span class=\"op\">=<\/span> S_R_tp1_inv.dot(U_R_tp1.T)<\/span>\n<span id=\"cb7-8\"><a href=\"#cb7-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-9\"><a href=\"#cb7-9\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># B_t = C_t @ G_tp1.T @ inv(R_tp1)<\/span><\/span>\n<span id=\"cb7-10\"><a href=\"#cb7-10\" aria-hidden=\"true\"><\/a>    B_t <span class=\"op\">=<\/span> matrix_dot(N_C_t.T, N_C_t, G_tp1.T, N_R_tp1_inv.T, N_R_tp1_inv)<\/span>\n<span id=\"cb7-11\"><a href=\"#cb7-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-12\"><a href=\"#cb7-12\" aria-hidden=\"true\"><\/a>    S_C_t_inv <span class=\"op\">=<\/span> tt_finite_inv(S_C_t)<\/span>\n<span id=\"cb7-13\"><a href=\"#cb7-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-14\"><a href=\"#cb7-14\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># U_C_t @ N_H_t_inv.T @ N_H_t_inv @ U_C_t.T = G_tp1.T @ W_tp1_inv @ G_tp1 + C_t_inv<\/span><\/span>\n<span id=\"cb7-15\"><a href=\"#cb7-15\" aria-hidden=\"true\"><\/a>    N_H_t_inv <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>,<\/span>\n<span id=\"cb7-16\"><a href=\"#cb7-16\" aria-hidden=\"true\"><\/a>                        matrix_dot(N_W_tp1_inv, G_tp1, U_C_t),<\/span>\n<span id=\"cb7-17\"><a href=\"#cb7-17\" aria-hidden=\"true\"><\/a>                        S_C_t_inv)<\/span>\n<span id=\"cb7-18\"><a href=\"#cb7-18\" aria-hidden=\"true\"><\/a>    _, d_N_H_t_inv, V_N_H_t_T <span class=\"op\">=<\/span> svd(N_H_t_inv)<\/span>\n<span id=\"cb7-19\"><a href=\"#cb7-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-20\"><a href=\"#cb7-20\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># H_t = inv(U_C_t @ N_H_t_inv.T @ N_H_t_inv @ U_C_t.T) = C_t - B_t @ R_tp1 @ B_t.T<\/span><\/span>\n<span id=\"cb7-21\"><a href=\"#cb7-21\" aria-hidden=\"true\"><\/a>    U_H_t <span class=\"op\">=<\/span> U_C_t.dot(V_N_H_t_T.T)<\/span>\n<span id=\"cb7-22\"><a href=\"#cb7-22\" aria-hidden=\"true\"><\/a>    S_H_t <span class=\"op\">=<\/span> tt.diag(tt_finite_inv(d_N_H_t_inv))<\/span>\n<span id=\"cb7-23\"><a href=\"#cb7-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-24\"><a href=\"#cb7-24\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># S_t = N_S_t.T.dot(N_S_t) = C_t - matrix_dot(B_t, R_tp1 - S_tp1, B_t.T)<\/span><\/span>\n<span id=\"cb7-25\"><a href=\"#cb7-25\" aria-hidden=\"true\"><\/a>    N_S_t <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>,<\/span>\n<span id=\"cb7-26\"><a href=\"#cb7-26\" aria-hidden=\"true\"><\/a>                     S_H_t.dot(U_H_t.T),<\/span>\n<span id=\"cb7-27\"><a href=\"#cb7-27\" aria-hidden=\"true\"><\/a>                     matrix_dot(S_S_tp1, U_S_tp1.T, B_t.T))<\/span>\n<span id=\"cb7-28\"><a href=\"#cb7-28\" aria-hidden=\"true\"><\/a>    _, d_N_S_t, V_N_S_t_T <span class=\"op\">=<\/span> svd(N_S_t)<\/span>\n<span id=\"cb7-29\"><a href=\"#cb7-29\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-30\"><a href=\"#cb7-30\" aria-hidden=\"true\"><\/a>    U_S_t <span class=\"op\">=<\/span> V_N_S_t_T.T<\/span>\n<span id=\"cb7-31\"><a href=\"#cb7-31\" aria-hidden=\"true\"><\/a>    S_S_t <span class=\"op\">=<\/span> tt.diag(d_N_S_t)<\/span>\n<span id=\"cb7-32\"><a href=\"#cb7-32\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-33\"><a href=\"#cb7-33\" aria-hidden=\"true\"><\/a>    s_t <span class=\"op\">=<\/span> m_t <span class=\"op\">+<\/span> B_t.dot(s_tp1 <span class=\"op\">-<\/span> a_tp1)<\/span>\n<span id=\"cb7-34\"><a href=\"#cb7-34\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-35\"><a href=\"#cb7-35\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> [s_t, U_S_t, S_S_t]<\/span>\n<span id=\"cb7-36\"><a href=\"#cb7-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-37\"><a href=\"#cb7-37\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-38\"><a href=\"#cb7-38\" aria-hidden=\"true\"><\/a>N_W_inv_tt <span class=\"op\">=<\/span> tt.diag(tt_finite_inv(s_W_tt, eps_truncate<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)).dot(U_W_tt.T)<\/span>\n<span id=\"cb7-39\"><a href=\"#cb7-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-40\"><a href=\"#cb7-40\" aria-hidden=\"true\"><\/a>m_T <span class=\"op\">=<\/span> m_t[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb7-41\"><a href=\"#cb7-41\" aria-hidden=\"true\"><\/a>U_C_T <span class=\"op\">=<\/span> U_C_t[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb7-42\"><a href=\"#cb7-42\" aria-hidden=\"true\"><\/a>S_C_T <span class=\"op\">=<\/span> S_C_t[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb7-43\"><a href=\"#cb7-43\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-44\"><a href=\"#cb7-44\" aria-hidden=\"true\"><\/a><span class=\"co\"># These series only go from N_obs - 1 to 1<\/span><\/span>\n<span id=\"cb7-45\"><a href=\"#cb7-45\" aria-hidden=\"true\"><\/a>smoother_res, _ <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>smoother_step,<\/span>\n<span id=\"cb7-46\"><a href=\"#cb7-46\" aria-hidden=\"true\"><\/a>                              sequences<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb7-47\"><a href=\"#cb7-47\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;input&quot;<\/span>: m_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-48\"><a href=\"#cb7-48\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;input&quot;<\/span>: U_C_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-49\"><a href=\"#cb7-49\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;input&quot;<\/span>: S_C_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-50\"><a href=\"#cb7-50\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;input&quot;<\/span>: a_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-51\"><a href=\"#cb7-51\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;input&quot;<\/span>: U_R_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-52\"><a href=\"#cb7-52\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;input&quot;<\/span>: S_R_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]}<\/span>\n<span id=\"cb7-53\"><a href=\"#cb7-53\" aria-hidden=\"true\"><\/a>                              ],<\/span>\n<span id=\"cb7-54\"><a href=\"#cb7-54\" aria-hidden=\"true\"><\/a>                              outputs_info<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb7-55\"><a href=\"#cb7-55\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;initial&quot;<\/span>: m_T, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-56\"><a href=\"#cb7-56\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;initial&quot;<\/span>: U_C_T, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-57\"><a href=\"#cb7-57\" aria-hidden=\"true\"><\/a>                                  {<span class=\"st\">&quot;initial&quot;<\/span>: S_C_T, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb7-58\"><a href=\"#cb7-58\" aria-hidden=\"true\"><\/a>                              ],<\/span>\n<span id=\"cb7-59\"><a href=\"#cb7-59\" aria-hidden=\"true\"><\/a>                              non_sequences<span class=\"op\">=<\/span>[G_tt, N_W_inv_tt],<\/span>\n<span id=\"cb7-60\"><a href=\"#cb7-60\" aria-hidden=\"true\"><\/a>                              go_backwards<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb7-61\"><a href=\"#cb7-61\" aria-hidden=\"true\"><\/a>                              strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb7-62\"><a href=\"#cb7-62\" aria-hidden=\"true\"><\/a>                              name<span class=\"op\">=<\/span><span class=\"st\">&#39;theta_smoothed_obs&#39;<\/span>)<\/span>\n<span id=\"cb7-63\"><a href=\"#cb7-63\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-64\"><a href=\"#cb7-64\" aria-hidden=\"true\"><\/a>(s_t_rev, U_S_t_rev, S_S_t_rev) <span class=\"op\">=<\/span> smoother_res<\/span>\n<span id=\"cb7-65\"><a href=\"#cb7-65\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-66\"><a href=\"#cb7-66\" aria-hidden=\"true\"><\/a>s_t <span class=\"op\">=<\/span> s_t_rev[::<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb7-67\"><a href=\"#cb7-67\" aria-hidden=\"true\"><\/a>U_S_t <span class=\"op\">=<\/span> U_S_t_rev[::<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb7-68\"><a href=\"#cb7-68\" aria-hidden=\"true\"><\/a>S_S_t <span class=\"op\">=<\/span> S_S_t_rev[::<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb7-69\"><a href=\"#cb7-69\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-70\"><a href=\"#cb7-70\" aria-hidden=\"true\"><\/a>s_t <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>, s_t, [m_T])<\/span>\n<span id=\"cb7-71\"><a href=\"#cb7-71\" aria-hidden=\"true\"><\/a>U_S_t <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>, U_S_t, [U_C_T])<\/span>\n<span id=\"cb7-72\"><a href=\"#cb7-72\" aria-hidden=\"true\"><\/a>S_S_t <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>, S_S_t, [S_C_T])<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 7\n<\/figcaption>\n<\/figure>\n<p><a id=\"org9beb158\"><\/a><\/p>\n<\/section>\n<section id=\"example\" class=\"level2\">\n<h2>Example<\/h2>\n<p>Listing <a href=\"#org6c20a11\">8<\/a> computes the filtered and smoothed means for our simulated series, and Figure <a href=\"#org822c0f2\">9<\/a> shows the results.<\/p>\n<figure id=\"org6c20a11\">\n<div class=\"sourceCode\" id=\"cb8\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb8-1\"><a href=\"#cb8-1\" aria-hidden=\"true\"><\/a>filter_smooth_dlm <span class=\"op\">=<\/span> tt_function([y_tt, N_theta_tt, G_tt, F_tt], [m_t, s_t])<\/span>\n<span id=\"cb8-2\"><a href=\"#cb8-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-3\"><a href=\"#cb8-3\" aria-hidden=\"true\"><\/a><span class=\"co\"># phi_W_tt.set_value(phi_W_true)<\/span><\/span>\n<span id=\"cb8-4\"><a href=\"#cb8-4\" aria-hidden=\"true\"><\/a><span class=\"co\"># phi_V_tt.set_value(phi_V_true)<\/span><\/span>\n<span id=\"cb8-5\"><a href=\"#cb8-5\" aria-hidden=\"true\"><\/a>phi_W_tt.set_value(np.r_[<span class=\"fl\">100.0<\/span>, <span class=\"fl\">100.0<\/span>])<\/span>\n<span id=\"cb8-6\"><a href=\"#cb8-6\" aria-hidden=\"true\"><\/a>phi_V_tt.set_value(<span class=\"fl\">1.5<\/span>)<\/span>\n<span id=\"cb8-7\"><a href=\"#cb8-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-8\"><a href=\"#cb8-8\" aria-hidden=\"true\"><\/a>(m_t_sim, s_t_sim) <span class=\"op\">=<\/span> filter_smooth_dlm(y_sim, dlm_sim_values[N_theta_tt], dlm_sim_values[G_tt], dlm_sim_values[F_tt])<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 8\n<\/figcaption>\n<\/figure>\n<figure id=\"org822c0f2\">\n<div class=\"sourceCode\" id=\"cb9\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb9-1\"><a href=\"#cb9-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> cycler <span class=\"im\">import<\/span> cycler<\/span>\n<span id=\"cb9-2\"><a href=\"#cb9-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-3\"><a href=\"#cb9-3\" aria-hidden=\"true\"><\/a>bivariate_cycler <span class=\"op\">=<\/span> plt_orig_cycler <span class=\"op\">*<\/span> cycler(<span class=\"st\">&#39;linestyle&#39;<\/span>, [<span class=\"st\">&#39;-&#39;<\/span>, <span class=\"st\">&#39;--&#39;<\/span>])<\/span>\n<span id=\"cb9-4\"><a href=\"#cb9-4\" aria-hidden=\"true\"><\/a>plt.close(fig<span class=\"op\">=<\/span><span class=\"st\">&#39;all&#39;<\/span>)<\/span>\n<span id=\"cb9-5\"><a href=\"#cb9-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-6\"><a href=\"#cb9-6\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb9-7\"><a href=\"#cb9-7\" aria-hidden=\"true\"><\/a>ax.set_prop_cycle(bivariate_cycler)<\/span>\n<span id=\"cb9-8\"><a href=\"#cb9-8\" aria-hidden=\"true\"><\/a>ax.plot(theta_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$\\theta_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.8<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>)<\/span>\n<span id=\"cb9-9\"><a href=\"#cb9-9\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb9-10\"><a href=\"#cb9-10\" aria-hidden=\"true\"><\/a>ax.plot(m_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$E[\\theta_t \\mid D_<\/span><span class=\"sc\">{t}<\/span><span class=\"vs\">]$&#39;<\/span>, alpha<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.8<\/span>)<\/span>\n<span id=\"cb9-11\"><a href=\"#cb9-11\" aria-hidden=\"true\"><\/a>ax.plot(s_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$E[\\theta_t \\mid D_<\/span><span class=\"sc\">{T}<\/span><span class=\"vs\">]$&#39;<\/span>, alpha<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.8<\/span>)<\/span>\n<span id=\"cb9-12\"><a href=\"#cb9-12\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span>\n<span id=\"cb9-13\"><a href=\"#cb9-13\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 9\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/svd-steps-sim-plot.png\" title=\"fig:\" alt=\"Filtered and smoothed \\theta_t\u2013against the true \\theta_t\u2013computed using the SVD approach. \" \/>\n<figcaption>\nFiltered and smoothed <span class=\"math inline\">\\(\\theta_t\\)<\/span>\u2013against the true <span class=\"math inline\">\\(\\theta_t\\)<\/span>\u2013computed using the SVD approach.\n<\/figcaption>\n<\/figure>\n<p><a id=\"orgaa30fc5\"><\/a><\/p>\n<\/section>\n<\/section>\n<section id=\"forward-filtering-backward-sampling\" class=\"level1\">\n<h1>Forward-filtering Backward-sampling<\/h1>\n<p>We can use the smoothing and filtering steps in the previous section to perform more efficient MCMC estimation than would otherwise be possible without the Rao-Blackwellization inherent to both steps.<\/p>\n<p>Forward-filtering backward-sampling <a id=\"66fb99775f308e808a193bd7bb2d2038\"><a href=\"#Fruhwirth-SchnatterDataaugmentationdynamic1994\">(Fr-Schnatter 1994)<\/a><\/a> works by first computing the forward filtered moments, allowing one to draw <span class=\"math inline\">\\(\\theta_T\\)<\/span> from <span class=\"math inline\">\\(\\left(\\theta_T \\mid D_T\\right) \\sim \\operatorname{N}\\left(m_T, C_T\\right)\\)<\/span> and, subsequently, <span class=\"math inline\">\\(\\theta_t\\)<\/span> from <span class=\"math inline\">\\(\\left(\\theta_t \\mid \\theta_{t+1}, D_T \\right) \\sim \\operatorname{N}\\left(h_t, H_t\\right)\\)<\/span>.<\/p>\n<p>The latter distribution\u2019s moments are easily derived from the filtered and smoothed moments:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    h_t = m_t + B_t \\left(\\theta_{t+1} - a_{t+1}\\right)\n    \\\\\n    H_t = C_t - B_t R_{t+1} B^\\top_t\n  \\end{gathered}\n  \\label{eq:ffbs-moments}\n\\end{equation}\\]<\/span><\/p>\n<p>Since all the quantities in <span class=\"math inline\">\\(\\eqref{eq:ffbs-moments}\\)<\/span> appear in the filtering and smoothing moments, we can use the SVD-based approach described earlier to perform the updates and sampling. We reproduce the relevant subset of calculations in Listing <a href=\"#org6d942a3\">10<\/a>.<\/p>\n<figure id=\"org6d942a3\">\n<div class=\"sourceCode\" id=\"cb10\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb10-1\"><a href=\"#cb10-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> ffbs_step(m_t, U_C_t, S_C_t, a_tp1, U_R_tp1, S_R_tp1, theta_tp1, F_tp1, G_tp1, N_W_tp1_inv, rng):<\/span>\n<span id=\"cb10-2\"><a href=\"#cb10-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Perform forward-filtering backward-sampling.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb10-3\"><a href=\"#cb10-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-4\"><a href=\"#cb10-4\" aria-hidden=\"true\"><\/a>    S_C_t_inv <span class=\"op\">=<\/span> tt_finite_inv(S_C_t)<\/span>\n<span id=\"cb10-5\"><a href=\"#cb10-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-6\"><a href=\"#cb10-6\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># H_t_inv = U_C_t @ N_H_t_inv.T @ N_H_t_inv @ U_C_t.T = G_tp1^T @ W_tp1_inv @ G_tp1.T + C_t_inv<\/span><\/span>\n<span id=\"cb10-7\"><a href=\"#cb10-7\" aria-hidden=\"true\"><\/a>    N_H_t_inv <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>,<\/span>\n<span id=\"cb10-8\"><a href=\"#cb10-8\" aria-hidden=\"true\"><\/a>                        matrix_dot(N_W_tp1_inv, G_tp1, U_C_t),<\/span>\n<span id=\"cb10-9\"><a href=\"#cb10-9\" aria-hidden=\"true\"><\/a>                        S_C_t_inv)<\/span>\n<span id=\"cb10-10\"><a href=\"#cb10-10\" aria-hidden=\"true\"><\/a>    _, d_N_H_t_inv, V_N_H_t_inv_T <span class=\"op\">=<\/span> svd(N_H_t_inv)<\/span>\n<span id=\"cb10-11\"><a href=\"#cb10-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-12\"><a href=\"#cb10-12\" aria-hidden=\"true\"><\/a>    U_H_t <span class=\"op\">=<\/span> U_C_t.dot(V_N_H_t_inv_T.T)<\/span>\n<span id=\"cb10-13\"><a href=\"#cb10-13\" aria-hidden=\"true\"><\/a>    s_H_t <span class=\"op\">=<\/span> tt_finite_inv(d_N_H_t_inv)<\/span>\n<span id=\"cb10-14\"><a href=\"#cb10-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-15\"><a href=\"#cb10-15\" aria-hidden=\"true\"><\/a>    N_C_t <span class=\"op\">=<\/span> S_C_t.dot(U_C_t.T)<\/span>\n<span id=\"cb10-16\"><a href=\"#cb10-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-17\"><a href=\"#cb10-17\" aria-hidden=\"true\"><\/a>    S_R_tp1_inv <span class=\"op\">=<\/span> tt_finite_inv(S_R_tp1)<\/span>\n<span id=\"cb10-18\"><a href=\"#cb10-18\" aria-hidden=\"true\"><\/a>    N_R_tp1_inv <span class=\"op\">=<\/span> S_R_tp1_inv.dot(U_R_tp1.T)<\/span>\n<span id=\"cb10-19\"><a href=\"#cb10-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-20\"><a href=\"#cb10-20\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># B_t = C_t @ G_tp1.T @ inv(R_tp1)<\/span><\/span>\n<span id=\"cb10-21\"><a href=\"#cb10-21\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># B_t = matrix_dot(U_H_t * s_H_t, s_H_t * U_H_t.T,<\/span><\/span>\n<span id=\"cb10-22\"><a href=\"#cb10-22\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#                  G_tp1.T, N_W_tp1_inv.T, N_W_tp1_inv)<\/span><\/span>\n<span id=\"cb10-23\"><a href=\"#cb10-23\" aria-hidden=\"true\"><\/a>    B_t <span class=\"op\">=<\/span> matrix_dot(N_C_t.T, N_C_t, G_tp1.T, N_R_tp1_inv.T, N_R_tp1_inv)<\/span>\n<span id=\"cb10-24\"><a href=\"#cb10-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-25\"><a href=\"#cb10-25\" aria-hidden=\"true\"><\/a>    h_t <span class=\"op\">=<\/span> m_t <span class=\"op\">+<\/span> B_t.dot(theta_tp1 <span class=\"op\">-<\/span> a_tp1)<\/span>\n<span id=\"cb10-26\"><a href=\"#cb10-26\" aria-hidden=\"true\"><\/a>    h_t.name <span class=\"op\">=<\/span> <span class=\"st\">&#39;h_t&#39;<\/span><\/span>\n<span id=\"cb10-27\"><a href=\"#cb10-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-28\"><a href=\"#cb10-28\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: Add an option or optimization to use the SVD to sample in<\/span><\/span>\n<span id=\"cb10-29\"><a href=\"#cb10-29\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># `MvNormalRV`.<\/span><\/span>\n<span id=\"cb10-30\"><a href=\"#cb10-30\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># theta_t = MvNormalRV(h_t, H_t, rng=rng, name=&#39;theta_t_ffbs&#39;)<\/span><\/span>\n<span id=\"cb10-31\"><a href=\"#cb10-31\" aria-hidden=\"true\"><\/a>    theta_t <span class=\"op\">=<\/span> h_t <span class=\"op\">+<\/span> tt.dot(U_H_t, s_H_t <span class=\"op\">*<\/span><\/span>\n<span id=\"cb10-32\"><a href=\"#cb10-32\" aria-hidden=\"true\"><\/a>                           MvNormalRV(tt.zeros_like(h_t),<\/span>\n<span id=\"cb10-33\"><a href=\"#cb10-33\" aria-hidden=\"true\"><\/a>                                      tt.eye(h_t.shape[<span class=\"dv\">0<\/span>]),<\/span>\n<span id=\"cb10-34\"><a href=\"#cb10-34\" aria-hidden=\"true\"><\/a>                                      rng<span class=\"op\">=<\/span>rng))<\/span>\n<span id=\"cb10-35\"><a href=\"#cb10-35\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-36\"><a href=\"#cb10-36\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># These are statistics we&#39;re gathering for other posterior updates<\/span><\/span>\n<span id=\"cb10-37\"><a href=\"#cb10-37\" aria-hidden=\"true\"><\/a>    theta_tp1_diff <span class=\"op\">=<\/span> theta_tp1 <span class=\"op\">-<\/span> G_tp1.dot(theta_t)<\/span>\n<span id=\"cb10-38\"><a href=\"#cb10-38\" aria-hidden=\"true\"><\/a>    f_tp1 <span class=\"op\">=<\/span> F_tp1.T.dot(theta_t)<\/span>\n<span id=\"cb10-39\"><a href=\"#cb10-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-40\"><a href=\"#cb10-40\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Sequentially sample\/update quantities conditional on `theta_t` here...<\/span><\/span>\n<span id=\"cb10-41\"><a href=\"#cb10-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-42\"><a href=\"#cb10-42\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> [theta_t, theta_tp1_diff, f_tp1]<\/span>\n<span id=\"cb10-43\"><a href=\"#cb10-43\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-44\"><a href=\"#cb10-44\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-45\"><a href=\"#cb10-45\" aria-hidden=\"true\"><\/a><span class=\"co\"># C_T = matrix_dot(U_C_T, tt.square(S_C_T), U_C_T.T)<\/span><\/span>\n<span id=\"cb10-46\"><a href=\"#cb10-46\" aria-hidden=\"true\"><\/a><span class=\"co\"># theta_T_post = MvNormalRV(m_T, C_T, rng=rng_tt)<\/span><\/span>\n<span id=\"cb10-47\"><a href=\"#cb10-47\" aria-hidden=\"true\"><\/a>theta_T_post <span class=\"op\">=<\/span> m_T <span class=\"op\">+<\/span> matrix_dot(U_C_T, S_C_T,<\/span>\n<span id=\"cb10-48\"><a href=\"#cb10-48\" aria-hidden=\"true\"><\/a>                                MvNormalRV(tt.zeros_like(m_T),<\/span>\n<span id=\"cb10-49\"><a href=\"#cb10-49\" aria-hidden=\"true\"><\/a>                                           tt.eye(m_T.shape[<span class=\"dv\">0<\/span>]),<\/span>\n<span id=\"cb10-50\"><a href=\"#cb10-50\" aria-hidden=\"true\"><\/a>                                           rng<span class=\"op\">=<\/span>rng_tt))<\/span>\n<span id=\"cb10-51\"><a href=\"#cb10-51\" aria-hidden=\"true\"><\/a>theta_T_post.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;theta_T_post&quot;<\/span><\/span>\n<span id=\"cb10-52\"><a href=\"#cb10-52\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-53\"><a href=\"#cb10-53\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-54\"><a href=\"#cb10-54\" aria-hidden=\"true\"><\/a>ffbs_output, ffbs_updates <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>ffbs_step,<\/span>\n<span id=\"cb10-55\"><a href=\"#cb10-55\" aria-hidden=\"true\"><\/a>                                        sequences<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb10-56\"><a href=\"#cb10-56\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: m_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb10-57\"><a href=\"#cb10-57\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: U_C_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb10-58\"><a href=\"#cb10-58\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: S_C_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb10-59\"><a href=\"#cb10-59\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: a_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb10-60\"><a href=\"#cb10-60\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: U_R_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb10-61\"><a href=\"#cb10-61\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: S_R_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]}<\/span>\n<span id=\"cb10-62\"><a href=\"#cb10-62\" aria-hidden=\"true\"><\/a>                                        ],<\/span>\n<span id=\"cb10-63\"><a href=\"#cb10-63\" aria-hidden=\"true\"><\/a>                                        outputs_info<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb10-64\"><a href=\"#cb10-64\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;initial&quot;<\/span>: theta_T_post, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb10-65\"><a href=\"#cb10-65\" aria-hidden=\"true\"><\/a>                                            {}, {}, <span class=\"co\"># theta_tp1_diff, f_tp1<\/span><\/span>\n<span id=\"cb10-66\"><a href=\"#cb10-66\" aria-hidden=\"true\"><\/a>                                        ],<\/span>\n<span id=\"cb10-67\"><a href=\"#cb10-67\" aria-hidden=\"true\"><\/a>                                        non_sequences<span class=\"op\">=<\/span>[F_tt, G_tt, N_W_inv_tt, rng_tt],<\/span>\n<span id=\"cb10-68\"><a href=\"#cb10-68\" aria-hidden=\"true\"><\/a>                                        go_backwards<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb10-69\"><a href=\"#cb10-69\" aria-hidden=\"true\"><\/a>                                        strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb10-70\"><a href=\"#cb10-70\" aria-hidden=\"true\"><\/a>                                        name<span class=\"op\">=<\/span><span class=\"st\">&#39;ffbs_samples&#39;<\/span>)<\/span>\n<span id=\"cb10-71\"><a href=\"#cb10-71\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-72\"><a href=\"#cb10-72\" aria-hidden=\"true\"><\/a>(theta_t_post_rev, theta_t_diff_rev, f_t_rev) <span class=\"op\">=<\/span> ffbs_output<\/span>\n<span id=\"cb10-73\"><a href=\"#cb10-73\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-74\"><a href=\"#cb10-74\" aria-hidden=\"true\"><\/a>theta_t_post <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>, theta_t_post_rev[::<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>], [theta_T_post])<\/span>\n<span id=\"cb10-75\"><a href=\"#cb10-75\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-76\"><a href=\"#cb10-76\" aria-hidden=\"true\"><\/a><span class=\"co\"># We need to add the missing end-points onto these statistics...<\/span><\/span>\n<span id=\"cb10-77\"><a href=\"#cb10-77\" aria-hidden=\"true\"><\/a>f_t_post <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>, f_t_rev[::<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>], [F_tt.T.dot(theta_T_post)])<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 10\n<\/figcaption>\n<\/figure>\n<p>Quantities besides the state values, <span class=\"math inline\">\\(\\theta_t\\)<\/span>, can be sampled sequentially (i.e.\u00a0within the function <code>ffbs_step<\/code> in Listing <a href=\"#org6d942a3\">10<\/a>), or after FFBS when all <span class=\"math inline\">\\(\\theta_t \\mid D_T\\)<\/span> have been sampled. These quantities can use the conditionally Gaussian form of <span class=\"math inline\">\\(\\left(\\theta_t \\mid \\theta_{t+1}, D_T \\right)\\)<\/span> to derive Gibbs steps, further Rao-Blackwellize hierarchical quantities, or apply any other means of producing posterior samples conditional on <span class=\"math inline\">\\(\\left(\\theta_t \\mid \\theta_{t+1}, D_T \\right)\\)<\/span>.<\/p>\n<p>In our simulation example, we will further augment our original model by adding the classic conjugate gamma priors to our previously fixed state and observation precision parameters, <span class=\"math inline\">\\(\\phi_W\\)<\/span> and <span class=\"math inline\">\\(\\phi_V\\)<\/span>, respectively:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    \\phi_{W} \\sim \\operatorname{Gamma}\\left( a_W, b_W \\right)\n    ,\\quad\n    \\phi_{V} \\sim \\operatorname{Gamma}\\left( a_V, b_V \\right)\n  \\end{gathered}\n  \\label{eq:phi-gamma-priors}\n\\end{equation}\\]<\/span><\/p>\n<p>This classical conjugate prior allows one to derive simple closed-form posteriors for a Gibbs sampler conditional on <span class=\"math inline\">\\(y_t\\)<\/span>, <span class=\"math inline\">\\(\\theta_t\\)<\/span>, and <span class=\"math inline\">\\(\\theta_{t-1}\\)<\/span>, as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    \\phi_{W} \\mid D_T \\sim\n    \\operatorname{Gamma}\\left( a_W + \\frac{N}{2}, b_W + \\frac{1}{2} \\sum_{t=1}^N \\left( \\theta_t - G_t \\theta_{t-1} \\right)^2 \\right)\n    ,\\quad\n    \\phi_{V} \\mid D_T \\sim\n    \\operatorname{Gamma}\\left( a_V + \\frac{N}{2}, b_V + \\frac{1}{2} \\sum_{t=1}^N \\left( y_t - F_t^\\top \\theta_t \\right)^2 \\right)\n  \\end{gathered}\n  \\label{eq:phi-gamma-posteriors}\n\\end{equation}\\]<\/span><\/p>\n<p>Those posterior computations are implemented in Listings <a href=\"#org4e5c0a0\">11<\/a> and <a href=\"#org1156a41\">12<\/a>, and they are used to update the shared Theano variables for <span class=\"math inline\">\\(\\phi_W\\)<\/span> and <span class=\"math inline\">\\(\\phi_V\\)<\/span> within a Gibbs sampling loop in Listing <a href=\"#org005b0cb\">13<\/a>.<\/p>\n<p><a id=\"org68ce824\"><\/a><\/p>\n<section id=\"simulation-example\" class=\"level2\">\n<h2>Simulation Example<\/h2>\n<figure id=\"org4e5c0a0\">\n<div class=\"sourceCode\" id=\"cb11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb11-1\"><a href=\"#cb11-1\" aria-hidden=\"true\"><\/a>phi_W_a, phi_W_b <span class=\"op\">=<\/span> theano.shared(np.r_[<span class=\"fl\">2.5<\/span>, <span class=\"fl\">2.5<\/span>]), theano.shared(np.r_[<span class=\"fl\">0.5<\/span>, <span class=\"fl\">0.5<\/span>])<\/span>\n<span id=\"cb11-2\"><a href=\"#cb11-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-3\"><a href=\"#cb11-3\" aria-hidden=\"true\"><\/a>phi_W_a_post_tt <span class=\"op\">=<\/span> phi_W_a <span class=\"op\">+<\/span> N_obs_tt <span class=\"op\">*<\/span> <span class=\"fl\">0.5<\/span><\/span>\n<span id=\"cb11-4\"><a href=\"#cb11-4\" aria-hidden=\"true\"><\/a>phi_W_SS_tt <span class=\"op\">=<\/span> tt.square(theta_t_diff_rev).<span class=\"bu\">sum<\/span>(<span class=\"dv\">0<\/span>)<\/span>\n<span id=\"cb11-5\"><a href=\"#cb11-5\" aria-hidden=\"true\"><\/a>phi_W_b_post_tt <span class=\"op\">=<\/span> phi_W_b <span class=\"op\">+<\/span> <span class=\"fl\">0.5<\/span> <span class=\"op\">*<\/span> phi_W_SS_tt<\/span>\n<span id=\"cb11-6\"><a href=\"#cb11-6\" aria-hidden=\"true\"><\/a>phi_W_post_tt <span class=\"op\">=<\/span> GammaRV(phi_W_a_post_tt, phi_W_b_post_tt, rng<span class=\"op\">=<\/span>rng_tt, name<span class=\"op\">=<\/span><span class=\"st\">&#39;phi_W_post&#39;<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 11\n<\/figcaption>\n<\/figure>\n<figure id=\"org1156a41\">\n<div class=\"sourceCode\" id=\"cb12\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb12-1\"><a href=\"#cb12-1\" aria-hidden=\"true\"><\/a>phi_V_a, phi_V_b <span class=\"op\">=<\/span> theano.shared(<span class=\"fl\">0.125<\/span>), theano.shared(<span class=\"fl\">0.25<\/span>)<\/span>\n<span id=\"cb12-2\"><a href=\"#cb12-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-3\"><a href=\"#cb12-3\" aria-hidden=\"true\"><\/a>phi_V_a_post_tt <span class=\"op\">=<\/span> phi_V_a <span class=\"op\">+<\/span> N_obs_tt <span class=\"op\">*<\/span> <span class=\"fl\">0.5<\/span><\/span>\n<span id=\"cb12-4\"><a href=\"#cb12-4\" aria-hidden=\"true\"><\/a>phi_V_SS_tt <span class=\"op\">=<\/span> tt.square(y_tt <span class=\"op\">-<\/span> f_t_post).<span class=\"bu\">sum<\/span>()<\/span>\n<span id=\"cb12-5\"><a href=\"#cb12-5\" aria-hidden=\"true\"><\/a>phi_V_b_post_tt <span class=\"op\">=<\/span> phi_V_b <span class=\"op\">+<\/span> <span class=\"fl\">0.5<\/span> <span class=\"op\">*<\/span> phi_V_SS_tt<\/span>\n<span id=\"cb12-6\"><a href=\"#cb12-6\" aria-hidden=\"true\"><\/a>phi_V_post_tt <span class=\"op\">=<\/span> GammaRV(phi_V_a_post_tt, phi_V_b_post_tt, rng<span class=\"op\">=<\/span>rng_tt, name<span class=\"op\">=<\/span><span class=\"st\">&#39;phi_V_post&#39;<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 12\n<\/figcaption>\n<\/figure>\n<figure id=\"org005b0cb\">\n<div class=\"sourceCode\" id=\"cb13\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb13-1\"><a href=\"#cb13-1\" aria-hidden=\"true\"><\/a>ffbs_dlm <span class=\"op\">=<\/span> tt_function([y_tt, N_obs_tt, N_theta_tt, G_tt, F_tt],<\/span>\n<span id=\"cb13-2\"><a href=\"#cb13-2\" aria-hidden=\"true\"><\/a>                       [theta_t_post, phi_W_post_tt, phi_V_post_tt,<\/span>\n<span id=\"cb13-3\"><a href=\"#cb13-3\" aria-hidden=\"true\"><\/a>                        phi_W_SS_tt, phi_V_SS_tt],<\/span>\n<span id=\"cb13-4\"><a href=\"#cb13-4\" aria-hidden=\"true\"><\/a>                       updates<span class=\"op\">=<\/span>ffbs_updates)<\/span>\n<span id=\"cb13-5\"><a href=\"#cb13-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-6\"><a href=\"#cb13-6\" aria-hidden=\"true\"><\/a>rng_tt.get_value(borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>).set_state(rng_sim_state)<\/span>\n<span id=\"cb13-7\"><a href=\"#cb13-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-8\"><a href=\"#cb13-8\" aria-hidden=\"true\"><\/a>phi_W_0 <span class=\"op\">=<\/span> phi_W_a.get_value()<span class=\"op\">\/<\/span>phi_W_b.get_value()<\/span>\n<span id=\"cb13-9\"><a href=\"#cb13-9\" aria-hidden=\"true\"><\/a>phi_V_0 <span class=\"op\">=<\/span> phi_V_a.get_value()<span class=\"op\">\/<\/span>phi_V_b.get_value()<\/span>\n<span id=\"cb13-10\"><a href=\"#cb13-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-11\"><a href=\"#cb13-11\" aria-hidden=\"true\"><\/a>phi_W_tt.set_value(phi_W_0)<\/span>\n<span id=\"cb13-12\"><a href=\"#cb13-12\" aria-hidden=\"true\"><\/a>phi_V_tt.set_value(phi_V_0)<\/span>\n<span id=\"cb13-13\"><a href=\"#cb13-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-14\"><a href=\"#cb13-14\" aria-hidden=\"true\"><\/a>chain <span class=\"op\">=<\/span> <span class=\"dv\">0<\/span><\/span>\n<span id=\"cb13-15\"><a href=\"#cb13-15\" aria-hidden=\"true\"><\/a>theta_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\theta_t \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb13-16\"><a href=\"#cb13-16\" aria-hidden=\"true\"><\/a>phi_W_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\phi_W \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb13-17\"><a href=\"#cb13-17\" aria-hidden=\"true\"><\/a>phi_V_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\phi_V \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb13-18\"><a href=\"#cb13-18\" aria-hidden=\"true\"><\/a>theta_t_post_sim, phi_W_post_sim, phi_V_post_sim <span class=\"op\">=<\/span> <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span><\/span>\n<span id=\"cb13-19\"><a href=\"#cb13-19\" aria-hidden=\"true\"><\/a>posterior_samples <span class=\"op\">=<\/span> {theta_label: [[]], phi_W_label: [[]], phi_V_label: [[]]}<\/span>\n<span id=\"cb13-20\"><a href=\"#cb13-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-21\"><a href=\"#cb13-21\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(<span class=\"dv\">1000<\/span>):<\/span>\n<span id=\"cb13-22\"><a href=\"#cb13-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-23\"><a href=\"#cb13-23\" aria-hidden=\"true\"><\/a>    theta_t_post_sim, phi_W_post_sim, phi_V_post_sim, phi_W_SS_sim, phi_V_SS_sim <span class=\"op\">=<\/span> ffbs_dlm(<\/span>\n<span id=\"cb13-24\"><a href=\"#cb13-24\" aria-hidden=\"true\"><\/a>        y_sim,<\/span>\n<span id=\"cb13-25\"><a href=\"#cb13-25\" aria-hidden=\"true\"><\/a>        dlm_sim_values[N_obs_tt], dlm_sim_values[N_theta_tt],<\/span>\n<span id=\"cb13-26\"><a href=\"#cb13-26\" aria-hidden=\"true\"><\/a>        dlm_sim_values[G_tt], dlm_sim_values[F_tt])<\/span>\n<span id=\"cb13-27\"><a href=\"#cb13-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-28\"><a href=\"#cb13-28\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Update variance precision parameters<\/span><\/span>\n<span id=\"cb13-29\"><a href=\"#cb13-29\" aria-hidden=\"true\"><\/a>    phi_W_tt.set_value(phi_W_post_sim)<\/span>\n<span id=\"cb13-30\"><a href=\"#cb13-30\" aria-hidden=\"true\"><\/a>    phi_V_tt.set_value(phi_V_post_sim)<\/span>\n<span id=\"cb13-31\"><a href=\"#cb13-31\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-32\"><a href=\"#cb13-32\" aria-hidden=\"true\"><\/a>    posterior_samples[theta_label][chain].append(theta_t_post_sim)<\/span>\n<span id=\"cb13-33\"><a href=\"#cb13-33\" aria-hidden=\"true\"><\/a>    posterior_samples[phi_W_label][chain].append(phi_W_post_sim)<\/span>\n<span id=\"cb13-34\"><a href=\"#cb13-34\" aria-hidden=\"true\"><\/a>    posterior_samples[phi_V_label][chain].append(phi_V_post_sim)<\/span>\n<span id=\"cb13-35\"><a href=\"#cb13-35\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-36\"><a href=\"#cb13-36\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"ss\">f&#39;i=<\/span><span class=\"sc\">{i}<\/span><span class=\"ss\">,<\/span><span class=\"ch\">\\t<\/span><span class=\"ss\">phi_W=<\/span><span class=\"sc\">{<\/span>phi_W_post_sim<span class=\"sc\">}<\/span><span class=\"ch\">\\t<\/span><span class=\"ss\">(<\/span><span class=\"sc\">{<\/span>phi_W_SS_sim<span class=\"sc\">}<\/span><span class=\"ss\">),<\/span><span class=\"ch\">\\t<\/span><span class=\"ss\">phi_V=<\/span><span class=\"sc\">{<\/span>phi_V_post_sim<span class=\"sc\">}<\/span><span class=\"ss\"> (<\/span><span class=\"sc\">{<\/span>phi_V_SS_sim<span class=\"sc\">}<\/span><span class=\"ss\">)&#39;<\/span>)<\/span>\n<span id=\"cb13-37\"><a href=\"#cb13-37\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-38\"><a href=\"#cb13-38\" aria-hidden=\"true\"><\/a>posterior_samples <span class=\"op\">=<\/span> {k: np.asarray(v) <span class=\"cf\">for<\/span> k,v <span class=\"kw\">in<\/span> posterior_samples.items()}<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 13\n<\/figcaption>\n<\/figure>\n<p>Figure <a href=\"#org010a945\">14<\/a> shows the posterior <span class=\"math inline\">\\(\\theta_t\\)<\/span> samples and Figure <a href=\"#org968df9e\">15<\/a> plots the posterior sample traces.<\/p>\n<figure id=\"org010a945\">\n<div class=\"sourceCode\" id=\"cb14\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb14-1\"><a href=\"#cb14-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb14-2\"><a href=\"#cb14-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-3\"><a href=\"#cb14-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb14-4\"><a href=\"#cb14-4\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb14-5\"><a href=\"#cb14-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-6\"><a href=\"#cb14-6\" aria-hidden=\"true\"><\/a><span class=\"co\"># bivariate_cycler =  cycler(&#39;linestyle&#39;, [&#39;-&#39;, &#39;--&#39;]) * plt_orig_cycler<\/span><\/span>\n<span id=\"cb14-7\"><a href=\"#cb14-7\" aria-hidden=\"true\"><\/a><span class=\"co\"># ax.set_prop_cycle(bivariate_cycler)<\/span><\/span>\n<span id=\"cb14-8\"><a href=\"#cb14-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-9\"><a href=\"#cb14-9\" aria-hidden=\"true\"><\/a>thetas_shape <span class=\"op\">=<\/span> posterior_samples[theta_label][<span class=\"dv\">0<\/span>].shape<\/span>\n<span id=\"cb14-10\"><a href=\"#cb14-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-11\"><a href=\"#cb14-11\" aria-hidden=\"true\"><\/a>cycle <span class=\"op\">=<\/span> ax._get_lines.prop_cycler<\/span>\n<span id=\"cb14-12\"><a href=\"#cb14-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-13\"><a href=\"#cb14-13\" aria-hidden=\"true\"><\/a>bivariate_obs_cycler <span class=\"op\">=<\/span>  cycler(<span class=\"st\">&#39;linestyle&#39;<\/span>, [<span class=\"st\">&#39;-&#39;<\/span>, <span class=\"st\">&#39;--&#39;<\/span>]) <span class=\"op\">*<\/span> cycler(<span class=\"st\">&#39;color&#39;<\/span>, [<span class=\"st\">&#39;black&#39;<\/span>])<\/span>\n<span id=\"cb14-14\"><a href=\"#cb14-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-15\"><a href=\"#cb14-15\" aria-hidden=\"true\"><\/a>ax.set_prop_cycle(bivariate_obs_cycler)<\/span>\n<span id=\"cb14-16\"><a href=\"#cb14-16\" aria-hidden=\"true\"><\/a>ax.plot(theta_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$\\theta_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span>)<\/span>\n<span id=\"cb14-17\"><a href=\"#cb14-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-18\"><a href=\"#cb14-18\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb14-19\"><a href=\"#cb14-19\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb14-20\"><a href=\"#cb14-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-21\"><a href=\"#cb14-21\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> d <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]):<\/span>\n<span id=\"cb14-22\"><a href=\"#cb14-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-23\"><a href=\"#cb14-23\" aria-hidden=\"true\"><\/a>    styles <span class=\"op\">=<\/span> <span class=\"bu\">next<\/span>(cycle)<\/span>\n<span id=\"cb14-24\"><a href=\"#cb14-24\" aria-hidden=\"true\"><\/a>    thetas <span class=\"op\">=<\/span> posterior_samples[theta_label][<span class=\"dv\">0<\/span>].T[d].T<\/span>\n<span id=\"cb14-25\"><a href=\"#cb14-25\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-26\"><a href=\"#cb14-26\" aria-hidden=\"true\"><\/a>    theta_lines <span class=\"op\">=<\/span> np.empty(thetas_shape[:<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>] <span class=\"op\">+<\/span> (<span class=\"dv\">2<\/span>,))<\/span>\n<span id=\"cb14-27\"><a href=\"#cb14-27\" aria-hidden=\"true\"><\/a>    theta_lines.T[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.tile(np.arange(thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">2<\/span>]), [thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">3<\/span>], <span class=\"dv\">1<\/span>]).T<\/span>\n<span id=\"cb14-28\"><a href=\"#cb14-28\" aria-hidden=\"true\"><\/a>    theta_lines.T[<span class=\"dv\">1<\/span>] <span class=\"op\">=<\/span> thetas.T<\/span>\n<span id=\"cb14-29\"><a href=\"#cb14-29\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-30\"><a href=\"#cb14-30\" aria-hidden=\"true\"><\/a>    ax.add_collection(<\/span>\n<span id=\"cb14-31\"><a href=\"#cb14-31\" aria-hidden=\"true\"><\/a>        LineCollection(theta_lines,<\/span>\n<span id=\"cb14-32\"><a href=\"#cb14-32\" aria-hidden=\"true\"><\/a>                       label<span class=\"op\">=<\/span>theta_label,<\/span>\n<span id=\"cb14-33\"><a href=\"#cb14-33\" aria-hidden=\"true\"><\/a>                       alpha<span class=\"op\">=<\/span><span class=\"fl\">0.3<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>,<\/span>\n<span id=\"cb14-34\"><a href=\"#cb14-34\" aria-hidden=\"true\"><\/a>                       <span class=\"op\">**<\/span>styles)<\/span>\n<span id=\"cb14-35\"><a href=\"#cb14-35\" aria-hidden=\"true\"><\/a>    )<\/span>\n<span id=\"cb14-36\"><a href=\"#cb14-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-37\"><a href=\"#cb14-37\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb14-38\"><a href=\"#cb14-38\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-39\"><a href=\"#cb14-39\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 14\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/ffbs-sim-theta-plot.png\" title=\"fig:\" alt=\"Posterior \\theta_t samples generated by a FFBS-based Gibbs sampler. \" \/>\n<figcaption>\nPosterior <span class=\"math inline\">\\(\\theta_t\\)<\/span> samples generated by a FFBS-based Gibbs sampler.\n<\/figcaption>\n<\/figure>\n<figure id=\"org968df9e\">\n<div class=\"sourceCode\" id=\"cb15\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb15-1\"><a href=\"#cb15-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> arviz <span class=\"im\">as<\/span> az<\/span>\n<span id=\"cb15-2\"><a href=\"#cb15-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb15-3\"><a href=\"#cb15-3\" aria-hidden=\"true\"><\/a>az_trace <span class=\"op\">=<\/span> az.from_dict(posterior<span class=\"op\">=<\/span>posterior_samples)<\/span>\n<span id=\"cb15-4\"><a href=\"#cb15-4\" aria-hidden=\"true\"><\/a>az.plot_trace(az_trace, compact<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 15\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/ffbs-sim-trace-plot.png\" title=\"fig:\" alt=\"Posterior sample traces for the FFBS-based Gibbs sampler. \" \/>\n<figcaption>\nPosterior sample traces for the FFBS-based Gibbs sampler.\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb16\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb16-1\"><a href=\"#cb16-1\" aria-hidden=\"true\"><\/a>plt.close(<span class=\"st\">&#39;all&#39;<\/span>)<\/span>\n<span id=\"cb16-2\"><a href=\"#cb16-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-3\"><a href=\"#cb16-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb16-4\"><a href=\"#cb16-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-5\"><a href=\"#cb16-5\" aria-hidden=\"true\"><\/a>ax.plot(y_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$y_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>)<\/span>\n<span id=\"cb16-6\"><a href=\"#cb16-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-7\"><a href=\"#cb16-7\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb16-8\"><a href=\"#cb16-8\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb16-9\"><a href=\"#cb16-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-10\"><a href=\"#cb16-10\" aria-hidden=\"true\"><\/a>f_t_ordinates <span class=\"op\">=<\/span> np.dot(posterior_samples[theta_label][<span class=\"dv\">0<\/span>], dlm_sim_values[F_tt].squeeze())<\/span>\n<span id=\"cb16-11\"><a href=\"#cb16-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-12\"><a href=\"#cb16-12\" aria-hidden=\"true\"><\/a>f_t_lines <span class=\"op\">=<\/span> np.empty(f_t_ordinates.shape <span class=\"op\">+<\/span> (<span class=\"dv\">2<\/span>,))<\/span>\n<span id=\"cb16-13\"><a href=\"#cb16-13\" aria-hidden=\"true\"><\/a>f_t_lines.T[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.tile(np.arange(f_t_ordinates.shape[<span class=\"dv\">1<\/span>]), [f_t_ordinates.shape[<span class=\"dv\">0<\/span>], <span class=\"dv\">1<\/span>]).T<\/span>\n<span id=\"cb16-14\"><a href=\"#cb16-14\" aria-hidden=\"true\"><\/a>f_t_lines.T[<span class=\"dv\">1<\/span>] <span class=\"op\">=<\/span> f_t_ordinates.T<\/span>\n<span id=\"cb16-15\"><a href=\"#cb16-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-16\"><a href=\"#cb16-16\" aria-hidden=\"true\"><\/a>ax.add_collection(<\/span>\n<span id=\"cb16-17\"><a href=\"#cb16-17\" aria-hidden=\"true\"><\/a>    LineCollection(f_t_lines,<\/span>\n<span id=\"cb16-18\"><a href=\"#cb16-18\" aria-hidden=\"true\"><\/a>                   label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$E[y_t \\mid \\theta_t, D_T]$&#39;<\/span>,<\/span>\n<span id=\"cb16-19\"><a href=\"#cb16-19\" aria-hidden=\"true\"><\/a>                   alpha<span class=\"op\">=<\/span><span class=\"fl\">0.3<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;red&#39;<\/span>)<\/span>\n<span id=\"cb16-20\"><a href=\"#cb16-20\" aria-hidden=\"true\"><\/a>)<\/span>\n<span id=\"cb16-21\"><a href=\"#cb16-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-22\"><a href=\"#cb16-22\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb16-23\"><a href=\"#cb16-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-24\"><a href=\"#cb16-24\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/ffbs-sim-pred-plot.png\" title=\"fig:\" alt=\"Posterior predicive sample means generated by a FFBS-based Gibbs sampler. \" \/>\n<figcaption>\nPosterior predicive sample means generated by a FFBS-based Gibbs sampler.\n<\/figcaption>\n<\/figure>\n<p><a id=\"orge6290b1\"><\/a><\/p>\n<\/section>\n<\/section>\n<section id=\"non-gaussian-extension\" class=\"level1\">\n<h1>Non-Gaussian Extension<\/h1>\n<p>Let\u2019s say we want to model count observations that are driven by a smooth time-varying process. We can assume a negative-binomial observation model with a log-link function\u2013in standard GLM fashion <a id=\"820d466dbb5494bd5a5cb080d0b82638\"><a href=\"#mccullagh_generalized_1989\">(McCullagh &amp; Nelder 1989)<\/a><\/a>\u2013and connect it to the same state and observation dynamics as the basic DLM in <span class=\"math inline\">\\(\\eqref{eq:basic-dlm-state}\\)<\/span> via its mean <span class=\"math inline\">\\(\\mu_t = \\exp\\left(\\eta_t\\right)\\)<\/span>:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    Y_t &amp;\\sim \\operatorname{NB}\\left(r, p_t\\right)\n    \\\\\n    E[Y_t \\mid \\theta_t] &amp;= \\mu_t = \\exp\\left( F_t^\\top \\theta_t \\right)\n    \\\\\n    \\theta_t &amp;= G_t \\theta_{t-1} + \\nu_t, \\quad \\nu_t \\sim \\operatorname{N}\\left( 0, W \\right)\n  \\end{aligned}\n\\label{eq:nb-dlm}\n\\end{equation}\\]<\/span><\/p>\n<p>Under the parameterization <span class=\"math inline\">\\(p_t = \\frac{\\mu_t}{\\mu_t + r}\\)<\/span>, the negative-binomial density function takes the following form:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    \\operatorname{p}\\left(Y_t = y_t \\mid r, p_t\\right) &amp;=\n    \\frac{\\Gamma\\left( y_t + r \\right)}{y_t!\\,\\Gamma(r)}\n    \\left( 1 - p_t \\right)^r \\left( p_t \\right)^{y_t}\n    \\\\\n    &amp;=\n    \\frac{\\Gamma\\left( y_t + r \\right)}{y_t!\\,\\Gamma(r)}\n    \\left( \\frac{r}{r + \\mu_t} \\right)^r \\left( \\frac{\\mu_t}{r + \\mu_t} \\right)^{y_t}\n    \\\\\n    &amp;=\n    \\frac{\\Gamma\\left( y_t + r \\right)}{y_t!\\,\\Gamma(r)}\n    \\frac{\\left( \\mu_t \/ r \\right)^{y_t}}{\\left( 1 + \\mu_t \/ r \\right)^{r + y_t}}\n    \\\\\n    &amp;=\n    \\frac{\\Gamma\\left( y_t + r \\right)}{y_t!\\,\\Gamma(r)}\n    \\frac{\\left( e^{\\eta_t - \\log r} \\right)^{y_t}}{\\left( 1 + e^{\\eta_t - \\log r} \\right)^{r + y_t}}\n    .\n  \\end{aligned}\n\\label{eq:nb-pmf}\n\\end{equation}\\]<\/span><\/p>\n<p>The logit-inverse form in <span class=\"math inline\">\\(\\eqref{eq:nb-pmf}\\)<\/span> has a Gaussian scale-mixture representation in the Po\u0301lya-Gamma distribution <a id=\"2776c69c558410bc65a3a0a0f1ff5bf4\"><a href=\"#polson_bayesian_2013\">(Polson, Scott &amp; Windle 2013)<\/a><\/a>. Said scale-mixture is as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    \\frac{e^{\\psi a}}{\\left(1 + e^{\\psi}\\right)^b} &amp;=\n    2^{-b} e^{\\kappa \\psi} \\int^{\\infty}_{0} e^{-\\frac{\\omega}{2} \\psi^2} \\operatorname{p}(\\omega) d\\omega\n    \\\\\n    &amp;=\n    2^{-b} \\int^{\\infty}_{0} e^{-\\frac{\\omega}{2} \\left( \\psi - \\frac{\\kappa}{\\omega} \\right)^2}\n      e^{\\frac{\\kappa^2}{2 \\omega} } \\operatorname{p}(\\omega) d\\omega\n    .\n  \\end{aligned}\n\\label{eq:pg-identity}\n\\end{equation}\\]<\/span><\/p>\n<p>where <span class=\"math inline\">\\(\\kappa = a - b \/ 2\\)<\/span> and <span class=\"math inline\">\\(\\omega_t \\sim \\operatorname{PG}\\left(b, 0\\right)\\)<\/span>.<\/p>\n<p>When the Gaussian scale-mixture identity of <span class=\"math inline\">\\(\\eqref{eq:pg-identity}\\)<\/span> is applied to our observation model in <span class=\"math inline\">\\(\\eqref{eq:nb-pmf}\\)<\/span>, <span class=\"math inline\">\\(a = y_t\\)<\/span>, <span class=\"math inline\">\\(b = r + y_t\\)<\/span>, <span class=\"math inline\">\\(\\kappa = (y_t - r)\/2\\)<\/span>, and <span class=\"math inline\">\\(\\psi = \\eta_t - \\log r\\)<\/span>.<\/p>\n<p>From the scale mixture formulation, we obtain the following augmented joint observation model density:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{aligned}\n    \\operatorname{p}(\\theta_t \\mid \\omega, y_t, r, D_{t-1}) &amp;\\propto\n    \\exp\\left(-\\frac{\\omega}{2} \\left( F_t^\\top \\theta_t - \\left( \\log r + \\frac{y_t - r}{2 \\omega} \\right) \\right)^2 \\right)\n    \\operatorname{p}(\\theta_t \\mid D_{t-1})\n    \\\\\n    &amp;=\n    e^{-\\frac{\\omega}{2} \\left(y^*_t - F_t^\\top \\theta_t \\right)^2}\n    \\operatorname{p}(\\theta_t \\mid D_{t-1})\n    \\\\\n    &amp;\\propto \\operatorname{p}\\left( y^*_t \\mid F_t^\\top \\theta_t, \\omega^{-1} \\right)\n    \\operatorname{p}(\\theta_t \\mid D_{t-1})\n  \\end{aligned}\n\\label{eq:nb-aug-obs}\n\\end{equation}\\]<\/span><\/p>\n<p>where <span class=\"math inline\">\\(y^*_t = \\log r + \\frac{y_t - r}{2 \\omega}\\)<\/span>.<\/p>\n<p>The reformulation at the end of <span class=\"math inline\">\\(\\eqref{eq:nb-aug-obs}\\)<\/span> characterizes a distribution of \u201cvirtual\u201d observations,<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  y^*_t = F_t^\\top \\theta_t + \\epsilon^*_t,\\quad\n  \\epsilon^*_t \\sim \\operatorname{N}\\left(0, V^*_t \\right)\n  ,\n  \\label{eq:nb-virtual-obs}\n\\end{equation}\\]<\/span><\/p>\n<p>with <span class=\"math inline\">\\(V_t = \\operatorname{diag}\\left( \\omega^{-1}_t \\right)\\)<\/span>.<\/p>\n<p>The augmented observation equation in <span class=\"math inline\">\\(\\eqref{eq:nb-virtual-obs}\\)<\/span> recovers our original DLM formulation in <span class=\"math inline\">\\(\\eqref{eq:basic-dlm-obs}\\)<\/span> and\u2013by comparison\u2013details the changes needed to use the Po\u0301lya-Gamma scale-mixture. Specifically, we need to sample <span class=\"math inline\">\\(\\omega_t \\sim \\operatorname{PG}\\left(r + y_t, F_t^\\top \\theta_t - \\log r \\right)\\)<\/span> and perform the following substitutions: <span class=\"math inline\">\\(y_t \\to y^*_t\\)<\/span> and <span class=\"math inline\">\\(V_t \\to V^*_t\\)<\/span>.<\/p>\n<p><a id=\"orgeef4d98\"><\/a><\/p>\n<section id=\"simulation-example-1\" class=\"level2\">\n<h2>Simulation Example<\/h2>\n<p>Listing <a href=\"#orga5ba480\">17<\/a> creates a Theano graph for negative-binomial model defined in <span class=\"math inline\">\\(\\eqref{eq:nb-dlm}\\)<\/span>.<\/p>\n<figure id=\"orga5ba480\">\n<div class=\"sourceCode\" id=\"cb17\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb17-1\"><a href=\"#cb17-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.theano.random_variables <span class=\"im\">import<\/span> NegBinomialRV<\/span>\n<span id=\"cb17-2\"><a href=\"#cb17-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-3\"><a href=\"#cb17-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-4\"><a href=\"#cb17-4\" aria-hidden=\"true\"><\/a>r_tt <span class=\"op\">=<\/span> tt.iscalar(<span class=\"st\">&#39;r&#39;<\/span>)<\/span>\n<span id=\"cb17-5\"><a href=\"#cb17-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-6\"><a href=\"#cb17-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-7\"><a href=\"#cb17-7\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> nb_obs_step(theta_t, F_t, r, rng):<\/span>\n<span id=\"cb17-8\"><a href=\"#cb17-8\" aria-hidden=\"true\"><\/a>    mu_t <span class=\"op\">=<\/span> tt.exp(F_t.T.dot(theta_t))<\/span>\n<span id=\"cb17-9\"><a href=\"#cb17-9\" aria-hidden=\"true\"><\/a>    p_t <span class=\"op\">=<\/span> mu_t <span class=\"op\">\/<\/span> (mu_t <span class=\"op\">+<\/span> r)<\/span>\n<span id=\"cb17-10\"><a href=\"#cb17-10\" aria-hidden=\"true\"><\/a>    y_t <span class=\"op\">=<\/span> NegBinomialRV(r, (<span class=\"fl\">1.<\/span> <span class=\"op\">-<\/span> p_t), rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span><span class=\"st\">&#39;y_t&#39;<\/span>)<\/span>\n<span id=\"cb17-11\"><a href=\"#cb17-11\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> y_t, p_t<\/span>\n<span id=\"cb17-12\"><a href=\"#cb17-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-13\"><a href=\"#cb17-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-14\"><a href=\"#cb17-14\" aria-hidden=\"true\"><\/a>nb_obs_res, nb_Y_t_updates <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>nb_obs_step,<\/span>\n<span id=\"cb17-15\"><a href=\"#cb17-15\" aria-hidden=\"true\"><\/a>                                         sequences<span class=\"op\">=<\/span>[theta_t_rv],<\/span>\n<span id=\"cb17-16\"><a href=\"#cb17-16\" aria-hidden=\"true\"><\/a>                                         non_sequences<span class=\"op\">=<\/span>[F_tt, r_tt, rng_tt],<\/span>\n<span id=\"cb17-17\"><a href=\"#cb17-17\" aria-hidden=\"true\"><\/a>                                         outputs_info<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb17-18\"><a href=\"#cb17-18\" aria-hidden=\"true\"><\/a>                                             {}, {}, <span class=\"co\"># y_t, p_t<\/span><\/span>\n<span id=\"cb17-19\"><a href=\"#cb17-19\" aria-hidden=\"true\"><\/a>                                         ],<\/span>\n<span id=\"cb17-20\"><a href=\"#cb17-20\" aria-hidden=\"true\"><\/a>                                         strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb17-21\"><a href=\"#cb17-21\" aria-hidden=\"true\"><\/a>                                         name<span class=\"op\">=<\/span><span class=\"st\">&#39;Y_t&#39;<\/span>)<\/span>\n<span id=\"cb17-22\"><a href=\"#cb17-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-23\"><a href=\"#cb17-23\" aria-hidden=\"true\"><\/a>nb_Y_t_rv, nb_p_t_tt <span class=\"op\">=<\/span> nb_obs_res<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 17\n<\/figcaption>\n<\/figure>\n<p>Listing <a href=\"#org1bd6385\">18<\/a> specifies parameters for a simulation from <span class=\"math inline\">\\(\\eqref{eq:nb-dlm}\\)<\/span> and samples a series.<\/p>\n<figure id=\"org1bd6385\">\n<div class=\"sourceCode\" id=\"cb18\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb18-1\"><a href=\"#cb18-1\" aria-hidden=\"true\"><\/a>nb_dlm_sim_values <span class=\"op\">=<\/span> dlm_sim_values.copy()<\/span>\n<span id=\"cb18-2\"><a href=\"#cb18-2\" aria-hidden=\"true\"><\/a>nb_dlm_sim_values[F_tt] <span class=\"op\">=<\/span> np.array([[<span class=\"fl\">1.0<\/span>],<\/span>\n<span id=\"cb18-3\"><a href=\"#cb18-3\" aria-hidden=\"true\"><\/a>                                    [<span class=\"fl\">0.0<\/span>]], dtype<span class=\"op\">=<\/span>theano.config.floatX)<\/span>\n<span id=\"cb18-4\"><a href=\"#cb18-4\" aria-hidden=\"true\"><\/a>nb_dlm_sim_values[G_tt] <span class=\"op\">=<\/span> np.array([[<span class=\"fl\">1.0<\/span>, <span class=\"fl\">0.1<\/span>],<\/span>\n<span id=\"cb18-5\"><a href=\"#cb18-5\" aria-hidden=\"true\"><\/a>                                    [<span class=\"fl\">0.0<\/span>, <span class=\"fl\">0.8<\/span>]], dtype<span class=\"op\">=<\/span>theano.config.floatX)<\/span>\n<span id=\"cb18-6\"><a href=\"#cb18-6\" aria-hidden=\"true\"><\/a>nb_dlm_sim_values[r_tt] <span class=\"op\">=<\/span> <span class=\"dv\">1000<\/span><\/span>\n<span id=\"cb18-7\"><a href=\"#cb18-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-8\"><a href=\"#cb18-8\" aria-hidden=\"true\"><\/a>phi_W_tt.set_value(np.r_[<span class=\"fl\">10.0<\/span>, <span class=\"fl\">10.0<\/span>])<\/span>\n<span id=\"cb18-9\"><a href=\"#cb18-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-10\"><a href=\"#cb18-10\" aria-hidden=\"true\"><\/a>rng_tt.get_value(borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>).set_state(rng_init_state)<\/span>\n<span id=\"cb18-11\"><a href=\"#cb18-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-12\"><a href=\"#cb18-12\" aria-hidden=\"true\"><\/a>simulate_nb_dlm <span class=\"op\">=<\/span> tt_function([N_obs_tt, N_theta_tt, G_tt, F_tt, r_tt],<\/span>\n<span id=\"cb18-13\"><a href=\"#cb18-13\" aria-hidden=\"true\"><\/a>                              [nb_Y_t_rv, theta_t_rv, nb_p_t_tt],<\/span>\n<span id=\"cb18-14\"><a href=\"#cb18-14\" aria-hidden=\"true\"><\/a>                              givens<span class=\"op\">=<\/span>{theta_0_rv: np.r_[<span class=\"fl\">1.0<\/span>, <span class=\"fl\">0.5<\/span>]},<\/span>\n<span id=\"cb18-15\"><a href=\"#cb18-15\" aria-hidden=\"true\"><\/a>                              updates<span class=\"op\">=<\/span>nb_Y_t_updates)<\/span>\n<span id=\"cb18-16\"><a href=\"#cb18-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-17\"><a href=\"#cb18-17\" aria-hidden=\"true\"><\/a>sim_nb_res <span class=\"op\">=<\/span> simulate_nb_dlm(nb_dlm_sim_values[N_obs_tt],<\/span>\n<span id=\"cb18-18\"><a href=\"#cb18-18\" aria-hidden=\"true\"><\/a>                             nb_dlm_sim_values[N_theta_tt],<\/span>\n<span id=\"cb18-19\"><a href=\"#cb18-19\" aria-hidden=\"true\"><\/a>                             nb_dlm_sim_values[G_tt],<\/span>\n<span id=\"cb18-20\"><a href=\"#cb18-20\" aria-hidden=\"true\"><\/a>                             nb_dlm_sim_values[F_tt],<\/span>\n<span id=\"cb18-21\"><a href=\"#cb18-21\" aria-hidden=\"true\"><\/a>                             nb_dlm_sim_values[r_tt])<\/span>\n<span id=\"cb18-22\"><a href=\"#cb18-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-23\"><a href=\"#cb18-23\" aria-hidden=\"true\"><\/a>nb_y_sim, nb_theta_t_sim, nb_p_t_sim <span class=\"op\">=<\/span> sim_nb_res<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 18\n<\/figcaption>\n<\/figure>\n<p>In Figure <a href=\"#orgf31ba4c\">19<\/a> we plot the sample generated in Listing <a href=\"#orga5ba480\">17<\/a>.<\/p>\n<figure id=\"orgf31ba4c\">\n<div class=\"sourceCode\" id=\"cb19\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb19-1\"><a href=\"#cb19-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb19-2\"><a href=\"#cb19-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb19-3\"><a href=\"#cb19-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb19-4\"><a href=\"#cb19-4\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> ax.plot(nb_y_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$y_t$&#39;<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.2<\/span>, drawstyle<span class=\"op\">=<\/span><span class=\"st\">&#39;steps-pre&#39;<\/span>)<\/span>\n<span id=\"cb19-5\"><a href=\"#cb19-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb19-6\"><a href=\"#cb19-6\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb19-7\"><a href=\"#cb19-7\" aria-hidden=\"true\"><\/a>plt.legend()<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 19\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-dlm-sim-plot.png\" title=\"fig:\" alt=\"A series sampled from our negative-binomial model defined in \\eqref{eq:nb-dlm}. \" \/>\n<figcaption>\nA series sampled from our negative-binomial model defined in <span class=\"math inline\">\\(\\eqref{eq:nb-dlm}\\)<\/span>.\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb20\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb20-1\"><a href=\"#cb20-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb20-2\"><a href=\"#cb20-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-3\"><a href=\"#cb20-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb20-4\"><a href=\"#cb20-4\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb20-5\"><a href=\"#cb20-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-6\"><a href=\"#cb20-6\" aria-hidden=\"true\"><\/a>bivariate_obs_cycler <span class=\"op\">=<\/span>  cycler(<span class=\"st\">&#39;linestyle&#39;<\/span>, [<span class=\"st\">&#39;-&#39;<\/span>, <span class=\"st\">&#39;--&#39;<\/span>]) <span class=\"op\">*<\/span> cycler(<span class=\"st\">&#39;color&#39;<\/span>, [<span class=\"st\">&#39;black&#39;<\/span>])<\/span>\n<span id=\"cb20-7\"><a href=\"#cb20-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-8\"><a href=\"#cb20-8\" aria-hidden=\"true\"><\/a>ax.set_prop_cycle(bivariate_obs_cycler)<\/span>\n<span id=\"cb20-9\"><a href=\"#cb20-9\" aria-hidden=\"true\"><\/a>ax.plot(nb_theta_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$\\theta_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span>)<\/span>\n<span id=\"cb20-10\"><a href=\"#cb20-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-11\"><a href=\"#cb20-11\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb20-12\"><a href=\"#cb20-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-13\"><a href=\"#cb20-13\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb20-14\"><a href=\"#cb20-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-15\"><a href=\"#cb20-15\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-theta-sim-plot.png\" title=\"fig:\" alt=\"Simulated \\theta_t values from \\eqref{eq:nb-dlm}. \" \/>\n<figcaption>\nSimulated <span class=\"math inline\">\\(\\theta_t\\)<\/span> values from <span class=\"math inline\">\\(\\eqref{eq:nb-dlm}\\)<\/span>.\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb21\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb21-1\"><a href=\"#cb21-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb21-2\"><a href=\"#cb21-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb21-3\"><a href=\"#cb21-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb21-4\"><a href=\"#cb21-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb21-5\"><a href=\"#cb21-5\" aria-hidden=\"true\"><\/a>ax.plot(nb_p_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$p_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>)<\/span>\n<span id=\"cb21-6\"><a href=\"#cb21-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb21-7\"><a href=\"#cb21-7\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb21-8\"><a href=\"#cb21-8\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-p-sim-plot.png\" title=\"fig:\" alt=\"Simulated p_t \\mid \\theta_t values from \\eqref{eq:nb-dlm}. \" \/>\n<figcaption>\nSimulated <span class=\"math inline\">\\(p_t \\mid \\theta_t\\)<\/span> values from <span class=\"math inline\">\\(\\eqref{eq:nb-dlm}\\)<\/span>.\n<\/figcaption>\n<\/figure>\n<p><a id=\"orgd705951\"><\/a><\/p>\n<\/section>\n<\/section>\n<section id=\"augmented-ffbs-sampler\" class=\"level1\">\n<h1>Augmented FFBS Sampler<\/h1>\n<p>In order to create a FFBS sampler for our Po\u0301lya-Gamma DLM in <span class=\"math inline\">\\(\\eqref{eq:nb-aug-obs}\\)<\/span>, we need to update the filtering code to use time-varying \u201cvirtual\u201d observation variances, <span class=\"math inline\">\\(V^*_t\\)<\/span>. After this change is made, all Theano graphs that depend on the resulting objects need to be recreated, as well. This is done in Listing <a href=\"#org6e0be06\">22<\/a>.<\/p>\n<figure id=\"org6e0be06\">\n<div class=\"sourceCode\" id=\"cb22\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb22-1\"><a href=\"#cb22-1\" aria-hidden=\"true\"><\/a>V_t_tt <span class=\"op\">=<\/span> tt.specify_shape(tt.col(), [N_obs_tt, <span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb22-2\"><a href=\"#cb22-2\" aria-hidden=\"true\"><\/a>V_t_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&#39;V_t&#39;<\/span><\/span>\n<span id=\"cb22-3\"><a href=\"#cb22-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-4\"><a href=\"#cb22-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-5\"><a href=\"#cb22-5\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> nb_filtering_step(y_t, V_t, m_tm1, U_C_tm1, S_C_tm1, F_t, G_t, N_W_t):<\/span>\n<span id=\"cb22-6\"><a href=\"#cb22-6\" aria-hidden=\"true\"><\/a>    N_V_t_inv <span class=\"op\">=<\/span> tt.diag(tt_finite_inv(tt.sqrt(V_t), eps_truncate<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span>\n<span id=\"cb22-7\"><a href=\"#cb22-7\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> filtering_step(y_t, m_tm1, U_C_tm1, S_C_tm1, F_t, G_t, N_W_t, N_V_t_inv)<\/span>\n<span id=\"cb22-8\"><a href=\"#cb22-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-9\"><a href=\"#cb22-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-10\"><a href=\"#cb22-10\" aria-hidden=\"true\"><\/a>filter_res, filter_updates <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>nb_filtering_step,<\/span>\n<span id=\"cb22-11\"><a href=\"#cb22-11\" aria-hidden=\"true\"><\/a>                                         sequences<span class=\"op\">=<\/span>[y_tt, V_t_tt],<\/span>\n<span id=\"cb22-12\"><a href=\"#cb22-12\" aria-hidden=\"true\"><\/a>                                         outputs_info<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb22-13\"><a href=\"#cb22-13\" aria-hidden=\"true\"><\/a>                                             {<span class=\"st\">&quot;initial&quot;<\/span>: m_0_tt, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-14\"><a href=\"#cb22-14\" aria-hidden=\"true\"><\/a>                                             {<span class=\"st\">&quot;initial&quot;<\/span>: U_C_0_tt, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-15\"><a href=\"#cb22-15\" aria-hidden=\"true\"><\/a>                                             {<span class=\"st\">&quot;initial&quot;<\/span>: S_C_0_tt, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-16\"><a href=\"#cb22-16\" aria-hidden=\"true\"><\/a>                                             {}, {}, {}  <span class=\"co\"># a_t, U_R_t, S_R_t<\/span><\/span>\n<span id=\"cb22-17\"><a href=\"#cb22-17\" aria-hidden=\"true\"><\/a>                                         ],<\/span>\n<span id=\"cb22-18\"><a href=\"#cb22-18\" aria-hidden=\"true\"><\/a>                                         non_sequences<span class=\"op\">=<\/span>[F_tt, G_tt, N_W_tt],<\/span>\n<span id=\"cb22-19\"><a href=\"#cb22-19\" aria-hidden=\"true\"><\/a>                                         strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb22-20\"><a href=\"#cb22-20\" aria-hidden=\"true\"><\/a>                                         name<span class=\"op\">=<\/span><span class=\"st\">&#39;theta_filtered&#39;<\/span>)<\/span>\n<span id=\"cb22-21\"><a href=\"#cb22-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-22\"><a href=\"#cb22-22\" aria-hidden=\"true\"><\/a>(m_t, U_C_t, S_C_t, a_t, U_R_t, S_R_t) <span class=\"op\">=<\/span> filter_res<\/span>\n<span id=\"cb22-23\"><a href=\"#cb22-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-24\"><a href=\"#cb22-24\" aria-hidden=\"true\"><\/a>m_T <span class=\"op\">=<\/span> m_t[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb22-25\"><a href=\"#cb22-25\" aria-hidden=\"true\"><\/a>U_C_T <span class=\"op\">=<\/span> U_C_t[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb22-26\"><a href=\"#cb22-26\" aria-hidden=\"true\"><\/a>S_C_T <span class=\"op\">=<\/span> S_C_t[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb22-27\"><a href=\"#cb22-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-28\"><a href=\"#cb22-28\" aria-hidden=\"true\"><\/a>C_T <span class=\"op\">=<\/span> matrix_dot(U_C_T, tt.square(S_C_T), U_C_T.T)<\/span>\n<span id=\"cb22-29\"><a href=\"#cb22-29\" aria-hidden=\"true\"><\/a>theta_T_post <span class=\"op\">=<\/span> MvNormalRV(m_T, C_T, rng<span class=\"op\">=<\/span>rng_tt)<\/span>\n<span id=\"cb22-30\"><a href=\"#cb22-30\" aria-hidden=\"true\"><\/a>theta_T_post.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;theta_T_post&quot;<\/span><\/span>\n<span id=\"cb22-31\"><a href=\"#cb22-31\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-32\"><a href=\"#cb22-32\" aria-hidden=\"true\"><\/a>ffbs_output, ffbs_updates <span class=\"op\">=<\/span> theano.scan(fn<span class=\"op\">=<\/span>ffbs_step,<\/span>\n<span id=\"cb22-33\"><a href=\"#cb22-33\" aria-hidden=\"true\"><\/a>                                        sequences<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb22-34\"><a href=\"#cb22-34\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: m_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-35\"><a href=\"#cb22-35\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: U_C_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-36\"><a href=\"#cb22-36\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: S_C_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-37\"><a href=\"#cb22-37\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: a_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-38\"><a href=\"#cb22-38\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: U_R_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-39\"><a href=\"#cb22-39\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;input&quot;<\/span>: S_R_t, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"dv\">1<\/span>]}<\/span>\n<span id=\"cb22-40\"><a href=\"#cb22-40\" aria-hidden=\"true\"><\/a>                                        ],<\/span>\n<span id=\"cb22-41\"><a href=\"#cb22-41\" aria-hidden=\"true\"><\/a>                                        outputs_info<span class=\"op\">=<\/span>[<\/span>\n<span id=\"cb22-42\"><a href=\"#cb22-42\" aria-hidden=\"true\"><\/a>                                            {<span class=\"st\">&quot;initial&quot;<\/span>: theta_T_post, <span class=\"st\">&quot;taps&quot;<\/span>: [<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb22-43\"><a href=\"#cb22-43\" aria-hidden=\"true\"><\/a>                                            {}, {}, <span class=\"co\"># theta_tp1_diff, f_tp1<\/span><\/span>\n<span id=\"cb22-44\"><a href=\"#cb22-44\" aria-hidden=\"true\"><\/a>                                        ],<\/span>\n<span id=\"cb22-45\"><a href=\"#cb22-45\" aria-hidden=\"true\"><\/a>                                        non_sequences<span class=\"op\">=<\/span>[F_tt, G_tt, N_W_inv_tt, rng_tt],<\/span>\n<span id=\"cb22-46\"><a href=\"#cb22-46\" aria-hidden=\"true\"><\/a>                                        go_backwards<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb22-47\"><a href=\"#cb22-47\" aria-hidden=\"true\"><\/a>                                        strict<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb22-48\"><a href=\"#cb22-48\" aria-hidden=\"true\"><\/a>                                        name<span class=\"op\">=<\/span><span class=\"st\">&#39;ffbs_samples&#39;<\/span>)<\/span>\n<span id=\"cb22-49\"><a href=\"#cb22-49\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-50\"><a href=\"#cb22-50\" aria-hidden=\"true\"><\/a>(theta_t_post_rev, theta_t_diff_rev, f_t_rev) <span class=\"op\">=<\/span> ffbs_output<\/span>\n<span id=\"cb22-51\"><a href=\"#cb22-51\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-52\"><a href=\"#cb22-52\" aria-hidden=\"true\"><\/a>theta_t_post <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>, theta_t_post_rev[::<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>], [theta_T_post])<\/span>\n<span id=\"cb22-53\"><a href=\"#cb22-53\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-54\"><a href=\"#cb22-54\" aria-hidden=\"true\"><\/a>f_t_post <span class=\"op\">=<\/span> tt.join(<span class=\"dv\">0<\/span>, f_t_rev[::<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>], [F_tt.T.dot(theta_T_post)])<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 22\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb23\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb23-1\"><a href=\"#cb23-1\" aria-hidden=\"true\"><\/a>phi_W_a_post_tt <span class=\"op\">=<\/span> phi_W_a <span class=\"op\">+<\/span> N_obs_tt <span class=\"op\">*<\/span> <span class=\"fl\">0.5<\/span><\/span>\n<span id=\"cb23-2\"><a href=\"#cb23-2\" aria-hidden=\"true\"><\/a>phi_W_b_post_tt <span class=\"op\">=<\/span> phi_W_b <span class=\"op\">+<\/span> <span class=\"fl\">0.5<\/span> <span class=\"op\">*<\/span> tt.square(theta_t_diff_rev).<span class=\"bu\">sum<\/span>(<span class=\"dv\">0<\/span>)<\/span>\n<span id=\"cb23-3\"><a href=\"#cb23-3\" aria-hidden=\"true\"><\/a>phi_W_post_tt <span class=\"op\">=<\/span> GammaRV(phi_W_a_post_tt, phi_W_b_post_tt, rng<span class=\"op\">=<\/span>rng_tt, name<span class=\"op\">=<\/span><span class=\"st\">&#39;phi_W_post&#39;<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p><a id=\"org1057f5a\"><\/a><\/p>\n<section id=\"simulation-example-2\" class=\"level2\">\n<h2>Simulation Example<\/h2>\n<p>In Listing <a href=\"#orgdf7d0e4\">24<\/a> we sample the initial values and create Theano terms for posterior\/updated <span class=\"math inline\">\\(\\omega_t\\)<\/span>-related values.<\/p>\n<figure id=\"orgdf7d0e4\">\n<div class=\"sourceCode\" id=\"cb24\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb24-1\"><a href=\"#cb24-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> pypolyagamma <span class=\"im\">import<\/span> PyPolyaGamma<\/span>\n<span id=\"cb24-2\"><a href=\"#cb24-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-3\"><a href=\"#cb24-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.theano.random_variables <span class=\"im\">import<\/span> PolyaGammaRV<\/span>\n<span id=\"cb24-4\"><a href=\"#cb24-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-5\"><a href=\"#cb24-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-6\"><a href=\"#cb24-6\" aria-hidden=\"true\"><\/a>y_raw_tt <span class=\"op\">=<\/span> theano.shared(nb_y_sim.astype(theano.config.floatX), name<span class=\"op\">=<\/span><span class=\"st\">&#39;y_raw_t&#39;<\/span>)<\/span>\n<span id=\"cb24-7\"><a href=\"#cb24-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-8\"><a href=\"#cb24-8\" aria-hidden=\"true\"><\/a>r_sim <span class=\"op\">=<\/span> np.array(nb_dlm_sim_values[r_tt], dtype<span class=\"op\">=<\/span><span class=\"st\">&#39;double&#39;<\/span>)<\/span>\n<span id=\"cb24-9\"><a href=\"#cb24-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-10\"><a href=\"#cb24-10\" aria-hidden=\"true\"><\/a><span class=\"co\"># XXX: testing<\/span><\/span>\n<span id=\"cb24-11\"><a href=\"#cb24-11\" aria-hidden=\"true\"><\/a>F_t_theta_0 <span class=\"op\">=<\/span> np.dot(nb_theta_t_sim, nb_dlm_sim_values[F_tt])<\/span>\n<span id=\"cb24-12\"><a href=\"#cb24-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-13\"><a href=\"#cb24-13\" aria-hidden=\"true\"><\/a>omega_0 <span class=\"op\">=<\/span> np.empty(nb_y_sim.shape[<span class=\"dv\">0<\/span>], dtype<span class=\"op\">=<\/span><span class=\"st\">&#39;double&#39;<\/span>)<\/span>\n<span id=\"cb24-14\"><a href=\"#cb24-14\" aria-hidden=\"true\"><\/a>PyPolyaGamma(<span class=\"dv\">12344<\/span>).pgdrawv(r_sim <span class=\"op\">+<\/span> nb_y_sim.squeeze(),<\/span>\n<span id=\"cb24-15\"><a href=\"#cb24-15\" aria-hidden=\"true\"><\/a>                            F_t_theta_0.squeeze() <span class=\"op\">-<\/span> np.log(r_sim),<\/span>\n<span id=\"cb24-16\"><a href=\"#cb24-16\" aria-hidden=\"true\"><\/a>                            omega_0)<\/span>\n<span id=\"cb24-17\"><a href=\"#cb24-17\" aria-hidden=\"true\"><\/a>omega_0 <span class=\"op\">=<\/span> np.expand_dims(omega_0, <span class=\"op\">-<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"cb24-18\"><a href=\"#cb24-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-19\"><a href=\"#cb24-19\" aria-hidden=\"true\"><\/a>y_aug_0 <span class=\"op\">=<\/span> np.log(r_sim) <span class=\"op\">+<\/span> (nb_y_sim <span class=\"op\">-<\/span> r_sim) <span class=\"op\">\/<\/span> (<span class=\"fl\">2.0<\/span> <span class=\"op\">*<\/span> omega_0)<\/span>\n<span id=\"cb24-20\"><a href=\"#cb24-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-21\"><a href=\"#cb24-21\" aria-hidden=\"true\"><\/a>omega_t_tt <span class=\"op\">=<\/span> theano.shared(omega_0, name<span class=\"op\">=<\/span><span class=\"st\">&#39;omega_t&#39;<\/span>)<\/span>\n<span id=\"cb24-22\"><a href=\"#cb24-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-23\"><a href=\"#cb24-23\" aria-hidden=\"true\"><\/a>omega_post_tt <span class=\"op\">=<\/span> PolyaGammaRV(r_tt <span class=\"op\">+<\/span> y_raw_tt, theta_t_post.dot(F_tt) <span class=\"op\">-<\/span> tt.log(r_tt), rng<span class=\"op\">=<\/span>rng_tt, name<span class=\"op\">=<\/span><span class=\"st\">&#39;omega_post&#39;<\/span>)<\/span>\n<span id=\"cb24-24\"><a href=\"#cb24-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-25\"><a href=\"#cb24-25\" aria-hidden=\"true\"><\/a>y_aug_post_tt <span class=\"op\">=<\/span> tt.log(r_tt) <span class=\"op\">+<\/span> (y_raw_tt <span class=\"op\">-<\/span> r_tt) <span class=\"op\">\/<\/span> (<span class=\"fl\">2.0<\/span> <span class=\"op\">*<\/span> omega_post_tt)<\/span>\n<span id=\"cb24-26\"><a href=\"#cb24-26\" aria-hidden=\"true\"><\/a>y_aug_post_tt.name <span class=\"op\">=<\/span> <span class=\"st\">&#39;y_aug_post&#39;<\/span><\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 24\n<\/figcaption>\n<\/figure>\n<p>Finally, the sampler steps are defined and executed in Listing <a href=\"#org2dd74e9\">25<\/a>.<\/p>\n<figure id=\"org2dd74e9\">\n<div class=\"sourceCode\" id=\"cb25\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb25-1\"><a href=\"#cb25-1\" aria-hidden=\"true\"><\/a>nb_ffbs_dlm <span class=\"op\">=<\/span> tt_function([N_obs_tt, N_theta_tt, y_tt, G_tt, F_tt, V_t_tt, r_tt],<\/span>\n<span id=\"cb25-2\"><a href=\"#cb25-2\" aria-hidden=\"true\"><\/a>                          [theta_t_post, phi_W_post_tt, omega_post_tt, y_aug_post_tt],<\/span>\n<span id=\"cb25-3\"><a href=\"#cb25-3\" aria-hidden=\"true\"><\/a>                          updates<span class=\"op\">=<\/span>ffbs_updates)<\/span>\n<span id=\"cb25-4\"><a href=\"#cb25-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-5\"><a href=\"#cb25-5\" aria-hidden=\"true\"><\/a>rng_tt.get_value(borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>).set_state(rng_sim_state)<\/span>\n<span id=\"cb25-6\"><a href=\"#cb25-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-7\"><a href=\"#cb25-7\" aria-hidden=\"true\"><\/a>phi_W_tt.set_value(np.r_[<span class=\"fl\">10.0<\/span>, <span class=\"fl\">10.0<\/span>])<\/span>\n<span id=\"cb25-8\"><a href=\"#cb25-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-9\"><a href=\"#cb25-9\" aria-hidden=\"true\"><\/a>chain <span class=\"op\">=<\/span> <span class=\"dv\">0<\/span><\/span>\n<span id=\"cb25-10\"><a href=\"#cb25-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-11\"><a href=\"#cb25-11\" aria-hidden=\"true\"><\/a>theta_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\theta_t \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb25-12\"><a href=\"#cb25-12\" aria-hidden=\"true\"><\/a>phi_W_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\phi_W \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb25-13\"><a href=\"#cb25-13\" aria-hidden=\"true\"><\/a>omega_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\omega_t \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb25-14\"><a href=\"#cb25-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-15\"><a href=\"#cb25-15\" aria-hidden=\"true\"><\/a>theta_t_post_sim, phi_W_post_sim, omega_post_sim, y_aug_post_sim <span class=\"op\">=<\/span> <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span><\/span>\n<span id=\"cb25-16\"><a href=\"#cb25-16\" aria-hidden=\"true\"><\/a>nb_posterior_samples <span class=\"op\">=<\/span> {theta_label: [[]], phi_W_label: [[]], omega_label: [[]]}<\/span>\n<span id=\"cb25-17\"><a href=\"#cb25-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-18\"><a href=\"#cb25-18\" aria-hidden=\"true\"><\/a>V_t_sim <span class=\"op\">=<\/span> np.reciprocal(omega_0)<\/span>\n<span id=\"cb25-19\"><a href=\"#cb25-19\" aria-hidden=\"true\"><\/a>y_aug_sim <span class=\"op\">=<\/span> y_aug_0<\/span>\n<span id=\"cb25-20\"><a href=\"#cb25-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-21\"><a href=\"#cb25-21\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(<span class=\"dv\">1000<\/span>):<\/span>\n<span id=\"cb25-22\"><a href=\"#cb25-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-23\"><a href=\"#cb25-23\" aria-hidden=\"true\"><\/a>    nb_ffbs_res <span class=\"op\">=<\/span> nb_ffbs_dlm(<\/span>\n<span id=\"cb25-24\"><a href=\"#cb25-24\" aria-hidden=\"true\"><\/a>        nb_dlm_sim_values[N_obs_tt],<\/span>\n<span id=\"cb25-25\"><a href=\"#cb25-25\" aria-hidden=\"true\"><\/a>        nb_dlm_sim_values[N_theta_tt],<\/span>\n<span id=\"cb25-26\"><a href=\"#cb25-26\" aria-hidden=\"true\"><\/a>        y_aug_sim,<\/span>\n<span id=\"cb25-27\"><a href=\"#cb25-27\" aria-hidden=\"true\"><\/a>        nb_dlm_sim_values[G_tt],<\/span>\n<span id=\"cb25-28\"><a href=\"#cb25-28\" aria-hidden=\"true\"><\/a>        nb_dlm_sim_values[F_tt],<\/span>\n<span id=\"cb25-29\"><a href=\"#cb25-29\" aria-hidden=\"true\"><\/a>        V_t_sim,<\/span>\n<span id=\"cb25-30\"><a href=\"#cb25-30\" aria-hidden=\"true\"><\/a>        nb_dlm_sim_values[r_tt])<\/span>\n<span id=\"cb25-31\"><a href=\"#cb25-31\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-32\"><a href=\"#cb25-32\" aria-hidden=\"true\"><\/a>    theta_t_post_sim, phi_W_post_sim, omega_post_sim, y_aug_post_sim <span class=\"op\">=<\/span> nb_ffbs_res<\/span>\n<span id=\"cb25-33\"><a href=\"#cb25-33\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-34\"><a href=\"#cb25-34\" aria-hidden=\"true\"><\/a>    phi_W_tt.set_value(phi_W_post_sim)<\/span>\n<span id=\"cb25-35\"><a href=\"#cb25-35\" aria-hidden=\"true\"><\/a>    omega_t_tt.set_value(omega_post_sim)<\/span>\n<span id=\"cb25-36\"><a href=\"#cb25-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-37\"><a href=\"#cb25-37\" aria-hidden=\"true\"><\/a>    V_t_sim <span class=\"op\">=<\/span> np.reciprocal(omega_post_sim)<\/span>\n<span id=\"cb25-38\"><a href=\"#cb25-38\" aria-hidden=\"true\"><\/a>    y_aug_sim <span class=\"op\">=<\/span> y_aug_post_sim<\/span>\n<span id=\"cb25-39\"><a href=\"#cb25-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-40\"><a href=\"#cb25-40\" aria-hidden=\"true\"><\/a>    nb_posterior_samples[theta_label][chain].append(theta_t_post_sim)<\/span>\n<span id=\"cb25-41\"><a href=\"#cb25-41\" aria-hidden=\"true\"><\/a>    nb_posterior_samples[phi_W_label][chain].append(phi_W_post_sim)<\/span>\n<span id=\"cb25-42\"><a href=\"#cb25-42\" aria-hidden=\"true\"><\/a>    nb_posterior_samples[omega_label][chain].append(omega_post_sim)<\/span>\n<span id=\"cb25-43\"><a href=\"#cb25-43\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-44\"><a href=\"#cb25-44\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"ss\">f&#39;i=<\/span><span class=\"sc\">{i}<\/span><span class=\"ss\">,<\/span><span class=\"ch\">\\t<\/span><span class=\"ss\">phi_W=<\/span><span class=\"sc\">{<\/span>phi_W_post_sim<span class=\"sc\">}<\/span><span class=\"ss\">&#39;<\/span>)<\/span>\n<span id=\"cb25-45\"><a href=\"#cb25-45\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-46\"><a href=\"#cb25-46\" aria-hidden=\"true\"><\/a>nb_posterior_samples <span class=\"op\">=<\/span> {k: np.asarray(v) <span class=\"cf\">for<\/span> k,v <span class=\"kw\">in<\/span> nb_posterior_samples.items()}<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 25\n<\/figcaption>\n<\/figure>\n<p>Figure <a href=\"#orgf9a4406\">26<\/a> shows the posterior <span class=\"math inline\">\\(\\theta_t\\)<\/span> samples, Figure <a href=\"#orga6f3d67\">27<\/a> plots the posterior sample traces, and Figure <a href=\"#org528d6ec\">28<\/a> shows <span class=\"math inline\">\\(p_t \\mid \\theta_t\\)<\/span>.<\/p>\n<figure id=\"orgf9a4406\">\n<div class=\"sourceCode\" id=\"cb26\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb26-1\"><a href=\"#cb26-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb26-2\"><a href=\"#cb26-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-3\"><a href=\"#cb26-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb26-4\"><a href=\"#cb26-4\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb26-5\"><a href=\"#cb26-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-6\"><a href=\"#cb26-6\" aria-hidden=\"true\"><\/a>thetas_shape <span class=\"op\">=<\/span> nb_posterior_samples[theta_label][<span class=\"dv\">0<\/span>].shape<\/span>\n<span id=\"cb26-7\"><a href=\"#cb26-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-8\"><a href=\"#cb26-8\" aria-hidden=\"true\"><\/a>cycle <span class=\"op\">=<\/span> ax._get_lines.prop_cycler<\/span>\n<span id=\"cb26-9\"><a href=\"#cb26-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-10\"><a href=\"#cb26-10\" aria-hidden=\"true\"><\/a>bivariate_obs_cycler <span class=\"op\">=<\/span>  cycler(<span class=\"st\">&#39;linestyle&#39;<\/span>, [<span class=\"st\">&#39;-&#39;<\/span>, <span class=\"st\">&#39;--&#39;<\/span>]) <span class=\"op\">*<\/span> cycler(<span class=\"st\">&#39;color&#39;<\/span>, [<span class=\"st\">&#39;black&#39;<\/span>])<\/span>\n<span id=\"cb26-11\"><a href=\"#cb26-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-12\"><a href=\"#cb26-12\" aria-hidden=\"true\"><\/a>ax.set_prop_cycle(bivariate_obs_cycler)<\/span>\n<span id=\"cb26-13\"><a href=\"#cb26-13\" aria-hidden=\"true\"><\/a>ax.plot(nb_theta_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$\\theta_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span>)<\/span>\n<span id=\"cb26-14\"><a href=\"#cb26-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-15\"><a href=\"#cb26-15\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb26-16\"><a href=\"#cb26-16\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb26-17\"><a href=\"#cb26-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-18\"><a href=\"#cb26-18\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> d <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]):<\/span>\n<span id=\"cb26-19\"><a href=\"#cb26-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-20\"><a href=\"#cb26-20\" aria-hidden=\"true\"><\/a>    styles <span class=\"op\">=<\/span> <span class=\"bu\">next<\/span>(cycle)<\/span>\n<span id=\"cb26-21\"><a href=\"#cb26-21\" aria-hidden=\"true\"><\/a>    thetas <span class=\"op\">=<\/span> nb_posterior_samples[theta_label][<span class=\"dv\">0<\/span>].T[d].T<\/span>\n<span id=\"cb26-22\"><a href=\"#cb26-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-23\"><a href=\"#cb26-23\" aria-hidden=\"true\"><\/a>    theta_lines <span class=\"op\">=<\/span> np.empty(thetas_shape[:<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>] <span class=\"op\">+<\/span> (<span class=\"dv\">2<\/span>,))<\/span>\n<span id=\"cb26-24\"><a href=\"#cb26-24\" aria-hidden=\"true\"><\/a>    theta_lines.T[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.tile(np.arange(thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">2<\/span>]), [thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">3<\/span>], <span class=\"dv\">1<\/span>]).T<\/span>\n<span id=\"cb26-25\"><a href=\"#cb26-25\" aria-hidden=\"true\"><\/a>    theta_lines.T[<span class=\"dv\">1<\/span>] <span class=\"op\">=<\/span> thetas.T<\/span>\n<span id=\"cb26-26\"><a href=\"#cb26-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-27\"><a href=\"#cb26-27\" aria-hidden=\"true\"><\/a>    ax.add_collection(<\/span>\n<span id=\"cb26-28\"><a href=\"#cb26-28\" aria-hidden=\"true\"><\/a>        LineCollection(theta_lines,<\/span>\n<span id=\"cb26-29\"><a href=\"#cb26-29\" aria-hidden=\"true\"><\/a>                       label<span class=\"op\">=<\/span>theta_label,<\/span>\n<span id=\"cb26-30\"><a href=\"#cb26-30\" aria-hidden=\"true\"><\/a>                       alpha<span class=\"op\">=<\/span><span class=\"fl\">0.05<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>,<\/span>\n<span id=\"cb26-31\"><a href=\"#cb26-31\" aria-hidden=\"true\"><\/a>                       <span class=\"op\">**<\/span>styles)<\/span>\n<span id=\"cb26-32\"><a href=\"#cb26-32\" aria-hidden=\"true\"><\/a>    )<\/span>\n<span id=\"cb26-33\"><a href=\"#cb26-33\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-34\"><a href=\"#cb26-34\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb26-35\"><a href=\"#cb26-35\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-36\"><a href=\"#cb26-36\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 26\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-ffbs-sim-plot.png\" title=\"fig:\" alt=\"Posterior \\theta_t samples generated by our Po\u0301lya-Gamma FFBS sampler. \" \/>\n<figcaption>\nPosterior <span class=\"math inline\">\\(\\theta_t\\)<\/span> samples generated by our Po\u0301lya-Gamma FFBS sampler.\n<\/figcaption>\n<\/figure>\n<figure id=\"orga6f3d67\">\n<div class=\"sourceCode\" id=\"cb27\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb27-1\"><a href=\"#cb27-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> arviz <span class=\"im\">as<\/span> az<\/span>\n<span id=\"cb27-2\"><a href=\"#cb27-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb27-3\"><a href=\"#cb27-3\" aria-hidden=\"true\"><\/a>az_trace <span class=\"op\">=<\/span> az.from_dict(posterior<span class=\"op\">=<\/span>nb_posterior_samples)<\/span>\n<span id=\"cb27-4\"><a href=\"#cb27-4\" aria-hidden=\"true\"><\/a>az.plot_trace(az_trace, compact<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 27\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-ffbs-trace-plot.png\" title=\"fig:\" alt=\"Posterior sample traces for our Po\u0301lya-Gamma FFBS Gibbs sampler. \" \/>\n<figcaption>\nPosterior sample traces for our Po\u0301lya-Gamma FFBS Gibbs sampler.\n<\/figcaption>\n<\/figure>\n<figure id=\"org528d6ec\">\n<div class=\"sourceCode\" id=\"cb28\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb28-1\"><a href=\"#cb28-1\" aria-hidden=\"true\"><\/a>plt.close(<span class=\"st\">&#39;all&#39;<\/span>)<\/span>\n<span id=\"cb28-2\"><a href=\"#cb28-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-3\"><a href=\"#cb28-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb28-4\"><a href=\"#cb28-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-5\"><a href=\"#cb28-5\" aria-hidden=\"true\"><\/a>ax.plot(nb_p_t_sim, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$p_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>)<\/span>\n<span id=\"cb28-6\"><a href=\"#cb28-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-7\"><a href=\"#cb28-7\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb28-8\"><a href=\"#cb28-8\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb28-9\"><a href=\"#cb28-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-10\"><a href=\"#cb28-10\" aria-hidden=\"true\"><\/a>mu_t_sim <span class=\"op\">=<\/span> np.exp(np.dot(nb_posterior_samples[theta_label][<span class=\"dv\">0<\/span>], nb_dlm_sim_values[F_tt].squeeze()))<\/span>\n<span id=\"cb28-11\"><a href=\"#cb28-11\" aria-hidden=\"true\"><\/a>p_t_sim <span class=\"op\">=<\/span> mu_t_sim <span class=\"op\">\/<\/span> (mu_t_sim <span class=\"op\">+<\/span> nb_dlm_sim_values[r_tt])<\/span>\n<span id=\"cb28-12\"><a href=\"#cb28-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-13\"><a href=\"#cb28-13\" aria-hidden=\"true\"><\/a>p_t_lines <span class=\"op\">=<\/span> np.empty(p_t_sim.shape <span class=\"op\">+<\/span> (<span class=\"dv\">2<\/span>,))<\/span>\n<span id=\"cb28-14\"><a href=\"#cb28-14\" aria-hidden=\"true\"><\/a>p_t_lines.T[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.tile(np.arange(p_t_sim.shape[<span class=\"dv\">1<\/span>]), [p_t_sim.shape[<span class=\"dv\">0<\/span>], <span class=\"dv\">1<\/span>]).T<\/span>\n<span id=\"cb28-15\"><a href=\"#cb28-15\" aria-hidden=\"true\"><\/a>p_t_lines.T[<span class=\"dv\">1<\/span>] <span class=\"op\">=<\/span> p_t_sim.T<\/span>\n<span id=\"cb28-16\"><a href=\"#cb28-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-17\"><a href=\"#cb28-17\" aria-hidden=\"true\"><\/a>ax.add_collection(<\/span>\n<span id=\"cb28-18\"><a href=\"#cb28-18\" aria-hidden=\"true\"><\/a>    LineCollection(p_t_lines,<\/span>\n<span id=\"cb28-19\"><a href=\"#cb28-19\" aria-hidden=\"true\"><\/a>                   label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$p_t \\mid \\theta_t, D_T$&#39;<\/span>,<\/span>\n<span id=\"cb28-20\"><a href=\"#cb28-20\" aria-hidden=\"true\"><\/a>                   alpha<span class=\"op\">=<\/span><span class=\"fl\">0.3<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;red&#39;<\/span>)<\/span>\n<span id=\"cb28-21\"><a href=\"#cb28-21\" aria-hidden=\"true\"><\/a>)<\/span>\n<span id=\"cb28-22\"><a href=\"#cb28-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-23\"><a href=\"#cb28-23\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb28-24\"><a href=\"#cb28-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb28-25\"><a href=\"#cb28-25\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 28\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-ffbs-sim-pred-plot.png\" title=\"fig:\" alt=\"p_t \\mid \\theta_t, D_T samples generated by our Po\u0301lya-Gamma FFBS sampler. \" \/>\n<figcaption>\n<span class=\"math inline\">\\(p_t \\mid \\theta_t, D_T\\)<\/span> samples generated by our Po\u0301lya-Gamma FFBS sampler.\n<\/figcaption>\n<\/figure>\n<p><a id=\"orgf74c6fe\"><\/a><\/p>\n<\/section>\n<section id=\"covid19-example\" class=\"level2\">\n<h2>COVID\u201319 Example<\/h2>\n<p>As another example, we\u2019ll use data from the COVID\u201319 outbreak in New York and we\u2019ll only concentrate on the positive test counts. We do not make any assertions about the results; we simply want to apply our implementation to a real-life dataset. If anything, we would simply like to highlight the potential provided by these flexible and constructive Bayesian frameworks and offer a concrete starting point for motivated statisticians and Python developers.<\/p>\n<p>To start, Listing <a href=\"#org703a0b2\">29<\/a> pulls and formats the COVID\u201319 data and Figure <a href=\"#org9c1b4d3\">30<\/a> plots it.<\/p>\n<figure id=\"org703a0b2\">\n<div class=\"sourceCode\" id=\"cb29\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb29-1\"><a href=\"#cb29-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pandas <span class=\"im\">as<\/span> pd<\/span>\n<span id=\"cb29-2\"><a href=\"#cb29-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb29-3\"><a href=\"#cb29-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb29-4\"><a href=\"#cb29-4\" aria-hidden=\"true\"><\/a>url <span class=\"op\">=<\/span> <span class=\"st\">&#39;https:\/\/covidtracking.com\/api\/v1\/states\/daily.csv&#39;<\/span><\/span>\n<span id=\"cb29-5\"><a href=\"#cb29-5\" aria-hidden=\"true\"><\/a>states <span class=\"op\">=<\/span> pd.read_csv(url, parse_dates<span class=\"op\">=<\/span>[<span class=\"st\">&#39;date&#39;<\/span>], index_col<span class=\"op\">=<\/span>[<span class=\"st\">&#39;state&#39;<\/span>,<span class=\"st\">&#39;date&#39;<\/span>]).sort_index()<\/span>\n<span id=\"cb29-6\"><a href=\"#cb29-6\" aria-hidden=\"true\"><\/a>state <span class=\"op\">=<\/span> <span class=\"st\">&#39;NY&#39;<\/span><\/span>\n<span id=\"cb29-7\"><a href=\"#cb29-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb29-8\"><a href=\"#cb29-8\" aria-hidden=\"true\"><\/a>counts <span class=\"op\">=<\/span> states.loc[state, [<span class=\"st\">&#39;positive&#39;<\/span>, <span class=\"st\">&#39;total&#39;<\/span>]].diff().loc[<span class=\"st\">&#39;2020-03-20&#39;<\/span>:]<\/span>\n<span id=\"cb29-9\"><a href=\"#cb29-9\" aria-hidden=\"true\"><\/a>counts.columns <span class=\"op\">=<\/span> [<span class=\"st\">&#39;new_positive_tests&#39;<\/span>, <span class=\"st\">&#39;new_total_tests&#39;<\/span>]<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 29\n<\/figcaption>\n<\/figure>\n<figure id=\"org9c1b4d3\">\n<div class=\"sourceCode\" id=\"cb30\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb30-1\"><a href=\"#cb30-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb30-2\"><a href=\"#cb30-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb30-3\"><a href=\"#cb30-3\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb30-4\"><a href=\"#cb30-4\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> ax.plot(counts.new_positive_tests, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$y_t$&#39;<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.2<\/span>, drawstyle<span class=\"op\">=<\/span><span class=\"st\">&#39;steps-pre&#39;<\/span>)<\/span>\n<span id=\"cb30-5\"><a href=\"#cb30-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb30-6\"><a href=\"#cb30-6\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb30-7\"><a href=\"#cb30-7\" aria-hidden=\"true\"><\/a>plt.legend()<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 30\n<\/figcaption>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-dlm-ny-plot.png\" title=\"fig:\" alt=\"COVID\u201319 positive test counts in NY \" \/>\n<figcaption>\nCOVID\u201319 positive test counts in NY\n<\/figcaption>\n<\/figure>\n<p>Our example model is a simple Gaussian walk with an unknown evolution variance (i.e.\u00a0again, modeled by a gamma prior on the precision <span class=\"math inline\">\\(\\phi_W\\)<\/span>) given by the following values:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    F_t = \\begin{pmatrix}\n      1\n    \\end{pmatrix}\n    ,\\quad\n    G_t = \\begin{pmatrix}\n      1\n    \\end{pmatrix}\n    \\\\\n    W_t = \\begin{pmatrix}\n      \\phi_{W}^{-1}\n    \\end{pmatrix}\n    , \\quad\n    \\phi_{W} \\sim \\operatorname{Gamma}\\left(2.5, 0.5\\right)\n    , \\quad\n    \\phi_{W} = 10\n  \\end{gathered}\n  \\label{eq:ny-model-settings}\n  \\;.\n\\end{equation}\\]<\/span><\/p>\n<p>Again, we set <span class=\"math inline\">\\(r = 1000\\)<\/span>, effectively making our negative-binomial a Poisson approximation. Listing <a href=\"#org2352b66\">31<\/a> sets the initial values in <span class=\"math inline\">\\(\\eqref{eq:ny-model-settings}\\)<\/span> and Listing <a href=\"#orgde4fb94\">32<\/a> specifies the sampler loop.<\/p>\n<figure id=\"org2352b66\">\n<div class=\"sourceCode\" id=\"cb31\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb31-1\"><a href=\"#cb31-1\" aria-hidden=\"true\"><\/a>nb_dlm_ny_values <span class=\"op\">=<\/span> {<\/span>\n<span id=\"cb31-2\"><a href=\"#cb31-2\" aria-hidden=\"true\"><\/a>    N_obs_tt: counts.new_positive_tests.shape[<span class=\"dv\">0<\/span>],<\/span>\n<span id=\"cb31-3\"><a href=\"#cb31-3\" aria-hidden=\"true\"><\/a>    N_theta_tt: <span class=\"dv\">1<\/span><\/span>\n<span id=\"cb31-4\"><a href=\"#cb31-4\" aria-hidden=\"true\"><\/a>}<\/span>\n<span id=\"cb31-5\"><a href=\"#cb31-5\" aria-hidden=\"true\"><\/a>nb_dlm_ny_values[F_tt] <span class=\"op\">=<\/span> np.array([[<span class=\"fl\">1.0<\/span>]], dtype<span class=\"op\">=<\/span>theano.config.floatX)<\/span>\n<span id=\"cb31-6\"><a href=\"#cb31-6\" aria-hidden=\"true\"><\/a>nb_dlm_ny_values[G_tt] <span class=\"op\">=<\/span> np.array([[<span class=\"fl\">1.0<\/span>]], dtype<span class=\"op\">=<\/span>theano.config.floatX)<\/span>\n<span id=\"cb31-7\"><a href=\"#cb31-7\" aria-hidden=\"true\"><\/a>nb_dlm_ny_values[r_tt] <span class=\"op\">=<\/span> <span class=\"dv\">1000<\/span><\/span>\n<span id=\"cb31-8\"><a href=\"#cb31-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb31-9\"><a href=\"#cb31-9\" aria-hidden=\"true\"><\/a>phi_W_a.set_value(np.r_[<span class=\"fl\">2.5<\/span>])<\/span>\n<span id=\"cb31-10\"><a href=\"#cb31-10\" aria-hidden=\"true\"><\/a>phi_W_b.set_value(np.r_[<span class=\"fl\">0.5<\/span>])<\/span>\n<span id=\"cb31-11\"><a href=\"#cb31-11\" aria-hidden=\"true\"><\/a>phi_W_tt.set_value(np.r_[<span class=\"fl\">10.0<\/span>])<\/span>\n<span id=\"cb31-12\"><a href=\"#cb31-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb31-13\"><a href=\"#cb31-13\" aria-hidden=\"true\"><\/a>rng_tt.get_value(borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>).set_state(rng_init_state)<\/span>\n<span id=\"cb31-14\"><a href=\"#cb31-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb31-15\"><a href=\"#cb31-15\" aria-hidden=\"true\"><\/a>y_raw_tt.set_value(np.expand_dims(counts.new_positive_tests, <span class=\"op\">-<\/span><span class=\"dv\">1<\/span>).astype(theano.config.floatX))<\/span>\n<span id=\"cb31-16\"><a href=\"#cb31-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb31-17\"><a href=\"#cb31-17\" aria-hidden=\"true\"><\/a>r_ny <span class=\"op\">=<\/span> np.array(nb_dlm_ny_values[r_tt], dtype<span class=\"op\">=<\/span><span class=\"st\">&#39;double&#39;<\/span>)<\/span>\n<span id=\"cb31-18\"><a href=\"#cb31-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb31-19\"><a href=\"#cb31-19\" aria-hidden=\"true\"><\/a>omega_0 <span class=\"op\">=<\/span> np.empty(counts.new_positive_tests.shape[<span class=\"dv\">0<\/span>], dtype<span class=\"op\">=<\/span><span class=\"st\">&#39;double&#39;<\/span>)<\/span>\n<span id=\"cb31-20\"><a href=\"#cb31-20\" aria-hidden=\"true\"><\/a>PyPolyaGamma(<span class=\"dv\">12344<\/span>).pgdrawv(r_ny <span class=\"op\">+<\/span> counts.new_positive_tests.values,<\/span>\n<span id=\"cb31-21\"><a href=\"#cb31-21\" aria-hidden=\"true\"><\/a>                            F_t_theta_0.squeeze() <span class=\"op\">-<\/span> np.log(r_ny),<\/span>\n<span id=\"cb31-22\"><a href=\"#cb31-22\" aria-hidden=\"true\"><\/a>                            omega_0)<\/span>\n<span id=\"cb31-23\"><a href=\"#cb31-23\" aria-hidden=\"true\"><\/a>omega_0 <span class=\"op\">=<\/span> np.expand_dims(omega_0, <span class=\"op\">-<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"cb31-24\"><a href=\"#cb31-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb31-25\"><a href=\"#cb31-25\" aria-hidden=\"true\"><\/a>y_aug_0 <span class=\"op\">=<\/span> np.log(r_ny) <span class=\"op\">+<\/span> (counts.new_positive_tests.values[:, np.newaxis] <span class=\"op\">-<\/span> r_ny) <span class=\"op\">\/<\/span> (<span class=\"fl\">2.0<\/span> <span class=\"op\">*<\/span> omega_0)<\/span>\n<span id=\"cb31-26\"><a href=\"#cb31-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb31-27\"><a href=\"#cb31-27\" aria-hidden=\"true\"><\/a>omega_t_tt.set_value(omega_0)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 31\n<\/figcaption>\n<\/figure>\n<figure id=\"orgde4fb94\">\n<div class=\"sourceCode\" id=\"cb32\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb32-1\"><a href=\"#cb32-1\" aria-hidden=\"true\"><\/a>nb_ffbs_dlm <span class=\"op\">=<\/span> tt_function([N_obs_tt, N_theta_tt, y_tt, G_tt, F_tt, V_t_tt, r_tt],<\/span>\n<span id=\"cb32-2\"><a href=\"#cb32-2\" aria-hidden=\"true\"><\/a>                          [theta_t_post, phi_W_post_tt, omega_post_tt, y_aug_post_tt],<\/span>\n<span id=\"cb32-3\"><a href=\"#cb32-3\" aria-hidden=\"true\"><\/a>                          <span class=\"co\"># mode=&#39;FAST_COMPILE&#39;,<\/span><\/span>\n<span id=\"cb32-4\"><a href=\"#cb32-4\" aria-hidden=\"true\"><\/a>                          updates<span class=\"op\">=<\/span>ffbs_updates)<\/span>\n<span id=\"cb32-5\"><a href=\"#cb32-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-6\"><a href=\"#cb32-6\" aria-hidden=\"true\"><\/a>rng_tt.get_value(borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>).set_state(rng_sim_state)<\/span>\n<span id=\"cb32-7\"><a href=\"#cb32-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-8\"><a href=\"#cb32-8\" aria-hidden=\"true\"><\/a>chain <span class=\"op\">=<\/span> <span class=\"dv\">0<\/span><\/span>\n<span id=\"cb32-9\"><a href=\"#cb32-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-10\"><a href=\"#cb32-10\" aria-hidden=\"true\"><\/a>theta_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\theta_t \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb32-11\"><a href=\"#cb32-11\" aria-hidden=\"true\"><\/a>phi_W_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\phi_W \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb32-12\"><a href=\"#cb32-12\" aria-hidden=\"true\"><\/a>omega_label <span class=\"op\">=<\/span> <span class=\"vs\">r&#39;$\\omega_t \\mid D_T$&#39;<\/span><\/span>\n<span id=\"cb32-13\"><a href=\"#cb32-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-14\"><a href=\"#cb32-14\" aria-hidden=\"true\"><\/a>theta_t_post_sim, phi_W_post_sim, omega_post_sim, y_aug_post_sim <span class=\"op\">=<\/span> <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span>, <span class=\"va\">None<\/span><\/span>\n<span id=\"cb32-15\"><a href=\"#cb32-15\" aria-hidden=\"true\"><\/a>nb_ny_posterior_samples <span class=\"op\">=<\/span> {theta_label: [[]], phi_W_label: [[]], omega_label: [[]]}<\/span>\n<span id=\"cb32-16\"><a href=\"#cb32-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-17\"><a href=\"#cb32-17\" aria-hidden=\"true\"><\/a>V_t_sim <span class=\"op\">=<\/span> np.reciprocal(omega_0)<\/span>\n<span id=\"cb32-18\"><a href=\"#cb32-18\" aria-hidden=\"true\"><\/a>y_aug_sim <span class=\"op\">=<\/span> y_aug_0<\/span>\n<span id=\"cb32-19\"><a href=\"#cb32-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-20\"><a href=\"#cb32-20\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(<span class=\"dv\">5000<\/span>):<\/span>\n<span id=\"cb32-21\"><a href=\"#cb32-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-22\"><a href=\"#cb32-22\" aria-hidden=\"true\"><\/a>    nb_ffbs_res <span class=\"op\">=<\/span> nb_ffbs_dlm(<\/span>\n<span id=\"cb32-23\"><a href=\"#cb32-23\" aria-hidden=\"true\"><\/a>        nb_dlm_ny_values[N_obs_tt],<\/span>\n<span id=\"cb32-24\"><a href=\"#cb32-24\" aria-hidden=\"true\"><\/a>        nb_dlm_ny_values[N_theta_tt],<\/span>\n<span id=\"cb32-25\"><a href=\"#cb32-25\" aria-hidden=\"true\"><\/a>        y_aug_sim,<\/span>\n<span id=\"cb32-26\"><a href=\"#cb32-26\" aria-hidden=\"true\"><\/a>        nb_dlm_ny_values[G_tt],<\/span>\n<span id=\"cb32-27\"><a href=\"#cb32-27\" aria-hidden=\"true\"><\/a>        nb_dlm_ny_values[F_tt],<\/span>\n<span id=\"cb32-28\"><a href=\"#cb32-28\" aria-hidden=\"true\"><\/a>        V_t_sim,<\/span>\n<span id=\"cb32-29\"><a href=\"#cb32-29\" aria-hidden=\"true\"><\/a>        nb_dlm_ny_values[r_tt])<\/span>\n<span id=\"cb32-30\"><a href=\"#cb32-30\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-31\"><a href=\"#cb32-31\" aria-hidden=\"true\"><\/a>    theta_t_post_sim, phi_W_post_sim, omega_post_sim, y_aug_post_sim <span class=\"op\">=<\/span> nb_ffbs_res<\/span>\n<span id=\"cb32-32\"><a href=\"#cb32-32\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-33\"><a href=\"#cb32-33\" aria-hidden=\"true\"><\/a>    phi_W_tt.set_value(phi_W_post_sim)<\/span>\n<span id=\"cb32-34\"><a href=\"#cb32-34\" aria-hidden=\"true\"><\/a>    omega_t_tt.set_value(omega_post_sim)<\/span>\n<span id=\"cb32-35\"><a href=\"#cb32-35\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-36\"><a href=\"#cb32-36\" aria-hidden=\"true\"><\/a>    V_t_sim <span class=\"op\">=<\/span> np.reciprocal(omega_post_sim)<\/span>\n<span id=\"cb32-37\"><a href=\"#cb32-37\" aria-hidden=\"true\"><\/a>    y_aug_sim <span class=\"op\">=<\/span> y_aug_post_sim<\/span>\n<span id=\"cb32-38\"><a href=\"#cb32-38\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-39\"><a href=\"#cb32-39\" aria-hidden=\"true\"><\/a>    nb_ny_posterior_samples[theta_label][chain].append(theta_t_post_sim)<\/span>\n<span id=\"cb32-40\"><a href=\"#cb32-40\" aria-hidden=\"true\"><\/a>    nb_ny_posterior_samples[phi_W_label][chain].append(phi_W_post_sim)<\/span>\n<span id=\"cb32-41\"><a href=\"#cb32-41\" aria-hidden=\"true\"><\/a>    nb_ny_posterior_samples[omega_label][chain].append(omega_post_sim)<\/span>\n<span id=\"cb32-42\"><a href=\"#cb32-42\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-43\"><a href=\"#cb32-43\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"ss\">f&#39;i=<\/span><span class=\"sc\">{i}<\/span><span class=\"ss\">,<\/span><span class=\"ch\">\\t<\/span><span class=\"ss\">phi_W=<\/span><span class=\"sc\">{<\/span>phi_W_post_sim<span class=\"sc\">}<\/span><span class=\"ss\">&#39;<\/span>)<\/span>\n<span id=\"cb32-44\"><a href=\"#cb32-44\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb32-45\"><a href=\"#cb32-45\" aria-hidden=\"true\"><\/a><span class=\"co\"># Convert and thin samples<\/span><\/span>\n<span id=\"cb32-46\"><a href=\"#cb32-46\" aria-hidden=\"true\"><\/a>nb_ny_posterior_samples <span class=\"op\">=<\/span> {k: np.asarray(v)[:, <span class=\"dv\">1000<\/span>:] <span class=\"cf\">for<\/span> k,v <span class=\"kw\">in<\/span> nb_ny_posterior_samples.items()}<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 32\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb33\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb33-1\"><a href=\"#cb33-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> arviz <span class=\"im\">as<\/span> az<\/span>\n<span id=\"cb33-2\"><a href=\"#cb33-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb33-3\"><a href=\"#cb33-3\" aria-hidden=\"true\"><\/a>az_trace <span class=\"op\">=<\/span> az.from_dict(posterior<span class=\"op\">=<\/span>nb_ny_posterior_samples)<\/span>\n<span id=\"cb33-4\"><a href=\"#cb33-4\" aria-hidden=\"true\"><\/a>az.plot_trace(az_trace, compact<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-ffbs-ny-trace-plot.png\" title=\"fig:\" alt=\"Posterior sample traces for the NY data. \" \/>\n<figcaption>\nPosterior sample traces for the NY data.\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb34\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb34-1\"><a href=\"#cb34-1\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb34-2\"><a href=\"#cb34-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-3\"><a href=\"#cb34-3\" aria-hidden=\"true\"><\/a>thetas_shape <span class=\"op\">=<\/span> nb_ny_posterior_samples[theta_label][<span class=\"dv\">0<\/span>].shape<\/span>\n<span id=\"cb34-4\"><a href=\"#cb34-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-5\"><a href=\"#cb34-5\" aria-hidden=\"true\"><\/a>cycle <span class=\"op\">=<\/span> ax._get_lines.prop_cycler<\/span>\n<span id=\"cb34-6\"><a href=\"#cb34-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-7\"><a href=\"#cb34-7\" aria-hidden=\"true\"><\/a>bivariate_obs_cycler <span class=\"op\">=<\/span>  cycler(<span class=\"st\">&#39;linestyle&#39;<\/span>, [<span class=\"st\">&#39;-&#39;<\/span>, <span class=\"st\">&#39;--&#39;<\/span>]) <span class=\"op\">*<\/span> cycler(<span class=\"st\">&#39;color&#39;<\/span>, [<span class=\"st\">&#39;black&#39;<\/span>])<\/span>\n<span id=\"cb34-8\"><a href=\"#cb34-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-9\"><a href=\"#cb34-9\" aria-hidden=\"true\"><\/a>ax.set_prop_cycle(bivariate_obs_cycler)<\/span>\n<span id=\"cb34-10\"><a href=\"#cb34-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-11\"><a href=\"#cb34-11\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> d <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>]):<\/span>\n<span id=\"cb34-12\"><a href=\"#cb34-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-13\"><a href=\"#cb34-13\" aria-hidden=\"true\"><\/a>    styles <span class=\"op\">=<\/span> <span class=\"bu\">next<\/span>(cycle)<\/span>\n<span id=\"cb34-14\"><a href=\"#cb34-14\" aria-hidden=\"true\"><\/a>    thetas <span class=\"op\">=<\/span> nb_ny_posterior_samples[theta_label][<span class=\"dv\">0<\/span>].T[d].T<\/span>\n<span id=\"cb34-15\"><a href=\"#cb34-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-16\"><a href=\"#cb34-16\" aria-hidden=\"true\"><\/a>    theta_lines <span class=\"op\">=<\/span> np.empty(thetas_shape[:<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>] <span class=\"op\">+<\/span> (<span class=\"dv\">2<\/span>,))<\/span>\n<span id=\"cb34-17\"><a href=\"#cb34-17\" aria-hidden=\"true\"><\/a>    theta_lines.T[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.tile(np.arange(thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">2<\/span>]), [thetas_shape[<span class=\"op\">-<\/span><span class=\"dv\">3<\/span>], <span class=\"dv\">1<\/span>]).T<\/span>\n<span id=\"cb34-18\"><a href=\"#cb34-18\" aria-hidden=\"true\"><\/a>    theta_lines.T[<span class=\"dv\">1<\/span>] <span class=\"op\">=<\/span> thetas.T<\/span>\n<span id=\"cb34-19\"><a href=\"#cb34-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-20\"><a href=\"#cb34-20\" aria-hidden=\"true\"><\/a>    ax.add_collection(<\/span>\n<span id=\"cb34-21\"><a href=\"#cb34-21\" aria-hidden=\"true\"><\/a>        LineCollection(theta_lines,<\/span>\n<span id=\"cb34-22\"><a href=\"#cb34-22\" aria-hidden=\"true\"><\/a>                       label<span class=\"op\">=<\/span>theta_label,<\/span>\n<span id=\"cb34-23\"><a href=\"#cb34-23\" aria-hidden=\"true\"><\/a>                       alpha<span class=\"op\">=<\/span><span class=\"fl\">0.05<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>,<\/span>\n<span id=\"cb34-24\"><a href=\"#cb34-24\" aria-hidden=\"true\"><\/a>                       <span class=\"op\">**<\/span>styles)<\/span>\n<span id=\"cb34-25\"><a href=\"#cb34-25\" aria-hidden=\"true\"><\/a>    )<\/span>\n<span id=\"cb34-26\"><a href=\"#cb34-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-27\"><a href=\"#cb34-27\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb34-28\"><a href=\"#cb34-28\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb34-29\"><a href=\"#cb34-29\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-30\"><a href=\"#cb34-30\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-ffbs-ny-sim-plot.png\" title=\"fig:\" alt=\"Posterior \\theta_t samples for the negative-binomial NY data model. \" \/>\n<figcaption>\nPosterior <span class=\"math inline\">\\(\\theta_t\\)<\/span> samples for the negative-binomial NY data model.\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb35\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb35-1\"><a href=\"#cb35-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> scipy<\/span>\n<span id=\"cb35-2\"><a href=\"#cb35-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-3\"><a href=\"#cb35-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-4\"><a href=\"#cb35-4\" aria-hidden=\"true\"><\/a>plt.close(<span class=\"st\">&#39;all&#39;<\/span>)<\/span>\n<span id=\"cb35-5\"><a href=\"#cb35-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-6\"><a href=\"#cb35-6\" aria-hidden=\"true\"><\/a>fig, ax <span class=\"op\">=<\/span> plt.subplots(figsize<span class=\"op\">=<\/span>(<span class=\"dv\">8<\/span>, <span class=\"fl\">4.8<\/span>))<\/span>\n<span id=\"cb35-7\"><a href=\"#cb35-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-8\"><a href=\"#cb35-8\" aria-hidden=\"true\"><\/a>ax.plot(counts.new_positive_tests.values, label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$y_t$&#39;<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>)<\/span>\n<span id=\"cb35-9\"><a href=\"#cb35-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-10\"><a href=\"#cb35-10\" aria-hidden=\"true\"><\/a>mu_t_sim <span class=\"op\">=<\/span> np.exp(np.dot(nb_ny_posterior_samples[theta_label][<span class=\"dv\">0<\/span>], nb_dlm_ny_values[F_tt].squeeze()))<\/span>\n<span id=\"cb35-11\"><a href=\"#cb35-11\" aria-hidden=\"true\"><\/a>p_t_sim <span class=\"op\">=<\/span> mu_t_sim <span class=\"op\">\/<\/span> (mu_t_sim <span class=\"op\">+<\/span> nb_dlm_ny_values[r_tt])<\/span>\n<span id=\"cb35-12\"><a href=\"#cb35-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-13\"><a href=\"#cb35-13\" aria-hidden=\"true\"><\/a>y_t_sim <span class=\"op\">=<\/span> scipy.stats.nbinom.rvs(nb_dlm_ny_values[r_tt], (<span class=\"fl\">1.<\/span> <span class=\"op\">-<\/span> p_t_sim)).squeeze()<\/span>\n<span id=\"cb35-14\"><a href=\"#cb35-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-15\"><a href=\"#cb35-15\" aria-hidden=\"true\"><\/a>y_t_lines <span class=\"op\">=<\/span> np.empty(y_t_sim.shape <span class=\"op\">+<\/span> (<span class=\"dv\">2<\/span>,))<\/span>\n<span id=\"cb35-16\"><a href=\"#cb35-16\" aria-hidden=\"true\"><\/a>y_t_lines.T[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.tile(np.arange(y_t_sim.shape[<span class=\"dv\">1<\/span>]), [y_t_sim.shape[<span class=\"dv\">0<\/span>], <span class=\"dv\">1<\/span>]).T<\/span>\n<span id=\"cb35-17\"><a href=\"#cb35-17\" aria-hidden=\"true\"><\/a>y_t_lines.T[<span class=\"dv\">1<\/span>] <span class=\"op\">=<\/span> y_t_sim.T<\/span>\n<span id=\"cb35-18\"><a href=\"#cb35-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-19\"><a href=\"#cb35-19\" aria-hidden=\"true\"><\/a>ax.add_collection(<\/span>\n<span id=\"cb35-20\"><a href=\"#cb35-20\" aria-hidden=\"true\"><\/a>    LineCollection(y_t_lines,<\/span>\n<span id=\"cb35-21\"><a href=\"#cb35-21\" aria-hidden=\"true\"><\/a>                   label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$y_t \\mid \\theta_t, D_T$&#39;<\/span>,<\/span>\n<span id=\"cb35-22\"><a href=\"#cb35-22\" aria-hidden=\"true\"><\/a>                   alpha<span class=\"op\">=<\/span><span class=\"fl\">0.3<\/span>, linewidth<span class=\"op\">=<\/span><span class=\"fl\">0.9<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;red&#39;<\/span>),<\/span>\n<span id=\"cb35-23\"><a href=\"#cb35-23\" aria-hidden=\"true\"><\/a>)<\/span>\n<span id=\"cb35-24\"><a href=\"#cb35-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-25\"><a href=\"#cb35-25\" aria-hidden=\"true\"><\/a>ax.autoscale(enable<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb35-26\"><a href=\"#cb35-26\" aria-hidden=\"true\"><\/a>plt.tight_layout()<\/span>\n<span id=\"cb35-27\"><a href=\"#cb35-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb35-28\"><a href=\"#cb35-28\" aria-hidden=\"true\"><\/a>plt.legend(framealpha<span class=\"op\">=<\/span><span class=\"fl\">0.4<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/nb-ffbs-ny-pred-plot.png\" title=\"fig:\" alt=\"Posterior predictive values for the negative-binomial NY data model. \" \/>\n<figcaption>\nPosterior predictive values for the negative-binomial NY data model.\n<\/figcaption>\n<\/figure>\n<p><a id=\"org24da3be\"><\/a><\/p>\n<\/section>\n<\/section>\n<section id=\"discussion\" class=\"level1\">\n<h1>Discussion<\/h1>\n<p>We\u2019ve implemented a highly extensible Bayesian timeseries framework with respectable model coverage (e.g.\u00a0ARIMA models, polynomial trends, Fourier seasonality, dynamic regression, etc.), and shown how such a basic class can be extended to non-Gaussian observations. Additionally, we were able to show how an existing model-specialized sampler, i.e.\u00a0FFBS, can be adapted alongside such an extension using the conditional linearity provided by a Gaussian scale-mixture.<\/p>\n<p>In general, the same scale-mixture and conditional linearity techniques can be employed for other observation distributions (e.g.\u00a0Beta, Binomial), as well as state-of-the-art shrinkage and sparsity priors <a id=\"4902d622d6ba9f93a92f53f2df8a726b\"><a href=\"#BhadraDefaultBayesiananalysis2016\">(Bhadra, Datta, Polson &amp; Willard 2016)<\/a><\/a>,<a id=\"4d6738c60064f8c0e0d27f8cd67cff8c\"><a href=\"#datta2016bayesian\">(Datta &amp; Dunson 2016)<\/a><\/a>.<\/p>\n<p>Although we used Gibbs sampling, these techniques do not necessitate it; one can use Metropolis\u2013or any other mathematically sound\u2013steps instead.<\/p>\n<p>Bayesian frameworks like these have considerable flexibility and are amenable to much analytic creativity; unfortunately, we don\u2019t often see this\u2013directly, at least\u2013in practice. There are clearly some roadblocks to the average applied statistics\/data science developer when attempting to implement these models straight from published papers. Those roadblocks often involve subtle numeric stability issues and a sometimes non-standard mathematics proficiency.<\/p>\n<p>We hope that future automations will be able to address more straight-forward numerical stability issues by\u2013for instance\u2013automating the SVD-based reformulations used herein. The scale mixture manipulations are likewise amenable to certain automations, but, at the very least, it would serve the Bayesian community\u2013and all who desire to construct non-trivial principled statistical models\u2013well to develop low-level tools that codify and generate such scale-mixture variations. In this way, even experimentation by knowledgeable developers will be much easier and less error-prone.<\/p>\n<\/section>\n<section id=\"bibliography\" class=\"level1\">\n<h1>Bibliography<\/h1>\n<p><a id=\"harrison_bayesian_1999\"><\/a> Harrison &amp; West, Bayesian Forecasting &amp; Dynamic Models, Springer (1999). <a href=\"#4bbd465b4e78e5c5151b0cbba54d984e\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"PetrisDynamicLinearModels2009\"><\/a> Petris, Petrone &amp; Campagnoli, Dynamic Linear Models, Springer (2009). <a href=\"#1c3f471fd137724bd53b01eb4d6534fe\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"GamermanMarkovchainMonte2006\"><\/a> Gamerman &amp; Lopes, Markov Chain Monte Carlo: Stochastic Simulation for Bayesian Inference, CRC Press (2006). <a href=\"#8247a823e73eba801aad2942a49b03be\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"pole_applied_1994\"><\/a> Pole, West &amp; Harrison, Applied Bayesian Forecasting and Time Series Analysis, CRC Press (1994). <a href=\"#c6a5d83a82b5963f42264d85625b5153\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"bergstra_theano:_2010\"><\/a> Bergstra, Breuleux, Bastien, Lamblin, Pascanu, Desjardins, Turian, Warde-Farley &amp; Bengio, Theano: A CPU and GPU Math Expression Compiler, in in: Proceedings of the Python for Scientific Computing Conference (SciPy), edited by (2010) <a href=\"#25fdcab375e353d7bb32eec7b13064c6\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"Willardsymbolicpymc2019\"><\/a> Willard, Symbolic-Pymc, <i><\/i>, (2019). <a href=\"https:\/\/github.com\/pymc-devs\/symbolic-pymc\">link<\/a>. <a href=\"#e8fa26b92264fca946be25e6b617fc56\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"ZhangFixedintervalsmoothingalgorithm1996\"><\/a> Zhang &amp; Li, Fixed-Interval Smoothing Algorithm Based on Singular Value Decomposition, 916\u2013921, in in: , Proceedings of the 1996 IEEE International Conference on Control Applications, 1996, edited by (1996) <a href=\"#0ae04c048b20d07f32d7f0f75bb51483\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"PetrisDynamiclinearmodels2009\"><\/a> Petris, Petrone &amp; Campagnoli, Dynamic Linear Models, Springer (2009). <a href=\"#3a4d89388a434d7b1b91dc8690f3a03b\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"Fruhwirth-SchnatterDataaugmentationdynamic1994\"><\/a> Fr\"uhwirth-Schnatter, Data Augmentation and Dynamic Linear Models, <i>Journal of time series analysis<\/i>, <b>15(2)<\/b>, 183\u2013202 (1994). <a href=\"http:\/\/onlinelibrary.wiley.com\/doi\/10.1111\/j.1467-9892.1994.tb00184.x\/abstract\">link<\/a>. <a href=\"#66fb99775f308e808a193bd7bb2d2038\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"mccullagh_generalized_1989\"><\/a> McCullagh &amp; Nelder, Generalized Linear Models, CRC press (1989). <a href=\"#820d466dbb5494bd5a5cb080d0b82638\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"polson_bayesian_2013\"><\/a> Polson, Scott &amp; Windle, Bayesian Inference for Logistic Models Using P'olyaLatent Variables, <i>Journal of the American Statistical Association<\/i>, <b>108(504)<\/b>, 1339\u20131349 (2013). <a href=\"http:\/\/www.tandfonline.com\/doi\/abs\/10.1080\/01621459.2013.829001\">link<\/a>. <a href=\"#2776c69c558410bc65a3a0a0f1ff5bf4\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"BhadraDefaultBayesiananalysis2016\"><\/a> Bhadra, Datta, Polson &amp; Willard, Default Bayesian analysis with global-local shrinkage priors, <i>Biometrika<\/i>, <b>103(4)<\/b>, 955\u2013969 (2016). <a href=\"http:\/\/dx.doi.org\/10.1093\/biomet\/asw041\">doi<\/a>. <a href=\"#4902d622d6ba9f93a92f53f2df8a726b\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"datta2016bayesian\"><\/a> Datta &amp; Dunson, Bayesian Inference on Quasi-Sparse Count Data, <i>Biometrika<\/i>, <b>103(4)<\/b>, 971\u2013983 (2016). <a href=\"#4d6738c60064f8c0e0d27f8cd67cff8c\">\u21a9\ufe0e<\/a><\/p>\n<\/section>\n<\/body>\n<\/html>\n","category":[{"@attributes":{"term":"articles"}},{"@attributes":{"term":"symbolic-pymc"}},{"@attributes":{"term":"theano"}},{"@attributes":{"term":"statistics"}},{"@attributes":{"term":"timeseries"}},{"@attributes":{"term":"dlm"}},{"@attributes":{"term":"ffbs"}},{"@attributes":{"term":"gibbs"}}]},{"title":"Symbolic PyMC Radon Example in PyMC4","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/symbolic-pymc-radon-example-in-pymc4.html","rel":"alternate"}},"published":"2019-09-08T00:00:00-05:00","updated":"2019-10-24T00:00:00-05:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2019-09-08:\/symbolic-pymc-radon-example-in-pymc4.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>Symbolic PyMC Radon Example in PyMC4<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">Symbolic PyMC Radon Example in PyMC4<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2019\u201309\u201308<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p><a href=\"https:\/\/github.com\/pymc-devs\/symbolic-pymc\">Symbolic PyMC<\/a> is a library that provides tools for symbolic manipulation of Tensor library models in TensorFlow (TF) and Theano. Over time, we plan to add tools that are mostly specialized toward Bayesian model manipulation and mathematical identities relevant to MCMC.<\/p>\n<p>The main approach taken by <code>symbolic-pymc<\/code> is relational\/logic programming powered by a <a href=\"http:\/\/minikanren.org\/\">miniKanren<\/a> implementation in pure Python (based on the <a href=\"https:\/\/github.com\/pymc-devs\/kanren\"><code>kanren<\/code><\/a> package).<\/p>\n<p>As an example of how <code>symbolic-pymc<\/code>\u2019s offerings can be used, we\u2019ll create a model \u201coptimizer\u201d that approximates the re-centering and re-scaling commonly demonstrated on a hierarchical normal model for the radon dataset. This optimization is <strong>symbolic<\/strong> and effectively produces another equivalent model with better sampling properties.<\/p>\n<p>A similar example already exists in Theano and PyMC3; it can be found in the <a href=\"https:\/\/github.com\/pymc-devs\/symbolic-pymc#automatic-re-centering-and-re-scaling\">project README<\/a>. In this case, we will operate on TF graphs via PyMC4 and approximate the same optimization using a very different approach targeted toward the log-likelihood graph.<\/p>\n<p>To get started, we download the radon dataset and define the initial model in Listings <a href=\"#org15f2d13\">1<\/a>, <a href=\"#org84a4bd7\">2<\/a>, and <a href=\"#orgf0b697c\">3<\/a>.<\/p>\n<figure id=\"org15f2d13\">\n<div class=\"sourceCode\" id=\"cb1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> numpy <span class=\"im\">as<\/span> np<\/span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pandas <span class=\"im\">as<\/span> pd<\/span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> tensorflow <span class=\"im\">as<\/span> tf<\/span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pymc4 <span class=\"im\">as<\/span> pm<\/span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> arviz <span class=\"im\">as<\/span> az<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 1\n<\/figcaption>\n<\/figure>\n<figure id=\"org84a4bd7\">\n<div class=\"sourceCode\" id=\"cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb2-1\"><a href=\"#cb2-1\" aria-hidden=\"true\"><\/a>data <span class=\"op\">=<\/span> pd.read_csv(<span class=\"st\">&#39;https:\/\/github.com\/pymc-devs\/pymc3\/raw\/master\/pymc3\/examples\/data\/radon.csv&#39;<\/span>)<\/span>\n<span id=\"cb2-2\"><a href=\"#cb2-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-3\"><a href=\"#cb2-3\" aria-hidden=\"true\"><\/a>county_names <span class=\"op\">=<\/span> data.county.unique()<\/span>\n<span id=\"cb2-4\"><a href=\"#cb2-4\" aria-hidden=\"true\"><\/a>county_idx <span class=\"op\">=<\/span> data[<span class=\"st\">&#39;county_code&#39;<\/span>].values.astype(np.int32)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 2\n<\/figcaption>\n<\/figure>\n<figure id=\"orgf0b697c\">\n<div class=\"sourceCode\" id=\"cb3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb3-1\"><a href=\"#cb3-1\" aria-hidden=\"true\"><\/a><span class=\"at\">@pm.model<\/span><\/span>\n<span id=\"cb3-2\"><a href=\"#cb3-2\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> hierarchical_model(data, county_idx):<\/span>\n<span id=\"cb3-3\"><a href=\"#cb3-3\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Hyperpriors<\/span><\/span>\n<span id=\"cb3-4\"><a href=\"#cb3-4\" aria-hidden=\"true\"><\/a>    mu_a <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.Normal(<span class=\"st\">&#39;mu_alpha&#39;<\/span>, mu<span class=\"op\">=<\/span><span class=\"fl\">0.<\/span>, sigma<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"cb3-5\"><a href=\"#cb3-5\" aria-hidden=\"true\"><\/a>    sigma_a <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.HalfCauchy(<span class=\"st\">&#39;sigma_alpha&#39;<\/span>, beta<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"cb3-6\"><a href=\"#cb3-6\" aria-hidden=\"true\"><\/a>    mu_b <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.Normal(<span class=\"st\">&#39;mu_beta&#39;<\/span>, mu<span class=\"op\">=<\/span><span class=\"fl\">0.<\/span>, sigma<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"cb3-7\"><a href=\"#cb3-7\" aria-hidden=\"true\"><\/a>    sigma_b <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.HalfCauchy(<span class=\"st\">&#39;sigma_beta&#39;<\/span>, beta<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"cb3-8\"><a href=\"#cb3-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-9\"><a href=\"#cb3-9\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Intercept for each county, distributed around group mean mu_a<\/span><\/span>\n<span id=\"cb3-10\"><a href=\"#cb3-10\" aria-hidden=\"true\"><\/a>    a <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.Normal(<span class=\"st\">&#39;alpha&#39;<\/span>, mu<span class=\"op\">=<\/span>mu_a, sigma<span class=\"op\">=<\/span>sigma_a, plate<span class=\"op\">=<\/span><span class=\"bu\">len<\/span>(data.county.unique()))<\/span>\n<span id=\"cb3-11\"><a href=\"#cb3-11\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Intercept for each county, distributed around group mean mu_a<\/span><\/span>\n<span id=\"cb3-12\"><a href=\"#cb3-12\" aria-hidden=\"true\"><\/a>    b <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.Normal(<span class=\"st\">&#39;beta&#39;<\/span>, mu<span class=\"op\">=<\/span>mu_b, sigma<span class=\"op\">=<\/span>sigma_b, plate<span class=\"op\">=<\/span><span class=\"bu\">len<\/span>(data.county.unique()))<\/span>\n<span id=\"cb3-13\"><a href=\"#cb3-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-14\"><a href=\"#cb3-14\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Model error<\/span><\/span>\n<span id=\"cb3-15\"><a href=\"#cb3-15\" aria-hidden=\"true\"><\/a>    eps <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.HalfCauchy(<span class=\"st\">&#39;eps&#39;<\/span>, beta<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"cb3-16\"><a href=\"#cb3-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-17\"><a href=\"#cb3-17\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Expected value<\/span><\/span>\n<span id=\"cb3-18\"><a href=\"#cb3-18\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#radon_est = a[county_idx] + b[county_idx] * data.floor.values<\/span><\/span>\n<span id=\"cb3-19\"><a href=\"#cb3-19\" aria-hidden=\"true\"><\/a>    radon_est <span class=\"op\">=<\/span> tf.gather(a, county_idx) <span class=\"op\">+<\/span> tf.gather(<\/span>\n<span id=\"cb3-20\"><a href=\"#cb3-20\" aria-hidden=\"true\"><\/a>        b, county_idx) <span class=\"op\">*<\/span> data.floor.values<\/span>\n<span id=\"cb3-21\"><a href=\"#cb3-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-22\"><a href=\"#cb3-22\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Data likelihood<\/span><\/span>\n<span id=\"cb3-23\"><a href=\"#cb3-23\" aria-hidden=\"true\"><\/a>    y_like <span class=\"op\">=<\/span> <span class=\"cf\">yield<\/span> pm.Normal(<span class=\"st\">&#39;y_like&#39;<\/span>, mu<span class=\"op\">=<\/span>radon_est, sigma<span class=\"op\">=<\/span>eps, observed<span class=\"op\">=<\/span>data.log_radon)<\/span>\n<span id=\"cb3-24\"><a href=\"#cb3-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-25\"><a href=\"#cb3-25\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-26\"><a href=\"#cb3-26\" aria-hidden=\"true\"><\/a>init_num_chains <span class=\"op\">=<\/span> <span class=\"dv\">50<\/span><\/span>\n<span id=\"cb3-27\"><a href=\"#cb3-27\" aria-hidden=\"true\"><\/a>model <span class=\"op\">=<\/span> hierarchical_model(data, county_idx)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 3\n<\/figcaption>\n<\/figure>\n<p>In Listing <a href=\"#org76f5220\">5<\/a>, we estimate the model using the sample routine from <a href=\"https:\/\/github.com\/pymc-devs\/pymc4\/blob\/master\/notebooks\/radon_hierarchical.ipynb\">PyMC4\u2019s Radon example Notebook<\/a> (reproduced in Listing <a href=\"#org9df08c0\">4<\/a>). The same plots from the aforementioned notebook are also reproduced here in Figures <a href=\"#org2d6c05e\">7<\/a> and <a href=\"#orgef38802\">8<\/a>.<\/p>\n<figure id=\"org9df08c0\">\n<div class=\"sourceCode\" id=\"cb4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb4-1\"><a href=\"#cb4-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> sample(model, init_num_chains<span class=\"op\">=<\/span><span class=\"dv\">50<\/span>, num_samples<span class=\"op\">=<\/span><span class=\"dv\">500<\/span>, burn_in<span class=\"op\">=<\/span><span class=\"dv\">500<\/span>):<\/span>\n<span id=\"cb4-2\"><a href=\"#cb4-2\" aria-hidden=\"true\"><\/a>    init_num_chains <span class=\"op\">=<\/span> <span class=\"dv\">50<\/span><\/span>\n<span id=\"cb4-3\"><a href=\"#cb4-3\" aria-hidden=\"true\"><\/a>    pm4_trace, _ <span class=\"op\">=<\/span> pm.inference.sampling.sample(<\/span>\n<span id=\"cb4-4\"><a href=\"#cb4-4\" aria-hidden=\"true\"><\/a>        model, num_chains<span class=\"op\">=<\/span>init_num_chains, num_samples<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>, burn_in<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>, step_size<span class=\"op\">=<\/span><span class=\"fl\">1.<\/span>, xla<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb4-5\"><a href=\"#cb4-5\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(<span class=\"dv\">3<\/span>):<\/span>\n<span id=\"cb4-6\"><a href=\"#cb4-6\" aria-hidden=\"true\"><\/a>        step_size_ <span class=\"op\">=<\/span> []<\/span>\n<span id=\"cb4-7\"><a href=\"#cb4-7\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">for<\/span> _, x <span class=\"kw\">in<\/span> pm4_trace.items():<\/span>\n<span id=\"cb4-8\"><a href=\"#cb4-8\" aria-hidden=\"true\"><\/a>            std <span class=\"op\">=<\/span> tf.math.reduce_std(x, axis<span class=\"op\">=<\/span>[<span class=\"dv\">0<\/span>, <span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb4-9\"><a href=\"#cb4-9\" aria-hidden=\"true\"><\/a>            step_size_.append(<\/span>\n<span id=\"cb4-10\"><a href=\"#cb4-10\" aria-hidden=\"true\"><\/a>                std[tf.newaxis, ...] <span class=\"op\">*<\/span> tf.ones([init_num_chains] <span class=\"op\">+<\/span> std.shape, dtype<span class=\"op\">=<\/span>std.dtype))<\/span>\n<span id=\"cb4-11\"><a href=\"#cb4-11\" aria-hidden=\"true\"><\/a>        pm4_trace, _ <span class=\"op\">=<\/span> pm.inference.sampling.sample(<\/span>\n<span id=\"cb4-12\"><a href=\"#cb4-12\" aria-hidden=\"true\"><\/a>            model, num_chains<span class=\"op\">=<\/span>init_num_chains, num_samples<span class=\"op\">=<\/span><span class=\"dv\">10<\/span> <span class=\"op\">+<\/span> <span class=\"dv\">10<\/span><span class=\"op\">*<\/span>i, burn_in<span class=\"op\">=<\/span><span class=\"dv\">10<\/span> <span class=\"op\">+<\/span> <span class=\"dv\">10<\/span><span class=\"op\">*<\/span>i,<\/span>\n<span id=\"cb4-13\"><a href=\"#cb4-13\" aria-hidden=\"true\"><\/a>            step_size<span class=\"op\">=<\/span>step_size_, xla<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb4-14\"><a href=\"#cb4-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-15\"><a href=\"#cb4-15\" aria-hidden=\"true\"><\/a>    num_chains <span class=\"op\">=<\/span> <span class=\"dv\">5<\/span><\/span>\n<span id=\"cb4-16\"><a href=\"#cb4-16\" aria-hidden=\"true\"><\/a>    step_size_ <span class=\"op\">=<\/span> []<\/span>\n<span id=\"cb4-17\"><a href=\"#cb4-17\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">for<\/span> _, x <span class=\"kw\">in<\/span> pm4_trace.items():<\/span>\n<span id=\"cb4-18\"><a href=\"#cb4-18\" aria-hidden=\"true\"><\/a>        std <span class=\"op\">=<\/span> tf.math.reduce_std(x, axis<span class=\"op\">=<\/span>[<span class=\"dv\">0<\/span>, <span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb4-19\"><a href=\"#cb4-19\" aria-hidden=\"true\"><\/a>        step_size_.append(<\/span>\n<span id=\"cb4-20\"><a href=\"#cb4-20\" aria-hidden=\"true\"><\/a>            std[tf.newaxis, ...] <span class=\"op\">*<\/span> tf.ones([num_chains]<span class=\"op\">+<\/span>std.shape, dtype<span class=\"op\">=<\/span>std.dtype))<\/span>\n<span id=\"cb4-21\"><a href=\"#cb4-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-22\"><a href=\"#cb4-22\" aria-hidden=\"true\"><\/a>    pm4_trace, sample_stat <span class=\"op\">=<\/span> pm.inference.sampling.sample(<\/span>\n<span id=\"cb4-23\"><a href=\"#cb4-23\" aria-hidden=\"true\"><\/a>        model, num_chains<span class=\"op\">=<\/span>num_chains, num_samples<span class=\"op\">=<\/span>num_samples, burn_in<span class=\"op\">=<\/span>burn_in,<\/span>\n<span id=\"cb4-24\"><a href=\"#cb4-24\" aria-hidden=\"true\"><\/a>        step_size<span class=\"op\">=<\/span>step_size_, xla<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb4-25\"><a href=\"#cb4-25\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-26\"><a href=\"#cb4-26\" aria-hidden=\"true\"><\/a>    az_trace <span class=\"op\">=<\/span> pm.inference.utils.trace_to_arviz(pm4_trace, sample_stat)<\/span>\n<span id=\"cb4-27\"><a href=\"#cb4-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-28\"><a href=\"#cb4-28\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> az_trace<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 4\n<\/figcaption>\n<\/figure>\n<figure id=\"org76f5220\">\n<div class=\"sourceCode\" id=\"cb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb5-1\"><a href=\"#cb5-1\" aria-hidden=\"true\"><\/a>az_trace <span class=\"op\">=<\/span> sample(model)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 5\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb6-1\"><a href=\"#cb6-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> matplotlib.pyplot <span class=\"im\">as<\/span> plt<\/span>\n<span id=\"cb6-2\"><a href=\"#cb6-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-3\"><a href=\"#cb6-3\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> seaborn <span class=\"im\">as<\/span> sns<\/span>\n<span id=\"cb6-4\"><a href=\"#cb6-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-5\"><a href=\"#cb6-5\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> matplotlib <span class=\"im\">import<\/span> rcParams<\/span>\n<span id=\"cb6-6\"><a href=\"#cb6-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-7\"><a href=\"#cb6-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-8\"><a href=\"#cb6-8\" aria-hidden=\"true\"><\/a>rcParams[<span class=\"st\">&#39;figure.figsize&#39;<\/span>] <span class=\"op\">=<\/span> (<span class=\"fl\">11.7<\/span>, <span class=\"fl\">8.27<\/span>)<\/span>\n<span id=\"cb6-9\"><a href=\"#cb6-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-10\"><a href=\"#cb6-10\" aria-hidden=\"true\"><\/a><span class=\"co\"># plt.rc(&#39;text&#39;, usetex=True)<\/span><\/span>\n<span id=\"cb6-11\"><a href=\"#cb6-11\" aria-hidden=\"true\"><\/a>sns.set_style(<span class=\"st\">&quot;whitegrid&quot;<\/span>)<\/span>\n<span id=\"cb6-12\"><a href=\"#cb6-12\" aria-hidden=\"true\"><\/a>sns.set_context(<span class=\"st\">&quot;paper&quot;<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb7-1\"><a href=\"#cb7-1\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> az.plot_energy(az_trace)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"fig:pymc4-radon-plot-energy\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/pymc4-radon-plot-energy.png\" title=\"fig:\" alt=\"\" \/>\n<figcaption>\n<\/figcaption>\n<\/figure>\n<figure id=\"fig:pymc4-radon-plot-trace\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/pymc4-radon-plot-trace.png\" title=\"fig:\" alt=\"\" \/>\n<figcaption>\n<\/figcaption>\n<\/figure>\n<\/section>\n<section id=\"the-models-log-likelihood-graph\" class=\"level1\">\n<h1>The Model\u2019s Log-likelihood Graph<\/h1>\n<p>In order to apply our optimization, we need to obtain a graph of the log-likelihood function generated by the model in Listing <a href=\"#orgf0b697c\">3<\/a>. With the graph in-hand, we can perform the re-centering and re-scaling transform\u2013in log-space\u2013and produce a new log-likelihood graph that improves sampling.<\/p>\n<p>This exercise introduces the TensorFlow function-graph backed by the class <code>tensorflow.python.framework.func_graph.FuncGraph<\/code>. <code>FuncGraph<\/code> is a subclass of the regular <code>Graph<\/code> objects upon which <code>symbolic-pymc<\/code> indirectly operates. Just like Theano\u2019s <code>FunctionGraph<\/code>s, <code>FuncGraph<\/code> simply specializes a generic graph by specifying which constituent tensors are considered inputs and outputs.<\/p>\n<p>In Listing <a href=\"#orgc606190\">8<\/a>, we use PyMC4\u2019s internal mechanisms to build the log-likelihood function for our model and a corresponding list of initial values for the parameters.<\/p>\n<figure id=\"orgc606190\">\n<div class=\"sourceCode\" id=\"cb8\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb8-1\"><a href=\"#cb8-1\" aria-hidden=\"true\"><\/a>state <span class=\"op\">=<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"cb8-2\"><a href=\"#cb8-2\" aria-hidden=\"true\"><\/a>observed <span class=\"op\">=<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"cb8-3\"><a href=\"#cb8-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-4\"><a href=\"#cb8-4\" aria-hidden=\"true\"><\/a>logpfn, init <span class=\"op\">=<\/span> pm.inference.sampling.build_logp_function(model,<\/span>\n<span id=\"cb8-5\"><a href=\"#cb8-5\" aria-hidden=\"true\"><\/a>                                                         state<span class=\"op\">=<\/span>state,<\/span>\n<span id=\"cb8-6\"><a href=\"#cb8-6\" aria-hidden=\"true\"><\/a>                                                         observed<span class=\"op\">=<\/span>observed)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 8\n<\/figcaption>\n<\/figure>\n<p>From here we need <code>FuncGraph<\/code>s for each input to <code>logpfn<\/code>. Since <code>logpfn<\/code> is a <code>tensorflow.python.eager.def_function.Function<\/code> instance, every time it\u2019s called with a specific tensor it may create a new function-object with its own <code>FuncGraph<\/code>. In other words, it dynamically generates function objects based on the inputs it\u2019s given.<\/p>\n<p>This specialization process can be performed manually using <code>logpfn.get_concrete_function(*args)<\/code>, which necessarily produces a <code>tensorflow.python.eager.function.ConcreteFunction<\/code> with the desired <code>FuncGraph<\/code>. Listing <a href=\"#org966cccf\">9<\/a> creates and extracts these two objects.<\/p>\n<figure id=\"org966cccf\">\n<div class=\"sourceCode\" id=\"cb9\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb9-1\"><a href=\"#cb9-1\" aria-hidden=\"true\"><\/a>logpfn_cf <span class=\"op\">=<\/span> logpfn.get_concrete_function(<span class=\"op\">*<\/span>init.values())<\/span>\n<span id=\"cb9-2\"><a href=\"#cb9-2\" aria-hidden=\"true\"><\/a>logpfn_fg <span class=\"op\">=<\/span> logpfn_cf.graph<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 9\n<\/figcaption>\n<\/figure>\n<p>The outputs are now available in graph form as <code>logpfn_fg.outputs<\/code>.<\/p>\n<\/section>\n<section id=\"the-log-space-transform\" class=\"level1\">\n<h1>The Log-space Transform<\/h1>\n<p>Consider the following two equivalent hierarchical models,<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    Y = X + \\epsilon, \\quad\n    \\epsilon \\sim \\operatorname{N}\\left(0, \\sigma^2\\right)\n    \\\\\n    X \\sim \\operatorname{N}\\left(\\mu, \\tau^2\\right)\n  \\end{gathered}\n\\label{eq:model-1}\n\\end{equation}\\]<\/span><\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\begin{gathered}\n    Y = \\mu + \\tau \\cdot \\tilde{X} + \\epsilon, \\quad\n    \\epsilon \\sim \\operatorname{N}\\left(0, \\sigma^2\\right)\n    \\\\\n    \\tilde{X} \\sim \\operatorname{N}\\left(0, 1\\right)\n  \\;.\n  \\end{gathered}\n\\label{eq:model-2}\n\\end{equation}\\]<\/span><\/p>\n<p>Models <span class=\"math inline\">\\(\\eqref{eq:model-1}\\)<\/span> and <span class=\"math inline\">\\(\\eqref{eq:model-2}\\)<\/span> are represented in (log) measure space, respectively, as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{align}\n    \\log p(Y, X) &amp;= \\log P(Y\\mid X) + \\log P(X)\n    \\nonumber\n    \\\\\n    &amp;= C - \\frac{1}{2} \\left(\\frac{y}{\\sigma} - \\frac{x}{\\sigma}\\right)^2 -\n       \\frac{1}{2} \\left(\\frac{x}{\\tau} - \\frac{\\mu}{\\tau}\\right)^2\n    \\label{eq:log-model-1}\n    \\\\\n    &amp;= \\tilde{C} - \\frac{1}{2} \\left(\\frac{y}{\\sigma} - \\frac{\\mu - \\tau \\cdot \\tilde{x}}{\\sigma}\\right)^2 - \\frac{1}{2} \\tilde{x}^2\n  \\label{eq:log-model-2}\n  \\;.\n\\end{align}\\]<\/span><\/p>\n<p>Via term rewriting, Equation <span class=\"math inline\">\\(\\eqref{eq:log-model-2}\\)<\/span> is produced\u2013in part\u2013by applying the replacement rule <span class=\"math inline\">\\(x \\to \\mu + \\tau \\cdot \\tilde{x}\\)<\/span> to Equation <span class=\"math inline\">\\(\\eqref{eq:log-model-1}\\)<\/span>, i.e.<\/p>\n<p><span class=\"math display\">\\[\\begin{align*}\n\\tilde{C} - \\frac{1}{2} \\left(\\frac{y}{\\sigma} - \\frac{\\mu + \\tau \\cdot \\tilde{x}}{\\sigma}\\right)^2 -\n  \\frac{1}{2} \\left(\\frac{\\mu + \\tau \\cdot \\tilde{x}}{\\tau} - \\frac{\\mu}{\\tau}\\right)^2\n\\;.\n\\end{align*}\\]<\/span><\/p>\n<p>For consistency, the transform must also be applied to the <span class=\"math inline\">\\(dx\\)<\/span> term where\/when-ever it is considered.<\/p>\n<p>After a few algebraic simplifications, one obtains the exact form of Equation <span class=\"math inline\">\\(\\eqref{eq:log-model-2}\\)<\/span>.<\/p>\n<\/section>\n<section id=\"creating-the-minikanren-goals\" class=\"level1\">\n<h1>Creating the miniKanren Goals<\/h1>\n<p><code>symbolic-pymc<\/code> is designed to use miniKanren as a means of specifying mathematical relations. The degree to which an implementation of a mathematical relation upholds its known characteristics is\u2013of course\u2013always up to the developer. For the needs of PPLs like PyMC4, we can\u2019t reasonably expect\u2013or provide\u2013capabilities at the level of automatic theorem proving or every relevant state-of-the-art symbolic math routine.<\/p>\n<p>Even so, we <strong>do<\/strong> expect that some capabilities from within those more advanced areas of symbolic computing will eventually be required\u2013or necessary\u2013and we want to build on a foundation that allows them to be integrated and\/or simply expressed. We believe that miniKanren is a great foundation for such work due to the core concepts it shares with symbolic computation, as well as its immense flexibility. It also maintains an elegant simplicity and is amenable to developer intervention at nearly all levels\u2013often without the need for low- or DSL-level rewrites.<\/p>\n<p>User-level development in miniKanren occurs within its DSL, which is a succinct relational\/logic programming paradigm that\u2013in our case\u2013is entirely written in Python. This DSL provides primitive <strong>goals<\/strong> that can be composed and eventually evaluated by the <code>run<\/code> function. We refer the reader to any one of the many great introductions to miniKanren available at <a href=\"http:\/\/minikanren.org\" class=\"uri\">http:\/\/minikanren.org<\/a>, or, for the specific Python package used here: <a href=\"https:\/\/github.com\/logpy\/logpy\/blob\/master\/doc\/basic.md\">this simple introduction<\/a>.<\/p>\n<p>For the matter at hand, we need to create goals that implement the substitution described above. The first step is to understand the exact TF graphs involved, and the best way to do that is to construct the relevant graph objects, observe them directly, and build \u201cpatterns\u201d that match their general forms. Patterns are built with <code>symbolic-pymc<\/code> meta objects obtained from the <code>mt<\/code> helper \u201cnamespace\u201d. Wherever we want to leave room for variation\/ambiguity, we use a \u201clogic variable\u201d instead of an explicit TF (meta) object. Logic variables are created with <code>var()<\/code> and can optionally be given a string \u201cname\u201d argument that identifies them globally as a singleton-like object.<\/p>\n<section id=\"inspecting-the-tf-graphs\" class=\"level2\">\n<h2>Inspecting the TF Graphs<\/h2>\n<p>In our case, the log-density returned by PyMC4\u2013via the TensorFlow Probability library (TFP)\u2013 uses <code>tf.math.squared_difference<\/code> to construct the \u201csquared error\u201d term in the exponential of a normal distribution. This term contains everything we need to construct the substitution as a pair of TF graph objects.<\/p>\n<p>Listing <a href=\"#orgc47d26d\">10<\/a> shows the graph produced by a normal distribution in TFP.<\/p>\n<figure id=\"orgc47d26d\">\n<div class=\"sourceCode\" id=\"cb10\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb10-1\"><a href=\"#cb10-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> tensorflow_probability <span class=\"im\">as<\/span> tfp<\/span>\n<span id=\"cb10-2\"><a href=\"#cb10-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-3\"><a href=\"#cb10-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.eager.context <span class=\"im\">import<\/span> graph_mode<\/span>\n<span id=\"cb10-4\"><a href=\"#cb10-4\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.framework.ops <span class=\"im\">import<\/span> disable_tensor_equality<\/span>\n<span id=\"cb10-5\"><a href=\"#cb10-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-6\"><a href=\"#cb10-6\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.tensorflow.printing <span class=\"im\">import<\/span> tf_dprint<\/span>\n<span id=\"cb10-7\"><a href=\"#cb10-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-8\"><a href=\"#cb10-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-9\"><a href=\"#cb10-9\" aria-hidden=\"true\"><\/a>disable_tensor_equality()<\/span>\n<span id=\"cb10-10\"><a href=\"#cb10-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-11\"><a href=\"#cb10-11\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> graph_mode(), tf.Graph().as_default() <span class=\"im\">as<\/span> test_graph:<\/span>\n<span id=\"cb10-12\"><a href=\"#cb10-12\" aria-hidden=\"true\"><\/a>    mu_tf <span class=\"op\">=<\/span> tf.compat.v1.placeholder(tf.float32, name<span class=\"op\">=<\/span><span class=\"st\">&#39;mu&#39;<\/span>,<\/span>\n<span id=\"cb10-13\"><a href=\"#cb10-13\" aria-hidden=\"true\"><\/a>                                     shape<span class=\"op\">=<\/span>tf.TensorShape([<span class=\"va\">None<\/span>]))<\/span>\n<span id=\"cb10-14\"><a href=\"#cb10-14\" aria-hidden=\"true\"><\/a>    tau_tf <span class=\"op\">=<\/span> tf.compat.v1.placeholder(tf.float32, name<span class=\"op\">=<\/span><span class=\"st\">&#39;tau&#39;<\/span>,<\/span>\n<span id=\"cb10-15\"><a href=\"#cb10-15\" aria-hidden=\"true\"><\/a>                                      shape<span class=\"op\">=<\/span>tf.TensorShape([<span class=\"va\">None<\/span>]))<\/span>\n<span id=\"cb10-16\"><a href=\"#cb10-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-17\"><a href=\"#cb10-17\" aria-hidden=\"true\"><\/a>    normal_tfp <span class=\"op\">=<\/span> tfp.distributions.normal.Normal(mu_tf, tau_tf)<\/span>\n<span id=\"cb10-18\"><a href=\"#cb10-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-19\"><a href=\"#cb10-19\" aria-hidden=\"true\"><\/a>    value_tf <span class=\"op\">=<\/span> tf.compat.v1.placeholder(tf.float32, name<span class=\"op\">=<\/span><span class=\"st\">&#39;value&#39;<\/span>,<\/span>\n<span id=\"cb10-20\"><a href=\"#cb10-20\" aria-hidden=\"true\"><\/a>                                        shape<span class=\"op\">=<\/span>tf.TensorShape([<span class=\"va\">None<\/span>]))<\/span>\n<span id=\"cb10-21\"><a href=\"#cb10-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-22\"><a href=\"#cb10-22\" aria-hidden=\"true\"><\/a>    normal_log_lik <span class=\"op\">=<\/span> normal_tfp.log_prob(value_tf)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 10\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb11-1\"><a href=\"#cb11-1\" aria-hidden=\"true\"><\/a>tf_dprint(normal_log_lik)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(Sub):0,  shape=[None]    &quot;Normal_1\/log_prob\/sub:0&quot;\n|  Tensor(Mul):0,   shape=[None]    &quot;Normal_1\/log_prob\/mul:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal_1\/log_prob\/mul\/x:0&quot;\n|  |  |  -0.5\n|  |  Tensor(SquaredDifference):0,  shape=[None]    &quot;Normal_1\/log_prob\/SquaredDifference:0&quot;\n|  |  |  Tensor(RealDiv):0, shape=[None]    &quot;Normal_1\/log_prob\/truediv:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;value:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;tau:0&quot;\n|  |  |  Tensor(RealDiv):0, shape=[None]    &quot;Normal_1\/log_prob\/truediv_1:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;mu:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;tau:0&quot;\n|  Tensor(AddV2):0, shape=[None]    &quot;Normal_1\/log_prob\/add:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal_1\/log_prob\/add\/x:0&quot;\n|  |  |  0.9189385\n|  |  Tensor(Log):0,    shape=[None]    &quot;Normal_1\/log_prob\/Log:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[None]    &quot;tau:0&quot;\n\n<\/code><\/pre>\n<\/figure>\n<p>Instead of looking for the entire log-likelihood graph for a distribution, we can focus on only the <code>SquaredDifference<\/code> operators, since they contain all the relevant terms for our transformation.<\/p>\n<p>More specifically, if we can identify \u201cchains\u201d of such terms, i.e.\u00a0<code>SquaredDifference(y, x)<\/code> and <code>SquaredDifference(x, mu)<\/code>, then we might be able to assume that the corresponding subgraph was formed from such a hierarchical normal model.<\/p>\n<p>Listing <a href=\"#orgacfcea8\">13<\/a> shows the <code>SquaredDifference<\/code> sub-graphs in the log-likelihood graph for our radon model. It demonstrates two instances of said <code>SquaredDifference<\/code> \u201cchains\u201d: they involve tensors named <code>values_5<\/code> and <code>values_1<\/code>.<\/p>\n<figure id=\"orgacfcea8\">\n<div class=\"sourceCode\" id=\"cb13\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb13-1\"><a href=\"#cb13-1\" aria-hidden=\"true\"><\/a>square_diff_outs <span class=\"op\">=<\/span> [o.outputs[<span class=\"dv\">0<\/span>] <span class=\"cf\">for<\/span> o <span class=\"kw\">in<\/span> logpfn_fg.get_operations()<\/span>\n<span id=\"cb13-2\"><a href=\"#cb13-2\" aria-hidden=\"true\"><\/a>                    <span class=\"cf\">if<\/span> o.<span class=\"bu\">type<\/span> <span class=\"op\">==<\/span> <span class=\"st\">&#39;SquaredDifference&#39;<\/span> <span class=\"kw\">or<\/span> o.<span class=\"bu\">type<\/span>.startswith(<span class=\"st\">&#39;Gather&#39;<\/span>)]<\/span>\n<span id=\"cb13-3\"><a href=\"#cb13-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-4\"><a href=\"#cb13-4\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> t <span class=\"kw\">in<\/span> square_diff_outs:<\/span>\n<span id=\"cb13-5\"><a href=\"#cb13-5\" aria-hidden=\"true\"><\/a>    tf_dprint(t)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 13\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(GatherV2):0, shape=[919] &quot;GatherV2:0&quot;\n|  Tensor(Placeholder):0,   shape=[85]  &quot;values_3:0&quot;\n|  Tensor(Const):0, shape=[919] &quot;GatherV2\/indices:0&quot;\n|  |  [ 0  0  0 ... 83 84 84]\n|  Tensor(Const):0, shape=[]    &quot;GatherV2\/axis:0&quot;\n|  |  0\nTensor(GatherV2):0, shape=[919] &quot;GatherV2_1:0&quot;\n|  Tensor(Placeholder):0,   shape=[85]  &quot;values_2:0&quot;\n|  Tensor(Const):0, shape=[919] &quot;GatherV2_1\/indices:0&quot;\n|  |  [ 0  0  0 ... 83 84 84]\n|  Tensor(Const):0, shape=[]    &quot;GatherV2_1\/axis:0&quot;\n|  |  0\nTensor(SquaredDifference):0,    shape=[]    &quot;Normal_5\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(RealDiv):0,   shape=[]    &quot;Normal_5\/log_prob\/truediv:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_1:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal\/scale:0&quot;\n|  |  |  1.\n|  Tensor(RealDiv):0,   shape=[]    &quot;Normal_5\/log_prob\/truediv_1:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal\/loc:0&quot;\n|  |  |  0.\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal\/scale:0&quot;\n|  |  |  1.\nTensor(SquaredDifference):0,    shape=[]    &quot;Normal_1_1\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(RealDiv):0,   shape=[]    &quot;Normal_1_1\/log_prob\/truediv:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_4:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal_1\/scale:0&quot;\n|  |  |  1.\n|  Tensor(RealDiv):0,   shape=[]    &quot;Normal_1_1\/log_prob\/truediv_1:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal_1\/loc:0&quot;\n|  |  |  0.\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal_1\/scale:0&quot;\n|  |  |  1.\nTensor(SquaredDifference):0,    shape=[85]  &quot;SampleNormal_2_1\/log_prob\/Normal_2\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(RealDiv):0,   shape=[85]  &quot;SampleNormal_2_1\/log_prob\/Normal_2\/log_prob\/truediv:0&quot;\n|  |  Tensor(Transpose):0,  shape=[85]  &quot;SampleNormal_2_1\/log_prob\/transpose:0&quot;\n|  |  |  Tensor(Reshape):0, shape=[85]  &quot;SampleNormal_2_1\/log_prob\/Reshape:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[85]  &quot;values_3:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[1]   &quot;SampleNormal_2_1\/log_prob\/Reshape\/shape:0&quot;\n|  |  |  |  |  [85]\n|  |  |  Tensor(Const):0,   shape=[1]   &quot;SampleNormal_2_1\/log_prob\/transpose\/perm:0&quot;\n|  |  |  |  [0]\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_1\/forward\/Exp:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[]    &quot;values_0:0&quot;\n|  Tensor(RealDiv):0,   shape=[]    &quot;SampleNormal_2_1\/log_prob\/Normal_2\/log_prob\/truediv_1:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_1:0&quot;\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_1\/forward\/Exp:0&quot;\n|  |  |  ...\nTensor(SquaredDifference):0,    shape=[85]  &quot;SampleNormal_3_1\/log_prob\/Normal_3\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(RealDiv):0,   shape=[85]  &quot;SampleNormal_3_1\/log_prob\/Normal_3\/log_prob\/truediv:0&quot;\n|  |  Tensor(Transpose):0,  shape=[85]  &quot;SampleNormal_3_1\/log_prob\/transpose:0&quot;\n|  |  |  Tensor(Reshape):0, shape=[85]  &quot;SampleNormal_3_1\/log_prob\/Reshape:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[85]  &quot;values_2:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[1]   &quot;SampleNormal_3_1\/log_prob\/Reshape\/shape:0&quot;\n|  |  |  |  |  [85]\n|  |  |  Tensor(Const):0,   shape=[1]   &quot;SampleNormal_3_1\/log_prob\/transpose\/perm:0&quot;\n|  |  |  |  [0]\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_2_1\/forward\/Exp:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[]    &quot;values_5:0&quot;\n|  Tensor(RealDiv):0,   shape=[]    &quot;SampleNormal_3_1\/log_prob\/Normal_3\/log_prob\/truediv_1:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_4:0&quot;\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_2_1\/forward\/Exp:0&quot;\n|  |  |  ...\nTensor(SquaredDifference):0,    shape=[919] &quot;Normal_4_1\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(RealDiv):0,   shape=[919] &quot;Normal_4_1\/log_prob\/truediv:0&quot;\n|  |  Tensor(Const):0,  shape=[919] &quot;Normal_4_1\/log_prob\/value:0&quot;\n|  |  |  [0.8329091 0.8329091 1.0986123 ... 1.6292405 1.3350011 1.0986123]\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_3_1\/forward\/Exp:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[]    &quot;values_6:0&quot;\n|  Tensor(RealDiv):0,   shape=[919] &quot;Normal_4_1\/log_prob\/truediv_1:0&quot;\n|  |  Tensor(AddV2):0,  shape=[919] &quot;add:0&quot;\n|  |  |  Tensor(GatherV2):0,    shape=[919] &quot;GatherV2:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[85]  &quot;values_3:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[919] &quot;GatherV2\/indices:0&quot;\n|  |  |  |  |  [ 0  0  0 ... 83 84 84]\n|  |  |  |  Tensor(Const):0,    shape=[]    &quot;GatherV2\/axis:0&quot;\n|  |  |  |  |  0\n|  |  |  Tensor(Mul):0, shape=[919] &quot;mul:0&quot;\n|  |  |  |  Tensor(GatherV2):0, shape=[919] &quot;GatherV2_1:0&quot;\n|  |  |  |  |  Tensor(Placeholder):0,   shape=[85]  &quot;values_2:0&quot;\n|  |  |  |  |  Tensor(Const):0, shape=[919] &quot;GatherV2_1\/indices:0&quot;\n|  |  |  |  |  |  [ 0  0  0 ... 83 84 84]\n|  |  |  |  |  Tensor(Const):0, shape=[]    &quot;GatherV2_1\/axis:0&quot;\n|  |  |  |  |  |  0\n|  |  |  |  Tensor(Const):0,    shape=[919] &quot;mul\/y:0&quot;\n|  |  |  |  |  [1. 0. 0. ... 0. 0. 0.]\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_3_1\/forward\/Exp:0&quot;\n|  |  |  ...\n\n<\/code><\/pre>\n<\/figure>\n<p>The names in the TFP graph are not based on the PyMC4 model objects, so, to make the graph output slightly more interpretable, Listing <a href=\"#orgc9910cb\">15<\/a> attempts to re-association the labels.<\/p>\n<figure id=\"orgc9910cb\">\n<div class=\"sourceCode\" id=\"cb15\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb15-1\"><a href=\"#cb15-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> pprint <span class=\"im\">import<\/span> pprint<\/span>\n<span id=\"cb15-2\"><a href=\"#cb15-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb15-3\"><a href=\"#cb15-3\" aria-hidden=\"true\"><\/a>tfp_names_to_pymc <span class=\"op\">=<\/span> {i.name: k <span class=\"cf\">for<\/span> i, k <span class=\"kw\">in<\/span> <span class=\"bu\">zip<\/span>(logpfn_cf.structured_input_signature[<span class=\"dv\">0<\/span>], init.keys())}<\/span>\n<span id=\"cb15-4\"><a href=\"#cb15-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb15-5\"><a href=\"#cb15-5\" aria-hidden=\"true\"><\/a>pprint(tfp_names_to_pymc)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 15\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb16\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb16-1\"><a href=\"#cb16-1\" aria-hidden=\"true\"><\/a>{<span class=\"st\">&#39;values_0&#39;<\/span>: <span class=\"st\">&#39;hierarchical_model\/__log_sigma_alpha&#39;<\/span>,<\/span>\n<span id=\"cb16-2\"><a href=\"#cb16-2\" aria-hidden=\"true\"><\/a> <span class=\"st\">&#39;values_1&#39;<\/span>: <span class=\"st\">&#39;hierarchical_model\/mu_alpha&#39;<\/span>,<\/span>\n<span id=\"cb16-3\"><a href=\"#cb16-3\" aria-hidden=\"true\"><\/a> <span class=\"st\">&#39;values_2&#39;<\/span>: <span class=\"st\">&#39;hierarchical_model\/beta&#39;<\/span>,<\/span>\n<span id=\"cb16-4\"><a href=\"#cb16-4\" aria-hidden=\"true\"><\/a> <span class=\"st\">&#39;values_3&#39;<\/span>: <span class=\"st\">&#39;hierarchical_model\/alpha&#39;<\/span>,<\/span>\n<span id=\"cb16-5\"><a href=\"#cb16-5\" aria-hidden=\"true\"><\/a> <span class=\"st\">&#39;values_4&#39;<\/span>: <span class=\"st\">&#39;hierarchical_model\/mu_beta&#39;<\/span>,<\/span>\n<span id=\"cb16-6\"><a href=\"#cb16-6\" aria-hidden=\"true\"><\/a> <span class=\"st\">&#39;values_5&#39;<\/span>: <span class=\"st\">&#39;hierarchical_model\/__log_sigma_beta&#39;<\/span>,<\/span>\n<span id=\"cb16-7\"><a href=\"#cb16-7\" aria-hidden=\"true\"><\/a> <span class=\"st\">&#39;values_6&#39;<\/span>: <span class=\"st\">&#39;hierarchical_model\/__log_eps&#39;<\/span>}<\/span>\n<span id=\"cb16-8\"><a href=\"#cb16-8\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<\/figure>\n<\/section>\n<section id=\"graph-normalization\" class=\"level2\">\n<h2>Graph Normalization<\/h2>\n<p>In general, we don\u2019t want our \u201cpatterns\u201d to be \u201cbrittle\u201d, e.g.\u00a0rely on explicit\u2013yet variable\u2013term orderings in commutative operators (e.g.\u00a0a pattern that exclusively targets <code>mt.add(x_lv, y_lv)<\/code> and won\u2019t match the equivalent <code>mt.add(y_lv, x_lv)<\/code>).<\/p>\n<p>The <code>grappler<\/code> library in TensorFlow provides a subset of graph pruning\/optimization steps. Ideally, a library like <code>grappler<\/code> would provide full-fledged graph normalization\/canonicalization upon which we could base the subgraphs used in our relations.<\/p>\n<div class=\"remark\" data-markdown=\"\">\n<p>While <code>grappler<\/code> does appear to provide some minimal algebraic normalizations, the extent to which these are performed and their breadth of relevant operator coverage isn\u2019t clear; however, the normalizations that it does provide are worth using, so we\u2019ll make use of them throughout.<\/p>\n<\/div>\n<p>Listing <a href=\"#org0c58115\">17<\/a> provides a simple means of applying <code>grappler<\/code>.<\/p>\n<figure id=\"org0c58115\">\n<div class=\"sourceCode\" id=\"cb17\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb17-1\"><a href=\"#cb17-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.core.protobuf <span class=\"im\">import<\/span> config_pb2<\/span>\n<span id=\"cb17-2\"><a href=\"#cb17-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-3\"><a href=\"#cb17-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.framework <span class=\"im\">import<\/span> ops<\/span>\n<span id=\"cb17-4\"><a href=\"#cb17-4\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.framework <span class=\"im\">import<\/span> importer<\/span>\n<span id=\"cb17-5\"><a href=\"#cb17-5\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.framework <span class=\"im\">import<\/span> meta_graph<\/span>\n<span id=\"cb17-6\"><a href=\"#cb17-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-7\"><a href=\"#cb17-7\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.grappler <span class=\"im\">import<\/span> cluster<\/span>\n<span id=\"cb17-8\"><a href=\"#cb17-8\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.grappler <span class=\"im\">import<\/span> tf_optimizer<\/span>\n<span id=\"cb17-9\"><a href=\"#cb17-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-10\"><a href=\"#cb17-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-11\"><a href=\"#cb17-11\" aria-hidden=\"true\"><\/a><span class=\"cf\">try<\/span>:<\/span>\n<span id=\"cb17-12\"><a href=\"#cb17-12\" aria-hidden=\"true\"><\/a>    gcluster <span class=\"op\">=<\/span> cluster.Cluster()<\/span>\n<span id=\"cb17-13\"><a href=\"#cb17-13\" aria-hidden=\"true\"><\/a><span class=\"cf\">except<\/span> tf.errors.UnavailableError:<\/span>\n<span id=\"cb17-14\"><a href=\"#cb17-14\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">pass<\/span><\/span>\n<span id=\"cb17-15\"><a href=\"#cb17-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-16\"><a href=\"#cb17-16\" aria-hidden=\"true\"><\/a>config <span class=\"op\">=<\/span> config_pb2.ConfigProto()<\/span>\n<span id=\"cb17-17\"><a href=\"#cb17-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-18\"><a href=\"#cb17-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-19\"><a href=\"#cb17-19\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> normalize_tf_graph(graph_output, graph_inputs<span class=\"op\">=<\/span>[]):<\/span>\n<span id=\"cb17-20\"><a href=\"#cb17-20\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Use grappler to normalize a graph.<\/span><\/span>\n<span id=\"cb17-21\"><a href=\"#cb17-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-22\"><a href=\"#cb17-22\" aria-hidden=\"true\"><\/a><span class=\"co\">    Arguments<\/span><\/span>\n<span id=\"cb17-23\"><a href=\"#cb17-23\" aria-hidden=\"true\"><\/a><span class=\"co\">    =========<\/span><\/span>\n<span id=\"cb17-24\"><a href=\"#cb17-24\" aria-hidden=\"true\"><\/a><span class=\"co\">    graph_output: Tensor<\/span><\/span>\n<span id=\"cb17-25\"><a href=\"#cb17-25\" aria-hidden=\"true\"><\/a><span class=\"co\">      A tensor we want to consider as &quot;output&quot; of a FuncGraph.<\/span><\/span>\n<span id=\"cb17-26\"><a href=\"#cb17-26\" aria-hidden=\"true\"><\/a><span class=\"co\">    graph_inputs: list of Tensor (optional)<\/span><\/span>\n<span id=\"cb17-27\"><a href=\"#cb17-27\" aria-hidden=\"true\"><\/a><span class=\"co\">      Any tensors that correspond to inputs for the given output node.<\/span><\/span>\n<span id=\"cb17-28\"><a href=\"#cb17-28\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-29\"><a href=\"#cb17-29\" aria-hidden=\"true\"><\/a><span class=\"co\">    Returns<\/span><\/span>\n<span id=\"cb17-30\"><a href=\"#cb17-30\" aria-hidden=\"true\"><\/a><span class=\"co\">    =======<\/span><\/span>\n<span id=\"cb17-31\"><a href=\"#cb17-31\" aria-hidden=\"true\"><\/a><span class=\"co\">    The simplified graph.<\/span><\/span>\n<span id=\"cb17-32\"><a href=\"#cb17-32\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb17-33\"><a href=\"#cb17-33\" aria-hidden=\"true\"><\/a>    train_op <span class=\"op\">=<\/span> graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)<\/span>\n<span id=\"cb17-34\"><a href=\"#cb17-34\" aria-hidden=\"true\"><\/a>    train_op.clear()<\/span>\n<span id=\"cb17-35\"><a href=\"#cb17-35\" aria-hidden=\"true\"><\/a>    train_op.extend([graph_output] <span class=\"op\">+<\/span> graph_inputs)<\/span>\n<span id=\"cb17-36\"><a href=\"#cb17-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-37\"><a href=\"#cb17-37\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># if graph_inputs is not None:<\/span><\/span>\n<span id=\"cb17-38\"><a href=\"#cb17-38\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     # ops.GraphKeys.MODEL_VARIABLES?<\/span><\/span>\n<span id=\"cb17-39\"><a href=\"#cb17-39\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     train_vars = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES),<\/span><\/span>\n<span id=\"cb17-40\"><a href=\"#cb17-40\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     train_vars.clear()<\/span><\/span>\n<span id=\"cb17-41\"><a href=\"#cb17-41\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     train_vars.extend(graph_inputs)<\/span><\/span>\n<span id=\"cb17-42\"><a href=\"#cb17-42\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-43\"><a href=\"#cb17-43\" aria-hidden=\"true\"><\/a>    metagraph <span class=\"op\">=<\/span> meta_graph.create_meta_graph_def(graph<span class=\"op\">=<\/span>graph_output.graph)<\/span>\n<span id=\"cb17-44\"><a href=\"#cb17-44\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-45\"><a href=\"#cb17-45\" aria-hidden=\"true\"><\/a>    optimized_graphdef <span class=\"op\">=<\/span> tf_optimizer.OptimizeGraph(<\/span>\n<span id=\"cb17-46\"><a href=\"#cb17-46\" aria-hidden=\"true\"><\/a>        config, metagraph, verbose<span class=\"op\">=<\/span><span class=\"va\">True<\/span>, cluster<span class=\"op\">=<\/span>gcluster)<\/span>\n<span id=\"cb17-47\"><a href=\"#cb17-47\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-48\"><a href=\"#cb17-48\" aria-hidden=\"true\"><\/a>    optimized_graph <span class=\"op\">=<\/span> ops.Graph()<\/span>\n<span id=\"cb17-49\"><a href=\"#cb17-49\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">with<\/span> optimized_graph.as_default():<\/span>\n<span id=\"cb17-50\"><a href=\"#cb17-50\" aria-hidden=\"true\"><\/a>        importer.import_graph_def(optimized_graphdef, name<span class=\"op\">=<\/span><span class=\"st\">&quot;&quot;<\/span>)<\/span>\n<span id=\"cb17-51\"><a href=\"#cb17-51\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-52\"><a href=\"#cb17-52\" aria-hidden=\"true\"><\/a>    opt_graph_output <span class=\"op\">=<\/span> optimized_graph.get_tensor_by_name(graph_output.name)<\/span>\n<span id=\"cb17-53\"><a href=\"#cb17-53\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb17-54\"><a href=\"#cb17-54\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> opt_graph_output<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 17\n<\/figcaption>\n<\/figure>\n<p>In Listing <a href=\"#org0c58115\">17<\/a> we run <code>grappler<\/code> on the log-likelihood graph for a normal random variable from Listing <a href=\"#orgc47d26d\">10<\/a>.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb18\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb18-1\"><a href=\"#cb18-1\" aria-hidden=\"true\"><\/a>normal_log_lik_opt <span class=\"op\">=<\/span> normalize_tf_graph(normal_log_lik)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>Listing <a href=\"#org04c54ca\">19<\/a> compares the computed outputs for the original and normalized graphs\u2013given identical inputs.<\/p>\n<figure id=\"org04c54ca\">\n<div class=\"sourceCode\" id=\"cb19\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb19-1\"><a href=\"#cb19-1\" aria-hidden=\"true\"><\/a>res_unopt <span class=\"op\">=<\/span> normal_log_lik.<span class=\"bu\">eval<\/span>({<span class=\"st\">&#39;mu:0&#39;<\/span>: np.r_[<span class=\"dv\">3<\/span>], <span class=\"st\">&#39;tau:0&#39;<\/span>: np.r_[<span class=\"dv\">1<\/span>], <span class=\"st\">&#39;value:0&#39;<\/span>: np.r_[<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb19-2\"><a href=\"#cb19-2\" aria-hidden=\"true\"><\/a>                                 session<span class=\"op\">=<\/span>tf.compat.v1.Session(graph<span class=\"op\">=<\/span>normal_log_lik.graph))<\/span>\n<span id=\"cb19-3\"><a href=\"#cb19-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb19-4\"><a href=\"#cb19-4\" aria-hidden=\"true\"><\/a>res_opt <span class=\"op\">=<\/span> normal_log_lik_opt.<span class=\"bu\">eval<\/span>({<span class=\"st\">&#39;mu:0&#39;<\/span>: np.r_[<span class=\"dv\">3<\/span>], <span class=\"st\">&#39;tau:0&#39;<\/span>: np.r_[<span class=\"dv\">1<\/span>], <span class=\"st\">&#39;value:0&#39;<\/span>: np.r_[<span class=\"dv\">1<\/span>]},<\/span>\n<span id=\"cb19-5\"><a href=\"#cb19-5\" aria-hidden=\"true\"><\/a>                                  session<span class=\"op\">=<\/span>tf.compat.v1.Session(graph<span class=\"op\">=<\/span>normal_log_lik_opt.graph))<\/span>\n<span id=\"cb19-6\"><a href=\"#cb19-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb19-7\"><a href=\"#cb19-7\" aria-hidden=\"true\"><\/a><span class=\"co\"># They should be equal, naturally<\/span><\/span>\n<span id=\"cb19-8\"><a href=\"#cb19-8\" aria-hidden=\"true\"><\/a><span class=\"cf\">assert<\/span> np.array_equal(res_unopt, res_opt)<\/span>\n<span id=\"cb19-9\"><a href=\"#cb19-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb19-10\"><a href=\"#cb19-10\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> [res_unopt, res_opt]<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 19\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb20\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb20-1\"><a href=\"#cb20-1\" aria-hidden=\"true\"><\/a>[array([<span class=\"op\">-<\/span><span class=\"fl\">2.9189386<\/span>], dtype<span class=\"op\">=<\/span>float32), array([<span class=\"op\">-<\/span><span class=\"fl\">2.9189386<\/span>], dtype<span class=\"op\">=<\/span>float32)]<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"orge1da777\">\n<div class=\"sourceCode\" id=\"cb21\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb21-1\"><a href=\"#cb21-1\" aria-hidden=\"true\"><\/a>tf_dprint(normal_log_lik_opt)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 21\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(Sub):0,  shape=[None]    &quot;Normal_1\/log_prob\/sub:0&quot;\n|  Tensor(Mul):0,   shape=[None]    &quot;Normal_1\/log_prob\/mul:0&quot;\n|  |  Tensor(SquaredDifference):0,  shape=[None]    &quot;Normal_1\/log_prob\/SquaredDifference:0&quot;\n|  |  |  Tensor(RealDiv):0, shape=[None]    &quot;Normal_1\/log_prob\/truediv:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;value:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;tau:0&quot;\n|  |  |  Tensor(RealDiv):0, shape=[None]    &quot;Normal_1\/log_prob\/truediv_1:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;mu:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;tau:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal_1\/log_prob\/mul\/x:0&quot;\n|  |  |  -0.5\n|  Tensor(AddV2):0, shape=[None]    &quot;Normal_1\/log_prob\/add:0&quot;\n|  |  Tensor(Log):0,    shape=[None]    &quot;Normal_1\/log_prob\/Log:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[None]    &quot;tau:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;Normal_1\/log_prob\/add\/x:0&quot;\n|  |  |  0.9189385\n\n<\/code><\/pre>\n<\/figure>\n<p>From the output of Listing <a href=\"#orge1da777\">21<\/a>, we can see that <code>grappler<\/code> has performed some constant folding and has reordered the inputs in <code>\"add_1_1\"<\/code>\u2013among other things.<\/p>\n<\/section>\n<section id=\"minikanren-transform-relations\" class=\"level2\">\n<h2>miniKanren Transform Relations<\/h2>\n<p>In Listing <a href=\"#org0ad3a96\">23<\/a>, we create miniKanren functions that identify the aforementioned <code>SquaredDifference<\/code> \u201cchains\u201d and perform the re-centering\/scaling substitutions.<\/p>\n<figure id=\"org0ad3a96\">\n<div class=\"sourceCode\" id=\"cb23\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb23-1\"><a href=\"#cb23-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> itertools <span class=\"im\">import<\/span> chain<\/span>\n<span id=\"cb23-2\"><a href=\"#cb23-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> functools <span class=\"im\">import<\/span> partial<\/span>\n<span id=\"cb23-3\"><a href=\"#cb23-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-4\"><a href=\"#cb23-4\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> unification <span class=\"im\">import<\/span> var, reify, unify<\/span>\n<span id=\"cb23-5\"><a href=\"#cb23-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-6\"><a href=\"#cb23-6\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> kanren <span class=\"im\">import<\/span> run, eq, lall, conde<\/span>\n<span id=\"cb23-7\"><a href=\"#cb23-7\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> kanren.goals <span class=\"im\">import<\/span> not_equalo<\/span>\n<span id=\"cb23-8\"><a href=\"#cb23-8\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> kanren.core <span class=\"im\">import<\/span> goaleval<\/span>\n<span id=\"cb23-9\"><a href=\"#cb23-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-10\"><a href=\"#cb23-10\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.tensorflow.meta <span class=\"im\">import<\/span> mt<\/span>\n<span id=\"cb23-11\"><a href=\"#cb23-11\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.relations <span class=\"im\">import<\/span> buildo<\/span>\n<span id=\"cb23-12\"><a href=\"#cb23-12\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.relations.graph <span class=\"im\">import<\/span> graph_applyo, reduceo<\/span>\n<span id=\"cb23-13\"><a href=\"#cb23-13\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> symbolic_pymc.etuple <span class=\"im\">import<\/span> ExpressionTuple, etuple<\/span>\n<span id=\"cb23-14\"><a href=\"#cb23-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-15\"><a href=\"#cb23-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-16\"><a href=\"#cb23-16\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> onceo(goal):<\/span>\n<span id=\"cb23-17\"><a href=\"#cb23-17\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;A non-relational operator that yields only the first result from a relation.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb23-18\"><a href=\"#cb23-18\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> onceo_goal(s):<\/span>\n<span id=\"cb23-19\"><a href=\"#cb23-19\" aria-hidden=\"true\"><\/a>        <span class=\"kw\">nonlocal<\/span> goal<\/span>\n<span id=\"cb23-20\"><a href=\"#cb23-20\" aria-hidden=\"true\"><\/a>        g <span class=\"op\">=<\/span> reify(goal, s)<\/span>\n<span id=\"cb23-21\"><a href=\"#cb23-21\" aria-hidden=\"true\"><\/a>        g_stream <span class=\"op\">=<\/span> goaleval(g)(s)<\/span>\n<span id=\"cb23-22\"><a href=\"#cb23-22\" aria-hidden=\"true\"><\/a>        s <span class=\"op\">=<\/span> <span class=\"bu\">next<\/span>(g_stream)<\/span>\n<span id=\"cb23-23\"><a href=\"#cb23-23\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">yield<\/span> s<\/span>\n<span id=\"cb23-24\"><a href=\"#cb23-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-25\"><a href=\"#cb23-25\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> onceo_goal<\/span>\n<span id=\"cb23-26\"><a href=\"#cb23-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-27\"><a href=\"#cb23-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-28\"><a href=\"#cb23-28\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> tf_graph_applyo(relation, a, b):<\/span>\n<span id=\"cb23-29\"><a href=\"#cb23-29\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Construct a `graph_applyo` goal that evaluates a relation only at tensor nodes in a meta graph.<\/span><\/span>\n<span id=\"cb23-30\"><a href=\"#cb23-30\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-31\"><a href=\"#cb23-31\" aria-hidden=\"true\"><\/a><span class=\"co\">    Parameters<\/span><\/span>\n<span id=\"cb23-32\"><a href=\"#cb23-32\" aria-hidden=\"true\"><\/a><span class=\"co\">    ----------<\/span><\/span>\n<span id=\"cb23-33\"><a href=\"#cb23-33\" aria-hidden=\"true\"><\/a><span class=\"co\">    relation: function<\/span><\/span>\n<span id=\"cb23-34\"><a href=\"#cb23-34\" aria-hidden=\"true\"><\/a><span class=\"co\">      A binary relation\/goal constructor function<\/span><\/span>\n<span id=\"cb23-35\"><a href=\"#cb23-35\" aria-hidden=\"true\"><\/a><span class=\"co\">    a: lvar, meta graph, or etuple<\/span><\/span>\n<span id=\"cb23-36\"><a href=\"#cb23-36\" aria-hidden=\"true\"><\/a><span class=\"co\">      The left-hand side of the relation.<\/span><\/span>\n<span id=\"cb23-37\"><a href=\"#cb23-37\" aria-hidden=\"true\"><\/a><span class=\"co\">    b: lvar, meta graph, or etuple<\/span><\/span>\n<span id=\"cb23-38\"><a href=\"#cb23-38\" aria-hidden=\"true\"><\/a><span class=\"co\">      The right-hand side of the relation<\/span><\/span>\n<span id=\"cb23-39\"><a href=\"#cb23-39\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb23-40\"><a href=\"#cb23-40\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-41\"><a href=\"#cb23-41\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> _expand_some_nodes(node):<\/span>\n<span id=\"cb23-42\"><a href=\"#cb23-42\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(node, mt.Tensor) <span class=\"kw\">and<\/span> node.op <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb23-43\"><a href=\"#cb23-43\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">return<\/span> etuple(node.operator, <span class=\"op\">*<\/span>node.inputs, eval_obj<span class=\"op\">=<\/span>node)<\/span>\n<span id=\"cb23-44\"><a href=\"#cb23-44\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"cb23-45\"><a href=\"#cb23-45\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-46\"><a href=\"#cb23-46\" aria-hidden=\"true\"><\/a>    gapplyo <span class=\"op\">=<\/span> partial(graph_applyo, relation, preprocess_graph<span class=\"op\">=<\/span>_expand_some_nodes)<\/span>\n<span id=\"cb23-47\"><a href=\"#cb23-47\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> gapplyo(a, b)<\/span>\n<span id=\"cb23-48\"><a href=\"#cb23-48\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-49\"><a href=\"#cb23-49\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-50\"><a href=\"#cb23-50\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> tfp_normal_log_prob(loc, scale):<\/span>\n<span id=\"cb23-51\"><a href=\"#cb23-51\" aria-hidden=\"true\"><\/a>    log_unnormalized <span class=\"op\">=<\/span> <span class=\"op\">-<\/span><span class=\"fl\">0.5<\/span> <span class=\"op\">*<\/span> tf.math.squared_difference(<\/span>\n<span id=\"cb23-52\"><a href=\"#cb23-52\" aria-hidden=\"true\"><\/a>        x <span class=\"op\">\/<\/span> scale, loc <span class=\"op\">\/<\/span> scale)<\/span>\n<span id=\"cb23-53\"><a href=\"#cb23-53\" aria-hidden=\"true\"><\/a>    log_normalization <span class=\"op\">=<\/span> <span class=\"fl\">0.5<\/span> <span class=\"op\">*<\/span> np.log(<span class=\"fl\">2.<\/span> <span class=\"op\">*<\/span> np.pi) <span class=\"op\">+<\/span> tf.math.log(scale)<\/span>\n<span id=\"cb23-54\"><a href=\"#cb23-54\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> log_unnormalized <span class=\"op\">-<\/span> log_normalization<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 23\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb24\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb24-1\"><a href=\"#cb24-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> shift_squared_subso(in_graph, out_subs):<\/span>\n<span id=\"cb24-2\"><a href=\"#cb24-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Construct a goal that produces transforms for chains like (y + x)**2, (x + z)**2.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb24-3\"><a href=\"#cb24-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-4\"><a href=\"#cb24-4\" aria-hidden=\"true\"><\/a>    Y_lv, X_lv, mu_X_lv <span class=\"op\">=<\/span> var(), var(), var()<\/span>\n<span id=\"cb24-5\"><a href=\"#cb24-5\" aria-hidden=\"true\"><\/a>    scale_Y_lv <span class=\"op\">=<\/span> var()<\/span>\n<span id=\"cb24-6\"><a href=\"#cb24-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-7\"><a href=\"#cb24-7\" aria-hidden=\"true\"><\/a>    X_form_lv <span class=\"op\">=<\/span> mt.Placeholder(dtype<span class=\"op\">=<\/span>var(), shape<span class=\"op\">=<\/span>var(), name<span class=\"op\">=<\/span>var())<\/span>\n<span id=\"cb24-8\"><a href=\"#cb24-8\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># The actual base object&#39;s placeholder might have `_user_specified_name` as<\/span><\/span>\n<span id=\"cb24-9\"><a href=\"#cb24-9\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># an extra `op.node_def.attr`, so let&#39;s just make the entire NodeDef a<\/span><\/span>\n<span id=\"cb24-10\"><a href=\"#cb24-10\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># logic variable.<\/span><\/span>\n<span id=\"cb24-11\"><a href=\"#cb24-11\" aria-hidden=\"true\"><\/a>    X_form_lv.op.node_def <span class=\"op\">=<\/span> var()<\/span>\n<span id=\"cb24-12\"><a href=\"#cb24-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-13\"><a href=\"#cb24-13\" aria-hidden=\"true\"><\/a>    mu_Y_lv <span class=\"op\">=<\/span> mt.realdiv(X_lv, scale_Y_lv, name<span class=\"op\">=<\/span>var())<\/span>\n<span id=\"cb24-14\"><a href=\"#cb24-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-15\"><a href=\"#cb24-15\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Y_T_reshaped_lv = mt.Transpose(mt.reshape(Y_lv, var(), name=var()), var())<\/span><\/span>\n<span id=\"cb24-16\"><a href=\"#cb24-16\" aria-hidden=\"true\"><\/a>    Y_reshaped_lv <span class=\"op\">=<\/span> mt.reshape(Y_lv, var(), name<span class=\"op\">=<\/span>var())<\/span>\n<span id=\"cb24-17\"><a href=\"#cb24-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-18\"><a href=\"#cb24-18\" aria-hidden=\"true\"><\/a>    sqr_diff_Y_lv <span class=\"op\">=<\/span> mt.SquaredDifference(<\/span>\n<span id=\"cb24-19\"><a href=\"#cb24-19\" aria-hidden=\"true\"><\/a>        mt.realdiv(Y_reshaped_lv,<\/span>\n<span id=\"cb24-20\"><a href=\"#cb24-20\" aria-hidden=\"true\"><\/a>                   scale_Y_lv,<\/span>\n<span id=\"cb24-21\"><a href=\"#cb24-21\" aria-hidden=\"true\"><\/a>                   name<span class=\"op\">=<\/span>var()),<\/span>\n<span id=\"cb24-22\"><a href=\"#cb24-22\" aria-hidden=\"true\"><\/a>        mu_Y_lv,<\/span>\n<span id=\"cb24-23\"><a href=\"#cb24-23\" aria-hidden=\"true\"><\/a>        name<span class=\"op\">=<\/span>var())<\/span>\n<span id=\"cb24-24\"><a href=\"#cb24-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-25\"><a href=\"#cb24-25\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> Y_sqrdiffo(in_g, out_g):<\/span>\n<span id=\"cb24-26\"><a href=\"#cb24-26\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> lall(eq(in_g, sqr_diff_Y_lv),<\/span>\n<span id=\"cb24-27\"><a href=\"#cb24-27\" aria-hidden=\"true\"><\/a>                    <span class=\"co\"># This just makes sure that we&#39;re only considering X&#39;s<\/span><\/span>\n<span id=\"cb24-28\"><a href=\"#cb24-28\" aria-hidden=\"true\"><\/a>                    <span class=\"co\"># that are Placeholders.<\/span><\/span>\n<span id=\"cb24-29\"><a href=\"#cb24-29\" aria-hidden=\"true\"><\/a>                    eq(X_lv, X_form_lv))<\/span>\n<span id=\"cb24-30\"><a href=\"#cb24-30\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-31\"><a href=\"#cb24-31\" aria-hidden=\"true\"><\/a>    scale_X_lv <span class=\"op\">=<\/span> var()<\/span>\n<span id=\"cb24-32\"><a href=\"#cb24-32\" aria-hidden=\"true\"><\/a>    sqr_diff_X_lv <span class=\"op\">=<\/span> mt.SquaredDifference(<\/span>\n<span id=\"cb24-33\"><a href=\"#cb24-33\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Mul is only used because RealDiv with 1 is changed by grappler<\/span><\/span>\n<span id=\"cb24-34\"><a href=\"#cb24-34\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># mt.realdiv(X_lv, X_denom_lv, name=var()),<\/span><\/span>\n<span id=\"cb24-35\"><a href=\"#cb24-35\" aria-hidden=\"true\"><\/a>        mt.mul(scale_X_lv, X_lv, name<span class=\"op\">=<\/span>var()),<\/span>\n<span id=\"cb24-36\"><a href=\"#cb24-36\" aria-hidden=\"true\"><\/a>        mu_X_lv,<\/span>\n<span id=\"cb24-37\"><a href=\"#cb24-37\" aria-hidden=\"true\"><\/a>        name<span class=\"op\">=<\/span>var())<\/span>\n<span id=\"cb24-38\"><a href=\"#cb24-38\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-39\"><a href=\"#cb24-39\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> X_sqrdiffo(in_g, out_g):<\/span>\n<span id=\"cb24-40\"><a href=\"#cb24-40\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> eq(in_g, sqr_diff_X_lv)<\/span>\n<span id=\"cb24-41\"><a href=\"#cb24-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-42\"><a href=\"#cb24-42\" aria-hidden=\"true\"><\/a>    Y_new_mt <span class=\"op\">=<\/span> mt.addv2(X_lv, mt.mul(scale_Y_lv, Y_lv))<\/span>\n<span id=\"cb24-43\"><a href=\"#cb24-43\" aria-hidden=\"true\"><\/a>    Y_log_scale <span class=\"op\">=<\/span> mt.log(scale_Y_lv, name<span class=\"op\">=<\/span>var())<\/span>\n<span id=\"cb24-44\"><a href=\"#cb24-44\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-45\"><a href=\"#cb24-45\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">=<\/span> lall(<\/span>\n<span id=\"cb24-46\"><a href=\"#cb24-46\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># The first (y - x\/a)**2 (anywhere in the graph)<\/span><\/span>\n<span id=\"cb24-47\"><a href=\"#cb24-47\" aria-hidden=\"true\"><\/a>        tf_graph_applyo(Y_sqrdiffo, in_graph, in_graph),<\/span>\n<span id=\"cb24-48\"><a href=\"#cb24-48\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-49\"><a href=\"#cb24-49\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># The corresponding (x\/b - z)**2 (also anywhere else in the graph)<\/span><\/span>\n<span id=\"cb24-50\"><a href=\"#cb24-50\" aria-hidden=\"true\"><\/a>        tf_graph_applyo(X_sqrdiffo, in_graph, in_graph),<\/span>\n<span id=\"cb24-51\"><a href=\"#cb24-51\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-52\"><a href=\"#cb24-52\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Find the log-scale factor (at this point, we might as well match an<\/span><\/span>\n<span id=\"cb24-53\"><a href=\"#cb24-53\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># entire normal log-likelihood!)<\/span><\/span>\n<span id=\"cb24-54\"><a href=\"#cb24-54\" aria-hidden=\"true\"><\/a>        tf_graph_applyo(<span class=\"kw\">lambda<\/span> x, y: eq(x, Y_log_scale), in_graph, in_graph),<\/span>\n<span id=\"cb24-55\"><a href=\"#cb24-55\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-56\"><a href=\"#cb24-56\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Not sure if we need this, but we definitely don&#39;t want X == Y<\/span><\/span>\n<span id=\"cb24-57\"><a href=\"#cb24-57\" aria-hidden=\"true\"><\/a>        (not_equalo, [Y_lv, X_lv], <span class=\"va\">True<\/span>),<\/span>\n<span id=\"cb24-58\"><a href=\"#cb24-58\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-59\"><a href=\"#cb24-59\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Create replacement rule pairs<\/span><\/span>\n<span id=\"cb24-60\"><a href=\"#cb24-60\" aria-hidden=\"true\"><\/a>        eq(out_subs, [[Y_lv, Y_new_mt],<\/span>\n<span id=\"cb24-61\"><a href=\"#cb24-61\" aria-hidden=\"true\"><\/a>                      [Y_log_scale, <span class=\"fl\">0.0<\/span>]]))<\/span>\n<span id=\"cb24-62\"><a href=\"#cb24-62\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb24-63\"><a href=\"#cb24-63\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> res<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb25\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb25-1\"><a href=\"#cb25-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> shift_squared_terms(in_obj, graph_inputs<span class=\"op\">=<\/span>[]):<\/span>\n<span id=\"cb25-2\"><a href=\"#cb25-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Re-center\/scale SquaredDifference terms corresponding to hierarchical normals.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb25-3\"><a href=\"#cb25-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-4\"><a href=\"#cb25-4\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Normalize and convert to a meta graph<\/span><\/span>\n<span id=\"cb25-5\"><a href=\"#cb25-5\" aria-hidden=\"true\"><\/a>    in_obj <span class=\"op\">=<\/span> mt(normalize_tf_graph(in_obj, graph_inputs<span class=\"op\">=<\/span>graph_inputs))<\/span>\n<span id=\"cb25-6\"><a href=\"#cb25-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-7\"><a href=\"#cb25-7\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># This run returns all the substitutions found in the graph<\/span><\/span>\n<span id=\"cb25-8\"><a href=\"#cb25-8\" aria-hidden=\"true\"><\/a>    subs_lv <span class=\"op\">=<\/span> var()<\/span>\n<span id=\"cb25-9\"><a href=\"#cb25-9\" aria-hidden=\"true\"><\/a>    subs_res <span class=\"op\">=<\/span> run(<span class=\"dv\">0<\/span>, subs_lv, shift_squared_subso(in_obj, subs_lv))<\/span>\n<span id=\"cb25-10\"><a href=\"#cb25-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-11\"><a href=\"#cb25-11\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> subs_res:<\/span>\n<span id=\"cb25-12\"><a href=\"#cb25-12\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;Failed to find the required forms within the graph.&quot;<\/span>)<\/span>\n<span id=\"cb25-13\"><a href=\"#cb25-13\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span><\/span>\n<span id=\"cb25-14\"><a href=\"#cb25-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-15\"><a href=\"#cb25-15\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># <\/span><span class=\"al\">NOTE<\/span><span class=\"co\">: We&#39;re only going to apply the first transformation pair for now.<\/span><\/span>\n<span id=\"cb25-16\"><a href=\"#cb25-16\" aria-hidden=\"true\"><\/a>    subs_res <span class=\"op\">=<\/span> [subs_res[<span class=\"dv\">0<\/span>]]<\/span>\n<span id=\"cb25-17\"><a href=\"#cb25-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-18\"><a href=\"#cb25-18\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> subs_replaceo(in_g, out_g):<\/span>\n<span id=\"cb25-19\"><a href=\"#cb25-19\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;Create a goal that applies substitutions to a graph.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb25-20\"><a href=\"#cb25-20\" aria-hidden=\"true\"><\/a>        <span class=\"kw\">def<\/span> _subs_replaceo(in_g, out_g):<\/span>\n<span id=\"cb25-21\"><a href=\"#cb25-21\" aria-hidden=\"true\"><\/a>            <span class=\"kw\">nonlocal<\/span> subs_res<\/span>\n<span id=\"cb25-22\"><a href=\"#cb25-22\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># Each result is a pair of replacement pairs:<\/span><\/span>\n<span id=\"cb25-23\"><a href=\"#cb25-23\" aria-hidden=\"true\"><\/a>            <span class=\"co\">#   the first pair is the re-center\/scale transform,<\/span><\/span>\n<span id=\"cb25-24\"><a href=\"#cb25-24\" aria-hidden=\"true\"><\/a>            <span class=\"co\">#   the second pair is the cancellation of the log differential scale term.<\/span><\/span>\n<span id=\"cb25-25\"><a href=\"#cb25-25\" aria-hidden=\"true\"><\/a>            subs_goals <span class=\"op\">=<\/span> [[eq(in_g, x), eq(out_g, y)]<\/span>\n<span id=\"cb25-26\"><a href=\"#cb25-26\" aria-hidden=\"true\"><\/a>                          <span class=\"cf\">for<\/span> x, y <span class=\"kw\">in<\/span> chain.from_iterable(subs_res)]<\/span>\n<span id=\"cb25-27\"><a href=\"#cb25-27\" aria-hidden=\"true\"><\/a>            x_g <span class=\"op\">=<\/span> conde(<span class=\"op\">*<\/span>subs_goals)<\/span>\n<span id=\"cb25-28\"><a href=\"#cb25-28\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">return<\/span> x_g<\/span>\n<span id=\"cb25-29\"><a href=\"#cb25-29\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-30\"><a href=\"#cb25-30\" aria-hidden=\"true\"><\/a>        g <span class=\"op\">=<\/span> onceo(tf_graph_applyo(_subs_replaceo, in_g, out_g))<\/span>\n<span id=\"cb25-31\"><a href=\"#cb25-31\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> g<\/span>\n<span id=\"cb25-32\"><a href=\"#cb25-32\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-33\"><a href=\"#cb25-33\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Apply each substitution once<\/span><\/span>\n<span id=\"cb25-34\"><a href=\"#cb25-34\" aria-hidden=\"true\"><\/a>    out_graph_lv <span class=\"op\">=<\/span> var()<\/span>\n<span id=\"cb25-35\"><a href=\"#cb25-35\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">=<\/span> run(<span class=\"dv\">1<\/span>, out_graph_lv, reduceo(subs_replaceo, in_obj, out_graph_lv))<\/span>\n<span id=\"cb25-36\"><a href=\"#cb25-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-37\"><a href=\"#cb25-37\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> res:<\/span>\n<span id=\"cb25-38\"><a href=\"#cb25-38\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-39\"><a href=\"#cb25-39\" aria-hidden=\"true\"><\/a>        <span class=\"kw\">def<\/span> reify_res(graph_res):<\/span>\n<span id=\"cb25-40\"><a href=\"#cb25-40\" aria-hidden=\"true\"><\/a>            <span class=\"co\">&quot;&quot;&quot;Reconstruct and\/or reify meta object results.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb25-41\"><a href=\"#cb25-41\" aria-hidden=\"true\"><\/a>            from_etuple <span class=\"op\">=<\/span> graph_res.eval_obj <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(graph_res, ExpressionTuple) <span class=\"cf\">else<\/span> graph_res<\/span>\n<span id=\"cb25-42\"><a href=\"#cb25-42\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> <span class=\"bu\">hasattr<\/span>(from_etuple, <span class=\"st\">&#39;reify&#39;<\/span>):<\/span>\n<span id=\"cb25-43\"><a href=\"#cb25-43\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">return<\/span> from_etuple.reify()<\/span>\n<span id=\"cb25-44\"><a href=\"#cb25-44\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"cb25-45\"><a href=\"#cb25-45\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">return<\/span> from_etuple<\/span>\n<span id=\"cb25-46\"><a href=\"#cb25-46\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-47\"><a href=\"#cb25-47\" aria-hidden=\"true\"><\/a>        res <span class=\"op\">=<\/span> [reify_res(r) <span class=\"cf\">for<\/span> r <span class=\"kw\">in<\/span> res]<\/span>\n<span id=\"cb25-48\"><a href=\"#cb25-48\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-49\"><a href=\"#cb25-49\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"bu\">len<\/span>(res) <span class=\"op\">==<\/span> <span class=\"dv\">1<\/span> <span class=\"kw\">and<\/span> <span class=\"bu\">isinstance<\/span>(res[<span class=\"dv\">0<\/span>], tf.Tensor):<\/span>\n<span id=\"cb25-50\"><a href=\"#cb25-50\" aria-hidden=\"true\"><\/a>        graph_res <span class=\"op\">=<\/span> res[<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"cb25-51\"><a href=\"#cb25-51\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> normalize_tf_graph(graph_res, graph_inputs<span class=\"op\">=<\/span>graph_inputs), subs_res<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>As a test, we will run our miniKanren relations on the log-likelihood graph for a normal-normal hierarchical model in Listing <a href=\"#org29e93d9\">26<\/a>.<\/p>\n<figure id=\"org29e93d9\">\n<div class=\"sourceCode\" id=\"cb26\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb26-1\"><a href=\"#cb26-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> graph_mode(), tf.Graph().as_default() <span class=\"im\">as<\/span> demo_graph:<\/span>\n<span id=\"cb26-2\"><a href=\"#cb26-2\" aria-hidden=\"true\"><\/a>    X_tfp <span class=\"op\">=<\/span> tfp.distributions.normal.Normal(<span class=\"fl\">0.0<\/span>, <span class=\"fl\">1.0<\/span>, name<span class=\"op\">=<\/span><span class=\"st\">&#39;X&#39;<\/span>)<\/span>\n<span id=\"cb26-3\"><a href=\"#cb26-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-4\"><a href=\"#cb26-4\" aria-hidden=\"true\"><\/a>    x_tf <span class=\"op\">=<\/span> tf.compat.v1.placeholder(tf.float32, name<span class=\"op\">=<\/span><span class=\"st\">&#39;value_x&#39;<\/span>,<\/span>\n<span id=\"cb26-5\"><a href=\"#cb26-5\" aria-hidden=\"true\"><\/a>                                    shape<span class=\"op\">=<\/span>tf.TensorShape([<span class=\"va\">None<\/span>]))<\/span>\n<span id=\"cb26-6\"><a href=\"#cb26-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-7\"><a href=\"#cb26-7\" aria-hidden=\"true\"><\/a>    tau_tf <span class=\"op\">=<\/span> tf.compat.v1.placeholder(tf.float32, name<span class=\"op\">=<\/span><span class=\"st\">&#39;tau&#39;<\/span>,<\/span>\n<span id=\"cb26-8\"><a href=\"#cb26-8\" aria-hidden=\"true\"><\/a>                                      shape<span class=\"op\">=<\/span>tf.TensorShape([<span class=\"va\">None<\/span>]))<\/span>\n<span id=\"cb26-9\"><a href=\"#cb26-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-10\"><a href=\"#cb26-10\" aria-hidden=\"true\"><\/a>    Y_tfp <span class=\"op\">=<\/span> tfp.distributions.normal.Normal(x_tf, tau_tf, name<span class=\"op\">=<\/span><span class=\"st\">&#39;Y&#39;<\/span>)<\/span>\n<span id=\"cb26-11\"><a href=\"#cb26-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-12\"><a href=\"#cb26-12\" aria-hidden=\"true\"><\/a>    y_tf <span class=\"op\">=<\/span> tf.compat.v1.placeholder(tf.float32, name<span class=\"op\">=<\/span><span class=\"st\">&#39;value_y&#39;<\/span>,<\/span>\n<span id=\"cb26-13\"><a href=\"#cb26-13\" aria-hidden=\"true\"><\/a>                                    shape<span class=\"op\">=<\/span>tf.TensorShape([<span class=\"va\">None<\/span>]))<\/span>\n<span id=\"cb26-14\"><a href=\"#cb26-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-15\"><a href=\"#cb26-15\" aria-hidden=\"true\"><\/a>    y_T_reshaped <span class=\"op\">=<\/span> tf.transpose(tf.reshape(y_tf, []))<\/span>\n<span id=\"cb26-16\"><a href=\"#cb26-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-17\"><a href=\"#cb26-17\" aria-hidden=\"true\"><\/a>    hier_norm_lik <span class=\"op\">=<\/span> tf.math.log(y_tf) <span class=\"op\">+<\/span> Y_tfp.log_prob(y_T_reshaped) <span class=\"op\">+<\/span> X_tfp.log_prob(x_tf)<\/span>\n<span id=\"cb26-18\"><a href=\"#cb26-18\" aria-hidden=\"true\"><\/a>    hier_norm_lik <span class=\"op\">=<\/span> normalize_tf_graph(hier_norm_lik)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 26\n<\/figcaption>\n<\/figure>\n<p>Listing <a href=\"#org59b1e29\">27<\/a> shows the form that a graph representing a hierarchical normal-normal model will generally take in TFP.<\/p>\n<figure id=\"org59b1e29\">\n<div class=\"sourceCode\" id=\"cb27\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb27-1\"><a href=\"#cb27-1\" aria-hidden=\"true\"><\/a>tf_dprint(hier_norm_lik)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 27\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(AddV2):0,    shape=[None]    &quot;add_1:0&quot;\n|  Tensor(Sub):0,   shape=[None]    &quot;X_1\/log_prob\/sub:0&quot;\n|  |  Tensor(Mul):0,    shape=[None]    &quot;X_1\/log_prob\/mul:0&quot;\n|  |  |  Tensor(SquaredDifference):0,   shape=[None]    &quot;X_1\/log_prob\/SquaredDifference:0&quot;\n|  |  |  |  Tensor(Mul):0,  shape=[None]    &quot;X_1\/log_prob\/truediv:0&quot;\n|  |  |  |  |  Tensor(Const):0, shape=[]    &quot;ConstantFolding\/X_1\/log_prob\/truediv_recip:0&quot;\n|  |  |  |  |  |  1.\n|  |  |  |  |  Tensor(Placeholder):0,   shape=[None]    &quot;value_x:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[]    &quot;X_1\/log_prob\/truediv_1:0&quot;\n|  |  |  |  |  0.\n|  |  |  Tensor(Const):0,   shape=[]    &quot;Y_1\/log_prob\/mul\/x:0&quot;\n|  |  |  |  -0.5\n|  |  Tensor(Const):0,  shape=[]    &quot;Y_1\/log_prob\/add\/x:0&quot;\n|  |  |  0.9189385\n|  Tensor(AddV2):0, shape=[None]    &quot;add:0&quot;\n|  |  Tensor(Log):0,    shape=[None]    &quot;Log:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[None]    &quot;value_y:0&quot;\n|  |  Tensor(Sub):0,    shape=[None]    &quot;Y_1\/log_prob\/sub:0&quot;\n|  |  |  Tensor(Mul):0, shape=[None]    &quot;Y_1\/log_prob\/mul:0&quot;\n|  |  |  |  Tensor(SquaredDifference):0,    shape=[None]    &quot;Y_1\/log_prob\/SquaredDifference:0&quot;\n|  |  |  |  |  Tensor(RealDiv):0,   shape=[None]    &quot;Y_1\/log_prob\/truediv:0&quot;\n|  |  |  |  |  |  Tensor(Reshape):0,    shape=[]    &quot;Reshape:0&quot;\n|  |  |  |  |  |  |  Tensor(Placeholder):0, shape=[None]    &quot;value_y:0&quot;\n|  |  |  |  |  |  |  Tensor(Const):0,   shape=[0]   &quot;Reshape\/shape:0&quot;\n|  |  |  |  |  |  |  |  []\n|  |  |  |  |  |  Tensor(Placeholder):0,    shape=[None]    &quot;tau:0&quot;\n|  |  |  |  |  Tensor(RealDiv):0,   shape=[None]    &quot;Y_1\/log_prob\/truediv_1:0&quot;\n|  |  |  |  |  |  Tensor(Placeholder):0,    shape=[None]    &quot;value_x:0&quot;\n|  |  |  |  |  |  Tensor(Placeholder):0,    shape=[None]    &quot;tau:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[]    &quot;Y_1\/log_prob\/mul\/x:0&quot;\n|  |  |  |  |  -0.5\n|  |  |  Tensor(AddV2):0,   shape=[None]    &quot;Y_1\/log_prob\/add:0&quot;\n|  |  |  |  Tensor(Log):0,  shape=[None]    &quot;Y_1\/log_prob\/Log:0&quot;\n|  |  |  |  |  Tensor(Placeholder):0,   shape=[None]    &quot;tau:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[]    &quot;Y_1\/log_prob\/add\/x:0&quot;\n|  |  |  |  |  0.9189385\n\n<\/code><\/pre>\n<\/figure>\n<p>Listing <a href=\"#orgf81b6e5\">29<\/a> runs our transformation and Listing <a href=\"#orga761bbe\">32<\/a> prints the resulting graph.<\/p>\n<figure id=\"orgf81b6e5\">\n<div class=\"sourceCode\" id=\"cb29\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb29-1\"><a href=\"#cb29-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> graph_mode(), demo_graph.as_default():<\/span>\n<span id=\"cb29-2\"><a href=\"#cb29-2\" aria-hidden=\"true\"><\/a>    test_output_res, test_remaps <span class=\"op\">=<\/span> shift_squared_terms(hier_norm_lik, graph_inputs<span class=\"op\">=<\/span>[x_tf, y_tf])<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 29\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb30\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb30-1\"><a href=\"#cb30-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> rm <span class=\"kw\">in<\/span> test_remaps:<\/span>\n<span id=\"cb30-2\"><a href=\"#cb30-2\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">for<\/span> r <span class=\"kw\">in<\/span> rm:<\/span>\n<span id=\"cb30-3\"><a href=\"#cb30-3\" aria-hidden=\"true\"><\/a>      tf_dprint(r[<span class=\"dv\">0<\/span>])<\/span>\n<span id=\"cb30-4\"><a href=\"#cb30-4\" aria-hidden=\"true\"><\/a>      <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;-&gt;&quot;<\/span>)<\/span>\n<span id=\"cb30-5\"><a href=\"#cb30-5\" aria-hidden=\"true\"><\/a>      tf_dprint(r[<span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb30-6\"><a href=\"#cb30-6\" aria-hidden=\"true\"><\/a>      <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;------&quot;<\/span>)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(Placeholder):0,  shape=[None]    &quot;value_y:0&quot;\n-&gt;\nTensor(AddV2):0,    shape=[None]    &quot;AddV2:0&quot;\n|  Tensor(Placeholder):0,   shape=[None]    &quot;value_x:0&quot;\n|  Tensor(Mul):0,   shape=[None]    &quot;Mul:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[None]    &quot;tau:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[None]    &quot;value_y:0&quot;\n------\nTensor(Log):0,  shape=~_12312   &quot;Y_1\/log_prob\/Log:0&quot;\n|  Tensor(Placeholder):0,   shape=[None]    &quot;tau:0&quot;\n-&gt;\n0.0\n------\n\n<\/code><\/pre>\n<\/figure>\n<figure id=\"orga761bbe\">\n<div class=\"sourceCode\" id=\"cb32\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb32-1\"><a href=\"#cb32-1\" aria-hidden=\"true\"><\/a>tf_dprint(test_output_res)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 32\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(AddV2):0,    shape=[None]    &quot;add_1_1:0&quot;\n|  Tensor(Sub):0,   shape=[None]    &quot;X_1\/log_prob\/sub:0&quot;\n|  |  Tensor(Mul):0,    shape=[None]    &quot;X_1\/log_prob\/mul:0&quot;\n|  |  |  Tensor(SquaredDifference):0,   shape=[None]    &quot;X_1\/log_prob\/SquaredDifference:0&quot;\n|  |  |  |  Tensor(Mul):0,  shape=[None]    &quot;X_1\/log_prob\/truediv:0&quot;\n|  |  |  |  |  Tensor(Const):0, shape=[]    &quot;ConstantFolding\/X_1\/log_prob\/truediv_recip:0&quot;\n|  |  |  |  |  |  1.\n|  |  |  |  |  Tensor(Placeholder):0,   shape=[None]    &quot;value_x:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[]    &quot;X_1\/log_prob\/truediv_1:0&quot;\n|  |  |  |  |  0.\n|  |  |  Tensor(Const):0,   shape=[]    &quot;Y_1\/log_prob\/mul\/x:0&quot;\n|  |  |  |  -0.5\n|  |  Tensor(Const):0,  shape=[]    &quot;Y_1\/log_prob\/add\/x:0&quot;\n|  |  |  0.9189385\n|  Tensor(AddV2):0, shape=[None]    &quot;add_2:0&quot;\n|  |  Tensor(Log):0,    shape=[None]    &quot;Log_1:0&quot;\n|  |  |  Tensor(AddV2):0,   shape=[None]    &quot;AddV2:0&quot;\n|  |  |  |  Tensor(Mul):0,  shape=[None]    &quot;Mul:0&quot;\n|  |  |  |  |  Tensor(Placeholder):0,   shape=[None]    &quot;tau:0&quot;\n|  |  |  |  |  Tensor(Placeholder):0,   shape=[None]    &quot;value_y:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[None]    &quot;value_x:0&quot;\n|  |  Tensor(Sub):0,    shape=[None]    &quot;Y_1\/log_prob\/sub_1:0&quot;\n|  |  |  Tensor(Mul):0, shape=[None]    &quot;Y_1\/log_prob\/mul_1:0&quot;\n|  |  |  |  Tensor(SquaredDifference):0,    shape=[None]    &quot;Y_1\/log_prob\/SquaredDifference_1:0&quot;\n|  |  |  |  |  Tensor(RealDiv):0,   shape=[None]    &quot;Y_1\/log_prob\/truediv_1:0&quot;\n|  |  |  |  |  |  Tensor(Placeholder):0,    shape=[None]    &quot;value_x:0&quot;\n|  |  |  |  |  |  Tensor(Placeholder):0,    shape=[None]    &quot;tau:0&quot;\n|  |  |  |  |  Tensor(RealDiv):0,   shape=[None]    &quot;Y_1\/log_prob\/truediv_2:0&quot;\n|  |  |  |  |  |  Tensor(Reshape):0,    shape=[]    &quot;Reshape_1:0&quot;\n|  |  |  |  |  |  |  Tensor(AddV2):0,   shape=[None]    &quot;AddV2:0&quot;\n|  |  |  |  |  |  |  |  ...\n|  |  |  |  |  |  |  Tensor(Const):0,   shape=[0]   &quot;Reshape\/shape:0&quot;\n|  |  |  |  |  |  |  |  []\n|  |  |  |  |  |  Tensor(Placeholder):0,    shape=[None]    &quot;tau:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[]    &quot;Y_1\/log_prob\/mul\/x:0&quot;\n|  |  |  |  |  -0.5\n|  |  |  Tensor(Const):0,   shape=[]    &quot;Y_1\/log_prob\/add\/x:0&quot;\n|  |  |  |  0.9189385\n\n<\/code><\/pre>\n<\/figure>\n<\/section>\n<section id=\"missing-graph-simplifications\" class=\"level2\">\n<h2>Missing Graph Simplifications<\/h2>\n<p>From Listing <a href=\"#orga761bbe\">32<\/a> we can see that <code>grappler<\/code> is not applying enough algebraic simplifications (e.g.\u00a0it doesn\u2019t remove multiplications with 1 or reduce the <span class=\"math inline\">\\(\\left(\\mu + x - \\mu \\right)^2\\)<\/span> term in <code>SquaredDifference<\/code>).<\/p>\n<p>Does missing this simplification amount to anything practical? Listing <a href=\"#orga71aafb\">34<\/a> demonstrates the difference between our model without the simplification and a manually constructed model without the redundancy in <code>SquaredDifference<\/code>.<\/p>\n<figure id=\"orga71aafb\">\n<div class=\"sourceCode\" id=\"cb34\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb34-1\"><a href=\"#cb34-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> compute_point_diff():<\/span>\n<span id=\"cb34-2\"><a href=\"#cb34-2\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">with<\/span> graph_mode(), demo_graph.as_default():<\/span>\n<span id=\"cb34-3\"><a href=\"#cb34-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-4\"><a href=\"#cb34-4\" aria-hidden=\"true\"><\/a>        Y_trans_tfp <span class=\"op\">=<\/span> tfp.distributions.normal.Normal(<span class=\"fl\">0.0<\/span>, <span class=\"fl\">1.0<\/span>, name<span class=\"op\">=<\/span><span class=\"st\">&#39;Y_trans&#39;<\/span>)<\/span>\n<span id=\"cb34-5\"><a href=\"#cb34-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-6\"><a href=\"#cb34-6\" aria-hidden=\"true\"><\/a>        y_shifted_tf <span class=\"op\">=<\/span> x_tf <span class=\"op\">+<\/span> tau_tf <span class=\"op\">*<\/span> y_tf<\/span>\n<span id=\"cb34-7\"><a href=\"#cb34-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-8\"><a href=\"#cb34-8\" aria-hidden=\"true\"><\/a>        hier_norm_trans_lik <span class=\"op\">=<\/span> tf.math.log(y_shifted_tf) <span class=\"op\">+<\/span> Y_trans_tfp.log_prob(y_T_reshaped) <span class=\"op\">+<\/span> X_tfp.log_prob(x_tf)<\/span>\n<span id=\"cb34-9\"><a href=\"#cb34-9\" aria-hidden=\"true\"><\/a>        hier_norm_trans_lik <span class=\"op\">=<\/span> normalize_tf_graph(hier_norm_trans_lik)<\/span>\n<span id=\"cb34-10\"><a href=\"#cb34-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-11\"><a href=\"#cb34-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-12\"><a href=\"#cb34-12\" aria-hidden=\"true\"><\/a>    test_point <span class=\"op\">=<\/span> {x_tf.name: np.r_[<span class=\"fl\">1.0<\/span>],<\/span>\n<span id=\"cb34-13\"><a href=\"#cb34-13\" aria-hidden=\"true\"><\/a>                  tau_tf.name: np.r_[<span class=\"fl\">1e-20<\/span>],<\/span>\n<span id=\"cb34-14\"><a href=\"#cb34-14\" aria-hidden=\"true\"><\/a>                  y_tf.name: np.r_[<span class=\"fl\">1000.1<\/span>]}<\/span>\n<span id=\"cb34-15\"><a href=\"#cb34-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-16\"><a href=\"#cb34-16\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">with<\/span> tf.compat.v1.Session(graph<span class=\"op\">=<\/span>test_output_res.graph).as_default():<\/span>\n<span id=\"cb34-17\"><a href=\"#cb34-17\" aria-hidden=\"true\"><\/a>        val <span class=\"op\">=<\/span> test_output_res.<span class=\"bu\">eval<\/span>(test_point)<\/span>\n<span id=\"cb34-18\"><a href=\"#cb34-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-19\"><a href=\"#cb34-19\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">with<\/span> tf.compat.v1.Session(graph<span class=\"op\">=<\/span>hier_norm_trans_lik.graph).as_default():<\/span>\n<span id=\"cb34-20\"><a href=\"#cb34-20\" aria-hidden=\"true\"><\/a>        val_2 <span class=\"op\">=<\/span> hier_norm_trans_lik.<span class=\"bu\">eval<\/span>(test_point)<\/span>\n<span id=\"cb34-21\"><a href=\"#cb34-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb34-22\"><a href=\"#cb34-22\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> val, val_2<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 34\n<\/figcaption>\n<\/figure>\n<figure id=\"org7e4367b\">\n<div class=\"sourceCode\" id=\"cb35\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb35-1\"><a href=\"#cb35-1\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> np.subtract(<span class=\"op\">*<\/span>compute_point_diff())<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 35\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>[500099.94]<\/code><\/pre>\n<\/figure>\n<p>The output of Listing <a href=\"#org7e4367b\">35<\/a> shows exactly how large the discrepancy can be for carefully chosen parameter values. More specifically, as <code>tau_tf<\/code> gets smaller and the magnitude of the difference <code>x_tf - y_tf<\/code> gets larger, the discrepancy can increase. Since such parameter values are likely to be visited during sampling, we should address this missing simplification.<\/p>\n<p>In Listing <a href=\"#orge313efe\">38<\/a> we create a goal that performs that aforementioned simplification for <code>SquaredDifference<\/code>.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb37\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb37-1\"><a href=\"#cb37-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> recenter_sqrdiffo(in_g, out_g):<\/span>\n<span id=\"cb37-2\"><a href=\"#cb37-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Create a goal that essentially reduces `(a \/ d - (a + d * c) \/ d)**2` to `d**2`&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb37-3\"><a href=\"#cb37-3\" aria-hidden=\"true\"><\/a>    a_sqd_lv, b_sqd_lv, d_sqd_lv <span class=\"op\">=<\/span> var(), var(), var()<\/span>\n<span id=\"cb37-4\"><a href=\"#cb37-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb37-5\"><a href=\"#cb37-5\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Pattern: (a \/ d - b \/ d)**2<\/span><\/span>\n<span id=\"cb37-6\"><a href=\"#cb37-6\" aria-hidden=\"true\"><\/a>    target_sqrdiff_lv <span class=\"op\">=<\/span> mt.SquaredDifference(<\/span>\n<span id=\"cb37-7\"><a href=\"#cb37-7\" aria-hidden=\"true\"><\/a>        mt.realdiv(a_sqd_lv, d_sqd_lv, name<span class=\"op\">=<\/span>var()),<\/span>\n<span id=\"cb37-8\"><a href=\"#cb37-8\" aria-hidden=\"true\"><\/a>        mt.realdiv(b_sqd_lv, d_sqd_lv, name<span class=\"op\">=<\/span>var()),<\/span>\n<span id=\"cb37-9\"><a href=\"#cb37-9\" aria-hidden=\"true\"><\/a>        name<span class=\"op\">=<\/span>var()<\/span>\n<span id=\"cb37-10\"><a href=\"#cb37-10\" aria-hidden=\"true\"><\/a>    )<\/span>\n<span id=\"cb37-11\"><a href=\"#cb37-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb37-12\"><a href=\"#cb37-12\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Pattern: d * c + a<\/span><\/span>\n<span id=\"cb37-13\"><a href=\"#cb37-13\" aria-hidden=\"true\"><\/a>    c_sqd_lv <span class=\"op\">=<\/span> var()<\/span>\n<span id=\"cb37-14\"><a href=\"#cb37-14\" aria-hidden=\"true\"><\/a>    b_part_lv <span class=\"op\">=<\/span> mt.addv2(mt.mul(d_sqd_lv, c_sqd_lv, name<span class=\"op\">=<\/span>var()), a_sqd_lv, name<span class=\"op\">=<\/span>var())<\/span>\n<span id=\"cb37-15\"><a href=\"#cb37-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb37-16\"><a href=\"#cb37-16\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Replacement: c**2<\/span><\/span>\n<span id=\"cb37-17\"><a href=\"#cb37-17\" aria-hidden=\"true\"><\/a>    simplified_sqrdiff_lv <span class=\"op\">=<\/span> mt.SquaredDifference(<\/span>\n<span id=\"cb37-18\"><a href=\"#cb37-18\" aria-hidden=\"true\"><\/a>        c_sqd_lv,<\/span>\n<span id=\"cb37-19\"><a href=\"#cb37-19\" aria-hidden=\"true\"><\/a>        <span class=\"fl\">0.0<\/span><\/span>\n<span id=\"cb37-20\"><a href=\"#cb37-20\" aria-hidden=\"true\"><\/a>    )<\/span>\n<span id=\"cb37-21\"><a href=\"#cb37-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb37-22\"><a href=\"#cb37-22\" aria-hidden=\"true\"><\/a>    reshape_lv <span class=\"op\">=<\/span> var()<\/span>\n<span id=\"cb37-23\"><a href=\"#cb37-23\" aria-hidden=\"true\"><\/a>    simplified_sqrdiff_reshaped_lv <span class=\"op\">=<\/span> mt.SquaredDifference(<\/span>\n<span id=\"cb37-24\"><a href=\"#cb37-24\" aria-hidden=\"true\"><\/a>        mt.reshape(c_sqd_lv, reshape_lv),<\/span>\n<span id=\"cb37-25\"><a href=\"#cb37-25\" aria-hidden=\"true\"><\/a>        <span class=\"fl\">0.0<\/span><\/span>\n<span id=\"cb37-26\"><a href=\"#cb37-26\" aria-hidden=\"true\"><\/a>    )<\/span>\n<span id=\"cb37-27\"><a href=\"#cb37-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb37-28\"><a href=\"#cb37-28\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">=<\/span> lall(<\/span>\n<span id=\"cb37-29\"><a href=\"#cb37-29\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># input == (a \/ d - b \/ d)**2 must be &quot;true&quot;<\/span><\/span>\n<span id=\"cb37-30\"><a href=\"#cb37-30\" aria-hidden=\"true\"><\/a>        eq(in_g, target_sqrdiff_lv),<\/span>\n<span id=\"cb37-31\"><a href=\"#cb37-31\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># &quot;and&quot;<\/span><\/span>\n<span id=\"cb37-32\"><a href=\"#cb37-32\" aria-hidden=\"true\"><\/a>        conde([<\/span>\n<span id=\"cb37-33\"><a href=\"#cb37-33\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># &quot;if&quot; b == d * c + a is &quot;true&quot;<\/span><\/span>\n<span id=\"cb37-34\"><a href=\"#cb37-34\" aria-hidden=\"true\"><\/a>            eq(b_sqd_lv, b_part_lv),<\/span>\n<span id=\"cb37-35\"><a href=\"#cb37-35\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># &quot;then&quot; output ==  (c - 0)**2 is also &quot;true&quot;<\/span><\/span>\n<span id=\"cb37-36\"><a href=\"#cb37-36\" aria-hidden=\"true\"><\/a>            eq(out_g, simplified_sqrdiff_lv)<\/span>\n<span id=\"cb37-37\"><a href=\"#cb37-37\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb37-38\"><a href=\"#cb37-38\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># &quot;or&quot;<\/span><\/span>\n<span id=\"cb37-39\"><a href=\"#cb37-39\" aria-hidden=\"true\"><\/a>        ], [<\/span>\n<span id=\"cb37-40\"><a href=\"#cb37-40\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># We have to use this to cover some variation also not<\/span><\/span>\n<span id=\"cb37-41\"><a href=\"#cb37-41\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># sufficiently\/consistently &quot;normalized&quot; by `grappler`.<\/span><\/span>\n<span id=\"cb37-42\"><a href=\"#cb37-42\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb37-43\"><a href=\"#cb37-43\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># &quot;if&quot; b == reshape(d * c + a, ?) is &quot;true&quot;<\/span><\/span>\n<span id=\"cb37-44\"><a href=\"#cb37-44\" aria-hidden=\"true\"><\/a>            eq(b_sqd_lv, mt.reshape(b_part_lv, reshape_lv, name<span class=\"op\">=<\/span>var())),<\/span>\n<span id=\"cb37-45\"><a href=\"#cb37-45\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># &quot;then&quot; output == (reshape(c, ?) - 0)**2 is also &quot;true&quot;<\/span><\/span>\n<span id=\"cb37-46\"><a href=\"#cb37-46\" aria-hidden=\"true\"><\/a>            eq(out_g, simplified_sqrdiff_reshaped_lv)<\/span>\n<span id=\"cb37-47\"><a href=\"#cb37-47\" aria-hidden=\"true\"><\/a>        ]))<\/span>\n<span id=\"cb37-48\"><a href=\"#cb37-48\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> res<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>We apply the simplification in Listing <a href=\"#orge313efe\">38<\/a> and print the results in <a href=\"#orga55e147\">39<\/a>.<\/p>\n<figure id=\"orge313efe\">\n<div class=\"sourceCode\" id=\"cb38\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb38-1\"><a href=\"#cb38-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> graph_mode(), test_output_res.graph.as_default():<\/span>\n<span id=\"cb38-2\"><a href=\"#cb38-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb38-3\"><a href=\"#cb38-3\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">=<\/span> run(<span class=\"dv\">1<\/span>, var(<span class=\"st\">&#39;q&#39;<\/span>),<\/span>\n<span id=\"cb38-4\"><a href=\"#cb38-4\" aria-hidden=\"true\"><\/a>              reduceo(<span class=\"kw\">lambda<\/span> x, y: tf_graph_applyo(recenter_sqrdiffo, x, y),<\/span>\n<span id=\"cb38-5\"><a href=\"#cb38-5\" aria-hidden=\"true\"><\/a>                      test_output_res, var(<span class=\"st\">&#39;q&#39;<\/span>)))<\/span>\n<span id=\"cb38-6\"><a href=\"#cb38-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb38-7\"><a href=\"#cb38-7\" aria-hidden=\"true\"><\/a>    test_output_res <span class=\"op\">=<\/span> normalize_tf_graph(res[<span class=\"dv\">0<\/span>].eval_obj.reify())<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 38\n<\/figcaption>\n<\/figure>\n<figure id=\"orga55e147\">\n<div class=\"sourceCode\" id=\"cb39\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb39-1\"><a href=\"#cb39-1\" aria-hidden=\"true\"><\/a>tf_dprint(test_output_res.graph.get_tensor_by_name(<span class=\"st\">&#39;SquaredDifference:0&#39;<\/span>))<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 39\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(SquaredDifference):0,    shape=[None]    &quot;SquaredDifference:0&quot;\n|  Tensor(Const):0, shape=[]    &quot;X_1\/log_prob\/truediv_1:0&quot;\n|  |  0.\n|  Tensor(Placeholder):0,   shape=[None]    &quot;value_y:0&quot;\n\n<\/code><\/pre>\n<\/figure>\n<p>After simplification, the difference is now gone.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb41\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb41-1\"><a href=\"#cb41-1\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> np.subtract(<span class=\"op\">*<\/span>compute_point_diff())<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>[0.]<\/code><\/pre>\n<\/figure>\n<\/section>\n<\/section>\n<section id=\"transforming-the-log-likelihood-graph\" class=\"level1\">\n<h1>Transforming the Log-likelihood Graph<\/h1>\n<p>Now, we\u2019re ready to apply the transform to the radon model log-likelihood graph.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb43\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb43-1\"><a href=\"#cb43-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> graph_mode(), tf.Graph().as_default() <span class=\"im\">as<\/span> trans_graph:<\/span>\n<span id=\"cb43-2\"><a href=\"#cb43-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb43-3\"><a href=\"#cb43-3\" aria-hidden=\"true\"><\/a>    graph_inputs <span class=\"op\">=<\/span> [logpfn_fg.get_operation_by_name(i.name).outputs[<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"cb43-4\"><a href=\"#cb43-4\" aria-hidden=\"true\"><\/a>                    <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> logpfn_cf.structured_input_signature[<span class=\"dv\">0<\/span>]]<\/span>\n<span id=\"cb43-5\"><a href=\"#cb43-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb43-6\"><a href=\"#cb43-6\" aria-hidden=\"true\"><\/a>    logpfn_trans_tf, logpfn_remaps <span class=\"op\">=<\/span> shift_squared_terms(logpfn_fg.outputs[<span class=\"dv\">0<\/span>], graph_inputs<span class=\"op\">=<\/span>graph_inputs)<\/span>\n<span id=\"cb43-7\"><a href=\"#cb43-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb43-8\"><a href=\"#cb43-8\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> graph_mode(), logpfn_trans_tf.graph.as_default():<\/span>\n<span id=\"cb43-9\"><a href=\"#cb43-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb43-10\"><a href=\"#cb43-10\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">=<\/span> run(<span class=\"dv\">1<\/span>, var(<span class=\"st\">&#39;q&#39;<\/span>),<\/span>\n<span id=\"cb43-11\"><a href=\"#cb43-11\" aria-hidden=\"true\"><\/a>              reduceo(<span class=\"kw\">lambda<\/span> x, y: tf_graph_applyo(recenter_sqrdiffo, x, y),<\/span>\n<span id=\"cb43-12\"><a href=\"#cb43-12\" aria-hidden=\"true\"><\/a>                      logpfn_trans_tf, var(<span class=\"st\">&#39;q&#39;<\/span>)))<\/span>\n<span id=\"cb43-13\"><a href=\"#cb43-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb43-14\"><a href=\"#cb43-14\" aria-hidden=\"true\"><\/a>    logpfn_trans_tf <span class=\"op\">=<\/span> normalize_tf_graph(res[<span class=\"dv\">0<\/span>].eval_obj.reify())<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>Listing <a href=\"#org73bcbee\">44<\/a> shows the replacements that were made throughout the graph. Two replacements were found and they appear to correspond to the un-centered normal distribution terms <code>a<\/code> and <code>b<\/code> in our model\u2013as intended.<\/p>\n<figure id=\"org73bcbee\">\n<div class=\"sourceCode\" id=\"cb44\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb44-1\"><a href=\"#cb44-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> rm <span class=\"kw\">in<\/span> logpfn_remaps:<\/span>\n<span id=\"cb44-2\"><a href=\"#cb44-2\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">for<\/span> r <span class=\"kw\">in<\/span> rm:<\/span>\n<span id=\"cb44-3\"><a href=\"#cb44-3\" aria-hidden=\"true\"><\/a>      tf_dprint(r[<span class=\"dv\">0<\/span>])<\/span>\n<span id=\"cb44-4\"><a href=\"#cb44-4\" aria-hidden=\"true\"><\/a>      <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;-&gt;&quot;<\/span>)<\/span>\n<span id=\"cb44-5\"><a href=\"#cb44-5\" aria-hidden=\"true\"><\/a>      tf_dprint(r[<span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb44-6\"><a href=\"#cb44-6\" aria-hidden=\"true\"><\/a>      <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;------&quot;<\/span>)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 44\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(Placeholder):0,  shape=[85]  &quot;values_2:0&quot;\n-&gt;\nTensor(AddV2):0,    shape=[85]  &quot;AddV2:0&quot;\n|  Tensor(Placeholder):0,   shape=[]    &quot;values_4:0&quot;\n|  Tensor(Mul):0,   shape=[85]  &quot;Mul_4:0&quot;\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_2_1\/forward\/Exp:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[]    &quot;values_5:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[85]  &quot;values_2:0&quot;\n------\nTensor(Log):0,  shape=~_175065  &quot;SampleNormal_3_1\/log_prob\/Normal_3\/log_prob\/Log:0&quot;\n|  Tensor(Exp):0,   shape=[]    &quot;exp_2_1\/forward\/Exp:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_5:0&quot;\n-&gt;\n0.0\n------\n\n<\/code><\/pre>\n<\/figure>\n<p>Likewise, Listing <a href=\"#org0ce0bba\">46<\/a> shows <code>SquaredDifference<\/code> subgraphs that appear in the transformed log-likelihood.<\/p>\n<figure id=\"org0ce0bba\">\n<div class=\"sourceCode\" id=\"cb46\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb46-1\"><a href=\"#cb46-1\" aria-hidden=\"true\"><\/a>square_diff_outs <span class=\"op\">=<\/span> [o.outputs[<span class=\"dv\">0<\/span>] <span class=\"cf\">for<\/span> o <span class=\"kw\">in<\/span> logpfn_trans_tf.graph.get_operations()<\/span>\n<span id=\"cb46-2\"><a href=\"#cb46-2\" aria-hidden=\"true\"><\/a>                    <span class=\"cf\">if<\/span> o.<span class=\"bu\">type<\/span> <span class=\"op\">==<\/span> <span class=\"st\">&#39;SquaredDifference&#39;<\/span> <span class=\"kw\">or<\/span><\/span>\n<span id=\"cb46-3\"><a href=\"#cb46-3\" aria-hidden=\"true\"><\/a>                    o.<span class=\"bu\">type<\/span>.startswith(<span class=\"st\">&#39;Gather&#39;<\/span>) <span class=\"kw\">or<\/span> o.<span class=\"bu\">type<\/span> <span class=\"op\">==<\/span> <span class=\"st\">&#39;Log&#39;<\/span>]<\/span>\n<span id=\"cb46-4\"><a href=\"#cb46-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb46-5\"><a href=\"#cb46-5\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> t <span class=\"kw\">in<\/span> square_diff_outs:<\/span>\n<span id=\"cb46-6\"><a href=\"#cb46-6\" aria-hidden=\"true\"><\/a>    tf_dprint(t)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 46\n<\/figcaption>\n<\/figure>\n<figure>\n<pre class=\"text\"><code>Tensor(GatherV2):0, shape=[919] &quot;GatherV2:0&quot;\n|  Tensor(Placeholder):0,   shape=[85]  &quot;values_3:0&quot;\n|  Tensor(Const):0, shape=[919] &quot;GatherV2\/indices:0&quot;\n|  |  [ 0  0  0 ... 83 84 84]\n|  Tensor(Const):0, shape=[]    &quot;GatherV2\/axis:0&quot;\n|  |  0\nTensor(Log):0,  shape=[]    &quot;SampleNormal_2_1\/log_prob\/Normal_2\/log_prob\/Log:0&quot;\n|  Tensor(Exp):0,   shape=[]    &quot;exp_1\/forward\/Exp:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_0:0&quot;\nTensor(SquaredDifference):0,    shape=[]    &quot;Normal_5\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(Const):0, shape=[]    &quot;Const_723:0&quot;\n|  |  0.\n|  Tensor(Mul):0,   shape=[]    &quot;Normal_5\/log_prob\/truediv:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;exp_3_2\/inverse_log_det_jacobian\/mul_1:0&quot;\n|  |  |  1.\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_1:0&quot;\nTensor(SquaredDifference):0,    shape=[85]  &quot;SquaredDifference:0&quot;\n|  Tensor(Const):0, shape=[]    &quot;Const_723:0&quot;\n|  |  0.\n|  Tensor(Reshape):0,   shape=[85]  &quot;Reshape:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[85]  &quot;values_2:0&quot;\n|  |  Tensor(Const):0,  shape=[1]   &quot;SampleNormal_2_1\/log_prob\/Reshape\/shape:0&quot;\n|  |  |  [85]\nTensor(SquaredDifference):0,    shape=[]    &quot;Normal_1_1\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(Const):0, shape=[]    &quot;Const_723:0&quot;\n|  |  0.\n|  Tensor(Mul):0,   shape=[]    &quot;Normal_1_1\/log_prob\/truediv:0&quot;\n|  |  Tensor(Const):0,  shape=[]    &quot;exp_3_2\/inverse_log_det_jacobian\/mul_1:0&quot;\n|  |  |  1.\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_4:0&quot;\nTensor(Log):0,  shape=[]    &quot;Normal_4_1\/log_prob\/Log:0&quot;\n|  Tensor(Exp):0,   shape=[]    &quot;exp_3_1\/forward\/Exp:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_6:0&quot;\nTensor(SquaredDifference):0,    shape=[85]  &quot;SampleNormal_2_1\/log_prob\/Normal_2\/log_prob\/SquaredDifference:0&quot;\n|  Tensor(RealDiv):0,   shape=[85]  &quot;SampleNormal_2_1\/log_prob\/Normal_2\/log_prob\/truediv:0&quot;\n|  |  Tensor(Reshape):0,    shape=[85]  &quot;SampleNormal_2_1\/log_prob\/Reshape:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[85]  &quot;values_3:0&quot;\n|  |  |  Tensor(Const):0,   shape=[1]   &quot;SampleNormal_2_1\/log_prob\/Reshape\/shape:0&quot;\n|  |  |  |  [85]\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_1\/forward\/Exp:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[]    &quot;values_0:0&quot;\n|  Tensor(RealDiv):0,   shape=[]    &quot;SampleNormal_2_1\/log_prob\/Normal_2\/log_prob\/truediv_1:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_1:0&quot;\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_1\/forward\/Exp:0&quot;\n|  |  |  ...\nTensor(GatherV2):0, shape=[919] &quot;GatherV2_1_1:0&quot;\n|  Tensor(AddV2):0, shape=[85]  &quot;AddV2:0&quot;\n|  |  Tensor(Mul):0,    shape=[85]  &quot;Mul_4:0&quot;\n|  |  |  Tensor(Exp):0, shape=[]    &quot;exp_2_1\/forward\/Exp:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[]    &quot;values_5:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[85]  &quot;values_2:0&quot;\n|  |  Tensor(Placeholder):0,    shape=[]    &quot;values_4:0&quot;\n|  Tensor(Const):0, shape=[919] &quot;GatherV2\/indices:0&quot;\n|  |  [ 0  0  0 ... 83 84 84]\n|  Tensor(Const):0, shape=[]    &quot;GatherV2\/axis:0&quot;\n|  |  0\nTensor(SquaredDifference):0,    shape=[919] &quot;Normal_4_1\/log_prob\/SquaredDifference_1:0&quot;\n|  Tensor(RealDiv):0,   shape=[919] &quot;Normal_4_1\/log_prob\/truediv:0&quot;\n|  |  Tensor(Const):0,  shape=[919] &quot;Normal_4_1\/log_prob\/value:0&quot;\n|  |  |  [0.8329091 0.8329091 1.0986123 ... 1.6292405 1.3350011 1.0986123]\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_3_1\/forward\/Exp:0&quot;\n|  |  |  Tensor(Placeholder):0, shape=[]    &quot;values_6:0&quot;\n|  Tensor(RealDiv):0,   shape=[919] &quot;Normal_4_1\/log_prob\/truediv_1_1:0&quot;\n|  |  Tensor(AddV2):0,  shape=[919] &quot;add_12:0&quot;\n|  |  |  Tensor(GatherV2):0,    shape=[919] &quot;GatherV2:0&quot;\n|  |  |  |  Tensor(Placeholder):0,  shape=[85]  &quot;values_3:0&quot;\n|  |  |  |  Tensor(Const):0,    shape=[919] &quot;GatherV2\/indices:0&quot;\n|  |  |  |  |  [ 0  0  0 ... 83 84 84]\n|  |  |  |  Tensor(Const):0,    shape=[]    &quot;GatherV2\/axis:0&quot;\n|  |  |  |  |  0\n|  |  |  Tensor(Mul):0, shape=[919] &quot;mul_5:0&quot;\n|  |  |  |  Tensor(GatherV2):0, shape=[919] &quot;GatherV2_1_1:0&quot;\n|  |  |  |  |  Tensor(AddV2):0, shape=[85]  &quot;AddV2:0&quot;\n|  |  |  |  |  |  Tensor(Mul):0,    shape=[85]  &quot;Mul_4:0&quot;\n|  |  |  |  |  |  |  Tensor(Exp):0, shape=[]    &quot;exp_2_1\/forward\/Exp:0&quot;\n|  |  |  |  |  |  |  |  Tensor(Placeholder):0,  shape=[]    &quot;values_5:0&quot;\n|  |  |  |  |  |  |  Tensor(Placeholder):0, shape=[85]  &quot;values_2:0&quot;\n|  |  |  |  |  |  Tensor(Placeholder):0,    shape=[]    &quot;values_4:0&quot;\n|  |  |  |  |  Tensor(Const):0, shape=[919] &quot;GatherV2\/indices:0&quot;\n|  |  |  |  |  |  [ 0  0  0 ... 83 84 84]\n|  |  |  |  |  Tensor(Const):0, shape=[]    &quot;GatherV2\/axis:0&quot;\n|  |  |  |  |  |  0\n|  |  |  |  Tensor(Const):0,    shape=[919] &quot;mul\/y:0&quot;\n|  |  |  |  |  [1. 0. 0. ... 0. 0. 0.]\n|  |  Tensor(Exp):0,    shape=[]    &quot;exp_3_1\/forward\/Exp:0&quot;\n|  |  |  ...\n\n<\/code><\/pre>\n<\/figure>\n<\/section>\n<section id=\"creating-a-new-log-likelihood-function\" class=\"level1\">\n<h1>Creating a new Log-likelihood Function<\/h1>\n<p>Now that we have a transformed version of the original log-likelihood graph (i.e.\u00a0<code>logpfn_trans_tf<\/code>), we need to create a new <code>FuncGraph<\/code> from it. Listing <a href=\"#org051a679\">48<\/a> provides a simple function that creates a new <code>ConcreteFunction<\/code> from an updated output node.<\/p>\n<figure id=\"org051a679\">\n<div class=\"sourceCode\" id=\"cb48\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb48-1\"><a href=\"#cb48-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.framework.func_graph <span class=\"im\">import<\/span> FuncGraph<\/span>\n<span id=\"cb48-2\"><a href=\"#cb48-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.eager.function <span class=\"im\">import<\/span> ConcreteFunction<\/span>\n<span id=\"cb48-3\"><a href=\"#cb48-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> tensorflow.python.eager.lift_to_graph <span class=\"im\">import<\/span> lift_to_graph<\/span>\n<span id=\"cb48-4\"><a href=\"#cb48-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-5\"><a href=\"#cb48-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-6\"><a href=\"#cb48-6\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> new_tf_function(output, orig_cf):<\/span>\n<span id=\"cb48-7\"><a href=\"#cb48-7\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Create a new ConcreteFunction by replacing a single output in an existing FuncGraph.<\/span><\/span>\n<span id=\"cb48-8\"><a href=\"#cb48-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-9\"><a href=\"#cb48-9\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb48-10\"><a href=\"#cb48-10\" aria-hidden=\"true\"><\/a>    orig_fg <span class=\"op\">=<\/span> orig_cf.graph<\/span>\n<span id=\"cb48-11\"><a href=\"#cb48-11\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># with trans_graph.as_default(): #orig_fg.as_default():<\/span><\/span>\n<span id=\"cb48-12\"><a href=\"#cb48-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-13\"><a href=\"#cb48-13\" aria-hidden=\"true\"><\/a>    logpfn_fg_new <span class=\"op\">=<\/span> FuncGraph(<span class=\"st\">&#39;logpfn_new&#39;<\/span>, orig_fg.collections, orig_fg.capture_by_value)<\/span>\n<span id=\"cb48-14\"><a href=\"#cb48-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-15\"><a href=\"#cb48-15\" aria-hidden=\"true\"><\/a>    old_to_new_ops <span class=\"op\">=<\/span> lift_to_graph([output],<\/span>\n<span id=\"cb48-16\"><a href=\"#cb48-16\" aria-hidden=\"true\"><\/a>                                    logpfn_fg_new,<\/span>\n<span id=\"cb48-17\"><a href=\"#cb48-17\" aria-hidden=\"true\"><\/a>                                    add_sources<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb48-18\"><a href=\"#cb48-18\" aria-hidden=\"true\"><\/a>                                    handle_captures<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb48-19\"><a href=\"#cb48-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-20\"><a href=\"#cb48-20\" aria-hidden=\"true\"><\/a>    logpfn_fg_new.structured_input_signature <span class=\"op\">=<\/span> orig_fg.structured_input_signature<\/span>\n<span id=\"cb48-21\"><a href=\"#cb48-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-22\"><a href=\"#cb48-22\" aria-hidden=\"true\"><\/a>    new_inputs <span class=\"op\">=<\/span> [old_to_new_ops.get(output.graph.get_operation_by_name(i.name).outputs[<span class=\"dv\">0<\/span>])<\/span>\n<span id=\"cb48-23\"><a href=\"#cb48-23\" aria-hidden=\"true\"><\/a>                  <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> orig_cf.structured_input_signature[<span class=\"dv\">0<\/span>]]<\/span>\n<span id=\"cb48-24\"><a href=\"#cb48-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-25\"><a href=\"#cb48-25\" aria-hidden=\"true\"><\/a>    logpfn_fg_new.inputs <span class=\"op\">=<\/span> new_inputs<\/span>\n<span id=\"cb48-26\"><a href=\"#cb48-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-27\"><a href=\"#cb48-27\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">assert<\/span> <span class=\"bu\">all<\/span>(i <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span> <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> logpfn_fg_new.inputs)<\/span>\n<span id=\"cb48-28\"><a href=\"#cb48-28\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-29\"><a href=\"#cb48-29\" aria-hidden=\"true\"><\/a>    logpfn_fg_new.outputs <span class=\"op\">=<\/span> [old_to_new_ops[output]]<\/span>\n<span id=\"cb48-30\"><a href=\"#cb48-30\" aria-hidden=\"true\"><\/a>    logpfn_fg_new.structured_outputs <span class=\"op\">=<\/span> logpfn_fg_new.outputs[<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"cb48-31\"><a href=\"#cb48-31\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-32\"><a href=\"#cb48-32\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">assert<\/span> logpfn_fg_new.as_graph_element(logpfn_fg_new.outputs[<span class=\"dv\">0<\/span>]) <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"cb48-33\"><a href=\"#cb48-33\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-34\"><a href=\"#cb48-34\" aria-hidden=\"true\"><\/a>    logpfn_new_cf <span class=\"op\">=<\/span> ConcreteFunction(logpfn_fg_new)<\/span>\n<span id=\"cb48-35\"><a href=\"#cb48-35\" aria-hidden=\"true\"><\/a>    logpfn_new_cf._arg_keywords <span class=\"op\">=<\/span> orig_cf._arg_keywords<\/span>\n<span id=\"cb48-36\"><a href=\"#cb48-36\" aria-hidden=\"true\"><\/a>    logpfn_new_cf._num_positional_args <span class=\"op\">=<\/span> <span class=\"bu\">len<\/span>(logpfn_fg_new.inputs)<\/span>\n<span id=\"cb48-37\"><a href=\"#cb48-37\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb48-38\"><a href=\"#cb48-38\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> logpfn_new_cf<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 48\n<\/figcaption>\n<\/figure>\n<figure id=\"orge94e0bc\">\n<div class=\"sourceCode\" id=\"cb49\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb49-1\"><a href=\"#cb49-1\" aria-hidden=\"true\"><\/a>logpfn_new_cf <span class=\"op\">=<\/span> new_tf_function(logpfn_trans_tf, logpfn_cf)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 49\n<\/figcaption>\n<\/figure>\n<p>The new TF function, <code>logpfn_new_cf<\/code>, in Listing <a href=\"#org051a679\">48<\/a> is the function we are going to use for sampling from the new log-likelihood.<\/p>\n<figure id=\"org7bec3c9\">\n<div class=\"sourceCode\" id=\"cb50\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb50-1\"><a href=\"#cb50-1\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> logpfn_cf(<span class=\"op\">*<\/span>init.values()) <span class=\"op\">-<\/span> logpfn_new_cf(<span class=\"op\">*<\/span>init.values())<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 50\n<\/figcaption>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb51\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb51-1\"><a href=\"#cb51-1\" aria-hidden=\"true\"><\/a>tf.Tensor(<span class=\"fl\">153.41016<\/span>, shape<span class=\"op\">=<\/span>(), dtype<span class=\"op\">=<\/span>float32)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>Listing <a href=\"#org7bec3c9\">50<\/a> shows the difference between a transformed and non-transformed log-likelihood value given the same inputs.<\/p>\n<\/section>\n<section id=\"sampling-from-the-new-log-likelihood\" class=\"level1\">\n<h1>Sampling from the new Log-likelihood<\/h1>\n<p>In Listing <a href=\"#org4a79807\">53<\/a>, we reproduce the remaining steps of <code>pm.inference.sampling.sample<\/code> and\u2013unnaturally\u2013force the PyMC4 machinery to draw samples from our new transformed log-likelihood function.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb52\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb52-1\"><a href=\"#cb52-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> contextlib <span class=\"im\">import<\/span> contextmanager<\/span>\n<span id=\"cb52-2\"><a href=\"#cb52-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-3\"><a href=\"#cb52-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-4\"><a href=\"#cb52-4\" aria-hidden=\"true\"><\/a><span class=\"co\"># We need to create new initial values for our transformed variables.<\/span><\/span>\n<span id=\"cb52-5\"><a href=\"#cb52-5\" aria-hidden=\"true\"><\/a>new_val_map <span class=\"op\">=<\/span> {}<\/span>\n<span id=\"cb52-6\"><a href=\"#cb52-6\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> logpfn_remap <span class=\"kw\">in<\/span> logpfn_remaps:<\/span>\n<span id=\"cb52-7\"><a href=\"#cb52-7\" aria-hidden=\"true\"><\/a>    transed_var <span class=\"op\">=<\/span> logpfn_remap[<span class=\"dv\">0<\/span>][<span class=\"dv\">0<\/span>].reify()<\/span>\n<span id=\"cb52-8\"><a href=\"#cb52-8\" aria-hidden=\"true\"><\/a>    transed_var_pymc_name <span class=\"op\">=<\/span> tfp_names_to_pymc[transed_var.op.name]<\/span>\n<span id=\"cb52-9\"><a href=\"#cb52-9\" aria-hidden=\"true\"><\/a>    old_val_np <span class=\"op\">=<\/span> init[transed_var_pymc_name].numpy()<\/span>\n<span id=\"cb52-10\"><a href=\"#cb52-10\" aria-hidden=\"true\"><\/a>    new_val_np <span class=\"op\">=<\/span> np.random.standard_normal(old_val_np.shape).astype(old_val_np.dtype)<\/span>\n<span id=\"cb52-11\"><a href=\"#cb52-11\" aria-hidden=\"true\"><\/a>    new_val_map[transed_var_pymc_name] <span class=\"op\">=<\/span> tf.convert_to_tensor(new_val_np)<\/span>\n<span id=\"cb52-12\"><a href=\"#cb52-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-13\"><a href=\"#cb52-13\" aria-hidden=\"true\"><\/a>new_init <span class=\"op\">=<\/span> init.copy()<\/span>\n<span id=\"cb52-14\"><a href=\"#cb52-14\" aria-hidden=\"true\"><\/a>new_init.update(new_val_map)<\/span>\n<span id=\"cb52-15\"><a href=\"#cb52-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-16\"><a href=\"#cb52-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-17\"><a href=\"#cb52-17\" aria-hidden=\"true\"><\/a><span class=\"at\">@contextmanager<\/span><\/span>\n<span id=\"cb52-18\"><a href=\"#cb52-18\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> pymc4_force_logp(logpfn_new_cf, new_init):<\/span>\n<span id=\"cb52-19\"><a href=\"#cb52-19\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Temporarily fix the logp function and init values used by PyMC4&#39;s sampler.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb52-20\"><a href=\"#cb52-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-21\"><a href=\"#cb52-21\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> _new_build_logp_function(<span class=\"op\">*<\/span>args, <span class=\"op\">**<\/span>kwargs):<\/span>\n<span id=\"cb52-22\"><a href=\"#cb52-22\" aria-hidden=\"true\"><\/a>        <span class=\"kw\">nonlocal<\/span> logpfn_new_cf, new_init<\/span>\n<span id=\"cb52-23\"><a href=\"#cb52-23\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> logpfn_new_cf, new_init<\/span>\n<span id=\"cb52-24\"><a href=\"#cb52-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-25\"><a href=\"#cb52-25\" aria-hidden=\"true\"><\/a>    _old_fn <span class=\"op\">=<\/span> pm.inference.sampling.build_logp_function<\/span>\n<span id=\"cb52-26\"><a href=\"#cb52-26\" aria-hidden=\"true\"><\/a>    pm.inference.sampling.build_logp_function <span class=\"op\">=<\/span> _new_build_logp_function<\/span>\n<span id=\"cb52-27\"><a href=\"#cb52-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb52-28\"><a href=\"#cb52-28\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">try<\/span>:<\/span>\n<span id=\"cb52-29\"><a href=\"#cb52-29\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">yield<\/span><\/span>\n<span id=\"cb52-30\"><a href=\"#cb52-30\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">finally<\/span>:<\/span>\n<span id=\"cb52-31\"><a href=\"#cb52-31\" aria-hidden=\"true\"><\/a>        pm.inference.sampling.build_logp_function <span class=\"op\">=<\/span> _old_fn<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"org4a79807\">\n<div class=\"sourceCode\" id=\"cb53\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb53-1\"><a href=\"#cb53-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pymc4_force_logp(logpfn_new_cf, new_init):<\/span>\n<span id=\"cb53-2\"><a href=\"#cb53-2\" aria-hidden=\"true\"><\/a>    az_trace <span class=\"op\">=<\/span> sample(model)<\/span><\/code><\/pre><\/div>\n<figcaption>\nListing 53\n<\/figcaption>\n<\/figure>\n<figure id=\"fig:transformed-model-plot-energy\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/transformed-model-plot-energy.png\" title=\"fig:\" alt=\"\" \/>\n<figcaption>\n<\/figcaption>\n<\/figure>\n<figure id=\"fig:transformed-model-plot-trace\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/transformed-model-plot-trace.png\" title=\"fig:\" alt=\"\" \/>\n<figcaption>\n<\/figcaption>\n<\/figure>\n<\/section>\n<section id=\"discussion\" class=\"level1\">\n<h1>Discussion<\/h1>\n<p>The goals in the two separate <code>run<\/code> calls we used in Listing <a href=\"#org0ad3a96\">23<\/a> could have been combined into a single <code>run<\/code>. This could\u2019ve been accomplished using some \u201cmeta\u201d steps (e.g.\u00a0construct and evaluate a goal on-the-fly within a miniKanren) or special goals for reading from a miniKanren-generated <code>dict<\/code>s or association lists. Goals of this nature are not uncommon (e.g.\u00a0type inference and inhabitation exmaples), and serve to demonstrate the great breadth of activity possible within relational context of miniKanren.<\/p>\n<p>However, the point we want to make doesn\u2019t require much sophistication. Instead, we wanted to demonstrate how a non-trivial \u201cpattern\u201d can be specified and matched using <code>symbolic-pymc<\/code>, and how easily those results could be used to transform a graph.<\/p>\n<p>More specifically, our goal <code>shift_squared_subso<\/code> in <a href=\"#org0ad3a96\">23<\/a> demonstrates <strong>the way in which we were able to specify desired structure(s) within a graph<\/strong>. We defined one pattern, <code>Y_sqrdiffo<\/code>, to match anywhere in the graph then another pattern, <code>X_sqrdiffo<\/code>, that relied on matched terms from <code>Y_sqrdiffo<\/code> and could also be matched\/found anywhere else in the same graph.<\/p>\n<p>Furthermore, our substitutions needed information from both \u201cmatched\u201d subgraphs. Specifically, substitution pairs similar to <code>(x, z + x)<\/code>. Within this framework, we could just as easily have included <code>y<\/code>\u2013or any terms from either successfully matched subgraph\u2013in the substitution expressions.<\/p>\n<p>In sample-space, the search patterns and substitutions are much easier to specify exactly because they\u2019re single-subgraph patterns that themselves are the subgraphs to be replaced (i.e.\u00a0if we find a non-standard normal, replace it with a shifted\/scaled standard normal). In log-space, we chose to find distinct subgraph \u201cchains\u201d, i.e.\u00a0all <code>(y - x)**2<\/code> and <code>(x - z)**2<\/code> pairs (i.e.\u00a0\u201cconnected\u201d by an \u201cunknown\u201d term <code>x<\/code>), since these are produced by the log-likelihood form of hierarchical normal distributions.<\/p>\n<p>As a result, we had a non-trivial structure\/\u201cpattern\u201d to express\u2013and execute. Using conventional graph search-and-replace functionality would\u2019ve required much more orchestration and resulted considerably less flexible code with little-to-no reusability. In our case, the goals <code>onceo<\/code> and <code>tf_graph_applyo<\/code> are universal and the forms in <code>shift_squared_subso<\/code> can be easily changed to account for more sophisticated (or entirely distinct) patterns and substitutions.<\/p>\n<p>Most related graph manipulation offerings make it easy to find a single subgraph that matches a pattern, but not potentially \u201cco-dependent\u201d and\/or distinct subgraphs. In the end, the developer will often have to manually implement a \u201cglobal\u201d state and orchestrate multiple single-subgraph searches and their results.<\/p>\n<p>For single search-and-replace objectives, this amount of manual developer intervention\/orchestration might be excusable; however, for objectives requiring the evaluation of multiple graph transformation, this approach is mostly unmaintainable and extremely difficult to compartmentalize.<\/p>\n<p>This demonstration barely even scratches the surface of what\u2019s possible using miniKanren and relational programming for graph manipulation and symbolic statistical model optimization. As the <code>symbolic-pymc<\/code> project advances, we\u2019ll cover examples in which miniKanren\u2019s more distinct offerings are demonstrated.<\/p>\n<\/section>\n<\/body>\n<\/html>\n","category":[{"@attributes":{"term":"articles"}},{"@attributes":{"term":"pymc4"}},{"@attributes":{"term":"tensorflow"}},{"@attributes":{"term":"symbolic computation"}},{"@attributes":{"term":"python"}},{"@attributes":{"term":"symbolic-pymc"}}]},{"title":"Random Variables in Theano","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/random-variables-in-theano.html","rel":"alternate"}},"published":"2018-12-28T00:00:00-06:00","updated":"2019-02-04T00:00:00-06:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2018-12-28:\/random-variables-in-theano.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>Random Variables in Theano<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">Random Variables in Theano<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2018\u201312\u201328<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<div class=\"abstract\">\n<p>Continuing from <a href=\"#24875a2c31fa7f94ce562adddedc0bf8\">Willard, Brandon T. (2018)<\/a>, we\u2019ll attempt to improve upon <code>RandomFunction<\/code> and make a case for a similar <code>Op<\/code> in PyMC3.<\/p>\n<\/div>\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p>We\u2019ll call the new <code>Op<\/code> developed here <code>RandomVariable<\/code>, since random variables are the abstraction we\u2019re primarily targeting. <code>RandomVariable<\/code> will provide the functionality of <code>Distribution<\/code>, <code>FreeRV<\/code> and <code>ObservedRV<\/code>, and, by working at the <code>Op<\/code> level, it will be much more capable of leveraging existing Theano functionality.<\/p>\n<p>Specifically, by using the <code>Op<\/code> interface, we\u2019re able to do the following:<\/p>\n<ol type=\"1\">\n<li><p>Remove the need for an explicitly specified shape parameter.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>For example, definitions like<\/p>\n<div class=\"sourceCode\" id=\"org9f36fc6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org9f36fc6-1\"><a href=\"#org9f36fc6-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model():<\/span>\n<span id=\"org9f36fc6-2\"><a href=\"#org9f36fc6-2\" aria-hidden=\"true\"><\/a>    X_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;X_rv&#39;<\/span>, mu_X, sd<span class=\"op\">=<\/span>sd_X, shape<span class=\"op\">=<\/span>(<span class=\"dv\">1<\/span>,))<\/span><\/code><\/pre><\/div>\n<p>reduce to<\/p>\n<div class=\"sourceCode\" id=\"orgf62e7a4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgf62e7a4-1\"><a href=\"#orgf62e7a4-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model():<\/span>\n<span id=\"orgf62e7a4-2\"><a href=\"#orgf62e7a4-2\" aria-hidden=\"true\"><\/a>    X_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;X_rv&#39;<\/span>, mu_X, sd<span class=\"op\">=<\/span>sd_X)<\/span><\/code><\/pre><\/div>\n<\/div><\/li>\n<li><p>Random variable nodes created by an <code>Op<\/code> automatically implement <code>Distribution.default<\/code>\/<code>Distribution.get_test_val<\/code> functionality and remove the reliance on initial values during random variable instantiation. <code>Op<\/code> automatically uses <code>Op.perform<\/code>, which will draw a sample as a test value <strong>and<\/strong> propagate it throughout the graph to down-stream tensor variables.<\/p><\/li>\n<li><p>Log-densities can be generated as secondary outputs of <code>Op.make_node<\/code>, which removes the need for <code>Distribution.logp*<\/code> methods.<\/p><\/li>\n<li><p><code>pymc.distribution.draw_values<\/code> and related methods are no longer necessary; their functionality is already covered within Theano\u2019s existing graph machinery\u2013in the same way as <code>pymc.distribution.Distribution.default\/get_test_val<\/code>.<\/p><\/li>\n<\/ol>\n<p>The main points of entry in our <code>Op<\/code>, are <code>Op.make_node<\/code> and <code>Op.perform<\/code>. <code>Op.make_node<\/code> is used during symbolic graph creation and provides immediate access to the <code>Op<\/code>\u2019s symbolic inputs\u2013serving a purpose similar to <code>Distribution.__init__<\/code>. <code>Op.make_node<\/code> is where shape inference tasks (e.g.\u00a0<a href=\"https:\/\/github.com\/pymc-devs\/pymc3\/pull\/1125\">PyMC3 PR 1125<\/a>) are more suitably addressed; however, <code>Op<\/code> provides additional means of shape inference and management (e.g.\u00a0<code>Op.infer_shape<\/code>) occurring at different phases of graph compilation that aren\u2019t readily accessible outside of the <code>Op<\/code> framework.<\/p>\n<\/section>\n<section id=\"a-new-random-variable-op\" class=\"level1\">\n<h1>A <strong>new<\/strong> Random Variable <code>Op<\/code><\/h1>\n<div class=\"sourceCode\" id=\"org9c584cc\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org9c584cc-1\"><a href=\"#org9c584cc-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> sys<\/span>\n<span id=\"org9c584cc-2\"><a href=\"#org9c584cc-2\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> os<\/span>\n<span id=\"org9c584cc-3\"><a href=\"#org9c584cc-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org9c584cc-4\"><a href=\"#org9c584cc-4\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> pprint <span class=\"im\">import<\/span> pprint<\/span>\n<span id=\"org9c584cc-5\"><a href=\"#org9c584cc-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org9c584cc-6\"><a href=\"#org9c584cc-6\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> numpy <span class=\"im\">as<\/span> np<\/span>\n<span id=\"org9c584cc-7\"><a href=\"#org9c584cc-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org9c584cc-8\"><a href=\"#org9c584cc-8\" aria-hidden=\"true\"><\/a>os.environ[<span class=\"st\">&#39;MKL_THREADING_LAYER&#39;<\/span>] <span class=\"op\">=<\/span> <span class=\"st\">&#39;GNU&#39;<\/span><\/span>\n<span id=\"org9c584cc-9\"><a href=\"#org9c584cc-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org9c584cc-10\"><a href=\"#org9c584cc-10\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano<\/span>\n<span id=\"org9c584cc-11\"><a href=\"#org9c584cc-11\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano.tensor <span class=\"im\">as<\/span> tt<\/span>\n<span id=\"org9c584cc-12\"><a href=\"#org9c584cc-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org9c584cc-13\"><a href=\"#org9c584cc-13\" aria-hidden=\"true\"><\/a>theano.config.mode <span class=\"op\">=<\/span> <span class=\"st\">&#39;FAST_COMPILE&#39;<\/span><\/span>\n<span id=\"org9c584cc-14\"><a href=\"#org9c584cc-14\" aria-hidden=\"true\"><\/a>theano.config.exception_verbosity <span class=\"op\">=<\/span> <span class=\"st\">&#39;high&#39;<\/span><\/span>\n<span id=\"org9c584cc-15\"><a href=\"#org9c584cc-15\" aria-hidden=\"true\"><\/a><span class=\"co\"># <\/span><span class=\"al\">NOTE<\/span><span class=\"co\">: pymc3 requires test values<\/span><\/span>\n<span id=\"org9c584cc-16\"><a href=\"#org9c584cc-16\" aria-hidden=\"true\"><\/a>theano.config.compute_test_value <span class=\"op\">=<\/span> <span class=\"st\">&#39;warn&#39;<\/span><\/span>\n<span id=\"org9c584cc-17\"><a href=\"#org9c584cc-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org9c584cc-18\"><a href=\"#org9c584cc-18\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pymc3 <span class=\"im\">as<\/span> pm<\/span><\/code><\/pre><\/div>\n<p>Most of the work involved in generalizing <code>RandomFunction<\/code> has to do with symbolic shape handling and inference. We need to bridge the gaps between symbolic array\/tensor broadcasting parameters and the way Numpy random variable functions allow distribution parameters to be specified.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>Scalar normal random variates have a support and parameters with dimension zero. In Listing <a href=\"#org352d751\">4<\/a> we create a scalar normal random variate in Numpy and inspect its shape. The length of the shape corresponds to the dimension of the distribution\u2019s support (i.e.\u00a0zero).<\/p>\n<div class=\"sourceCode\" id=\"org352d751\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org352d751-1\"><a href=\"#org352d751-1\" aria-hidden=\"true\"><\/a>np.shape(np.random.normal(loc<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>, scale<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org96612e2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org96612e2-1\"><a href=\"#org96612e2-1\" aria-hidden=\"true\"><\/a>()<\/span><\/code><\/pre><\/div>\n<p>Numpy also allows one to specify <strong>independent<\/strong> normal variates using one function call with each variate\u2019s parameters spanning dimensions higher than the variate\u2019s. In Listing <a href=\"#orgbabaa48\">6<\/a> we specify three independent scalar normal variates, each with a different mean and scale parameter. This time, the result\u2019s shape reflects <strong>the number of independent random variates<\/strong>, and not the dimension of the underlying distribution\u2019s support.<\/p>\n<div class=\"sourceCode\" id=\"orgbabaa48\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgbabaa48-1\"><a href=\"#orgbabaa48-1\" aria-hidden=\"true\"><\/a>np.shape(np.random.normal(loc<span class=\"op\">=<\/span>[<span class=\"dv\">0<\/span>, <span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>], scale<span class=\"op\">=<\/span>[<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>], size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgef69d1e\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgef69d1e-1\"><a href=\"#orgef69d1e-1\" aria-hidden=\"true\"><\/a>(<span class=\"dv\">3<\/span>,)<\/span><\/code><\/pre><\/div>\n<p>Distribution parameters can also be broadcasted, as in <a href=\"#org6643ee7\">8<\/a>. Now, each independent variate has the same scale value.<\/p>\n<div class=\"sourceCode\" id=\"org6643ee7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org6643ee7-1\"><a href=\"#org6643ee7-1\" aria-hidden=\"true\"><\/a>np.shape(np.random.normal(loc<span class=\"op\">=<\/span>[<span class=\"dv\">0<\/span>, <span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>], scale<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>))<\/span><\/code><\/pre><\/div>\n<p>The <code>size<\/code> parameter effectively replicates variates, in-line with the\u2013potentially broadcasted\u2013distribution parameters.<\/p>\n<p>When bridging these Numpy functions and Theano, we have to adapt the underlying parameter\/shape logic of functions like <code>np.random.normal<\/code> to a scenario involving symbolic parameters and their symbolic shapes.<\/p>\n<p>For instance, in Theano a <strong>symbolic<\/strong> scalar\u2019s shape is represented in nearly the same way.<\/p>\n<div class=\"sourceCode\" id=\"orga6116d4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orga6116d4-1\"><a href=\"#orga6116d4-1\" aria-hidden=\"true\"><\/a>test_scalar <span class=\"op\">=<\/span> tt.scalar()<\/span>\n<span id=\"orga6116d4-2\"><a href=\"#orga6116d4-2\" aria-hidden=\"true\"><\/a>test_scalar.shape.<span class=\"bu\">eval<\/span>({test_scalar: <span class=\"dv\">1<\/span>})<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgb9018eb\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgb9018eb-1\"><a href=\"#orgb9018eb-1\" aria-hidden=\"true\"><\/a>[]<\/span><\/code><\/pre><\/div>\n<p>This means that our proposed Theano adaptation of <code>np.random.normal<\/code>, let\u2019s call it <code>tt_normal<\/code>, should return the same result as Numpy in the case of scalars.<\/p>\n<p>What about <code>tt_normal(loc=tt.vector(), scale=tt.vector(), size=None)<\/code>? Since the inputs are purely symbolic, the resulting symbolic object\u2019s shape should be, too, but we should also know that the symbolic shape should have dimension equal to one. Just as in Listing <a href=\"#orgbabaa48\">6<\/a>, each corresponding element in the vector arguments of <code>tt_normal<\/code> is an independent variate; in the symbolic case, we might not know exactly how many of them there are, yet, but we know that there\u2019s a vector\u2019s worth of them.<\/p>\n<p>How exactly do we get that information from Theano, though? The type produced by <code>tt.vector<\/code> has an <code>ndim<\/code> parameter that provides this. Furthermore, there is some (intermittent) functionality that allows one to iterate over shapes. Listing <a href=\"#org0d70518\">11<\/a> demonstrates this.<\/p>\n<div class=\"sourceCode\" id=\"org0d70518\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org0d70518-1\"><a href=\"#org0d70518-1\" aria-hidden=\"true\"><\/a>test_matrix <span class=\"op\">=<\/span> tt.matrix()<\/span>\n<span id=\"org0d70518-2\"><a href=\"#org0d70518-2\" aria-hidden=\"true\"><\/a>shape_parts <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(test_matrix.shape)<\/span>\n<span id=\"org0d70518-3\"><a href=\"#org0d70518-3\" aria-hidden=\"true\"><\/a>shape_parts<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgd865db6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgd865db6-1\"><a href=\"#orgd865db6-1\" aria-hidden=\"true\"><\/a>(Subtensor{int64}<span class=\"fl\">.0<\/span>, Subtensor{int64}<span class=\"fl\">.0<\/span>)<\/span><\/code><\/pre><\/div>\n<p>When the matrix in Listing <a href=\"#org0d70518\">11<\/a> is \u201cmaterialized\u201d (i.e.\u00a0given a value), its corresponding shape object\u2013and its components\u2013will take their respective values.<\/p>\n<div class=\"sourceCode\" id=\"org4d51e03\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org4d51e03-1\"><a href=\"#org4d51e03-1\" aria-hidden=\"true\"><\/a><span class=\"bu\">tuple<\/span>(p.<span class=\"bu\">eval<\/span>({test_matrix: np.diag([<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>])}) <span class=\"cf\">for<\/span> p <span class=\"kw\">in<\/span> shape_parts)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgc53dece\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgc53dece-1\"><a href=\"#orgc53dece-1\" aria-hidden=\"true\"><\/a>(array(<span class=\"dv\">2<\/span>), array(<span class=\"dv\">2<\/span>))<\/span><\/code><\/pre><\/div>\n<p>If we knew that the support of this distribution was a scalar\/vector\/matrix, then these <code>ndim<\/code>-related results\u2013obtained from the symbolic parameters\u2013would tell us that we have multiple, independent variates and we could reliably extract the symbolic variables corresponding to those actual dimension sizes.<\/p>\n<\/div>\n<p>To determine the shape parts (i.e.\u00a0support, number of independent and replicated variates) of the symbolic random variables, we mimic the corresponding Numpy logic and use the Theano <code>ndim<\/code> shape information described above. The following function generalizes that work for many simple distributions.<\/p>\n<div class=\"sourceCode\" id=\"org03297c0\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org03297c0-1\"><a href=\"#org03297c0-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> collections.abc <span class=\"im\">import<\/span> Iterable, ByteString<\/span>\n<span id=\"org03297c0-2\"><a href=\"#org03297c0-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> warnings <span class=\"im\">import<\/span> warn<\/span>\n<span id=\"org03297c0-3\"><a href=\"#org03297c0-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> copy <span class=\"im\">import<\/span> copy<\/span>\n<span id=\"org03297c0-4\"><a href=\"#org03297c0-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-5\"><a href=\"#org03297c0-5\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.tensor.raw_random <span class=\"im\">import<\/span> (RandomFunction, RandomStateType,<\/span>\n<span id=\"org03297c0-6\"><a href=\"#org03297c0-6\" aria-hidden=\"true\"><\/a>                                      _infer_ndim_bcast)<\/span>\n<span id=\"org03297c0-7\"><a href=\"#org03297c0-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-8\"><a href=\"#org03297c0-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-9\"><a href=\"#org03297c0-9\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> param_supp_shape_fn(ndim_supp, ndims_params, dist_params,<\/span>\n<span id=\"org03297c0-10\"><a href=\"#org03297c0-10\" aria-hidden=\"true\"><\/a>                        rep_param_idx<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>, param_shapes<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"org03297c0-11\"><a href=\"#org03297c0-11\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;A function for deriving a random variable&#39;s support shape\/dimensions<\/span><\/span>\n<span id=\"org03297c0-12\"><a href=\"#org03297c0-12\" aria-hidden=\"true\"><\/a><span class=\"co\">    from one of its parameters.<\/span><\/span>\n<span id=\"org03297c0-13\"><a href=\"#org03297c0-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-14\"><a href=\"#org03297c0-14\" aria-hidden=\"true\"><\/a><span class=\"co\">    XXX: It&#39;s not always possible to determine a random variable&#39;s support<\/span><\/span>\n<span id=\"org03297c0-15\"><a href=\"#org03297c0-15\" aria-hidden=\"true\"><\/a><span class=\"co\">    shape from its parameters, so this function has fundamentally limited<\/span><\/span>\n<span id=\"org03297c0-16\"><a href=\"#org03297c0-16\" aria-hidden=\"true\"><\/a><span class=\"co\">    applicability.<\/span><\/span>\n<span id=\"org03297c0-17\"><a href=\"#org03297c0-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-18\"><a href=\"#org03297c0-18\" aria-hidden=\"true\"><\/a><span class=\"co\">    XXX: This function is not expected to handle `ndim_supp = 0` (i.e.<\/span><\/span>\n<span id=\"org03297c0-19\"><a href=\"#org03297c0-19\" aria-hidden=\"true\"><\/a><span class=\"co\">    scalars), since that is already definitively handled in the `Op` that<\/span><\/span>\n<span id=\"org03297c0-20\"><a href=\"#org03297c0-20\" aria-hidden=\"true\"><\/a><span class=\"co\">    calls this.<\/span><\/span>\n<span id=\"org03297c0-21\"><a href=\"#org03297c0-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-22\"><a href=\"#org03297c0-22\" aria-hidden=\"true\"><\/a><span class=\"co\">    <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: Consider using `theano.compile.ops.shape_i` alongside `ShapeFeature`.<\/span><\/span>\n<span id=\"org03297c0-23\"><a href=\"#org03297c0-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-24\"><a href=\"#org03297c0-24\" aria-hidden=\"true\"><\/a><span class=\"co\">    Parameters<\/span><\/span>\n<span id=\"org03297c0-25\"><a href=\"#org03297c0-25\" aria-hidden=\"true\"><\/a><span class=\"co\">    ==========<\/span><\/span>\n<span id=\"org03297c0-26\"><a href=\"#org03297c0-26\" aria-hidden=\"true\"><\/a><span class=\"co\">    ndim_supp: int<\/span><\/span>\n<span id=\"org03297c0-27\"><a href=\"#org03297c0-27\" aria-hidden=\"true\"><\/a><span class=\"co\">        Total number of dimensions in the support (assumedly &gt; 0).<\/span><\/span>\n<span id=\"org03297c0-28\"><a href=\"#org03297c0-28\" aria-hidden=\"true\"><\/a><span class=\"co\">    ndims_params: list of int<\/span><\/span>\n<span id=\"org03297c0-29\"><a href=\"#org03297c0-29\" aria-hidden=\"true\"><\/a><span class=\"co\">        Number of dimensions for each distribution parameter.<\/span><\/span>\n<span id=\"org03297c0-30\"><a href=\"#org03297c0-30\" aria-hidden=\"true\"><\/a><span class=\"co\">    dist_params: list of `theano.gof.graph.Variable`<\/span><\/span>\n<span id=\"org03297c0-31\"><a href=\"#org03297c0-31\" aria-hidden=\"true\"><\/a><span class=\"co\">        The distribution parameters.<\/span><\/span>\n<span id=\"org03297c0-32\"><a href=\"#org03297c0-32\" aria-hidden=\"true\"><\/a><span class=\"co\">    param_shapes: list of `theano.compile.ops.Shape` (optional)<\/span><\/span>\n<span id=\"org03297c0-33\"><a href=\"#org03297c0-33\" aria-hidden=\"true\"><\/a><span class=\"co\">        Symbolic shapes for each distribution parameter.<\/span><\/span>\n<span id=\"org03297c0-34\"><a href=\"#org03297c0-34\" aria-hidden=\"true\"><\/a><span class=\"co\">        Providing this value prevents us from reproducing the requisite<\/span><\/span>\n<span id=\"org03297c0-35\"><a href=\"#org03297c0-35\" aria-hidden=\"true\"><\/a><span class=\"co\">        `theano.compile.ops.Shape` object (e.g. when it&#39;s already available to<\/span><\/span>\n<span id=\"org03297c0-36\"><a href=\"#org03297c0-36\" aria-hidden=\"true\"><\/a><span class=\"co\">        the caller).<\/span><\/span>\n<span id=\"org03297c0-37\"><a href=\"#org03297c0-37\" aria-hidden=\"true\"><\/a><span class=\"co\">    rep_param_idx: int (optional)<\/span><\/span>\n<span id=\"org03297c0-38\"><a href=\"#org03297c0-38\" aria-hidden=\"true\"><\/a><span class=\"co\">        The index of the distribution parameter to use as a reference<\/span><\/span>\n<span id=\"org03297c0-39\"><a href=\"#org03297c0-39\" aria-hidden=\"true\"><\/a><span class=\"co\">        In other words, a parameter in `dist_param` with a shape corresponding<\/span><\/span>\n<span id=\"org03297c0-40\"><a href=\"#org03297c0-40\" aria-hidden=\"true\"><\/a><span class=\"co\">        to the support&#39;s shape.<\/span><\/span>\n<span id=\"org03297c0-41\"><a href=\"#org03297c0-41\" aria-hidden=\"true\"><\/a><span class=\"co\">        The default is the first parameter (i.e. the value 0).<\/span><\/span>\n<span id=\"org03297c0-42\"><a href=\"#org03297c0-42\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org03297c0-43\"><a href=\"#org03297c0-43\" aria-hidden=\"true\"><\/a><span class=\"co\">    Results<\/span><\/span>\n<span id=\"org03297c0-44\"><a href=\"#org03297c0-44\" aria-hidden=\"true\"><\/a><span class=\"co\">    =======<\/span><\/span>\n<span id=\"org03297c0-45\"><a href=\"#org03297c0-45\" aria-hidden=\"true\"><\/a><span class=\"co\">    out: a tuple representing the support shape for a distribution with the<\/span><\/span>\n<span id=\"org03297c0-46\"><a href=\"#org03297c0-46\" aria-hidden=\"true\"><\/a><span class=\"co\">    given `dist_params`.<\/span><\/span>\n<span id=\"org03297c0-47\"><a href=\"#org03297c0-47\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"org03297c0-48\"><a href=\"#org03297c0-48\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># XXX: Gotta be careful slicing Theano variables, the `Subtensor` Op isn&#39;t<\/span><\/span>\n<span id=\"org03297c0-49\"><a href=\"#org03297c0-49\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># handled by `tensor.get_scalar_constant_value`!<\/span><\/span>\n<span id=\"org03297c0-50\"><a href=\"#org03297c0-50\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># E.g.<\/span><\/span>\n<span id=\"org03297c0-51\"><a href=\"#org03297c0-51\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     test_val = tt.as_tensor_variable([[1], [4]])<\/span><\/span>\n<span id=\"org03297c0-52\"><a href=\"#org03297c0-52\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     tt.get_scalar_constant_value(test_val.shape[-1]) # works<\/span><\/span>\n<span id=\"org03297c0-53\"><a href=\"#org03297c0-53\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     tt.get_scalar_constant_value(test_val.shape[0]) # doesn&#39;t<\/span><\/span>\n<span id=\"org03297c0-54\"><a href=\"#org03297c0-54\" aria-hidden=\"true\"><\/a>    <span class=\"co\">#     tt.get_scalar_constant_value(test_val.shape[:-1]) # doesn&#39;t<\/span><\/span>\n<span id=\"org03297c0-55\"><a href=\"#org03297c0-55\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> param_shapes <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"org03297c0-56\"><a href=\"#org03297c0-56\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># return param_shapes[0][-self.ndim_supp:]<\/span><\/span>\n<span id=\"org03297c0-57\"><a href=\"#org03297c0-57\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> (param_shapes[rep_param_idx][<span class=\"op\">-<\/span>ndim_supp],)<\/span>\n<span id=\"org03297c0-58\"><a href=\"#org03297c0-58\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org03297c0-59\"><a href=\"#org03297c0-59\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># return dist_params[rep_param_idx].shape[-ndim_supp]<\/span><\/span>\n<span id=\"org03297c0-60\"><a href=\"#org03297c0-60\" aria-hidden=\"true\"><\/a>        ref_shape <span class=\"op\">=<\/span> tt.shape(dist_params[rep_param_idx])<\/span>\n<span id=\"org03297c0-61\"><a href=\"#org03297c0-61\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> (ref_shape[<span class=\"op\">-<\/span>ndim_supp],)<\/span><\/code><\/pre><\/div>\n<p>Finally, we put everything together in a new random variable <code>Op<\/code> called <code>RandomVariable<\/code>.<\/p>\n<div class=\"sourceCode\" id=\"orgb7517e4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgb7517e4-1\"><a href=\"#orgb7517e4-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> RandomVariable(tt.gof.Op):<\/span>\n<span id=\"orgb7517e4-2\"><a href=\"#orgb7517e4-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;This is essentially `RandomFunction`, except that it removes the `outtype`<\/span><\/span>\n<span id=\"orgb7517e4-3\"><a href=\"#orgb7517e4-3\" aria-hidden=\"true\"><\/a><span class=\"co\">    dependency and handles shape dimension information more directly.<\/span><\/span>\n<span id=\"orgb7517e4-4\"><a href=\"#orgb7517e4-4\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgb7517e4-5\"><a href=\"#orgb7517e4-5\" aria-hidden=\"true\"><\/a>    __props__ <span class=\"op\">=<\/span> (<span class=\"st\">&#39;name&#39;<\/span>, <span class=\"st\">&#39;dtype&#39;<\/span>, <span class=\"st\">&#39;ndim_supp&#39;<\/span>, <span class=\"st\">&#39;inplace&#39;<\/span>, <span class=\"st\">&#39;ndims_params&#39;<\/span>)<\/span>\n<span id=\"orgb7517e4-6\"><a href=\"#orgb7517e4-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-7\"><a href=\"#orgb7517e4-7\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>, name, dtype, ndim_supp, ndims_params, rng_fn,<\/span>\n<span id=\"orgb7517e4-8\"><a href=\"#orgb7517e4-8\" aria-hidden=\"true\"><\/a>                 <span class=\"op\">*<\/span>args,<\/span>\n<span id=\"orgb7517e4-9\"><a href=\"#orgb7517e4-9\" aria-hidden=\"true\"><\/a>                 supp_shape_fn<span class=\"op\">=<\/span>param_supp_shape_fn,<\/span>\n<span id=\"orgb7517e4-10\"><a href=\"#orgb7517e4-10\" aria-hidden=\"true\"><\/a>                 inplace<span class=\"op\">=<\/span><span class=\"va\">False<\/span>,<\/span>\n<span id=\"orgb7517e4-11\"><a href=\"#orgb7517e4-11\" aria-hidden=\"true\"><\/a>                 <span class=\"op\">**<\/span>kwargs):<\/span>\n<span id=\"orgb7517e4-12\"><a href=\"#orgb7517e4-12\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;Create a random variable `Op`.<\/span><\/span>\n<span id=\"orgb7517e4-13\"><a href=\"#orgb7517e4-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-14\"><a href=\"#orgb7517e4-14\" aria-hidden=\"true\"><\/a><span class=\"co\">        Parameters<\/span><\/span>\n<span id=\"orgb7517e4-15\"><a href=\"#orgb7517e4-15\" aria-hidden=\"true\"><\/a><span class=\"co\">        ==========<\/span><\/span>\n<span id=\"orgb7517e4-16\"><a href=\"#orgb7517e4-16\" aria-hidden=\"true\"><\/a><span class=\"co\">        name: str<\/span><\/span>\n<span id=\"orgb7517e4-17\"><a href=\"#orgb7517e4-17\" aria-hidden=\"true\"><\/a><span class=\"co\">            The `Op`&#39;s display name.<\/span><\/span>\n<span id=\"orgb7517e4-18\"><a href=\"#orgb7517e4-18\" aria-hidden=\"true\"><\/a><span class=\"co\">        dtype: Theano dtype<\/span><\/span>\n<span id=\"orgb7517e4-19\"><a href=\"#orgb7517e4-19\" aria-hidden=\"true\"><\/a><span class=\"co\">            The underlying dtype.<\/span><\/span>\n<span id=\"orgb7517e4-20\"><a href=\"#orgb7517e4-20\" aria-hidden=\"true\"><\/a><span class=\"co\">        ndim_supp: int<\/span><\/span>\n<span id=\"orgb7517e4-21\"><a href=\"#orgb7517e4-21\" aria-hidden=\"true\"><\/a><span class=\"co\">            Dimension of the support.  This value is used to infer the exact<\/span><\/span>\n<span id=\"orgb7517e4-22\"><a href=\"#orgb7517e4-22\" aria-hidden=\"true\"><\/a><span class=\"co\">            shape of the support and independent terms from ``dist_params``.<\/span><\/span>\n<span id=\"orgb7517e4-23\"><a href=\"#orgb7517e4-23\" aria-hidden=\"true\"><\/a><span class=\"co\">        ndims_params: tuple (int)<\/span><\/span>\n<span id=\"orgb7517e4-24\"><a href=\"#orgb7517e4-24\" aria-hidden=\"true\"><\/a><span class=\"co\">            Number of dimensions of each parameter in ``dist_params``.<\/span><\/span>\n<span id=\"orgb7517e4-25\"><a href=\"#orgb7517e4-25\" aria-hidden=\"true\"><\/a><span class=\"co\">        rng_fn: function or str<\/span><\/span>\n<span id=\"orgb7517e4-26\"><a href=\"#orgb7517e4-26\" aria-hidden=\"true\"><\/a><span class=\"co\">            The non-symbolic random variate sampling function.<\/span><\/span>\n<span id=\"orgb7517e4-27\"><a href=\"#orgb7517e4-27\" aria-hidden=\"true\"><\/a><span class=\"co\">            Can be the string name of a method provided by<\/span><\/span>\n<span id=\"orgb7517e4-28\"><a href=\"#orgb7517e4-28\" aria-hidden=\"true\"><\/a><span class=\"co\">            `numpy.random.RandomState`.<\/span><\/span>\n<span id=\"orgb7517e4-29\"><a href=\"#orgb7517e4-29\" aria-hidden=\"true\"><\/a><span class=\"co\">        supp_shape_fn: callable (optional)<\/span><\/span>\n<span id=\"orgb7517e4-30\"><a href=\"#orgb7517e4-30\" aria-hidden=\"true\"><\/a><span class=\"co\">            Function used to determine the exact shape of the distribution&#39;s<\/span><\/span>\n<span id=\"orgb7517e4-31\"><a href=\"#orgb7517e4-31\" aria-hidden=\"true\"><\/a><span class=\"co\">            support.<\/span><\/span>\n<span id=\"orgb7517e4-32\"><a href=\"#orgb7517e4-32\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-33\"><a href=\"#orgb7517e4-33\" aria-hidden=\"true\"><\/a><span class=\"co\">            It must take arguments ndim_supp, ndims_params, dist_params<\/span><\/span>\n<span id=\"orgb7517e4-34\"><a href=\"#orgb7517e4-34\" aria-hidden=\"true\"><\/a><span class=\"co\">            (i.e. an collection of the distribution parameters) and an<\/span><\/span>\n<span id=\"orgb7517e4-35\"><a href=\"#orgb7517e4-35\" aria-hidden=\"true\"><\/a><span class=\"co\">            optional param_shapes (i.e. tuples containing the size of each<\/span><\/span>\n<span id=\"orgb7517e4-36\"><a href=\"#orgb7517e4-36\" aria-hidden=\"true\"><\/a><span class=\"co\">            dimension for each distribution parameter).<\/span><\/span>\n<span id=\"orgb7517e4-37\"><a href=\"#orgb7517e4-37\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-38\"><a href=\"#orgb7517e4-38\" aria-hidden=\"true\"><\/a><span class=\"co\">            Defaults to `param_supp_shape_fn`.<\/span><\/span>\n<span id=\"orgb7517e4-39\"><a href=\"#orgb7517e4-39\" aria-hidden=\"true\"><\/a><span class=\"co\">        inplace: boolean<\/span><\/span>\n<span id=\"orgb7517e4-40\"><a href=\"#orgb7517e4-40\" aria-hidden=\"true\"><\/a><span class=\"co\">            Determine whether or not the underlying rng state is updated in-place or<\/span><\/span>\n<span id=\"orgb7517e4-41\"><a href=\"#orgb7517e4-41\" aria-hidden=\"true\"><\/a><span class=\"co\">            not (i.e. copied).<\/span><\/span>\n<span id=\"orgb7517e4-42\"><a href=\"#orgb7517e4-42\" aria-hidden=\"true\"><\/a><span class=\"co\">        &quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgb7517e4-43\"><a href=\"#orgb7517e4-43\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"op\">*<\/span>args, <span class=\"op\">**<\/span>kwargs)<\/span>\n<span id=\"orgb7517e4-44\"><a href=\"#orgb7517e4-44\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-45\"><a href=\"#orgb7517e4-45\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.name <span class=\"op\">=<\/span> name<\/span>\n<span id=\"orgb7517e4-46\"><a href=\"#orgb7517e4-46\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.ndim_supp <span class=\"op\">=<\/span> ndim_supp<\/span>\n<span id=\"orgb7517e4-47\"><a href=\"#orgb7517e4-47\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.dtype <span class=\"op\">=<\/span> dtype<\/span>\n<span id=\"orgb7517e4-48\"><a href=\"#orgb7517e4-48\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.supp_shape_fn <span class=\"op\">=<\/span> supp_shape_fn<\/span>\n<span id=\"orgb7517e4-49\"><a href=\"#orgb7517e4-49\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.inplace <span class=\"op\">=<\/span> inplace<\/span>\n<span id=\"orgb7517e4-50\"><a href=\"#orgb7517e4-50\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-51\"><a href=\"#orgb7517e4-51\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(ndims_params, Iterable):<\/span>\n<span id=\"orgb7517e4-52\"><a href=\"#orgb7517e4-52\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">raise<\/span> <span class=\"pp\">ValueError<\/span>(<span class=\"st\">&#39;Parameter ndims_params must be iterable.&#39;<\/span>)<\/span>\n<span id=\"orgb7517e4-53\"><a href=\"#orgb7517e4-53\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-54\"><a href=\"#orgb7517e4-54\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.ndims_params <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(ndims_params)<\/span>\n<span id=\"orgb7517e4-55\"><a href=\"#orgb7517e4-55\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-56\"><a href=\"#orgb7517e4-56\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.default_output <span class=\"op\">=<\/span> <span class=\"dv\">1<\/span><\/span>\n<span id=\"orgb7517e4-57\"><a href=\"#orgb7517e4-57\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-58\"><a href=\"#orgb7517e4-58\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(rng_fn, (<span class=\"bu\">str<\/span>, ByteString)):<\/span>\n<span id=\"orgb7517e4-59\"><a href=\"#orgb7517e4-59\" aria-hidden=\"true\"><\/a>            <span class=\"va\">self<\/span>.rng_fn <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(np.random.RandomState, rng_fn)<\/span>\n<span id=\"orgb7517e4-60\"><a href=\"#orgb7517e4-60\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgb7517e4-61\"><a href=\"#orgb7517e4-61\" aria-hidden=\"true\"><\/a>            <span class=\"va\">self<\/span>.rng_fn <span class=\"op\">=<\/span> rng_fn<\/span>\n<span id=\"orgb7517e4-62\"><a href=\"#orgb7517e4-62\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-63\"><a href=\"#orgb7517e4-63\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__str__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgb7517e4-64\"><a href=\"#orgb7517e4-64\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"st\">&#39;<\/span><span class=\"sc\">{}<\/span><span class=\"st\">_rv&#39;<\/span>.<span class=\"bu\">format<\/span>(<span class=\"va\">self<\/span>.name)<\/span>\n<span id=\"orgb7517e4-65\"><a href=\"#orgb7517e4-65\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-66\"><a href=\"#orgb7517e4-66\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> _infer_shape(<span class=\"va\">self<\/span>, size, dist_params, param_shapes<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgb7517e4-67\"><a href=\"#orgb7517e4-67\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;Compute shapes and broadcasts properties.<\/span><\/span>\n<span id=\"orgb7517e4-68\"><a href=\"#orgb7517e4-68\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-69\"><a href=\"#orgb7517e4-69\" aria-hidden=\"true\"><\/a><span class=\"co\">        Inspired by `tt.add.get_output_info`.<\/span><\/span>\n<span id=\"orgb7517e4-70\"><a href=\"#orgb7517e4-70\" aria-hidden=\"true\"><\/a><span class=\"co\">        &quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgb7517e4-71\"><a href=\"#orgb7517e4-71\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-72\"><a href=\"#orgb7517e4-72\" aria-hidden=\"true\"><\/a>        size_len <span class=\"op\">=<\/span> tt.get_vector_length(size)<\/span>\n<span id=\"orgb7517e4-73\"><a href=\"#orgb7517e4-73\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-74\"><a href=\"#orgb7517e4-74\" aria-hidden=\"true\"><\/a>        dummy_params <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(p <span class=\"cf\">if<\/span> n <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span> <span class=\"cf\">else<\/span> tt.ones(<span class=\"bu\">tuple<\/span>(p.shape)[:<span class=\"op\">-<\/span>n])<\/span>\n<span id=\"orgb7517e4-75\"><a href=\"#orgb7517e4-75\" aria-hidden=\"true\"><\/a>                             <span class=\"cf\">for<\/span> p, n <span class=\"kw\">in<\/span> <span class=\"bu\">zip<\/span>(dist_params, <span class=\"va\">self<\/span>.ndims_params))<\/span>\n<span id=\"orgb7517e4-76\"><a href=\"#orgb7517e4-76\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-77\"><a href=\"#orgb7517e4-77\" aria-hidden=\"true\"><\/a>        _, out_bcasts, bcastd_inputs <span class=\"op\">=<\/span> tt.add.get_output_info(<\/span>\n<span id=\"orgb7517e4-78\"><a href=\"#orgb7517e4-78\" aria-hidden=\"true\"><\/a>            tt.DimShuffle, <span class=\"op\">*<\/span>dummy_params)<\/span>\n<span id=\"orgb7517e4-79\"><a href=\"#orgb7517e4-79\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-80\"><a href=\"#orgb7517e4-80\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># _, out_bcasts, bcastd_inputs = tt.add.get_output_info(tt.DimShuffle, *dist_params)<\/span><\/span>\n<span id=\"orgb7517e4-81\"><a href=\"#orgb7517e4-81\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-82\"><a href=\"#orgb7517e4-82\" aria-hidden=\"true\"><\/a>        bcast_ind, <span class=\"op\">=<\/span> out_bcasts<\/span>\n<span id=\"orgb7517e4-83\"><a href=\"#orgb7517e4-83\" aria-hidden=\"true\"><\/a>        ndim_ind <span class=\"op\">=<\/span> <span class=\"bu\">len<\/span>(bcast_ind)<\/span>\n<span id=\"orgb7517e4-84\"><a href=\"#orgb7517e4-84\" aria-hidden=\"true\"><\/a>        shape_ind <span class=\"op\">=<\/span> bcastd_inputs[<span class=\"dv\">0<\/span>].shape<\/span>\n<span id=\"orgb7517e4-85\"><a href=\"#orgb7517e4-85\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-86\"><a href=\"#orgb7517e4-86\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"va\">self<\/span>.ndim_supp <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span>:<\/span>\n<span id=\"orgb7517e4-87\"><a href=\"#orgb7517e4-87\" aria-hidden=\"true\"><\/a>            shape_supp <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>()<\/span>\n<span id=\"orgb7517e4-88\"><a href=\"#orgb7517e4-88\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-89\"><a href=\"#orgb7517e4-89\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># In the scalar case, `size` corresponds to the entire result&#39;s<\/span><\/span>\n<span id=\"orgb7517e4-90\"><a href=\"#orgb7517e4-90\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># shape. This implies the following:<\/span><\/span>\n<span id=\"orgb7517e4-91\"><a href=\"#orgb7517e4-91\" aria-hidden=\"true\"><\/a>            <span class=\"co\">#     shape_ind[-ndim_ind] == size[:ndim_ind]<\/span><\/span>\n<span id=\"orgb7517e4-92\"><a href=\"#orgb7517e4-92\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: How do we add this constraint\/check symbolically?<\/span><\/span>\n<span id=\"orgb7517e4-93\"><a href=\"#orgb7517e4-93\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-94\"><a href=\"#orgb7517e4-94\" aria-hidden=\"true\"><\/a>            ndim_reps <span class=\"op\">=<\/span> <span class=\"bu\">max<\/span>(size_len <span class=\"op\">-<\/span> ndim_ind, <span class=\"dv\">0<\/span>)<\/span>\n<span id=\"orgb7517e4-95\"><a href=\"#orgb7517e4-95\" aria-hidden=\"true\"><\/a>            shape_reps <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(size)[ndim_ind:]<\/span>\n<span id=\"orgb7517e4-96\"><a href=\"#orgb7517e4-96\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgb7517e4-97\"><a href=\"#orgb7517e4-97\" aria-hidden=\"true\"><\/a>            shape_supp <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.supp_shape_fn(<span class=\"va\">self<\/span>.ndim_supp,<\/span>\n<span id=\"orgb7517e4-98\"><a href=\"#orgb7517e4-98\" aria-hidden=\"true\"><\/a>                                            <span class=\"va\">self<\/span>.ndims_params,<\/span>\n<span id=\"orgb7517e4-99\"><a href=\"#orgb7517e4-99\" aria-hidden=\"true\"><\/a>                                            dist_params,<\/span>\n<span id=\"orgb7517e4-100\"><a href=\"#orgb7517e4-100\" aria-hidden=\"true\"><\/a>                                            param_shapes<span class=\"op\">=<\/span>param_shapes)<\/span>\n<span id=\"orgb7517e4-101\"><a href=\"#orgb7517e4-101\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-102\"><a href=\"#orgb7517e4-102\" aria-hidden=\"true\"><\/a>            ndim_reps <span class=\"op\">=<\/span> size_len<\/span>\n<span id=\"orgb7517e4-103\"><a href=\"#orgb7517e4-103\" aria-hidden=\"true\"><\/a>            shape_reps <span class=\"op\">=<\/span> size<\/span>\n<span id=\"orgb7517e4-104\"><a href=\"#orgb7517e4-104\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-105\"><a href=\"#orgb7517e4-105\" aria-hidden=\"true\"><\/a>        ndim_shape <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.ndim_supp <span class=\"op\">+<\/span> ndim_ind <span class=\"op\">+<\/span> ndim_reps<\/span>\n<span id=\"orgb7517e4-106\"><a href=\"#orgb7517e4-106\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-107\"><a href=\"#orgb7517e4-107\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> ndim_shape <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span>:<\/span>\n<span id=\"orgb7517e4-108\"><a href=\"#orgb7517e4-108\" aria-hidden=\"true\"><\/a>            shape <span class=\"op\">=<\/span> tt.constant([], dtype<span class=\"op\">=<\/span><span class=\"st\">&#39;int64&#39;<\/span>)<\/span>\n<span id=\"orgb7517e4-109\"><a href=\"#orgb7517e4-109\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgb7517e4-110\"><a href=\"#orgb7517e4-110\" aria-hidden=\"true\"><\/a>            shape <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(shape_reps) <span class=\"op\">+<\/span> <span class=\"bu\">tuple<\/span>(shape_ind) <span class=\"op\">+<\/span> <span class=\"bu\">tuple<\/span>(shape_supp)<\/span>\n<span id=\"orgb7517e4-111\"><a href=\"#orgb7517e4-111\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-112\"><a href=\"#orgb7517e4-112\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># if shape is None:<\/span><\/span>\n<span id=\"orgb7517e4-113\"><a href=\"#orgb7517e4-113\" aria-hidden=\"true\"><\/a>        <span class=\"co\">#     raise tt.ShapeError()<\/span><\/span>\n<span id=\"orgb7517e4-114\"><a href=\"#orgb7517e4-114\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-115\"><a href=\"#orgb7517e4-115\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> shape<\/span>\n<span id=\"orgb7517e4-116\"><a href=\"#orgb7517e4-116\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-117\"><a href=\"#orgb7517e4-117\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> compute_bcast(<span class=\"va\">self<\/span>, dist_params, size):<\/span>\n<span id=\"orgb7517e4-118\"><a href=\"#orgb7517e4-118\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;Compute the broadcast array for this distribution&#39;s `TensorType`.<\/span><\/span>\n<span id=\"orgb7517e4-119\"><a href=\"#orgb7517e4-119\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-120\"><a href=\"#orgb7517e4-120\" aria-hidden=\"true\"><\/a><span class=\"co\">        Parameters<\/span><\/span>\n<span id=\"orgb7517e4-121\"><a href=\"#orgb7517e4-121\" aria-hidden=\"true\"><\/a><span class=\"co\">        ==========<\/span><\/span>\n<span id=\"orgb7517e4-122\"><a href=\"#orgb7517e4-122\" aria-hidden=\"true\"><\/a><span class=\"co\">        dist_params: list<\/span><\/span>\n<span id=\"orgb7517e4-123\"><a href=\"#orgb7517e4-123\" aria-hidden=\"true\"><\/a><span class=\"co\">            Distribution parameters.<\/span><\/span>\n<span id=\"orgb7517e4-124\"><a href=\"#orgb7517e4-124\" aria-hidden=\"true\"><\/a><span class=\"co\">        size: int or Iterable (optional)<\/span><\/span>\n<span id=\"orgb7517e4-125\"><a href=\"#orgb7517e4-125\" aria-hidden=\"true\"><\/a><span class=\"co\">            Numpy-like size of the output (i.e. replications).<\/span><\/span>\n<span id=\"orgb7517e4-126\"><a href=\"#orgb7517e4-126\" aria-hidden=\"true\"><\/a><span class=\"co\">        &quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgb7517e4-127\"><a href=\"#orgb7517e4-127\" aria-hidden=\"true\"><\/a>        shape <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>._infer_shape(size, dist_params)<\/span>\n<span id=\"orgb7517e4-128\"><a href=\"#orgb7517e4-128\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-129\"><a href=\"#orgb7517e4-129\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Let&#39;s try to do a better job than `_infer_ndim_bcast` when<\/span><\/span>\n<span id=\"orgb7517e4-130\"><a href=\"#orgb7517e4-130\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># dimension sizes are symbolic.<\/span><\/span>\n<span id=\"orgb7517e4-131\"><a href=\"#orgb7517e4-131\" aria-hidden=\"true\"><\/a>        bcast <span class=\"op\">=<\/span> []<\/span>\n<span id=\"orgb7517e4-132\"><a href=\"#orgb7517e4-132\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">for<\/span> s <span class=\"kw\">in<\/span> shape:<\/span>\n<span id=\"orgb7517e4-133\"><a href=\"#orgb7517e4-133\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">try<\/span>:<\/span>\n<span id=\"orgb7517e4-134\"><a href=\"#orgb7517e4-134\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(s.owner.op, tt.Subtensor) <span class=\"kw\">and<\/span> <span class=\"op\">\\<\/span><\/span>\n<span id=\"orgb7517e4-135\"><a href=\"#orgb7517e4-135\" aria-hidden=\"true\"><\/a>                   s.owner.inputs[<span class=\"dv\">0<\/span>].owner <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"orgb7517e4-136\"><a href=\"#orgb7517e4-136\" aria-hidden=\"true\"><\/a>                    <span class=\"co\"># Handle a special case in which<\/span><\/span>\n<span id=\"orgb7517e4-137\"><a href=\"#orgb7517e4-137\" aria-hidden=\"true\"><\/a>                    <span class=\"co\"># `tensor.get_scalar_constant_value` doesn&#39;t really work.<\/span><\/span>\n<span id=\"orgb7517e4-138\"><a href=\"#orgb7517e4-138\" aria-hidden=\"true\"><\/a>                    s_x, s_idx <span class=\"op\">=<\/span> s.owner.inputs<\/span>\n<span id=\"orgb7517e4-139\"><a href=\"#orgb7517e4-139\" aria-hidden=\"true\"><\/a>                    s_idx <span class=\"op\">=<\/span> tt.get_scalar_constant_value(s_idx)<\/span>\n<span id=\"orgb7517e4-140\"><a href=\"#orgb7517e4-140\" aria-hidden=\"true\"><\/a>                    <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(s_x.owner.op, tt.Shape):<\/span>\n<span id=\"orgb7517e4-141\"><a href=\"#orgb7517e4-141\" aria-hidden=\"true\"><\/a>                        x_obj, <span class=\"op\">=<\/span> s_x.owner.inputs<\/span>\n<span id=\"orgb7517e4-142\"><a href=\"#orgb7517e4-142\" aria-hidden=\"true\"><\/a>                        s_val <span class=\"op\">=<\/span> x_obj.<span class=\"bu\">type<\/span>.broadcastable[s_idx]<\/span>\n<span id=\"orgb7517e4-143\"><a href=\"#orgb7517e4-143\" aria-hidden=\"true\"><\/a>                    <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgb7517e4-144\"><a href=\"#orgb7517e4-144\" aria-hidden=\"true\"><\/a>                        <span class=\"co\"># <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: Could go for an existing broadcastable here, too, no?<\/span><\/span>\n<span id=\"orgb7517e4-145\"><a href=\"#orgb7517e4-145\" aria-hidden=\"true\"><\/a>                        s_val <span class=\"op\">=<\/span> <span class=\"va\">False<\/span><\/span>\n<span id=\"orgb7517e4-146\"><a href=\"#orgb7517e4-146\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgb7517e4-147\"><a href=\"#orgb7517e4-147\" aria-hidden=\"true\"><\/a>                    s_val <span class=\"op\">=<\/span> tt.get_scalar_constant_value(s)<\/span>\n<span id=\"orgb7517e4-148\"><a href=\"#orgb7517e4-148\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">except<\/span> tt.NotScalarConstantError:<\/span>\n<span id=\"orgb7517e4-149\"><a href=\"#orgb7517e4-149\" aria-hidden=\"true\"><\/a>                s_val <span class=\"op\">=<\/span> <span class=\"va\">False<\/span><\/span>\n<span id=\"orgb7517e4-150\"><a href=\"#orgb7517e4-150\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-151\"><a href=\"#orgb7517e4-151\" aria-hidden=\"true\"><\/a>            bcast <span class=\"op\">+=<\/span> [s_val <span class=\"op\">==<\/span> <span class=\"dv\">1<\/span>]<\/span>\n<span id=\"orgb7517e4-152\"><a href=\"#orgb7517e4-152\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> bcast<\/span>\n<span id=\"orgb7517e4-153\"><a href=\"#orgb7517e4-153\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-154\"><a href=\"#orgb7517e4-154\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> infer_shape(<span class=\"va\">self<\/span>, node, input_shapes):<\/span>\n<span id=\"orgb7517e4-155\"><a href=\"#orgb7517e4-155\" aria-hidden=\"true\"><\/a>        size <span class=\"op\">=<\/span> node.inputs[<span class=\"op\">-<\/span><span class=\"dv\">2<\/span>]<\/span>\n<span id=\"orgb7517e4-156\"><a href=\"#orgb7517e4-156\" aria-hidden=\"true\"><\/a>        dist_params <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(node.inputs[:<span class=\"op\">-<\/span><span class=\"dv\">2<\/span>])<\/span>\n<span id=\"orgb7517e4-157\"><a href=\"#orgb7517e4-157\" aria-hidden=\"true\"><\/a>        shape <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>._infer_shape(size, dist_params,<\/span>\n<span id=\"orgb7517e4-158\"><a href=\"#orgb7517e4-158\" aria-hidden=\"true\"><\/a>                                  param_shapes<span class=\"op\">=<\/span>input_shapes[:<span class=\"op\">-<\/span><span class=\"dv\">2<\/span>])<\/span>\n<span id=\"orgb7517e4-159\"><a href=\"#orgb7517e4-159\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-160\"><a href=\"#orgb7517e4-160\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> [<span class=\"va\">None<\/span>, [s <span class=\"cf\">for<\/span> s <span class=\"kw\">in<\/span> shape]]<\/span>\n<span id=\"orgb7517e4-161\"><a href=\"#orgb7517e4-161\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-162\"><a href=\"#orgb7517e4-162\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, <span class=\"op\">*<\/span>dist_params, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgb7517e4-163\"><a href=\"#orgb7517e4-163\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;Create a random variable node.<\/span><\/span>\n<span id=\"orgb7517e4-164\"><a href=\"#orgb7517e4-164\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-165\"><a href=\"#orgb7517e4-165\" aria-hidden=\"true\"><\/a><span class=\"co\">        XXX: Unnamed\/non-keyword arguments are considered distribution<\/span><\/span>\n<span id=\"orgb7517e4-166\"><a href=\"#orgb7517e4-166\" aria-hidden=\"true\"><\/a><span class=\"co\">        parameters!  If you want to set `size`, `rng`, and\/or `name`, use their<\/span><\/span>\n<span id=\"orgb7517e4-167\"><a href=\"#orgb7517e4-167\" aria-hidden=\"true\"><\/a><span class=\"co\">        keywords.<\/span><\/span>\n<span id=\"orgb7517e4-168\"><a href=\"#orgb7517e4-168\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-169\"><a href=\"#orgb7517e4-169\" aria-hidden=\"true\"><\/a><span class=\"co\">        Parameters<\/span><\/span>\n<span id=\"orgb7517e4-170\"><a href=\"#orgb7517e4-170\" aria-hidden=\"true\"><\/a><span class=\"co\">        ==========<\/span><\/span>\n<span id=\"orgb7517e4-171\"><a href=\"#orgb7517e4-171\" aria-hidden=\"true\"><\/a><span class=\"co\">        dist_params: list<\/span><\/span>\n<span id=\"orgb7517e4-172\"><a href=\"#orgb7517e4-172\" aria-hidden=\"true\"><\/a><span class=\"co\">            Distribution parameters.<\/span><\/span>\n<span id=\"orgb7517e4-173\"><a href=\"#orgb7517e4-173\" aria-hidden=\"true\"><\/a><span class=\"co\">        size: int or Iterable (optional)<\/span><\/span>\n<span id=\"orgb7517e4-174\"><a href=\"#orgb7517e4-174\" aria-hidden=\"true\"><\/a><span class=\"co\">            Numpy-like size of the output (i.e. replications).<\/span><\/span>\n<span id=\"orgb7517e4-175\"><a href=\"#orgb7517e4-175\" aria-hidden=\"true\"><\/a><span class=\"co\">        rng: RandomState (optional)<\/span><\/span>\n<span id=\"orgb7517e4-176\"><a href=\"#orgb7517e4-176\" aria-hidden=\"true\"><\/a><span class=\"co\">            Existing Theano `RandomState` object to be used.  Creates a<\/span><\/span>\n<span id=\"orgb7517e4-177\"><a href=\"#orgb7517e4-177\" aria-hidden=\"true\"><\/a><span class=\"co\">            new one, if `None`.<\/span><\/span>\n<span id=\"orgb7517e4-178\"><a href=\"#orgb7517e4-178\" aria-hidden=\"true\"><\/a><span class=\"co\">        name: str (optional)<\/span><\/span>\n<span id=\"orgb7517e4-179\"><a href=\"#orgb7517e4-179\" aria-hidden=\"true\"><\/a><span class=\"co\">            Label for the resulting node.<\/span><\/span>\n<span id=\"orgb7517e4-180\"><a href=\"#orgb7517e4-180\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-181\"><a href=\"#orgb7517e4-181\" aria-hidden=\"true\"><\/a><span class=\"co\">        Results<\/span><\/span>\n<span id=\"orgb7517e4-182\"><a href=\"#orgb7517e4-182\" aria-hidden=\"true\"><\/a><span class=\"co\">        =======<\/span><\/span>\n<span id=\"orgb7517e4-183\"><a href=\"#orgb7517e4-183\" aria-hidden=\"true\"><\/a><span class=\"co\">        out: `Apply`<\/span><\/span>\n<span id=\"orgb7517e4-184\"><a href=\"#orgb7517e4-184\" aria-hidden=\"true\"><\/a><span class=\"co\">            A node with inputs `dist_args + (size, in_rng, name)` and outputs<\/span><\/span>\n<span id=\"orgb7517e4-185\"><a href=\"#orgb7517e4-185\" aria-hidden=\"true\"><\/a><span class=\"co\">            `(out_rng, sample_tensorvar)`.<\/span><\/span>\n<span id=\"orgb7517e4-186\"><a href=\"#orgb7517e4-186\" aria-hidden=\"true\"><\/a><span class=\"co\">        &quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgb7517e4-187\"><a href=\"#orgb7517e4-187\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> size <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"orgb7517e4-188\"><a href=\"#orgb7517e4-188\" aria-hidden=\"true\"><\/a>            size <span class=\"op\">=<\/span> tt.constant([], dtype<span class=\"op\">=<\/span><span class=\"st\">&#39;int64&#39;<\/span>)<\/span>\n<span id=\"orgb7517e4-189\"><a href=\"#orgb7517e4-189\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> <span class=\"bu\">isinstance<\/span>(size, <span class=\"bu\">int<\/span>):<\/span>\n<span id=\"orgb7517e4-190\"><a href=\"#orgb7517e4-190\" aria-hidden=\"true\"><\/a>            size <span class=\"op\">=<\/span> tt.as_tensor_variable([size], ndim<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"orgb7517e4-191\"><a href=\"#orgb7517e4-191\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(size, Iterable):<\/span>\n<span id=\"orgb7517e4-192\"><a href=\"#orgb7517e4-192\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">raise<\/span> <span class=\"pp\">ValueError<\/span>(<span class=\"st\">&#39;Parameter size must be None, int, or an iterable with ints.&#39;<\/span>)<\/span>\n<span id=\"orgb7517e4-193\"><a href=\"#orgb7517e4-193\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgb7517e4-194\"><a href=\"#orgb7517e4-194\" aria-hidden=\"true\"><\/a>            size <span class=\"op\">=<\/span> tt.as_tensor_variable(size, ndim<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"orgb7517e4-195\"><a href=\"#orgb7517e4-195\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-196\"><a href=\"#orgb7517e4-196\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">assert<\/span> size.dtype <span class=\"kw\">in<\/span> tt.int_dtypes<\/span>\n<span id=\"orgb7517e4-197\"><a href=\"#orgb7517e4-197\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-198\"><a href=\"#orgb7517e4-198\" aria-hidden=\"true\"><\/a>        dist_params <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(tt.as_tensor_variable(p)<\/span>\n<span id=\"orgb7517e4-199\"><a href=\"#orgb7517e4-199\" aria-hidden=\"true\"><\/a>                            <span class=\"cf\">for<\/span> p <span class=\"kw\">in<\/span> dist_params)<\/span>\n<span id=\"orgb7517e4-200\"><a href=\"#orgb7517e4-200\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-201\"><a href=\"#orgb7517e4-201\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> rng <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"orgb7517e4-202\"><a href=\"#orgb7517e4-202\" aria-hidden=\"true\"><\/a>            rng <span class=\"op\">=<\/span> theano.shared(np.random.RandomState())<\/span>\n<span id=\"orgb7517e4-203\"><a href=\"#orgb7517e4-203\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(rng.<span class=\"bu\">type<\/span>, RandomStateType):<\/span>\n<span id=\"orgb7517e4-204\"><a href=\"#orgb7517e4-204\" aria-hidden=\"true\"><\/a>            warn(<span class=\"st\">&#39;The type of rng should be an instance of RandomStateType&#39;<\/span>)<\/span>\n<span id=\"orgb7517e4-205\"><a href=\"#orgb7517e4-205\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-206\"><a href=\"#orgb7517e4-206\" aria-hidden=\"true\"><\/a>        bcast <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.compute_bcast(dist_params, size)<\/span>\n<span id=\"orgb7517e4-207\"><a href=\"#orgb7517e4-207\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-208\"><a href=\"#orgb7517e4-208\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># dtype = tt.scal.upcast(self.dtype, *[p.dtype for p in dist_params])<\/span><\/span>\n<span id=\"orgb7517e4-209\"><a href=\"#orgb7517e4-209\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-210\"><a href=\"#orgb7517e4-210\" aria-hidden=\"true\"><\/a>        outtype <span class=\"op\">=<\/span> tt.TensorType(dtype<span class=\"op\">=<\/span><span class=\"va\">self<\/span>.dtype, broadcastable<span class=\"op\">=<\/span>bcast)<\/span>\n<span id=\"orgb7517e4-211\"><a href=\"#orgb7517e4-211\" aria-hidden=\"true\"><\/a>        out_var <span class=\"op\">=<\/span> outtype(name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgb7517e4-212\"><a href=\"#orgb7517e4-212\" aria-hidden=\"true\"><\/a>        inputs <span class=\"op\">=<\/span> dist_params <span class=\"op\">+<\/span> (size, rng)<\/span>\n<span id=\"orgb7517e4-213\"><a href=\"#orgb7517e4-213\" aria-hidden=\"true\"><\/a>        outputs <span class=\"op\">=<\/span> (rng.<span class=\"bu\">type<\/span>(), out_var)<\/span>\n<span id=\"orgb7517e4-214\"><a href=\"#orgb7517e4-214\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-215\"><a href=\"#orgb7517e4-215\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> theano.gof.Apply(<span class=\"va\">self<\/span>, inputs, outputs)<\/span>\n<span id=\"orgb7517e4-216\"><a href=\"#orgb7517e4-216\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-217\"><a href=\"#orgb7517e4-217\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> perform(<span class=\"va\">self<\/span>, node, inputs, outputs):<\/span>\n<span id=\"orgb7517e4-218\"><a href=\"#orgb7517e4-218\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;Draw samples using Numpy\/SciPy.&quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgb7517e4-219\"><a href=\"#orgb7517e4-219\" aria-hidden=\"true\"><\/a>        rng_out, smpl_out <span class=\"op\">=<\/span> outputs<\/span>\n<span id=\"orgb7517e4-220\"><a href=\"#orgb7517e4-220\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-221\"><a href=\"#orgb7517e4-221\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng`<\/span><\/span>\n<span id=\"orgb7517e4-222\"><a href=\"#orgb7517e4-222\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># otherwise.<\/span><\/span>\n<span id=\"orgb7517e4-223\"><a href=\"#orgb7517e4-223\" aria-hidden=\"true\"><\/a>        args <span class=\"op\">=<\/span> <span class=\"bu\">list<\/span>(inputs)<\/span>\n<span id=\"orgb7517e4-224\"><a href=\"#orgb7517e4-224\" aria-hidden=\"true\"><\/a>        rng <span class=\"op\">=<\/span> args.pop()<\/span>\n<span id=\"orgb7517e4-225\"><a href=\"#orgb7517e4-225\" aria-hidden=\"true\"><\/a>        size <span class=\"op\">=<\/span> args.pop()<\/span>\n<span id=\"orgb7517e4-226\"><a href=\"#orgb7517e4-226\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-227\"><a href=\"#orgb7517e4-227\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">assert<\/span> <span class=\"bu\">isinstance<\/span>(rng, np.random.RandomState), (<span class=\"bu\">type<\/span>(rng), rng)<\/span>\n<span id=\"orgb7517e4-228\"><a href=\"#orgb7517e4-228\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-229\"><a href=\"#orgb7517e4-229\" aria-hidden=\"true\"><\/a>        rng_out[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> rng<\/span>\n<span id=\"orgb7517e4-230\"><a href=\"#orgb7517e4-230\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-231\"><a href=\"#orgb7517e4-231\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># The symbolic output variable corresponding to value produced here.<\/span><\/span>\n<span id=\"orgb7517e4-232\"><a href=\"#orgb7517e4-232\" aria-hidden=\"true\"><\/a>        out_var <span class=\"op\">=<\/span> node.outputs[<span class=\"dv\">1<\/span>]<\/span>\n<span id=\"orgb7517e4-233\"><a href=\"#orgb7517e4-233\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-234\"><a href=\"#orgb7517e4-234\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># If `size == []`, that means no size is enforced, and NumPy is<\/span><\/span>\n<span id=\"orgb7517e4-235\"><a href=\"#orgb7517e4-235\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># trusted to draw the appropriate number of samples, NumPy uses<\/span><\/span>\n<span id=\"orgb7517e4-236\"><a href=\"#orgb7517e4-236\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># `size=None` to represent that.  Otherwise, NumPy expects a tuple.<\/span><\/span>\n<span id=\"orgb7517e4-237\"><a href=\"#orgb7517e4-237\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> np.size(size) <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span>:<\/span>\n<span id=\"orgb7517e4-238\"><a href=\"#orgb7517e4-238\" aria-hidden=\"true\"><\/a>            size <span class=\"op\">=<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"orgb7517e4-239\"><a href=\"#orgb7517e4-239\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgb7517e4-240\"><a href=\"#orgb7517e4-240\" aria-hidden=\"true\"><\/a>            size <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(size)<\/span>\n<span id=\"orgb7517e4-241\"><a href=\"#orgb7517e4-241\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-242\"><a href=\"#orgb7517e4-242\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">self<\/span>.inplace:<\/span>\n<span id=\"orgb7517e4-243\"><a href=\"#orgb7517e4-243\" aria-hidden=\"true\"><\/a>            rng <span class=\"op\">=<\/span> copy(rng)<\/span>\n<span id=\"orgb7517e4-244\"><a href=\"#orgb7517e4-244\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-245\"><a href=\"#orgb7517e4-245\" aria-hidden=\"true\"><\/a>        smpl_val <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.rng_fn(rng, <span class=\"op\">*<\/span>(args <span class=\"op\">+<\/span> [size]))<\/span>\n<span id=\"orgb7517e4-246\"><a href=\"#orgb7517e4-246\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-247\"><a href=\"#orgb7517e4-247\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> (<span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(smpl_val, np.ndarray) <span class=\"kw\">or<\/span><\/span>\n<span id=\"orgb7517e4-248\"><a href=\"#orgb7517e4-248\" aria-hidden=\"true\"><\/a>            <span class=\"bu\">str<\/span>(smpl_val.dtype) <span class=\"op\">!=<\/span> out_var.<span class=\"bu\">type<\/span>.dtype):<\/span>\n<span id=\"orgb7517e4-249\"><a href=\"#orgb7517e4-249\" aria-hidden=\"true\"><\/a>            smpl_val <span class=\"op\">=<\/span> theano._asarray(smpl_val, dtype<span class=\"op\">=<\/span>out_var.<span class=\"bu\">type<\/span>.dtype)<\/span>\n<span id=\"orgb7517e4-250\"><a href=\"#orgb7517e4-250\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-251\"><a href=\"#orgb7517e4-251\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># When `size` is `None`, NumPy has a tendency to unexpectedly<\/span><\/span>\n<span id=\"orgb7517e4-252\"><a href=\"#orgb7517e4-252\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># return a scalar instead of a higher-dimension array containing<\/span><\/span>\n<span id=\"orgb7517e4-253\"><a href=\"#orgb7517e4-253\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># only one element. This value should be reshaped<\/span><\/span>\n<span id=\"orgb7517e4-254\"><a href=\"#orgb7517e4-254\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: Really?  Why shouldn&#39;t the output correctly correspond to<\/span><\/span>\n<span id=\"orgb7517e4-255\"><a href=\"#orgb7517e4-255\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># the returned NumPy value?  Sounds more like a mis-specification of<\/span><\/span>\n<span id=\"orgb7517e4-256\"><a href=\"#orgb7517e4-256\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># the symbolic output variable.<\/span><\/span>\n<span id=\"orgb7517e4-257\"><a href=\"#orgb7517e4-257\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> size <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span> <span class=\"kw\">and<\/span> smpl_val.ndim <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span> <span class=\"kw\">and<\/span> out_var.ndim <span class=\"op\">&gt;<\/span> <span class=\"dv\">0<\/span>:<\/span>\n<span id=\"orgb7517e4-258\"><a href=\"#orgb7517e4-258\" aria-hidden=\"true\"><\/a>            smpl_val <span class=\"op\">=<\/span> smpl_val.reshape([<span class=\"dv\">1<\/span>] <span class=\"op\">*<\/span> out_var.ndim)<\/span>\n<span id=\"orgb7517e4-259\"><a href=\"#orgb7517e4-259\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-260\"><a href=\"#orgb7517e4-260\" aria-hidden=\"true\"><\/a>        smpl_out[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> smpl_val<\/span>\n<span id=\"orgb7517e4-261\"><a href=\"#orgb7517e4-261\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-262\"><a href=\"#orgb7517e4-262\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> grad(<span class=\"va\">self<\/span>, inputs, outputs):<\/span>\n<span id=\"orgb7517e4-263\"><a href=\"#orgb7517e4-263\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> [theano.gradient.grad_undefined(<span class=\"va\">self<\/span>, k, inp,<\/span>\n<span id=\"orgb7517e4-264\"><a href=\"#orgb7517e4-264\" aria-hidden=\"true\"><\/a>                                               <span class=\"st\">&#39;No gradient defined through raw random numbers op&#39;<\/span>)<\/span>\n<span id=\"orgb7517e4-265\"><a href=\"#orgb7517e4-265\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">for<\/span> k, inp <span class=\"kw\">in<\/span> <span class=\"bu\">enumerate<\/span>(inputs)]<\/span>\n<span id=\"orgb7517e4-266\"><a href=\"#orgb7517e4-266\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgb7517e4-267\"><a href=\"#orgb7517e4-267\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> R_op(<span class=\"va\">self<\/span>, inputs, eval_points):<\/span>\n<span id=\"orgb7517e4-268\"><a href=\"#orgb7517e4-268\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> [<span class=\"va\">None<\/span> <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> eval_points]<\/span><\/code><\/pre><\/div>\n<\/section>\n<section id=\"using-randomvariable\" class=\"level1\">\n<h1>Using <code>RandomVariable<\/code><\/h1>\n<p>In Listing <a href=\"#orgf494cec\">17<\/a> we create some <code>RandomVariable<\/code> <code>Op<\/code>s.<\/p>\n<div class=\"sourceCode\" id=\"orgf494cec\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgf494cec-1\"><a href=\"#orgf494cec-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> scipy<\/span>\n<span id=\"orgf494cec-2\"><a href=\"#orgf494cec-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> functools <span class=\"im\">import<\/span> partial<\/span>\n<span id=\"orgf494cec-3\"><a href=\"#orgf494cec-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-4\"><a href=\"#orgf494cec-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-5\"><a href=\"#orgf494cec-5\" aria-hidden=\"true\"><\/a><span class=\"co\"># Continuous Numpy-generated variates<\/span><\/span>\n<span id=\"orgf494cec-6\"><a href=\"#orgf494cec-6\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> UniformRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-7\"><a href=\"#orgf494cec-7\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-8\"><a href=\"#orgf494cec-8\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;uniform&#39;<\/span>, theano.config.floatX, <span class=\"dv\">0<\/span>, [<span class=\"dv\">0<\/span>, <span class=\"dv\">0<\/span>], <span class=\"st\">&#39;uniform&#39;<\/span>, inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-9\"><a href=\"#orgf494cec-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-10\"><a href=\"#orgf494cec-10\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, lower, upper, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-11\"><a href=\"#orgf494cec-11\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(lower, upper, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-12\"><a href=\"#orgf494cec-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-13\"><a href=\"#orgf494cec-13\" aria-hidden=\"true\"><\/a>UniformRV <span class=\"op\">=<\/span> UniformRVType()<\/span>\n<span id=\"orgf494cec-14\"><a href=\"#orgf494cec-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-15\"><a href=\"#orgf494cec-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-16\"><a href=\"#orgf494cec-16\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> NormalRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-17\"><a href=\"#orgf494cec-17\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-18\"><a href=\"#orgf494cec-18\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;normal&#39;<\/span>, theano.config.floatX, <span class=\"dv\">0<\/span>, [<span class=\"dv\">0<\/span>, <span class=\"dv\">0<\/span>], <span class=\"st\">&#39;normal&#39;<\/span>, inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-19\"><a href=\"#orgf494cec-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-20\"><a href=\"#orgf494cec-20\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, mu, sigma, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-21\"><a href=\"#orgf494cec-21\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(mu, sigma, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-22\"><a href=\"#orgf494cec-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-23\"><a href=\"#orgf494cec-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-24\"><a href=\"#orgf494cec-24\" aria-hidden=\"true\"><\/a>NormalRV <span class=\"op\">=<\/span> NormalRVType()<\/span>\n<span id=\"orgf494cec-25\"><a href=\"#orgf494cec-25\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-26\"><a href=\"#orgf494cec-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-27\"><a href=\"#orgf494cec-27\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> GammaRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-28\"><a href=\"#orgf494cec-28\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-29\"><a href=\"#orgf494cec-29\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;gamma&#39;<\/span>, theano.config.floatX, <span class=\"dv\">0<\/span>, [<span class=\"dv\">0<\/span>, <span class=\"dv\">0<\/span>], <span class=\"st\">&#39;gamma&#39;<\/span>, inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-30\"><a href=\"#orgf494cec-30\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-31\"><a href=\"#orgf494cec-31\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, shape, scale, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-32\"><a href=\"#orgf494cec-32\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(shape, scale, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-33\"><a href=\"#orgf494cec-33\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-34\"><a href=\"#orgf494cec-34\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-35\"><a href=\"#orgf494cec-35\" aria-hidden=\"true\"><\/a>GammaRV <span class=\"op\">=<\/span> GammaRVType()<\/span>\n<span id=\"orgf494cec-36\"><a href=\"#orgf494cec-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-37\"><a href=\"#orgf494cec-37\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-38\"><a href=\"#orgf494cec-38\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> ExponentialRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-39\"><a href=\"#orgf494cec-39\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-40\"><a href=\"#orgf494cec-40\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;exponential&#39;<\/span>, theano.config.floatX, <span class=\"dv\">0<\/span>, [<span class=\"dv\">0<\/span>], <span class=\"st\">&#39;exponential&#39;<\/span>, inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-41\"><a href=\"#orgf494cec-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-42\"><a href=\"#orgf494cec-42\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, scale, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-43\"><a href=\"#orgf494cec-43\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(scale, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-44\"><a href=\"#orgf494cec-44\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-45\"><a href=\"#orgf494cec-45\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-46\"><a href=\"#orgf494cec-46\" aria-hidden=\"true\"><\/a>ExponentialRV <span class=\"op\">=<\/span> ExponentialRVType()<\/span>\n<span id=\"orgf494cec-47\"><a href=\"#orgf494cec-47\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-48\"><a href=\"#orgf494cec-48\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-49\"><a href=\"#orgf494cec-49\" aria-hidden=\"true\"><\/a><span class=\"co\"># One with multivariate support<\/span><\/span>\n<span id=\"orgf494cec-50\"><a href=\"#orgf494cec-50\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> MvNormalRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-51\"><a href=\"#orgf494cec-51\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-52\"><a href=\"#orgf494cec-52\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;multivariate_normal&#39;<\/span>, theano.config.floatX, <span class=\"dv\">1<\/span>, [<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>], <span class=\"st\">&#39;multivariate_normal&#39;<\/span>, inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-53\"><a href=\"#orgf494cec-53\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-54\"><a href=\"#orgf494cec-54\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, mean, cov, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-55\"><a href=\"#orgf494cec-55\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(mean, cov, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-56\"><a href=\"#orgf494cec-56\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-57\"><a href=\"#orgf494cec-57\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-58\"><a href=\"#orgf494cec-58\" aria-hidden=\"true\"><\/a>MvNormalRV <span class=\"op\">=<\/span> MvNormalRVType()<\/span>\n<span id=\"orgf494cec-59\"><a href=\"#orgf494cec-59\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-60\"><a href=\"#orgf494cec-60\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-61\"><a href=\"#orgf494cec-61\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> DirichletRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-62\"><a href=\"#orgf494cec-62\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-63\"><a href=\"#orgf494cec-63\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;dirichlet&#39;<\/span>, theano.config.floatX, <span class=\"dv\">1<\/span>, [<span class=\"dv\">1<\/span>], <span class=\"st\">&#39;dirichlet&#39;<\/span>, inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-64\"><a href=\"#orgf494cec-64\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-65\"><a href=\"#orgf494cec-65\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, alpha, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-66\"><a href=\"#orgf494cec-66\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(alpha, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-67\"><a href=\"#orgf494cec-67\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-68\"><a href=\"#orgf494cec-68\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-69\"><a href=\"#orgf494cec-69\" aria-hidden=\"true\"><\/a>DirichletRV <span class=\"op\">=<\/span> DirichletRVType()<\/span>\n<span id=\"orgf494cec-70\"><a href=\"#orgf494cec-70\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-71\"><a href=\"#orgf494cec-71\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-72\"><a href=\"#orgf494cec-72\" aria-hidden=\"true\"><\/a><span class=\"co\"># A discrete Numpy-generated variate<\/span><\/span>\n<span id=\"orgf494cec-73\"><a href=\"#orgf494cec-73\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> PoissonRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-74\"><a href=\"#orgf494cec-74\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-75\"><a href=\"#orgf494cec-75\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;poisson&#39;<\/span>, <span class=\"st\">&#39;int64&#39;<\/span>, <span class=\"dv\">0<\/span>, [<span class=\"dv\">0<\/span>], <span class=\"st\">&#39;poisson&#39;<\/span>, inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-76\"><a href=\"#orgf494cec-76\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-77\"><a href=\"#orgf494cec-77\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, rate, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-78\"><a href=\"#orgf494cec-78\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(rate, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-79\"><a href=\"#orgf494cec-79\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-80\"><a href=\"#orgf494cec-80\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-81\"><a href=\"#orgf494cec-81\" aria-hidden=\"true\"><\/a>PoissonRV <span class=\"op\">=<\/span> PoissonRVType()<\/span>\n<span id=\"orgf494cec-82\"><a href=\"#orgf494cec-82\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-83\"><a href=\"#orgf494cec-83\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-84\"><a href=\"#orgf494cec-84\" aria-hidden=\"true\"><\/a><span class=\"co\"># A SciPy-generated variate<\/span><\/span>\n<span id=\"orgf494cec-85\"><a href=\"#orgf494cec-85\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> CauchyRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-86\"><a href=\"#orgf494cec-86\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-87\"><a href=\"#orgf494cec-87\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;cauchy&#39;<\/span>, theano.config.floatX, <span class=\"dv\">0<\/span>, [<span class=\"dv\">0<\/span>, <span class=\"dv\">0<\/span>],<\/span>\n<span id=\"orgf494cec-88\"><a href=\"#orgf494cec-88\" aria-hidden=\"true\"><\/a>                         <span class=\"kw\">lambda<\/span> rng, <span class=\"op\">*<\/span>args: scipy.stats.cauchy.rvs(<span class=\"op\">*<\/span>args, random_state<span class=\"op\">=<\/span>rng),<\/span>\n<span id=\"orgf494cec-89\"><a href=\"#orgf494cec-89\" aria-hidden=\"true\"><\/a>                         inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-90\"><a href=\"#orgf494cec-90\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-91\"><a href=\"#orgf494cec-91\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, loc, scale, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-92\"><a href=\"#orgf494cec-92\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(loc, scale, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-93\"><a href=\"#orgf494cec-93\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-94\"><a href=\"#orgf494cec-94\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-95\"><a href=\"#orgf494cec-95\" aria-hidden=\"true\"><\/a>CauchyRV <span class=\"op\">=<\/span> CauchyRVType()<\/span>\n<span id=\"orgf494cec-96\"><a href=\"#orgf494cec-96\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-97\"><a href=\"#orgf494cec-97\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-98\"><a href=\"#orgf494cec-98\" aria-hidden=\"true\"><\/a><span class=\"co\"># Support shape is determined by the first dimension in the *second* parameter (i.e.<\/span><\/span>\n<span id=\"orgf494cec-99\"><a href=\"#orgf494cec-99\" aria-hidden=\"true\"><\/a><span class=\"co\"># the probabilities vector)<\/span><\/span>\n<span id=\"orgf494cec-100\"><a href=\"#orgf494cec-100\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> MultinomialRVType(RandomVariable):<\/span>\n<span id=\"orgf494cec-101\"><a href=\"#orgf494cec-101\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgf494cec-102\"><a href=\"#orgf494cec-102\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;multinomial&#39;<\/span>, <span class=\"st\">&#39;int64&#39;<\/span>, <span class=\"dv\">1<\/span>, [<span class=\"dv\">0<\/span>, <span class=\"dv\">1<\/span>], <span class=\"st\">&#39;multinomial&#39;<\/span>,<\/span>\n<span id=\"orgf494cec-103\"><a href=\"#orgf494cec-103\" aria-hidden=\"true\"><\/a>                         supp_shape_fn<span class=\"op\">=<\/span>partial(param_supp_shape_fn, rep_param_idx<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>),<\/span>\n<span id=\"orgf494cec-104\"><a href=\"#orgf494cec-104\" aria-hidden=\"true\"><\/a>                         inplace<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"orgf494cec-105\"><a href=\"#orgf494cec-105\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-106\"><a href=\"#orgf494cec-106\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, n, pvals, size<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, rng<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgf494cec-107\"><a href=\"#orgf494cec-107\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().make_node(n, pvals, size<span class=\"op\">=<\/span>size, rng<span class=\"op\">=<\/span>rng, name<span class=\"op\">=<\/span>name)<\/span>\n<span id=\"orgf494cec-108\"><a href=\"#orgf494cec-108\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-109\"><a href=\"#orgf494cec-109\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgf494cec-110\"><a href=\"#orgf494cec-110\" aria-hidden=\"true\"><\/a>MultinomialRV <span class=\"op\">=<\/span> MultinomialRVType()<\/span><\/code><\/pre><\/div>\n<div class=\"example\" data-markdown=\"\">\n<p>In Listing <a href=\"#orged59f09\">18<\/a> we draw samples from instances of <code>RandomVariable<\/code>s.<\/p>\n<div class=\"sourceCode\" id=\"orged59f09\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orged59f09-1\"><a href=\"#orged59f09-1\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;UniformRV(0., 30., size=[10]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-2\"><a href=\"#orged59f09-2\" aria-hidden=\"true\"><\/a>    UniformRV(<span class=\"fl\">0.<\/span>, <span class=\"fl\">30.<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">10<\/span>]).<span class=\"bu\">eval<\/span>()<\/span>\n<span id=\"orged59f09-3\"><a href=\"#orged59f09-3\" aria-hidden=\"true\"><\/a>))<\/span>\n<span id=\"orged59f09-4\"><a href=\"#orged59f09-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-5\"><a href=\"#orged59f09-5\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;NormalRV([0., 100.], 30, size=[4, 2]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-6\"><a href=\"#orged59f09-6\" aria-hidden=\"true\"><\/a>    NormalRV([<span class=\"fl\">0.<\/span>, <span class=\"fl\">100.<\/span>], <span class=\"dv\">30<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span>\n<span id=\"orged59f09-7\"><a href=\"#orged59f09-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-8\"><a href=\"#orged59f09-8\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;GammaRV([2., 1.], 2., size=[4, 2]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-9\"><a href=\"#orged59f09-9\" aria-hidden=\"true\"><\/a>    GammaRV([<span class=\"fl\">2.<\/span>, <span class=\"fl\">1.<\/span>], <span class=\"fl\">2.<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span>\n<span id=\"orged59f09-10\"><a href=\"#orged59f09-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-11\"><a href=\"#orged59f09-11\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;ExponentialRV([2., 50.], size=[4, 2]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-12\"><a href=\"#orged59f09-12\" aria-hidden=\"true\"><\/a>    ExponentialRV([<span class=\"fl\">2.<\/span>, <span class=\"fl\">50.<\/span>], size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span>\n<span id=\"orged59f09-13\"><a href=\"#orged59f09-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-14\"><a href=\"#orged59f09-14\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;MvNormalRV([0, 1e2, 2e3], np.diag([1, 1, 1]), size=[3, 2, 3]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-15\"><a href=\"#orged59f09-15\" aria-hidden=\"true\"><\/a>    MvNormalRV([<span class=\"dv\">0<\/span>, <span class=\"fl\">1e2<\/span>, <span class=\"fl\">2e3<\/span>], np.diag([<span class=\"dv\">1<\/span>, <span class=\"dv\">1<\/span>, <span class=\"dv\">1<\/span>]), size<span class=\"op\">=<\/span>[<span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span>\n<span id=\"orged59f09-16\"><a href=\"#orged59f09-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-17\"><a href=\"#orged59f09-17\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;DirichletRV([0.1, 10, 0.5], size=[3, 2, 3]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-18\"><a href=\"#orged59f09-18\" aria-hidden=\"true\"><\/a>    DirichletRV([<span class=\"fl\">0.1<\/span>, <span class=\"dv\">10<\/span>, <span class=\"fl\">0.5<\/span>], size<span class=\"op\">=<\/span>[<span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span>\n<span id=\"orged59f09-19\"><a href=\"#orged59f09-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-20\"><a href=\"#orged59f09-20\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;PoissonRV([2., 1.], size=[4, 2]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-21\"><a href=\"#orged59f09-21\" aria-hidden=\"true\"><\/a>    PoissonRV([<span class=\"fl\">2.<\/span>, <span class=\"fl\">15.<\/span>], size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span>\n<span id=\"orged59f09-22\"><a href=\"#orged59f09-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-23\"><a href=\"#orged59f09-23\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;CauchyRV([1., 100.], 30, size=[4, 2]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-24\"><a href=\"#orged59f09-24\" aria-hidden=\"true\"><\/a>    CauchyRV([<span class=\"fl\">1.<\/span>, <span class=\"fl\">100.<\/span>], <span class=\"dv\">30<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span>\n<span id=\"orged59f09-25\"><a href=\"#orged59f09-25\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orged59f09-26\"><a href=\"#orged59f09-26\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;MultinomialRV(20, [1\/6.]*6, size=[6, 2]):<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orged59f09-27\"><a href=\"#orged59f09-27\" aria-hidden=\"true\"><\/a>    MultinomialRV(<span class=\"dv\">20<\/span>, [<span class=\"dv\">1<\/span> <span class=\"op\">\/<\/span> <span class=\"fl\">6.<\/span>] <span class=\"op\">*<\/span> <span class=\"dv\">6<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">3<\/span>, <span class=\"dv\">2<\/span>]).<span class=\"bu\">eval<\/span>()))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orga89dd83\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orga89dd83-1\"><a href=\"#orga89dd83-1\" aria-hidden=\"true\"><\/a>UniformRV(<span class=\"fl\">0.<\/span>, <span class=\"fl\">30.<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">10<\/span>]):<\/span>\n<span id=\"orga89dd83-2\"><a href=\"#orga89dd83-2\" aria-hidden=\"true\"><\/a>[ <span class=\"fl\">5.83131933<\/span> <span class=\"fl\">28.56231204<\/span> <span class=\"fl\">20.73018065<\/span> <span class=\"fl\">17.21042461<\/span> <span class=\"fl\">25.53140341<\/span> <span class=\"fl\">23.76268637<\/span><\/span>\n<span id=\"orga89dd83-3\"><a href=\"#orga89dd83-3\" aria-hidden=\"true\"><\/a> <span class=\"fl\">28.27629994<\/span>  <span class=\"fl\">7.10457399<\/span> <span class=\"fl\">19.88378878<\/span> <span class=\"fl\">26.62382369<\/span>]<\/span>\n<span id=\"orga89dd83-4\"><a href=\"#orga89dd83-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-5\"><a href=\"#orga89dd83-5\" aria-hidden=\"true\"><\/a>NormalRV([<span class=\"fl\">0.<\/span>, <span class=\"fl\">100.<\/span>], <span class=\"dv\">30<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]):<\/span>\n<span id=\"orga89dd83-6\"><a href=\"#orga89dd83-6\" aria-hidden=\"true\"><\/a>[[  <span class=\"fl\">0.73277898<\/span>  <span class=\"fl\">98.26041204<\/span>]<\/span>\n<span id=\"orga89dd83-7\"><a href=\"#orga89dd83-7\" aria-hidden=\"true\"><\/a> [<span class=\"op\">-<\/span><span class=\"fl\">25.9810085<\/span>   <span class=\"fl\">79.13385495<\/span>]<\/span>\n<span id=\"orga89dd83-8\"><a href=\"#orga89dd83-8\" aria-hidden=\"true\"><\/a> [<span class=\"op\">-<\/span><span class=\"fl\">23.17013683<\/span> <span class=\"fl\">130.86966242<\/span>]<\/span>\n<span id=\"orga89dd83-9\"><a href=\"#orga89dd83-9\" aria-hidden=\"true\"><\/a> [<span class=\"op\">-<\/span><span class=\"fl\">52.83756722<\/span>  <span class=\"fl\">95.21829178<\/span>]]<\/span>\n<span id=\"orga89dd83-10\"><a href=\"#orga89dd83-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-11\"><a href=\"#orga89dd83-11\" aria-hidden=\"true\"><\/a>GammaRV([<span class=\"fl\">2.<\/span>, <span class=\"fl\">1.<\/span>], <span class=\"fl\">2.<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]):<\/span>\n<span id=\"orga89dd83-12\"><a href=\"#orga89dd83-12\" aria-hidden=\"true\"><\/a>[[<span class=\"fl\">5.09679154<\/span> <span class=\"fl\">0.6149213<\/span> ]<\/span>\n<span id=\"orga89dd83-13\"><a href=\"#orga89dd83-13\" aria-hidden=\"true\"><\/a> [<span class=\"fl\">2.64231927<\/span> <span class=\"fl\">0.7277265<\/span> ]<\/span>\n<span id=\"orga89dd83-14\"><a href=\"#orga89dd83-14\" aria-hidden=\"true\"><\/a> [<span class=\"fl\">5.98877316<\/span> <span class=\"fl\">0.41751667<\/span>]<\/span>\n<span id=\"orga89dd83-15\"><a href=\"#orga89dd83-15\" aria-hidden=\"true\"><\/a> [<span class=\"fl\">3.77525439<\/span> <span class=\"fl\">1.11561567<\/span>]]<\/span>\n<span id=\"orga89dd83-16\"><a href=\"#orga89dd83-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-17\"><a href=\"#orga89dd83-17\" aria-hidden=\"true\"><\/a>ExponentialRV([<span class=\"fl\">2.<\/span>, <span class=\"fl\">50.<\/span>], size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]):<\/span>\n<span id=\"orga89dd83-18\"><a href=\"#orga89dd83-18\" aria-hidden=\"true\"><\/a>[[ <span class=\"fl\">2.29684191<\/span>  <span class=\"fl\">7.12084933<\/span>]<\/span>\n<span id=\"orga89dd83-19\"><a href=\"#orga89dd83-19\" aria-hidden=\"true\"><\/a> [ <span class=\"fl\">0.39386731<\/span> <span class=\"fl\">38.79158981<\/span>]<\/span>\n<span id=\"orga89dd83-20\"><a href=\"#orga89dd83-20\" aria-hidden=\"true\"><\/a> [ <span class=\"fl\">1.11400165<\/span>  <span class=\"fl\">4.31175303<\/span>]<\/span>\n<span id=\"orga89dd83-21\"><a href=\"#orga89dd83-21\" aria-hidden=\"true\"><\/a> [ <span class=\"fl\">1.50499115<\/span>  <span class=\"fl\">9.65667649<\/span>]]<\/span>\n<span id=\"orga89dd83-22\"><a href=\"#orga89dd83-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-23\"><a href=\"#orga89dd83-23\" aria-hidden=\"true\"><\/a>MvNormalRV([<span class=\"dv\">0<\/span>, <span class=\"fl\">1e2<\/span>, <span class=\"fl\">2e3<\/span>], np.diag([<span class=\"dv\">1<\/span>, <span class=\"dv\">1<\/span>, <span class=\"dv\">1<\/span>]), size<span class=\"op\">=<\/span>[<span class=\"dv\">3<\/span>, <span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]):<\/span>\n<span id=\"orga89dd83-24\"><a href=\"#orga89dd83-24\" aria-hidden=\"true\"><\/a>[[[<span class=\"op\">-<\/span><span class=\"fl\">6.67447019e-01<\/span>  <span class=\"fl\">9.88636435e+01<\/span>  <span class=\"fl\">1.99973471e+03<\/span>]<\/span>\n<span id=\"orga89dd83-25\"><a href=\"#orga89dd83-25\" aria-hidden=\"true\"><\/a>  [ <span class=\"fl\">6.06351715e-01<\/span>  <span class=\"fl\">9.96429347e+01<\/span>  <span class=\"fl\">1.99915978e+03<\/span>]<\/span>\n<span id=\"orga89dd83-26\"><a href=\"#orga89dd83-26\" aria-hidden=\"true\"><\/a>  [ <span class=\"fl\">1.12246741e+00<\/span>  <span class=\"fl\">9.96807860e+01<\/span>  <span class=\"fl\">2.00201859e+03<\/span>]]<\/span>\n<span id=\"orga89dd83-27\"><a href=\"#orga89dd83-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-28\"><a href=\"#orga89dd83-28\" aria-hidden=\"true\"><\/a> [[ <span class=\"fl\">3.61931404e-02<\/span>  <span class=\"fl\">9.89907880e+01<\/span>  <span class=\"fl\">2.00036910e+03<\/span>]<\/span>\n<span id=\"orga89dd83-29\"><a href=\"#orga89dd83-29\" aria-hidden=\"true\"><\/a>  [<span class=\"op\">-<\/span><span class=\"fl\">1.61077330e+00<\/span>  <span class=\"fl\">1.01905479e+02<\/span>  <span class=\"fl\">2.00134565e+03<\/span>]<\/span>\n<span id=\"orga89dd83-30\"><a href=\"#orga89dd83-30\" aria-hidden=\"true\"><\/a>  [ <span class=\"fl\">9.45854243e-01<\/span>  <span class=\"fl\">1.00877071e+02<\/span>  <span class=\"fl\">1.99914438e+03<\/span>]]]<\/span>\n<span id=\"orga89dd83-31\"><a href=\"#orga89dd83-31\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-32\"><a href=\"#orga89dd83-32\" aria-hidden=\"true\"><\/a>DirichletRV([<span class=\"fl\">0.1<\/span>, <span class=\"dv\">10<\/span>, <span class=\"fl\">0.5<\/span>], size<span class=\"op\">=<\/span>[<span class=\"dv\">3<\/span>, <span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]):<\/span>\n<span id=\"orga89dd83-33\"><a href=\"#orga89dd83-33\" aria-hidden=\"true\"><\/a>[[[<span class=\"fl\">1.41863953e-06<\/span> <span class=\"fl\">9.35392908e-01<\/span> <span class=\"fl\">6.46056738e-02<\/span>]<\/span>\n<span id=\"orga89dd83-34\"><a href=\"#orga89dd83-34\" aria-hidden=\"true\"><\/a>  [<span class=\"fl\">4.50961569e-15<\/span> <span class=\"fl\">9.71338820e-01<\/span> <span class=\"fl\">2.86611803e-02<\/span>]<\/span>\n<span id=\"orga89dd83-35\"><a href=\"#orga89dd83-35\" aria-hidden=\"true\"><\/a>  [<span class=\"fl\">2.41299980e-05<\/span> <span class=\"fl\">9.94566812e-01<\/span> <span class=\"fl\">5.40905817e-03<\/span>]]<\/span>\n<span id=\"orga89dd83-36\"><a href=\"#orga89dd83-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-37\"><a href=\"#orga89dd83-37\" aria-hidden=\"true\"><\/a> [[<span class=\"fl\">5.79850503e-08<\/span> <span class=\"fl\">9.73090671e-01<\/span> <span class=\"fl\">2.69092713e-02<\/span>]<\/span>\n<span id=\"orga89dd83-38\"><a href=\"#orga89dd83-38\" aria-hidden=\"true\"><\/a>  [<span class=\"fl\">4.17758767e-09<\/span> <span class=\"fl\">9.61671733e-01<\/span> <span class=\"fl\">3.83282630e-02<\/span>]<\/span>\n<span id=\"orga89dd83-39\"><a href=\"#orga89dd83-39\" aria-hidden=\"true\"><\/a>  [<span class=\"fl\">8.78921782e-03<\/span> <span class=\"fl\">9.54146972e-01<\/span> <span class=\"fl\">3.70638103e-02<\/span>]]]<\/span>\n<span id=\"orga89dd83-40\"><a href=\"#orga89dd83-40\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-41\"><a href=\"#orga89dd83-41\" aria-hidden=\"true\"><\/a>PoissonRV([<span class=\"fl\">2.<\/span>, <span class=\"fl\">1.<\/span>], size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]):<\/span>\n<span id=\"orga89dd83-42\"><a href=\"#orga89dd83-42\" aria-hidden=\"true\"><\/a>[[ <span class=\"dv\">1<\/span> <span class=\"dv\">15<\/span>]<\/span>\n<span id=\"orga89dd83-43\"><a href=\"#orga89dd83-43\" aria-hidden=\"true\"><\/a> [ <span class=\"dv\">1<\/span> <span class=\"dv\">12<\/span>]<\/span>\n<span id=\"orga89dd83-44\"><a href=\"#orga89dd83-44\" aria-hidden=\"true\"><\/a> [ <span class=\"dv\">2<\/span> <span class=\"dv\">21<\/span>]<\/span>\n<span id=\"orga89dd83-45\"><a href=\"#orga89dd83-45\" aria-hidden=\"true\"><\/a> [ <span class=\"dv\">1<\/span> <span class=\"dv\">14<\/span>]]<\/span>\n<span id=\"orga89dd83-46\"><a href=\"#orga89dd83-46\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-47\"><a href=\"#orga89dd83-47\" aria-hidden=\"true\"><\/a>CauchyRV([<span class=\"fl\">1.<\/span>, <span class=\"fl\">100.<\/span>], <span class=\"dv\">30<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">4<\/span>, <span class=\"dv\">2<\/span>]):<\/span>\n<span id=\"orga89dd83-48\"><a href=\"#orga89dd83-48\" aria-hidden=\"true\"><\/a>[[ <span class=\"op\">-<\/span><span class=\"fl\">86.93222925<\/span>   <span class=\"fl\">79.9758127<\/span> ]<\/span>\n<span id=\"orga89dd83-49\"><a href=\"#orga89dd83-49\" aria-hidden=\"true\"><\/a> [  <span class=\"fl\">13.41882831<\/span> <span class=\"op\">-<\/span><span class=\"fl\">374.41779179<\/span>]<\/span>\n<span id=\"orga89dd83-50\"><a href=\"#orga89dd83-50\" aria-hidden=\"true\"><\/a> [  <span class=\"fl\">75.74505567<\/span>   <span class=\"fl\">93.2944822<\/span> ]<\/span>\n<span id=\"orga89dd83-51\"><a href=\"#orga89dd83-51\" aria-hidden=\"true\"><\/a> [  <span class=\"fl\">30.0824262<\/span>   <span class=\"fl\">130.40873511<\/span>]]<\/span>\n<span id=\"orga89dd83-52\"><a href=\"#orga89dd83-52\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-53\"><a href=\"#orga89dd83-53\" aria-hidden=\"true\"><\/a>MultinomialRV(<span class=\"dv\">20<\/span>, [<span class=\"dv\">1<\/span><span class=\"op\">\/<\/span><span class=\"fl\">6.<\/span>]<span class=\"op\">*<\/span><span class=\"dv\">6<\/span>, size<span class=\"op\">=<\/span>[<span class=\"dv\">6<\/span>, <span class=\"dv\">2<\/span>]):<\/span>\n<span id=\"orga89dd83-54\"><a href=\"#orga89dd83-54\" aria-hidden=\"true\"><\/a>[[[<span class=\"dv\">2<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">2<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">4<\/span>]<\/span>\n<span id=\"orga89dd83-55\"><a href=\"#orga89dd83-55\" aria-hidden=\"true\"><\/a>  [<span class=\"dv\">2<\/span> <span class=\"dv\">5<\/span> <span class=\"dv\">2<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">3<\/span> <span class=\"dv\">4<\/span>]]<\/span>\n<span id=\"orga89dd83-56\"><a href=\"#orga89dd83-56\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-57\"><a href=\"#orga89dd83-57\" aria-hidden=\"true\"><\/a> [[<span class=\"dv\">2<\/span> <span class=\"dv\">5<\/span> <span class=\"dv\">6<\/span> <span class=\"dv\">2<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">1<\/span>]<\/span>\n<span id=\"orga89dd83-58\"><a href=\"#orga89dd83-58\" aria-hidden=\"true\"><\/a>  [<span class=\"dv\">0<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">3<\/span> <span class=\"dv\">5<\/span> <span class=\"dv\">4<\/span>]]<\/span>\n<span id=\"orga89dd83-59\"><a href=\"#orga89dd83-59\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orga89dd83-60\"><a href=\"#orga89dd83-60\" aria-hidden=\"true\"><\/a> [[<span class=\"dv\">6<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">4<\/span>]<\/span>\n<span id=\"orga89dd83-61\"><a href=\"#orga89dd83-61\" aria-hidden=\"true\"><\/a>  [<span class=\"dv\">3<\/span> <span class=\"dv\">4<\/span> <span class=\"dv\">3<\/span> <span class=\"dv\">2<\/span> <span class=\"dv\">3<\/span> <span class=\"dv\">5<\/span>]]]<\/span>\n<span id=\"orga89dd83-62\"><a href=\"#orga89dd83-62\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<\/div>\n<p>As noted, there are a few long-standing difficulties surrounding the use and determination of shape information in PyMC3. <code>RandomVariable<\/code> doesn\u2019t suffer the same limitations.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>In Listing <a href=\"#org77fdfa2\">20<\/a>, we see that a multivariate normal random variable cannot be created in PyMC3 without explicit shape information.<\/p>\n<div class=\"sourceCode\" id=\"org77fdfa2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org77fdfa2-1\"><a href=\"#org77fdfa2-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> traceback<\/span>\n<span id=\"org77fdfa2-2\"><a href=\"#org77fdfa2-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org77fdfa2-3\"><a href=\"#org77fdfa2-3\" aria-hidden=\"true\"><\/a>test_mean <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;test_mean&#39;<\/span>)<\/span>\n<span id=\"org77fdfa2-4\"><a href=\"#org77fdfa2-4\" aria-hidden=\"true\"><\/a>test_cov <span class=\"op\">=<\/span> tt.matrix(<span class=\"st\">&#39;test_cov&#39;<\/span>, dtype<span class=\"op\">=<\/span><span class=\"st\">&#39;int64&#39;<\/span>)<\/span>\n<span id=\"org77fdfa2-5\"><a href=\"#org77fdfa2-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org77fdfa2-6\"><a href=\"#org77fdfa2-6\" aria-hidden=\"true\"><\/a>test_mean.tag.test_value <span class=\"op\">=<\/span> np.asarray([<span class=\"dv\">1<\/span>])<\/span>\n<span id=\"org77fdfa2-7\"><a href=\"#org77fdfa2-7\" aria-hidden=\"true\"><\/a>test_cov.tag.test_value <span class=\"op\">=<\/span> np.asarray([[<span class=\"dv\">1<\/span>]])<\/span>\n<span id=\"org77fdfa2-8\"><a href=\"#org77fdfa2-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org77fdfa2-9\"><a href=\"#org77fdfa2-9\" aria-hidden=\"true\"><\/a><span class=\"cf\">try<\/span>:<\/span>\n<span id=\"org77fdfa2-10\"><a href=\"#org77fdfa2-10\" aria-hidden=\"true\"><\/a>  <span class=\"cf\">with<\/span> pm.Model():<\/span>\n<span id=\"org77fdfa2-11\"><a href=\"#org77fdfa2-11\" aria-hidden=\"true\"><\/a>    test_rv <span class=\"op\">=<\/span> pm.MvNormal(<span class=\"st\">&#39;test_rv&#39;<\/span>, test_mean, test_cov)<\/span>\n<span id=\"org77fdfa2-12\"><a href=\"#org77fdfa2-12\" aria-hidden=\"true\"><\/a><span class=\"cf\">except<\/span> <span class=\"pp\">Exception<\/span> <span class=\"im\">as<\/span> e:<\/span>\n<span id=\"org77fdfa2-13\"><a href=\"#org77fdfa2-13\" aria-hidden=\"true\"><\/a>  <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;&quot;<\/span>.join(traceback.format_exception_only(<span class=\"bu\">type<\/span>(e), e)))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org9a55c25\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org9a55c25-1\"><a href=\"#org9a55c25-1\" aria-hidden=\"true\"><\/a><span class=\"pp\">ValueError<\/span>: Invalid dimension <span class=\"cf\">for<\/span> value: <span class=\"dv\">0<\/span><\/span>\n<span id=\"org9a55c25-2\"><a href=\"#org9a55c25-2\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>As Listing <a href=\"#orgb7414f6\">22<\/a> demonstrates, the same construction is possible when one specifies an explicit size\/shape.<\/p>\n<div class=\"sourceCode\" id=\"orgb7414f6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgb7414f6-1\"><a href=\"#orgb7414f6-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">try<\/span>:<\/span>\n<span id=\"orgb7414f6-2\"><a href=\"#orgb7414f6-2\" aria-hidden=\"true\"><\/a>  <span class=\"cf\">with<\/span> pm.Model():<\/span>\n<span id=\"orgb7414f6-3\"><a href=\"#orgb7414f6-3\" aria-hidden=\"true\"><\/a>    test_rv <span class=\"op\">=<\/span> pm.MvNormal(<span class=\"st\">&#39;test_rv&#39;<\/span>, test_mean, test_cov, shape<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)<\/span>\n<span id=\"orgb7414f6-4\"><a href=\"#orgb7414f6-4\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;test_rv.distribution.shape = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(test_rv.distribution.shape))<\/span>\n<span id=\"orgb7414f6-5\"><a href=\"#orgb7414f6-5\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;test_rv.tag.test_value = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(test_rv.tag.test_value))<\/span>\n<span id=\"orgb7414f6-6\"><a href=\"#orgb7414f6-6\" aria-hidden=\"true\"><\/a><span class=\"cf\">except<\/span> <span class=\"pp\">Exception<\/span> <span class=\"im\">as<\/span> e:<\/span>\n<span id=\"orgb7414f6-7\"><a href=\"#orgb7414f6-7\" aria-hidden=\"true\"><\/a>  <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;&quot;<\/span>.join(traceback.format_exception_only(<span class=\"bu\">type<\/span>(e), e)))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orga8cc629\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orga8cc629-1\"><a href=\"#orga8cc629-1\" aria-hidden=\"true\"><\/a>test_rv.distribution.shape <span class=\"op\">=<\/span> [<span class=\"dv\">1<\/span>]<\/span>\n<span id=\"orga8cc629-2\"><a href=\"#orga8cc629-2\" aria-hidden=\"true\"><\/a>test_rv.tag.test_value <span class=\"op\">=<\/span> [<span class=\"fl\">1.<\/span>]<\/span>\n<span id=\"orga8cc629-3\"><a href=\"#orga8cc629-3\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<\/div>\n<p>Using <code>RandomVariable<\/code>, we do not have to specify a shape, nor implement any sampling code outside of <code>RandomVariable.perform<\/code> to draw random variables and generate valid test values.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>Listings <a href=\"#org67b2727\">24<\/a> and <a href=\"#org56bde38\">26<\/a> demonstrate how easy it is to create dependencies between random variates using <code>RandomVariable<\/code>, and how sampling and test values are automatic. It uses a multivariate normal as the mean of another multivariate normal.<\/p>\n<div class=\"sourceCode\" id=\"org67b2727\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org67b2727-1\"><a href=\"#org67b2727-1\" aria-hidden=\"true\"><\/a>theano.config.compute_test_value <span class=\"op\">=<\/span> <span class=\"st\">&#39;ignore&#39;<\/span><\/span>\n<span id=\"org67b2727-2\"><a href=\"#org67b2727-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org67b2727-3\"><a href=\"#org67b2727-3\" aria-hidden=\"true\"><\/a>mu_tt <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;mu&#39;<\/span>)<\/span>\n<span id=\"org67b2727-4\"><a href=\"#org67b2727-4\" aria-hidden=\"true\"><\/a>C_tt <span class=\"op\">=<\/span> tt.matrix(<span class=\"st\">&#39;C&#39;<\/span>)<\/span>\n<span id=\"org67b2727-5\"><a href=\"#org67b2727-5\" aria-hidden=\"true\"><\/a>D_tt <span class=\"op\">=<\/span> tt.matrix(<span class=\"st\">&#39;D&#39;<\/span>)<\/span>\n<span id=\"org67b2727-6\"><a href=\"#org67b2727-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org67b2727-7\"><a href=\"#org67b2727-7\" aria-hidden=\"true\"><\/a>X_rv <span class=\"op\">=<\/span> MvNormalRV(mu_tt, C_tt)<\/span>\n<span id=\"org67b2727-8\"><a href=\"#org67b2727-8\" aria-hidden=\"true\"><\/a>Y_rv <span class=\"op\">=<\/span> MvNormalRV(X_rv, D_tt)<\/span>\n<span id=\"org67b2727-9\"><a href=\"#org67b2727-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org67b2727-10\"><a href=\"#org67b2727-10\" aria-hidden=\"true\"><\/a><span class=\"co\"># Sample some values under specific parameter values<\/span><\/span>\n<span id=\"org67b2727-11\"><a href=\"#org67b2727-11\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;<\/span><span class=\"sc\">{}<\/span><span class=\"st\"> ~ X<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"st\"> ~ Y&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"org67b2727-12\"><a href=\"#org67b2727-12\" aria-hidden=\"true\"><\/a>    X_rv.<span class=\"bu\">eval<\/span>({mu_tt: [<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>], C_tt: np.diag([<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>])}),<\/span>\n<span id=\"org67b2727-13\"><a href=\"#org67b2727-13\" aria-hidden=\"true\"><\/a>    Y_rv.<span class=\"bu\">eval<\/span>({mu_tt: [<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>], C_tt: np.diag([<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>]), D_tt: np.diag([<span class=\"dv\">10<\/span>, <span class=\"dv\">20<\/span>])})))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgd1cac3d\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgd1cac3d-1\"><a href=\"#orgd1cac3d-1\" aria-hidden=\"true\"><\/a>[<span class=\"op\">-<\/span><span class=\"fl\">1.25047147<\/span>  <span class=\"fl\">4.87459955<\/span>] <span class=\"op\">~<\/span> X<\/span>\n<span id=\"orgd1cac3d-2\"><a href=\"#orgd1cac3d-2\" aria-hidden=\"true\"><\/a>[ <span class=\"fl\">2.15486205<\/span> <span class=\"op\">-<\/span><span class=\"fl\">3.3066946<\/span> ] <span class=\"op\">~<\/span> Y<\/span>\n<span id=\"orgd1cac3d-3\"><a href=\"#orgd1cac3d-3\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org56bde38\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org56bde38-1\"><a href=\"#org56bde38-1\" aria-hidden=\"true\"><\/a>theano.config.compute_test_value <span class=\"op\">=<\/span> <span class=\"st\">&#39;warn&#39;<\/span><\/span>\n<span id=\"org56bde38-2\"><a href=\"#org56bde38-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org56bde38-3\"><a href=\"#org56bde38-3\" aria-hidden=\"true\"><\/a>mu_tt.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"dv\">0<\/span>, <span class=\"dv\">30<\/span>, <span class=\"dv\">40<\/span>])<\/span>\n<span id=\"org56bde38-4\"><a href=\"#org56bde38-4\" aria-hidden=\"true\"><\/a>C_tt.tag.test_value <span class=\"op\">=<\/span> np.diag([<span class=\"dv\">100<\/span>, <span class=\"dv\">10<\/span>, <span class=\"dv\">1<\/span>])<\/span>\n<span id=\"org56bde38-5\"><a href=\"#org56bde38-5\" aria-hidden=\"true\"><\/a>D_tt.tag.test_value <span class=\"op\">=<\/span> np.diag([<span class=\"dv\">100<\/span>, <span class=\"dv\">10<\/span>, <span class=\"dv\">1<\/span>])<\/span>\n<span id=\"org56bde38-6\"><a href=\"#org56bde38-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org56bde38-7\"><a href=\"#org56bde38-7\" aria-hidden=\"true\"><\/a>X_rv <span class=\"op\">=<\/span> MvNormalRV(mu_tt, C_tt)<\/span>\n<span id=\"org56bde38-8\"><a href=\"#org56bde38-8\" aria-hidden=\"true\"><\/a>Y_rv <span class=\"op\">=<\/span> MvNormalRV(X_rv, D_tt)<\/span>\n<span id=\"org56bde38-9\"><a href=\"#org56bde38-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org56bde38-10\"><a href=\"#org56bde38-10\" aria-hidden=\"true\"><\/a><span class=\"co\"># Observe the automatically generated test values<\/span><\/span>\n<span id=\"org56bde38-11\"><a href=\"#org56bde38-11\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;X test value: <\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">Y test value: <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"org56bde38-12\"><a href=\"#org56bde38-12\" aria-hidden=\"true\"><\/a>    X_rv.tag.test_value,<\/span>\n<span id=\"org56bde38-13\"><a href=\"#org56bde38-13\" aria-hidden=\"true\"><\/a>    Y_rv.tag.test_value))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgecacfb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgecacfb5-1\"><a href=\"#orgecacfb5-1\" aria-hidden=\"true\"><\/a>X test value: [ <span class=\"fl\">1.78826967<\/span> <span class=\"fl\">28.73266332<\/span> <span class=\"fl\">38.57297111<\/span>]<\/span>\n<span id=\"orgecacfb5-2\"><a href=\"#orgecacfb5-2\" aria-hidden=\"true\"><\/a>Y test value: [<span class=\"fl\">33.93703352<\/span> <span class=\"fl\">27.48925582<\/span> <span class=\"fl\">38.21563854<\/span>]<\/span>\n<span id=\"orgecacfb5-3\"><a href=\"#orgecacfb5-3\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<\/div>\n<div class=\"example\" data-markdown=\"\">\n<p>In Listing <a href=\"#orge489ad8\">28<\/a>, we specify the following hierarchical model:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation*}\n  \\begin{aligned}\n    M &amp;\\sim \\text{Poisson}\\left(10\\right)\n    \\\\\n    \\alpha_i &amp;\\sim \\text{Uniform}\\left(0, 1\\right),\n    \\quad i \\in \\left\\{0, \\dots, M\\right\\}\n    \\\\\n    \\pi &amp;\\sim \\text{Dirichlet}\\left(\\alpha\\right)\n    \\\\\n    Y &amp;\\sim \\text{Multinomial}\\left(M, \\pi\\right)\n  \\end{aligned}\n  \\;.\n\\end{equation*}\\]<\/span><\/p>\n<p>This toy model is particularly interesting in how it specifies symbolic dependencies between continuous and discrete distributions and uses random variables to determine the shapes of other random variables.<\/p>\n<div class=\"sourceCode\" id=\"orge489ad8\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orge489ad8-1\"><a href=\"#orge489ad8-1\" aria-hidden=\"true\"><\/a>theano.config.compute_test_value <span class=\"op\">=<\/span> <span class=\"st\">&#39;ignore&#39;<\/span><\/span>\n<span id=\"orge489ad8-2\"><a href=\"#orge489ad8-2\" aria-hidden=\"true\"><\/a>pois_rate <span class=\"op\">=<\/span> tt.dscalar(<span class=\"st\">&#39;rate&#39;<\/span>)<\/span>\n<span id=\"orge489ad8-3\"><a href=\"#orge489ad8-3\" aria-hidden=\"true\"><\/a>test_pois_rv <span class=\"op\">=<\/span> PoissonRV(pois_rate)<\/span>\n<span id=\"orge489ad8-4\"><a href=\"#orge489ad8-4\" aria-hidden=\"true\"><\/a>test_alpha <span class=\"op\">=<\/span> UniformRV(<span class=\"dv\">0<\/span>, <span class=\"dv\">1<\/span>, size<span class=\"op\">=<\/span>test_pois_rv)<\/span>\n<span id=\"orge489ad8-5\"><a href=\"#orge489ad8-5\" aria-hidden=\"true\"><\/a>test_dirichlet_rv <span class=\"op\">=<\/span> DirichletRV(test_uniform_rv)<\/span>\n<span id=\"orge489ad8-6\"><a href=\"#orge489ad8-6\" aria-hidden=\"true\"><\/a>test_multinom_rv <span class=\"op\">=<\/span> MultinomialRV(test_pois_rv, test_dirichlet_rv)<\/span>\n<span id=\"orge489ad8-7\"><a href=\"#orge489ad8-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orge489ad8-8\"><a href=\"#orge489ad8-8\" aria-hidden=\"true\"><\/a>test_multinom_draw <span class=\"op\">=<\/span> theano.function(inputs<span class=\"op\">=<\/span>[], outputs<span class=\"op\">=<\/span>test_multinom_rv,<\/span>\n<span id=\"orge489ad8-9\"><a href=\"#orge489ad8-9\" aria-hidden=\"true\"><\/a>                                     givens<span class=\"op\">=<\/span>{pois_rate: <span class=\"fl\">10.<\/span>})<\/span>\n<span id=\"orge489ad8-10\"><a href=\"#orge489ad8-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orge489ad8-11\"><a href=\"#orge489ad8-11\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;test_multinom_rv draw 1: <\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">test_multinom_rv draw 2: <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orge489ad8-12\"><a href=\"#orge489ad8-12\" aria-hidden=\"true\"><\/a>    test_multinom_draw(), test_multinom_draw()))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgff8bd09\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgff8bd09-1\"><a href=\"#orgff8bd09-1\" aria-hidden=\"true\"><\/a>test_multinom_rv draw <span class=\"dv\">1<\/span>: [<span class=\"dv\">0<\/span> <span class=\"dv\">2<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">2<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">0<\/span>]<\/span>\n<span id=\"orgff8bd09-2\"><a href=\"#orgff8bd09-2\" aria-hidden=\"true\"><\/a>test_multinom_rv draw <span class=\"dv\">2<\/span>: [<span class=\"dv\">5<\/span> <span class=\"dv\">2<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">0<\/span> <span class=\"dv\">1<\/span> <span class=\"dv\">0<\/span>]<\/span>\n<span id=\"orgff8bd09-3\"><a href=\"#orgff8bd09-3\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<\/div>\n<section id=\"random-variable-pretty-printing\" class=\"level2\">\n<h2>Random Variable Pretty Printing<\/h2>\n<p>In Listing <a href=\"#org52ecc27\">30<\/a>, we implement a pretty printer that produces more readable forms of Theano graphs containing <code>RandomVariable<\/code> nodes.<\/p>\n<div class=\"sourceCode\" id=\"org52ecc27\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org52ecc27-1\"><a href=\"#org52ecc27-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> RandomVariablePrinter:<\/span>\n<span id=\"org52ecc27-2\"><a href=\"#org52ecc27-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Pretty print random variables.<\/span><\/span>\n<span id=\"org52ecc27-3\"><a href=\"#org52ecc27-3\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"org52ecc27-4\"><a href=\"#org52ecc27-4\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>, name<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"org52ecc27-5\"><a href=\"#org52ecc27-5\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;<\/span><\/span>\n<span id=\"org52ecc27-6\"><a href=\"#org52ecc27-6\" aria-hidden=\"true\"><\/a><span class=\"co\">        Parameters<\/span><\/span>\n<span id=\"org52ecc27-7\"><a href=\"#org52ecc27-7\" aria-hidden=\"true\"><\/a><span class=\"co\">        ==========<\/span><\/span>\n<span id=\"org52ecc27-8\"><a href=\"#org52ecc27-8\" aria-hidden=\"true\"><\/a><span class=\"co\">        name: str (optional)<\/span><\/span>\n<span id=\"org52ecc27-9\"><a href=\"#org52ecc27-9\" aria-hidden=\"true\"><\/a><span class=\"co\">            A fixed name to use for the random variables printed by this<\/span><\/span>\n<span id=\"org52ecc27-10\"><a href=\"#org52ecc27-10\" aria-hidden=\"true\"><\/a><span class=\"co\">            printer.  If not specified, use `RandomVariable.name`.<\/span><\/span>\n<span id=\"org52ecc27-11\"><a href=\"#org52ecc27-11\" aria-hidden=\"true\"><\/a><span class=\"co\">        &quot;&quot;&quot;<\/span><\/span>\n<span id=\"org52ecc27-12\"><a href=\"#org52ecc27-12\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.name <span class=\"op\">=<\/span> name<\/span>\n<span id=\"org52ecc27-13\"><a href=\"#org52ecc27-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-14\"><a href=\"#org52ecc27-14\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process_param(<span class=\"va\">self<\/span>, idx, sform, pstate):<\/span>\n<span id=\"org52ecc27-15\"><a href=\"#org52ecc27-15\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;Special per-parameter post-formatting.<\/span><\/span>\n<span id=\"org52ecc27-16\"><a href=\"#org52ecc27-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-17\"><a href=\"#org52ecc27-17\" aria-hidden=\"true\"><\/a><span class=\"co\">        This can be used, for instance, to change a std. dev. into a variance.<\/span><\/span>\n<span id=\"org52ecc27-18\"><a href=\"#org52ecc27-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-19\"><a href=\"#org52ecc27-19\" aria-hidden=\"true\"><\/a><span class=\"co\">        Parameters<\/span><\/span>\n<span id=\"org52ecc27-20\"><a href=\"#org52ecc27-20\" aria-hidden=\"true\"><\/a><span class=\"co\">        ==========<\/span><\/span>\n<span id=\"org52ecc27-21\"><a href=\"#org52ecc27-21\" aria-hidden=\"true\"><\/a><span class=\"co\">        idx: int<\/span><\/span>\n<span id=\"org52ecc27-22\"><a href=\"#org52ecc27-22\" aria-hidden=\"true\"><\/a><span class=\"co\">            The index value of the parameter.<\/span><\/span>\n<span id=\"org52ecc27-23\"><a href=\"#org52ecc27-23\" aria-hidden=\"true\"><\/a><span class=\"co\">        sform: str<\/span><\/span>\n<span id=\"org52ecc27-24\"><a href=\"#org52ecc27-24\" aria-hidden=\"true\"><\/a><span class=\"co\">            The pre-formatted string form of the parameter.<\/span><\/span>\n<span id=\"org52ecc27-25\"><a href=\"#org52ecc27-25\" aria-hidden=\"true\"><\/a><span class=\"co\">        pstate: object<\/span><\/span>\n<span id=\"org52ecc27-26\"><a href=\"#org52ecc27-26\" aria-hidden=\"true\"><\/a><span class=\"co\">            The printer state.<\/span><\/span>\n<span id=\"org52ecc27-27\"><a href=\"#org52ecc27-27\" aria-hidden=\"true\"><\/a><span class=\"co\">        &quot;&quot;&quot;<\/span><\/span>\n<span id=\"org52ecc27-28\"><a href=\"#org52ecc27-28\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> sform<\/span>\n<span id=\"org52ecc27-29\"><a href=\"#org52ecc27-29\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-30\"><a href=\"#org52ecc27-30\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process(<span class=\"va\">self<\/span>, output, pstate):<\/span>\n<span id=\"org52ecc27-31\"><a href=\"#org52ecc27-31\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> output <span class=\"kw\">in<\/span> pstate.memo:<\/span>\n<span id=\"org52ecc27-32\"><a href=\"#org52ecc27-32\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">return<\/span> pstate.memo[output]<\/span>\n<span id=\"org52ecc27-33\"><a href=\"#org52ecc27-33\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-34\"><a href=\"#org52ecc27-34\" aria-hidden=\"true\"><\/a>        pprinter <span class=\"op\">=<\/span> pstate.pprinter<\/span>\n<span id=\"org52ecc27-35\"><a href=\"#org52ecc27-35\" aria-hidden=\"true\"><\/a>        node <span class=\"op\">=<\/span> output.owner<\/span>\n<span id=\"org52ecc27-36\"><a href=\"#org52ecc27-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-37\"><a href=\"#org52ecc27-37\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> node <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span> <span class=\"kw\">or<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(node.op, RandomVariable):<\/span>\n<span id=\"org52ecc27-38\"><a href=\"#org52ecc27-38\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">raise<\/span> <span class=\"pp\">TypeError<\/span>(<span class=\"st\">&quot;function <\/span><span class=\"sc\">%s<\/span><span class=\"st\"> cannot represent a variable that is &quot;<\/span><\/span>\n<span id=\"org52ecc27-39\"><a href=\"#org52ecc27-39\" aria-hidden=\"true\"><\/a>                            <span class=\"st\">&quot;not the result of a RandomVariable operation&quot;<\/span> <span class=\"op\">%<\/span><\/span>\n<span id=\"org52ecc27-40\"><a href=\"#org52ecc27-40\" aria-hidden=\"true\"><\/a>                            <span class=\"va\">self<\/span>.name)<\/span>\n<span id=\"org52ecc27-41\"><a href=\"#org52ecc27-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-42\"><a href=\"#org52ecc27-42\" aria-hidden=\"true\"><\/a>        new_precedence <span class=\"op\">=<\/span> <span class=\"op\">-<\/span><span class=\"dv\">1000<\/span><\/span>\n<span id=\"org52ecc27-43\"><a href=\"#org52ecc27-43\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">try<\/span>:<\/span>\n<span id=\"org52ecc27-44\"><a href=\"#org52ecc27-44\" aria-hidden=\"true\"><\/a>            old_precedence <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;precedence&#39;<\/span>, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"org52ecc27-45\"><a href=\"#org52ecc27-45\" aria-hidden=\"true\"><\/a>            pstate.precedence <span class=\"op\">=<\/span> new_precedence<\/span>\n<span id=\"org52ecc27-46\"><a href=\"#org52ecc27-46\" aria-hidden=\"true\"><\/a>            out_name <span class=\"op\">=<\/span> VariableWithShapePrinter.process_variable_name(<\/span>\n<span id=\"org52ecc27-47\"><a href=\"#org52ecc27-47\" aria-hidden=\"true\"><\/a>                output, pstate)<\/span>\n<span id=\"org52ecc27-48\"><a href=\"#org52ecc27-48\" aria-hidden=\"true\"><\/a>            shape_info_str <span class=\"op\">=<\/span> VariableWithShapePrinter.process_shape_info(<\/span>\n<span id=\"org52ecc27-49\"><a href=\"#org52ecc27-49\" aria-hidden=\"true\"><\/a>                output, pstate)<\/span>\n<span id=\"org52ecc27-50\"><a href=\"#org52ecc27-50\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;latex&#39;<\/span>, <span class=\"va\">False<\/span>):<\/span>\n<span id=\"org52ecc27-51\"><a href=\"#org52ecc27-51\" aria-hidden=\"true\"><\/a>                dist_format <span class=\"op\">=<\/span> <span class=\"st\">&quot;<\/span><span class=\"sc\">%s<\/span><span class=\"st\"> <\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">sim <\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">operatorname{<\/span><span class=\"sc\">%s<\/span><span class=\"st\">}<\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">left(<\/span><span class=\"sc\">%s<\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">right)&quot;<\/span><\/span>\n<span id=\"org52ecc27-52\"><a href=\"#org52ecc27-52\" aria-hidden=\"true\"><\/a>                dist_format <span class=\"op\">+=<\/span> <span class=\"st\">&#39;, <\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">quad <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(shape_info_str)<\/span>\n<span id=\"org52ecc27-53\"><a href=\"#org52ecc27-53\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org52ecc27-54\"><a href=\"#org52ecc27-54\" aria-hidden=\"true\"><\/a>                dist_format <span class=\"op\">=<\/span> <span class=\"st\">&quot;<\/span><span class=\"sc\">%s<\/span><span class=\"st\"> ~ <\/span><span class=\"sc\">%s<\/span><span class=\"st\">(<\/span><span class=\"sc\">%s<\/span><span class=\"st\">)&quot;<\/span><\/span>\n<span id=\"org52ecc27-55\"><a href=\"#org52ecc27-55\" aria-hidden=\"true\"><\/a>                dist_format <span class=\"op\">+=<\/span> <span class=\"st\">&#39;,  <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(shape_info_str)<\/span>\n<span id=\"org52ecc27-56\"><a href=\"#org52ecc27-56\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-57\"><a href=\"#org52ecc27-57\" aria-hidden=\"true\"><\/a>            op_name <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.name <span class=\"kw\">or<\/span> node.op.name<\/span>\n<span id=\"org52ecc27-58\"><a href=\"#org52ecc27-58\" aria-hidden=\"true\"><\/a>            dist_params <span class=\"op\">=<\/span> node.inputs[:<span class=\"op\">-<\/span><span class=\"dv\">2<\/span>]<\/span>\n<span id=\"org52ecc27-59\"><a href=\"#org52ecc27-59\" aria-hidden=\"true\"><\/a>            formatted_params <span class=\"op\">=<\/span> [<\/span>\n<span id=\"org52ecc27-60\"><a href=\"#org52ecc27-60\" aria-hidden=\"true\"><\/a>                <span class=\"va\">self<\/span>.process_param(i, pprinter.process(p, pstate), pstate)<\/span>\n<span id=\"org52ecc27-61\"><a href=\"#org52ecc27-61\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">for<\/span> i, p <span class=\"kw\">in<\/span> <span class=\"bu\">enumerate<\/span>(dist_params)<\/span>\n<span id=\"org52ecc27-62\"><a href=\"#org52ecc27-62\" aria-hidden=\"true\"><\/a>            ]<\/span>\n<span id=\"org52ecc27-63\"><a href=\"#org52ecc27-63\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-64\"><a href=\"#org52ecc27-64\" aria-hidden=\"true\"><\/a>            dist_params_r <span class=\"op\">=<\/span> dist_format <span class=\"op\">%<\/span> (out_name,<\/span>\n<span id=\"org52ecc27-65\"><a href=\"#org52ecc27-65\" aria-hidden=\"true\"><\/a>                                           op_name,<\/span>\n<span id=\"org52ecc27-66\"><a href=\"#org52ecc27-66\" aria-hidden=\"true\"><\/a>                                           <span class=\"st\">&quot;, &quot;<\/span>.join(formatted_params))<\/span>\n<span id=\"org52ecc27-67\"><a href=\"#org52ecc27-67\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">finally<\/span>:<\/span>\n<span id=\"org52ecc27-68\"><a href=\"#org52ecc27-68\" aria-hidden=\"true\"><\/a>            pstate.precedence <span class=\"op\">=<\/span> old_precedence<\/span>\n<span id=\"org52ecc27-69\"><a href=\"#org52ecc27-69\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-70\"><a href=\"#org52ecc27-70\" aria-hidden=\"true\"><\/a>        pstate.preamble_lines <span class=\"op\">+=<\/span> [dist_params_r]<\/span>\n<span id=\"org52ecc27-71\"><a href=\"#org52ecc27-71\" aria-hidden=\"true\"><\/a>        pstate.memo[output] <span class=\"op\">=<\/span> out_name<\/span>\n<span id=\"org52ecc27-72\"><a href=\"#org52ecc27-72\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org52ecc27-73\"><a href=\"#org52ecc27-73\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> out_name<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org414cfc6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org414cfc6-1\"><a href=\"#org414cfc6-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> string<\/span>\n<span id=\"org414cfc6-2\"><a href=\"#org414cfc6-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-3\"><a href=\"#org414cfc6-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> copy <span class=\"im\">import<\/span> copy<\/span>\n<span id=\"org414cfc6-4\"><a href=\"#org414cfc6-4\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> collections <span class=\"im\">import<\/span> OrderedDict<\/span>\n<span id=\"org414cfc6-5\"><a href=\"#org414cfc6-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-6\"><a href=\"#org414cfc6-6\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> sympy <span class=\"im\">import<\/span> Array <span class=\"im\">as<\/span> SympyArray<\/span>\n<span id=\"org414cfc6-7\"><a href=\"#org414cfc6-7\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> sympy.printing <span class=\"im\">import<\/span> latex <span class=\"im\">as<\/span> sympy_latex<\/span>\n<span id=\"org414cfc6-8\"><a href=\"#org414cfc6-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-9\"><a href=\"#org414cfc6-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-10\"><a href=\"#org414cfc6-10\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> VariableWithShapePrinter:<\/span>\n<span id=\"org414cfc6-11\"><a href=\"#org414cfc6-11\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Print variable shape info in the preamble and use readable character<\/span><\/span>\n<span id=\"org414cfc6-12\"><a href=\"#org414cfc6-12\" aria-hidden=\"true\"><\/a><span class=\"co\">    names for unamed variables.<\/span><\/span>\n<span id=\"org414cfc6-13\"><a href=\"#org414cfc6-13\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"org414cfc6-14\"><a href=\"#org414cfc6-14\" aria-hidden=\"true\"><\/a>    available_names <span class=\"op\">=<\/span> OrderedDict.fromkeys(string.ascii_letters)<\/span>\n<span id=\"org414cfc6-15\"><a href=\"#org414cfc6-15\" aria-hidden=\"true\"><\/a>    default_printer <span class=\"op\">=<\/span> theano.printing.default_printer<\/span>\n<span id=\"org414cfc6-16\"><a href=\"#org414cfc6-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-17\"><a href=\"#org414cfc6-17\" aria-hidden=\"true\"><\/a>    <span class=\"at\">@classmethod<\/span><\/span>\n<span id=\"org414cfc6-18\"><a href=\"#org414cfc6-18\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process(cls, output, pstate):<\/span>\n<span id=\"org414cfc6-19\"><a href=\"#org414cfc6-19\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> output <span class=\"kw\">in<\/span> pstate.memo:<\/span>\n<span id=\"org414cfc6-20\"><a href=\"#org414cfc6-20\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">return<\/span> pstate.memo[output]<\/span>\n<span id=\"org414cfc6-21\"><a href=\"#org414cfc6-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-22\"><a href=\"#org414cfc6-22\" aria-hidden=\"true\"><\/a>        using_latex <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;latex&#39;<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org414cfc6-23\"><a href=\"#org414cfc6-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-24\"><a href=\"#org414cfc6-24\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(output, tt.gof.Constant):<\/span>\n<span id=\"org414cfc6-25\"><a href=\"#org414cfc6-25\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> output.ndim <span class=\"op\">&gt;<\/span> <span class=\"dv\">0<\/span> <span class=\"kw\">and<\/span> using_latex:<\/span>\n<span id=\"org414cfc6-26\"><a href=\"#org414cfc6-26\" aria-hidden=\"true\"><\/a>                out_name <span class=\"op\">=<\/span> sympy_latex(SympyArray(output.data))<\/span>\n<span id=\"org414cfc6-27\"><a href=\"#org414cfc6-27\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org414cfc6-28\"><a href=\"#org414cfc6-28\" aria-hidden=\"true\"><\/a>                out_name <span class=\"op\">=<\/span> <span class=\"bu\">str<\/span>(output.data)<\/span>\n<span id=\"org414cfc6-29\"><a href=\"#org414cfc6-29\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> <span class=\"bu\">isinstance<\/span>(output, tt.TensorVariable):<\/span>\n<span id=\"org414cfc6-30\"><a href=\"#org414cfc6-30\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># Process name and shape<\/span><\/span>\n<span id=\"org414cfc6-31\"><a href=\"#org414cfc6-31\" aria-hidden=\"true\"><\/a>            out_name <span class=\"op\">=<\/span> cls.process_variable_name(output, pstate)<\/span>\n<span id=\"org414cfc6-32\"><a href=\"#org414cfc6-32\" aria-hidden=\"true\"><\/a>            shape_info <span class=\"op\">=<\/span> cls.process_shape_info(output, pstate)<\/span>\n<span id=\"org414cfc6-33\"><a href=\"#org414cfc6-33\" aria-hidden=\"true\"><\/a>            pstate.preamble_lines <span class=\"op\">+=<\/span> [shape_info]<\/span>\n<span id=\"org414cfc6-34\"><a href=\"#org414cfc6-34\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> output.name:<\/span>\n<span id=\"org414cfc6-35\"><a href=\"#org414cfc6-35\" aria-hidden=\"true\"><\/a>            out_name <span class=\"op\">=<\/span> output.name<\/span>\n<span id=\"org414cfc6-36\"><a href=\"#org414cfc6-36\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org414cfc6-37\"><a href=\"#org414cfc6-37\" aria-hidden=\"true\"><\/a>            out_name <span class=\"op\">=<\/span> cls.default_printer.process(output, pstate)<\/span>\n<span id=\"org414cfc6-38\"><a href=\"#org414cfc6-38\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-39\"><a href=\"#org414cfc6-39\" aria-hidden=\"true\"><\/a>        pstate.memo[output] <span class=\"op\">=<\/span> out_name<\/span>\n<span id=\"org414cfc6-40\"><a href=\"#org414cfc6-40\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> out_name<\/span>\n<span id=\"org414cfc6-41\"><a href=\"#org414cfc6-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-42\"><a href=\"#org414cfc6-42\" aria-hidden=\"true\"><\/a>    <span class=\"at\">@classmethod<\/span><\/span>\n<span id=\"org414cfc6-43\"><a href=\"#org414cfc6-43\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process_shape_name(cls, output, pstate):<\/span>\n<span id=\"org414cfc6-44\"><a href=\"#org414cfc6-44\" aria-hidden=\"true\"><\/a>        shape_of_var <span class=\"op\">=<\/span> output.owner.inputs[<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"org414cfc6-45\"><a href=\"#org414cfc6-45\" aria-hidden=\"true\"><\/a>        shape_names <span class=\"op\">=<\/span> pstate.memo.setdefault(<span class=\"st\">&#39;shape_names&#39;<\/span>, {})<\/span>\n<span id=\"org414cfc6-46\"><a href=\"#org414cfc6-46\" aria-hidden=\"true\"><\/a>        out_name <span class=\"op\">=<\/span> shape_names.setdefault(<\/span>\n<span id=\"org414cfc6-47\"><a href=\"#org414cfc6-47\" aria-hidden=\"true\"><\/a>            shape_of_var, cls.process_variable_name(output, pstate))<\/span>\n<span id=\"org414cfc6-48\"><a href=\"#org414cfc6-48\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> out_name<\/span>\n<span id=\"org414cfc6-49\"><a href=\"#org414cfc6-49\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-50\"><a href=\"#org414cfc6-50\" aria-hidden=\"true\"><\/a>    <span class=\"at\">@classmethod<\/span><\/span>\n<span id=\"org414cfc6-51\"><a href=\"#org414cfc6-51\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process_variable_name(cls, output, pstate):<\/span>\n<span id=\"org414cfc6-52\"><a href=\"#org414cfc6-52\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> output <span class=\"kw\">in<\/span> pstate.memo:<\/span>\n<span id=\"org414cfc6-53\"><a href=\"#org414cfc6-53\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">return<\/span> pstate.memo[output]<\/span>\n<span id=\"org414cfc6-54\"><a href=\"#org414cfc6-54\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-55\"><a href=\"#org414cfc6-55\" aria-hidden=\"true\"><\/a>        available_names <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;available_names&#39;<\/span>, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"org414cfc6-56\"><a href=\"#org414cfc6-56\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> available_names <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"org414cfc6-57\"><a href=\"#org414cfc6-57\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># Initialize this state&#39;s available names<\/span><\/span>\n<span id=\"org414cfc6-58\"><a href=\"#org414cfc6-58\" aria-hidden=\"true\"><\/a>            available_names <span class=\"op\">=<\/span> copy(cls.available_names)<\/span>\n<span id=\"org414cfc6-59\"><a href=\"#org414cfc6-59\" aria-hidden=\"true\"><\/a>            fgraph <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(output, <span class=\"st\">&#39;fgraph&#39;<\/span>, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"org414cfc6-60\"><a href=\"#org414cfc6-60\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> fgraph:<\/span>\n<span id=\"org414cfc6-61\"><a href=\"#org414cfc6-61\" aria-hidden=\"true\"><\/a>                <span class=\"co\"># Remove known names in the graph.<\/span><\/span>\n<span id=\"org414cfc6-62\"><a href=\"#org414cfc6-62\" aria-hidden=\"true\"><\/a>                _ <span class=\"op\">=<\/span> [available_names.pop(v.name, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"org414cfc6-63\"><a href=\"#org414cfc6-63\" aria-hidden=\"true\"><\/a>                     <span class=\"cf\">for<\/span> v <span class=\"kw\">in<\/span> fgraph.variables]<\/span>\n<span id=\"org414cfc6-64\"><a href=\"#org414cfc6-64\" aria-hidden=\"true\"><\/a>            <span class=\"bu\">setattr<\/span>(pstate, <span class=\"st\">&#39;available_names&#39;<\/span>, available_names)<\/span>\n<span id=\"org414cfc6-65\"><a href=\"#org414cfc6-65\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-66\"><a href=\"#org414cfc6-66\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> output.name:<\/span>\n<span id=\"org414cfc6-67\"><a href=\"#org414cfc6-67\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># Observed an existing name; remove it.<\/span><\/span>\n<span id=\"org414cfc6-68\"><a href=\"#org414cfc6-68\" aria-hidden=\"true\"><\/a>            out_name <span class=\"op\">=<\/span> output.name<\/span>\n<span id=\"org414cfc6-69\"><a href=\"#org414cfc6-69\" aria-hidden=\"true\"><\/a>            available_names.pop(out_name, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"org414cfc6-70\"><a href=\"#org414cfc6-70\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org414cfc6-71\"><a href=\"#org414cfc6-71\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># Take an unused name.<\/span><\/span>\n<span id=\"org414cfc6-72\"><a href=\"#org414cfc6-72\" aria-hidden=\"true\"><\/a>            out_name, _ <span class=\"op\">=<\/span> available_names.popitem(last<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"org414cfc6-73\"><a href=\"#org414cfc6-73\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-74\"><a href=\"#org414cfc6-74\" aria-hidden=\"true\"><\/a>        pstate.memo[output] <span class=\"op\">=<\/span> out_name<\/span>\n<span id=\"org414cfc6-75\"><a href=\"#org414cfc6-75\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> out_name<\/span>\n<span id=\"org414cfc6-76\"><a href=\"#org414cfc6-76\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-77\"><a href=\"#org414cfc6-77\" aria-hidden=\"true\"><\/a>    <span class=\"at\">@classmethod<\/span><\/span>\n<span id=\"org414cfc6-78\"><a href=\"#org414cfc6-78\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process_shape_info(cls, output, pstate):<\/span>\n<span id=\"org414cfc6-79\"><a href=\"#org414cfc6-79\" aria-hidden=\"true\"><\/a>        using_latex <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;latex&#39;<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org414cfc6-80\"><a href=\"#org414cfc6-80\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-81\"><a href=\"#org414cfc6-81\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> output.dtype <span class=\"kw\">in<\/span> tt.int_dtypes:<\/span>\n<span id=\"org414cfc6-82\"><a href=\"#org414cfc6-82\" aria-hidden=\"true\"><\/a>            sspace_char <span class=\"op\">=<\/span> <span class=\"st\">&#39;Z&#39;<\/span><\/span>\n<span id=\"org414cfc6-83\"><a href=\"#org414cfc6-83\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> output.dtype <span class=\"kw\">in<\/span> tt.uint_dtypes:<\/span>\n<span id=\"org414cfc6-84\"><a href=\"#org414cfc6-84\" aria-hidden=\"true\"><\/a>            sspace_char <span class=\"op\">=<\/span> <span class=\"st\">&#39;N&#39;<\/span><\/span>\n<span id=\"org414cfc6-85\"><a href=\"#org414cfc6-85\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> output.dtype <span class=\"kw\">in<\/span> tt.float_dtypes:<\/span>\n<span id=\"org414cfc6-86\"><a href=\"#org414cfc6-86\" aria-hidden=\"true\"><\/a>            sspace_char <span class=\"op\">=<\/span> <span class=\"st\">&#39;R&#39;<\/span><\/span>\n<span id=\"org414cfc6-87\"><a href=\"#org414cfc6-87\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org414cfc6-88\"><a href=\"#org414cfc6-88\" aria-hidden=\"true\"><\/a>            sspace_char <span class=\"op\">=<\/span> <span class=\"st\">&#39;?&#39;<\/span><\/span>\n<span id=\"org414cfc6-89\"><a href=\"#org414cfc6-89\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-90\"><a href=\"#org414cfc6-90\" aria-hidden=\"true\"><\/a>        fgraph <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(output, <span class=\"st\">&#39;fgraph&#39;<\/span>, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"org414cfc6-91\"><a href=\"#org414cfc6-91\" aria-hidden=\"true\"><\/a>        shape_feature <span class=\"op\">=<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"org414cfc6-92\"><a href=\"#org414cfc6-92\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> fgraph:<\/span>\n<span id=\"org414cfc6-93\"><a href=\"#org414cfc6-93\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">hasattr<\/span>(fgraph, <span class=\"st\">&#39;shape_feature&#39;<\/span>):<\/span>\n<span id=\"org414cfc6-94\"><a href=\"#org414cfc6-94\" aria-hidden=\"true\"><\/a>                fgraph.attach_feature(tt.opt.ShapeFeature())<\/span>\n<span id=\"org414cfc6-95\"><a href=\"#org414cfc6-95\" aria-hidden=\"true\"><\/a>            shape_feature <span class=\"op\">=<\/span> fgraph.shape_feature<\/span>\n<span id=\"org414cfc6-96\"><a href=\"#org414cfc6-96\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-97\"><a href=\"#org414cfc6-97\" aria-hidden=\"true\"><\/a>        shape_dims <span class=\"op\">=<\/span> []<\/span>\n<span id=\"org414cfc6-98\"><a href=\"#org414cfc6-98\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(output.ndim):<\/span>\n<span id=\"org414cfc6-99\"><a href=\"#org414cfc6-99\" aria-hidden=\"true\"><\/a>            s_i_out <span class=\"op\">=<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"org414cfc6-100\"><a href=\"#org414cfc6-100\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> using_latex:<\/span>\n<span id=\"org414cfc6-101\"><a href=\"#org414cfc6-101\" aria-hidden=\"true\"><\/a>                s_i_pat <span class=\"op\">=<\/span> <span class=\"st\">&#39;{n^{<\/span><span class=\"sc\">%s}}<\/span><span class=\"st\">&#39;<\/span> <span class=\"op\">+<\/span> (<span class=\"st\">&#39;_{<\/span><span class=\"sc\">%s<\/span><span class=\"st\">}&#39;<\/span> <span class=\"op\">%<\/span> i)<\/span>\n<span id=\"org414cfc6-102\"><a href=\"#org414cfc6-102\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org414cfc6-103\"><a href=\"#org414cfc6-103\" aria-hidden=\"true\"><\/a>                s_i_pat <span class=\"op\">=<\/span> <span class=\"st\">&#39;n^<\/span><span class=\"sc\">%s<\/span><span class=\"st\">&#39;<\/span> <span class=\"op\">+<\/span> (<span class=\"st\">&#39;_<\/span><span class=\"sc\">%s<\/span><span class=\"st\">&#39;<\/span> <span class=\"op\">%<\/span> i)<\/span>\n<span id=\"org414cfc6-104\"><a href=\"#org414cfc6-104\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> shape_feature:<\/span>\n<span id=\"org414cfc6-105\"><a href=\"#org414cfc6-105\" aria-hidden=\"true\"><\/a>                new_precedence <span class=\"op\">=<\/span> <span class=\"op\">-<\/span><span class=\"dv\">1000<\/span><\/span>\n<span id=\"org414cfc6-106\"><a href=\"#org414cfc6-106\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">try<\/span>:<\/span>\n<span id=\"org414cfc6-107\"><a href=\"#org414cfc6-107\" aria-hidden=\"true\"><\/a>                    old_precedence <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;precedence&#39;<\/span>, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"org414cfc6-108\"><a href=\"#org414cfc6-108\" aria-hidden=\"true\"><\/a>                    pstate.precedence <span class=\"op\">=<\/span> new_precedence<\/span>\n<span id=\"org414cfc6-109\"><a href=\"#org414cfc6-109\" aria-hidden=\"true\"><\/a>                    _s_i_out <span class=\"op\">=<\/span> shape_feature.get_shape(output, i)<\/span>\n<span id=\"org414cfc6-110\"><a href=\"#org414cfc6-110\" aria-hidden=\"true\"><\/a>                    <span class=\"cf\">if<\/span> _s_i_out.owner:<\/span>\n<span id=\"org414cfc6-111\"><a href=\"#org414cfc6-111\" aria-hidden=\"true\"><\/a>                        <span class=\"cf\">if<\/span> (<span class=\"bu\">isinstance<\/span>(_s_i_out.owner.op, tt.Subtensor) <span class=\"kw\">and<\/span><\/span>\n<span id=\"org414cfc6-112\"><a href=\"#org414cfc6-112\" aria-hidden=\"true\"><\/a>                            <span class=\"bu\">all<\/span>(<span class=\"bu\">isinstance<\/span>(i, tt.Constant)<\/span>\n<span id=\"org414cfc6-113\"><a href=\"#org414cfc6-113\" aria-hidden=\"true\"><\/a>                                <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> _s_i_out.owner.inputs)):<\/span>\n<span id=\"org414cfc6-114\"><a href=\"#org414cfc6-114\" aria-hidden=\"true\"><\/a>                            s_i_out <span class=\"op\">=<\/span> <span class=\"bu\">str<\/span>(_s_i_out.owner.inputs[<span class=\"dv\">0<\/span>].data[<\/span>\n<span id=\"org414cfc6-115\"><a href=\"#org414cfc6-115\" aria-hidden=\"true\"><\/a>                                _s_i_out.owner.inputs[<span class=\"dv\">1<\/span>].data])<\/span>\n<span id=\"org414cfc6-116\"><a href=\"#org414cfc6-116\" aria-hidden=\"true\"><\/a>                        <span class=\"cf\">elif<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(_s_i_out, tt.TensorVariable):<\/span>\n<span id=\"org414cfc6-117\"><a href=\"#org414cfc6-117\" aria-hidden=\"true\"><\/a>                            s_i_out <span class=\"op\">=<\/span> pstate.pprinter.process(_s_i_out, pstate)<\/span>\n<span id=\"org414cfc6-118\"><a href=\"#org414cfc6-118\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">except<\/span> <span class=\"pp\">KeyError<\/span>:<\/span>\n<span id=\"org414cfc6-119\"><a href=\"#org414cfc6-119\" aria-hidden=\"true\"><\/a>                    <span class=\"cf\">pass<\/span><\/span>\n<span id=\"org414cfc6-120\"><a href=\"#org414cfc6-120\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">finally<\/span>:<\/span>\n<span id=\"org414cfc6-121\"><a href=\"#org414cfc6-121\" aria-hidden=\"true\"><\/a>                    pstate.precedence <span class=\"op\">=<\/span> old_precedence<\/span>\n<span id=\"org414cfc6-122\"><a href=\"#org414cfc6-122\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-123\"><a href=\"#org414cfc6-123\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> s_i_out:<\/span>\n<span id=\"org414cfc6-124\"><a href=\"#org414cfc6-124\" aria-hidden=\"true\"><\/a>                s_i_out <span class=\"op\">=<\/span> cls.process_variable_name(output, pstate)<\/span>\n<span id=\"org414cfc6-125\"><a href=\"#org414cfc6-125\" aria-hidden=\"true\"><\/a>                s_i_out <span class=\"op\">=<\/span> s_i_pat <span class=\"op\">%<\/span> s_i_out<\/span>\n<span id=\"org414cfc6-126\"><a href=\"#org414cfc6-126\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-127\"><a href=\"#org414cfc6-127\" aria-hidden=\"true\"><\/a>            shape_dims <span class=\"op\">+=<\/span> [s_i_out]<\/span>\n<span id=\"org414cfc6-128\"><a href=\"#org414cfc6-128\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-129\"><a href=\"#org414cfc6-129\" aria-hidden=\"true\"><\/a>        shape_info <span class=\"op\">=<\/span> cls.process_variable_name(output, pstate)<\/span>\n<span id=\"org414cfc6-130\"><a href=\"#org414cfc6-130\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> using_latex:<\/span>\n<span id=\"org414cfc6-131\"><a href=\"#org414cfc6-131\" aria-hidden=\"true\"><\/a>            shape_info <span class=\"op\">+=<\/span> <span class=\"st\">&#39; <\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">in <\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">mathbb{<\/span><span class=\"sc\">%s<\/span><span class=\"st\">}&#39;<\/span> <span class=\"op\">%<\/span> sspace_char<\/span>\n<span id=\"org414cfc6-132\"><a href=\"#org414cfc6-132\" aria-hidden=\"true\"><\/a>            shape_dims <span class=\"op\">=<\/span> <span class=\"st\">&#39; <\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">times &#39;<\/span>.join(shape_dims)<\/span>\n<span id=\"org414cfc6-133\"><a href=\"#org414cfc6-133\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> shape_dims:<\/span>\n<span id=\"org414cfc6-134\"><a href=\"#org414cfc6-134\" aria-hidden=\"true\"><\/a>                shape_info <span class=\"op\">+=<\/span> <span class=\"st\">&#39;^{<\/span><span class=\"sc\">%s<\/span><span class=\"st\">}&#39;<\/span> <span class=\"op\">%<\/span> shape_dims<\/span>\n<span id=\"org414cfc6-135\"><a href=\"#org414cfc6-135\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"org414cfc6-136\"><a href=\"#org414cfc6-136\" aria-hidden=\"true\"><\/a>            shape_info <span class=\"op\">+=<\/span> <span class=\"st\">&#39; in <\/span><span class=\"sc\">%s<\/span><span class=\"st\">&#39;<\/span> <span class=\"op\">%<\/span> sspace_char<\/span>\n<span id=\"org414cfc6-137\"><a href=\"#org414cfc6-137\" aria-hidden=\"true\"><\/a>            shape_dims <span class=\"op\">=<\/span> <span class=\"st\">&#39; x &#39;<\/span>.join(shape_dims)<\/span>\n<span id=\"org414cfc6-138\"><a href=\"#org414cfc6-138\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> shape_dims:<\/span>\n<span id=\"org414cfc6-139\"><a href=\"#org414cfc6-139\" aria-hidden=\"true\"><\/a>                shape_info <span class=\"op\">+=<\/span> <span class=\"st\">&#39;**(<\/span><span class=\"sc\">%s<\/span><span class=\"st\">)&#39;<\/span> <span class=\"op\">%<\/span> shape_dims<\/span>\n<span id=\"org414cfc6-140\"><a href=\"#org414cfc6-140\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org414cfc6-141\"><a href=\"#org414cfc6-141\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> shape_info<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgccd5273\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgccd5273-1\"><a href=\"#orgccd5273-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> textwrap<\/span>\n<span id=\"orgccd5273-2\"><a href=\"#orgccd5273-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-3\"><a href=\"#orgccd5273-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-4\"><a href=\"#orgccd5273-4\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> PreamblePPrinter(theano.printing.PPrinter):<\/span>\n<span id=\"orgccd5273-5\"><a href=\"#orgccd5273-5\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Pretty printer that displays a preamble.<\/span><\/span>\n<span id=\"orgccd5273-6\"><a href=\"#orgccd5273-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-7\"><a href=\"#orgccd5273-7\" aria-hidden=\"true\"><\/a><span class=\"co\">    For example,<\/span><\/span>\n<span id=\"orgccd5273-8\"><a href=\"#orgccd5273-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-9\"><a href=\"#orgccd5273-9\" aria-hidden=\"true\"><\/a><span class=\"co\">        X ~ N(\\mu, \\sigma)<\/span><\/span>\n<span id=\"orgccd5273-10\"><a href=\"#orgccd5273-10\" aria-hidden=\"true\"><\/a><span class=\"co\">        (b * X)<\/span><\/span>\n<span id=\"orgccd5273-11\"><a href=\"#orgccd5273-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-12\"><a href=\"#orgccd5273-12\" aria-hidden=\"true\"><\/a><span class=\"co\">    XXX: Not thread-safe!<\/span><\/span>\n<span id=\"orgccd5273-13\"><a href=\"#orgccd5273-13\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgccd5273-14\"><a href=\"#orgccd5273-14\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>, <span class=\"op\">*<\/span>args, pstate_defaults<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, <span class=\"op\">**<\/span>kwargs):<\/span>\n<span id=\"orgccd5273-15\"><a href=\"#orgccd5273-15\" aria-hidden=\"true\"><\/a>        <span class=\"co\">&quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgccd5273-16\"><a href=\"#orgccd5273-16\" aria-hidden=\"true\"><\/a><span class=\"co\">        Parameters<\/span><\/span>\n<span id=\"orgccd5273-17\"><a href=\"#orgccd5273-17\" aria-hidden=\"true\"><\/a><span class=\"co\">        ==========<\/span><\/span>\n<span id=\"orgccd5273-18\"><a href=\"#orgccd5273-18\" aria-hidden=\"true\"><\/a><span class=\"co\">        pstate_defaults: dict (optional)<\/span><\/span>\n<span id=\"orgccd5273-19\"><a href=\"#orgccd5273-19\" aria-hidden=\"true\"><\/a><span class=\"co\">            Default printer state parameters.<\/span><\/span>\n<span id=\"orgccd5273-20\"><a href=\"#orgccd5273-20\" aria-hidden=\"true\"><\/a><span class=\"co\">        &quot;&quot;&quot;<\/span><\/span>\n<span id=\"orgccd5273-21\"><a href=\"#orgccd5273-21\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"op\">*<\/span>args, <span class=\"op\">**<\/span>kwargs)<\/span>\n<span id=\"orgccd5273-22\"><a href=\"#orgccd5273-22\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.pstate_defaults <span class=\"op\">=<\/span> pstate_defaults <span class=\"kw\">or<\/span> {}<\/span>\n<span id=\"orgccd5273-23\"><a href=\"#orgccd5273-23\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.printers_dict <span class=\"op\">=<\/span> <span class=\"bu\">dict<\/span>(tt.pprint.printers_dict)<\/span>\n<span id=\"orgccd5273-24\"><a href=\"#orgccd5273-24\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.printers <span class=\"op\">=<\/span> copy(tt.pprint.printers)<\/span>\n<span id=\"orgccd5273-25\"><a href=\"#orgccd5273-25\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>._pstate <span class=\"op\">=<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"orgccd5273-26\"><a href=\"#orgccd5273-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-27\"><a href=\"#orgccd5273-27\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> create_state(<span class=\"va\">self<\/span>, pstate):<\/span>\n<span id=\"orgccd5273-28\"><a href=\"#orgccd5273-28\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># <\/span><span class=\"al\">FIXME<\/span><span class=\"co\">: Find all the user-defined node names and make the tag<\/span><\/span>\n<span id=\"orgccd5273-29\"><a href=\"#orgccd5273-29\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># generator aware of them.<\/span><\/span>\n<span id=\"orgccd5273-30\"><a href=\"#orgccd5273-30\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> pstate <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"orgccd5273-31\"><a href=\"#orgccd5273-31\" aria-hidden=\"true\"><\/a>            pstate <span class=\"op\">=<\/span> theano.printing.PrinterState(<\/span>\n<span id=\"orgccd5273-32\"><a href=\"#orgccd5273-32\" aria-hidden=\"true\"><\/a>                pprinter<span class=\"op\">=<\/span><span class=\"va\">self<\/span>,<\/span>\n<span id=\"orgccd5273-33\"><a href=\"#orgccd5273-33\" aria-hidden=\"true\"><\/a>                preamble_lines<span class=\"op\">=<\/span>[],<\/span>\n<span id=\"orgccd5273-34\"><a href=\"#orgccd5273-34\" aria-hidden=\"true\"><\/a>                <span class=\"op\">**<\/span><span class=\"va\">self<\/span>.pstate_defaults)<\/span>\n<span id=\"orgccd5273-35\"><a href=\"#orgccd5273-35\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> <span class=\"bu\">isinstance<\/span>(pstate, <span class=\"bu\">dict<\/span>):<\/span>\n<span id=\"orgccd5273-36\"><a href=\"#orgccd5273-36\" aria-hidden=\"true\"><\/a>            pstate.setdefault(<span class=\"st\">&#39;preamble_lines&#39;<\/span>, [])<\/span>\n<span id=\"orgccd5273-37\"><a href=\"#orgccd5273-37\" aria-hidden=\"true\"><\/a>            pstate.update(<span class=\"va\">self<\/span>.pstate_defaults)<\/span>\n<span id=\"orgccd5273-38\"><a href=\"#orgccd5273-38\" aria-hidden=\"true\"><\/a>            pstate <span class=\"op\">=<\/span> theano.printing.PrinterState(pprinter<span class=\"op\">=<\/span><span class=\"va\">self<\/span>, <span class=\"op\">**<\/span>pstate)<\/span>\n<span id=\"orgccd5273-39\"><a href=\"#orgccd5273-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-40\"><a href=\"#orgccd5273-40\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># <\/span><span class=\"al\">FIXME<\/span><span class=\"co\">: Good old fashioned circular references...<\/span><\/span>\n<span id=\"orgccd5273-41\"><a href=\"#orgccd5273-41\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># We&#39;re doing this so that `self.process` will be called correctly<\/span><\/span>\n<span id=\"orgccd5273-42\"><a href=\"#orgccd5273-42\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># accross all code.  (I&#39;m lookin&#39; about you, `DimShufflePrinter`; get<\/span><\/span>\n<span id=\"orgccd5273-43\"><a href=\"#orgccd5273-43\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># your act together.)<\/span><\/span>\n<span id=\"orgccd5273-44\"><a href=\"#orgccd5273-44\" aria-hidden=\"true\"><\/a>        pstate.pprinter._pstate <span class=\"op\">=<\/span> pstate<\/span>\n<span id=\"orgccd5273-45\"><a href=\"#orgccd5273-45\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-46\"><a href=\"#orgccd5273-46\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> pstate<\/span>\n<span id=\"orgccd5273-47\"><a href=\"#orgccd5273-47\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-48\"><a href=\"#orgccd5273-48\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process(<span class=\"va\">self<\/span>, r, pstate<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgccd5273-49\"><a href=\"#orgccd5273-49\" aria-hidden=\"true\"><\/a>        pstate <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>._pstate<\/span>\n<span id=\"orgccd5273-50\"><a href=\"#orgccd5273-50\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">assert<\/span> pstate<\/span>\n<span id=\"orgccd5273-51\"><a href=\"#orgccd5273-51\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> <span class=\"bu\">super<\/span>().process(r, pstate)<\/span>\n<span id=\"orgccd5273-52\"><a href=\"#orgccd5273-52\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-53\"><a href=\"#orgccd5273-53\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process_graph(<span class=\"va\">self<\/span>, inputs, outputs, updates<span class=\"op\">=<\/span><span class=\"va\">None<\/span>,<\/span>\n<span id=\"orgccd5273-54\"><a href=\"#orgccd5273-54\" aria-hidden=\"true\"><\/a>                      display_inputs<span class=\"op\">=<\/span><span class=\"va\">False<\/span>):<\/span>\n<span id=\"orgccd5273-55\"><a href=\"#orgccd5273-55\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">raise<\/span> <span class=\"va\">NotImplemented<\/span>()<\/span>\n<span id=\"orgccd5273-56\"><a href=\"#orgccd5273-56\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-57\"><a href=\"#orgccd5273-57\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__call__<\/span>(<span class=\"va\">self<\/span>, <span class=\"op\">*<\/span>args, latex_env<span class=\"op\">=<\/span><span class=\"st\">&#39;equation&#39;<\/span>, latex_label<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"orgccd5273-58\"><a href=\"#orgccd5273-58\" aria-hidden=\"true\"><\/a>        var <span class=\"op\">=<\/span> args[<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"orgccd5273-59\"><a href=\"#orgccd5273-59\" aria-hidden=\"true\"><\/a>        pstate <span class=\"op\">=<\/span> <span class=\"bu\">next<\/span>(<span class=\"bu\">iter<\/span>(args[<span class=\"dv\">1<\/span>:]), <span class=\"va\">None<\/span>)<\/span>\n<span id=\"orgccd5273-60\"><a href=\"#orgccd5273-60\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(pstate, (theano.printing.PrinterState, <span class=\"bu\">dict<\/span>)):<\/span>\n<span id=\"orgccd5273-61\"><a href=\"#orgccd5273-61\" aria-hidden=\"true\"><\/a>            pstate <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.create_state(args[<span class=\"dv\">1<\/span>])<\/span>\n<span id=\"orgccd5273-62\"><a href=\"#orgccd5273-62\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">elif<\/span> pstate <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"orgccd5273-63\"><a href=\"#orgccd5273-63\" aria-hidden=\"true\"><\/a>            pstate <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.create_state(<span class=\"va\">None<\/span>)<\/span>\n<span id=\"orgccd5273-64\"><a href=\"#orgccd5273-64\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># else:<\/span><\/span>\n<span id=\"orgccd5273-65\"><a href=\"#orgccd5273-65\" aria-hidden=\"true\"><\/a>        <span class=\"co\">#     # XXX: The graph processing doesn&#39;t pass around the printer state!<\/span><\/span>\n<span id=\"orgccd5273-66\"><a href=\"#orgccd5273-66\" aria-hidden=\"true\"><\/a>        <span class=\"co\">#     # <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: We&#39;ll have to copy the code and fix it...<\/span><\/span>\n<span id=\"orgccd5273-67\"><a href=\"#orgccd5273-67\" aria-hidden=\"true\"><\/a>        <span class=\"co\">#     raise NotImplemented(&#39;No preambles for graph printing, yet.&#39;)<\/span><\/span>\n<span id=\"orgccd5273-68\"><a href=\"#orgccd5273-68\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-69\"><a href=\"#orgccd5273-69\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># This pretty printer needs more information about shapes and inputs,<\/span><\/span>\n<span id=\"orgccd5273-70\"><a href=\"#orgccd5273-70\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># which it gets from a `FunctionGraph`.  Create one, if `var` isn&#39;t<\/span><\/span>\n<span id=\"orgccd5273-71\"><a href=\"#orgccd5273-71\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># already assigned one.<\/span><\/span>\n<span id=\"orgccd5273-72\"><a href=\"#orgccd5273-72\" aria-hidden=\"true\"><\/a>        fgraph <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(var, <span class=\"st\">&#39;fgraph&#39;<\/span>, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"orgccd5273-73\"><a href=\"#orgccd5273-73\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> fgraph:<\/span>\n<span id=\"orgccd5273-74\"><a href=\"#orgccd5273-74\" aria-hidden=\"true\"><\/a>            fgraph <span class=\"op\">=<\/span> tt.gof.fg.FunctionGraph(<\/span>\n<span id=\"orgccd5273-75\"><a href=\"#orgccd5273-75\" aria-hidden=\"true\"><\/a>                tt.gof.graph.inputs([var]), [var])<\/span>\n<span id=\"orgccd5273-76\"><a href=\"#orgccd5273-76\" aria-hidden=\"true\"><\/a>            var <span class=\"op\">=<\/span> fgraph.outputs[<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"orgccd5273-77\"><a href=\"#orgccd5273-77\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-78\"><a href=\"#orgccd5273-78\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># Use this to get better shape info<\/span><\/span>\n<span id=\"orgccd5273-79\"><a href=\"#orgccd5273-79\" aria-hidden=\"true\"><\/a>            shape_feature <span class=\"op\">=<\/span> tt.opt.ShapeFeature()<\/span>\n<span id=\"orgccd5273-80\"><a href=\"#orgccd5273-80\" aria-hidden=\"true\"><\/a>            fgraph.attach_feature(shape_feature)<\/span>\n<span id=\"orgccd5273-81\"><a href=\"#orgccd5273-81\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-82\"><a href=\"#orgccd5273-82\" aria-hidden=\"true\"><\/a>        body_str <span class=\"op\">=<\/span> <span class=\"bu\">super<\/span>().<span class=\"fu\">__call__<\/span>(var, pstate)<\/span>\n<span id=\"orgccd5273-83\"><a href=\"#orgccd5273-83\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-84\"><a href=\"#orgccd5273-84\" aria-hidden=\"true\"><\/a>        latex_out <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;latex&#39;<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"orgccd5273-85\"><a href=\"#orgccd5273-85\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> pstate.preamble_lines <span class=\"kw\">and<\/span> latex_out:<\/span>\n<span id=\"orgccd5273-86\"><a href=\"#orgccd5273-86\" aria-hidden=\"true\"><\/a>            preamble_str <span class=\"op\">=<\/span> <span class=\"st\">&quot;<\/span><span class=\"ch\">\\n\\\\\\\\\\n<\/span><span class=\"st\">&quot;<\/span>.join(pstate.preamble_lines)<\/span>\n<span id=\"orgccd5273-87\"><a href=\"#orgccd5273-87\" aria-hidden=\"true\"><\/a>            preamble_str <span class=\"op\">=<\/span> <span class=\"st\">&quot;<\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">begin<\/span><span class=\"sc\">{gathered}<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">%s<\/span><span class=\"ch\">\\n\\\\<\/span><span class=\"st\">end<\/span><span class=\"sc\">{gathered}<\/span><span class=\"st\">&quot;<\/span> <span class=\"op\">%<\/span> (preamble_str)<\/span>\n<span id=\"orgccd5273-88\"><a href=\"#orgccd5273-88\" aria-hidden=\"true\"><\/a>            res <span class=\"op\">=<\/span> <span class=\"st\">&quot;<\/span><span class=\"ch\">\\n\\\\\\\\\\n<\/span><span class=\"st\">&quot;<\/span>.join([preamble_str, body_str])<\/span>\n<span id=\"orgccd5273-89\"><a href=\"#orgccd5273-89\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgccd5273-90\"><a href=\"#orgccd5273-90\" aria-hidden=\"true\"><\/a>            res <span class=\"op\">=<\/span> <span class=\"st\">&quot;<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&quot;<\/span>.join(pstate.preamble_lines <span class=\"op\">+<\/span> [body_str])<\/span>\n<span id=\"orgccd5273-91\"><a href=\"#orgccd5273-91\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-92\"><a href=\"#orgccd5273-92\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> latex_out <span class=\"kw\">and<\/span> latex_env:<\/span>\n<span id=\"orgccd5273-93\"><a href=\"#orgccd5273-93\" aria-hidden=\"true\"><\/a>            label_out <span class=\"op\">=<\/span> <span class=\"ss\">f&#39;<\/span><span class=\"ch\">\\\\<\/span><span class=\"ss\">label<\/span><span class=\"ch\">{{<\/span><span class=\"sc\">{<\/span>latex_label<span class=\"sc\">}<\/span><span class=\"ss\">}}<\/span><span class=\"ch\">\\n<\/span><span class=\"ss\">&#39;<\/span> <span class=\"cf\">if<\/span> latex_label <span class=\"cf\">else<\/span> <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"orgccd5273-94\"><a href=\"#orgccd5273-94\" aria-hidden=\"true\"><\/a>            res <span class=\"op\">=<\/span> textwrap.indent(res, <span class=\"st\">&#39;<\/span><span class=\"ch\">\\t\\t<\/span><span class=\"st\">&#39;<\/span>)<\/span>\n<span id=\"orgccd5273-95\"><a href=\"#orgccd5273-95\" aria-hidden=\"true\"><\/a>            res <span class=\"op\">=<\/span> (<span class=\"ss\">f&quot;<\/span><span class=\"ch\">\\\\<\/span><span class=\"ss\">begin<\/span><span class=\"ch\">{{<\/span><span class=\"sc\">{<\/span>latex_env<span class=\"sc\">}<\/span><span class=\"ss\">}}<\/span><span class=\"ch\">\\n<\/span><span class=\"ss\">&quot;<\/span><\/span>\n<span id=\"orgccd5273-96\"><a href=\"#orgccd5273-96\" aria-hidden=\"true\"><\/a>                   <span class=\"ss\">f&quot;<\/span><span class=\"sc\">{<\/span>res<span class=\"sc\">}<\/span><span class=\"ch\">\\n<\/span><span class=\"ss\">&quot;<\/span><\/span>\n<span id=\"orgccd5273-97\"><a href=\"#orgccd5273-97\" aria-hidden=\"true\"><\/a>                   <span class=\"ss\">f&quot;<\/span><span class=\"sc\">{<\/span>label_out<span class=\"sc\">}<\/span><span class=\"ss\">&quot;<\/span><\/span>\n<span id=\"orgccd5273-98\"><a href=\"#orgccd5273-98\" aria-hidden=\"true\"><\/a>                   <span class=\"ss\">f&quot;<\/span><span class=\"ch\">\\\\<\/span><span class=\"ss\">end<\/span><span class=\"ch\">{{<\/span><span class=\"sc\">{<\/span>latex_env<span class=\"sc\">}<\/span><span class=\"ss\">}}&quot;<\/span>)<\/span>\n<span id=\"orgccd5273-99\"><a href=\"#orgccd5273-99\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgccd5273-100\"><a href=\"#orgccd5273-100\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> res<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orgfc82717\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgfc82717-1\"><a href=\"#orgfc82717-1\" aria-hidden=\"true\"><\/a>tt_pprint <span class=\"op\">=<\/span> PreamblePPrinter()<\/span>\n<span id=\"orgfc82717-2\"><a href=\"#orgfc82717-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgfc82717-3\"><a href=\"#orgfc82717-3\" aria-hidden=\"true\"><\/a>tt_pprint.assign(<span class=\"kw\">lambda<\/span> pstate, r: <span class=\"va\">True<\/span>, VariableWithShapePrinter)<\/span>\n<span id=\"orgfc82717-4\"><a href=\"#orgfc82717-4\" aria-hidden=\"true\"><\/a>tt_pprint.assign(UniformRV, RandomVariablePrinter(<span class=\"st\">&#39;U&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-5\"><a href=\"#orgfc82717-5\" aria-hidden=\"true\"><\/a>tt_pprint.assign(GammaRV, RandomVariablePrinter(<span class=\"st\">&#39;Gamma&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-6\"><a href=\"#orgfc82717-6\" aria-hidden=\"true\"><\/a>tt_pprint.assign(ExponentialRV, RandomVariablePrinter(<span class=\"st\">&#39;Exp&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-7\"><a href=\"#orgfc82717-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgfc82717-8\"><a href=\"#orgfc82717-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgfc82717-9\"><a href=\"#orgfc82717-9\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> NormalRVPrinter(RandomVariablePrinter):<\/span>\n<span id=\"orgfc82717-10\"><a href=\"#orgfc82717-10\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"orgfc82717-11\"><a href=\"#orgfc82717-11\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">super<\/span>().<span class=\"fu\">__init__<\/span>(<span class=\"st\">&#39;N&#39;<\/span>)<\/span>\n<span id=\"orgfc82717-12\"><a href=\"#orgfc82717-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgfc82717-13\"><a href=\"#orgfc82717-13\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> process_param(<span class=\"va\">self<\/span>, idx, sform, pstate):<\/span>\n<span id=\"orgfc82717-14\"><a href=\"#orgfc82717-14\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> idx <span class=\"op\">==<\/span> <span class=\"dv\">1<\/span>:<\/span>\n<span id=\"orgfc82717-15\"><a href=\"#orgfc82717-15\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> <span class=\"bu\">getattr<\/span>(pstate, <span class=\"st\">&#39;latex&#39;<\/span>, <span class=\"va\">False<\/span>):<\/span>\n<span id=\"orgfc82717-16\"><a href=\"#orgfc82717-16\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">return<\/span> <span class=\"ss\">f&#39;<\/span><span class=\"ch\">{{<\/span><span class=\"sc\">{<\/span>sform<span class=\"sc\">}<\/span><span class=\"ss\">}}^<\/span><span class=\"ch\">{{<\/span><span class=\"ss\">2}}&#39;<\/span><\/span>\n<span id=\"orgfc82717-17\"><a href=\"#orgfc82717-17\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgfc82717-18\"><a href=\"#orgfc82717-18\" aria-hidden=\"true\"><\/a>                <span class=\"cf\">return<\/span> <span class=\"ss\">f&#39;<\/span><span class=\"sc\">{<\/span>sform<span class=\"sc\">}<\/span><span class=\"ss\">**2&#39;<\/span><\/span>\n<span id=\"orgfc82717-19\"><a href=\"#orgfc82717-19\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"orgfc82717-20\"><a href=\"#orgfc82717-20\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">return<\/span> sform<\/span>\n<span id=\"orgfc82717-21\"><a href=\"#orgfc82717-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgfc82717-22\"><a href=\"#orgfc82717-22\" aria-hidden=\"true\"><\/a>tt_pprint.assign(NormalRV, NormalRVPrinter())<\/span>\n<span id=\"orgfc82717-23\"><a href=\"#orgfc82717-23\" aria-hidden=\"true\"><\/a>tt_pprint.assign(MvNormalRV, NormalRVPrinter())<\/span>\n<span id=\"orgfc82717-24\"><a href=\"#orgfc82717-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgfc82717-25\"><a href=\"#orgfc82717-25\" aria-hidden=\"true\"><\/a>tt_pprint.assign(DirichletRV, RandomVariablePrinter(<span class=\"st\">&#39;Dir&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-26\"><a href=\"#orgfc82717-26\" aria-hidden=\"true\"><\/a>tt_pprint.assign(PoissonRV, RandomVariablePrinter(<span class=\"st\">&#39;Pois&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-27\"><a href=\"#orgfc82717-27\" aria-hidden=\"true\"><\/a>tt_pprint.assign(CauchyRV, RandomVariablePrinter(<span class=\"st\">&#39;C&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-28\"><a href=\"#orgfc82717-28\" aria-hidden=\"true\"><\/a>tt_pprint.assign(MultinomialRV, RandomVariablePrinter(<span class=\"st\">&#39;MN&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-29\"><a href=\"#orgfc82717-29\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgfc82717-30\"><a href=\"#orgfc82717-30\" aria-hidden=\"true\"><\/a>tt_tex_pprint <span class=\"op\">=<\/span> PreamblePPrinter(pstate_defaults<span class=\"op\">=<\/span>{<span class=\"st\">&#39;latex&#39;<\/span>: <span class=\"va\">True<\/span>})<\/span>\n<span id=\"orgfc82717-31\"><a href=\"#orgfc82717-31\" aria-hidden=\"true\"><\/a>tt_tex_pprint.printers <span class=\"op\">=<\/span> copy(tt_pprint.printers)<\/span>\n<span id=\"orgfc82717-32\"><a href=\"#orgfc82717-32\" aria-hidden=\"true\"><\/a>tt_tex_pprint.printers_dict <span class=\"op\">=<\/span> <span class=\"bu\">dict<\/span>(tt_pprint.printers_dict)<\/span>\n<span id=\"orgfc82717-33\"><a href=\"#orgfc82717-33\" aria-hidden=\"true\"><\/a>tt_tex_pprint.assign(tt.mul, theano.printing.OperatorPrinter(<span class=\"st\">&#39;<\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">odot&#39;<\/span>, <span class=\"op\">-<\/span><span class=\"dv\">1<\/span>, <span class=\"st\">&#39;either&#39;<\/span>))<\/span>\n<span id=\"orgfc82717-34\"><a href=\"#orgfc82717-34\" aria-hidden=\"true\"><\/a>tt_tex_pprint.assign(tt.true_div, theano.printing.PatternPrinter((<span class=\"st\">&#39;<\/span><span class=\"ch\">\\\\<\/span><span class=\"st\">frac{<\/span><span class=\"sc\">%(0)s<\/span><span class=\"st\">}{<\/span><span class=\"sc\">%(1)s<\/span><span class=\"st\">}&#39;<\/span>, <span class=\"op\">-<\/span><span class=\"dv\">1000<\/span>)))<\/span>\n<span id=\"orgfc82717-35\"><a href=\"#orgfc82717-35\" aria-hidden=\"true\"><\/a>tt_tex_pprint.assign(tt.<span class=\"bu\">pow<\/span>, theano.printing.PatternPrinter((<span class=\"st\">&#39;{<\/span><span class=\"sc\">%(0)s<\/span><span class=\"st\">}^{<\/span><span class=\"sc\">%(1)s<\/span><span class=\"st\">}&#39;<\/span>, <span class=\"op\">-<\/span><span class=\"dv\">1000<\/span>)))<\/span><\/code><\/pre><\/div>\n<div class=\"example\" data-markdown=\"\">\n<p>Listing <a href=\"#org1ed4a46\">35<\/a>, creates a graph with two random variables and prints the results with the default Theano pretty printer as Equation <span class=\"math inline\">\\(\\eqref{eq:rv-pprinter-exa}\\)<\/span>.<\/p>\n<div class=\"sourceCode\" id=\"org0e19f4e\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org0e19f4e-1\"><a href=\"#org0e19f4e-1\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org0e19f4e-2\"><a href=\"#org0e19f4e-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org0e19f4e-3\"><a href=\"#org0e19f4e-3\" aria-hidden=\"true\"><\/a>tt.config.compute_test_value <span class=\"op\">=<\/span> <span class=\"st\">&#39;ignore&#39;<\/span><\/span>\n<span id=\"org0e19f4e-4\"><a href=\"#org0e19f4e-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org0e19f4e-5\"><a href=\"#org0e19f4e-5\" aria-hidden=\"true\"><\/a>Z_tt <span class=\"op\">=<\/span> UniformRV(tt.scalar(<span class=\"st\">&#39;l_0&#39;<\/span>), tt.scalar(<span class=\"st\">&#39;l_1&#39;<\/span>), name<span class=\"op\">=<\/span><span class=\"st\">&#39;Z&#39;<\/span>)<\/span>\n<span id=\"org0e19f4e-6\"><a href=\"#org0e19f4e-6\" aria-hidden=\"true\"><\/a>X_tt <span class=\"op\">=<\/span> NormalRV(Z_tt, tt.scalar(<span class=\"st\">&#39;\\sigma_1&#39;<\/span>), name<span class=\"op\">=<\/span><span class=\"st\">&#39;X&#39;<\/span>)<\/span>\n<span id=\"org0e19f4e-7\"><a href=\"#org0e19f4e-7\" aria-hidden=\"true\"><\/a>Y_tt <span class=\"op\">=<\/span> MvNormalRV(tt.vector(<span class=\"st\">&#39;\\mu&#39;<\/span>), tt.abs_(X_tt) <span class=\"op\">*<\/span> tt.constant(np.diag([<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>])), name<span class=\"op\">=<\/span><span class=\"st\">&#39;Y&#39;<\/span>)<\/span>\n<span id=\"org0e19f4e-8\"><a href=\"#org0e19f4e-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org0e19f4e-9\"><a href=\"#org0e19f4e-9\" aria-hidden=\"true\"><\/a>W_tt <span class=\"op\">=<\/span> X_tt <span class=\"op\">*<\/span> (tt.scalar(<span class=\"st\">&#39;b&#39;<\/span>) <span class=\"op\">*<\/span> Y_tt <span class=\"op\">+<\/span> tt.scalar(<span class=\"st\">&#39;c&#39;<\/span>))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org1ed4a46\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org1ed4a46-1\"><a href=\"#org1ed4a46-1\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(tt_tex_pprint(W_tt, latex_label<span class=\"op\">=<\/span><span class=\"st\">&#39;eq:rv-pprinter-exa&#39;<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation}\n        \\begin{gathered}\n        l_0 \\in \\mathbb{R}\n        \\\\\n        l_1 \\in \\mathbb{R}\n        \\\\\n        Z \\sim \\operatorname{U}\\left(l_0, l_1\\right), \\quad Z \\in \\mathbb{R}\n        \\\\\n        \\sigma_1 \\in \\mathbb{R}\n        \\\\\n        X \\sim \\operatorname{N}\\left(Z, {\\sigma_1}^{2}\\right), \\quad X \\in \\mathbb{R}\n        \\\\\n        b \\in \\mathbb{R}\n        \\\\\n        \\mu \\in \\mathbb{R}^{{n^{\\mu}}_{0}}\n        \\\\\n        Y \\sim \\operatorname{N}\\left(\\mu, {(|X| \\odot \\left[\\begin{matrix}1 &amp; 0\\\\0 &amp; 2\\end{matrix}\\right])}^{2}\\right), \\quad Y \\in \\mathbb{R}^{{n^{Y}}_{0}}\n        \\\\\n        c \\in \\mathbb{R}\n        \\end{gathered}\n        \\\\\n        (X \\odot ((b \\odot Y) + c))\n\\label{eq:rv-pprinter-exa}\n\\end{equation}\\]<\/span><\/p>\n<\/div>\n<\/section>\n<\/section>\n<section id=\"algebraic-manipulations\" class=\"level1\">\n<h1>Algebraic Manipulations<\/h1>\n<p>With our new <code>RandomVariable<\/code>, we can alter the replacement patterns used by <code>tt.gof.opt.PatternSub<\/code> in <a href=\"#24875a2c31fa7f94ce562adddedc0bf8\">Willard, Brandon T. (2018)<\/a> and implement a slightly better parameter lifting for affine transforms of scalar normal random variables in Listing <a href=\"#orgc483b75\">36<\/a>.<\/p>\n<div class=\"sourceCode\" id=\"orgc483b75\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgc483b75-1\"><a href=\"#orgc483b75-1\" aria-hidden=\"true\"><\/a>norm_lift_pats <span class=\"op\">=<\/span> [<\/span>\n<span id=\"orgc483b75-2\"><a href=\"#orgc483b75-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Lift element-wise multiplication<\/span><\/span>\n<span id=\"orgc483b75-3\"><a href=\"#orgc483b75-3\" aria-hidden=\"true\"><\/a>    tt.gof.opt.PatternSub(<\/span>\n<span id=\"orgc483b75-4\"><a href=\"#orgc483b75-4\" aria-hidden=\"true\"><\/a>        (tt.mul,<\/span>\n<span id=\"orgc483b75-5\"><a href=\"#orgc483b75-5\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;a_x&#39;<\/span>,<\/span>\n<span id=\"orgc483b75-6\"><a href=\"#orgc483b75-6\" aria-hidden=\"true\"><\/a>         (NormalRV, <span class=\"st\">&#39;mu_x&#39;<\/span>, <span class=\"st\">&#39;sd_x&#39;<\/span>, <span class=\"st\">&#39;size_x&#39;<\/span>, <span class=\"st\">&#39;rs_x&#39;<\/span>)),<\/span>\n<span id=\"orgc483b75-7\"><a href=\"#orgc483b75-7\" aria-hidden=\"true\"><\/a>        (NormalRV,<\/span>\n<span id=\"orgc483b75-8\"><a href=\"#orgc483b75-8\" aria-hidden=\"true\"><\/a>         (tt.mul, <span class=\"st\">&#39;a_x&#39;<\/span>, <span class=\"st\">&#39;mu_x&#39;<\/span>),<\/span>\n<span id=\"orgc483b75-9\"><a href=\"#orgc483b75-9\" aria-hidden=\"true\"><\/a>         (tt.mul, <span class=\"st\">&#39;a_x&#39;<\/span>, <span class=\"st\">&#39;sd_x&#39;<\/span>),<\/span>\n<span id=\"orgc483b75-10\"><a href=\"#orgc483b75-10\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;size_x&#39;<\/span>,<\/span>\n<span id=\"orgc483b75-11\"><a href=\"#orgc483b75-11\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;rs_x&#39;<\/span>,<\/span>\n<span id=\"orgc483b75-12\"><a href=\"#orgc483b75-12\" aria-hidden=\"true\"><\/a>        )),<\/span>\n<span id=\"orgc483b75-13\"><a href=\"#orgc483b75-13\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Lift element-wise addition<\/span><\/span>\n<span id=\"orgc483b75-14\"><a href=\"#orgc483b75-14\" aria-hidden=\"true\"><\/a>    tt.gof.opt.PatternSub(<\/span>\n<span id=\"orgc483b75-15\"><a href=\"#orgc483b75-15\" aria-hidden=\"true\"><\/a>        (tt.add,<\/span>\n<span id=\"orgc483b75-16\"><a href=\"#orgc483b75-16\" aria-hidden=\"true\"><\/a>         (NormalRV, <span class=\"st\">&#39;mu_x&#39;<\/span>, <span class=\"st\">&#39;sd_x&#39;<\/span>, <span class=\"st\">&#39;size_x&#39;<\/span>, <span class=\"st\">&#39;rs_x&#39;<\/span>),<\/span>\n<span id=\"orgc483b75-17\"><a href=\"#orgc483b75-17\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;b_x&#39;<\/span>),<\/span>\n<span id=\"orgc483b75-18\"><a href=\"#orgc483b75-18\" aria-hidden=\"true\"><\/a>        (NormalRV,<\/span>\n<span id=\"orgc483b75-19\"><a href=\"#orgc483b75-19\" aria-hidden=\"true\"><\/a>         (tt.add, <span class=\"st\">&#39;mu_x&#39;<\/span>, <span class=\"st\">&#39;b_x&#39;<\/span>),<\/span>\n<span id=\"orgc483b75-20\"><a href=\"#orgc483b75-20\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;sd_x&#39;<\/span>,<\/span>\n<span id=\"orgc483b75-21\"><a href=\"#orgc483b75-21\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;size_x&#39;<\/span>,<\/span>\n<span id=\"orgc483b75-22\"><a href=\"#orgc483b75-22\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;rs_x&#39;<\/span>,<\/span>\n<span id=\"orgc483b75-23\"><a href=\"#orgc483b75-23\" aria-hidden=\"true\"><\/a>        )),<\/span>\n<span id=\"orgc483b75-24\"><a href=\"#orgc483b75-24\" aria-hidden=\"true\"><\/a>]<\/span>\n<span id=\"orgc483b75-25\"><a href=\"#orgc483b75-25\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc483b75-26\"><a href=\"#orgc483b75-26\" aria-hidden=\"true\"><\/a>norm_lift_opts <span class=\"op\">=<\/span> tt.gof.opt.EquilibriumOptimizer(<\/span>\n<span id=\"orgc483b75-27\"><a href=\"#orgc483b75-27\" aria-hidden=\"true\"><\/a>    norm_lift_pats, max_use_ratio<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"example\" data-markdown=\"\">\n<div class=\"sourceCode\" id=\"orgc69f52b\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgc69f52b-1\"><a href=\"#orgc69f52b-1\" aria-hidden=\"true\"><\/a><span class=\"co\"># [[file:~\/projects\/websites\/brandonwillard.github.io\/content\/articles\/src\/org\/symbolic-math-in-pymc3-new-op.org::graph-manipulation-setup][graph-manipulation-setup]]<\/span><\/span>\n<span id=\"orgc69f52b-2\"><a href=\"#orgc69f52b-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.gof <span class=\"im\">import<\/span> FunctionGraph, Feature, NodeFinder<\/span>\n<span id=\"orgc69f52b-3\"><a href=\"#orgc69f52b-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.gof.graph <span class=\"im\">import<\/span> inputs <span class=\"im\">as<\/span> tt_inputs, clone_get_equiv<\/span>\n<span id=\"orgc69f52b-4\"><a href=\"#orgc69f52b-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc69f52b-5\"><a href=\"#orgc69f52b-5\" aria-hidden=\"true\"><\/a>theano.config.compute_test_value <span class=\"op\">=<\/span> <span class=\"st\">&#39;ignore&#39;<\/span><\/span>\n<span id=\"orgc69f52b-6\"><a href=\"#orgc69f52b-6\" aria-hidden=\"true\"><\/a><span class=\"co\"># graph-manipulation-setup ends here<\/span><\/span>\n<span id=\"orgc69f52b-7\"><a href=\"#orgc69f52b-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc69f52b-8\"><a href=\"#orgc69f52b-8\" aria-hidden=\"true\"><\/a>mu_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;\\mu&#39;<\/span>)<\/span>\n<span id=\"orgc69f52b-9\"><a href=\"#orgc69f52b-9\" aria-hidden=\"true\"><\/a>sd_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;\\sigma&#39;<\/span>)<\/span>\n<span id=\"orgc69f52b-10\"><a href=\"#orgc69f52b-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc69f52b-11\"><a href=\"#orgc69f52b-11\" aria-hidden=\"true\"><\/a>a_tt <span class=\"op\">=<\/span> tt.fscalar(<span class=\"st\">&#39;a&#39;<\/span>)<\/span>\n<span id=\"orgc69f52b-12\"><a href=\"#orgc69f52b-12\" aria-hidden=\"true\"><\/a>b_tt <span class=\"op\">=<\/span> tt.fscalar(<span class=\"st\">&#39;b&#39;<\/span>)<\/span>\n<span id=\"orgc69f52b-13\"><a href=\"#orgc69f52b-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc69f52b-14\"><a href=\"#orgc69f52b-14\" aria-hidden=\"true\"><\/a>X_rv <span class=\"op\">=<\/span> NormalRV(mu_X, sd_X, name<span class=\"op\">=<\/span><span class=\"st\">&#39;X&#39;<\/span>)<\/span>\n<span id=\"orgc69f52b-15\"><a href=\"#orgc69f52b-15\" aria-hidden=\"true\"><\/a>trans_X_rv <span class=\"op\">=<\/span> a_tt <span class=\"op\">*<\/span> X_rv <span class=\"op\">+<\/span> b_tt<\/span>\n<span id=\"orgc69f52b-16\"><a href=\"#orgc69f52b-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc69f52b-17\"><a href=\"#orgc69f52b-17\" aria-hidden=\"true\"><\/a>trans_X_graph <span class=\"op\">=<\/span> FunctionGraph(tt_inputs([trans_X_rv]), [trans_X_rv])<\/span>\n<span id=\"orgc69f52b-18\"><a href=\"#orgc69f52b-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc69f52b-19\"><a href=\"#orgc69f52b-19\" aria-hidden=\"true\"><\/a><span class=\"co\"># Create a copy and optimize that<\/span><\/span>\n<span id=\"orgc69f52b-20\"><a href=\"#orgc69f52b-20\" aria-hidden=\"true\"><\/a>trans_X_graph_opt <span class=\"op\">=<\/span> trans_X_graph.clone()<\/span>\n<span id=\"orgc69f52b-21\"><a href=\"#orgc69f52b-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgc69f52b-22\"><a href=\"#orgc69f52b-22\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> norm_lift_opts.optimize(trans_X_graph_opt)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org2bd1258\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org2bd1258-1\"><a href=\"#org2bd1258-1\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(tt_tex_pprint(trans_X_graph.outputs[<span class=\"dv\">0<\/span>], latex_env<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>))<\/span><\/code><\/pre><\/div>\n<p>Before applying the optimization:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation*}\n        \\begin{gathered}\n        a \\in \\mathbb{R}\n        \\\\\n        \\mu \\in \\mathbb{R}^{{n^{\\mu}}_{0}}\n        \\\\\n        \\sigma \\in \\mathbb{R}^{{n^{\\sigma}}_{0}}\n        \\\\\n        X \\sim \\operatorname{N}\\left(\\mu, {\\sigma}^{2}\\right), \\quad X \\in \\mathbb{R}^{{n^{X}}_{0}}\n        \\\\\n        b \\in \\mathbb{R}\n        \\end{gathered}\n        \\\\\n        ((a \\odot X) + b)\n\\end{equation*}\\]<\/span><\/p>\n<div class=\"sourceCode\" id=\"orge71b0ac\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orge71b0ac-1\"><a href=\"#orge71b0ac-1\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(tt_tex_pprint(trans_X_graph_opt.outputs[<span class=\"dv\">0<\/span>], latex_env<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>))<\/span><\/code><\/pre><\/div>\n<p>After applying the optimization:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation*}\n        \\begin{gathered}\n        a \\in \\mathbb{R}\n        \\\\\n        \\mu \\in \\mathbb{R}^{{n^{\\mu}}_{0}}\n        \\\\\n        b \\in \\mathbb{R}\n        \\\\\n        \\sigma \\in \\mathbb{R}^{{n^{\\sigma}}_{0}}\n        \\\\\n        c \\sim \\operatorname{N}\\left(((a \\odot \\mu) + b), {(a \\odot \\sigma)}^{2}\\right), \\quad c \\in \\mathbb{R}^{{n^{c}}_{0}}\n        \\end{gathered}\n        \\\\\n        c\n\\end{equation*}\\]<\/span><\/p>\n<\/div>\n<p>Now, what if we wanted to handle affine transformations of a multivariate normal random variable? Specifically, consider the following:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation*}\n  X \\sim N\\left(\\mu, \\Sigma \\right), \\quad\n  A X \\sim N\\left(A \\mu, A \\Sigma A^\\top \\right)\n \\;.\n\\end{equation*}\\]<\/span><\/p>\n<p>At first, the substitution pattern in Listing <a href=\"#org4a792de\">40<\/a> might seem reasonable.<\/p>\n<div class=\"sourceCode\" id=\"org4a792de\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org4a792de-1\"><a href=\"#org4a792de-1\" aria-hidden=\"true\"><\/a><span class=\"co\"># Vector multiplication<\/span><\/span>\n<span id=\"org4a792de-2\"><a href=\"#org4a792de-2\" aria-hidden=\"true\"><\/a>tt.gof.opt.PatternSub(<\/span>\n<span id=\"org4a792de-3\"><a href=\"#org4a792de-3\" aria-hidden=\"true\"><\/a>    (tt.dot, <span class=\"st\">&#39;A_x&#39;<\/span>,<\/span>\n<span id=\"org4a792de-4\"><a href=\"#org4a792de-4\" aria-hidden=\"true\"><\/a>     (MvNormalRV, <span class=\"st\">&#39;mu_x&#39;<\/span>, <span class=\"st\">&#39;cov_x&#39;<\/span>, <span class=\"st\">&#39;size_x&#39;<\/span>, <span class=\"st\">&#39;rs_x&#39;<\/span>)),<\/span>\n<span id=\"org4a792de-5\"><a href=\"#org4a792de-5\" aria-hidden=\"true\"><\/a>    (MvNormalRV,<\/span>\n<span id=\"org4a792de-6\"><a href=\"#org4a792de-6\" aria-hidden=\"true\"><\/a>     (tt.dot, <span class=\"st\">&#39;A_x&#39;<\/span>, <span class=\"st\">&#39;mu_x&#39;<\/span>),<\/span>\n<span id=\"org4a792de-7\"><a href=\"#org4a792de-7\" aria-hidden=\"true\"><\/a>     (tt.dot,<\/span>\n<span id=\"org4a792de-8\"><a href=\"#org4a792de-8\" aria-hidden=\"true\"><\/a>      (tt.dot, <span class=\"st\">&#39;A_x&#39;<\/span>, <span class=\"st\">&#39;cov_x&#39;<\/span>)<\/span>\n<span id=\"org4a792de-9\"><a href=\"#org4a792de-9\" aria-hidden=\"true\"><\/a>      (tt.transpose, <span class=\"st\">&#39;A_x&#39;<\/span>)),<\/span>\n<span id=\"org4a792de-10\"><a href=\"#org4a792de-10\" aria-hidden=\"true\"><\/a>     <span class=\"st\">&#39;size_x&#39;<\/span>,<\/span>\n<span id=\"org4a792de-11\"><a href=\"#org4a792de-11\" aria-hidden=\"true\"><\/a>     <span class=\"st\">&#39;rs_x&#39;<\/span>,<\/span>\n<span id=\"org4a792de-12\"><a href=\"#org4a792de-12\" aria-hidden=\"true\"><\/a>    ))<\/span><\/code><\/pre><\/div>\n<p>Unfortunately, the combination of size parameter and broadcasting complicates the scenario. Both parameters indirectly affect the distribution parameters, making the un-lifted dot-product consistent, but not necessarily the lifted products.<\/p>\n<p>The following example demonstrates the lifting issues brought on by broadcasting.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>We create a simple multivariate normal in Listing <a href=\"#org7446c7b\">41<\/a>.<\/p>\n<div class=\"sourceCode\" id=\"org7446c7b\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org7446c7b-1\"><a href=\"#org7446c7b-1\" aria-hidden=\"true\"><\/a>mu_X <span class=\"op\">=<\/span> [<span class=\"dv\">0<\/span>, <span class=\"dv\">10<\/span>]<\/span>\n<span id=\"org7446c7b-2\"><a href=\"#org7446c7b-2\" aria-hidden=\"true\"><\/a>cov_X <span class=\"op\">=<\/span> np.diag([<span class=\"dv\">1<\/span>, <span class=\"fl\">1e-2<\/span>])<\/span>\n<span id=\"org7446c7b-3\"><a href=\"#org7446c7b-3\" aria-hidden=\"true\"><\/a>size_X_rv <span class=\"op\">=<\/span> [<span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]<\/span>\n<span id=\"org7446c7b-4\"><a href=\"#org7446c7b-4\" aria-hidden=\"true\"><\/a>X_rv <span class=\"op\">=<\/span> MvNormalRV(mu_X, cov_X, size<span class=\"op\">=<\/span>size_X_rv)<\/span>\n<span id=\"org7446c7b-5\"><a href=\"#org7446c7b-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org7446c7b-6\"><a href=\"#org7446c7b-6\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&#39;<\/span><span class=\"sc\">{}<\/span><span class=\"st\"> ~ X_rv<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(X_rv.tag.test_value))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org9e35ff4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org9e35ff4-1\"><a href=\"#org9e35ff4-1\" aria-hidden=\"true\"><\/a>[[[<span class=\"op\">-<\/span><span class=\"fl\">0.68284424<\/span>  <span class=\"fl\">9.95587926<\/span>]<\/span>\n<span id=\"org9e35ff4-2\"><a href=\"#org9e35ff4-2\" aria-hidden=\"true\"><\/a>  [ <span class=\"fl\">1.66236785<\/span>  <span class=\"fl\">9.87590909<\/span>]<\/span>\n<span id=\"org9e35ff4-3\"><a href=\"#org9e35ff4-3\" aria-hidden=\"true\"><\/a>  [ <span class=\"fl\">0.23449772<\/span> <span class=\"fl\">10.12455681<\/span>]]<\/span>\n<span id=\"org9e35ff4-4\"><a href=\"#org9e35ff4-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org9e35ff4-5\"><a href=\"#org9e35ff4-5\" aria-hidden=\"true\"><\/a> [[ <span class=\"fl\">0.3342739<\/span>  <span class=\"fl\">10.05580428<\/span>]<\/span>\n<span id=\"org9e35ff4-6\"><a href=\"#org9e35ff4-6\" aria-hidden=\"true\"><\/a>  [<span class=\"op\">-<\/span><span class=\"fl\">0.18913408<\/span> <span class=\"fl\">10.0359336<\/span> ]<\/span>\n<span id=\"org9e35ff4-7\"><a href=\"#org9e35ff4-7\" aria-hidden=\"true\"><\/a>  [<span class=\"op\">-<\/span><span class=\"fl\">1.2463576<\/span>   <span class=\"fl\">9.90671218<\/span>]]] <span class=\"op\">~<\/span> X_rv<\/span>\n<span id=\"org9e35ff4-8\"><a href=\"#org9e35ff4-8\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>Next, we create a simple matrix operator to apply to the multivariate normal.<\/p>\n<div class=\"sourceCode\" id=\"org0f84661\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org0f84661-1\"><a href=\"#org0f84661-1\" aria-hidden=\"true\"><\/a>A_tt <span class=\"op\">=<\/span> tt.as_tensor_variable([[<span class=\"dv\">2<\/span>, <span class=\"dv\">5<\/span>, <span class=\"dv\">8<\/span>], [<span class=\"dv\">3<\/span>, <span class=\"dv\">4<\/span>, <span class=\"dv\">9<\/span>]])<\/span>\n<span id=\"org0f84661-2\"><a href=\"#org0f84661-2\" aria-hidden=\"true\"><\/a><span class=\"co\"># or A_tt = tt.as_tensor_variable([[2, 5, 8]])<\/span><\/span>\n<span id=\"org0f84661-3\"><a href=\"#org0f84661-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org0f84661-4\"><a href=\"#org0f84661-4\" aria-hidden=\"true\"><\/a><span class=\"co\"># It&#39;s really just `mu_X`...<\/span><\/span>\n<span id=\"org0f84661-5\"><a href=\"#org0f84661-5\" aria-hidden=\"true\"><\/a>E_X_rv <span class=\"op\">=<\/span> X_rv.owner.inputs[<span class=\"dv\">2<\/span>]<\/span>\n<span id=\"org0f84661-6\"><a href=\"#org0f84661-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org0f84661-7\"><a href=\"#org0f84661-7\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&#39;A * X_rv =<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(tt.dot(A_tt, X_rv).tag.test_value))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org05ebd11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org05ebd11-1\"><a href=\"#org05ebd11-1\" aria-hidden=\"true\"><\/a>A <span class=\"op\">*<\/span> X_rv <span class=\"op\">=<\/span><\/span>\n<span id=\"org05ebd11-2\"><a href=\"#org05ebd11-2\" aria-hidden=\"true\"><\/a>[[[  <span class=\"fl\">1.18524621<\/span> <span class=\"fl\">150.31045062<\/span>]<\/span>\n<span id=\"org05ebd11-3\"><a href=\"#org05ebd11-3\" aria-hidden=\"true\"><\/a>  [  <span class=\"fl\">1.07000851<\/span> <span class=\"fl\">150.65771936<\/span>]]<\/span>\n<span id=\"org05ebd11-4\"><a href=\"#org05ebd11-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org05ebd11-5\"><a href=\"#org05ebd11-5\" aria-hidden=\"true\"><\/a> [[  <span class=\"fl\">1.31685497<\/span> <span class=\"fl\">160.33572146<\/span>]<\/span>\n<span id=\"org05ebd11-6\"><a href=\"#org05ebd11-6\" aria-hidden=\"true\"><\/a>  [  <span class=\"fl\">0.33506491<\/span> <span class=\"fl\">160.82202495<\/span>]]]<\/span>\n<span id=\"org05ebd11-7\"><a href=\"#org05ebd11-7\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>As we can see, the multivariate normal\u2019s test\/sampled value has the correct shape for our matrix operator.<\/p>\n<div class=\"sourceCode\" id=\"orgd95a139\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgd95a139-1\"><a href=\"#orgd95a139-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> traceback<\/span>\n<span id=\"orgd95a139-2\"><a href=\"#orgd95a139-2\" aria-hidden=\"true\"><\/a><span class=\"cf\">try<\/span>:<\/span>\n<span id=\"orgd95a139-3\"><a href=\"#orgd95a139-3\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&#39;A * E[X_rv] =<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(tt.dot(A_tt, E_X_rv).tag.test_value))<\/span>\n<span id=\"orgd95a139-4\"><a href=\"#orgd95a139-4\" aria-hidden=\"true\"><\/a><span class=\"cf\">except<\/span> <span class=\"pp\">ValueError<\/span> <span class=\"im\">as<\/span> e:<\/span>\n<span id=\"orgd95a139-5\"><a href=\"#orgd95a139-5\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;&quot;<\/span>.join(traceback.format_exception_only(<span class=\"bu\">type<\/span>(e), e)))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org51f4e3a\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org51f4e3a-1\"><a href=\"#org51f4e3a-1\" aria-hidden=\"true\"><\/a><span class=\"pp\">ValueError<\/span>: shapes (<span class=\"dv\">2<\/span>,<span class=\"dv\">3<\/span>) <span class=\"kw\">and<\/span> (<span class=\"dv\">2<\/span>,) <span class=\"kw\">not<\/span> aligned: <span class=\"dv\">3<\/span> (dim <span class=\"dv\">1<\/span>) <span class=\"op\">!=<\/span> <span class=\"dv\">2<\/span> (dim <span class=\"dv\">0<\/span>)<\/span>\n<span id=\"org51f4e3a-2\"><a href=\"#org51f4e3a-2\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>However, we see that the multivariate normal\u2019s inputs (i.e.\u00a0the <code>Op<\/code> inputs)\u2013specifically the mean parameter\u2013do not directly reflect the support\u2019s shape, as one might expect.<\/p>\n<div class=\"sourceCode\" id=\"org6e32649\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org6e32649-1\"><a href=\"#org6e32649-1\" aria-hidden=\"true\"><\/a>size_tile <span class=\"op\">=<\/span> <span class=\"bu\">tuple<\/span>(size_X_rv) <span class=\"op\">+<\/span> (<span class=\"dv\">1<\/span>,)<\/span>\n<span id=\"org6e32649-2\"><a href=\"#org6e32649-2\" aria-hidden=\"true\"><\/a>E_X_rv_ <span class=\"op\">=<\/span> tt.tile(E_X_rv, size_tile, X_rv.ndim)<\/span>\n<span id=\"org6e32649-3\"><a href=\"#org6e32649-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org6e32649-4\"><a href=\"#org6e32649-4\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&#39;A * E[X_rv] =<\/span><span class=\"ch\">\\n<\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(tt.dot(A_tt, E_X_rv_).tag.test_value))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org7f64727\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org7f64727-1\"><a href=\"#org7f64727-1\" aria-hidden=\"true\"><\/a>A <span class=\"op\">*<\/span> E[X_rv] <span class=\"op\">=<\/span><\/span>\n<span id=\"org7f64727-2\"><a href=\"#org7f64727-2\" aria-hidden=\"true\"><\/a>[[[  <span class=\"dv\">0<\/span> <span class=\"dv\">150<\/span>]<\/span>\n<span id=\"org7f64727-3\"><a href=\"#org7f64727-3\" aria-hidden=\"true\"><\/a>  [  <span class=\"dv\">0<\/span> <span class=\"dv\">150<\/span>]]<\/span>\n<span id=\"org7f64727-4\"><a href=\"#org7f64727-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org7f64727-5\"><a href=\"#org7f64727-5\" aria-hidden=\"true\"><\/a> [[  <span class=\"dv\">0<\/span> <span class=\"dv\">160<\/span>]<\/span>\n<span id=\"org7f64727-6\"><a href=\"#org7f64727-6\" aria-hidden=\"true\"><\/a>  [  <span class=\"dv\">0<\/span> <span class=\"dv\">160<\/span>]]]<\/span>\n<span id=\"org7f64727-7\"><a href=\"#org7f64727-7\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>We can manually replicate the inputs so that they match the output shape, but a solution to the general problem requires a more organized response.<\/p>\n<\/div>\n<\/section>\n<section id=\"a-problem-with-conversion-from-pymc3\" class=\"level1\">\n<h1>A Problem with Conversion from PyMC3<\/h1>\n<p>As in <a href=\"#24875a2c31fa7f94ce562adddedc0bf8\">Willard, Brandon T. (2018)<\/a>, we can create mappings between existing PyMC3 random variables and their new <code>RandomVariable<\/code> equivalents.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<div class=\"sourceCode\" id=\"org8bffc80\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org8bffc80-1\"><a href=\"#org8bffc80-1\" aria-hidden=\"true\"><\/a>pymc_theano_rv_equivs <span class=\"op\">=<\/span> {<\/span>\n<span id=\"org8bffc80-2\"><a href=\"#org8bffc80-2\" aria-hidden=\"true\"><\/a>    pm.Normal:<\/span>\n<span id=\"org8bffc80-3\"><a href=\"#org8bffc80-3\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">lambda<\/span> dist, rand_state:<\/span>\n<span id=\"org8bffc80-4\"><a href=\"#org8bffc80-4\" aria-hidden=\"true\"><\/a>    (<span class=\"va\">None<\/span>,<\/span>\n<span id=\"org8bffc80-5\"><a href=\"#org8bffc80-5\" aria-hidden=\"true\"><\/a>     <span class=\"co\"># PyMC3 shapes aren&#39;t NumPy-like size parameters, so we attempt to<\/span><\/span>\n<span id=\"org8bffc80-6\"><a href=\"#org8bffc80-6\" aria-hidden=\"true\"><\/a>     <span class=\"co\"># adjust for that.<\/span><\/span>\n<span id=\"org8bffc80-7\"><a href=\"#org8bffc80-7\" aria-hidden=\"true\"><\/a>     NormalRV(dist.mu, dist.sd, size<span class=\"op\">=<\/span>dist.shape[<span class=\"dv\">1<\/span>:], rng<span class=\"op\">=<\/span>rand_state)),<\/span>\n<span id=\"org8bffc80-8\"><a href=\"#org8bffc80-8\" aria-hidden=\"true\"><\/a>    pm.MvNormal:<\/span>\n<span id=\"org8bffc80-9\"><a href=\"#org8bffc80-9\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">lambda<\/span> dist, rand_state:<\/span>\n<span id=\"org8bffc80-10\"><a href=\"#org8bffc80-10\" aria-hidden=\"true\"><\/a>    (<span class=\"va\">None<\/span>, NormalRV(dist.mu, dist.cov, size<span class=\"op\">=<\/span>dist.shape[<span class=\"dv\">1<\/span>:], rng<span class=\"op\">=<\/span>rand_state)),<\/span>\n<span id=\"org8bffc80-11\"><a href=\"#org8bffc80-11\" aria-hidden=\"true\"><\/a>}<\/span><\/code><\/pre><\/div>\n<\/div>\n<p>However, if we attempt the same PymC3 graph conversion approach as before (i.e.\u00a0convert a PyMC3 model to a Theano <code>FunctionGraph<\/code> using <code>model_graph<\/code>, then replace PyMC3 random variable nodes with our new random variable types using <code>create_theano_rvs<\/code>), we\u2019re likely to run into a problem involving mismatching broadcastable dimensions.<\/p>\n<p>The problem arises because <strong>PyMC3 \u201cknows\u201d more broadcast information than it should<\/strong>, since it uses the Theano variables\u2019 test values in order to obtain concrete shapes for the random variables it creates. Using concrete, non-symbolic shapes, it can exactly determine what would otherwise be ambiguous <a href=\"http:\/\/deeplearning.net\/software\/theano\/library\/tensor\/basic.html?highlight=broadcastable#theano.tensor.TensorType.broadcastable\">broadcastable dimensions<\/a> at the symbolic level.<\/p>\n<p>More specifically, broadcast information is required during the construction of a Theano <code>TensorType<\/code>, so PyMC3 random variable types can be inconsistent (unnecessarily restrictive, really) causing Theano to complain when we try to construct a <code>FunctionGraph<\/code>.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>Consider the following example; it constructs two purely symbolic Theano vectors: one with broadcasting and one without.<\/p>\n<div class=\"sourceCode\" id=\"orgbd54927\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgbd54927-1\"><a href=\"#orgbd54927-1\" aria-hidden=\"true\"><\/a>y_tt <span class=\"op\">=<\/span> tt.row(<span class=\"st\">&#39;y&#39;<\/span>)<\/span>\n<span id=\"orgbd54927-2\"><a href=\"#orgbd54927-2\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;y_tt.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(y_tt.broadcastable))<\/span>\n<span id=\"orgbd54927-3\"><a href=\"#orgbd54927-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgbd54927-4\"><a href=\"#orgbd54927-4\" aria-hidden=\"true\"><\/a>x_tt <span class=\"op\">=<\/span> tt.matrix(<span class=\"st\">&#39;x&#39;<\/span>)<\/span>\n<span id=\"orgbd54927-5\"><a href=\"#orgbd54927-5\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;x_tt.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(x_tt.broadcastable))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orga8f9d4c\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orga8f9d4c-1\"><a href=\"#orga8f9d4c-1\" aria-hidden=\"true\"><\/a>y_tt.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">True<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"orga8f9d4c-2\"><a href=\"#orga8f9d4c-2\" aria-hidden=\"true\"><\/a>x_tt.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">False<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"orga8f9d4c-3\"><a href=\"#orga8f9d4c-3\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>Notice that it\u2013by default\u2013signifies no broadcasting on its first and only dimension.<\/p>\n<p>If we wish\u2013or if <a href=\"http:\/\/deeplearning.net\/software\/theano\/library\/config.html#config.compute_test_value\">Theano\u2019s configuration demands<\/a> it\u2013we can assign the symbolic vector arbitrary test values, as long as they\u2019re consistent with its type (i.e.\u00a0a vector, or 1-dimensional array).<\/p>\n<p>In the following, we assign both a broadcastable (i.e.\u00a0first\u2013and only\u2013dimension has size 1) and non-broadcastable test value.<\/p>\n<p>Test value is broadcastable:<\/p>\n<div class=\"sourceCode\" id=\"orgaadc727\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgaadc727-1\"><a href=\"#orgaadc727-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> contextlib <span class=\"im\">import<\/span> contextmanager<\/span>\n<span id=\"orgaadc727-2\"><a href=\"#orgaadc727-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgaadc727-3\"><a href=\"#orgaadc727-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgaadc727-4\"><a href=\"#orgaadc727-4\" aria-hidden=\"true\"><\/a>x_tt.tag.test_value <span class=\"op\">=<\/span> np.array([[<span class=\"dv\">5<\/span>]])<\/span>\n<span id=\"orgaadc727-5\"><a href=\"#orgaadc727-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgaadc727-6\"><a href=\"#orgaadc727-6\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;test_value.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orgaadc727-7\"><a href=\"#orgaadc727-7\" aria-hidden=\"true\"><\/a>    tt.as_tensor_variable(x_tt.tag.test_value).broadcastable))<\/span>\n<span id=\"orgaadc727-8\"><a href=\"#orgaadc727-8\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;x_tt.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(x_tt.broadcastable))<\/span>\n<span id=\"orgaadc727-9\"><a href=\"#orgaadc727-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgaadc727-10\"><a href=\"#orgaadc727-10\" aria-hidden=\"true\"><\/a><span class=\"at\">@contextmanager<\/span><\/span>\n<span id=\"orgaadc727-11\"><a href=\"#orgaadc727-11\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> short_exception_msg(exc_type):<\/span>\n<span id=\"orgaadc727-12\"><a href=\"#orgaadc727-12\" aria-hidden=\"true\"><\/a>    _verbosity <span class=\"op\">=<\/span> theano.config.exception_verbosity<\/span>\n<span id=\"orgaadc727-13\"><a href=\"#orgaadc727-13\" aria-hidden=\"true\"><\/a>    theano.config.exception_verbosity <span class=\"op\">=<\/span> <span class=\"st\">&#39;low&#39;<\/span><\/span>\n<span id=\"orgaadc727-14\"><a href=\"#orgaadc727-14\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">try<\/span>:<\/span>\n<span id=\"orgaadc727-15\"><a href=\"#orgaadc727-15\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">yield<\/span><\/span>\n<span id=\"orgaadc727-16\"><a href=\"#orgaadc727-16\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">except<\/span> exc_type <span class=\"im\">as<\/span> e:<\/span>\n<span id=\"orgaadc727-17\"><a href=\"#orgaadc727-17\" aria-hidden=\"true\"><\/a>        <span class=\"im\">import<\/span> traceback<\/span>\n<span id=\"orgaadc727-18\"><a href=\"#orgaadc727-18\" aria-hidden=\"true\"><\/a>        <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;&quot;<\/span>.join(traceback.format_exception_only(<span class=\"bu\">type<\/span>(e), e)))<\/span>\n<span id=\"orgaadc727-19\"><a href=\"#orgaadc727-19\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">finally<\/span>:<\/span>\n<span id=\"orgaadc727-20\"><a href=\"#orgaadc727-20\" aria-hidden=\"true\"><\/a>        theano.config.exception_verbosity <span class=\"op\">=<\/span> _verbosity<\/span>\n<span id=\"orgaadc727-21\"><a href=\"#orgaadc727-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgaadc727-22\"><a href=\"#orgaadc727-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orgaadc727-23\"><a href=\"#orgaadc727-23\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> short_exception_msg(<span class=\"pp\">TypeError<\/span>):<\/span>\n<span id=\"orgaadc727-24\"><a href=\"#orgaadc727-24\" aria-hidden=\"true\"><\/a>    x_tt.shape<\/span>\n<span id=\"orgaadc727-25\"><a href=\"#orgaadc727-25\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;shape checks out!&quot;<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org9c02ec5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org9c02ec5-1\"><a href=\"#org9c02ec5-1\" aria-hidden=\"true\"><\/a>test_value.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">True<\/span>, <span class=\"va\">True<\/span>)<\/span>\n<span id=\"org9c02ec5-2\"><a href=\"#org9c02ec5-2\" aria-hidden=\"true\"><\/a>x_tt.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">False<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org9c02ec5-3\"><a href=\"#org9c02ec5-3\" aria-hidden=\"true\"><\/a>shape checks out<span class=\"op\">!<\/span><\/span>\n<span id=\"org9c02ec5-4\"><a href=\"#org9c02ec5-4\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org56323a0\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org56323a0-1\"><a href=\"#org56323a0-1\" aria-hidden=\"true\"><\/a>y_tt.tag.test_value <span class=\"op\">=<\/span> np.array([[<span class=\"dv\">5<\/span>]])<\/span>\n<span id=\"org56323a0-2\"><a href=\"#org56323a0-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org56323a0-3\"><a href=\"#org56323a0-3\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;test_value.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"org56323a0-4\"><a href=\"#org56323a0-4\" aria-hidden=\"true\"><\/a>    tt.as_tensor_variable(y_tt.tag.test_value).broadcastable))<\/span>\n<span id=\"org56323a0-5\"><a href=\"#org56323a0-5\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;y_tt.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(y_tt.broadcastable))<\/span>\n<span id=\"org56323a0-6\"><a href=\"#org56323a0-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org56323a0-7\"><a href=\"#org56323a0-7\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> short_exception_msg(<span class=\"pp\">TypeError<\/span>):<\/span>\n<span id=\"org56323a0-8\"><a href=\"#org56323a0-8\" aria-hidden=\"true\"><\/a>    y_tt.shape<\/span>\n<span id=\"org56323a0-9\"><a href=\"#org56323a0-9\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;shape checks out!&quot;<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"orga78b1b0\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orga78b1b0-1\"><a href=\"#orga78b1b0-1\" aria-hidden=\"true\"><\/a>test_value.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">True<\/span>, <span class=\"va\">True<\/span>)<\/span>\n<span id=\"orga78b1b0-2\"><a href=\"#orga78b1b0-2\" aria-hidden=\"true\"><\/a>y_tt.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">True<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"orga78b1b0-3\"><a href=\"#orga78b1b0-3\" aria-hidden=\"true\"><\/a>shape checks out<span class=\"op\">!<\/span><\/span>\n<span id=\"orga78b1b0-4\"><a href=\"#orga78b1b0-4\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>Test value is <strong>not<\/strong> broadcastable:<\/p>\n<div class=\"sourceCode\" id=\"orge2d0568\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orge2d0568-1\"><a href=\"#orge2d0568-1\" aria-hidden=\"true\"><\/a>x_tt.tag.test_value <span class=\"op\">=<\/span> np.array([[<span class=\"dv\">5<\/span>, <span class=\"dv\">4<\/span>]])<\/span>\n<span id=\"orge2d0568-2\"><a href=\"#orge2d0568-2\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;test_value.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"orge2d0568-3\"><a href=\"#orge2d0568-3\" aria-hidden=\"true\"><\/a>    tt.as_tensor_variable(x_tt.tag.test_value).broadcastable))<\/span>\n<span id=\"orge2d0568-4\"><a href=\"#orge2d0568-4\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;x_tt.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(x_tt.broadcastable))<\/span>\n<span id=\"orge2d0568-5\"><a href=\"#orge2d0568-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"orge2d0568-6\"><a href=\"#orge2d0568-6\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> short_exception_msg(<span class=\"pp\">TypeError<\/span>):<\/span>\n<span id=\"orge2d0568-7\"><a href=\"#orge2d0568-7\" aria-hidden=\"true\"><\/a>    x_tt.shape<\/span>\n<span id=\"orge2d0568-8\"><a href=\"#orge2d0568-8\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;shape checks out!&quot;<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org3ac23d5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org3ac23d5-1\"><a href=\"#org3ac23d5-1\" aria-hidden=\"true\"><\/a>test_value.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">True<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org3ac23d5-2\"><a href=\"#org3ac23d5-2\" aria-hidden=\"true\"><\/a>x_tt.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">False<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org3ac23d5-3\"><a href=\"#org3ac23d5-3\" aria-hidden=\"true\"><\/a>shape checks out<span class=\"op\">!<\/span><\/span>\n<span id=\"org3ac23d5-4\"><a href=\"#org3ac23d5-4\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org53e07ae\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org53e07ae-1\"><a href=\"#org53e07ae-1\" aria-hidden=\"true\"><\/a>y_tt.tag.test_value <span class=\"op\">=<\/span> np.array([[<span class=\"dv\">5<\/span>, <span class=\"dv\">4<\/span>], [<span class=\"dv\">3<\/span>, <span class=\"dv\">2<\/span>]])<\/span>\n<span id=\"org53e07ae-2\"><a href=\"#org53e07ae-2\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;test_value.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"org53e07ae-3\"><a href=\"#org53e07ae-3\" aria-hidden=\"true\"><\/a>    tt.as_tensor_variable(y_tt.tag.test_value).broadcastable))<\/span>\n<span id=\"org53e07ae-4\"><a href=\"#org53e07ae-4\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&quot;y_tt.broadcastable = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&quot;<\/span>.<span class=\"bu\">format<\/span>(y_tt.broadcastable))<\/span>\n<span id=\"org53e07ae-5\"><a href=\"#org53e07ae-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org53e07ae-6\"><a href=\"#org53e07ae-6\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> short_exception_msg(<span class=\"pp\">TypeError<\/span>):<\/span>\n<span id=\"org53e07ae-7\"><a href=\"#org53e07ae-7\" aria-hidden=\"true\"><\/a>    y_tt.shape<\/span>\n<span id=\"org53e07ae-8\"><a href=\"#org53e07ae-8\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">print<\/span>(<span class=\"st\">&quot;shape checks out!&quot;<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org60e8d8f\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org60e8d8f-1\"><a href=\"#org60e8d8f-1\" aria-hidden=\"true\"><\/a>test_value.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">False<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org60e8d8f-2\"><a href=\"#org60e8d8f-2\" aria-hidden=\"true\"><\/a>y_tt.broadcastable <span class=\"op\">=<\/span> (<span class=\"va\">True<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org60e8d8f-3\"><a href=\"#org60e8d8f-3\" aria-hidden=\"true\"><\/a><span class=\"pp\">TypeError<\/span>: For compute_test_value, one <span class=\"bu\">input<\/span> test value does <span class=\"kw\">not<\/span> have the requested <span class=\"bu\">type<\/span>.<\/span>\n<span id=\"org60e8d8f-4\"><a href=\"#org60e8d8f-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org60e8d8f-5\"><a href=\"#org60e8d8f-5\" aria-hidden=\"true\"><\/a>Backtrace when that variable <span class=\"kw\">is<\/span> created:<\/span>\n<span id=\"org60e8d8f-6\"><a href=\"#org60e8d8f-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org60e8d8f-7\"><a href=\"#org60e8d8f-7\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;\/home\/bwillard\/apps\/anaconda3\/envs\/github-website\/lib\/python3.6\/site-packages\/IPython\/terminal\/interactiveshell.py&quot;<\/span>, line <span class=\"dv\">485<\/span>, <span class=\"kw\">in<\/span> mainloop<\/span>\n<span id=\"org60e8d8f-8\"><a href=\"#org60e8d8f-8\" aria-hidden=\"true\"><\/a>    <span class=\"va\">self<\/span>.interact()<\/span>\n<span id=\"org60e8d8f-9\"><a href=\"#org60e8d8f-9\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;\/home\/bwillard\/apps\/anaconda3\/envs\/github-website\/lib\/python3.6\/site-packages\/IPython\/terminal\/interactiveshell.py&quot;<\/span>, line <span class=\"dv\">476<\/span>, <span class=\"kw\">in<\/span> interact<\/span>\n<span id=\"org60e8d8f-10\"><a href=\"#org60e8d8f-10\" aria-hidden=\"true\"><\/a>    <span class=\"va\">self<\/span>.run_cell(code, store_history<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"org60e8d8f-11\"><a href=\"#org60e8d8f-11\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;\/home\/bwillard\/apps\/anaconda3\/envs\/github-website\/lib\/python3.6\/site-packages\/IPython\/core\/interactiveshell.py&quot;<\/span>, line <span class=\"dv\">2662<\/span>, <span class=\"kw\">in<\/span> run_cell<\/span>\n<span id=\"org60e8d8f-12\"><a href=\"#org60e8d8f-12\" aria-hidden=\"true\"><\/a>    raw_cell, store_history, silent, shell_futures)<\/span>\n<span id=\"org60e8d8f-13\"><a href=\"#org60e8d8f-13\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;\/home\/bwillard\/apps\/anaconda3\/envs\/github-website\/lib\/python3.6\/site-packages\/IPython\/core\/interactiveshell.py&quot;<\/span>, line <span class=\"dv\">2785<\/span>, <span class=\"kw\">in<\/span> _run_cell<\/span>\n<span id=\"org60e8d8f-14\"><a href=\"#org60e8d8f-14\" aria-hidden=\"true\"><\/a>    interactivity<span class=\"op\">=<\/span>interactivity, compiler<span class=\"op\">=<\/span>compiler, result<span class=\"op\">=<\/span>result)<\/span>\n<span id=\"org60e8d8f-15\"><a href=\"#org60e8d8f-15\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;\/home\/bwillard\/apps\/anaconda3\/envs\/github-website\/lib\/python3.6\/site-packages\/IPython\/core\/interactiveshell.py&quot;<\/span>, line <span class=\"dv\">2909<\/span>, <span class=\"kw\">in<\/span> run_ast_nodes<\/span>\n<span id=\"org60e8d8f-16\"><a href=\"#org60e8d8f-16\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"va\">self<\/span>.run_code(code, result):<\/span>\n<span id=\"org60e8d8f-17\"><a href=\"#org60e8d8f-17\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;\/home\/bwillard\/apps\/anaconda3\/envs\/github-website\/lib\/python3.6\/site-packages\/IPython\/core\/interactiveshell.py&quot;<\/span>, line <span class=\"dv\">2963<\/span>, <span class=\"kw\">in<\/span> run_code<\/span>\n<span id=\"org60e8d8f-18\"><a href=\"#org60e8d8f-18\" aria-hidden=\"true\"><\/a>    <span class=\"bu\">exec<\/span>(code_obj, <span class=\"va\">self<\/span>.user_global_ns, <span class=\"va\">self<\/span>.user_ns)<\/span>\n<span id=\"org60e8d8f-19\"><a href=\"#org60e8d8f-19\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;&lt;ipython-input-19-7427b1688530&gt;&quot;<\/span>, line <span class=\"dv\">1<\/span>, <span class=\"kw\">in<\/span> <span class=\"op\">&lt;<\/span>module<span class=\"op\">&gt;<\/span><\/span>\n<span id=\"org60e8d8f-20\"><a href=\"#org60e8d8f-20\" aria-hidden=\"true\"><\/a>    __org_babel_python_fname <span class=\"op\">=<\/span> <span class=\"st\">&#39;\/tmp\/user\/1000\/babel-fsZXPU\/python-cZypXi&#39;<\/span><span class=\"op\">;<\/span> __org_babel_python_fh <span class=\"op\">=<\/span> <span class=\"bu\">open<\/span>(__org_babel_python_fname)<span class=\"op\">;<\/span> <span class=\"bu\">exec<\/span>(<span class=\"bu\">compile<\/span>(__org_babel_python_fh.read(), __org_babel_python_fname, <span class=\"st\">&#39;exec&#39;<\/span>))<span class=\"op\">;<\/span> __org_babel_python_fh.close()<\/span>\n<span id=\"org60e8d8f-21\"><a href=\"#org60e8d8f-21\" aria-hidden=\"true\"><\/a>  File <span class=\"st\">&quot;\/tmp\/user\/1000\/babel-fsZXPU\/python-cZypXi&quot;<\/span>, line <span class=\"dv\">1<\/span>, <span class=\"kw\">in<\/span> <span class=\"op\">&lt;<\/span>module<span class=\"op\">&gt;<\/span><\/span>\n<span id=\"org60e8d8f-22\"><a href=\"#org60e8d8f-22\" aria-hidden=\"true\"><\/a>    y_tt <span class=\"op\">=<\/span> tt.row(<span class=\"st\">&#39;y&#39;<\/span>)<\/span>\n<span id=\"org60e8d8f-23\"><a href=\"#org60e8d8f-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org60e8d8f-24\"><a href=\"#org60e8d8f-24\" aria-hidden=\"true\"><\/a>The error when converting the test value to that variable <span class=\"bu\">type<\/span>:<\/span>\n<span id=\"org60e8d8f-25\"><a href=\"#org60e8d8f-25\" aria-hidden=\"true\"><\/a>Non<span class=\"op\">-<\/span>unit value on shape on a broadcastable dimension.<\/span>\n<span id=\"org60e8d8f-26\"><a href=\"#org60e8d8f-26\" aria-hidden=\"true\"><\/a>(<span class=\"dv\">2<\/span>, <span class=\"dv\">2<\/span>)<\/span>\n<span id=\"org60e8d8f-27\"><a href=\"#org60e8d8f-27\" aria-hidden=\"true\"><\/a>(<span class=\"va\">True<\/span>, <span class=\"va\">False<\/span>)<\/span>\n<span id=\"org60e8d8f-28\"><a href=\"#org60e8d8f-28\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>Simply put: non-broadcastable Theano tensor variable types can take broadcastable and non-broadcastable values, while broadcastable types can only take broadcastable values.<\/p>\n<\/div>\n<p>What we can take from the example above is that if we determine that a vector has broadcastable dimensions using test values\u2013as PyMC3 does\u2013we unnecessarily introduce restrictions and potential inconsistencies down the line. One point of origin for such issues is <strong>shared variables<\/strong>.<\/p>\n<\/section>\n<section id=\"discussion\" class=\"level1\">\n<h1>Discussion<\/h1>\n<p>In follow-ups to this series, we\u2019ll address a few loose ends, such as<\/p>\n<ul>\n<li>the inclusion of density functions and likelihoods,<\/li>\n<li>decompositions\/reductions of overlapping multivariate types (e.g.\u00a0transforms between tensors of univariate normals and equivalent multivariate normals),<\/li>\n<li>canonicalization of graphs containing <code>RandomVariable<\/code> terms,<\/li>\n<li>and more optimizations that specifically target MCMC schemes (e.g.\u00a0automatic conversion to scale mixture decompositions).<\/li>\n<\/ul>\n<\/section>\n<\/body>\n<\/html>\n","category":[{"@attributes":{"term":"articles"}},{"@attributes":{"term":"pymc3"}},{"@attributes":{"term":"theano"}},{"@attributes":{"term":"statistics"}},{"@attributes":{"term":"symbolic computation"}},{"@attributes":{"term":"python"}},{"@attributes":{"term":"probability theory"}}]},{"title":"Readable Strings and Relational Programming in Hy","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/readable-strings-and-relational-programming-in-hy.html","rel":"alternate"}},"published":"2018-12-20T00:00:00-06:00","updated":"2019-01-07T00:00:00-06:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2018-12-20:\/readable-strings-and-relational-programming-in-hy.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>Readable Strings and Relational Programming in Hy<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">Readable Strings and Relational Programming in Hy<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2018\u201312\u201320<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<div class=\"abstract\">\n<p>Just some thoughts on a generalized <code>repr<\/code> for Hy and some connections with relational programming.<\/p>\n<\/div>\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p>In the past few months, I\u2019ve been working on <a href=\"https:\/\/github.com\/hylang\/hy\">Hy<\/a> a lot. It\u2019s been great for translating symbolic computation ideas originating in the Lisp community or simply performing the generic meta-programming inherent to the subject.<\/p>\n<p>One feature I\u2019ve been missing the most is \u201creadable\u201d print-outs from the REPL. In this case, \u201creadable\u201d means \u201ca string that can be <code>eval<\/code>\u2019ed to [re-]produce the object it\u2019s meant to represent\u201d. <a href=\"https:\/\/docs.python.org\/3\/library\/functions.html#repr\">Python calls the function(s) that produce these strings \u201c<code>repr<\/code>\u201ds<\/a> and provides a generic <code>repr<\/code> function\u2013with limited Python \u201creadability\u201d guarantees\u2013and a <code>__repr__<\/code> property for object\/class-level customization.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<div class=\"sourceCode\" id=\"org1fcc4d3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org1fcc4d3-1\"><a href=\"#org1fcc4d3-1\" aria-hidden=\"true\"><\/a>test_obj <span class=\"op\">=<\/span> {<span class=\"st\">&quot;a&quot;<\/span>: <span class=\"dv\">1<\/span>, <span class=\"st\">&quot;b&quot;<\/span>: [<span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]}<\/span>\n<span id=\"org1fcc4d3-2\"><a href=\"#org1fcc4d3-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org1fcc4d3-3\"><a href=\"#org1fcc4d3-3\" aria-hidden=\"true\"><\/a><span class=\"co\"># Produce a readable string using `repr`<\/span><\/span>\n<span id=\"org1fcc4d3-4\"><a href=\"#org1fcc4d3-4\" aria-hidden=\"true\"><\/a>obj_repr_str <span class=\"op\">=<\/span> <span class=\"bu\">repr<\/span>(test_obj)<\/span>\n<span id=\"org1fcc4d3-5\"><a href=\"#org1fcc4d3-5\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(obj_repr_str)<\/span>\n<span id=\"org1fcc4d3-6\"><a href=\"#org1fcc4d3-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org1fcc4d3-7\"><a href=\"#org1fcc4d3-7\" aria-hidden=\"true\"><\/a><span class=\"co\"># Re-create the object from its readable string form<\/span><\/span>\n<span id=\"org1fcc4d3-8\"><a href=\"#org1fcc4d3-8\" aria-hidden=\"true\"><\/a>obj_from_repr <span class=\"op\">=<\/span> <span class=\"bu\">eval<\/span>(obj_repr_str)<\/span>\n<span id=\"org1fcc4d3-9\"><a href=\"#org1fcc4d3-9\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(obj_from_repr)<\/span>\n<span id=\"org1fcc4d3-10\"><a href=\"#org1fcc4d3-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"org1fcc4d3-11\"><a href=\"#org1fcc4d3-11\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(test_obj <span class=\"op\">==<\/span> obj_from_repr)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"org636a628\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org636a628-1\"><a href=\"#org636a628-1\" aria-hidden=\"true\"><\/a>{<span class=\"st\">&#39;a&#39;<\/span>: <span class=\"dv\">1<\/span>, <span class=\"st\">&#39;b&#39;<\/span>: [<span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]}<\/span>\n<span id=\"org636a628-2\"><a href=\"#org636a628-2\" aria-hidden=\"true\"><\/a>{<span class=\"st\">&#39;a&#39;<\/span>: <span class=\"dv\">1<\/span>, <span class=\"st\">&#39;b&#39;<\/span>: [<span class=\"dv\">2<\/span>, <span class=\"dv\">3<\/span>]}<\/span>\n<span id=\"org636a628-3\"><a href=\"#org636a628-3\" aria-hidden=\"true\"><\/a><span class=\"va\">True<\/span><\/span>\n<span id=\"org636a628-4\"><a href=\"#org636a628-4\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<\/div>\n<p>There\u2019s already a <code>hy.contrib.hy-repr<\/code> module that gets most of the way there, but it doesn\u2019t implement the Python standard library\u2019s <code>reprlib.Repr<\/code>. The class <code>reprlib.Repr<\/code> implements limits for the display lengths of the strings it produces, and its source code provides a few standard library implementations of primitive object <code>repr<\/code>s\u2013which require only trivial changes to produce the desired Hy syntax.<\/p>\n<p>For these reasons\u2013and an overall interest in using and translating more of the Python standard library to Hy\u2013I decided to try a quick refactoring of <code>hy.contrib.hy-repr<\/code> that implements <code>reprlib.Repr<\/code>.<\/p>\n<\/section>\n<section id=\"the-hy-repr-problems\" class=\"level1\">\n<h1>The Hy <code>repr<\/code> Problem(s)<\/h1>\n<p>The translation of Hy AST to string form is fairly straight-forward. In most cases, one only needs to change the <code>repr<\/code>s for Python primitives and basic function calls (e.g.\u00a0from <code>func(1)<\/code> to <code>(func 1)<\/code>); however, changing just a couple lines in <code>repr<\/code>\/<code>__repr__<\/code> functions for all the Python builtins is very annoying.<\/p>\n<p>Furthermore, what about those custom object <code>__repr__<\/code> methods? While one might be able to manually patch most\u2013if not all\u2013of the (Python-implemented) standard library objects, there are far too many 3rd-party library <code>__repr__<\/code>s with exactly the same trivial function-call form that can\u2019t reasonably be patched.<\/p>\n<section id=\"some-approaches\" class=\"level2\">\n<h2>Some approaches<\/h2>\n<p>The first few things that come to mind when considering a more general approach to Python-to-Hy <code>__repr__<\/code> translation involve some use of the existing <code>repr<\/code> code. That might come in the form of string manipulation of <code>repr<\/code> output, which <code>hy.contrib.hy-repr<\/code> already does in some cases, or quite possibly some use of a <code>repr<\/code> function\u2019s source or code object.<\/p>\n<p>The latter seems like it has the potential to be more thorough and far-reaching, but also considerably more involved and computationally inefficient. Unfortunately, similar things can be said about the regex approach. Although it does seem a little easier to implement and\u2013for limited cases\u2013efficient enough for most purposes, it also comes across as much more brittle.<\/p>\n<p>Fortunately, the latter is unnecessary, because, when the existing <code>repr<\/code> output is Python readable, it can be parsed by <code>ast.parse<\/code>. The function <code>ast.parse<\/code> effectively handles the regex work and yields the bulk of information needed for a Hy <code>repr<\/code> string: the function name and its (positional and keyword) arguments.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>Let\u2019s say we implement our own object and <code>repr<\/code>.<\/p>\n<pre id=\"orgcb48770\" class=\"hy\"><code>(defclass TestClass [object]\n  (defn --init-- [self arg1 arg2 &amp;optional kwarg1 kwarg2]\n    (setv self.arg1 arg1\n          self.arg2 arg2\n          self.kwarg1 kwarg1\n          self.kwarg2 kwarg2))\n  (defn --repr-- [self]\n    (.format &quot;TestClass({}, {}, kwarg1={}, kwarg2={})&quot;\n             #* (lfor a [self.arg1 self.arg2\n                         self.kwarg1 self.kwarg2]\n                      (repr a)))))\n\n(setv test-obj (TestClass 1 {&quot;a&quot; 1 &quot;b&quot; 2} :kwarg1 1 :kwarg2 &quot;ok&quot;))\n(print (repr test-obj))<\/code><\/pre>\n<div class=\"sourceCode\" id=\"org2d96184\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org2d96184-1\"><a href=\"#org2d96184-1\" aria-hidden=\"true\"><\/a>TestClass(<span class=\"dv\">1<\/span>, {<span class=\"st\">&#39;a&#39;<\/span>: <span class=\"dv\">1<\/span>, <span class=\"st\">&#39;b&#39;<\/span>: <span class=\"dv\">2<\/span>}, kwarg1<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>, kwarg2<span class=\"op\">=<\/span><span class=\"st\">&#39;ok&#39;<\/span>)<\/span><\/code><\/pre><\/div>\n<p>Since the results are readable, we can do the following:<\/p>\n<pre id=\"org382ba33\" class=\"hy\"><code>(import ast astor)\n(setv repr-ast (ast.parse (repr test-obj) :mode &quot;eval&quot;))\n(print (astor.dump repr-ast))<\/code><\/pre>\n<div class=\"sourceCode\" id=\"orgf08b55b\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgf08b55b-1\"><a href=\"#orgf08b55b-1\" aria-hidden=\"true\"><\/a>Expression(<\/span>\n<span id=\"orgf08b55b-2\"><a href=\"#orgf08b55b-2\" aria-hidden=\"true\"><\/a>    body<span class=\"op\">=<\/span>Call(func<span class=\"op\">=<\/span>Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;TestClass&#39;<\/span>),<\/span>\n<span id=\"orgf08b55b-3\"><a href=\"#orgf08b55b-3\" aria-hidden=\"true\"><\/a>              args<span class=\"op\">=<\/span>[Num(n<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>),<\/span>\n<span id=\"orgf08b55b-4\"><a href=\"#orgf08b55b-4\" aria-hidden=\"true\"><\/a>                    Dict(keys<span class=\"op\">=<\/span>[Str(s<span class=\"op\">=<\/span><span class=\"st\">&#39;a&#39;<\/span>), Str(s<span class=\"op\">=<\/span><span class=\"st\">&#39;b&#39;<\/span>)],<\/span>\n<span id=\"orgf08b55b-5\"><a href=\"#orgf08b55b-5\" aria-hidden=\"true\"><\/a>                         values<span class=\"op\">=<\/span>[Num(n<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>), Num(n<span class=\"op\">=<\/span><span class=\"dv\">2<\/span>)])],<\/span>\n<span id=\"orgf08b55b-6\"><a href=\"#orgf08b55b-6\" aria-hidden=\"true\"><\/a>              keywords<span class=\"op\">=<\/span>[keyword(arg<span class=\"op\">=<\/span><span class=\"st\">&#39;kwarg1&#39;<\/span>, value<span class=\"op\">=<\/span>Num(n<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>)),<\/span>\n<span id=\"orgf08b55b-7\"><a href=\"#orgf08b55b-7\" aria-hidden=\"true\"><\/a>                        keyword(arg<span class=\"op\">=<\/span><span class=\"st\">&#39;kwarg2&#39;<\/span>, value<span class=\"op\">=<\/span>Str(s<span class=\"op\">=<\/span><span class=\"st\">&#39;ok&#39;<\/span>))]))<\/span><\/code><\/pre><\/div>\n<\/div>\n<\/section>\n<section id=\"a-generalized-hy-repr-prototype\" class=\"level2\">\n<h2>A Generalized Hy <code>repr<\/code> Prototype<\/h2>\n<p>With existing <code>repr<\/code> output converted to Python AST by Python itself (using <code>ast.parse<\/code>), we can produce readable Hy strings from the resulting AST objects.<\/p>\n<p>In this scenario, we need only be concerned with the conversion of Python AST into readable Hy strings. This works like an inverse to the Hy compiler: in other words, a Hy decompiler. For <code>repr<\/code> purposes, only function call statements and their arguments need to be decompiled. Unfortunately, function arguments can consist of arbitrary Python\/Hy objects, and that\u2019s how the decompilation responsibilities start to expand. If we limit our scope to a reasonable subset of Python builtins\/primitives, the results can still be quite effective, and won\u2019t require a complete decompiler.<\/p>\n<p>On the down-side, if a Hy <code>repr<\/code> implementation overrides the built-in <code>repr<\/code>, then arguments in existing <code>repr<\/code>\/<code>__repr__<\/code>s might already be converted by the overridden <code>repr<\/code>; however, the results from <code>ast.parse<\/code> will undo\/discard those results. Even so, custom class <code>__repr__<\/code>s aren\u2019t guaranteed to use the built-in <code>repr<\/code> on their arguments, so attempts to salvage already-converted <code>repr<\/code> output are undeniably fraught with complications.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>Working from the <code>repr<\/code>-produced AST above, I mocked-up a quick prototype for a generic Python-to-Hy conversion function.<\/p>\n<pre id=\"orga431a0a\" class=\"hy\"><code>(import ast)\n(import builtins)\n\n(import [hy.contrib.hy-repr [hy-repr :as -hy-repr]])\n\n(defn ast-funcall-to-hy [ast-obj repr1\n                         &amp;optional [level 1]]\n  &quot;Turn Python `ast.Call` expressions into Hy `repr` strings.\n\nXXX: Only a very minimal subset of Python-to-Hy AST is implemented.\n\nThis can be used to turn a \\&quot;readable\\&quot; `repr` result, via an actual \\&quot;read\\&quot; by\n`ast.parse`, to Python AST then Hy AST.\n&quot;\n  (assert (and (instance? ast.Expression ast-obj)\n               (instance? ast.Call ast-obj.body)))\n  (setv func-name (. ast-obj body func id))\n  (setv eval-fn (fn [o]\n                  (if (instance? ast.Name o)\n                      o.id\n                      (repr1 (ast.literal-eval o) (dec level)))))\n  (setv func-args (lfor a (. ast-obj body args) (eval-fn a)))\n  (setv func-kwargs (lfor k (. ast-obj body keywords)\n                          (.format &quot;:{} {}&quot; k.arg (eval-fn k.value))))\n  (.format &quot;({})&quot; (.join &quot; &quot; (+ [func-name] func-args func-kwargs))))\n\n\n(setv test-ast (ast.parse &quot;range(x, y, blah=1, bloh=\\&quot;ok\\&quot;)&quot; :mode &quot;eval&quot;))\n(print (ast-funcall-to-hy test-ast (fn [x &amp;rest y] (-hy-repr x))))<\/code><\/pre>\n<pre id=\"org9d771a1\" class=\"hy\"><code>(range x y :blah 1 :bloh &quot;ok&quot;)<\/code><\/pre>\n<p><code>ast-funcall-to-hy<\/code> is an extremely narrow decompiler that only handles readable function calls (represented by <code>ast.Call<\/code> nodes), but, as part of a fallback sequence in a Hy <code>repr<\/code> implementation, it\u2019s still pretty useful.<\/p>\n<p>A function like <code>ast-funcall-to-hy<\/code> can be used in <code>repr<\/code> logic as follows:<\/p>\n<pre id=\"org80332ca\" class=\"hy\"><code>(defn hy-repr [x &amp;optional [level 1] [-repr (fn [x &amp;rest y] (-hy-repr x))]]\n  &quot;Use `builtin.repr` results to generate readable Hy `repr` strings for cases\nwe haven&#39;t covered explicitly.\n&quot;\n  (try\n    (setv s (builtins.repr x))\n    (when (not (.startswith s &quot;&lt;&quot;))\n      (do\n        (setv repr-ast (ast.parse s :mode &quot;eval&quot;))\n        (setv s (ast-funcall-to-hy repr-ast -repr))))\n    s\n    (except [Exception]\n      (.format &quot;&lt;{} instance at {}&gt;&quot; x.__class__.__name__ (id x)))))<\/code><\/pre>\n<p>Now, for the example class, <code>TestClass<\/code>, we can demonstrate automatic conversion of its Python <code>__repr__<\/code> implementation.<\/p>\n<pre id=\"org21d5bb7\" class=\"hy\"><code>(setv test-ast (TestClass 1 {&quot;a&quot; 2 &quot;b&quot; 3} :kwarg1 1 :kwarg2 &quot;ok&quot;))\n(print (.format &quot;before: {}\\nafter: {}&quot;\n                (repr test-ast)\n                (hy-repr test-ast)))<\/code><\/pre>\n<pre id=\"orgb59b9c2\" class=\"text\"><code>before: TestClass(1, {&#39;a&#39;: 2, &#39;b&#39;: 3}, kwarg1=1, kwarg2=&#39;ok&#39;)\nafter: (TestClass 1 {&quot;a&quot; 2  &quot;b&quot; 3} :kwarg1 1 :kwarg2 &quot;ok&quot;)<\/code><\/pre>\n<\/div>\n<\/section>\n<\/section>\n<section id=\"relational-programming\" class=\"level1\">\n<h1>Relational Programming<\/h1>\n<p>While considering all this, I kept thinking about how nice it would be to have a \u201cbijective\u201d compiler; in other words, the existing Hy compiler, which translates Hy-to-Python, <strong>and<\/strong> a Python-to-Hy (de)compiler. With a Python-to-Hy AST compiler, we could more broadly convert Python AST output\u2013like the kind in our example above\u2013to a <code>repr<\/code>\/readable string in Hy.<\/p>\n<p>The idea isn\u2019t too crazy, especially since one can easily work backward from a lot of the logic in the existing Hy compiler. There will be some edge cases that result in non-bijective translations (i.e.\u00a0some round-trip Hy\/Python translations might only be <strong>equivalent<\/strong> and not exactly <strong>equal<\/strong>), but this isn\u2019t necessarily a blocking issue. Decisions regarding \u201ccanonical\u201d or reduced forms of Hy\/Python AST might be necessary, especially if the resulting AST is intended to be more human readable than not.<\/p>\n<p>Perhaps what\u2019s more discouraging is the effort it would take to ensure that the compilation processes going both ways are\u2013and stay\u2013coherent during the course of development. For instance, when changes are made to the standard compilation process (i.e.\u00a0Hy-to-Python), it\u2019s likely that changes and tests would also be needed for the other direction.<\/p>\n<p>This is where a paradigm like relational programming is particularly appealing: it provides a language for defining\u2013and means for computing\u2013the maps<\/p>\n<p><span class=\"math display\">\\[\\begin{equation*}\n  \\text{Hy Syntax}\n  \\longleftrightarrow \\text{Python AST}\n  \\longleftrightarrow \\text{Python Syntax}\n  \\;\n\\end{equation*}\\]<\/span><\/p>\n<p>in a cohesive way.<\/p>\n<p>My relational programming DSL of choice, <a href=\"http:\/\/minikanren.org\">miniKanren<\/a>, already has an implementation in Hy: <a href=\"https:\/\/github.com\/algernon\/adderall\"><code>loghyc<\/code> (and to be formally known as <code>adderall<\/code>)<\/a>. We\u2019ve been using it to perform static code analysis and refactoring in the project <a href=\"https:\/\/github.com\/hylang\/hydiomatic\"><code>hydiomatic<\/code><\/a>, so there\u2019s also a precedent for parsing Hy syntax in a relational context.<\/p>\n<p>The missing\/next step would be to output Python AST (instead of more Hy forms, like <code>hydiomatic<\/code> produces, for example). In the following sections, we will construct a small relational Hy\/Python compiler as a proof-of-concept.<\/p>\n<section id=\"a-prototype-relational-compiler\" class=\"level2\">\n<h2>A Prototype Relational Compiler<\/h2>\n<p>Creating a bi-directional Hy\/Python AST compiler in miniKanren involves the construction of goals \u201crelating\u201d the two AST forms. For simplicity, we\u2019ll just consider function call expressions, like <code>func(args)<\/code> and <code>(func args)<\/code>.<\/p>\n<div class=\"remark\" data-markdown=\"\">\n<p>Also, since these kinds of relations are more easy to specify using constraints and subtle unification adjustments, we\u2019ll use a prototype microKanren implementation in Hy that provides immediate access to those: <a href=\"https:\/\/github.com\/brandonwillard\/hypoKanren\"><code>hypoKanren<\/code><\/a>.<\/p>\n<p>Regardless, given the universality of miniKanren, the goals we construct should be directly translate-able to other implementations of miniKanren (even in completely different host languages).<\/p>\n<p>The only obvious caveat to such translation is the availability of traditional <code>cons<\/code> semantics in the host language (i.e.\u00a0the standard Lisp behavior of <code>cons<\/code>, <code>car<\/code>, <code>cdr<\/code>, and improper lists\/<code>cons<\/code> pairs).<\/p>\n<\/div>\n<pre id=\"orgc53b310\" class=\"hy\"><code>(import ast)\n(import astor)\n(import types)\n(import [collections [Callable]])\n\n(import hy.models)\n(import [hy.compiler [asty hy-eval hy-compile]])\n\n(import [hypoKanren.goals [*]])\n(import [hypoKanren.core [*]])\n\n\n(require [hy.contrib.walk [let]])\n(require [hypoKanren.goals [*]])\n(require [hypoKanren.core [*]])<\/code><\/pre>\n<p>First, let\u2019s examine the general structure of the Python AST output generated by the Hy compiler for the Hy function-call given by <code>`(func x :y z)<\/code>.<\/p>\n<pre id=\"org7521929\" class=\"hy\"><code>(astor.dump (hy-compile `(func x :y z) &quot;__console__&quot;))<\/code><\/pre>\n<div class=\"sourceCode\" id=\"orgb137229\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgb137229-1\"><a href=\"#orgb137229-1\" aria-hidden=\"true\"><\/a>Module(<\/span>\n<span id=\"orgb137229-2\"><a href=\"#orgb137229-2\" aria-hidden=\"true\"><\/a>    body<span class=\"op\">=<\/span>[Expr(value<span class=\"op\">=<\/span>Call(func<span class=\"op\">=<\/span>Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;func&#39;<\/span>), args<span class=\"op\">=<\/span>[Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;x&#39;<\/span>)], keywords<span class=\"op\">=<\/span>[keyword(arg<span class=\"op\">=<\/span><span class=\"st\">&#39;y&#39;<\/span>, value<span class=\"op\">=<\/span>Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;z&#39;<\/span>))]))])<\/span><\/code><\/pre><\/div>\n<p>In what follows, we\u2019ll exclude the <code>ast.Module<\/code> and focus only on the <code>src.Expr<\/code> and its children.<\/p>\n<section id=\"ast-object-unification\" class=\"level3\">\n<h3>AST Object Unification<\/h3>\n<p>To make existing Python AST objects amenable to the <a href=\"https:\/\/en.wikipedia.org\/wiki\/Unification_(computer_science)\">unification<\/a> used by miniKanren, we implement <code>unify<\/code> specializations for <code>ast.AST<\/code> types. Our implementation simply generates unevaluated Hy forms, or Hy AST, that\u2013when evaluated\u2013would (re)create the <code>ast.AST<\/code> objects.<\/p>\n<div class=\"REMARK\">\n<p>Alternatively, we could only ever use and create unevaluated Hy forms for Python AST. Providing unification for AST objects allows for more immediate integration with existing Python code and\/or what it would most likely produce.<\/p>\n<\/div>\n<p><code>hypoKanren<\/code> uses <a href=\"https:\/\/github.com\/mrocklin\/multipledispatch\"><code>multipledispatch<\/code><\/a>, so augmenting the unification process is easy. This is how we\u2019ll add support for AST objects.<\/p>\n<div class=\"REMARK\">\n<p>There\u2019s already a good pure Python library for unification built upon <code>multipledispatch<\/code>, <a href=\"https:\/\/github.com\/mrocklin\/unification\"><code>unfication<\/code><\/a>. At a later time, it might be worthwhile to simply add support for Hy objects and use that library instead.<\/p>\n<\/div>\n<pre id=\"org1468291\" class=\"hy\"><code>(import [multipledispatch [dispatch]])\n(import [hypoKanren.unify [*]])\n(import [hy.models [*]])\n(import [hy.contrib.walk [prewalk]])\n\n\n(defmacro\/g! dispatch-unify-trans [disp-type trans-func &amp;optional [func &#39;unify]]\n  `(do\n     #@((dispatch ~disp-type object object)\n        (defn unify-post-walk [~g!u ~g!v ~g!s]\n          (~func (~trans-func ~g!u) ~g!v ~g!s)))\n     #@((dispatch object ~disp-type object)\n        (defn unify-post-walk [~g!u ~g!v ~g!s]\n          (~func ~g!u (~trans-func ~g!v) ~g!s)))\n     #@((dispatch ~disp-type ~disp-type object)\n        (defn unify-post-walk [~g!u ~g!v ~g!s]\n          (~func (~trans-func ~g!u) (~trans-func ~g!v) ~g!s)))))\n\n(defn py-ast-to-expr [x]\n  (defn -py-ast-to-expr [u]\n    (setv ast-expr\n          `(~(HySymbol (+ &quot;ast.&quot; (name (type u))))\n            ~@(chain.from-iterable\n                (lfor f u.-fields\n                      :if (hasattr u f)\n                      [(HyKeyword f) (getattr u f)]))))\n    ast-expr)\n  (prewalk (fn [y] (if (instance? ast.AST y)\n                       (-py-ast-to-expr y)\n                       y))\n           x))\n\n;; Python AST expansion pre-unification\n(dispatch-unify-trans ast.AST (fn [x] (py-ast-to-expr x)))<\/code><\/pre>\n<div class=\"example\" data-markdown=\"\">\n<pre id=\"org70c1279\" class=\"hy\"><code>;; One is an `ast.AST` object, the other an unevaluated `ast.AST`\n;; object-generating form.\n(setv unify-exa-1 (unify (ast.Expr :value [])\n                         `(ast.Expr :value ~(var 0))\n                         {}))\n\n;; Both are `ast.AST` objects\n(setv unify-exa-2 (unify (ast.Expr :value [])\n                         (ast.Expr :value (var 0))\n                         {}))\n\n(= (.get unify-exa-1 (var 0))\n   (.get unify-exa-2 (var 0))\n   [])<\/code><\/pre>\n<div class=\"sourceCode\" id=\"orgc9b3e95\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgc9b3e95-1\"><a href=\"#orgc9b3e95-1\" aria-hidden=\"true\"><\/a><span class=\"va\">True<\/span><\/span><\/code><\/pre><\/div>\n<p>Listing <a href=\"#org70c1279\">16<\/a> illustrates unification of two <code>ast.AST<\/code> forms. The <code>(var 0)<\/code> objects are \u201clogic variables\u201d taking the value of sub-expressions that cause the two <code>unify<\/code> arguments to, well, unify. The third argument to <code>unify<\/code> is simply a <code>dict<\/code> that stores the logic variable\/sub-expression mappings.<\/p>\n<p>In other words, logic variables are like unknowns that <code>unify(u, v, s)<\/code> will \u201csolve\u201d in order to make <code>u<\/code> and <code>v<\/code> equal.<\/p>\n<\/div>\n<div class=\"example\" data-markdown=\"\">\n<pre id=\"orgd9478e0\" class=\"hy\"><code>(unify (cons &#39;ast.Expr (var 0))\n       (ast.Expr :value [(ast.Name :id &quot;a&quot;)])\n       {})<\/code><\/pre>\n<div class=\"sourceCode\" id=\"org3302cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org3302cb2-1\"><a href=\"#org3302cb2-1\" aria-hidden=\"true\"><\/a>{(LVar <span class=\"dv\">0<\/span>): HyExpression([<\/span>\n<span id=\"org3302cb2-2\"><a href=\"#org3302cb2-2\" aria-hidden=\"true\"><\/a>  HyKeyword(<span class=\"st\">&#39;value&#39;<\/span>),<\/span>\n<span id=\"org3302cb2-3\"><a href=\"#org3302cb2-3\" aria-hidden=\"true\"><\/a>  [HyExpression([<\/span>\n<span id=\"org3302cb2-4\"><a href=\"#org3302cb2-4\" aria-hidden=\"true\"><\/a>    HySymbol(<span class=\"st\">&#39;ast.Name&#39;<\/span>),<\/span>\n<span id=\"org3302cb2-5\"><a href=\"#org3302cb2-5\" aria-hidden=\"true\"><\/a>    HyKeyword(<span class=\"st\">&#39;id&#39;<\/span>),<\/span>\n<span id=\"org3302cb2-6\"><a href=\"#org3302cb2-6\" aria-hidden=\"true\"><\/a>    <span class=\"st\">&#39;a&#39;<\/span>])]])}<\/span><\/code><\/pre><\/div>\n<p>Listing <a href=\"#orgd9478e0\">18<\/a> is a more interesting example that demonstrates partial\/improper list unification. Since <code>ast.AST<\/code> objects are expanded into equal object-instantiating Hy AST forms, <code>(cons 'ast.Expr (var 0))<\/code> is ultimately unified with a <code>HyExpression<\/code> (a subclass of <code>list<\/code>). Under the <code>cons<\/code> abstraction, <code>(var 0)<\/code> can be anything that\u2013when <code>cons<\/code>ed with the symbol <code>ast.Expr<\/code>\u2013will produce the expression <code>(ast.Expr :value [(ast.Name :id \"a\")])<\/code>. The result is the partial <code>HyExpression<\/code> comprising the arguments to the <code>ast.Expr<\/code> constructor\u2013in other words, the <code>cdr<\/code> of the <code>ast.AST<\/code> form.<\/p>\n<\/div>\n<p>We will also need to unify some limited Hy AST forms; specifically, <code>HySymbol<\/code>s. We will want to extract only the name part of a Hy symbol and relate that to Python <code>ast.Name<\/code>s via one of the latter\u2019s constructor arguments.<\/p>\n<p>Similar to Python AST nodes, we will expand\/lift\/abstract <code>HySymbol<\/code>s to Hy expressions that\u2013when <code>eval<\/code>\u2019ed\u2013would construct them. We can only do this in very limited cases; otherwise, we could end up producing ever-expanding forms.<\/p>\n<pre id=\"org7b1a1a0\" class=\"hy\"><code>;; Hy AST expansion pre-unification\n(defn unify-hysymbol [u v s]\n  (cond\n    [(= (first v) &#39;HySymbol)\n     (print )\n     (unify `(HySymbol ~(name u)) v s)]\n    [True\n     (unify u v s)]))\n\n#@((dispatch HySymbol HyExpression object)\n   (defn unify-post-walk [u v s]\n     (unify-hysymbol u v s)))\n\n#@((dispatch HyExpression HySymbol object)\n   (defn unify-post-walk [u v s]\n     (unify-hysymbol v u s)))<\/code><\/pre>\n<div class=\"example\" data-markdown=\"\">\n<pre id=\"orgbc4dbe9\" class=\"hy\"><code>(unify &#39;a `(HySymbol ~(var 0)) {})<\/code><\/pre>\n<div class=\"sourceCode\" id=\"orga00b5a1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orga00b5a1-1\"><a href=\"#orga00b5a1-1\" aria-hidden=\"true\"><\/a>{(LVar <span class=\"dv\">0<\/span>): <span class=\"st\">&#39;a&#39;<\/span>}<\/span><\/code><\/pre><\/div>\n<p>Listing <a href=\"#org7b1a1a0\">20<\/a> demonstrates the expansion and unification of Hy AST symbols.<\/p>\n<\/div>\n<\/section>\n<section id=\"call-expression-goals\" class=\"level3\">\n<h3>Call-expression Goals<\/h3>\n<p>Next, we create the miniKanren goals that encapsulate the relationships between simple Hy and Python AST forms. In particular, we\u2019ll limit ourselves to only variable reference and function call forms.<\/p>\n<pre id=\"orge442cd6\" class=\"hy\"><code>(defn listo [l]\n  &quot;A goal stating that `l` is a list.&quot;\n  (conde\n    [(== l []) s#]\n    [(fresh [lcar lcdr]\n            (== l (cons lcar lcdr))\n            (listo lcdr))]\n    [s# u#]))<\/code><\/pre>\n<p>The first AST relation is a simple one between <code>HySymbol<\/code>s and <code>ast.Name<\/code>s. This is where the <code>HySymbol<\/code> unification implemented above is used.<\/p>\n<pre id=\"org9fe6aea\" class=\"hy\"><code>(defn hy-py-symbolo [hy-ast py-ast]\n  &quot;A goal relating Hy and Python AST symbol\/name objects (e.g. variable and\n function references).&quot;\n  (fresh [symbol-name py-ctx]\n         (== hy-ast `(HySymbol ~symbol-name))\n         (== py-ast `(ast.Name :id ~symbol-name\n                               :ctx (ast.Load)))))<\/code><\/pre>\n<p>Some Python <code>ast.AST<\/code> types have fields consisting of lists containing other <code>ast.AST<\/code> objects (e.g.\u00a0the <code>ast.Call<\/code> expressions below). We need a goal that enforces a relation between the Hy and Python AST forms of each element in such lists.<\/p>\n<pre id=\"org988ece9\" class=\"hy\"><code>(defn lapplyo [func l-in l-out]\n  &quot;A goal that applies the goal `func` between all elements in lists `l-in` and\n `l-out`.&quot;\n  (conj+\n    (listo l-in)\n    (conde\n      [(fresh [lcar lcdr lout-car lout-cdr]\n              (== l-in (cons lcar lcdr))\n              (func lcar lout-car)\n              (lapplyo func lcdr lout-cdr)\n              (== l-out (cons lout-car lout-cdr)))]\n      [(== l-in [])\n       (== l-out l-in)])))<\/code><\/pre>\n<p>Finally, we create a goal for the AST of call expressions like <code>func(x, y, z)<\/code> and <code>(func x y z)<\/code>.<\/p>\n<pre id=\"org10bec84\" class=\"hy\"><code>(defn hy-py-callo [hy-ast py-ast]\n  &quot;A goal relating call expressions in Python and Hy AST.&quot;\n  (fresh [hy-op hy-args py-op py-args]\n         ;; Hy AST form\n         (== (cons hy-op hy-args) hy-ast)\n         ;; Py AST form\n         (== py-ast `(ast.Expr :value\n                               (ast.Call :func\n                                         ~py-op\n                                         :args\n                                         ~py-args\n                                         :keywords\n                                         [])))\n         ;; These two must be related symbols\n         (hy-py-symbolo hy-op py-op)\n         ;; The arguments are related lists containing more of each AST type.\n         (lapplyo hy-py-asto hy-args py-args)))\n\n(defn hy-py-asto [hy-ast py-ast]\n  &quot;A goal for a &#39;branching&#39; relation between multiple types of forms and their\n corresponding Python AST.&quot;\n  (conde\n    [(hy-py-symbolo hy-ast py-ast)]\n    [(hy-py-callo hy-ast py-ast)]))<\/code><\/pre>\n<div class=\"example\" data-markdown=\"\">\n<p>To demonstrate our [extremely] minimal relational compiler, we create a Hy function call expression and its corresponding Python AST.<\/p>\n<pre id=\"org7845670\" class=\"hy\"><code>(setv hy-ast-exa `(print x y z))\n(setv py-ast-exa (. (hy-compile hy-ast-exa &quot;__console__&quot;) body [0]))\n(.format &quot;hy_ast_exa = {}\\npy_ast_exa = {}&quot;\n         hy-ast-exa\n         (astor.dump py-ast-exa))<\/code><\/pre>\n<div class=\"sourceCode\" id=\"org09b193c\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org09b193c-1\"><a href=\"#org09b193c-1\" aria-hidden=\"true\"><\/a>hy_ast_exa <span class=\"op\">=<\/span> HyExpression([<\/span>\n<span id=\"org09b193c-2\"><a href=\"#org09b193c-2\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;print&#39;<\/span>),<\/span>\n<span id=\"org09b193c-3\"><a href=\"#org09b193c-3\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;x&#39;<\/span>),<\/span>\n<span id=\"org09b193c-4\"><a href=\"#org09b193c-4\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;y&#39;<\/span>),<\/span>\n<span id=\"org09b193c-5\"><a href=\"#org09b193c-5\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;z&#39;<\/span>)])<\/span>\n<span id=\"org09b193c-6\"><a href=\"#org09b193c-6\" aria-hidden=\"true\"><\/a>py_ast_exa <span class=\"op\">=<\/span> Expr(value<span class=\"op\">=<\/span>Call(func<span class=\"op\">=<\/span>Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;print&#39;<\/span>), args<span class=\"op\">=<\/span>[Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;x&#39;<\/span>), Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;y&#39;<\/span>), Name(<span class=\"bu\">id<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;z&#39;<\/span>)], keywords<span class=\"op\">=<\/span>[]))<\/span><\/code><\/pre><\/div>\n<p>We first run the Hy-to-Python direction by providing <code>hy-expro<\/code> the <code>hy-ast-exa<\/code> value above and a logic variable (i.e.\u00a0an \u201cunknown\u201d) for the Python AST term.<\/p>\n<pre id=\"org5336211\" class=\"hy\"><code>(setv rel-res (run 1 [py-ast] (hy-py-asto hy-ast-exa py-ast)))\n(setv ast-res (get rel-res 0 0))\nast-res<\/code><\/pre>\n<div class=\"sourceCode\" id=\"org63b2bde\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"org63b2bde-1\"><a href=\"#org63b2bde-1\" aria-hidden=\"true\"><\/a>HyExpression([<\/span>\n<span id=\"org63b2bde-2\"><a href=\"#org63b2bde-2\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;ast.Expr&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-3\"><a href=\"#org63b2bde-3\" aria-hidden=\"true\"><\/a>  HyKeyword(<span class=\"st\">&#39;value&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-4\"><a href=\"#org63b2bde-4\" aria-hidden=\"true\"><\/a>  HyExpression([<\/span>\n<span id=\"org63b2bde-5\"><a href=\"#org63b2bde-5\" aria-hidden=\"true\"><\/a>    HySymbol(<span class=\"st\">&#39;ast.Call&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-6\"><a href=\"#org63b2bde-6\" aria-hidden=\"true\"><\/a>    HyKeyword(<span class=\"st\">&#39;func&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-7\"><a href=\"#org63b2bde-7\" aria-hidden=\"true\"><\/a>    HyExpression([<\/span>\n<span id=\"org63b2bde-8\"><a href=\"#org63b2bde-8\" aria-hidden=\"true\"><\/a>      HySymbol(<span class=\"st\">&#39;ast.Name&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-9\"><a href=\"#org63b2bde-9\" aria-hidden=\"true\"><\/a>      HyKeyword(<span class=\"st\">&#39;id&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-10\"><a href=\"#org63b2bde-10\" aria-hidden=\"true\"><\/a>      <span class=\"st\">&#39;print&#39;<\/span>,<\/span>\n<span id=\"org63b2bde-11\"><a href=\"#org63b2bde-11\" aria-hidden=\"true\"><\/a>      HyKeyword(<span class=\"st\">&#39;ctx&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-12\"><a href=\"#org63b2bde-12\" aria-hidden=\"true\"><\/a>      HyExpression([<\/span>\n<span id=\"org63b2bde-13\"><a href=\"#org63b2bde-13\" aria-hidden=\"true\"><\/a>        HySymbol(<span class=\"st\">&#39;ast.Load&#39;<\/span>)])]),<\/span>\n<span id=\"org63b2bde-14\"><a href=\"#org63b2bde-14\" aria-hidden=\"true\"><\/a>    HyKeyword(<span class=\"st\">&#39;args&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-15\"><a href=\"#org63b2bde-15\" aria-hidden=\"true\"><\/a>    HyExpression([<\/span>\n<span id=\"org63b2bde-16\"><a href=\"#org63b2bde-16\" aria-hidden=\"true\"><\/a>      HyExpression([<\/span>\n<span id=\"org63b2bde-17\"><a href=\"#org63b2bde-17\" aria-hidden=\"true\"><\/a>        HySymbol(<span class=\"st\">&#39;ast.Name&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-18\"><a href=\"#org63b2bde-18\" aria-hidden=\"true\"><\/a>        HyKeyword(<span class=\"st\">&#39;id&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-19\"><a href=\"#org63b2bde-19\" aria-hidden=\"true\"><\/a>        <span class=\"st\">&#39;x&#39;<\/span>,<\/span>\n<span id=\"org63b2bde-20\"><a href=\"#org63b2bde-20\" aria-hidden=\"true\"><\/a>        HyKeyword(<span class=\"st\">&#39;ctx&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-21\"><a href=\"#org63b2bde-21\" aria-hidden=\"true\"><\/a>        HyExpression([<\/span>\n<span id=\"org63b2bde-22\"><a href=\"#org63b2bde-22\" aria-hidden=\"true\"><\/a>          HySymbol(<span class=\"st\">&#39;ast.Load&#39;<\/span>)])]),<\/span>\n<span id=\"org63b2bde-23\"><a href=\"#org63b2bde-23\" aria-hidden=\"true\"><\/a>      HyExpression([<\/span>\n<span id=\"org63b2bde-24\"><a href=\"#org63b2bde-24\" aria-hidden=\"true\"><\/a>        HySymbol(<span class=\"st\">&#39;ast.Name&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-25\"><a href=\"#org63b2bde-25\" aria-hidden=\"true\"><\/a>        HyKeyword(<span class=\"st\">&#39;id&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-26\"><a href=\"#org63b2bde-26\" aria-hidden=\"true\"><\/a>        <span class=\"st\">&#39;y&#39;<\/span>,<\/span>\n<span id=\"org63b2bde-27\"><a href=\"#org63b2bde-27\" aria-hidden=\"true\"><\/a>        HyKeyword(<span class=\"st\">&#39;ctx&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-28\"><a href=\"#org63b2bde-28\" aria-hidden=\"true\"><\/a>        HyExpression([<\/span>\n<span id=\"org63b2bde-29\"><a href=\"#org63b2bde-29\" aria-hidden=\"true\"><\/a>          HySymbol(<span class=\"st\">&#39;ast.Load&#39;<\/span>)])]),<\/span>\n<span id=\"org63b2bde-30\"><a href=\"#org63b2bde-30\" aria-hidden=\"true\"><\/a>      HyExpression([<\/span>\n<span id=\"org63b2bde-31\"><a href=\"#org63b2bde-31\" aria-hidden=\"true\"><\/a>        HySymbol(<span class=\"st\">&#39;ast.Name&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-32\"><a href=\"#org63b2bde-32\" aria-hidden=\"true\"><\/a>        HyKeyword(<span class=\"st\">&#39;id&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-33\"><a href=\"#org63b2bde-33\" aria-hidden=\"true\"><\/a>        <span class=\"st\">&#39;z&#39;<\/span>,<\/span>\n<span id=\"org63b2bde-34\"><a href=\"#org63b2bde-34\" aria-hidden=\"true\"><\/a>        HyKeyword(<span class=\"st\">&#39;ctx&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-35\"><a href=\"#org63b2bde-35\" aria-hidden=\"true\"><\/a>        HyExpression([<\/span>\n<span id=\"org63b2bde-36\"><a href=\"#org63b2bde-36\" aria-hidden=\"true\"><\/a>          HySymbol(<span class=\"st\">&#39;ast.Load&#39;<\/span>)])])]),<\/span>\n<span id=\"org63b2bde-37\"><a href=\"#org63b2bde-37\" aria-hidden=\"true\"><\/a>    HyKeyword(<span class=\"st\">&#39;keywords&#39;<\/span>),<\/span>\n<span id=\"org63b2bde-38\"><a href=\"#org63b2bde-38\" aria-hidden=\"true\"><\/a>    HyList()])])<\/span><\/code><\/pre><\/div>\n<p>And, now, the other direction (i.e.\u00a0known Python AST, unknown Hy AST).<\/p>\n<pre id=\"org7148809\" class=\"hy\"><code>(setv rel-res (run 1 [hy-ast] (hy-py-asto hy-ast py-ast-exa)))\n(setv ast-res (get rel-res 0 0))\nast-res<\/code><\/pre>\n<div class=\"sourceCode\" id=\"orgbf40b7d\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"orgbf40b7d-1\"><a href=\"#orgbf40b7d-1\" aria-hidden=\"true\"><\/a>[HyExpression([<\/span>\n<span id=\"orgbf40b7d-2\"><a href=\"#orgbf40b7d-2\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;HySymbol&#39;<\/span>),<\/span>\n<span id=\"orgbf40b7d-3\"><a href=\"#orgbf40b7d-3\" aria-hidden=\"true\"><\/a>  <span class=\"st\">&#39;print&#39;<\/span>]), HyExpression([<\/span>\n<span id=\"orgbf40b7d-4\"><a href=\"#orgbf40b7d-4\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;HySymbol&#39;<\/span>),<\/span>\n<span id=\"orgbf40b7d-5\"><a href=\"#orgbf40b7d-5\" aria-hidden=\"true\"><\/a>  <span class=\"st\">&#39;x&#39;<\/span>]), HyExpression([<\/span>\n<span id=\"orgbf40b7d-6\"><a href=\"#orgbf40b7d-6\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;HySymbol&#39;<\/span>),<\/span>\n<span id=\"orgbf40b7d-7\"><a href=\"#orgbf40b7d-7\" aria-hidden=\"true\"><\/a>  <span class=\"st\">&#39;y&#39;<\/span>]), HyExpression([<\/span>\n<span id=\"orgbf40b7d-8\"><a href=\"#orgbf40b7d-8\" aria-hidden=\"true\"><\/a>  HySymbol(<span class=\"st\">&#39;HySymbol&#39;<\/span>),<\/span>\n<span id=\"orgbf40b7d-9\"><a href=\"#orgbf40b7d-9\" aria-hidden=\"true\"><\/a>  <span class=\"st\">&#39;z&#39;<\/span>])]<\/span><\/code><\/pre><\/div>\n<\/div>\n<\/section>\n<\/section>\n<\/section>\n<\/body>\n<\/html>\n","category":[{"@attributes":{"term":"articles"}},{"@attributes":{"term":"hy"}},{"@attributes":{"term":"relational programming"}},{"@attributes":{"term":"python"}}]},{"title":"Data Science at Citybase","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/data-science-at-citybase.html","rel":"alternate"}},"published":"2018-12-18T00:00:00-06:00","updated":"2019-10-22T00:00:00-05:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2018-12-18:\/data-science-at-citybase.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>Data Science at Citybase<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <!--   <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n         <script>\n          MathJax.Hub.Config({\n              tex2jax: {\n                  processEnvironments: true,\n                  processRefs: false\n              },\n              TeX: {\n                  equationNumbers: { autoNumber: \"AMS\" },\n                  extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n              }\n          });\n         <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">Data Science at Citybase<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2018\u201312\u201318<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<p>I recently wrote about data science at CityBase on the CityBase blog: <a href=\"#c26e94c1b0b80ac3545371089d4f9936\">Programming an Intelligent City: The Role of Data Science<\/a>.<\/p>\n<section id=\"bibliography\" class=\"level1\">\n<h1>Bibliography<\/h1>\n<p><a id=\"WillardProgrammingIntelligentCity2018a\"><\/a>[WillardProgrammingIntelligentCity2018a] Willard, Programming an Intelligent City: The Role of Data Science, <i>CityBase<\/i>, (2018). <a href=\"https:\/\/thecitybase.com\/programming-an-intelligent-city-the-role-of-data-science\/\">link<\/a>. <a href=\"#c26e94c1b0b80ac3545371089d4f9936\">\u21a9\ufe0e<\/a><\/p>\n<\/section>\n<\/body>\n<\/html>\n","category":[{"@attributes":{"term":"articles"}},{"@attributes":{"term":"statistics"}},{"@attributes":{"term":"symbolic computation"}},{"@attributes":{"term":"citybase"}}]},{"title":"Symbolic Math in PyMC3","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/symbolic-math-in-pymc3.html","rel":"alternate"}},"published":"2018-12-18T00:00:00-06:00","updated":"2018-12-26T00:00:00-06:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2018-12-18:\/symbolic-math-in-pymc3.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>Symbolic Math in PyMC3<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">Symbolic Math in PyMC3<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2018\u201312\u201318<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p>In <sup id=\"4407b21e48ab9ff17c017e8d62684725\"><a href=\"#WillardRoleSymbolicComputation2017\">A Role for Symbolic Computation in the General Estimation of Statistical Models<\/a><\/sup>, I described how symbolic computation is used by bayesian modeling software like PyMC3 and some directions it could take. It closed with an example of automatic normal-normal convolution using PyMC3 objects and Theano\u2019s optimization framework. This article elaborates on the foundations for symbolic mathematics in Theano and PyMC3; specifically, its current state, some challenges, and potential improvements.<\/p>\n<p>Let\u2019s start by reconsidering the simple normal-normal convolution model. Mathematically, we can represent the model as follows:<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  X \\sim N(0, 1), \\quad\n  Y \\sim N\\left(1, \\frac12\\right), \\quad\n  Z = X + Y \\sim N\\left(1, \\frac32\\right)\n  \\label{eq:norm_conv_model}\n\\end{equation}\\]<\/span><\/p>\n<p>Using PyMC3, the model for Equation <span class=\"math inline\">\\(\\eqref{eq:norm_conv_model}\\)<\/span> is constructed as follows:<\/p>\n<div class=\"sourceCode\" id=\"cb1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> sys<\/span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> os<\/span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> pprint <span class=\"im\">import<\/span> pprint<\/span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> numpy <span class=\"im\">as<\/span> np<\/span>\n<span id=\"cb1-7\"><a href=\"#cb1-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-8\"><a href=\"#cb1-8\" aria-hidden=\"true\"><\/a>os.environ[<span class=\"st\">&#39;MKL_THREADING_LAYER&#39;<\/span>] <span class=\"op\">=<\/span> <span class=\"st\">&#39;GNU&#39;<\/span><\/span>\n<span id=\"cb1-9\"><a href=\"#cb1-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-10\"><a href=\"#cb1-10\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano<\/span>\n<span id=\"cb1-11\"><a href=\"#cb1-11\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano.tensor <span class=\"im\">as<\/span> tt<\/span>\n<span id=\"cb1-12\"><a href=\"#cb1-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-13\"><a href=\"#cb1-13\" aria-hidden=\"true\"><\/a>theano.config.mode <span class=\"op\">=<\/span> <span class=\"st\">&#39;FAST_COMPILE&#39;<\/span><\/span>\n<span id=\"cb1-14\"><a href=\"#cb1-14\" aria-hidden=\"true\"><\/a>theano.config.exception_verbosity <span class=\"op\">=<\/span> <span class=\"st\">&#39;high&#39;<\/span><\/span>\n<span id=\"cb1-15\"><a href=\"#cb1-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-16\"><a href=\"#cb1-16\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pymc3 <span class=\"im\">as<\/span> pm<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb2-1\"><a href=\"#cb2-1\" aria-hidden=\"true\"><\/a>mu_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;mu_X&#39;<\/span>)<\/span>\n<span id=\"cb2-2\"><a href=\"#cb2-2\" aria-hidden=\"true\"><\/a>mu_X.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">0.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb2-3\"><a href=\"#cb2-3\" aria-hidden=\"true\"><\/a>sd_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;sd_X&#39;<\/span>)<\/span>\n<span id=\"cb2-4\"><a href=\"#cb2-4\" aria-hidden=\"true\"><\/a>sd_X.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">1.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb2-5\"><a href=\"#cb2-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-6\"><a href=\"#cb2-6\" aria-hidden=\"true\"><\/a>mu_Y <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;mu_Y&#39;<\/span>)<\/span>\n<span id=\"cb2-7\"><a href=\"#cb2-7\" aria-hidden=\"true\"><\/a>mu_Y.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">1.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb2-8\"><a href=\"#cb2-8\" aria-hidden=\"true\"><\/a>sd_Y <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;sd_Y&#39;<\/span>)<\/span>\n<span id=\"cb2-9\"><a href=\"#cb2-9\" aria-hidden=\"true\"><\/a>sd_Y.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">0.5<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb2-10\"><a href=\"#cb2-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-11\"><a href=\"#cb2-11\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model() <span class=\"im\">as<\/span> conv_model:<\/span>\n<span id=\"cb2-12\"><a href=\"#cb2-12\" aria-hidden=\"true\"><\/a>    X_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;X_rv&#39;<\/span>, mu_X, sd<span class=\"op\">=<\/span>sd_X, shape<span class=\"op\">=<\/span>(<span class=\"dv\">1<\/span>,))<\/span>\n<span id=\"cb2-13\"><a href=\"#cb2-13\" aria-hidden=\"true\"><\/a>    Y_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;Y_rv&#39;<\/span>, mu_Y, sd<span class=\"op\">=<\/span>sd_Y, shape<span class=\"op\">=<\/span>(<span class=\"dv\">1<\/span>,))<\/span>\n<span id=\"cb2-14\"><a href=\"#cb2-14\" aria-hidden=\"true\"><\/a>    Z_rv <span class=\"op\">=<\/span> X_rv <span class=\"op\">+<\/span> Y_rv<\/span><\/code><\/pre><\/div>\n<p>The Python objects representing terms in <span class=\"math inline\">\\(\\eqref{eq:norm_conv_model}\\)<\/span> are <code>X_rv<\/code>, <code>Y_rv<\/code>, and <code>Z_rv<\/code> in <a href=\"#pymc3_model\">pymc3_model<\/a>. Those terms together form a Theano graph for the entirety of <span class=\"math inline\">\\(\\eqref{eq:norm_conv_model}\\)<\/span>.<\/p>\n<p>Other aspects of the model are implicitly stored in the <a href=\"https:\/\/docs.python.org\/3.6\/reference\/compound_stmts.html#with\">Python context object<\/a> <code>conv_model<\/code>. For example, the context object tracks the model\u2019s log likelihood function when some variables are designated as \u201cobserved\u201d\u2013i.e.\u00a0associated with sample data. In this example, we haven\u2019t specified an observed variable, so the context object won\u2019t be immediately useful.<\/p>\n<div class=\"remark\" data-markdown=\"\">\n<p>In what follows, we\u2019ll briefly introduce the internal aspects of PyMC3 that are immediately relevant for the topics addressed here; otherwise, see <a href=\"https:\/\/docs.pymc.io\/developer_guide.html\">the PyMC3 developer\u2019s guide<\/a> for an explanation of its design and internal workings.<\/p>\n<\/div>\n<p>The terms <code>X_rv<\/code>, <code>Y_rv<\/code> are derived from both a PyMC3 <a href=\"https:\/\/github.com\/pymc-devs\/pymc3\/blob\/v3.3\/pymc3\/model.py#L151\"><code>Factor<\/code><\/a> class and the standard Theano <code>TensorVariable<\/code>, as illustrated in the output of <a href=\"#pymc3_mro\">pymc3_mro<\/a>. However, the convolution term <code>Z_rv<\/code> is not a PyMC3 random variable; in other words, it does <strong>not<\/strong> implement the PyMC3 <code>Factor<\/code> class, but it <strong>is<\/strong> a Theano <code>TensorVariable<\/code>.<\/p>\n<div class=\"sourceCode\" id=\"cb3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb3-1\"><a href=\"#cb3-1\" aria-hidden=\"true\"><\/a>pprint({<span class=\"st\">&#39;Y_rv&#39;<\/span>: <span class=\"bu\">type<\/span>(Y_rv).mro()})<\/span>\n<span id=\"cb3-2\"><a href=\"#cb3-2\" aria-hidden=\"true\"><\/a>pprint({<span class=\"st\">&#39;Z_rv&#39;<\/span>: <span class=\"bu\">type<\/span>(Z_rv).mro()})<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb4-1\"><a href=\"#cb4-1\" aria-hidden=\"true\"><\/a>{<span class=\"st\">&#39;Y_rv&#39;<\/span>: [<span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;pymc3.model.FreeRV&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-2\"><a href=\"#cb4-2\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;pymc3.model.Factor&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-3\"><a href=\"#cb4-3\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.tensor.var.TensorVariable&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-4\"><a href=\"#cb4-4\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.tensor.var._tensor_py_operators&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-5\"><a href=\"#cb4-5\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.gof.graph.Variable&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-6\"><a href=\"#cb4-6\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.gof.graph.Node&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-7\"><a href=\"#cb4-7\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.gof.utils.object2&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-8\"><a href=\"#cb4-8\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;object&#39;<\/span><span class=\"op\">&gt;<\/span>]}<\/span>\n<span id=\"cb4-9\"><a href=\"#cb4-9\" aria-hidden=\"true\"><\/a>{<span class=\"st\">&#39;Z_rv&#39;<\/span>: [<span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.tensor.var.TensorVariable&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-10\"><a href=\"#cb4-10\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.tensor.var._tensor_py_operators&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-11\"><a href=\"#cb4-11\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.gof.graph.Variable&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-12\"><a href=\"#cb4-12\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.gof.graph.Node&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-13\"><a href=\"#cb4-13\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;theano.gof.utils.object2&#39;<\/span><span class=\"op\">&gt;<\/span>,<\/span>\n<span id=\"cb4-14\"><a href=\"#cb4-14\" aria-hidden=\"true\"><\/a>          <span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;object&#39;<\/span><span class=\"op\">&gt;<\/span>]}<\/span>\n<span id=\"cb4-15\"><a href=\"#cb4-15\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p>While PyMC3 doesn\u2019t <strong>need<\/strong> to support convolution, so much within Bayesian statistics, MCMC, and probabilistic programming rely on it in some way. It\u2019s an intrinsic part of the algebra(s) implied by the use of probability theory and essential to the implementation of more sophisticated models and sampler optimizations\u2013in at least the same way as symbolic differentiation. Here, the question isn\u2019t whether these algebraic properties are explicitly supported, but how easily they can be implemented when necessary.<\/p>\n<p>As it appears, all work related to probability theory or the algebra of random variables is performed implicitly within the context of Theano and mostly detached from the model-level meta information provided by the PyMC3 abstractions. This means that the linear\/tensor algebra supported by Theano is the primary level of abstraction.<\/p>\n<p>More specifically, one purpose of the PyMC3 probability theory abstractions (e.g.\u00a0random variable classes\u2014<code>FreeRV<\/code> and <code>ObservedRV<\/code>, distributions and their likelihoods, etc.) is to associate a PyMC3 <a href=\"https:\/\/github.com\/pymc-devs\/pymc3\/blob\/v3.3\/pymc3\/distributions\/distribution.py#L18\"><code>Distribution<\/code><\/a> object with a Theano <code>TensorVariable<\/code>. This connection is made through a <code>distribution<\/code> attribute<\/p>\n<div class=\"sourceCode\" id=\"cb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb5-1\"><a href=\"#cb5-1\" aria-hidden=\"true\"><\/a>pprint(Y_rv.distribution)<\/span>\n<span id=\"cb5-2\"><a href=\"#cb5-2\" aria-hidden=\"true\"><\/a>pprint(X_rv.distribution)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb6-1\"><a href=\"#cb6-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&lt;<\/span>pymc3.distributions.continuous.Normal <span class=\"bu\">object<\/span> at <span class=\"bn\">0x7f5d1796e908<\/span><span class=\"op\">&gt;<\/span><\/span>\n<span id=\"cb6-2\"><a href=\"#cb6-2\" aria-hidden=\"true\"><\/a><span class=\"op\">&lt;<\/span>pymc3.distributions.continuous.Normal <span class=\"bu\">object<\/span> at <span class=\"bn\">0x7f5d17b10208<\/span><span class=\"op\">&gt;<\/span><\/span>\n<span id=\"cb6-3\"><a href=\"#cb6-3\" aria-hidden=\"true\"><\/a><\/span><\/code><\/pre><\/div>\n<p><code>Distribution<\/code> objects loosely represents a measure, holding distribution parameters (e.g.\u00a0mean and standard deviation <code>mu_X<\/code>, <code>sd_X<\/code>) and constructing the appropriate conditional log likelihoods\u2013from which the model\u2019s total log likelihood is later derived. The distribution parameters and log-likelihoods are Theano <code>TensorVariable<\/code>s\u2013including other PyMC3-derived <code>TensorVariable<\/code>s corresponding to (the output of) random variables.<\/p>\n<p>Again, since objects derived via algebraic manipulation of random variables are not themselves random variables within the framework of PyMC3, objects like <code>Z_rv<\/code> do not have a <code>Distribution<\/code> attribute. The mechanics described here provide a means for supporting terms like <code>Z_rv<\/code> with the appropriate \u201cderived\u201d distribution.<\/p>\n<p>To start, we\u2019ll have to dive deeper into the graph aspects of Theano.<\/p>\n<\/section>\n<section id=\"random-variables-in-graphs\" class=\"level1\">\n<h1>Random Variables in Graphs<\/h1>\n<p>The Theano graph representing <span class=\"math inline\">\\(\\eqref{eq:norm_conv_model}\\)<\/span> consists of linear\/tensor algebra operations\u2013under the interface of <code>theano.gof.op.Op<\/code>\u2013on <code>TensorVariable<\/code>s. For our example in <a href=\"#pymc3_model\">pymc3_model<\/a>, a textual representation is given in <a href=\"#Z_rv_debugprint\">Z_rv_debugprint<\/a> and a graphical form in <a href=\"#fig:norm_sum_graph\">fig:norm_sum_graph<\/a>.<\/p>\n<div class=\"sourceCode\" id=\"cb7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb7-1\"><a href=\"#cb7-1\" aria-hidden=\"true\"><\/a>tt.printing.debugprint(Z_rv)<\/span><\/code><\/pre><\/div>\n<pre class=\"text\"><code>Elemwise{add,no_inplace} [id A] &#39;&#39;\n |X_rv [id B]\n |Y_rv [id C]\n\n<\/code><\/pre>\n<figure id=\"fig:norm_sum_graph\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/Z_rv.png\" title=\"fig:\" alt=\"Graph of Z_rv for the PyMC3 model in 2. \" \/>\n<figcaption>\nGraph of <code>Z_rv<\/code> for the PyMC3 model in <a href=\"#org7c56540\">2<\/a>.\n<\/figcaption>\n<\/figure>\n<p>At present, PyMC3 (version <code>print(pm.__version__)<\/code> 3.3) does not make very consistent use of Theano\u2019s graph objects. For instance, notice how the dependent parameters <code>mu_X<\/code> and <code>sd_X<\/code> are not present in the model\u2019s graph (e.g.\u00a0<a href=\"#fig:norm_sum_graph\">fig:norm_sum_graph<\/a>). We know that <code>X_rv<\/code> and <code>Y_rv<\/code> are PyMC3 random variables, but what we see in the graph is only their representations as sampled scalar\/vector\/matrix\/tensor values. In other words, where <span class=\"math inline\">\\(X\\)<\/span>, <span class=\"math inline\">\\(Y\\)<\/span> symbolize random variables and <span class=\"math inline\">\\(x \\sim X\\)<\/span>, <span class=\"math inline\">\\(y \\sim Y\\)<\/span> their samples, we have a graph expressing only <span class=\"math inline\">\\(z = x + y\\)<\/span>.<\/p>\n<p>What we need for higher-level work is a graph of <span class=\"math inline\">\\(Z = X + Y\\)<\/span> that includes every term involved. This is true for graphs representing a model\u2019s measure\/log-likelihood <strong>and<\/strong> its sampled values. The former is essentially covered by the log-likelihood graphs we can already produce using the PyMC3 model objects. It\u2019s the latter that we\u2019ll establish here, since it sets the stage for applications of numerous techniques in statistics and probability theory.<\/p>\n<p>One way to produce graphs that represent the full probabilistic model is to formalize the notion of random variables using the Theano API. Basically, if we want to include the relationships between distribution parameters and sampled variables, <strong>we need an <code>Op<\/code> that represents random variables and\/or the act of sampling<\/strong>. <code>theano.tensor.raw_random.RandomFunction<\/code> does exactly this; although it represents the concept of a sampling action and not exactly a random measure.<\/p>\n<p>Nonetheless, using <code>RandomFunction<\/code>, we can replace nodes corresponding to PyMC3 random variables with newly constructed <code>Op<\/code> nodes.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>We can produce the types of graphs described above through conversion of existing PyMC3 models.<\/p>\n<p>In order to perform any manipulations on our model\u2019s graph, we need to create a Theano <code>theano.gof.FunctionGraph<\/code> object. We create a utility function in <a href=\"#model_graph_fn\">model_graph_fn<\/a> that constructs a <code>FunctionGraph<\/code> from a PyMC3 model.<\/p>\n<div class=\"sourceCode\" id=\"cb9\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb9-1\"><a href=\"#cb9-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.gof <span class=\"im\">import<\/span> FunctionGraph, Feature, NodeFinder<\/span>\n<span id=\"cb9-2\"><a href=\"#cb9-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.gof.graph <span class=\"im\">import<\/span> inputs <span class=\"im\">as<\/span> tt_inputs, clone_get_equiv<\/span>\n<span id=\"cb9-3\"><a href=\"#cb9-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-4\"><a href=\"#cb9-4\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> model_graph(pymc_model, derived_vars<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"cb9-5\"><a href=\"#cb9-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-6\"><a href=\"#cb9-6\" aria-hidden=\"true\"><\/a>    model <span class=\"op\">=<\/span> pm.modelcontext(pymc_model)<\/span>\n<span id=\"cb9-7\"><a href=\"#cb9-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-8\"><a href=\"#cb9-8\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> derived_vars <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb9-9\"><a href=\"#cb9-9\" aria-hidden=\"true\"><\/a>        model_outs <span class=\"op\">=<\/span> derived_vars<\/span>\n<span id=\"cb9-10\"><a href=\"#cb9-10\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"cb9-11\"><a href=\"#cb9-11\" aria-hidden=\"true\"><\/a>        model_outs <span class=\"op\">=<\/span> [o.logpt <span class=\"cf\">for<\/span> o <span class=\"kw\">in<\/span> model.observed_RVs]<\/span>\n<span id=\"cb9-12\"><a href=\"#cb9-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-13\"><a href=\"#cb9-13\" aria-hidden=\"true\"><\/a>    model_inputs <span class=\"op\">=<\/span> [inp <span class=\"cf\">for<\/span> inp <span class=\"kw\">in<\/span> tt_inputs(model_outs)]<\/span>\n<span id=\"cb9-14\"><a href=\"#cb9-14\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># if not isinstance(inp, theano.gof.graph.Constant)]<\/span><\/span>\n<span id=\"cb9-15\"><a href=\"#cb9-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-16\"><a href=\"#cb9-16\" aria-hidden=\"true\"><\/a>    model_memo <span class=\"op\">=<\/span> clone_get_equiv(model_inputs, model_outs,<\/span>\n<span id=\"cb9-17\"><a href=\"#cb9-17\" aria-hidden=\"true\"><\/a>                                 copy_orphans<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb9-18\"><a href=\"#cb9-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-19\"><a href=\"#cb9-19\" aria-hidden=\"true\"><\/a>    fg_features <span class=\"op\">=<\/span> [<\/span>\n<span id=\"cb9-20\"><a href=\"#cb9-20\" aria-hidden=\"true\"><\/a>        NodeFinder(),<\/span>\n<span id=\"cb9-21\"><a href=\"#cb9-21\" aria-hidden=\"true\"><\/a>    ]<\/span>\n<span id=\"cb9-22\"><a href=\"#cb9-22\" aria-hidden=\"true\"><\/a>    model_fg <span class=\"op\">=<\/span> FunctionGraph([model_memo[i] <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> model_inputs],<\/span>\n<span id=\"cb9-23\"><a href=\"#cb9-23\" aria-hidden=\"true\"><\/a>                             [model_memo[i] <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> model_outs],<\/span>\n<span id=\"cb9-24\"><a href=\"#cb9-24\" aria-hidden=\"true\"><\/a>                             clone<span class=\"op\">=<\/span><span class=\"va\">False<\/span>, features<span class=\"op\">=<\/span>fg_features)<\/span>\n<span id=\"cb9-25\"><a href=\"#cb9-25\" aria-hidden=\"true\"><\/a>    model_fg.memo <span class=\"op\">=<\/span> model_memo<\/span>\n<span id=\"cb9-26\"><a href=\"#cb9-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-27\"><a href=\"#cb9-27\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> model_fg<\/span><\/code><\/pre><\/div>\n<p>When cloning the graph with <code>theano.gof.graph.clone_get_equiv<\/code> in <code>model_graph<\/code>, we lose the <code>FreeRV.distribution<\/code> attribute\u2013among others. Since those attributes hold all the information required to construct our <code>RandomFunction<\/code> <code>Op<\/code>s, we\u2019ll need to find a way to preserve it.<\/p>\n<p>This can be accomplished by overriding the default Theano <code>clone<\/code> function inherited by the PyMC3 random variable classes.<\/p>\n<div class=\"sourceCode\" id=\"cb10\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb10-1\"><a href=\"#cb10-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> types<\/span>\n<span id=\"cb10-2\"><a href=\"#cb10-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> copy <span class=\"im\">import<\/span> copy<\/span>\n<span id=\"cb10-3\"><a href=\"#cb10-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-4\"><a href=\"#cb10-4\" aria-hidden=\"true\"><\/a>pymc_rv_types <span class=\"op\">=<\/span> (pm.model.FreeRV, pm.model.ObservedRV, pm.model.TransformedRV)<\/span>\n<span id=\"cb10-5\"><a href=\"#cb10-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-6\"><a href=\"#cb10-6\" aria-hidden=\"true\"><\/a>pymc_rv_attrs <span class=\"op\">=<\/span> [<span class=\"st\">&#39;dshape&#39;<\/span>, <span class=\"st\">&#39;dsize&#39;<\/span>, <span class=\"st\">&#39;distribution&#39;<\/span>, <span class=\"st\">&#39;logp_elemwiset&#39;<\/span>,<\/span>\n<span id=\"cb10-7\"><a href=\"#cb10-7\" aria-hidden=\"true\"><\/a>                 <span class=\"st\">&#39;logp_sum_unscaledt&#39;<\/span>, <span class=\"st\">&#39;logp_nojac_unscaledt&#39;<\/span>, <span class=\"st\">&#39;total_size&#39;<\/span>,<\/span>\n<span id=\"cb10-8\"><a href=\"#cb10-8\" aria-hidden=\"true\"><\/a>                 <span class=\"st\">&#39;scaling&#39;<\/span>, <span class=\"st\">&#39;missing_values&#39;<\/span>]<\/span>\n<span id=\"cb10-9\"><a href=\"#cb10-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-10\"><a href=\"#cb10-10\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> rv_type <span class=\"kw\">in<\/span> pymc_rv_types:<\/span>\n<span id=\"cb10-11\"><a href=\"#cb10-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-12\"><a href=\"#cb10-12\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">hasattr<\/span>(rv_type, <span class=\"st\">&#39;__clone&#39;<\/span>):<\/span>\n<span id=\"cb10-13\"><a href=\"#cb10-13\" aria-hidden=\"true\"><\/a>        rv_type.__clone <span class=\"op\">=<\/span> rv_type.clone<\/span>\n<span id=\"cb10-14\"><a href=\"#cb10-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-15\"><a href=\"#cb10-15\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> pymc_rv_clone(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"cb10-16\"><a href=\"#cb10-16\" aria-hidden=\"true\"><\/a>        cp <span class=\"op\">=<\/span> rv_type.__clone(<span class=\"va\">self<\/span>)<\/span>\n<span id=\"cb10-17\"><a href=\"#cb10-17\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">for<\/span> attr <span class=\"kw\">in<\/span> pymc_rv_attrs:<\/span>\n<span id=\"cb10-18\"><a href=\"#cb10-18\" aria-hidden=\"true\"><\/a>            <span class=\"bu\">setattr<\/span>(cp, attr, copy(<span class=\"bu\">getattr<\/span>(<span class=\"va\">self<\/span>, attr, <span class=\"va\">None<\/span>)))<\/span>\n<span id=\"cb10-19\"><a href=\"#cb10-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-20\"><a href=\"#cb10-20\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Allow a cloned rv to inherit the context&#39;s model?<\/span><\/span>\n<span id=\"cb10-21\"><a href=\"#cb10-21\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># try:<\/span><\/span>\n<span id=\"cb10-22\"><a href=\"#cb10-22\" aria-hidden=\"true\"><\/a>        <span class=\"co\">#     cp.model = pm.Model.get_context()<\/span><\/span>\n<span id=\"cb10-23\"><a href=\"#cb10-23\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># except TypeError:<\/span><\/span>\n<span id=\"cb10-24\"><a href=\"#cb10-24\" aria-hidden=\"true\"><\/a>        <span class=\"co\">#     pass<\/span><\/span>\n<span id=\"cb10-25\"><a href=\"#cb10-25\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-26\"><a href=\"#cb10-26\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"bu\">getattr<\/span>(cp, <span class=\"st\">&#39;model&#39;<\/span>, <span class=\"va\">None<\/span>) <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb10-27\"><a href=\"#cb10-27\" aria-hidden=\"true\"><\/a>            cp.model <span class=\"op\">=<\/span> <span class=\"bu\">getattr<\/span>(<span class=\"va\">self<\/span>, <span class=\"st\">&#39;model&#39;<\/span>, <span class=\"va\">None<\/span>)<\/span>\n<span id=\"cb10-28\"><a href=\"#cb10-28\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-29\"><a href=\"#cb10-29\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> cp<\/span>\n<span id=\"cb10-30\"><a href=\"#cb10-30\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-31\"><a href=\"#cb10-31\" aria-hidden=\"true\"><\/a>    rv_type.clone <span class=\"op\">=<\/span> pymc_rv_clone<\/span><\/code><\/pre><\/div>\n<p>Now, we can produce a proper <code>FunctionGraph<\/code> from our PyMC3 model.<\/p>\n<div class=\"sourceCode\" id=\"cb11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb11-1\"><a href=\"#cb11-1\" aria-hidden=\"true\"><\/a>Z_fgraph_tt <span class=\"op\">=<\/span> model_graph(conv_model, derived_vars<span class=\"op\">=<\/span>[Z_rv])<\/span><\/code><\/pre><\/div>\n<p>With a <code>FunctionGraph<\/code> at our disposal, we can use the graph manipulation tools provided by Theano to replace the PyMC3 <code>TensorVariable<\/code>s used to represent random variables with corresponding Theano <code>RandomFunction<\/code>s that represent the <strong>act of sampling<\/strong> to produce said random variables.<\/p>\n<p>We can use a simple mapping between Pymc3 random variable nodes and <code>RandomFunction<\/code> to specify the desired replacements. Fortunately, this isn\u2019t too difficult, since <code>RandomFunction<\/code> already supports numerous Numpy-provided random distributions\u2013covering much of the same ground as the PyMC3 distributions. Otherwise, the rest of the work involves mapping distribution parameters.<\/p>\n<p>Also, <code>RandomFunction<\/code> requires a <code>RandomStream<\/code>, which it uses to track the sampler state. For our purely symbolic purposes, the stream object is not immediately useful, but it does\u2013in the end\u2013provide a sample-able graph as a nice side-effect. We demonstrate the PyMC3 random variable-to-<code>RandomFunction<\/code> translation in <a href=\"#random_op_mapping\">random_op_mapping<\/a> using only a single mapping.<\/p>\n<div class=\"sourceCode\" id=\"cb12\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb12-1\"><a href=\"#cb12-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.tensor.raw_random <span class=\"im\">import<\/span> RandomFunction<\/span>\n<span id=\"cb12-2\"><a href=\"#cb12-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-3\"><a href=\"#cb12-3\" aria-hidden=\"true\"><\/a>pymc_theano_rv_equivs <span class=\"op\">=<\/span> {<\/span>\n<span id=\"cb12-4\"><a href=\"#cb12-4\" aria-hidden=\"true\"><\/a>    pm.Normal:<\/span>\n<span id=\"cb12-5\"><a href=\"#cb12-5\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">lambda<\/span> dist, rand_state:<\/span>\n<span id=\"cb12-6\"><a href=\"#cb12-6\" aria-hidden=\"true\"><\/a>    tt.raw_random.normal(rand_state, dist.shape.tolist(), dist.mu, dist.sd),<\/span>\n<span id=\"cb12-7\"><a href=\"#cb12-7\" aria-hidden=\"true\"><\/a>}<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb13\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb13-1\"><a href=\"#cb13-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> create_theano_rvs(fgraph, clone<span class=\"op\">=<\/span><span class=\"va\">True<\/span>, rand_state<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"cb13-2\"><a href=\"#cb13-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot;Replace PyMC3 random variables with `RandomFunction` Ops.<\/span><\/span>\n<span id=\"cb13-3\"><a href=\"#cb13-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-4\"><a href=\"#cb13-4\" aria-hidden=\"true\"><\/a><span class=\"co\">    <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: Could use a Theano graph `Feature` to trace--or even<\/span><\/span>\n<span id=\"cb13-5\"><a href=\"#cb13-5\" aria-hidden=\"true\"><\/a><span class=\"co\">    replace--random variables.<\/span><\/span>\n<span id=\"cb13-6\"><a href=\"#cb13-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-7\"><a href=\"#cb13-7\" aria-hidden=\"true\"><\/a><span class=\"co\">    Parameters<\/span><\/span>\n<span id=\"cb13-8\"><a href=\"#cb13-8\" aria-hidden=\"true\"><\/a><span class=\"co\">    ----------<\/span><\/span>\n<span id=\"cb13-9\"><a href=\"#cb13-9\" aria-hidden=\"true\"><\/a><span class=\"co\">    fgraph : FunctionGraph<\/span><\/span>\n<span id=\"cb13-10\"><a href=\"#cb13-10\" aria-hidden=\"true\"><\/a><span class=\"co\">    A graph containing PyMC3 random variables.<\/span><\/span>\n<span id=\"cb13-11\"><a href=\"#cb13-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-12\"><a href=\"#cb13-12\" aria-hidden=\"true\"><\/a><span class=\"co\">    clone: bool, optional<\/span><\/span>\n<span id=\"cb13-13\"><a href=\"#cb13-13\" aria-hidden=\"true\"><\/a><span class=\"co\">    Clone the original graph.<\/span><\/span>\n<span id=\"cb13-14\"><a href=\"#cb13-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-15\"><a href=\"#cb13-15\" aria-hidden=\"true\"><\/a><span class=\"co\">    rand_state : RandomStateType, optional<\/span><\/span>\n<span id=\"cb13-16\"><a href=\"#cb13-16\" aria-hidden=\"true\"><\/a><span class=\"co\">    The Theano random state.<\/span><\/span>\n<span id=\"cb13-17\"><a href=\"#cb13-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-18\"><a href=\"#cb13-18\" aria-hidden=\"true\"><\/a><span class=\"co\">    Returns<\/span><\/span>\n<span id=\"cb13-19\"><a href=\"#cb13-19\" aria-hidden=\"true\"><\/a><span class=\"co\">    -------<\/span><\/span>\n<span id=\"cb13-20\"><a href=\"#cb13-20\" aria-hidden=\"true\"><\/a><span class=\"co\">    out : A cloned graph with random variables replaced and a `memo` attribute.<\/span><\/span>\n<span id=\"cb13-21\"><a href=\"#cb13-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-22\"><a href=\"#cb13-22\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb13-23\"><a href=\"#cb13-23\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> clone:<\/span>\n<span id=\"cb13-24\"><a href=\"#cb13-24\" aria-hidden=\"true\"><\/a>        fgraph_, fgraph_memo_ <span class=\"op\">=<\/span> fgraph.clone_get_equiv(attach_feature<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb13-25\"><a href=\"#cb13-25\" aria-hidden=\"true\"><\/a>        fgraph_.memo <span class=\"op\">=<\/span> fgraph_memo_<\/span>\n<span id=\"cb13-26\"><a href=\"#cb13-26\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"cb13-27\"><a href=\"#cb13-27\" aria-hidden=\"true\"><\/a>        fgraph_ <span class=\"op\">=<\/span> fgraph<\/span>\n<span id=\"cb13-28\"><a href=\"#cb13-28\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-29\"><a href=\"#cb13-29\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> rand_state <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb13-30\"><a href=\"#cb13-30\" aria-hidden=\"true\"><\/a>        rand_state <span class=\"op\">=<\/span> theano.shared(np.random.RandomState())<\/span>\n<span id=\"cb13-31\"><a href=\"#cb13-31\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-32\"><a href=\"#cb13-32\" aria-hidden=\"true\"><\/a>    fgraph_replacements <span class=\"op\">=<\/span> {}<\/span>\n<span id=\"cb13-33\"><a href=\"#cb13-33\" aria-hidden=\"true\"><\/a>    fgraph_new_inputs <span class=\"op\">=<\/span> <span class=\"bu\">set<\/span>()<\/span>\n<span id=\"cb13-34\"><a href=\"#cb13-34\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-35\"><a href=\"#cb13-35\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">for<\/span> old_rv_i, old_rv <span class=\"kw\">in<\/span> <span class=\"bu\">enumerate<\/span>(fgraph_.inputs):<\/span>\n<span id=\"cb13-36\"><a href=\"#cb13-36\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(old_rv, pymc_rv_types):<\/span>\n<span id=\"cb13-37\"><a href=\"#cb13-37\" aria-hidden=\"true\"><\/a>            dist <span class=\"op\">=<\/span> old_rv.distribution<\/span>\n<span id=\"cb13-38\"><a href=\"#cb13-38\" aria-hidden=\"true\"><\/a>            theano_rv_op <span class=\"op\">=<\/span> pymc_theano_rv_equivs.get(<span class=\"bu\">type<\/span>(dist), <span class=\"va\">None<\/span>)<\/span>\n<span id=\"cb13-39\"><a href=\"#cb13-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-40\"><a href=\"#cb13-40\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">if<\/span> theano_rv_op <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb13-41\"><a href=\"#cb13-41\" aria-hidden=\"true\"><\/a>                rng_tt, new_rv <span class=\"op\">=<\/span> theano_rv_op(dist, rand_state)<\/span>\n<span id=\"cb13-42\"><a href=\"#cb13-42\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-43\"><a href=\"#cb13-43\" aria-hidden=\"true\"><\/a>                <span class=\"co\"># Keep track of our replacements<\/span><\/span>\n<span id=\"cb13-44\"><a href=\"#cb13-44\" aria-hidden=\"true\"><\/a>                fgraph_replacements[old_rv] <span class=\"op\">=<\/span> new_rv<\/span>\n<span id=\"cb13-45\"><a href=\"#cb13-45\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-46\"><a href=\"#cb13-46\" aria-hidden=\"true\"><\/a>                new_rv.name <span class=\"op\">=<\/span> <span class=\"st\">&#39;~<\/span><span class=\"sc\">{}<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(old_rv.name)<\/span>\n<span id=\"cb13-47\"><a href=\"#cb13-47\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-48\"><a href=\"#cb13-48\" aria-hidden=\"true\"><\/a>                new_rv_inputs <span class=\"op\">=<\/span> [i <span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> tt_inputs([new_rv])]<\/span>\n<span id=\"cb13-49\"><a href=\"#cb13-49\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-50\"><a href=\"#cb13-50\" aria-hidden=\"true\"><\/a>                fgraph_new_inputs.update(new_rv_inputs)<\/span>\n<span id=\"cb13-51\"><a href=\"#cb13-51\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"cb13-52\"><a href=\"#cb13-52\" aria-hidden=\"true\"><\/a>                <span class=\"bu\">print<\/span>(<span class=\"st\">&#39;<\/span><span class=\"sc\">{}<\/span><span class=\"st\"> could not be mapped to a random function&#39;<\/span>.<span class=\"bu\">format<\/span>(old_rv))<\/span>\n<span id=\"cb13-53\"><a href=\"#cb13-53\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-54\"><a href=\"#cb13-54\" aria-hidden=\"true\"><\/a>    fgraph_new_inputs_memo <span class=\"op\">=<\/span> theano.gof.graph.clone_get_equiv(<\/span>\n<span id=\"cb13-55\"><a href=\"#cb13-55\" aria-hidden=\"true\"><\/a>        fgraph_new_inputs, <span class=\"bu\">list<\/span>(fgraph_replacements.values()),<\/span>\n<span id=\"cb13-56\"><a href=\"#cb13-56\" aria-hidden=\"true\"><\/a>        copy_orphans<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb13-57\"><a href=\"#cb13-57\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-58\"><a href=\"#cb13-58\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Update our maps and new inputs to use the cloned objects<\/span><\/span>\n<span id=\"cb13-59\"><a href=\"#cb13-59\" aria-hidden=\"true\"><\/a>    fgraph_replacements <span class=\"op\">=<\/span> {old_rv: fgraph_new_inputs_memo.pop(new_rv)<\/span>\n<span id=\"cb13-60\"><a href=\"#cb13-60\" aria-hidden=\"true\"><\/a>                           <span class=\"cf\">for<\/span> old_rv, new_rv <span class=\"kw\">in<\/span> fgraph_replacements.items()}<\/span>\n<span id=\"cb13-61\"><a href=\"#cb13-61\" aria-hidden=\"true\"><\/a>    fgraph_new_inputs <span class=\"op\">=<\/span> <span class=\"bu\">set<\/span>(<span class=\"bu\">map<\/span>(fgraph_new_inputs_memo.pop, fgraph_new_inputs))<\/span>\n<span id=\"cb13-62\"><a href=\"#cb13-62\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-63\"><a href=\"#cb13-63\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># What remains in `fgraph_new_inputs_memo` are the nodes between our desired<\/span><\/span>\n<span id=\"cb13-64\"><a href=\"#cb13-64\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># inputs (i.e. the random variables&#39; distribution parameters) and the old inputs<\/span><\/span>\n<span id=\"cb13-65\"><a href=\"#cb13-65\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># (i.e. Theano `Variable`s corresponding to a sample of said random variables).<\/span><\/span>\n<span id=\"cb13-66\"><a href=\"#cb13-66\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-67\"><a href=\"#cb13-67\" aria-hidden=\"true\"><\/a>    _ <span class=\"op\">=<\/span> [fgraph_.add_input(new_in) <span class=\"cf\">for<\/span> new_in <span class=\"kw\">in<\/span> fgraph_new_inputs<\/span>\n<span id=\"cb13-68\"><a href=\"#cb13-68\" aria-hidden=\"true\"><\/a>         <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(new_in, theano.gof.graph.Constant)]<\/span>\n<span id=\"cb13-69\"><a href=\"#cb13-69\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-70\"><a href=\"#cb13-70\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># _ = [fgraph_.add_input(new_in) for new_in in fgraph_new_inputs_memo.values()]<\/span><\/span>\n<span id=\"cb13-71\"><a href=\"#cb13-71\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-72\"><a href=\"#cb13-72\" aria-hidden=\"true\"><\/a>    fgraph_.replace_all(fgraph_replacements.items())<\/span>\n<span id=\"cb13-73\"><a href=\"#cb13-73\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-74\"><a href=\"#cb13-74\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># The replace method apparently doesn&#39;t remove the old inputs...<\/span><\/span>\n<span id=\"cb13-75\"><a href=\"#cb13-75\" aria-hidden=\"true\"><\/a>    _ <span class=\"op\">=<\/span> [fgraph_.inputs.remove(old_rv) <span class=\"cf\">for<\/span> old_rv <span class=\"kw\">in<\/span> fgraph_replacements.keys()]<\/span>\n<span id=\"cb13-76\"><a href=\"#cb13-76\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-77\"><a href=\"#cb13-77\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> fgraph_<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb14\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb14-1\"><a href=\"#cb14-1\" aria-hidden=\"true\"><\/a>Z_fgraph_rv_tt <span class=\"op\">=<\/span> create_theano_rvs(Z_fgraph_tt)<\/span>\n<span id=\"cb14-2\"><a href=\"#cb14-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-3\"><a href=\"#cb14-3\" aria-hidden=\"true\"><\/a>tt.printing.debugprint(Z_fgraph_rv_tt)<\/span><\/code><\/pre><\/div>\n<pre class=\"text\"><code>Elemwise{add,no_inplace} [id A] &#39;&#39;   10\n |RandomFunction{normal}.1 [id B] &#39;~X_rv&#39;   9\n | |&lt;RandomStateType&gt; [id C]\n | |Elemwise{Cast{int64}} [id D] &#39;&#39;   8\n | | |MakeVector{dtype=&#39;int8&#39;} [id E] &#39;&#39;   7\n | |   |TensorConstant{1} [id F]\n | |mu_X [id G]\n | |Elemwise{mul,no_inplace} [id H] &#39;&#39;   6\n |   |InplaceDimShuffle{x} [id I] &#39;&#39;   5\n |   | |TensorConstant{1.0} [id J]\n |   |sd_X [id K]\n |RandomFunction{normal}.1 [id L] &#39;~Y_rv&#39;   4\n   |&lt;RandomStateType&gt; [id C]\n   |Elemwise{Cast{int64}} [id M] &#39;&#39;   3\n   | |MakeVector{dtype=&#39;int8&#39;} [id N] &#39;&#39;   2\n   |   |TensorConstant{1} [id F]\n   |mu_Y [id O]\n   |Elemwise{mul,no_inplace} [id P] &#39;&#39;   1\n     |InplaceDimShuffle{x} [id Q] &#39;&#39;   0\n     | |TensorConstant{1.0} [id J]\n     |sd_Y [id R]\n\n<\/code><\/pre>\n<figure id=\"fig:random_op_mapping_exa_graph\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/Z_fgraph_rv_tt.png\" title=\"fig:\" alt=\"Graph of Z = X + Y using an Op to represent sampling\/a random variable. \" \/>\n<figcaption>\nGraph of <span class=\"math inline\">\\(Z = X + Y\\)<\/span> using an <code>Op<\/code> to represent sampling\/a random variable.\n<\/figcaption>\n<\/figure>\n<\/div>\n<p>Illustrations of the transformed graphs given in <a href=\"#random_op_mapping_exa\">random_op_mapping_exa<\/a> and <a href=\"#fig:random_op_mapping_exa_graph\">fig:random_op_mapping_exa_graph<\/a> show the full extent of our simple example model and provide a context in which to perform higher-level manipulations.<\/p>\n<p>With a graph representing the relevant terms and relationships, we can implement the convolution simplification\/transformation\/optimization. For instance, as shown in <a href=\"#rv_find_nodes\">rv_find_nodes<\/a>, we can now easily query random function\/variable nodes in a graph.<\/p>\n<div class=\"sourceCode\" id=\"cb16\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb16-1\"><a href=\"#cb16-1\" aria-hidden=\"true\"><\/a><span class=\"co\"># Using a `FunctionGraph` &quot;feature&quot;<\/span><\/span>\n<span id=\"cb16-2\"><a href=\"#cb16-2\" aria-hidden=\"true\"><\/a>Z_fgraph_rv_tt.attach_feature(NodeFinder())<\/span>\n<span id=\"cb16-3\"><a href=\"#cb16-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-4\"><a href=\"#cb16-4\" aria-hidden=\"true\"><\/a><span class=\"co\"># The fixed `TensorType` is unnecessarily restrictive.<\/span><\/span>\n<span id=\"cb16-5\"><a href=\"#cb16-5\" aria-hidden=\"true\"><\/a>rf_normal_type <span class=\"op\">=<\/span> RandomFunction(<span class=\"st\">&#39;normal&#39;<\/span>, tt.TensorType(<span class=\"st\">&#39;float64&#39;<\/span>, (<span class=\"va\">True<\/span>,)))<\/span>\n<span id=\"cb16-6\"><a href=\"#cb16-6\" aria-hidden=\"true\"><\/a>rf_nodes <span class=\"op\">=<\/span> Z_fgraph_rv_tt.get_nodes(rf_normal_type)<\/span>\n<span id=\"cb16-7\"><a href=\"#cb16-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-8\"><a href=\"#cb16-8\" aria-hidden=\"true\"><\/a><span class=\"co\">#<\/span><\/span>\n<span id=\"cb16-9\"><a href=\"#cb16-9\" aria-hidden=\"true\"><\/a><span class=\"co\"># or, more generally,...<\/span><\/span>\n<span id=\"cb16-10\"><a href=\"#cb16-10\" aria-hidden=\"true\"><\/a><span class=\"co\">#<\/span><\/span>\n<span id=\"cb16-11\"><a href=\"#cb16-11\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> get_random_nodes(fgraph):<\/span>\n<span id=\"cb16-12\"><a href=\"#cb16-12\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> <span class=\"bu\">list<\/span>(<span class=\"bu\">filter<\/span>(<span class=\"kw\">lambda<\/span> x: <span class=\"bu\">isinstance<\/span>(x.op, RandomFunction), fgraph.apply_nodes))<\/span>\n<span id=\"cb16-13\"><a href=\"#cb16-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-14\"><a href=\"#cb16-14\" aria-hidden=\"true\"><\/a>rf_nodes <span class=\"op\">=<\/span> get_random_nodes(Z_fgraph_rv_tt)<\/span>\n<span id=\"cb16-15\"><a href=\"#cb16-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb16-16\"><a href=\"#cb16-16\" aria-hidden=\"true\"><\/a>tt.printing.debugprint(rf_nodes)<\/span><\/code><\/pre><\/div>\n<pre class=\"text\"><code>RandomFunction{normal}.0 [id A] &#39;&#39;\n |&lt;RandomStateType&gt; [id B]\n |Elemwise{Cast{int64}} [id C] &#39;&#39;\n | |MakeVector{dtype=&#39;int8&#39;} [id D] &#39;&#39;\n |   |TensorConstant{1} [id E]\n |mu_X [id F]\n |Elemwise{mul,no_inplace} [id G] &#39;&#39;\n   |InplaceDimShuffle{x} [id H] &#39;&#39;\n   | |TensorConstant{1.0} [id I]\n   |sd_X [id J]\nRandomFunction{normal}.1 [id A] &#39;~X_rv&#39;\nRandomFunction{normal}.0 [id K] &#39;&#39;\n |&lt;RandomStateType&gt; [id B]\n |Elemwise{Cast{int64}} [id L] &#39;&#39;\n | |MakeVector{dtype=&#39;int8&#39;} [id M] &#39;&#39;\n |   |TensorConstant{1} [id E]\n |mu_Y [id N]\n |Elemwise{mul,no_inplace} [id O] &#39;&#39;\n   |InplaceDimShuffle{x} [id P] &#39;&#39;\n   | |TensorConstant{1.0} [id I]\n   |sd_Y [id Q]\nRandomFunction{normal}.1 [id K] &#39;~Y_rv&#39;\n\n<\/code><\/pre>\n<\/section>\n<section id=\"performing-high-level-simplifications\" class=\"level1\">\n<h1>Performing High-level Simplifications<\/h1>\n<p>To apply optimizations like our simple convolution, we need to first identify the appropriate circumstances for its application. This means finding all sub-graphs for which we are able to replace existing nodes with a convolution node.<\/p>\n<p>Theano provides some <a href=\"https:\/\/en.wikipedia.org\/wiki\/Unification_(computer_science)\">unification<\/a> tools that facilitate the search component. We\u2019ll use those to implement an extremely restrictive form of our convolution.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>In <a href=\"#normal_conv_pattern\">normal_conv_pattern<\/a>, we create patterns for our expressions of interest that are unified against the elements in our graph and reified with a replacement expression. The patterns are expressed as tuples in a LISP-like fashion, e.g.\u00a0<code>(add, 1, 2)<\/code> corresponding to an unevaluated <code>add(1, 2)<\/code>.<\/p>\n<div class=\"sourceCode\" id=\"cb18\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb18-1\"><a href=\"#cb18-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> operator <span class=\"im\">import<\/span> attrgetter, itemgetter<\/span>\n<span id=\"cb18-2\"><a href=\"#cb18-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-3\"><a href=\"#cb18-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-4\"><a href=\"#cb18-4\" aria-hidden=\"true\"><\/a><span class=\"co\"># <\/span><span class=\"al\">FIXME<\/span><span class=\"co\">: This fixed `TensorType` specification is restrictive.<\/span><\/span>\n<span id=\"cb18-5\"><a href=\"#cb18-5\" aria-hidden=\"true\"><\/a>NormalRV <span class=\"op\">=<\/span> RandomFunction(<span class=\"st\">&#39;normal&#39;<\/span>, tt.TensorType(<span class=\"st\">&#39;float64&#39;<\/span>, (<span class=\"va\">True<\/span>,)))<\/span>\n<span id=\"cb18-6\"><a href=\"#cb18-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb18-7\"><a href=\"#cb18-7\" aria-hidden=\"true\"><\/a>norm_conv_pat_tt <span class=\"op\">=<\/span> [<\/span>\n<span id=\"cb18-8\"><a href=\"#cb18-8\" aria-hidden=\"true\"><\/a>    tt.gof.opt.PatternSub(<\/span>\n<span id=\"cb18-9\"><a href=\"#cb18-9\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Search expression pattern<\/span><\/span>\n<span id=\"cb18-10\"><a href=\"#cb18-10\" aria-hidden=\"true\"><\/a>      (tt.add,<\/span>\n<span id=\"cb18-11\"><a href=\"#cb18-11\" aria-hidden=\"true\"><\/a>       (NormalRV, <span class=\"st\">&#39;rs_x&#39;<\/span>, <span class=\"st\">&#39;shp_x&#39;<\/span>, <span class=\"st\">&#39;mu_x&#39;<\/span>, <span class=\"st\">&#39;sd_x&#39;<\/span>),<\/span>\n<span id=\"cb18-12\"><a href=\"#cb18-12\" aria-hidden=\"true\"><\/a>       (NormalRV, <span class=\"st\">&#39;rs_y&#39;<\/span>, <span class=\"st\">&#39;shp_y&#39;<\/span>, <span class=\"st\">&#39;mu_y&#39;<\/span>, <span class=\"st\">&#39;sd_y&#39;<\/span>),<\/span>\n<span id=\"cb18-13\"><a href=\"#cb18-13\" aria-hidden=\"true\"><\/a>      ),<\/span>\n<span id=\"cb18-14\"><a href=\"#cb18-14\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Replacement expression<\/span><\/span>\n<span id=\"cb18-15\"><a href=\"#cb18-15\" aria-hidden=\"true\"><\/a>      (itemgetter(<span class=\"dv\">1<\/span>), <span class=\"co\">#<\/span><\/span>\n<span id=\"cb18-16\"><a href=\"#cb18-16\" aria-hidden=\"true\"><\/a>       (NormalRV,<\/span>\n<span id=\"cb18-17\"><a href=\"#cb18-17\" aria-hidden=\"true\"><\/a>        <span class=\"st\">&#39;rs_x&#39;<\/span>,<\/span>\n<span id=\"cb18-18\"><a href=\"#cb18-18\" aria-hidden=\"true\"><\/a>        <span class=\"st\">&#39;shp_x&#39;<\/span>,<\/span>\n<span id=\"cb18-19\"><a href=\"#cb18-19\" aria-hidden=\"true\"><\/a>        (tt.add, <span class=\"st\">&#39;mu_x&#39;<\/span>, <span class=\"st\">&#39;mu_y&#39;<\/span>),<\/span>\n<span id=\"cb18-20\"><a href=\"#cb18-20\" aria-hidden=\"true\"><\/a>        (tt.sqrt, (tt.add, (tt.square, <span class=\"st\">&#39;sd_x&#39;<\/span>), (tt.square, <span class=\"st\">&#39;sd_y&#39;<\/span>))),<\/span>\n<span id=\"cb18-21\"><a href=\"#cb18-21\" aria-hidden=\"true\"><\/a>       )),<\/span>\n<span id=\"cb18-22\"><a href=\"#cb18-22\" aria-hidden=\"true\"><\/a>    ),<\/span>\n<span id=\"cb18-23\"><a href=\"#cb18-23\" aria-hidden=\"true\"><\/a>]<\/span><\/code><\/pre><\/div>\n<p>The <code>itemgetter(1)<\/code> applied to the replacement result is necessary because the <code>Op<\/code> <code>RandomFunction<\/code> returns two outputs and the second is the <code>TensorVariable<\/code> corresponding to a sample from that random variable.<\/p>\n<p>We also need to specify exactly how the pattern matching and replacement are to be performed for the entire graph. Do we match a single sum of normal distributions or all of them? What happens when a replacement creates yet another sum of normals that can be reduced?<\/p>\n<p>In this case, we choose to apply the operation until it reaches a fixed point, i.e.\u00a0until it produces no changes in the graph.<\/p>\n<div class=\"sourceCode\" id=\"cb19\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb19-1\"><a href=\"#cb19-1\" aria-hidden=\"true\"><\/a>norm_conv_opt_tt <span class=\"op\">=<\/span> tt.gof.opt.EquilibriumOptimizer(norm_conv_pat_tt,<\/span>\n<span id=\"cb19-2\"><a href=\"#cb19-2\" aria-hidden=\"true\"><\/a>                                                   max_use_ratio<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>)<\/span><\/code><\/pre><\/div>\n<p>Finally, we manually perform our Theano optimization.<\/p>\n<div class=\"sourceCode\" id=\"cb20\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb20-1\"><a href=\"#cb20-1\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> norm_conv_opt_tt.optimize(Z_fgraph_rv_tt)<\/span><\/code><\/pre><\/div>\n<\/div>\n<p>The optimization was applied within our graph, as evidenced by the single new <code>RandomFunction<\/code> node.<\/p>\n<div class=\"sourceCode\" id=\"cb21\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb21-1\"><a href=\"#cb21-1\" aria-hidden=\"true\"><\/a>tt.printing.debugprint(Z_fgraph_rv_tt)<\/span><\/code><\/pre><\/div>\n<pre class=\"text\"><code>RandomFunction{normal}.1 [id A] &#39;&#39;   11\n |&lt;RandomStateType&gt; [id B]\n |Elemwise{Cast{int64}} [id C] &#39;&#39;   10\n | |MakeVector{dtype=&#39;int8&#39;} [id D] &#39;&#39;   9\n |   |TensorConstant{1} [id E]\n |Elemwise{add,no_inplace} [id F] &#39;&#39;   8\n | |mu_X [id G]\n | |mu_Y [id H]\n |Elemwise{sqrt,no_inplace} [id I] &#39;&#39;   7\n   |Elemwise{add,no_inplace} [id J] &#39;&#39;   6\n     |Elemwise{sqr,no_inplace} [id K] &#39;&#39;   5\n     | |Elemwise{mul,no_inplace} [id L] &#39;&#39;   4\n     |   |InplaceDimShuffle{x} [id M] &#39;&#39;   3\n     |   | |TensorConstant{1.0} [id N]\n     |   |sd_X [id O]\n     |Elemwise{sqr,no_inplace} [id P] &#39;&#39;   2\n       |Elemwise{mul,no_inplace} [id Q] &#39;&#39;   1\n         |InplaceDimShuffle{x} [id R] &#39;&#39;   0\n         | |TensorConstant{1.0} [id N]\n         |sd_Y [id S]\n\n<\/code><\/pre>\n<p>Likewise, the resulting distribution terms in the optimized graph reflect the normal-normal random variable sum. Figure <a href=\"#fig:norm_sum_merge_graph\">fig:norm_sum_merge_graph<\/a> shows the graph under our optimization.<\/p>\n<div class=\"sourceCode\" id=\"cb23\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb23-1\"><a href=\"#cb23-1\" aria-hidden=\"true\"><\/a>conv_rv_tt <span class=\"op\">=<\/span> Z_fgraph_rv_tt.outputs[<span class=\"dv\">0<\/span>].owner<\/span>\n<span id=\"cb23-2\"><a href=\"#cb23-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-3\"><a href=\"#cb23-3\" aria-hidden=\"true\"><\/a>new_mu, new_sd <span class=\"op\">=<\/span> conv_rv_tt.inputs[<span class=\"dv\">2<\/span>:<span class=\"dv\">4<\/span>]<\/span>\n<span id=\"cb23-4\"><a href=\"#cb23-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-5\"><a href=\"#cb23-5\" aria-hidden=\"true\"><\/a><span class=\"co\"># Test values of the original means\/new moments&#39; inputs<\/span><\/span>\n<span id=\"cb23-6\"><a href=\"#cb23-6\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&#39;, &#39;<\/span>.join([<span class=\"st\">&#39;<\/span><span class=\"sc\">{}<\/span><span class=\"st\"> = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(tt.pprint(o), o.tag.test_value)<\/span>\n<span id=\"cb23-7\"><a href=\"#cb23-7\" aria-hidden=\"true\"><\/a>                 <span class=\"cf\">for<\/span> o <span class=\"kw\">in<\/span> new_mu.owner.inputs]))<\/span>\n<span id=\"cb23-8\"><a href=\"#cb23-8\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(tt.pprint(new_mu))<\/span>\n<span id=\"cb23-9\"><a href=\"#cb23-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-10\"><a href=\"#cb23-10\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&#39;, &#39;<\/span>.join([<span class=\"st\">&#39;<\/span><span class=\"sc\">{}<\/span><span class=\"st\"> = <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(tt.pprint(o), o.tag.test_value)<\/span>\n<span id=\"cb23-11\"><a href=\"#cb23-11\" aria-hidden=\"true\"><\/a>                 <span class=\"cf\">for<\/span> o <span class=\"kw\">in<\/span> new_sd.owner.inputs]))<\/span>\n<span id=\"cb23-12\"><a href=\"#cb23-12\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(tt.pprint(new_sd))<\/span>\n<span id=\"cb23-13\"><a href=\"#cb23-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb23-14\"><a href=\"#cb23-14\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(<span class=\"st\">&#39;mean: <\/span><span class=\"sc\">{}<\/span><span class=\"ch\">\\n<\/span><span class=\"st\">std. dev.: <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(<\/span>\n<span id=\"cb23-15\"><a href=\"#cb23-15\" aria-hidden=\"true\"><\/a>    new_mu.tag.test_value,<\/span>\n<span id=\"cb23-16\"><a href=\"#cb23-16\" aria-hidden=\"true\"><\/a>    new_sd.tag.test_value))<\/span><\/code><\/pre><\/div>\n<pre class=\"text\"><code>mu_X = [0.], mu_Y = [1.]\n(mu_X + mu_Y)\n(sqr((TensorConstant{1.0} * sd_X)) + sqr((TensorConstant{1.0} * sd_Y))) = [1.25]\nsqrt((sqr((TensorConstant{1.0} * sd_X)) + sqr((TensorConstant{1.0} * sd_Y))))\nmean: [1.]\nstd. dev.: [1.11803399]\n\n<\/code><\/pre>\n<figure id=\"fig:norm_sum_merge_graph\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/Z_fgraph_opt_tt.png\" title=\"fig:\" alt=\"Graph of merged normal variables. \" \/>\n<figcaption>\nGraph of merged normal variables.\n<\/figcaption>\n<\/figure>\n<\/section>\n<section id=\"generalizing-operations\" class=\"level1\">\n<h1>Generalizing Operations<\/h1>\n<p>Our example above was admittedly too simple; for instance, what about scale and location transformed variables? Most models\/graphs will consist of more elaborate manipulations of random variables, so it\u2019s necessary that we account for as many basic manipulations, as well.<\/p>\n<p>We start by adding an optimization that lifts scale parameters into the arguments\/parameters of a random variable. In other words,<\/p>\n<p><span class=\"math display\">\\[\\begin{gather*}\n  X \\sim N(\\mu, \\sigma^2) \\\\\n  Z = a X \\sim N\\left(a \\mu, (a \\sigma)^2\\right)\n  \\;.\n\\end{gather*}\\]<\/span><\/p>\n<div class=\"sourceCode\" id=\"cb25\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb25-1\"><a href=\"#cb25-1\" aria-hidden=\"true\"><\/a>norm_conv_pat_tt <span class=\"op\">+=<\/span> [<\/span>\n<span id=\"cb25-2\"><a href=\"#cb25-2\" aria-hidden=\"true\"><\/a>    tt.gof.opt.PatternSub(<\/span>\n<span id=\"cb25-3\"><a href=\"#cb25-3\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Search expression pattern<\/span><\/span>\n<span id=\"cb25-4\"><a href=\"#cb25-4\" aria-hidden=\"true\"><\/a>        (tt.mul,<\/span>\n<span id=\"cb25-5\"><a href=\"#cb25-5\" aria-hidden=\"true\"><\/a>         <span class=\"st\">&#39;a_x&#39;<\/span>,<\/span>\n<span id=\"cb25-6\"><a href=\"#cb25-6\" aria-hidden=\"true\"><\/a>         (NormalRV, <span class=\"st\">&#39;rs_x&#39;<\/span>, <span class=\"st\">&#39;shp_x&#39;<\/span>, <span class=\"st\">&#39;mu_x&#39;<\/span>, <span class=\"st\">&#39;sd_x&#39;<\/span>)),<\/span>\n<span id=\"cb25-7\"><a href=\"#cb25-7\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># Replacement expression<\/span><\/span>\n<span id=\"cb25-8\"><a href=\"#cb25-8\" aria-hidden=\"true\"><\/a>        (itemgetter(<span class=\"dv\">1<\/span>),<\/span>\n<span id=\"cb25-9\"><a href=\"#cb25-9\" aria-hidden=\"true\"><\/a>         (NormalRV,<\/span>\n<span id=\"cb25-10\"><a href=\"#cb25-10\" aria-hidden=\"true\"><\/a>          <span class=\"co\"># RNG<\/span><\/span>\n<span id=\"cb25-11\"><a href=\"#cb25-11\" aria-hidden=\"true\"><\/a>                <span class=\"st\">&#39;rs_x&#39;<\/span>,<\/span>\n<span id=\"cb25-12\"><a href=\"#cb25-12\" aria-hidden=\"true\"><\/a>          <span class=\"co\"># Convolution shape<\/span><\/span>\n<span id=\"cb25-13\"><a href=\"#cb25-13\" aria-hidden=\"true\"><\/a>                <span class=\"st\">&#39;shp_x&#39;<\/span>,<\/span>\n<span id=\"cb25-14\"><a href=\"#cb25-14\" aria-hidden=\"true\"><\/a>          <span class=\"co\"># Convolution mean<\/span><\/span>\n<span id=\"cb25-15\"><a href=\"#cb25-15\" aria-hidden=\"true\"><\/a>                (tt.mul, <span class=\"st\">&#39;a_x&#39;<\/span>, <span class=\"st\">&#39;mu_x&#39;<\/span>),<\/span>\n<span id=\"cb25-16\"><a href=\"#cb25-16\" aria-hidden=\"true\"><\/a>          <span class=\"co\"># Convolution std. dev.<\/span><\/span>\n<span id=\"cb25-17\"><a href=\"#cb25-17\" aria-hidden=\"true\"><\/a>                (tt.mul, <span class=\"st\">&#39;a_x&#39;<\/span>, <span class=\"st\">&#39;sd_x&#39;<\/span>),<\/span>\n<span id=\"cb25-18\"><a href=\"#cb25-18\" aria-hidden=\"true\"><\/a>         )),<\/span>\n<span id=\"cb25-19\"><a href=\"#cb25-19\" aria-hidden=\"true\"><\/a>    )<\/span>\n<span id=\"cb25-20\"><a href=\"#cb25-20\" aria-hidden=\"true\"><\/a>]<\/span>\n<span id=\"cb25-21\"><a href=\"#cb25-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb25-22\"><a href=\"#cb25-22\" aria-hidden=\"true\"><\/a>norm_conv_opt_tt <span class=\"op\">=<\/span> tt.gof.opt.EquilibriumOptimizer(<\/span>\n<span id=\"cb25-23\"><a href=\"#cb25-23\" aria-hidden=\"true\"><\/a>    norm_conv_pat_tt, max_use_ratio<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>)<\/span><\/code><\/pre><\/div>\n<p>The additional optimization is demonstrated in <a href=\"#mat_mul_scaling_exa\">mat_mul_scaling_exa<\/a>.<\/p>\n<div class=\"sourceCode\" id=\"cb26\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb26-1\"><a href=\"#cb26-1\" aria-hidden=\"true\"><\/a>mu_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;mu_X&#39;<\/span>)<\/span>\n<span id=\"cb26-2\"><a href=\"#cb26-2\" aria-hidden=\"true\"><\/a>mu_X.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">0.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb26-3\"><a href=\"#cb26-3\" aria-hidden=\"true\"><\/a>sd_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;sd_X&#39;<\/span>)<\/span>\n<span id=\"cb26-4\"><a href=\"#cb26-4\" aria-hidden=\"true\"><\/a>sd_X.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">1.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb26-5\"><a href=\"#cb26-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-6\"><a href=\"#cb26-6\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model() <span class=\"im\">as<\/span> conv_scale_model:<\/span>\n<span id=\"cb26-7\"><a href=\"#cb26-7\" aria-hidden=\"true\"><\/a>    X_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;X_rv&#39;<\/span>, mu_X, sd<span class=\"op\">=<\/span>sd_X, shape<span class=\"op\">=<\/span>(<span class=\"dv\">1<\/span>,))<\/span>\n<span id=\"cb26-8\"><a href=\"#cb26-8\" aria-hidden=\"true\"><\/a>    Z_rv <span class=\"op\">=<\/span> <span class=\"dv\">5<\/span> <span class=\"op\">*<\/span> X_rv<\/span>\n<span id=\"cb26-9\"><a href=\"#cb26-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-10\"><a href=\"#cb26-10\" aria-hidden=\"true\"><\/a>Z_mul_tt <span class=\"op\">=<\/span> model_graph(conv_scale_model, derived_vars<span class=\"op\">=<\/span>[Z_rv])<\/span>\n<span id=\"cb26-11\"><a href=\"#cb26-11\" aria-hidden=\"true\"><\/a>Z_mul_rv <span class=\"op\">=<\/span> create_theano_rvs(Z_mul_tt)<\/span>\n<span id=\"cb26-12\"><a href=\"#cb26-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-13\"><a href=\"#cb26-13\" aria-hidden=\"true\"><\/a>Z_mul_rv_merged <span class=\"op\">=<\/span> Z_mul_rv.clone()<\/span>\n<span id=\"cb26-14\"><a href=\"#cb26-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb26-15\"><a href=\"#cb26-15\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> norm_conv_opt_tt.optimize(Z_mul_rv_merged)<\/span><\/code><\/pre><\/div>\n<p><a href=\"#fig:scaled_random_sum_before\">fig:scaled_random_sum_before<\/a> and <a href=\"#fig:scaled_random_sum_after\">fig:scaled_random_sum_after<\/a> demonstrate the a scaled normal random variable before and after the optimization, respectively.<\/p>\n<figure id=\"fig:scaled_random_sum_before\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/Z_mul_rv.png\" title=\"fig:\" alt=\"Graph of a single term scaled in a normal-normal convolution. \" \/>\n<figcaption>\nGraph of a single term scaled in a normal-normal convolution.\n<\/figcaption>\n<\/figure>\n<figure id=\"fig:scaled_random_sum_after\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/Z_mul_rv_merged.png\" title=\"fig:\" alt=\"Graph of a single term scaled in a normal-normal convolution after the convolution optimization. \" \/>\n<figcaption>\nGraph of a single term scaled in a normal-normal convolution after the convolution optimization.\n<\/figcaption>\n<\/figure>\n<\/section>\n<section id=\"challenges\" class=\"level1\">\n<h1>Challenges<\/h1>\n<p>If we change the dimensions of our example above, the pattern employed by our scaling optimization will not match. To fix this, we can generalize the form of our <code>RandomFunction<\/code> operator so that it includes more cases of broadcastable dimensions\u2013instead of only <code>(True, )<\/code><\/p>\n<p>We could also extend the reach of our <code>PatternSub<\/code>s; however, this direction introduces more complexity into the process of writing optimizations and provides no foreseeable benefit elsewhere.<\/p>\n<p>More generally, one of the major challenges in this kind of work is due to the design of <code>RandomFunction<\/code>; its type is dependent on a <code>TensorType<\/code> parameter that requires an array of \u201cbroadcast\u201d dimensions.<\/p>\n<p>This situation arises\u2013in part\u2013from PyMC3, Theano, and NumPy\u2019s use of a \u201csize\u201d parameter in combination with random variable dimensions inferred from distribution parameters. A few outstanding <a href=\"https:\/\/github.com\/pymc-devs\/pymc3\/pull\/1125\">PyMC3 issues seem to revolve<\/a> around the interactions between these elements.<\/p>\n<p>The size parameter is like a sample size, but with all the samples considered together as a single tensor (e.g.\u00a0each sample of a multivariate normal random variable, say, acting as a column in a matrix). The size parameter is independent of a random variable\u2019s parameters\u2019 sizes (e.g.\u00a0dimensions of a mean and covariance), but, together, the size and distribution parameters effectively compose the size\/dimension of a random variable\u2019s support (e.g.\u00a0the matrix in the above example is the resulting random variable).<\/p>\n<p>Needless to say, PyMC3 and Theano\u2019s terms\u2013and their relation to mathematical notions\u2013are a bit confusing, and likely driven more by software design choices than the mathematical frameworks in use. However, those design choices significantly affect our ability to manipulate graphs and express common mathematical notions. For instance, these terms and design choices put greater demand on the graph manipulation steps, due to the ambiguous dimensions of the elements involved.<\/p>\n<\/section>\n<section id=\"next-steps\" class=\"level1\">\n<h1>Next Steps<\/h1>\n<p>In a follow-up, I\u2019ll introduce a new <code>Op<\/code> that overcomes some of the dimensionality issues and allows for much easier graph manipulation. It replaces <code>RandomFunction<\/code> with a single <code>Op<\/code> for each distribution type and [re]moves the type specifier from the definition of the <code>Op<\/code>.<\/p>\n<p>Essentially, the <code>TensorType<\/code> argument to the <code>RandomFunction<\/code> constructor is moved into <code>RandomFunction<\/code>\u2019s <code>make_node<\/code> method and, thus, generated\/inferred from the symbolic inputs.<\/p>\n<p>To be clear, we\u2019re talking about two distinct aspects of <code>RandomFunction<\/code>: one is the <code>NormalRV = RandomFunction('normal', TensorType('float64', bcast))<\/code> step, in which we <strong>create the <code>Op<\/code><\/strong> corresponding to a specific type of normal random variable, and the other in which we <strong>use the <code>Op<\/code><\/strong> (e.g.\u00a0<code>NormalRV(rng, 1, 2)<\/code>)\u2013to, say, produce a tensor variable corresponding to an instance of said random variable.<\/p>\n<p>This distinction is important for pattern matching because <code>NormalRV<\/code>, as defined above, isn\u2019t very general and mostly due to the <code>TensorType('float64', bcast))<\/code> covering only some Theano tensor types (i.e.\u00a0those that match the fixed broadcast dimensions specified by <code>bcast<\/code>).<\/p>\n<p>As stated previously, there have been real difficulties with the handling of shape and type information in PyMC3 (see <a href=\"https:\/\/github.com\/pymc-devs\/pymc3\/pull\/1125\">PyMC3 PR 1125<\/a>). These problems are related to the same concerns involving <code>TensorType<\/code>s. In refactoring the type information requirement for <code>RandomFunction<\/code>, we\u2019ll end up addressing those PyMC3 issues as well.<\/p>\n<\/section>\n<section id=\"bibliography\" class=\"level1\">\n<h1>Bibliography<\/h1>\n<p><a id=\"WillardRoleSymbolicComputation2017\"><\/a>[WillardRoleSymbolicComputation2017] Willard, A Role for Symbolic Computation in the General Estimation of Statistical Models, <i>Brandon T. Willard<\/i>, (2017). <a href=\"https:\/\/brandonwillard.github.io\/a-role-for-symbolic-computation-in-the-general-estimation-of-statistical-models.html\">link<\/a>. <a href=\"#4407b21e48ab9ff17c017e8d62684725\">\u21a9\ufe0e<\/a><\/p>\n<\/section>\n<\/body>\n<\/html>\n","category":[{"@attributes":{"term":"articles"}},{"@attributes":{"term":"pymc3"}},{"@attributes":{"term":"theano"}},{"@attributes":{"term":"statistics"}},{"@attributes":{"term":"symbolic computation"}},{"@attributes":{"term":"python"}},{"@attributes":{"term":"probability theory"}}]},{"title":"More Proximal Estimation","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/more-proximal-estimation.html","rel":"alternate"}},"published":"2017-03-06T00:00:00-06:00","updated":"2017-03-06T00:00:00-06:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2017-03-06:\/more-proximal-estimation.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>More Proximal Estimation<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">More Proximal Estimation<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2017\u201303\u201306<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p>The focal point of this short exposition will be an elaboration of the basic <span class=\"math inline\">\\(\\ell_1\\)<\/span> penalization problem discussed in <span class=\"citation\" data-cites=\"willard_role_2017\">Willard (2017)<\/span>, <span class=\"math display\">\\[\\begin{equation}\n\\operatorname*{argmin}_{\\beta} \\left\\{\n  \\frac{1}{2} \\|y - X \\beta\\|^2_2\n    + \\lambda \\|\\beta\\|_1\n  \\right\\}\n  \\;.\n  \\label{eq:lasso}\n\\end{equation}\\]<\/span> We continue our discussion on topics concerning automation and symbolic computation in Theano <span class=\"citation\" data-cites=\"bergstra_theano_2010\">(Bergstra et al. 2010)<\/span>, as well as the mathematical methodology we believe is suitable for such implementations. Again, our framing of the problem is in terms of \u201cproximal methods\u201d <span class=\"citation\" data-cites=\"parikh_proximal_2014 combettes_proximal_2011\">(Parikh and Boyd 2014; Combettes and Pesquet 2011)<\/span>. Along the way we propose one simple means of placing the well-known technique of coordinate descent within the scope of proximal methods via a general property of proximal operators. These efforts are a continued outgrowth of our work in <span class=\"citation\" data-cites=\"polson_proximal_2015\">Polson, Scott, and Willard (2015)<\/span>.<\/p>\n<\/section>\n<section id=\"proximal-and-computational-components\" class=\"level1\">\n<h1>Proximal and Computational Components<\/h1>\n<p>First, we [re]-introduce the workhorse of proximal methods: the <em>proximal operator<\/em>.<\/p>\n<div class=\"definition\" data-markdown=\"\" data-title-name=\"[Proximal Operator]\">\n<p><span class=\"math display\">\\[\\begin{equation*}\n\\operatorname*{prox}_{\\phi}(x) =\n    \\operatorname*{argmin}_{z} \\left\\{\n    \\frac{1}{2} \\left(z - x\\right)^2 + \\phi(z)\n    \\right\\}\n    \\;.\n\\end{equation*}\\]<\/span><\/p>\n<\/div>\n<p>Inspired by Equation\u00a0<span class=\"math inline\">\\(\\eqref{eq:lasso}\\)<\/span>, we produce a toy dataset as follows:<\/p>\n<div class=\"sourceCode\" id=\"cb1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> shared <span class=\"im\">as<\/span> tt_shared<\/span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\"><\/a>M <span class=\"op\">=<\/span> <span class=\"dv\">50<\/span><\/span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\"><\/a>M_nonzero <span class=\"op\">=<\/span> M <span class=\"op\">*<\/span> <span class=\"dv\">2<\/span> <span class=\"op\">\/\/<\/span> <span class=\"dv\">10<\/span><\/span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\"><\/a>beta_true <span class=\"op\">=<\/span> np.zeros(M)<\/span>\n<span id=\"cb1-7\"><a href=\"#cb1-7\" aria-hidden=\"true\"><\/a>beta_true[:M_nonzero] <span class=\"op\">=<\/span> np.exp(<span class=\"op\">-<\/span>np.arange(M_nonzero)) <span class=\"op\">*<\/span> <span class=\"dv\">100<\/span><\/span>\n<span id=\"cb1-8\"><a href=\"#cb1-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-9\"><a href=\"#cb1-9\" aria-hidden=\"true\"><\/a>N <span class=\"op\">=<\/span> <span class=\"bu\">int<\/span>(np.alen(beta_true) <span class=\"op\">*<\/span> <span class=\"fl\">0.4<\/span>)<\/span>\n<span id=\"cb1-10\"><a href=\"#cb1-10\" aria-hidden=\"true\"><\/a>X <span class=\"op\">=<\/span> np.random.randn(N, M)<\/span>\n<span id=\"cb1-11\"><a href=\"#cb1-11\" aria-hidden=\"true\"><\/a>mu_true <span class=\"op\">=<\/span> X.dot(beta_true)<\/span>\n<span id=\"cb1-12\"><a href=\"#cb1-12\" aria-hidden=\"true\"><\/a>y <span class=\"op\">=<\/span> mu_true <span class=\"op\">+<\/span> sc.stats.norm.rvs(np.zeros(N), scale<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>)<\/span>\n<span id=\"cb1-13\"><a href=\"#cb1-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-14\"><a href=\"#cb1-14\" aria-hidden=\"true\"><\/a>X_tt <span class=\"op\">=<\/span> tt_shared(X, name<span class=\"op\">=<\/span><span class=\"st\">&#39;X&#39;<\/span>, borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-15\"><a href=\"#cb1-15\" aria-hidden=\"true\"><\/a>y_tt <span class=\"op\">=<\/span> tt_shared(y, name<span class=\"op\">=<\/span><span class=\"st\">&#39;y&#39;<\/span>, borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-16\"><a href=\"#cb1-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-17\"><a href=\"#cb1-17\" aria-hidden=\"true\"><\/a><span class=\"co\"># Estimation starting parameters...<\/span><\/span>\n<span id=\"cb1-18\"><a href=\"#cb1-18\" aria-hidden=\"true\"><\/a>beta_0 <span class=\"op\">=<\/span> np.zeros(X.shape[<span class=\"dv\">1<\/span>]).astype(<span class=\"st\">&#39;float64&#39;<\/span>)<\/span>\n<span id=\"cb1-19\"><a href=\"#cb1-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-20\"><a href=\"#cb1-20\" aria-hidden=\"true\"><\/a><span class=\"co\"># Gradient [starting] step size<\/span><\/span>\n<span id=\"cb1-21\"><a href=\"#cb1-21\" aria-hidden=\"true\"><\/a>alpha_0 <span class=\"op\">=<\/span> <span class=\"fl\">1.<\/span> <span class=\"op\">\/<\/span> np.linalg.norm(X, <span class=\"dv\">2<\/span>)<span class=\"op\">**<\/span><span class=\"dv\">2<\/span><\/span>\n<span id=\"cb1-22\"><a href=\"#cb1-22\" aria-hidden=\"true\"><\/a><span class=\"co\"># np.linalg.matrix_rank(X)<\/span><\/span>\n<span id=\"cb1-23\"><a href=\"#cb1-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-24\"><a href=\"#cb1-24\" aria-hidden=\"true\"><\/a><span class=\"co\"># Regularization value heuristic<\/span><\/span>\n<span id=\"cb1-25\"><a href=\"#cb1-25\" aria-hidden=\"true\"><\/a><span class=\"co\"># beta_ols = np.linalg.lstsq(X, y)[0]<\/span><\/span>\n<span id=\"cb1-26\"><a href=\"#cb1-26\" aria-hidden=\"true\"><\/a><span class=\"co\"># lambda_max = 0.1 * np.linalg.norm(beta_ols, np.inf)<\/span><\/span>\n<span id=\"cb1-27\"><a href=\"#cb1-27\" aria-hidden=\"true\"><\/a>lambda_max <span class=\"op\">=<\/span> np.linalg.norm(X.T.dot(y), np.inf)<\/span><\/code><\/pre><\/div>\n<p>As in <span class=\"citation\" data-cites=\"willard_role_2017\">Willard (2017)<\/span>, we can start with a model defined within a system like PyMC3 <span class=\"citation\" data-cites=\"salvatier_probabilistic_2016\">(Salvatier, Wiecki, and Fonnesbeck 2016)<\/span>.<\/p>\n<div class=\"sourceCode\" id=\"cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb2-1\"><a href=\"#cb2-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model() <span class=\"im\">as<\/span> lasso_model:<\/span>\n<span id=\"cb2-2\"><a href=\"#cb2-2\" aria-hidden=\"true\"><\/a>    beta_rv <span class=\"op\">=<\/span> pm.Laplace(<span class=\"st\">&#39;beta&#39;<\/span>, mu<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>, b<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>,<\/span>\n<span id=\"cb2-3\"><a href=\"#cb2-3\" aria-hidden=\"true\"><\/a>                         shape<span class=\"op\">=<\/span>X.shape[<span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb2-4\"><a href=\"#cb2-4\" aria-hidden=\"true\"><\/a>    y_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;y&#39;<\/span>, mu<span class=\"op\">=<\/span>X_tt.dot(beta_rv), sd<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>,<\/span>\n<span id=\"cb2-5\"><a href=\"#cb2-5\" aria-hidden=\"true\"><\/a>                     shape<span class=\"op\">=<\/span>y.shape[<span class=\"dv\">0<\/span>], observed<span class=\"op\">=<\/span>y_tt)<\/span><\/code><\/pre><\/div>\n<p>In this setting one might then arrive at the necessary steps toward estimation automatically (i.e.\u00a0identify the underlying <span class=\"math inline\">\\(\\ell_1\\)<\/span> estimation problem). We discuss this more in <span class=\"citation\" data-cites=\"willard_role_2017\">Willard (2017)<\/span>.<\/p>\n<p>For simplicity, we\u2019ll just assume that all components of the estimation problem are know\u2013i.e.\u00a0loss and penalty functions. The proximal operator that arises in this standard example is the <em>soft thresholding<\/em> operator. In Theano, it can be implemented with the following:<\/p>\n<div class=\"sourceCode\" id=\"cb3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb3-1\"><a href=\"#cb3-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> tt_soft_threshold(beta_, lambda_):<\/span>\n<span id=\"cb3-2\"><a href=\"#cb3-2\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> tt.sgn(beta_) <span class=\"op\">*<\/span> tt.maximum(tt.abs_(beta_) <span class=\"op\">-<\/span> lambda_, <span class=\"dv\">0<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"remark\" data-markdown=\"\" data-title-name=\"\">\n<p>This operator can take other forms and the one used here is not particularly special. For instance, the <code>maximum<\/code> can be replaced by other conditional-like statements\u2013such as <span class=\"math display\">\\[\\begin{equation*}\n\\operatorname{S}(z, \\lambda) =\n    \\begin{cases}\n     {\\mathop{\\mathrm{sgn}}}(\\beta) (\\beta - \\lambda) &amp; \\beta &gt; \\lambda\n     \\\\\n     0 &amp; \\text{otherwise}\n    \\end{cases}\n    \\;.\n\\end{equation*}\\]<\/span> If we were to\u2013say\u2013multiply the output of this operator with another more difficult to compute result, then we might also wish to \u201coptimize\u201d this implementation by pushing the multiplication into the definition of the operator and altogether avoid its computation in the <span class=\"math inline\">\\(\\beta \\leq \\lambda\\)<\/span> case.<\/p>\n<p>Barring any reuses of this quantity, or a need to preserve undefined results produced by an expensive product with zero, we would really like a \u201ccompiler\u201d to make such an optimization itself. It isn\u2019t clear how a standard compiler\u2013or interpreter\/hybrid\u2013could safely make this optimization, whereas it does seem more reasonable as a symbolic\/Theano optimization.<\/p>\n<p>Optimizations like this are\u2013I think\u2013a necessary step to enable expressive, generalized methods and truly rapid prototyping at the math level.<\/p>\n<\/div>\n<p>Now, assuming that we\u2019ve obtained the relevant loss and penalty functions\u2013for example, in PyMC3\u2013then we can proceed by setting up the exact context of our proximal problem.<\/p>\n<div class=\"sourceCode\" id=\"cb4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb4-1\"><a href=\"#cb4-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> clone <span class=\"im\">as<\/span> tt_clone<\/span>\n<span id=\"cb4-2\"><a href=\"#cb4-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-3\"><a href=\"#cb4-3\" aria-hidden=\"true\"><\/a><span class=\"co\"># Clone the negative log-likelihood of our observation model.<\/span><\/span>\n<span id=\"cb4-4\"><a href=\"#cb4-4\" aria-hidden=\"true\"><\/a>nlogl_rv <span class=\"op\">=<\/span> <span class=\"op\">-<\/span>lasso_model.observed_RVs[<span class=\"dv\">0<\/span>].logpt<\/span>\n<span id=\"cb4-5\"><a href=\"#cb4-5\" aria-hidden=\"true\"><\/a>nlogl <span class=\"op\">=<\/span> tt_clone(nlogl_rv)<\/span>\n<span id=\"cb4-6\"><a href=\"#cb4-6\" aria-hidden=\"true\"><\/a>nlogl.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;-logl&quot;<\/span><\/span>\n<span id=\"cb4-7\"><a href=\"#cb4-7\" aria-hidden=\"true\"><\/a>beta_tt <span class=\"op\">=<\/span> tt_inputs([nlogl])[<span class=\"dv\">4<\/span>]<\/span><\/code><\/pre><\/div>\n<\/section>\n<section id=\"proximal-gradient\" class=\"level1\">\n<h1>Proximal Gradient<\/h1>\n<p>In what follows it will be convenient to generalize a bit and work in terms of arbitrary loss and penalty functions <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span>, respectively, which in our case corresponds to <span class=\"math display\">\\[\\begin{equation*}\n\\begin{gathered}\n  l(\\beta) = \\frac12 \\|y - X \\beta\\|^2_2, \\quad\n  \\text{and}\\;\n  \\phi(\\beta) = \\|\\beta\\|_1\n  \\;.\\end{gathered}\n\\end{equation*}\\]<\/span><\/p>\n<p>The proximal gradient <span class=\"citation\" data-cites=\"combettes_proximal_2011\">(Combettes and Pesquet 2011)<\/span> algorithm is a staple of the proximal framework that provides solutions to problems of the form <span class=\"math display\">\\[\\begin{equation*}\n\\operatorname*{argmin}_\\beta \\left\\{\n    l(\\beta) + \\lambda \\phi(\\beta)\n  \\right\\}\n  \\;,\n\\end{equation*}\\]<\/span> when both <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span> are lower semi-continuous convex functions, and <span class=\"math inline\">\\(l\\)<\/span> is differentiable with Lipschitz gradient.<\/p>\n<p>The solution is given as the following fixed-point: <span class=\"math display\">\\[\\begin{equation}\n\\beta = \\operatorname*{prox}_{\\alpha \\lambda \\phi}(\\beta - \\alpha \\nabla l(\\beta))\n  \\;.\n  \\label{eq:forward-backward}\n\\end{equation}\\]<\/span> The constant step size <span class=\"math inline\">\\(\\alpha\\)<\/span> is related to the Lipschitz constant of <span class=\"math inline\">\\(\\nabla l\\)<\/span>, but can also be a sequence obeying certain constraints. Since our <span class=\"math inline\">\\(l\\)<\/span> under consideration is <span class=\"math inline\">\\(\\ell_2\\)<\/span>, we have the incredibly standard <span class=\"math inline\">\\(\\nabla l(\\beta) = X^\\top (X \\beta - y)\\)<\/span>.<\/p>\n<section id=\"implementation\" class=\"level2\">\n<h2>Implementation<\/h2>\n<p>As in <span class=\"citation\" data-cites=\"willard_role_2017\">Willard (2017)<\/span>, we provide an implementation of a proximal gradient step.<\/p>\n<div class=\"sourceCode\" id=\"cb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb5-1\"><a href=\"#cb5-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> function <span class=\"im\">as<\/span> tt_function<\/span>\n<span id=\"cb5-2\"><a href=\"#cb5-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano.<span class=\"bu\">compile<\/span>.nanguardmode <span class=\"im\">import<\/span> NanGuardMode<\/span>\n<span id=\"cb5-3\"><a href=\"#cb5-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-4\"><a href=\"#cb5-4\" aria-hidden=\"true\"><\/a>tt_func_mode <span class=\"op\">=<\/span> NanGuardMode(nan_is_error<span class=\"op\">=<\/span><span class=\"va\">True<\/span>,<\/span>\n<span id=\"cb5-5\"><a href=\"#cb5-5\" aria-hidden=\"true\"><\/a>                            inf_is_error<span class=\"op\">=<\/span><span class=\"va\">False<\/span>,<\/span>\n<span id=\"cb5-6\"><a href=\"#cb5-6\" aria-hidden=\"true\"><\/a>                            big_is_error<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb5-7\"><a href=\"#cb5-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-8\"><a href=\"#cb5-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-9\"><a href=\"#cb5-9\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> prox_gradient_step(loss, beta_tt, prox_func,<\/span>\n<span id=\"cb5-10\"><a href=\"#cb5-10\" aria-hidden=\"true\"><\/a>                       alpha_tt<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, lambda_tt<span class=\"op\">=<\/span><span class=\"va\">None<\/span>,<\/span>\n<span id=\"cb5-11\"><a href=\"#cb5-11\" aria-hidden=\"true\"><\/a>                       return_loss_grad<span class=\"op\">=<\/span><span class=\"va\">False<\/span>,<\/span>\n<span id=\"cb5-12\"><a href=\"#cb5-12\" aria-hidden=\"true\"><\/a>                       tt_func_kwargs<span class=\"op\">=<\/span>{<span class=\"st\">&#39;mode&#39;<\/span>: tt_func_mode}<\/span>\n<span id=\"cb5-13\"><a href=\"#cb5-13\" aria-hidden=\"true\"><\/a>                       ):<\/span>\n<span id=\"cb5-14\"><a href=\"#cb5-14\" aria-hidden=\"true\"><\/a>    <span class=\"co\">r&quot;&quot;&quot; Creates a function that produces a proximal gradient step.<\/span><\/span>\n<span id=\"cb5-15\"><a href=\"#cb5-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-16\"><a href=\"#cb5-16\" aria-hidden=\"true\"><\/a><span class=\"co\">    Arguments<\/span><\/span>\n<span id=\"cb5-17\"><a href=\"#cb5-17\" aria-hidden=\"true\"><\/a><span class=\"co\">    =========<\/span><\/span>\n<span id=\"cb5-18\"><a href=\"#cb5-18\" aria-hidden=\"true\"><\/a><span class=\"co\">    loss: TensorVariable<\/span><\/span>\n<span id=\"cb5-19\"><a href=\"#cb5-19\" aria-hidden=\"true\"><\/a><span class=\"co\">        Continuously differentiable &quot;loss&quot; function in the objective<\/span><\/span>\n<span id=\"cb5-20\"><a href=\"#cb5-20\" aria-hidden=\"true\"><\/a><span class=\"co\">        function.<\/span><\/span>\n<span id=\"cb5-21\"><a href=\"#cb5-21\" aria-hidden=\"true\"><\/a><span class=\"co\">    beta_tt: TensorVariable<\/span><\/span>\n<span id=\"cb5-22\"><a href=\"#cb5-22\" aria-hidden=\"true\"><\/a><span class=\"co\">        Variable argument of the loss function.<\/span><\/span>\n<span id=\"cb5-23\"><a href=\"#cb5-23\" aria-hidden=\"true\"><\/a><span class=\"co\">    prox_fn: function<\/span><\/span>\n<span id=\"cb5-24\"><a href=\"#cb5-24\" aria-hidden=\"true\"><\/a><span class=\"co\">        Function that computes the proximal operator for the &quot;penalty&quot;<\/span><\/span>\n<span id=\"cb5-25\"><a href=\"#cb5-25\" aria-hidden=\"true\"><\/a><span class=\"co\">        function.  Must take two parameters: the first a<\/span><\/span>\n<span id=\"cb5-26\"><a href=\"#cb5-26\" aria-hidden=\"true\"><\/a><span class=\"co\">TensorVariable<\/span><\/span>\n<span id=\"cb5-27\"><a href=\"#cb5-27\" aria-hidden=\"true\"><\/a><span class=\"co\">        of the gradient step, the second a float or Scalar value.<\/span><\/span>\n<span id=\"cb5-28\"><a href=\"#cb5-28\" aria-hidden=\"true\"><\/a><span class=\"co\">    alpha_tt: float, Scalar (optional)<\/span><\/span>\n<span id=\"cb5-29\"><a href=\"#cb5-29\" aria-hidden=\"true\"><\/a><span class=\"co\">        Gradient step size.<\/span><\/span>\n<span id=\"cb5-30\"><a href=\"#cb5-30\" aria-hidden=\"true\"><\/a><span class=\"co\">    lambda_tt: float, Scalar (optional)<\/span><\/span>\n<span id=\"cb5-31\"><a href=\"#cb5-31\" aria-hidden=\"true\"><\/a><span class=\"co\">        Additional scalar value passed to `prox_fn`.<\/span><\/span>\n<span id=\"cb5-32\"><a href=\"#cb5-32\" aria-hidden=\"true\"><\/a><span class=\"co\">        <\/span><span class=\"al\">TODO<\/span><span class=\"co\">: Not sure if this should be here; is redundant.<\/span><\/span>\n<span id=\"cb5-33\"><a href=\"#cb5-33\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb5-34\"><a href=\"#cb5-34\" aria-hidden=\"true\"><\/a>    loss_grad <span class=\"op\">=<\/span> tt.grad(loss, wrt<span class=\"op\">=<\/span>beta_tt)<\/span>\n<span id=\"cb5-35\"><a href=\"#cb5-35\" aria-hidden=\"true\"><\/a>    loss_grad.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;loss_grad&quot;<\/span><\/span>\n<span id=\"cb5-36\"><a href=\"#cb5-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-37\"><a href=\"#cb5-37\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> alpha_tt <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb5-38\"><a href=\"#cb5-38\" aria-hidden=\"true\"><\/a>        alpha_tt <span class=\"op\">=<\/span> tt.scalar(name<span class=\"op\">=<\/span><span class=\"st\">&#39;alpha&#39;<\/span>)<\/span>\n<span id=\"cb5-39\"><a href=\"#cb5-39\" aria-hidden=\"true\"><\/a>        alpha_tt.tag.test_value <span class=\"op\">=<\/span> <span class=\"dv\">1<\/span><\/span>\n<span id=\"cb5-40\"><a href=\"#cb5-40\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> lambda_tt <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb5-41\"><a href=\"#cb5-41\" aria-hidden=\"true\"><\/a>        lambda_tt <span class=\"op\">=<\/span> tt.scalar(name<span class=\"op\">=<\/span><span class=\"st\">&#39;lambda&#39;<\/span>)<\/span>\n<span id=\"cb5-42\"><a href=\"#cb5-42\" aria-hidden=\"true\"><\/a>        lambda_tt.tag.test_value <span class=\"op\">=<\/span> <span class=\"dv\">1<\/span><\/span>\n<span id=\"cb5-43\"><a href=\"#cb5-43\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-44\"><a href=\"#cb5-44\" aria-hidden=\"true\"><\/a>    beta_grad_step <span class=\"op\">=<\/span> beta_tt <span class=\"op\">-<\/span> alpha_tt <span class=\"op\">*<\/span> loss_grad<\/span>\n<span id=\"cb5-45\"><a href=\"#cb5-45\" aria-hidden=\"true\"><\/a>    beta_grad_step.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;beta_grad_step&quot;<\/span><\/span>\n<span id=\"cb5-46\"><a href=\"#cb5-46\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-47\"><a href=\"#cb5-47\" aria-hidden=\"true\"><\/a>    prox_grad_step <span class=\"op\">=<\/span> prox_func(beta_grad_step, lambda_tt <span class=\"op\">*<\/span> alpha_tt)<\/span>\n<span id=\"cb5-48\"><a href=\"#cb5-48\" aria-hidden=\"true\"><\/a>    prox_grad_step.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;prox_grad_step&quot;<\/span><\/span>\n<span id=\"cb5-49\"><a href=\"#cb5-49\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-50\"><a href=\"#cb5-50\" aria-hidden=\"true\"><\/a>    inputs <span class=\"op\">=<\/span> []<\/span>\n<span id=\"cb5-51\"><a href=\"#cb5-51\" aria-hidden=\"true\"><\/a>    updates <span class=\"op\">=<\/span> <span class=\"va\">None<\/span><\/span>\n<span id=\"cb5-52\"><a href=\"#cb5-52\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"bu\">isinstance<\/span>(beta_tt, tt.sharedvar.SharedVariable):<\/span>\n<span id=\"cb5-53\"><a href=\"#cb5-53\" aria-hidden=\"true\"><\/a>        updates <span class=\"op\">=<\/span> [(beta_tt, prox_grad_step)]<\/span>\n<span id=\"cb5-54\"><a href=\"#cb5-54\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"cb5-55\"><a href=\"#cb5-55\" aria-hidden=\"true\"><\/a>        inputs <span class=\"op\">+=<\/span> [beta_tt]<\/span>\n<span id=\"cb5-56\"><a href=\"#cb5-56\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(alpha_tt, tt.sharedvar.SharedVariable):<\/span>\n<span id=\"cb5-57\"><a href=\"#cb5-57\" aria-hidden=\"true\"><\/a>        inputs <span class=\"op\">+=<\/span> [alpha_tt]<\/span>\n<span id=\"cb5-58\"><a href=\"#cb5-58\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"kw\">not<\/span> <span class=\"bu\">isinstance<\/span>(lambda_tt, tt.sharedvar.SharedVariable):<\/span>\n<span id=\"cb5-59\"><a href=\"#cb5-59\" aria-hidden=\"true\"><\/a>        inputs <span class=\"op\">+=<\/span> [lambda_tt]<\/span>\n<span id=\"cb5-60\"><a href=\"#cb5-60\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-61\"><a href=\"#cb5-61\" aria-hidden=\"true\"><\/a>    prox_grad_step_fn <span class=\"op\">=<\/span> tt_function(inputs,<\/span>\n<span id=\"cb5-62\"><a href=\"#cb5-62\" aria-hidden=\"true\"><\/a>                                    prox_grad_step,<\/span>\n<span id=\"cb5-63\"><a href=\"#cb5-63\" aria-hidden=\"true\"><\/a>                                    updates<span class=\"op\">=<\/span>updates,<\/span>\n<span id=\"cb5-64\"><a href=\"#cb5-64\" aria-hidden=\"true\"><\/a>                                    <span class=\"op\">**<\/span>tt_func_kwargs)<\/span>\n<span id=\"cb5-65\"><a href=\"#cb5-65\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-66\"><a href=\"#cb5-66\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">=<\/span> (prox_grad_step_fn,)<\/span>\n<span id=\"cb5-67\"><a href=\"#cb5-67\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> return_loss_grad:<\/span>\n<span id=\"cb5-68\"><a href=\"#cb5-68\" aria-hidden=\"true\"><\/a>        res <span class=\"op\">+=<\/span> (loss_grad,)<\/span>\n<span id=\"cb5-69\"><a href=\"#cb5-69\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-70\"><a href=\"#cb5-70\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> res<\/span><\/code><\/pre><\/div>\n<\/section>\n<section id=\"step-sizes\" class=\"level2\">\n<h2>Step Sizes<\/h2>\n<p>A critical aspect of the proximal gradient approach\u2013and most optimizations\u2013involves the use of an appropriate step size, <span class=\"math inline\">\\(\\alpha\\)<\/span>. The step sizes needn\u2019t always be fixed values and, because of this, we can search for a suitable value during estimation. Furthermore, in some cases, step sizes can be sequences amenable to acceleration techniques <span class=\"citation\" data-cites=\"beck_fast_2014\">(Beck and Teboulle 2014)<\/span>.<\/p>\n<p>Step sizes\u2013and the values that drive them\u2013have critical connections to the performance of an optimization method and do not simply ensure convergence. In that sense, the power of an implementation can depend on how much support it has for various types of step size sequences and when they can be\/are used.<\/p>\n<p>Often, acceptable ranges of step size values are derived from broad properties of the functions involved and their gradients (e.g.\u00a0Lipschitz). When explicitly parameterized, these properties can give meaning to what some call \u201ctuning parameters\u201d. The connections between function-analytic properties and \u201ctuning parameters\u201d themselves highlight the need for more mathematical coverage\/symbolic assessment within implementations. Currently, most tuning parameter act as stand-ins for information that\u2019s theoretically obtained from the know functions.<\/p>\n<p>In this spirit, one particularly relevant direction of work can be found in Theano\u2019s experimental matrix \u201cHints\u201d. Matrix-property hints like <code>theano.sandbox.linalg.ops.{psd, spectral_radius_bound}<\/code> are good examples of the machinery needed to automatically determine applicable and efficient <span class=\"math inline\">\\(\\alpha\\)<\/span> constants and sequences.<\/p>\n<p>For our example, we will simply use backtracking line-search.<\/p>\n<div class=\"sourceCode\" id=\"cb6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb6-1\"><a href=\"#cb6-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> backtracking_search(beta_, alpha_,<\/span>\n<span id=\"cb6-2\"><a href=\"#cb6-2\" aria-hidden=\"true\"><\/a>                        prox_fn, loss_fn, loss_grad_fn,<\/span>\n<span id=\"cb6-3\"><a href=\"#cb6-3\" aria-hidden=\"true\"><\/a>                        lambda_<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>, bt_rate<span class=\"op\">=<\/span><span class=\"fl\">0.5<\/span>, obj_tol<span class=\"op\">=<\/span><span class=\"fl\">1e-5<\/span>):<\/span>\n<span id=\"cb6-4\"><a href=\"#cb6-4\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># alpha_start = alpha_<\/span><\/span>\n<span id=\"cb6-5\"><a href=\"#cb6-5\" aria-hidden=\"true\"><\/a>    z <span class=\"op\">=<\/span> beta_<\/span>\n<span id=\"cb6-6\"><a href=\"#cb6-6\" aria-hidden=\"true\"><\/a>    beta_start_ <span class=\"op\">=<\/span> beta_<\/span>\n<span id=\"cb6-7\"><a href=\"#cb6-7\" aria-hidden=\"true\"><\/a>    loss_start_ <span class=\"op\">=<\/span> loss_fn(beta_)<\/span>\n<span id=\"cb6-8\"><a href=\"#cb6-8\" aria-hidden=\"true\"><\/a>    loss_grad_start_ <span class=\"op\">=<\/span> loss_grad_fn(beta_)<\/span>\n<span id=\"cb6-9\"><a href=\"#cb6-9\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">while<\/span> <span class=\"va\">True<\/span>:<\/span>\n<span id=\"cb6-10\"><a href=\"#cb6-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-11\"><a href=\"#cb6-11\" aria-hidden=\"true\"><\/a>        beta_ <span class=\"op\">=<\/span> beta_start_ <span class=\"op\">-<\/span> alpha_ <span class=\"op\">*<\/span> loss_grad_start_<\/span>\n<span id=\"cb6-12\"><a href=\"#cb6-12\" aria-hidden=\"true\"><\/a>        z <span class=\"op\">=<\/span> prox_fn(beta_, alpha_ <span class=\"op\">*<\/span> lambda_)<\/span>\n<span id=\"cb6-13\"><a href=\"#cb6-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-14\"><a href=\"#cb6-14\" aria-hidden=\"true\"><\/a>        loss_z <span class=\"op\">=<\/span> loss_fn(z)<\/span>\n<span id=\"cb6-15\"><a href=\"#cb6-15\" aria-hidden=\"true\"><\/a>        step_diff <span class=\"op\">=<\/span> z <span class=\"op\">-<\/span> beta_start_<\/span>\n<span id=\"cb6-16\"><a href=\"#cb6-16\" aria-hidden=\"true\"><\/a>        loss_diff <span class=\"op\">=<\/span> loss_z <span class=\"op\">-<\/span> loss_start_<\/span>\n<span id=\"cb6-17\"><a href=\"#cb6-17\" aria-hidden=\"true\"><\/a>        line_diff <span class=\"op\">=<\/span> alpha_ <span class=\"op\">*<\/span> (loss_diff <span class=\"op\">-<\/span><\/span>\n<span id=\"cb6-18\"><a href=\"#cb6-18\" aria-hidden=\"true\"><\/a>loss_grad_start_.T.dot(step_diff))<\/span>\n<span id=\"cb6-19\"><a href=\"#cb6-19\" aria-hidden=\"true\"><\/a>        line_diff <span class=\"op\">-=<\/span> step_diff.T.dot(step_diff) <span class=\"op\">\/<\/span> <span class=\"fl\">2.<\/span><\/span>\n<span id=\"cb6-20\"><a href=\"#cb6-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-21\"><a href=\"#cb6-21\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> line_diff <span class=\"op\">&lt;=<\/span> obj_tol:<\/span>\n<span id=\"cb6-22\"><a href=\"#cb6-22\" aria-hidden=\"true\"><\/a>            <span class=\"cf\">return<\/span> z, alpha_, loss_z<\/span>\n<span id=\"cb6-23\"><a href=\"#cb6-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb6-24\"><a href=\"#cb6-24\" aria-hidden=\"true\"><\/a>        alpha_ <span class=\"op\">*=<\/span> bt_rate<\/span>\n<span id=\"cb6-25\"><a href=\"#cb6-25\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">assert<\/span> alpha_ <span class=\"op\">&gt;=<\/span> <span class=\"dv\">0<\/span>, <span class=\"st\">&#39;invalid step size: <\/span><span class=\"sc\">{}<\/span><span class=\"st\">&#39;<\/span>.<span class=\"bu\">format<\/span>(alpha_)<\/span><\/code><\/pre><\/div>\n<div class=\"remark\" data-markdown=\"\" data-title-name=\"\">\n<p>Routines\u2013like this\u2013that make use of the gradient and other quantities might also be good candidates for execution in Theano, if only for the graph optimizations that are able to remedy obviously redundant computations.<\/p>\n<p>In this vein, we could consider performing the line-search, and\/or the entire optimization loop, within a Theano <code>scan<\/code> operation. We could also create an <code>Op<\/code> that represents gradient and line-search steps. These might make graph construction much simpler and be more suited for the current optimization framework.<\/p>\n<p>Although there\u2019s no guarantee that <code>scan<\/code> and tighter Theano integrations will always produce better results than our current implementation, we wish to emphasize that it\u2019s possible\u2013given work in these symbolic directions.<\/p>\n<p>Likewise, an <code>Op<\/code> for the proximal operator might also be necessary for solving many proximal operators found within a log-likelihood\/objective function graph automatically and in closed-form. An effective implementation could be as simple as the use of lookup tables combined with some algebraic relationships\/identities. State-of-the-art symbolic algebra libraries effectively use the same approach for symbolic integration.<\/p>\n<\/div>\n<\/section>\n<\/section>\n<section id=\"examples\" class=\"level1\">\n<h1>Examples<\/h1>\n<p>First, to compute anything from our Theano graphs, we need to compile them to Theano functions.<\/p>\n<div class=\"sourceCode\" id=\"cb7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb7-1\"><a href=\"#cb7-1\" aria-hidden=\"true\"><\/a>lambda_tt <span class=\"op\">=<\/span> tt.scalar(<span class=\"st\">&#39;lambda&#39;<\/span>)<\/span>\n<span id=\"cb7-2\"><a href=\"#cb7-2\" aria-hidden=\"true\"><\/a>lambda_tt.tag.test_value <span class=\"op\">=<\/span> <span class=\"dv\">1<\/span><\/span>\n<span id=\"cb7-3\"><a href=\"#cb7-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-4\"><a href=\"#cb7-4\" aria-hidden=\"true\"><\/a>prox_fn <span class=\"op\">=<\/span> tt_function([beta_tt, lambda_tt],<\/span>\n<span id=\"cb7-5\"><a href=\"#cb7-5\" aria-hidden=\"true\"><\/a>                      tt_soft_threshold(beta_tt, lambda_tt))<\/span>\n<span id=\"cb7-6\"><a href=\"#cb7-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-7\"><a href=\"#cb7-7\" aria-hidden=\"true\"><\/a>prox_grad_step_fn, loss_grad <span class=\"op\">=<\/span> prox_gradient_step(<\/span>\n<span id=\"cb7-8\"><a href=\"#cb7-8\" aria-hidden=\"true\"><\/a>    nlogl, beta_tt, tt_soft_threshold,<\/span>\n<span id=\"cb7-9\"><a href=\"#cb7-9\" aria-hidden=\"true\"><\/a>    return_loss_grad<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb7-10\"><a href=\"#cb7-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-11\"><a href=\"#cb7-11\" aria-hidden=\"true\"><\/a>loss_fn <span class=\"op\">=<\/span> tt_function([beta_tt], nlogl)<\/span>\n<span id=\"cb7-12\"><a href=\"#cb7-12\" aria-hidden=\"true\"><\/a>loss_grad_fn <span class=\"op\">=<\/span> tt_function([beta_tt], loss_grad)<\/span>\n<span id=\"cb7-13\"><a href=\"#cb7-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-14\"><a href=\"#cb7-14\" aria-hidden=\"true\"><\/a>cols_fns <span class=\"op\">=<\/span> [<\/span>\n<span id=\"cb7-15\"><a href=\"#cb7-15\" aria-hidden=\"true\"><\/a>    (<span class=\"kw\">lambda<\/span> i, b: i, <span class=\"vs\">r&#39;$i$&#39;<\/span>),<\/span>\n<span id=\"cb7-16\"><a href=\"#cb7-16\" aria-hidden=\"true\"><\/a>    (<span class=\"kw\">lambda<\/span> i, b: np.asscalar(loss_fn(b)),<\/span>\n<span id=\"cb7-17\"><a href=\"#cb7-17\" aria-hidden=\"true\"><\/a>        <span class=\"vs\">r&#39;$l(\\beta^{(i)})$&#39;<\/span>),<\/span>\n<span id=\"cb7-18\"><a href=\"#cb7-18\" aria-hidden=\"true\"><\/a>    (<span class=\"kw\">lambda<\/span> i, b: np.linalg.norm(b <span class=\"op\">-<\/span> beta_true, <span class=\"dv\">2<\/span>),<\/span>\n<span id=\"cb7-19\"><a href=\"#cb7-19\" aria-hidden=\"true\"><\/a>        <span class=\"vs\">r&#39;$\\|\\beta^{(i)} - \\beta^*\\|^2_2$&#39;<\/span>)<\/span>\n<span id=\"cb7-20\"><a href=\"#cb7-20\" aria-hidden=\"true\"><\/a>]<\/span><\/code><\/pre><\/div>\n<p>For a baseline comparison\u2013and sanity check\u2013we\u2019ll use the <code>cvxpy<\/code> library <span class=\"citation\" data-cites=\"diamond_cvxpy:_2016\">(Diamond and Boyd 2016)<\/span>.<\/p>\n<div class=\"sourceCode\" id=\"cb8\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb8-1\"><a href=\"#cb8-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> cvxpy <span class=\"im\">as<\/span> cvx<\/span>\n<span id=\"cb8-2\"><a href=\"#cb8-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-3\"><a href=\"#cb8-3\" aria-hidden=\"true\"><\/a>beta_var_cvx <span class=\"op\">=<\/span> cvx.Variable(M, name<span class=\"op\">=<\/span><span class=\"st\">&#39;beta&#39;<\/span>)<\/span>\n<span id=\"cb8-4\"><a href=\"#cb8-4\" aria-hidden=\"true\"><\/a>lambda_cvx <span class=\"op\">=<\/span> <span class=\"fl\">1e-2<\/span> <span class=\"op\">*<\/span> lambda_max <span class=\"op\">*<\/span> N<\/span>\n<span id=\"cb8-5\"><a href=\"#cb8-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-6\"><a href=\"#cb8-6\" aria-hidden=\"true\"><\/a>cvx_obj <span class=\"op\">=<\/span> cvx.Minimize(<span class=\"fl\">0.5<\/span> <span class=\"op\">*<\/span> cvx.sum_squares(y <span class=\"op\">-<\/span> X <span class=\"op\">*<\/span> beta_var_cvx)<\/span>\n<span id=\"cb8-7\"><a href=\"#cb8-7\" aria-hidden=\"true\"><\/a>                       <span class=\"op\">+<\/span> lambda_cvx <span class=\"op\">*<\/span> cvx.norm(beta_var_cvx, <span class=\"dv\">1<\/span>) )<\/span>\n<span id=\"cb8-8\"><a href=\"#cb8-8\" aria-hidden=\"true\"><\/a>cvx_prob <span class=\"op\">=<\/span> cvx.Problem(cvx_obj)<\/span>\n<span id=\"cb8-9\"><a href=\"#cb8-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-10\"><a href=\"#cb8-10\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> cvx_prob.solve(solver<span class=\"op\">=<\/span>cvx.CVXOPT, verbose<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb8-11\"><a href=\"#cb8-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-12\"><a href=\"#cb8-12\" aria-hidden=\"true\"><\/a>beta_cvx <span class=\"op\">=<\/span> np.asarray(beta_var_cvx.value).squeeze()<\/span>\n<span id=\"cb8-13\"><a href=\"#cb8-13\" aria-hidden=\"true\"><\/a>loss_cvx <span class=\"op\">=<\/span> loss_fn(beta_cvx)<\/span>\n<span id=\"cb8-14\"><a href=\"#cb8-14\" aria-hidden=\"true\"><\/a>beta_cvx_err <span class=\"op\">=<\/span> np.linalg.norm(beta_cvx <span class=\"op\">-<\/span> beta_true, <span class=\"dv\">2<\/span>)<\/span><\/code><\/pre><\/div>\n<p>We now have the necessary pieces to perform an example estimation. We\u2019ll start with an exceedingly large step size and let backtracking line-search find a good value.<\/p>\n<div class=\"sourceCode\" id=\"cb9\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb9-1\"><a href=\"#cb9-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> ProxGradient(<span class=\"bu\">object<\/span>):<\/span>\n<span id=\"cb9-2\"><a href=\"#cb9-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-3\"><a href=\"#cb9-3\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>, y, X, beta_0,<\/span>\n<span id=\"cb9-4\"><a href=\"#cb9-4\" aria-hidden=\"true\"><\/a>                 prox_fn_, loss_fn_, loss_grad_fn_,<\/span>\n<span id=\"cb9-5\"><a href=\"#cb9-5\" aria-hidden=\"true\"><\/a>                 alpha_0):<\/span>\n<span id=\"cb9-6\"><a href=\"#cb9-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-7\"><a href=\"#cb9-7\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.y <span class=\"op\">=<\/span> y<\/span>\n<span id=\"cb9-8\"><a href=\"#cb9-8\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.X <span class=\"op\">=<\/span> X<\/span>\n<span id=\"cb9-9\"><a href=\"#cb9-9\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.alpha_val <span class=\"op\">=<\/span> alpha_0<\/span>\n<span id=\"cb9-10\"><a href=\"#cb9-10\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.beta_0 <span class=\"op\">=<\/span> beta_0<\/span>\n<span id=\"cb9-11\"><a href=\"#cb9-11\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.N, <span class=\"va\">self<\/span>.M <span class=\"op\">=<\/span> X.shape<\/span>\n<span id=\"cb9-12\"><a href=\"#cb9-12\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.prox_fn_ <span class=\"op\">=<\/span> prox_fn_<\/span>\n<span id=\"cb9-13\"><a href=\"#cb9-13\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.loss_fn_ <span class=\"op\">=<\/span> loss_fn_<\/span>\n<span id=\"cb9-14\"><a href=\"#cb9-14\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.loss_grad_fn_ <span class=\"op\">=<\/span> loss_grad_fn_<\/span>\n<span id=\"cb9-15\"><a href=\"#cb9-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-16\"><a href=\"#cb9-16\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> step(<span class=\"va\">self<\/span>, beta):<\/span>\n<span id=\"cb9-17\"><a href=\"#cb9-17\" aria-hidden=\"true\"><\/a>        beta_val <span class=\"op\">=<\/span> np.copy(beta)<\/span>\n<span id=\"cb9-18\"><a href=\"#cb9-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-19\"><a href=\"#cb9-19\" aria-hidden=\"true\"><\/a>        beta_val, <span class=\"va\">self<\/span>.alpha_val, _ <span class=\"op\">=<\/span> backtracking_search(<\/span>\n<span id=\"cb9-20\"><a href=\"#cb9-20\" aria-hidden=\"true\"><\/a>            beta_val, <span class=\"va\">self<\/span>.alpha_val,<\/span>\n<span id=\"cb9-21\"><a href=\"#cb9-21\" aria-hidden=\"true\"><\/a>            <span class=\"va\">self<\/span>.prox_fn_, <span class=\"va\">self<\/span>.loss_fn_, <span class=\"va\">self<\/span>.loss_grad_fn_)<\/span>\n<span id=\"cb9-22\"><a href=\"#cb9-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb9-23\"><a href=\"#cb9-23\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> beta_val<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb10\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb10-1\"><a href=\"#cb10-1\" aria-hidden=\"true\"><\/a>beta_0 <span class=\"op\">=<\/span> np.zeros(M).astype(<span class=\"st\">&#39;float64&#39;<\/span>)<\/span>\n<span id=\"cb10-2\"><a href=\"#cb10-2\" aria-hidden=\"true\"><\/a>lambda_val <span class=\"op\">=<\/span> <span class=\"fl\">1e-2<\/span> <span class=\"op\">*<\/span> lambda_max<\/span>\n<span id=\"cb10-3\"><a href=\"#cb10-3\" aria-hidden=\"true\"><\/a>pg_step <span class=\"op\">=<\/span> ProxGradient(y, X, beta_0,<\/span>\n<span id=\"cb10-4\"><a href=\"#cb10-4\" aria-hidden=\"true\"><\/a>                       <span class=\"kw\">lambda<\/span> x, a: prox_fn(x, N <span class=\"op\">*<\/span> lambda_val <span class=\"op\">*<\/span> a),<\/span>\n<span id=\"cb10-5\"><a href=\"#cb10-5\" aria-hidden=\"true\"><\/a>                       loss_fn, loss_grad_fn, <span class=\"dv\">10<\/span>)<\/span>\n<span id=\"cb10-6\"><a href=\"#cb10-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-7\"><a href=\"#cb10-7\" aria-hidden=\"true\"><\/a>pg_cols_fns <span class=\"op\">=<\/span> cols_fns <span class=\"op\">+<\/span> [(<span class=\"kw\">lambda<\/span> <span class=\"op\">*<\/span>args, <span class=\"op\">**<\/span>kwargs: pg_step.alpha_val,<\/span>\n<span id=\"cb10-8\"><a href=\"#cb10-8\" aria-hidden=\"true\"><\/a><span class=\"vs\">r&#39;$\\alpha$&#39;<\/span>)]<\/span>\n<span id=\"cb10-9\"><a href=\"#cb10-9\" aria-hidden=\"true\"><\/a>pg_est_data, _ <span class=\"op\">=<\/span> iterative_run(pg_step, loss_fn, pg_cols_fns)<\/span>\n<span id=\"cb10-10\"><a href=\"#cb10-10\" aria-hidden=\"true\"><\/a>pg_ls_data <span class=\"op\">=<\/span> pd.DataFrame(pg_est_data)<\/span>\n<span id=\"cb10-11\"><a href=\"#cb10-11\" aria-hidden=\"true\"><\/a><span class=\"co\"># pg_ls_data = pg_ls_data.append(pg_est_data, ignore_index=True)<\/span><\/span><\/code><\/pre><\/div>\n<p><span id=\"fig:pg_ls_plot\"><span id=\"fig:pg_ls_plot_span\" style=\"display:none;visibility:hidden\"><span class=\"math display\">\\[\\begin{equation}\\tag{1}\\label{fig:pg_ls_plot}\\end{equation}\\]<\/span><\/span><img src=\"https:\/\/brandonwillard.github.io\/figures\/more_proximal_estimation_pg_ls_plot_1.png\" title=\"fig:\" alt=\"Minimization by proximal gradient with backtracking line-search.\" \/><\/span><\/p>\n<p>Figure\u00a0<span class=\"math inline\">\\(\\ref{fig:pg_ls_plot}\\)<\/span> shows a couple convergence measures for proximal gradient steps alongside the step size changes due to backtracking line-search. Regarding the latter, in our example a sufficient step size is found within the first few iterations, so the overall result isn\u2019t too interesting. Fortunately, this sort of behaviour isn\u2019t uncommon, which makes line-search quite effective in practice.<\/p>\n<\/section>\n<section id=\"coordinate-wise-estimation\" class=\"level1\">\n<h1>Coordinate-wise Estimation<\/h1>\n<p>Given that our loss is a composition of <span class=\"math inline\">\\(\\ell_2\\)<\/span> and a linear operator of finite dimension (i.e.\u00a0<span class=\"math inline\">\\(X\\)<\/span>), we can conveniently exploit conditional separability and obtain simple estimation steps in each coordinate. This is, effectively, what characterizes coordinate\u2013or cyclic\u2013descent. Since it is a common technique in the estimation of <span class=\"math inline\">\\(\\ell_1\\)<\/span> models <span class=\"citation\" data-cites=\"friedman_pathwise_2007 mazumder_regularization_2009 scikit-learn_sklearn.linear_model.elasticnet_2017\">(Friedman et al. 2007; Mazumder, Hastie, and Tibshirani 2009; scikit-learn 2017)<\/span>, it\u2019s worthwhile to consider how it can viewed in terms of proximal operators.<\/p>\n<p>From a statistical perspective, the basics of coordinate-wise methods begin with the \u201cpartial residuals\u201d, <span class=\"math inline\">\\(r_{-m} \\in {{\\mathbb{R}}}^{N}\\)<\/span> discussed in <span class=\"citation\" data-cites=\"friedman_pathwise_2007\">Friedman et al. (2007)<\/span>, and implicitly defined by <span class=\"math display\">\\[\\begin{equation}\n\\begin{aligned}\n    \\beta^*\n    &amp;= \\operatorname*{argmin}_{\\beta} \\left\\{\n      \\frac12\n      \\|\n    y - X(\\beta - e_m \\beta_m)\n        - X e_m \\cdot \\beta_{m}\\|^2_2\n      + \\lambda \\left|\\beta_m\\right|\n      + \\lambda \\sum_{m^\\prime \\neq m} \\left|\\beta_{m^\\prime}\\right|\n      \\right\\}\n    \\\\\n    &amp;= \\operatorname*{argmin}_{\\beta} \\left\\{\n      \\frac12\n      \\|r_{-m} - X e_m \\cdot \\beta_{m}\\|^2_2\n      + \\lambda \\left|\\beta_m\\right|\n      + \\dots\n    \\right\\}\n  \\;.\n  \\end{aligned}\n  \\label{eq:partial_resid}\n\\end{equation}\\]<\/span> The last expression hints at the most basic idea behind the coordinate-wise approach: conditional minimization in each <span class=\"math inline\">\\(m\\)<\/span>. Its exact solution in each coordinate is given by the aforementioned soft-thresholding function, which\u2013as we\u2019ve already stated\u2013is a proximal operator. In symbols, <span class=\"math inline\">\\(\\operatorname*{prox}_{\\lambda \\left|\\cdot\\right|}(x) = \\operatorname{S}_\\lambda(x)\\)<\/span>, where the latter is the soft-thresholding operator.<\/p>\n<p>Now, we can relate Equation\u00a0<span class=\"math inline\">\\(\\eqref{eq:partial_resid}\\)<\/span> to proximal methods through the proximal gradient fixed-point solution\u2013i.e.\u00a0Equation\u00a0<span class=\"math inline\">\\(\\eqref{eq:forward-backward}\\)<\/span>\u2013and the following property of proximal operators:<\/p>\n<div id=\"lem:prox_ortho_basis\" class=\"lemma\" data-markdown=\"\" data-title-name=\"\">\n<p><span id=\"lem:prox_ortho_basis_span\" style=\"display:none;visibility:hidden\"><span class=\"math display\">\\[\\begin{equation}\\tag{1}\\label{lem:prox_ortho_basis}\\end{equation}\\]<\/span><\/span><\/p>\n<p><span class=\"math display\">\\[\\begin{equation*}\n\\operatorname*{prox}_{\\lambda \\phi \\circ e^\\top_m}(z) =\n    \\sum^M_m \\operatorname*{prox}_{\\lambda \\phi}\\left(e^\\top_m z\\right) e_m\n    \\;.\n\\end{equation*}\\]<\/span><\/p>\n<div class=\"proof\" data-markdown=\"\" data-title-name=\"\">\n<p>See <span class=\"citation\" data-cites=\"chaux_variational_2007\">Chaux et al. (2007)<\/span>.<\/p>\n<\/div>\n<\/div>\n<p>The next result yields our desired connection.<\/p>\n<div id=\"eq:prox_grad_descent\" class=\"proposition\" data-markdown=\"\" data-title-name=\"\">\n<p><span id=\"eq:prox_grad_descent_span\" style=\"display:none;visibility:hidden\"><span class=\"math display\">\\[\\begin{equation}\\tag{1}\\label{eq:prox_grad_descent}\\end{equation}\\]<\/span><\/span><\/p>\n<p>For <span class=\"math inline\">\\(X\\)<\/span> such that <span class=\"math inline\">\\({{\\bf 1}}^\\top X e_m = 0\\)<\/span> and <span class=\"math inline\">\\(e^\\top_m X^\\top X e_m = 1\\)<\/span>, <span class=\"math inline\">\\(m \\in \\{1, \\dots, M\\}\\)<\/span>, the coordinate-wise step of the Lasso in <span class=\"citation\" data-cites=\"friedman_pathwise_2007\">Friedman et al. (2007 Equation (9))<\/span>, <span class=\"math display\">\\[\\begin{equation*}\n\\beta_m = \\operatorname{S}_{\\lambda}\\left[\n      \\sum_{n}^N X_{n,m} \\left(\n      y_n - \\sum^M_{m^\\prime \\neq m} X_{n,m^\\prime} \\beta_{m^\\prime}\n      \\right)\n    \\right]\n    \\;,\n\\end{equation*}\\]<\/span> has a proximal gradient fixed-point solution under a Euclidean basis decomposition with the form <span class=\"math display\">\\[\\begin{equation*}\n\\beta =\n    \\sum^M_m \\operatorname*{prox}_{\\alpha \\lambda \\phi}\\left[\n      e^\\top_m \\left(\\beta - \\alpha \\nabla l(\\beta)\\right)\n    \\right] e_m\n    \\;.\n\\end{equation*}\\]<\/span><\/p>\n<div class=\"proof\" data-markdown=\"\" data-title-name=\"\">\n<p>We start with an expansion of the terms in <span class=\"math inline\">\\(\\operatorname*{prox}_{\\lambda \\phi} \\equiv \\operatorname{S}_\\lambda\\)<\/span>. After simplifying the notation with <span class=\"math display\">\\[\\begin{equation*}\n\\begin{gathered}\n    \\sum^N_{n} X_{n,m} z_n = e^\\top_m X^\\top z, \\quad \\text{and} \\quad\n    \\sum^M_{m^\\prime \\neq m} X_{n,m^\\prime} \\beta_{m^\\prime} =\n    X \\left(\\beta - \\beta_m e_m \\right)\n    \\;,\n  \\end{gathered}\n\\end{equation*}\\]<\/span> the expanded argument of <span class=\"math inline\">\\(\\operatorname{S}\\)<\/span> reduces to <span class=\"math display\">\\[\\begin{equation*}\n\\begin{aligned}\n      e^\\top_m X^\\top \\left(y - X\\left( \\beta - e_m \\beta_m\\right)\\right)\n      &amp;= e^\\top_m X^\\top X e_m \\beta_m + e^\\top_m X^\\top \\left(y - X \\beta\\right)\n      \\\\\n      &amp;= \\beta_m + e^\\top_m X^\\top \\left(y - X \\beta\\right)\n      \\\\\n      &amp;= e^\\top_m \\left(\\beta + X^\\top \\left(y - X \\beta\\right)\\right)\n    \\end{aligned}\n\\end{equation*}\\]<\/span> where the last step follows from <span class=\"math inline\">\\(X\\)<\/span> standardization. This establishes the relationship with Equation\u00a0<span class=\"math inline\">\\(\\eqref{eq:forward-backward}\\)<\/span> only component-wise. Using Lemma\u00a0<span class=\"math inline\">\\(\\eqref{lem:prox_ortho_basis}\\)<\/span> together with <span class=\"math inline\">\\(z = \\beta - \\alpha \\nabla  l(\\beta)\\)<\/span> yields the proximal gradient fixed-point statement, i.e.\u00a0<span class=\"math display\">\\[\\begin{equation*}\n\\begin{aligned}\n      \\beta\n      &amp;=\n      \\sum^M_m \\operatorname*{prox}_{\\alpha \\lambda \\phi}\\left[\n    e^\\top_m \\left(\\beta - \\alpha \\nabla l(\\beta)\\right)\n      \\right] e_m\n      \\\\\n      &amp;=\n      \\sum^M_m \\operatorname*{prox}_{\\alpha \\lambda \\phi}\\left(\n      \\beta_m + \\alpha e_m^\\top X^\\top \\left(y - X \\beta \\right)\n      \\right) e_m\n      \\;.\n    \\end{aligned}\n\\end{equation*}\\]<\/span><\/p>\n<\/div>\n<\/div>\n<div id=\"rem:bases\" class=\"remark\" data-markdown=\"\" data-title-name=\"\">\n<p><span id=\"rem:bases_span\" style=\"display:none;visibility:hidden\"><span class=\"math display\">\\[\\begin{equation}\\tag{3}\\label{rem:bases}\\end{equation}\\]<\/span><\/span><\/p>\n<p>The property in Lemma\u00a0<span class=\"math inline\">\\(\\eqref{lem:prox_ortho_basis}\\)<\/span> can be used with other orthonormal bases\u2013providing yet another connection between proximal methods and established dimensionality reduction and sparse estimation techniques <span class=\"citation\" data-cites=\"chaux_variational_2007\">(Chaux et al. 2007)<\/span>. Also, this property provides a neat way to think about <span class=\"math inline\">\\(X\\)<\/span>-based orthogonalizations in estimations for regression and grouped-penalization problems.<\/p>\n<\/div>\n<section id=\"implementation-1\" class=\"level2\">\n<h2>Implementation<\/h2>\n<p>The following performs a standard form of coordinate descent:<\/p>\n<div class=\"sourceCode\" id=\"cb11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb11-1\"><a href=\"#cb11-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> CoordDescent(<span class=\"bu\">object<\/span>):<\/span>\n<span id=\"cb11-2\"><a href=\"#cb11-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-3\"><a href=\"#cb11-3\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> <span class=\"fu\">__init__<\/span>(<span class=\"va\">self<\/span>, y, X, beta_0, prox_fn_, col_seq<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"cb11-4\"><a href=\"#cb11-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-5\"><a href=\"#cb11-5\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.y <span class=\"op\">=<\/span> y<\/span>\n<span id=\"cb11-6\"><a href=\"#cb11-6\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.X <span class=\"op\">=<\/span> X<\/span>\n<span id=\"cb11-7\"><a href=\"#cb11-7\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.beta_0 <span class=\"op\">=<\/span> beta_0<\/span>\n<span id=\"cb11-8\"><a href=\"#cb11-8\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.N, <span class=\"va\">self<\/span>.M <span class=\"op\">=<\/span> X.shape<\/span>\n<span id=\"cb11-9\"><a href=\"#cb11-9\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.Xb <span class=\"op\">=<\/span> np.dot(<span class=\"va\">self<\/span>.X, <span class=\"va\">self<\/span>.beta_0)<\/span>\n<span id=\"cb11-10\"><a href=\"#cb11-10\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.prox_fn_ <span class=\"op\">=<\/span> prox_fn_<\/span>\n<span id=\"cb11-11\"><a href=\"#cb11-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-12\"><a href=\"#cb11-12\" aria-hidden=\"true\"><\/a>        <span class=\"co\"># (Inverse) 2-norm of each column\/feature, i.e.<\/span><\/span>\n<span id=\"cb11-13\"><a href=\"#cb11-13\" aria-hidden=\"true\"><\/a>        <span class=\"co\">#   np.reciprocal(np.diag(np.dot(X.T, X)))<\/span><\/span>\n<span id=\"cb11-14\"><a href=\"#cb11-14\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.alpha_vals <span class=\"op\">=<\/span> np.reciprocal((<span class=\"va\">self<\/span>.X<span class=\"op\">**<\/span><span class=\"dv\">2<\/span>).<span class=\"bu\">sum<\/span>(axis<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>))<\/span>\n<span id=\"cb11-15\"><a href=\"#cb11-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-16\"><a href=\"#cb11-16\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">if<\/span> col_seq <span class=\"kw\">is<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb11-17\"><a href=\"#cb11-17\" aria-hidden=\"true\"><\/a>            <span class=\"va\">self<\/span>.col_seq <span class=\"op\">=<\/span> np.arange(<span class=\"va\">self<\/span>.M)<\/span>\n<span id=\"cb11-18\"><a href=\"#cb11-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-19\"><a href=\"#cb11-19\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> reset(<span class=\"va\">self<\/span>):<\/span>\n<span id=\"cb11-20\"><a href=\"#cb11-20\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.Xb <span class=\"op\">=<\/span> np.dot(<span class=\"va\">self<\/span>.X, <span class=\"va\">self<\/span>.beta_0)<\/span>\n<span id=\"cb11-21\"><a href=\"#cb11-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-22\"><a href=\"#cb11-22\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> step(<span class=\"va\">self<\/span>, beta):<\/span>\n<span id=\"cb11-23\"><a href=\"#cb11-23\" aria-hidden=\"true\"><\/a>        beta_val <span class=\"op\">=<\/span> np.copy(beta)<\/span>\n<span id=\"cb11-24\"><a href=\"#cb11-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-25\"><a href=\"#cb11-25\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">for<\/span> j <span class=\"kw\">in<\/span> <span class=\"va\">self<\/span>.col_seq:<\/span>\n<span id=\"cb11-26\"><a href=\"#cb11-26\" aria-hidden=\"true\"><\/a>            X_j <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.X[:, j]<\/span>\n<span id=\"cb11-27\"><a href=\"#cb11-27\" aria-hidden=\"true\"><\/a>            alpha_val <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.alpha_vals[j]<\/span>\n<span id=\"cb11-28\"><a href=\"#cb11-28\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-29\"><a href=\"#cb11-29\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># A little cheaper to just subtract the column&#39;s<\/span><\/span>\n<span id=\"cb11-30\"><a href=\"#cb11-30\" aria-hidden=\"true\"><\/a>contribution...<\/span>\n<span id=\"cb11-31\"><a href=\"#cb11-31\" aria-hidden=\"true\"><\/a>            <span class=\"va\">self<\/span>.Xb <span class=\"op\">-=<\/span> X_j <span class=\"op\">*<\/span> beta_val[j]<\/span>\n<span id=\"cb11-32\"><a href=\"#cb11-32\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-33\"><a href=\"#cb11-33\" aria-hidden=\"true\"><\/a>            Xt_r <span class=\"op\">=<\/span> np.dot(X_j.T, <span class=\"va\">self<\/span>.y <span class=\"op\">-<\/span> <span class=\"va\">self<\/span>.Xb) <span class=\"op\">*<\/span> alpha_val<\/span>\n<span id=\"cb11-34\"><a href=\"#cb11-34\" aria-hidden=\"true\"><\/a>            beta_val[j] <span class=\"op\">=<\/span> <span class=\"va\">self<\/span>.prox_fn_(np.atleast_1d(Xt_r),<\/span>\n<span id=\"cb11-35\"><a href=\"#cb11-35\" aria-hidden=\"true\"><\/a>alpha_val)<\/span>\n<span id=\"cb11-36\"><a href=\"#cb11-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-37\"><a href=\"#cb11-37\" aria-hidden=\"true\"><\/a>            <span class=\"co\"># ...and add the updated column back.<\/span><\/span>\n<span id=\"cb11-38\"><a href=\"#cb11-38\" aria-hidden=\"true\"><\/a>            <span class=\"va\">self<\/span>.Xb <span class=\"op\">+=<\/span> X_j <span class=\"op\">*<\/span> beta_val[j]<\/span>\n<span id=\"cb11-39\"><a href=\"#cb11-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-40\"><a href=\"#cb11-40\" aria-hidden=\"true\"><\/a>        <span class=\"va\">self<\/span>.beta_last <span class=\"op\">=<\/span> beta_val<\/span>\n<span id=\"cb11-41\"><a href=\"#cb11-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-42\"><a href=\"#cb11-42\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> beta_val<\/span><\/code><\/pre><\/div>\n<p>Our example randomizes the order of coordinates to loosely demonstrate the range of efficiency possible in coordinate descent.<\/p>\n<div class=\"sourceCode\" id=\"cb12\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb12-1\"><a href=\"#cb12-1\" aria-hidden=\"true\"><\/a>beta_0 <span class=\"op\">=<\/span> np.zeros(M).astype(<span class=\"st\">&#39;float64&#39;<\/span>)<\/span>\n<span id=\"cb12-2\"><a href=\"#cb12-2\" aria-hidden=\"true\"><\/a>lambda_val <span class=\"op\">=<\/span> <span class=\"fl\">1e-2<\/span> <span class=\"op\">*<\/span> lambda_max<\/span>\n<span id=\"cb12-3\"><a href=\"#cb12-3\" aria-hidden=\"true\"><\/a>cd_step <span class=\"op\">=<\/span> CoordDescent(y, X, beta_0,<\/span>\n<span id=\"cb12-4\"><a href=\"#cb12-4\" aria-hidden=\"true\"><\/a>                       <span class=\"kw\">lambda<\/span> x, a: prox_fn(x, N <span class=\"op\">*<\/span> lambda_val <span class=\"op\">*<\/span> a))<\/span>\n<span id=\"cb12-5\"><a href=\"#cb12-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-6\"><a href=\"#cb12-6\" aria-hidden=\"true\"><\/a>cd_cols_fns <span class=\"op\">=<\/span> cols_fns <span class=\"op\">+<\/span> [(<span class=\"kw\">lambda<\/span> <span class=\"op\">*<\/span>args, <span class=\"op\">**<\/span>kwargs: j, <span class=\"st\">&quot;replication&quot;<\/span>)]<\/span>\n<span id=\"cb12-7\"><a href=\"#cb12-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-8\"><a href=\"#cb12-8\" aria-hidden=\"true\"><\/a>pg_coord_data <span class=\"op\">=<\/span> pd.DataFrame()<\/span>\n<span id=\"cb12-9\"><a href=\"#cb12-9\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> j <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(<span class=\"dv\">15<\/span>):<\/span>\n<span id=\"cb12-10\"><a href=\"#cb12-10\" aria-hidden=\"true\"><\/a>    est_data, _ <span class=\"op\">=<\/span> iterative_run(cd_step, loss_fn, cd_cols_fns)<\/span>\n<span id=\"cb12-11\"><a href=\"#cb12-11\" aria-hidden=\"true\"><\/a>    pg_coord_data <span class=\"op\">=<\/span> pg_coord_data.append(est_data,<\/span>\n<span id=\"cb12-12\"><a href=\"#cb12-12\" aria-hidden=\"true\"><\/a>                                         ignore_index<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb12-13\"><a href=\"#cb12-13\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Reset internal state of our step method, since we&#39;re<\/span><\/span>\n<span id=\"cb12-14\"><a href=\"#cb12-14\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># running multiple replications.<\/span><\/span>\n<span id=\"cb12-15\"><a href=\"#cb12-15\" aria-hidden=\"true\"><\/a>    cd_step.reset()<\/span>\n<span id=\"cb12-16\"><a href=\"#cb12-16\" aria-hidden=\"true\"><\/a>    np.random.shuffle(cd_step.col_seq)<\/span><\/code><\/pre><\/div>\n<p><span id=\"fig:pg_coord_plot\"><span id=\"fig:pg_coord_plot_span\" style=\"display:none;visibility:hidden\"><span class=\"math display\">\\[\\begin{equation}\\tag{2}\\label{fig:pg_coord_plot}\\end{equation}\\]<\/span><\/span><img src=\"https:\/\/brandonwillard.github.io\/figures\/more_proximal_estimation_pg_coord_plot_1.png\" title=\"fig:\" alt=\"Minimization by coordinate descent.\" \/><\/span><\/p>\n<p>Figure\u00a0<span class=\"math inline\">\\(\\ref{fig:pg_coord_plot}\\)<\/span> shows convergence measures for each randomized coordinate order. The [average] difference in the number of iterations required for coordinate descent and proximal gradient is fairly noticeable. Nonetheless, both reach effectively the same limits.<\/p>\n<div class=\"remark\" data-markdown=\"\" data-title-name=\"\">\n<p>Similar ideas behind batched vs.\u00a0non-batched steps and block sampling\u2013found within the Gibbs sampling literature <span class=\"citation\" data-cites=\"roberts_updating_1997\">(Roberts and Sahu 1997)<\/span>\u2013could explain the variation due to coordinate order and the relative efficiency of coordinate descent. There are also connections with our comments in Remark\u00a0<span class=\"math inline\">\\(\\ref{rem:bases}\\)<\/span> and, to some extent, stochastic gradient descent (SGD) <span class=\"citation\" data-cites=\"bertsekas_incremental_2010\">(Bertsekas 2010)<\/span>.<\/p>\n<p>In a woefully lacking over-generalization, let\u2019s say that it comes down to the [spectral] properties of the composite operator(s) <span class=\"math inline\">\\(l \\circ X\\)<\/span> and\/or <span class=\"math inline\">\\(\\nabla l \\circ X\\)<\/span>. These determine the bounds of efficiency for steps in certain directions and how blocking or partitioning the dimensions of <span class=\"math inline\">\\(\\beta\\)<\/span> nears or distances from those bounds.<\/p>\n<\/div>\n<section id=\"regularization-paths\" class=\"level3\">\n<h3>Regularization Paths<\/h3>\n<p>Also, due to the relatively fast convergence of coordinate descent, the method is a little more suitable for the computation of regularization paths\u2013 i.e.\u00a0varying <span class=\"math inline\">\\(\\lambda\\)<\/span> between iterations. There is much more to this topic, but for simplicity let\u2019s just note that each <span class=\"math inline\">\\(\\lambda\\)<\/span> step has a \u201cwarm-start\u201d from the previous descent iteration\u2013which helps\u2013and that we\u2019re otherwise fine with the solution provided by this approach.<\/p>\n<p>Next, we make a small extension to demonstrate the computation of regularization paths\u2013using <code>lasso_path<\/code> for comparison.<\/p>\n<div class=\"sourceCode\" id=\"cb13\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb13-1\"><a href=\"#cb13-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> sklearn.linear_model <span class=\"im\">import<\/span> lasso_path, enet_path<\/span>\n<span id=\"cb13-2\"><a href=\"#cb13-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-3\"><a href=\"#cb13-3\" aria-hidden=\"true\"><\/a>beta_0 <span class=\"op\">=<\/span> np.zeros(M).astype(<span class=\"st\">&#39;float64&#39;<\/span>)<\/span>\n<span id=\"cb13-4\"><a href=\"#cb13-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-5\"><a href=\"#cb13-5\" aria-hidden=\"true\"><\/a>lambda_path, beta_path, _ <span class=\"op\">=<\/span> lasso_path(X, y)<\/span>\n<span id=\"cb13-6\"><a href=\"#cb13-6\" aria-hidden=\"true\"><\/a>path_len <span class=\"op\">=<\/span> np.alen(lambda_path)<\/span>\n<span id=\"cb13-7\"><a href=\"#cb13-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-8\"><a href=\"#cb13-8\" aria-hidden=\"true\"><\/a>beta_last <span class=\"op\">=<\/span> beta_0<\/span>\n<span id=\"cb13-9\"><a href=\"#cb13-9\" aria-hidden=\"true\"><\/a>pg_path_data <span class=\"op\">=<\/span> pd.DataFrame()<\/span>\n<span id=\"cb13-10\"><a href=\"#cb13-10\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> i, lambda_ <span class=\"kw\">in<\/span> <span class=\"bu\">enumerate<\/span>(lambda_path):<\/span>\n<span id=\"cb13-11\"><a href=\"#cb13-11\" aria-hidden=\"true\"><\/a>    cd_path_step <span class=\"op\">=<\/span> CoordDescent(y, X, beta_last,<\/span>\n<span id=\"cb13-12\"><a href=\"#cb13-12\" aria-hidden=\"true\"><\/a>                        <span class=\"kw\">lambda<\/span> x, a: prox_fn(x, N <span class=\"op\">*<\/span> lambda_ <span class=\"op\">*<\/span> a))<\/span>\n<span id=\"cb13-13\"><a href=\"#cb13-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-14\"><a href=\"#cb13-14\" aria-hidden=\"true\"><\/a>    cd_cols_fns <span class=\"op\">=<\/span> cols_fns[<span class=\"dv\">1<\/span>:] <span class=\"op\">+<\/span> [<\/span>\n<span id=\"cb13-15\"><a href=\"#cb13-15\" aria-hidden=\"true\"><\/a>        (<span class=\"kw\">lambda<\/span> <span class=\"op\">*<\/span>args, <span class=\"op\">**<\/span>kwargs: lambda_, <span class=\"vs\">r&#39;$\\lambda$&#39;<\/span>)]<\/span>\n<span id=\"cb13-16\"><a href=\"#cb13-16\" aria-hidden=\"true\"><\/a>    est_data, beta_last <span class=\"op\">=<\/span> iterative_run(cd_path_step, loss_fn,<\/span>\n<span id=\"cb13-17\"><a href=\"#cb13-17\" aria-hidden=\"true\"><\/a>                                        cd_cols_fns,<\/span>\n<span id=\"cb13-18\"><a href=\"#cb13-18\" aria-hidden=\"true\"><\/a>                                        stop_tol<span class=\"op\">=<\/span><span class=\"fl\">1e-4<\/span>,<\/span>\n<span id=\"cb13-19\"><a href=\"#cb13-19\" aria-hidden=\"true\"><\/a>                                        stop_loss<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb13-20\"><a href=\"#cb13-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb13-21\"><a href=\"#cb13-21\" aria-hidden=\"true\"><\/a>    pg_path_data <span class=\"op\">=<\/span> pg_path_data.append(est_data.iloc[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>, :],<\/span>\n<span id=\"cb13-22\"><a href=\"#cb13-22\" aria-hidden=\"true\"><\/a>                                       ignore_index<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb14\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb14-1\"><a href=\"#cb14-1\" aria-hidden=\"true\"><\/a>cd_cols_fns <span class=\"op\">=<\/span> cols_fns[<span class=\"dv\">1<\/span>:] <span class=\"op\">+<\/span> [<\/span>\n<span id=\"cb14-2\"><a href=\"#cb14-2\" aria-hidden=\"true\"><\/a>    (<span class=\"kw\">lambda<\/span> <span class=\"op\">*<\/span>args, <span class=\"op\">**<\/span>kwargs: lambda_path[args[<span class=\"dv\">0<\/span>]], <span class=\"vs\">r&#39;$\\lambda$&#39;<\/span>)]<\/span>\n<span id=\"cb14-3\"><a href=\"#cb14-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-4\"><a href=\"#cb14-4\" aria-hidden=\"true\"><\/a>iter_values <span class=\"op\">=<\/span> []<\/span>\n<span id=\"cb14-5\"><a href=\"#cb14-5\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> i, beta_ <span class=\"kw\">in<\/span> <span class=\"bu\">enumerate<\/span>(beta_path.T):<\/span>\n<span id=\"cb14-6\"><a href=\"#cb14-6\" aria-hidden=\"true\"><\/a>    iter_values.append([col_fn(i, beta_)<\/span>\n<span id=\"cb14-7\"><a href=\"#cb14-7\" aria-hidden=\"true\"><\/a>                        <span class=\"cf\">for<\/span> col_fn, _ <span class=\"kw\">in<\/span> cd_cols_fns])<\/span>\n<span id=\"cb14-8\"><a href=\"#cb14-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-9\"><a href=\"#cb14-9\" aria-hidden=\"true\"><\/a>sklearn_path_data <span class=\"op\">=<\/span> pd.DataFrame(iter_values,<\/span>\n<span id=\"cb14-10\"><a href=\"#cb14-10\" aria-hidden=\"true\"><\/a>                                 columns<span class=\"op\">=<\/span><span class=\"bu\">zip<\/span>(<span class=\"op\">*<\/span>cd_cols_fns)[<span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb14-11\"><a href=\"#cb14-11\" aria-hidden=\"true\"><\/a>sklearn_path_data <span class=\"op\">=<\/span> sklearn_path_data.assign(<\/span>\n<span id=\"cb14-12\"><a href=\"#cb14-12\" aria-hidden=\"true\"><\/a>    replication<span class=\"op\">=<\/span><span class=\"va\">None<\/span>, <span class=\"bu\">type<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;sklearn&#39;<\/span>)<\/span>\n<span id=\"cb14-13\"><a href=\"#cb14-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-14\"><a href=\"#cb14-14\" aria-hidden=\"true\"><\/a>pg_path_data <span class=\"op\">=<\/span> pg_path_data.assign(<span class=\"bu\">type<\/span><span class=\"op\">=<\/span><span class=\"st\">&#39;pg&#39;<\/span>)<\/span>\n<span id=\"cb14-15\"><a href=\"#cb14-15\" aria-hidden=\"true\"><\/a>pg_path_data <span class=\"op\">=<\/span> pg_path_data.append(sklearn_path_data,<\/span>\n<span id=\"cb14-16\"><a href=\"#cb14-16\" aria-hidden=\"true\"><\/a>                                   ignore_index<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span><\/code><\/pre><\/div>\n<p><span id=\"fig:pg_path_plot\"><span id=\"fig:pg_path_plot_span\" style=\"display:none;visibility:hidden\"><span class=\"math display\">\\[\\begin{equation}\\tag{3}\\label{fig:pg_path_plot}\\end{equation}\\]<\/span><\/span><img src=\"https:\/\/brandonwillard.github.io\/figures\/more_proximal_estimation_pg_path_plot_1.png\" title=\"fig:\" alt=\"Regularization paths via coordinate descent.\" \/><\/span><\/p>\n<\/section>\n<\/section>\n<\/section>\n<section id=\"discussion\" class=\"level1\">\n<h1>Discussion<\/h1>\n<p>Among the changes discussed earlier regarding Theano <code>Op<\/code>s for the proximal objects used here, we would also like to motivate much larger changes to the applied mathematician\/statistician\u2019s standard tools by demonstrating the relevance of less common\u2013yet increasingly useful\u2013abstractions. For instance, the proximal methods are neatly framed within operator theory and set-valued analysis, where concepts like the resolvent, sub-differential\/gradient and others are common. Abstractions like these provide a compact means of extending familiar ideas into new contexts\u2013such as non-differentiable functions.<\/p>\n<p>Unfortunately, numerical libraries do not provide much in the way of utilizing these abstractions. Most are strictly founded in the representation of point-valued mappings, which can require significant work-arounds to handle even the most basic non-differentiable functions (e.g.\u00a0the absolute value within our example problem). Our use of the proximal framework is, in part, motivated by its near seamless use <em>and<\/em> simultaneous bypassing of set-valued maps\u2013in implementation, at least.<\/p>\n<p>There is no fundamental restriction blocking support for set-valued maps, however\u2013aside from the necessary labor and community interest. Even minimal support could provide a context that makes frameworks like ours merely minor abstractions. A similar idea can be found in the symbolic calculation of limits via filters <span class=\"citation\" data-cites=\"beeson_meaning_2005\">(Beeson and Wiedijk 2005)<\/span>. Perhaps we can liken these changes to the modern evolution of linear algebra libraries to tensor libraries.<\/p>\n<p>We would also like to stress that the value provided by the symbolic tools discussed here (Theano, really) are not <em>just<\/em> in their ability to act as compilers at a \u201cmath level\u201d, but more for their ability to concretely encode mathematical characterizations of optimization problems and methods. Work in this direction is not new by any means; however, the combination of open-source tools and industry interest in algorithms that fall under the broad class of proximal methods (e.g.\u00a0gradient descent, ADMM, EM, etc.) provides a more immediate reason to pursue these abstractions in code and automate their use.<\/p>\n<p>Regarding the proximal methods, we can consider Theano optimizations that make direct use of the orthonormal basis property in Lemma\u00a0<span class=\"math inline\">\\(\\eqref{lem:prox_ortho_basis}\\)<\/span>, or the Moreau-Fenchel theorem, and automate consideration for various estimation methods via splitting (e.g.\u00a0ADMM, Douglas-Rachford, etc.)\u2013perhaps by making decisions based on inferred or specified tensor, function, and operator properties. In future installments we\u2019ll delve into the details of these ideas.<\/p>\n<p><span class=\"citation\" data-cites=\"wytock_new_2016\">(Wytock et al. 2016)<\/span> also discuss similar ideas in an optimization setting, such as the use of symbolic graphs and a close coupling with useful mathematical abstractions\u2013including proximal operators. Additionally, there are many other good examples <span class=\"citation\" data-cites=\"diamond_cvxpy:_2016\">(Diamond and Boyd 2016)<\/span> of constructive mathematical abstractions applied in code.<\/p>\n<p>In most cases, libraries providing optimization tools and supporting model estimation do not attempt to root their implementations within an independently developed symbolic framework and then realize their relevant methodologies in that context. Too often the mathematical abstractions\u2013or the resulting methods alone\u2013are directly implemented at the highest levels of abstraction possible. This is what we see as the result of popular libraries like <code>scikit-learn<\/code> and the body of <code>R<\/code> packages. One can also find the same efforts for proximal methods themselves\u2013e.g.\u00a0in <span class=\"citation\" data-cites=\"svaiter_pyprox_2017\">(svaiter 2017)<\/span>, where individual functions for ADMM, forward-backward\/proximal gradient and Douglas-Rachford are the end result. This is the most common approach and it makes sense in terms of simplicity, but offers very little of the extensibility, generalization, or efficiencies provided by shared efforts across related projects and fields.<\/p>\n<p>In the context of Theano, implementations immediately benefit from its code conversion, parallelization and relevant improvements to its basic graph optimizations. The latter covers both low-level computational efficiency\u2013such as relevant application of BLAS functions\u2013and high-level tensor algebra simplifications.<\/p>\n<p>In a development community that builds on these tools, related efficiency and performance gains can occur much more often, without necessarily sacrificing the specificity inherent to certain areas of application. For example, we can safely use the Rao-Blackwell theorem as the basis of a graph optimization in PyMC3, so it could be included among that project\u2019s default offerings; however, it would be far too cumbersome to use productively in a less specific context.<\/p>\n<\/section>\n<section id=\"bibliography\" class=\"level1 unnumbered\">\n<h1 class=\"unnumbered\">\"References\"<\/h1>\n<div id=\"refs\" class=\"references hanging-indent\" role=\"doc-bibliography\">\n<div id=\"ref-beck_fast_2014\">\n<p>Beck, Amir, and Marc Teboulle. 2014. \u201cA Fast Dual Proximal Gradient Algorithm for Convex Minimization and Applications.\u201d <em>Operations Research Letters<\/em> 42 (1): 1\u20136. <a href=\"http:\/\/www.sciencedirect.com\/science\/article\/pii\/S0167637713001454\">http:\/\/www.sciencedirect.com\/science\/article\/pii\/S0167637713001454<\/a>.<\/p>\n<\/div>\n<div id=\"ref-beeson_meaning_2005\">\n<p>Beeson, Michael, and Freek Wiedijk. 2005. \u201cThe Meaning of Infinity in Calculus and Computer Algebra Systems.\u201d <em>Journal of Symbolic Computation<\/em>, Automated reasoning and computer algebra systems (ar-ca)AR-ca, 39 (5): 523\u201338. <a href=\"https:\/\/www.sciencedirect.com\/science\/article\/pii\/S074771710500026X\">https:\/\/www.sciencedirect.com\/science\/article\/pii\/S074771710500026X<\/a>.<\/p>\n<\/div>\n<div id=\"ref-bergstra_theano_2010\">\n<p>Bergstra, James, Olivier Breuleux, Fr\u00e9d\u00e9ric Bastien, Pascal Lamblin, Razvan Pascanu, Guillaume Desjardins, Joseph Turian, David Warde-Farley, and Yoshua Bengio. 2010. \u201cTheano: A CPU and GPU Math Expression Compiler.\u201d In <em>Proceedings of the Python for Scientific Computing Conference (SciPy)<\/em>. Austin, TX.<\/p>\n<\/div>\n<div id=\"ref-bertsekas_incremental_2010\">\n<p>Bertsekas, Dimitri P. 2010. \u201cIncremental Gradient, Subgradient, and Proximal Methods for Convex Optimization: A Survey.\u201d <a href=\"http:\/\/web.mit.edu\/dimitrib\/www\/Incremental_Survey_LIDS.pdf\">http:\/\/web.mit.edu\/dimitrib\/www\/Incremental_Survey_LIDS.pdf<\/a>.<\/p>\n<\/div>\n<div id=\"ref-chaux_variational_2007\">\n<p>Chaux, Caroline, Patrick L Combettes, Jean-Christophe Pesquet, and Val\u00e9rie R Wajs. 2007. \u201cA Variational Formulation for Frame-Based Inverse Problems.\u201d <em>Inverse Problems<\/em> 23 (4): 1495.<\/p>\n<\/div>\n<div id=\"ref-combettes_proximal_2011\">\n<p>Combettes, Patrick L, and Jean-Christophe Pesquet. 2011. \u201cProximal Splitting Methods in Signal Processing.\u201d <em>Fixed-Point Algorithms for Inverse Problems in Science and Engineering<\/em>, 185\u2013212.<\/p>\n<\/div>\n<div id=\"ref-diamond_cvxpy:_2016\">\n<p>Diamond, Steven, and Stephen Boyd. 2016. \u201cCVXPY: A Python-Embedded Modeling Language for Convex Optimization.\u201d <em>Journal of Machine Learning Research<\/em> 17 (83): 1\u20135.<\/p>\n<\/div>\n<div id=\"ref-friedman_pathwise_2007\">\n<p>Friedman, Jerome, Trevor Hastie, Holger H\u00f6fling, Robert Tibshirani, and others. 2007. \u201cPathwise Coordinate Optimization.\u201d <em>The Annals of Applied Statistics<\/em> 1 (2): 302\u201332. <a href=\"http:\/\/projecteuclid.org\/euclid.aoas\/1196438020\">http:\/\/projecteuclid.org\/euclid.aoas\/1196438020<\/a>.<\/p>\n<\/div>\n<div id=\"ref-mazumder_regularization_2009\">\n<p>Mazumder, Rahul, Trevor Hastie, and Rob Tibshirani. 2009. \u201cRegularization Methods for Learning Incomplete Matrices.\u201d <em>arXiv Preprint arXiv:0906.2034<\/em>. <a href=\"https:\/\/arxiv.org\/abs\/0906.2034\">https:\/\/arxiv.org\/abs\/0906.2034<\/a>.<\/p>\n<\/div>\n<div id=\"ref-parikh_proximal_2014\">\n<p>Parikh, Neal, and Stephen Boyd. 2014. \u201cProximal Algorithms.\u201d <em>Foundations and Trends in Optimization<\/em> 1 (3): 123\u2013231. <a href=\"https:\/\/doi.org\/10.1561\/2400000003\">https:\/\/doi.org\/10.1561\/2400000003<\/a>.<\/p>\n<\/div>\n<div id=\"ref-polson_proximal_2015\">\n<p>Polson, Nicholas G., James G. Scott, and Brandon T. Willard. 2015. \u201cProximal Algorithms in Statistics and Machine Learning.\u201d <em>Statistical Science<\/em> 30 (4): 559\u201381. <a href=\"http:\/\/projecteuclid.org\/euclid.ss\/1449670858\">http:\/\/projecteuclid.org\/euclid.ss\/1449670858<\/a>.<\/p>\n<\/div>\n<div id=\"ref-roberts_updating_1997\">\n<p>Roberts, Gareth O., and Sujit K. Sahu. 1997. \u201cUpdating Schemes, Correlation Structure, Blocking and Parameterization for the Gibbs Sampler.\u201d <em>Journal of the Royal Statistical Society: Series B (Statistical Methodology)<\/em> 59 (2): 291\u2013317. <a href=\"http:\/\/onlinelibrary.wiley.com\/doi\/10.1111\/1467-9868.00070\/abstract\">http:\/\/onlinelibrary.wiley.com\/doi\/10.1111\/1467-9868.00070\/abstract<\/a>.<\/p>\n<\/div>\n<div id=\"ref-salvatier_probabilistic_2016\">\n<p>Salvatier, John, Thomas V. Wiecki, and Christopher Fonnesbeck. 2016. \u201cProbabilistic Programming in Python Using PyMC3.\u201d <em>PeerJ Computer Science<\/em> 2 (April): e55. <a href=\"https:\/\/peerj.com\/articles\/cs-55\">https:\/\/peerj.com\/articles\/cs-55<\/a>.<\/p>\n<\/div>\n<div id=\"ref-scikit-learn_sklearn.linear_model.elasticnet_2017\">\n<p>scikit-learn. 2017. \u201cSklearn.Linear_model.ElasticNet Scikit-Learn 0.19.Dev0 Documentation.\u201d <a href=\"http:\/\/scikit-learn.org\/dev\/modules\/generated\/sklearn.linear_model.ElasticNet.html\\#sklearn-linear-model-elasticnet\">http:\/\/scikit-learn.org\/dev\/modules\/generated\/sklearn.linear_model.ElasticNet.html\\#sklearn-linear-model-elasticnet<\/a>.<\/p>\n<\/div>\n<div id=\"ref-svaiter_pyprox_2017\">\n<p>svaiter. 2017. \u201cPyprox.\u201d <a href=\"https:\/\/github.com\/svaiter\/pyprox\">https:\/\/github.com\/svaiter\/pyprox<\/a>.<\/p>\n<\/div>\n<div id=\"ref-willard_role_2017\">\n<p>Willard, Brandon T. 2017. \u201cA Role for Symbolic Computation in the General Estimation of Statistical Models.\u201d <a href=\"https:\/\/brandonwillard.github.io\/a-role-for-symbolic-computation-in-the-general-estimation-of-statistical-models.html\">https:\/\/brandonwillard.github.io\/a-role-for-symbolic-computation-in-the-general-estimation-of-statistical-models.html<\/a>.<\/p>\n<\/div>\n<div id=\"ref-wytock_new_2016\">\n<p>Wytock, Matt, Steven Diamond, Felix Heide, and Stephen Boyd. 2016. \u201cA New Architecture for Optimization Modeling Frameworks.\u201d <em>arXiv Preprint arXiv:1609.03488<\/em>. <a href=\"https:\/\/arxiv.org\/abs\/1609.03488\">https:\/\/arxiv.org\/abs\/1609.03488<\/a>.<\/p>\n<\/div>\n<\/div>\n<\/section>\n<\/body>\n<\/html>\n","category":{"@attributes":{"term":"articles"}}},{"title":"A Role for Symbolic Computation in the General Estimation of Statistical Models","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/a-role-for-symbolic-computation-in-the-general-estimation-of-statistical-models.html","rel":"alternate"}},"published":"2017-01-18T00:00:00-06:00","updated":"2017-01-18T00:00:00-06:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2017-01-18:\/a-role-for-symbolic-computation-in-the-general-estimation-of-statistical-models.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>A Role for Symbolic Computation in the General Estimation of Statistical Models<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">A Role for Symbolic Computation in the General Estimation of Statistical Models<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2017\u201301\u201318<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p>In this document we describe how symbolic computation can be used to provide generalizable statistical estimation through a combination of existing open source frameworks. Specifically, we will show how symbolic tools can be used to address the estimation of non-smooth functions that appear in models with parameter regularization, shrinkage and sparsity. We employ a mathematical framework that makes extensive use of <em>proximal operators<\/em> <span class=\"citation\" data-cites=\"parikh_proximal_2014 combettes_proximal_2011\">(Parikh and Boyd 2014; Combettes and Pesquet 2011)<\/span> and their properties for maximum a posteriori (MAP) estimation: i.e.\u00a0the <em>proximal framework<\/em>. This framework produces what we\u2019ll call <em>proximal methods<\/em> and their implementations as <em>proximal algorithms<\/em>.<\/p>\n<p>In <span class=\"citation\" data-cites=\"polson_proximal_2015\">Polson, Scott, and Willard (2015)<\/span> we outlined a set of seemingly disparate optimization techniques within the fields of statistics, computer vision, and machine learning (e.g.\u00a0gradient descent, ADMM, EM, Douglas-Rachford) that are unified by their various applications of proximal methods. These methods\u2013and the concepts behind them\u2013have found much success in recent times and admit quite a few interesting paths for research. In other words, there are many reasons to alone discuss the implementation of proximal methods.<\/p>\n<p>Proximal operators also enjoy a breadth of closed-form solutions and useful properties that are amenable to symbolic computation. In more than a few cases, the work required to produce a proximal algorithm overlaps with well-established features of computer algebra systems and symbolic mathematics, such as symbolic differentiation and algebraic equation solving.<\/p>\n<p>Symbolic integration provides an excellent example of how proximal operators could be implemented in a symbolic system. In these systems, mappings between functions (as canonicalized graphs) and their generalized hypergeometric equivalents are used to exploit the latter\u2019s relevant convolution identities. In the same vein, it is possible to use tables of closed-form proximal operators and their properties to produce a wide array of estimation algorithms for many non-smooth functions. We outline how this might be done in the following sections.<\/p>\n<p>Otherwise, the ideas discussed here are part of a never-ending attempt to answer a question that arises naturally in both mathematics and programming\u2013at all levels: <em>How does one provide a means of generating robust solutions to as many problems as possible?<\/em> Instead of the common efforts to independently implement each model, method and\/or combination of the two\u2013followed by their placement in an API or library of functions\u2013implementations can be encoded in and organized by the very mathematics from which they were derived. This close coupling between mathematical principles and their implementations might be the only reasonable way to remove barriers between theory, research and practice.<\/p>\n<section id=\"a-context\" class=\"level2\">\n<h2>A Context<\/h2>\n<p>Much recent work in statistical modeling and estimation has had the goal of producing sparse results and\/or efficient, near automatic model selection. This objective is shared with other related practices\u2013such as Deep Learning and Compressed Sensing. In the former case, we can point to Dropout <span class=\"citation\" data-cites=\"srivastava_dropout_2014\">(Srivastava et al. 2014)<\/span> and\u2013in the latter\u2013<span class=\"math inline\">\\(\\ell_p\\)<\/span> regularization <span class=\"citation\" data-cites=\"donoho_compressed_2006\">(Donoho 2006)<\/span> as basic examples.<\/p>\n<p>Here we\u2019ll simply assume that a practitioner intends to produce sparse estimates using the well-known LASSO\u2013or <span class=\"math inline\">\\(\\ell_1\\)<\/span> penalty.<\/p>\n<p>In PyMC3 <span class=\"citation\" data-cites=\"salvatier_probabilistic_2016\">(Salvatier, Wiecki, and Fonnesbeck 2016)<\/span>, the Bayes version of LASSO <span class=\"citation\" data-cites=\"park_bayesian_2008\">(Park and Casella 2008)<\/span> is easily specified.<\/p>\n<div class=\"sourceCode\" id=\"cb1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> numpy <span class=\"im\">as<\/span> np<\/span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> scipy <span class=\"im\">as<\/span> sc<\/span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pymc3 <span class=\"im\">as<\/span> pm<\/span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano<\/span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano.tensor <span class=\"im\">as<\/span> tt<\/span>\n<span id=\"cb1-7\"><a href=\"#cb1-7\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> shared <span class=\"im\">as<\/span> tt_shared<\/span>\n<span id=\"cb1-8\"><a href=\"#cb1-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-9\"><a href=\"#cb1-9\" aria-hidden=\"true\"><\/a>theano.config.mode <span class=\"op\">=<\/span> <span class=\"st\">&#39;FAST_COMPILE&#39;<\/span><\/span>\n<span id=\"cb1-10\"><a href=\"#cb1-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-11\"><a href=\"#cb1-11\" aria-hidden=\"true\"><\/a>mu_true <span class=\"op\">=<\/span> np.zeros(<span class=\"dv\">100<\/span>)<\/span>\n<span id=\"cb1-12\"><a href=\"#cb1-12\" aria-hidden=\"true\"><\/a>mu_true[:<span class=\"dv\">20<\/span>] <span class=\"op\">=<\/span> np.exp(<span class=\"op\">-<\/span>np.arange(<span class=\"dv\">20<\/span>)) <span class=\"op\">*<\/span> <span class=\"dv\">100<\/span><\/span>\n<span id=\"cb1-13\"><a href=\"#cb1-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-14\"><a href=\"#cb1-14\" aria-hidden=\"true\"><\/a>X <span class=\"op\">=<\/span> np.random.randn(<span class=\"bu\">int<\/span>(np.alen(mu_true) <span class=\"op\">*<\/span> <span class=\"fl\">0.7<\/span>), np.alen(mu_true))<\/span>\n<span id=\"cb1-15\"><a href=\"#cb1-15\" aria-hidden=\"true\"><\/a>y <span class=\"op\">=<\/span> sc.stats.norm.rvs(loc<span class=\"op\">=<\/span>X.dot(mu_true), scale<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>)<\/span>\n<span id=\"cb1-16\"><a href=\"#cb1-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-17\"><a href=\"#cb1-17\" aria-hidden=\"true\"><\/a>X_tt <span class=\"op\">=<\/span> tt_shared(X, name<span class=\"op\">=<\/span><span class=\"st\">&#39;X&#39;<\/span>, borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-18\"><a href=\"#cb1-18\" aria-hidden=\"true\"><\/a>y_tt <span class=\"op\">=<\/span> tt_shared(y, name<span class=\"op\">=<\/span><span class=\"st\">&#39;y&#39;<\/span>, borrow<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-19\"><a href=\"#cb1-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-20\"><a href=\"#cb1-20\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model() <span class=\"im\">as<\/span> lasso_model:<\/span>\n<span id=\"cb1-21\"><a href=\"#cb1-21\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Would be nice if we could pass the symbolic y_tt.shape, so<\/span><\/span>\n<span id=\"cb1-22\"><a href=\"#cb1-22\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># that our model would automatically conform to changes in<\/span><\/span>\n<span id=\"cb1-23\"><a href=\"#cb1-23\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># the shared variables X_tt.<\/span><\/span>\n<span id=\"cb1-24\"><a href=\"#cb1-24\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># See https:\/\/github.com\/pymc-devs\/pymc3\/pull\/1125<\/span><\/span>\n<span id=\"cb1-25\"><a href=\"#cb1-25\" aria-hidden=\"true\"><\/a>    beta_rv <span class=\"op\">=<\/span> pm.Laplace(<span class=\"st\">&#39;beta&#39;<\/span>, mu<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>, b<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>, shape<span class=\"op\">=<\/span>X.shape[<span class=\"dv\">1<\/span>])<\/span>\n<span id=\"cb1-26\"><a href=\"#cb1-26\" aria-hidden=\"true\"><\/a>    y_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;y&#39;<\/span>, mu<span class=\"op\">=<\/span>X_tt.dot(beta_rv), sd<span class=\"op\">=<\/span><span class=\"dv\">1<\/span>,<\/span>\n<span id=\"cb1-27\"><a href=\"#cb1-27\" aria-hidden=\"true\"><\/a>                     shape<span class=\"op\">=<\/span>y.shape[<span class=\"dv\">0<\/span>], observed<span class=\"op\">=<\/span>y_tt)<\/span><\/code><\/pre><\/div>\n<p>Again, the negative total log likelihood in our example has a non-smooth <span class=\"math inline\">\\(\\ell_1\\)<\/span> term. Keeping this in mind, let\u2019s say we wanted to produce a MAP estimate using PyMC3. A function is already provided for this task: <code>find_MAP<\/code>.<\/p>\n<div class=\"sourceCode\" id=\"cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb2-1\"><a href=\"#cb2-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> lasso_model:<\/span>\n<span id=\"cb2-2\"><a href=\"#cb2-2\" aria-hidden=\"true\"><\/a>    params_0 <span class=\"op\">=<\/span> pm.find_MAP(<span class=\"bu\">vars<\/span><span class=\"op\">=<\/span>[beta_rv])<\/span><\/code><\/pre><\/div>\n<p>In our run of the above, an exception was thrown due to <code>nan<\/code> values within the gradient evaluation. We can inspect the gradient at <span class=\"math inline\">\\(\\beta = 0, 1\\)<\/span> and reproduce the result.<\/p>\n<div class=\"sourceCode\" id=\"cb3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb3-1\"><a href=\"#cb3-1\" aria-hidden=\"true\"><\/a>start <span class=\"op\">=<\/span> pm.Point({<span class=\"st\">&#39;beta&#39;<\/span>: np.zeros(X.shape[<span class=\"dv\">1<\/span>])}, model<span class=\"op\">=<\/span>lasso_model)<\/span>\n<span id=\"cb3-2\"><a href=\"#cb3-2\" aria-hidden=\"true\"><\/a>bij <span class=\"op\">=<\/span> pm.DictToArrayBijection(pm.ArrayOrdering(lasso_model.<span class=\"bu\">vars<\/span>),<\/span>\n<span id=\"cb3-3\"><a href=\"#cb3-3\" aria-hidden=\"true\"><\/a>start)<\/span>\n<span id=\"cb3-4\"><a href=\"#cb3-4\" aria-hidden=\"true\"><\/a>logp <span class=\"op\">=<\/span> bij.mapf(lasso_model.fastlogp)<\/span>\n<span id=\"cb3-5\"><a href=\"#cb3-5\" aria-hidden=\"true\"><\/a>dlogp <span class=\"op\">=<\/span> bij.mapf(lasso_model.fastdlogp(lasso_model.<span class=\"bu\">vars<\/span>))<\/span>\n<span id=\"cb3-6\"><a href=\"#cb3-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-7\"><a href=\"#cb3-7\" aria-hidden=\"true\"><\/a><span class=\"co\"># Could also inspect the log likelihood of the prior:<\/span><\/span>\n<span id=\"cb3-8\"><a href=\"#cb3-8\" aria-hidden=\"true\"><\/a><span class=\"co\"># beta_rv.dlogp().f(np.zeros_like(start[&#39;beta&#39;]))<\/span><\/span>\n<span id=\"cb3-9\"><a href=\"#cb3-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-10\"><a href=\"#cb3-10\" aria-hidden=\"true\"><\/a>grad_at_0 <span class=\"op\">=<\/span> dlogp(np.zeros_like(start[<span class=\"st\">&#39;beta&#39;<\/span>]))<\/span>\n<span id=\"cb3-11\"><a href=\"#cb3-11\" aria-hidden=\"true\"><\/a>grad_at_1 <span class=\"op\">=<\/span> dlogp(np.ones_like(start[<span class=\"st\">&#39;beta&#39;<\/span>]))<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb4-1\"><a href=\"#cb4-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> <span class=\"bu\">print<\/span>(np.<span class=\"bu\">sum<\/span>(np.isnan(grad_at_0)))<\/span>\n<span id=\"cb4-2\"><a href=\"#cb4-2\" aria-hidden=\"true\"><\/a><span class=\"dv\">100<\/span><\/span>\n<span id=\"cb4-3\"><a href=\"#cb4-3\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> <span class=\"bu\">print<\/span>(np.<span class=\"bu\">sum<\/span>(np.isnan(grad_at_1)))<\/span>\n<span id=\"cb4-4\"><a href=\"#cb4-4\" aria-hidden=\"true\"><\/a><span class=\"dv\">0<\/span><\/span><\/code><\/pre><\/div>\n<p>The s are not due to any short-coming of PyMC3; they only demonstrate a suitable place for our ideas and improvements. Additionally, by working within PyMC3, we can readily apply certain mathematical results. For instance, theorems that apply only to distributions. This idea is more relevant to the graph optimizations we consider later, but is still very important.<\/p>\n<\/section>\n<\/section>\n<section id=\"the-proximal-context\" class=\"level1\">\n<h1>The Proximal Context<\/h1>\n<p>We start with the essential ingredient: the proximal operator.<\/p>\n<div class=\"Def\" data-markdown=\"\" data-title-name=\"[Proximal Operator]\">\n<p><span class=\"math display\">\\[\\begin{equation*}\n\\operatorname*{prox}_{\\phi}(x) =\n    \\operatorname*{argmin}_{z} \\left\\{\n    \\frac{1}{2} \\left(z - x\\right)^2 + \\phi(z)\n    \\right\\}\n    \\;.\n\\end{equation*}\\]<\/span><\/p>\n<\/div>\n<p>As we mentioned earlier, the proximal operator is the main tool of proximal algorithms. Exact solutions to proximal operators exist for many <span class=\"math inline\">\\(\\phi\\)<\/span>, and, since they\u2019re often quite simple in form, their computation is relatively cheap: a property that the proximal methods themselves can inherit.<\/p>\n<p>Consider the MAP estimation of a penalized likelihood, i.e.\u00a0<span class=\"math display\">\\[\\begin{equation}\n\\beta^* = \\operatorname*{argmin}_\\beta \\left\\{ l(\\beta) + \\gamma \\phi(\\beta) \\right\\}\n  \\;,\n  \\label{eq:prox_problem}\n\\end{equation}\\]<\/span> where functions <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span> are commonly referred to as likelihood and prior terms (or loss and penalty), respectively. The proximal framework usually assumes <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span> are at least lower semi-continuous and convex\u2013although quite a few useful results still hold for non-convex functions.<\/p>\n<p>Notice that Equation\u00a0<span class=\"math inline\">\\(\\eqref{eq:prox_problem}\\)<\/span> takes the form of a proximal operator when <span class=\"math inline\">\\(l(\\beta) = \\frac{1}{2} (y - \\beta)^2\\)<\/span>. Otherwise, in regression problems, we have <span class=\"math inline\">\\(l(\\beta) = \\frac{1}{2} \\|y - X \\beta\\|^2\\)<\/span>. In this case, properties of the proximal operator can be used to produce independent proximal operators in each dimension of <span class=\"math inline\">\\(\\beta\\)<\/span>. Since more than one property of the proximal operator can accomplish this\u2013and result in distinct approaches\u2013one might begin to see here a reason for the breadth of proximal methods.<\/p>\n<p>The proximal operator relevant to our example, <span class=\"math inline\">\\(\\operatorname*{prox}_{|\\cdot|}\\)<\/span>, is equivalent to the soft-thresholding operator. Its implementation in Theano is somewhat trivial, but\u2013for the sake of exposition\u2013we provide an example.<\/p>\n<div class=\"sourceCode\" id=\"cb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb5-1\"><a href=\"#cb5-1\" aria-hidden=\"true\"><\/a>beta_tt <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;beta&#39;<\/span>, dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb5-2\"><a href=\"#cb5-2\" aria-hidden=\"true\"><\/a>beta_tt.tag.test_value <span class=\"op\">=<\/span> np.r_[<span class=\"op\">-<\/span><span class=\"dv\">10<\/span>, <span class=\"op\">-<\/span><span class=\"dv\">1<\/span>, <span class=\"op\">-<\/span><span class=\"fl\">0.2<\/span>, <span class=\"dv\">0<\/span>, <span class=\"fl\">0.2<\/span>, <span class=\"dv\">1<\/span>,<\/span>\n<span id=\"cb5-3\"><a href=\"#cb5-3\" aria-hidden=\"true\"><\/a><span class=\"dv\">10<\/span>].astype(tt.config.floatX)<\/span>\n<span id=\"cb5-4\"><a href=\"#cb5-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-5\"><a href=\"#cb5-5\" aria-hidden=\"true\"><\/a>lambda_tt <span class=\"op\">=<\/span> tt.scalar(<span class=\"st\">&#39;lambda&#39;<\/span>, dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb5-6\"><a href=\"#cb5-6\" aria-hidden=\"true\"><\/a>lambda_tt.tag.test_value <span class=\"op\">=<\/span> np.array(<span class=\"fl\">0.5<\/span>).astype(tt.config.floatX)<\/span>\n<span id=\"cb5-7\"><a href=\"#cb5-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-8\"><a href=\"#cb5-8\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> soft_threshold(beta_, lambda_):<\/span>\n<span id=\"cb5-9\"><a href=\"#cb5-9\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> tt.sgn(beta_) <span class=\"op\">*<\/span> tt.maximum(tt.abs_(beta_) <span class=\"op\">-<\/span> lambda_, <span class=\"dv\">0<\/span>)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb6-1\"><a href=\"#cb6-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> <span class=\"bu\">print<\/span>(soft_threshold(beta_tt, lambda_tt).tag.test_value)<\/span>\n<span id=\"cb6-2\"><a href=\"#cb6-2\" aria-hidden=\"true\"><\/a>[<span class=\"op\">-<\/span><span class=\"fl\">9.5<\/span> <span class=\"op\">-<\/span><span class=\"fl\">0.5<\/span> <span class=\"op\">-<\/span><span class=\"fl\">0.<\/span>   <span class=\"fl\">0.<\/span>   <span class=\"fl\">0.<\/span>   <span class=\"fl\">0.5<\/span>  <span class=\"fl\">9.5<\/span>]<\/span><\/code><\/pre><\/div>\n<p>Proximal operators can be composed with a gradient step to produce the <em>proximal gradient<\/em> algorithm: <span class=\"math display\">\\[\\begin{equation}\n\\beta = \\operatorname*{prox}_{\\alpha \\lambda \\phi}(\\beta - \\alpha \\nabla l(\\beta))\n  \\;.\n  \\label{eq:forward-backward}\n\\end{equation}\\]<\/span><\/p>\n<p>Besides the proximal operator for <span class=\"math inline\">\\(\\phi\\)<\/span>, steps in the proximal gradient algorithm are very straightforward and require only the gradient of <span class=\"math inline\">\\(l(\\beta)\\)<\/span>. This is where a tangible benefit of symbolic computation becomes apparent: <span class=\"math inline\">\\(\\nabla l(\\beta)\\)<\/span> can be computed automatically and efficiently. With [backtracking] line search to handle unknown step sizes, <span class=\"math inline\">\\(\\alpha\\)<\/span>, the proximal gradient algorithm provides a surprisingly general means of sparse estimation.<\/p>\n<\/section>\n<section id=\"the-symbolic-operations\" class=\"level1\">\n<h1>The Symbolic Operations<\/h1>\n<p>In order to identify a relevant, non-smooth problem, check that a given proximal method\u2019s conditions are satisfied (e.g.\u00a0convexity), and potentially solve the resulting proximal operators in closed-form, we need to obtain expressions for <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span>.<\/p>\n<p>In some cases, we\u2019re able to tease apart <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span> using only the interface provided by PyMC3. Specifically, the <em>observed<\/em> and <em>unobserved<\/em> random variable fields in PyMC3 models.<\/p>\n<div class=\"sourceCode\" id=\"cb7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb7-1\"><a href=\"#cb7-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> clone <span class=\"im\">as<\/span> tt_clone<\/span>\n<span id=\"cb7-2\"><a href=\"#cb7-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-3\"><a href=\"#cb7-3\" aria-hidden=\"true\"><\/a>logl <span class=\"op\">=<\/span> tt_clone(lasso_model.observed_RVs[<span class=\"dv\">0<\/span>].logpt,<\/span>\n<span id=\"cb7-4\"><a href=\"#cb7-4\" aria-hidden=\"true\"><\/a>                {beta_rv: beta_tt})<\/span>\n<span id=\"cb7-5\"><a href=\"#cb7-5\" aria-hidden=\"true\"><\/a>logl.name <span class=\"op\">=<\/span> <span class=\"st\">&quot;logl&quot;<\/span><\/span><\/code><\/pre><\/div>\n<p>Instead, let\u2019s assume we\u2019re extending <code>find_MAP<\/code> with even more generality, so that we can\u2019t determine <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span> in this way. This situation can occur when a user specifies custom distributions or potential functions. Regardless, we need to operate at a more symbolic level.<\/p>\n<div class=\"remark\" data-markdown=\"\" data-title-name=\"\">\n<p>At this point, it is extremely worthwhile to browse the <a href=\"http:\/\/deeplearning.net\/software\/theano\/extending\/graphstructures.html\">Theano documentation<\/a> regarding graphs and their constituent objects.<\/p>\n<\/div>\n<p>The total log-likelihood is a good place to start. Let\u2019s look at the symbolic graph for the log-likelihood of our model.<\/p>\n<div class=\"sourceCode\" id=\"cb8\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb8-1\"><a href=\"#cb8-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> pp <span class=\"im\">as<\/span> tt_pp<\/span>\n<span id=\"cb8-2\"><a href=\"#cb8-2\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> theano <span class=\"im\">import<\/span> pprint <span class=\"im\">as<\/span> tt_pprint<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb9\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb9-1\"><a href=\"#cb9-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> <span class=\"bu\">print<\/span>(tt_pp(lasso_model.logpt))<\/span>\n<span id=\"cb9-2\"><a href=\"#cb9-2\" aria-hidden=\"true\"><\/a>(Sum{acc_dtype<span class=\"op\">=<\/span>float64}(Sum{acc_dtype<span class=\"op\">=<\/span>float64}(((<span class=\"op\">-<\/span>log(TensorConstant{<span class=\"dv\">2<\/span>}))<\/span>\n<span id=\"cb9-3\"><a href=\"#cb9-3\" aria-hidden=\"true\"><\/a><span class=\"op\">-<\/span> (<span class=\"op\">|<\/span>(<span class=\"op\">\\<\/span>beta <span class=\"op\">-<\/span> TensorConstant{<span class=\"dv\">0<\/span>})<span class=\"op\">|<\/span> <span class=\"op\">\/<\/span> TensorConstant{<span class=\"dv\">1<\/span>})))) <span class=\"op\">+<\/span><\/span>\n<span id=\"cb9-4\"><a href=\"#cb9-4\" aria-hidden=\"true\"><\/a>Sum{acc_dtype<span class=\"op\">=<\/span>float64}(Sum{acc_dtype<span class=\"op\">=<\/span>float64}(switch(TensorConstant{<span class=\"dv\">1<\/span>},<\/span>\n<span id=\"cb9-5\"><a href=\"#cb9-5\" aria-hidden=\"true\"><\/a>(((TensorConstant{<span class=\"op\">-<\/span><span class=\"fl\">1.0<\/span>} <span class=\"op\">*<\/span> ((y <span class=\"op\">-<\/span> (X <span class=\"op\">\\<\/span>dot <span class=\"op\">\\<\/span>beta)) <span class=\"op\">**<\/span> TensorConstant{<span class=\"dv\">2<\/span>}))<\/span>\n<span id=\"cb9-6\"><a href=\"#cb9-6\" aria-hidden=\"true\"><\/a><span class=\"op\">+<\/span> log(TensorConstant{<span class=\"fl\">0.159154943092<\/span>})) <span class=\"op\">\/<\/span> TensorConstant{<span class=\"fl\">2.0<\/span>}),<\/span>\n<span id=\"cb9-7\"><a href=\"#cb9-7\" aria-hidden=\"true\"><\/a>TensorConstant{<span class=\"op\">-<\/span>inf}))))<\/span><\/code><\/pre><\/div>\n<p>The <a href=\"http:\/\/deeplearning.net\/software\/theano\/tutorial\/printing_drawing.html#pretty-printing\">pretty printed<\/a> Theano graph tells us\u2013among other things\u2013that we indeed have a sum of <span class=\"math inline\">\\(\\ell_2\\)<\/span> and <span class=\"math inline\">\\(\\ell_1\\)<\/span> terms, although they are found among other confusing results (such as a <code>switch<\/code> statement).<\/p>\n<p>As with most graphs produced by symbolic algebra systems, we need to understand how operations and objects are expressed in a graph and exactly which ones are relevant to us. After doing so, we can develop a means of finding what we want. The <a href=\"http:\/\/deeplearning.net\/software\/theano\/tutorial\/printing_drawing.html#debug-print\">debug printout<\/a> is often a better visual summary of graphs, since it expresses branches clearly.<\/p>\n<div class=\"sourceCode\" id=\"cb10\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb10-1\"><a href=\"#cb10-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> tt.printing.debugprint(lasso_model.logpt)<\/span>\n<span id=\"cb10-2\"><a href=\"#cb10-2\" aria-hidden=\"true\"><\/a>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> A] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-3\"><a href=\"#cb10-3\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> B] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-4\"><a href=\"#cb10-4\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> C] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-5\"><a href=\"#cb10-5\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>Elemwise{sub,no_inplace} [<span class=\"bu\">id<\/span> D] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-6\"><a href=\"#cb10-6\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> E] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-7\"><a href=\"#cb10-7\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{neg,no_inplace} [<span class=\"bu\">id<\/span> F] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-8\"><a href=\"#cb10-8\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>Elemwise{log,no_inplace} [<span class=\"bu\">id<\/span> G] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-9\"><a href=\"#cb10-9\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> H]<\/span>\n<span id=\"cb10-10\"><a href=\"#cb10-10\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>Elemwise{true_div,no_inplace} [<span class=\"bu\">id<\/span> I] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-11\"><a href=\"#cb10-11\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>       <span class=\"op\">|<\/span>Elemwise{abs_,no_inplace} [<span class=\"bu\">id<\/span> J] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-12\"><a href=\"#cb10-12\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{sub,no_inplace} [<span class=\"bu\">id<\/span> K] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-13\"><a href=\"#cb10-13\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>       <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>beta [<span class=\"bu\">id<\/span> L]<\/span>\n<span id=\"cb10-14\"><a href=\"#cb10-14\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>       <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> M] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-15\"><a href=\"#cb10-15\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>       <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">0<\/span>} [<span class=\"bu\">id<\/span> N]<\/span>\n<span id=\"cb10-16\"><a href=\"#cb10-16\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>       <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> O] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-17\"><a href=\"#cb10-17\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>         <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">1<\/span>} [<span class=\"bu\">id<\/span> P]<\/span>\n<span id=\"cb10-18\"><a href=\"#cb10-18\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> Q] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-19\"><a href=\"#cb10-19\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> R] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-20\"><a href=\"#cb10-20\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>Elemwise{switch,no_inplace} [<span class=\"bu\">id<\/span> S] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-21\"><a href=\"#cb10-21\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> T] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-22\"><a href=\"#cb10-22\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">1<\/span>} [<span class=\"bu\">id<\/span> P]<\/span>\n<span id=\"cb10-23\"><a href=\"#cb10-23\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>Elemwise{true_div,no_inplace} [<span class=\"bu\">id<\/span> U] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-24\"><a href=\"#cb10-24\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> V] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-25\"><a href=\"#cb10-25\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> W] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-26\"><a href=\"#cb10-26\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> X] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-27\"><a href=\"#cb10-27\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"op\">-<\/span><span class=\"fl\">1.0<\/span>} [<span class=\"bu\">id<\/span> Y]<\/span>\n<span id=\"cb10-28\"><a href=\"#cb10-28\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{<span class=\"bu\">pow<\/span>,no_inplace} [<span class=\"bu\">id<\/span> Z] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-29\"><a href=\"#cb10-29\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>Elemwise{sub,no_inplace} [<span class=\"bu\">id<\/span> BA] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-30\"><a href=\"#cb10-30\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>y [<span class=\"bu\">id<\/span> BB]<\/span>\n<span id=\"cb10-31\"><a href=\"#cb10-31\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>dot [<span class=\"bu\">id<\/span> BC] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-32\"><a href=\"#cb10-32\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>X [<span class=\"bu\">id<\/span> BD]<\/span>\n<span id=\"cb10-33\"><a href=\"#cb10-33\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>beta [<span class=\"bu\">id<\/span> L]<\/span>\n<span id=\"cb10-34\"><a href=\"#cb10-34\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> BE] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-35\"><a href=\"#cb10-35\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> H]<\/span>\n<span id=\"cb10-36\"><a href=\"#cb10-36\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> BF] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-37\"><a href=\"#cb10-37\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>Elemwise{log,no_inplace} [<span class=\"bu\">id<\/span> BG] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-38\"><a href=\"#cb10-38\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"fl\">0.159154943092<\/span>} [<span class=\"bu\">id<\/span> BH]<\/span>\n<span id=\"cb10-39\"><a href=\"#cb10-39\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> BI] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-40\"><a href=\"#cb10-40\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>TensorConstant{<span class=\"fl\">2.0<\/span>} [<span class=\"bu\">id<\/span> BJ]<\/span>\n<span id=\"cb10-41\"><a href=\"#cb10-41\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> BK] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb10-42\"><a href=\"#cb10-42\" aria-hidden=\"true\"><\/a>         <span class=\"op\">|<\/span>TensorConstant{<span class=\"op\">-<\/span>inf} [<span class=\"bu\">id<\/span> BL]<\/span><\/code><\/pre><\/div>\n<p>We see that the top-most operator is an <code>Elemwise<\/code> that applies the scalar <code>add<\/code> operation. This is the \u201c<span class=\"math inline\">\\(+\\)<\/span>\u201d in <span class=\"math inline\">\\(l + \\phi\\)<\/span>. If we were to consider the inputs of this operator as candidates for <span class=\"math inline\">\\(l\\)<\/span> and <span class=\"math inline\">\\(\\phi\\)<\/span>, then we could do the following:<\/p>\n<div class=\"sourceCode\" id=\"cb11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb11-1\"><a href=\"#cb11-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> <span class=\"bu\">print<\/span>(lasso_model.logpt.owner.inputs)<\/span>\n<span id=\"cb11-2\"><a href=\"#cb11-2\" aria-hidden=\"true\"><\/a>[Sum{acc_dtype<span class=\"op\">=<\/span>float64}<span class=\"fl\">.0<\/span>, Sum{acc_dtype<span class=\"op\">=<\/span>float64}<span class=\"fl\">.0<\/span>]<\/span><\/code><\/pre><\/div>\n<p>Starting from the sub-graphs of each term, we could then search for any non-smooth functions that have known closed-form proximal operators. In our case, we only consider the absolute value function.<\/p>\n<div class=\"sourceCode\" id=\"cb12\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb12-1\"><a href=\"#cb12-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> get_abs_between(input_node):<\/span>\n<span id=\"cb12-2\"><a href=\"#cb12-2\" aria-hidden=\"true\"><\/a>    <span class=\"co\">&quot;&quot;&quot; Search for `abs` in the operations between our input and the<\/span><\/span>\n<span id=\"cb12-3\"><a href=\"#cb12-3\" aria-hidden=\"true\"><\/a><span class=\"co\">    log-likelihood output node.<\/span><\/span>\n<span id=\"cb12-4\"><a href=\"#cb12-4\" aria-hidden=\"true\"><\/a><span class=\"co\">    &quot;&quot;&quot;<\/span><\/span>\n<span id=\"cb12-5\"><a href=\"#cb12-5\" aria-hidden=\"true\"><\/a>    term_ops <span class=\"op\">=<\/span> <span class=\"bu\">list<\/span>(tt.gof.graph.ops([input_node],<\/span>\n<span id=\"cb12-6\"><a href=\"#cb12-6\" aria-hidden=\"true\"><\/a>[lasso_model.logpt]))<\/span>\n<span id=\"cb12-7\"><a href=\"#cb12-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-8\"><a href=\"#cb12-8\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Is there an absolute value in there?<\/span><\/span>\n<span id=\"cb12-9\"><a href=\"#cb12-9\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> <span class=\"bu\">filter<\/span>(<span class=\"kw\">lambda<\/span> x: x.op <span class=\"kw\">is<\/span> tt.abs_, term_ops)<\/span>\n<span id=\"cb12-10\"><a href=\"#cb12-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-11\"><a href=\"#cb12-11\" aria-hidden=\"true\"><\/a>abs_res <span class=\"op\">=<\/span> [(get_abs_between(in_), in_)<\/span>\n<span id=\"cb12-12\"><a href=\"#cb12-12\" aria-hidden=\"true\"><\/a>           <span class=\"cf\">for<\/span> in_ <span class=\"kw\">in<\/span> lasso_model.logpt.owner.inputs]<\/span>\n<span id=\"cb12-13\"><a href=\"#cb12-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-14\"><a href=\"#cb12-14\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> r_ <span class=\"kw\">in<\/span> abs_res:<\/span>\n<span id=\"cb12-15\"><a href=\"#cb12-15\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> <span class=\"bu\">len<\/span>(r_[<span class=\"dv\">0<\/span>]) <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span>:<\/span>\n<span id=\"cb12-16\"><a href=\"#cb12-16\" aria-hidden=\"true\"><\/a>        phi <span class=\"op\">=<\/span> r_[<span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb12-17\"><a href=\"#cb12-17\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">else<\/span>:<\/span>\n<span id=\"cb12-18\"><a href=\"#cb12-18\" aria-hidden=\"true\"><\/a>        logp <span class=\"op\">=<\/span> r_[<span class=\"dv\">1<\/span>]<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb13\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb13-1\"><a href=\"#cb13-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> tt.printing.debugprint(logp)<\/span>\n<span id=\"cb13-2\"><a href=\"#cb13-2\" aria-hidden=\"true\"><\/a>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> A] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-3\"><a href=\"#cb13-3\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> B] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-4\"><a href=\"#cb13-4\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>Elemwise{switch,no_inplace} [<span class=\"bu\">id<\/span> C] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-5\"><a href=\"#cb13-5\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> D] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-6\"><a href=\"#cb13-6\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">1<\/span>} [<span class=\"bu\">id<\/span> E]<\/span>\n<span id=\"cb13-7\"><a href=\"#cb13-7\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>Elemwise{true_div,no_inplace} [<span class=\"bu\">id<\/span> F] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-8\"><a href=\"#cb13-8\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> G] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-9\"><a href=\"#cb13-9\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> H] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-10\"><a href=\"#cb13-10\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> I] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-11\"><a href=\"#cb13-11\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"op\">-<\/span><span class=\"fl\">1.0<\/span>} [<span class=\"bu\">id<\/span> J]<\/span>\n<span id=\"cb13-12\"><a href=\"#cb13-12\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{<span class=\"bu\">pow<\/span>,no_inplace} [<span class=\"bu\">id<\/span> K] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-13\"><a href=\"#cb13-13\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>Elemwise{sub,no_inplace} [<span class=\"bu\">id<\/span> L] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-14\"><a href=\"#cb13-14\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>y [<span class=\"bu\">id<\/span> M]<\/span>\n<span id=\"cb13-15\"><a href=\"#cb13-15\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>dot [<span class=\"bu\">id<\/span> N] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-16\"><a href=\"#cb13-16\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>X [<span class=\"bu\">id<\/span> O]<\/span>\n<span id=\"cb13-17\"><a href=\"#cb13-17\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>beta [<span class=\"bu\">id<\/span> P]<\/span>\n<span id=\"cb13-18\"><a href=\"#cb13-18\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> Q] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-19\"><a href=\"#cb13-19\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> R]<\/span>\n<span id=\"cb13-20\"><a href=\"#cb13-20\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> S] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-21\"><a href=\"#cb13-21\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>Elemwise{log,no_inplace} [<span class=\"bu\">id<\/span> T] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-22\"><a href=\"#cb13-22\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"fl\">0.159154943092<\/span>} [<span class=\"bu\">id<\/span> U]<\/span>\n<span id=\"cb13-23\"><a href=\"#cb13-23\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> V] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-24\"><a href=\"#cb13-24\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>TensorConstant{<span class=\"fl\">2.0<\/span>} [<span class=\"bu\">id<\/span> W]<\/span>\n<span id=\"cb13-25\"><a href=\"#cb13-25\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> X] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-26\"><a href=\"#cb13-26\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>TensorConstant{<span class=\"op\">-<\/span>inf} [<span class=\"bu\">id<\/span> Y]<\/span>\n<span id=\"cb13-27\"><a href=\"#cb13-27\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> tt.printing.debugprint(phi)<\/span>\n<span id=\"cb13-28\"><a href=\"#cb13-28\" aria-hidden=\"true\"><\/a>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> A] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-29\"><a href=\"#cb13-29\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Sum{acc_dtype<span class=\"op\">=<\/span>float64} [<span class=\"bu\">id<\/span> B] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-30\"><a href=\"#cb13-30\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>Elemwise{sub,no_inplace} [<span class=\"bu\">id<\/span> C] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-31\"><a href=\"#cb13-31\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> D] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-32\"><a href=\"#cb13-32\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{neg,no_inplace} [<span class=\"bu\">id<\/span> E] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-33\"><a href=\"#cb13-33\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>Elemwise{log,no_inplace} [<span class=\"bu\">id<\/span> F] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-34\"><a href=\"#cb13-34\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> G]<\/span>\n<span id=\"cb13-35\"><a href=\"#cb13-35\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>Elemwise{true_div,no_inplace} [<span class=\"bu\">id<\/span> H] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-36\"><a href=\"#cb13-36\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>Elemwise{abs_,no_inplace} [<span class=\"bu\">id<\/span> I] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-37\"><a href=\"#cb13-37\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>Elemwise{sub,no_inplace} [<span class=\"bu\">id<\/span> J] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-38\"><a href=\"#cb13-38\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>beta [<span class=\"bu\">id<\/span> K]<\/span>\n<span id=\"cb13-39\"><a href=\"#cb13-39\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> L] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-40\"><a href=\"#cb13-40\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">0<\/span>} [<span class=\"bu\">id<\/span> M]<\/span>\n<span id=\"cb13-41\"><a href=\"#cb13-41\" aria-hidden=\"true\"><\/a>       <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> N] <span class=\"st\">&#39;&#39;<\/span><\/span>\n<span id=\"cb13-42\"><a href=\"#cb13-42\" aria-hidden=\"true\"><\/a>         <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">1<\/span>} [<span class=\"bu\">id<\/span> O]<\/span><\/code><\/pre><\/div>\n<p>The above approach is still too limiting; we need something more robust. For instance, our logic could fail on graphs that are expressed as <span class=\"math inline\">\\(\\eta (l + \\phi) + 1\\)<\/span>\u2013although a graph for the equivalent expression <span class=\"math inline\">\\(\\eta l + \\eta \\phi + \\eta\\)<\/span> might succeed. These are types of weaknesses inherent to naive approaches like ours. Furthermore, sufficient logic that uses a similar approach is likely to result in complicated and less approachable code.<\/p>\n<p>The appropriate computational tools are found in the subjects of graph unification and term rewriting, as well as the areas of functional and logic programming. Luckily, Theano provides some basic unification capabilities through its <code>PatternSub<\/code> class.<\/p>\n<p><code>PatternSub<\/code> works within the context of Theano <a href=\"http:\/\/deeplearning.net\/software\/theano\/optimizations.html\">graph optimization<\/a>. Graph optimizations perform the common symbolic operations of reduction\/simplification and rewriting. Consider the <code>phi<\/code> variable; the print-outs show an unnecessary subtraction with <span class=\"math inline\">\\(0\\)<\/span>. Clearly this step is unnecessary, so\u2013in a basic way\u2013we can see that the graph hasn\u2019t been simplified, yet.<\/p>\n<p>Many standard algebraic simplifications are already present in Theano, and, by creating our own graph optimizations, we can provide the advanced functionality we\u2019ve been alluding to.<\/p>\n<div class=\"example\" data-markdown=\"\" data-title-name=\"[Algebraic Graph Optimization]\">\n<p>As a quick demonstration, we\u2019ll make replacement patterns for multiplicative distribution across two forms of addition: <code>sum<\/code> and <code>add<\/code>.<\/p>\n<div class=\"sourceCode\" id=\"cb14\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb14-1\"><a href=\"#cb14-1\" aria-hidden=\"true\"><\/a>test_a_tt <span class=\"op\">=<\/span> tt.as_tensor_variable(<span class=\"dv\">5<\/span>, name<span class=\"op\">=<\/span><span class=\"st\">&#39;a&#39;<\/span>)<\/span>\n<span id=\"cb14-2\"><a href=\"#cb14-2\" aria-hidden=\"true\"><\/a>test_b_tt <span class=\"op\">=<\/span> tt.as_tensor_variable(<span class=\"dv\">2<\/span>, name<span class=\"op\">=<\/span><span class=\"st\">&#39;b&#39;<\/span>)<\/span>\n<span id=\"cb14-3\"><a href=\"#cb14-3\" aria-hidden=\"true\"><\/a>test_c_tt <span class=\"op\">=<\/span> tt.as_tensor_variable(np.r_[<span class=\"dv\">1<\/span>, <span class=\"dv\">2<\/span>], name<span class=\"op\">=<\/span><span class=\"st\">&#39;c&#39;<\/span>)<\/span>\n<span id=\"cb14-4\"><a href=\"#cb14-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-5\"><a href=\"#cb14-5\" aria-hidden=\"true\"><\/a>test_exprs_tt <span class=\"op\">=<\/span> (test_a_tt <span class=\"op\">*<\/span> test_b_tt,)<\/span>\n<span id=\"cb14-6\"><a href=\"#cb14-6\" aria-hidden=\"true\"><\/a>test_exprs_tt <span class=\"op\">+=<\/span> (test_a_tt <span class=\"op\">*<\/span> (test_b_tt <span class=\"op\">+<\/span> test_a_tt),)<\/span>\n<span id=\"cb14-7\"><a href=\"#cb14-7\" aria-hidden=\"true\"><\/a>test_exprs_tt <span class=\"op\">+=<\/span> (test_a_tt <span class=\"op\">*<\/span> (test_c_tt <span class=\"op\">+<\/span> test_a_tt),)<\/span>\n<span id=\"cb14-8\"><a href=\"#cb14-8\" aria-hidden=\"true\"><\/a>test_exprs_tt <span class=\"op\">+=<\/span> (test_a_tt <span class=\"op\">*<\/span> (test_c_tt <span class=\"op\">+<\/span> test_c_tt),)<\/span>\n<span id=\"cb14-9\"><a href=\"#cb14-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb14-10\"><a href=\"#cb14-10\" aria-hidden=\"true\"><\/a>mul_dist_pat_tt <span class=\"op\">=<\/span> (tt.gof.opt.PatternSub(<\/span>\n<span id=\"cb14-11\"><a href=\"#cb14-11\" aria-hidden=\"true\"><\/a>    (tt.mul, <span class=\"st\">&#39;x&#39;<\/span>, (tt.<span class=\"bu\">sum<\/span>, <span class=\"st\">&#39;y&#39;<\/span>, <span class=\"st\">&#39;z&#39;<\/span>)),<\/span>\n<span id=\"cb14-12\"><a href=\"#cb14-12\" aria-hidden=\"true\"><\/a>    (tt.<span class=\"bu\">sum<\/span>, (tt.mul, <span class=\"st\">&#39;x&#39;<\/span>, <span class=\"st\">&#39;y&#39;<\/span>), (tt.mul, <span class=\"st\">&#39;x&#39;<\/span>, <span class=\"st\">&#39;z&#39;<\/span>))<\/span>\n<span id=\"cb14-13\"><a href=\"#cb14-13\" aria-hidden=\"true\"><\/a>),)<\/span>\n<span id=\"cb14-14\"><a href=\"#cb14-14\" aria-hidden=\"true\"><\/a>mul_dist_pat_tt <span class=\"op\">+=<\/span> (tt.gof.opt.PatternSub(<\/span>\n<span id=\"cb14-15\"><a href=\"#cb14-15\" aria-hidden=\"true\"><\/a>    (tt.mul, <span class=\"st\">&#39;x&#39;<\/span>, (tt.add, <span class=\"st\">&#39;y&#39;<\/span>, <span class=\"st\">&#39;z&#39;<\/span>)),<\/span>\n<span id=\"cb14-16\"><a href=\"#cb14-16\" aria-hidden=\"true\"><\/a>    (tt.add, (tt.mul, <span class=\"st\">&#39;x&#39;<\/span>, <span class=\"st\">&#39;y&#39;<\/span>), (tt.mul, <span class=\"st\">&#39;x&#39;<\/span>, <span class=\"st\">&#39;z&#39;<\/span>))<\/span>\n<span id=\"cb14-17\"><a href=\"#cb14-17\" aria-hidden=\"true\"><\/a>),)<\/span><\/code><\/pre><\/div>\n<p>Substitutions can be applied to an objective function until it is in a fully-reduced form: <code>EquilibriumOptimizer<\/code> provides this functionality.<\/p>\n<div class=\"sourceCode\" id=\"cb15\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb15-1\"><a href=\"#cb15-1\" aria-hidden=\"true\"><\/a>test_sub_eqz_opt_tt <span class=\"op\">=<\/span> tt.gof.opt.EquilibriumOptimizer(<\/span>\n<span id=\"cb15-2\"><a href=\"#cb15-2\" aria-hidden=\"true\"><\/a>    mul_dist_pat_tt, max_use_ratio<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>)<\/span>\n<span id=\"cb15-3\"><a href=\"#cb15-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb15-4\"><a href=\"#cb15-4\" aria-hidden=\"true\"><\/a>test_fgraph_tt <span class=\"op\">=<\/span> tt.gof.fg.FunctionGraph(<\/span>\n<span id=\"cb15-5\"><a href=\"#cb15-5\" aria-hidden=\"true\"><\/a>    tt.gof.graph.inputs(test_exprs_tt), test_exprs_tt)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb16\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb16-1\"><a href=\"#cb16-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> tt.printing.debugprint(test_fgraph_tt)<\/span>\n<span id=\"cb16-2\"><a href=\"#cb16-2\" aria-hidden=\"true\"><\/a>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> A] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">5<\/span><\/span>\n<span id=\"cb16-3\"><a href=\"#cb16-3\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb16-4\"><a href=\"#cb16-4\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> C]<\/span>\n<span id=\"cb16-5\"><a href=\"#cb16-5\" aria-hidden=\"true\"><\/a>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> D] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">8<\/span><\/span>\n<span id=\"cb16-6\"><a href=\"#cb16-6\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb16-7\"><a href=\"#cb16-7\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> E] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">4<\/span><\/span>\n<span id=\"cb16-8\"><a href=\"#cb16-8\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> C]<\/span>\n<span id=\"cb16-9\"><a href=\"#cb16-9\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb16-10\"><a href=\"#cb16-10\" aria-hidden=\"true\"><\/a>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> F] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">9<\/span><\/span>\n<span id=\"cb16-11\"><a href=\"#cb16-11\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> G] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">3<\/span><\/span>\n<span id=\"cb16-12\"><a href=\"#cb16-12\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb16-13\"><a href=\"#cb16-13\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> H] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">7<\/span><\/span>\n<span id=\"cb16-14\"><a href=\"#cb16-14\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{[<span class=\"dv\">1<\/span> <span class=\"dv\">2<\/span>]} [<span class=\"bu\">id<\/span> I]<\/span>\n<span id=\"cb16-15\"><a href=\"#cb16-15\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> J] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">2<\/span><\/span>\n<span id=\"cb16-16\"><a href=\"#cb16-16\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb16-17\"><a href=\"#cb16-17\" aria-hidden=\"true\"><\/a>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> K] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">6<\/span><\/span>\n<span id=\"cb16-18\"><a href=\"#cb16-18\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> L] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">1<\/span><\/span>\n<span id=\"cb16-19\"><a href=\"#cb16-19\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb16-20\"><a href=\"#cb16-20\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> M] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">0<\/span><\/span>\n<span id=\"cb16-21\"><a href=\"#cb16-21\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{[<span class=\"dv\">1<\/span> <span class=\"dv\">2<\/span>]} [<span class=\"bu\">id<\/span> I]<\/span>\n<span id=\"cb16-22\"><a href=\"#cb16-22\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{[<span class=\"dv\">1<\/span> <span class=\"dv\">2<\/span>]} [<span class=\"bu\">id<\/span> I]<\/span><\/code><\/pre><\/div>\n<p>Now, when we apply the optimization, the <code>FunctionGraph<\/code> should contain the replacements.<\/p>\n<div class=\"sourceCode\" id=\"cb17\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb17-1\"><a href=\"#cb17-1\" aria-hidden=\"true\"><\/a>test_fgraph_opt <span class=\"op\">=<\/span> test_sub_eqz_opt_tt.optimize(test_fgraph_tt)<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb18\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb18-1\"><a href=\"#cb18-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> tt.printing.debugprint(test_fgraph_tt)<\/span>\n<span id=\"cb18-2\"><a href=\"#cb18-2\" aria-hidden=\"true\"><\/a>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> A] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">5<\/span><\/span>\n<span id=\"cb18-3\"><a href=\"#cb18-3\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb18-4\"><a href=\"#cb18-4\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> C]<\/span>\n<span id=\"cb18-5\"><a href=\"#cb18-5\" aria-hidden=\"true\"><\/a>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> D] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">10<\/span><\/span>\n<span id=\"cb18-6\"><a href=\"#cb18-6\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> E] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">4<\/span><\/span>\n<span id=\"cb18-7\"><a href=\"#cb18-7\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb18-8\"><a href=\"#cb18-8\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">2<\/span>} [<span class=\"bu\">id<\/span> C]<\/span>\n<span id=\"cb18-9\"><a href=\"#cb18-9\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> F] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">3<\/span><\/span>\n<span id=\"cb18-10\"><a href=\"#cb18-10\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb18-11\"><a href=\"#cb18-11\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb18-12\"><a href=\"#cb18-12\" aria-hidden=\"true\"><\/a>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> G] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">12<\/span><\/span>\n<span id=\"cb18-13\"><a href=\"#cb18-13\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> H] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">9<\/span><\/span>\n<span id=\"cb18-14\"><a href=\"#cb18-14\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> I] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">2<\/span><\/span>\n<span id=\"cb18-15\"><a href=\"#cb18-15\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb18-16\"><a href=\"#cb18-16\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{[<span class=\"dv\">1<\/span> <span class=\"dv\">2<\/span>]} [<span class=\"bu\">id<\/span> J]<\/span>\n<span id=\"cb18-17\"><a href=\"#cb18-17\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> K] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">8<\/span><\/span>\n<span id=\"cb18-18\"><a href=\"#cb18-18\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> I] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">2<\/span><\/span>\n<span id=\"cb18-19\"><a href=\"#cb18-19\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> L] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">1<\/span><\/span>\n<span id=\"cb18-20\"><a href=\"#cb18-20\" aria-hidden=\"true\"><\/a>     <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb18-21\"><a href=\"#cb18-21\" aria-hidden=\"true\"><\/a>Elemwise{add,no_inplace} [<span class=\"bu\">id<\/span> M] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">11<\/span><\/span>\n<span id=\"cb18-22\"><a href=\"#cb18-22\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> N] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">7<\/span><\/span>\n<span id=\"cb18-23\"><a href=\"#cb18-23\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> O] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">0<\/span><\/span>\n<span id=\"cb18-24\"><a href=\"#cb18-24\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{<span class=\"dv\">5<\/span>} [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb18-25\"><a href=\"#cb18-25\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span> <span class=\"op\">|<\/span>TensorConstant{[<span class=\"dv\">1<\/span> <span class=\"dv\">2<\/span>]} [<span class=\"bu\">id<\/span> J]<\/span>\n<span id=\"cb18-26\"><a href=\"#cb18-26\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Elemwise{mul,no_inplace} [<span class=\"bu\">id<\/span> P] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">6<\/span><\/span>\n<span id=\"cb18-27\"><a href=\"#cb18-27\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>DimShuffle{x} [<span class=\"bu\">id<\/span> O] <span class=\"st\">&#39;&#39;<\/span>   <span class=\"dv\">0<\/span><\/span>\n<span id=\"cb18-28\"><a href=\"#cb18-28\" aria-hidden=\"true\"><\/a>   <span class=\"op\">|<\/span>TensorConstant{[<span class=\"dv\">1<\/span> <span class=\"dv\">2<\/span>]} [<span class=\"bu\">id<\/span> J]<\/span><\/code><\/pre><\/div>\n<\/div>\n<p>Even more symbolic capabilities might be needed to [efficiently] achieve the functionality we desire. Standalone libraries like SymPy and <a href=\"https:\/\/github.com\/logpy\/logpy\/\">LogPy<\/a> can be adapted to Theano graphs and provide these capabilities\u2013although direct implementation in Theano may be better.<\/p>\n<p>Finally, let\u2019s briefly imagine how convexity could be determined symbolically. For differentiable terms, we could start with a simple second derivative test. Within Theano, a \u201csecond derivative\u201d can be obtained using the <code>hessian<\/code> function, and within <code>theano.sandbox.linalg<\/code> are <code>Optimizer<\/code> hints for matrix positivity and other properties relevant to determining convexity.<\/p>\n<div class=\"remark\" data-markdown=\"\" data-title-name=\"\">\n<p>Other great examples of linear algebra themed optimizations are in <code>theano.sandbox.linalg<\/code>: for instance, <code>no_transpose_symmetric<\/code>. Some of these demonstrate exactly how straight-forward adding algebraic features can be.<\/p>\n<\/div>\n<p>Although our convexity testing idea is far too simple for some functions, the point is that the basic tools necessary for work in this direction are already in place. With the logic programming and symbolic libraries mentioned earlier, a robust implementation of the convex function calculus could be very much in reach.<\/p>\n<\/section>\n<section id=\"discussion\" class=\"level1\">\n<h1>Discussion<\/h1>\n<p>We\u2019ve sketched out some ideas and tools with which one could develop a robust estimation platform guided by the more abstract mathematical frameworks from which new and efficient methods are produced.<\/p>\n<p>Some key steps may require the integration of a fully featured symbolic algebra system. Along these lines, connections between Theano, SymPy and LogPy have been explored in <span class=\"citation\" data-cites=\"rocklin_mathematically_2013\">Rocklin (2013)<\/span>\u2013as well as many other important aspects of the topics discussed here.<\/p>\n<p>Besides the automation of proximal algorithms themselves, there are areas of application involving very large and complex models\u2013perhaps the ones arising in Deep Learning. How might we consider the operator splitting of ADMM within deeply layered or hierarchical models <span class=\"citation\" data-cites=\"polson_statistical_2015\">(Polson, Willard, and Heidari 2015)<\/span>? At which levels and on which terms should the splitting be performed? Beyond trying to solve the potentially unwieldy mathematics arising from such questions, by imbuing these symbolic tools with more mathematical awareness, we can at least experiment in these directions and quickly offer numerical solutions. This is\u2013in part\u2013the edge from which statistics hasn\u2019t been benefiting and modern machine learning has.<\/p>\n<p>Before closing, a very related\u2013and interesting\u2013set of ideas is worth mentioning: the possibility of encoding more symbolic knowledge into probabilistic programming platforms like PyMC3. Using the same optimization mechanisms as the examples here, simple distributional relationships can be encoded. For instance, the convolution of normally distributed random variables:<\/p>\n<div class=\"sourceCode\" id=\"cb19\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb19-1\"><a href=\"#cb19-1\" aria-hidden=\"true\"><\/a>mu_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;mu_X&#39;<\/span>)<\/span>\n<span id=\"cb19-2\"><a href=\"#cb19-2\" aria-hidden=\"true\"><\/a>mu_X.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">1.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb19-3\"><a href=\"#cb19-3\" aria-hidden=\"true\"><\/a>sd_X <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;sd_X&#39;<\/span>)<\/span>\n<span id=\"cb19-4\"><a href=\"#cb19-4\" aria-hidden=\"true\"><\/a>sd_X.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">2.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb19-5\"><a href=\"#cb19-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb19-6\"><a href=\"#cb19-6\" aria-hidden=\"true\"><\/a>mu_Y <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;mu_Y&#39;<\/span>)<\/span>\n<span id=\"cb19-7\"><a href=\"#cb19-7\" aria-hidden=\"true\"><\/a>mu_Y.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">1.<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb19-8\"><a href=\"#cb19-8\" aria-hidden=\"true\"><\/a>sd_Y <span class=\"op\">=<\/span> tt.vector(<span class=\"st\">&#39;sd_Y&#39;<\/span>)<\/span>\n<span id=\"cb19-9\"><a href=\"#cb19-9\" aria-hidden=\"true\"><\/a>sd_Y.tag.test_value <span class=\"op\">=<\/span> np.array([<span class=\"fl\">0.5<\/span>], dtype<span class=\"op\">=<\/span>tt.config.floatX)<\/span>\n<span id=\"cb19-10\"><a href=\"#cb19-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb19-11\"><a href=\"#cb19-11\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model() <span class=\"im\">as<\/span> conv_model:<\/span>\n<span id=\"cb19-12\"><a href=\"#cb19-12\" aria-hidden=\"true\"><\/a>    X_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;X&#39;<\/span>, mu_X, sd<span class=\"op\">=<\/span>sd_X, shape<span class=\"op\">=<\/span>(<span class=\"dv\">1<\/span>,))<\/span>\n<span id=\"cb19-13\"><a href=\"#cb19-13\" aria-hidden=\"true\"><\/a>    Y_rv <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&#39;Y&#39;<\/span>, mu_Y, sd<span class=\"op\">=<\/span>sd_Y, shape<span class=\"op\">=<\/span>(<span class=\"dv\">1<\/span>,))<\/span>\n<span id=\"cb19-14\"><a href=\"#cb19-14\" aria-hidden=\"true\"><\/a>    Z_rv <span class=\"op\">=<\/span> X_rv <span class=\"op\">+<\/span> Y_rv<\/span><\/code><\/pre><\/div>\n<p>We create a Theano <code>Op<\/code> to handle the convolution.<\/p>\n<div class=\"sourceCode\" id=\"cb20\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb20-1\"><a href=\"#cb20-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">class<\/span> NormConvOp(tt.Op):<\/span>\n<span id=\"cb20-2\"><a href=\"#cb20-2\" aria-hidden=\"true\"><\/a>    __props__ <span class=\"op\">=<\/span> ()<\/span>\n<span id=\"cb20-3\"><a href=\"#cb20-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-4\"><a href=\"#cb20-4\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> make_node(<span class=\"va\">self<\/span>, <span class=\"op\">*<\/span>inputs):<\/span>\n<span id=\"cb20-5\"><a href=\"#cb20-5\" aria-hidden=\"true\"><\/a>        name_new <span class=\"op\">=<\/span> <span class=\"bu\">str<\/span>.join(<span class=\"st\">&#39;+&#39;<\/span>, [<span class=\"bu\">getattr<\/span>(in_, <span class=\"st\">&#39;name&#39;<\/span>, <span class=\"st\">&#39;&#39;<\/span>) <span class=\"cf\">for<\/span> in_ <span class=\"kw\">in<\/span><\/span>\n<span id=\"cb20-6\"><a href=\"#cb20-6\" aria-hidden=\"true\"><\/a>inputs])<\/span>\n<span id=\"cb20-7\"><a href=\"#cb20-7\" aria-hidden=\"true\"><\/a>        mu_new <span class=\"op\">=<\/span> tt.add(<span class=\"op\">*<\/span>[in_.distribution.mu <span class=\"cf\">for<\/span> in_ <span class=\"kw\">in<\/span> inputs])<\/span>\n<span id=\"cb20-8\"><a href=\"#cb20-8\" aria-hidden=\"true\"><\/a>        sd_new <span class=\"op\">=<\/span> tt.sqrt(tt.add(<span class=\"op\">*<\/span>[in_.distribution.sd<span class=\"op\">**<\/span><span class=\"dv\">2<\/span> <span class=\"cf\">for<\/span> in_ <span class=\"kw\">in<\/span><\/span>\n<span id=\"cb20-9\"><a href=\"#cb20-9\" aria-hidden=\"true\"><\/a>inputs]))<\/span>\n<span id=\"cb20-10\"><a href=\"#cb20-10\" aria-hidden=\"true\"><\/a>        conv_rv <span class=\"op\">=<\/span> pm.Normal(name_new, mu<span class=\"op\">=<\/span>mu_new, sd<span class=\"op\">=<\/span>sd_new,<\/span>\n<span id=\"cb20-11\"><a href=\"#cb20-11\" aria-hidden=\"true\"><\/a>                            <span class=\"co\"># Is this another place where<\/span><\/span>\n<span id=\"cb20-12\"><a href=\"#cb20-12\" aria-hidden=\"true\"><\/a>automatically<span class=\"op\">\/<\/span>Theano managed<\/span>\n<span id=\"cb20-13\"><a href=\"#cb20-13\" aria-hidden=\"true\"><\/a>                            <span class=\"co\"># shapes are really needed.  For now, we<\/span><\/span>\n<span id=\"cb20-14\"><a href=\"#cb20-14\" aria-hidden=\"true\"><\/a>hack it.<\/span>\n<span id=\"cb20-15\"><a href=\"#cb20-15\" aria-hidden=\"true\"><\/a>                            shape<span class=\"op\">=<\/span>(<span class=\"dv\">1<\/span>,))<\/span>\n<span id=\"cb20-16\"><a href=\"#cb20-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-17\"><a href=\"#cb20-17\" aria-hidden=\"true\"><\/a>        <span class=\"cf\">return<\/span> tt.Apply(<span class=\"va\">self<\/span>, inputs, [conv_rv])<\/span>\n<span id=\"cb20-18\"><a href=\"#cb20-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb20-19\"><a href=\"#cb20-19\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">def<\/span> perform(<span class=\"va\">self<\/span>, node, inputs, output_storage):<\/span>\n<span id=\"cb20-20\"><a href=\"#cb20-20\" aria-hidden=\"true\"><\/a>        z <span class=\"op\">=<\/span> output_storage[<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"cb20-21\"><a href=\"#cb20-21\" aria-hidden=\"true\"><\/a>        z[<span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.add(<span class=\"op\">*<\/span>inputs)<\/span><\/code><\/pre><\/div>\n<p>Now, all that\u2019s needed is a <code>PatternSub<\/code> like before.<\/p>\n<div class=\"sourceCode\" id=\"cb21\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb21-1\"><a href=\"#cb21-1\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> is_normal_dist(x):<\/span>\n<span id=\"cb21-2\"><a href=\"#cb21-2\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> <span class=\"bu\">hasattr<\/span>(x, <span class=\"st\">&#39;distribution&#39;<\/span>) <span class=\"kw\">and<\/span> <span class=\"bu\">isinstance<\/span>(x.distribution,<\/span>\n<span id=\"cb21-3\"><a href=\"#cb21-3\" aria-hidden=\"true\"><\/a>pm.Normal)<\/span>\n<span id=\"cb21-4\"><a href=\"#cb21-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb21-5\"><a href=\"#cb21-5\" aria-hidden=\"true\"><\/a>norm_conv_pat_tt <span class=\"op\">=<\/span> (tt.gof.opt.PatternSub(<\/span>\n<span id=\"cb21-6\"><a href=\"#cb21-6\" aria-hidden=\"true\"><\/a>    (tt.add,<\/span>\n<span id=\"cb21-7\"><a href=\"#cb21-7\" aria-hidden=\"true\"><\/a>     {<span class=\"st\">&#39;pattern&#39;<\/span>: <span class=\"st\">&#39;x&#39;<\/span>,<\/span>\n<span id=\"cb21-8\"><a href=\"#cb21-8\" aria-hidden=\"true\"><\/a>      <span class=\"st\">&#39;constraint&#39;<\/span>: <span class=\"kw\">lambda<\/span> x: is_normal_dist(x)},<\/span>\n<span id=\"cb21-9\"><a href=\"#cb21-9\" aria-hidden=\"true\"><\/a>     {<span class=\"st\">&#39;pattern&#39;<\/span>: <span class=\"st\">&#39;y&#39;<\/span>,<\/span>\n<span id=\"cb21-10\"><a href=\"#cb21-10\" aria-hidden=\"true\"><\/a>      <span class=\"st\">&#39;constraint&#39;<\/span>: <span class=\"kw\">lambda<\/span> x: is_normal_dist(x)}<\/span>\n<span id=\"cb21-11\"><a href=\"#cb21-11\" aria-hidden=\"true\"><\/a>     ),<\/span>\n<span id=\"cb21-12\"><a href=\"#cb21-12\" aria-hidden=\"true\"><\/a>    (NormConvOp(), <span class=\"st\">&#39;x&#39;<\/span>, <span class=\"st\">&#39;y&#39;<\/span>)),)<\/span>\n<span id=\"cb21-13\"><a href=\"#cb21-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb21-14\"><a href=\"#cb21-14\" aria-hidden=\"true\"><\/a>norm_conv_opt_tt <span class=\"op\">=<\/span> tt.gof.opt.EquilibriumOptimizer(norm_conv_pat_tt,<\/span>\n<span id=\"cb21-15\"><a href=\"#cb21-15\" aria-hidden=\"true\"><\/a>                                                   max_use_ratio<span class=\"op\">=<\/span><span class=\"dv\">10<\/span>)<\/span>\n<span id=\"cb21-16\"><a href=\"#cb21-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb21-17\"><a href=\"#cb21-17\" aria-hidden=\"true\"><\/a>Z_fgraph_tt <span class=\"op\">=<\/span> tt.gof.fg.FunctionGraph([X_rv, Y_rv], [Z_rv])<\/span>\n<span id=\"cb21-18\"><a href=\"#cb21-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb21-19\"><a href=\"#cb21-19\" aria-hidden=\"true\"><\/a><span class=\"co\"># We lose the `FreeRV.distribution` attribute when cloning the graph<\/span><\/span>\n<span id=\"cb21-20\"><a href=\"#cb21-20\" aria-hidden=\"true\"><\/a><span class=\"co\"># with `theano.gof.graph.clone_get_equiv` in `FunctionGraph`, so this<\/span><\/span>\n<span id=\"cb21-21\"><a href=\"#cb21-21\" aria-hidden=\"true\"><\/a><span class=\"co\"># hackishly reattaches that information:<\/span><\/span>\n<span id=\"cb21-22\"><a href=\"#cb21-22\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> [<span class=\"bu\">setattr<\/span>(g_in, <span class=\"st\">&#39;distribution&#39;<\/span>, s_in.distribution)<\/span>\n<span id=\"cb21-23\"><a href=\"#cb21-23\" aria-hidden=\"true\"><\/a>     <span class=\"cf\">for<\/span> s_in, g_in <span class=\"kw\">in<\/span> <span class=\"bu\">zip<\/span>([X_rv, Y_rv], Z_fgraph_tt.inputs)]<\/span><\/code><\/pre><\/div>\n<div class=\"sourceCode\" id=\"cb22\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb22-1\"><a href=\"#cb22-1\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> conv_model:<\/span>\n<span id=\"cb22-2\"><a href=\"#cb22-2\" aria-hidden=\"true\"><\/a>    _ <span class=\"op\">=<\/span> norm_conv_opt_tt.optimize(Z_fgraph_tt)<\/span>\n<span id=\"cb22-3\"><a href=\"#cb22-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb22-4\"><a href=\"#cb22-4\" aria-hidden=\"true\"><\/a>norm_conv_var_dist <span class=\"op\">=<\/span> Z_fgraph_tt.outputs[<span class=\"dv\">0<\/span>].distribution<\/span><\/code><\/pre><\/div>\n<p>The resulting graph:<\/p>\n<div class=\"sourceCode\" id=\"cb23\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb23-1\"><a href=\"#cb23-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> tt.printing.debugprint(Z_fgraph_tt)<\/span>\n<span id=\"cb23-2\"><a href=\"#cb23-2\" aria-hidden=\"true\"><\/a>NormConvOp [<span class=\"bu\">id<\/span> A] <span class=\"st\">&#39;X+Y&#39;<\/span>   <span class=\"dv\">0<\/span><\/span>\n<span id=\"cb23-3\"><a href=\"#cb23-3\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>X [<span class=\"bu\">id<\/span> B]<\/span>\n<span id=\"cb23-4\"><a href=\"#cb23-4\" aria-hidden=\"true\"><\/a> <span class=\"op\">|<\/span>Y [<span class=\"bu\">id<\/span> C]<\/span><\/code><\/pre><\/div>\n<p>and the convolution\u2019s parameters (for the test values):<\/p>\n<div class=\"sourceCode\" id=\"cb24\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb24-1\"><a href=\"#cb24-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> <span class=\"bu\">print<\/span>(norm_conv_var_dist.mu.tag.test_value)<\/span>\n<span id=\"cb24-2\"><a href=\"#cb24-2\" aria-hidden=\"true\"><\/a>[ <span class=\"fl\">2.<\/span>]<\/span>\n<span id=\"cb24-3\"><a href=\"#cb24-3\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> <span class=\"bu\">print<\/span>(norm_conv_var_dist.sd.tag.test_value)<\/span>\n<span id=\"cb24-4\"><a href=\"#cb24-4\" aria-hidden=\"true\"><\/a>[ <span class=\"fl\">2.06155281<\/span>]<\/span><\/code><\/pre><\/div>\n<p>More sophisticated routines\u2013like the example above\u2013could implement parameter expansions, efficient re-parameterizations and equivalent scale mixture forms in an effort to optimize a graph for sampling or point evaluation. Objectives for these optimizations could be straightforward and computationally based (e.g.\u00a0reducing the number of operations in computations of the log likelihood and other quantities) or more statistically focused (e.g.\u00a0highly efficient sampling, improve mixing). These ideas are most definitely not new\u2013one example is given by <span class=\"citation\" data-cites=\"mohasel_afshar_probabilistic_2016\">Mohasel Afshar (2016)<\/span> for symbolic Gibbs sampling, but we hope the examples given here make the point that the tools are readily available and quite accessible.<\/p>\n<p>We\u2019ll end on a much more spacey consideration. Namely, that this is a context in which we can start experimenting rapidly with objectives over the space of estimation routines. This space is generated by\u2013but not limited to\u2013the variety of symbolic representations, re-parameterizations, etc., mentioned above. It does not necessarily require the complete estimation of a model at each step, nor even the numeric value of quantities like the gradient or Hessian. It may involve them, but not their evaluation; perhaps, instead, symbolic comparisons of competing gradients and Hessians arising from different representations. What we\u2019re describing lies somewhere between the completely numeric assessments common today, and the entirely symbolic work found within the theorems and manipulations of the mathematics we use to derive methods.<\/p>\n<\/section>\n<section id=\"bibliography\" class=\"level1 unnumbered\">\n<h1 class=\"unnumbered\">\"References\"<\/h1>\n<div id=\"refs\" class=\"references hanging-indent\" role=\"doc-bibliography\">\n<div id=\"ref-combettes_proximal_2011\">\n<p>Combettes, Patrick L, and Jean-Christophe Pesquet. 2011. \u201cProximal Splitting Methods in Signal Processing.\u201d <em>Fixed-Point Algorithms for Inverse Problems in Science and Engineering<\/em>, 185\u2013212.<\/p>\n<\/div>\n<div id=\"ref-donoho_compressed_2006\">\n<p>Donoho, David L. 2006. \u201cCompressed Sensing.\u201d <em>IEEE Transactions on Information Theory<\/em> 52 (4): 1289\u20131306. <a href=\"http:\/\/ieeexplore.ieee.org\/xpls\/abs_all.jsp?arnumber=1614066\">http:\/\/ieeexplore.ieee.org\/xpls\/abs_all.jsp?arnumber=1614066<\/a>.<\/p>\n<\/div>\n<div id=\"ref-mohasel_afshar_probabilistic_2016\">\n<p>Mohasel Afshar, Hadi. 2016. \u201cProbabilistic Inference in Piecewise Graphical Models.\u201d <a href=\"https:\/\/digitalcollections.anu.edu.au\/handle\/1885\/107386\">https:\/\/digitalcollections.anu.edu.au\/handle\/1885\/107386<\/a>.<\/p>\n<\/div>\n<div id=\"ref-parikh_proximal_2014\">\n<p>Parikh, Neal, and Stephen Boyd. 2014. \u201cProximal Algorithms.\u201d <em>Foundations and Trends in Optimization<\/em> 1 (3): 123\u2013231. <a href=\"https:\/\/doi.org\/10.1561\/2400000003\">https:\/\/doi.org\/10.1561\/2400000003<\/a>.<\/p>\n<\/div>\n<div id=\"ref-park_bayesian_2008\">\n<p>Park, Trevor, and George Casella. 2008. \u201cThe Bayesian Lasso.\u201d <em>Journal of the American Statistical Association<\/em> 103 (482): 681\u201386. <a href=\"http:\/\/amstat.tandfonline.com\/doi\/abs\/10.1198\/016214508000000337\">http:\/\/amstat.tandfonline.com\/doi\/abs\/10.1198\/016214508000000337<\/a>.<\/p>\n<\/div>\n<div id=\"ref-polson_proximal_2015\">\n<p>Polson, Nicholas G., James G. Scott, and Brandon T. Willard. 2015. \u201cProximal Algorithms in Statistics and Machine Learning.\u201d <em>Statistical Science<\/em> 30 (4): 559\u201381. <a href=\"http:\/\/projecteuclid.org\/euclid.ss\/1449670858\">http:\/\/projecteuclid.org\/euclid.ss\/1449670858<\/a>.<\/p>\n<\/div>\n<div id=\"ref-polson_statistical_2015\">\n<p>Polson, Nicholas G., Brandon T. Willard, and Massoud Heidari. 2015. \u201cA Statistical Theory of Deep Learning via Proximal Splitting.\u201d <em>arXiv Preprint arXiv:1509.06061<\/em>. <a href=\"http:\/\/arxiv.org\/abs\/1509.06061\">http:\/\/arxiv.org\/abs\/1509.06061<\/a>.<\/p>\n<\/div>\n<div id=\"ref-rocklin_mathematically_2013\">\n<p>Rocklin, Matthew. 2013. \u201cMathematically Informed Linear Algebra Codes Through Term Rewriting.\u201d PhD thesis, PhD Thesis, August. <a href=\"http:\/\/people.cs.uchicago.edu\/~mrocklin\/storage\/dissertation.pdf\">http:\/\/people.cs.uchicago.edu\/~mrocklin\/storage\/dissertation.pdf<\/a>.<\/p>\n<\/div>\n<div id=\"ref-salvatier_probabilistic_2016\">\n<p>Salvatier, John, Thomas V. Wiecki, and Christopher Fonnesbeck. 2016. \u201cProbabilistic Programming in Python Using PyMC3.\u201d <em>PeerJ Computer Science<\/em> 2 (April): e55. <a href=\"https:\/\/peerj.com\/articles\/cs-55\">https:\/\/peerj.com\/articles\/cs-55<\/a>.<\/p>\n<\/div>\n<div id=\"ref-srivastava_dropout_2014\">\n<p>Srivastava, Nitish, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. 2014. \u201cDropout: A Simple Way to Prevent Neural Networks from Overfitting.\u201d <em>The Journal of Machine Learning Research<\/em> 15 (1): 1929\u201358. <a href=\"http:\/\/dl.acm.org\/citation.cfm?id=2670313\">http:\/\/dl.acm.org\/citation.cfm?id=2670313<\/a>.<\/p>\n<\/div>\n<\/div>\n<\/section>\n<\/body>\n<\/html>\n","category":{"@attributes":{"term":"articles"}}},{"title":"Regarding Statistical Model Specification and Sample Results","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/regarding-statistical-model-specification-and-sample-results.html","rel":"alternate"}},"published":"2016-11-01T00:00:00-05:00","updated":"2022-01-13T00:00:00-06:00","author":{"name":"Brandon T. Willard"},"id":"tag:brandonwillard.github.io,2016-11-01:\/regarding-statistical-model-specification-and-sample-results.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon T. Willard\" \/>\n  <title>Regarding Statistical Model Specification and Sample Results<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">Regarding Statistical Model Specification and Sample Results<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon T. Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2016\u201311\u201301<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<ul>\n<li><a href=\"#org080791d\">Introduction<\/a><\/li>\n<li><a href=\"#org1247f0c\">Notation<\/a><\/li>\n<li><a href=\"#org9a150d4\">A Simple Model<\/a><\/li>\n<li><a href=\"#org7489ef0\">Estimation (via MCMC)<\/a>\n<ul>\n<li><a href=\"#orgc99ff06\">The Situation on Implementation<\/a><\/li>\n<li><a href=\"#org82a5c5b\">The Costs<\/a><\/li>\n<\/ul><\/li>\n<li><a href=\"#orgfe79892\">Predictions<\/a><\/li>\n<li><a href=\"#org72d572e\">Hierarchical Extensions<\/a><\/li>\n<\/ul>\n<p><a id=\"org080791d\"><\/a><\/p>\n<section id=\"introduction\" class=\"level1\">\n<h1>Introduction<\/h1>\n<p>In this post I want to address some concepts regarding statistical model specification within the Bayesian paradigm, motivation for its use, and the utility of sample results (e.g.\u00a0empirical posterior distributions). This write-up isn\u2019t intended to be thorough or self-contained, especially since numerous quality introductions already exist for Bayesian modeling and MCMC <a id=\"6824e74293e673dfd3d897debac1bd36\"><a href=\"#gelman_bayesian_2013\">(Gelman, Carlin, Stern, Dunson, Vehtari &amp; Rubin 2013)<\/a><\/a>. Instead, its purpose is to illustrate some specific points in the context of a simple, evolving problem that mirrors some real-life objectives. Also, what\u2019s advocated here is in large part just <em>statistical<\/em> modeling and not exclusively <em>Bayesian<\/em>.<\/p>\n<p>The generality, applicability and relative simplicity of the core concepts within Bayesian modeling are sadly overlooked in practice. Bayes is too often conflated with MCMC and its associated computational costs, or is seen as needlessly \u201cmathy\u201d and technical. I argue that there is an oft unacknowledged trade-off in the efforts of mathematical modeling, and that Bayesian modeling helps navigate that complexity. In doing so, one can save on expended efforts in the long run.<\/p>\n<p>When a model is [fully] specified in a statistical or Bayesian way, the modeler has at their disposal distributions for the unknown quantities of interest; these distributions are often the primary interest. The desired estimates are found \u201cwithin\u201d the distributions. For instance, as a distribution\u2019s moments (e.g.\u00a0mean, mode, variance, etc.), which may correspond to certain \u201cbest\u201d estimates or measures of parameter uncertainty. The same goes for functions of these distributions (e.g.\u00a0rolling sums and averages).<\/p>\n<p>Normally, modeling objectives are specified in terms of <strong>point-estimates<\/strong> instead of distributions: like the aforementioned \u201cbest\u201d parameter estimates. This situation is also covered by the Bayesian paradigm, especially when the corresponding distributions have a closed-form and are fully specified by a finite number of parameters. However, when this isn\u2019t the case, point-estimates provide only part of the picture. It\u2019s usually these missing parts that make model assessment and prediction largely separate and difficult endeavours down the road.<\/p>\n<p>Even so, modeling and estimation often proceeds without much statistical consideration or context, making these distributions\u2013and the results they can provide\u2013more and more inaccessible. In a situation where modeling started with common machine learning\/statistical software and resulted in non-statistical extensions, the work needed for something like <em>uncertainty quantification or propagation<\/em> broadly equates to retrofitting and\/or defining the altered or missing statistical context of the problem. This sort of work necessarily requires a much rarer expertise, which is usually too difficult for outsiders to vet. Considerations like this might be reason enough to\u2013at least minimally\u2013maintain clear statistical assumptions throughout the life of a non-trivial project. The Bayesian approach can be a more accessible means of providing this type of statistical coherency.<\/p>\n<p>As a starting point, one can find quite a few non-Bayes models with Bayesian interpretations and counterparts. Even finding a Bayesian interpretation for an existing non-Bayes model can itself advance one\u2019s understanding of the statistical assumptions and properties of the model. In some cases this understanding can inspire new forms of estimation or new non-Bayes variants of a model. Multiple examples arise from models defined by objective or loss functions with forms equivalent to the total log-likelihoods of Bayesian models. This, for instance, is one way that general point-wise estimates can be related to maximum a posteriori (MAP) estimates in the Bayesian context.<\/p>\n<p><a id=\"org1247f0c\"><\/a><\/p>\n<\/section>\n<section id=\"notation\" class=\"level1\">\n<h1>Notation<\/h1>\n<p>Before getting into the details, let\u2019s cover some preliminaries regarding notation.<\/p>\n<p>The symbol <span class=\"math inline\">\\(\\sim\\)<\/span> is overloaded to mean a couple things. First, a statement like <span class=\"math inline\">\\(X \\sim \\operatorname{P}\\)<\/span> means \u201c<span class=\"math inline\">\\(X\\)<\/span> is distributed according to <span class=\"math inline\">\\(\\operatorname{P}\\)<\/span>\u201d, when <span class=\"math inline\">\\(X\\)<\/span> is understood to be a random variable (generally denoted by capital letter variables). Second, for a non-random variable <span class=\"math inline\">\\(x\\)<\/span>, <span class=\"math inline\">\\(x \\sim \\operatorname{P}\\)<\/span> and <span class=\"math inline\">\\(x \\sim X\\)<\/span> means \u201c<span class=\"math inline\">\\(x\\)<\/span> is a sample from distribution <span class=\"math inline\">\\(\\operatorname{P}\\)<\/span>\u201d. When <span class=\"math inline\">\\(\\operatorname{P}\\)<\/span> is not meant to signify a distribution, but instead a generic function\u2013like a probability density function <span class=\"math inline\">\\(p(X=x) \\equiv p(x)\\)<\/span>, then the distribution in question is [the] one arising from the function (interpreted as a probability density and\/or measure)\u2013when possible. See <a href=\"https:\/\/en.wikipedia.org\/wiki\/Notation_in_probability_and_statistics\">here<\/a> for a similar notation. Also, whenever indices are dropped, the resulting symbol is assumed to be a stacked matrix containing each entry, e.g.<\/p>\n<p><span class=\"math display\">\\[\\begin{gather*}\n  X^\\top = \\begin{pmatrix} X_1 &amp; \\dots &amp; X_N \\end{pmatrix} \\;.\n\\end{gather*}\\]<\/span><\/p>\n<p>When the indexed symbol is a vector, then it is customary to denote the row stacked matrix of each vector with the symbol\u2019s capital letter. E.g., for [column] vectors <span class=\"math inline\">\\(z_i\\)<\/span> over <span class=\"math inline\">\\(i \\in \\{1, \\dots, N\\}\\)<\/span>,<\/p>\n<p><span class=\"math display\">\\[\nZ = \\begin{pmatrix} z_1 \\\\ \\vdots \\\\ z_N \\end{pmatrix} \\;.\n\\]<\/span><\/p>\n<p><a id=\"org9a150d4\"><\/a><\/p>\n<\/section>\n<section id=\"a-simple-model\" class=\"level1\">\n<h1>A Simple Model<\/h1>\n<p>First, a simple normal-normal model<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  Y_t \\sim \\operatorname{N}(x^\\top_t \\theta, \\sigma^2), \\quad\n    \\theta \\sim \\operatorname{N}(\\mu, I \\tau^2)\n    \\label{eq:normal-normal}\n\\end{equation}\\]<\/span><\/p>\n<p>for an identity matrix <span class=\"math inline\">\\(I\\)<\/span>, observed random variable <span class=\"math inline\">\\(Y_t\\)<\/span> at time <span class=\"math inline\">\\(t \\in \\{1, \\dots, T\\}\\)<\/span>, and known constant values (of matching dimensions) <span class=\"math inline\">\\(x_t\\)<\/span>, <span class=\"math inline\">\\(\\sigma\\)<\/span>, <span class=\"math inline\">\\(\\mu\\)<\/span> and <span class=\"math inline\">\\(\\tau\\)<\/span>. The <span class=\"math inline\">\\(x_t\\)<\/span> play the role of predictors, or features, and we\u2019ll assume that the time dependencies arise primarily through them.<\/p>\n<p>In Bayes parlance, the model in <span class=\"math inline\">\\(\\eqref{eq:normal-normal}\\)<\/span> gives <span class=\"math inline\">\\(\\theta\\)<\/span> a normal prior distribution, and the primary goal involves estimating the \u201cposterior\u201d distribution <span class=\"math inline\">\\(p(\\theta \\mid y)\\)<\/span>\u2013for a vector of observations <span class=\"math inline\">\\(y\\)<\/span> under the assumption <span class=\"math inline\">\\(y \\sim Y\\)<\/span>.<\/p>\n<p>This simple example has the well known closed-form posterior solution for <span class=\"math inline\">\\(\\theta\\)<\/span>,<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\left(\\theta \\mid y_t\\right) \\sim \\operatorname{N}(m, C)\n    \\;.\n    \\label{eq:theta-posterior}\n\\end{equation}\\]<\/span><\/p>\n<p>for<\/p>\n<p><span class=\"math display\">\\[\\begin{gather*}\n  m = C \\left(\\mu \\tau^{-2} + X^\\top y\\, \\sigma^{-2}\\right), \\quad\n  C = \\left(\\tau^{-2} + \\operatorname{diag}(X^\\top X) \\sigma^{-2}\\right)^{-1}\n  \\;.\n\\end{gather*}\\]<\/span><\/p>\n<p>Results like this are easily obtained for the classical pairings of \u201cconjugate\u201d distributions. Detailed <a href=\"https:\/\/en.wikipedia.org\/wiki\/Conjugate_prior#Table_of_conjugate_distributions\">tables<\/a> and <a href=\"https:\/\/goo.gl\/UCL3pc\">tutorials<\/a> for conjugate distributions can be found online or in any standard text.<\/p>\n<p><a id=\"org7489ef0\"><\/a><\/p>\n<\/section>\n<section id=\"estimation-via-mcmc\" class=\"level1\">\n<h1>Estimation (via MCMC)<\/h1>\n<p>From here on let\u2019s assume we do not have the closed-form result in <span class=\"math inline\">\\(\\eqref{eq:theta-posterior}\\)<\/span>. Instead, we\u2019ll estimate the posterior numerically with <a href=\"https:\/\/en.wikipedia.org\/wiki\/Markov_chain_Monte_Carlo\">MCMC<\/a>. Again, MCMC is covered to varying degrees of detail all over the place (e.g.\u00a0<a href=\"https:\/\/goo.gl\/JNwfuo\">here<\/a>), so we\u2019ll skip most of those details. Let\u2019s say we\u2019ve decided to use <a href=\"https:\/\/en.wikipedia.org\/wiki\/Metropolis%E2%80%93Hastings_algorithm\">Metropolis-Hastings<\/a>.<\/p>\n<p>For demonstration purposes, we produce a simulation of some data we might observe and for which we would consider applying the model in <span class=\"math inline\">\\(\\eqref{eq:normal-normal}\\)<\/span>.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> datetime <span class=\"im\">import<\/span> datetime<\/span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> numpy <span class=\"im\">as<\/span> np<\/span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pandas <span class=\"im\">as<\/span> pd<\/span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> scipy.stats <span class=\"im\">as<\/span> scs<\/span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-7\"><a href=\"#cb1-7\" aria-hidden=\"true\"><\/a><span class=\"co\"># Unknown parameter<\/span><\/span>\n<span id=\"cb1-8\"><a href=\"#cb1-8\" aria-hidden=\"true\"><\/a>mu_true <span class=\"op\">=<\/span> <span class=\"fl\">1.5<\/span><\/span>\n<span id=\"cb1-9\"><a href=\"#cb1-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-10\"><a href=\"#cb1-10\" aria-hidden=\"true\"><\/a><span class=\"co\"># [Assumed] known parameter<\/span><\/span>\n<span id=\"cb1-11\"><a href=\"#cb1-11\" aria-hidden=\"true\"><\/a>sigma2 <span class=\"op\">=<\/span> <span class=\"fl\">0.05<\/span><\/span>\n<span id=\"cb1-12\"><a href=\"#cb1-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-13\"><a href=\"#cb1-13\" aria-hidden=\"true\"><\/a><span class=\"co\"># Prior parameters<\/span><\/span>\n<span id=\"cb1-14\"><a href=\"#cb1-14\" aria-hidden=\"true\"><\/a>tau2 <span class=\"op\">=<\/span> <span class=\"fl\">1e2<\/span><\/span>\n<span id=\"cb1-15\"><a href=\"#cb1-15\" aria-hidden=\"true\"><\/a>mu <span class=\"op\">=<\/span> <span class=\"dv\">1<\/span><\/span>\n<span id=\"cb1-16\"><a href=\"#cb1-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-17\"><a href=\"#cb1-17\" aria-hidden=\"true\"><\/a>start_datetime <span class=\"op\">=<\/span> pd.Timestamp(datetime.now())<\/span>\n<span id=\"cb1-18\"><a href=\"#cb1-18\" aria-hidden=\"true\"><\/a>sim_index <span class=\"op\">=<\/span> pd.date_range(<\/span>\n<span id=\"cb1-19\"><a href=\"#cb1-19\" aria-hidden=\"true\"><\/a>    start<span class=\"op\">=<\/span><span class=\"st\">&quot;2016-01-01 12:00:00&quot;<\/span>, end<span class=\"op\">=<\/span><span class=\"st\">&quot;2016-01-08 12:00:00&quot;<\/span>, freq<span class=\"op\">=<\/span><span class=\"st\">&quot;H&quot;<\/span><\/span>\n<span id=\"cb1-20\"><a href=\"#cb1-20\" aria-hidden=\"true\"><\/a>)<\/span>\n<span id=\"cb1-21\"><a href=\"#cb1-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-22\"><a href=\"#cb1-22\" aria-hidden=\"true\"><\/a><span class=\"co\"># Simulated observations<\/span><\/span>\n<span id=\"cb1-23\"><a href=\"#cb1-23\" aria-hidden=\"true\"><\/a>X <span class=\"op\">=<\/span> np.sin(np.linspace(<span class=\"dv\">0<\/span>, <span class=\"dv\">2<\/span> <span class=\"op\">*<\/span> np.pi, np.alen(sim_index)))<\/span>\n<span id=\"cb1-24\"><a href=\"#cb1-24\" aria-hidden=\"true\"><\/a>y_obs <span class=\"op\">=<\/span> scs.norm.rvs(loc<span class=\"op\">=<\/span>X <span class=\"op\">*<\/span> mu_true, scale<span class=\"op\">=<\/span>np.sqrt(sigma2))<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>A Metropolis-Hastings sampler would perform a simple loop that accepts or rejects samples from a proposal distribution, <span class=\"math inline\">\\(\\theta_i \\sim p(\\theta_i \\mid \\theta_{i-1})\\)<\/span>, according to the probability<\/p>\n<p><span class=\"math display\">\\[\n  \\min\\left\\{1,\n  \\frac{p(Y = y \\mid X, \\theta_i)}{p(Y = y \\mid X, \\theta_{i-1})}\n  \\frac{p(\\theta_i \\mid \\theta_{i-1})}{p(\\theta_{i-1} \\mid \\theta_i)}\n  \\right\\}\n  \\;.\n\\]<\/span><\/p>\n<p>Let\u2019s say our proposal is a normal distribution with a mean equal to the previous sample and a variance given by <span class=\"math inline\">\\(\\lambda^2\\)<\/span>. The resulting sampling scheme is a random walk Metropolis-Hastings sampler, and since the proposal is a symmetric distribution, <span class=\"math inline\">\\(\\frac{p(\\theta_i \\mid \\theta_{i-1})}{p(\\theta_{i-1} \\mid \\theta_i)} = 1\\)<\/span>.<\/p>\n<p>In code, this could look like<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb2-1\"><a href=\"#cb2-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> functools <span class=\"im\">import<\/span> partial<\/span>\n<span id=\"cb2-2\"><a href=\"#cb2-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-3\"><a href=\"#cb2-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-4\"><a href=\"#cb2-4\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> model_logpdf(theta_):<\/span>\n<span id=\"cb2-5\"><a href=\"#cb2-5\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">=<\/span> np.<span class=\"bu\">sum<\/span>(scs.norm.logpdf(y_obs, loc<span class=\"op\">=<\/span>X <span class=\"op\">*<\/span> theta_, scale<span class=\"op\">=<\/span>np.sqrt(sigma2)))<\/span>\n<span id=\"cb2-6\"><a href=\"#cb2-6\" aria-hidden=\"true\"><\/a>    res <span class=\"op\">+=<\/span> scs.norm.logpdf(theta_, loc<span class=\"op\">=<\/span>mu, scale<span class=\"op\">=<\/span>np.sqrt(tau2))<\/span>\n<span id=\"cb2-7\"><a href=\"#cb2-7\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> res<\/span>\n<span id=\"cb2-8\"><a href=\"#cb2-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-9\"><a href=\"#cb2-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-10\"><a href=\"#cb2-10\" aria-hidden=\"true\"><\/a>N_samples <span class=\"op\">=<\/span> <span class=\"dv\">2000<\/span><\/span>\n<span id=\"cb2-11\"><a href=\"#cb2-11\" aria-hidden=\"true\"><\/a>theta_samples <span class=\"op\">=<\/span> []<\/span>\n<span id=\"cb2-12\"><a href=\"#cb2-12\" aria-hidden=\"true\"><\/a>lam <span class=\"op\">=<\/span> <span class=\"fl\">1.0<\/span><\/span>\n<span id=\"cb2-13\"><a href=\"#cb2-13\" aria-hidden=\"true\"><\/a>current_sample <span class=\"op\">=<\/span> np.random.normal(loc<span class=\"op\">=<\/span>mu, scale<span class=\"op\">=<\/span>lam)<\/span>\n<span id=\"cb2-14\"><a href=\"#cb2-14\" aria-hidden=\"true\"><\/a>proposal_logpdf <span class=\"op\">=<\/span> partial(scs.norm.logpdf, scale<span class=\"op\">=<\/span>lam)<\/span>\n<span id=\"cb2-15\"><a href=\"#cb2-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-16\"><a href=\"#cb2-16\" aria-hidden=\"true\"><\/a><span class=\"cf\">for<\/span> i <span class=\"kw\">in<\/span> <span class=\"bu\">range<\/span>(N_samples):<\/span>\n<span id=\"cb2-17\"><a href=\"#cb2-17\" aria-hidden=\"true\"><\/a>    proposal_sample <span class=\"op\">=<\/span> np.random.normal(loc<span class=\"op\">=<\/span>current_sample, scale<span class=\"op\">=<\/span>lam)<\/span>\n<span id=\"cb2-18\"><a href=\"#cb2-18\" aria-hidden=\"true\"><\/a>    l_ratio <span class=\"op\">=<\/span> np.<span class=\"bu\">sum<\/span>(model_logpdf(proposal_sample))<\/span>\n<span id=\"cb2-19\"><a href=\"#cb2-19\" aria-hidden=\"true\"><\/a>    l_ratio <span class=\"op\">-=<\/span> np.<span class=\"bu\">sum<\/span>(model_logpdf(current_sample))<\/span>\n<span id=\"cb2-20\"><a href=\"#cb2-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-21\"><a href=\"#cb2-21\" aria-hidden=\"true\"><\/a>    p_ratio <span class=\"op\">=<\/span> np.<span class=\"bu\">sum<\/span>(proposal_logpdf(current_sample, loc<span class=\"op\">=<\/span>proposal_sample))<\/span>\n<span id=\"cb2-22\"><a href=\"#cb2-22\" aria-hidden=\"true\"><\/a>    p_ratio <span class=\"op\">-=<\/span> np.<span class=\"bu\">sum<\/span>(proposal_logpdf(proposal_sample, loc<span class=\"op\">=<\/span>current_sample))<\/span>\n<span id=\"cb2-23\"><a href=\"#cb2-23\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-24\"><a href=\"#cb2-24\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> np.log(np.random.uniform()) <span class=\"op\">&lt;=<\/span> <span class=\"bu\">min<\/span>(<span class=\"dv\">0<\/span>, l_ratio <span class=\"op\">+<\/span> p_ratio):<\/span>\n<span id=\"cb2-25\"><a href=\"#cb2-25\" aria-hidden=\"true\"><\/a>        current_sample <span class=\"op\">=<\/span> proposal_sample<\/span>\n<span id=\"cb2-26\"><a href=\"#cb2-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-27\"><a href=\"#cb2-27\" aria-hidden=\"true\"><\/a>    theta_samples.append(current_sample)<\/span>\n<span id=\"cb2-28\"><a href=\"#cb2-28\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb2-29\"><a href=\"#cb2-29\" aria-hidden=\"true\"><\/a>theta_samples <span class=\"op\">=<\/span> np.asarray(theta_samples)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>The Metropolis-Hastings sampler does not rely on any prior information or Bayesian formulations. Although the prior is implicitly involved, via the total probability, the concepts behind the sampler itself are still valid without it. Basically, Metropolis-Hastings\u2013like many other MCMC sampling routines\u2013is not specifically Bayesian. It\u2019s better to simply consider MCMC as just another estimation approach (or perhaps a type of stochastic optimization).<\/p>\n<p><a href=\"https:\/\/en.wikipedia.org\/wiki\/Gibbs_sampling\">Gibbs sampling<\/a> is arguably the other most ubiquitous MCMC technique. Since a model specified in a Bayesian way usually provides a clear joint distribution (or at least something proportional to it) and conditional probabilities, Gibbs sampling is well facilitated.<\/p>\n<p>The context of Bayesian modeling is, however, a good source of direction and motivation for improvements to a sampling procedure (and estimation in general). Under Bayesian assumptions, decompositions and reformulations for broad classes of distributions are often immediately available. Guiding generalities, like the <a href=\"https:\/\/en.wikipedia.org\/wiki\/Rao%E2%80%93Blackwell_theorem\">Rao-Blackwell<\/a> theorem, are also applicable, and\u2013more generally\u2013the same principles, tools and results that guide the model creation and assessment process can also feed into the estimation process.<\/p>\n<p><a id=\"orgc99ff06\"><\/a><\/p>\n<section id=\"the-situation-on-implementation\" class=\"level2\">\n<h2>The Situation on Implementation<\/h2>\n<p>MCMC sampling schemes like the above are fairly general and easily abstracted, giving rise to some generic frameworks that put more focus on model specification and attempt to automate the choice of estimation (or implement one robust technique). Some of the more common frameworks are Bayesian in nature: <a href=\"http:\/\/www.openbugs.net\/w\/FrontPage\">OpenBUGS<\/a>, <a href=\"http:\/\/mcmc-jags.sourceforge.net\/\">JAGS<\/a>, <a href=\"http:\/\/mc-stan.org\/\">Stan<\/a>, and <a href=\"https:\/\/pymc-devs.github.io\/pymc\/\">PyMC2<\/a> \/ <a href=\"https:\/\/pymc-devs.github.io\/pymc3\/\">PyMC3<\/a>. These libraries provide a sort of meta-language that facilitates the specification of a Bayesian model and mirrors the mathematical language of probability. They also implicitly implement the <a href=\"https:\/\/en.wikipedia.org\/wiki\/Algebra_of_random_variables\">algebra of random variables<\/a> and automatically handle the mechanics of variable transforms.<\/p>\n<p>Our model, estimated with a Metropolis-Hastings sampler, can be expressed in PyMC3 with the following code:<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb3-1\"><a href=\"#cb3-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> pymc3 <span class=\"im\">as<\/span> pm<\/span>\n<span id=\"cb3-2\"><a href=\"#cb3-2\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano<\/span>\n<span id=\"cb3-3\"><a href=\"#cb3-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-4\"><a href=\"#cb3-4\" aria-hidden=\"true\"><\/a>theano.config.mode <span class=\"op\">=<\/span> <span class=\"st\">&quot;FAST_COMPILE&quot;<\/span><\/span>\n<span id=\"cb3-5\"><a href=\"#cb3-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-6\"><a href=\"#cb3-6\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model() <span class=\"im\">as<\/span> model:<\/span>\n<span id=\"cb3-7\"><a href=\"#cb3-7\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Model definition<\/span><\/span>\n<span id=\"cb3-8\"><a href=\"#cb3-8\" aria-hidden=\"true\"><\/a>    theta <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&quot;theta&quot;<\/span>, mu<span class=\"op\">=<\/span>mu, tau<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span> <span class=\"op\">\/<\/span> tau2)<\/span>\n<span id=\"cb3-9\"><a href=\"#cb3-9\" aria-hidden=\"true\"><\/a>    Y <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&quot;Y&quot;<\/span>, mu<span class=\"op\">=<\/span>X <span class=\"op\">*<\/span> theta, tau<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span> <span class=\"op\">\/<\/span> sigma2, observed<span class=\"op\">=<\/span>y_obs)<\/span>\n<span id=\"cb3-10\"><a href=\"#cb3-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-11\"><a href=\"#cb3-11\" aria-hidden=\"true\"><\/a>    <span class=\"co\"># Posterior sampling<\/span><\/span>\n<span id=\"cb3-12\"><a href=\"#cb3-12\" aria-hidden=\"true\"><\/a>    sample_steps <span class=\"op\">=<\/span> pm.Metropolis()<\/span>\n<span id=\"cb3-13\"><a href=\"#cb3-13\" aria-hidden=\"true\"><\/a>    sample_traces <span class=\"op\">=<\/span> pm.sample(<span class=\"dv\">2000<\/span>, sample_steps)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>As per the basic examples in the <a href=\"https:\/\/goo.gl\/WW3TO8\">PyMC3 notebooks<\/a>, the posterior samples are plotted below using the following code:<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb4-1\"><a href=\"#cb4-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> matplotlib.pyplot <span class=\"im\">as<\/span> plt<\/span>\n<span id=\"cb4-2\"><a href=\"#cb4-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-3\"><a href=\"#cb4-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-4\"><a href=\"#cb4-4\" aria-hidden=\"true\"><\/a>plt.style.use(<span class=\"st\">&quot;ggplot&quot;<\/span>)<\/span>\n<span id=\"cb4-5\"><a href=\"#cb4-5\" aria-hidden=\"true\"><\/a>plt.rc(<span class=\"st\">&quot;text&quot;<\/span>, usetex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb4-6\"><a href=\"#cb4-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb4-7\"><a href=\"#cb4-7\" aria-hidden=\"true\"><\/a>tp_axes <span class=\"op\">=<\/span> pm.traceplot(sample_traces)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>We can also superimpose the true posterior density given by <span class=\"math inline\">\\(\\eqref{eq:theta-posterior}\\)<\/span> with the following:<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb5-1\"><a href=\"#cb5-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> seaborn <span class=\"im\">as<\/span> sns<\/span>\n<span id=\"cb5-2\"><a href=\"#cb5-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-3\"><a href=\"#cb5-3\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> matplotlib.pyplot <span class=\"im\">as<\/span> plt<\/span>\n<span id=\"cb5-4\"><a href=\"#cb5-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-5\"><a href=\"#cb5-5\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-6\"><a href=\"#cb5-6\" aria-hidden=\"true\"><\/a>plt.style.use(<span class=\"st\">&quot;ggplot&quot;<\/span>)<\/span>\n<span id=\"cb5-7\"><a href=\"#cb5-7\" aria-hidden=\"true\"><\/a>plt.rc(<span class=\"st\">&quot;text&quot;<\/span>, usetex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb5-8\"><a href=\"#cb5-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-9\"><a href=\"#cb5-9\" aria-hidden=\"true\"><\/a>tp_axes <span class=\"op\">=<\/span> pm.traceplot(sample_traces)<\/span>\n<span id=\"cb5-10\"><a href=\"#cb5-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-11\"><a href=\"#cb5-11\" aria-hidden=\"true\"><\/a>_ <span class=\"op\">=<\/span> [a_.set_title(<span class=\"vs\">r&quot;Posterior $(\\theta \\mid y)$ Samples&quot;<\/span>) <span class=\"cf\">for<\/span> a_ <span class=\"kw\">in<\/span> tp_axes.ravel()]<\/span>\n<span id=\"cb5-12\"><a href=\"#cb5-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-13\"><a href=\"#cb5-13\" aria-hidden=\"true\"><\/a>freq_axis <span class=\"op\">=<\/span> tp_axes[<span class=\"dv\">0<\/span>][<span class=\"dv\">0<\/span>]<\/span>\n<span id=\"cb5-14\"><a href=\"#cb5-14\" aria-hidden=\"true\"><\/a>freq_axis.set_xlabel(<span class=\"vs\">r&quot;$\\theta$&quot;<\/span>)<\/span>\n<span id=\"cb5-15\"><a href=\"#cb5-15\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-16\"><a href=\"#cb5-16\" aria-hidden=\"true\"><\/a>sample_axis <span class=\"op\">=<\/span> tp_axes[<span class=\"dv\">0<\/span>][<span class=\"dv\">1<\/span>]<\/span>\n<span id=\"cb5-17\"><a href=\"#cb5-17\" aria-hidden=\"true\"><\/a>sample_axis.set_xlabel(<span class=\"vs\">r&quot;$i$&quot;<\/span>)<\/span>\n<span id=\"cb5-18\"><a href=\"#cb5-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-19\"><a href=\"#cb5-19\" aria-hidden=\"true\"><\/a>rhs <span class=\"op\">=<\/span> np.dot(<span class=\"fl\">1.0<\/span> <span class=\"op\">\/<\/span> tau2, mu) <span class=\"op\">+<\/span> np.dot(X.T <span class=\"op\">\/<\/span> sigma2, y_obs)<\/span>\n<span id=\"cb5-20\"><a href=\"#cb5-20\" aria-hidden=\"true\"><\/a>tau_post <span class=\"op\">=<\/span> <span class=\"fl\">1.0<\/span> <span class=\"op\">\/<\/span> tau2 <span class=\"op\">+<\/span> np.dot(X.T <span class=\"op\">\/<\/span> sigma2, X)<\/span>\n<span id=\"cb5-21\"><a href=\"#cb5-21\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-22\"><a href=\"#cb5-22\" aria-hidden=\"true\"><\/a>post_mean <span class=\"op\">=<\/span> rhs <span class=\"op\">\/<\/span> tau_post<\/span>\n<span id=\"cb5-23\"><a href=\"#cb5-23\" aria-hidden=\"true\"><\/a>post_var_inv <span class=\"op\">=<\/span> tau_post<\/span>\n<span id=\"cb5-24\"><a href=\"#cb5-24\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-25\"><a href=\"#cb5-25\" aria-hidden=\"true\"><\/a>post_pdf <span class=\"op\">=<\/span> partial(scs.norm.pdf, loc<span class=\"op\">=<\/span>post_mean, scale<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span> <span class=\"op\">\/<\/span> np.sqrt(post_var_inv))<\/span>\n<span id=\"cb5-26\"><a href=\"#cb5-26\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-27\"><a href=\"#cb5-27\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-28\"><a href=\"#cb5-28\" aria-hidden=\"true\"><\/a><span class=\"kw\">def<\/span> add_function_plot(func, ax, num<span class=\"op\">=<\/span><span class=\"fl\">1e2<\/span>, label<span class=\"op\">=<\/span><span class=\"va\">None<\/span>):<\/span>\n<span id=\"cb5-29\"><a href=\"#cb5-29\" aria-hidden=\"true\"><\/a>    post_range <span class=\"op\">=<\/span> np.linspace(<span class=\"op\">*<\/span>ax.get_xlim(), num<span class=\"op\">=<\/span><span class=\"bu\">int<\/span>(num), endpoint<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb5-30\"><a href=\"#cb5-30\" aria-hidden=\"true\"><\/a>    post_data <span class=\"op\">=<\/span> [post_pdf(v) <span class=\"cf\">for<\/span> v <span class=\"kw\">in<\/span> post_range]<\/span>\n<span id=\"cb5-31\"><a href=\"#cb5-31\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">return<\/span> ax.plot(post_range, post_data, label<span class=\"op\">=<\/span>label)<\/span>\n<span id=\"cb5-32\"><a href=\"#cb5-32\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-33\"><a href=\"#cb5-33\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-34\"><a href=\"#cb5-34\" aria-hidden=\"true\"><\/a><span class=\"co\"># Add true posterior pdf to the plot<\/span><\/span>\n<span id=\"cb5-35\"><a href=\"#cb5-35\" aria-hidden=\"true\"><\/a>add_function_plot(post_pdf, freq_axis, label<span class=\"op\">=<\/span><span class=\"vs\">r&quot;Exact&quot;<\/span>)<\/span>\n<span id=\"cb5-36\"><a href=\"#cb5-36\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-37\"><a href=\"#cb5-37\" aria-hidden=\"true\"><\/a><span class=\"co\"># Add manually produced MH samples to the plot<\/span><\/span>\n<span id=\"cb5-38\"><a href=\"#cb5-38\" aria-hidden=\"true\"><\/a>sns.distplot(theta_samples[:<span class=\"dv\">2000<\/span>], ax<span class=\"op\">=<\/span>freq_axis, hist<span class=\"op\">=<\/span><span class=\"va\">False<\/span>, label<span class=\"op\">=<\/span><span class=\"vs\">r&quot;Manual MH&quot;<\/span>)<\/span>\n<span id=\"cb5-39\"><a href=\"#cb5-39\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-40\"><a href=\"#cb5-40\" aria-hidden=\"true\"><\/a>sample_axis.plot(theta_samples[:<span class=\"dv\">2000<\/span>], label<span class=\"op\">=<\/span><span class=\"vs\">r&quot;Manual MH&quot;<\/span>)<\/span>\n<span id=\"cb5-41\"><a href=\"#cb5-41\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb5-42\"><a href=\"#cb5-42\" aria-hidden=\"true\"><\/a>freq_axis.legend()<\/span>\n<span id=\"cb5-43\"><a href=\"#cb5-43\" aria-hidden=\"true\"><\/a>sample_axis.legend()<\/span>\n<span id=\"cb5-44\"><a href=\"#cb5-44\" aria-hidden=\"true\"><\/a>plt.show()<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/theta_post_plot.png\" title=\"fig:\" alt=\"Posterior samples \" \/>\n<figcaption>\nPosterior samples\n<\/figcaption>\n<\/figure>\n<p><a id=\"org82a5c5b\"><\/a><\/p>\n<\/section>\n<section id=\"the-costs\" class=\"level2\">\n<h2>The Costs<\/h2>\n<p>MCMC, and specifically the Metropolis-Hastings approach used above, can look very simple and universally applicable, but\u2013of course\u2013there\u2019s a trade-off occurring somewhere. The trade-offs most often appear in relation to the complexity and cost of [intermediate] sampling steps and convergence rates. To over simplify, the standard <span class=\"math inline\">\\(O(N^{-1\/2})\\)<\/span> error rate\u2013from the <a href=\"https:\/\/en.wikipedia.org\/wiki\/Central_limit_theorem\">Central Limit Theorem<\/a>\u2013is the MCMC baseline, which isn\u2019t all that competitive with some of the standard deterministic optimization methods.<\/p>\n<p>Even for conceptually simple models, the proposal distribution (and its parameters) are not always easy to choose or cheap to tune. The upfront computational costs can be quite high for the more generic MCMC approaches, but there are almost always paths toward efficient samplers\u2013in the context of a specific problem, at least.<\/p>\n<p>In practice, the generality and relative simplicity of the Bayes approach, combined with MCMC, can be somewhat misleading to newcomers. After some immediate success with simpler and\/or scaled down problems, one is soon led to believe that the cost of direct computations and the effort and skill required to derive efficient methods is not worth the potential parsimony and extra information provided by sample results.<\/p>\n<p>The unfortunate outcome of this situation is sometimes an effective rejection of Bayes and MCMC altogether. Although the point hasn\u2019t been illustrated here, MCMC isn\u2019t the only option. <strong>Bayesian models are just as amenable to deterministic estimation as non-Bayesian ones<\/strong>, and a wide array of efficient deterministic estimation techniques are available\u2013albeit not so common in standard practice <a id=\"1b02eb0eebf5ae001cbb5ff5b74acff1\"><a href=\"#polson_proximal_2015\">(Polson, Scott &amp; Willard 2015)<\/a><\/a>.<\/p>\n<p><a id=\"orgfe79892\"><\/a><\/p>\n<\/section>\n<\/section>\n<section id=\"predictions\" class=\"level1\">\n<h1>Predictions<\/h1>\n<p>The sampling situation offered by MCMC (and Bayes) puts one in a nice situation to make extensive use of predictions <em>and<\/em> obtain uncertainty measures (e.g.\u00a0variances, credible intervals, etc.).<\/p>\n<p>In general, posterior predictive samples are fairly easy to obtain. Once you have posterior samples of <span class=\"math inline\">\\(\\theta\\)<\/span>, say <span class=\"math inline\">\\(\\{\\theta_i\\}_{i=0}^M\\)<\/span>, simply plug those into the sampling\/observation distribution and sample <span class=\"math inline\">\\(Y\\)<\/span> values. Specifically,<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  \\{y_i \\sim p(Y \\mid X, \\theta_i) : \\theta_i \\sim p(\\theta_i \\mid y)\\}_{i=0}^M\n  \\label{eq:post_predict_samples}\n\\end{equation}\\]<\/span><\/p>\n<p>is a posterior predictive sample from <span class=\"math inline\">\\(p(Y \\mid X, y)\\)<\/span>.<\/p>\n<p>The procedural interpretation of <span class=\"math inline\">\\(\\eqref{eq:post_predict_samples}\\)<\/span> is:<\/p>\n<p>Assuming we\u2019ve already produced a posterior sample, this is as simple as plugging those <span class=\"math inline\">\\(\\theta_i\\)<\/span> into the observation distribution <span class=\"math inline\">\\(\\eqref{eq:normal-normal}\\)<\/span> and sampling. The cumulative effect of this process is equivalent to producing an estimate of the marginal<\/p>\n<p><span class=\"math display\">\\[\n  \\int p(Y_t \\mid x_t, \\theta) p(\\theta \\mid y) d\\theta = p(Y_t \\mid x_t, y)\n  \\;.\n\\]<\/span><\/p>\n<p>The posterior predictive sample in <span class=\"math inline\">\\(\\eqref{eq:post_predict_samples}\\)<\/span> contains much of the information a modeler desires. Take the variance of this sample and one has a common measure of prediction error; produce quantiles of the sample and one has <a href=\"https:\/\/en.wikipedia.org\/wiki\/Credible_interval\">\u201ccredible\u201d<\/a> prediction intervals. The sample produced by mapping an arbitrary function to each posterior predictive sample is itself amenable to the aforementioned summaries, allowing one to easily produce errors for complicated uses of predicted quantities. We illustrate these use cases below.<\/p>\n<p>Using our previous simulation and PyMC3, the posterior predictive samples are obtained with<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb6-1\"><a href=\"#cb6-1\" aria-hidden=\"true\"><\/a>ppc_samples <span class=\"op\">=<\/span> pm.sample_posterior_predictive(sample_traces, model<span class=\"op\">=<\/span>model)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>and plotted with<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb7-1\"><a href=\"#cb7-1\" aria-hidden=\"true\"><\/a>y_obs_h <span class=\"op\">=<\/span> pd.Series(y_obs, index<span class=\"op\">=<\/span>sim_index)<\/span>\n<span id=\"cb7-2\"><a href=\"#cb7-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-3\"><a href=\"#cb7-3\" aria-hidden=\"true\"><\/a>ppc_hpd <span class=\"op\">=<\/span> pm.hpd(ppc_samples[<span class=\"st\">&quot;Y&quot;<\/span>], <span class=\"fl\">0.95<\/span>)<\/span>\n<span id=\"cb7-4\"><a href=\"#cb7-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-5\"><a href=\"#cb7-5\" aria-hidden=\"true\"><\/a>y_obs_h.plot(label<span class=\"op\">=<\/span><span class=\"st\">&quot;$y$&quot;<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&quot;black&quot;<\/span>)<\/span>\n<span id=\"cb7-6\"><a href=\"#cb7-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-7\"><a href=\"#cb7-7\" aria-hidden=\"true\"><\/a>y_obs_mean <span class=\"op\">=<\/span> pd.Series(ppc_samples[<span class=\"st\">&quot;Y&quot;<\/span>].mean(axis<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>), index<span class=\"op\">=<\/span>sim_index)<\/span>\n<span id=\"cb7-8\"><a href=\"#cb7-8\" aria-hidden=\"true\"><\/a>y_obs_mean.plot(label<span class=\"op\">=<\/span><span class=\"vs\">r&quot;$E[Y \\mid X, y]$&quot;<\/span>, alpha<span class=\"op\">=<\/span><span class=\"fl\">0.7<\/span>)<\/span>\n<span id=\"cb7-9\"><a href=\"#cb7-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-10\"><a href=\"#cb7-10\" aria-hidden=\"true\"><\/a>plt.fill_between(<\/span>\n<span id=\"cb7-11\"><a href=\"#cb7-11\" aria-hidden=\"true\"><\/a>    sim_index,<\/span>\n<span id=\"cb7-12\"><a href=\"#cb7-12\" aria-hidden=\"true\"><\/a>    ppc_hpd[:, <span class=\"dv\">0<\/span>],<\/span>\n<span id=\"cb7-13\"><a href=\"#cb7-13\" aria-hidden=\"true\"><\/a>    ppc_hpd[:, <span class=\"dv\">1<\/span>],<\/span>\n<span id=\"cb7-14\"><a href=\"#cb7-14\" aria-hidden=\"true\"><\/a>    label<span class=\"op\">=<\/span><span class=\"vs\">r&quot;$(Y \\mid X, y)$ 95\\<\/span><span class=\"sc\">% i<\/span><span class=\"vs\">nterval&quot;<\/span>,<\/span>\n<span id=\"cb7-15\"><a href=\"#cb7-15\" aria-hidden=\"true\"><\/a>    alpha<span class=\"op\">=<\/span><span class=\"fl\">0.5<\/span>,<\/span>\n<span id=\"cb7-16\"><a href=\"#cb7-16\" aria-hidden=\"true\"><\/a>)<\/span>\n<span id=\"cb7-17\"><a href=\"#cb7-17\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb7-18\"><a href=\"#cb7-18\" aria-hidden=\"true\"><\/a>plt.legend()<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/hourly_ppc_plot.png\" title=\"fig:\" alt=\"Posterior predictive samples \" \/>\n<figcaption>\nPosterior predictive samples\n<\/figcaption>\n<\/figure>\n<div class=\"example\" data-markdown=\"\">\n<p>Let\u2019s say we\u2019re interested in daily, monthly, or yearly averages for <span class=\"math inline\">\\(Y_t\\)<\/span> at a lower frequency\u2013like minutes or hours. Similarly, we might want to consider functions of differences between the outputs of different models, <span class=\"math inline\">\\(f(Y^{(j)} - Y^{(k)})\\)<\/span> for <span class=\"math inline\">\\(j, k \\in \\{1, 2\\}\\)<\/span>, or more generally <span class=\"math inline\">\\(f(Y^{(j)}, Y^{(k)})\\)<\/span>. These quantities derived from simple manipulations of <code>ppc_hpd<\/code>.<\/p>\n<\/div>\n<p>Next, we produce predictions for daily averages\u2013along with (credible) intervals.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb8\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb8-1\"><a href=\"#cb8-1\" aria-hidden=\"true\"><\/a>ppc_samples_h <span class=\"op\">=<\/span> pd.DataFrame(ppc_samples[<span class=\"st\">&quot;Y&quot;<\/span>].T, index<span class=\"op\">=<\/span>sim_index)<\/span>\n<span id=\"cb8-2\"><a href=\"#cb8-2\" aria-hidden=\"true\"><\/a>ppc_samples_h <span class=\"op\">=<\/span> ppc_samples_h.stack()<\/span>\n<span id=\"cb8-3\"><a href=\"#cb8-3\" aria-hidden=\"true\"><\/a>ppc_samples_h <span class=\"op\">=<\/span> ppc_samples_h[:, <span class=\"dv\">0<\/span>]<\/span>\n<span id=\"cb8-4\"><a href=\"#cb8-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-5\"><a href=\"#cb8-5\" aria-hidden=\"true\"><\/a>ppc_quantiles_d <span class=\"op\">=<\/span> ppc_samples_h.resample(<span class=\"st\">&quot;D&quot;<\/span>).<span class=\"bu\">apply<\/span>(<\/span>\n<span id=\"cb8-6\"><a href=\"#cb8-6\" aria-hidden=\"true\"><\/a>    <span class=\"kw\">lambda<\/span> x: x.quantile(q<span class=\"op\">=<\/span>[<span class=\"fl\">0.05<\/span>, <span class=\"fl\">0.5<\/span>, <span class=\"fl\">0.95<\/span>])<\/span>\n<span id=\"cb8-7\"><a href=\"#cb8-7\" aria-hidden=\"true\"><\/a>)<\/span>\n<span id=\"cb8-8\"><a href=\"#cb8-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-9\"><a href=\"#cb8-9\" aria-hidden=\"true\"><\/a>ppc_quantiles_d <span class=\"op\">=<\/span> ppc_quantiles_d.unstack()<\/span>\n<span id=\"cb8-10\"><a href=\"#cb8-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb8-11\"><a href=\"#cb8-11\" aria-hidden=\"true\"><\/a>y_obs_d <span class=\"op\">=<\/span> y_obs_h.resample(<span class=\"st\">&quot;D&quot;<\/span>).mean()<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure>\n<div class=\"sourceCode\" id=\"cb9\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb9-1\"><a href=\"#cb9-1\" aria-hidden=\"true\"><\/a>plt.clf()<\/span>\n<span id=\"cb9-2\"><a href=\"#cb9-2\" aria-hidden=\"true\"><\/a>y_obs_d.plot(label<span class=\"op\">=<\/span><span class=\"st\">&#39;$f(y)$&#39;<\/span>, color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>)<\/span>\n<span id=\"cb9-3\"><a href=\"#cb9-3\" aria-hidden=\"true\"><\/a>plt.fill_between(ppc_quantiles_d.index,<\/span>\n<span id=\"cb9-4\"><a href=\"#cb9-4\" aria-hidden=\"true\"><\/a>                 ppc_quantiles_d[<span class=\"fl\">0.05<\/span>],<\/span>\n<span id=\"cb9-5\"><a href=\"#cb9-5\" aria-hidden=\"true\"><\/a>                 ppc_quantiles_d[<span class=\"fl\">0.95<\/span>],<\/span>\n<span id=\"cb9-6\"><a href=\"#cb9-6\" aria-hidden=\"true\"><\/a>                 label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$(f(Y) \\mid X, y)$ 95\\<\/span><span class=\"sc\">% i<\/span><span class=\"vs\">nterval&#39;<\/span>,<\/span>\n<span id=\"cb9-7\"><a href=\"#cb9-7\" aria-hidden=\"true\"><\/a>                 alpha<span class=\"op\">=<\/span><span class=\"fl\">0.5<\/span>)<\/span>\n<span id=\"cb9-8\"><a href=\"#cb9-8\" aria-hidden=\"true\"><\/a>ppc_quantiles_d[<span class=\"fl\">0.5<\/span>].plot(label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$E[f(Y) \\mid X, y]$&#39;<\/span>, alpha<span class=\"op\">=<\/span><span class=\"fl\">.7<\/span>)<\/span>\n<span id=\"cb9-9\"><a href=\"#cb9-9\" aria-hidden=\"true\"><\/a>plt.legend()<\/span>\n<span id=\"cb9-10\"><a href=\"#cb9-10\" aria-hidden=\"true\"><\/a>plt.show()<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/daily_ppc_plot.png\" title=\"fig:\" alt=\"Daily posterior predictive results from the hourly posterior. \" \/>\n<figcaption>\nDaily posterior predictive results from the hourly posterior.\n<\/figcaption>\n<\/figure>\n<p><a id=\"org72d572e\"><\/a><\/p>\n<\/section>\n<section id=\"hierarchical-extensions\" class=\"level1\">\n<h1>Hierarchical Extensions<\/h1>\n<p>Even though we only considered \u201cin-sample\u201d predictions in the previous section, out-of-sample and missing values are covered by exactly the same process (neatly simplified by PyMC3\u2019s <code>sample_ppc<\/code>). In our example we needed an exogenous variable <span class=\"math inline\">\\(x_t\\)<\/span> in order to sample a point from the observation model <span class=\"math inline\">\\((Y_t \\mid x_t)\\)<\/span>. When the values in <span class=\"math inline\">\\(X\\)<\/span> cannot be obtained\u2013e.g.\u00a0future values of a non-deterministic quantity\u2013clever, context specific imputations are usually proposed.<\/p>\n<p>Nearly every instance of such imputations gives rise to an implicit model. Going back to our preference for transparent statistical specification, it behooves us to formally specify the model. If we do so in a well-defined Bayes way, then we\u2019re immediately provided the exact same conveniences as above.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p><a id=\"org729ed4b\"><\/a> If the <span class=\"math inline\">\\(X\\)<\/span> values in our sample now correspond to, say, temperature, and today is the last day in our time-indexed observations <code>y_obs<\/code>, then predicting forward in time will require temperatures for the future.<\/p>\n<\/div>\n<p>One answer to this situation is a model for <span class=\"math inline\">\\(x_t\\)<\/span>. If we specify some <span class=\"math inline\">\\(X_t \\sim P\\)<\/span>, then we can apply the same principles above via the posterior predictive <span class=\"math inline\">\\(p(X_t)\\)<\/span>. This posterior predictive will have no exogenous dependencies (unless we want it to), and its posterior can be estimated with our given <span class=\"math inline\">\\(X\\)<\/span> observations. All this occurs in exactly the same fashion as our model for <span class=\"math inline\">\\(Y_t\\)<\/span>.<\/p>\n<p>In practice, one often sees the use of summary statistics from previous <span class=\"math inline\">\\(x_t\\)<\/span> observations in intervals representative of the desired prediction period. For instance, in the context of Example <a href=\"#org729ed4b\">2<\/a>, the average temperatures in previous years over the months corresponding to the prediction interval (e.g.\u00a0January-February averages through 2010 to 2016 as imputations for January-February 2017).<\/p>\n<p>This isn\u2019t a bad idea, per se, but it is a needlessly indirect\u2013and often insufficient\u2013approach to defining a statistical model for <span class=\"math inline\">\\(X\\)<\/span>. It leaves out critical distributional details, the same details needed to determine how anything using our new <span class=\"math inline\">\\(x_t\\)<\/span> estimates might be affected (through <a href=\"https:\/\/en.wikipedia.org\/wiki\/Propagation_of_uncertainty\">propagation of uncertainty<\/a>). Eventually one comes around to specifying these details, but, in situations of sufficient complexity, this practice doesn\u2019t produce a very clean, manageable or easily extensible model.<\/p>\n<p>The kinds of complicated models arising in these situations are both conceptually and technically difficult to use, and\u2013as a result\u2013it can be very hard to produce anything other than naive asymptotic approximations for errors and intervals. Sadly, these approximations are generally insufficient for all but the simplest scenarios.<\/p>\n<p>In contrast, we can model the <span class=\"math inline\">\\(x_t\\)<\/span> values directly and have a very clear cut path toward out-of-sample predictions and their distributional properties. Even if we hold to the belief that the previous average values are a reasonable imputation, then a number of simple models can account for that assumption.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p><a id=\"orge850896\"><\/a> Let\u2019s consider a normal regression model for <span class=\"math inline\">\\(x_t\\)<\/span> with seasonal factors, i.e.<\/p>\n<p><span class=\"math display\">\\[\\begin{equation}\n  X_t \\sim \\operatorname{N}(d(t)^\\top \\beta, I \\sigma_x^2)\n  \\label{eq:exogenous_model}\n\\end{equation}\\]<\/span><\/p>\n<p>where <span class=\"math inline\">\\(d(t)\\)<\/span> is an indicator vector containing the seasonal factors and <span class=\"math inline\">\\(I\\)<\/span> is an identity matrix.<\/p>\n<p>Keep in mind that we\u2019ve stretched the notation a bit by letting <span class=\"math inline\">\\(X_t\\)<\/span> be a random vector at time <span class=\"math inline\">\\(t\\)<\/span>, while <span class=\"math inline\">\\(X\\)<\/span> is still the stacked matrix of observed <span class=\"math inline\">\\(x_t\\)<\/span> values. Now, we\u2019re simply adding the assumption <span class=\"math inline\">\\(x_t \\sim X_t\\)<\/span>.<\/p>\n<p>Let\u2019s say that our new <span class=\"math inline\">\\(\\beta\\)<\/span> vector has terms for each day of the week; this means the matrix of stacked <span class=\"math inline\">\\(d(t)\\)<\/span> values, <span class=\"math inline\">\\(D\\)<\/span>, is some classical factor design matrix with levels for each day. The product <span class=\"math inline\">\\(d(t)^\\top \\beta\\)<\/span> is then some scalar mean for the day corresponding to <span class=\"math inline\">\\(t\\)<\/span>.<\/p>\n<p>A simple substitution of this model for our previously constant <span class=\"math inline\">\\(X\\)<\/span> matrix, results in a sort of hierarchical model, which we can now coherently marginalize and obtain the desired posterior predictive, <span class=\"math inline\">\\(p(Y \\mid y)\\)<\/span>. This time, the posterior predictive is independent of <span class=\"math inline\">\\(X_t\\)<\/span>, so we can produce results for any <span class=\"math inline\">\\(t\\)<\/span>.<\/p>\n<p>The change in our complete model is relatively minimal. The model above for <span class=\"math inline\">\\(X\\)<\/span> results in the following marginal observation model:<\/p>\n<p><span class=\"math display\">\\[\\begin{align*}\n  \\left(Y_t \\mid \\beta, \\theta \\right) &amp;\\propto\n  \\int p(Y_t \\mid X_t, \\theta) p(X_t \\mid \\beta) dX\n  \\\\\n  &amp;\\sim \\operatorname{N}\\left(\n  d(t)^\\top \\beta \\cdot \\theta,\n  \\sigma^2 + \\sigma_x^2 \\cdot d(t)^\\top \\beta \\beta^\\top d(t) \\right)\n  \\;.\n\\end{align*}\\]<\/span><\/p>\n<\/div>\n<p>The reduction in Example <a href=\"#orge850896\">3<\/a> could be considered an entire re-definition of our initial observation model in <span class=\"math inline\">\\(\\eqref{eq:normal-normal}\\)<\/span>. A change like this is a natural part of the standard model development cycle. However, this is not the only way to look at it. In the Bayesian setting we can keep the observation model fixed and iterate on the prior\u2019s specification. The resulting marginal distribution could effectively be the same under both approaches (if desired), but the latter has the advantage of at least maintaining\u2013conditionally\u2013our earlier work.<\/p>\n<div class=\"example\" data-markdown=\"\">\n<p>We haven\u2019t given a prior to <span class=\"math inline\">\\(\\beta\\)<\/span>, but if we did, in the absence of conflicting assumptions, we might want the product <span class=\"math inline\">\\(\\beta \\cdot \\theta\\)<\/span> to simplify to a single unknown variables of its own, so that we\u2019re not estimating two \u201centangled\u201d variables. This idea might be inspired by an understanding of the classical <a href=\"https:\/\/en.wikipedia.org\/wiki\/Parameter_identification_problem\">identification<\/a> issue arising from such products.<\/p>\n<p>With <span class=\"math inline\">\\(\\beta\\)<\/span> constant, the form of our marginal observation model is basically unchanged from our initial <span class=\"math inline\">\\(\\eqref{eq:normal-normal}\\)<\/span> under <span class=\"math inline\">\\(x_t \\to d(t)^\\top \\beta\\)<\/span> and <span class=\"math inline\">\\(\\sigma^2 \\to \\sigma^2 + \\sigma_x^2 \\cdot d(t)^\\top \\beta \\beta^\\top d(t)\\)<\/span>.<\/p>\n<\/div>\n<p>Adherence to established models or industry standards is not uncommon. Outside of hierarchical model development, it can be very difficult to make these connections and coherently propagate statistical assumptions.<\/p>\n<p>This model development process expands in complexity and applicability through natural and compartmental extensions of existing terms. Simpler, \u201cbase\u201d models are found as marginalizations of the new terms, and all the same estimation techniques apply.<\/p>\n<p>We\u2019ll close with an illustration of the piecewise exogenous variable model described in <a href=\"#orge850896\">6<\/a>. A few days are added to demonstrate out-of-sample predictions and the design matrix, <span class=\"math inline\">\\(D\\)<\/span>, for <span class=\"math inline\">\\(\\eqref{eq:exogenous_model}\\)<\/span> is produced using <a href=\"https:\/\/patsy.readthedocs.io\/en\/latest\/\">Patsy<\/a>.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb10\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb10-1\"><a href=\"#cb10-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> patsy<\/span>\n<span id=\"cb10-2\"><a href=\"#cb10-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-3\"><a href=\"#cb10-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-4\"><a href=\"#cb10-4\" aria-hidden=\"true\"><\/a>ext_sim_index <span class=\"op\">=<\/span> pd.date_range(<\/span>\n<span id=\"cb10-5\"><a href=\"#cb10-5\" aria-hidden=\"true\"><\/a>    start<span class=\"op\">=<\/span><span class=\"st\">&quot;2016-01-01 12:00:00&quot;<\/span>, end<span class=\"op\">=<\/span><span class=\"st\">&quot;2016-01-16 12:00:00&quot;<\/span>, freq<span class=\"op\">=<\/span><span class=\"st\">&quot;H&quot;<\/span><\/span>\n<span id=\"cb10-6\"><a href=\"#cb10-6\" aria-hidden=\"true\"><\/a>)<\/span>\n<span id=\"cb10-7\"><a href=\"#cb10-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-8\"><a href=\"#cb10-8\" aria-hidden=\"true\"><\/a>y_obs_df <span class=\"op\">=<\/span> pd.DataFrame(y_obs, index<span class=\"op\">=<\/span>sim_index, columns<span class=\"op\">=<\/span>[<span class=\"vs\">r&quot;y&quot;<\/span>])<\/span>\n<span id=\"cb10-9\"><a href=\"#cb10-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-10\"><a href=\"#cb10-10\" aria-hidden=\"true\"><\/a><span class=\"co\"># The extra out-of-sample days are set to NaN<\/span><\/span>\n<span id=\"cb10-11\"><a href=\"#cb10-11\" aria-hidden=\"true\"><\/a><span class=\"co\"># y_obs_df = y_obs_df.reindex(ext_sim_index)<\/span><\/span>\n<span id=\"cb10-12\"><a href=\"#cb10-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-13\"><a href=\"#cb10-13\" aria-hidden=\"true\"><\/a>y_obs_df <span class=\"op\">=<\/span> y_obs_df.assign(weekday<span class=\"op\">=<\/span>y_obs_df.index.weekday)<\/span>\n<span id=\"cb10-14\"><a href=\"#cb10-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-15\"><a href=\"#cb10-15\" aria-hidden=\"true\"><\/a>y_df, D_df <span class=\"op\">=<\/span> patsy.dmatrices(<span class=\"st\">&quot;y ~ C(weekday)&quot;<\/span>, y_obs_df, return_type<span class=\"op\">=<\/span><span class=\"st\">&quot;dataframe&quot;<\/span>)<\/span>\n<span id=\"cb10-16\"><a href=\"#cb10-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-17\"><a href=\"#cb10-17\" aria-hidden=\"true\"><\/a><span class=\"co\"># Create a missing day<\/span><\/span>\n<span id=\"cb10-18\"><a href=\"#cb10-18\" aria-hidden=\"true\"><\/a>y_df.iloc[y_df.index.weekday <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span>, <span class=\"dv\">0<\/span>] <span class=\"op\">=<\/span> np.nan<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>Again, with PyMC3 our model and its extension are easily expressed, and the missing observations will be sampled automatically.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb11-1\"><a href=\"#cb11-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> theano.tensor <span class=\"im\">as<\/span> tt<\/span>\n<span id=\"cb11-2\"><a href=\"#cb11-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-3\"><a href=\"#cb11-3\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-4\"><a href=\"#cb11-4\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> pm.Model() <span class=\"im\">as<\/span> model:<\/span>\n<span id=\"cb11-5\"><a href=\"#cb11-5\" aria-hidden=\"true\"><\/a>    theta <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&quot;theta&quot;<\/span>, mu<span class=\"op\">=<\/span>mu, tau<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span> <span class=\"op\">\/<\/span> tau2)<\/span>\n<span id=\"cb11-6\"><a href=\"#cb11-6\" aria-hidden=\"true\"><\/a>    beta <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&quot;beta&quot;<\/span>, mu<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>, sd<span class=\"op\">=<\/span><span class=\"fl\">1e1<\/span>, shape<span class=\"op\">=<\/span>(D_df.shape[<span class=\"op\">-<\/span><span class=\"dv\">1<\/span>],))<\/span>\n<span id=\"cb11-7\"><a href=\"#cb11-7\" aria-hidden=\"true\"><\/a>    mu_y <span class=\"op\">=<\/span> tt.transpose(tt.dot(D_df, beta)) <span class=\"op\">*<\/span> theta<\/span>\n<span id=\"cb11-8\"><a href=\"#cb11-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-9\"><a href=\"#cb11-9\" aria-hidden=\"true\"><\/a>    Y <span class=\"op\">=<\/span> pm.Normal(<span class=\"st\">&quot;Y&quot;<\/span>, mu<span class=\"op\">=<\/span>mu_y, tau<span class=\"op\">=<\/span><span class=\"fl\">1.0<\/span> <span class=\"op\">\/<\/span> sigma2, observed<span class=\"op\">=<\/span>y_df.y)<\/span>\n<span id=\"cb11-10\"><a href=\"#cb11-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-11\"><a href=\"#cb11-11\" aria-hidden=\"true\"><\/a><span class=\"cf\">with<\/span> model:<\/span>\n<span id=\"cb11-12\"><a href=\"#cb11-12\" aria-hidden=\"true\"><\/a>    sample_steps <span class=\"op\">=<\/span> [pm.Metropolis([theta]), pm.Metropolis([beta])]<\/span>\n<span id=\"cb11-13\"><a href=\"#cb11-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-14\"><a href=\"#cb11-14\" aria-hidden=\"true\"><\/a>    <span class=\"cf\">if<\/span> Y.missing_values <span class=\"kw\">is<\/span> <span class=\"kw\">not<\/span> <span class=\"va\">None<\/span>:<\/span>\n<span id=\"cb11-15\"><a href=\"#cb11-15\" aria-hidden=\"true\"><\/a>        sample_steps <span class=\"op\">+=<\/span> [pm.Metropolis(Y.missing_values)]<\/span>\n<span id=\"cb11-16\"><a href=\"#cb11-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-17\"><a href=\"#cb11-17\" aria-hidden=\"true\"><\/a>    sample_traces <span class=\"op\">=<\/span> pm.sample(<span class=\"dv\">2000<\/span>, sample_steps)<\/span>\n<span id=\"cb11-18\"><a href=\"#cb11-18\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb11-19\"><a href=\"#cb11-19\" aria-hidden=\"true\"><\/a>    ppc_samples <span class=\"op\">=<\/span> pm.sample_posterior_predictive(sample_traces)<\/span><\/code><\/pre><\/div>\n<\/figure>\n<p>The posterior predictive results are plotted below.<\/p>\n<figure>\n<div class=\"sourceCode\" id=\"cb12\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb12-1\"><a href=\"#cb12-1\" aria-hidden=\"true\"><\/a>ppc_y_samples <span class=\"op\">=<\/span> ppc_samples[<span class=\"st\">&#39;Y&#39;<\/span>]<\/span>\n<span id=\"cb12-2\"><a href=\"#cb12-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-3\"><a href=\"#cb12-3\" aria-hidden=\"true\"><\/a>ppc_mean_df <span class=\"op\">=<\/span> pd.DataFrame(ppc_y_samples.mean(axis<span class=\"op\">=<\/span><span class=\"dv\">0<\/span>),<\/span>\n<span id=\"cb12-4\"><a href=\"#cb12-4\" aria-hidden=\"true\"><\/a>                           index<span class=\"op\">=<\/span>sim_index,<\/span>\n<span id=\"cb12-5\"><a href=\"#cb12-5\" aria-hidden=\"true\"><\/a>                           columns<span class=\"op\">=<\/span>[<span class=\"vs\">r&#39;$E[Y \\mid y]$&#39;<\/span>])<\/span>\n<span id=\"cb12-6\"><a href=\"#cb12-6\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-7\"><a href=\"#cb12-7\" aria-hidden=\"true\"><\/a>ppc_hpd <span class=\"op\">=<\/span> pd.DataFrame(pm.hpd(ppc_y_samples, <span class=\"fl\">0.95<\/span>),<\/span>\n<span id=\"cb12-8\"><a href=\"#cb12-8\" aria-hidden=\"true\"><\/a>                       index<span class=\"op\">=<\/span>sim_index)<\/span>\n<span id=\"cb12-9\"><a href=\"#cb12-9\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-10\"><a href=\"#cb12-10\" aria-hidden=\"true\"><\/a>y_obs_df.y.plot(color<span class=\"op\">=<\/span><span class=\"st\">&#39;black&#39;<\/span>, subplots<span class=\"op\">=<\/span><span class=\"va\">False<\/span>)<\/span>\n<span id=\"cb12-11\"><a href=\"#cb12-11\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-12\"><a href=\"#cb12-12\" aria-hidden=\"true\"><\/a>missing_ins_range <span class=\"op\">=<\/span> sim_index[sim_index.weekday <span class=\"op\">==<\/span> <span class=\"dv\">0<\/span>]<\/span>\n<span id=\"cb12-13\"><a href=\"#cb12-13\" aria-hidden=\"true\"><\/a>plt.axvspan(missing_ins_range.<span class=\"bu\">min<\/span>(), missing_ins_range.<span class=\"bu\">max<\/span>(), alpha<span class=\"op\">=<\/span><span class=\"fl\">0.1<\/span>)<\/span>\n<span id=\"cb12-14\"><a href=\"#cb12-14\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-15\"><a href=\"#cb12-15\" aria-hidden=\"true\"><\/a>plt.fill_between(sim_index,<\/span>\n<span id=\"cb12-16\"><a href=\"#cb12-16\" aria-hidden=\"true\"><\/a>                 ppc_hpd[<span class=\"dv\">0<\/span>].values,<\/span>\n<span id=\"cb12-17\"><a href=\"#cb12-17\" aria-hidden=\"true\"><\/a>                 ppc_hpd[<span class=\"dv\">1<\/span>].values,<\/span>\n<span id=\"cb12-18\"><a href=\"#cb12-18\" aria-hidden=\"true\"><\/a>                 label<span class=\"op\">=<\/span><span class=\"vs\">r&#39;$(Y \\mid y)$ 95\\<\/span><span class=\"sc\">% i<\/span><span class=\"vs\">nterval&#39;<\/span>,<\/span>\n<span id=\"cb12-19\"><a href=\"#cb12-19\" aria-hidden=\"true\"><\/a>                 alpha<span class=\"op\">=<\/span><span class=\"fl\">0.5<\/span>)<\/span>\n<span id=\"cb12-20\"><a href=\"#cb12-20\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-21\"><a href=\"#cb12-21\" aria-hidden=\"true\"><\/a>ppc_mean_df.plot(ax<span class=\"op\">=<\/span>plt.axes(), alpha<span class=\"op\">=<\/span><span class=\"fl\">0.7<\/span>)<\/span>\n<span id=\"cb12-22\"><a href=\"#cb12-22\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb12-23\"><a href=\"#cb12-23\" aria-hidden=\"true\"><\/a>plt.legend()<\/span><\/code><\/pre><\/div>\n<\/figure>\n<figure id=\"nil\" class=\"plot\">\n<img src=\"https:\/\/brandonwillard.github.io\/figures\/temp_ppc_plot.png\" title=\"fig:\" alt=\"Posterior predictive results for the stochastic X model \" \/>\n<figcaption>\nPosterior predictive results for the stochastic <span class=\"math inline\">\\(X\\)<\/span> model\n<\/figcaption>\n<\/figure>\n<\/section>\n<section id=\"bibliography\" class=\"level1\">\n<h1>Bibliography<\/h1>\n<p><a id=\"gelman_bayesian_2013\"><\/a> Gelman, Carlin, Stern, Dunson, Vehtari &amp; Rubin, Bayesian Data Analysis, CRC Press (2013). <a href=\"#6824e74293e673dfd3d897debac1bd36\">\u21a9\ufe0e<\/a><\/p>\n<p><a id=\"polson_proximal_2015\"><\/a> Polson, Scott &amp; Willard, Proximal Algorithms in Statistics and Machine Learning, <i>Statistical Science<\/i>, <b>30(4)<\/b>, 559\u2013581 (2015). <a href=\"http:\/\/projecteuclid.org\/euclid.ss\/1449670858\">link<\/a>. <a href=\"#1b02eb0eebf5ae001cbb5ff5b74acff1\">\u21a9\ufe0e<\/a><\/p>\n<\/section>\n<\/body>\n<\/html>\n","category":[{"@attributes":{"term":"articles"}},{"@attributes":{"term":"pymc3"}},{"@attributes":{"term":"bayes"}}]},{"title":"SymPy Expression Tree Manipulation","link":{"@attributes":{"href":"https:\/\/brandonwillard.github.io\/sympy-expression-tree-manipulation.html","rel":"alternate"}},"published":"2016-10-27T00:00:00-05:00","updated":"2016-10-27T00:00:00-05:00","author":{"name":"Brandon Willard"},"id":"tag:brandonwillard.github.io,2016-10-27:\/sympy-expression-tree-manipulation.html","summary":{"@attributes":{"type":"html"}},"content":"<!DOCTYPE html PUBLIC \"-\/\/W3C\/\/DTD XHTML 1.0 Transitional\/\/EN\" \"http:\/\/www.w3.org\/TR\/xhtml1\/DTD\/xhtml1-transitional.dtd\">\n<html xmlns=\"http:\/\/www.w3.org\/1999\/xhtml\">\n<head>\n  <meta http-equiv=\"Content-Type\" content=\"text\/html; charset=utf-8\" \/>\n  <meta http-equiv=\"Content-Style-Type\" content=\"text\/css\" \/>\n  <meta name=\"generator\" content=\"pandoc\" \/>\n  <meta name=\"author\" content=\"Brandon Willard\" \/>\n  <title>SymPy Expression Tree Manipulation<\/title>\n  <style type=\"text\/css\">code{white-space: pre;}<\/style>\n  <style type=\"text\/css\">\npre > code.sourceCode { white-space: pre; position: relative; }\npre > code.sourceCode > span { display: inline-block; line-height: 1.25; }\npre > code.sourceCode > span:empty { height: 1.2em; }\ncode.sourceCode > span { color: inherit; text-decoration: inherit; }\ndiv.sourceCode { margin: 1em 0; }\npre.sourceCode { margin: 0; }\n@media screen {\ndiv.sourceCode { overflow: auto; }\n}\n@media print {\npre > code.sourceCode { white-space: pre-wrap; }\npre > code.sourceCode > span { text-indent: -5em; padding-left: 5em; }\n}\npre.numberSource code\n  { counter-reset: source-line 0; }\npre.numberSource code > span\n  { position: relative; left: -4em; counter-increment: source-line; }\npre.numberSource code > span > a:first-child::before\n  { content: counter(source-line);\n    position: relative; left: -1em; text-align: right; vertical-align: baseline;\n    border: none; display: inline-block;\n    -webkit-touch-callout: none; -webkit-user-select: none;\n    -khtml-user-select: none; -moz-user-select: none;\n    -ms-user-select: none; user-select: none;\n    padding: 0 4px; width: 4em;\n    color: #aaaaaa;\n  }\npre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }\ndiv.sourceCode\n  {   }\n@media screen {\npre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }\n}\ncode span.al { color: #ff0000; font-weight: bold; } \/* Alert *\/\ncode span.an { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Annotation *\/\ncode span.at { color: #7d9029; } \/* Attribute *\/\ncode span.bn { color: #40a070; } \/* BaseN *\/\ncode span.bu { } \/* BuiltIn *\/\ncode span.cf { color: #007020; font-weight: bold; } \/* ControlFlow *\/\ncode span.ch { color: #4070a0; } \/* Char *\/\ncode span.cn { color: #880000; } \/* Constant *\/\ncode span.co { color: #60a0b0; font-style: italic; } \/* Comment *\/\ncode span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } \/* CommentVar *\/\ncode span.do { color: #ba2121; font-style: italic; } \/* Documentation *\/\ncode span.dt { color: #902000; } \/* DataType *\/\ncode span.dv { color: #40a070; } \/* DecVal *\/\ncode span.er { color: #ff0000; font-weight: bold; } \/* Error *\/\ncode span.ex { } \/* Extension *\/\ncode span.fl { color: #40a070; } \/* Float *\/\ncode span.fu { color: #06287e; } \/* Function *\/\ncode span.im { } \/* Import *\/\ncode span.in { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Information *\/\ncode span.kw { color: #007020; font-weight: bold; } \/* Keyword *\/\ncode span.op { color: #666666; } \/* Operator *\/\ncode span.ot { color: #007020; } \/* Other *\/\ncode span.pp { color: #bc7a00; } \/* Preprocessor *\/\ncode span.sc { color: #4070a0; } \/* SpecialChar *\/\ncode span.ss { color: #bb6688; } \/* SpecialString *\/\ncode span.st { color: #4070a0; } \/* String *\/\ncode span.va { color: #19177c; } \/* Variable *\/\ncode span.vs { color: #4070a0; } \/* VerbatimString *\/\ncode span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } \/* Warning *\/\n  <\/style>\n  <!--        <script src=\"https:\/\/cdn.jsdelivr.net\/npm\/mathjax@3\/es5\/tex-mml-chtml.js\" type=\"text\/javascript\"><\/script> -->\n  <script src=\"https:\/\/cdnjs.cloudflare.com\/ajax\/libs\/mathjax\/2.7.0\/MathJax.js?config=TeX-AMS_HTML\" id=\"MathJax-script\"><\/script>\n  <script>\n   MathJax.Hub.Config({\n       tex2jax: {\n           processEnvironments: true,\n           processRefs: false\n       },\n       TeX: {\n           equationNumbers: { autoNumber: \"AMS\" },\n           extensions: [\"AMSmath.js\",\"AMSsymbols.js\",\"noErrors.js\",\"noUndefined.js\"]\n       }\n   });\n  <\/script>\n<\/head>\n<body>\n<!--  -->\n<!-- <div id=\"header\"> -->\n<!-- <h1 class=\"title\">SymPy Expression Tree Manipulation<\/h1> -->\n<!--  -->\n<!--  -->\n<!-- <h2 class=\"author\">Brandon Willard<\/h2> -->\n<!--  -->\n<!--  -->\n<!-- <h3 class=\"date\">2016\u201310\u201327<\/h3> -->\n<!--  -->\n<!-- <\/div> -->\n<!--  -->\n<p>I\u2019ve been working on some extensions to our special function computations in <a href=\"https:\/\/arxiv.org\/abs\/1605.04796\">Prediction risk for global-local shrinkage regression<\/a> and decided to employ <a href=\"https:\/\/github.com\/sympy\/sympy\">SymPy<\/a> as much as possible. Out of this came an <a href=\"https:\/\/bitbucket.org\/bayes-horseshoe-plus\/hsplus-python-pkg\/src\/master\/hsplus\/horn_symbolic.py\">implementation<\/a> of a bivariate confluent hypergeometric function: the <a href=\"https:\/\/en.wikipedia.org\/wiki\/Humbert_series\">Humbert<\/a> <span class=\"math inline\">\\(\\Phi_1\\)<\/span>. This, and some numeric implementations, are available in a <a href=\"https:\/\/bitbucket.org\/bayes-horseshoe-plus\/hsplus-python-pkg\">Python package<\/a> and an <a href=\"https:\/\/bitbucket.org\/bayes-horseshoe-plus\/hsplus-r-pkg\">R package<\/a>.<\/p>\n<p>In the course of this work there are expectations that appear as ratios of <span class=\"math inline\">\\(\\Phi_1\\)<\/span> functions, so it\u2019s helpful to have a symbolic replacement routine to identify them. <a href=\"http:\/\/docs.sympy.org\/dev\/modules\/core.html#sympy.core.basic.Basic.match\">Pattern matching<\/a>, <a href=\"http:\/\/docs.sympy.org\/dev\/modules\/core.html#sympy.core.basic.Basic.find\">finding<\/a>, substitution and <a href=\"http:\/\/docs.sympy.org\/dev\/modules\/core.html#sympy.core.basic.Basic.replace\">replacement<\/a> are fairly standard in SymPy, so nothing special there; however, when you want something specific, it can get rather tricky.<\/p>\n<p>Personally, I\u2019ve found the approach offered by the <a href=\"https:\/\/github.com\/sympy\/sympy\/tree\/master\/sympy\/strategies\"><code>sympy.strategies<\/code><\/a> and <a href=\"https:\/\/github.com\/sympy\/sympy\/tree\/master\/sympy\/unify\"><code>sympy.unify<\/code><\/a> frameworks the most appealing. See the original discussion <a href=\"https:\/\/groups.google.com\/d\/msg\/sympy\/fspCavhbd9I\/vrzUitvgiuYJ\">here<\/a>. The reason for their appeal is mostly due to their organization of the processes behind expression tree traversal and manipulation. It\u2019s much easier to see how a very specific and non-trivial simplification or replacement could be accomplished and iteratively improved. These points are made very well in the posts <a href=\"http:\/\/matthewrocklin.com\/blog\/tags.html#SymPy-ref\">here<\/a>, so check them out.<\/p>\n<p>Let\u2019s say we want to write a function <code>as_expectations<\/code> that takes a <code>sympy.Expr<\/code> and replaces ratios of <span class=\"math inline\">\\(\\Phi_1\\)<\/span> functions according to the following pattern: <span class=\"math display\">\\[\\begin{equation}\nE[X^n] = \\frac{\\Phi_1(\\alpha, \\beta, \\gamma + n; x, y)}{\\Phi_1(\\alpha, \\beta, \\gamma; x, y)}\n\\;.\n\\label{eq:expectation}\n\\end{equation}\\]<\/span><\/p>\n<p>As an example, let\u2019s set up a situation in which <code>as_expectations<\/code> would be used, and, from there, attempt to construct our function. Naturally, this will involve a test expression with terms that we know match Equation\u00a0<span class=\"math inline\">\\(\\eqref{eq:expectation}\\)<\/span>:<\/p>\n<div class=\"sourceCode\" id=\"cb1\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\"><\/a><span class=\"im\">import<\/span> sympy <span class=\"im\">as<\/span> sp<\/span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> hsplus.horn_symbolic <span class=\"im\">import<\/span> HornPhi1<\/span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\"><\/a>a, b, g, z_1, z_2 <span class=\"op\">=<\/span> sp.symbols(<span class=\"st\">&#39;a, b, g, z_1, z_2&#39;<\/span>, real<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\"><\/a>phi1_1 <span class=\"op\">=<\/span> HornPhi1((a, b), (g,), z_1, z_2)<\/span>\n<span id=\"cb1-7\"><a href=\"#cb1-7\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-8\"><a href=\"#cb1-8\" aria-hidden=\"true\"><\/a>n <span class=\"op\">=<\/span> sp.Dummy(<span class=\"st\">&#39;n&#39;<\/span>, integer<span class=\"op\">=<\/span><span class=\"va\">True<\/span>, positive<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-9\"><a href=\"#cb1-9\" aria-hidden=\"true\"><\/a>i <span class=\"op\">=<\/span> sp.Dummy(<span class=\"st\">&#39;i&#39;<\/span>, integer<span class=\"op\">=<\/span><span class=\"va\">True<\/span>, nonnegative<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb1-10\"><a href=\"#cb1-10\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-11\"><a href=\"#cb1-11\" aria-hidden=\"true\"><\/a>phi1_2 <span class=\"op\">=<\/span> HornPhi1((a, b), (g <span class=\"op\">+<\/span> n,), z_1, z_2)<\/span>\n<span id=\"cb1-12\"><a href=\"#cb1-12\" aria-hidden=\"true\"><\/a>phi1_3 <span class=\"op\">=<\/span> HornPhi1((a, b), (g <span class=\"op\">+<\/span> n <span class=\"op\">-<\/span> i,), z_1, z_2)<\/span>\n<span id=\"cb1-13\"><a href=\"#cb1-13\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-14\"><a href=\"#cb1-14\" aria-hidden=\"true\"><\/a>r_1 <span class=\"op\">=<\/span> phi1_2<span class=\"op\">\/<\/span>phi1_1<\/span>\n<span id=\"cb1-15\"><a href=\"#cb1-15\" aria-hidden=\"true\"><\/a>r_2 <span class=\"op\">=<\/span> phi1_3<span class=\"op\">\/<\/span>phi1_1<\/span>\n<span id=\"cb1-16\"><a href=\"#cb1-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb1-17\"><a href=\"#cb1-17\" aria-hidden=\"true\"><\/a>expr <span class=\"op\">=<\/span> a <span class=\"op\">*<\/span> r_1 <span class=\"op\">-<\/span> b <span class=\"op\">*<\/span> r_1 <span class=\"op\">\/<\/span> g <span class=\"op\">+<\/span> sp.Sum(z_1<span class=\"op\">\/<\/span>z_2 <span class=\"op\">*<\/span> r_2 <span class=\"op\">-<\/span> <span class=\"dv\">3<\/span> <span class=\"op\">*<\/span> r_2, (i, <span class=\"dv\">0<\/span>,<\/span>\n<span id=\"cb1-18\"><a href=\"#cb1-18\" aria-hidden=\"true\"><\/a>n))<\/span><\/code><\/pre><\/div>\n<p>Our test expression <code>expr<\/code> looks like this<\/p>\n<div class=\"sourceCode\" id=\"cb2\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb2-1\"><a href=\"#cb2-1\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(expr, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\n\\frac{a \\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad n + g,\n\\quad z_{1}, \\quad z_{2}\\right\n)\\right)}}{\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g,\n\\quad z_{1}, \\quad z_{2}\\right )\\right)}} + \\sum_{i=0}^{n}\n\\left(\\frac{z_{1} \\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b,\n\\quad - i + n + g, \\quad z_{1}, \\quad z_{2}\\right )\\right)}}{z_{2}\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g, \\quad z_{1},\n\\quad z_{2}\\right )\\right)}} - \\frac{3\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad - i + n + g,\n\\quad z_{1}, \\quad z_{2}\\right\n)\\right)}}{\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g,\n\\quad z_{1}, \\quad z_{2}\\right )\\right)}}\\right) - \\frac{b\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad n + g, \\quad\nz_{1}, \\quad z_{2}\\right )\\right)}}{g\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g, \\quad z_{1},\n\\quad z_{2}\\right )\\right)}}\n\\end{equation*}\\]<\/span><\/p>\n<p>The ratios <code>r_1<\/code> and <code>r_2<\/code> should both be replaced by a symbol for <span class=\"math inline\">\\(E[X^m]\\)<\/span>, for <span class=\"math inline\">\\(m = n\\)<\/span> and <span class=\"math inline\">\\(m = n - i\\)<\/span> when <span class=\"math inline\">\\(i &lt; n\\)<\/span> respectively. We could allow <span class=\"math inline\">\\(E[X^0]\\)<\/span>, I suppose, but\u2013for a more interesting discussion\u2013let\u2019s not.<\/p>\n<p>We start by creating a SymPy pattern that expresses the mathematical form of <span class=\"math inline\">\\(E[X^m]\\)<\/span> in Equation\u00a0<span class=\"math inline\">\\(\\eqref{eq:expectation}\\)<\/span>.<\/p>\n<div class=\"sourceCode\" id=\"cb3\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb3-1\"><a href=\"#cb3-1\" aria-hidden=\"true\"><\/a>pnames <span class=\"op\">=<\/span> (<span class=\"st\">&#39;a&#39;<\/span>, <span class=\"st\">&#39;b&#39;<\/span>, <span class=\"st\">&#39;g&#39;<\/span>, <span class=\"st\">&#39;z_1&#39;<\/span>, <span class=\"st\">&#39;z_2&#39;<\/span>)<\/span>\n<span id=\"cb3-2\"><a href=\"#cb3-2\" aria-hidden=\"true\"><\/a>phi1_wild_args_n <span class=\"op\">=<\/span> sp.symbols(<span class=\"st\">&#39;,&#39;<\/span>.join(n_ <span class=\"op\">+<\/span> <span class=\"st\">&#39;_w&#39;<\/span> <span class=\"cf\">for<\/span> n_ <span class=\"kw\">in<\/span> pnames),<\/span>\n<span id=\"cb3-3\"><a href=\"#cb3-3\" aria-hidden=\"true\"><\/a>                              cls<span class=\"op\">=<\/span>sp.Wild, real<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span>\n<span id=\"cb3-4\"><a href=\"#cb3-4\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-5\"><a href=\"#cb3-5\" aria-hidden=\"true\"><\/a>n_w <span class=\"op\">=<\/span> sp.Wild(<span class=\"st\">&#39;n_w&#39;<\/span>,<\/span>\n<span id=\"cb3-6\"><a href=\"#cb3-6\" aria-hidden=\"true\"><\/a>              properties<span class=\"op\">=<\/span>(<span class=\"kw\">lambda<\/span> x: x.is_integer <span class=\"kw\">and<\/span> x.is_positive,),<\/span>\n<span id=\"cb3-7\"><a href=\"#cb3-7\" aria-hidden=\"true\"><\/a>              exclude<span class=\"op\">=<\/span>(phi1_wild_args_n[<span class=\"dv\">2<\/span>],))<\/span>\n<span id=\"cb3-8\"><a href=\"#cb3-8\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-9\"><a href=\"#cb3-9\" aria-hidden=\"true\"><\/a>phi1_wild_d <span class=\"op\">=<\/span> HornPhi1(phi1_wild_args_n[<span class=\"dv\">0<\/span>:<span class=\"dv\">2<\/span>],<\/span>\n<span id=\"cb3-10\"><a href=\"#cb3-10\" aria-hidden=\"true\"><\/a>                       phi1_wild_args_n[<span class=\"dv\">2<\/span>:<span class=\"dv\">3<\/span>],<\/span>\n<span id=\"cb3-11\"><a href=\"#cb3-11\" aria-hidden=\"true\"><\/a>                       <span class=\"op\">*<\/span>phi1_wild_args_n[<span class=\"dv\">3<\/span>:<span class=\"dv\">5<\/span>])<\/span>\n<span id=\"cb3-12\"><a href=\"#cb3-12\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-13\"><a href=\"#cb3-13\" aria-hidden=\"true\"><\/a>phi1_wild_n <span class=\"op\">=<\/span> HornPhi1(phi1_wild_args_n[<span class=\"dv\">0<\/span>:<span class=\"dv\">2<\/span>],<\/span>\n<span id=\"cb3-14\"><a href=\"#cb3-14\" aria-hidden=\"true\"><\/a>                       (phi1_wild_args_n[<span class=\"dv\">2<\/span>] <span class=\"op\">+<\/span> n_w,),<\/span>\n<span id=\"cb3-15\"><a href=\"#cb3-15\" aria-hidden=\"true\"><\/a>                       <span class=\"op\">*<\/span>phi1_wild_args_n[<span class=\"dv\">3<\/span>:<span class=\"dv\">5<\/span>])<\/span>\n<span id=\"cb3-16\"><a href=\"#cb3-16\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-17\"><a href=\"#cb3-17\" aria-hidden=\"true\"><\/a>C_w <span class=\"op\">=<\/span> sp.Wild(<span class=\"st\">&#39;C_w&#39;<\/span>, exclude<span class=\"op\">=<\/span>[sp.S.Zero])<\/span>\n<span id=\"cb3-18\"><a href=\"#cb3-18\" aria-hidden=\"true\"><\/a>E_pattern <span class=\"op\">=<\/span> phi1_wild_n <span class=\"op\">\/<\/span> phi1_wild_d<\/span>\n<span id=\"cb3-19\"><a href=\"#cb3-19\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb3-20\"><a href=\"#cb3-20\" aria-hidden=\"true\"><\/a>E_fn <span class=\"op\">=<\/span> sp.Function(<span class=\"st\">&quot;E&quot;<\/span>, real<span class=\"op\">=<\/span><span class=\"va\">True<\/span>)<\/span><\/code><\/pre><\/div>\n<p>When we find an <span class=\"math inline\">\\(E[X^m]\\)<\/span> we\u2019ll replace it with the symbolic function <code>E_fn<\/code>.<\/p>\n<p>If we focus on only one of the terms (one we know matches <code>E_pattern<\/code>), <code>r_1<\/code>, we should find that our pattern suffices:<\/p>\n<div class=\"sourceCode\" id=\"cb4\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb4-1\"><a href=\"#cb4-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> r_1.match(E_pattern)<\/span>\n<span id=\"cb4-2\"><a href=\"#cb4-2\" aria-hidden=\"true\"><\/a>{n_w_: _n, z_2_w_: z_2, z_1_w_: z_1, a_w_: a, g_w_: g, b_w_: b}<\/span><\/code><\/pre><\/div>\n<p>However, building up to the complexity of <code>expr<\/code>, we see that a simple product doesn\u2019t:<\/p>\n<div class=\"sourceCode\" id=\"cb5\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb5-1\"><a href=\"#cb5-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> (a <span class=\"op\">*<\/span> r_1).match(E_pattern)<\/span><\/code><\/pre><\/div>\n<p>Basically, the product has introduced some problems that arise from associativity. Here are the details for the root expression tree:<\/p>\n<div class=\"sourceCode\" id=\"cb6\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb6-1\"><a href=\"#cb6-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> (a <span class=\"op\">*<\/span> r_1).func<\/span>\n<span id=\"cb6-2\"><a href=\"#cb6-2\" aria-hidden=\"true\"><\/a><span class=\"op\">&lt;<\/span><span class=\"kw\">class<\/span> <span class=\"st\">&#39;sympy.core.mul.Mul&#39;<\/span><span class=\"op\">&gt;<\/span><\/span>\n<span id=\"cb6-3\"><a href=\"#cb6-3\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> (a <span class=\"op\">*<\/span> r_1).args<\/span>\n<span id=\"cb6-4\"><a href=\"#cb6-4\" aria-hidden=\"true\"><\/a>(a, <span class=\"dv\">1<\/span><span class=\"op\">\/<\/span>HornPhi1(a, b, g, z_1, z_2), HornPhi1(a, b, _n <span class=\"op\">+<\/span> g, z_1, z_2))<\/span><\/code><\/pre><\/div>\n<p>The root operation is multiplication and the operation\u2019s arguments are all terms in the product\/division.<\/p>\n<p>Any complete search for matches to <code>E_pattern<\/code> would have to consider all possible combinations of terms in <code>(a * r_1).args<\/code>, i.e.\u00a0all possible groupings that arise due to associativity. The simple inclusion of another <code>Wild<\/code> term causes the match to succeed, since SymPy\u2019s basic pattern matching does account for associativity in this case.<\/p>\n<p>Here are a few explicit ways to make the match work:<\/p>\n<div class=\"sourceCode\" id=\"cb7\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb7-1\"><a href=\"#cb7-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> (a <span class=\"op\">*<\/span> r_1).match(C_w <span class=\"op\">*<\/span> E_pattern)<\/span>\n<span id=\"cb7-2\"><a href=\"#cb7-2\" aria-hidden=\"true\"><\/a>{a_w_: a, n_w_: _n, g_w_: g, z_2_w_: z_2, C_w_: a, b_w_: b, z_1_w_:<\/span>\n<span id=\"cb7-3\"><a href=\"#cb7-3\" aria-hidden=\"true\"><\/a>z_1}<\/span><\/code><\/pre><\/div>\n<p>or as a replacement:<\/p>\n<div class=\"sourceCode\" id=\"cb8\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb8-1\"><a href=\"#cb8-1\" aria-hidden=\"true\"><\/a>res <span class=\"op\">=<\/span> (a <span class=\"op\">*<\/span> r_1).replace(C_w <span class=\"op\">*<\/span> E_pattern, C_w <span class=\"op\">*<\/span> E_fn(n_w,<\/span>\n<span id=\"cb8-2\"><a href=\"#cb8-2\" aria-hidden=\"true\"><\/a><span class=\"op\">*<\/span>phi1_wild_args_n))<\/span>\n<span id=\"cb8-3\"><a href=\"#cb8-3\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(res, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\na E{\\left (n,a,b,g,z_{1},z_{2} \\right )}\n\\end{equation*}\\]<\/span><\/p>\n<p>and via <code>rewriterule<\/code>:<\/p>\n<div class=\"sourceCode\" id=\"cb9\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb9-1\"><a href=\"#cb9-1\" aria-hidden=\"true\"><\/a><span class=\"im\">from<\/span> sympy.unify.rewrite <span class=\"im\">import<\/span> rewriterule<\/span>\n<span id=\"cb9-2\"><a href=\"#cb9-2\" aria-hidden=\"true\"><\/a>rl <span class=\"op\">=<\/span> rewriterule(C_w <span class=\"op\">*<\/span> E_pattern,<\/span>\n<span id=\"cb9-3\"><a href=\"#cb9-3\" aria-hidden=\"true\"><\/a>                 C_w <span class=\"op\">*<\/span> E_fn(n_w, <span class=\"op\">*<\/span>phi1_wild_args_n),<\/span>\n<span id=\"cb9-4\"><a href=\"#cb9-4\" aria-hidden=\"true\"><\/a>                 phi1_wild_args_n <span class=\"op\">+<\/span> (n_w, C_w))<\/span>\n<span id=\"cb9-5\"><a href=\"#cb9-5\" aria-hidden=\"true\"><\/a>res <span class=\"op\">=<\/span> <span class=\"bu\">list<\/span>(rl(a <span class=\"op\">*<\/span> r_1))<\/span>\n<span id=\"cb9-6\"><a href=\"#cb9-6\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(res, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\n\\left [ a E{\\left (n,a,b,g,z_{1},z_{2} \\right )}\\right ]\n\\end{equation*}\\]<\/span><\/p>\n<p>The advantage in using <code>rewriterule<\/code> is that multiple matches will be returned. If we add another <span class=\"math inline\">\\(\\Phi_1\\)<\/span> in the numerator, so there are multiple possible <span class=\"math inline\">\\(E[X^m]\\)<\/span>, we get<\/p>\n<div class=\"sourceCode\" id=\"cb10\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb10-1\"><a href=\"#cb10-1\" aria-hidden=\"true\"><\/a>phi1_4 <span class=\"op\">=<\/span> HornPhi1((a, b), (g <span class=\"op\">+<\/span> n <span class=\"op\">+<\/span> <span class=\"dv\">1<\/span>,), z_1, z_2)<\/span>\n<span id=\"cb10-2\"><a href=\"#cb10-2\" aria-hidden=\"true\"><\/a><\/span>\n<span id=\"cb10-3\"><a href=\"#cb10-3\" aria-hidden=\"true\"><\/a>res <span class=\"op\">=<\/span> <span class=\"bu\">list<\/span>(rl(a <span class=\"op\">*<\/span> r_1 <span class=\"op\">*<\/span> phi1_4))<\/span>\n<span id=\"cb10-4\"><a href=\"#cb10-4\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(res, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\n\\left [ a \\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad n +\ng, \\quad z_{1}, \\quad z_{2}\\right )\\right)} E{\\left (n +\n1,a,b,g,z_{1},z_{2} \\right )}, \\quad a\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad n + g, \\quad\nz_{1}, \\quad z_{2}\\right )\\right)} E{\\left (n + 1,a,b,g,z_{1},z_{2}\n\\right )}, \\quad a \\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b,\n\\quad n + g + 1, \\quad z_{1}, \\quad z_{2}\\right )\\right)} E{\\left\n(n,a,b,g,z_{1},z_{2} \\right )}, \\quad a\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad n + g, \\quad\nz_{1}, \\quad z_{2}\\right )\\right)} E{\\left (n + 1,a,b,g,z_{1},z_{2}\n\\right )}, \\quad a \\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b,\n\\quad n + g, \\quad z_{1}, \\quad z_{2}\\right )\\right)} E{\\left (n +\n1,a,b,g,z_{1},z_{2} \\right )}, \\quad a\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad n + g + 1, \\quad\nz_{1}, \\quad z_{2}\\right )\\right)} E{\\left (n,a,b,g,z_{1},z_{2} \\right\n)}\\right ]\n\\end{equation*}\\]<\/span><\/p>\n<p>FYI: the associativity of terms inside the function arguments is causing the seemingly duplicate results.<\/p>\n<p>Naive use of <code>Expr.replace<\/code> doesn\u2019t give all results; instead, it does something likely unexpected:<\/p>\n<div class=\"sourceCode\" id=\"cb11\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb11-1\"><a href=\"#cb11-1\" aria-hidden=\"true\"><\/a>res <span class=\"op\">=<\/span> (a <span class=\"op\">*<\/span> r_1 <span class=\"op\">*<\/span> phi1_4).replace(C_w <span class=\"op\">*<\/span> E_pattern,<\/span>\n<span id=\"cb11-2\"><a href=\"#cb11-2\" aria-hidden=\"true\"><\/a>                                 C_w <span class=\"op\">*<\/span> E_fn(n_w, <span class=\"op\">*<\/span>phi1_wild_args_n))<\/span>\n<span id=\"cb11-3\"><a href=\"#cb11-3\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(res, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\na E{\\left (n,a,b,g,z_{1},z_{2} \\right )} E{\\left (n +\n1,a,b,g,z_{1},z_{2} \\right )} \\operatorname{\\Phi_1}{\\left(\\left ( a,\n\\quad b, \\quad g, \\quad z_{1}, \\quad z_{2}\\right )\\right)}\n\\end{equation*}\\]<\/span><\/p>\n<p>Returning to our more complicated <code>expr<\/code>\u2026Just because we can match products doesn\u2019t mean we\u2019re finished, since we still need a good way to traverse the entire expression tree and match the sub-trees. More importantly, adding the multiplicative <code>Wild<\/code> term <code>C_w<\/code> is more of a hack than a direct solution, since we don\u2019t want the matched contents of <code>C_w<\/code>.<\/p>\n<p>Although <code>Expr.replace\/xreplace<\/code> will match sub-expressions, we found above that it produces some odd results. Those results persist when applied to more complicated expressions:<\/p>\n<div class=\"sourceCode\" id=\"cb12\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb12-1\"><a href=\"#cb12-1\" aria-hidden=\"true\"><\/a>res <span class=\"op\">=<\/span> expr.replace(C_w <span class=\"op\">*<\/span> E_pattern, C_w <span class=\"op\">*<\/span> E_fn(n_w,<\/span>\n<span id=\"cb12-2\"><a href=\"#cb12-2\" aria-hidden=\"true\"><\/a><span class=\"op\">*<\/span>phi1_wild_args_n))<\/span>\n<span id=\"cb12-3\"><a href=\"#cb12-3\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(res, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\na E{\\left (n,a,b,g,z_{1},z_{2} \\right )} - \\frac{b}{g} E{\\left\n(n,a,b,g,z_{1},z_{2} \\right )} + \\sum_{i=0}^{n} \\left(\\frac{z_{1}\nE{\\left (n,a,b,- i + g,z_{1},z_{2} \\right )}\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad - i + g, \\quad\nz_{1}, \\quad z_{2}\\right )\\right)}}{z_{2}\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g, \\quad z_{1},\n\\quad z_{2}\\right )\\right)}} - \\frac{3 E{\\left (n,a,b,- i +\ng,z_{1},z_{2} \\right )} \\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad\nb, \\quad - i + g, \\quad z_{1}, \\quad z_{2}\\right\n)\\right)}}{\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g,\n\\quad z_{1}, \\quad z_{2}\\right )\\right)}}\\right)\n\\end{equation*}\\]<\/span><\/p>\n<p>Again, it looks like the matching was a little too liberal and introduced extra <code>E<\/code> and <code>HornPhi1<\/code> terms. This is to be expected from the <code>Wild<\/code> matching in SymPy; it needs us to specify what <em>not<\/em> to match, as well. Our \u201cfix\u201d that introduced <code>C_w<\/code> is the exact source of the problem, but we can tell it not to match <code>HornPhi1<\/code> terms and get better results:<\/p>\n<div class=\"sourceCode\" id=\"cb13\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb13-1\"><a href=\"#cb13-1\" aria-hidden=\"true\"><\/a>C_w <span class=\"op\">=<\/span> sp.Wild(<span class=\"st\">&#39;C_w&#39;<\/span>, exclude<span class=\"op\">=<\/span>[sp.S.Zero, HornPhi1])<\/span>\n<span id=\"cb13-2\"><a href=\"#cb13-2\" aria-hidden=\"true\"><\/a>res <span class=\"op\">=<\/span> expr.replace(C_w <span class=\"op\">*<\/span> E_pattern, C_w <span class=\"op\">*<\/span> E_fn(n_w,<\/span>\n<span id=\"cb13-3\"><a href=\"#cb13-3\" aria-hidden=\"true\"><\/a><span class=\"op\">*<\/span>phi1_wild_args_n))<\/span>\n<span id=\"cb13-4\"><a href=\"#cb13-4\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(res, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\na E{\\left (n,a,b,g,z_{1},z_{2} \\right )} - \\frac{b}{g} E{\\left\n(n,a,b,g,z_{1},z_{2} \\right )} + \\sum_{i=0}^{n} \\left(\\frac{z_{1}\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad - i + n + g,\n\\quad z_{1}, \\quad z_{2}\\right )\\right)}}{z_{2}\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g, \\quad z_{1},\n\\quad z_{2}\\right )\\right)}} - \\frac{3\n\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad - i + n + g,\n\\quad z_{1}, \\quad z_{2}\\right\n)\\right)}}{\\operatorname{\\Phi_1}{\\left(\\left ( a, \\quad b, \\quad g,\n\\quad z_{1}, \\quad z_{2}\\right )\\right)}}\\right)\n\\end{equation*}\\]<\/span><\/p>\n<p>We\u2019ve stopped it from introducing those superfluous <code>E<\/code> terms, but we\u2019re still not getting replacements for the <code>HornPhi1<\/code> ratios in the sums. Let\u2019s single out those terms and see what\u2019s going on:<\/p>\n<div class=\"sourceCode\" id=\"cb14\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb14-1\"><a href=\"#cb14-1\" aria-hidden=\"true\"><\/a>res <span class=\"op\">=<\/span> r_2.find(C_w <span class=\"op\">*<\/span> E_pattern)<\/span>\n<span id=\"cb14-2\"><a href=\"#cb14-2\" aria-hidden=\"true\"><\/a><span class=\"bu\">print<\/span>(sp.latex(res, mode<span class=\"op\">=<\/span><span class=\"st\">&#39;equation*&#39;<\/span>, itex<span class=\"op\">=<\/span><span class=\"va\">True<\/span>))<\/span><\/code><\/pre><\/div>\n<p><span class=\"math display\">\\[\\begin{equation*}\n\\left\\{\\right\\}\n\\end{equation*}\\]<\/span><\/p>\n<p>The constrained integer <code>Wild<\/code> term, <code>n_w<\/code>, probably isn\u2019t matching. Given the form of our pattern, <code>n_w<\/code> should match <code>n - i<\/code>, but <code>n - i<\/code> isn\u2019t strictly positive, as required:<\/p>\n<div class=\"sourceCode\" id=\"cb15\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb15-1\"><a href=\"#cb15-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> (n <span class=\"op\">-<\/span> i).is_positive <span class=\"op\">==<\/span> <span class=\"va\">True<\/span><\/span>\n<span id=\"cb15-2\"><a href=\"#cb15-2\" aria-hidden=\"true\"><\/a><span class=\"va\">False<\/span><\/span>\n<span id=\"cb15-3\"><a href=\"#cb15-3\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> sp.ask(sp.Q.positive(n <span class=\"op\">-<\/span> i)) <span class=\"op\">==<\/span> <span class=\"va\">True<\/span><\/span>\n<span id=\"cb15-4\"><a href=\"#cb15-4\" aria-hidden=\"true\"><\/a><span class=\"va\">False<\/span><\/span><\/code><\/pre><\/div>\n<p>Since <span class=\"math inline\">\\(n &gt; 0\\)<\/span> and <span class=\"math inline\">\\(i &gt;= 0\\)<\/span>, the only missing piece is that <span class=\"math inline\">\\(n &gt; i\\)<\/span>. The most relevant mechanism in SymPy to assess this information is the <a href=\"http:\/\/docs.sympy.org\/dev\/modules\/assumptions\/index.html\"><code>sympy.assumptions<\/code><\/a> interface. We could add and retrieve the assumption <code>sympy.Q.is_true(n &gt; i)<\/code> via <code>sympy.assume.global_assumptions<\/code>, or perform these operations inside of a Python <code>with<\/code> block, etc. This context management, via <code>sympy.assumptions.assume.AssumptionsContext<\/code>, would have to be performed manually, since I am not aware of any such mechanism offered by <code>Sum<\/code> and\/or <code>Basic.replace<\/code>.<\/p>\n<p>Unfortunately, these ideas sound good, but aren\u2019t implemented:<\/p>\n<div class=\"sourceCode\" id=\"cb16\"><pre class=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb16-1\"><a href=\"#cb16-1\" aria-hidden=\"true\"><\/a><span class=\"op\">&gt;&gt;&gt;<\/span> sp.ask(sp.Q.positive(n <span class=\"op\">-<\/span> i), sp.Q.is_true(n <span class=\"op\">&gt;<\/span> i)) <span class=\"op\">==<\/span> <span class=\"va\">True<\/span><\/span>\n<span id=\"cb16-2\"><a href=\"#cb16-2\" aria-hidden=\"true\"><\/a><span class=\"va\">False<\/span><\/span><\/code><\/pre><\/div>\n<p>See the documentation for <code>sympy.assumptions.ask.ask<\/code>; it explicitely states that inequalities aren\u2019t handled, yet.<\/p>\n<p>We could probably perform a manual reworking of <code>sympy.Q.is_true(n &gt; i)<\/code> to <code>sympy.Q.is_true(n - i &gt; 0)<\/code>, which is of course equivalent to <code>sympy.Q.positive(n - i)<\/code>: the result we want.<\/p>\n<p>If one were to provide this functionality, there\u2019s still the question of how the relevant <code>AssumptionsContext<\/code>s would be created and passed around\/nested during the subexpression replacements. There is no apparent means of adding this sort of functionality through the <code>Basic.replace<\/code> interface, so this path looks less appealing. However, nesting <code>with<\/code> blocks from strategies in <code>sympy.strategies<\/code> does seem quite possible. For example, in <code>sympy.strategies.traverse.sall<\/code>, one could possibly wrap the <code>return<\/code> statement after the <code>map(rule, ...)<\/code> call in a <code>with sympy.assuming(...):<\/code> block that contains the assumptions for any variables arising as, say, the index of a <code>Sum<\/code>\u2013like in our case. In this scenario, code in the subexpressions would be able to ask questions like <code>sympy.Q.is_true(n &gt; i)<\/code> without altering the global assumptions context or the objects involved.<\/p>\n<p>Anyway, that\u2019s all I wanted to cover here. Perhaps later I\u2019ll post a hack for the assumptions approach, but\u2013at the very least\u2013I\u2019ll try to follow up with a more direct solution that uses <code>sympy.strategies<\/code>.<\/p>\n<\/body>\n<\/html>\n","category":{"@attributes":{"term":"articles"}}}]}