Implements Deep Residual U-Net for semantic segmentation using Tensorflow 2.0.
Run following script from your target python virtual environment
python -m pip install git+https://github.com/JanPalasek/resunet-tensorflow
After installing to your own project, you can import and instantiate the model. The model instance is a standard Tensorflow 2.0 model, so you can use it accordingly.
from resunet import ResUNet
# create model for inputs of sizes (128, 128, 1) for semantic segmentation into 2 classes
# architecture will have 16 filters in the root and the depth of 3 blocks
model = ResUNet(input_shape=(128, 128, 1), classes=2, filters_root=16, depth=3)
# compile the model
# categorical crossentropy is the preferred loss function
model.compile(loss="categorical_crossentropy", optimizer="adam",
metrics=["categorical_accuracy", "some other metrics"])
# obtain the dataset
x, y, validation_dataset = ...
# use model.fit, model.evalute as with any other tf2 model
model.fit(x=x, y=y, validation_data=validation_dataset, epochs=args.epochs, batch_size=args.batch_size)
Alternatively you can use repository's script train.py, although it is not recommended.
- Tensorflow 2.0 (version can be also higher)