-
Notifications
You must be signed in to change notification settings - Fork 13
/
Sampler.lua
59 lines (43 loc) · 1.67 KB
/
Sampler.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
------------------------------------------------------------------------
--[[ Sampler ]]--
-- hyper parameter sampling distributions
------------------------------------------------------------------------
local Sampler = torch.class("hypero.Sampler")
-- sample from a categorical distribution
function Sampler:categorical(probs, vals)
assert(torch.type(probs) == 'table', "Expecting table of probabilites, got :"..tostring(probs))
local probs = torch.Tensor(probs)
local idx = torch.multinomial(probs, 1)[1]
local val = vals and vals[idx] or idx
return val
end
-- sample from a normal distribution
function Sampler:normal(mean, std)
assert(torch.type(mean) == 'number')
assert(torch.type(std) == 'number')
local val = torch.normal(mean, std)
return val
end
-- sample from uniform distribution
function Sampler:uniform(minval, maxval)
assert(torch.type(minval) == 'number')
assert(torch.type(maxval) == 'number')
local val = torch.uniform(minval, maxval)
return val
end
-- Returns a value drawn according to exp(uniform(low, high))
-- so that the logarithm of the return value is uniformly distributed.
-- When optimizing, this variable is constrained to the interval [exp(low), exp(high)].
function Sampler:logUniform(minval, maxval)
assert(torch.type(minval) == 'number')
assert(torch.type(maxval) == 'number')
local val = torch.exp(torch.uniform(minval, maxval))
return val
end
-- sample from uniform integer distribution
function Sampler:randint(minval, maxval)
assert(torch.type(minval) == 'number')
assert(torch.type(maxval) == 'number')
local val = math.random(minval, maxval)
return val
end