版权所有 2018 JAX 作者。

根据 Apache 许可证 2.0 版(“许可证”)授权;

根据 Apache 许可证 2.0 版(“许可证”)授权;除非遵守许可证的规定,否则您不得使用此文件。您可以在以下位置获得许可证的副本:

https://apache.ac.cn/licenses/LICENSE-2.0

除非适用法律要求或书面同意,否则根据本许可证分发的软件将按“原样”分发,不提供任何形式的明示或暗示的保证或条件。请参阅许可证以了解有关特定语言的管辖权限和限制的详细信息。

训练一个简单的神经网络,使用 tensorflow/datasets 加载数据#

Open in Colab Open in Kaggle

派生自 neural_network_and_data_loading.ipynb

JAX

让我们将我们在 快速入门 中展示的所有内容结合起来,训练一个简单的神经网络。 我们首先使用 JAX 进行计算,在 MNIST 上指定并训练一个简单的 MLP。 我们将使用 tensorflow/datasets 数据加载 API 来加载图像和标签(因为它非常棒,而且世界不需要另一个数据加载库 :P)。

当然,您可以将 JAX 与任何与 NumPy 兼容的 API 一起使用,以使模型的指定更具即插即用性。 在这里,仅为了解释目的,我们将不使用任何神经网络库或特殊的 API 来构建我们的模型。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

超参数#

让我们先处理一些簿记事项。

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

自动批量预测#

让我们首先定义我们的预测函数。请注意,我们正在为单个图像示例定义此函数。我们将使用 JAX 的 vmap 函数来自动处理小批量,而不会产生性能损失。

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

让我们检查一下我们的预测函数是否仅适用于单个图像。

# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)

至此,我们拥有定义和训练神经网络的所有要素。我们构建了一个自动批处理版本的 predict,我们应该可以在损失函数中使用它。我们应该可以使用 grad 来获取损失相对于神经网络参数的导数。最后,我们应该可以使用 jit 来加速一切。

实用工具和损失函数#

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

使用 tensorflow/datasets 加载数据#

JAX 专注于程序转换和加速器支持的 NumPy,因此我们不在 JAX 库中包含数据加载或整理。已经有很多很棒的数据加载器,所以让我们直接使用它们,而不是重新发明任何东西。我们将使用 tensorflow/datasets 数据加载器。

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)

训练循环#

import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341

我们现在已经使用了大部分 JAX API:用于求导的 grad,用于加速的 jit 和用于自动向量化的 vmap。我们使用 NumPy 来指定我们所有的计算,并从 tensorflow/datasets 借用了很棒的数据加载器,并在 GPU 上运行了整个过程。