Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge dev postgreql to master #1208

Merged
merged 67 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
90357af
Add the distributed version of the model
GY-GitCode May 24, 2024
79dd9f6
Create mnist_cnn.py
jinyangturbo May 26, 2024
d46f338
Merge pull request #1174 from GY-GitCode/24-5-24-dev
lzjpaul May 31, 2024
af09a0f
Add sparsification version of the model
GY-GitCode Jun 13, 2024
f52e35f
Merge pull request #1175 from jinyangturbo/dev-postgresql
lzjpaul Jun 13, 2024
2730269
Merge pull request #1176 from GY-GitCode/24-6-13-dev
chrishkchris Jun 13, 2024
fb9106c
Create mnist_dist.py
solopku Jun 14, 2024
bdbb7a2
Add multiprocess version of the model
NLGithubWP Jun 14, 2024
0ee8cb3
Merge pull request #1178 from NLGithubWP/dev-postgresql
lzjpaul Jun 14, 2024
edfac8f
Add implementations for the autograd resnet cifar-10
GY-GitCode Jun 15, 2024
b9eb59b
Add the benchmark dataset for testing the model selection
NLGithubWP Jun 15, 2024
cd90fa1
Merge pull request #1181 from NLGithubWP/dataset
lzjpaul Jun 15, 2024
44b1c9a
Add the implementations of alexnet model in the autograd
dcslin Jun 15, 2024
dbb194b
Merge pull request #1182 from dcslin/feature/alexnet
lzjpaul Jun 15, 2024
bad4181
Add the download implementations for the benchmark dataset
lemonviv Jun 16, 2024
7cf604e
Add the implementations of resnet model in the autograd
Jun 16, 2024
7b4b658
Merge pull request #1179 from GY-GitCode/24-6-16-dev
chrishkchris Jun 16, 2024
15640e7
Merge pull request #1177 from solopku/patch-2
nudles Jun 17, 2024
cdd8ece
add the implementations of xceptionnet model in the autograd
Zrealshadow Jun 17, 2024
80760ce
Merge pull request #1183 from lemonviv/add-mnist-download
chrishkchris Jun 17, 2024
a1416c2
Create cifar100.py
jinyangturbo Jun 17, 2024
9caa25d
Merge pull request #1184 from zlheui/add-resnet-in-autograd
nudles Jun 17, 2024
69f5167
Create cifar10.py
solopku Jun 19, 2024
ef7b196
Merge pull request #1185 from Zrealshadow/dev-postgresql-patch
lzjpaul Jul 17, 2024
a97941f
Create download_cifar100.py
solopku Jul 17, 2024
2fd2aa6
Merge pull request #1188 from solopku/patch-4
nudles Jul 18, 2024
c2ffece
Merge pull request #1187 from solopku/patch-3
lzjpaul Jul 18, 2024
d842c74
Merge pull request #1186 from jinyangturbo/dev-postgresql
chrishkchris Jul 19, 2024
133e106
Add the implementation for the native model in the cnn ms example
NLGithubWP Jul 25, 2024
10086e0
Add implementations for data downloading in cnn ms example
dcslin Jul 25, 2024
7281a74
Merge pull request #1189 from NLGithubWP/dev-postgresql
chrishkchris Jul 26, 2024
b34cfd7
Merge pull request #1190 from dcslin/dev-postgresql
lzjpaul Jul 26, 2024
6f9051c
Add the msmlp model implementation for the cnn ms example
GY-GitCode Jul 27, 2024
9c1a05e
Commit the benchmark model for the cnn ms example
moazreyad Jul 29, 2024
924f352
Merge pull request #1192 from moazreyad/dev-postgresql
nudles Jul 30, 2024
d917f09
Merge pull request #1191 from GY-GitCode/24-7-27-dev
lzjpaul Aug 2, 2024
20ffa82
Add the Sum Error Loss for Synfolw
moazreyad Aug 2, 2024
8735f49
Merge pull request #1193 from moazreyad/dev-postgresql
chrishkchris Aug 5, 2024
c01551e
Add the implementation for the SumError class
GY-GitCode Aug 6, 2024
e9d1cc9
Merge pull request #1194 from GY-GitCode/24-8-6-dev
lzjpaul Aug 7, 2024
d745905
Update the RELEASE_NOTES for v4.3.0
lzjpaul Aug 9, 2024
8ec9585
Update the CMakeList file for V 4.3.0
lemonviv Aug 10, 2024
7237a87
Merge pull request #1196 from lemonviv/cmakefile-v4.3
lzjpaul Aug 10, 2024
9a87f9e
Merge pull request #1195 from lzjpaul/24-8-9-dev
chrishkchris Aug 11, 2024
04859e6
Update setup.py for v4.3.0
GY-GitCode Aug 11, 2024
827f22d
Update meta.yaml for v4.3.0
GY-GitCode Aug 11, 2024
8f81095
Merge pull request #1197 from GY-GitCode/24-8-11-dev
chrishkchris Aug 11, 2024
d4b11ab
Add the SumErrorLayer class for the msmlp model
zmeihui Aug 11, 2024
77fbc63
Merge pull request #1198 from zmeihui/24-8-11-dev
lzjpaul Aug 11, 2024
09df965
Create train.py
solopku Aug 12, 2024
c594e78
Merge pull request #1199 from solopku/patch-5
lzjpaul Aug 13, 2024
e397d7c
Add the running script for the Transformer example
moazreyad Aug 13, 2024
55b6de2
Add the readme file for the Transformer example
NLGithubWP Aug 13, 2024
1544f14
Add Machine translation model using Transformer Example
NLGithubWP Aug 13, 2024
e1200aa
Merge pull request #1200 from moazreyad/dev-postgresql
nudles Aug 14, 2024
b5857b3
Merge pull request #1201 from NLGithubWP/dev-postgresql
lzjpaul Aug 14, 2024
2677bdd
Add implementations of the Transformer model
gzrp Aug 14, 2024
9ce6f8f
Merge pull request #1202 from gzrp/dev-postgresql
chrishkchris Aug 14, 2024
4b721fb
Add the implementation of data processing for the transformer example
gzrp Aug 14, 2024
292acb2
Merge branch 'apache:dev-postgresql' into dev-postgresql
gzrp Aug 14, 2024
26664e8
add benchmark dataset for the transformer dataset
Zrealshadow Aug 16, 2024
b7d3a3b
Merge pull request #1203 from gzrp/dev-postgresql
lzjpaul Aug 16, 2024
a204dfd
Merge pull request #1204 from Zrealshadow/dev-postgresql-patch-2
lzjpaul Aug 17, 2024
0096e45
Update Transformer example README.md
zmeihui Aug 17, 2024
9b0bbf1
Update the license header for README.md
zmeihui Aug 17, 2024
294a719
Update pom.xml
zmeihui Aug 17, 2024
d2108b7
Merge pull request #1207 from zmeihui/24-8-17-dev
lzjpaul Aug 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ LIST(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Thirdparty)
#string(REGEX REPLACE "^[0-9]+\\.[0-9]+\\.([0-9]+).*" "\\1" VERSION_PATCH "${VERSION}")


SET(PACKAGE_VERSION 4.2.0) # ${VERSION})
SET(VERSION 4.2.0)
SET(PACKAGE_VERSION 4.3.0) # ${VERSION})
SET(VERSION 4.3.0)
SET(SINGA_MAJOR_VERSION 4)
SET(SINGA_MINOR_VERSION 2)
SET(SINGA_MINOR_VERSION 3)
SET(SINGA_PATCH_VERSION 0)
#SET(SINGA_MAJOR_VERSION ${VERSION_MAJOR}) # 0 -
#SET(SINGA_MINOR_VERSION ${VERSION_MINOR}) # 0 - 9
Expand Down
28 changes: 28 additions & 0 deletions RELEASE_NOTES
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
Release Notes - SINGA - Version singa-4.3.0

SINGA is a distributed deep learning library.

This release includes following changes:

* Add the implementation for the Transformer example.

* Enhance examples
* Update the readme file for the dynamic model slicing example.
* Update the HFL example by setting the maximum number of epochs.
* Add the multiprocess training implementation for the cnn ms example.
* Add the sparsification version of the model for the cnn ms example.

* Extend the matrix multiplication operator to more dimensions.

* Update the data types and tensor operations for model training.

* Add the implementation for the new sum error loss.

* Update the website
* Add the news for the SIGMOD Systems Award.

* Fix bugs
* Fix the Github Actions for online code testing.

----------------------------------------------------------------------------------------------

Release Notes - SINGA - Version singa-4.2.0

SINGA is a distributed deep learning library.
Expand Down
304 changes: 304 additions & 0 deletions examples/cnn_ms/autograd/mnist_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from singa import singa_wrap as singa
from singa import autograd
from singa import layer
from singa import tensor
from singa import device
from singa import opt
import numpy as np
import os
import sys
import gzip
import codecs
import time


class CNN:

def __init__(self):
self.conv1 = layer.Conv2d(1, 20, 5, padding=0)
self.conv2 = layer.Conv2d(20, 50, 5, padding=0)
self.linear1 = layer.Linear(4 * 4 * 50, 500)
self.linear2 = layer.Linear(500, 10)
self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
self.relu1 = layer.ReLU()
self.relu2 = layer.ReLU()
self.relu3 = layer.ReLU()
self.flatten = layer.Flatten()

def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pooling1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pooling2(y)
y = self.flatten(y)
y = self.linear1(y)
y = self.relu3(y)
y = self.linear2(y)
return y


def check_dataset_exist(dirpath):
if not os.path.exists(dirpath):
print(
'The MNIST dataset does not exist. Please download the mnist dataset using download_mnist.py (e.g. python3 download_mnist.py)'
)
sys.exit(0)
return dirpath


def load_dataset():
train_x_path = '/tmp/train-images-idx3-ubyte.gz'
train_y_path = '/tmp/train-labels-idx1-ubyte.gz'
valid_x_path = '/tmp/t10k-images-idx3-ubyte.gz'
valid_y_path = '/tmp/t10k-labels-idx1-ubyte.gz'

train_x = read_image_file(check_dataset_exist(train_x_path)).astype(
np.float32)
train_y = read_label_file(check_dataset_exist(train_y_path)).astype(
np.float32)
valid_x = read_image_file(check_dataset_exist(valid_x_path)).astype(
np.float32)
valid_y = read_label_file(check_dataset_exist(valid_y_path)).astype(
np.float32)
return train_x, train_y, valid_x, valid_y


def read_label_file(path):
with gzip.open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape((length))
return parsed


def get_int(b):
return int(codecs.encode(b, 'hex'), 16)


def read_image_file(path):
with gzip.open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
(length, 1, num_rows, num_cols))
return parsed


def to_categorical(y, num_classes):
y = np.array(y, dtype="int")
n = y.shape[0]
categorical = np.zeros((n, num_classes))
categorical[np.arange(n), y] = 1
categorical = categorical.astype(np.float32)
return categorical


def accuracy(pred, target):
y = np.argmax(pred, axis=1)
t = np.argmax(target, axis=1)
a = y == t
return np.array(a, "int").sum()


# Function to all reduce NUMPY accuracy and loss from multiple devices
def reduce_variable(variable, dist_opt, reducer):
reducer.copy_from_numpy(variable)
dist_opt.all_reduce(reducer.data)
dist_opt.wait()
output = tensor.to_numpy(reducer)
return output


# Function to sychronize SINGA TENSOR initial model parameters
def synchronize(tensor, dist_opt):
dist_opt.all_reduce(tensor.data)
dist_opt.wait()
tensor /= dist_opt.world_size


# Data augmentation
def augmentation(x, batch_size):
xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
for data_num in range(0, batch_size):
offset = np.random.randint(8, size=2)
x[data_num, :, :, :] = xpad[data_num, :, offset[0]:offset[0] + 28,
offset[1]:offset[1] + 28]
if_flip = np.random.randint(2)
if (if_flip):
x[data_num, :, :, :] = x[data_num, :, :, ::-1]
return x


# Data partition
def data_partition(dataset_x, dataset_y, global_rank, world_size):
data_per_rank = dataset_x.shape[0] // world_size
idx_start = global_rank * data_per_rank
idx_end = (global_rank + 1) * data_per_rank
return dataset_x[idx_start:idx_end], dataset_y[idx_start:idx_end]


def train_mnist_cnn(DIST=False,
local_rank=None,
world_size=None,
nccl_id=None,
spars=0,
topK=False,
corr=True):

# Define the hypermeters for the mnist_cnn
max_epoch = 10
batch_size = 64
sgd = opt.SGD(lr=0.005, momentum=0.9, weight_decay=1e-5)

# Prepare training and valadiation data
train_x, train_y, test_x, test_y = load_dataset()
IMG_SIZE = 28
num_classes = 10
train_y = to_categorical(train_y, num_classes)
test_y = to_categorical(test_y, num_classes)

# Normalization
train_x = train_x / 255
test_x = test_x / 255

if DIST:
# For distributed GPU training
sgd = opt.DistOpt(sgd,
nccl_id=nccl_id,
local_rank=local_rank,
world_size=world_size)
dev = device.create_cuda_gpu_on(sgd.local_rank)

# Dataset partition for distributed training
train_x, train_y = data_partition(train_x, train_y, sgd.global_rank,
sgd.world_size)
test_x, test_y = data_partition(test_x, test_y, sgd.global_rank,
sgd.world_size)
world_size = sgd.world_size
else:
# For single GPU
dev = device.create_cuda_gpu()
world_size = 1

# Create model
model = CNN()

tx = tensor.Tensor((batch_size, 1, IMG_SIZE, IMG_SIZE), dev, tensor.float32)
ty = tensor.Tensor((batch_size, num_classes), dev, tensor.int32)
num_train_batch = train_x.shape[0] // batch_size
num_test_batch = test_x.shape[0] // batch_size
idx = np.arange(train_x.shape[0], dtype=np.int32)

if DIST:
#Sychronize the initial parameters
autograd.training = True
x = np.random.randn(batch_size, 1, IMG_SIZE,
IMG_SIZE).astype(np.float32)
y = np.zeros(shape=(batch_size, num_classes), dtype=np.int32)
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
out = model.forward(tx)
loss = autograd.softmax_cross_entropy(out, ty)
for p, g in autograd.backward(loss):
synchronize(p, sgd)

# Training and evaulation loop
for epoch in range(max_epoch):
start_time = time.time()
np.random.shuffle(idx)

if ((DIST == False) or (sgd.global_rank == 0)):
print('Starting Epoch %d:' % (epoch))

# Training phase
autograd.training = True
train_correct = np.zeros(shape=[1], dtype=np.float32)
test_correct = np.zeros(shape=[1], dtype=np.float32)
train_loss = np.zeros(shape=[1], dtype=np.float32)

for b in range(num_train_batch):
x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
x = augmentation(x, batch_size)
y = train_y[idx[b * batch_size:(b + 1) * batch_size]]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
out = model.forward(tx)
loss = autograd.softmax_cross_entropy(out, ty)
train_correct += accuracy(tensor.to_numpy(out), y)
train_loss += tensor.to_numpy(loss)[0]
if DIST:
if (spars == 0):
sgd.backward_and_update(loss, threshold=50000)
else:
sgd.backward_and_sparse_update(loss,
spars=spars,
topK=topK,
corr=corr)
else:
sgd(loss)

if DIST:
# Reduce the evaluation accuracy and loss from multiple devices
reducer = tensor.Tensor((1,), dev, tensor.float32)
train_correct = reduce_variable(train_correct, sgd, reducer)
train_loss = reduce_variable(train_loss, sgd, reducer)

# Output the training loss and accuracy
if ((DIST == False) or (sgd.global_rank == 0)):
print('Training loss = %f, training accuracy = %f' %
(train_loss, train_correct /
(num_train_batch * batch_size * world_size)),
flush=True)

# Evaluation phase
autograd.training = False
for b in range(num_test_batch):
x = test_x[b * batch_size:(b + 1) * batch_size]
y = test_y[b * batch_size:(b + 1) * batch_size]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
out_test = model.forward(tx)
test_correct += accuracy(tensor.to_numpy(out_test), y)

if DIST:
# Reduce the evaulation accuracy from multiple devices
test_correct = reduce_variable(test_correct, sgd, reducer)

# Output the evaluation accuracy
if ((DIST == False) or (sgd.global_rank == 0)):
print('Evaluation accuracy = %f, Elapsed Time = %fs' %
(test_correct / (num_test_batch * batch_size * world_size),
time.time() - start_time),
flush=True)


if __name__ == '__main__':

DIST = False
train_mnist_cnn(DIST=DIST)
25 changes: 25 additions & 0 deletions examples/cnn_ms/autograd/mnist_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from mnist_cnn import *

if __name__ == '__main__':

DIST = True
train_mnist_cnn(DIST=DIST)
Loading
Loading