Skip to content

Commit

Permalink
Add go-tfdata integration (tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
knopt committed May 11, 2020
1 parent b6b3cd0 commit fde2e02
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 10 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ bin/tarp-full: $(cmds) $(datapipes)
bin/tarp -h

test:
cd datapipes && go test -v
cd dpipes && go test -v

test-tfdata:
cd dpipes && go test -v --tags=gitlabnvidia

dtest:
cd datapipes && debug=stdout go test -v | tee ../test.log
Expand Down
3 changes: 2 additions & 1 deletion dpipes/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/tmbdev/tarp/dpipes
go 1.14

require (
github.com/NVIDIA/go-tfdata v0.3.1
github.com/shamaton/msgpack v1.1.1
github.com/stretchr/testify v1.2.2
github.com/stretchr/testify v1.3.0
)
143 changes: 143 additions & 0 deletions dpipes/gotfdata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package dpipes

import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"testing"

"github.com/NVIDIA/go-tfdata/tfdata/core"
"github.com/NVIDIA/go-tfdata/tfdata/transform"
"github.com/stretchr/testify/assert"
)

type (
SamplesReader struct {
pipe Pipe
}
)

func (r *SamplesReader) Read() (sample *core.Sample, err error) {
s, ok := <-r.pipe
if !ok {
return nil, io.EOF
}

return tarpSampleToTfDataSample(s), nil
}

func TFRecordSink(t *testing.T, writer io.Writer) Sink {
return func(pipe Pipe) {
w := core.NewTFRecordWriter(writer)
samplesReader := &SamplesReader{pipe}
tfExamplesReader := transform.SamplesToTFExample(samplesReader)
err := w.WriteMessages(tfExamplesReader)

assert.NoError(t, err)
}
}

func TFRecordSource(t *testing.T, reader io.Reader) Source {
return func(pipe Pipe) {
defer close(pipe)
var (
ex *core.TFExample
err error
r core.TFExampleReader
)
r = core.NewTFRecordReader(reader)
for ex, err = r.Read(); err == nil; ex, err = r.Read() {
pipe <- tfExampleTarpSample(ex)
}
if err != io.EOF {
assert.Fail(t, "expected to get io.EOF, got %v instead", err)
}
}
}

func SamplesChecker(t *testing.T, target int) Process {
return func(in, out Pipe) {
total := 0
for s := range in {
assert.Equal(t, s["txt"], Bytes(fmt.Sprintf("%d", total)))
assert.Equal(t, s["__key__"], Bytes(fmt.Sprintf("%06d", total)))
total++
out <- s
}
close(out)
assert.Equal(t, target, total)
}
}

func tarpSampleToTfDataSample(sample Sample) *core.Sample {
s := core.NewSample()
for k, v := range sample {
s.Entries[k] = v
}
return s
}

func tfExampleTarpSample(example *core.TFExample) Sample {
s := make(map[string]Bytes, len(example.GetFeatures().Feature))
for k, v := range example.GetFeatures().Feature {
var b Bytes
err := json.Unmarshal(v.GetBytesList().Value[0], &b)
if err != nil {
panic(err)
}
s[k] = b // assume that all TFExample features are just a list of bytes
}
return s
}

func PrepareTarSource() Source {
return func(pipe Pipe) {
for i := 0; i < 1; i++ {
pipe <- Sample{
"__key__": Bytes(fmt.Sprintf("%06d", i)),
"txt": Bytes(fmt.Sprintf("%d", i)),
}
}
close(pipe)
}
}

func prepareTar(t *testing.T) *os.File {
var (
sinkFd *os.File
err error
)
sinkFd, err = ioutil.TempFile("", "go-tfdata-*.tar")
assert.NoError(t, err)

sink := TarSink(sinkFd)
Processing(PrepareTarSource(), nil, sink)
return sinkFd
}

func TestGoTfData(t *testing.T) {
var (
sourceFd = prepareTar(t)
sinkFd *os.File
err error
)

defer os.RemoveAll(sourceFd.Name())
sourceFd, err = os.Open(sourceFd.Name())
assert.NoError(t, err)

sinkFd, err = ioutil.TempFile("", "go-tfdata-*.tfrecord")
assert.NoError(t, err)
defer os.RemoveAll(sinkFd.Name())

Processing(TarSource(sourceFd), nil, TFRecordSink(t, sinkFd))
sinkFd.Close()
sourceFd, err = os.Open(sinkFd.Name())
assert.NoError(t, err)
sinkFd, err = os.OpenFile(os.DevNull, os.O_RDWR, os.ModeAppend)
assert.NoError(t, err)

Processing(TFRecordSource(t, sourceFd), SamplesChecker(t, 1), TFRecordSink(t, sinkFd))
}
8 changes: 0 additions & 8 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,3 @@ module github.com/tmbdev/tarp
replace github.com/tmbdev/tarp/dpipes => ./dpipes

go 1.14

require (
github.com/bcicen/ctop v0.7.3 // indirect
github.com/jessevdk/go-flags v1.4.0
github.com/maruel/panicparse v1.3.0 // indirect
github.com/stretchr/testify v1.2.2
github.com/tmbdev/tarp/dpipes v0.0.0-20200330012711-53823ac810b9
)

0 comments on commit fde2e02

Please sign in to comment.