Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement KaHyPar optimizer #35

Merged
merged 15 commits into from
Oct 30, 2023
Merged
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ImmutableArrays = "667c17eb-ab9b-4487-935f-1c621bb82497"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
109 changes: 109 additions & 0 deletions src/Optimizers/Hyper.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using SparseArrays
using KaHyPar

@kwdef struct HyperGraphPartitioning <: Optimizer
metric::Function = removedsize
outer::Bool = false
end


function get_hypergraph(path)

Check warning on line 10 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L10

Added line #L10 was not covered by tests
# assume unique indices across all tensors
all_indices = mapreduce(head, ∪, path)

Check warning on line 12 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L12

Added line #L12 was not covered by tests
mofeing marked this conversation as resolved.
Show resolved Hide resolved

# Create incidence matrix
incidence_matrix = spzeros(Int64, length(path.args), length(all_indices))

Check warning on line 15 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L15

Added line #L15 was not covered by tests

for (i, tensor) in enumerate(path.args)
for idx in tensor.head
j = findfirst(==(idx), all_indices)
incidence_matrix[i, j] = 1
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
end
end

Check warning on line 22 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L17-L22

Added lines #L17 - L22 were not covered by tests

# Vertex weights (assuming equal weight for all tensors)
vertex_weights = ones(Int64, length(path.args))

Check warning on line 25 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L25

Added line #L25 was not covered by tests

# Hyperedge weights set to the size of the index
edge_weights = [size(path, idx) for idx in all_indices]

Check warning on line 28 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L28

Added line #L28 was not covered by tests

# Create hypergraph
hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)

Check warning on line 31 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L31

Added line #L31 was not covered by tests

return hypergraph

Check warning on line 33 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L33

Added line #L33 was not covered by tests
end

function partition_hypergraph(hypergraph::KaHyPar.HyperGraph; parts=2, parts_decay=0.5, random_strength=0.01)

Check warning on line 36 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L36

Added line #L36 was not covered by tests
# Create a context for KaHyPar
context = KaHyPar.kahypar_context_new()

Check warning on line 38 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L38

Added line #L38 was not covered by tests

# Load default configuration
config_file_path = KaHyPar.default_configuration
KaHyPar.kahypar_configure_context_from_file(context, config_file_path)

Check warning on line 42 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L41-L42

Added lines #L41 - L42 were not covered by tests

# Calculate the relative subgraph size
subsize = hypergraph.n_vertices
N = hypergraph.n_vertices # Assuming N is the total number of vertices in the entire graph
s = subsize / N

Check warning on line 47 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L45-L47

Added lines #L45 - L47 were not covered by tests
mofeing marked this conversation as resolved.
Show resolved Hide resolved

# Determine the number of partitions based on the relative subgraph size
kparts = max(Int(s^parts_decay * parts), 2)

Check warning on line 50 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L50

Added line #L50 was not covered by tests

# Perform the partitioning
partitioning_result = KaHyPar.partition(hypergraph, kparts; configuration=config_file_path)

Check warning on line 53 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L53

Added line #L53 was not covered by tests

# Clean up the context
KaHyPar.kahypar_context_free(context)

Check warning on line 56 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L56

Added line #L56 was not covered by tests

return partitioning_result

Check warning on line 58 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L58

Added line #L58 was not covered by tests
end

function recursive_partition(expr::EinExpr; parts=2, parts_decay=0.5, random_strength=0.01, cutoff=2, max_iterations=10)

Check warning on line 61 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L61

Added line #L61 was not covered by tests
# Convert the EinExpr to a hypergraph
hypergraph = get_hypergraph(expr)
println("Hypergraph size: ", hypergraph.n_vertices)

Check warning on line 64 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L63-L64

Added lines #L63 - L64 were not covered by tests

# Base case: if the hypergraph is small enough, we stop the recursion
if hypergraph.n_vertices <= cutoff
println("Base case")
return expr

Check warning on line 69 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L67-L69

Added lines #L67 - L69 were not covered by tests
end

if max_iterations == 0
println("Max iterations reached")
mofeing marked this conversation as resolved.
Show resolved Hide resolved
return expr

Check warning on line 74 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L72-L74

Added lines #L72 - L74 were not covered by tests
end

# Partition the hypergraph
partitioning_result = partition_hypergraph(hypergraph, parts=parts, parts_decay=parts_decay, random_strength=random_strength)

Check warning on line 78 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L78

Added line #L78 was not covered by tests

# Get unique partitions
unique_partitions = unique(partitioning_result)

Check warning on line 81 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L81

Added line #L81 was not covered by tests

# Contract nodes based on the partitioning
new_exprs = [partition_to_einexpr(part_id, partitioning_result, expr) for part_id in unique_partitions]
combined_expr = sum(new_exprs)

Check warning on line 85 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L84-L85

Added lines #L84 - L85 were not covered by tests

# Recursively call the function on the combined EinExpr
return recursive_partition(combined_expr, parts=parts, parts_decay=parts_decay, random_strength=random_strength, cutoff=cutoff, max_iterations=max_iterations-1)

Check warning on line 88 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L88

Added line #L88 was not covered by tests
end

function partition_to_einexpr(partition_id::Int, partition_result::Vector{Int}, original_path::EinExpr)

Check warning on line 91 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L91

Added line #L91 was not covered by tests
# Identify the tensor indices that belong to the specific partition_id
tensor_indices_in_partition = [idx for (idx, part) in enumerate(partition_result) if part == partition_id]

Check warning on line 93 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L93

Added line #L93 was not covered by tests

# Extract the corresponding tensors from the original EinExpr
tensors_in_partition = [original_path.args[idx] for idx in tensor_indices_in_partition]

Check warning on line 96 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L96

Added line #L96 was not covered by tests

# Sum (contract) these tensors to create a new tensor
new_tensor = sum(tensors_in_partition)

Check warning on line 99 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L99

Added line #L99 was not covered by tests

# @show tensors_in_partition
# @show tensor_indices_in_partition
# @show new_tensor

# Create a new EinExpr with the new tensor's indices in the head, the original tensors in the args, and the size dictionary
new_einexpr = EinExpr(new_tensor.head, tensors_in_partition)

Check warning on line 106 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L106

Added line #L106 was not covered by tests

return new_einexpr

Check warning on line 108 in src/Optimizers/Hyper.jl

View check run for this annotation

Codecov / codecov/patch

src/Optimizers/Hyper.jl#L108

Added line #L108 was not covered by tests
end