Skip to content

Simplest implementation of ResNet in Keras for R

License

Notifications You must be signed in to change notification settings

GreenEric/resnet-rkeras

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ResNet in Keras for R

This is the simplest implementation of ResNet in Keras for R you can think of. It's quite short and limited by now, but I'll try to add more features in the future. It's also missing some auxiliary functions I was using to plot confidence intervals and so on, I'll upload a Jupyter notebook any time soon.

The implementation is based on this one written in Lua with the Torch Framework. It also implements the small tweak of removing the ReLU activations at the end of each residual block as described here.

Note that you have to call install_keras() in a R session after installing the environment.

A simple example of how to use the code is shown below.

source('resnet.R')

# Taking a subset of the Cifar-10 dataset
cifar10 <- dataset_cifar10()
cifar10.orig <- cifar10

x_train <- cifar10$train$x[1:100,,,]
y_train <- cifar10$train$y[1:100,]
x_test <- cifar10$test$x[1:10,,,]
y_test <- cifar10$test$y[1:10,]

y_tags <- y_train
y_train <- to_categorical(y_train)
y_test <- to_categorical(y_test)

model <- build_resnet_cifar10(20)

# Doing cross validation (it concatenates all the results)
model.cv <- do.cross.validation.resnet(20,
	    x_train, y_train, batch_size=5, 
	    epochs=10, y_tags=y_tags, k=5,
	    loss='categorical_crossentropy', 
	    metrics=c('accuracy') 
    )

# Compiling and training the model
model %>%
	compile(
  	  optimizer=optimizer_sgd(lr=0.1, momentum=0.9, decay=0.0001),
  	  loss='categorical_crossentropy', metrics=c('accuracy')
  	  ) %>%
	fit(
    	x_train, y_train, validation_split=0.2,
    	verbose=0, batch_size=5, epochs=10,
    	callbacks = c(callback_reduce_lr_on_plateau(verbose=0, patience=10, factor=0.1))
    )
 
 # Getting and plotting the predictions
 predictions <- predict(model, x_test)
 print(paste('Predictions:', paste0(max.col(predictions), collapse=' ')))
 print(paste('Real values:', paste0(max.col(y_test), collapse=' ')))

About

Simplest implementation of ResNet in Keras for R

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • R 100.0%