diff --git a/tests/integration-tests/conftest.py b/tests/integration-tests/conftest.py index fd9b4ad20a..29236902fc 100644 --- a/tests/integration-tests/conftest.py +++ b/tests/integration-tests/conftest.py @@ -82,10 +82,10 @@ generate_stack_name, get_architecture_supported_by_instance_type, get_arn_partition, + get_flexible_instance_types, get_instance_info, get_metadata, get_network_interfaces_count, - get_similar_instance_types, get_vpc_snakecase_value, random_alphanumeric, to_pascal_case, @@ -698,7 +698,7 @@ def inject_placement_group_settings(vpc_stack, instance, region, kwargs): def inject_flexible_instance_types_settings(instance, region, kwargs): - kwargs["flexible_instance_types"] = list({instance, *get_similar_instance_types(instance, region, 5)}) + kwargs["flexible_instance_types"] = get_flexible_instance_types(instance, region) def inject_additional_image_configs_settings(image_config, request): diff --git a/tests/integration-tests/framework/file_cache.py b/tests/integration-tests/framework/file_cache.py new file mode 100644 index 0000000000..0ed17cf00d --- /dev/null +++ b/tests/integration-tests/framework/file_cache.py @@ -0,0 +1,102 @@ +# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. +# This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. +# See the License for the specific language governing permissions and limitations under the License. +"""Cross-process file-backed memoization decorator. + +Drop-in replacement for ``functools.cache`` that persists results to a file +guarded by a :class:`filelock.FileLock`, so that callers running in separate +processes (e.g. pytest-xdist workers) share cached values instead of each +recomputing the same result. +""" + +import functools +import os +import pickle +import tempfile + +from filelock import FileLock + + +def file_cache(filename: str): + """Decorator providing cross-process memoization backed by a file. + + Works like ``functools.cache`` but persists results across processes via a + pickle file. All positional and keyword arguments must be hashable and + return values must be picklable. + + Parameters + ---------- + filename: + Path to the cache file. If a relative path is given, it is resolved + under :func:`tempfile.gettempdir` so the cache survives a single + machine across pytest sessions and is shared by all workers. + """ + cache_path = filename if os.path.isabs(filename) else os.path.join(tempfile.gettempdir(), filename) + lock_path = cache_path + ".lock" + + def decorator(func): + in_memory = {} + + @functools.wraps(func) + def wrapper(*args, **kwargs): + key = (args, tuple(sorted(kwargs.items()))) + if key in in_memory: + return in_memory[key] + + with FileLock(lock_path): + disk_cache = _load(cache_path) + if key in disk_cache: + in_memory[key] = disk_cache[key] + return disk_cache[key] + + result = func(*args, **kwargs) + disk_cache[key] = result + _dump(cache_path, disk_cache) + in_memory[key] = result + return result + + def cache_clear(): + in_memory.clear() + with FileLock(lock_path): + if os.path.exists(cache_path): + os.remove(cache_path) + + wrapper.cache_clear = cache_clear + wrapper.__wrapped__ = func + return wrapper + + return decorator + + +def _load(path): + if not os.path.exists(path): + return {} + try: + with open(path, "rb") as f: + return pickle.load(f) + except (EOFError, pickle.UnpicklingError): + # Corrupted cache file — start fresh. + return {} + + +def _dump(path, data): + # Atomic write: dump to a temp file in the same directory, then rename. + directory = os.path.dirname(path) or "." + os.makedirs(directory, exist_ok=True) + fd, tmp_path = tempfile.mkstemp(prefix=".file_cache_", dir=directory) + try: + with os.fdopen(fd, "wb") as f: + pickle.dump(data, f) + os.replace(tmp_path, path) + except Exception: + if os.path.exists(tmp_path): + os.remove(tmp_path) + raise diff --git a/tests/integration-tests/tests/common/capacity_helpers.py b/tests/integration-tests/tests/common/capacity_helpers.py index 1fb351986a..e2e475e5f8 100644 --- a/tests/integration-tests/tests/common/capacity_helpers.py +++ b/tests/integration-tests/tests/common/capacity_helpers.py @@ -35,7 +35,7 @@ def resolve_instance_with_capacity(region, az_id, instance_type, os, minutes=50, if instance_type not in DEFAULT_INSTANCE_TYPES: return instance_type - candidates = [instance_type] + get_similar_instance_types(instance_type) + candidates = [instance_type] + get_similar_instance_types(instance_type, region) ec2_client = boto3.client("ec2", region_name=region) instance_platform = "Red Hat Enterprise Linux" if "rhel" in os else "Linux/UNIX" diff --git a/tests/integration-tests/utils.py b/tests/integration-tests/utils.py index 18d14eec8d..361400591d 100644 --- a/tests/integration-tests/utils.py +++ b/tests/integration-tests/utils.py @@ -19,12 +19,12 @@ import string import subprocess from datetime import datetime, timedelta -from functools import cache from hashlib import sha1 import boto3 import requests from assertpy import assert_that +from framework.file_cache import file_cache from jinja2 import FileSystemLoader from jinja2.sandbox import SandboxedEnvironment from retrying import retry @@ -1026,9 +1026,10 @@ def _get_gpu_spec(instance_type_data): return frozenset((gpu.get("Manufacturer", ""), gpu.get("Count", 0)) for gpu in gpu_info.get("Gpus", [])) +@file_cache("pcluster_similar_instance_types.cache") def get_similar_instance_types(instance_type: str, region: str = None, max_items: int = None): + """Return instance types compatible with ``instance_type`` in ``region``.""" ec2 = boto3.client("ec2", region_name=region) - # First, get the target instance details to use as filter criteria target_response = ec2.describe_instance_types(InstanceTypes=[instance_type]) @@ -1046,6 +1047,7 @@ def get_similar_instance_types(instance_type: str, region: str = None, max_items # Now query for similar instances using filters paginator = ec2.get_paginator("describe_instance_types") similar_instances = [] + reached_max_items = False for page in paginator.paginate( Filters=[ @@ -1069,17 +1071,26 @@ def get_similar_instance_types(instance_type: str, region: str = None, max_items ): similar_instances.append(instance["InstanceType"]) if max_items and len(similar_instances) >= max_items: - return similar_instances + reached_max_items = True + break + if reached_max_items: + break + + logging.info(f"Retrieved instance types equivalent to {instance_type} in {region}: {similar_instances}") return similar_instances -@cache +def get_flexible_instance_types(instance, region): + """Return ``instance`` plus up to 5 similar instance types available in ``region``.""" + return list({instance, *get_similar_instance_types(instance, region)[:5]}) + + def get_flexible_gpu_instance_types(instance, region): """Return a list of NVIDIA GPU instance types compatible with ``instance``'s architecture.""" architecture = get_architecture_supported_by_instance_type(instance, region) gpu_instance_type = "g4dn.2xlarge" if architecture == "x86_64" else "g5g.2xlarge" - return list({gpu_instance_type, *get_similar_instance_types(gpu_instance_type, region, 5)}) + return list({gpu_instance_type, *get_similar_instance_types(gpu_instance_type, region)[:5]}) def verify_cluster_node_config_version_in_ddb(region, cluster_name, instance_id, expected_version):