Skip to content

Commit

Permalink
Add nat type to nat rule resource
Browse files Browse the repository at this point in the history
Signed-off-by: Anna Khmelnitsky <[email protected]>
  • Loading branch information
annakhm committed Dec 11, 2024
1 parent 35f6088 commit aee1f8b
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 39 deletions.
66 changes: 45 additions & 21 deletions nsxt/resource_nsxt_policy_nat_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ var policyNATRulePolicyBasedVpnModeTypeValues = []string{
model.PolicyNatRule_POLICY_BASED_VPN_MODE_MATCH,
}

var policyNATRuleTypeValues = []string{
model.PolicyNat_NAT_TYPE_INTERNAL,
model.PolicyNat_NAT_TYPE_USER,
model.PolicyNat_NAT_TYPE_DEFAULT,
model.PolicyNat_NAT_TYPE_NAT64,
}

func resourceNsxtPolicyNATRule() *schema.Resource {
return &schema.Resource{
Create: resourceNsxtPolicyNATRuleCreate,
Expand Down Expand Up @@ -144,6 +151,14 @@ func resourceNsxtPolicyNATRule() *schema.Resource {
Computed: true,
ValidateFunc: validation.StringInSlice(policyNATRulePolicyBasedVpnModeTypeValues, false),
},
"type": {
Type: schema.TypeString,
Description: "NAT Type",
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(policyNATRuleTypeValues, false),
Computed: true,
},
},
}
}
Expand Down Expand Up @@ -180,7 +195,8 @@ func resourceNsxtPolicyNATRuleDelete(d *schema.ResourceData, m interface{}) erro
}

action := d.Get("action").(string)
natType := getNatTypeByAction(action)
natType := d.Get("type").(string)
natType = getNatTypeByAction(natType, action)
err := deleteNsxtPolicyNATRule(context, getPolicyConnector(m), gwID, isT0, natType, id)
if err != nil {
return handleDeleteError("NAT Rule", id, err)
Expand All @@ -204,8 +220,7 @@ func getNsxtPolicyNATRuleByID(sessionContext utl.SessionContext, connector clien
return client.Get(gwID, natType, ruleID)
}

func patchNsxtPolicyNATRule(sessionContext utl.SessionContext, connector client.Connector, gwID string, rule model.PolicyNatRule, isT0 bool) error {
natType := getNatTypeByAction(*rule.Action)
func patchNsxtPolicyNATRule(sessionContext utl.SessionContext, connector client.Connector, gwID string, rule model.PolicyNatRule, isT0 bool, natType string) error {
_, err := getTranslatedNetworks(rule)
if err != nil {
return err
Expand All @@ -231,12 +246,15 @@ func patchNsxtPolicyNATRule(sessionContext utl.SessionContext, connector client.
return client.Patch(gwID, natType, *rule.Id, rule)
}

func getNatTypeByAction(action string) string {
func getNatTypeByAction(natType string, action string) string {
if action == model.PolicyNatRule_ACTION_NAT64 {
return model.PolicyNat_NAT_TYPE_NAT64
}
if natType == "" {
return model.PolicyNat_NAT_TYPE_USER
}

return model.PolicyNat_NAT_TYPE_USER
return natType
}

func translatedNetworksNeeded(action string) bool {
Expand Down Expand Up @@ -284,8 +302,14 @@ func resourceNsxtPolicyNATRuleRead(d *schema.ResourceData, m interface{}) error
return handleMultitenancyTier0Error()
}

action := d.Get("action").(string)
natType := getNatTypeByAction(action)
natType := d.Get("type").(string)
if natType == "" {
// This can happen when provider was upgraded and we're refreshing an existing resource
// This is not an import case, so action should be set
action := d.Get("action").(string)
natType = getNatTypeByAction(natType, action)
d.Set("type", natType)
}
obj, err := getNsxtPolicyNATRuleByID(context, connector, gwID, isT0, natType, id)
if err != nil {
return handleReadError(d, "NAT Rule", id, err)
Expand Down Expand Up @@ -327,7 +351,11 @@ func resourceNsxtPolicyNATRuleCreate(d *schema.ResourceData, m interface{}) erro

gwPolicyPath := d.Get("gateway_path").(string)
action := d.Get("action").(string)
natType := getNatTypeByAction(action)
natType := d.Get("type").(string)
// nat type attribute was introduced as explicit attribute when existing deployments
// were calculating it based on action
// for backward compatibility, we allow the type to be overridden by NAT64 action
natType = getNatTypeByAction(natType, action)
isT0, gwID := parseGatewayPolicyPath(gwPolicyPath)
if gwID == "" {
return fmt.Errorf("gateway_path is not valid")
Expand Down Expand Up @@ -394,13 +422,15 @@ func resourceNsxtPolicyNATRuleCreate(d *schema.ResourceData, m interface{}) erro

log.Printf("[INFO] Creating NAT Rule with ID %s", id)

err := patchNsxtPolicyNATRule(getSessionContext(d, m), connector, gwID, ruleStruct, isT0)
err := patchNsxtPolicyNATRule(getSessionContext(d, m), connector, gwID, ruleStruct, isT0, natType)
if err != nil {
return handleCreateError("NAT Rule", id, err)
}

d.SetId(id)
d.Set("nsx_id", id)
// In case nat type was not specified or got overridden by action
d.Set("type", natType)

return resourceNsxtPolicyNATRuleRead(d, m)
}
Expand All @@ -426,6 +456,8 @@ func resourceNsxtPolicyNATRuleUpdate(d *schema.ResourceData, m interface{}) erro
displayName := d.Get("display_name").(string)
description := d.Get("description").(string)
action := d.Get("action").(string)
natType := d.Get("type").(string)
natType = getNatTypeByAction(natType, action)
enabled := d.Get("enabled").(bool)
logging := d.Get("logging").(bool)
priority := int64(d.Get("rule_priority").(int))
Expand Down Expand Up @@ -467,7 +499,7 @@ func resourceNsxtPolicyNATRuleUpdate(d *schema.ResourceData, m interface{}) erro
}

log.Printf("[INFO] Updating NAT Rule with ID %s", id)
err := patchNsxtPolicyNATRule(context, connector, gwID, ruleStruct, isT0)
err := patchNsxtPolicyNATRule(context, connector, gwID, ruleStruct, isT0, natType)
if err != nil {
return handleUpdateError("NAT Rule", id, err)
}
Expand All @@ -492,24 +524,16 @@ func resourceNsxtPolicyNATRuleImport(d *schema.ResourceData, m interface{}) ([]*
if err != nil {
return nil, err
}
if natType == model.PolicyNat_NAT_TYPE_USER {
// Value will be overwritten by resourceNsxtPolicyNATRuleRead()
d.Set("action", model.PolicyNatRule_ACTION_DNAT)
} else {
d.Set("action", model.PolicyNatRule_ACTION_NAT64)
}
d.Set("type", natType)
return rd, nil
} else if !errors.Is(err, ErrNotAPolicyPath) {
return rd, err
}
if len(s) < 2 || len(s) > 3 {
return nil, fmt.Errorf("Please provide <gateway-id>/<nat-rule-id>/[nat-type] as an input")
}
if len(s) == 3 {
// take care of NAT64 nat-type via action
if s[2] == model.PolicyNat_NAT_TYPE_NAT64 {
d.Set("action", model.PolicyNatRule_ACTION_NAT64)
}
if len(s) < 3 {
d.Set("type", model.PolicyNat_NAT_TYPE_USER)
}

gwID := s[0]
Expand Down
94 changes: 76 additions & 18 deletions nsxt/resource_nsxt_policy_nat_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestAccResourceNsxtPolicyNATRule_minimalT0(t *testing.T) {
{
Config: testAccNsxtPolicyNATRuleTier0MinimalCreateTemplate(name, testAccResourcePolicyNATRuleSourceNet, testAccResourcePolicyNATRuleTransNet),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "source_networks.#", "1"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "translated_networks.#", "1"),
Expand All @@ -48,7 +48,7 @@ func TestAccResourceNsxtPolicyNATRule_minimalT0(t *testing.T) {
})
}

func TestAccResourceNsxtPolicyNATRule_basicT1(t *testing.T) {
func TestAccResourceNsxtPolicyNATRule_basic_T1(t *testing.T) {
testAccResourceNsxtPolicyNATRuleBasicT1(t, false, func() {
testAccPreCheck(t)
})
Expand Down Expand Up @@ -79,7 +79,7 @@ func testAccResourceNsxtPolicyNATRuleBasicT1(t *testing.T, withContext bool, pre
{
Config: testAccNsxtPolicyNATRuleTier1CreateTemplate(name, action, testAccResourcePolicyNATRuleSourceNet, testAccResourcePolicyNATRuleDestNet, testAccResourcePolicyNATRuleTransNet, withContext),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand All @@ -99,7 +99,7 @@ func testAccResourceNsxtPolicyNATRuleBasicT1(t *testing.T, withContext bool, pre
{
Config: testAccNsxtPolicyNATRuleTier1CreateTemplate(updateName, action, snet, dnet, tnet, withContext),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", updateName),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand All @@ -119,7 +119,7 @@ func testAccResourceNsxtPolicyNATRuleBasicT1(t *testing.T, withContext bool, pre
{
Config: testAccNsxtPolicyNATRuleTier1UpdateMultipleSourceNetworksTemplate(name, action, testAccResourcePolicyNATRuleSourceNet, snet, dnet, tnet, withContext),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand All @@ -139,6 +139,47 @@ func testAccResourceNsxtPolicyNATRuleBasicT1(t *testing.T, withContext bool, pre
})
}

func TestAccResourceNsxtPolicyNATRuleT1_natType(t *testing.T) {
name := getAccTestResourceName()

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
CheckDestroy: func(state *terraform.State) error {
return testAccNsxtPolicyNATRuleCheckDestroy(state, name, false)
},
Steps: []resource.TestStep{
{
Config: testAccNsxtPolicyNATRuleNatTypeTemplate(name, "DEFAULT", model.PolicyNatRule_ACTION_DNAT, "22.1.1.14", "33.1.1.14"),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "DEFAULT"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttrSet(testAccResourcePolicyNATRuleName, "path"),
resource.TestCheckResourceAttrSet(testAccResourcePolicyNATRuleName, "revision"),
),
},
{
Config: testAccNsxtPolicyNATRuleNatTypeTemplate(name, "USER", model.PolicyNatRule_ACTION_DNAT, "22.1.1.14", "33.1.1.14"),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttrSet(testAccResourcePolicyNATRuleName, "path"),
resource.TestCheckResourceAttrSet(testAccResourcePolicyNATRuleName, "revision"),
),
},
{
Config: testAccNsxtPolicyNATRuleNatTypeTemplate(name, "NAT64", model.PolicyNatRule_ACTION_NAT64, "2201::0014", "3301:1122::1280:0014"),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "NAT64"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttrSet(testAccResourcePolicyNATRuleName, "path"),
resource.TestCheckResourceAttrSet(testAccResourcePolicyNATRuleName, "revision"),
),
},
},
})
}

func TestAccResourceNsxtPolicyNATRule_withPolicyBasedVpnMode(t *testing.T) {
name := getAccTestResourceName()
updateName := getAccTestResourceName()
Expand All @@ -157,7 +198,7 @@ func TestAccResourceNsxtPolicyNATRule_withPolicyBasedVpnMode(t *testing.T) {
{
Config: testAccNsxtPolicyNATRuleTier1CreateTemplateWithPolicyBasedVpnMode(name, action, testAccResourcePolicyNATRuleSourceNet, testAccResourcePolicyNATRuleDestNet, testAccResourcePolicyNATRuleTransNet, model.PolicyNatRule_POLICY_BASED_VPN_MODE_BYPASS, false),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand All @@ -178,7 +219,7 @@ func TestAccResourceNsxtPolicyNATRule_withPolicyBasedVpnMode(t *testing.T) {
{
Config: testAccNsxtPolicyNATRuleTier1CreateTemplateWithPolicyBasedVpnMode(updateName, action, snet, dnet, tnet, model.PolicyNatRule_POLICY_BASED_VPN_MODE_MATCH, false),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", updateName),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand Down Expand Up @@ -219,7 +260,7 @@ func TestAccResourceNsxtPolicyNATRule_basicT0(t *testing.T) {
{
Config: testAccNsxtPolicyNATRuleTier0CreateTemplate(name, action, testAccResourcePolicyNATRuleSourceNet, testAccResourcePolicyNATRuleTransNet),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "source_networks.#", "1"),
Expand All @@ -238,7 +279,7 @@ func TestAccResourceNsxtPolicyNATRule_basicT0(t *testing.T) {
{
Config: testAccNsxtPolicyNATRuleTier0CreateTemplate(updateName, action, snet, tnet),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", updateName),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "source_networks.#", "1"),
Expand Down Expand Up @@ -325,7 +366,7 @@ func TestAccResourceNsxtPolicyNATRule_nat64T1(t *testing.T) {
{
Config: testAccNsxtPolicyNATRuleTier1CreateTemplate(name, action, snet, dnet, tnet, false),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, true),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "NAT64"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand All @@ -345,7 +386,7 @@ func TestAccResourceNsxtPolicyNATRule_nat64T1(t *testing.T) {
{
Config: testAccNsxtPolicyNATRuleTier1CreateTemplate(updateName, action, snet, dnet, tnet1, false),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, true),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "NAT64"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", updateName),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand Down Expand Up @@ -384,7 +425,7 @@ func TestAccResourceNsxtPolicyNATRuleNoSnatWithoutTNet(t *testing.T) {
{
Config: testAccNsxPolicyNatRuleNoTranslatedNetworkTemplate(name, action, testAccResourcePolicyNATRuleSourceNet, testAccResourcePolicyNATRuleDestNet),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", name),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand All @@ -402,7 +443,7 @@ func TestAccResourceNsxtPolicyNATRuleNoSnatWithoutTNet(t *testing.T) {
{
Config: testAccNsxPolicyNatRuleNoTranslatedNetworkTemplate(updateName, action, snet, dnet),
Check: resource.ComposeTestCheckFunc(
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, false),
testAccNsxtPolicyNATRuleExists(testAccResourcePolicyNATRuleName, "USER"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "display_name", updateName),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "description", "Acceptance Test"),
resource.TestCheckResourceAttr(testAccResourcePolicyNATRuleName, "destination_networks.#", "1"),
Expand Down Expand Up @@ -438,7 +479,7 @@ func testAccNSXPolicyNATRuleImporterGetID(s *terraform.State) (string, error) {
return fmt.Sprintf("%s/%s", gwID, resourceID), nil
}

func testAccNsxtPolicyNATRuleExists(resourceName string, isNat bool) resource.TestCheckFunc {
func testAccNsxtPolicyNATRuleExists(resourceName string, natType string) resource.TestCheckFunc {
return func(state *terraform.State) error {
connector := getPolicyConnector(testAccProvider.Meta().(nsxtClients))

Expand All @@ -453,10 +494,6 @@ func testAccNsxtPolicyNATRuleExists(resourceName string, isNat bool) resource.Te
}

gwPath := rs.Primary.Attributes["gateway_path"]
natType := model.PolicyNat_NAT_TYPE_USER
if isNat {
natType = model.PolicyNat_NAT_TYPE_NAT64
}
isT0, gwID := parseGatewayPolicyPath(gwPath)
_, err := getNsxtPolicyNATRuleByID(testAccGetSessionContext(), connector, gwID, isT0, natType, resourceID)
if err != nil {
Expand Down Expand Up @@ -503,6 +540,27 @@ resource "nsxt_policy_nat_rule" "test" {
`, name, model.PolicyNatRule_ACTION_REFLEXIVE, sourceNet, translatedNet)
}

func testAccNsxtPolicyNATRuleNatTypeTemplate(name string, natType string, action string, srcIP string, dstIP string) string {
return testAccNsxtPolicyEdgeClusterReadTemplate(getEdgeClusterName()) +
testAccNsxtPolicyTier1WithEdgeClusterTemplate("test", false, false) + fmt.Sprintf(`
data "nsxt_policy_service" "test" {
display_name = "DNS-UDP"
}
resource "nsxt_policy_nat_rule" "test" {
display_name = "%s"
type = "%s"
description = "Acceptance Test"
gateway_path = nsxt_policy_tier1_gateway.test.path
action = "%s"
source_networks = ["%s"]
destination_networks = ["%s"]
translated_networks = ["44.11.11.2"]
service = data.nsxt_policy_service.test.path
}
`, name, natType, action, srcIP, dstIP)
}

func testAccNsxtPolicyNATRuleTier1CreateTemplate(name string, action string, sourceNet string, destNet string, translatedNet string, withContext bool) string {
context := ""
if withContext {
Expand Down
Loading

0 comments on commit aee1f8b

Please sign in to comment.