Flax
Flax
A neural network ecosystem for JAX designed for flexibility
- What makes Flax a rising star in Machine learning? 🤩
- It’s a fast, lightweight and highly customizable ML framework
- Doc
- GitHub
- Google’s Approach To Flexibility In Machine Learning
- Shakespeare Meets Google’s Flax
Overview
-
Neural network API (
flax.linen
): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout -
Optimizers (
flax.optim
): SGD, Momentum, Adam, LARS, Adagrad, LAMB, RMSprop -
Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device
-
Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging
-
Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b
What does Flax look like?
We provide here two examples using the Flax API: a simple multi-layer perceptron and a CNN. To learn more about the Module
abstraction, please check our docs.
class SimpleMLP(nn.Module):
""" A MLP model """
features: Sequence[int]
@nn.compact
def __call__(self, x):
for i, feat in enumerate(self.features):
x = nn.Dense(feat)(x)
if i != len(self.features) - 1:
x = nn.relu(x)
return x
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x