Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(oomstore/join): return the result instead of writing directly to the file #1177

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions oomagent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package main

import (
"context"
"encoding/csv"
"fmt"
"io"
"log"
"os"
"time"

"google.golang.org/grpc/codes"
Expand All @@ -14,6 +16,7 @@ import (
"github.com/oom-ai/oomstore/pkg/errdefs"
"github.com/oom-ai/oomstore/pkg/oomstore"
"github.com/oom-ai/oomstore/pkg/oomstore/types"
"github.com/spf13/cast"
)

type server struct {
Expand Down Expand Up @@ -292,18 +295,55 @@ func (s *server) ChannelJoin(stream codegen.OomAgent_ChannelJoinServer) error {
}

func (s *server) Join(ctx context.Context, req *codegen.JoinRequest) (*codegen.JoinResponse, error) {
err := s.oomstore.Join(ctx, types.JoinOpt{
FeatureNames: req.Features,
InputFilePath: req.InputFile,
OutputFilePath: req.OutputFile,
ctx, cancel := context.WithCancel(ctx)
defer cancel()

joinResult, err := s.oomstore.Join(ctx, types.JoinOpt{
FeatureNames: req.Features,
InputFilePath: req.InputFile,
})
if err != nil {
return nil, internalError(err.Error())
}

if err := writeJoinResultToFile(req.OutputFile, joinResult); err != nil {
return nil, wrapErr(err)
}

return &codegen.JoinResponse{}, nil
}

func writeJoinResultToFile(outputFilePath string, joinResult *types.JoinResult) error {
lianxmfor marked this conversation as resolved.
Show resolved Hide resolved
file, err := os.Create(outputFilePath)
if err != nil {
return err
}
defer file.Close()
w := csv.NewWriter(file)
defer w.Flush()

if err := w.Write(joinResult.Header); err != nil {
return err
}
for row := range joinResult.Data {
if row.Error != nil {
return row.Error
}
if err := w.Write(joinRecord(row.Record)); err != nil {
return err
}
}
return nil
}

func joinRecord(row []interface{}) []string {
record := make([]string, 0, len(row))
for _, value := range row {
record = append(record, cast.ToString(value))
}
return record
}

func (s *server) ChannelExport(req *codegen.ChannelExportRequest, stream codegen.OomAgent_ChannelExportServer) error {
if len(req.Features) == 0 {
return nil
Expand Down
12 changes: 3 additions & 9 deletions oomcli/cmd/join_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@ func join(ctx context.Context, store *oomstore.OomStore, opt JoinOpt, output str
ctx, cancel := context.WithCancel(ctx)
defer cancel()

entityRows, header, err := oomstore.GetEntityRowsFromInputFile(ctx, opt.InputFilePath)
if err != nil {
return err
}

joinResult, err := store.ChannelJoin(ctx, types.ChannelJoinOpt{
JoinFeatureNames: opt.FeatureNames,
EntityRows: entityRows,
ExistedFeatureNames: header[2:],
joinResult, err := store.Join(ctx, types.JoinOpt{
FeatureNames: opt.FeatureNames,
InputFilePath: opt.InputFilePath,
})
if err != nil {
return err
Expand Down
54 changes: 7 additions & 47 deletions pkg/oomstore/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"sort"
"strconv"

"github.com/spf13/cast"

"github.com/oom-ai/oomstore/internal/database/offline"
"github.com/oom-ai/oomstore/pkg/errdefs"
"github.com/oom-ai/oomstore/pkg/oomstore/types"
Expand Down Expand Up @@ -73,31 +71,24 @@ func (s *OomStore) ChannelJoin(ctx context.Context, opt types.ChannelJoinOpt) (*
}

// Join gets point-in-time correct feature values for each entity row.
// The method is similar to Join, except that both input and output are files on disk.
// The method is similar to ChannelJoin, except a input files on disk.
// Input File should contain header, the first two columns of Input File should be
// entity_key, unix_milli, then followed by other real-time feature values.
func (s *OomStore) Join(ctx context.Context, opt types.JoinOpt) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

func (s *OomStore) Join(ctx context.Context, opt types.JoinOpt) (*types.JoinResult, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this way, Go SDK will behave differently from other SDKs, is that expected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes go SDK more flexible. Because other SDKs need to go through the oomagent layer of wrappers.
As for the inconsistent behavior, I'm not quite sure of the scope of the impact, we can discuss it.

if err := util.ValidateFullFeatureNames(opt.FeatureNames...); err != nil {
return err
return nil, err
}

entityRows, header, err := GetEntityRowsFromInputFile(ctx, opt.InputFilePath)
entityRows, header, err := getEntityRowsFromInputFile(ctx, opt.InputFilePath)
if err != nil {
return err
return nil, err
}

joinResult, err := s.ChannelJoin(ctx, types.ChannelJoinOpt{
return s.ChannelJoin(ctx, types.ChannelJoinOpt{
JoinFeatureNames: opt.FeatureNames,
EntityRows: entityRows,
ExistedFeatureNames: header[2:],
})
if err != nil {
return err
}
return writeJoinResultToFile(opt.OutputFilePath, joinResult)
}

func (s *OomStore) buildRevisionRanges(ctx context.Context, group *types.Group) ([]*offline.RevisionRange, error) {
Expand Down Expand Up @@ -142,7 +133,7 @@ func (s *OomStore) buildRevisionRanges(ctx context.Context, group *types.Group)
return ranges, nil
}

func GetEntityRowsFromInputFile(ctx context.Context, inputFilePath string) (<-chan types.EntityRow, []string, error) {
func getEntityRowsFromInputFile(ctx context.Context, inputFilePath string) (<-chan types.EntityRow, []string, error) {
input, err := os.Open(inputFilePath)
if err != nil {
return nil, nil, errdefs.WithStack(err)
Expand Down Expand Up @@ -207,34 +198,3 @@ func GetEntityRowsFromInputFile(ctx context.Context, inputFilePath string) (<-ch
}()
return entityRows, header, nil
}

func writeJoinResultToFile(outputFilePath string, joinResult *types.JoinResult) error {
file, err := os.Create(outputFilePath)
if err != nil {
return err
}
defer file.Close()
w := csv.NewWriter(file)
defer w.Flush()

if err := w.Write(joinResult.Header); err != nil {
return err
}
for row := range joinResult.Data {
if row.Error != nil {
return row.Error
}
if err := w.Write(joinRecord(row.Record)); err != nil {
return err
}
}
return nil
}

func joinRecord(row []interface{}) []string {
record := make([]string, 0, len(row))
for _, value := range row {
record = append(record, cast.ToString(value))
}
return record
}
5 changes: 2 additions & 3 deletions pkg/oomstore/types/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ type ChannelJoinOpt struct {
}

type JoinOpt struct {
FeatureNames []string
InputFilePath string
OutputFilePath string
FeatureNames []string
InputFilePath string
}

type UpdateEntityOpt struct {
Expand Down