Skip to content

Commit

Permalink
Minor fix to AMDs device id tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Aug 20, 2023
1 parent b47a6f4 commit f1ab569
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/ext_amdgpu/get_devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
for ordinal in 0:(length(AMDGPU.devices()) - 1)
amd_device = Flux.get_device("AMD", ordinal)
@test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice
@test AMDGPU.device_id(amd_device.deviceID) == ordinal
@test AMDGPU.device_id(amd_device.deviceID) == ordinal + 1

dense_model = dense_model |> amd_device
@test dense_model.weight isa AMDGPU.ROCArray
@test dense_model.bias isa AMDGPU.ROCArray
@test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal
@test ADMGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal
@test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1
@test ADMGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + 1
@test isequal(Flux.cpu(dense_model.weight), weight)
@test isequal(Flux.cpu(dense_model.bias), bias)
end
Expand Down

0 comments on commit f1ab569

Please sign in to comment.