Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeretsi authored May 23, 2022
1 parent ba2bcb7 commit 19a1fbc
Show file tree
Hide file tree
Showing 10 changed files with 937 additions and 0 deletions.
59 changes: 59 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
############ config #############

import os

dataset_path = './saved_datasets' # path to save the process dataset for quick loading
if not os.path.isdir(dataset_path):
os.mkdir(dataset_path)
save_path = './saved_models/' # path for saving models
if not os.path.isdir(save_path):
os.mkdir(save_path)

######################################

dataset='IAM'

fixed_size = (64, 256)
max_epochs = 240
batch_size = 64
display=100

# ctc auxiliary loss
ctc_aux = True

######################################

classes = '_*0123456789abcdefghijklmnopqrstuvwxyz '

cdict = {c:i for i,c in enumerate(classes)}
icdict = {i:c for i,c in enumerate(classes)}

def reduced(istr):
return ''.join([c if (c.isalnum() or c=='_' or c==' ') else '*' for c in istr.lower()])

######################################

from utils.phoc import build_phoc_descriptor
levels = [1, 2, 3, 4]
phoc = lambda x: build_phoc_descriptor(x, classes, levels)
phoc_size = sum(levels) * len(classes)

######################################

# architecture configuration
cnn_cfg = [(2, 64), 'M', (4, 128), 'M', (4, 256)]


######################################

load_code = 'phoc_kws_a_iam'
name_code = 'phoc_kws_a'
name_code += '_iam'

######################################
from iam_config import stopwords_path
stopwords = []
if dataset=='IAM':
for line in open(stopwords_path):
stopwords.append(line.strip().split(','))
stopwords = stopwords[0]
13 changes: 13 additions & 0 deletions iam_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# iam configuration

trainset_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/set_split/trainset.txt'
testset_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/set_split/testset.txt'
valset_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/set_split/validationset1.txt'

#line_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/ascii/lines.txt'
word_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/ascii/words.txt'

word_path = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/words'
#line_path = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/lines'

stopwords_path = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/iam-stopwords'
132 changes: 132 additions & 0 deletions iam_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
'''
@author: georgeretsi
'''

import numpy as np
from skimage import io as img_io
import torch
from torch.utils.data import Dataset

from os.path import isfile

from utils.auxilary_functions import image_resize, centered

from iam_config import *

def gather_iam_info(set='train', level='word'):

# train/test file
if set == 'train':
valid_set = np.loadtxt(trainset_file, dtype=str)
elif set == 'test':
valid_set = np.loadtxt(testset_file, dtype=str)
elif set == 'val':
valid_set = np.loadtxt(valset_file, dtype=str)
else:
print('split name not found. Valid splits: [train, val, test]')
return


if level == 'word':
gtfile= word_file
root_path = word_path
elif level == 'line':
gtfile = line_file
root_path = line_path
else:
print('segmentation level not found. Valid segmentation level: [word, line[')
return

gt = []
for line in open(gtfile):
if not line.startswith("#"):
info = line.strip().split()
name = info[0]

name_parts = name.split('-')
pathlist = [root_path] + ['-'.join(name_parts[:i+1]) for i in range(len(name_parts))]
if level == 'word':
tname = pathlist[-2]
del pathlist[-2]
elif level == 'line':
tname = pathlist[-1]

if (info[1] != 'ok') or (tname not in valid_set):
#if (line_name not in valid_set):
continue

img_path = '/'.join(pathlist)

transcr = ' '.join(info[8:])
gt.append((img_path, transcr))

return gt

def main_loader(set, level):

info = gather_iam_info(set, level)

data = []
for i, (img_path, transcr) in enumerate(info):

if i % 1000 == 0:
print('imgs: [{}/{} ({:.0f}%)]'.format(i, len(info), 100. * i / len(info)))

try:
img = img_io.imread(img_path + '.png')
img = 1 - img.astype(np.float32) / 255.0
img = image_resize(img, height=img.shape[0] // 2)
except:
continue

data += [(img, transcr.replace("|", " "))]

return data

class IAMLoader(Dataset):

def __init__(self, dataset_path, set, level='line', fixed_size=(128, None), transforms=None):

self.transforms = transforms
self.set = set
self.fixed_size = fixed_size

save_file = dataset_path + '/' + set + '_' + level + '.pt'

if isfile(save_file) is False:
data = main_loader(set=set, level=level)
torch.save(data, save_file)
else:
data = torch.load(save_file)

self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):

img = self.data[index][0]

transcr = self.data[index][1]
fheight, fwidth = self.fixed_size[0], self.fixed_size[1]

if self.set == 'train':
# random resize at training !!!
nwidth = int(np.random.uniform(.5, 1.5) * img.shape[1])
nheight = int((np.random.uniform(.8, 1.2) * img.shape[0] / img.shape[1]) * nwidth)
else:
nheight, nwidth = img.shape[0], img.shape[1]

nheight, nwidth = max(4, min(fheight-16, nheight)), max(8, min(fwidth-32, nwidth))
img = image_resize(img, height=int(1.0 * nheight), width=int(1.0 * nwidth))

img = centered(img, (fheight, fwidth), border_value=None)
if self.transforms is not None:
for tr in self.transforms:
if np.random.rand() < .5:
img = tr(img)

img = torch.Tensor(img).float().unsqueeze(0)

return img, transcr
115 changes: 115 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch.nn as nn
import torch.nn.functional as F
import torch

class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class CNN(nn.Module):
def __init__(self, cnn_cfg):
super(CNN, self).__init__()

self.features = nn.ModuleList([nn.Conv2d(1, 32, 7, 2, 3), nn.ReLU()])
in_channels = 32
cntm = 0
cnt = 1
for m in cnn_cfg:
if m == 'M':
self.features.add_module('mxp' + str(cntm), nn.MaxPool2d(kernel_size=2, stride=2))
cntm += 1
else:
for i in range(m[0]):
x = m[1]
self.features.add_module('cnv' + str(cnt), BasicBlock(in_channels, x))
in_channels = x
cnt += 1

def forward(self, x):

y = x
for nn_module in self.features:
y = nn_module(y)

y = F.max_pool2d(y, [y.size(2), 1], stride=[y.size(2), 1])

return y


class KWSNet(nn.Module):
def __init__(self, cnn_cfg, phoc_size, nclasses=None):
super(KWSNet, self).__init__()

self.cnn = CNN(cnn_cfg)

hidden_size = cnn_cfg[-1][-1]

self.termporal = nn.Sequential(
nn.Conv2d(hidden_size, nclasses, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2)),
)

hidden_size = cnn_cfg[-1][-1]


self.enc = nn.Sequential(
nn.Conv2d(hidden_size, hidden_size, kernel_size=(1, 5), stride=(1, 2), padding=(0, 2)),
nn.BatchNorm2d(hidden_size), nn.ReLU(), nn.Dropout(.1),
nn.Conv2d(hidden_size, hidden_size, kernel_size=(1, 5), stride=(1, 2), padding=(0, 2)),
nn.BatchNorm2d(hidden_size), nn.ReLU(), nn.Dropout(.1),
nn.Conv2d(hidden_size, hidden_size, kernel_size=(1, 5), stride=(1, 2), padding=(0, 2)),
nn.BatchNorm2d(hidden_size),
)

self.fnl = nn.Sequential(
nn.ReLU(), nn.Dropout(.1),
nn.Linear(4 * hidden_size, 4 * hidden_size), nn.ReLU(), nn.Dropout(.1),
nn.Linear(4 * hidden_size, phoc_size)
)


def forward(self, x):
y = self.cnn(x)

y_ctc = self.termporal(y)

y_feat = self.enc(y)

if self.training:
return y_ctc.permute(2, 3, 0, 1)[0], self.fnl(y_feat.view(x.size(0), -1))
else:
return y_ctc.permute(2, 3, 0, 1)[0], y_feat.view(x.size(0), -1)

def load_my_state_dict(self, state_dict):

own_state = self.state_dict()
for name, param in state_dict.items():
if name not in own_state:
continue
if isinstance(param, nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print("parameter missmatch @ " + name)
13 changes: 13 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
imageio==2.16.1
networkx==2.6.3
numpy==1.21.5
opencv-python==4.5.5.62
packaging==21.3
Pillow==9.0.1
pyparsing==3.0.7
PyWavelets==1.2.0
scikit-image==0.19.2
scipy==1.7.3
tifffile==2021.11.2
torch==1.10.2
typing-extensions==4.1.1
Loading

0 comments on commit 19a1fbc

Please sign in to comment.