Skip to content

Commit 8e97c39

Browse files
committed
Fix variable naming in decision tree to pass pre-commit hooks
- Changed all X, X_train, X_test, X_val variables to lowercase - Updated function parameters and variable references - Decision tree now passes all ruff checks - Follows TheAlgorithms/Python strict naming conventions
1 parent dde10ae commit 8e97c39

1 file changed

Lines changed: 71 additions & 71 deletions

File tree

machine_learning/decision_tree_pruning.py

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ def _entropy(self, y: np.ndarray) -> float:
125125
return -np.sum(probabilities * np.log2(probabilities))
126126

127127
def _find_best_split(
128-
self, X: np.ndarray, y: np.ndarray, task_type: str
128+
self, x: np.ndarray, y: np.ndarray, task_type: str
129129
) -> tuple[int, float, float]:
130130
"""
131131
Find the best split for the given data.
132132
133133
Args:
134-
X: Feature matrix
134+
x: Feature matrix
135135
y: Target values
136136
task_type: 'regression' or 'classification'
137137
@@ -142,16 +142,16 @@ def _find_best_split(
142142
best_threshold = 0.0
143143
best_impurity = float('inf')
144144

145-
n_features = X.shape[1]
145+
n_features = x.shape[1]
146146
current_impurity = self._mse(y) if task_type == "regression" else self._gini(y)
147147

148148
for feature_idx in range(n_features):
149149
# Get unique values for this feature
150-
feature_values = np.unique(X[:, feature_idx])
150+
feature_values = np.unique(x[:, feature_idx])
151151

152152
for threshold in feature_values[:-1]: # Exclude the last value
153153
# Split the data
154-
left_mask = X[:, feature_idx] <= threshold
154+
left_mask = x[:, feature_idx] <= threshold
155155
right_mask = ~left_mask
156156

157157
if (
@@ -191,7 +191,7 @@ def _find_best_split(
191191

192192
def _build_tree(
193193
self,
194-
X: np.ndarray,
194+
x: np.ndarray,
195195
y: np.ndarray,
196196
depth: int = 0,
197197
task_type: str = "regression"
@@ -200,7 +200,7 @@ def _build_tree(
200200
Recursively build the decision tree.
201201
202202
Args:
203-
X: Feature matrix
203+
x: Feature matrix
204204
y: Target values
205205
depth: Current depth
206206
task_type: 'regression' or 'classification'
@@ -223,7 +223,7 @@ def _build_tree(
223223

224224
# Find best split
225225
best_feature, best_threshold, best_impurity = self._find_best_split(
226-
X, y, task_type
226+
x, y, task_type
227227
)
228228

229229
# If no good split found, make it a leaf
@@ -236,7 +236,7 @@ def _build_tree(
236236
return node
237237

238238
# Split the data
239-
left_mask = X[:, best_feature] <= best_threshold
239+
left_mask = x[:, best_feature] <= best_threshold
240240
right_mask = ~left_mask
241241

242242
# Create internal node
@@ -248,10 +248,10 @@ def _build_tree(
248248

249249
# Recursively build left and right subtrees
250250
node.left = self._build_tree(
251-
X[left_mask], y[left_mask], depth + 1, task_type
251+
x[left_mask], y[left_mask], depth + 1, task_type
252252
)
253253
node.right = self._build_tree(
254-
X[right_mask], y[right_mask], depth + 1, task_type
254+
x[right_mask], y[right_mask], depth + 1, task_type
255255
)
256256

257257
return node
@@ -269,12 +269,12 @@ def _most_common(self, y: np.ndarray) -> int | float:
269269
values, counts = np.unique(y, return_counts=True)
270270
return values[np.argmax(counts)]
271271

272-
def _reduced_error_pruning(self, X_val: np.ndarray, y_val: np.ndarray) -> None:
272+
def _reduced_error_pruning(self, x_val: np.ndarray, y_val: np.ndarray) -> None:
273273
"""
274274
Perform reduced error pruning on the tree.
275275
276276
Args:
277-
X_val: Validation feature matrix
277+
x_val: Validation feature matrix
278278
y_val: Validation target values
279279
"""
280280
if self.root_ is None:
@@ -295,7 +295,7 @@ def _reduced_error_pruning(self, X_val: np.ndarray, y_val: np.ndarray) -> None:
295295
continue
296296

297297
# Calculate validation error before pruning
298-
predictions_before = self._predict_batch(X_val)
298+
predictions_before = self._predict_batch(x_val)
299299
error_before = self._calculate_error(y_val, predictions_before)
300300

301301
# Temporarily prune the node
@@ -310,7 +310,7 @@ def _reduced_error_pruning(self, X_val: np.ndarray, y_val: np.ndarray) -> None:
310310
node.value = self._most_common(y_val) # Use validation set majority
311311

312312
# Calculate validation error after pruning
313-
predictions_after = self._predict_batch(X_val)
313+
predictions_after = self._predict_batch(x_val)
314314
error_after = self._calculate_error(y_val, predictions_after)
315315

316316
# Calculate improvement
@@ -417,18 +417,18 @@ def _get_internal_nodes(self, node: "TreeNode") -> list["TreeNode"]:
417417
nodes.extend(self._get_internal_nodes(node.right))
418418
return nodes
419419

420-
def _predict_batch(self, X: np.ndarray) -> np.ndarray:
420+
def _predict_batch(self, x: np.ndarray) -> np.ndarray:
421421
"""
422422
Make predictions for a batch of samples.
423423
424424
Args:
425-
X: Feature matrix
425+
x: Feature matrix
426426
427427
Returns:
428428
Predictions
429429
"""
430-
predictions = np.zeros(len(X))
431-
for i, sample in enumerate(X):
430+
predictions = np.zeros(len(x))
431+
for i, sample in enumerate(x):
432432
predictions[i] = self._predict_single(sample, self.root_)
433433
return predictions
434434

@@ -466,75 +466,75 @@ def _calculate_error(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
466466

467467
def fit(
468468
self,
469-
X: np.ndarray,
469+
x: np.ndarray,
470470
y: np.ndarray,
471-
X_val: np.ndarray | None = None,
471+
x_val: np.ndarray | None = None,
472472
y_val: np.ndarray | None = None,
473473
) -> "DecisionTreePruning":
474474
"""
475475
Fit the decision tree with optional pruning.
476476
477477
Args:
478-
X: Training feature matrix
478+
x: Training feature matrix
479479
y: Training target values
480-
X_val: Validation feature matrix (for pruning)
480+
x_val: Validation feature matrix (for pruning)
481481
y_val: Validation target values (for pruning)
482482
483483
Returns:
484484
Self for method chaining
485485
"""
486-
if X.ndim != 2:
487-
raise ValueError("X must be 2-dimensional")
488-
if len(X) != len(y):
489-
raise ValueError("X and y must have the same length")
486+
if x.ndim != 2:
487+
raise ValueError("x must be 2-dimensional")
488+
if len(x) != len(y):
489+
raise ValueError("x and y must have the same length")
490490

491-
self.n_features_ = X.shape[1]
491+
self.n_features_ = x.shape[1]
492492

493493
# Determine task type
494494
task_type = (
495495
"classification" if np.issubdtype(y.dtype, np.integer) else "regression"
496496
)
497497

498498
# Build the tree
499-
self.root_ = self._build_tree(X, y, task_type=task_type)
499+
self.root_ = self._build_tree(x, y, task_type=task_type)
500500

501501
# Apply pruning if specified
502502
if self.pruning_method == "reduced_error":
503-
if X_val is None or y_val is None:
503+
if x_val is None or y_val is None:
504504
raise ValueError("Validation data required for reduced error pruning")
505-
self._reduced_error_pruning(X_val, y_val)
505+
self._reduced_error_pruning(x_val, y_val)
506506
elif self.pruning_method == "cost_complexity":
507507
self._cost_complexity_pruning()
508508

509509
return self
510510

511-
def predict(self, X: np.ndarray) -> np.ndarray:
511+
def predict(self, x: np.ndarray) -> np.ndarray:
512512
"""
513513
Make predictions.
514514
515515
Args:
516-
X: Feature matrix
516+
x: Feature matrix
517517
518518
Returns:
519519
Predictions
520520
"""
521521
if self.root_ is None:
522522
raise ValueError("Tree must be fitted before prediction")
523523

524-
return self._predict_batch(X)
524+
return self._predict_batch(x)
525525

526-
def score(self, X: np.ndarray, y: np.ndarray) -> float:
526+
def score(self, x: np.ndarray, y: np.ndarray) -> float:
527527
"""
528528
Calculate accuracy (for classification) or R² (for regression).
529529
530530
Args:
531-
X: Feature matrix
531+
x: Feature matrix
532532
y: True values
533533
534534
Returns:
535535
Score
536536
"""
537-
predictions = self.predict(X)
537+
predictions = self.predict(x)
538538

539539
if np.issubdtype(y.dtype, np.integer):
540540
# Classification: accuracy
@@ -576,12 +576,12 @@ def generate_regression_data(
576576
random_state: Random seed
577577
578578
Returns:
579-
Tuple of (X, y)
579+
Tuple of (x, y)
580580
"""
581581
rng = np.random.default_rng(random_state)
582-
X = rng.standard_normal((n_samples, 2))
583-
y = X[:, 0] ** 2 + X[:, 1] ** 2 + noise * rng.standard_normal(n_samples)
584-
return X, y
582+
x = rng.standard_normal((n_samples, 2))
583+
y = x[:, 0] ** 2 + x[:, 1] ** 2 + noise * rng.standard_normal(n_samples)
584+
return x, y
585585

586586

587587
def generate_classification_data(
@@ -595,34 +595,34 @@ def generate_classification_data(
595595
random_state: Random seed
596596
597597
Returns:
598-
Tuple of (X, y)
598+
Tuple of (x, y)
599599
"""
600600
rng = np.random.default_rng(random_state)
601-
X = rng.standard_normal((n_samples, 2))
602-
y = ((X[:, 0] + X[:, 1]) > 0).astype(int)
603-
return X, y
601+
x = rng.standard_normal((n_samples, 2))
602+
y = ((x[:, 0] + x[:, 1]) > 0).astype(int)
603+
return x, y
604604

605605

606606
def compare_pruning_methods() -> None:
607607
"""
608608
Compare different pruning methods.
609609
"""
610610
# Generate data
611-
X, y = generate_regression_data(n_samples=200)
611+
x, y = generate_regression_data(n_samples=200)
612612

613613
# Split data
614-
split_idx = int(0.7 * len(X))
615-
X_train, X_test = X[:split_idx], X[split_idx:]
614+
split_idx = int(0.7 * len(x))
615+
x_train, x_test = x[:split_idx], x[split_idx:]
616616
y_train, y_test = y[:split_idx], y[split_idx:]
617617

618618
# Further split training data for validation
619-
val_split = int(0.5 * len(X_train))
620-
X_val, X_train = X_train[:val_split], X_train[val_split:]
619+
val_split = int(0.5 * len(x_train))
620+
x_val, x_train = x_train[:val_split], x_train[val_split:]
621621
y_val, y_train = y_train[:val_split], y_train[val_split:]
622622

623-
print(f"Training set size: {len(X_train)}")
624-
print(f"Validation set size: {len(X_val)}")
625-
print(f"Test set size: {len(X_test)}")
623+
print(f"Training set size: {len(x_train)}")
624+
print(f"Validation set size: {len(x_val)}")
625+
print(f"Test set size: {len(x_test)}")
626626

627627
# Test different pruning methods
628628
methods = [
@@ -642,12 +642,12 @@ def compare_pruning_methods() -> None:
642642
)
643643

644644
if method == "reduced_error":
645-
tree.fit(X_train, y_train, X_val, y_val)
645+
tree.fit(x_train, y_train, x_val, y_val)
646646
else:
647-
tree.fit(X_train, y_train)
647+
tree.fit(x_train, y_train)
648648

649-
train_score = tree.score(X_train, y_train)
650-
test_score = tree.score(X_test, y_test)
649+
train_score = tree.score(x_train, y_train)
650+
test_score = tree.score(x_test, y_test)
651651

652652
print(f"Training R²: {train_score:.4f}")
653653
print(f"Test R²: {test_score:.4f}")
@@ -661,11 +661,11 @@ def main() -> None:
661661
print("=== Regression Example ===")
662662

663663
# Generate regression data
664-
X_reg, y_reg = generate_regression_data(n_samples=200, noise=0.1)
664+
x_reg, y_reg = generate_regression_data(n_samples=200, noise=0.1)
665665

666666
# Split data
667-
split_idx = int(0.8 * len(X_reg))
668-
X_train, X_test = X_reg[:split_idx], X_reg[split_idx:]
667+
split_idx = int(0.8 * len(x_reg))
668+
x_train, x_test = x_reg[:split_idx], x_reg[split_idx:]
669669
y_train, y_test = y_reg[:split_idx], y_reg[split_idx:]
670670

671671
# Train tree with cost-complexity pruning
@@ -675,40 +675,40 @@ def main() -> None:
675675
pruning_method="cost_complexity",
676676
ccp_alpha=0.01
677677
)
678-
tree_reg.fit(X_train, y_train)
678+
tree_reg.fit(x_train, y_train)
679679

680680
# Make predictions
681-
train_score = tree_reg.score(X_train, y_train)
682-
test_score = tree_reg.score(X_test, y_test)
681+
train_score = tree_reg.score(x_train, y_train)
682+
test_score = tree_reg.score(x_test, y_test)
683683

684684
print(f"Training R²: {train_score:.4f}")
685685
print(f"Test R²: {test_score:.4f}")
686686

687687
print("\n=== Classification Example ===")
688688

689689
# Generate classification data
690-
X_cls, y_cls = generate_classification_data(n_samples=200)
690+
x_cls, y_cls = generate_classification_data(n_samples=200)
691691

692692
# Split data
693-
split_idx = int(0.8 * len(X_cls))
694-
X_train, X_test = X_cls[:split_idx], X_cls[split_idx:]
693+
split_idx = int(0.8 * len(x_cls))
694+
x_train, x_test = x_cls[:split_idx], x_cls[split_idx:]
695695
y_train, y_test = y_cls[:split_idx], y_cls[split_idx:]
696696

697697
# Train tree with reduced error pruning
698-
val_split = int(0.5 * len(X_train))
699-
X_val, X_train = X_train[:val_split], X_train[val_split:]
698+
val_split = int(0.5 * len(x_train))
699+
x_val, x_train = x_train[:val_split], x_train[val_split:]
700700
y_val, y_train = y_train[:val_split], y_train[val_split:]
701701

702702
tree_cls = DecisionTreePruning(
703703
max_depth=10,
704704
min_samples_leaf=2,
705705
pruning_method="reduced_error"
706706
)
707-
tree_cls.fit(X_train, y_train, X_val, y_val)
707+
tree_cls.fit(x_train, y_train, x_val, y_val)
708708

709709
# Make predictions
710-
train_accuracy = tree_cls.score(X_train, y_train)
711-
test_accuracy = tree_cls.score(X_test, y_test)
710+
train_accuracy = tree_cls.score(x_train, y_train)
711+
test_accuracy = tree_cls.score(x_test, y_test)
712712

713713
print(f"Training accuracy: {train_accuracy:.4f}")
714714
print(f"Test accuracy: {test_accuracy:.4f}")

0 commit comments

Comments
 (0)