diff --git a/inst/Classification/ClassificationGAM.m b/inst/Classification/ClassificationGAM.m index 087ceb2d..609b7597 100644 --- a/inst/Classification/ClassificationGAM.m +++ b/inst/Classification/ClassificationGAM.m @@ -1,6 +1,7 @@ ## Copyright (C) 2024 Ruchika Sonagote ## Copyright (C) 2024-2025 Andreas Bertsatos ## Copyright (C) 2025 Swayam Shah +## Copyright (C) 2026 Jayant Chauhan <0001jayant@gmail.com> ## ## This file is part of the statistics package for GNU Octave. ## @@ -222,9 +223,12 @@ ## ## Model specification formula ## - ## A character vector specifying the model formula in the form - ## @qcode{"Y ~ terms"} where @qcode{Y} represents the response variable and - ## @qcode{terms} specifies the predictor variables and interaction terms. + ## A character vector specifying the model formula using standard Wilkinson + ## notation in the form @qcode{"Y ~ terms"}. In addition to basic main + ## effects and interactions (@code{+}, @code{:}), it fully supports + ## advanced operators including crossing (@code{*}), nesting (@code{/}), + ## power/limits (@code{^}), and deletion (@code{-}). The formula is + ## evaluated internally via @code{parseWilkinsonFormula}. ## This property is read-only. ## ## @end deftp @@ -518,9 +522,10 @@ function disp (this) ## @qcode{'symmetriclogit'}. ## ## @item @qcode{'Formula'} @tab @tab A character vector specifying the model - ## formula in the form @qcode{"Y ~ terms"} where @qcode{Y} represents the - ## response variable and @qcode{terms} specifies the predictor variables and - ## interaction terms. + ## formula using standard Wilkinson notation in the form @qcode{"Y ~ terms"}. + ## In addition to basic main effects and interactions (@code{+}, @code{:}), + ## it supports advanced operators including crossing (@code{*}), nesting + ## (@code{/}), power/limits (@code{^}), and deletion (@code{-}). ## ## @item @qcode{'Interactions'} @tab @tab A logical matrix, a positive ## integer scalar, or the string @qcode{"all"} for defining the interactions @@ -1276,52 +1281,38 @@ function savemodel (this, fname) ## Determine interactions from formula function intMat = parseFormula (this) - intMat = []; - ## Check formula for syntax - if (isempty (strfind (this.Formula, '~'))) - error ("ClassificationGAM: invalid syntax in 'Formula'."); + try + schema = parseWilkinsonFormula (this.Formula, 'matrix'); + catch ME + error ("ClassificationGAM: Invalid formula. %s", ME.message); + end_try_catch + + termMat = schema.Terms; + varNames = schema.VariableNames; + + if (! isempty (schema.ResponseIdx)) + respIdx = schema.ResponseIdx; + varNames(respIdx) = []; + termMat(:, respIdx) = []; endif - ## Split formula and keep predictor terms - formulaParts = strsplit (this.Formula, '~'); - ## Check there is some string after '~' - if (numel (formulaParts) < 2) - error ("ClassificationGAM: no predictor terms in 'Formula'."); - endif - predictorString = strtrim (formulaParts{2}); - if (isempty (predictorString)) - error ("ClassificationGAM: no predictor terms in 'Formula'."); - endif - ## Split additive terms (between + sign) - aterms = strtrim (strsplit (predictorString, '+')); - ## Process all terms - for i = 1:numel (aterms) - ## Find individual terms (string missing ':') - if (isempty (strfind (aterms(i), ':'){:})) - ## Search PredictorNames to associate with column in X - sterms = strcmp (this.PredictorNames, aterms(i)); - ## Append to interactions matrix - intMat = [intMat; sterms]; - else - ## Split interaction terms (string contains ':') - mterms = strsplit (aterms{i}, ':'); - ## Add each individual predictor to interaction term vector - iterms = logical (zeros (1, this.NumPredictors)); - for t = 1:numel (mterms) - iterms = iterms | strcmp (this.PredictorNames, mterms(t)); - endfor - ## Check that all predictors have been identified - if (sum (iterms) != t) - error (strcat ("ClassificationGAM: some predictors", ... - " have not been identified.")); - endif - ## Append to interactions matrix - intMat = [intMat; iterms]; + + intMat = zeros (rows (termMat), this.NumPredictors); + + for i = 1:numel (varNames) + colIdx = find (strcmp (this.PredictorNames, varNames{i})); + if (isempty (colIdx)) + error ("ClassificationGAM: Formula contains unknown predictor '%s'.", varNames{i}); endif + intMat(:, colIdx) = termMat(:, i); endfor - ## Check that all terms have been identified - if (! all (sum (intMat, 2) > 0)) - error ("ClassificationGAM: some terms have not been identified."); - endif + + ## Remove intercept row (all zeros) + ## ClassificationGAM handles the intercept internally during fitting, + ## so we strip out the explicit intercept term if the parser included it. + interceptRow = (sum (intMat, 2) == 0); + intMat(interceptRow, :) = []; + + intMat = logical (intMat); endfunction ## Fit the model @@ -1634,6 +1625,22 @@ function savemodel (this, fname) %! assert (CVMdl.KFold == 3) %! assert (class (CVMdl.Trained{1}), "CompactClassificationGAM") %! assert (CVMdl.CrossValidatedModel, "ClassificationGAM") +%!test +%! ## Test advanced Wilkinson notation parsing in ClassificationGAM +%! X = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1]; +%! Y = [0; 0; 1; 1]; +%! +%! ## The * operator should automatically expand to main effects + interaction +%! a = ClassificationGAM (X, Y, 'Formula', 'Y ~ x1 * x2'); +%! +%! assert (class (a), "ClassificationGAM"); +%! assert (a.NumPredictors, 3); +%! +%! ## Verify the IntMatrix correctly captured x1, x2, and x1:x2 in MATLAB order +%! expected_int = logical ([1, 0, 0; +%! 0, 1, 0; +%! 1, 1, 0]); +%! assert (a.IntMatrix, expected_int); ## Test input validation for crossval method %!error ... diff --git a/inst/parseWilkinsonFormula.m b/inst/parseWilkinsonFormula.m index 91315e40..3b066817 100644 --- a/inst/parseWilkinsonFormula.m +++ b/inst/parseWilkinsonFormula.m @@ -201,7 +201,7 @@ else lhs_vars = resolve_lhs_symbolic (lhs_str); endif - + ## build the required output. varargout{1} = run_equation_builder (lhs_vars, rhs_terms); @@ -519,14 +519,14 @@ args_str_parts = {}; for k = 1:length (node.args) arg_res = run_expander (node.args{k}, mode); - + if (! isempty (arg_res) && ! isempty (arg_res{1})) - args_str_parts{end+1} = arg_res{1}{1}; + args_str_parts{end+1} = arg_res{1}{1}; else args_str_parts{end+1} = ''; endif endfor - + full_term = sprintf ("%s(%s)", node.name, strjoin (args_str_parts, ',')); result = {{full_term}}; else @@ -829,14 +829,20 @@ terms_mat(i, idx) = 1; endfor - ## sorting : order by order. + ## sorting : order by order (ascending), then by variable sequence (descending) term_orders = sum (terms_mat, 2); M = [term_orders, terms_mat]; + ## Create unique rows first [~, unique_idx] = unique (M, 'rows'); terms_mat = terms_mat (unique_idx, :); + M = M (unique_idx, :); + + ## Create the direction vector: [1, -2, -3, -4, ...] + sort_dirs = [1, -(2:size(M, 2))]; - [~, sort_idx] = sortrows ([sum(terms_mat, 2), terms_mat]); + ## Sort using the direction vector + [~, sort_idx] = sortrows (M, sort_dirs); schema.Terms = terms_mat (sort_idx, :); endfunction @@ -1140,14 +1146,14 @@ ## process RHS rhs_tokens = run_lexer (rhs_str); [rhs_tree, ~] = run_parser (rhs_tokens); - + wrapper.type = 'OPERATOR'; wrapper.value = '~'; wrapper.left = []; wrapper.right = rhs_tree; - + expanded = run_expander (wrapper, mode); - + ## extract the terms. if (isstruct (expanded) && isfield (expanded, 'model')) rhs_terms = expanded.model; @@ -1164,22 +1170,22 @@ for i = 1:length (parts) p = strtrim (parts{i}); if (isempty (p)), continue; endif - + range_parts = strsplit (p, '-'); - + if (length (range_parts) == 2) s_str = strtrim (range_parts{1}); e_str = strtrim (range_parts{2}); - + [s_tok] = regexp (s_str, '^([a-zA-Z_]\w*)(\d+)$', 'tokens'); [e_tok] = regexp (e_str, '^([a-zA-Z_]\w*)(\d+)$', 'tokens'); - + if (! isempty (s_tok) && ! isempty (e_tok)) prefix = s_tok{1}{1}; s_num = str2double (s_tok{1}{2}); e_prefix = e_tok{1}{1}; e_num = str2double (e_tok{1}{2}); - + if (strcmp (prefix, e_prefix) && s_num <= e_num) for n = s_num:e_num vars{end+1} = sprintf ("%s%d", prefix, n); @@ -1203,7 +1209,7 @@ for i = 1:length (rhs_terms) t = rhs_terms{i}; if (isempty (t)) - term_strs{end+1} = ''; + term_strs{end+1} = ''; else if (length (t) == 1 && any (strfind (t{1}, "("))) term_strs{end+1} = t{1}; @@ -1229,13 +1235,13 @@ rhs_parts{end+1} = sprintf ("%s*%s", coeff, t_str); endif endfor - + full_rhs = strjoin (rhs_parts, ' + '); if (isempty (full_rhs)), full_rhs = '0'; endif lines{end+1} = sprintf ("%s = %s", lhs_vars{k}, full_rhs); endfor - eq_list = string (lines'); + eq_list = string (lines'); endfunction %!demo @@ -1271,7 +1277,7 @@ %!demo %! -%! ## Interaction Effects : +%! ## Interaction Effects : %! ## We analyze Relief Score based on Drug Type and Dosage Level. %! ## The '*' operator expands to the main effects PLUS the interaction term. %! ## Categorical variables are automatically created. @@ -1287,11 +1293,11 @@ %!demo %! -%! ## Polynomial Regression : +%! ## Polynomial Regression : %! ## Uses the power operator (^) to model non-linear relationships. %! Distance = [20; 45; 80; 125]; %! Speed = [30; 50; 70; 90]; -%! Speed_2 = Speed .^ 2; +%! Speed_2 = Speed .^ 2; %! t = table (Distance, Speed, Speed_2, 'VariableNames', {'Distance', 'Speed', 'Speed^2'}); %! %! formula = 'Distance ~ Speed^2'; @@ -1316,7 +1322,7 @@ %!demo %! -%! ## Explicit Nesting : +%! ## Explicit Nesting : %! ## The parser also supports the explicit 'B(A)' syntax, which means %! ## 'B is nested within A'. This is equivalent to the interaction 'A:B' %! ## but often used to denote random effects or specific hierarchy. @@ -1327,7 +1333,7 @@ %!demo %! -%! ## Excluding Terms : +%! ## Excluding Terms : %! ## Demonstrates building a complex model and then simplifying it. %! ## We define a full 3-way interaction (A*B*C) but explicitly remove the %! ## three-way term (A:B:C) using the minus operator. @@ -1338,7 +1344,7 @@ %!demo %! -%! ## Repeated Measures : +%! ## Repeated Measures : %! ## This allows predicting multiple outcomes simultaneously. %! ## The range operator '-' selects all variables between 'T1' and 'T3' %! ## as the response matrix Y. @@ -1659,6 +1665,24 @@ %! eq = parseWilkinsonFormula ('y ~ A - A', 'equation'); %! expected = string('y = c1'); %! assert (isequal (eq, expected)); +%!test +%! ## Verify parseWilkinsonFormula schema matches MATLAB fitlm sorting +%! formula = 'Y ~ x1 * x2 * x3'; +%! schema = parseWilkinsonFormula(formula, 'matrix'); +%! +%! ## Expected Octave Binary Matrix (8 rows, 4 columns) +%! ## Columns are sorted alphabetically by the parser: {'Y', 'x1', 'x2', 'x3'} +%! expected_terms = [0, 0, 0, 0; ## (Intercept) +%! 0, 1, 0, 0; ## x1 +%! 0, 0, 1, 0; ## x2 +%! 0, 0, 0, 1; ## x3 +%! 0, 1, 1, 0; ## x1:x2 +%! 0, 1, 0, 1; ## x1:x3 +%! 0, 0, 1, 1; ## x2:x3 +%! 0, 1, 1, 1]; ## x1:x2:x3 +%! +%! assert(schema.VariableNames, {'Y', 'x1', 'x2', 'x3'}); +%! assert(schema.Terms, expected_terms); %!error parseWilkinsonFormula () %!error parseWilkinsonFormula ('y ~ x', 'invalid_mode') %!error parseWilkinsonFormula ('', 'parse')