-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ba2bcb7
commit 19a1fbc
Showing
10 changed files
with
937 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
############ config ############# | ||
|
||
import os | ||
|
||
dataset_path = './saved_datasets' # path to save the process dataset for quick loading | ||
if not os.path.isdir(dataset_path): | ||
os.mkdir(dataset_path) | ||
save_path = './saved_models/' # path for saving models | ||
if not os.path.isdir(save_path): | ||
os.mkdir(save_path) | ||
|
||
###################################### | ||
|
||
dataset='IAM' | ||
|
||
fixed_size = (64, 256) | ||
max_epochs = 240 | ||
batch_size = 64 | ||
display=100 | ||
|
||
# ctc auxiliary loss | ||
ctc_aux = True | ||
|
||
###################################### | ||
|
||
classes = '_*0123456789abcdefghijklmnopqrstuvwxyz ' | ||
|
||
cdict = {c:i for i,c in enumerate(classes)} | ||
icdict = {i:c for i,c in enumerate(classes)} | ||
|
||
def reduced(istr): | ||
return ''.join([c if (c.isalnum() or c=='_' or c==' ') else '*' for c in istr.lower()]) | ||
|
||
###################################### | ||
|
||
from utils.phoc import build_phoc_descriptor | ||
levels = [1, 2, 3, 4] | ||
phoc = lambda x: build_phoc_descriptor(x, classes, levels) | ||
phoc_size = sum(levels) * len(classes) | ||
|
||
###################################### | ||
|
||
# architecture configuration | ||
cnn_cfg = [(2, 64), 'M', (4, 128), 'M', (4, 256)] | ||
|
||
|
||
###################################### | ||
|
||
load_code = 'phoc_kws_a_iam' | ||
name_code = 'phoc_kws_a' | ||
name_code += '_iam' | ||
|
||
###################################### | ||
from iam_config import stopwords_path | ||
stopwords = [] | ||
if dataset=='IAM': | ||
for line in open(stopwords_path): | ||
stopwords.append(line.strip().split(',')) | ||
stopwords = stopwords[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# iam configuration | ||
|
||
trainset_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/set_split/trainset.txt' | ||
testset_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/set_split/testset.txt' | ||
valset_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/set_split/validationset1.txt' | ||
|
||
#line_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/ascii/lines.txt' | ||
word_file = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/ascii/words.txt' | ||
|
||
word_path = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/words' | ||
#line_path = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/lines' | ||
|
||
stopwords_path = '/media/ncsr/bee4cbda-e313-4acf-9bc8-817a69ad98ae/IAM/iam-stopwords' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
''' | ||
@author: georgeretsi | ||
''' | ||
|
||
import numpy as np | ||
from skimage import io as img_io | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from os.path import isfile | ||
|
||
from utils.auxilary_functions import image_resize, centered | ||
|
||
from iam_config import * | ||
|
||
def gather_iam_info(set='train', level='word'): | ||
|
||
# train/test file | ||
if set == 'train': | ||
valid_set = np.loadtxt(trainset_file, dtype=str) | ||
elif set == 'test': | ||
valid_set = np.loadtxt(testset_file, dtype=str) | ||
elif set == 'val': | ||
valid_set = np.loadtxt(valset_file, dtype=str) | ||
else: | ||
print('split name not found. Valid splits: [train, val, test]') | ||
return | ||
|
||
|
||
if level == 'word': | ||
gtfile= word_file | ||
root_path = word_path | ||
elif level == 'line': | ||
gtfile = line_file | ||
root_path = line_path | ||
else: | ||
print('segmentation level not found. Valid segmentation level: [word, line[') | ||
return | ||
|
||
gt = [] | ||
for line in open(gtfile): | ||
if not line.startswith("#"): | ||
info = line.strip().split() | ||
name = info[0] | ||
|
||
name_parts = name.split('-') | ||
pathlist = [root_path] + ['-'.join(name_parts[:i+1]) for i in range(len(name_parts))] | ||
if level == 'word': | ||
tname = pathlist[-2] | ||
del pathlist[-2] | ||
elif level == 'line': | ||
tname = pathlist[-1] | ||
|
||
if (info[1] != 'ok') or (tname not in valid_set): | ||
#if (line_name not in valid_set): | ||
continue | ||
|
||
img_path = '/'.join(pathlist) | ||
|
||
transcr = ' '.join(info[8:]) | ||
gt.append((img_path, transcr)) | ||
|
||
return gt | ||
|
||
def main_loader(set, level): | ||
|
||
info = gather_iam_info(set, level) | ||
|
||
data = [] | ||
for i, (img_path, transcr) in enumerate(info): | ||
|
||
if i % 1000 == 0: | ||
print('imgs: [{}/{} ({:.0f}%)]'.format(i, len(info), 100. * i / len(info))) | ||
|
||
try: | ||
img = img_io.imread(img_path + '.png') | ||
img = 1 - img.astype(np.float32) / 255.0 | ||
img = image_resize(img, height=img.shape[0] // 2) | ||
except: | ||
continue | ||
|
||
data += [(img, transcr.replace("|", " "))] | ||
|
||
return data | ||
|
||
class IAMLoader(Dataset): | ||
|
||
def __init__(self, dataset_path, set, level='line', fixed_size=(128, None), transforms=None): | ||
|
||
self.transforms = transforms | ||
self.set = set | ||
self.fixed_size = fixed_size | ||
|
||
save_file = dataset_path + '/' + set + '_' + level + '.pt' | ||
|
||
if isfile(save_file) is False: | ||
data = main_loader(set=set, level=level) | ||
torch.save(data, save_file) | ||
else: | ||
data = torch.load(save_file) | ||
|
||
self.data = data | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
|
||
img = self.data[index][0] | ||
|
||
transcr = self.data[index][1] | ||
fheight, fwidth = self.fixed_size[0], self.fixed_size[1] | ||
|
||
if self.set == 'train': | ||
# random resize at training !!! | ||
nwidth = int(np.random.uniform(.5, 1.5) * img.shape[1]) | ||
nheight = int((np.random.uniform(.8, 1.2) * img.shape[0] / img.shape[1]) * nwidth) | ||
else: | ||
nheight, nwidth = img.shape[0], img.shape[1] | ||
|
||
nheight, nwidth = max(4, min(fheight-16, nheight)), max(8, min(fwidth-32, nwidth)) | ||
img = image_resize(img, height=int(1.0 * nheight), width=int(1.0 * nwidth)) | ||
|
||
img = centered(img, (fheight, fwidth), border_value=None) | ||
if self.transforms is not None: | ||
for tr in self.transforms: | ||
if np.random.rand() < .5: | ||
img = tr(img) | ||
|
||
img = torch.Tensor(img).float().unsqueeze(0) | ||
|
||
return img, transcr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch | ||
|
||
class BasicBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, in_planes, planes, stride=1): | ||
super(BasicBlock, self).__init__() | ||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
|
||
self.shortcut = nn.Sequential() | ||
if stride != 1 or in_planes != self.expansion*planes: | ||
self.shortcut = nn.Sequential( | ||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(self.expansion*planes) | ||
) | ||
|
||
def forward(self, x): | ||
out = F.relu(self.bn1(self.conv1(x))) | ||
out = self.bn2(self.conv2(out)) | ||
out += self.shortcut(x) | ||
out = F.relu(out) | ||
return out | ||
|
||
|
||
class CNN(nn.Module): | ||
def __init__(self, cnn_cfg): | ||
super(CNN, self).__init__() | ||
|
||
self.features = nn.ModuleList([nn.Conv2d(1, 32, 7, 2, 3), nn.ReLU()]) | ||
in_channels = 32 | ||
cntm = 0 | ||
cnt = 1 | ||
for m in cnn_cfg: | ||
if m == 'M': | ||
self.features.add_module('mxp' + str(cntm), nn.MaxPool2d(kernel_size=2, stride=2)) | ||
cntm += 1 | ||
else: | ||
for i in range(m[0]): | ||
x = m[1] | ||
self.features.add_module('cnv' + str(cnt), BasicBlock(in_channels, x)) | ||
in_channels = x | ||
cnt += 1 | ||
|
||
def forward(self, x): | ||
|
||
y = x | ||
for nn_module in self.features: | ||
y = nn_module(y) | ||
|
||
y = F.max_pool2d(y, [y.size(2), 1], stride=[y.size(2), 1]) | ||
|
||
return y | ||
|
||
|
||
class KWSNet(nn.Module): | ||
def __init__(self, cnn_cfg, phoc_size, nclasses=None): | ||
super(KWSNet, self).__init__() | ||
|
||
self.cnn = CNN(cnn_cfg) | ||
|
||
hidden_size = cnn_cfg[-1][-1] | ||
|
||
self.termporal = nn.Sequential( | ||
nn.Conv2d(hidden_size, nclasses, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2)), | ||
) | ||
|
||
hidden_size = cnn_cfg[-1][-1] | ||
|
||
|
||
self.enc = nn.Sequential( | ||
nn.Conv2d(hidden_size, hidden_size, kernel_size=(1, 5), stride=(1, 2), padding=(0, 2)), | ||
nn.BatchNorm2d(hidden_size), nn.ReLU(), nn.Dropout(.1), | ||
nn.Conv2d(hidden_size, hidden_size, kernel_size=(1, 5), stride=(1, 2), padding=(0, 2)), | ||
nn.BatchNorm2d(hidden_size), nn.ReLU(), nn.Dropout(.1), | ||
nn.Conv2d(hidden_size, hidden_size, kernel_size=(1, 5), stride=(1, 2), padding=(0, 2)), | ||
nn.BatchNorm2d(hidden_size), | ||
) | ||
|
||
self.fnl = nn.Sequential( | ||
nn.ReLU(), nn.Dropout(.1), | ||
nn.Linear(4 * hidden_size, 4 * hidden_size), nn.ReLU(), nn.Dropout(.1), | ||
nn.Linear(4 * hidden_size, phoc_size) | ||
) | ||
|
||
|
||
def forward(self, x): | ||
y = self.cnn(x) | ||
|
||
y_ctc = self.termporal(y) | ||
|
||
y_feat = self.enc(y) | ||
|
||
if self.training: | ||
return y_ctc.permute(2, 3, 0, 1)[0], self.fnl(y_feat.view(x.size(0), -1)) | ||
else: | ||
return y_ctc.permute(2, 3, 0, 1)[0], y_feat.view(x.size(0), -1) | ||
|
||
def load_my_state_dict(self, state_dict): | ||
|
||
own_state = self.state_dict() | ||
for name, param in state_dict.items(): | ||
if name not in own_state: | ||
continue | ||
if isinstance(param, nn.Parameter): | ||
# backwards compatibility for serialized parameters | ||
param = param.data | ||
try: | ||
own_state[name].copy_(param) | ||
except: | ||
print("parameter missmatch @ " + name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
imageio==2.16.1 | ||
networkx==2.6.3 | ||
numpy==1.21.5 | ||
opencv-python==4.5.5.62 | ||
packaging==21.3 | ||
Pillow==9.0.1 | ||
pyparsing==3.0.7 | ||
PyWavelets==1.2.0 | ||
scikit-image==0.19.2 | ||
scipy==1.7.3 | ||
tifffile==2021.11.2 | ||
torch==1.10.2 | ||
typing-extensions==4.1.1 |
Oops, something went wrong.