Skip to content

Commit 02ce65b

Browse files
committed
loading of hdf5 files from keras #554
1 parent c04216f commit 02ce65b

File tree

5 files changed

+119
-103
lines changed

5 files changed

+119
-103
lines changed

src/TensorFlowNET.Core/ops.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,5 +532,13 @@ public static string get_name_scope()
532532
var g = get_default_graph();
533533
return g.get_name_scope();
534534
}
535+
536+
public static bool executing_eagerly_outside_functions()
537+
{
538+
if (tf.Context.executing_eagerly())
539+
return true;
540+
else
541+
throw new NotImplementedException("");
542+
}
535543
}
536544
}

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,19 @@ public void set_learning_phase(bool value)
142142
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
143143
}
144144

145+
public void batch_set_value(List<(IVariableV1, NDArray)> tuples)
146+
{
147+
if (ops.executing_eagerly_outside_functions())
148+
{
149+
foreach (var (x, value) in tuples)
150+
x.assign(value);
151+
}
152+
else
153+
{
154+
throw new NotImplementedException("");
155+
}
156+
}
157+
145158
/// <summary>
146159
/// Pads the 2nd and 3rd dimensions of a 4D tensor.
147160
/// </summary>

src/TensorFlowNET.Keras/Engine/Model.Training.cs

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,48 +9,28 @@ namespace Tensorflow.Keras.Engine
99
{
1010
public partial class Model
1111
{
12-
private long fileId = -1;
13-
private long f = -1;
14-
public void load_weights(string filepath ="",bool by_name= false, bool skip_mismatch=false, object options = null)
12+
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
1513
{
16-
long root = Hdf5.OpenFile(filepath, true);
14+
long fileId = Hdf5.OpenFile(filepath, true);
1715

18-
long fileId = root;
19-
//try
20-
//{
16+
bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
17+
bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");
2118

22-
bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
23-
bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");
24-
25-
if (!lsuccess && msuccess)
26-
{
27-
f = H5G.open(fileId, "model_weights");
28-
29-
}
30-
if (by_name)
31-
{
32-
//fdf5_format.load_weights_from_hdf5_group_by_name();
33-
}
34-
else
35-
{
36-
fdf5_format.load_weights_from_hdf5_group(f, this);
37-
}
38-
H5G.close(f);
39-
//}
40-
//catch (Exception ex)
41-
//{
42-
// if (fileId != -1)
43-
// {
44-
// Hdf5.CloseFile(fileId);
45-
// }
46-
// if (f != -1)
47-
// {
48-
// H5G.close(f);
49-
// }
50-
// throw new Exception(ex.ToString());
51-
//}
19+
if (!lsuccess && msuccess)
20+
{
21+
fileId = H5G.open(fileId, "model_weights");
22+
}
23+
if (by_name)
24+
{
25+
//fdf5_format.load_weights_from_hdf5_group_by_name();
26+
throw new NotImplementedException("");
27+
}
28+
else
29+
{
30+
fdf5_format.load_weights_from_hdf5_group(fileId, Layers);
31+
}
32+
H5G.close(fileId);
5233
}
53-
5434
}
5535
}
5636

src/TensorFlowNET.Keras/Saving/fdf5_format.cs

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
using HDF5CSharp;
88
using static Tensorflow.Binding;
99
using static Tensorflow.KerasApi;
10+
using System.Linq;
11+
1012
namespace Tensorflow.Keras.Saving
1113
{
1214
public class fdf5_format
@@ -45,13 +47,29 @@ public static void save_model_to_hdf5(long filepath = -1, Dictionary<string, obj
4547
{
4648

4749
}
48-
public static void preprocess_weights_for_loading(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
49-
{
5050

51-
}
52-
public static void _convert_rnn_weights(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
51+
/// <summary>
52+
/// Preprocess layer weights between different Keras formats.
53+
/// </summary>
54+
/// <param name="layer"></param>
55+
/// <param name="weights"></param>
56+
/// <param name="original_keras_version"></param>
57+
/// <param name="original_backend"></param>
58+
public static List<NDArray> preprocess_weights_for_loading(ILayer layer, List<NDArray> weights, string original_keras_version = null, string original_backend = null)
5359
{
60+
// convert CuDNN layers
61+
return _convert_rnn_weights(layer, weights);
62+
}
5463

64+
/// <summary>
65+
/// Converts weights for RNN layers between native and CuDNN format.
66+
/// </summary>
67+
/// <param name="layer"></param>
68+
/// <param name="weights"></param>
69+
static List<NDArray> _convert_rnn_weights(ILayer layer, List<NDArray> weights)
70+
{
71+
var target_class = layer.GetType().Name;
72+
return weights;
5573
}
5674
public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
5775
{
@@ -65,56 +83,79 @@ public static void save_weights_to_hdf5_group(long filepath = -1, Dictionary<str
6583
{
6684

6785
}
68-
public static void load_weights_from_hdf5_group(long f=-1,Model model=null)
86+
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
6987
{
70-
string original_keras_version = "1";
88+
string original_keras_version = "2.4.0";
7189
string original_backend = null;
7290
if (Hdf5.AttributeExists(f, "keras_version"))
7391
{
74-
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
92+
var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
7593
if (success)
76-
{
77-
original_keras_version = attr[0];
78-
}
94+
original_keras_version = attr.First();
95+
// keras version should be 2.5.0+
96+
var ver_major = int.Parse(original_keras_version.Split('.')[0]);
97+
var ver_minor = int.Parse(original_keras_version.Split('.')[1]);
98+
if (ver_major < 2 || (ver_major == 2 && ver_minor < 5))
99+
throw new ValueError("keras version should be 2.5.0 or later.");
79100
}
80101
if (Hdf5.AttributeExists(f, "backend"))
81102
{
82-
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "backend", "");
103+
var (success, attr) = Hdf5.ReadStringAttributes(f, "backend", "");
83104
if (success)
84-
{
85-
original_backend = attr[0];
86-
}
105+
original_backend = attr.First();
87106
}
88107
List<ILayer> filtered_layers = new List<ILayer>();
89-
List<Tensor> weights;
90-
foreach (var layer in model.Layers)
108+
List<IVariableV1> weights;
109+
foreach (var layer in layers)
91110
{
92111
weights = _legacy_weights(layer);
93-
if (weights.Count>0)
112+
if (weights.Count > 0)
94113
{
95114
filtered_layers.append(layer);
96115
}
97116
}
98-
string[] layer_names = load_attributes_from_hdf5_group(f,"layer_names");
99-
List<NDArray> weight_values=new List<NDArray>();
100-
foreach (var i in filtered_layers) {
101-
long g = H5G.open(f, i.Name);
102-
string[] weight_names = null;
103-
if (g != -1)
104-
{
105-
weight_names = load_attributes_from_hdf5_group(g, "weight_names");
106-
}
107-
if (weight_names != null)
117+
string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names");
118+
var filtered_layer_names = new List<string>();
119+
foreach(var name in layer_names)
120+
{
121+
long g = H5G.open(f, name);
122+
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
123+
if (weight_names.Count() > 0)
124+
filtered_layer_names.Add(name);
125+
H5G.close(g);
126+
}
127+
layer_names = filtered_layer_names.ToArray();
128+
if (layer_names.Length != filtered_layers.Count())
129+
throw new ValueError("You are trying to load a weight file " +
130+
$"containing {layer_names}" +
131+
$" layers into a model with {filtered_layers.Count} layers.");
132+
133+
var weight_value_tuples = new List<(IVariableV1, NDArray)>();
134+
foreach (var (k, name) in enumerate(layer_names))
135+
{
136+
var weight_values = new List<NDArray>();
137+
long g = H5G.open(f, name);
138+
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
139+
foreach (var i_ in weight_names)
108140
{
109-
foreach (var i_ in weight_names) {
110-
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
111-
//
141+
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
142+
if (success)
112143
weight_values.Add(np.array(result));
113-
}
114144
}
115145
H5G.close(g);
146+
var layer = filtered_layers[k];
147+
var symbolic_weights = _legacy_weights(layer);
148+
preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend);
149+
if (weight_values.Count() != symbolic_weights.Count())
150+
throw new ValueError($"Layer #{k} (named {layer.Name}" +
151+
"in the current model) was found to " +
152+
$"correspond to layer {name} in the save file." +
153+
$"However the new layer {layer.Name} expects " +
154+
$"{symbolic_weights.Count()} weights, but the saved weights have " +
155+
$"{weight_values.Count()} elements.");
156+
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
116157
}
117-
158+
keras.backend.batch_set_value(weight_value_tuples);
118159
}
119160
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
120161
{
@@ -128,15 +169,13 @@ public static void save_attributes_to_hdf5_group(long filepath = -1, Dictionary<
128169
{
129170

130171
}
131-
public static string[] load_attributes_from_hdf5_group(long f = -1, string name = "")
172+
public static string[] load_attributes_from_hdf5_group(long group, string name)
132173
{
133-
if (Hdf5.AttributeExists(f, name))
174+
if (Hdf5.AttributeExists(group, name))
134175
{
135-
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, name, "");
176+
var (success, attr) = Hdf5.ReadStringAttributes(group, name, "");
136177
if (success)
137-
{
138-
return attr;
139-
}
178+
return attr.ToArray();
140179
}
141180
return null;
142181
}
@@ -145,33 +184,10 @@ public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionar
145184

146185
}
147186

148-
public static List<Tensor> _legacy_weights(ILayer layer)
187+
public static List<IVariableV1> _legacy_weights(ILayer layer)
149188
{
150-
151-
List<Tensor> weights= new List<Tensor>();
152-
if (layer.trainable_weights.Count != 0)
153-
{
154-
Tensor[] trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.trainable_weights.ToArray(), s => s.AsTensor());
155-
Tensor[] non_trainable_weights =null;
156-
if (layer.non_trainable_weights.Count != 0)
157-
{
158-
non_trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.non_trainable_weights.ToArray(), s => s.AsTensor());
159-
}
160-
foreach (var i in trainable_weights) {
161-
if (non_trainable_weights != null)
162-
{
163-
foreach (var i_ in non_trainable_weights)
164-
{
165-
weights.Add(i + i_);
166-
}
167-
}
168-
else {
169-
weights.Add(i);
170-
};
171-
172-
173-
}
174-
}
189+
var weights = layer.trainable_weights.Select(x => x).ToList();
190+
weights.AddRange(layer.non_trainable_weights);
175191
return weights;
176192
}
177193
}

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
4646
</PropertyGroup>
4747

4848
<ItemGroup>
49-
<PackageReference Include="HDF.PInvoke.1.10" Version="1.10.500" />
5049
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
5150
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
5251
<PackageReference Include="NumSharp.Lite" Version="0.1.10" />

0 commit comments

Comments
 (0)