-
Notifications
You must be signed in to change notification settings - Fork 523
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Seperate SafeCheckpointReaderHandle.
- Loading branch information
1 parent
559d471
commit 016294d
Showing
2 changed files
with
74 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,100 +1,69 @@ | ||
using Tensorflow.Util; | ||
namespace Tensorflow.Checkpoint; | ||
|
||
namespace Tensorflow.Checkpoint | ||
public class CheckpointReader | ||
{ | ||
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle | ||
{ | ||
public SafeCheckpointReaderHandle(): base() | ||
{ | ||
|
||
} | ||
public SafeCheckpointReaderHandle(IntPtr handle): base(handle) | ||
{ | ||
private SafeCheckpointReaderHandle _handle; | ||
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | ||
public Dictionary<string, Shape> VariableToShapeMap { get; set; } | ||
|
||
} | ||
|
||
protected override bool ReleaseHandle() | ||
{ | ||
c_api.TF_DeleteCheckpointReader(handle); | ||
SetHandle(IntPtr.Zero); | ||
return true; | ||
} | ||
} | ||
public class CheckpointReader | ||
public CheckpointReader(string filename) | ||
{ | ||
private SafeCheckpointReaderHandle _handle; | ||
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | ||
public Dictionary<string, Shape> VariableToShapeMap { get; set; } | ||
|
||
public CheckpointReader(string filename) | ||
{ | ||
Status status = new Status(); | ||
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | ||
status.Check(true); | ||
ReadAllShapeAndType(); | ||
} | ||
Status status = new Status(); | ||
VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | ||
VariableToShapeMap = new Dictionary<string, Shape>(); | ||
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | ||
status.Check(true); | ||
ReadAllShapeAndType(); | ||
} | ||
|
||
public int HasTensor(string name) | ||
{ | ||
return c_api.TF_CheckpointReaderHasTensor(_handle, name); | ||
} | ||
public int HasTensor(string name) | ||
=> c_api.TF_CheckpointReaderHasTensor(_handle, name); | ||
|
||
/// <summary> | ||
/// Get the variable name. | ||
/// </summary> | ||
/// <param name="index"></param> | ||
/// <returns></returns> | ||
public string GetVariable(int index) | ||
{ | ||
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); | ||
} | ||
/// <summary> | ||
/// Get the variable name. | ||
/// </summary> | ||
/// <param name="index"></param> | ||
/// <returns></returns> | ||
public string GetVariable(int index) | ||
=> c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); | ||
|
||
public int Size() | ||
{ | ||
return c_api.TF_CheckpointReaderSize(_handle); | ||
} | ||
public int Size() | ||
=> c_api.TF_CheckpointReaderSize(_handle); | ||
|
||
public TF_DataType GetVariableDataType(string name) | ||
{ | ||
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); | ||
} | ||
public TF_DataType GetVariableDataType(string name) | ||
=> c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); | ||
|
||
public Shape GetVariableShape(string name) | ||
{ | ||
int num_dims = GetVariableNumDims(name); | ||
long[] dims = new long[num_dims]; | ||
Status status = new Status(); | ||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | ||
status.Check(true); | ||
return new Shape(dims); | ||
} | ||
public Shape GetVariableShape(string name) | ||
{ | ||
int num_dims = GetVariableNumDims(name); | ||
long[] dims = new long[num_dims]; | ||
Status status = new Status(); | ||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | ||
status.Check(true); | ||
return new Shape(dims); | ||
} | ||
|
||
public int GetVariableNumDims(string name) | ||
{ | ||
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); | ||
} | ||
public int GetVariableNumDims(string name) | ||
=> c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); | ||
|
||
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | ||
{ | ||
Status status = new Status(); | ||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | ||
status.Check(true); | ||
return new Tensor(tensor); | ||
} | ||
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | ||
{ | ||
Status status = new Status(); | ||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | ||
status.Check(true); | ||
return new Tensor(tensor); | ||
} | ||
|
||
private void ReadAllShapeAndType() | ||
private void ReadAllShapeAndType() | ||
{ | ||
int size = Size(); | ||
for(int i = 0; i < size; i++) | ||
{ | ||
VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | ||
VariableToShapeMap = new Dictionary<string, Shape>(); | ||
int size = Size(); | ||
for(int i = 0; i < size; i++) | ||
{ | ||
var name = GetVariable(i); | ||
var shape = GetVariableShape(name); | ||
var dtype = GetVariableDataType(name); | ||
VariableToDataTypeMap[name] = dtype; | ||
VariableToShapeMap[name] = shape; | ||
} | ||
var name = GetVariable(i); | ||
var shape = GetVariableShape(name); | ||
var dtype = GetVariableDataType(name); | ||
VariableToDataTypeMap[name] = dtype; | ||
VariableToShapeMap[name] = shape; | ||
} | ||
} | ||
} |
21 changes: 21 additions & 0 deletions
21
src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
using Tensorflow.Util; | ||
|
||
namespace Tensorflow.Checkpoint; | ||
|
||
public sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle | ||
{ | ||
private SafeCheckpointReaderHandle() : base () | ||
{ | ||
} | ||
|
||
public SafeCheckpointReaderHandle(IntPtr handle) : base(handle) | ||
{ | ||
} | ||
|
||
protected override bool ReleaseHandle() | ||
{ | ||
c_api.TF_DeleteCheckpointReader(handle); | ||
SetHandle(IntPtr.Zero); | ||
return true; | ||
} | ||
} |