Skip to content

Commit

Permalink
Disallow multiple masks on a given column, trino trinodb/trino@bdd1cb5
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler committed Dec 20, 2024
1 parent 0e6e057 commit 1aefdf0
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public class Analysis
private final Map<NodeRef<Table>, List<Expression>> rowFilters = new LinkedHashMap<>();

private final Multiset<ColumnMaskScopeEntry> columnMaskScopes = HashMultiset.create();
private final Map<NodeRef<Table>, Map<String, List<Expression>>> columnMasks = new LinkedHashMap<>();
private final Map<NodeRef<Table>, Map<String, Expression>> columnMasks = new LinkedHashMap<>();

// for create table
private Optional<QualifiedObjectName> createTableDestination = Optional.empty();
Expand Down Expand Up @@ -1047,12 +1047,12 @@ public void unregisterTableForColumnMasking(QualifiedObjectName table, String co

public void addColumnMask(Table table, String column, Expression mask)
{
Map<String, List<Expression>> masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>());
masks.computeIfAbsent(column, name -> new ArrayList<>())
.add(mask);
Map<String, Expression> masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>());
checkArgument(!masks.containsKey(column), "Mask already exists for column %s", column);
masks.put(column, mask);
}

public Map<String, List<Expression>> getColumnMasks(Table table)
public Map<String, Expression> getColumnMasks(Table table)
{
return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.hive.security;

import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.hive.HiveColumnConverterProvider;
import com.facebook.presto.hive.HiveTransactionManager;
import com.facebook.presto.hive.TransactionalMetadata;
Expand Down Expand Up @@ -288,4 +289,10 @@ public Optional<ViewExpression> getRowFilter(ConnectorTransactionHandle transact
{
return Optional.empty();
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.hive.security;

import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.hive.HiveColumnConverterProvider;
import com.facebook.presto.hive.HiveConnectorId;
import com.facebook.presto.hive.HiveTransactionManager;
Expand Down Expand Up @@ -688,6 +689,12 @@ public Optional<ViewExpression> getRowFilter(ConnectorTransactionHandle transact
return Optional.empty();
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}

private boolean isAdmin(ConnectorTransactionHandle transaction, ConnectorIdentity identity, MetastoreContext metastoreContext)
{
SemiTransactionalHiveMetastore metastore = getMetastore(transaction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.concurrent.atomic.AtomicReference;

import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_COLUMN_MASK;
import static com.facebook.presto.spi.StandardErrorCode.SERVER_STARTING_UP;
import static com.facebook.presto.util.PropertiesUtil.loadProperties;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -805,7 +806,7 @@ public List<ViewExpression> getRowFilters(TransactionId transactionId, Identity
}

@Override
public List<ViewExpression> getColumnMasks(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, String columnName, Type type)
public Optional<ViewExpression> getColumnMask(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, String columnName, Type type)
{
requireNonNull(transactionId, "transactionId is null");
requireNonNull(identity, "identity is null");
Expand All @@ -823,7 +824,12 @@ public List<ViewExpression> getColumnMasks(TransactionId transactionId, Identity
systemAccessControl.get().getColumnMask(identity, context, toCatalogSchemaTableName(tableName), columnName, type)
.ifPresent(masks::add);

return masks.build();
List<ViewExpression> allMasks = masks.build();
if (allMasks.size() > 1) {
throw new PrestoException(INVALID_COLUMN_MASK, format("Column must have a single mask: %s", columnName));
}

return allMasks.stream().findFirst();
}

private CatalogAccessControlEntry getConnectorAccessControl(TransactionId transactionId, String catalogName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.security;

import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.CatalogSchemaTableName;
import com.facebook.presto.spi.SchemaTableName;
import com.facebook.presto.spi.security.AccessControlContext;
Expand Down Expand Up @@ -233,4 +234,10 @@ public Optional<ViewExpression> getRowFilter(Identity identity, AccessControlCon
{
return Optional.empty();
}

@Override
public Optional<ViewExpression> getColumnMask(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.airlift.log.Logger;
import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.plugin.base.security.ForwardingSystemAccessControl;
import com.facebook.presto.plugin.base.security.SchemaAccessControlRule;
import com.facebook.presto.security.CatalogAccessControlRule.AccessMode;
Expand Down Expand Up @@ -448,6 +449,12 @@ public Optional<ViewExpression> getRowFilter(Identity identity, AccessControlCon
return Optional.empty();
}

@Override
public Optional<ViewExpression> getColumnMask(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}

private boolean isSchemaOwner(Identity identity, CatalogSchemaName schema)
{
if (!canAccessCatalog(identity, schema.getCatalogName(), ALL)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
import com.facebook.presto.spi.PrestoWarning;
import com.facebook.presto.spi.SchemaTableName;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.TableMetadata;
import com.facebook.presto.spi.TableNotFoundException;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.analyzer.AccessControlInfoForTable;
import com.facebook.presto.spi.analyzer.MetadataResolver;
Expand Down Expand Up @@ -423,7 +421,7 @@ protected Scope visitInsert(Insert insert, Optional<Scope> scope)
List<ColumnMetadata> columnsMetadata = tableColumnsMetadata.getColumnsMetadata();

for (ColumnMetadata column : columnsMetadata) {
if (!accessControl.getColumnMasks(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), targetTable, column.getName(), column.getType()).isEmpty()) {
if (accessControl.getColumnMask(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), targetTable, column.getName(), column.getType()).isPresent()) {
throw new SemanticException(NOT_SUPPORTED, insert, "Insert into table with column masks is not supported");
}
}
Expand Down Expand Up @@ -618,11 +616,11 @@ protected Scope visitDelete(Delete node, Optional<Scope> scope)
if (!accessControl.getRowFilters(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName).isEmpty()) {
throw new SemanticException(NOT_SUPPORTED, node, "Delete from table with row filter is not supported");
}

TableColumnMetadata tableColumnsMetadata = getTableColumnsMetadata(session, metadataResolver, analysis.getMetadataHandle(), tableName);
List<ColumnMetadata> columnsMetadata = tableColumnsMetadata.getColumnsMetadata();
for (ColumnMetadata columnMetadata : columnsMetadata) {
if (!accessControl.getColumnMasks(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName, columnMetadata.getName(), columnMetadata.getType()).isEmpty()) {
if (accessControl.getColumnMask(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName, columnMetadata.getName(), columnMetadata.getType()).isPresent()) {
throw new SemanticException(NOT_SUPPORTED, node, "Delete from table with column mask is not supported");
}
}
Expand Down Expand Up @@ -1394,13 +1392,15 @@ protected Scope visitTable(Table table, Optional<Scope> scope)
.build();

for (Field field : outputFields) {
accessControl.getColumnMasks(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), name, field.getName().get(), field.getType())
.forEach(mask -> analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask));
Optional<ViewExpression> mask = accessControl.getColumnMask(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), name, field.getName().get(), field.getType());
mask.ifPresent(viewExpression -> analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, viewExpression));
}

accessControl.getRowFilters(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), name)
.forEach(filter -> analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter));

analysis.registerTable(table, tableHandle.get());

if (statement instanceof RefreshMaterializedView) {
Table view = ((RefreshMaterializedView) statement).getTarget();
if (!table.equals(view) && !analysis.hasTableRegisteredForMaterializedView(view, table)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,24 @@ private RelationPlan addRowFilters(Table node, RelationPlan plan, SqlPlannerCont

private RelationPlan addColumnMasks(Table table, RelationPlan plan, SqlPlannerContext context)
{
Map<String, List<Expression>> columnMasks = analysis.getColumnMasks(table);
Map<String, Expression> columnMasks = analysis.getColumnMasks(table);

PlanNode root = plan.getRoot();
List<VariableReferenceExpression> mappings = plan.getFieldMappings();

TranslationMap translations = new TranslationMap(plan, analysis, lambdaDeclarationToVariableMap);
translations.setFieldMappings(mappings);

PlanBuilder planBuilder = new PlanBuilder(translations, root);
PlanBuilder planBuilder = new PlanBuilder(translations, plan.getRoot());

for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) {
Field field = plan.getDescriptor().getFieldByIndex(i);

for (Expression mask : columnMasks.getOrDefault(field.getName().get(), ImmutableList.of())) {
if (field.getName().isPresent() && columnMasks.containsKey(field.getName().get())) {
Expression mask = columnMasks.get(field.getName().get());

planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, mask, context);

Map<VariableReferenceExpression, RowExpression> assignments = new LinkedHashMap<>();
for (VariableReferenceExpression variableReferenceExpression : root.getOutputVariables()) {
for (VariableReferenceExpression variableReferenceExpression : planBuilder.getRoot().getOutputVariables()) {
assignments.put(variableReferenceExpression, rowExpression(new SymbolReference(variableReferenceExpression.getName()), context));
}
assignments.put(mappings.get(i), rowExpression(translations.rewrite(mask), context));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public class TestingAccessControlManager
{
private final Set<TestingPrivilege> denyPrivileges = new HashSet<>();
private final Map<RowFilterKey, List<ViewExpression>> rowFilters = new HashMap<>();
private final Map<ColumnMaskKey, List<ViewExpression>> columnMasks = new HashMap<>();
private final Map<ColumnMaskKey, ViewExpression> columnMasks = new HashMap<>();

@Inject
public TestingAccessControlManager(TransactionManager transactionManager)
Expand Down Expand Up @@ -133,8 +133,7 @@ public void rowFilter(QualifiedObjectName table, String identity, ViewExpression

public void columnMask(QualifiedObjectName table, String column, String identity, ViewExpression mask)
{
columnMasks.computeIfAbsent(new ColumnMaskKey(identity, table, column), key -> new ArrayList<>())
.add(mask);
columnMasks.put(new ColumnMaskKey(identity, table, column), mask);
}

@Override
Expand Down Expand Up @@ -394,9 +393,11 @@ public List<ViewExpression> getRowFilters(TransactionId transactionId, Identity
}

@Override
public List<ViewExpression> getColumnMasks(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, String column, Type type)
public Optional<ViewExpression> getColumnMask(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, String column, Type type)
{
return columnMasks.getOrDefault(new ColumnMaskKey(identity.getUser(), tableName, column), ImmutableList.of());
return Optional.ofNullable(
columnMasks.getOrDefault(new ColumnMaskKey(identity.getUser(), tableName, column),
super.getColumnMask(transactionId, identity, context, tableName, column, type).orElse(null)));
}

private boolean shouldDenyPrivilege(String userName, String entityName, TestingPrivilegeType type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@

import java.security.Principal;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId;
import static com.facebook.presto.spi.ConnectorId.createSystemTablesConnectorId;
import static com.facebook.presto.spi.security.AccessDeniedException.denyQueryIntegrityCheck;
Expand Down Expand Up @@ -314,14 +312,6 @@ public Optional<ViewExpression> getColumnMask(ConnectorTransactionHandle transac
return Optional.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask"));
}
});

transaction(transactionManager, accessControlManager)
.execute(transactionId -> {
List<ViewExpression> masks = accessControlManager.getColumnMasks(transactionId, new Identity(USER_NAME, Optional.of(PRINCIPAL)),
new AccessControlContext(new QueryId(QUERY_ID), Optional.empty(), Collections.emptySet(), Optional.empty(), WarningCollector.NOOP, new RuntimeStats(), Optional.empty()), new QualifiedObjectName("catalog", "schema", "table"), "column", BIGINT);
assertEquals(masks.get(0).getExpression(), "connector mask");
assertEquals(masks.get(1).getExpression(), "system mask");
});
}

private static ConnectorId registerBogusConnector(CatalogManager catalogManager, TransactionManager transactionManager, AccessControl accessControl, String catalogName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,22 @@ public void testSimpleMask()
}

@Test
public void testMultipleMasks()
public void testMultipleMasksOnDifferentColumns()
{
assertions.executeExclusively(() -> {
accessControl.reset();
accessControl.columnMask(
new QualifiedObjectName(CATALOG, "tiny", "orders"),
"custkey",
USER,
new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey"));

accessControl.columnMask(
new QualifiedObjectName(CATALOG, "tiny", "orders"),
"custkey",
USER,
new ViewExpression(USER, Optional.empty(), Optional.empty(), "custkey * 2"));

assertions.assertQuery("SELECT custkey FROM orders WHERE orderkey = 1", "VALUES BIGINT '-740'");
});
accessControl.reset();
accessControl.columnMask(
new QualifiedObjectName(CATALOG, "tiny", "orders"),
"custkey",
USER,
new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey"));

accessControl.columnMask(
new QualifiedObjectName(CATALOG, "tiny", "orders"),
"orderstatus",
USER,
new ViewExpression(USER, Optional.empty(), Optional.empty(), "'X'"));

assertions.assertQuery("SELECT custkey, orderstatus FROM orders WHERE orderkey = 1", "VALUES (BIGINT '-370', 'X')");
}

@Test
Expand Down Expand Up @@ -379,4 +377,4 @@ public void testDeleteWithColumnMasking()
assertions.assertFails("DELETE FROM orders", "\\Qline 1:1: Delete from table with column mask is not supported\\E");
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ default List<ViewExpression> getRowFilters(TransactionId transactionId, Identity
return Collections.emptyList();
}

default List<ViewExpression> getColumnMasks(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, String columnName, Type type)
default Optional<ViewExpression> getColumnMask(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, String columnName, Type type)
{
return Collections.emptyList();
return Optional.empty();
}
}

0 comments on commit 1aefdf0

Please sign in to comment.