diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index fa455cbac..580e4eb24 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -801,10 +801,13 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, channels = [] for n in node.all_input_nodes: channel_dim = find_srcs_channel_dim(graph_model, n) - if channel_dim is _UNSUPPORTED_OP: - state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP - continue channels.append(channel_dim) + + # If we found an unsupported op while walking up, we exit this branch and + # invalidate the region + if _UNSUPPORTED_OP in channels: + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP + continue start = sum(channels[:index]) end = start + channels[index] new_state = WalkRegionState(offset=state.offset)