diff --git a/.mypy.ini b/.mypy.ini index 994dc9f2cad..e008ceb8f25 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,9 +1,5 @@ [mypy] -files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/, nncf/common/utils/ +files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/ , nncf/common/pruning, nncf/common/utils/ follow_imports = silent strict = True - -# should be removed later -# mypy recommends the following tool as an autofix: -# https://github.com/hauntsaninja/no_implicit_optional implicit_optional = True diff --git a/nncf/common/pruning/clusterization.py b/nncf/common/pruning/clusterization.py index 47457d01ca7..d723a9a8c07 100644 --- a/nncf/common/pruning/clusterization.py +++ b/nncf/common/pruning/clusterization.py @@ -19,16 +19,16 @@ class Cluster(Generic[T]): Represents element of Сlusterization. Groups together elements. """ - def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]): + def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]) -> None: self.id = cluster_id self.elements = list(elements) self.importance = max(nodes_orders) - def clean_cluster(self): + def clean_cluster(self) -> None: self.elements = [] self.importance = 0 - def add_elements(self, elements: List[T], importance: int): + def add_elements(self, elements: List[T], importance: int) -> None: self.elements.extend(elements) self.importance = max(self.importance, importance) @@ -39,11 +39,11 @@ class Clusterization(Generic[T]): delete existing one or merge existing clusters. """ - def __init__(self, id_fn: Callable[[T], Hashable] = None): + def __init__(self, id_fn: Callable[[T], Hashable] = None) -> None: self.clusters: Dict[int, Cluster[T]] = {} self._element_to_cluster: Dict[Hashable, int] = {} if id_fn is None: - self._id_fn = lambda x: x.id + self._id_fn: Callable[[T], Hashable] = lambda x: x.id # type:ignore else: self._id_fn = id_fn @@ -78,7 +78,7 @@ def is_node_in_clusterization(self, node_id: int) -> bool: """ return node_id in self._element_to_cluster - def add_cluster(self, cluster: Cluster[T]): + def add_cluster(self, cluster: Cluster[T]) -> None: """ Adds provided cluster to clusterization. @@ -91,7 +91,7 @@ def add_cluster(self, cluster: Cluster[T]): for elt in cluster.elements: self._element_to_cluster[self._id_fn(elt)] = cluster_id - def delete_cluster(self, cluster_id: int): + def delete_cluster(self, cluster_id: int) -> None: """ Removes cluster with `cluster_id` from clusterization. @@ -123,7 +123,7 @@ def get_all_nodes(self) -> List[T]: all_elements.extend(cluster.elements) return all_elements - def merge_clusters(self, first_id: int, second_id: int): + def merge_clusters(self, first_id: int, second_id: int) -> None: """ Merges two clusters with provided ids. @@ -143,7 +143,7 @@ def merge_clusters(self, first_id: int, second_id: int): self._element_to_cluster[self._id_fn(elt)] = second_id self.clusters.pop(first_id) - def merge_list_of_clusters(self, clusters: List[int]): + def merge_list_of_clusters(self, clusters: List[int]) -> None: """ Merges provided clusters. diff --git a/nncf/common/pruning/mask_propagation.py b/nncf/common/pruning/mask_propagation.py index dfa0c79da48..44f98b40977 100644 --- a/nncf/common/pruning/mask_propagation.py +++ b/nncf/common/pruning/mask_propagation.py @@ -51,7 +51,7 @@ def __init__( self._pruning_operator_metatypes = pruning_operator_metatypes self._tensor_processor = tensor_processor - def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp: + def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]: """ Returns class of metaop that corresponds to `type_name` type. @@ -61,16 +61,16 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp: cls = self._pruning_operator_metatypes.get_operator_metatype_by_op_name(type_name) if cls is None: cls = self._pruning_operator_metatypes.registry_dict["stop_propagation_ops"] - return cls + return cls # type:ignore - def mask_propagation(self): + def mask_propagation(self) -> None: """ Mask propagation in graph: to propagate masks run method mask_propagation (of metaop of current node) on all nodes in topological order. """ for node in self._graph.topological_sort(): cls = self.get_meta_operation_by_type_name(node.node_type) - cls.mask_propagation(node, self._graph, self._tensor_processor) + cls.mask_propagation(node, self._graph, self._tensor_processor) # type:ignore def symbolic_mask_propagation( self, prunable_layers_types: List[str], can_prune_after_analysis: Dict[int, PruningAnalysisDecision] @@ -96,24 +96,31 @@ def symbolic_mask_propagation( """ can_be_closing_convs = self._get_can_closing_convs(prunable_layers_types) - can_prune_by_dim = {k: None for k in can_be_closing_convs} + can_prune_by_dim: Dict[int, PruningAnalysisDecision] = { + k: PruningAnalysisDecision(False, PruningAnalysisReason.CLOSING_CONV_MISSING) for k in can_be_closing_convs + } for node in self._graph.topological_sort(): - if node.node_id in can_be_closing_convs and can_prune_after_analysis[node.node_id]: + if node.node_id in can_be_closing_convs and can_prune_after_analysis.get(node.node_id, None): # type:ignore # Set output mask node.attributes["output_mask"] = SymbolicMask(get_output_channels(node), node.node_id) # Propagate masks cls = self.get_meta_operation_by_type_name(node.node_type) - cls.mask_propagation(node, self._graph, SymbolicMaskProcessor) + if cls is not None: + cls.mask_propagation(node, self._graph, SymbolicMaskProcessor) if node.node_id in can_be_closing_convs: # Check input mask producers out channel dimension input_masks = get_input_masks(node, self._graph) if any(input_masks): assert len(input_masks) == 1 - input_mask: SymbolicMask = input_masks[0] + input_mask: SymbolicMask = input_masks[0] # type:ignore for producer in input_mask.mask_producers: previously_dims_equal = ( - True if can_prune_by_dim[producer.id] is None else can_prune_by_dim[producer.id] + True + if can_prune_by_dim[producer.id].decision + else False + if can_prune_by_dim[producer.id] is not None + else False ) is_dims_equal = get_input_channels(node) == input_mask.shape[0] @@ -124,13 +131,13 @@ def symbolic_mask_propagation( # Remove all convolutions with masks # that were propagated to output node for out_node in self._graph.get_output_nodes(): - for input_mask in get_input_masks(out_node, self._graph): + for input_mask in get_input_masks(out_node, self._graph): # type:ignore if input_mask: for producer in input_mask.mask_producers: can_prune_by_dim[producer.id] = PruningAnalysisDecision(False, PruningAnalysisReason.LAST_CONV) # Update decision for nodes which # have no closing convolution - convs_without_closing_conv = {} + convs_without_closing_conv = {} # type:ignore for k, v in can_prune_by_dim.items(): if v is None: convs_without_closing_conv[k] = PruningAnalysisDecision( @@ -144,11 +151,11 @@ def symbolic_mask_propagation( return can_prune_by_dim - def _get_can_closing_convs(self, prunable_layers_types) -> Dict: - retval = set() + def _get_can_closing_convs(self, prunable_layers_types: List[str]) -> Dict[int, bool]: + retval = {} for node in self._graph.get_all_nodes(): if node.node_type in prunable_layers_types and not ( is_grouped_conv(node) or is_batched_linear(node, self._graph) ): - retval.add(node.node_id) + retval[node.node_id] = True return retval diff --git a/nncf/common/pruning/model_analysis.py b/nncf/common/pruning/model_analysis.py index 47d7e79d454..3a3fe9dbd05 100644 --- a/nncf/common/pruning/model_analysis.py +++ b/nncf/common/pruning/model_analysis.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Type, cast from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -23,14 +23,14 @@ from nncf.common.pruning.utils import is_prunable_depthwise_conv -def get_position(nodes_list: List[NNCFNode], idx: int): +def get_position(nodes_list: List[NNCFNode], idx: int) -> int: for i, node in enumerate(nodes_list): if node.node_id == idx: return i - return None + return -1 -def merge_clusters_for_nodes(nodes_to_merge: List[NNCFNode], clusterization: Clusterization): +def merge_clusters_for_nodes(nodes_to_merge: List[NNCFNode], clusterization: Clusterization[NNCFNode]) -> None: """ Merges clusters to which nodes from nodes_to_merge belongs. @@ -151,7 +151,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool: """ return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases() - def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp: + def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]: """ Returns class of metaop that corresponds to `type_name` type. @@ -160,9 +160,9 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp: cls = self._pruning_operator_metatypes.get_operator_metatype_by_op_name(type_name) if cls is None: cls = self._stop_propagation_op_metatype - return cls + return cast(Type[BasePruningOp], cls) - def propagate_can_prune_attr_up(self): + def propagate_can_prune_attr_up(self) -> None: """ Propagating can_prune attribute in reversed topological order. This attribute depends on accept_pruned_input and can_prune attributes of output nodes. @@ -181,7 +181,7 @@ def propagate_can_prune_attr_up(self): ) self.can_prune[node.node_id] = outputs_accept_pruned_input and outputs_will_be_pruned - def propagate_can_prune_attr_down(self): + def propagate_can_prune_attr_down(self) -> None: """ Propagating can_prune attribute down to fix all branching cases with one pruned and one not pruned branches. @@ -199,7 +199,7 @@ def propagate_can_prune_attr_down(self): ): self.can_prune[node.node_id] = can_prune - def set_accept_pruned_input_attr(self): + def set_accept_pruned_input_attr(self) -> None: for nncf_node in self.graph.get_all_nodes(): cls = self.get_meta_operation_by_type_name(nncf_node.node_type) self.accept_pruned_input[nncf_node.node_id] = cls.accept_pruned_input(nncf_node) diff --git a/nncf/common/pruning/node_selector.py b/nncf/common/pruning/node_selector.py index a58ed7e1f37..327bd2c9c9d 100644 --- a/nncf/common/pruning/node_selector.py +++ b/nncf/common/pruning/node_selector.py @@ -116,7 +116,9 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]: all_pruned_inputs[source_node.node_id] = source_node if all_pruned_inputs: - cluster = Cluster[NNCFNode](i, all_pruned_inputs.values(), all_pruned_inputs.keys()) + all_pruned_nodes = list(all_pruned_inputs.values()) + all_pruned_node_ids = list(all_pruned_inputs.keys()) + cluster = Cluster[NNCFNode](i, all_pruned_nodes, all_pruned_node_ids) clusters_to_merge.append(cluster.id) pruned_nodes_clusterization.add_cluster(cluster) @@ -202,8 +204,8 @@ def _get_multiforward_nodes(self, graph: NNCFGraph) -> List[List[NNCFNode]]: def _pruning_dimensions_analysis( self, graph: NNCFGraph, - pruned_nodes_clusterization: Clusterization, - can_prune_after_check: Dict[int, PruningAnalysisDecision], + pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id), + can_prune_after_check: Dict[int, PruningAnalysisDecision] = {}, ) -> Dict[int, PruningAnalysisDecision]: """ Checks: @@ -251,7 +253,7 @@ def _check_all_closing_nodes_are_feasible( return can_prune_updated def _check_internal_groups_dim( - self, pruned_nodes_clusterization: Clusterization + self, pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id) ) -> Dict[int, PruningAnalysisDecision]: """ Checks pruning dimensions of all nodes in each cluster group are equal and @@ -278,8 +280,8 @@ def _check_internal_groups_dim( def _should_prune_groups_analysis( self, graph: NNCFGraph, - pruned_nodes_clusterization: Clusterization, - can_prune: Dict[int, PruningAnalysisDecision], + pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id), + can_prune: Dict[int, PruningAnalysisDecision] = {}, ) -> Dict[int, PruningAnalysisDecision]: """ Check whether all nodes in group can be pruned based on user-defined constraints and @@ -312,7 +314,9 @@ def _should_prune_groups_analysis( return can_prune_updated def _filter_groups( - self, pruned_nodes_clusterization: Clusterization, can_prune: Dict[int, PruningAnalysisDecision] + self, + pruned_nodes_clusterization: Clusterization[NNCFNode] = Clusterization[NNCFNode](lambda x: x.node_id), + can_prune: Dict[int, PruningAnalysisDecision] = {}, ) -> None: """ Check whether all nodes in group can be pruned based on user-defined constraints and @@ -355,7 +359,7 @@ def _is_module_prunable(self, graph: NNCFGraph, node: NNCFNode) -> PruningAnalys input_non_pruned_nodes = get_first_nodes_of_type(graph, types_to_track) node_name = node.node_name - if not should_consider_scope(node_name, self._ignored_scopes, self._target_scopes): + if not should_consider_scope(node_name, self._ignored_scopes or [], self._target_scopes): return PruningAnalysisDecision(False, PruningAnalysisReason.IGNORED_SCOPE) if not self._prune_first and node in input_non_pruned_nodes: diff --git a/nncf/common/pruning/operations.py b/nncf/common/pruning/operations.py index 022f4bd375e..9609fd7e79b 100644 --- a/nncf/common/pruning/operations.py +++ b/nncf/common/pruning/operations.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFGraphEdge @@ -30,8 +30,8 @@ class BasePruningOp: properties of interaction with pruning masks """ - subtypes = [] - additional_types = [] + subtypes: List[Any] = [] + additional_types: List[Any] = [] @classmethod def accept_pruned_input(cls, node: NNCFNode) -> bool: @@ -204,7 +204,7 @@ def mask_propagation( class ConcatPruningOp(BasePruningOp): @classmethod - def accept_pruned_input(cls, node: NNCFNode): + def accept_pruned_input(cls, node: NNCFNode) -> bool: return True @classmethod @@ -237,7 +237,8 @@ def generate_output_mask( filled_input_masks = [] for i, mask in enumerate(input_masks): if mask is None: - concat_axis = node.layer_attributes.axis + concat_axis = node.layer_attributes.axis if node.layer_attributes is not None else None # type:ignore + if concat_axis is not None: concat_dim = input_edges[i].tensor_shape[concat_axis] mask = tensor_processor.ones(concat_dim, device) filled_input_masks.append(mask) @@ -277,7 +278,7 @@ def match_multiple_output_masks( return result_masks @classmethod - def accept_pruned_input(cls, node: NNCFNode): + def accept_pruned_input(cls, node: NNCFNode) -> bool: if node.layer_attributes is not None: return True return False @@ -305,7 +306,7 @@ def generate_output_masks( if not input_mask: return None - chunk_axis = node.layer_attributes.axis + chunk_axis = node.layer_attributes.axis # type:ignore output_edges = graph.get_output_edges(node) output_shapes = [edge.tensor_shape[chunk_axis] for edge in output_edges] @@ -319,9 +320,9 @@ def generate_output_masks( return None split_masks = tensor_processor.split(input_mask, output_shapes) - result_masks = cls.match_multiple_output_masks(split_masks, output_edges, chunk_axis) + result_masks = cls.match_multiple_output_masks(split_masks, output_edges, chunk_axis) # type:ignore - return result_masks + return result_masks # type:ignore @classmethod def mask_propagation( @@ -334,7 +335,7 @@ def mask_propagation( class PadPruningOp(IdentityMaskForwardPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode) -> bool: - mode, value = node.layer_attributes.mode, node.layer_attributes.value + mode, value = node.layer_attributes.mode, node.layer_attributes.value # type:ignore if mode == "constant" and value != 0: return False return True @@ -352,7 +353,7 @@ def mask_propagation( input_masks = get_input_masks(node, graph) output_mask = input_masks[0] if output_mask is not None: - output_mask = tensor_processor.elementwise_mask_propagation(input_masks) + output_mask = tensor_processor.elementwise_mask_propagation(input_masks) # type:ignore node.attributes["output_mask"] = output_mask @@ -360,16 +361,16 @@ def mask_propagation( class ReshapePruningOp(BasePruningOp): @staticmethod def _is_flatten(node: NNCFNode) -> bool: - return len(node.layer_attributes.output_shape) == 2 + return len(node.layer_attributes.output_shape) == 2 # type:ignore @staticmethod def _is_not_mixing_dim(node: NNCFNode) -> bool: - input_shape = node.layer_attributes.input_shape - output_shape = node.layer_attributes.output_shape + input_shape = node.layer_attributes.input_shape # type:ignore + output_shape = node.layer_attributes.output_shape # type:ignore # TODO(dlyakhov): Cover all corner cases that appear here (ticket 90976) if len(input_shape) == len(output_shape) and set(input_shape) == set(output_shape): - return input_shape == output_shape + return input_shape == output_shape # type:ignore return True @classmethod @@ -394,12 +395,12 @@ def mask_propagation( class FlattenPruningOp(BasePruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode) -> bool: - if node.layer_attributes is not None: - return True - return False + return node.layer_attributes is not None @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph, tensor_processor: Type[NNCFPruningBaseTensorProcessor]): + def mask_propagation( + cls, node: NNCFNode, graph: NNCFGraph, tensor_processor: Type[NNCFPruningBaseTensorProcessor] + ) -> None: output_mask = None input_masks = get_input_masks(node, graph) assert len(input_masks) == 1 @@ -410,7 +411,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph, tensor_processor: Ty # Besides, since input_mask is not None thus no StopMaskForwardOp operations # was in the path from mask producer node to this node. As all # known nodes have input/output batch dim == 0 previous has too. - flatten_channels = node.layer_attributes.output_shape[1] + flatten_channels = node.layer_attributes.output_shape[1] # type:ignore mask_len = input_mask.shape[0] assert flatten_channels % mask_len == 0 output_mask = tensor_processor.repeat(input_mask, repeats=flatten_channels // mask_len) diff --git a/nncf/common/pruning/schedulers.py b/nncf/common/pruning/schedulers.py index facd3795588..c9f3c34467e 100644 --- a/nncf/common/pruning/schedulers.py +++ b/nncf/common/pruning/schedulers.py @@ -9,10 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional, Tuple, cast import numpy as np -import scipy.optimize +import scipy.optimize # type: ignore from nncf.api.compression import CompressionAlgorithmController from nncf.common.schedulers import BaseCompressionScheduler @@ -42,12 +42,12 @@ class PruningScheduler(BaseCompressionScheduler): section of the NNCF config file .json section (https://openvinotoolkit.github.io/nncf/schema). """ - def __init__(self, controller: CompressionAlgorithmController, params: dict): + def __init__(self, controller: CompressionAlgorithmController, params: Dict[str, Any]): super().__init__() self._controller = controller - self.initial_level = self._controller.pruning_init + self.initial_level = getattr(self._controller, "pruning_init", 0.0) - if self._controller.prune_flops: + if hasattr(self._controller, "prune_flops") and self._controller.prune_flops: self.target_level = params.get("pruning_flops_target") else: self.target_level = params.get("pruning_target", PRUNING_TARGET) @@ -74,9 +74,9 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None: will update the state of the pruning method. """ super().epoch_step(next_epoch) - self._controller.set_pruning_level(self.current_pruning_level) + self._controller.set_pruning_level(self.current_pruning_level) # type:ignore if self.current_epoch >= self.freeze_epoch: - self._controller.freeze() + self._controller.freeze() # type:ignore def step(self, next_step: Optional[int] = None) -> None: """ @@ -87,7 +87,7 @@ def step(self, next_step: Optional[int] = None) -> None: will update the state of the pruning method. """ super().step(next_step) - self._controller.step(next_step) + self._controller.step(next_step) # type:ignore @property def current_pruning_level(self) -> float: @@ -110,11 +110,11 @@ class BaselinePruningScheduler(PruningScheduler): Then scheduler sets `target_level` and freezes the algorithm. """ - def __init__(self, controller: CompressionAlgorithmController, params: dict): + def __init__(self, controller: CompressionAlgorithmController, params: Dict[str, Any]): super().__init__(controller, params) self.freeze_epoch = self.num_warmup_epochs - def _calculate_pruning_level(self) -> float: + def _calculate_pruning_level(self) -> Any: return self.target_level @@ -130,7 +130,7 @@ class ExponentialPruningScheduler(PruningScheduler): current_density = 1.0 - current_level """ - def __init__(self, controller: CompressionAlgorithmController, params: dict): + def __init__(self, controller: CompressionAlgorithmController, params: Dict[str, Any]): """ Initializes a pruning scheduler with an exponential decay schedule. @@ -138,15 +138,15 @@ def __init__(self, controller: CompressionAlgorithmController, params: dict): :param params: Parameters of the scheduler. """ super().__init__(controller, params) - initial_density = 1.0 - self.initial_level - target_density = 1.0 - self.target_level - target_epoch = self.num_pruning_epochs - 1 + initial_density = 1.0 - (self.initial_level or 0.0) + target_density = 1.0 - (self.target_level or 0.0) + target_epoch = (self.num_pruning_epochs or 0) - 1 self.schedule = ExponentialDecaySchedule(initial_density, target_density, target_epoch) def _calculate_pruning_level(self) -> float: current_density = self.schedule(self.current_epoch - self.num_warmup_epochs) current_level = 1.0 - current_density - return min(current_level, self.target_level) + return min(float(current_level or 0.0), float(self.target_level or 0.0)) @PRUNING_SCHEDULERS.register("exponential_with_bias") @@ -160,7 +160,7 @@ class ExponentialWithBiasPruningScheduler(PruningScheduler): where a, b, k is a params. """ - def __init__(self, controller: CompressionAlgorithmController, params: dict): + def __init__(self, controller: CompressionAlgorithmController, params: Dict[str, Any]): """ Initializes a pruning scheduler with an exponential (with bias) decay schedule. @@ -168,15 +168,17 @@ def __init__(self, controller: CompressionAlgorithmController, params: dict): :param params: Parameters of the scheduler. """ super().__init__(controller, params) - target_epoch = self.num_pruning_epochs - 1 - self.a, self.b, self.k = self._init_exp(target_epoch, self.initial_level, self.target_level) + target_epoch: int = int(self.num_pruning_epochs or 0) - 1 + initial_level: float = float(self.initial_level or 0.0) + target_level: float = float(self.target_level or 0.0) + self.a, self.b, self.k = self._init_exp(target_epoch, initial_level, target_level) def _calculate_pruning_level(self) -> float: - current_level = self.a * np.exp(-self.k * (self.current_epoch - self.num_warmup_epochs)) + self.b - return min(current_level, self.target_level) + current_level: float = self.a * np.exp(-self.k * (self.current_epoch - self.num_warmup_epochs)) + self.b + return min(current_level or 0.0, float(self.target_level or 0.0)) @staticmethod - def _init_exp(epoch_idx, p_min, p_max, factor=0.125): + def _init_exp(epoch_idx: int, p_min: float, p_max: float, factor: float = 0.125) -> Tuple[float, float, float]: """ Finds parameters a, b, k from the system: p_min = a + b @@ -190,18 +192,18 @@ def _init_exp(epoch_idx, p_min, p_max, factor=0.125): :param factor: Hyperparameter. """ - def get_b(a): - return p_min - a + def get_b(a: float) -> float: + return float(p_min - a) - def get_a(k): - return (p_max - p_min) / (np.exp(-k * epoch_idx) - 1) + def get_a(k: float) -> float: + return float((p_max - p_min) / (np.exp(-k * epoch_idx) - 1)) - def f_to_solve(x): - c = (0.75 * p_max - p_min) / (p_max - p_min) + def f_to_solve(x: Any) -> Any: + c: float = (0.75 * p_max - p_min) / (p_max - p_min) y = np.exp(-x * epoch_idx) - return y**factor - c * y + c - 1 + return cast(float, y**factor - c * y + c - 1) - k = scipy.optimize.fsolve(f_to_solve, [1])[0] - a = get_a(k) - b = get_b(a) + k: float = scipy.optimize.fsolve(f_to_solve, [1])[0] + a: float = get_a(k) + b: float = get_b(a) return a, b, k diff --git a/nncf/common/pruning/shape_pruning_processor.py b/nncf/common/pruning/shape_pruning_processor.py index 6214551bf49..c7e5465225a 100644 --- a/nncf/common/pruning/shape_pruning_processor.py +++ b/nncf/common/pruning/shape_pruning_processor.py @@ -21,6 +21,7 @@ from nncf.common.pruning.structs import PrunedLayerInfoBase from nncf.common.pruning.symbolic_mask import SymbolicMask from nncf.common.pruning.symbolic_mask import SymbolicMaskProcessor +from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry from nncf.common.pruning.utils import get_input_masks from nncf.common.pruning.utils import get_next_nodes_of_types from nncf.common.pruning.utils import get_output_channels @@ -35,7 +36,7 @@ class ShapePruningProcessor: compression algorithms execution. """ - def __init__(self, prunable_types: List[str], pruning_operations_metatype: List[str]): + def __init__(self, prunable_types: List[str], pruning_operations_metatype: "PruningOperationsMetatypeRegistry"): """ Constructor. @@ -48,7 +49,7 @@ def __init__(self, prunable_types: List[str], pruning_operations_metatype: List[ def calculate_in_out_channels_by_masks( self, graph: NNCFGraph, - pruning_groups: List[Cluster[PrunedLayerInfoBase]], + pruning_groups: "Clusterization[PrunedLayerInfoBase]", pruning_groups_next_nodes: Dict[int, List[Dict[str, Any]]], num_of_sparse_elements_by_node: Dict[NNCFNodeName, int], ) -> Tuple[Dict[str, int], Dict[str, int]]: @@ -63,7 +64,7 @@ def calculate_in_out_channels_by_masks( :return Dictionary of new input channels number {node_name: channels_num} """ - def get_sparser(full_output_channels): + def get_sparser(full_output_channels: Dict[NNCFNodeName, int]) -> Callable[[NNCFNodeName], int]: def get_num_of_sparse_elements_by_node(node_name: str) -> int: return num_of_sparse_elements_by_node[node_name] @@ -74,7 +75,7 @@ def get_num_of_sparse_elements_by_node(node_name: str) -> int: def calculate_in_out_channels_in_uniformly_pruned_model( self, graph: NNCFGraph, - pruning_groups: List[Cluster[PrunedLayerInfoBase]], + pruning_groups: "Clusterization[PrunedLayerInfoBase]", pruning_groups_next_nodes: Dict[int, List[Dict[str, Any]]], pruning_level: float, ) -> Tuple[Dict[str, int], Dict[str, int]]: @@ -89,7 +90,7 @@ def calculate_in_out_channels_in_uniformly_pruned_model( :return Tuple of dictionarise of new input and output channels number {node_name: channels_num} """ - def get_sparser(full_output_channels): + def get_sparser(full_output_channels: Dict[NNCFNodeName, int]) -> Callable[[NNCFNodeName], int]: def get_num_of_sparse_elements_by_node(node_name: str) -> int: old_out_channels = full_output_channels[node_name] return get_rounded_pruned_element_number(old_out_channels, pruning_level) @@ -135,7 +136,7 @@ def _calculate_in_out_channels( self, sparse_elements_counter_getter: Callable[[Dict[NNCFNodeName, int]], Callable[[NNCFNodeName], int]], graph: NNCFGraph, - pruning_groups: List[Cluster[PrunedLayerInfoBase]], + pruning_groups: "Clusterization[PrunedLayerInfoBase]", pruning_groups_next_nodes: Dict[int, List[Dict[str, Any]]], ) -> Tuple[Dict[str, int], Dict[str, int]]: full_input_channels, full_output_channels = get_prunable_layers_in_out_channels(graph) @@ -161,18 +162,18 @@ def _calculate_in_out_channels( def _get_next_node_sparse_multiplier( self, graph: NNCFGraph, next_node: NNCFNode, cluster: Clusterization[PrunedLayerInfoBase] ) -> int: - cluster_nodes_idxs = {node.nncf_node_id for node in cluster.elements} + cluster_nodes_idxs = {node.nncf_node_id for node in cluster.elements} # type:ignore for input_mask in get_input_masks(next_node, graph): if not input_mask: continue - for mask_producer in input_mask.mask_producers: + for mask_producer in input_mask.mask_producers: # type:ignore if mask_producer.id in cluster_nodes_idxs: - return mask_producer.sparse_multiplier + return mask_producer.sparse_multiplier # type:ignore - raise nncf.ValidationError(f"Next node for cluster {cluster.elements} doesn't have closing mask") + raise nncf.ValidationError(f"Next node for cluster {cluster.elements} doesn't have closing mask") # type:ignore def get_next_nodes( - self, graph: NNCFGraph, pruning_groups: Clusterization[PrunedLayerInfoBase] + self, graph: NNCFGraph, pruning_groups: "Clusterization[PrunedLayerInfoBase]" ) -> Dict[int, Dict[str, Any]]: """ Finds nodes of `prunable_types` types that receive the output of a pruned cluster as input @@ -190,9 +191,9 @@ def get_next_nodes( MaskPropagationAlgorithm(graph, self._pruning_operations_metatype, SymbolicMaskProcessor).mask_propagation() # 2. Find next nodes and correspondent sparse multipliers - next_nodes = {} + next_nodes = {} # type:ignore for cluster in pruning_groups.get_all_clusters(): - next_nodes_cluster = set() + next_nodes_cluster = set() # type:ignore cluster_nodes = set() for pruned_layer_info in cluster.elements: nncf_cluster_node = graph.get_node_by_id(pruned_layer_info.nncf_node_id) @@ -203,7 +204,7 @@ def get_next_nodes( next_nodes_cluster = next_nodes_cluster - cluster_nodes next_nodes[cluster.id] = [] for next_node in next_nodes_cluster: - sparse_multiplier = self._get_next_node_sparse_multiplier(graph, next_node, cluster) + sparse_multiplier = self._get_next_node_sparse_multiplier(graph, next_node, cluster) # type:ignore next_nodes[cluster.id].append( {"node_name": next_node.node_name, "sparse_multiplier": sparse_multiplier} ) diff --git a/nncf/common/pruning/symbolic_mask.py b/nncf/common/pruning/symbolic_mask.py index 6e67e31ceba..031e4ba0bc1 100644 --- a/nncf/common/pruning/symbolic_mask.py +++ b/nncf/common/pruning/symbolic_mask.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Dict, List, Union import nncf from nncf.common.pruning.tensor_processor import NNCFPruningBaseTensorProcessor +from nncf.common.tensor import DeviceType from nncf.common.tensor import NNCFTensor @@ -37,7 +38,7 @@ def id(self) -> int: @classmethod def merge_producers(cls, masks: List["SymbolicMask"]) -> List["SymbolicMaskProducer"]: - merged_producers = {} + merged_producers: Dict[int, SymbolicMaskProducer] = {} for mask in masks: for mask_producer in mask.mask_producers: if mask_producer.id in merged_producers: @@ -74,10 +75,10 @@ def shape(self) -> List[int]: @property def mask_producers(self) -> List[SymbolicMaskProducer]: - return self._mask_producers + return self._mask_producers # type:ignore @property - def device(self) -> None: + def device(self) -> None: # type:ignore return None @@ -101,13 +102,13 @@ class SymbolicMaskProcessor(NNCFPruningBaseTensorProcessor): """ @classmethod - def concatenate(cls, tensors: List[SymbolicMask], axis: int) -> SymbolicMask: + def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> SymbolicMask: ret_shape = sum([t.shape[0] for t in tensors]) - producers = SymbolicMaskProducer.merge_producers(tensors) + producers = SymbolicMaskProducer.merge_producers([t for t in tensors if isinstance(t, SymbolicMask)]) return SymbolicMask(ret_shape, producers) @classmethod - def ones(cls, shape: Union[int, List[int]], device) -> SymbolicMask: + def ones(cls, shape: Union[int, List[int]], device: DeviceType) -> SymbolicMask: if isinstance(shape, list): if len(shape) != 1: raise nncf.ValidationError(f"Unexpected shape = {shape} for 1D symbolic mask") @@ -116,12 +117,12 @@ def ones(cls, shape: Union[int, List[int]], device) -> SymbolicMask: return SymbolicMask(shape) @classmethod - def assert_allclose(cls, tensors: List[SymbolicMask]) -> None: + def assert_allclose(cls, tensors: List[NNCFTensor]) -> None: for input_mask in tensors[1:]: assert tensors[0].shape == input_mask.shape @classmethod - def repeat(cls, tensor: SymbolicMask, repeats: int) -> SymbolicMask: + def repeat(cls, tensor: SymbolicMask, repeats: int) -> SymbolicMask: # type:ignore updated_mask_producers = [] for mask_producer in tensor.mask_producers: updated_mask_producers.append( @@ -130,7 +131,7 @@ def repeat(cls, tensor: SymbolicMask, repeats: int) -> SymbolicMask: return SymbolicMask(tensor.shape[0] * repeats, updated_mask_producers) @classmethod - def elementwise_mask_propagation(cls, input_masks: List[SymbolicMask]) -> SymbolicMask: + def elementwise_mask_propagation(cls, input_masks: List[SymbolicMask]) -> SymbolicMask: # type:ignore """ Assemble output mask for elementwise pruning operation from given input masks. In case input_masks have different shape don't propagate any masks. @@ -146,7 +147,7 @@ def elementwise_mask_propagation(cls, input_masks: List[SymbolicMask]) -> Symbol return SymbolicMask(input_masks[0].shape[0], producers) @classmethod - def split(cls, tensor: SymbolicMask, output_shapes: List[int]) -> List[SymbolicMask]: + def split(cls, tensor: SymbolicMask, output_shapes: List[int]) -> List[SymbolicMask]: # type:ignore if any(shape <= 0 for shape in output_shapes) or tensor.shape[0] != sum(output_shapes): raise AssertionError( "Symbolic mask split was called with" diff --git a/nncf/common/pruning/utils.py b/nncf/common/pruning/utils.py index b67deff12a8..6b786c20c36 100644 --- a/nncf/common/pruning/utils.py +++ b/nncf/common/pruning/utils.py @@ -12,7 +12,7 @@ import math from enum import Enum from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np @@ -67,7 +67,7 @@ def get_sources_of_node(nncf_node: NNCFNode, graph: NNCFGraph, sources_types: Li :param graph: NNCF graph to work with. :return: List of all sources nodes. """ - visited = {node_id: False for node_id in graph.get_all_node_ids()} + visited: Dict[str, bool] = {str(node_id): False for node_id in graph.get_all_node_ids()} partial_traverse_function = partial(traverse_function, type_check_fn=lambda x: x in sources_types, visited=visited) nncf_nodes = [nncf_node] if nncf_node.node_type in sources_types: @@ -142,7 +142,9 @@ def get_rounded_pruned_element_number(total: int, sparsity_rate: float, multiple return max(total - remaining_elems, 0) -def traverse_function(node: NNCFNode, output: List[NNCFNode], type_check_fn, visited) -> Tuple[bool, List[NNCFNode]]: +def traverse_function( + node: NNCFNode, output: List[NNCFNode], type_check_fn: Callable[[str], bool], visited: Dict[Any, bool] +) -> Tuple[bool, List[NNCFNode]]: if visited[node.node_id]: return True, output visited[node.node_id] = True @@ -166,7 +168,7 @@ def get_last_nodes_of_type(graph: NNCFGraph, op_types: List[str]) -> List[NNCFNo """ graph_outputs = graph.get_output_nodes() # NNCFNodes here - visited = {node_id: False for node_id in graph.get_all_node_ids()} + visited: Dict[str, bool] = {str(node_id): False for node_id in graph.get_all_node_ids()} partial_traverse_function = partial(traverse_function, type_check_fn=lambda x: x in op_types, visited=visited) last_nodes_of_type = [] for output in graph_outputs: @@ -210,15 +212,15 @@ def get_prunable_layers_in_out_channels(graph: NNCFGraph) -> Tuple[Dict[NNCFNode class PruningOperationsMetatypeRegistry(Registry): - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__(name) - self._op_name_to_op_class = {} + self._op_name_to_op_class: Dict[str, Type[object]] = {} - def register(self, name=None): - name_ = name + def register(self, name: Optional[str] = None) -> Callable[[Type[object]], Type[object]]: + name_: Optional[str] = name super_register = super()._register - def wrap(obj): + def wrap(obj): # type:ignore cls_name = name_ if cls_name is None: cls_name = obj.__name__ @@ -236,7 +238,7 @@ def wrap(obj): return wrap - def get_operator_metatype_by_op_name(self, op_name: str): + def get_operator_metatype_by_op_name(self, op_name: str) -> Optional[Type[object]]: if op_name in self._op_name_to_op_class: return self._op_name_to_op_class[op_name] return None @@ -269,7 +271,7 @@ def message(cls, node_name: str, decision: Optional["PruningAnalysisDecision"]) :return: Pruning analysis decision in a human-readable format. """ prefix = f"ignored adding Weight Pruner in: {node_name}" - reasons = decision.reasons + reasons = decision.reasons # type:ignore if not reasons: return prefix # Filter messages @@ -295,7 +297,7 @@ def __init__( ): self.decision = decision if not isinstance(possible_reasons, list): - possible_reasons = [possible_reasons] + possible_reasons = [possible_reasons] # type:ignore self._reasons: Optional[List[PruningAnalysisReason]] = ( possible_reasons if not decision and possible_reasons else None ) @@ -306,7 +308,7 @@ def __repr__(self) -> str: representation += "; Reasons: " + str(self._reasons) return representation - def __eq__(self, other: "PruningAnalysisDecision") -> bool: + def __eq__(self, other: "PruningAnalysisDecision") -> bool: # type:ignore eq = self.decision == other.decision if self._reasons is None: return eq and other._reasons is None @@ -368,7 +370,7 @@ def get_input_masks(node: NNCFNode, graph: NNCFGraph) -> List[Optional[NNCFTenso :return: Input masks. """ retval = [] - input_masks = [input_edge.from_node.attributes["output_mask"] for input_edge in graph.get_input_edges(node)] + input_masks = [input_edge.from_node.attributes.get("output_mask") for input_edge in graph.get_input_edges(node)] for input_mask in input_masks: retval.append(input_mask[node.node_name] if isinstance(input_mask, dict) else input_mask) return retval @@ -381,7 +383,7 @@ def get_input_channels(node: NNCFNode) -> int: :param node: Given prunable node. :return: Count of input channels of the given node. """ - layer_attrs: Union[ConvolutionLayerAttributes, LinearLayerAttributes] = node.layer_attributes + layer_attrs = node.layer_attributes if isinstance(layer_attrs, ConvolutionLayerAttributes): return layer_attrs.in_channels if isinstance(layer_attrs, LinearLayerAttributes): @@ -396,7 +398,7 @@ def get_output_channels(node: NNCFNode) -> int: :param node: Given prunable node. :return: Count of output channels of the given node. """ - layer_attrs: Union[ConvolutionLayerAttributes, LinearLayerAttributes] = node.layer_attributes + layer_attrs = node.layer_attributes if isinstance(layer_attrs, ConvolutionLayerAttributes): return layer_attrs.out_channels if isinstance(layer_attrs, LinearLayerAttributes): diff --git a/nncf/common/pruning/weights_flops_calculator.py b/nncf/common/pruning/weights_flops_calculator.py index 85554ab2649..5679e98bfb7 100644 --- a/nncf/common/pruning/weights_flops_calculator.py +++ b/nncf/common/pruning/weights_flops_calculator.py @@ -102,13 +102,13 @@ def count_flops_and_weights_per_node( output_channels = output_channels or {} kernel_sizes = kernel_sizes or {} op_addresses_to_skip = op_addresses_to_skip or [] - for node in graph.get_nodes_by_metatypes(self._conv_op_metatypes): + for node in graph.get_nodes_by_metatypes([type(op) for op in self._conv_op_metatypes]): name = node.node_name if name in op_addresses_to_skip: continue - num_in_channels = input_channels.get(name, node.layer_attributes.in_channels) - num_out_channels = output_channels.get(name, node.layer_attributes.out_channels) - kernel_size = kernel_sizes.get(name, node.layer_attributes.kernel_size) + num_in_channels = input_channels.get(name, getattr(node.layer_attributes, "in_channels", 0)) + num_out_channels = output_channels.get(name, getattr(node.layer_attributes, "out_channels", 0)) + kernel_size = kernel_sizes.get(name, getattr(node.layer_attributes, "kernel_size", (0, 0))) if is_prunable_depthwise_conv(node): # Prunable depthwise conv processed in special way # because common way to calculate filters per @@ -116,7 +116,7 @@ def count_flops_and_weights_per_node( # some of the output channels are pruned. filters_per_channel = 1 else: - filters_per_channel = num_out_channels // node.layer_attributes.groups + filters_per_channel = num_out_channels // getattr(node.layer_attributes, "groups", 1) flops_numpy = ( 2 * np.prod(kernel_size) * num_in_channels * filters_per_channel * np.prod(output_shapes[name]) @@ -126,15 +126,14 @@ def count_flops_and_weights_per_node( flops[name] = flops_numpy.astype(int).item() weights[name] = weights_numpy.astype(int).item() - for node in graph.get_nodes_by_metatypes(self._linear_op_metatypes): + for node in graph.get_nodes_by_metatypes([type(op) for op in self._linear_op_metatypes]): name = node.node_name if name in op_addresses_to_skip: continue - num_in_features = input_channels.get(name, node.layer_attributes.in_features) - num_out_features = output_channels.get(name, node.layer_attributes.out_features) - - flops_numpy = 2 * num_in_features * num_out_features * np.prod(output_shapes[name][:-1]) + num_in_features = input_channels.get(name, getattr(node.layer_attributes, "in_features", 0)) + num_out_features = output_channels.get(name, getattr(node.layer_attributes, "out_features", 0)) + flops_numpy = 2 * num_in_features * num_out_features * np.prod(output_shapes[name]) weights_numpy = num_in_features * num_out_features flops[name] = flops_numpy weights[name] = weights_numpy @@ -151,6 +150,6 @@ def count_filters_num(self, graph: NNCFGraph, output_channels: Dict[NNCFNodeName """ filters_num = 0 output_channels = output_channels or {} - for node in graph.get_nodes_by_metatypes(self._conv_op_metatypes + self._linear_op_metatypes): + for node in graph.get_nodes_by_metatypes([type(op) for op in self._linear_op_metatypes]): filters_num += output_channels.get(node.node_name, get_output_channels(node)) return filters_num