Skip to content

Commit c04216f

Browse files
dataangelOceania2018
authored andcommitted
update:Keras
1 parent bf4f65b commit c04216f

File tree

6 files changed

+258
-5
lines changed

6 files changed

+258
-5
lines changed

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ public interface ILayer
1313
List<INode> OutboundNodes { get; }
1414
Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false);
1515
List<IVariableV1> trainable_variables { get; }
16+
List<IVariableV1> trainable_weights { get; }
17+
List<IVariableV1> non_trainable_weights { get; }
1618
TensorShape output_shape { get; }
1719
int count_params();
1820
LayerArgs get_config();

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ public abstract class RnnCell : ILayer
6767
public bool Trainable => throw new NotImplementedException();
6868

6969
public List<IVariableV1> trainable_variables => throw new NotImplementedException();
70+
public List<IVariableV1> trainable_weights => throw new NotImplementedException();
71+
public List<IVariableV1> non_trainable_weights => throw new NotImplementedException();
7072

7173
public TensorShape output_shape => throw new NotImplementedException();
7274

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,21 @@ public int count_params()
239239
return layer_utils.count_params(this, weights);
240240
return 0;
241241
}
242+
List<IVariableV1> ILayer.trainable_weights
243+
{
244+
get
245+
{
246+
return trainable_weights;
247+
}
248+
}
249+
250+
List<IVariableV1> ILayer.non_trainable_weights
251+
{
252+
get
253+
{
254+
return non_trainable_weights;
255+
}
256+
}
242257

243258
public List<IVariableV1> weights
244259
{
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using HDF.PInvoke;
5+
using HDF5CSharp;
6+
using Tensorflow.Keras.Saving;
7+
8+
namespace Tensorflow.Keras.Engine
9+
{
10+
public partial class Model
11+
{
12+
private long fileId = -1;
13+
private long f = -1;
14+
public void load_weights(string filepath ="",bool by_name= false, bool skip_mismatch=false, object options = null)
15+
{
16+
long root = Hdf5.OpenFile(filepath, true);
17+
18+
long fileId = root;
19+
//try
20+
//{
21+
22+
bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
23+
bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");
24+
25+
if (!lsuccess && msuccess)
26+
{
27+
f = H5G.open(fileId, "model_weights");
28+
29+
}
30+
if (by_name)
31+
{
32+
//fdf5_format.load_weights_from_hdf5_group_by_name();
33+
}
34+
else
35+
{
36+
fdf5_format.load_weights_from_hdf5_group(f, this);
37+
}
38+
H5G.close(f);
39+
//}
40+
//catch (Exception ex)
41+
//{
42+
// if (fileId != -1)
43+
// {
44+
// Hdf5.CloseFile(fileId);
45+
// }
46+
// if (f != -1)
47+
// {
48+
// H5G.close(f);
49+
// }
50+
// throw new Exception(ex.ToString());
51+
//}
52+
}
53+
54+
}
55+
}
56+
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using HDF.PInvoke;
5+
using NumSharp;
6+
using Tensorflow.Keras.Engine;
7+
using HDF5CSharp;
8+
using static Tensorflow.Binding;
9+
using static Tensorflow.KerasApi;
10+
namespace Tensorflow.Keras.Saving
11+
{
12+
public class fdf5_format
13+
{
14+
15+
public static void load_model_from_hdf5(string filepath = "", Dictionary<string, object> custom_objects = null, bool compile = false)
16+
{
17+
long root = Hdf5.OpenFile(filepath,true);
18+
load_model_from_hdf5(root, custom_objects, compile);
19+
}
20+
public static void load_model_from_hdf5(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
21+
{
22+
//long fileId = filepath;
23+
//try
24+
//{
25+
// groupId = H5G.open(fileId, "/");
26+
// (bool success, string[] attrId) = Hdf5.ReadStringAttributes(groupId, "model_config", "");
27+
// H5G.close(groupId);
28+
// if (success == true) {
29+
// Console.WriteLine(attrId[0]);
30+
// }
31+
//}
32+
//catch (Exception ex)
33+
//{
34+
// if (filepath != -1) {
35+
// Hdf5.CloseFile(filepath);
36+
// }
37+
// if (groupId != -1) {
38+
// H5G.close(groupId);
39+
// }
40+
// throw new Exception(ex.ToString());
41+
//}
42+
43+
}
44+
public static void save_model_to_hdf5(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
45+
{
46+
47+
}
48+
public static void preprocess_weights_for_loading(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
49+
{
50+
51+
}
52+
public static void _convert_rnn_weights(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
53+
{
54+
55+
}
56+
public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
57+
{
58+
59+
}
60+
public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
61+
{
62+
63+
}
64+
public static void save_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
65+
{
66+
67+
}
68+
public static void load_weights_from_hdf5_group(long f=-1,Model model=null)
69+
{
70+
string original_keras_version = "1";
71+
string original_backend = null;
72+
if (Hdf5.AttributeExists(f, "keras_version"))
73+
{
74+
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
75+
if (success)
76+
{
77+
original_keras_version = attr[0];
78+
}
79+
}
80+
if (Hdf5.AttributeExists(f, "backend"))
81+
{
82+
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "backend", "");
83+
if (success)
84+
{
85+
original_backend = attr[0];
86+
}
87+
}
88+
List<ILayer> filtered_layers = new List<ILayer>();
89+
List<Tensor> weights;
90+
foreach (var layer in model.Layers)
91+
{
92+
weights = _legacy_weights(layer);
93+
if (weights.Count>0)
94+
{
95+
filtered_layers.append(layer);
96+
}
97+
}
98+
string[] layer_names = load_attributes_from_hdf5_group(f,"layer_names");
99+
List<NDArray> weight_values=new List<NDArray>();
100+
foreach (var i in filtered_layers) {
101+
long g = H5G.open(f, i.Name);
102+
string[] weight_names = null;
103+
if (g != -1)
104+
{
105+
weight_names = load_attributes_from_hdf5_group(g, "weight_names");
106+
}
107+
if (weight_names != null)
108+
{
109+
foreach (var i_ in weight_names) {
110+
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
111+
//
112+
weight_values.Add(np.array(result));
113+
}
114+
}
115+
H5G.close(g);
116+
}
117+
118+
}
119+
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
120+
{
121+
122+
}
123+
public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
124+
{
125+
126+
}
127+
public static void save_attributes_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
128+
{
129+
130+
}
131+
public static string[] load_attributes_from_hdf5_group(long f = -1, string name = "")
132+
{
133+
if (Hdf5.AttributeExists(f, name))
134+
{
135+
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, name, "");
136+
if (success)
137+
{
138+
return attr;
139+
}
140+
}
141+
return null;
142+
}
143+
public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
144+
{
145+
146+
}
147+
148+
public static List<Tensor> _legacy_weights(ILayer layer)
149+
{
150+
151+
List<Tensor> weights= new List<Tensor>();
152+
if (layer.trainable_weights.Count != 0)
153+
{
154+
Tensor[] trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.trainable_weights.ToArray(), s => s.AsTensor());
155+
Tensor[] non_trainable_weights =null;
156+
if (layer.non_trainable_weights.Count != 0)
157+
{
158+
non_trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.non_trainable_weights.ToArray(), s => s.AsTensor());
159+
}
160+
foreach (var i in trainable_weights) {
161+
if (non_trainable_weights != null)
162+
{
163+
foreach (var i_ in non_trainable_weights)
164+
{
165+
weights.Add(i + i_);
166+
}
167+
}
168+
else {
169+
weights.Add(i);
170+
};
171+
172+
173+
}
174+
}
175+
return weights;
176+
}
177+
}
178+
}
179+

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent &amp; simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear &amp; actionable error messages.</Description>
2727
<Company>SciSharp STACK</Company>
2828
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
29-
<PackageTags>tensorflow, keras, deep learning, machine learning, scisharp</PackageTags>
29+
<PackageTags>tensorflow, keras, deep learning, machine learning</PackageTags>
3030
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
3131
<RepositoryType>Git</RepositoryType>
3232
<SignAssembly>true</SignAssembly>
@@ -46,15 +46,14 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
4646
</PropertyGroup>
4747

4848
<ItemGroup>
49+
<PackageReference Include="HDF.PInvoke.1.10" Version="1.10.500" />
4950
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
5051
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
52+
<PackageReference Include="NumSharp.Lite" Version="0.1.10" />
53+
<PackageReference Include="SciSharp.Keras.HDF5" Version="1.1.10.500" />
5154
<PackageReference Include="SharpZipLib" Version="1.3.1" />
5255
</ItemGroup>
5356

54-
<ItemGroup>
55-
<Folder Include="Saving\" />
56-
</ItemGroup>
57-
5857
<ItemGroup>
5958
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
6059
</ItemGroup>

0 commit comments

Comments
 (0)