-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix bug with dataset iterator processing (#201)
- Fix bug with iterating over datasets whose elements may be tuples. - This happens when a tensor dataset generator outputs tensors in a specific format. - Previously, Ariadne was not inferring tensors that were being picked out of the tuple during dataset iteration. - Added corresponding tests. - Don't consider exceptions as tensor dataflow sources. - This was happening in the old code. I've added it to the new code as well. - Since invocation instructions are processed in multiple places, I've extracted a common method. - The two different call sites differ on the `src` points-to variable to be added. - This guarantees that the order in which the points-to variables are processed doesn't matter.
- Loading branch information
Showing
6 changed files
with
235 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class C: | ||
|
||
def __init__(self, some_iter): | ||
self.some_iter = some_iter | ||
|
||
def __str__(self): | ||
return str(self.some_iter) | ||
|
||
|
||
def add(a, b): | ||
return a + b | ||
|
||
|
||
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) | ||
my_iter = iter(dataset) | ||
c = C(my_iter) | ||
length = len(dataset) | ||
|
||
for _ in range(length): | ||
element = next(c.some_iter) | ||
add(element, element) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import tensorflow as tf | ||
|
||
|
||
def add(a, b): | ||
return a + b | ||
|
||
|
||
def gen_iter(ds): | ||
return iter(ds) | ||
|
||
|
||
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) | ||
|
||
my_iter = gen_iter(dataset) | ||
length = len(dataset) | ||
|
||
for _ in range(length): | ||
element = next(my_iter) | ||
add(element, element) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class C: | ||
|
||
def __init__(self, some_iter): | ||
self.some_iter = some_iter | ||
|
||
def __str__(self): | ||
return str(self.some_iter) | ||
|
||
|
||
def id1(a): | ||
return a | ||
|
||
|
||
def id2(a): | ||
return a | ||
|
||
|
||
def gen(): | ||
yield "42", tf.constant("43") | ||
|
||
|
||
dataset = tf.data.Dataset.from_generator(gen, output_types=(tf.string, tf.string)) | ||
|
||
my_iter = iter(dataset) | ||
c = C(my_iter) | ||
length = 1 | ||
|
||
for _ in range(length): | ||
x, y = next(c.some_iter) | ||
id1(x) | ||
id2(y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class C: | ||
|
||
def __init__(self, some_iter): | ||
self.some_iter = some_iter | ||
|
||
def __str__(self): | ||
return str(self.some_iter) | ||
|
||
|
||
def add(a, b): | ||
return a + b | ||
|
||
|
||
def gen_iter(dataset): | ||
my_iter = iter(dataset) | ||
return C(my_iter) | ||
|
||
|
||
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) | ||
c = gen_iter(dataset) | ||
length = len(dataset) | ||
|
||
for _ in range(length): | ||
element = next(c.some_iter) | ||
add(element, element) |