-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
65 lines (48 loc) · 1.85 KB
/
main.cpp
File metadata and controls
65 lines (48 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <iostream>
#include <vector>
#include <cassert>
#include <cstdlib>
#include "net.hpp"
#include "trainingDriver.hpp"
void showVectorValues(std::string label, std::vector<double> &v)
{
std::cout << label << " ";
for (unsigned i = 0; i < v.size(); ++i) {
std::cout << v[i] << " ";
}
std::cout << std::endl;
}
int main() {
const char* env_p = std::getenv("DATA_PATH");
TrainingDriver trainData(env_p);
//number of layers and number of neurons per layer in neural net passed into myNet object through topology
// e.g., { 3, 2, 1 }
std::vector<unsigned> topology;
trainData.getTopology(topology); //pass in vector specifying number of input neurons, layers, output neurons
Net myNet(topology);
std::vector<double> input_values, target_values, result_values;
int trainingPass = 0;
while (!trainData.isEof()) {
++trainingPass;
std::cout << std::endl << "Pass " << trainingPass;
// Get new input data and feed it forward:
if (trainData.getNextInputs(input_values) != topology[0]) {
break;
}
showVectorValues(": Inputs:", input_values);
myNet.feedForward(input_values);
// Collect the net's actual output results:
myNet.getResult(result_values);
showVectorValues("Outputs:", result_values);
// Train the net what the outputs should have been:
trainData.getTargetOutputs(target_values);
showVectorValues("Targets:", target_values);
assert(target_values.size() == topology.back());
myNet.backPropagate(target_values);
// Report how well the training is working, average over recent samples:
std::cout << "Net recent average error: "
<< myNet.getRecentAverageError() << std::endl;
}
std::cout << std::endl << "Done" << std::endl;
return 0;
}