A Quick Start Guide to JAX: High-Performance Computing for Machine Learning

Stay in the Know on All Things CODE

Ramya Ravi, AI Software Marketing Engineer, Intel | LinkedIn
Chandan Damannagari, Director, AI Software, Intel | LinkedIn

AI Frameworks have transformed the way machine learning and deep learning applications are developed which enables faster training and deployment. Out of all these frameworks, PyTorch and TensorFlow have become the most used due to their vast feature sets, extensive ecosystems and community support. While these frameworks dominate the AI landscape, Just After eXecution (JAX) framework was developed for array-oriented numerical computations by providing NumPy-like interface. This framework is designed for high-performance computing and machine learning research by leveraging the power of hardware accelerators like CPUs, GPUs and TPUs to speed up computations in deep learning models. JAX will complement both PyTorch and TensorFlow frameworks by empowering users to tackle cutting edge challenges and research in AI development.

Features

JAX provides the below features:

  • Unified NumPy-like interface - To perform computations on CPU, GPU, or TPU.
  • Built-in Just-In-Time (JIT) compilation - Optimize computations on hardware accelerators with jit via Open XLA (an open-source machine learning compiler ecosystem) which results in faster training times for deep learning models. jit can be used either as an @jit decorator or as a higher order function.
  • Automatic differentiation transformations - To calculate gradients efficiently which is essential for many deep learning algorithms. The most popular function is grad for reverse-mode gradients.
  • Automatic vectorization - To efficiently map JAX functions over arrays representing batches of inputs. vmap is the vectorizing map which has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance.

Compared to other frameworks like PyTorch and TensorFlow, JAX’s NumPy like interface makes it uniquely suitable for scientific computing, optimization tasks and it’s JIT integration requires minimal additional code from the user.

Getting Started

Installation

  1. For CPU (on Linux, Windows, and macOS) - Install directly from the Python Package Index.
    pip install -U jax
  2. For GPU -
    pip install -U "jax[cuda12]"
  3. For TPU
    pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

     

For more information on platform-specific installation and supported platforms, check out the documentation.

Code Sample

The code sample shows how to implement a simple neural network that will be trained on the MNIST dataset using JAX for parallel computations across multiple CPU cores. It focuses on how to use JAX's ‘pmap’ to execute single-program multiple-data (SPMD) programs for data parallelism along a batch dimension, while minimizing dependencies. The following steps and functions are implemented in the code sample.

1. Import the necessary packages and libraries.

from functools import partial import time import numpy as np import numpy.random as npr import jax from jax import jit, grad, pmap from jax.scipy.special import logsumexp from jax.tree_util import tree_map from jax import lax import jax.numpy as jnp from examples import datasets

2. Define init_random_params function – to initialize the weights and biases for each layer in the neural network.

def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

3. Define predict function – to compute the forward pass of the network by applying weights, biases, and activations to inputs.

def predict(params, inputs): activations = inputs for w, b in params[:-1]: outputs = jnp.dot(activations, w) + b activations = jnp.tanh(outputs) final_w, final_b = params[-1] logits = jnp.dot(activations, final_w) + final_b return logits - logsumexp(logits, axis=1, keepdims=True)

4. Define loss function – to calculate the cross-entropy loss between predictions and target labels.

def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1))

5. Define accuracy function – to compute the accuracy of the model by predicting the class of each input in the batch and comparing it to the true target class. It uses the jnp.argmax function to find the predicted class and then computes the mean of correct predictions.

def accuracy(params, batch): inputs, targets = batch target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(predict(params, inputs), axis=1) return jnp.mean(predicted_class == target_class)

6. Define data_stream function – to generate batches of shuffled training data. It reshapes the data so that it can be split across multiple cores, ensuring that the batch size is divisible by the number of cores for parallel processing.

def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] images, labels = train_images[batch_idx], train_labels[batch_idx] # For this SPMD example, we reshape the data batch dimension into two # batch dimensions, one of which is mapped over parallel devices. batch_size_per_device, ragged = divmod(images.shape[0], num_devices) if ragged: msg = "batch size must be divisible by device count, got {} and {}." raise ValueError(msg.format(batch_size, num_devices)) shape_prefix = (num_devices, batch_size_per_device) images = images.reshape(shape_prefix + images.shape[1:]) labels = labels.reshape(shape_prefix + labels.shape[1:]) yield images, labels

7. Define spmd_update function – to perform parallel gradient updates across multiple devices using JAX’s pmap and lax.psum.

@partial(pmap, axis_name='batch') def spmd_update(params, batch): grads = grad(loss)(params, batch) # `lax.psum` SPMD primitive - does a fast all-reduce-sum. grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads] return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

8. Create a training loop - train the model for a number of epochs by updating parameters and printing training/test accuracy after each epoch. The parameters are replicated across devices and updated in parallel using spmd_update. After each epoch, the model’s accuracy is evaluated on both training and test data using accuracy.

Try out and run the above code sample to implement a simple neural network's training and inference for mnist images using JAX on CPU. The network is trained over multiple epochs, and we will evaluate accuracy and adjust parameters using stochastic gradient descent.

What’s Next

Get started with JAX on different datasets like Sentiment140 dataset for sentiment analysis and perform complex numerical computations on high-performance devices. Also, check out Intel® Extension for TensorFlow* (Intel optimizes the open source TensorFlow framework for Intel hardware and releases its newest optimizations and features in Intel Extension) and this extension includes PJRT (Pluggable Device Runtime) plugin implementation, which seamlessly runs JAX models on Intel GPUs.

Access and try the AI Tools for yourself to build additional end-to-end AI applications. We encourage you to also check out and incorporate Intel’s other AI/ML Framework optimizations and tools into your AI workflow and learn about the unified, open, standards-based oneAPI programming model that forms the foundation of Intel’s AI Software Portfolio to help you prepare, build, deploy, and scale your AI solutions.

Useful resources