Skip to content

Commit

Permalink
Added GELU activation. (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtortonesi authored Apr 25, 2024
1 parent 4b09248 commit 5b347b1
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/torch.rb
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@

# nn activations
require_relative "torch/nn/elu"
require_relative "torch/nn/gelu"
require_relative "torch/nn/hardshrink"
require_relative "torch/nn/leaky_relu"
require_relative "torch/nn/log_sigmoid"
Expand Down
4 changes: 4 additions & 0 deletions lib/torch/nn/functional.rb
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def elu(input, alpha: 1, inplace: false)
end
end

def gelu(input, approximate: 'none')
NN.gelu(input, approximate: approximate)
end

def hardshrink(input, lambd = 0.5)
Torch.hardshrink(input, lambd)
end
Expand Down
18 changes: 18 additions & 0 deletions lib/torch/nn/gelu.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module Torch
module NN
class GELU < Module
def initialize(approximate: 'none')
super()
@approximate = approximate
end

def forward(input)
F.gelu(input, approximate: @approximate)
end

def extra_inspect
"approximate: '%{@approximate}'"
end
end
end
end
6 changes: 6 additions & 0 deletions test/nn/activations_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,10 @@ def test_tanhshrink
input = Torch.randn(2)
_output = m.call(input)
end

def test_gelu
m = Torch::NN::GELU.new
input = Torch.randn(2)
_output = m.call(input)
end
end

0 comments on commit 5b347b1

Please sign in to comment.