-
Notifications
You must be signed in to change notification settings - Fork 2
/
output_projection.py
79 lines (57 loc) · 3.84 KB
/
output_projection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.python.ops import variable_scope
def output_projection_layer(num_units, num_symbols, num_samples=None, name="output_projection"):
def output_fn(outputs):
return layers.linear(outputs, num_symbols, scope=name)
def selector_fn(outputs):
selector = tf.sigmoid(layers.linear(outputs, 1, scope='selector'))
return selector
def sampled_sequence_loss(outputs, targets, masks):
with variable_scope.variable_scope('decoder_rnn/%s' % name):
weights = tf.transpose(tf.get_variable("weights", [num_units, num_symbols]))
bias = tf.get_variable("biases", [num_symbols])
local_labels = tf.reshape(targets, [-1, 1])
local_outputs = tf.reshape(outputs, [-1, num_units])
local_masks = tf.reshape(masks, [-1])
local_loss = tf.nn.sampled_softmax_loss(weights, bias, local_labels,
local_outputs, num_samples, num_symbols)
local_loss = local_loss * local_masks
loss = tf.reduce_sum(local_loss)
total_size = tf.reduce_sum(local_masks)
total_size += 1e-12 # to avoid division by 0 for all-0 weights
return loss / total_size
def sequence_loss(outputs, targets, masks):
with variable_scope.variable_scope('decoder_rnn'):
logits = layers.linear(outputs, num_symbols, scope=name)
logits = tf.reshape(logits, [-1, num_symbols])
local_labels = tf.reshape(targets, [-1])
local_masks = tf.reshape(masks, [-1])
local_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=local_labels, logits=logits)
local_loss = local_loss * local_masks
loss = tf.reduce_sum(local_loss)
total_size = tf.reduce_sum(local_masks)
total_size += 1e-12 # to avoid division by 0 for all-0 weights
return loss / total_size
def total_loss(outputs, targets, masks, alignments, triples_embedding, use_entities, entity_targets):
alpha = tf.maximum(tf.reduce_sum(use_entities), 1e-12)
beta = tf.maximum(tf.reduce_sum(masks) - alpha, 1e-12)
alpha_mask = (masks - use_entities) / beta + use_entities / alpha
batch_size = tf.shape(outputs)[0]
local_masks = tf.reshape(masks, [-1])
logits = layers.linear(outputs, num_symbols, scope='decoder_rnn/%s' % name)
one_hot_targets = tf.one_hot(targets, num_symbols)
selector = tf.squeeze(tf.sigmoid(layers.linear(outputs, 1, scope='decoder_rnn/selector')))
word_prob = tf.reduce_sum(tf.nn.softmax(logits) * one_hot_targets, axis=2)
triple_prob = tf.reduce_sum(alignments * entity_targets, axis=2)
ppx_prob = word_prob * (1 - use_entities) + triple_prob * use_entities
final_prob = word_prob * (1 - selector) + triple_prob * selector
final_loss = tf.reduce_sum(tf.reshape( - tf.log(tf.maximum(1e-12, final_prob)), [-1]) * local_masks)
ppx_loss = tf.reduce_sum(tf.reshape( - tf.log(tf.maximum(1e-12, ppx_prob)), [-1]) * local_masks)
sentence_ppx = tf.reduce_sum(tf.reshape(tf.reshape( - tf.log(tf.maximum(1e-12, final_prob)), [-1]) * local_masks, [batch_size, -1]), axis=1)
selector_loss = tf.reduce_sum(tf.reshape( - alpha_mask * tf.log(tf.maximum(1e-12, selector * use_entities + (1 - selector) * (1 - use_entities))), [-1]) * local_masks)
total_size = tf.maximum(tf.reduce_sum(local_masks), 1e-12)
loss = ppx_loss / total_size + 0.1 * selector_loss
ppx_loss = ppx_loss / total_size
return loss, ppx_loss, sentence_ppx / tf.maximum(tf.reduce_sum(masks, axis=1), 1e-12)
return output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss