Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove all code related to loading and saving text file network formats. #181

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 11 additions & 240 deletions tf/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,7 @@ def set_input(self, input_format):
elif input_format != pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE:
self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_3

def get_weight_amounts(self):
value_weights = 8
policy_weights = 6
head_weights = value_weights + policy_weights
if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:
# Batch norm gammas in head convolutions.
head_weights += 2
if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:
return {"input": 5, "residual": 14, "head": head_weights}
else:
return {"input": 4, "residual": 8, "head": head_weights}

def fill_layer_v2(self, layer, params):
def fill_layer(self, layer, params):
"""Normalize and populate 16bit layer in protobuf"""
params = params.flatten().astype(np.float32)
layer.min_val = 0 if len(params) == 1 else float(np.min(params))
Expand All @@ -105,112 +93,12 @@ def fill_layer_v2(self, layer, params):
params = np.round(params)
layer.params = params.astype(np.uint16).tobytes()

def fill_layer(self, layer, weights):
"""Normalize and populate 16bit layer in protobuf"""
params = np.array(weights.pop(), dtype=np.float32)
layer.min_val = 0 if len(params) == 1 else float(np.min(params))
layer.max_val = 1 if len(params) == 1 and np.max(
params) == 0 else float(np.max(params))
if layer.max_val == layer.min_val:
# Avoid division by zero if max == min.
params = (params - layer.min_val)
else:
params = (params - layer.min_val) / (layer.max_val - layer.min_val)
params *= 0xffff
params = np.round(params)
layer.params = params.astype(np.uint16).tobytes()

def fill_conv_block(self, convblock, weights, gammas):
"""Normalize and populate 16bit convblock in protobuf"""
if gammas:
self.fill_layer(convblock.bn_stddivs, weights)
self.fill_layer(convblock.bn_means, weights)
self.fill_layer(convblock.bn_betas, weights)
self.fill_layer(convblock.bn_gammas, weights)
self.fill_layer(convblock.weights, weights)
else:
self.fill_layer(convblock.bn_stddivs, weights)
self.fill_layer(convblock.bn_means, weights)
self.fill_layer(convblock.biases, weights)
self.fill_layer(convblock.weights, weights)

def fill_plain_conv(self, convblock, weights):
"""Normalize and populate 16bit convblock in protobuf"""
self.fill_layer(convblock.biases, weights)
self.fill_layer(convblock.weights, weights)

def fill_se_unit(self, se_unit, weights):
self.fill_layer(se_unit.b2, weights)
self.fill_layer(se_unit.w2, weights)
self.fill_layer(se_unit.b1, weights)
self.fill_layer(se_unit.w1, weights)

def denorm_layer_v2(self, layer):
def denorm_layer(self, layer):
"""Denormalize a layer from protobuf"""
params = np.frombuffer(layer.params, np.uint16).astype(np.float32)
params /= 0xffff
return params * (layer.max_val - layer.min_val) + layer.min_val

def denorm_layer(self, layer, weights):
weights.insert(0, self.denorm_layer_v2(layer))

def denorm_conv_block(self, convblock, weights):
"""Denormalize a convblock from protobuf"""
se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT

if se:
self.denorm_layer(convblock.bn_stddivs, weights)
self.denorm_layer(convblock.bn_means, weights)
self.denorm_layer(convblock.bn_betas, weights)
self.denorm_layer(convblock.bn_gammas, weights)
self.denorm_layer(convblock.weights, weights)
else:
self.denorm_layer(convblock.bn_stddivs, weights)
self.denorm_layer(convblock.bn_means, weights)
self.denorm_layer(convblock.biases, weights)
self.denorm_layer(convblock.weights, weights)

def denorm_plain_conv(self, convblock, weights):
"""Denormalize a plain convolution from protobuf"""
self.denorm_layer(convblock.biases, weights)
self.denorm_layer(convblock.weights, weights)

def denorm_se_unit(self, convblock, weights):
"""Denormalize SE-unit from protobuf"""
se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT

assert se

self.denorm_layer(convblock.b2, weights)
self.denorm_layer(convblock.w2, weights)
self.denorm_layer(convblock.b1, weights)
self.denorm_layer(convblock.w1, weights)

def save_txt(self, filename):
"""Save weights as txt file"""
weights = self.get_weights()

if len(filename.split('.')) == 1:
filename += ".txt.gz"

# Legacy .txt files are version 2, SE is version 3.

version = 2
if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:
version = 3

if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:
version = 4

with gzip.open(filename, 'wb') as f:
f.write("{}\n".format(version).encode('utf-8'))
for row in weights:
f.write(
(" ".join(map(str, row.tolist())) + "\n").encode('utf-8'))

size = os.path.getsize(filename) / 1024**2
print("saved as '{}' {}M".format(filename, round(size, 2)))

def save_proto(self, filename):
"""Save weights gzipped protobuf file"""
if len(filename.split('.')) == 1:
Expand Down Expand Up @@ -320,7 +208,7 @@ def moves_left_to_bp(l, w):

return (pb_name, block)

def get_weights_v2(self, names):
def get_weights(self, names):
# `names` is a list of Tensorflow tensor names to get from the protobuf.
# Returns list of [Tensor name, Tensor weights].
tensors = {}
Expand All @@ -346,7 +234,7 @@ def get_weights_v2(self, names):
else:
pb_weights = self.pb.weights.residual[block]

w = self.denorm_layer_v2(nested_getattr(pb_weights, pb_name))
w = self.denorm_layer(nested_getattr(pb_weights, pb_name))

# Only variance is stored in the protobuf.
if 'stddev' in tf_name:
Expand All @@ -355,34 +243,6 @@ def get_weights_v2(self, names):
tensors[tf_name] = w
return tensors

def get_weights(self):
"""Returns the weights as floats per layer"""
se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT
if self.weights == []:
self.denorm_layer(self.pb.weights.ip2_val_b, self.weights)
self.denorm_layer(self.pb.weights.ip2_val_w, self.weights)
self.denorm_layer(self.pb.weights.ip1_val_b, self.weights)
self.denorm_layer(self.pb.weights.ip1_val_w, self.weights)
self.denorm_conv_block(self.pb.weights.value, self.weights)

if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:
self.denorm_plain_conv(self.pb.weights.policy, self.weights)
self.denorm_conv_block(self.pb.weights.policy1, self.weights)
else:
self.denorm_layer(self.pb.weights.ip_pol_b, self.weights)
self.denorm_layer(self.pb.weights.ip_pol_w, self.weights)
self.denorm_conv_block(self.pb.weights.policy, self.weights)

for res in reversed(self.pb.weights.residual):
if se:
self.denorm_se_unit(res.se, self.weights)
self.denorm_conv_block(res.conv2, self.weights)
self.denorm_conv_block(res.conv1, self.weights)

self.denorm_conv_block(self.pb.weights.input, self.weights)

return self.weights

def filters(self):
layer = self.pb.weights.input.bn_means
params = np.frombuffer(layer.params, np.uint16).astype(np.float32)
Expand Down Expand Up @@ -414,27 +274,7 @@ def parse_proto(self, filename):
self.set_policyformat(pb.NetworkFormat.POLICY_CLASSICAL)
self.set_movesleftformat(pb.NetworkFormat.MOVES_LEFT_NONE)

def parse_txt(self, filename):
weights = []

with open(filename, 'r') as f:
try:
version = int(f.readline()[0])
except:
raise ValueError('Unable to read version.')
for e, line in enumerate(f):
weights.append(list(map(float, line.split(' '))))

if version == 3:
self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)

if version == 4:
self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)
self.set_policyformat(pb.NetworkFormat.POLICY_CONVOLUTION)

self.fill_net(weights)

def fill_net_v2(self, all_weights):
def fill_net(self, all_weights):
# all_weights is array of [name of weight, numpy array of weights].
self.pb.format.weights_encoding = pb.Format.LINEAR16

Expand Down Expand Up @@ -493,7 +333,7 @@ def fill_net_v2(self, all_weights):
self.pb.weights.residual.add()
pb_weights = self.pb.weights.residual[block]

self.fill_layer_v2(nested_getattr(pb_weights, pb_name), weights)
self.fill_layer(nested_getattr(pb_weights, pb_name), weights)

if pb_name.endswith('bn_betas'):
# Check if we need to add constant one gammas.
Expand All @@ -502,49 +342,7 @@ def fill_net_v2(self, all_weights):
continue
gamma = np.ones(weights.shape)
pb_gamma = pb_name.replace('bn_betas', 'bn_gammas')
self.fill_layer_v2(nested_getattr(pb_weights, pb_gamma), gamma)

def fill_net(self, weights):
self.weights = []
# Batchnorm gammas in ConvBlock?
se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT
gammas = se

ws = self.get_weight_amounts()

blocks = len(weights) - (ws['input'] + ws['head'])

if blocks % ws['residual'] != 0:
raise ValueError("Inconsistent number of weights in the file")
blocks //= ws['residual']

self.pb.format.weights_encoding = pb.Format.LINEAR16
self.fill_layer(self.pb.weights.ip2_val_b, weights)
self.fill_layer(self.pb.weights.ip2_val_w, weights)
self.fill_layer(self.pb.weights.ip1_val_b, weights)
self.fill_layer(self.pb.weights.ip1_val_w, weights)
self.fill_conv_block(self.pb.weights.value, weights, gammas)

if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:
self.fill_plain_conv(self.pb.weights.policy, weights)
self.fill_conv_block(self.pb.weights.policy1, weights, gammas)
else:
self.fill_layer(self.pb.weights.ip_pol_b, weights)
self.fill_layer(self.pb.weights.ip_pol_w, weights)
self.fill_conv_block(self.pb.weights.policy, weights, gammas)

del self.pb.weights.residual[:]
tower = []
for i in range(blocks):
tower.append(self.pb.weights.residual.add())

for res in reversed(tower):
if se:
self.fill_se_unit(res.se, weights)
self.fill_conv_block(res.conv2, weights, gammas)
self.fill_conv_block(res.conv1, weights, gammas)

self.fill_conv_block(self.pb.weights.input, weights, gammas)
self.fill_layer(nested_getattr(pb_weights, pb_gamma), gamma)


def print_pb_stats(obj, parent=None):
Expand All @@ -566,42 +364,15 @@ def print_pb_stats(obj, parent=None):

def main(argv):
net = Net()

if argv.input.endswith(".txt"):
print('Found .txt network')
net.parse_txt(argv.input)
net.print_stats()
if argv.output == None:
argv.output = argv.input.replace('.txt', '.pb.gz')
assert argv.output.endswith('.pb.gz')
print('Writing output to: {}'.format(argv.output))
net.save_proto(argv.output)
elif argv.input.endswith(".pb.gz"):
print('Found .pb.gz network')
net.parse_proto(argv.input)
net.print_stats()
if argv.output == None:
argv.output = argv.input.replace('.pb.gz', '.txt.gz')
print('Writing output to: {}'.format(argv.output))
assert argv.output.endswith('.txt.gz')
if argv.output.endswith(".pb.gz"):
net.save_proto(argv.output)
else:
net.save_txt(argv.output)
else:
print('Unable to detect the network format. '
'Filename should end in ".txt" or ".pb.gz"')
net.parse_proto(argv.input)
net.print_stats()


if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description='Convert network textfile to proto.')
description='Print net stats')
argparser.add_argument('-i',
'--input',
type=str,
help='input network weight text file')
argparser.add_argument('-o',
'--output',
type=str,
help='output filepath without extension')
help='input network weight file')
main(argparser.parse_args())
4 changes: 2 additions & 2 deletions tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def replace_weights(self, proto_filename, ignore_errors=False):
for weight in self.model.weights:
names.append(weight.name)

new_weights = self.net.get_weights_v2(names)
new_weights = self.net.get_weights(names)
for weight in self.model.weights:
if 'renorm' in weight.name:
# Renorm variables are not populated.
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def save_leelaz_weights(self, filename):
numpy_weights = []
for weight in self.model.weights:
numpy_weights.append([weight.name, weight.numpy()])
self.net.fill_net_v2(numpy_weights)
self.net.fill_net(numpy_weights)
self.net.save_proto(filename)

def batch_norm(self, input, name, scale=False):
Expand Down