Skip to content

Commit

Permalink
Merge branch 'viamrobotics:main' into dataWrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
jckras authored Nov 18, 2024
2 parents e8c6499 + 8f47d2c commit 649ab76
Show file tree
Hide file tree
Showing 44 changed files with 1,087 additions and 2,783 deletions.
68 changes: 68 additions & 0 deletions app/viam_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Package app contains all logic needed for communication and interaction with app.
package app

import (
"context"
"errors"
"net/url"
"strings"

"go.viam.com/utils/rpc"

"go.viam.com/rdk/logging"
)

// ViamClient is a gRPC client for method calls to Viam app.
type ViamClient struct {
conn rpc.ClientConn
}

// Options has the options necessary to connect through gRPC.
type Options struct {
baseURL string
entity string
credentials rpc.Credentials
}

var dialDirectGRPC = rpc.DialDirectGRPC

// CreateViamClientWithOptions creates a ViamClient with an Options struct.
func CreateViamClientWithOptions(ctx context.Context, options Options, logger logging.Logger) (*ViamClient, error) {
if options.baseURL == "" {
options.baseURL = "https://app.viam.com"
} else if !strings.HasPrefix(options.baseURL, "http://") && !strings.HasPrefix(options.baseURL, "https://") {
return nil, errors.New("use valid URL")
}
serviceHost, err := url.Parse(options.baseURL + ":443")
if err != nil {
return nil, err
}

if options.credentials.Payload == "" || options.entity == "" {
return nil, errors.New("entity and payload cannot be empty")
}
opts := rpc.WithEntityCredentials(options.entity, options.credentials)

conn, err := dialDirectGRPC(ctx, serviceHost.Host, logger, opts)
if err != nil {
return nil, err
}
return &ViamClient{conn: conn}, nil
}

// CreateViamClientWithAPIKey creates a ViamClient with an API key.
func CreateViamClientWithAPIKey(
ctx context.Context, options Options, apiKey, apiKeyID string, logger logging.Logger,
) (*ViamClient, error) {
options.entity = apiKeyID
options.credentials = rpc.Credentials{
Type: rpc.CredentialsTypeAPIKey,
Payload: apiKey,
}
return CreateViamClientWithOptions(ctx, options, logger)
}

// Close closes the gRPC connection.
func (c *ViamClient) Close() error {
return c.conn.Close()
}
118 changes: 118 additions & 0 deletions app/viam_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package app

import (
"context"
"testing"

"github.com/viamrobotics/webrtc/v3"
"go.viam.com/utils"
"go.viam.com/utils/rpc"
"google.golang.org/grpc"

"go.viam.com/rdk/logging"
)

var (
logger = logging.NewLogger("test")
defaultURL = "https://app.viam.com"
testAPIKey = "abcdefghijklmnopqrstuv0123456789"
testAPIKeyID = "abcd0123-ef45-gh67-ij89-klmnopqr01234567"
)

type MockConn struct{}

func (m *MockConn) NewStream(
ctx context.Context,
desc *grpc.StreamDesc,
method string,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
return nil, nil
}

func (m *MockConn) Invoke(ctx context.Context, method string, args, reply any, opts ...grpc.CallOption) error {
return nil
}
func (m *MockConn) PeerConn() *webrtc.PeerConnection { return nil }
func (m *MockConn) Close() error { return nil }
func mockDialDirectGRPC(
ctx context.Context,
address string,
logger utils.ZapCompatibleLogger,
opts ...rpc.DialOption,
) (rpc.ClientConn, error) {
return &MockConn{}, nil
}

func TestCreateViamClientWithOptions(t *testing.T) {
urlTests := []struct {
name string
baseURL string
entity string
payload string
expectErr bool
}{
{"Default URL", defaultURL, testAPIKeyID, testAPIKey, false},
{"Default URL", defaultURL, "", "", true},
{"Default URL", defaultURL, "", testAPIKey, true},
{"Default URL", defaultURL, testAPIKeyID, "", true},
{name: "No URL", entity: testAPIKey, payload: testAPIKey, expectErr: false},
{"Empty URL", "", testAPIKeyID, testAPIKey, false},
{"Valid URL", "https://test.com", testAPIKeyID, testAPIKey, false},
{"Invalid URL", "test", testAPIKey, testAPIKey, true},
}
originalDialDirectGRPC := dialDirectGRPC
dialDirectGRPC = mockDialDirectGRPC
defer func() { dialDirectGRPC = originalDialDirectGRPC }()
for _, tt := range urlTests {
t.Run(tt.name, func(t *testing.T) {
opts := Options{
baseURL: tt.baseURL,
entity: tt.entity,
credentials: rpc.Credentials{
Type: rpc.CredentialsTypeAPIKey,
Payload: tt.payload,
},
}
client, err := CreateViamClientWithOptions(context.Background(), opts, logger)
if (err != nil) != tt.expectErr {
t.Errorf("Expected error: %v, got: %v", tt.expectErr, err)
}
if !tt.expectErr {
if client == nil {
t.Error("Expected a valid client, got nil")
} else {
client.Close()
}
}
})
}
}

func TestCreateViamClientWithAPIKeyTests(t *testing.T) {
apiKeyTests := []struct {
name string
apiKey string
apiKeyID string
expectErr bool
}{
{"Valid API Key", testAPIKey, testAPIKeyID, false},
{"Empty API Key", "", testAPIKeyID, true},
{"Empty API Key ID", testAPIKey, "", true},
}
for _, tt := range apiKeyTests {
t.Run(tt.name, func(t *testing.T) {
client, err := CreateViamClientWithAPIKey(context.Background(), Options{}, tt.apiKey, tt.apiKeyID, logger)
if (err != nil) != tt.expectErr {
t.Errorf("Expected error: %v, got: %v", tt.expectErr, err)
}
if !tt.expectErr {
if client == nil {
t.Error("Expected a valid client, got nil")
} else {
client.Close()
}
}
})
}
}
9 changes: 9 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ const (
packageFlagType = "type"
packageFlagDestination = "destination"
packageFlagPath = "path"
packageFlagFramework = "model-framework"

packageMetadataFlagFramework = "model_framework"

authApplicationFlagName = "application-name"
authApplicationFlagApplicationID = "application-id"
Expand Down Expand Up @@ -1913,6 +1916,12 @@ This won't work unless you have an existing installation of our GitHub app on yo
Required: true,
Usage: "type of the requested package, can be: " + strings.Join(packageTypes, ", "),
},
&cli.StringFlag{
Name: packageFlagFramework,
Required: false,
Usage: "framework for an ml_model being uploaded, can be: " +
strings.Join(modelFrameworks, ", ") + ", Required if packages if of type `ml_model`",
},
},
Action: PackageUploadAction,
},
Expand Down
31 changes: 30 additions & 1 deletion cli/packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path"
"path/filepath"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -183,13 +184,25 @@ func PackageUploadAction(c *cli.Context) error {
return err
}

if err := validatePackageUploadRequest(c); err != nil {
return err
}

resp, err := client.uploadPackage(
c.String(generalFlagOrgID),
c.String(packageFlagName),
c.String(packageFlagVersion),
c.String(packageFlagType),
c.Path(packageFlagPath),
nil,
&structpb.Struct{
Fields: map[string]*structpb.Value{
packageMetadataFlagFramework: {
Kind: &structpb.Value_StringValue{
StringValue: c.String(packageFlagFramework),
},
},
},
},
)
if err != nil {
return err
Expand Down Expand Up @@ -274,3 +287,19 @@ func getNextPackageUploadRequest(file *os.File) (*packagespb.CreatePackageReques
func (m *moduleID) ToDetailURL(baseURL string, packageType PackageType) string {
return fmt.Sprintf("https://%s/%s/%s/%s", baseURL, strings.ReplaceAll(string(packageType), "_", "-"), m.prefix, m.name)
}

func validatePackageUploadRequest(c *cli.Context) error {
packageType := c.String(packageFlagType)

if packageType == "ml_model" {
if c.String(packageFlagFramework) == "" {
return errors.New("must pass in a model-framework if package is of type `ml_model`")
}

if !slices.Contains(modelFrameworks, c.String(packageFlagFramework)) {
return errors.New("framework must be of type " + strings.Join(modelFrameworks, ", "))
}
}

return nil
}
File renamed without changes.
61 changes: 34 additions & 27 deletions components/arm/collectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"testing"
"time"

clk "github.com/benbjohnson/clock"
"github.com/benbjohnson/clock"
"github.com/golang/geo/r3"
v1 "go.viam.com/api/common/v1"
datasyncpb "go.viam.com/api/app/datasync/v1"
pb "go.viam.com/api/component/arm/v1"
"go.viam.com/test"

Expand All @@ -22,8 +22,7 @@ import (

const (
componentName = "arm"
captureInterval = time.Second
numRetries = 5
captureInterval = time.Millisecond
)

var floatList = &pb.JointPositions{Values: []float64{1.0, 2.0, 3.0}}
Expand All @@ -32,40 +31,50 @@ func TestCollectors(t *testing.T) {
tests := []struct {
name string
collector data.CollectorConstructor
expected map[string]any
expected *datasyncpb.SensorData
}{
{
name: "End position collector should write a pose",
collector: arm.NewEndPositionCollector,
expected: tu.ToProtoMapIgnoreOmitEmpty(pb.GetEndPositionResponse{
Pose: &v1.Pose{
OX: 0,
OY: 0,
OZ: 1,
Theta: 0,
X: 1,
Y: 2,
Z: 3,
},
}),
expected: &datasyncpb.SensorData{
Metadata: &datasyncpb.SensorMetadata{},
Data: &datasyncpb.SensorData_Struct{Struct: tu.ToStructPBStruct(t, map[string]any{
"pose": map[string]any{
"o_x": 0,
"o_y": 0,
"o_z": 1,
"theta": 0,
"x": 1,
"y": 2,
"z": 3,
},
})},
},
},
{
name: "Joint positions collector should write a list of positions",
collector: arm.NewJointPositionsCollector,
expected: tu.ToProtoMapIgnoreOmitEmpty(pb.GetJointPositionsResponse{Positions: floatList}),
expected: &datasyncpb.SensorData{
Metadata: &datasyncpb.SensorMetadata{},
Data: &datasyncpb.SensorData_Struct{Struct: tu.ToStructPBStruct(t, map[string]any{
"positions": map[string]any{
"values": []any{1.0, 2.0, 3.0},
},
})},
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockClock := clk.NewMock()
buf := tu.MockBuffer{}
start := time.Now()
buf := tu.NewMockBuffer()
params := data.CollectorParams{
ComponentName: componentName,
Interval: captureInterval,
Logger: logging.NewTestLogger(t),
Clock: mockClock,
Target: &buf,
Clock: clock.New(),
Target: buf,
}

arm := newArm()
Expand All @@ -74,13 +83,11 @@ func TestCollectors(t *testing.T) {

defer col.Close()
col.Collect()
mockClock.Add(captureInterval)

tu.Retry(func() bool {
return buf.Length() != 0
}, numRetries)
test.That(t, buf.Length(), test.ShouldBeGreaterThan, 0)
test.That(t, buf.Writes[0].GetStruct().AsMap(), test.ShouldResemble, tc.expected)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
tu.CheckMockBufferWrites(t, ctx, start, buf.Writes, tc.expected)
buf.Close()
})
}
}
Expand Down
File renamed without changes.
Loading

0 comments on commit 649ab76

Please sign in to comment.