Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 56 additions & 49 deletions inst/Classification/ClassificationGAM.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Copyright (C) 2024 Ruchika Sonagote <ruchikasonagote2003@gmail.com>
## Copyright (C) 2024-2025 Andreas Bertsatos <abertsatos@biol.uoa.gr>
## Copyright (C) 2025 Swayam Shah <swayamshah66@gmail.com>
## Copyright (C) 2026 Jayant Chauhan <0001jayant@gmail.com>
##
## This file is part of the statistics package for GNU Octave.
##
Expand Down Expand Up @@ -222,9 +223,12 @@
##
## Model specification formula
##
## A character vector specifying the model formula in the form
## @qcode{"Y ~ terms"} where @qcode{Y} represents the response variable and
## @qcode{terms} specifies the predictor variables and interaction terms.
## A character vector specifying the model formula using standard Wilkinson
## notation in the form @qcode{"Y ~ terms"}. In addition to basic main
## effects and interactions (@code{+}, @code{:}), it fully supports
## advanced operators including crossing (@code{*}), nesting (@code{/}),
## power/limits (@code{^}), and deletion (@code{-}). The formula is
## evaluated internally via @code{parseWilkinsonFormula}.
## This property is read-only.
##
## @end deftp
Expand Down Expand Up @@ -518,9 +522,10 @@ function disp (this)
## @qcode{'symmetriclogit'}.
##
## @item @qcode{'Formula'} @tab @tab A character vector specifying the model
## formula in the form @qcode{"Y ~ terms"} where @qcode{Y} represents the
## response variable and @qcode{terms} specifies the predictor variables and
## interaction terms.
## formula using standard Wilkinson notation in the form @qcode{"Y ~ terms"}.
## In addition to basic main effects and interactions (@code{+}, @code{:}),
## it supports advanced operators including crossing (@code{*}), nesting
## (@code{/}), power/limits (@code{^}), and deletion (@code{-}).
##
## @item @qcode{'Interactions'} @tab @tab A logical matrix, a positive
## integer scalar, or the string @qcode{"all"} for defining the interactions
Expand Down Expand Up @@ -1276,52 +1281,38 @@ function savemodel (this, fname)

## Determine interactions from formula
function intMat = parseFormula (this)
intMat = [];
## Check formula for syntax
if (isempty (strfind (this.Formula, '~')))
error ("ClassificationGAM: invalid syntax in 'Formula'.");
try
schema = parseWilkinsonFormula (this.Formula, 'matrix');
catch ME
error ("ClassificationGAM: Invalid formula. %s", ME.message);
end_try_catch

termMat = schema.Terms;
varNames = schema.VariableNames;

if (! isempty (schema.ResponseIdx))
respIdx = schema.ResponseIdx;
varNames(respIdx) = [];
termMat(:, respIdx) = [];
endif
## Split formula and keep predictor terms
formulaParts = strsplit (this.Formula, '~');
## Check there is some string after '~'
if (numel (formulaParts) < 2)
error ("ClassificationGAM: no predictor terms in 'Formula'.");
endif
predictorString = strtrim (formulaParts{2});
if (isempty (predictorString))
error ("ClassificationGAM: no predictor terms in 'Formula'.");
endif
## Split additive terms (between + sign)
aterms = strtrim (strsplit (predictorString, '+'));
## Process all terms
for i = 1:numel (aterms)
## Find individual terms (string missing ':')
if (isempty (strfind (aterms(i), ':'){:}))
## Search PredictorNames to associate with column in X
sterms = strcmp (this.PredictorNames, aterms(i));
## Append to interactions matrix
intMat = [intMat; sterms];
else
## Split interaction terms (string contains ':')
mterms = strsplit (aterms{i}, ':');
## Add each individual predictor to interaction term vector
iterms = logical (zeros (1, this.NumPredictors));
for t = 1:numel (mterms)
iterms = iterms | strcmp (this.PredictorNames, mterms(t));
endfor
## Check that all predictors have been identified
if (sum (iterms) != t)
error (strcat ("ClassificationGAM: some predictors", ...
" have not been identified."));
endif
## Append to interactions matrix
intMat = [intMat; iterms];

intMat = zeros (rows (termMat), this.NumPredictors);

for i = 1:numel (varNames)
colIdx = find (strcmp (this.PredictorNames, varNames{i}));
if (isempty (colIdx))
error ("ClassificationGAM: Formula contains unknown predictor '%s'.", varNames{i});
endif
intMat(:, colIdx) = termMat(:, i);
endfor
## Check that all terms have been identified
if (! all (sum (intMat, 2) > 0))
error ("ClassificationGAM: some terms have not been identified.");
endif

## Remove intercept row (all zeros)
## ClassificationGAM handles the intercept internally during fitting,
## so we strip out the explicit intercept term if the parser included it.
interceptRow = (sum (intMat, 2) == 0);
intMat(interceptRow, :) = [];

intMat = logical (intMat);
endfunction

## Fit the model
Expand Down Expand Up @@ -1634,6 +1625,22 @@ function savemodel (this, fname)
%! assert (CVMdl.KFold == 3)
%! assert (class (CVMdl.Trained{1}), "CompactClassificationGAM")
%! assert (CVMdl.CrossValidatedModel, "ClassificationGAM")
%!test
%! ## Test advanced Wilkinson notation parsing in ClassificationGAM
%! X = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! Y = [0; 0; 1; 1];
%!
%! ## The * operator should automatically expand to main effects + interaction
%! a = ClassificationGAM (X, Y, 'Formula', 'Y ~ x1 * x2');
%!
%! assert (class (a), "ClassificationGAM");
%! assert (a.NumPredictors, 3);
%!
%! ## Verify the IntMatrix correctly captured x1, x2, and x1:x2 in MATLAB order
%! expected_int = logical ([1, 0, 0;
%! 0, 1, 0;
%! 1, 1, 0]);
%! assert (a.IntMatrix, expected_int);

## Test input validation for crossval method
%!error<ClassificationGAM.crossval: Name-Value arguments must be in pairs.> ...
Expand Down
70 changes: 47 additions & 23 deletions inst/parseWilkinsonFormula.m
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
else
lhs_vars = resolve_lhs_symbolic (lhs_str);
endif

## build the required output.
varargout{1} = run_equation_builder (lhs_vars, rhs_terms);

Expand Down Expand Up @@ -519,14 +519,14 @@
args_str_parts = {};
for k = 1:length (node.args)
arg_res = run_expander (node.args{k}, mode);

if (! isempty (arg_res) && ! isempty (arg_res{1}))
args_str_parts{end+1} = arg_res{1}{1};
args_str_parts{end+1} = arg_res{1}{1};
else
args_str_parts{end+1} = '';
endif
endfor

full_term = sprintf ("%s(%s)", node.name, strjoin (args_str_parts, ','));
result = {{full_term}};
else
Expand Down Expand Up @@ -829,14 +829,20 @@
terms_mat(i, idx) = 1;
endfor

## sorting : order by order.
## sorting : order by order (ascending), then by variable sequence (descending)
term_orders = sum (terms_mat, 2);
M = [term_orders, terms_mat];

## Create unique rows first
[~, unique_idx] = unique (M, 'rows');
terms_mat = terms_mat (unique_idx, :);
M = M (unique_idx, :);

## Create the direction vector: [1, -2, -3, -4, ...]
sort_dirs = [1, -(2:size(M, 2))];

[~, sort_idx] = sortrows ([sum(terms_mat, 2), terms_mat]);
## Sort using the direction vector
[~, sort_idx] = sortrows (M, sort_dirs);
schema.Terms = terms_mat (sort_idx, :);

endfunction
Expand Down Expand Up @@ -1140,14 +1146,14 @@
## process RHS
rhs_tokens = run_lexer (rhs_str);
[rhs_tree, ~] = run_parser (rhs_tokens);

wrapper.type = 'OPERATOR';
wrapper.value = '~';
wrapper.left = [];
wrapper.right = rhs_tree;

expanded = run_expander (wrapper, mode);

## extract the terms.
if (isstruct (expanded) && isfield (expanded, 'model'))
rhs_terms = expanded.model;
Expand All @@ -1164,22 +1170,22 @@
for i = 1:length (parts)
p = strtrim (parts{i});
if (isempty (p)), continue; endif

range_parts = strsplit (p, '-');

if (length (range_parts) == 2)
s_str = strtrim (range_parts{1});
e_str = strtrim (range_parts{2});

[s_tok] = regexp (s_str, '^([a-zA-Z_]\w*)(\d+)$', 'tokens');
[e_tok] = regexp (e_str, '^([a-zA-Z_]\w*)(\d+)$', 'tokens');

if (! isempty (s_tok) && ! isempty (e_tok))
prefix = s_tok{1}{1};
s_num = str2double (s_tok{1}{2});
e_prefix = e_tok{1}{1};
e_num = str2double (e_tok{1}{2});

if (strcmp (prefix, e_prefix) && s_num <= e_num)
for n = s_num:e_num
vars{end+1} = sprintf ("%s%d", prefix, n);
Expand All @@ -1203,7 +1209,7 @@
for i = 1:length (rhs_terms)
t = rhs_terms{i};
if (isempty (t))
term_strs{end+1} = '';
term_strs{end+1} = '';
else
if (length (t) == 1 && any (strfind (t{1}, "(")))
term_strs{end+1} = t{1};
Expand All @@ -1229,13 +1235,13 @@
rhs_parts{end+1} = sprintf ("%s*%s", coeff, t_str);
endif
endfor

full_rhs = strjoin (rhs_parts, ' + ');
if (isempty (full_rhs)), full_rhs = '0'; endif
lines{end+1} = sprintf ("%s = %s", lhs_vars{k}, full_rhs);
endfor

eq_list = string (lines');
eq_list = string (lines');
endfunction

%!demo
Expand Down Expand Up @@ -1271,7 +1277,7 @@

%!demo
%!
%! ## Interaction Effects :
%! ## Interaction Effects :
%! ## We analyze Relief Score based on Drug Type and Dosage Level.
%! ## The '*' operator expands to the main effects PLUS the interaction term.
%! ## Categorical variables are automatically created.
Expand All @@ -1287,11 +1293,11 @@

%!demo
%!
%! ## Polynomial Regression :
%! ## Polynomial Regression :
%! ## Uses the power operator (^) to model non-linear relationships.
%! Distance = [20; 45; 80; 125];
%! Speed = [30; 50; 70; 90];
%! Speed_2 = Speed .^ 2;
%! Speed_2 = Speed .^ 2;
%! t = table (Distance, Speed, Speed_2, 'VariableNames', {'Distance', 'Speed', 'Speed^2'});
%!
%! formula = 'Distance ~ Speed^2';
Expand All @@ -1316,7 +1322,7 @@

%!demo
%!
%! ## Explicit Nesting :
%! ## Explicit Nesting :
%! ## The parser also supports the explicit 'B(A)' syntax, which means
%! ## 'B is nested within A'. This is equivalent to the interaction 'A:B'
%! ## but often used to denote random effects or specific hierarchy.
Expand All @@ -1327,7 +1333,7 @@

%!demo
%!
%! ## Excluding Terms :
%! ## Excluding Terms :
%! ## Demonstrates building a complex model and then simplifying it.
%! ## We define a full 3-way interaction (A*B*C) but explicitly remove the
%! ## three-way term (A:B:C) using the minus operator.
Expand All @@ -1338,7 +1344,7 @@

%!demo
%!
%! ## Repeated Measures :
%! ## Repeated Measures :
%! ## This allows predicting multiple outcomes simultaneously.
%! ## The range operator '-' selects all variables between 'T1' and 'T3'
%! ## as the response matrix Y.
Expand Down Expand Up @@ -1659,6 +1665,24 @@
%! eq = parseWilkinsonFormula ('y ~ A - A', 'equation');
%! expected = string('y = c1');
%! assert (isequal (eq, expected));
%!test
%! ## Verify parseWilkinsonFormula schema matches MATLAB fitlm sorting
%! formula = 'Y ~ x1 * x2 * x3';
%! schema = parseWilkinsonFormula(formula, 'matrix');
%!
%! ## Expected Octave Binary Matrix (8 rows, 4 columns)
%! ## Columns are sorted alphabetically by the parser: {'Y', 'x1', 'x2', 'x3'}
%! expected_terms = [0, 0, 0, 0; ## (Intercept)
%! 0, 1, 0, 0; ## x1
%! 0, 0, 1, 0; ## x2
%! 0, 0, 0, 1; ## x3
%! 0, 1, 1, 0; ## x1:x2
%! 0, 1, 0, 1; ## x1:x3
%! 0, 0, 1, 1; ## x2:x3
%! 0, 1, 1, 1]; ## x1:x2:x3
%!
%! assert(schema.VariableNames, {'Y', 'x1', 'x2', 'x3'});
%! assert(schema.Terms, expected_terms);
%!error <Input formula string is required> parseWilkinsonFormula ()
%!error <Unknown mode> parseWilkinsonFormula ('y ~ x', 'invalid_mode')
%!error <Unexpected End Of Formula> parseWilkinsonFormula ('', 'parse')
Expand Down