diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5d59849..206f8b5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,9 +6,12 @@ on: pull_request: branches: [ main ] +permissions: + contents: read + jobs: test: - name: Test on Python ${{ matrix.python-version }} + name: Unit Tests (Python ${{ matrix.python-version }}) runs-on: ubuntu-latest strategy: fail-fast: false @@ -30,9 +33,46 @@ jobs: python -m pip install --upgrade pip pip install -e ".[dev]" - - name: Run tests + - name: Run unit tests + run: | + pytest -v --ignore=tests/accuracy --ignore=tests/benchmark -x + + - name: Run tests with coverage + if: matrix.python-version == '3.12' + run: | + pip install pytest-cov + pytest --cov=numta --cov-report=xml --ignore=tests/accuracy --ignore=tests/benchmark + + - name: Upload coverage reports + if: matrix.python-version == '3.12' + uses: codecov/codecov-action@v4 + with: + files: ./coverage.xml + fail_ci_if_error: false + continue-on-error: true + + benchmark: + name: Benchmark Tests + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run benchmark tests run: | - pytest -v + pytest tests/benchmark/test_benchmark.py -v build: name: Build package diff --git a/pyproject.toml b/pyproject.toml index f185604..7b578be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ dev = [ "pytest>=7.0.0", "pytest-benchmark>=4.0.0", + "pytest-cov>=4.0.0", "pandas>=1.3.0", ] pandas = [ @@ -79,3 +80,33 @@ testpaths = ["tests"] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] +markers = [ + "numba: marks tests requiring numba", + "pandas: marks tests requiring pandas", + "talib: marks tests requiring TA-Lib", + "pandas_ta: marks tests requiring pandas-ta", + "slow: marks tests as slow", + "benchmark: marks tests as benchmark tests", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::PendingDeprecationWarning", +] + +[tool.coverage.run] +source = ["src/numta"] +branch = true +omit = [ + "*/tests/*", + "*/_version.py", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "if __name__ == .__main__.:", +] +show_missing = true diff --git a/tests/accuracy/test_accuracy.py b/tests/accuracy/test_accuracy.py new file mode 100644 index 0000000..169676a --- /dev/null +++ b/tests/accuracy/test_accuracy.py @@ -0,0 +1,468 @@ +""" +Accuracy testing framework for numta. + +This module provides infrastructure for testing the accuracy of numta +functions against TA-Lib and pandas-ta implementations. +""" + +import pytest +import numpy as np +from dataclasses import dataclass +from typing import Dict, List, Callable, Optional, Tuple + +import numta + +# Optional imports +try: + import talib + HAS_TALIB = True +except ImportError: + HAS_TALIB = False + +try: + import pandas_ta + import pandas as pd + HAS_PANDAS_TA = True +except ImportError: + HAS_PANDAS_TA = False + pd = None + + +# ===================================================================== +# Data Classes +# ===================================================================== + +@dataclass +class AccuracyMetrics: + """Metrics for comparing two result arrays.""" + mae: float # Mean Absolute Error + rmse: float # Root Mean Square Error + max_error: float # Maximum absolute error + correlation: float # Pearson correlation coefficient + match_rate: float # Percentage of values within tolerance + valid_count: int # Number of valid (non-NaN) comparisons + total_count: int # Total number of elements + + def classification(self) -> str: + """ + Classify the accuracy level. + + Returns + ------- + str + Classification: EXACT, NEAR-EXACT, VERY HIGH, HIGH, or MODERATE + """ + if self.mae < 1e-10 and self.correlation > 0.999999: + return "EXACT" + elif self.mae < 1e-6 and self.correlation > 0.99999: + return "NEAR-EXACT" + elif self.mae < 1e-3 and self.correlation > 0.9999: + return "VERY HIGH" + elif self.mae < 0.01 and self.correlation > 0.999: + return "HIGH" + else: + return "MODERATE" + + +# ===================================================================== +# Comparison Functions +# ===================================================================== + +def compare_results( + result_a: np.ndarray, + result_b: np.ndarray, + tolerance: float = 1e-6 +) -> AccuracyMetrics: + """ + Compare two result arrays and compute accuracy metrics. + + Parameters + ---------- + result_a, result_b : np.ndarray + Arrays to compare + tolerance : float + Tolerance for match rate calculation + + Returns + ------- + AccuracyMetrics + Computed metrics + """ + # Handle different lengths + if len(result_a) != len(result_b): + raise ValueError(f"Arrays must have same length: {len(result_a)} vs {len(result_b)}") + + # Find valid (non-NaN) positions in both arrays + valid_mask = ~(np.isnan(result_a) | np.isnan(result_b)) + valid_a = result_a[valid_mask] + valid_b = result_b[valid_mask] + + if len(valid_a) == 0: + return AccuracyMetrics( + mae=0.0, + rmse=0.0, + max_error=0.0, + correlation=1.0, + match_rate=100.0, + valid_count=0, + total_count=len(result_a), + ) + + # Calculate metrics + diff = np.abs(valid_a - valid_b) + mae = float(np.mean(diff)) + rmse = float(np.sqrt(np.mean((valid_a - valid_b) ** 2))) + max_error = float(np.max(diff)) + + # Correlation coefficient + if np.std(valid_a) > 0 and np.std(valid_b) > 0: + correlation = float(np.corrcoef(valid_a, valid_b)[0, 1]) + else: + correlation = 1.0 if mae < 1e-10 else 0.0 + + # Match rate (within tolerance) + matches = np.sum(diff < tolerance) + match_rate = float((matches / len(valid_a)) * 100) + + return AccuracyMetrics( + mae=mae, + rmse=rmse, + max_error=max_error, + correlation=correlation, + match_rate=match_rate, + valid_count=len(valid_a), + total_count=len(result_a), + ) + + +# ===================================================================== +# Data Generators +# ===================================================================== + +DATA_TYPES = { + 'random': 'Random walk data', + 'trending': 'Upward trend with noise', + 'cyclical': 'Sinusoidal pattern with noise', + 'volatile': 'High volatility data', +} + + +def generate_test_data(size: int, data_type: str, seed: int = 42) -> np.ndarray: + """ + Generate test data of a specific type. + + Parameters + ---------- + size : int + Number of data points + data_type : str + Type of data: 'random', 'trending', 'cyclical', 'volatile' + seed : int + Random seed + + Returns + ------- + np.ndarray + Generated data + """ + np.random.seed(seed) + + if data_type == 'random': + return 100 + np.cumsum(np.random.randn(size) * 0.5) + elif data_type == 'trending': + trend = np.linspace(100, 120, size) + noise = np.random.randn(size) * 0.5 + return trend + noise + elif data_type == 'cyclical': + x = np.linspace(0, 10 * np.pi, size) + cycle = 10 * np.sin(x) + 100 + noise = np.random.randn(size) * 0.3 + return cycle + noise + elif data_type == 'volatile': + return 100 + np.cumsum(np.random.randn(size) * 2.0) + else: + raise ValueError(f"Unknown data type: {data_type}") + + +def generate_ohlcv_data( + size: int, + data_type: str = 'random', + seed: int = 42 +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Generate OHLCV test data. + + Parameters + ---------- + size : int + Number of data points + data_type : str + Type of data + seed : int + Random seed + + Returns + ------- + Tuple + (open, high, low, close, volume) arrays + """ + close = generate_test_data(size, data_type, seed) + np.random.seed(seed + 1) + high = close + np.abs(np.random.randn(size) * 0.5) + low = close - np.abs(np.random.randn(size) * 0.5) + open_ = close + np.random.randn(size) * 0.3 + volume = np.random.randint(1000, 10000, size).astype(np.float64) + return open_, high, low, close, volume + + +# ===================================================================== +# Test Classes +# ===================================================================== + +@pytest.mark.talib +class TestTaLibAccuracy: + """Tests comparing numta accuracy against TA-Lib.""" + + @pytest.fixture(autouse=True) + def check_talib(self): + """Skip if TA-Lib not available.""" + if not HAS_TALIB: + pytest.skip("TA-Lib not installed") + + @pytest.fixture + def test_data(self): + """Generate test data for each test.""" + return { + data_type: generate_ohlcv_data(1000, data_type) + for data_type in DATA_TYPES.keys() + } + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_sma_accuracy(self, test_data, data_type): + """Test SMA accuracy against TA-Lib.""" + _, _, _, close, _ = test_data[data_type] + + numta_result = numta.SMA(close, timeperiod=20) + talib_result = talib.SMA(close, timeperiod=20) + + metrics = compare_results(numta_result, talib_result) + + assert metrics.correlation > 0.999, f"Low correlation: {metrics.correlation}" + assert metrics.mae < 1e-6, f"High MAE: {metrics.mae}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_ema_accuracy(self, test_data, data_type): + """Test EMA accuracy against TA-Lib.""" + _, _, _, close, _ = test_data[data_type] + + numta_result = numta.EMA(close, timeperiod=20) + talib_result = talib.EMA(close, timeperiod=20) + + metrics = compare_results(numta_result, talib_result) + + assert metrics.correlation > 0.999, f"Low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_rsi_accuracy(self, test_data, data_type): + """Test RSI accuracy against TA-Lib.""" + _, _, _, close, _ = test_data[data_type] + + numta_result = numta.RSI(close, timeperiod=14) + talib_result = talib.RSI(close, timeperiod=14) + + metrics = compare_results(numta_result, talib_result) + + assert metrics.correlation > 0.99, f"Low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_macd_accuracy(self, test_data, data_type): + """Test MACD accuracy against TA-Lib.""" + _, _, _, close, _ = test_data[data_type] + + numta_macd, numta_signal, numta_hist = numta.MACD(close) + talib_macd, talib_signal, talib_hist = talib.MACD(close) + + # Check each output + for name, numta_out, talib_out in [ + ('MACD', numta_macd, talib_macd), + ('Signal', numta_signal, talib_signal), + ('Histogram', numta_hist, talib_hist), + ]: + metrics = compare_results(numta_out, talib_out) + assert metrics.correlation > 0.99, f"{name} low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_bbands_accuracy(self, test_data, data_type): + """Test Bollinger Bands accuracy against TA-Lib.""" + _, _, _, close, _ = test_data[data_type] + + numta_upper, numta_middle, numta_lower = numta.BBANDS(close, timeperiod=20) + talib_upper, talib_middle, talib_lower = talib.BBANDS(close, timeperiod=20) + + for name, numta_out, talib_out in [ + ('Upper', numta_upper, talib_upper), + ('Middle', numta_middle, talib_middle), + ('Lower', numta_lower, talib_lower), + ]: + metrics = compare_results(numta_out, talib_out) + assert metrics.correlation > 0.999, f"{name} low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_atr_accuracy(self, test_data, data_type): + """Test ATR accuracy against TA-Lib.""" + _, high, low, close, _ = test_data[data_type] + + numta_result = numta.ATR(high, low, close, timeperiod=14) + talib_result = talib.ATR(high, low, close, timeperiod=14) + + metrics = compare_results(numta_result, talib_result) + + assert metrics.correlation > 0.99, f"Low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_adx_accuracy(self, test_data, data_type): + """Test ADX accuracy against TA-Lib.""" + _, high, low, close, _ = test_data[data_type] + + numta_result = numta.ADX(high, low, close, timeperiod=14) + talib_result = talib.ADX(high, low, close, timeperiod=14) + + metrics = compare_results(numta_result, talib_result) + + assert metrics.correlation > 0.99, f"Low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_stoch_accuracy(self, test_data, data_type): + """Test Stochastic accuracy against TA-Lib.""" + _, high, low, close, _ = test_data[data_type] + + numta_slowk, numta_slowd = numta.STOCH(high, low, close) + talib_slowk, talib_slowd = talib.STOCH(high, low, close) + + for name, numta_out, talib_out in [ + ('SlowK', numta_slowk, talib_slowk), + ('SlowD', numta_slowd, talib_slowd), + ]: + metrics = compare_results(numta_out, talib_out) + assert metrics.correlation > 0.99, f"{name} low correlation: {metrics.correlation}" + + +@pytest.mark.pandas_ta +class TestPandasTaAccuracy: + """Tests comparing numta accuracy against pandas-ta.""" + + @pytest.fixture(autouse=True) + def check_pandas_ta(self): + """Skip if pandas-ta not available.""" + if not HAS_PANDAS_TA: + pytest.skip("pandas-ta not installed") + + @pytest.fixture + def test_data(self): + """Generate test data as pandas DataFrame.""" + data = {} + for data_type in DATA_TYPES.keys(): + open_, high, low, close, volume = generate_ohlcv_data(1000, data_type) + df = pd.DataFrame({ + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume, + }) + data[data_type] = (df, close) + return data + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_sma_accuracy(self, test_data, data_type): + """Test SMA accuracy against pandas-ta.""" + df, close = test_data[data_type] + + numta_result = numta.SMA(close, timeperiod=20) + pandas_ta_result = pandas_ta.sma(df['close'], length=20).values + + metrics = compare_results(numta_result, pandas_ta_result) + + assert metrics.correlation > 0.999, f"Low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_ema_accuracy(self, test_data, data_type): + """Test EMA accuracy against pandas-ta.""" + df, close = test_data[data_type] + + numta_result = numta.EMA(close, timeperiod=20) + pandas_ta_result = pandas_ta.ema(df['close'], length=20).values + + metrics = compare_results(numta_result, pandas_ta_result) + + assert metrics.correlation > 0.99, f"Low correlation: {metrics.correlation}" + + @pytest.mark.parametrize("data_type", list(DATA_TYPES.keys())) + def test_rsi_accuracy(self, test_data, data_type): + """Test RSI accuracy against pandas-ta.""" + df, close = test_data[data_type] + + numta_result = numta.RSI(close, timeperiod=14) + pandas_ta_result = pandas_ta.rsi(df['close'], length=14).values + + metrics = compare_results(numta_result, pandas_ta_result) + + # pandas-ta may use slightly different calculation + assert metrics.correlation > 0.95, f"Low correlation: {metrics.correlation}" + + +# ===================================================================== +# Utility Functions for Manual Testing +# ===================================================================== + +def run_accuracy_report() -> str: + """ + Generate a comprehensive accuracy report. + + Returns + ------- + str + Markdown formatted accuracy report + """ + if not HAS_TALIB: + return "TA-Lib not installed. Cannot generate accuracy report." + + lines = [ + "# numta Accuracy Report", + "", + "Comparison against TA-Lib reference implementation.", + "", + ] + + # Test functions + functions = [ + ('SMA', lambda c: numta.SMA(c, 20), lambda c: talib.SMA(c, 20)), + ('EMA', lambda c: numta.EMA(c, 20), lambda c: talib.EMA(c, 20)), + ('RSI', lambda c: numta.RSI(c, 14), lambda c: talib.RSI(c, 14)), + ] + + for data_type in DATA_TYPES.keys(): + _, _, _, close, _ = generate_ohlcv_data(1000, data_type) + + lines.append(f"## {DATA_TYPES[data_type]}") + lines.append("") + lines.append("| Function | MAE | RMSE | Correlation | Classification |") + lines.append("|----------|-----|------|-------------|----------------|") + + for name, numta_fn, talib_fn in functions: + numta_result = numta_fn(close) + talib_result = talib_fn(close) + metrics = compare_results(numta_result, talib_result) + + lines.append( + f"| {name} | {metrics.mae:.2e} | {metrics.rmse:.2e} | " + f"{metrics.correlation:.6f} | {metrics.classification()} |" + ) + + lines.append("") + + return "\n".join(lines) + + +if __name__ == '__main__': + print(run_accuracy_report()) diff --git a/tests/benchmark/benchmark_runner.py b/tests/benchmark/benchmark_runner.py new file mode 100644 index 0000000..a2ea782 --- /dev/null +++ b/tests/benchmark/benchmark_runner.py @@ -0,0 +1,566 @@ +""" +Benchmark runner for numta performance testing. + +This module provides infrastructure for benchmarking numta functions +against TA-Lib and pandas-ta implementations. +""" + +import time +import json +from dataclasses import dataclass, field, asdict +from typing import Dict, List, Callable, Optional, Any, Tuple +import numpy as np + +# Optional imports +try: + import talib + HAS_TALIB = True +except ImportError: + HAS_TALIB = False + +try: + import pandas_ta + HAS_PANDAS_TA = True +except ImportError: + HAS_PANDAS_TA = False + +import numta + + +# ===================================================================== +# Data Classes for Results +# ===================================================================== + +@dataclass +class BenchmarkResult: + """Result of a single benchmark run.""" + name: str + iterations: int + mean_time: float # seconds + median_time: float # seconds + std_time: float # seconds + min_time: float # seconds + max_time: float # seconds + data_size: int + ops_per_second: float = field(init=False) + + def __post_init__(self): + """Calculate operations per second.""" + if self.mean_time > 0: + self.ops_per_second = 1.0 / self.mean_time + else: + self.ops_per_second = float('inf') + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return asdict(self) + + +@dataclass +class ComparisonResult: + """Result of comparing multiple implementations.""" + function_name: str + data_size: int + results: Dict[str, BenchmarkResult] = field(default_factory=dict) + speedup_vs_baseline: Dict[str, float] = field(default_factory=dict) + + def add_result(self, impl_name: str, result: BenchmarkResult): + """Add a benchmark result for an implementation.""" + self.results[impl_name] = result + + def calculate_speedups(self, baseline: str = 'numta'): + """Calculate speedup ratios vs baseline implementation.""" + if baseline not in self.results: + return + + baseline_time = self.results[baseline].mean_time + for name, result in self.results.items(): + if result.mean_time > 0: + self.speedup_vs_baseline[name] = baseline_time / result.mean_time + else: + self.speedup_vs_baseline[name] = float('inf') + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return { + 'function_name': self.function_name, + 'data_size': self.data_size, + 'results': {k: v.to_dict() for k, v in self.results.items()}, + 'speedup_vs_baseline': self.speedup_vs_baseline, + } + + +# ===================================================================== +# Benchmark Runner +# ===================================================================== + +class BenchmarkRunner: + """ + Runner for benchmarking numta functions. + + Provides methods for timing functions, comparing implementations, + and generating reports. + """ + + DEFAULT_ITERATIONS = 100 + DEFAULT_WARMUP = 10 + DEFAULT_DATA_SIZES = [1000, 10000, 100000] + + def __init__(self, seed: int = 42): + """ + Initialize benchmark runner. + + Parameters + ---------- + seed : int + Random seed for data generation + """ + self.seed = seed + self.results: List[ComparisonResult] = [] + + def _generate_ohlcv_data(self, size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Generate OHLCV data for benchmarking. + + Parameters + ---------- + size : int + Number of data points + + Returns + ------- + Tuple + (open, high, low, close, volume) arrays + """ + np.random.seed(self.seed) + close = 100 + np.cumsum(np.random.randn(size) * 0.5) + high = close + np.abs(np.random.randn(size) * 0.5) + low = close - np.abs(np.random.randn(size) * 0.5) + open_ = close + np.random.randn(size) * 0.3 + volume = np.random.randint(1000, 10000, size).astype(np.float64) + return open_, high, low, close, volume + + def _time_function( + self, + func: Callable, + args: tuple, + kwargs: dict, + iterations: int = DEFAULT_ITERATIONS, + warmup: int = DEFAULT_WARMUP + ) -> List[float]: + """ + Time a function execution. + + Parameters + ---------- + func : Callable + Function to time + args : tuple + Positional arguments + kwargs : dict + Keyword arguments + iterations : int + Number of timing iterations + warmup : int + Number of warmup iterations (not timed) + + Returns + ------- + List[float] + List of execution times in seconds + """ + # Warmup + for _ in range(warmup): + func(*args, **kwargs) + + # Timed runs + times = [] + for _ in range(iterations): + start = time.perf_counter() + func(*args, **kwargs) + end = time.perf_counter() + times.append(end - start) + + return times + + def benchmark_function( + self, + name: str, + func: Callable, + args: tuple, + kwargs: dict, + data_size: int, + iterations: int = DEFAULT_ITERATIONS, + warmup: int = DEFAULT_WARMUP + ) -> BenchmarkResult: + """ + Benchmark a single function. + + Parameters + ---------- + name : str + Name for this benchmark + func : Callable + Function to benchmark + args : tuple + Positional arguments + kwargs : dict + Keyword arguments + data_size : int + Size of input data + iterations : int + Number of timing iterations + warmup : int + Number of warmup iterations + + Returns + ------- + BenchmarkResult + Benchmark results + """ + times = self._time_function(func, args, kwargs, iterations, warmup) + times_array = np.array(times) + + return BenchmarkResult( + name=name, + iterations=iterations, + mean_time=float(np.mean(times_array)), + median_time=float(np.median(times_array)), + std_time=float(np.std(times_array)), + min_time=float(np.min(times_array)), + max_time=float(np.max(times_array)), + data_size=data_size, + ) + + def compare_implementations( + self, + func_name: str, + numta_func: Callable, + numta_args: tuple, + numta_kwargs: dict, + talib_func: Optional[Callable] = None, + talib_args: Optional[tuple] = None, + talib_kwargs: Optional[dict] = None, + pandas_ta_func: Optional[Callable] = None, + pandas_ta_args: Optional[tuple] = None, + pandas_ta_kwargs: Optional[dict] = None, + data_size: int = 10000, + iterations: int = DEFAULT_ITERATIONS + ) -> ComparisonResult: + """ + Compare numta implementation against TA-Lib and/or pandas-ta. + + Parameters + ---------- + func_name : str + Name of the function being compared + numta_func, talib_func, pandas_ta_func : Callable + Functions to compare + *_args, *_kwargs : tuple, dict + Arguments for each function + data_size : int + Size of input data + iterations : int + Number of timing iterations + + Returns + ------- + ComparisonResult + Comparison results + """ + comparison = ComparisonResult( + function_name=func_name, + data_size=data_size, + ) + + # Benchmark numta + result = self.benchmark_function( + name='numta', + func=numta_func, + args=numta_args, + kwargs=numta_kwargs, + data_size=data_size, + iterations=iterations, + ) + comparison.add_result('numta', result) + + # Benchmark TA-Lib if available + if HAS_TALIB and talib_func is not None: + result = self.benchmark_function( + name='talib', + func=talib_func, + args=talib_args or numta_args, + kwargs=talib_kwargs or {}, + data_size=data_size, + iterations=iterations, + ) + comparison.add_result('talib', result) + + # Benchmark pandas-ta if available + if HAS_PANDAS_TA and pandas_ta_func is not None: + result = self.benchmark_function( + name='pandas_ta', + func=pandas_ta_func, + args=pandas_ta_args or (), + kwargs=pandas_ta_kwargs or {}, + data_size=data_size, + iterations=iterations, + ) + comparison.add_result('pandas_ta', result) + + comparison.calculate_speedups(baseline='numta') + self.results.append(comparison) + + return comparison + + def run_standard_benchmarks( + self, + data_sizes: Optional[List[int]] = None, + iterations: int = DEFAULT_ITERATIONS + ) -> List[ComparisonResult]: + """ + Run standard benchmarks for common indicators. + + Parameters + ---------- + data_sizes : List[int], optional + Data sizes to test + iterations : int + Number of timing iterations + + Returns + ------- + List[ComparisonResult] + All comparison results + """ + if data_sizes is None: + data_sizes = self.DEFAULT_DATA_SIZES + + all_results = [] + + for size in data_sizes: + open_, high, low, close, volume = self._generate_ohlcv_data(size) + + # SMA + result = self.compare_implementations( + func_name='SMA', + numta_func=numta.SMA, + numta_args=(close,), + numta_kwargs={'timeperiod': 20}, + talib_func=talib.SMA if HAS_TALIB else None, + talib_args=(close,), + talib_kwargs={'timeperiod': 20}, + data_size=size, + iterations=iterations, + ) + all_results.append(result) + + # EMA + result = self.compare_implementations( + func_name='EMA', + numta_func=numta.EMA, + numta_args=(close,), + numta_kwargs={'timeperiod': 20}, + talib_func=talib.EMA if HAS_TALIB else None, + talib_args=(close,), + talib_kwargs={'timeperiod': 20}, + data_size=size, + iterations=iterations, + ) + all_results.append(result) + + # RSI + result = self.compare_implementations( + func_name='RSI', + numta_func=numta.RSI, + numta_args=(close,), + numta_kwargs={'timeperiod': 14}, + talib_func=talib.RSI if HAS_TALIB else None, + talib_args=(close,), + talib_kwargs={'timeperiod': 14}, + data_size=size, + iterations=iterations, + ) + all_results.append(result) + + # MACD + result = self.compare_implementations( + func_name='MACD', + numta_func=numta.MACD, + numta_args=(close,), + numta_kwargs={}, + talib_func=talib.MACD if HAS_TALIB else None, + talib_args=(close,), + talib_kwargs={}, + data_size=size, + iterations=iterations, + ) + all_results.append(result) + + # ATR + result = self.compare_implementations( + func_name='ATR', + numta_func=numta.ATR, + numta_args=(high, low, close), + numta_kwargs={'timeperiod': 14}, + talib_func=talib.ATR if HAS_TALIB else None, + talib_args=(high, low, close), + talib_kwargs={'timeperiod': 14}, + data_size=size, + iterations=iterations, + ) + all_results.append(result) + + return all_results + + def generate_report(self, results: Optional[List[ComparisonResult]] = None) -> str: + """ + Generate a markdown report of benchmark results. + + Parameters + ---------- + results : List[ComparisonResult], optional + Results to include in report. Uses stored results if not provided. + + Returns + ------- + str + Markdown formatted report + """ + if results is None: + results = self.results + + if not results: + return "No benchmark results available." + + lines = [ + "# numta Benchmark Report", + "", + "## Summary", + "", + ] + + # Group by function + by_function: Dict[str, List[ComparisonResult]] = {} + for result in results: + if result.function_name not in by_function: + by_function[result.function_name] = [] + by_function[result.function_name].append(result) + + for func_name, func_results in by_function.items(): + lines.append(f"### {func_name}") + lines.append("") + lines.append("| Data Size | numta (μs) | TA-Lib (μs) | Speedup vs TA-Lib |") + lines.append("|-----------|------------|-------------|-------------------|") + + for result in sorted(func_results, key=lambda x: x.data_size): + numta_time = result.results.get('numta') + talib_time = result.results.get('talib') + + numta_us = numta_time.mean_time * 1e6 if numta_time else 'N/A' + talib_us = talib_time.mean_time * 1e6 if talib_time else 'N/A' + + if numta_time and talib_time: + speedup = talib_time.mean_time / numta_time.mean_time + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = 'N/A' + + if isinstance(numta_us, float): + numta_us = f"{numta_us:.2f}" + if isinstance(talib_us, float): + talib_us = f"{talib_us:.2f}" + + lines.append(f"| {result.data_size:,} | {numta_us} | {talib_us} | {speedup_str} |") + + lines.append("") + + return "\n".join(lines) + + def save_results(self, filepath: str, results: Optional[List[ComparisonResult]] = None): + """ + Save benchmark results to JSON file. + + Parameters + ---------- + filepath : str + Path to save results + results : List[ComparisonResult], optional + Results to save. Uses stored results if not provided. + """ + if results is None: + results = self.results + + data = { + 'version': '1.0', + 'seed': self.seed, + 'results': [r.to_dict() for r in results], + } + + with open(filepath, 'w') as f: + json.dump(data, f, indent=2) + + def load_results(self, filepath: str) -> List[ComparisonResult]: + """ + Load benchmark results from JSON file. + + Parameters + ---------- + filepath : str + Path to load results from + + Returns + ------- + List[ComparisonResult] + Loaded results + """ + with open(filepath, 'r') as f: + data = json.load(f) + + results = [] + for r in data.get('results', []): + comparison = ComparisonResult( + function_name=r['function_name'], + data_size=r['data_size'], + ) + for name, result_data in r.get('results', {}).items(): + benchmark_result = BenchmarkResult( + name=result_data['name'], + iterations=result_data['iterations'], + mean_time=result_data['mean_time'], + median_time=result_data['median_time'], + std_time=result_data['std_time'], + min_time=result_data['min_time'], + max_time=result_data['max_time'], + data_size=result_data['data_size'], + ) + comparison.add_result(name, benchmark_result) + comparison.speedup_vs_baseline = r.get('speedup_vs_baseline', {}) + results.append(comparison) + + return results + + +# ===================================================================== +# Convenience Functions +# ===================================================================== + +def run_quick_benchmark() -> str: + """ + Run a quick benchmark and return markdown report. + + Returns + ------- + str + Markdown formatted report + """ + runner = BenchmarkRunner() + runner.run_standard_benchmarks( + data_sizes=[1000, 10000], + iterations=50, + ) + return runner.generate_report() + + +if __name__ == '__main__': + print(run_quick_benchmark()) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3dbf226 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,363 @@ +""" +Shared pytest configuration and fixtures for numta test suite. +""" + +import pytest +import numpy as np +from typing import Optional, Tuple + +# ===================================================================== +# Conditional Imports +# ===================================================================== + +try: + import pandas as pd + HAS_PANDAS = True +except ImportError: + HAS_PANDAS = False + pd = None + +try: + import numba + HAS_NUMBA = True +except ImportError: + HAS_NUMBA = False + +try: + import talib + HAS_TALIB = True +except ImportError: + HAS_TALIB = False + +try: + import pandas_ta + HAS_PANDAS_TA = True +except ImportError: + HAS_PANDAS_TA = False + + +# ===================================================================== +# Custom Pytest Markers +# ===================================================================== + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line( + "markers", "numba: marks tests requiring numba (skip if not available)" + ) + config.addinivalue_line( + "markers", "pandas: marks tests requiring pandas (skip if not available)" + ) + config.addinivalue_line( + "markers", "talib: marks tests requiring TA-Lib (skip if not available)" + ) + config.addinivalue_line( + "markers", "pandas_ta: marks tests requiring pandas-ta (skip if not available)" + ) + config.addinivalue_line( + "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" + ) + config.addinivalue_line( + "markers", "benchmark: marks tests as benchmark tests" + ) + + +def pytest_collection_modifyitems(config, items): + """Automatically skip tests based on marker and availability.""" + skip_numba = pytest.mark.skip(reason="numba not installed") + skip_pandas = pytest.mark.skip(reason="pandas not installed") + skip_talib = pytest.mark.skip(reason="TA-Lib not installed") + skip_pandas_ta = pytest.mark.skip(reason="pandas-ta not installed") + + for item in items: + if "numba" in item.keywords and not HAS_NUMBA: + item.add_marker(skip_numba) + if "pandas" in item.keywords and not HAS_PANDAS: + item.add_marker(skip_pandas) + if "talib" in item.keywords and not HAS_TALIB: + item.add_marker(skip_talib) + if "pandas_ta" in item.keywords and not HAS_PANDAS_TA: + item.add_marker(skip_pandas_ta) + + +# ===================================================================== +# Data Generation Fixtures +# ===================================================================== + +RANDOM_SEED = 42 + + +@pytest.fixture +def sample_ohlcv_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Generate sample OHLCV data as numpy arrays. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] + (open, high, low, close, volume) arrays with 100 data points + """ + np.random.seed(RANDOM_SEED) + n = 100 + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + high = close + np.abs(np.random.randn(n) * 0.5) + low = close - np.abs(np.random.randn(n) * 0.5) + open_ = close + np.random.randn(n) * 0.3 + volume = np.random.randint(1000, 10000, n).astype(np.float64) + + return open_, high, low, close, volume + + +@pytest.fixture +def sample_close_data() -> np.ndarray: + """ + Generate sample close price data. + + Returns + ------- + np.ndarray + Close price array with 100 data points + """ + np.random.seed(RANDOM_SEED) + return 100 + np.cumsum(np.random.randn(100) * 0.5) + + +@pytest.fixture +def sample_ohlcv_dataframe(sample_ohlcv_data): + """ + Generate sample OHLCV data as pandas DataFrame. + + Returns + ------- + pd.DataFrame + DataFrame with open, high, low, close, volume columns + """ + if not HAS_PANDAS: + pytest.skip("pandas not installed") + + open_, high, low, close, volume = sample_ohlcv_data + return pd.DataFrame({ + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }) + + +@pytest.fixture +def sample_ohlcv_dataframe_with_datetime(sample_ohlcv_data): + """ + Generate sample OHLCV DataFrame with DatetimeIndex. + + Returns + ------- + pd.DataFrame + DataFrame with DatetimeIndex + """ + if not HAS_PANDAS: + pytest.skip("pandas not installed") + + open_, high, low, close, volume = sample_ohlcv_data + index = pd.date_range('2020-01-01', periods=len(close), freq='D') + return pd.DataFrame({ + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }, index=index) + + +@pytest.fixture +def large_sample_ohlcv_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Generate larger sample OHLCV data for performance testing. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] + (open, high, low, close, volume) arrays with 10000 data points + """ + np.random.seed(RANDOM_SEED) + n = 10000 + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + high = close + np.abs(np.random.randn(n) * 0.5) + low = close - np.abs(np.random.randn(n) * 0.5) + open_ = close + np.random.randn(n) * 0.3 + volume = np.random.randint(1000, 10000, n).astype(np.float64) + + return open_, high, low, close, volume + + +# ===================================================================== +# Edge Case Fixtures +# ===================================================================== + +@pytest.fixture +def edge_case_data(): + """ + Generate edge case test data. + + Returns + ------- + dict + Dictionary with various edge case data: + - 'empty': empty array + - 'single': single value array + - 'two': two value array + - 'constant': array with constant values + - 'with_nan': array with NaN values + - 'with_inf': array with inf values + - 'all_nan': array with all NaN values + - 'negative': array with negative values + - 'zeros': array with zeros + """ + return { + 'empty': np.array([], dtype=np.float64), + 'single': np.array([100.0], dtype=np.float64), + 'two': np.array([100.0, 101.0], dtype=np.float64), + 'constant': np.full(50, 100.0, dtype=np.float64), + 'with_nan': np.array([100.0, np.nan, 101.0, 102.0, np.nan, 103.0], dtype=np.float64), + 'with_inf': np.array([100.0, np.inf, 101.0, -np.inf, 102.0], dtype=np.float64), + 'all_nan': np.full(10, np.nan, dtype=np.float64), + 'negative': np.array([-100.0, -99.0, -98.0, -97.0, -96.0], dtype=np.float64), + 'zeros': np.zeros(10, dtype=np.float64), + } + + +@pytest.fixture +def edge_case_ohlcv_data(): + """ + Generate edge case OHLCV test data. + + Returns + ------- + dict + Dictionary with OHLCV edge case data: + - 'empty': tuple of empty arrays + - 'single': tuple of single value arrays + - 'constant': tuple of constant value arrays + """ + empty = np.array([], dtype=np.float64) + single = np.array([100.0], dtype=np.float64) + constant = np.full(50, 100.0, dtype=np.float64) + + return { + 'empty': (empty, empty, empty, empty, empty), + 'single': (single, single + 1, single - 1, single, single * 10), + 'constant': (constant, constant + 1, constant - 1, constant, constant * 10), + } + + +# ===================================================================== +# Data Type Generators +# ===================================================================== + +@pytest.fixture +def data_generators(): + """ + Provide data generators for different data types. + + Returns + ------- + dict + Dictionary of data generators: + - 'random': random walk data + - 'trending': upward trending data with noise + - 'cyclical': sine wave with noise + - 'volatile': high volatility data + """ + def generate_random(size: int = 1000, seed: int = 42) -> np.ndarray: + """Generate random walk data.""" + np.random.seed(seed) + return 100 + np.cumsum(np.random.randn(size) * 0.5) + + def generate_trending(size: int = 1000, seed: int = 42) -> np.ndarray: + """Generate trending data with noise.""" + np.random.seed(seed) + trend = np.linspace(100, 120, size) + noise = np.random.randn(size) * 0.5 + return trend + noise + + def generate_cyclical(size: int = 1000, seed: int = 42) -> np.ndarray: + """Generate cyclical data with noise.""" + np.random.seed(seed) + x = np.linspace(0, 10 * np.pi, size) + cycle = 10 * np.sin(x) + 100 + noise = np.random.randn(size) * 0.3 + return cycle + noise + + def generate_volatile(size: int = 1000, seed: int = 42) -> np.ndarray: + """Generate high volatility data.""" + np.random.seed(seed) + return 100 + np.cumsum(np.random.randn(size) * 2.0) + + return { + 'random': generate_random, + 'trending': generate_trending, + 'cyclical': generate_cyclical, + 'volatile': generate_volatile, + } + + +# ===================================================================== +# Helper Functions +# ===================================================================== + +def create_ohlcv_from_close(close: np.ndarray, seed: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Create OHLCV data from close prices. + + Parameters + ---------- + close : np.ndarray + Close price array + seed : int, optional + Random seed for reproducibility + + Returns + ------- + Tuple + (open, high, low, close, volume) arrays + """ + np.random.seed(seed) + n = len(close) + high = close + np.abs(np.random.randn(n) * 0.5) + low = close - np.abs(np.random.randn(n) * 0.5) + open_ = close + np.random.randn(n) * 0.3 + volume = np.random.randint(1000, 10000, n).astype(np.float64) + return open_, high, low, close, volume + + +def arrays_almost_equal(a: np.ndarray, b: np.ndarray, decimal: int = 10) -> bool: + """ + Check if two arrays are almost equal, ignoring NaN positions. + + Parameters + ---------- + a, b : np.ndarray + Arrays to compare + decimal : int + Number of decimal places for comparison + + Returns + ------- + bool + True if arrays are almost equal + """ + # Handle different lengths + if len(a) != len(b): + return False + + # Find valid (non-NaN) positions in both arrays + valid_mask = ~(np.isnan(a) | np.isnan(b)) + + if not np.any(valid_mask): + return True # Both all NaN + + try: + np.testing.assert_array_almost_equal(a[valid_mask], b[valid_mask], decimal=decimal) + return True + except AssertionError: + return False diff --git a/tests/test_pandas_ext_comprehensive.py b/tests/test_pandas_ext_comprehensive.py new file mode 100644 index 0000000..7ab8a0f --- /dev/null +++ b/tests/test_pandas_ext_comprehensive.py @@ -0,0 +1,427 @@ +""" +Comprehensive tests for pandas DataFrame extension accessor (.ta). + +This module provides extensive testing for the pandas integration, +including edge cases, index preservation, and all indicator types. +""" + +import pytest +import numpy as np + +# Check if pandas is available +try: + import pandas as pd + HAS_PANDAS = True +except ImportError: + HAS_PANDAS = False + pd = None + +if HAS_PANDAS: + import numta + + +# Skip all tests if pandas is not available +pytestmark = pytest.mark.skipif(not HAS_PANDAS, reason="pandas not installed") + + +# ===================================================================== +# Fixtures +# ===================================================================== + +@pytest.fixture +def sample_df(): + """Create a sample OHLCV DataFrame for testing.""" + np.random.seed(42) + n = 100 + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + high = close + np.abs(np.random.randn(n) * 0.5) + low = close - np.abs(np.random.randn(n) * 0.5) + open_ = close + np.random.randn(n) * 0.3 + volume = np.random.randint(1000, 10000, n).astype(float) + + return pd.DataFrame({ + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }) + + +@pytest.fixture +def sample_series(): + """Create a sample price Series.""" + np.random.seed(42) + return pd.Series(100 + np.cumsum(np.random.randn(100) * 0.5), name='close') + + +@pytest.fixture +def datetime_df(): + """Create a DataFrame with DatetimeIndex.""" + np.random.seed(42) + n = 100 + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + high = close + np.abs(np.random.randn(n) * 0.5) + low = close - np.abs(np.random.randn(n) * 0.5) + open_ = close + np.random.randn(n) * 0.3 + volume = np.random.randint(1000, 10000, n).astype(float) + + index = pd.date_range('2020-01-01', periods=n, freq='D') + + return pd.DataFrame({ + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }, index=index) + + +@pytest.fixture +def empty_df(): + """Create an empty DataFrame.""" + return pd.DataFrame({ + 'open': pd.Series([], dtype=float), + 'high': pd.Series([], dtype=float), + 'low': pd.Series([], dtype=float), + 'close': pd.Series([], dtype=float), + 'volume': pd.Series([], dtype=float) + }) + + +@pytest.fixture +def nan_df(): + """Create a DataFrame with NaN values.""" + np.random.seed(42) + n = 20 + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + close[5] = np.nan + close[10] = np.nan + high = close + np.abs(np.random.randn(n) * 0.5) + low = close - np.abs(np.random.randn(n) * 0.5) + open_ = close + np.random.randn(n) * 0.3 + volume = np.random.randint(1000, 10000, n).astype(float) + + return pd.DataFrame({ + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }) + + +# ===================================================================== +# Test Classes +# ===================================================================== + +class TestAccessorRegistration: + """Test that accessor is properly registered.""" + + def test_ta_accessor_exists_on_dataframe(self, sample_df): + """Test that .ta accessor is available on DataFrame.""" + assert hasattr(sample_df, 'ta') + + def test_ta_accessor_callable(self, sample_df): + """Test that accessor methods are callable.""" + assert callable(getattr(sample_df.ta, 'sma', None)) + assert callable(getattr(sample_df.ta, 'ema', None)) + assert callable(getattr(sample_df.ta, 'rsi', None)) + + def test_accessor_with_different_dtypes(self): + """Test accessor works with different column dtypes.""" + df = pd.DataFrame({ + 'open': [100, 101, 102, 103, 104], + 'high': [105, 106, 107, 108, 109], + 'low': [99, 100, 101, 102, 103], + 'close': [104, 105, 106, 107, 108], + 'volume': [1000, 1100, 1200, 1300, 1400] + }) + + # Should work with integer columns + assert hasattr(df, 'ta') + result = df.ta.sma(timeperiod=2) + assert isinstance(result, pd.Series) + + +class TestSeriesIndicators: + """Test indicators on Series objects.""" + + def test_sma_on_series(self, sample_df): + """Test SMA on Series via DataFrame accessor.""" + result = sample_df.ta.sma(timeperiod=10) + assert isinstance(result, pd.Series) + assert len(result) == len(sample_df) + assert result.name == 'SMA_10' + + def test_ema_on_series(self, sample_df): + """Test EMA on Series via DataFrame accessor.""" + result = sample_df.ta.ema(timeperiod=10) + assert isinstance(result, pd.Series) + assert len(result) == len(sample_df) + assert result.name == 'EMA_10' + + def test_rsi_on_series(self, sample_df): + """Test RSI calculation.""" + result = sample_df.ta.rsi(timeperiod=14) + assert isinstance(result, pd.Series) + + # RSI should be bounded between 0 and 100 + valid_values = result.dropna() + assert (valid_values >= 0).all() and (valid_values <= 100).all() + + +class TestOHLCIndicators: + """Test indicators requiring OHLC data.""" + + def test_atr_calculation(self, sample_df): + """Test ATR calculation using high, low, close.""" + result = sample_df.ta.atr(timeperiod=14) + assert isinstance(result, pd.Series) + assert len(result) == len(sample_df) + + # ATR should be positive + valid_values = result.dropna() + assert (valid_values >= 0).all() + + def test_adx_calculation(self, sample_df): + """Test ADX calculation.""" + result = sample_df.ta.adx(timeperiod=14) + assert isinstance(result, pd.Series) + assert len(result) == len(sample_df) + + def test_stoch_calculation(self, sample_df): + """Test Stochastic calculation.""" + result = sample_df.ta.stoch() + assert isinstance(result, pd.DataFrame) + assert len(result) == len(sample_df) + assert 'STOCH_SLOWK_5_3_3' in result.columns + assert 'STOCH_SLOWD_5_3_3' in result.columns + + def test_bbands_calculation(self, sample_df): + """Test Bollinger Bands calculation.""" + result = sample_df.ta.bbands(timeperiod=20) + assert isinstance(result, pd.DataFrame) + assert 'BBU_20_2.0' in result.columns + assert 'BBM_20' in result.columns + assert 'BBL_20_2.0' in result.columns + + # Upper should be above middle, middle above lower + valid_mask = ~result.isna().any(axis=1) + assert (result.loc[valid_mask, 'BBU_20_2.0'] >= result.loc[valid_mask, 'BBM_20']).all() + assert (result.loc[valid_mask, 'BBM_20'] >= result.loc[valid_mask, 'BBL_20_2.0']).all() + + +class TestIndexPreservation: + """Test that index is preserved in results.""" + + def test_datetime_index_preserved(self, datetime_df): + """Test that DatetimeIndex is preserved.""" + result = datetime_df.ta.sma(timeperiod=10) + pd.testing.assert_index_equal(result.index, datetime_df.index) + + def test_custom_integer_index_preserved(self, sample_df): + """Test that custom integer index is preserved.""" + sample_df.index = range(100, 200) + result = sample_df.ta.sma(timeperiod=10) + pd.testing.assert_index_equal(result.index, sample_df.index) + + def test_string_index_preserved(self, sample_df): + """Test that string index is preserved.""" + sample_df.index = [f'row_{i}' for i in range(len(sample_df))] + result = sample_df.ta.sma(timeperiod=10) + pd.testing.assert_index_equal(result.index, sample_df.index) + + def test_multi_output_index_preserved(self, datetime_df): + """Test that multi-output functions preserve index.""" + result = datetime_df.ta.bbands(timeperiod=10) + pd.testing.assert_index_equal(result.index, datetime_df.index) + + +class TestEdgeCases: + """Test edge cases for pandas extension.""" + + def test_empty_dataframe(self, empty_df): + """Test handling of empty DataFrame.""" + result = empty_df.ta.sma(timeperiod=10) + assert len(result) == 0 + assert isinstance(result, pd.Series) + + def test_nan_handling(self, nan_df): + """Test handling of NaN values in data.""" + result = nan_df.ta.sma(timeperiod=5) + assert isinstance(result, pd.Series) + assert len(result) == len(nan_df) + + def test_single_row_dataframe(self): + """Test with single row DataFrame.""" + df = pd.DataFrame({ + 'open': [100.0], + 'high': [105.0], + 'low': [99.0], + 'close': [104.0], + 'volume': [1000.0] + }) + result = df.ta.sma(timeperiod=2) + assert len(result) == 1 + assert pd.isna(result.iloc[0]) + + def test_timeperiod_larger_than_data(self, sample_df): + """Test when timeperiod is larger than data length.""" + # Use a smaller subset to avoid potential JIT issues + small_df = sample_df.head(50) + result = small_df.ta.sma(timeperiod=100) + assert len(result) == len(small_df) + # All values should be NaN when timeperiod > data length + assert result.isna().all() + + def test_missing_columns_error(self): + """Test that missing required columns raise appropriate errors.""" + df = pd.DataFrame({'price': [1, 2, 3]}) + with pytest.raises(ValueError): + df.ta.sma(timeperiod=2) + + +class TestAppendBehavior: + """Test append=True/False behavior.""" + + def test_append_false_returns_series(self, sample_df): + """Test that append=False returns a Series.""" + result = sample_df.ta.sma(timeperiod=10, append=False) + assert isinstance(result, pd.Series) + assert 'SMA_10' not in sample_df.columns + + def test_append_true_modifies_dataframe(self, sample_df): + """Test that append=True adds column to DataFrame.""" + original_cols = len(sample_df.columns) + sample_df.ta.sma(timeperiod=10, append=True) + assert 'SMA_10' in sample_df.columns + assert len(sample_df.columns) == original_cols + 1 + + def test_append_returns_none(self, sample_df): + """Test that append=True returns None.""" + result = sample_df.ta.sma(timeperiod=10, append=True) + assert result is None + + def test_multiple_appends(self, sample_df): + """Test multiple indicator appends.""" + sample_df.ta.sma(timeperiod=10, append=True) + sample_df.ta.ema(timeperiod=10, append=True) + sample_df.ta.rsi(timeperiod=14, append=True) + + assert 'SMA_10' in sample_df.columns + assert 'EMA_10' in sample_df.columns + assert 'RSI_14' in sample_df.columns + + def test_append_multi_output(self, sample_df): + """Test appending multi-output indicators.""" + sample_df.ta.bbands(timeperiod=20, append=True) + assert 'BBU_20_2.0' in sample_df.columns + assert 'BBM_20' in sample_df.columns + assert 'BBL_20_2.0' in sample_df.columns + + +class TestCustomColumn: + """Test using custom columns.""" + + def test_custom_column_name(self, sample_df): + """Test using a custom column for calculation.""" + sample_df['custom_price'] = sample_df['close'] * 1.1 + result = sample_df.ta.sma(timeperiod=10, column='custom_price') + + assert isinstance(result, pd.Series) + # Result should be different from close-based SMA + close_sma = sample_df.ta.sma(timeperiod=10) + assert not np.allclose(result.dropna().values, close_sma.dropna().values) + + def test_invalid_column_raises(self, sample_df): + """Test that specifying invalid column raises error.""" + with pytest.raises(ValueError): + sample_df.ta.sma(timeperiod=10, column='nonexistent') + + +class TestVolumeIndicators: + """Test volume-based indicators.""" + + def test_obv_calculation(self, sample_df): + """Test OBV calculation.""" + result = sample_df.ta.obv() + assert isinstance(result, pd.Series) + assert len(result) == len(sample_df) + + def test_ad_calculation(self, sample_df): + """Test AD calculation.""" + result = sample_df.ta.ad() + assert isinstance(result, pd.Series) + assert len(result) == len(sample_df) + + def test_mfi_calculation(self, sample_df): + """Test MFI calculation.""" + result = sample_df.ta.mfi(timeperiod=14) + assert isinstance(result, pd.Series) + + # MFI should be bounded between 0 and 100 + valid_values = result.dropna() + assert (valid_values >= 0).all() and (valid_values <= 100).all() + + +class TestCandlestickPatterns: + """Test candlestick pattern recognition.""" + + def test_cdldoji_pattern(self, sample_df): + """Test CDLDOJI pattern detection.""" + result = sample_df.ta.cdldoji() + assert isinstance(result, pd.Series) + assert result.name == 'CDLDOJI' + + # Pattern should return -100, 0, or 100 + unique_values = set(result.dropna().unique().astype(int)) + assert unique_values.issubset({-100, 0, 100}) + + def test_cdlengulfing_pattern(self, sample_df): + """Test CDLENGULFING pattern detection.""" + result = sample_df.ta.cdlengulfing() + assert isinstance(result, pd.Series) + assert result.name == 'CDLENGULFING' + + +class TestResultsMatchNumta: + """Test that results match direct numta calls.""" + + def test_sma_matches_numta(self, sample_df): + """Test SMA matches direct numta call.""" + accessor_result = sample_df.ta.sma(timeperiod=10) + direct_result = numta.SMA(sample_df['close'].values, timeperiod=10) + np.testing.assert_array_almost_equal(accessor_result.values, direct_result) + + def test_ema_matches_numta(self, sample_df): + """Test EMA matches direct numta call.""" + accessor_result = sample_df.ta.ema(timeperiod=10) + direct_result = numta.EMA(sample_df['close'].values, timeperiod=10) + np.testing.assert_array_almost_equal(accessor_result.values, direct_result) + + def test_rsi_matches_numta(self, sample_df): + """Test RSI matches direct numta call.""" + accessor_result = sample_df.ta.rsi(timeperiod=14) + direct_result = numta.RSI(sample_df['close'].values, timeperiod=14) + np.testing.assert_array_almost_equal(accessor_result.values, direct_result) + + def test_atr_matches_numta(self, sample_df): + """Test ATR matches direct numta call.""" + accessor_result = sample_df.ta.atr(timeperiod=14) + direct_result = numta.ATR( + sample_df['high'].values, + sample_df['low'].values, + sample_df['close'].values, + timeperiod=14 + ) + np.testing.assert_array_almost_equal(accessor_result.values, direct_result) + + def test_bbands_matches_numta(self, sample_df): + """Test BBANDS matches direct numta call.""" + accessor_result = sample_df.ta.bbands(timeperiod=20) + upper, middle, lower = numta.BBANDS(sample_df['close'].values, timeperiod=20) + + np.testing.assert_array_almost_equal(accessor_result['BBU_20_2.0'].values, upper) + np.testing.assert_array_almost_equal(accessor_result['BBM_20'].values, middle) + np.testing.assert_array_almost_equal(accessor_result['BBL_20_2.0'].values, lower) diff --git a/tests/test_streaming_comprehensive.py b/tests/test_streaming_comprehensive.py new file mode 100644 index 0000000..ba4c5ce --- /dev/null +++ b/tests/test_streaming_comprehensive.py @@ -0,0 +1,721 @@ +""" +Comprehensive tests for streaming indicators. + +This module provides extensive testing for streaming indicator functionality, +including batch equivalence, reset behavior, and edge cases. +""" + +import pytest +import numpy as np + +from numta.streaming import ( + # Base classes + StreamingIndicator, + CircularBuffer, + # Overlap + StreamingSMA, + StreamingEMA, + StreamingBBANDS, + StreamingDEMA, + StreamingTEMA, + StreamingWMA, + # Momentum + StreamingRSI, + StreamingMACD, + StreamingSTOCH, + StreamingMOM, + StreamingROC, + # Volatility + StreamingATR, + StreamingTRANGE, + # Volume + StreamingOBV, + StreamingAD, +) + +from numta import ( + SMA, EMA, BBANDS, DEMA, TEMA, WMA, + RSI, MACD, STOCH, MOM, ROC, + ATR, TRANGE, + OBV, AD, +) + + +# ===================================================================== +# Fixtures +# ===================================================================== + +RANDOM_SEED = 42 + + +@pytest.fixture +def sample_prices(): + """Generate sample price data.""" + np.random.seed(RANDOM_SEED) + return 100 + np.cumsum(np.random.randn(100) * 0.5) + + +@pytest.fixture +def sample_ohlcv(): + """Generate sample OHLCV data.""" + np.random.seed(RANDOM_SEED) + n = 100 + close = 100 + np.cumsum(np.random.randn(n) * 0.5) + high = close + np.abs(np.random.randn(n) * 0.5) + low = close - np.abs(np.random.randn(n) * 0.5) + open_ = close + np.random.randn(n) * 0.3 + volume = np.random.randint(1000, 10000, n).astype(float) + return open_, high, low, close, volume + + +@pytest.fixture +def large_sample_prices(): + """Generate larger sample price data for robustness testing.""" + np.random.seed(RANDOM_SEED) + return 100 + np.cumsum(np.random.randn(1000) * 0.5) + + +# ===================================================================== +# Test Classes +# ===================================================================== + +class TestStreamingSMABasic: + """Basic tests for StreamingSMA.""" + + def test_initialization(self): + """Test proper initialization.""" + sma = StreamingSMA(timeperiod=10) + assert sma.timeperiod == 10 + assert not sma.ready + assert sma.value is None + + def test_first_values_none(self): + """Test that first values return None until ready.""" + sma = StreamingSMA(timeperiod=5) + + for i in range(4): + result = sma.update(100 + i) + assert result is None, f"Expected None at index {i}" + + # 5th value should produce result + result = sma.update(104) + assert result is not None + + def test_simple_calculation(self): + """Test simple SMA calculation.""" + sma = StreamingSMA(timeperiod=3) + + sma.update(1.0) + sma.update(2.0) + result = sma.update(3.0) + + assert result is not None + np.testing.assert_almost_equal(result, 2.0) # (1+2+3)/3 = 2 + + def test_ready_property(self): + """Test ready property.""" + sma = StreamingSMA(timeperiod=3) + + assert not sma.ready + sma.update(1.0) + assert not sma.ready + sma.update(2.0) + assert not sma.ready + sma.update(3.0) + assert sma.ready + + +class TestStreamingEMABasic: + """Basic tests for StreamingEMA.""" + + def test_initialization(self): + """Test proper initialization.""" + ema = StreamingEMA(timeperiod=10) + assert ema.timeperiod == 10 + assert not ema.ready + + def test_first_value_is_sma(self): + """Test that first EMA value equals SMA.""" + ema = StreamingEMA(timeperiod=3) + + ema.update(1.0) + ema.update(2.0) + result = ema.update(3.0) + + # First EMA = SMA = (1+2+3)/3 = 2 + np.testing.assert_almost_equal(result, 2.0) + + def test_smoothing_factor(self): + """Test that subsequent values use smoothing.""" + ema = StreamingEMA(timeperiod=3) + + ema.update(1.0) + ema.update(2.0) + ema.update(3.0) # First EMA = 2.0 + + # Alpha = 2/(3+1) = 0.5 + # Next EMA = 4 * 0.5 + 2 * 0.5 = 3.0 + result = ema.update(4.0) + np.testing.assert_almost_equal(result, 3.0) + + +class TestStreamingRSIBasic: + """Basic tests for StreamingRSI.""" + + def test_initialization(self): + """Test proper initialization.""" + rsi = StreamingRSI(timeperiod=14) + assert rsi.timeperiod == 14 + assert not rsi.ready + + def test_rsi_bounds(self): + """Test RSI stays within 0-100 range.""" + rsi = StreamingRSI(timeperiod=5) + + # Generate random walk + np.random.seed(42) + for _ in range(50): + result = rsi.update(100 + np.random.randn() * 5) + if result is not None: + assert 0 <= result <= 100, f"RSI out of bounds: {result}" + + def test_uptrend_high_rsi(self): + """Test that uptrend produces high RSI.""" + rsi = StreamingRSI(timeperiod=5) + + # Consistent upward movement + for i in range(20): + result = rsi.update(100 + i) + + assert result is not None + assert result > 70, "RSI should be high in uptrend" + + def test_downtrend_low_rsi(self): + """Test that downtrend produces low RSI.""" + rsi = StreamingRSI(timeperiod=5) + + # Consistent downward movement + for i in range(20): + result = rsi.update(200 - i) + + assert result is not None + assert result < 30, "RSI should be low in downtrend" + + +class TestBatchEquivalence: + """Test that streaming results match batch calculations.""" + + def test_sma_matches_batch(self, sample_prices): + """Test StreamingSMA matches batch SMA.""" + timeperiod = 10 + + # Batch calculation + batch_result = SMA(sample_prices, timeperiod=timeperiod) + + # Streaming calculation + sma = StreamingSMA(timeperiod=timeperiod) + streaming_result = [] + for price in sample_prices: + result = sma.update(price) + streaming_result.append(result if result is not None else np.nan) + streaming_result = np.array(streaming_result) + + # Compare valid values + valid_mask = ~np.isnan(batch_result) + np.testing.assert_array_almost_equal( + streaming_result[valid_mask], + batch_result[valid_mask], + decimal=10 + ) + + def test_ema_matches_batch(self, sample_prices): + """Test StreamingEMA matches batch EMA.""" + timeperiod = 10 + + batch_result = EMA(sample_prices, timeperiod=timeperiod) + + ema = StreamingEMA(timeperiod=timeperiod) + streaming_result = [] + for price in sample_prices: + result = ema.update(price) + streaming_result.append(result if result is not None else np.nan) + streaming_result = np.array(streaming_result) + + valid_mask = ~np.isnan(batch_result) + np.testing.assert_array_almost_equal( + streaming_result[valid_mask], + batch_result[valid_mask], + decimal=10 + ) + + def test_wma_matches_batch(self, sample_prices): + """Test StreamingWMA matches batch WMA.""" + timeperiod = 10 + + batch_result = WMA(sample_prices, timeperiod=timeperiod) + + wma = StreamingWMA(timeperiod=timeperiod) + streaming_result = [] + for price in sample_prices: + result = wma.update(price) + streaming_result.append(result if result is not None else np.nan) + streaming_result = np.array(streaming_result) + + valid_mask = ~np.isnan(batch_result) + np.testing.assert_array_almost_equal( + streaming_result[valid_mask], + batch_result[valid_mask], + decimal=10 + ) + + def test_bbands_matches_batch(self, sample_prices): + """Test StreamingBBANDS matches batch BBANDS.""" + timeperiod = 10 + + batch_upper, batch_middle, batch_lower = BBANDS(sample_prices, timeperiod=timeperiod) + + bbands = StreamingBBANDS(timeperiod=timeperiod) + streaming_upper = [] + streaming_middle = [] + streaming_lower = [] + + for price in sample_prices: + result = bbands.update(price) + if result is not None: + streaming_upper.append(result[0]) + streaming_middle.append(result[1]) + streaming_lower.append(result[2]) + else: + streaming_upper.append(np.nan) + streaming_middle.append(np.nan) + streaming_lower.append(np.nan) + + streaming_middle = np.array(streaming_middle) + + valid_mask = ~np.isnan(batch_middle) + np.testing.assert_array_almost_equal( + streaming_middle[valid_mask], + batch_middle[valid_mask], + decimal=10 + ) + + +class TestResetFunctionality: + """Test reset functionality for streaming indicators.""" + + def test_sma_reset(self): + """Test StreamingSMA reset.""" + sma = StreamingSMA(timeperiod=5) + + # Fill with data + for i in range(10): + sma.update(100 + i) + + assert sma.ready + original_value = sma.value + + # Reset + sma.reset() + + assert not sma.ready + assert sma.value is None + + # Fill again with different data + for i in range(10): + sma.update(200 + i) + + assert sma.ready + assert sma.value != original_value + + def test_ema_reset(self): + """Test StreamingEMA reset.""" + ema = StreamingEMA(timeperiod=5) + + for i in range(10): + ema.update(100 + i) + + assert ema.ready + + ema.reset() + + assert not ema.ready + assert ema.value is None + + def test_rsi_reset(self): + """Test StreamingRSI reset.""" + rsi = StreamingRSI(timeperiod=5) + + for i in range(10): + rsi.update(100 + i) + + assert rsi.ready + + rsi.reset() + + assert not rsi.ready + assert rsi.value is None + + def test_all_indicators_reset(self): + """Test reset for all streaming indicators.""" + indicators = [ + StreamingSMA(5), + StreamingEMA(5), + StreamingDEMA(5), + StreamingTEMA(5), + StreamingWMA(5), + StreamingRSI(5), + StreamingMOM(5), + StreamingROC(5), + ] + + # Fill all with data + for _ in range(20): + for ind in indicators: + ind.update(100.0) + + # Verify all are ready + for ind in indicators: + assert ind.ready, f"{type(ind).__name__} should be ready" + + # Reset all + for ind in indicators: + ind.reset() + + # Verify all are not ready + for ind in indicators: + assert not ind.ready, f"{type(ind).__name__} should not be ready" + assert ind.value is None + + +class TestEdgeCases: + """Test edge cases for streaming indicators.""" + + def test_nan_input(self): + """Test handling of NaN input values.""" + sma = StreamingSMA(timeperiod=5) + + # Send some values then NaN + for i in range(4): + sma.update(100 + i) + + # NaN input + result = sma.update(np.nan) + # Should handle NaN gracefully (result might be NaN) + + # Continue with valid values + for i in range(5): + result = sma.update(100 + i) + + def test_inf_input(self): + """Test handling of inf input values.""" + sma = StreamingSMA(timeperiod=5) + + for i in range(4): + sma.update(100 + i) + + # Inf input + result = sma.update(np.inf) + + # Result might be inf + # Just verify it doesn't crash + + def test_very_large_values(self): + """Test handling of very large values.""" + sma = StreamingSMA(timeperiod=5) + + for i in range(10): + result = sma.update(1e100) + + assert result is not None + np.testing.assert_almost_equal(result, 1e100) + + def test_very_small_values(self): + """Test handling of very small values.""" + sma = StreamingSMA(timeperiod=5) + + for i in range(10): + result = sma.update(1e-100) + + assert result is not None + np.testing.assert_almost_equal(result, 1e-100) + + def test_zero_values(self): + """Test handling of zero values.""" + sma = StreamingSMA(timeperiod=5) + + for i in range(10): + result = sma.update(0.0) + + assert result is not None + np.testing.assert_almost_equal(result, 0.0) + + def test_negative_values(self): + """Test handling of negative values.""" + sma = StreamingSMA(timeperiod=5) + + values = [-100.0, -101.0, -102.0, -103.0, -104.0, -105.0, -106.0, -107.0, -108.0, -109.0] + for val in values: + result = sma.update(val) + + assert result is not None + # SMA of last 5 values: -105, -106, -107, -108, -109 = -107 + np.testing.assert_almost_equal(result, -107.0) + + +class TestCircularBufferComprehensive: + """Comprehensive tests for CircularBuffer.""" + + def test_empty_operations(self): + """Test operations on empty buffer.""" + buf = CircularBuffer(5) + + assert len(buf) == 0 + assert not buf.full + assert buf.sum == 0.0 + assert len(buf.values) == 0 + + def test_filling_buffer(self): + """Test progressive filling of buffer.""" + buf = CircularBuffer(5) + + for i in range(5): + buf.append(float(i + 1)) + assert len(buf) == i + 1 + assert not buf.full if i < 4 else buf.full + + def test_overflow_behavior(self): + """Test behavior when buffer overflows.""" + buf = CircularBuffer(3) + + # Fill buffer + buf.append(1.0) + buf.append(2.0) + buf.append(3.0) + + assert buf.full + assert len(buf) == 3 + + # Overflow + buf.append(4.0) + + assert len(buf) == 3 # Still 3 + assert buf[0] == 2.0 # Oldest is now 2 + assert buf[2] == 4.0 # Newest is 4 + + def test_sum_tracking(self): + """Test sum is correctly maintained.""" + buf = CircularBuffer(3) + + buf.append(10.0) + assert buf.sum == 10.0 + + buf.append(20.0) + assert buf.sum == 30.0 + + buf.append(30.0) + assert buf.sum == 60.0 + + # Overflow removes 10, adds 40 + buf.append(40.0) + assert buf.sum == 90.0 # 20 + 30 + 40 + + def test_values_property(self): + """Test values property returns correct array.""" + buf = CircularBuffer(4) + + buf.append(1.0) + buf.append(2.0) + np.testing.assert_array_equal(buf.values, [1.0, 2.0]) + + buf.append(3.0) + buf.append(4.0) + np.testing.assert_array_equal(buf.values, [1.0, 2.0, 3.0, 4.0]) + + buf.append(5.0) + np.testing.assert_array_equal(buf.values, [2.0, 3.0, 4.0, 5.0]) + + def test_negative_indexing(self): + """Test negative indexing.""" + buf = CircularBuffer(3) + + buf.append(1.0) + buf.append(2.0) + buf.append(3.0) + + assert buf[-1] == 3.0 # Most recent + assert buf[-2] == 2.0 + assert buf[-3] == 1.0 # Oldest + + def test_clear(self): + """Test clearing buffer.""" + buf = CircularBuffer(5) + + for i in range(5): + buf.append(float(i)) + + assert buf.full + + buf.clear() + + assert len(buf) == 0 + assert not buf.full + assert buf.sum == 0.0 + + +class TestMomentumIndicators: + """Test momentum streaming indicators.""" + + def test_streaming_mom(self): + """Test StreamingMOM.""" + mom = StreamingMOM(timeperiod=5) + + # Linearly increasing values + for i in range(10): + result = mom.update(100 + i) + + assert mom.ready + # MOM = current - past = 109 - 104 = 5 + np.testing.assert_almost_equal(mom.value, 5.0) + + def test_streaming_roc(self): + """Test StreamingROC.""" + roc = StreamingROC(timeperiod=5) + + # Double the price + for i in range(6): + roc.update(100.0) + + result = roc.update(200.0) + + # ROC = ((200 - 100) / 100) * 100 = 100% + assert roc.ready + np.testing.assert_almost_equal(roc.value, 100.0) + + +class TestVolatilityIndicators: + """Test volatility streaming indicators.""" + + def test_streaming_trange(self, sample_ohlcv): + """Test StreamingTRANGE.""" + open_, high, low, close, _ = sample_ohlcv + + tr = StreamingTRANGE() + + for i in range(len(high)): + result = tr.update_bar(open_[i], high[i], low[i], close[i]) + + # Should have values after first bar + assert tr.ready + assert tr.value is not None + assert tr.value >= 0 # True range is always non-negative + + def test_streaming_atr(self, sample_ohlcv): + """Test StreamingATR.""" + _, high, low, close, _ = sample_ohlcv + + atr = StreamingATR(timeperiod=14) + + for i in range(len(high)): + result = atr.update_bar( + open_=high[i] - 0.5, + high=high[i], + low=low[i], + close=close[i] + ) + + assert atr.ready + assert atr.value is not None + assert atr.value > 0 # ATR should be positive + + +class TestVolumeIndicators: + """Test volume streaming indicators.""" + + def test_streaming_obv(self, sample_ohlcv): + """Test StreamingOBV.""" + _, high, low, close, volume = sample_ohlcv + + obv = StreamingOBV() + + for i in range(len(close)): + result = obv.update_bar( + open_=high[i] - 0.5, + high=high[i], + low=low[i], + close=close[i], + volume=volume[i] + ) + + # OBV should have a value + assert obv.value is not None + + def test_obv_up_days(self): + """Test OBV accumulates on up days.""" + obv = StreamingOBV() + + # First bar + obv.update_bar(100, 105, 95, 100, 1000) + assert obv.value == 0 # First bar is always 0 + + # Up day - volume should be added + obv.update_bar(100, 110, 98, 105, 2000) + assert obv.value == 2000 + + # Another up day + obv.update_bar(105, 115, 103, 110, 1500) + assert obv.value == 3500 + + def test_obv_down_days(self): + """Test OBV decrements on down days.""" + obv = StreamingOBV() + + # First bar + obv.update_bar(100, 105, 95, 100, 1000) + + # Down day - volume should be subtracted + obv.update_bar(100, 102, 92, 95, 2000) + assert obv.value == -2000 + + +class TestLongRunning: + """Test streaming indicators over long periods.""" + + def test_sma_long_running(self, large_sample_prices): + """Test StreamingSMA over 1000 points.""" + timeperiod = 50 + + batch_result = SMA(large_sample_prices, timeperiod=timeperiod) + + sma = StreamingSMA(timeperiod=timeperiod) + streaming_result = [] + for price in large_sample_prices: + result = sma.update(price) + streaming_result.append(result if result is not None else np.nan) + streaming_result = np.array(streaming_result) + + valid_mask = ~np.isnan(batch_result) + np.testing.assert_array_almost_equal( + streaming_result[valid_mask], + batch_result[valid_mask], + decimal=8 + ) + + def test_ema_long_running(self, large_sample_prices): + """Test StreamingEMA over 1000 points.""" + timeperiod = 50 + + batch_result = EMA(large_sample_prices, timeperiod=timeperiod) + + ema = StreamingEMA(timeperiod=timeperiod) + streaming_result = [] + for price in large_sample_prices: + result = ema.update(price) + streaming_result.append(result if result is not None else np.nan) + streaming_result = np.array(streaming_result) + + valid_mask = ~np.isnan(batch_result) + np.testing.assert_array_almost_equal( + streaming_result[valid_mask], + batch_result[valid_mask], + decimal=8 + ) diff --git a/tests/unit/test_all_functions.py b/tests/unit/test_all_functions.py new file mode 100644 index 0000000..b418c68 --- /dev/null +++ b/tests/unit/test_all_functions.py @@ -0,0 +1,506 @@ +""" +Comprehensive tests for all numta functions. + +This module auto-discovers all numta functions and tests them for: +- Basic functionality (no crash on valid input) +- Edge cases (empty, NaN, single value, constant data) +- Consistent output shapes +""" + +import pytest +import numpy as np +import numta + + +# ===================================================================== +# Function Discovery and Signatures +# ===================================================================== + +def get_all_numta_functions(): + """Get all callable functions from numta module.""" + functions = [] + for name in dir(numta): + if name.startswith('_'): + continue + obj = getattr(numta, name) + if callable(obj) and not isinstance(obj, type): + functions.append(name) + return functions + + +# Function signatures define input requirements for each function +# Format: {'function_name': {'inputs': ['close'] or ['high', 'low', 'close'], 'params': {param: default}}} +FUNCTION_SIGNATURES = { + # Overlap Studies + 'SMA': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'EMA': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'DEMA': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'TEMA': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'WMA': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'TRIMA': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'KAMA': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'MA': {'inputs': ['close'], 'params': {'timeperiod': 10, 'matype': 0}}, + 'T3': {'inputs': ['close'], 'params': {'timeperiod': 5, 'vfactor': 0.7}}, + 'BBANDS': {'inputs': ['close'], 'params': {'timeperiod': 5, 'nbdevup': 2.0, 'nbdevdn': 2.0, 'matype': 0}}, + 'MAMA': {'inputs': ['close'], 'params': {'fastlimit': 0.5, 'slowlimit': 0.05}}, + 'SAR': {'inputs': ['high', 'low'], 'params': {'acceleration': 0.02, 'maximum': 0.2}}, + 'SAREXT': {'inputs': ['high', 'low'], 'params': { + 'startvalue': 0.0, 'offsetonreverse': 0.0, + 'accelerationinit_long': 0.02, 'accelerationlong': 0.02, + 'accelerationmax_long': 0.2, 'accelerationinit_short': 0.02, + 'accelerationshort': 0.02, 'accelerationmax_short': 0.2 + }}, + + # Momentum Indicators + 'RSI': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'MOM': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'ROC': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'ROCP': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'ROCR': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'ROCR100': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'CMO': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'TRIX': {'inputs': ['close'], 'params': {'timeperiod': 10}}, + 'PPO': {'inputs': ['close'], 'params': {'fastperiod': 12, 'slowperiod': 26, 'matype': 0}}, + 'APO': {'inputs': ['close'], 'params': {'fastperiod': 12, 'slowperiod': 26, 'matype': 0}}, + 'MACD': {'inputs': ['close'], 'params': {'fastperiod': 12, 'slowperiod': 26, 'signalperiod': 9}}, + 'MACDEXT': {'inputs': ['close'], 'params': { + 'fastperiod': 12, 'fastmatype': 0, + 'slowperiod': 26, 'slowmatype': 0, + 'signalperiod': 9, 'signalmatype': 0 + }}, + 'MACDFIX': {'inputs': ['close'], 'params': {'signalperiod': 9}}, + 'ADX': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'ADXR': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'DX': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'PLUS_DI': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'PLUS_DM': {'inputs': ['high', 'low'], 'params': {'timeperiod': 14}}, + 'MINUS_DI': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'MINUS_DM': {'inputs': ['high', 'low'], 'params': {'timeperiod': 14}}, + 'ATR': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'AROON': {'inputs': ['high', 'low'], 'params': {'timeperiod': 14}}, + 'AROONOSC': {'inputs': ['high', 'low'], 'params': {'timeperiod': 14}}, + 'BOP': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CCI': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'MFI': {'inputs': ['high', 'low', 'close', 'volume'], 'params': {'timeperiod': 14}}, + 'STOCH': {'inputs': ['high', 'low', 'close'], 'params': { + 'fastk_period': 5, 'slowk_period': 3, 'slowk_matype': 0, + 'slowd_period': 3, 'slowd_matype': 0 + }}, + 'STOCHF': {'inputs': ['high', 'low', 'close'], 'params': { + 'fastk_period': 5, 'fastd_period': 3, 'fastd_matype': 0 + }}, + 'STOCHRSI': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'ULTOSC': {'inputs': ['high', 'low', 'close'], 'params': { + 'timeperiod1': 7, 'timeperiod2': 14, 'timeperiod3': 28 + }}, + 'WILLR': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + + # Volume Indicators + 'OBV': {'inputs': ['close', 'volume'], 'params': {}}, + 'AD': {'inputs': ['high', 'low', 'close', 'volume'], 'params': {}}, + 'ADOSC': {'inputs': ['high', 'low', 'close', 'volume'], 'params': {'fastperiod': 3, 'slowperiod': 10}}, + + # Volatility Indicators + 'NATR': {'inputs': ['high', 'low', 'close'], 'params': {'timeperiod': 14}}, + 'TRANGE': {'inputs': ['high', 'low', 'close'], 'params': {}}, + + # Cycle Indicators + 'HT_DCPERIOD': {'inputs': ['close'], 'params': {}}, + 'HT_DCPHASE': {'inputs': ['close'], 'params': {}}, + 'HT_PHASOR': {'inputs': ['close'], 'params': {}}, + 'HT_SINE': {'inputs': ['close'], 'params': {}}, + 'HT_TRENDLINE': {'inputs': ['close'], 'params': {}}, + 'HT_TRENDMODE': {'inputs': ['close'], 'params': {}}, + + # Statistic Functions + 'STDDEV': {'inputs': ['close'], 'params': {'timeperiod': 5, 'nbdev': 1.0}}, + 'VAR': {'inputs': ['close'], 'params': {'timeperiod': 5, 'nbdev': 1.0}}, + 'TSF': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'LINEARREG': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'LINEARREG_ANGLE': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'LINEARREG_INTERCEPT': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'LINEARREG_SLOPE': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'BETA': {'inputs': ['high', 'low'], 'params': {'timeperiod': 5}}, + 'CORREL': {'inputs': ['high', 'low'], 'params': {'timeperiod': 30}}, + + # Math Operators + 'MAX': {'inputs': ['close'], 'params': {'timeperiod': 30}}, + 'MAXINDEX': {'inputs': ['close'], 'params': {'timeperiod': 30}}, + 'MIN': {'inputs': ['close'], 'params': {'timeperiod': 30}}, + 'MININDEX': {'inputs': ['close'], 'params': {'timeperiod': 30}}, + 'MINMAX': {'inputs': ['close'], 'params': {'timeperiod': 30}}, + 'MINMAXINDEX': {'inputs': ['close'], 'params': {'timeperiod': 30}}, + 'SUM': {'inputs': ['close'], 'params': {'timeperiod': 30}}, + + # Price Transforms + 'MEDPRICE': {'inputs': ['high', 'low'], 'params': {}}, + 'MIDPOINT': {'inputs': ['close'], 'params': {'timeperiod': 14}}, + 'MIDPRICE': {'inputs': ['high', 'low'], 'params': {'timeperiod': 14}}, + 'TYPPRICE': {'inputs': ['high', 'low', 'close'], 'params': {}}, + 'WCLPRICE': {'inputs': ['high', 'low', 'close'], 'params': {}}, + + # Candlestick Patterns (all take open, high, low, close) + 'CDLDOJI': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDL2CROWS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDL3BLACKCROWS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDL3INSIDE': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDL3OUTSIDE': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDL3STARSINSOUTH': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDL3WHITESOLDIERS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLABANDONEDBABY': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLADVANCEBLOCK': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLBELTHOLD': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLBREAKAWAY': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLCLOSINGMARUBOZU': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLCONCEALBABYSWALL': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLCOUNTERATTACK': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLDARKCLOUDCOVER': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLDOJISTAR': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLDRAGONFLYDOJI': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLENGULFING': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLEVENINGDOJISTAR': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLEVENINGSTAR': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLGAPSIDESIDEWHITE': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLGRAVESTONEDOJI': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHAMMER': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHANGINGMAN': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHARAMI': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHARAMICROSS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHIGHWAVE': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHIKKAKE': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHIKKAKEMOD': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLHOMINGPIGEON': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLIDENTICAL3CROWS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLINNECK': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLINVERTEDHAMMER': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLKICKING': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLKICKINGBYLENGTH': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLLADDERBOTTOM': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLLONGLEGGEDDOJI': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLLONGLINE': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLMARUBOZU': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLMATCHINGLOW': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLMATHOLD': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLMORNINGDOJISTAR': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLMORNINGSTAR': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLONNECK': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLPIERCING': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLRICKSHAWMAN': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLRISEFALL3METHODS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLSEPARATINGLINES': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLSHOOTINGSTAR': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLSHORTLINE': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLSPINNINGTOP': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLSTALLEDPATTERN': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLSTICKSANDWICH': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLTAKURI': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLTASUKIGAP': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLTHRUSTING': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLTRISTAR': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLUNIQUE3RIVER': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLUPSIDEGAP2CROWS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, + 'CDLXSIDEGAP3METHODS': {'inputs': ['open', 'high', 'low', 'close'], 'params': {}}, +} + + +def get_function_args(func_name: str, sample_ohlcv_data): + """ + Get the appropriate arguments for a function based on its signature. + + Parameters + ---------- + func_name : str + Name of the function + sample_ohlcv_data : tuple + (open, high, low, close, volume) arrays + + Returns + ------- + tuple + (args, kwargs) to call the function + """ + open_, high, low, close, volume = sample_ohlcv_data + + if func_name not in FUNCTION_SIGNATURES: + return None, None + + sig = FUNCTION_SIGNATURES[func_name] + inputs = sig['inputs'] + params = sig['params'] + + data_map = { + 'open': open_, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume, + } + + args = [data_map[inp] for inp in inputs] + return args, params + + +# Get list of testable functions (those with known signatures) +TESTABLE_FUNCTIONS = list(FUNCTION_SIGNATURES.keys()) + + +# ===================================================================== +# Test Classes +# ===================================================================== + +class TestAllFunctionsNoCrash: + """Test that all functions execute without crashing on valid input.""" + + @pytest.fixture(autouse=True) + def setup_data(self, sample_ohlcv_data): + """Setup test data.""" + self.sample_ohlcv_data = sample_ohlcv_data + + @pytest.mark.parametrize("func_name", TESTABLE_FUNCTIONS) + def test_function_no_crash(self, func_name): + """Test that function executes without error on valid input.""" + func = getattr(numta, func_name) + args, kwargs = get_function_args(func_name, self.sample_ohlcv_data) + + if args is None: + pytest.skip(f"No signature defined for {func_name}") + + # Should not raise any exceptions + result = func(*args, **kwargs) + + # Basic validation + assert result is not None + + @pytest.mark.parametrize("func_name", TESTABLE_FUNCTIONS) + def test_function_returns_array(self, func_name): + """Test that function returns numpy array(s).""" + func = getattr(numta, func_name) + args, kwargs = get_function_args(func_name, self.sample_ohlcv_data) + + if args is None: + pytest.skip(f"No signature defined for {func_name}") + + result = func(*args, **kwargs) + + # Result should be array or tuple of arrays + if isinstance(result, tuple): + for r in result: + assert isinstance(r, np.ndarray), f"{func_name} returned non-array in tuple" + else: + assert isinstance(result, np.ndarray), f"{func_name} did not return array" + + @pytest.mark.parametrize("func_name", TESTABLE_FUNCTIONS) + def test_function_output_length(self, func_name): + """Test that function returns output with correct length.""" + func = getattr(numta, func_name) + args, kwargs = get_function_args(func_name, self.sample_ohlcv_data) + + if args is None: + pytest.skip(f"No signature defined for {func_name}") + + result = func(*args, **kwargs) + expected_len = len(args[0]) # First input determines output length + + if isinstance(result, tuple): + for r in result: + assert len(r) == expected_len, f"{func_name} output length mismatch" + else: + assert len(result) == expected_len, f"{func_name} output length mismatch" + + +class TestEdgeCases: + """Test edge cases for indicator functions.""" + + @pytest.fixture(autouse=True) + def setup_data(self): + """Setup edge case test data.""" + self.empty = np.array([], dtype=np.float64) + # Use larger arrays to avoid Numba JIT issues with very small arrays + self.small = np.array([100.0, 101.0, 102.0, 103.0, 104.0], dtype=np.float64) + self.constant = np.full(50, 100.0, dtype=np.float64) + self.with_nan = np.array([100.0, np.nan, 101.0, 102.0, np.nan, 103.0, 104.0, 105.0, 106.0, 107.0, + 108.0, 109.0, 110.0, 111.0, 112.0], dtype=np.float64) + + # Test a representative subset of simple functions for edge cases + # Avoid complex JIT-compiled functions that may crash on edge case data + EDGE_CASE_FUNCTIONS = ['SMA', 'EMA', 'BBANDS'] + + @pytest.mark.parametrize("func_name", EDGE_CASE_FUNCTIONS) + def test_empty_input(self, func_name): + """Test function behavior with empty input.""" + func = getattr(numta, func_name) + sig = FUNCTION_SIGNATURES[func_name] + + # Create empty arrays matching input signature + args = [self.empty for _ in sig['inputs']] + kwargs = sig['params'] + + try: + result = func(*args, **kwargs) + # If it doesn't raise, result should be empty or tuple of empty + if isinstance(result, tuple): + for r in result: + assert len(r) == 0 + else: + assert len(result) == 0 + except (ValueError, IndexError): + # It's acceptable to raise an error for empty input + pass + + @pytest.mark.parametrize("func_name", EDGE_CASE_FUNCTIONS) + def test_small_input(self, func_name): + """Test function behavior with small input (5 values).""" + func = getattr(numta, func_name) + sig = FUNCTION_SIGNATURES[func_name] + + args = [self.small for _ in sig['inputs']] + kwargs = sig['params'] + + try: + result = func(*args, **kwargs) + # Result should have length 5 + if isinstance(result, tuple): + for r in result: + assert len(r) == 5 + else: + assert len(result) == 5 + except ValueError: + # It's acceptable to raise an error for insufficient data + pass + + @pytest.mark.parametrize("func_name", EDGE_CASE_FUNCTIONS) + def test_constant_input(self, func_name): + """Test function behavior with constant input.""" + func = getattr(numta, func_name) + sig = FUNCTION_SIGNATURES[func_name] + + args = [self.constant for _ in sig['inputs']] + kwargs = sig['params'] + + result = func(*args, **kwargs) + + # Result should not contain inf or negative inf (except for legitimate cases) + if isinstance(result, tuple): + for r in result: + # Check for any unexpected inf values + finite_mask = np.isfinite(r) | np.isnan(r) + assert np.all(finite_mask), f"{func_name} produced unexpected inf values" + else: + finite_mask = np.isfinite(result) | np.isnan(result) + assert np.all(finite_mask), f"{func_name} produced unexpected inf values" + + @pytest.mark.parametrize("func_name", ['SMA', 'EMA']) + def test_nan_input_handling(self, func_name): + """Test function behavior with NaN values in input.""" + func = getattr(numta, func_name) + sig = FUNCTION_SIGNATURES[func_name] + + args = [self.with_nan for _ in sig['inputs']] + kwargs = sig['params'] + + # Should not crash + result = func(*args, **kwargs) + + # Result should be array with correct length + assert isinstance(result, np.ndarray) + assert len(result) == len(self.with_nan) + + +class TestFunctionOutputTypes: + """Test that functions return correct output types.""" + + @pytest.fixture(autouse=True) + def setup_data(self, sample_ohlcv_data): + """Setup test data.""" + self.sample_ohlcv_data = sample_ohlcv_data + + # Functions that return tuples (multiple outputs) + MULTI_OUTPUT_FUNCTIONS = [ + ('BBANDS', 3), # upper, middle, lower + ('MACD', 3), # macd, signal, hist + ('MACDEXT', 3), + ('MACDFIX', 3), + ('STOCH', 2), # slowk, slowd + ('STOCHF', 2), # fastk, fastd + ('STOCHRSI', 2), + ('AROON', 2), # aroondown, aroonup + ('MAMA', 2), # mama, fama + ('HT_PHASOR', 2), # inphase, quadrature + ('HT_SINE', 2), # sine, leadsine + ('MINMAX', 2), # min, max + ('MINMAXINDEX', 2), # minidx, maxidx + ] + + @pytest.mark.parametrize("func_name,expected_outputs", MULTI_OUTPUT_FUNCTIONS) + def test_multi_output_function(self, func_name, expected_outputs): + """Test that multi-output functions return correct number of outputs.""" + func = getattr(numta, func_name) + args, kwargs = get_function_args(func_name, self.sample_ohlcv_data) + + if args is None: + pytest.skip(f"No signature defined for {func_name}") + + result = func(*args, **kwargs) + + assert isinstance(result, tuple), f"{func_name} should return tuple" + assert len(result) == expected_outputs, \ + f"{func_name} should return {expected_outputs} outputs, got {len(result)}" + + # Functions that return single array + SINGLE_OUTPUT_FUNCTIONS = ['SMA', 'EMA', 'RSI', 'ATR', 'ADX', 'OBV'] + + @pytest.mark.parametrize("func_name", SINGLE_OUTPUT_FUNCTIONS) + def test_single_output_function(self, func_name): + """Test that single-output functions return array, not tuple.""" + func = getattr(numta, func_name) + args, kwargs = get_function_args(func_name, self.sample_ohlcv_data) + + if args is None: + pytest.skip(f"No signature defined for {func_name}") + + result = func(*args, **kwargs) + + assert isinstance(result, np.ndarray), f"{func_name} should return numpy array" + assert not isinstance(result, tuple), f"{func_name} should not return tuple" + + +class TestCandlestickPatterns: + """Test candlestick pattern recognition functions.""" + + @pytest.fixture(autouse=True) + def setup_data(self, sample_ohlcv_data): + """Setup test data.""" + self.sample_ohlcv_data = sample_ohlcv_data + + # Get all candlestick pattern functions + CDL_FUNCTIONS = [f for f in TESTABLE_FUNCTIONS if f.startswith('CDL')] + + @pytest.mark.parametrize("func_name", CDL_FUNCTIONS) + def test_pattern_returns_integer_array(self, func_name): + """Test that pattern functions return integer results.""" + func = getattr(numta, func_name) + args, kwargs = get_function_args(func_name, self.sample_ohlcv_data) + + if args is None: + pytest.skip(f"No signature defined for {func_name}") + + result = func(*args, **kwargs) + + assert isinstance(result, np.ndarray) + # Pattern results should be integers (typically -100, 0, or 100) + # Allow for float representation of integers + valid_values = {-100, 0, 100} + unique_values = set(np.unique(result[~np.isnan(result)]).astype(int)) + assert unique_values.issubset(valid_values), \ + f"{func_name} returned unexpected values: {unique_values}" + + @pytest.mark.parametrize("func_name", CDL_FUNCTIONS) + def test_pattern_output_length(self, func_name): + """Test that pattern functions return correct output length.""" + func = getattr(numta, func_name) + open_, high, low, close, _ = self.sample_ohlcv_data + + result = func(open_, high, low, close) + + assert len(result) == len(close), \ + f"{func_name} output length mismatch"