Skip to content

Commit

Permalink
Remove indices on selectdim to 1 dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 28, 2023
1 parent fc0cd69 commit 7bd9fb0
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/Slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,27 @@ function Base.selectdim(path::EinExpr, index::Symbol, i)
path = deepcopy(path)

for leave in Iterators.filter((index) head, Leaves(path))
leave.size[index] = length(i)
leave.size[index] = length(i)
end

return path
end

function Base.selectdim(path::EinExpr, index::Symbol, _::Integer)
path = deepcopy(path)

index head(path) && (path = EinExpr(filter(!=(index), path.head), path.args))

for branch in Branches(path)
for arg in Iterators.filter((index) head, branch.args)
replace!(
branch.args,
arg => EinExpr(
filter(!=(index), arg.head),
isempty(arg.args) ? filter(p -> p.first != index, arg.size) : arg.args,
),
)
end
end

return path
Expand Down

0 comments on commit 7bd9fb0

Please sign in to comment.