Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function management #62

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions pgbedrock/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import click
import psycopg2
import re


logger = logging.getLogger(__name__)
Expand All @@ -19,6 +20,7 @@
FAILED_QUERY_MSG = 'Failed to execute query "{}": {}'
UNSUPPORTED_CHAR_MSG = 'Role "{}" contains an unsupported character: \' or "'
PROGRESS_TEMPLATE = '%(label)s [%(bar)s] %(info)s'
FUNCTION_PARSING_RE = re.compile('^("[^"]+"|[^(]+)((.*))$')


def check_name(name):
Expand Down Expand Up @@ -76,24 +78,29 @@ class ObjectName(object):
* Be sure that when we get the fully-qualified name it will be double quoted
properly, i.e. "myschema"."mytable"
"""
def __init__(self, schema, unqualified_name=None):
def __init__(self, schema, unqualified_name=None, object_args=None):
# Make sure schema and table are both stored without double quotes around
# them; we add these when ObjectName.qualified_name is called
self._schema = self._unquoted_item(schema)
self._unqualified_name = self._unquoted_item(unqualified_name)
self._object_args = object_args or ''

if self._unqualified_name and self._unqualified_name == '*':
self._qualified_name = '{}.{}'.format(self.schema, self.unqualified_name)
elif self._unqualified_name and self._unqualified_name != '*':
# Note that if we decide to support "schema"."table" within YAML that we'll need to
# add a custom constructor since otherwise YAML gets confused unless you do
# '"schema"."table"'
self._qualified_name = '{}."{}"'.format(self.schema, self.unqualified_name)
self._qualified_name = '{}."{}"{}'.format(self.schema,
self.unqualified_name,
self.object_args)
else:
self._qualified_name = '{}'.format(self.schema)

def __eq__(self, other):
return (self.schema == other.schema) and (self.unqualified_name == other.unqualified_name)
return ((self.schema == other.schema) and
(self.unqualified_name == other.unqualified_name) and
(self.object_args == other.object_args))

def __hash__(self):
return hash(self.qualified_name)
Expand All @@ -103,12 +110,17 @@ def __lt__(self, other):

def __repr__(self):
if self.unqualified_name:
return "ObjectName('{}', '{}')".format(self.schema, self.unqualified_name)

if self.object_args:
return "ObjectName('{}', '{}', '{}')".format(self.schema,
self.unqualified_name,
self.object_args)
else:
return "ObjectName('{}', '{}')".format(self.schema,
self.unqualified_name)
return "ObjectName('{}')".format(self.schema)

@classmethod
def from_str(cls, text):
def from_str(cls, text, kind=None):
""" Convert a text representation of a qualified object name into an ObjectName instance

For example, 'foo.bar', '"foo".bar', '"foo"."bar"', etc. will be converted an object with
Expand All @@ -123,8 +135,14 @@ def from_str(cls, text):
# If there are multiple periods we assume that the first one delineates the schema from
# the rest of the object, i.e. foo.bar.baz means schema foo and object "bar.baz"
schema, unqualified_name = text.split('.', 1)
object_args = None
if kind == 'functions' and unqualified_name != '*':
groups = FUNCTION_PARSING_RE.match(unqualified_name).groups()
if len(groups) == 2:
unqualified_name, object_args = groups
# Don't worry about removing double quotes as that happens in __init__
return cls(schema=schema, unqualified_name=unqualified_name)
return cls(schema=schema, unqualified_name=unqualified_name,
object_args=object_args)

def only_schema(self):
""" Return an ObjectName instance for the schema associated with the current object """
Expand All @@ -134,6 +152,10 @@ def only_schema(self):
def schema(self):
return self._schema

@property
def object_args(self):
return self._object_args

@property
def unqualified_name(self):
return self._unqualified_name
Expand Down
59 changes: 50 additions & 9 deletions pgbedrock/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
SELECT
nsp.nspname AS schema,
c.relname AS unqualified_name,
NULL::text AS object_args,
map.objkind,
(aclexplode(c.relacl)).grantee AS grantee_oid,
t_owner.rolname AS owner,
Expand All @@ -78,25 +79,42 @@
SELECT
nsp.nspname AS schema,
NULL::TEXT AS unqualified_name,
NULL::text AS object_args,
'schemas'::TEXT AS objkind,
(aclexplode(nsp.nspacl)).grantee AS grantee_oid,
t_owner.rolname AS owner,
(aclexplode(nsp.nspacl)).privilege_type
FROM pg_namespace nsp
JOIN pg_authid t_owner
ON nsp.nspowner = t_owner.OID
), functions AS (
SELECT
nsp.nspname AS schema,
proname AS unqualified_name,
'(' || pg_get_function_identity_arguments(p.oid) || ')' AS object_args,
'functions'::TEXT as objkind,
(aclexplode(p.proacl)).grantee AS grantee_oid,
t_owner.rolname AS owner,
(aclexplode(p.proacl)).privilege_type
FROM pg_proc p
JOIN pg_namespace nsp on nsp.oid = pronamespace
JOIN pg_authid t_owner ON p.proowner = t_owner.oid
), combined AS (
SELECT *
FROM tables_and_sequences
UNION ALL
SELECT *
FROM schemas
UNION ALL
SELECT *
FROM functions
)
SELECT
t_grantee.rolname AS grantee,
combined.objkind,
combined.schema,
combined.unqualified_name,
combined.object_args,
combined.privilege_type
FROM
combined
Expand Down Expand Up @@ -149,6 +167,7 @@
map.kind,
nsp.nspname AS schema,
c.relname AS unqualified_name,
NULL::TEXT AS object_args,
c.relowner AS owner_id,
-- Auto-dependency means that a sequence is linked to a table. Ownership of
-- that sequence automatically derives from the table's ownership
Expand All @@ -174,20 +193,35 @@
'schemas'::TEXT AS kind,
nsp.nspname AS schema,
NULL::TEXT AS unqualified_name,
NULL::TEXT AS object_args,
nsp.nspowner AS owner_id,
FALSE AS is_dependent
FROM pg_namespace nsp
), functions AS (
SELECT
'functions'::TEXT as kind,
nsp.nspname AS schema,
proname as unqualified_name,
'(' || pg_get_function_identity_arguments(p.oid) || ')' as object_args,
p.proowner as owner_id,
FALSE AS is_dependent
FROM pg_proc p
JOIN pg_namespace nsp on nsp.oid = pronamespace
), combined AS (
SELECT *
FROM tables_and_sequences
UNION ALL
SELECT *
FROM schemas
UNION ALL
SELECT *
FROM functions
)
SELECT
co.kind,
co.schema,
co.unqualified_name,
co.object_args,
t_owner.rolname AS owner,
co.is_dependent
FROM combined AS co
Expand All @@ -209,10 +243,9 @@

Q_GET_VERSIONS = """
SELECT
substring(version from 'PostgreSQL ([0-9.]*) ') AS postgres_version,
substring(version from 'Redshift ([0-9.]*)') AS redshift_version,
version LIKE '%Redshift%' AS is_redshift
FROM version()
current_setting('server_version_num')::int as postgres_version,
substring(version() from 'Redshift ([0-9.]*)') AS redshift_version,
version() LIKE '%Redshift%' AS is_redshift
;
"""

Expand All @@ -231,6 +264,9 @@
{'read': ('USAGE', ),
'write': ('CREATE', )
},
'functions':
{'read': ('EXECUTE', ),
'write': ()}
}

ObjectInfo = namedtuple('ObjectInfo', ['kind', 'objname', 'owner', 'is_dependent'])
Expand Down Expand Up @@ -376,7 +412,8 @@ def get_all_current_nondefaults(self):
This will not include privileges granted by this role to itself
"""
NamedRow = namedtuple('NamedRow',
['grantee', 'objkind', 'schema', 'unqualified_name', 'privilege'])
['grantee', 'objkind', 'schema',
'unqualified_name', 'object_args', 'privilege'])
common.run_query(self.cursor, self.verbose, Q_GET_ALL_CURRENT_NONDEFAULTS)
current_nondefaults = defaultdict(dict)

Expand All @@ -392,8 +429,9 @@ def get_all_current_nondefaults(self):
'read': set(),
'write': set(),
}

objname = common.ObjectName(schema=row.schema, unqualified_name=row.unqualified_name)
objname = common.ObjectName(schema=row.schema,
unqualified_name=row.unqualified_name,
object_args=row.object_args)
entry = (objname, row.privilege)
role_nondefaults[row.objkind][access_key].add(entry)

Expand Down Expand Up @@ -436,10 +474,13 @@ def get_all_raw_object_attributes(self):
"""
common.run_query(self.cursor, self.verbose, Q_GET_ALL_RAW_OBJECT_ATTRIBUTES)
results = []
NamedRow = namedtuple('NamedRow', ['kind', 'schema', 'unqualified_name', 'owner', 'is_dependent'])
NamedRow = namedtuple('NamedRow', ['kind', 'schema', 'unqualified_name',
'object_args', 'owner', 'is_dependent'])
for i in self.cursor.fetchall():
row = NamedRow(*i)
objname = common.ObjectName(schema=row.schema, unqualified_name=row.unqualified_name)
objname = common.ObjectName(schema=row.schema,
unqualified_name=row.unqualified_name,
object_args=row.object_args)
entry = ObjectAttributes(row.kind, row.schema, objname, row.owner, row.is_dependent)
results.append(entry)
return results
Expand Down
6 changes: 4 additions & 2 deletions pgbedrock/spec_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
- schemas
- tables
- sequences
- functions
valueschema:
type: list
schema:
Expand All @@ -74,6 +75,7 @@
- schemas
- sequences
- tables
- functions
valueschema:
type: dict
allowed:
Expand Down Expand Up @@ -103,14 +105,14 @@ def convert_spec_to_objectnames(spec):
for objkind, owned_items in config.get('owns', {}).items():
if not owned_items:
continue
converted = [common.ObjectName.from_str(item) for item in owned_items]
converted = [common.ObjectName.from_str(item, objkind) for item in owned_items]
config['owns'][objkind] = converted

for objkind, perm_dicts in config.get('privileges', {}).items():
for priv_kind, granted_items in perm_dicts.items():
if not granted_items:
continue
converted = [common.ObjectName.from_str(item) for item in granted_items]
converted = [common.ObjectName.from_str(item, objkind) for item in granted_items]
config['privileges'][objkind][priv_kind] = converted

return output_spec
Expand Down
18 changes: 15 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def drop_users_and_objects(cursor):
WHERE rolname NOT IN (
'test_user', 'postgres', 'pg_signal_backend',
-- Roles introduced in Postgres 10:
'pg_monitor', 'pg_read_all_settings', 'pg_read_all_stats', 'pg_stat_scan_tables'
'pg_monitor', 'pg_read_all_settings', 'pg_read_all_stats', 'pg_stat_scan_tables',
-- Roles introduced in Postgres 11:
'pg_execute_server_program', 'pg_read_server_files', 'pg_write_server_files'
);
""")
users = [u[0] for u in cursor.fetchall()]
Expand Down Expand Up @@ -96,6 +98,9 @@ def base_spec(cursor):
tables:
- information_schema.*
- pg_catalog.*
functions:
- pg_catalog.*
- information_schema.*
privileges:
schemas:
write:
Expand All @@ -111,9 +116,9 @@ def base_spec(cursor):
""")

# Postgres 10 introduces several new roles that we have to account for
cursor.execute("SELECT substring(version from 'PostgreSQL ([0-9.]*) ') FROM version()")
cursor.execute("SELECT current_setting('server_version_num')::int")
pg_version = cursor.fetchone()[0]
if pg_version.startswith('10.'):
if pg_version >= 100000:
spec += dedent("""

pg_read_all_settings:
Expand All @@ -128,7 +133,14 @@ def base_spec(cursor):
- pg_stat_scan_tables
- pg_read_all_stats
""")
if pg_version >= 110000:
spec += dedent("""
pg_execute_server_program:

pg_read_server_files:

pg_write_server_files:
""")
return spec


Expand Down
15 changes: 12 additions & 3 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

Q_CREATE_TABLE = 'SET ROLE {}; CREATE TABLE {}.{} AS (SELECT 1+1); RESET ROLE;'
Q_CREATE_SEQUENCE = 'SET ROLE {}; CREATE SEQUENCE {}.{}; RESET ROLE;'
Q_CREATE_FUNCTION = 'SET ROLE {}; CREATE FUNCTION {}.{} RETURNS VOID AS $$$$ language SQL; RESET ROLE;'
Q_HAS_PRIVILEGE = "SELECT has_table_privilege('{}', '{}', 'SELECT');"

SCHEMAS = tuple('schema{}'.format(i) for i in range(4))
ROLES = tuple('role{}'.format(i) for i in range(4))
TABLES = tuple('table{}'.format(i) for i in range(6))
SEQUENCES = tuple('seq{}'.format(i) for i in range(6))
FUNCTIONS = tuple(['func()', 'func(int, text)', 'func2(int)'])
DUMMY = 'foo'


Expand All @@ -32,6 +34,10 @@
# Grant default privileges to role0 from role3 for this schema; these should get
# revoked in our test
privs.Q_GRANT_DEFAULT.format(ROLES[3], SCHEMAS[0], 'SELECT', 'TABLES', ROLES[0]),

# Add some functions
Q_CREATE_FUNCTION.format(ROLES[2], SCHEMAS[0], FUNCTIONS[0]),
Q_CREATE_FUNCTION.format(ROLES[2], SCHEMAS[0], FUNCTIONS[1]),
]
)
def test_get_all_current_defaults(cursor):
Expand Down Expand Up @@ -325,11 +331,14 @@ def test_get_all_role_attributes(cursor):
expected = set(['test_user', 'postgres', ROLES[0], ROLES[1]])
pg_version = dbcontext.get_version_info().postgres_version
# Postgres 10 introduces several new roles that we have to account for
if pg_version.startswith('10.'):
if pg_version >= 100000:
expected.update(set([
'pg_read_all_settings', 'pg_stat_scan_tables', 'pg_read_all_stats', 'pg_monitor']
))

if pg_version >= 110000:
expected.update(set([
'pg_execute_server_program', 'pg_read_server_files', 'pg_write_server_files']
))
actual = dbcontext.get_all_role_attributes()
assert set(actual.keys()) == expected

Expand Down Expand Up @@ -422,7 +431,7 @@ def test_get_all_memberships(cursor):
expected = set([('role1', 'role0'), ('role2', 'role1')])
pg_version = dbcontext.get_version_info().postgres_version
# Postgres 10 introduces several new roles and memberships that we have to account for
if pg_version.startswith('10.'):
if pg_version >= 100000:
expected.update(set([
('pg_monitor', 'pg_stat_scan_tables'),
('pg_monitor', 'pg_read_all_stats'),
Expand Down
Loading