Skip to content

Commit

Permalink
Seperate SafeCheckpointReaderHandle.
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Mar 3, 2023
1 parent 559d471 commit 016294d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 84 deletions.
137 changes: 53 additions & 84 deletions src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
Original file line number Diff line number Diff line change
@@ -1,100 +1,69 @@
using Tensorflow.Util;
namespace Tensorflow.Checkpoint;

namespace Tensorflow.Checkpoint
public class CheckpointReader
{
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
{
public SafeCheckpointReaderHandle(): base()
{

}
public SafeCheckpointReaderHandle(IntPtr handle): base(handle)
{
private SafeCheckpointReaderHandle _handle;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; }

}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
public class CheckpointReader
public CheckpointReader(string filename)
{
private SafeCheckpointReaderHandle _handle;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; }

public CheckpointReader(string filename)
{
Status status = new Status();
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true);
ReadAllShapeAndType();
}
Status status = new Status();
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
VariableToShapeMap = new Dictionary<string, Shape>();
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true);
ReadAllShapeAndType();
}

public int HasTensor(string name)
{
return c_api.TF_CheckpointReaderHasTensor(_handle, name);
}
public int HasTensor(string name)
=> c_api.TF_CheckpointReaderHasTensor(_handle, name);

/// <summary>
/// Get the variable name.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public string GetVariable(int index)
{
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));
}
/// <summary>
/// Get the variable name.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public string GetVariable(int index)
=> c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));

public int Size()
{
return c_api.TF_CheckpointReaderSize(_handle);
}
public int Size()
=> c_api.TF_CheckpointReaderSize(_handle);

public TF_DataType GetVariableDataType(string name)
{
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);
}
public TF_DataType GetVariableDataType(string name)
=> c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);

public Shape GetVariableShape(string name)
{
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
status.Check(true);
return new Shape(dims);
}
public Shape GetVariableShape(string name)
{
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
status.Check(true);
return new Shape(dims);
}

public int GetVariableNumDims(string name)
{
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);
}
public int GetVariableNumDims(string name)
=> c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);

public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
status.Check(true);
return new Tensor(tensor);
}
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
status.Check(true);
return new Tensor(tensor);
}

private void ReadAllShapeAndType()
private void ReadAllShapeAndType()
{
int size = Size();
for(int i = 0; i < size; i++)
{
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
VariableToShapeMap = new Dictionary<string, Shape>();
int size = Size();
for(int i = 0; i < size; i++)
{
var name = GetVariable(i);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
VariableToDataTypeMap[name] = dtype;
VariableToShapeMap[name] = shape;
}
var name = GetVariable(i);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
VariableToDataTypeMap[name] = dtype;
VariableToShapeMap[name] = shape;
}
}
}
21 changes: 21 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Tensorflow.Util;

namespace Tensorflow.Checkpoint;

public sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
{
private SafeCheckpointReaderHandle() : base ()
{
}

public SafeCheckpointReaderHandle(IntPtr handle) : base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
SetHandle(IntPtr.Zero);
return true;
}
}

0 comments on commit 016294d

Please sign in to comment.