forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BinaryTextClassification.cs
164 lines (134 loc) · 5.68 KB
/
BinaryTextClassification.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
using System;
using System.Collections.Generic;
using System.IO;
using Tensorflow;
using Newtonsoft.Json;
using System.Linq;
using NumSharp;
namespace TensorFlowNET.Examples
{
/// <summary>
/// This example classifies movie reviews as positive or negative using the text of the review.
/// This is a binary—or two-class—classification, an important and widely applicable kind of machine learning problem.
/// https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/basic_text_classification.ipynb
/// </summary>
public class BinaryTextClassification : IExample
{
public bool Enabled { get; set; } = false;
public string Name => "Binary Text Classification";
public bool IsImportingGraph { get; set; } = true;
string dir = "binary_text_classification";
string dataFile = "imdb.zip";
NDArray train_data, train_labels, test_data, test_labels;
public bool Run()
{
PrepareData();
Console.WriteLine($"Training entries: {train_data.shape[0]}, labels: {train_labels.shape[0]}");
// A dictionary mapping words to an integer index
var word_index = GetWordIndex();
/*train_data = keras.preprocessing.sequence.pad_sequences(train_data,
value: word_index["<PAD>"],
padding: "post",
maxlen: 256);
test_data = keras.preprocessing.sequence.pad_sequences(test_data,
value: word_index["<PAD>"],
padding: "post",
maxlen: 256);*/
// input shape is the vocabulary count used for the movie reviews (10,000 words)
int vocab_size = 10000;
var model = keras.Sequential();
var layer = keras.layers.Embedding(vocab_size, 16);
model.add(layer);
return false;
}
public void PrepareData()
{
Directory.CreateDirectory(dir);
// get model file
string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
Utility.Web.Download(url, dir, "imdb.zip");
Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);
// prepare training dataset
var x_train = ReadData(Path.Join(dir, "x_train.txt"));
var labels_train = ReadData(Path.Join(dir, "y_train.txt"));
var indices_train = ReadData(Path.Join(dir, "indices_train.txt"));
x_train = x_train[indices_train];
labels_train = labels_train[indices_train];
var x_test = ReadData(Path.Join(dir, "x_test.txt"));
var labels_test = ReadData(Path.Join(dir, "y_test.txt"));
var indices_test = ReadData(Path.Join(dir, "indices_test.txt"));
x_test = x_test[indices_test];
labels_test = labels_test[indices_test];
// not completed
var xs = x_train.hstack(x_test);
var labels = labels_train.hstack(labels_test);
var idx = x_train.size;
var y_train = labels_train;
var y_test = labels_test;
// convert x_train
train_data = new NDArray(np.int32, (x_train.size, 256));
/*for (int i = 0; i < x_train.size; i++)
train_data[i] = x_train[i].Data<string>()[1].Split(',').Select(x => int.Parse(x)).ToArray();*/
test_data = new NDArray(np.int32, (x_test.size, 256));
/*for (int i = 0; i < x_test.size; i++)
test_data[i] = x_test[i].Data<string>()[1].Split(',').Select(x => int.Parse(x)).ToArray();*/
train_labels = y_train;
test_labels = y_test;
}
private NDArray ReadData(string file)
{
var lines = File.ReadAllLines(file);
var nd = new NDArray(lines[0].StartsWith("[") ? typeof(string) : np.int32, new Shape(lines.Length));
if (lines[0].StartsWith("["))
{
for (int i = 0; i < lines.Length; i++)
{
/*var matches = Regex.Matches(lines[i], @"\d+\s*");
var data = new int[matches.Count];
for (int j = 0; j < data.Length; j++)
data[j] = Convert.ToInt32(matches[j].Value);
nd[i] = data.ToArray();*/
nd[i] = lines[i].Substring(1, lines[i].Length - 2).Replace(" ", string.Empty);
}
}
else
{
for (int i = 0; i < lines.Length; i++)
nd[i] = Convert.ToInt32(lines[i]);
}
return nd;
}
private Dictionary<string, int> GetWordIndex()
{
var result = new Dictionary<string, int>();
var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json"));
var dict = JsonConvert.DeserializeObject<Dictionary<string, int>>(json);
dict.Keys.Select(k => result[k] = dict[k] + 3).ToList();
result["<PAD>"] = 0;
result["<START>"] = 1;
result["<UNK>"] = 2; // unknown
result["<UNUSED>"] = 3;
return result;
}
public Graph ImportGraph()
{
throw new NotImplementedException();
}
public Graph BuildGraph()
{
throw new NotImplementedException();
}
public void Train(Session sess)
{
throw new NotImplementedException();
}
public void Predict(Session sess)
{
throw new NotImplementedException();
}
public void Test(Session sess)
{
throw new NotImplementedException();
}
}
}