diff --git a/lib/torch.rb b/lib/torch.rb index 7d8aa2d..bc5246f 100644 --- a/lib/torch.rb +++ b/lib/torch.rb @@ -124,6 +124,7 @@ # nn activations require_relative "torch/nn/elu" +require_relative "torch/nn/gelu" require_relative "torch/nn/hardshrink" require_relative "torch/nn/leaky_relu" require_relative "torch/nn/log_sigmoid" diff --git a/lib/torch/nn/functional.rb b/lib/torch/nn/functional.rb index 3c85d75..472ced8 100644 --- a/lib/torch/nn/functional.rb +++ b/lib/torch/nn/functional.rb @@ -182,6 +182,10 @@ def elu(input, alpha: 1, inplace: false) end end + def gelu(input, approximate: 'none') + NN.gelu(input, approximate: approximate) + end + def hardshrink(input, lambd = 0.5) Torch.hardshrink(input, lambd) end diff --git a/lib/torch/nn/gelu.rb b/lib/torch/nn/gelu.rb new file mode 100644 index 0000000..0b502a6 --- /dev/null +++ b/lib/torch/nn/gelu.rb @@ -0,0 +1,18 @@ +module Torch + module NN + class GELU < Module + def initialize(approximate: 'none') + super() + @approximate = approximate + end + + def forward(input) + F.gelu(input, approximate: @approximate) + end + + def extra_inspect + "approximate: '%{@approximate}'" + end + end + end +end diff --git a/test/nn/activations_test.rb b/test/nn/activations_test.rb index 73d3652..d866863 100644 --- a/test/nn/activations_test.rb +++ b/test/nn/activations_test.rb @@ -72,4 +72,10 @@ def test_tanhshrink input = Torch.randn(2) _output = m.call(input) end + + def test_gelu + m = Torch::NN::GELU.new + input = Torch.randn(2) + _output = m.call(input) + end end