Skip to content

Commit 34a5b32

Browse files
committed
Consolidate dataset map args.
1 parent 2cc6295 commit 34a5b32

File tree

7 files changed

+24
-27
lines changed

7 files changed

+24
-27
lines changed

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public IDatasetV2 map(Func<Tensor, Tensor> map_func,
6060
preserve_cardinality: preserve_cardinality,
6161
use_legacy_function: use_legacy_function);
6262

63-
public IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, int num_parallel_calls = -1)
63+
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1)
6464
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
6565

6666
public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ IDatasetV2 map(Func<Tensor, Tensor> map_func,
6060
bool preserve_cardinality = true,
6161
bool use_legacy_function = false);
6262

63-
IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
63+
IDatasetV2 map(Func<Tensors, Tensors> map_func,
6464
int num_parallel_calls = -1);
6565

6666
IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);

src/TensorFlowNET.Core/Data/ParallelMapDataset.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Tensorflow
99
public class ParallelMapDataset : UnaryDataset
1010
{
1111
public ParallelMapDataset(IDatasetV2 input_dataset,
12-
Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
12+
Func<Tensors, Tensors> map_func,
1313
int num_parallel_calls = -1,
1414
bool use_inter_op_parallelism = true,
1515
bool preserve_cardinality = false,

src/TensorFlowNET.Core/Data/TensorDataset.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ public class TensorDataset : DatasetSource
1212
public TensorDataset(Tensors elements)
1313
{
1414
_tensors = elements;
15-
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
16-
structure = batched_spec.Select(x => x._unbatch()).ToArray();
15+
structure = _tensors.Select(x => x.ToTensorSpec()).ToArray();
1716

1817
variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
1918
}

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
8383
graph.Exit();
8484
}
8585

86-
public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
86+
public ConcreteFunction(Func<Tensors, Tensors> func,
8787
TF_DataType[] dtypes, TensorShape[] shapes)
8888
{
8989
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
@@ -92,19 +92,14 @@ public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
9292
using var graph = new FuncGraph(func_name);
9393
graph.as_default();
9494

95-
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
96-
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
97-
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
98-
var outputs = func(input1, (input2, input3));
99-
100-
Outputs = new[] { outputs.Item1, outputs.Item2 };
101-
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
95+
var inputs = new Tensors();
96+
foreach(var (i, dtype) in enumerate(dtypes))
97+
inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args"));
98+
Outputs = func(inputs);
99+
OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray();
102100

103101
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
104-
_handle = graph.ToGraph(opers,
105-
new[] { input1, input2, input3 },
106-
new[] { outputs.Item1, outputs.Item2 },
107-
null);
102+
_handle = graph.ToGraph(opers, inputs, Outputs, null);
108103
graph.Exit();
109104
}
110105

src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Linq;
23
using Tensorflow.Keras.ArgsDefinition;
34
using static Tensorflow.Binding;
45

@@ -21,20 +22,20 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
2122
num_samples = args.X.shape[0];
2223
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
2324
_batch_size = batch_size;
24-
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f)));
25+
_size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f)));
2526
num_full_batches = num_samples / batch_size;
2627
var _partial_batch_size = num_samples % batch_size;
2728

2829
var indices_dataset = tf.data.Dataset.range(1);
2930
indices_dataset = indices_dataset.repeat(args.Epochs);
3031
indices_dataset = indices_dataset.map(permutation).prefetch(1);
3132
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
32-
var elements = new Tensors();
33+
var inputs = new Tensors();
3334
if (args.X != null)
34-
elements.Add(args.X);
35+
inputs.Add(args.X);
3536
if (args.Y != null)
36-
elements.Add(args.Y);
37-
dataset = slice_inputs(indices_dataset, elements);
37+
inputs.Add(args.Y);
38+
dataset = slice_inputs(indices_dataset, inputs);
3839
}
3940

4041
Tensor permutation(Tensor tensor)
@@ -64,11 +65,13 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)
6465
var dataset2 = tf.data.Dataset.from_tensor(elements).repeat();
6566
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2);
6667

67-
dataset = dataset.map((batch, data) =>
68+
dataset = dataset.map(inputs =>
6869
{
69-
var x = gen_array_ops.gather_v2(data.Item1, batch, 0);
70-
var y = gen_array_ops.gather_v2(data.Item2, batch, 0);
71-
return (x, y);
70+
var indices = inputs[0];
71+
var results = inputs.Skip(1)
72+
.Select(x => gen_array_ops.gather_v2(x, indices, 0))
73+
.ToArray();
74+
return new Tensors(results);
7275
});
7376

7477
dataset = dataset.with_options(new DatasetOptions { });

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ void FitInternal(int epochs)
9797
// callbacks.on_train_batch_begin(step)
9898
var results = step_function(iterator);
9999
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
100-
Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
100+
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
101101
}
102102
}
103103
}

0 commit comments

Comments
 (0)