Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new credential entry point discovery #15685

Open
wants to merge 8 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions awx/main/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def _load_credential_types_feature(self):

@bypass_in_test
def load_credential_types_feature(self):
from awx.main.models.credential import load_credentials

load_credentials()
return self._load_credential_types_feature()

def load_inventory_plugins(self):
Expand Down
80 changes: 48 additions & 32 deletions awx/main/models/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# All Rights Reserved.
from contextlib import nullcontext
import functools

import inspect
import logging
from importlib.metadata import entry_points
Expand Down Expand Up @@ -45,6 +46,8 @@
)
from awx.main.models import Team, Organization
from awx.main.utils import encrypt_field
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name


# DAB
from ansible_base.resource_registry.tasks.sync import get_resource_server_client
Expand All @@ -54,7 +57,6 @@
__all__ = ['Credential', 'CredentialType', 'CredentialInputSource', 'build_safe_env']

logger = logging.getLogger('awx.main.models.credential')
credential_plugins = {entry_point.name: entry_point.load() for entry_point in entry_points(group='awx_plugins.credentials')}

HIDDEN_PASSWORD = '**********'

Expand Down Expand Up @@ -472,7 +474,7 @@ def default_for_field(self, field_id):

@classproperty
def defaults(cls):
return dict((k, functools.partial(v.create)) for k, v in ManagedCredentialType.registry.items())
return dict((k, functools.partial(CredentialTypeHelper.create, v)) for k, v in ManagedCredentialType.registry.items())

@classmethod
def _get_credential_type_class(cls, apps: Apps = None, app_config: AppConfig = None):
Expand Down Expand Up @@ -507,7 +509,7 @@ def _setup_tower_managed_defaults(cls, apps: Apps = None, app_config: AppConfig
existing.save()
continue
logger.debug(_("adding %s credential type" % default.name))
params = default.get_creation_params()
params = CredentialTypeHelper.get_creation_params(default)
if 'managed' not in [f.name for f in ct_class._meta.get_fields()]:
params['managed_by_tower'] = params.pop('managed')
params['created'] = params['modified'] = now() # CreatedModifiedModel service
Expand Down Expand Up @@ -541,46 +543,37 @@ def setup_tower_managed_defaults(cls, apps: Apps = None, app_config: AppConfig =
@classmethod
def load_plugin(cls, ns, plugin):
# TODO: User "side-loaded" credential custom_injectors isn't supported
ManagedCredentialType(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs)
ManagedCredentialType.registry[ns] = ManagedCredentialType(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs, injectors={})

def inject_credential(self, credential, env, safe_env, args, private_data_dir):
from awx_plugins.interfaces._temporary_private_inject_api import inject_credential

inject_credential(self, credential, env, safe_env, args, private_data_dir)


class ManagedCredentialType(SimpleNamespace):
registry = {}

def __init__(self, namespace, **kwargs):
for k in ('inputs', 'injectors'):
if k not in kwargs:
kwargs[k] = {}
super(ManagedCredentialType, self).__init__(namespace=namespace, **kwargs)
if namespace in ManagedCredentialType.registry:
raise ValueError(
'a ManagedCredentialType with namespace={} is already defined in {}'.format(
namespace, inspect.getsourcefile(ManagedCredentialType.registry[namespace].__class__)
)
)
ManagedCredentialType.registry[namespace] = self

def get_creation_params(self):
class CredentialTypeHelper:
@classmethod
def get_creation_params(cls, cred_type):
return dict(
namespace=self.namespace,
kind=self.kind,
name=self.name,
namespace=cred_type.namespace,
kind=cred_type.kind,
name=cred_type.name,
managed=True,
inputs=self.inputs,
injectors=self.injectors,
inputs=cred_type.inputs,
injectors=cred_type.injectors,
)

def create(self):
res = CredentialType(**self.get_creation_params())
res.custom_injectors = getattr(self, 'custom_injectors', None)
@classmethod
def create(cls, cred_type):
res = CredentialType(**CredentialTypeHelper.get_creation_params(cred_type))
res.custom_injectors = getattr(cred_type, "custom_injectors", None)
return res


class ManagedCredentialType(SimpleNamespace):
registry = {}


class CredentialInputSource(PrimordialModel):
class Meta:
app_label = 'main'
Expand Down Expand Up @@ -645,7 +638,30 @@ def get_absolute_url(self, request=None):
return reverse(view_name, kwargs={'pk': self.pk}, request=request)


from awx_plugins.credentials.plugins import * # noqa
def load_credentials():

awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')}
supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')}
plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points}

for ns, ep in plugin_entry_points.items():
cred_plugin = ep.load()
if not hasattr(cred_plugin, 'inputs'):
setattr(cred_plugin, 'inputs', {})
if not hasattr(cred_plugin, 'injectors'):
setattr(cred_plugin, 'injectors', {})
if ns in ManagedCredentialType.registry:
raise ValueError(
'a ManagedCredentialType with namespace={} is already defined in {}'.format(
ns, inspect.getsourcefile(ManagedCredentialType.registry[ns].__class__)
)
)
ManagedCredentialType.registry[ns] = cred_plugin

credential_plugins = {ep.name: ep for ep in entry_points(group='awx_plugins.credentials')}
if detect_server_product_name() == 'AWX':
credential_plugins = {}

for ns, plugin in credential_plugins.items():
CredentialType.load_plugin(ns, plugin)
for ns, ep in credential_plugins.items():
plugin = ep.load()
CredentialType.load_plugin(ns, plugin)
9 changes: 9 additions & 0 deletions awx/main/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,12 @@ def mock_me():
me_mock = mock.MagicMock(return_value=Instance(id=1, hostname=settings.CLUSTER_HOST_ID, uuid='00000000-0000-0000-0000-000000000000'))
with mock.patch.object(Instance.objects, 'me', me_mock):
yield


@pytest.fixture(scope="session", autouse=True)
def load_all_credentials():
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
from awx.main.models.credential import load_credentials

load_credentials()
yield
16 changes: 12 additions & 4 deletions awx/main/tests/functional/test_inventory_source_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,24 @@ def cleanup_cloudforms():
assert 'cloudforms' not in CredentialType.defaults


@pytest.mark.django_db
def test_cloudforms_inventory_removal(request, inventory):
request.addfinalizer(cleanup_cloudforms)
ManagedCredentialType(
@pytest.fixture
def cloudforms_mct():
ManagedCredentialType.registry['cloudforms'] = ManagedCredentialType(
name='Red Hat CloudForms',
namespace='cloudforms',
kind='cloud',
managed=True,
inputs={},
injectors={},
)
yield
ManagedCredentialType.registry.pop('cloudforms', None)


@pytest.mark.django_db
def test_cloudforms_inventory_removal(request, inventory, cloudforms_mct):
request.addfinalizer(cleanup_cloudforms)

CredentialType.defaults['cloudforms']().save()
cloudforms = CredentialType.objects.get(namespace='cloudforms')
Credential.objects.create(
Expand Down
Loading