AI Frameworks have transformed the way machine learning and deep learning applications are developed which enables faster training and deployment. Out of all these frameworks, PyTorch and TensorFlow have become the most used due to their vast feature sets, extensive ecosystems and community support. While these frameworks dominate the AI landscape, Just After eXecution (JAX) framework was developed for array-oriented numerical computations by providing NumPy-like interface. This framework is designed for high-performance computing and machine learning research by leveraging the power of hardware accelerators like CPUs, GPUs and TPUs to speed up computations in deep learning models. JAX will complement both PyTorch and TensorFlow frameworks by empowering users to tackle cutting edge challenges and research in AI development.
Features
JAX provides the below features:
- Unified NumPy-like interface - To perform computations on CPU, GPU, or TPU.
- Built-in Just-In-Time (JIT) compilation - Optimize computations on hardware accelerators with jit via Open XLA (an open-source machine learning compiler ecosystem) which results in faster training times for deep learning models. jit can be used either as an @jit decorator or as a higher order function.
- Automatic differentiation transformations - To calculate gradients efficiently which is essential for many deep learning algorithms. The most popular function is grad for reverse-mode gradients.
- Automatic vectorization - To efficiently map JAX functions over arrays representing batches of inputs. vmap is the vectorizing map which has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance.
Compared to other frameworks like PyTorch and TensorFlow, JAX’s NumPy like interface makes it uniquely suitable for scientific computing, optimization tasks and it’s JIT integration requires minimal additional code from the user.
Getting Started
Installation
- For CPU (on Linux, Windows, and macOS) - Install directly from the Python Package Index.
- For GPU -
- For TPU -
For more information on platform-specific installation and supported platforms, check out the documentation.
Code Sample
The code sample shows how to implement a simple neural network that will be trained on the MNIST dataset using JAX for parallel computations across multiple CPU cores. It focuses on how to use JAX's ‘pmap’ to execute single-program multiple-data (SPMD) programs for data parallelism along a batch dimension, while minimizing dependencies. The following steps and functions are implemented in the code sample.
1. Import the necessary packages and libraries.
2. Define init_random_params function – to initialize the weights and biases for each layer in the neural network.
3. Define predict function – to compute the forward pass of the network by applying weights, biases, and activations to inputs.
4. Define loss function – to calculate the cross-entropy loss between predictions and target labels.
5. Define accuracy function – to compute the accuracy of the model by predicting the class of each input in the batch and comparing it to the true target class. It uses the jnp.argmax function to find the predicted class and then computes the mean of correct predictions.
6. Define data_stream function – to generate batches of shuffled training data. It reshapes the data so that it can be split across multiple cores, ensuring that the batch size is divisible by the number of cores for parallel processing.
7. Define spmd_update function – to perform parallel gradient updates across multiple devices using JAX’s pmap and lax.psum.
8. Create a training loop - train the model for a number of epochs by updating parameters and printing training/test accuracy after each epoch. The parameters are replicated across devices and updated in parallel using spmd_update. After each epoch, the model’s accuracy is evaluated on both training and test data using accuracy.
Try out and run the above code sample to implement a simple neural network's training and inference for mnist images using JAX on CPU. The network is trained over multiple epochs, and we will evaluate accuracy and adjust parameters using stochastic gradient descent.
What’s Next
Get started with JAX on different datasets like Sentiment140 dataset for sentiment analysis and perform complex numerical computations on high-performance devices. Also, check out Intel® Extension for TensorFlow* (Intel optimizes the open source TensorFlow framework for Intel hardware and releases its newest optimizations and features in Intel Extension) and this extension includes PJRT (Pluggable Device Runtime) plugin implementation, which seamlessly runs JAX models on Intel GPUs.
Access and try the AI Tools for yourself to build additional end-to-end AI applications. We encourage you to also check out and incorporate Intel’s other AI/ML Framework optimizations and tools into your AI workflow and learn about the unified, open, standards-based oneAPI programming model that forms the foundation of Intel’s AI Software Portfolio to help you prepare, build, deploy, and scale your AI solutions.