-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathClassification.cs
More file actions
56 lines (47 loc) · 2.5 KB
/
Classification.cs
File metadata and controls
56 lines (47 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// Copyright © Benjamin Abt 2021, all rights reserved
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
using Spectre.Console;
namespace ToxicSentimentSample
{
public static class BinaryClassification
{
public static EstimatorChain<BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>>
CreatePipeline(MLContext mlContext)
{
EstimatorChain<BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>> pipeline
= mlContext.Transforms.Text.FeaturizeText(
outputColumnName: "Features",
inputColumnName: nameof(TrainInput.Text))
// step 2: add a fast tree learner
.Append(mlContext.BinaryClassification.Trainers.FastTree(
labelColumnName: nameof(TrainInput.IsToxic),
featureColumnName: "Features"));
return pipeline;
}
public static CalibratedBinaryClassificationMetrics Validate(MLContext mlContext, IDataView predictions)
=> mlContext.BinaryClassification.Evaluate(predictions, labelColumnName: nameof(TrainInput.IsToxic));
public static void PrintMetrics(CalibratedBinaryClassificationMetrics metrics)
{
AnsiConsole.MarkupLine($"=> metrics for binary classification");
Table trt = new();
{
trt.AddColumn(new TableColumn("Name").RightAligned());
trt.AddColumn("Value");
}
trt.AddRow("Accuracy", $"{metrics.Accuracy:P2}");
trt.AddRow("Area Under Precision Recall Curve", $"{metrics.AreaUnderPrecisionRecallCurve:P2}");
trt.AddRow("Area Under Roc Curve", $"{metrics.AreaUnderRocCurve:P2}");
trt.AddRow("F1 Score", $"{metrics.F1Score:P2}");
trt.AddRow("Log Loss", $"{metrics.LogLoss:0.##}");
trt.AddRow("Log Loss Reduction", $"{metrics.LogLossReduction:0.##}");
trt.AddRow("Positive Precision", $"{metrics.PositivePrecision:0.##}");
trt.AddRow("Positive Recall", $"{metrics.PositiveRecall:0.##}");
trt.AddRow("Negative Precision", $"{metrics.NegativePrecision:0.##}");
trt.AddRow("Negative Recall", $"{metrics.NegativeRecall:0.##}");
AnsiConsole.Render(trt);
}
}
}