# PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX

Posted January 2022 Back to Resources

## Table of contents

- 1. Introducing PGMax!
- 2. Background: factor graphs and Restricted Boltzmann Machines
- 3. Tutorial: implementing LBP inference for RBMs with PGMax
- 4.Specifying complex factor graphs
- 5.Exciting new possibilities from implementing LBP in JAX
- 5.1 Processing a batch of samples/models with jax.vmap
- 5.2 End-to-end differentiable LBP
- 6. Conclusion

Probabilistic graphical models (PGMs) aim to provide a compact and complete statistical description of data, and include many popular models like Bayesian networks, Markov random fields, and energy-based models whose energy is a sum of partial energy functions. PGMs have enjoyed successful applications to problems in a wide range of fields such as computer vision, natural language processing and biology. They are naturally equipped to reason under uncertainty, and are able to handle arbitrary inference queries (in which we decide which variables to predict at test time) and missing inputs. Their generality and interpretability make them an indispensable tool for probabilistic modeling.

Despite the desirable properties of PGMs, most AI models today combine a somewhat simplistic form of probabilistic modeling (such as a cross-entropy loss or a Gaussian likelihood) with deep neural networks (DNNs). Often this approach is followed out of convenience: training boils down to minimizing the loss function of the DNN on the training set with respect to its weights, using stochastic gradient descent. This allows training massive architectures with little structure on even more massive training sets. In theory, any DNN architecture that one could imagine can be trained right away using this technique, as long as its loss is a differentiable function of the weights. Two decades ago this was done by manually computing the gradients for each new architecture and then coding them. What has enabled fast iteration in the trial-and-error loop of creating new useful architectures are software packages for automatic differentiation, such as TensorFlow and PyTorch. These tools eliminate the time-consuming and error-prone process of manual gradient computation, allowing researchers to play around with architectures while keeping their gradients instantly correct.

However, this PGM-light, DNN-heavy approach has limitations. Typically, the PGM advantages that we mentioned above are not available: the model has trouble reasoning about uncertainty, and cannot accept arbitrary inference queries or missing inputs. We believe that one of the reasons why current AI models use for the most part simple probabilistic modeling and relegate the model complexity to deterministic functions in the form of DNNs is the lack of automated tooling for probabilistic inference, particularly in the discrete domain. Existing tools in probabilistic programming usually do not play well with undirected PGMs and struggle when the random variables are discrete. Existing tools for discrete inference have problems with quality, efficiency, and scalability. In this blog post we introduce a new tool to fill this gap.

# 1 Introducing PGMax!

Today, we’re excited to announce PGMax, a new open-source Python package designed for the express purpose of flexibly specifying arbitrary discrete PGMs using standard factor graph representations, and automated derivation of efficient and scalable loopy belief propagation (LBP) for both marginal and maximum-a-posteriori (MAP) inference.

At its core, PGMax is designed to have the following key features:

- Easy specification of general factor graphs: PGMax supports general factor graphs with arbitrary graph topology, factor definitions, and discrete variables with a varying number of states. PGMax has built-in primitives for common repetitive variable and factor structures (e.g. grid of variables), and enables users to specify complex discrete PGMs with ease.
- Efficient, scalable loopy belief propagation (LBP) implementation: PGMax adopts an LBP implementation that has been extensively tested in our recent works [
^{1}] [^{2}] [^{3}] [^{4}] with good performance on a wide range of discrete PGMs. PGMax implements LBP in JAX, and is able to leverage just-in-time compilation and modern accelerators like GPUs/TPUs. This results in an inference engine that is several orders-of-magnitude more efficient than existing alternative Python packages, and enables PGMax to practically scale to discrete PGMs with a large number of variables and factors. - Seamless interaction with the JAX ecosystem: PGMax implements LBP as pure functions with no side effects. This functional design means users can easily apply JAX transformations like batching and automatic differentiation, and additionally allows PGMax to seamlessly interact with the rapidly growing JAX ecosystem. This opens up exciting new possibilities like gradient-based discrete PGM learning, as we show in Section 5.

PGMax is a specialized probabilistic programming system, which provides an interface for specifying probabilistic models (discrete PGMs as factor graphs) and automatically derives inference methods (LBP). However, PGMax differs from existing probabilistic programming systems as those systems typically focus on probabilistic models with continuous variables, with Markov Chain Monte Carlo methods like Hamiltonian Monte Carlo as the core inference engine, and have limited support for undirected PGMs and discrete variables.

While there are several open-source Python packages that target discrete PGMs(e.g. pgmpy, pomegranate, py-factorgraph, fglib to name a few), many of these packages are limited in the PGMs they can support, and they additionally suffer from the aforementioned issues with quality, speed, and scalability, as shown in the companion paper [^{5}].

# 2 Background: factor graphs and Restricted Boltzmann Machines

PGMax specifies discrete PGMs using the standard factor graphs, which are bipartite graphs containing variable and factor nodes and edges connecting the two types of nodes. In a factor graph, the variable nodes represent the random variables in the corresponding PGM, while the factor nodes specify the underlying probabilistic model.

As a concrete example, consider the Restricted Boltzmann Machine (RBM), a well-known and widely used PGM for learning probabilistic distributions over binary data. A standard RBM consists of \(n_h\) binary hidden variables \(h \in\{0, 1\}^{n_h}\) and \(n_v\) binary visible variables \(v \in \{ 0, 1\}^{n_v}\), and is usually specified using an energy function \begin{equation} E(h, v) = -b_h^T h – b_v^T v – h^T W v \end{equation} where \(b_h \in \mathbb{R}^{n_h}\) and \(b_v \in \mathbb{R}^{n_v}\) are the biases on the hidden and visible variables, respectively, and \(W \in \mathbb{R}^{n_h \times n_v}\) is a weight matrix. Given the energy function \(E\), the joint probability distribution for the hidden and visible variables can be defined as \begin{equation*} P(h, v) = \frac{1}{Z} e^{-E(h, v)} \end{equation*} where \(Z\) is the so-called partition function.

In Figure 1, we show the factor graph topology for an RBM with \(n_h = 3\) hidden variables and \(n_v = 5\) visible variables. The circles represent variable nodes, while the shaded squares represent factor nodes. The RBM factor graph has one unary factor for each of the hidden and visible variables (representing the biases \(b_h, b_v\)), and a pairwise factor for each pair of hidden and visible variables (representing the weight matrix \(W\)). The factor graph provides a visual illustration of the factorized definition of the energy function \(E\) as the sum of local terms, each of which involves only a small number of random variables, and can be used to easily derive the message updates for inference with LBP.

To fully define the RBM, we need to additionally define the factors in the factor graph using \(b_h, b_v, W\). A factor involves its connected variables in the factor graph, and can be defined by specifying a list of valid joint configurations of the involved variables and their corresponding log potentials. In Table 1 and Table 2 we illustrate the factor definitions for a given RBM.

\(v_3\) | Log potentials |
---|---|

0 | 0 |

1 | \(b_{v 3}\) |

\(h_1\) | \(v_2\) | Log potentials |
---|---|---|

0 | 0 | 0 |

0 | 1 | 0 |

1 | 0 | 0 |

1 | 1 | \(W_{1 2}\) |

PGMax works in the log space for numerical stability, and defines factors using the log potentials, which correspond to minus energy in the typical energy-based formulation. In factor definitions like Table 2, we explicitly enumerate all possible joint configurations, but PGMax more generally supports enumerating only a sparse subset of valid joint configurations, with the assumption that all other joint configurations have log potentials of \(-\infty\) (i.e. these other joint configurations are invalid/disallowed). As we demonstrate and discuss in Section 4, this ability to enumerate a sparse subset of valid joint configurations allows PGMax to flexibly and efficiently support non-standard, structured factors (including higher-order factors), something existing alternatives struggle to do.

# 3 Tutorial: implementing LBP inference for RBMs with PGMax

Following the above review on factor graphs and RBMs, in this section we demonstrate how to use PGMax to implement LBP inference for RBMs. See here for an interactive colab notebook containing the following code snippets.

We start by making some necessary imports.

```
import itertools
import jax
import matplotlib.pyplot as plt
import numpy as np
from pgmax.fg import graph, groups
```

The `pgmax.fg.graph`

module contains core classes for specifying factor graphs and implementing LBP, while the `pgmax.fg.groups`

module contains classes for specifying groups of variables/factors.

Assuming we have 1D numpy arrays `bh`

and `bv`

for the RBM biases and a 2D numpy array `W`

for the RBM weight matrix, we can initialize the factor graph for the RBM with

```
hidden_variables = groups.NDVariableArray(num_states=2, shape=bh.shape)
visible_variables = groups.NDVariableArray(num_states=2, shape=bv.shape)
fg = graph.FactorGraph(
variables=dict(hidden=hidden_variables, visible=visible_variables),
)
```

`NDVariableArray`

is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with `numpy.ndarray`

. The `FactorGraph`

`fg`

is initialized with a set of variables, which can be either a single `VariableGroup`

(e.g. an `NDVariableArray`

), or a list/dictionary of `VariableGroup`

s. Once initialized, the set of variables in `fg`

is fixed and cannot be changed.

After initialization, `fg`

does not have any factors. PGMax supports imperatively adding factors to a `FactorGraph`

. We can add the unary and pairwise factors one at a time to `fg`

by

```
# Add unary factors
for ii in range(bh.shape[0]):
fg.add_factor(
variable_names=[("hidden", ii)],
factor_configs=np.arange(2)[:, None],
log_potentials=np.array([0, bh[ii]]),
)
for jj in range(bv.shape[0]):
fg.add_factor(
variable_names=[("visible", jj)],
factor_configs=np.arange(2)[:, None],
log_potentials=np.array([0, bv[jj]]),
)
# Add pairwise factors
factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
for ii in range(bh.shape[0]):
for jj in range(bv.shape[0]):
fg.add_factor(
variable_names=[("hidden", ii), ("visible", jj)],
factor_configs=factor_configs,
log_potentials=np.array([0, 0, 0, W[ii, jj]]),
)
```

`fg.add_factor`

takes 3 arguments, `variable_names`

, `factor_configs`

and `log_potentials`

, and is a literal translation of the factor definitions shown in Table 1 and Table 2. `variable_names`

is a list containing the name of the involved variables. In this example, since we construct `fg`

with variables `dict(hidden=hidden_variables, visible=visible_variables)`

, where `hidden_variables`

and `visible_variables`

are `NDVariableArray`

s, we can refer to the `ii`

th hidden variable as `("hidden", ii)`

and the `jj`

th visible variable as `("visible", jj)`

. In general, PGMax implements an intuitive scheme for automatically assigning names to the variables in a `FactorGraph`

.

PGMax also implements `FactorGroup`

s for adding groups of similar factors. An alternative way of adding the above factors using `FactorGroup`

s is by calling `fg.add_factor_group`

```
# Add unary factors
fg.add_factor_group(
factory=groups.EnumerationFactorGroup,
variable_names_for_factors=[[("hidden", ii)] for ii in range(bh.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
)
fg.add_factor_group(
factory=groups.EnumerationFactorGroup,
variable_names_for_factors=[[("visible", jj)] for jj in range(bv.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
)
# Add pairwise factors
log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
log_potential_matrix[:, 1, 1] = W.ravel()
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
variable_names_for_factors=[
[("hidden", ii), ("visible", jj)]
for ii in range(bh.shape[0])
for jj in range(bv.shape[0])
],
log_potential_matrix=log_potential_matrix,
)
```

This makes use of `EnumerationFactorGroup`

and `PairwiseFactorGroup`

, two `FactorGroup`

s implemented in the `pgmax.fg.groups`

module. Despite containing more structural information, at the moment `FactorGroup`

s expand into individual factors using regular Python for loops. We have plans to make full use of the additional structural information in `FactorGroup`

s to speed up model construction in future developments.

Once we have added the factors, we can run max-product LBP and get MAP decoding by

```
run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=0.0)
bp_arrays = run_bp(damping=0.5)
beliefs = get_beliefs(bp_arrays)
map_states = graph.decode_map_states(beliefs)
```

and run sum-product LBP and get estimated marginals by

```
run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=1.0)
bp_arrays = run_bp(damping=0.5)
beliefs = get_beliefs(bp_arrays)
marginals = graph.get_marginals(beliefs)
```

More generally, PGMax implements LBP with temperature, with `temperature=0.0`

and `temperature=1.0`

corresponding to the commonly used max/sum-product LBP respectively.

To see PGMax in action, we demonstrate how we can easily implement perturb-and-max-product (PMP) [^{3}] sampling from an RBM trained on MNIST digits using PGMax. PMP is a recently proposed method for approximately sampling from a PGM by computing the maximum-a-posteriori (MAP) configuration (using max-product LBP) of a perturbed version of the model. We first load the \(b_h, b_v\) and \(W\) for an RBM trained in Sec. 5.5 of [3] on MNIST digits (the `rbm_mnist.npz`

file is available for download here):

```
params = np.load("rbm_mnist.npz")
bh = params["bh"]
bv = params["bv"]
W = params["W"]
```

Once we load the trained RBM parameters, we can construct the RBM factor graph as shown above. PMP perturbs the model with Gumbel unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model

```
run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=0.0)
bp_arrays = run_bp(
evidence_updates={
"hidden": np.random.gumbel(size=(bh.shape[0], 2)),
"visible": np.random.gumbel(size=(bv.shape[0], 2)),
},
damping=0.5,
)
beliefs = get_beliefs(bp_arrays)
map_states = graph.decode_map_states(beliefs)
```

Here we use the `evidence_updates`

argument of `run_bp`

to perturb the model with Gumbel unary potentials. In general, `evidence_updates`

can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference.

Visualizing the MAP decoding (Figure 2), we see that we have sampled an MNIST digit!

```
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(map_states["visible"].copy().reshape((28, 28)), cmap='gray')
ax.axis('off')
```

Despite the many past attempts at implementing factor graphs and LBP inference in Python (e.g. pyfac, py-factorgraph, factorflow, fglib, sumproduct, pgmpy, pomegranate to name a few), there exists no well-established open-source Python package with efficient and scalable LBP implementations for general factor graphs. To put things into perspective with a concrete example: the above RBM contains 500 hidden and 784 visible variables (representing a \(28\times 28\) MNIST image), i.e. a total of 1284 variables, and \(500 + 784 + 500 \times 784 = 393284\) factors in the factor graph. PGMax can handle the factor graph building and LBP inference with ease, yet as we demonstrate in our companion paper^{5} , the size of this RBM is already far beyond the capabilities^{1} of two relatively active and well-established alternative Python packages pgmpy and pomegranate. Additional experiments on smaller RBMs also show that, in addition to being qualitatively more efficient and scalable, PGMax’s LBP implementation can obtain higher quality inference results (e.g. lower energy MAP decodings) than existing alternatives.

# 4 Specifying complex factor graphs

The RBM example in Section 3 is fairly simple, with binary variables and standard unary/pairwise factors. To showcase PGMax’s support for complex factor graphs, we next demonstrate how we can easily implement max-product LBP inference for Recursive Cortical Network (RCN) [^{1}] in PGMax. RCN is a neuroscience-inspired probabilistic generative model for computer vision that can learn with very little training data and handles recognition, segmentation, and reasoning in a unified manner. RCN is formulated as a PGM, and uses max-product LBP for inference. Our example in this section adapts the two-level RCN model for MNIST classification in the RCN reference implementation. A self-contained interactive colab notebook implementing the example in this section is available here.

The two-level RCN model consists of a set of templates, one for each training image (upsampled \(200\times 200\) MNIST image). In Figure 3 we show an example training image and its corresponding template. An RCN template can be represented as a graph. A node represents a salient feature in the training image, and corresponds to a variable in the PGM. Each node is allowed to move within a \(25\times 25\) squared region centered at its original location, giving rise to a discrete variable with \(625\) states. An edge represents a constraint on how much the relative distance of two connected nodes can vary with respect to their original relative distance, and corresponds to a factor in the PGM. An edge with associated number (the *perturb radius*) \(r\) means the relative distance of the two connected nodes cannot deviate from their original relative distance by more than \(r\). Classification is achieved via a template matching process, where we run max-product LBP to identify the template with the highest score.

The factors in an RCN template are highly non-standard: depending on the perturb radius \(r\), only a very small number of joint configurations out of the \(625^2=390625\) possibilities are valid. Most of the existing alternative packages (except the now unmaintained OpenGM 2) require specifying all \(390625\) log potentials in order to define a factor, which quickly becomes infeasible. However, PGMax naturally supports such non-standard factors. Given two nodes `NODE1`

and `NODE2`

, we can add factors to the RCN factor graph `fg`

by

```
fg.add_factor(
variable_names=[NODE1, NODE2],
factor_configs=VALID_CONFIGS
)
```

Here `VALID_CONFIGS`

is a numpy array of shape `(n_valid_configs, 2)`

(`n_valid_configs`

\(\ll 390625\)) which explicitly enumerates all possible joint configurations of `NODE1`

, `NODE2`

, and `log_potentials`

is default to an all-zero array when not explicitly specified.

We refer the interested readers to the self-contained interactive colab notebook to see PGMax in action. Using an RCN model trained on 20 images, we achieve a classification accuracy of 80% on a (very small, for-illustration-only) test set with 20 images, with easily interpretable MAP decodings (see Figure 4).

Note that in addition to supporting non-standard pairwise factors like those in RCN, PGMax’s explicit enumeration of valid configurations also makes implementation of many structured higher-order factors (e.g. those used in the recently proposed Markov Attention Models [^{4}]) feasible. Moreover, PGMax’s fully flat LBP implementation [^{5}] allows it to support factor graphs with arbitrary topology and discrete variables of varying sizes without compromising speed or memory usage. These are all things existing alternative packages struggle to do.

# 5 Exciting new possibilities from implementing LBP in JAX

PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with

` run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=NUM_ITERS, temperature=T)`

where `run_bp`

and `get_beliefs`

are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`

/`vmap`

/`grad`

, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see here and here). In what follows we demonstrate how this opens up some exciting new possibilities.

## 5.1 Processing a batch of samples/models with `jax.vmap`

`jax.vmap`

is a convenient transformation for automatically vectorizing functions. Since we implement `run_bp`

/`get_beliefs`

as a pure function, we can apply `jax.vmap`

to `run_bp`

/`get_beliefs`

to process a batch of samples/models in parallel. As an example, consider the PGMax implementation of PMP sampling from the RBM trained on MNIST images in Section 3. Instead of drawing one sample at a time

```
bp_arrays = run_bp(
evidence_updates={
"hidden": np.random.gumbel(size=(bh.shape[0], 2)),
"visible": np.random.gumbel(size=(bv.shape[0], 2)),
},
damping=0.5,
)
beliefs = get_beliefs(bp_arrays)
map_states = graph.decode_map_states(beliefs)
```

we can draw a batch of samples in parallel by transforming `run_bp`

/`get_beliefs`

with `jax.vmap`

```
n_samples = 10
bp_arrays = jax.vmap(functools.partial(run_bp, damping=0.5), in_axes=0, out_axes=0)(
evidence_updates={
"hidden": np.random.gumbel(size=(n_samples, bh.shape[0], 2)),
"visible": np.random.gumbel(size=(n_samples, bv.shape[0], 2)),
},
)
beliefs = jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)
```

Visualizing the MAP decodings (Figure 5), we see that we have sampled 10 MNIST digits in parallel!

```
fig, ax = plt.subplots(2, 5, figsize=(20, 8))
for ii in range(10):
ax[np.unravel_index(ii, (2, 5))].imshow(
map_states["visible"][ii].copy().reshape((28, 28)), cmap='gray'
)
ax[np.unravel_index(ii, (2, 5))].axis("off")
fig.tight_layout()
```

The above is only an example of the possibilities that implementing `run_bp`

as a pure function opens up. To give some additional examples, we can vectorize over the `log_potentials_updates`

argument of `run_bp`

to process a batch of models with the same structure but different log potentials in parallel, or even apply `jax.pmap`

to effortlessly scale LBP inference to multiple GPUs/TPUs!

## 5.2 End-to-end differentiable LBP

The operations involved in LBP are all differentiable, and PGMax’s implementation of LBP in JAX means we can in fact use LBP as part of a larger end-to-end differentiable system, and easily differentiate through the whole LBP inference process.

To showcase this exciting new possibility, in this section we demonstrate how we can use PGMax to train a Grid Markov Random Field (GMRF) on the border ownership dataset using query training, originally proposed in [^{2}]. Roughly speaking, the task is to learn the parameters of the involved PGM (the GMRF) so that running LBP on the learned GMRF results in a specific kind of image denoising. Here we focus on how we can leverage PGMax’s end-to-end differenttiable LBP implementation to easily learn the parameters with gradient-based training, and refer the interested readers to the original paper [2] and the self-contained interactive colab notebook for more details on the dataset, the model and the learning approach.

Given the GMRF factor graph `fg`

, we can generate functions for running sum-product LBP inference

` run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=15, temperature=1.0)`

Implementing gradient-based learning is as simple as setting up the loss function

```
@jax.jit
def loss(noisy_image, target_image, log_potentials):
evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
target = prototype_targets[target_image]
marginals = graph.get_marginals(
get_beliefs(
run_bp(
evidence_updates={None: evidence},
log_potentials_updates=log_potentials,
damping=0.0,
)
)
)
logp = jnp.mean(jnp.log(jnp.sum(target * marginals, axis=-1)))
return -logp
@jax.jit
def batch_loss(noisy_images, target_images, log_potentials):
return jnp.mean(
jax.vmap(loss, in_axes=(0, 0, None), out_axes=0)(
noisy_images, target_images, log_potentials
)
)
value_and_grad = jax.jit(jax.value_and_grad(batch_loss, argnums=2))
```

initializing the optimizer

```
@jax.jit
def update(step, batch_noisy_images, batch_target_images, opt_state):
value, grad = value_and_grad(
batch_noisy_images, batch_target_images, get_params(opt_state)
)
opt_state = opt_update(step, grad, opt_state)
return value, opt_state
```

initializing the `log_potentials`

we want to learn

```
opt_state = init_fun(
{
"top_down": np.random.randn(variable_size, variable_size),
"left_right": np.random.randn(variable_size, variable_size),
"diagonal0": np.random.randn(variable_size, variable_size),
"diagonal1": np.random.randn(variable_size, variable_size),
}
)
```

and going through the actual optimization process

```
for epoch in range(n_epochs):
indices = np.random.permutation(noisy_images_train.shape[0])
for idx in range(n_batches):
batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
batch_noisy_images, batch_target_images = (
noisy_images_train[batch_indices],
target_images_train[batch_indices],
)
step = epoch * n_batches + idx
value, opt_state = update(
step, batch_noisy_images, batch_target_images, opt_state
)
```

Figure 6 visualizes image denoising on the border ownership dataset using sum-product LBP inference on the trained GMRF. Notice how we can flexibly apply (and compose) various JAX transformations (e.g.`jit`

, `vmap`

, `grad`

, `valud_and_grad`

) and how we use sum-product LBP inference (implemented in `run_bp`

) as a regular component in an end-to-end differentiable system.

# 6 Conclusion

We’ve developed PGMax: a new, open-source framework for specifying and performing inference on discrete PGM’s. We designed PGMax with a focus on efficiency and usability, enabling users to easily specify complex models with a large number of variables and factors and perform inference in a practical amount of time. Moreover, PGMax’s tight integration with JAX makes it extremely extensible: users can leverage both advanced features within JAX, as well as features from other libraries in the JAX ecosystem, to develop new features and capabilities for PGMax.

PGMax is the product of many months of active development and iteration. We’ve been using it internally during that time and have already seen significant utility for rapidly prototyping new PGM’s for novel research projects. Given how much open-source frameworks have accelerated research in related fields like deep learning, we’re extremely excited by the potential for PGMax to enable and accelerate research that uses PGMs. So please do consider using, and perhaps even contributing to, our framework: we can’t wait to see what you’ll build!

^{1}A key challenge in implementing efficient and scalable LBP inference for general factor graphs in Python lies in the design and manipulation of Python data structures containing LBP messages for factor graphs with potentially complicated topology and factor definitions and discrete variables with varying number of states. We refer interested readers to our companion paper [

^{5}] for a detailed description of our approach for addressing this challenge and more discussion of related work.

**1**]George, Dileep, Wolfgang Lehrach, Ken Kansky, Miguel Lázaro-Gredilla, Christopher Laan, Bhaskara Marthi, Xinghua Lou et al. “A generative vision model that trains with high data efficiency and breaks text-based CAPTCHAs” Science 358, no. 6368 (2017).

**2**]Lázaro-Gredilla, Miguel, Wolfgang Lehrach, Nishad Gothoskar, Guangyao Zhou, Antoine Dedieu, and Dileep George. “Query Training: Learning a Worse Model to Infer Better Marginals in Undirected Graphical Models with Hidden Variables.” In Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, no. 9, pp. 8252-8260. 2021.

**3**]Lázaro-Gredilla, Miguel, Antoine Dedieu, and Dileep George. “Perturb-and-max-product: Sampling and learning in discrete energy-based models.” Advances in Neural Information Processing Systems 34 (2021).