Intrinsic + Vicarious

The Vicarious business has been acquired by Intrinsic, a robotics software and AI company at Alphabet. Learn more about our shared mission here.

PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX

Posted January 2022 Back to Resources



$$\newcommand{\n}{\hat{n}}\newcommand{\thetai}{\theta_\mathrm{i}}\newcommand{\thetao}{\theta_\mathrm{o}}\newcommand{\d}[1]{\mathrm{d}#1}\newcommand{\w}{\hat{\omega}}\newcommand{\wi}{\w_\mathrm{i}}\newcommand{\wo}{\w_\mathrm{o}}\newcommand{\wh}{\w_\mathrm{h}}\newcommand{\Li}{L_\mathrm{i}}\newcommand{\Lo}{L_\mathrm{o}}\newcommand{\Le}{L_\mathrm{e}}\newcommand{\Lr}{L_\mathrm{r}}\newcommand{\Lt}{L_\mathrm{t}}\newcommand{\O}{\mathrm{O}}\newcommand{\degrees}{{^{\large\circ}}}\newcommand{\T}{\mathsf{T}}\newcommand{\mathset}[1]{\mathbb{#1}}\newcommand{\Real}{\mathset{R}}\newcommand{\Integer}{\mathset{Z}}\newcommand{\Boolean}{\mathset{B}}\newcommand{\Complex}{\mathset{C}}\newcommand{\un}[1]{\,\mathrm{#1}}$$

Table of contents

Probabilistic graphical models (PGMs) aim to provide a compact and complete statistical description of data, and include many popular models like Bayesian networks, Markov random fields, and energy-based models whose energy is a sum of partial energy functions. PGMs have enjoyed successful applications to problems in a wide range of fields such as computer vision, natural language processing and biology. They are naturally equipped to reason under uncertainty, and are able to handle arbitrary inference queries (in which we decide which variables to predict at test time) and missing inputs. Their generality and interpretability make them an indispensable tool for probabilistic modeling.

Despite the desirable properties of PGMs, most AI models today combine a somewhat simplistic form of probabilistic modeling (such as a cross-entropy loss or a Gaussian likelihood) with deep neural networks (DNNs). Often this approach is followed out of convenience: training boils down to minimizing the loss function of the DNN on the training set with respect to its weights, using stochastic gradient descent. This allows training massive architectures with little structure on even more massive training sets. In theory, any DNN architecture that one could imagine can be trained right away using this technique, as long as its loss is a differentiable function of the weights. Two decades ago this was done by manually computing the gradients for each new architecture and then coding them. What has enabled fast iteration in the trial-and-error loop of creating new useful architectures are software packages for automatic differentiation, such as TensorFlow and PyTorch. These tools eliminate the time-consuming and error-prone process of manual gradient computation, allowing researchers to play around with architectures while keeping their gradients instantly correct.

However, this PGM-light, DNN-heavy approach has limitations. Typically, the PGM advantages that we mentioned above are not available: the model has trouble reasoning about uncertainty, and cannot accept arbitrary inference queries or missing inputs. We believe that one of the reasons why current AI models use for the most part simple probabilistic modeling and relegate the model complexity to deterministic functions in the form of DNNs is the lack of automated tooling for probabilistic inference, particularly in the discrete domain. Existing tools in probabilistic programming usually do not play well with undirected PGMs and struggle when the random variables are discrete. Existing tools for discrete inference have problems with quality, efficiency, and scalability. In this blog post we introduce a new tool to fill this gap.

1 Introducing PGMax!

Today, we’re excited to announce PGMax, a new open-source Python package designed for the express purpose of flexibly specifying arbitrary discrete PGMs using standard factor graph representations, and automated derivation of efficient and scalable loopy belief propagation (LBP) for both marginal and maximum-a-posteriori (MAP) inference.

At its core, PGMax is designed to have the following key features:

  • Easy specification of general factor graphs: PGMax supports general factor graphs with arbitrary graph topology, factor definitions, and discrete variables with a varying number of states. PGMax has built-in primitives for common repetitive variable and factor structures (e.g. grid of variables), and enables users to specify complex discrete PGMs with ease.
  • Efficient, scalable loopy belief propagation (LBP) implementation: PGMax adopts an LBP implementation that has been extensively tested in our recent works [1] [2] [3] [4] with good performance on a wide range of discrete PGMs. PGMax implements LBP in JAX, and is able to leverage just-in-time compilation and modern accelerators like GPUs/TPUs. This results in an inference engine that is several orders-of-magnitude more efficient than existing alternative Python packages, and enables PGMax to practically scale to discrete PGMs with a large number of variables and factors.
  • Seamless interaction with the JAX ecosystem: PGMax implements LBP as pure functions with no side effects. This functional design means users can easily apply JAX transformations like batching and automatic differentiation, and additionally allows PGMax to seamlessly interact with the rapidly growing JAX ecosystem. This opens up exciting new possibilities like gradient-based discrete PGM learning, as we show in Section 5.

PGMax is a specialized probabilistic programming system, which provides an interface for specifying probabilistic models (discrete PGMs as factor graphs) and automatically derives inference methods (LBP). However, PGMax differs from existing probabilistic programming systems as those systems typically focus on probabilistic models with continuous variables, with Markov Chain Monte Carlo methods like Hamiltonian Monte Carlo as the core inference engine, and have limited support for undirected PGMs and discrete variables.

While there are several open-source Python packages that target discrete PGMs(e.g. pgmpy, pomegranate, py-factorgraph, fglib to name a few), many of these packages are limited in the PGMs they can support, and they additionally suffer from the aforementioned issues with quality, speed, and scalability, as shown in the companion paper [5].

2 Background: factor graphs and Restricted Boltzmann Machines

PGMax specifies discrete PGMs using the standard factor graphs, which are bipartite graphs containing variable and factor nodes and edges connecting the two types of nodes. In a factor graph, the variable nodes represent the random variables in the corresponding PGM, while the factor nodes specify the underlying probabilistic model.

As a concrete example, consider the Restricted Boltzmann Machine (RBM), a well-known and widely used PGM for learning probabilistic distributions over binary data. A standard RBM consists of \(n_h\) binary hidden variables \(h \in\{0, 1\}^{n_h}\) and \(n_v\) binary visible variables \(v \in \{ 0, 1\}^{n_v}\), and is usually specified using an energy function \begin{equation} E(h, v) = -b_h^T h – b_v^T v – h^T W v \end{equation} where \(b_h \in \mathbb{R}^{n_h}\) and \(b_v \in \mathbb{R}^{n_v}\) are the biases on the hidden and visible variables, respectively, and \(W \in \mathbb{R}^{n_h \times n_v}\) is a weight matrix. Given the energy function \(E\), the joint probability distribution for the hidden and visible variables can be defined as \begin{equation*} P(h, v) = \frac{1}{Z} e^{-E(h, v)} \end{equation*} where \(Z\) is the so-called partition function.

In Figure 1, we show the factor graph topology for an RBM with \(n_h = 3\) hidden variables and \(n_v = 5\) visible variables. The circles represent variable nodes, while the shaded squares represent factor nodes. The RBM factor graph has one unary factor for each of the hidden and visible variables (representing the biases \(b_h, b_v\)), and a pairwise factor for each pair of hidden and visible variables (representing the weight matrix \(W\)). The factor graph provides a visual illustration of the factorized definition of the energy function \(E\) as the sum of local terms, each of which involves only a small number of random variables, and can be used to easily derive the message updates for inference with LBP.

 Figure 1: Factor graph topology for an RBM with \(n_h = 3\) hidden variables and \(n_v = 5\) visible variables

To fully define the RBM, we need to additionally define the factors in the factor graph using \(b_h, b_v, W\). A factor involves its connected variables in the factor graph, and can be defined by specifying a list of valid joint configurations of the involved variables and their corresponding log potentials. In Table 1 and Table 2 we illustrate the factor definitions for a given RBM.

\(v_3\)Log potentials
00
1\(b_{v 3}\)
 Table 1: Definition of the unary factor involving \(v_3\) in Figure 1
\(h_1\)\(v_2\)Log potentials
000
010
100
11\(W_{1 2}\)
 Table 2: Definition of the pairwise factor involving \(h_1, v_2\) in Figure 1

PGMax works in the log space for numerical stability, and defines factors using the log potentials, which correspond to minus energy in the typical energy-based formulation. In factor definitions like Table 2, we explicitly enumerate all possible joint configurations, but PGMax more generally supports enumerating only a sparse subset of valid joint configurations, with the assumption that all other joint configurations have log potentials of \(-\infty\) (i.e. these other joint configurations are invalid/disallowed). As we demonstrate and discuss in Section 4, this ability to enumerate a sparse subset of valid joint configurations allows PGMax to flexibly and efficiently support non-standard, structured factors (including higher-order factors), something existing alternatives struggle to do.

3 Tutorial: implementing LBP inference for RBMs with PGMax

Following the above review on factor graphs and RBMs, in this section we demonstrate how to use PGMax to implement LBP inference for RBMs. See here for an interactive colab notebook containing the following code snippets.

We start by making some necessary imports.


          import itertools
          import jax
          import matplotlib.pyplot as plt
          import numpy as np
          
          from pgmax.fg import graph, groups
        

The pgmax.fg.graph module contains core classes for specifying factor graphs and implementing LBP, while the pgmax.fg.groups module contains classes for specifying groups of variables/factors.

Assuming we have 1D numpy arrays bh and bv for the RBM biases and a 2D numpy array W for the RBM weight matrix, we can initialize the factor graph for the RBM with


          hidden_variables = groups.NDVariableArray(num_states=2, shape=bh.shape)
          visible_variables = groups.NDVariableArray(num_states=2, shape=bv.shape)
          fg = graph.FactorGraph(
              variables=dict(hidden=hidden_variables, visible=visible_variables),
          )
        

NDVariableArray is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with numpy.ndarray. The FactorGraph fg is initialized with a set of variables, which can be either a single VariableGroup (e.g. an NDVariableArray), or a list/dictionary of VariableGroups. Once initialized, the set of variables in fg is fixed and cannot be changed.

After initialization, fg does not have any factors. PGMax supports imperatively adding factors to a FactorGraph. We can add the unary and pairwise factors one at a time to fg by

    # Add unary factors
    for ii in range(bh.shape[0]):
        fg.add_factor(
            variable_names=[("hidden", ii)],
            factor_configs=np.arange(2)[:, None],
            log_potentials=np.array([0, bh[ii]]),
        )
    
    for jj in range(bv.shape[0]):
        fg.add_factor(
            variable_names=[("visible", jj)],
            factor_configs=np.arange(2)[:, None],
            log_potentials=np.array([0, bv[jj]]),
        )
    
    
    # Add pairwise factors
    factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
    for ii in range(bh.shape[0]):
        for jj in range(bv.shape[0]):
            fg.add_factor(
                variable_names=[("hidden", ii), ("visible", jj)],
                factor_configs=factor_configs,
                log_potentials=np.array([0, 0, 0, W[ii, jj]]),
            )

fg.add_factor takes 3 arguments, variable_names, factor_configs and log_potentials, and is a literal translation of the factor definitions shown in Table 1 and Table 2. variable_names is a list containing the name of the involved variables. In this example, since we construct fg with variables dict(hidden=hidden_variables, visible=visible_variables), where hidden_variables and visible_variables are NDVariableArrays, we can refer to the iith hidden variable as ("hidden", ii) and the jjth visible variable as ("visible", jj). In general, PGMax implements an intuitive scheme for automatically assigning names to the variables in a FactorGraph.

PGMax also implements FactorGroups for adding groups of similar factors. An alternative way of adding the above factors using FactorGroups is by calling fg.add_factor_group

    # Add unary factors
    fg.add_factor_group(
        factory=groups.EnumerationFactorGroup,
        variable_names_for_factors=[[("hidden", ii)] for ii in range(bh.shape[0])],
        factor_configs=np.arange(2)[:, None],
        log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
    )
    fg.add_factor_group(
        factory=groups.EnumerationFactorGroup,
        variable_names_for_factors=[[("visible", jj)] for jj in range(bv.shape[0])],
        factor_configs=np.arange(2)[:, None],
        log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
    )
    
    # Add pairwise factors
    log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
    log_potential_matrix[:, 1, 1] = W.ravel()
    fg.add_factor_group(
        factory=groups.PairwiseFactorGroup,
        variable_names_for_factors=[
            [("hidden", ii), ("visible", jj)]
            for ii in range(bh.shape[0])
            for jj in range(bv.shape[0])
        ],
        log_potential_matrix=log_potential_matrix,
    )

This makes use of EnumerationFactorGroup and PairwiseFactorGroup, two FactorGroups implemented in the pgmax.fg.groups module. Despite containing more structural information, at the moment FactorGroups expand into individual factors using regular Python for loops. We have plans to make full use of the additional structural information in FactorGroups to speed up model construction in future developments.

Once we have added the factors, we can run max-product LBP and get MAP decoding by

    run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=0.0)
    bp_arrays = run_bp(damping=0.5)
    beliefs = get_beliefs(bp_arrays)
    map_states = graph.decode_map_states(beliefs)

and run sum-product LBP and get estimated marginals by

    run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=1.0)
    bp_arrays = run_bp(damping=0.5)
    beliefs = get_beliefs(bp_arrays)
    marginals = graph.get_marginals(beliefs)

More generally, PGMax implements LBP with temperature, with temperature=0.0 and temperature=1.0 corresponding to the commonly used max/sum-product LBP respectively.

To see PGMax in action, we demonstrate how we can easily implement perturb-and-max-product (PMP) [3] sampling from an RBM trained on MNIST digits using PGMax. PMP is a recently proposed method for approximately sampling from a PGM by computing the maximum-a-posteriori (MAP) configuration (using max-product LBP) of a perturbed version of the model. We first load the \(b_h, b_v\) and \(W\) for an RBM trained in Sec. 5.5 of [3] on MNIST digits (the rbm_mnist.npz file is available for download here):

    params = np.load("rbm_mnist.npz")
    bh = params["bh"]
    bv = params["bv"]
    W = params["W"]

Once we load the trained RBM parameters, we can construct the RBM factor graph as shown above. PMP perturbs the model with Gumbel unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model

    run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=0.0)
    bp_arrays = run_bp(
        evidence_updates={
            "hidden": np.random.gumbel(size=(bh.shape[0], 2)),
            "visible": np.random.gumbel(size=(bv.shape[0], 2)),
        },
        damping=0.5,
    )
    beliefs = get_beliefs(bp_arrays)
    map_states = graph.decode_map_states(beliefs)

Here we use the evidence_updates argument of run_bp to perturb the model with Gumbel unary potentials. In general, evidence_updates can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference.

Visualizing the MAP decoding (Figure 2), we see that we have sampled an MNIST digit!

    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.imshow(map_states["visible"].copy().reshape((28, 28)), cmap='gray')
    ax.axis('off')

 

 Figure 2: A sampled MNIST digit using PMP

 

Despite the many past attempts at implementing factor graphs and LBP inference in Python (e.g. pyfac, py-factorgraph, factorflow, fglib, sumproduct, pgmpy, pomegranate to name a few), there exists no well-established open-source Python package with efficient and scalable LBP implementations for general factor graphs. To put things into perspective with a concrete example: the above RBM contains 500 hidden and 784 visible variables (representing a \(28\times 28\) MNIST image), i.e. a total of 1284 variables, and \(500 + 784 + 500 \times 784 = 393284\) factors in the factor graph. PGMax can handle the factor graph building and LBP inference with ease, yet as we demonstrate in our companion paper5 , the size of this RBM is already far beyond the capabilities1 of two relatively active and well-established alternative Python packages pgmpy and pomegranate. Additional experiments on smaller RBMs also show that, in addition to being qualitatively more efficient and scalable, PGMax’s LBP implementation can obtain higher quality inference results (e.g. lower energy MAP decodings) than existing alternatives.

4 Specifying complex factor graphs

The RBM example in Section 3 is fairly simple, with binary variables and standard unary/pairwise factors. To showcase PGMax’s support for complex factor graphs, we next demonstrate how we can easily implement max-product LBP inference for Recursive Cortical Network (RCN) [1] in PGMax. RCN is a neuroscience-inspired probabilistic generative model for computer vision that can learn with very little training data and handles recognition, segmentation, and reasoning in a unified manner. RCN is formulated as a PGM, and uses max-product LBP for inference. Our example in this section adapts the two-level RCN model for MNIST classification in the RCN reference implementation. A self-contained interactive colab notebook implementing the example in this section is available here.

The two-level RCN model consists of a set of templates, one for each training image (upsampled \(200\times 200\) MNIST image). In Figure 3 we show an example training image and its corresponding template. An RCN template can be represented as a graph. A node represents a salient feature in the training image, and corresponds to a variable in the PGM. Each node is allowed to move within a \(25\times 25\) squared region centered at its original location, giving rise to a discrete variable with \(625\) states. An edge represents a constraint on how much the relative distance of two connected nodes can vary with respect to their original relative distance, and corresponds to a factor in the PGM. An edge with associated number (the perturb radius) \(r\) means the relative distance of the two connected nodes cannot deviate from their original relative distance by more than \(r\). Classification is achieved via a template matching process, where we run max-product LBP to identify the template with the highest score.

 

 Figure 3: An example template in a two-level RCN model

 

The factors in an RCN template are highly non-standard: depending on the perturb radius \(r\), only a very small number of joint configurations out of the \(625^2=390625\) possibilities are valid. Most of the existing alternative packages (except the now unmaintained OpenGM 2) require specifying all \(390625\) log potentials in order to define a factor, which quickly becomes infeasible. However, PGMax naturally supports such non-standard factors. Given two nodes NODE1 and NODE2, we can add factors to the RCN factor graph fg by

    fg.add_factor(
        variable_names=[NODE1, NODE2], 
        factor_configs=VALID_CONFIGS
    )

Here VALID_CONFIGS is a numpy array of shape (n_valid_configs, 2) (n_valid_configs\(\ll 390625\)) which explicitly enumerates all possible joint configurations of NODE1, NODE2, and log_potentials is default to an all-zero array when not explicitly specified.

We refer the interested readers to the self-contained interactive colab notebook to see PGMax in action. Using an RCN model trained on 20 images, we achieve a classification accuracy of 80% on a (very small, for-illustration-only) test set with 20 images, with easily interpretable MAP decodings (see Figure 4).

 

 Figure 4: Visualizations of MAP decodings from max-product LBP inference with RCN

 

Note that in addition to supporting non-standard pairwise factors like those in RCN, PGMax’s explicit enumeration of valid configurations also makes implementation of many structured higher-order factors (e.g. those used in the recently proposed Markov Attention Models [4]) feasible. Moreover, PGMax’s fully flat LBP implementation [5] allows it to support factor graphs with arbitrary topology and discrete variables of varying sizes without compromising speed or memory usage. These are all things existing alternative packages struggle to do.

5 Exciting new possibilities from implementing LBP in JAX

PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with

    run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=NUM_ITERS, temperature=T)

where run_bp and get_beliefs are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like jit/vmap/grad, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see here and here). In what follows we demonstrate how this opens up some exciting new possibilities.

5.1 Processing a batch of samples/models with jax.vmap

jax.vmap is a convenient transformation for automatically vectorizing functions. Since we implement run_bp/get_beliefs as a pure function, we can apply jax.vmap to run_bp/get_beliefs to process a batch of samples/models in parallel. As an example, consider the PGMax implementation of PMP sampling from the RBM trained on MNIST images in Section 3. Instead of drawing one sample at a time

    bp_arrays = run_bp(
        evidence_updates={
            "hidden": np.random.gumbel(size=(bh.shape[0], 2)),
            "visible": np.random.gumbel(size=(bv.shape[0], 2)),
        },
        damping=0.5,
    )
    beliefs = get_beliefs(bp_arrays)
    map_states = graph.decode_map_states(beliefs)

we can draw a batch of samples in parallel by transforming run_bp/get_beliefs with jax.vmap

    n_samples = 10
    bp_arrays = jax.vmap(functools.partial(run_bp, damping=0.5), in_axes=0, out_axes=0)(
        evidence_updates={
            "hidden": np.random.gumbel(size=(n_samples, bh.shape[0], 2)),
            "visible": np.random.gumbel(size=(n_samples, bv.shape[0], 2)),
        },
    )
    beliefs = jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
    map_states = graph.decode_map_states(beliefs)

Visualizing the MAP decodings (Figure 5), we see that we have sampled 10 MNIST digits in parallel!

    fig, ax = plt.subplots(2, 5, figsize=(20, 8))
    for ii in range(10):
        ax[np.unravel_index(ii, (2, 5))].imshow(
            map_states["visible"][ii].copy().reshape((28, 28)), cmap='gray'
        )
        ax[np.unravel_index(ii, (2, 5))].axis("off")
    
    fig.tight_layout()

 

 Figure 5: Sampling multiple MNIST digits in parallel using jax.vmap

 

The above is only an example of the possibilities that implementing run_bp as a pure function opens up. To give some additional examples, we can vectorize over the log_potentials_updates argument of run_bp to process a batch of models with the same structure but different log potentials in parallel, or even apply jax.pmap to effortlessly scale LBP inference to multiple GPUs/TPUs!

5.2 End-to-end differentiable LBP

The operations involved in LBP are all differentiable, and PGMax’s implementation of LBP in JAX means we can in fact use LBP as part of a larger end-to-end differentiable system, and easily differentiate through the whole LBP inference process.

To showcase this exciting new possibility, in this section we demonstrate how we can use PGMax to train a Grid Markov Random Field (GMRF) on the border ownership dataset using query training, originally proposed in [2]. Roughly speaking, the task is to learn the parameters of the involved PGM (the GMRF) so that running LBP on the learned GMRF results in a specific kind of image denoising. Here we focus on how we can leverage PGMax’s end-to-end differenttiable LBP implementation to easily learn the parameters with gradient-based training, and refer the interested readers to the original paper [2] and the self-contained interactive colab notebook for more details on the dataset, the model and the learning approach.

Given the GMRF factor graph fg, we can generate functions for running sum-product LBP inference

    run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=15, temperature=1.0)

Implementing gradient-based learning is as simple as setting up the loss function


        @jax.jit
        def loss(noisy_image, target_image, log_potentials):
            evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
            target = prototype_targets[target_image]
            marginals = graph.get_marginals(
                get_beliefs(
                    run_bp(
                        evidence_updates={None: evidence},
                        log_potentials_updates=log_potentials,
                        damping=0.0,
                    )
                )
            )
            logp = jnp.mean(jnp.log(jnp.sum(target * marginals, axis=-1)))
            return -logp
        
        @jax.jit
        def batch_loss(noisy_images, target_images, log_potentials):
            return jnp.mean(
                jax.vmap(loss, in_axes=(0, 0, None), out_axes=0)(
                    noisy_images, target_images, log_potentials
                )
            )
        
        value_and_grad = jax.jit(jax.value_and_grad(batch_loss, argnums=2))
      

initializing the optimizer


      @jax.jit
      def update(step, batch_noisy_images, batch_target_images, opt_state):
          value, grad = value_and_grad(
              batch_noisy_images, batch_target_images, get_params(opt_state)
          )
          opt_state = opt_update(step, grad, opt_state)
          return value, opt_state
  

initializing the log_potentials we want to learn


      opt_state = init_fun(
          {
              "top_down": np.random.randn(variable_size, variable_size),
              "left_right": np.random.randn(variable_size, variable_size),
              "diagonal0": np.random.randn(variable_size, variable_size),
              "diagonal1": np.random.randn(variable_size, variable_size),
          }
      )
    

and going through the actual optimization process


      for epoch in range(n_epochs):
          indices = np.random.permutation(noisy_images_train.shape[0])
          for idx in range(n_batches):
              batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
              batch_noisy_images, batch_target_images = (
                  noisy_images_train[batch_indices],
                  target_images_train[batch_indices],
              )
              step = epoch * n_batches + idx
              value, opt_state = update(
                  step, batch_noisy_images, batch_target_images, opt_state
              )
    

Figure 6 visualizes image denoising on the border ownership dataset using sum-product LBP inference on the trained GMRF. Notice how we can flexibly apply (and compose) various JAX transformations (e.g.jit, vmap, grad, valud_and_grad) and how we use sum-product LBP inference (implemented in run_bp) as a regular component in an end-to-end differentiable system.

 

 Figure 6: Image denoising on the border ownership dataset using sum-product LBP inference on the trained GMRF

 

6 Conclusion

We’ve developed PGMax: a new, open-source framework for specifying and performing inference on discrete PGM’s. We designed PGMax with a focus on efficiency and usability, enabling users to easily specify complex models with a large number of variables and factors and perform inference in a practical amount of time. Moreover, PGMax’s tight integration with JAX makes it extremely extensible: users can leverage both advanced features within JAX, as well as features from other libraries in the JAX ecosystem, to develop new features and capabilities for PGMax.

PGMax is the product of many months of active development and iteration. We’ve been using it internally during that time and have already seen significant utility for rapidly prototyping new PGM’s for novel research projects. Given how much open-source frameworks have accelerated research in related fields like deep learning, we’re extremely excited by the potential for PGMax to enable and accelerate research that uses PGMs. So please do consider using, and perhaps even contributing to, our framework: we can’t wait to see what you’ll build!

 1 A key challenge in implementing efficient and scalable LBP inference for general factor graphs in Python lies in the design and manipulation of Python data structures containing LBP messages for factor graphs with potentially complicated topology and factor definitions and discrete variables with varying number of states. We refer interested readers to our companion paper [5] for a detailed description of our approach for addressing this challenge and more discussion of related work.
References

Posted