-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathProgram.cs
More file actions
86 lines (68 loc) · 2.84 KB
/
Program.cs
File metadata and controls
86 lines (68 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// Copyright © Benjamin Abt 2021, all rights reserved
using Microsoft.ML;
using Spectre.Console;
using System;
using System.Collections.Generic;
using System.IO;
using ToxicSentimentSample;
// settings
string trainDataFile = Path.Combine(Environment.CurrentDirectory, "train.csv");
// Context
MLContext mlContext = new();
AnsiConsole.Render(new FigletText("ML.NET Toxic Sentiment Sample").LeftAligned().Color(Color.Red));
// create ML context
AnsiConsole.MarkupLine("=> create model context...");
// read train data
AnsiConsole.MarkupLine("=> loading test data...");
List<TrainInput> trainData = DataReader.ReadTrainData(trainDataFile);
if(trainData.Count == 0)
{
AnsiConsole.MarkupLine("[red]!!! No testdata found. Have you downloaded the train.csv?[/]");
return;
}
// load train data
DataOperationsCatalog.TrainTestData dataView;
{
AnsiConsole.MarkupLine("=> parsing test data...");
// on load from enumerable, the property information (and attribute settings) are used for headers and features.
IDataView data = mlContext.Data.LoadFromEnumerable(trainData);
AnsiConsole.MarkupLine("=> create model partitions...");
dataView = mlContext.Data.TrainTestSplit(data, testFraction: 0.2); // create partitions, we use 20% testing and 80% training
}
// setup model options
AnsiConsole.MarkupLine("=> create binary classification pipeline...");
var textPipeline = BinaryClassification.CreatePipeline(mlContext);
// train model
AnsiConsole.MarkupLine("=> model training...");
ITransformer trainedModel = textPipeline.Fit(dataView.TrainSet);
// validate model
AnsiConsole.MarkupLine("=> model validation...");
IDataView predictions = trainedModel.Transform(dataView.TestSet);
var metrics = BinaryClassification.Validate(mlContext, predictions);
// report the results
BinaryClassification.PrintMetrics(metrics);
// user data
PredictionEngine<TextInput, TextIntentBinaryPrediction> predictionEngine =
mlContext.Model.CreatePredictionEngine<TextInput, TextIntentBinaryPrediction>(trainedModel);
while (true)
{
AnsiConsole.MarkupLine("[blue]--------------------------------------------------[/]");
string? predictInput = AnsiConsole.Ask<string?>("Text Input: ");
if (string.IsNullOrEmpty(predictInput))
{
break;
}
TextInput input = new() { Text = predictInput };
TextIntentBinaryPrediction prediction = predictionEngine.Predict(input);
AnsiConsole.MarkupLine(Environment.NewLine + "[green]Predict Results:[/]");
Table prt = new();
{
prt.AddColumn(new TableColumn("Name").RightAligned());
prt.AddColumn("Value");
}
prt.AddRow("Text", input.Text)
.AddRow("Prediction", prediction.Prediction ? "Toxic :-(" : "Friendly :-)")
.AddRow("Probability", $"{prediction.Probability:P2}")
.AddRow("Score", $"{prediction.Score}");
AnsiConsole.Render(prt);
}