-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist-sampler-1.3.1.py
executable file
·41 lines (33 loc) · 1.1 KB
/
mnist-sampler-1.3.1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
'''
Demonstrates how to sample and plot MNIST digits
using Keras API
https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
# load dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# count the number of unique train labels
unique, counts = np.unique(y_train, return_counts=True)
print("Train labels: ", dict(zip(unique, counts)))
# count the number of unique test labels
unique, counts = np.unique(y_test, return_counts=True)
print("Test labels: ", dict(zip(unique, counts)))
# sample 25 mnist digits from train dataset
indexes = np.random.randint(0, x_train.shape[0], size=25)
images = x_train[indexes]
labels = y_train[indexes]
# plot the 25 mnist digits
plt.figure(figsize=(5,5))
for i in range(len(indexes)):
plt.subplot(5, 5, i + 1)
image = images[i]
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.savefig("mnist-samples.png")
plt.show()
plt.close('all')