使用 PyTorch 数据加载训练简单的神经网络#
版权所有 2018 JAX 作者。
根据 Apache 许可证 2.0 版(“许可证”)获得许可;除非符合许可证的规定,否则您不得使用此文件。您可以在以下网址获取许可证副本:
https://www.apache.org/licenses/LICENSE-2.0
除非适用法律要求或以书面形式达成协议,否则根据许可证分发的软件按“原样”分发,没有任何明示或暗示的保证或条件。请参阅许可证以了解管理权限和限制的具体语言。
让我们将我们在快速入门
中展示的所有内容结合起来,训练一个简单的神经网络。我们首先将使用 JAX 进行计算,在 MNIST 上指定和训练一个简单的 MLP。我们将使用 PyTorch 的数据加载 API 来加载图像和标签(因为它非常好用,而且世界不需要另一个数据加载库)。当然,您可以将 JAX 与任何与 NumPy 兼容的 API 结合使用,以使模型的指定更加即插即用。在这里,仅仅出于解释的目的,我们不会使用任何神经网络库或特殊 API 来构建我们的模型。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
超参数#
让我们先解决一些簿记事项。
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
自动批处理预测#
让我们首先定义我们的预测函数。请注意,我们正在为一个单个图像示例定义此函数。我们将使用 JAX 的 vmap
函数来自动处理小批量,而不会产生性能损失。
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
让我们检查一下我们的预测函数是否只适用于单个图像。
# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)
至此,我们拥有了定义和训练神经网络所需的所有要素。我们构建了一个 predict
的自动批处理版本,我们应该能够在损失函数中使用它。我们应该能够使用 grad
来计算损失相对于神经网络参数的导数。最后,我们应该能够使用 jit
来加速所有操作。
实用程序和损失函数#
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
使用 PyTorch 加载数据#
JAX 专注于程序转换和加速支持的 NumPy,因此我们没有在 JAX 库中包含数据加载或处理。已经有许多很棒的数据加载器,所以让我们直接使用它们,而不是重新发明轮子。我们将使用 PyTorch 的数据加载器,并创建一个微小的适配器使其能够与 NumPy 数组一起工作。
!pip install torch torchvision
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)
Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)
Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST
def numpy_collate(batch):
return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
def __init__(self, dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(self.__class__, self).__init__(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=numpy_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
def __call__(self, pic):
return np.ravel(np.array(pic, dtype=jnp.float32))
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Processing...
Done!
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)
# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data
warnings.warn("train_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets
warnings.warn("train_labels has been renamed targets")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data
warnings.warn("test_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets
warnings.warn("test_labels has been renamed targets")
训练循环#
import time
for epoch in range(num_epochs):
start_time = time.time()
for x, y in training_generator:
y = one_hot(y, n_targets)
params = update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 55.15 sec
Training set accuracy 0.9157500267028809
Test set accuracy 0.9195000529289246
Epoch 1 in 42.26 sec
Training set accuracy 0.9372166991233826
Test set accuracy 0.9384000301361084
Epoch 2 in 44.37 sec
Training set accuracy 0.9491666555404663
Test set accuracy 0.9469000697135925
Epoch 3 in 41.75 sec
Training set accuracy 0.9568166732788086
Test set accuracy 0.9534000158309937
Epoch 4 in 41.16 sec
Training set accuracy 0.9631333351135254
Test set accuracy 0.9577000737190247
Epoch 5 in 38.89 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9616000652313232
Epoch 6 in 40.68 sec
Training set accuracy 0.9708333611488342
Test set accuracy 0.9650000333786011
Epoch 7 in 41.50 sec
Training set accuracy 0.973716676235199
Test set accuracy 0.9672000408172607
我们现在已经使用了整个 JAX API:grad
用于求导,jit
用于加速,以及 vmap
用于自动向量化。我们使用 NumPy 指定了所有计算,借用了 PyTorch 的优秀数据加载器,并在 GPU 上运行了整个过程。