-
Notifications
You must be signed in to change notification settings - Fork 68
/
TestUpd.py
executable file
·73 lines (49 loc) · 2.26 KB
/
TestUpd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import tensorflow as tf
import numpy as np
import time
def omniglot():
sess = tf.InteractiveSession()
""" def wrapper(v):
return tf.Print(v, [v], message="Printing v")
v = tf.Variable(initial_value=np.arange(0, 36).reshape((6, 6)), dtype=tf.float32, name='Matrix')
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
temp = tf.Variable(initial_value=np.arange(0, 36).reshape((6, 6)), dtype=tf.float32, name='temp')
temp = wrapper(v)
#with tf.control_dependencies([temp]):
temp.eval()
print 'Hello'"""
def update_tensor(V, dim2, val): # Update tensor V, with index(:,dim2[:]) by val[:]
val = tf.cast(val, V.dtype)
def body(_, (v, d2, chg)):
d2_int = tf.cast(d2, tf.int32)
return tf.slice(tf.concat_v2([v[:d2_int],[chg] ,v[d2_int+1:]], axis=0), [0], [v.get_shape().as_list()[0]])
Z = tf.scan(body, elems=(V, dim2, val), initializer=tf.constant(1, shape=V.get_shape().as_list()[1:], dtype=tf.float32), name="Scan_Update")
return Z
print 'Compiling the Model'
tt1 = tf.Variable(initial_value=np.arange(0, 36).reshape((6, 6)), dtype=tf.float32, name='Matrix')
ix = tf.Variable(initial_value=np.arange(0, 6), name='Indices')
val = tf.Variable(initial_value=np.arange(100, 106), name='Values', dtype=tf.float32)
tt = tf.concat_v2([tt1[:3], tf.reshape(tf.range(0,6,dtype=tf.float32),shape=(1,6)), tt1[3:]], axis=0)
print tt1[:3].get_shape().as_list()
"""op = tt1[4].assign(val)
sess.run(tf.global_variables_initializer())
sess.run(op)
print tt1.eval()"""
op = tt1.assign(update_tensor(tt1, ix, val))
val = tf.Print(val, [val], "This works fine")
sess.run(tf.global_variables_initializer())
#sess.run(tf.local_variables_initializer())
print 'Training the model'
print tt.eval()
writer = tf.summary.FileWriter('/tmp/tensorflow', graph=tf.get_default_graph())
#tf.scalar_summary('cost', cost)
print 'tt1: ',tt1.eval()
print 'ix: ',ix.eval()
print 'val: ',val.eval()
sess.run(op)
print 'After run\n', tt1.eval()
#with tf.control_dependencies([op]):
# print '********************','\n',tt1.eval(),'\n', op.eval()
if __name__ == '__main__':
omniglot()