-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
248 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package client | ||
|
||
import ( | ||
"context" | ||
"log" | ||
"time" | ||
|
||
"github.com/xmtp/xmtpd/pkg/authn" | ||
"github.com/xmtp/xmtpd/pkg/constants" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/metadata" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
// AuthInterceptor is a struct for holding the token and adding it to each request. | ||
type AuthInterceptor struct { | ||
tokenFactory *authn.TokenFactory | ||
targetNodeID uint32 | ||
currentToken *authn.Token | ||
} | ||
|
||
func NewAuthInterceptor( | ||
tokenFactory *authn.TokenFactory, | ||
targetNodeID uint32, | ||
) *AuthInterceptor { | ||
// This should never happen | ||
if tokenFactory == nil { | ||
log.Fatal("tokenFactory is required") | ||
} | ||
return &AuthInterceptor{ | ||
tokenFactory: tokenFactory, | ||
targetNodeID: targetNodeID, | ||
} | ||
} | ||
|
||
func (i *AuthInterceptor) getToken() (*authn.Token, error) { | ||
// If we have a token that is not expired (or nearing expiry) then return it | ||
if i.currentToken != nil && | ||
i.currentToken.ExpiresAt.After(time.Now().Add(authn.MAX_CLOCK_SKEW)) { | ||
return i.currentToken, nil | ||
} | ||
token, err := i.tokenFactory.CreateToken(i.targetNodeID) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
i.currentToken = token | ||
return token, nil | ||
} | ||
|
||
// Unary method to intercept requests and inject the token into headers. | ||
func (i *AuthInterceptor) Unary() grpc.UnaryClientInterceptor { | ||
return func( | ||
ctx context.Context, | ||
method string, | ||
req interface{}, | ||
reply interface{}, | ||
cc *grpc.ClientConn, | ||
invoker grpc.UnaryInvoker, | ||
opts ...grpc.CallOption, | ||
) error { | ||
token, err := i.getToken() | ||
if err != nil { | ||
return status.Errorf(codes.Unauthenticated, "failed to get token: %v", err) | ||
} | ||
// Create the metadata with the token | ||
md := metadata.Pairs(constants.NODE_AUTHORIZATION_HEADER_NAME, token.SignedString) | ||
// Attach metadata to the outgoing context | ||
ctx = metadata.NewOutgoingContext(ctx, md) | ||
|
||
// Proceed with the request | ||
return invoker(ctx, method, req, reply, cc, opts...) | ||
} | ||
} | ||
|
||
func (i *AuthInterceptor) Stream() grpc.StreamClientInterceptor { | ||
return func( | ||
ctx context.Context, | ||
desc *grpc.StreamDesc, | ||
cc *grpc.ClientConn, | ||
method string, | ||
streamer grpc.Streamer, | ||
opts ...grpc.CallOption, | ||
) (grpc.ClientStream, error) { | ||
token, err := i.getToken() | ||
if err != nil { | ||
return nil, status.Errorf(codes.Unauthenticated, "failed to get token: %v", err) | ||
} | ||
// Create the metadata with the token | ||
md := metadata.Pairs(constants.NODE_AUTHORIZATION_HEADER_NAME, token.SignedString) | ||
// Attach the metadata to the outgoing context | ||
ctx = metadata.NewOutgoingContext(ctx, md) | ||
|
||
// Proceed with the stream | ||
return streamer(ctx, desc, cc, method, opts...) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package client | ||
|
||
import ( | ||
"context" | ||
"net" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/xmtp/xmtpd/pkg/api" | ||
"github.com/xmtp/xmtpd/pkg/authn" | ||
"github.com/xmtp/xmtpd/pkg/constants" | ||
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" | ||
"github.com/xmtp/xmtpd/pkg/testutils" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/credentials/insecure" | ||
"google.golang.org/grpc/metadata" | ||
"google.golang.org/grpc/status" | ||
"google.golang.org/grpc/test/bufconn" | ||
) | ||
|
||
// Create a mock implementation of the ReplicationApiServer interface | ||
// but that embeds `UnimplementedReplicationApiServer` (which mockery won't do for us) | ||
type mockReplicationApiServer struct { | ||
api.Service | ||
expectedToken string | ||
} | ||
|
||
func (s *mockReplicationApiServer) QueryEnvelopes( | ||
ctx context.Context, | ||
req *message_api.QueryEnvelopesRequest, | ||
) (*message_api.QueryEnvelopesResponse, error) { | ||
// Get metadata from the context | ||
md, ok := metadata.FromIncomingContext(ctx) | ||
if !ok { | ||
return nil, status.Error(codes.Unauthenticated, "metadata is not provided") | ||
} | ||
|
||
// Extract and verify the token | ||
tokens := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME) | ||
if len(tokens) == 0 { | ||
return nil, status.Error(codes.Unauthenticated, "authorization token is not provided") | ||
} | ||
token := tokens[0] | ||
if token != s.expectedToken { | ||
return nil, status.Error(codes.Unauthenticated, "invalid authorization token") | ||
} | ||
|
||
// You can add more assertions here to verify the token's content | ||
// For example, you might want to decode the token and check its claims | ||
return &message_api.QueryEnvelopesResponse{}, nil | ||
} | ||
|
||
func TestAuthInterceptor(t *testing.T) { | ||
privateKey := testutils.RandomPrivateKey(t) | ||
myNodeID := uint32(100) | ||
targetNodeID := uint32(200) | ||
tokenFactory := authn.NewTokenFactory(privateKey, myNodeID) | ||
interceptor := NewAuthInterceptor(tokenFactory, targetNodeID) | ||
token, err := interceptor.getToken() | ||
require.NoError(t, err) | ||
|
||
// Use a bufconn listener to simulate a gRPC connection without actually dialing | ||
listener := bufconn.Listen(1024 * 1024) | ||
|
||
// Register the mock service on the server | ||
server := grpc.NewServer() | ||
message_api.RegisterReplicationApiServer( | ||
server, | ||
&mockReplicationApiServer{expectedToken: token.SignedString}, | ||
) | ||
|
||
// Start the gRPC server in a goroutine | ||
go func() { | ||
if err := server.Serve(listener); err != nil { | ||
t.Fail() | ||
} | ||
}() | ||
|
||
t.Cleanup(func() { | ||
server.Stop() | ||
listener.Close() | ||
}) | ||
|
||
// Connect to the fake server and set the right interceptors | ||
conn, err := grpc.NewClient( | ||
"passthrough://bufnet", | ||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { | ||
return listener.Dial() | ||
}), | ||
grpc.WithTransportCredentials(insecure.NewCredentials()), | ||
grpc.WithUnaryInterceptor(interceptor.Unary()), | ||
) | ||
require.NoError(t, err) | ||
defer conn.Close() | ||
|
||
// Create a client with the connection | ||
client := message_api.NewReplicationApiClient(conn) | ||
|
||
// Call the unary method and check the response | ||
_, err = client.QueryEnvelopes(context.Background(), &message_api.QueryEnvelopesRequest{}) | ||
require.NoError(t, err) | ||
|
||
// Create another client without the interceptor | ||
connWithoutInterceptor, err := grpc.NewClient( | ||
"passthrough://bufnet", | ||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { | ||
return listener.Dial() | ||
}), | ||
grpc.WithTransportCredentials(insecure.NewCredentials()), | ||
) | ||
require.NoError(t, err) | ||
defer connWithoutInterceptor.Close() | ||
|
||
client = message_api.NewReplicationApiClient(connWithoutInterceptor) | ||
|
||
// Call the unary method and check the response | ||
_, err = client.QueryEnvelopes(context.Background(), &message_api.QueryEnvelopesRequest{}) | ||
require.Error(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters