diff --git a/CHANGELOG.md b/CHANGELOG.md index f061397..91cdd79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.15.0 (unreleased) - Updated LibTorch to 2.2.0 +- Fixed error with `inspect` for MPS tensors ## 0.14.1 (2023-12-26) diff --git a/ext/torch/tensor.cpp b/ext/torch/tensor.cpp index 8cc2677..4f3e735 100644 --- a/ext/torch/tensor.cpp +++ b/ext/torch/tensor.cpp @@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions rb_cTensor .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); }) + .define_method("mps?", [](Tensor& self) { return self.is_mps(); }) .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); }) .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); }) .define_method("dim", [](Tensor& self) { return self.dim(); }) diff --git a/lib/torch/inspector.rb b/lib/torch/inspector.rb index d91e9ea..c72026c 100644 --- a/lib/torch/inspector.rb +++ b/lib/torch/inspector.rb @@ -31,9 +31,9 @@ def initialize(tensor) return if nonzero_finite_vals.numel == 0 # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. - nonzero_finite_abs = nonzero_finite_vals.abs.double - nonzero_finite_min = nonzero_finite_abs.min.double - nonzero_finite_max = nonzero_finite_abs.max.double + nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs) + nonzero_finite_min = tensor_totype(nonzero_finite_abs.min) + nonzero_finite_max = tensor_totype(nonzero_finite_abs.max) nonzero_finite_vals.each do |value| if value.item != value.item.ceil @@ -107,6 +107,11 @@ def format(value) # Ruby throws error when negative, Python doesn't " " * [@max_width - ret.size, 0].max + ret end + + def tensor_totype(t) + dtype = t.mps? ? :float : :double + t.to(dtype: dtype) + end end def inspect