From 016294d33fe7a8600561fa14e659f7651bbf7fde Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Fri, 3 Mar 2023 10:35:58 -0600 Subject: [PATCH] Seperate SafeCheckpointReaderHandle. --- .../Checkpoint/CheckpointReader.cs | 137 +++++++----------- .../Checkpoint/SafeCheckpointReaderHandle.cs | 21 +++ 2 files changed, 74 insertions(+), 84 deletions(-) create mode 100644 src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs index 0cc8e5fbd..ffefe3128 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -1,100 +1,69 @@ -using Tensorflow.Util; +namespace Tensorflow.Checkpoint; -namespace Tensorflow.Checkpoint +public class CheckpointReader { - sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle - { - public SafeCheckpointReaderHandle(): base() - { - - } - public SafeCheckpointReaderHandle(IntPtr handle): base(handle) - { + private SafeCheckpointReaderHandle _handle; + public Dictionary VariableToDataTypeMap { get; set; } + public Dictionary VariableToShapeMap { get; set; } - } - - protected override bool ReleaseHandle() - { - c_api.TF_DeleteCheckpointReader(handle); - SetHandle(IntPtr.Zero); - return true; - } - } - public class CheckpointReader + public CheckpointReader(string filename) { - private SafeCheckpointReaderHandle _handle; - public Dictionary VariableToDataTypeMap { get; set; } - public Dictionary VariableToShapeMap { get; set; } - - public CheckpointReader(string filename) - { - Status status = new Status(); - _handle = c_api.TF_NewCheckpointReader(filename, status.Handle); - status.Check(true); - ReadAllShapeAndType(); - } + Status status = new Status(); + VariableToDataTypeMap = new Dictionary(); + VariableToShapeMap = new Dictionary(); + _handle = c_api.TF_NewCheckpointReader(filename, status.Handle); + status.Check(true); + ReadAllShapeAndType(); + } - public int HasTensor(string name) - { - return c_api.TF_CheckpointReaderHasTensor(_handle, name); - } + public int HasTensor(string name) + => c_api.TF_CheckpointReaderHasTensor(_handle, name); - /// - /// Get the variable name. - /// - /// - /// - public string GetVariable(int index) - { - return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); - } + /// + /// Get the variable name. + /// + /// + /// + public string GetVariable(int index) + => c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); - public int Size() - { - return c_api.TF_CheckpointReaderSize(_handle); - } + public int Size() + => c_api.TF_CheckpointReaderSize(_handle); - public TF_DataType GetVariableDataType(string name) - { - return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); - } + public TF_DataType GetVariableDataType(string name) + => c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); - public Shape GetVariableShape(string name) - { - int num_dims = GetVariableNumDims(name); - long[] dims = new long[num_dims]; - Status status = new Status(); - c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); - status.Check(true); - return new Shape(dims); - } + public Shape GetVariableShape(string name) + { + int num_dims = GetVariableNumDims(name); + long[] dims = new long[num_dims]; + Status status = new Status(); + c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); + status.Check(true); + return new Shape(dims); + } - public int GetVariableNumDims(string name) - { - return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); - } + public int GetVariableNumDims(string name) + => c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); - public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) - { - Status status = new Status(); - var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); - status.Check(true); - return new Tensor(tensor); - } + public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) + { + Status status = new Status(); + var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); + status.Check(true); + return new Tensor(tensor); + } - private void ReadAllShapeAndType() + private void ReadAllShapeAndType() + { + int size = Size(); + for(int i = 0; i < size; i++) { - VariableToDataTypeMap = new Dictionary(); - VariableToShapeMap = new Dictionary(); - int size = Size(); - for(int i = 0; i < size; i++) - { - var name = GetVariable(i); - var shape = GetVariableShape(name); - var dtype = GetVariableDataType(name); - VariableToDataTypeMap[name] = dtype; - VariableToShapeMap[name] = shape; - } + var name = GetVariable(i); + var shape = GetVariableShape(name); + var dtype = GetVariableDataType(name); + VariableToDataTypeMap[name] = dtype; + VariableToShapeMap[name] = shape; } } } diff --git a/src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs b/src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs new file mode 100644 index 000000000..674e83512 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs @@ -0,0 +1,21 @@ +using Tensorflow.Util; + +namespace Tensorflow.Checkpoint; + +public sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle +{ + private SafeCheckpointReaderHandle() : base () + { + } + + public SafeCheckpointReaderHandle(IntPtr handle) : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteCheckpointReader(handle); + SetHandle(IntPtr.Zero); + return true; + } +}