77using HDF5CSharp ;
88using static Tensorflow . Binding ;
99using static Tensorflow . KerasApi ;
10+ using System . Linq ;
11+
1012namespace Tensorflow . Keras . Saving
1113{
1214 public class fdf5_format
@@ -45,13 +47,29 @@ public static void save_model_to_hdf5(long filepath = -1, Dictionary<string, obj
4547 {
4648
4749 }
48- public static void preprocess_weights_for_loading ( long filepath = - 1 , Dictionary < string , object > custom_objects = null , bool compile = false )
49- {
5050
51- }
52- public static void _convert_rnn_weights ( long filepath = - 1 , Dictionary < string , object > custom_objects = null , bool compile = false )
51+ /// <summary>
52+ /// Preprocess layer weights between different Keras formats.
53+ /// </summary>
54+ /// <param name="layer"></param>
55+ /// <param name="weights"></param>
56+ /// <param name="original_keras_version"></param>
57+ /// <param name="original_backend"></param>
58+ public static List < NDArray > preprocess_weights_for_loading ( ILayer layer , List < NDArray > weights , string original_keras_version = null , string original_backend = null )
5359 {
60+ // convert CuDNN layers
61+ return _convert_rnn_weights ( layer , weights ) ;
62+ }
5463
64+ /// <summary>
65+ /// Converts weights for RNN layers between native and CuDNN format.
66+ /// </summary>
67+ /// <param name="layer"></param>
68+ /// <param name="weights"></param>
69+ static List < NDArray > _convert_rnn_weights ( ILayer layer , List < NDArray > weights )
70+ {
71+ var target_class = layer . GetType ( ) . Name ;
72+ return weights ;
5573 }
5674 public static void save_optimizer_weights_to_hdf5_group ( long filepath = - 1 , Dictionary < string , object > custom_objects = null , bool compile = false )
5775 {
@@ -65,56 +83,79 @@ public static void save_weights_to_hdf5_group(long filepath = -1, Dictionary<str
6583 {
6684
6785 }
68- public static void load_weights_from_hdf5_group ( long f = - 1 , Model model = null )
86+ public static void load_weights_from_hdf5_group ( long f , List < ILayer > layers )
6987 {
70- string original_keras_version = "1 " ;
88+ string original_keras_version = "2.4.0 " ;
7189 string original_backend = null ;
7290 if ( Hdf5 . AttributeExists ( f , "keras_version" ) )
7391 {
74- ( bool success , string [ ] attr ) = Hdf5 . ReadStringAttributes ( f , "keras_version" , "" ) ;
92+ var ( success , attr ) = Hdf5 . ReadStringAttributes ( f , "keras_version" , "" ) ;
7593 if ( success )
76- {
77- original_keras_version = attr [ 0 ] ;
78- }
94+ original_keras_version = attr . First ( ) ;
95+ // keras version should be 2.5.0+
96+ var ver_major = int . Parse ( original_keras_version . Split ( '.' ) [ 0 ] ) ;
97+ var ver_minor = int . Parse ( original_keras_version . Split ( '.' ) [ 1 ] ) ;
98+ if ( ver_major < 2 || ( ver_major == 2 && ver_minor < 5 ) )
99+ throw new ValueError ( "keras version should be 2.5.0 or later." ) ;
79100 }
80101 if ( Hdf5 . AttributeExists ( f , "backend" ) )
81102 {
82- ( bool success , string [ ] attr ) = Hdf5 . ReadStringAttributes ( f , "backend" , "" ) ;
103+ var ( success , attr ) = Hdf5 . ReadStringAttributes ( f , "backend" , "" ) ;
83104 if ( success )
84- {
85- original_backend = attr [ 0 ] ;
86- }
105+ original_backend = attr . First ( ) ;
87106 }
88107 List < ILayer > filtered_layers = new List < ILayer > ( ) ;
89- List < Tensor > weights ;
90- foreach ( var layer in model . Layers )
108+ List < IVariableV1 > weights ;
109+ foreach ( var layer in layers )
91110 {
92111 weights = _legacy_weights ( layer ) ;
93- if ( weights . Count > 0 )
112+ if ( weights . Count > 0 )
94113 {
95114 filtered_layers . append ( layer ) ;
96115 }
97116 }
98- string [ ] layer_names = load_attributes_from_hdf5_group ( f , "layer_names" ) ;
99- List < NDArray > weight_values = new List < NDArray > ( ) ;
100- foreach ( var i in filtered_layers ) {
101- long g = H5G . open ( f , i . Name ) ;
102- string [ ] weight_names = null ;
103- if ( g != - 1 )
104- {
105- weight_names = load_attributes_from_hdf5_group ( g , "weight_names" ) ;
106- }
107- if ( weight_names != null )
117+ string [ ] layer_names = load_attributes_from_hdf5_group ( f , "layer_names" ) ;
118+ var filtered_layer_names = new List < string > ( ) ;
119+ foreach ( var name in layer_names )
120+ {
121+ long g = H5G . open ( f , name ) ;
122+ var weight_names = load_attributes_from_hdf5_group ( g , "weight_names" ) ;
123+ if ( weight_names . Count ( ) > 0 )
124+ filtered_layer_names . Add ( name ) ;
125+ H5G . close ( g ) ;
126+ }
127+ layer_names = filtered_layer_names . ToArray ( ) ;
128+ if ( layer_names . Length != filtered_layers . Count ( ) )
129+ throw new ValueError ( "You are trying to load a weight file " +
130+ $ "containing { layer_names } " +
131+ $ " layers into a model with { filtered_layers . Count } layers.") ;
132+
133+ var weight_value_tuples = new List < ( IVariableV1 , NDArray ) > ( ) ;
134+ foreach ( var ( k , name ) in enumerate ( layer_names ) )
135+ {
136+ var weight_values = new List < NDArray > ( ) ;
137+ long g = H5G . open ( f , name ) ;
138+ var weight_names = load_attributes_from_hdf5_group ( g , "weight_names" ) ;
139+ foreach ( var i_ in weight_names )
108140 {
109- foreach ( var i_ in weight_names ) {
110- ( bool success , Array result ) = Hdf5 . ReadDataset < float > ( g , i_ ) ;
111- //
141+ ( bool success , Array result ) = Hdf5 . ReadDataset < float > ( g , i_ ) ;
142+ if ( success )
112143 weight_values . Add ( np . array ( result ) ) ;
113- }
114144 }
115145 H5G . close ( g ) ;
146+ var layer = filtered_layers [ k ] ;
147+ var symbolic_weights = _legacy_weights ( layer ) ;
148+ preprocess_weights_for_loading ( layer , weight_values , original_keras_version , original_backend ) ;
149+ if ( weight_values . Count ( ) != symbolic_weights . Count ( ) )
150+ throw new ValueError ( $ "Layer #{ k } (named { layer . Name } " +
151+ "in the current model) was found to " +
152+ $ "correspond to layer { name } in the save file." +
153+ $ "However the new layer { layer . Name } expects " +
154+ $ "{ symbolic_weights . Count ( ) } weights, but the saved weights have " +
155+ $ "{ weight_values . Count ( ) } elements.") ;
156+ weight_value_tuples . AddRange ( zip ( symbolic_weights , weight_values ) ) ;
116157 }
117-
158+ keras . backend . batch_set_value ( weight_value_tuples ) ;
118159 }
119160 public static void toarrayf4 ( long filepath = - 1 , Dictionary < string , object > custom_objects = null , bool compile = false )
120161 {
@@ -128,15 +169,13 @@ public static void save_attributes_to_hdf5_group(long filepath = -1, Dictionary<
128169 {
129170
130171 }
131- public static string [ ] load_attributes_from_hdf5_group ( long f = - 1 , string name = "" )
172+ public static string [ ] load_attributes_from_hdf5_group ( long group , string name )
132173 {
133- if ( Hdf5 . AttributeExists ( f , name ) )
174+ if ( Hdf5 . AttributeExists ( group , name ) )
134175 {
135- ( bool success , string [ ] attr ) = Hdf5 . ReadStringAttributes ( f , name , "" ) ;
176+ var ( success , attr ) = Hdf5 . ReadStringAttributes ( group , name , "" ) ;
136177 if ( success )
137- {
138- return attr ;
139- }
178+ return attr . ToArray ( ) ;
140179 }
141180 return null ;
142181 }
@@ -145,33 +184,10 @@ public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionar
145184
146185 }
147186
148- public static List < Tensor > _legacy_weights ( ILayer layer )
187+ public static List < IVariableV1 > _legacy_weights ( ILayer layer )
149188 {
150-
151- List < Tensor > weights = new List < Tensor > ( ) ;
152- if ( layer . trainable_weights . Count != 0 )
153- {
154- Tensor [ ] trainable_weights = Array . ConvertAll < IVariableV1 , Tensor > ( layer . trainable_weights . ToArray ( ) , s => s . AsTensor ( ) ) ;
155- Tensor [ ] non_trainable_weights = null ;
156- if ( layer . non_trainable_weights . Count != 0 )
157- {
158- non_trainable_weights = Array . ConvertAll < IVariableV1 , Tensor > ( layer . non_trainable_weights . ToArray ( ) , s => s . AsTensor ( ) ) ;
159- }
160- foreach ( var i in trainable_weights ) {
161- if ( non_trainable_weights != null )
162- {
163- foreach ( var i_ in non_trainable_weights )
164- {
165- weights . Add ( i + i_ ) ;
166- }
167- }
168- else {
169- weights . Add ( i ) ;
170- } ;
171-
172-
173- }
174- }
189+ var weights = layer . trainable_weights . Select ( x => x ) . ToList ( ) ;
190+ weights . AddRange ( layer . non_trainable_weights ) ;
175191 return weights ;
176192 }
177193 }
0 commit comments