Skip to content

Commit e661b82

Browse files
committed
Fix ones_like.
1 parent 901d574 commit e661b82

File tree

14 files changed

+96
-156
lines changed

14 files changed

+96
-156
lines changed

docs/RELEASE.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,25 @@
44

55
This release contains contributions from many people at SciSharp as well as the external contributors.
66

7+
**Release Date 02/06/2021**
8+
9+
### TensorFlow.Binding v0.33.0
10+
11+
* Improve memory usage
12+
* Fix minor bugs
13+
14+
### TensorFlow.Keras v0.4.0
15+
16+
* Add Subtract layer
17+
18+
* Add model.load_weights and model.save_weights
19+
20+
* Fix memory leak issue
21+
22+
* Support to build YOLOv3 object detection model
23+
24+
25+
726
**Release Date 01/09/2021**
827

928
### TensorFlow.Binding v0.32.0

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ public Tensor stack(object values, int axis = 0, string name = "stack")
215215
public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
216216
=> array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize);
217217

218+
public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
219+
=> array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize);
220+
218221
public Tensor one_hot(Tensor indices, int depth,
219222
Tensor on_value = null,
220223
Tensor off_value = null,
@@ -290,6 +293,9 @@ public Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name
290293
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
291294
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize);
292295

296+
public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
297+
=> array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize);
298+
293299
/// <summary>
294300
/// Stops gradient computation.
295301
/// </summary>

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ public static int len(object a)
137137
{
138138
switch (a)
139139
{
140+
case Tensors arr:
141+
return arr.Length;
140142
case Array arr:
141143
return arr.Length;
142144
case IList arr:

src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace Tensorflow.Contexts
2828
/// </summary>
2929
public sealed partial class Context
3030
{
31+
// [DebuggerStepThrough]
3132
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args)
3233
{
3334
if (tf.Context.has_graph_arg(args))

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,12 @@ private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name
388388
if (dtype == TF_DataType.DtInvalid)
389389
dtype = tensor1.dtype;
390390
var ret = ones(ones_shape, dtype: dtype, name: name);
391-
ret.shape = tensor1.shape;
392391
return ret;
393392
});
394393
}
395394

396395
public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
397396
{
398-
dtype = dtype.as_base_dtype();
399397
return tf_with(ops.name_scope(name, "ones", new { shape }), scope =>
400398
{
401399
name = scope;
@@ -578,11 +576,10 @@ public static Tensor shape_internal(Tensor input, string name = null, bool optim
578576

579577
if (!tf.Context.executing_eagerly())
580578
{
581-
var input_tensor = ops.convert_to_tensor(input);
582-
var input_shape = input_tensor.TensorShape;
583-
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
579+
var input_shape = input.TensorShape;
580+
if (optimize && input.NDims > -1 && input_shape.is_fully_defined())
584581
{
585-
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype());
582+
var nd = np.array(input.shape).astype(out_type.as_numpy_dtype());
586583
return constant_op.constant(nd, name: name);
587584
}
588585
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
124124
x, y).FirstOrDefault(),
125125
x, y);
126126

127+
public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null)
128+
=> mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name);
129+
127130
/// <summary>
128131
/// Computes the mean of elements across dimensions of a tensor.
129132
/// Reduces `input` along the dimensions given in `axis`. Unless
@@ -137,23 +140,30 @@ public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
137140
/// <param name="keep_dims"> An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1.</param>
138141
/// <param name="name"> A name for the operation (optional).</param>
139142
/// <returns> A `Tensor`. Has the same type as `input`.</returns>
140-
public static Tensor mean<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null)
141-
{
142-
if (tf.Context.executing_eagerly())
143-
{
144-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
143+
public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null)
144+
=> tf.Context.RunInAutoMode2(
145+
() => tf.OpDefLib._apply_op_helper("Mean", name, new
146+
{
147+
input,
148+
reduction_indices = axis,
149+
keep_dims = keep_dims
150+
}).output,
151+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
145152
"Mean", name,
146153
null,
147154
input, axis,
148-
"keep_dims", keep_dims);
149-
150-
return results[0];
151-
}
152-
153-
var _op = tf.OpDefLib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });
154-
155-
return _op.output;
156-
}
155+
"keep_dims", keep_dims).FirstOrDefault(),
156+
(op) =>
157+
{
158+
var attrs = new object[]
159+
{
160+
"T", op.get_attr<TF_DataType>("T"),
161+
"Tidx", op.get_attr<TF_DataType>("Tidx"),
162+
"keep_dims", op.get_attr<bool>("keep_dims")
163+
};
164+
tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs);
165+
},
166+
new Tensors(input, axis));
157167

158168
public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null)
159169
{
@@ -786,20 +796,21 @@ public static Tensor sqrt(Tensor x, string name = null)
786796
}
787797

788798
public static Tensor sub(Tensor x, Tensor y, string name = null)
789-
{
790-
if (tf.Context.executing_eagerly())
791-
{
792-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
799+
=> tf.Context.RunInAutoMode2(
800+
() => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output,
801+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
793802
"Sub", name,
794803
null,
795-
x, y);
796-
return results[0];
797-
}
798-
799-
var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y });
800-
801-
return _op.output;
802-
}
804+
x, y).FirstOrDefault(),
805+
(op) =>
806+
{
807+
var attrs = new object[]
808+
{
809+
"T", op.get_attr<TF_DataType>("T")
810+
};
811+
tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs);
812+
},
813+
new Tensors(x, y));
803814

804815
public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null)
805816
{

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -327,31 +327,17 @@ public static Tensor real(Tensor input, string name = null)
327327
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
328328
{
329329
var r = _ReductionDims(input_tensor, axis);
330-
if (axis == null)
331-
{
332-
var m = gen_math_ops.mean(input_tensor, r, keepdims, name);
333-
return _may_reduce_to_scalar(keepdims, axis, m);
334-
}
335-
else
336-
{
337-
var m = gen_math_ops.mean(input_tensor, axis, keepdims, name);
338-
return _may_reduce_to_scalar(keepdims, axis, m);
339-
}
330+
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis);
331+
var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name);
332+
return _may_reduce_to_scalar(keepdims, axis_tensor, m);
340333
}
341334

342335
public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
343336
{
344-
if (axis == null)
345-
{
346-
var r = _ReductionDims(input_tensors, axis);
347-
var m = gen_math_ops.mean(input_tensors, r, keepdims, name);
348-
return _may_reduce_to_scalar(keepdims, axis, m);
349-
}
350-
else
351-
{
352-
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
353-
return _may_reduce_to_scalar(keepdims, axis, m);
354-
}
337+
var r = _ReductionDims(input_tensors, axis);
338+
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value);
339+
var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name);
340+
return _may_reduce_to_scalar(keepdims, axis, m);
355341
}
356342

357343
/// <summary>

src/TensorFlowNET.Core/Tensors/Tensor.String.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,17 @@ public unsafe byte[][] StringBytes()
9090
size *= s;
9191

9292
var buffer = new byte[size][];
93-
var src = c_api.TF_TensorData(_handle);
94-
src += (int)(size * 8);
93+
var data_start = c_api.TF_TensorData(_handle);
94+
data_start += (int)(size * sizeof(ulong));
9595
for (int i = 0; i < buffer.Length; i++)
9696
{
9797
IntPtr dst = IntPtr.Zero;
9898
ulong dstLen = 0;
99-
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle);
99+
var read = c_api.TF_StringDecode((byte*)data_start, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle);
100100
tf.Status.Check(true);
101101
buffer[i] = new byte[(int)dstLen];
102102
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
103-
src += (int)read;
103+
data_start += (int)read;
104104
}
105105

106106
return buffer;

src/TensorFlowNET.Core/Tensors/Tensors.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,14 @@ public void Insert(int index, Tensor tensor)
6969
=> items.Insert(index, tensor);
7070

7171
IEnumerator IEnumerable.GetEnumerator()
72-
{
73-
throw new NotImplementedException();
74-
}
72+
=> GetEnumerator();
7573

7674
public static implicit operator Tensors(Tensor tensor)
7775
=> new Tensors(tensor);
7876

77+
public static implicit operator Tensors((Tensor, Tensor) tuple)
78+
=> new Tensors(tuple.Item1, tuple.Item2);
79+
7980
public static implicit operator Tensors(NDArray nd)
8081
=> new Tensors(nd);
8182

tensorflowlib/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\
5656

5757
1. Build static library
5858

59-
`bazel build --config=opt //tensorflow:tensorflow`
59+
`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow`
6060

6161
2. Build pip package
6262

0 commit comments

Comments
 (0)