Skip to content

Commit 415445f

Browse files
AangbaeckEsther2013
authored andcommitted
Fixed Keras c# example so it works
1 parent 061ee22 commit 415445f

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

README.md

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,46 +112,40 @@ Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube)
112112
Toy version of `ResNet` in `Keras` functional API:
113113

114114
```csharp
115+
var layers = new LayersApi();
115116
// input layer
116117
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
117-
118118
// convolutional layer
119119
var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs);
120120
x = layers.Conv2D(64, 3, activation: "relu").Apply(x);
121121
var block_1_output = layers.MaxPooling2D(3).Apply(x);
122-
123122
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output);
124123
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
125-
var block_2_output = layers.add(x, block_1_output);
126-
124+
var block_2_output = layers.Add().Apply(new Tensors(x, block_1_output));
127125
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output);
128126
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
129-
var block_3_output = layers.add(x, block_2_output);
130-
127+
var block_3_output = layers.Add().Apply(new Tensors(x, block_2_output));
131128
x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output);
132129
x = layers.GlobalAveragePooling2D().Apply(x);
133130
x = layers.Dense(256, activation: "relu").Apply(x);
134131
x = layers.Dropout(0.5f).Apply(x);
135-
136132
// output layer
137133
var outputs = layers.Dense(10).Apply(x);
138-
139134
// build keras model
140-
model = keras.Model(inputs, outputs, name: "toy_resnet");
135+
var model = keras.Model(inputs, outputs, name: "toy_resnet");
141136
model.summary();
142-
143137
// compile keras model in tensorflow static graph
144138
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
145-
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
146-
metrics: new[] { "acc" });
147-
139+
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
140+
metrics: new[] { "acc" });
148141
// prepare dataset
149142
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
150-
143+
x_train = x_train / 255.0f;
144+
y_train = np_utils.to_categorical(y_train, 10);
151145
// training
152-
model.fit(x_train[new Slice(0, 1000)], y_train[new Slice(0, 1000)],
153-
batch_size: 64,
154-
epochs: 10,
146+
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)],
147+
batch_size: 64,
148+
epochs: 10,
155149
validation_split: 0.2f);
156150
```
157151

@@ -260,4 +254,4 @@ WeChat Sponsor 微信打赏:
260254

261255
TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/)
262256
<br>
263-
<a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a>
257+
<a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a>

0 commit comments

Comments
 (0)