This repository has been archived by the owner on Jan 13, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
init.lua
67 lines (59 loc) · 1.68 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
--
-- Copyright (c) 2015, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
require 'torch'
local ffi = require 'ffi'
local argcheck = require 'argcheck'
local C = require 'ztorch.THZ'
local ztorch = require 'ztorch.env'
ztorch.complex = require 'ztorch.complex'
ztorch.fcomplex = require 'ztorch.fcomplex'
function ztorch.isComplex(v)
return ffi.istype(ztorch.complex.type, v) or ffi.istype(ztorch.fcomplex.type, v)
end
function ztorch.im(im)
return ffi.new('complex', 0, im)
end
local argcheckenv = require 'argcheck.env'
function argcheckenv.istype(obj, typename)
local tname = torch.typename(obj)
return tname and tname == typename or type(obj) == typename
end
require 'ztorch.Storage'
require 'ztorch.Tensor'
ztorch.re = argcheck{
{name='value', type='number'},
nonamed=true,
call =
function(value)
return ffi.new('complex', value, 0)
end
}
ztorch.im = argcheck{
{name='value', type='number'},
nonamed=true,
call =
function(value)
return ffi.new('complex', 0, value)
end
}
-- HACK: until we get torch.isTypeOf to work with complex tensors and storages
function torch.isTensor(obj)
local typename = torch.typename(obj)
if typename and typename:find('torch.*Tensor') then
return true
end
return false
end
function torch.isStorage(obj)
local typename = torch.typename(obj)
if typename and typename:find('torch.*Storage') then
return true
end
return false
end
return ztorch