-
Notifications
You must be signed in to change notification settings - Fork 4
/
cnn.py
15 lines (15 loc) · 913 Bytes
/
cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class CNN(Sequential):
def __init__(self,nfilters,sfilters):
super().__init__()
tensorflow.random.set_seed(0)
self.add(Conv2D(nfilters[0],kernel_size=(sfilters[0],sfilters[0]),padding='same',activation='relu',input_shape=(112,92,1)))
self.add(MaxPooling2D(pool_size=(2,2),strides=(2,2)))
self.add(Conv2D(nfilters[1],kernel_size=(sfilters[1],sfilters[1]),padding='same',activation='relu'))
self.add(MaxPooling2D(pool_size=(2,2),strides=(2,2)))
self.add(Conv2D(nfilters[2],kernel_size=(sfilters[2],sfilters[2]),padding='same',activation='relu'))
self.add(MaxPooling2D(pool_size=(2,2),strides=(2,2)))
self.add(Flatten())
self.add(Dropout(0.3))
self.add(Dense(512,activation='relu'))
self.add(Dense(40,activation='softmax'))
self.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])