Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy checks for NNCF common pruning code #2613

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
6 changes: 1 addition & 5 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
[mypy]
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/, nncf/common/utils/
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/ , nncf/common/pruning, nncf/common/utils/
follow_imports = silent
strict = True

# should be removed later
# mypy recommends the following tool as an autofix:
# https://github.com/hauntsaninja/no_implicit_optional
implicit_optional = True
18 changes: 9 additions & 9 deletions nncf/common/pruning/clusterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ class Cluster(Generic[T]):
Represents element of Сlusterization. Groups together elements.
"""

def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]):
def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]) -> None:
self.id = cluster_id
self.elements = list(elements)
self.importance = max(nodes_orders)

def clean_cluster(self):
def clean_cluster(self) -> None:
self.elements = []
self.importance = 0

def add_elements(self, elements: List[T], importance: int):
def add_elements(self, elements: List[T], importance: int) -> None:
self.elements.extend(elements)
self.importance = max(self.importance, importance)

Expand All @@ -39,11 +39,11 @@ class Clusterization(Generic[T]):
delete existing one or merge existing clusters.
"""

def __init__(self, id_fn: Callable[[T], Hashable] = None):
def __init__(self, id_fn: Callable[[T], Hashable] = None) -> None:
self.clusters: Dict[int, Cluster[T]] = {}
self._element_to_cluster: Dict[Hashable, int] = {}
if id_fn is None:
self._id_fn = lambda x: x.id
self._id_fn: Callable[[T], Hashable] = lambda x: x.id # type:ignore
else:
self._id_fn = id_fn

Expand Down Expand Up @@ -78,7 +78,7 @@ def is_node_in_clusterization(self, node_id: int) -> bool:
"""
return node_id in self._element_to_cluster

def add_cluster(self, cluster: Cluster[T]):
def add_cluster(self, cluster: Cluster[T]) -> None:
"""
Adds provided cluster to clusterization.

Expand All @@ -91,7 +91,7 @@ def add_cluster(self, cluster: Cluster[T]):
for elt in cluster.elements:
self._element_to_cluster[self._id_fn(elt)] = cluster_id

def delete_cluster(self, cluster_id: int):
def delete_cluster(self, cluster_id: int) -> None:
"""
Removes cluster with `cluster_id` from clusterization.

Expand Down Expand Up @@ -123,7 +123,7 @@ def get_all_nodes(self) -> List[T]:
all_elements.extend(cluster.elements)
return all_elements

def merge_clusters(self, first_id: int, second_id: int):
def merge_clusters(self, first_id: int, second_id: int) -> None:
"""
Merges two clusters with provided ids.

Expand All @@ -143,7 +143,7 @@ def merge_clusters(self, first_id: int, second_id: int):
self._element_to_cluster[self._id_fn(elt)] = second_id
self.clusters.pop(first_id)

def merge_list_of_clusters(self, clusters: List[int]):
def merge_list_of_clusters(self, clusters: List[int]) -> None:
"""
Merges provided clusters.

Expand Down
35 changes: 21 additions & 14 deletions nncf/common/pruning/mask_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self._pruning_operator_metatypes = pruning_operator_metatypes
self._tensor_processor = tensor_processor

def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]:
"""
Returns class of metaop that corresponds to `type_name` type.

Expand All @@ -61,16 +61,16 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
cls = self._pruning_operator_metatypes.get_operator_metatype_by_op_name(type_name)
if cls is None:
cls = self._pruning_operator_metatypes.registry_dict["stop_propagation_ops"]
return cls
return cls # type:ignore

def mask_propagation(self):
def mask_propagation(self) -> None:
"""
Mask propagation in graph:
to propagate masks run method mask_propagation (of metaop of current node) on all nodes in topological order.
"""
for node in self._graph.topological_sort():
cls = self.get_meta_operation_by_type_name(node.node_type)
cls.mask_propagation(node, self._graph, self._tensor_processor)
cls.mask_propagation(node, self._graph, self._tensor_processor) # type:ignore

def symbolic_mask_propagation(
self, prunable_layers_types: List[str], can_prune_after_analysis: Dict[int, PruningAnalysisDecision]
Expand All @@ -96,24 +96,31 @@ def symbolic_mask_propagation(
"""

can_be_closing_convs = self._get_can_closing_convs(prunable_layers_types)
can_prune_by_dim = {k: None for k in can_be_closing_convs}
can_prune_by_dim: Dict[int, PruningAnalysisDecision] = {
k: PruningAnalysisDecision(False, PruningAnalysisReason.CLOSING_CONV_MISSING) for k in can_be_closing_convs
}
for node in self._graph.topological_sort():
if node.node_id in can_be_closing_convs and can_prune_after_analysis[node.node_id]:
if node.node_id in can_be_closing_convs and can_prune_after_analysis.get(node.node_id, None): # type:ignore
# Set output mask
node.attributes["output_mask"] = SymbolicMask(get_output_channels(node), node.node_id)
# Propagate masks
cls = self.get_meta_operation_by_type_name(node.node_type)
cls.mask_propagation(node, self._graph, SymbolicMaskProcessor)
if cls is not None:
cls.mask_propagation(node, self._graph, SymbolicMaskProcessor)
if node.node_id in can_be_closing_convs:
# Check input mask producers out channel dimension
input_masks = get_input_masks(node, self._graph)
if any(input_masks):
assert len(input_masks) == 1
input_mask: SymbolicMask = input_masks[0]
input_mask: SymbolicMask = input_masks[0] # type:ignore

for producer in input_mask.mask_producers:
previously_dims_equal = (
True if can_prune_by_dim[producer.id] is None else can_prune_by_dim[producer.id]
True
if can_prune_by_dim[producer.id].decision
else False
if can_prune_by_dim[producer.id] is not None
else False
)

is_dims_equal = get_input_channels(node) == input_mask.shape[0]
Expand All @@ -124,13 +131,13 @@ def symbolic_mask_propagation(
# Remove all convolutions with masks
# that were propagated to output node
for out_node in self._graph.get_output_nodes():
for input_mask in get_input_masks(out_node, self._graph):
for input_mask in get_input_masks(out_node, self._graph): # type:ignore
if input_mask:
for producer in input_mask.mask_producers:
can_prune_by_dim[producer.id] = PruningAnalysisDecision(False, PruningAnalysisReason.LAST_CONV)
# Update decision for nodes which
# have no closing convolution
convs_without_closing_conv = {}
convs_without_closing_conv = {} # type:ignore
for k, v in can_prune_by_dim.items():
if v is None:
convs_without_closing_conv[k] = PruningAnalysisDecision(
Expand All @@ -144,11 +151,11 @@ def symbolic_mask_propagation(

return can_prune_by_dim

def _get_can_closing_convs(self, prunable_layers_types) -> Dict:
retval = set()
def _get_can_closing_convs(self, prunable_layers_types: List[str]) -> Dict[int, bool]:
retval = {}
for node in self._graph.get_all_nodes():
if node.node_type in prunable_layers_types and not (
is_grouped_conv(node) or is_batched_linear(node, self._graph)
):
retval.add(node.node_id)
retval[node.node_id] = True
return retval
18 changes: 9 additions & 9 deletions nncf/common/pruning/model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Type, cast

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
Expand All @@ -23,14 +23,14 @@
from nncf.common.pruning.utils import is_prunable_depthwise_conv


def get_position(nodes_list: List[NNCFNode], idx: int):
def get_position(nodes_list: List[NNCFNode], idx: int) -> int:
for i, node in enumerate(nodes_list):
if node.node_id == idx:
return i
return None
return -1


def merge_clusters_for_nodes(nodes_to_merge: List[NNCFNode], clusterization: Clusterization):
def merge_clusters_for_nodes(nodes_to_merge: List[NNCFNode], clusterization: Clusterization[NNCFNode]) -> None:
"""
Merges clusters to which nodes from nodes_to_merge belongs.

Expand Down Expand Up @@ -151,7 +151,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool:
"""
return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases()

def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]:
"""
Returns class of metaop that corresponds to `type_name` type.

Expand All @@ -160,9 +160,9 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
cls = self._pruning_operator_metatypes.get_operator_metatype_by_op_name(type_name)
if cls is None:
cls = self._stop_propagation_op_metatype
return cls
return cast(Type[BasePruningOp], cls)

def propagate_can_prune_attr_up(self):
def propagate_can_prune_attr_up(self) -> None:
"""
Propagating can_prune attribute in reversed topological order.
This attribute depends on accept_pruned_input and can_prune attributes of output nodes.
Expand All @@ -181,7 +181,7 @@ def propagate_can_prune_attr_up(self):
)
self.can_prune[node.node_id] = outputs_accept_pruned_input and outputs_will_be_pruned

def propagate_can_prune_attr_down(self):
def propagate_can_prune_attr_down(self) -> None:
"""
Propagating can_prune attribute down to fix all branching cases with one pruned and one not pruned
branches.
Expand All @@ -199,7 +199,7 @@ def propagate_can_prune_attr_down(self):
):
self.can_prune[node.node_id] = can_prune

def set_accept_pruned_input_attr(self):
def set_accept_pruned_input_attr(self) -> None:
for nncf_node in self.graph.get_all_nodes():
cls = self.get_meta_operation_by_type_name(nncf_node.node_type)
self.accept_pruned_input[nncf_node.node_id] = cls.accept_pruned_input(nncf_node)
Expand Down
20 changes: 12 additions & 8 deletions nncf/common/pruning/node_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]:
all_pruned_inputs[source_node.node_id] = source_node

if all_pruned_inputs:
cluster = Cluster[NNCFNode](i, all_pruned_inputs.values(), all_pruned_inputs.keys())
all_pruned_nodes = list(all_pruned_inputs.values())
all_pruned_node_ids = list(all_pruned_inputs.keys())
cluster = Cluster[NNCFNode](i, all_pruned_nodes, all_pruned_node_ids)
clusters_to_merge.append(cluster.id)
pruned_nodes_clusterization.add_cluster(cluster)

Expand Down Expand Up @@ -202,8 +204,8 @@ def _get_multiforward_nodes(self, graph: NNCFGraph) -> List[List[NNCFNode]]:
def _pruning_dimensions_analysis(
self,
graph: NNCFGraph,
pruned_nodes_clusterization: Clusterization,
can_prune_after_check: Dict[int, PruningAnalysisDecision],
pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id),
can_prune_after_check: Dict[int, PruningAnalysisDecision] = {},
) -> Dict[int, PruningAnalysisDecision]:
"""
Checks:
Expand Down Expand Up @@ -251,7 +253,7 @@ def _check_all_closing_nodes_are_feasible(
return can_prune_updated

def _check_internal_groups_dim(
self, pruned_nodes_clusterization: Clusterization
self, pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id)
) -> Dict[int, PruningAnalysisDecision]:
"""
Checks pruning dimensions of all nodes in each cluster group are equal and
Expand All @@ -278,8 +280,8 @@ def _check_internal_groups_dim(
def _should_prune_groups_analysis(
self,
graph: NNCFGraph,
pruned_nodes_clusterization: Clusterization,
can_prune: Dict[int, PruningAnalysisDecision],
pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id),
can_prune: Dict[int, PruningAnalysisDecision] = {},
) -> Dict[int, PruningAnalysisDecision]:
"""
Check whether all nodes in group can be pruned based on user-defined constraints and
Expand Down Expand Up @@ -312,7 +314,9 @@ def _should_prune_groups_analysis(
return can_prune_updated

def _filter_groups(
self, pruned_nodes_clusterization: Clusterization, can_prune: Dict[int, PruningAnalysisDecision]
self,
pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id),
can_prune: Dict[int, PruningAnalysisDecision] = {},
) -> None:
"""
Check whether all nodes in group can be pruned based on user-defined constraints and
Expand Down Expand Up @@ -355,7 +359,7 @@ def _is_module_prunable(self, graph: NNCFGraph, node: NNCFNode) -> PruningAnalys
input_non_pruned_nodes = get_first_nodes_of_type(graph, types_to_track)
node_name = node.node_name

if not should_consider_scope(node_name, self._ignored_scopes, self._target_scopes):
if not should_consider_scope(node_name, self._ignored_scopes or [], self._target_scopes):
return PruningAnalysisDecision(False, PruningAnalysisReason.IGNORED_SCOPE)

if not self._prune_first and node in input_non_pruned_nodes:
Expand Down
Loading
Loading