diff --git a/pkg/blocks/iocapture.go b/pkg/blocks/iocapture.go index 58589c2e..7bb2bdbe 100644 --- a/pkg/blocks/iocapture.go +++ b/pkg/blocks/iocapture.go @@ -27,39 +27,80 @@ import ( "github.com/facebookincubator/ttpforge/pkg/logging" ) -type zapWriter struct { - prefix string +type bufferedWriter struct { + buff bytes.Buffer + writer io.Writer } -func (z *zapWriter) Write(b []byte) (int, error) { - n := len(b) - // extra-defensive programming :P - if n <= 0 { - return 0, nil +func (bw *bufferedWriter) Write(b []byte) (n int, err error) { + n = len(b) + for len(b) > 0 { + b = bw.writeLine(b) } + return n, nil +} - // strip trailing newline - if b[n-1] == '\n' { - b = b[:n-1] +func (bw *bufferedWriter) writeLine(line []byte) (remaining []byte) { + idx := bytes.IndexByte(line, '\n') + if idx < 0 { + // If there are no newlines, buffer the entire string. + bw.buff.Write(line) + return nil } - // split lines - lines := bytes.Split(b, []byte{'\n'}) - for _, line := range lines { - logging.L().Info(z.prefix, string(line)) + // Split on the newline, buffer and flush the left. + line, remaining = line[:idx], line[idx+1:] + + // Fast path: if we don't have a partial message from a previous write + // in the buffer, skip the buffer and log directly. + if bw.buff.Len() == 0 { + bw.log(line) + return remaining + } + + bw.buff.Write(line) + bw.log(bw.buff.Bytes()) + bw.buff.Reset() + return remaining +} + +func (bw *bufferedWriter) log(line []byte) { + _, err := bw.writer.Write(line) + if err != nil { + logging.L().Errorw("failed to log", "err", err) } - return n, nil +} + +func (bw *bufferedWriter) Close() error { + if bw.buff.Len() != 0 { + bw.log(bw.buff.Bytes()) + bw.buff.Reset() + } + return nil +} + +type zapWriter struct { + prefix string +} + +func (zw *zapWriter) Write(p []byte) (n int, err error) { + logging.L().Info(zw.prefix, string(p)) + return len(p), nil } func streamAndCapture(cmd exec.Cmd, stdout, stderr io.Writer) (*ActResult, error) { if stdout == nil { - stdout = &zapWriter{ - prefix: "[STDOUT] ", + stdout = &bufferedWriter{ + writer: &zapWriter{ + prefix: "[STDOUT] ", + }, } } if stderr == nil { - stderr = &zapWriter{ - prefix: "[STDERR] ", + stderr = &bufferedWriter{ + writer: &zapWriter{ + prefix: "[STDERR] ", + }, } } @@ -75,8 +116,5 @@ func streamAndCapture(cmd exec.Cmd, stdout, stderr io.Writer) (*ActResult, error result := ActResult{} result.Stdout = outStr result.Stderr = errStr - if err != nil { - return nil, err - } return &result, nil } diff --git a/pkg/blocks/iocapture_test.go b/pkg/blocks/iocapture_test.go new file mode 100644 index 00000000..e79a6a56 --- /dev/null +++ b/pkg/blocks/iocapture_test.go @@ -0,0 +1,94 @@ +/* +Copyright © 2024-present, Meta Platforms, Inc. and affiliates +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +package blocks + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockWriter struct { + memory []byte +} + +func (mw *mockWriter) Write(p []byte) (n int, err error) { + mw.memory = append(mw.memory, p...) + return len(p), nil +} + +func (mw *mockWriter) GetMemory() []byte { + return mw.memory +} + +func TestBufferedWriter(t *testing.T) { + testCases := []struct { + name string + textChunks []string + wantError bool + }{ + { + name: "Finished line", + textChunks: []string{"Hello\n"}, + wantError: false, + }, { + name: "Unfinished line", + textChunks: []string{"Hello, world", "!\n"}, + wantError: false, + }, { + name: "No last newline", + textChunks: []string{"Hello,\nworld", "!\nfoobar"}, + wantError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockWriter := &mockWriter{} + zw := &bufferedWriter{ + writer: mockWriter, + } + var totalBytesWritten int + for _, text := range tc.textChunks { + bytesWritten, err := zw.Write([]byte(text)) + totalBytesWritten += bytesWritten + if tc.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } + zw.Close() + + expectedTotalBytesWritten := 0 + for _, text := range tc.textChunks { + expectedTotalBytesWritten += len(text) + } + assert.Equal(t, expectedTotalBytesWritten, totalBytesWritten) + expectedBytesWritten := []byte{} + for _, text := range tc.textChunks { + text = strings.ReplaceAll(text, "\n", "") + expectedBytesWritten = append(expectedBytesWritten, []byte(text)...) + } + assert.Equal(t, expectedBytesWritten, mockWriter.GetMemory()) + }) + } +}