diff --git a/concat_fix_repro.py b/concat_fix_repro.py new file mode 100644 index 00000000000000..53f77b2f762d05 --- /dev/null +++ b/concat_fix_repro.py @@ -0,0 +1,49 @@ +import tensorflow as tf + +class TestModel(tf.keras.Model): + + def __init__(self): + super().__init__() + self.d1 = tf.keras.layers.Dense(64, activation='relu') + self.d2 = tf.keras.layers.Dense(32) + self.d3 = tf.keras.layers.Dense(16) + + def call(self, x, indices=None): + x = self.d1(x) + if indices is not None: + (unique_vals, _) = tf.unique(indices) + x = tf.nn.relu(tf.gather(x, unique_vals)) + else: + x = tf.nn.relu(x) + partitioned = tf.dynamic_partition(x, tf.cast(tf.reduce_sum(x, axis=1) > 0, tf.int32), num_partitions=2) + x = tf.concat(partitioned, axis=0) + (top_k_values, _) = tf.nn.top_k(x, k=tf.shape(x)[0] // 2) + x = tf.nn.relu(self.d2(top_k_values)) + return self.d3(x) + + +def get_default_model(): + return TestModel() + + +def get_sample_inputs(): + x = tf.random.normal([10, 64]) + indices = tf.random.uniform([10], maxval=5, dtype=tf.int32) + return (x, indices) + + +def main(): + model = get_default_model() + inputs = get_sample_inputs() + eager_out = model(*inputs) + print('Eager Input shape:', inputs[0].shape) + print('Eager Output shape:', eager_out.shape) + @tf.function(jit_compile=True) + def compiled_forward(*args): + return model(*args) + compiled_out = compiled_forward(*inputs) + print('XLA Output shape:', compiled_out.shape) + + +if __name__ == '__main__': + main() diff --git a/tensorflow/compiler/tf2xla/kernels/concat_dynamic_test.py b/tensorflow/compiler/tf2xla/kernels/concat_dynamic_test.py new file mode 100644 index 00000000000000..60981dc367e0a4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/concat_dynamic_test.py @@ -0,0 +1,83 @@ +"""Test for dynamic shape preservation in concat operation with XLA compilation.""" + +import tensorflow as tf +import unittest + + +class ConcatDynamicShapeTest(unittest.TestCase): + + def test_dynamic_partition_concat_topk_matmul_xla(self): + """Test that XLA preserves dynamic shapes through the full pipeline.""" + + class TestModel(tf.keras.Model): + def __init__(self): + super().__init__() + self.d1 = tf.keras.layers.Dense(64, activation='relu') + self.d2 = tf.keras.layers.Dense(32) + self.d3 = tf.keras.layers.Dense(16) + + def call(self, x, indices=None): + x = self.d1(x) + if indices is not None: + (unique_vals, _) = tf.unique(indices) + x = tf.nn.relu(tf.gather(x, unique_vals)) + else: + x = tf.nn.relu(x) + + # This chain should preserve dynamic dimensions + partitioned = tf.dynamic_partition( + x, tf.cast(tf.reduce_sum(x, axis=1) > 0, tf.int32), + num_partitions=2) + x = tf.concat(partitioned, axis=0) # Critical: must preserve dynamic dim + (top_k_values, _) = tf.nn.top_k(x, k=tf.shape(x)[0] // 2) + x = tf.nn.relu(self.d2(top_k_values)) + return self.d3(x) + + model = TestModel() + x = tf.random.normal([10, 64]) + indices = tf.random.uniform([10], maxval=5, dtype=tf.int32) + + # Test eager execution + eager_out = model(x, indices) + self.assertEqual(len(eager_out.shape), 2) + self.assertEqual(eager_out.shape[1], 16) + + # Test XLA compilation - this should not fail with shape errors + @tf.function(jit_compile=True) + def compiled_forward(x_input, indices_input): + return model(x_input, indices_input) + + # This should succeed without Matrix size-incompatible error + compiled_out = compiled_forward(x, indices) + + # Verify output shapes match between eager and compiled + self.assertEqual(eager_out.shape[1], compiled_out.shape[1]) + + def test_concat_preserves_dynamic_dimensions(self): + """Direct test of concat with dynamic partition outputs.""" + + @tf.function(jit_compile=True) + def test_concat_dynamic(): + x = tf.random.normal([8, 32]) + partitions = tf.cast(tf.reduce_sum(x, axis=1) > 0, tf.int32) + + # Create dynamic partition outputs + partitioned = tf.dynamic_partition(x, partitions, num_partitions=2) + + # Concat should preserve dynamic dimension + result = tf.concat(partitioned, axis=0) + + return result, tf.shape(result) + + result, shape = test_concat_dynamic() + + # Should succeed without errors and have reasonable output shape + self.assertEqual(len(result.shape), 2) + self.assertEqual(result.shape[1], 32) + self.assertGreater(shape[0], 0) # Dynamic first dimension + + +if __name__ == '__main__': + # Run with eager execution to verify test logic + tf.config.run_functions_eagerly(False) + unittest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index bed3479941ca41..f85d716a5661d1 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -94,7 +94,44 @@ class ConcatBaseOp : public XlaOpKernel { } VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + xla::XlaOp result = xla::ConcatInDim(ctx->builder(), input_data, axis); + + // Preserve dynamic dimension information for the concat axis. + // If any input has a dynamic dimension along `axis`, compute the + // summed size (mixing dynamic GetDimensionSize() and static constants) + // and set it on the concat result using SetDimensionSize. + bool any_input_dynamic = false; + xla::XlaOp dynamic_size_sum; + bool size_sum_initialized = false; + + for (int i = 0; i < N; ++i) { + // Query the input XLA shape to see if this dimension is dynamic. + auto input_shape_or = ctx->InputXlaShape(i); + OP_REQUIRES_OK(ctx, input_shape_or.status()); + const xla::Shape& input_shape = *input_shape_or; + + xla::XlaOp input_size; + if (input_shape.is_dynamic_dimension(axis)) { + any_input_dynamic = true; + input_size = xla::GetDimensionSize(input_data[i], axis); + } else { + input_size = xla::ConstantR0(ctx->builder(), + input_shape.dimensions(axis)); + } + + if (!size_sum_initialized) { + dynamic_size_sum = input_size; + size_sum_initialized = true; + } else { + dynamic_size_sum = xla::Add(dynamic_size_sum, input_size); + } + } + + if (any_input_dynamic && N > 0) { + result = xla::SetDimensionSize(result, dynamic_size_sum, axis); + } + + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/issue_105133_fix_demo.py b/tensorflow/compiler/tf2xla/kernels/issue_105133_fix_demo.py new file mode 100644 index 00000000000000..9836e97a2d260d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/issue_105133_fix_demo.py @@ -0,0 +1,232 @@ +"""Integration demo for GitHub issue #105133: XLA compatibility with conditional operations. + +This script reproduces the exact issue and demonstrates the fix. + +REQUIREMENTS: +- TensorFlow 2.20.0+ with XLA/GPU support enabled +- Run on XLA-enabled build (GPU recommended) + +USAGE: + python issue_105133_fix_demo.py + +This is an integration test - for CI-safe unit tests, see xla_conditional_compatibility_test.py + +Issue: OperatorNotAllowedInGraphError when using tf.shape()[0] in Python conditionals with jit_compile=True +Fix: Replace Python 'if' statements with tf.cond() for XLA compatibility +""" + +import tensorflow as tf + + +class TestModelProblematic(tf.keras.Model): + """Original problematic model from issue #105133.""" + + def __init__(self): + super().__init__() + self.conv1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)) + self.pool1 = tf.keras.layers.MaxPooling2D((2, 2)) + self.flatten = tf.keras.layers.Flatten() + self.dense1 = tf.keras.layers.Dense(64, activation='relu') + self.dense2 = tf.keras.layers.Dense(10, activation='softmax') + + def call(self, x): + # These Python if statements cause OperatorNotAllowedInGraphError in XLA + if tf.shape(x)[0] >= 1: # PROBLEMATIC + x = tf.stop_gradient(x) + (h, w) = (tf.shape(x)[1], tf.shape(x)[2]) + if h > 1 and w > 1: # PROBLEMATIC + x = self.conv1(x) + x = self.pool1(x) + else: + x = tf.nn.avg_pool2d(x, ksize=2, strides=2, padding='VALID') + x = self.flatten(x) + flat_size = tf.size(x) + if flat_size == 1024: # PROBLEMATIC + x = self.dense1(x) + else: + x = tf.nn.dropout(x, rate=0.5) + x = self.dense2(x) + return x + + +class TestModelFixed(tf.keras.Model): + """Fixed model using tf.cond for XLA compatibility.""" + + def __init__(self): + super().__init__() + self.conv1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)) + self.pool1 = tf.keras.layers.MaxPooling2D((2, 2)) + self.flatten = tf.keras.layers.Flatten() + self.dense1 = tf.keras.layers.Dense(64, activation='relu') + self.dense2 = tf.keras.layers.Dense(10, activation='softmax') + + def call(self, x): + # Use tf.cond instead of Python if statements for XLA compatibility + x = tf.cond( + tf.shape(x)[0] >= 1, + lambda: tf.stop_gradient(x), + lambda: x + ) + + h, w = tf.shape(x)[1], tf.shape(x)[2] + x = tf.cond( + tf.logical_and(h > 1, w > 1), + lambda: self._conv_pool_branch(x), + lambda: tf.nn.avg_pool2d(x, ksize=2, strides=2, padding='VALID') + ) + + x = self.flatten(x) + flat_size = tf.size(x) + x = tf.cond( + tf.equal(flat_size, 1024), + lambda: self.dense1(x), + lambda: self.dense2(tf.nn.dropout(x, rate=0.5)) + ) + + return x + + def _conv_pool_branch(self, x): + """Helper method for convolution + pooling branch.""" + x = self.conv1(x) + return self.pool1(x) + + +def get_default_model_problematic(): + return TestModelProblematic() + + +def get_default_model_fixed(): + return TestModelFixed() + + +def get_sample_inputs(): + x = tf.random.normal([16, 28, 28, 1]) + return (x,) + + +def test_problematic_version(): + """Test the problematic version - this demonstrates the original issue.""" + print("Testing problematic version...") + model = get_default_model_problematic() + inputs = get_sample_inputs() + + # This works in eager execution + eager_out = model(*inputs) + print('Problematic - Eager Input shape:', inputs[0].shape) + print('Problematic - Eager Output shape:', eager_out.shape) + + # This should fail with OperatorNotAllowedInGraphError when jit_compile=True + @tf.function(jit_compile=True) + def compiled_forward(*args): + return model(*args) + + try: + compiled_out = compiled_forward(*inputs) + print('Problematic - XLA Output shape:', compiled_out.shape) + print("⚠️ WARNING: Expected OperatorNotAllowedInGraphError but compilation succeeded!") + print(" This may indicate XLA is not enabled or a different TF version.") + except Exception as e: + error_msg = str(e) + expected_keywords = ['symbolic', 'python bool', 'not allowed', 'operatornotallowed'] + + if any(keyword in error_msg.lower() for keyword in expected_keywords): + print(f"✓ Expected error caught: {type(e).__name__}") + print(f" Error indicates symbolic tensor used as Python bool") + else: + print(f"? Unexpected error type: {type(e).__name__}") + print(f" Error message: {error_msg}") + print(" This may be related to the symbolic tensor issue or XLA availability.") + + +def test_fixed_version(): + """Test the fixed version - this demonstrates the solution.""" + print("\nTesting fixed version...") + model = get_default_model_fixed() + inputs = get_sample_inputs() + + # Test eager execution + eager_out = model(*inputs) + print('Fixed - Eager Input shape:', inputs[0].shape) + print('Fixed - Eager Output shape:', eager_out.shape) + + # Test XLA compilation - this should now work + @tf.function(jit_compile=True) + def compiled_forward(*args): + return model(*args) + + try: + compiled_out = compiled_forward(*inputs) + print('Fixed - XLA Output shape:', compiled_out.shape) + print("✓ SUCCESS: XLA compilation worked with tf.cond!") + + # Verify deterministic behavior + if eager_out.shape == compiled_out.shape: + print("✓ Output shapes match between eager and XLA execution") + else: + print(f"✗ Shape mismatch: eager {eager_out.shape} vs XLA {compiled_out.shape}") + + # Test numerical consistency (shapes should be deterministic) + if eager_out.dtype == compiled_out.dtype: + print("✓ Output dtypes match between eager and XLA execution") + else: + print(f"✗ Dtype mismatch: eager {eager_out.dtype} vs XLA {compiled_out.dtype}") + + except Exception as e: + print(f"✗ Unexpected error in fixed version: {type(e).__name__}: {e}") + print(" This suggests an issue with the XLA environment or tf.cond implementation.") + + +def check_xla_availability(): + """Check if XLA/JIT compilation is available.""" + try: + @tf.function(jit_compile=True) + def test_xla(): + return tf.constant(1.0) + test_xla() + return True + except Exception: + return False + + +def main(): + """Run both test cases to demonstrate the issue and the fix.""" + print("=" * 70) + print("GitHub Issue #105133: XLA Conditional Compatibility Demo") + print("OperatorNotAllowedInGraphError with tf.shape()[0] in conditionals") + print("=" * 70) + + # Check environment + xla_available = check_xla_availability() + print(f"TensorFlow version: {tf.__version__}") + print(f"XLA/JIT compilation available: {xla_available}") + + if not xla_available: + print("⚠️ WARNING: XLA/JIT compilation not available.") + print(" Some parts of this demo may not show the expected behavior.") + print(" Run on an XLA-enabled TensorFlow build for full demonstration.") + + print() + + # Test the problematic version (expected to fail with XLA) + test_problematic_version() + + # Test the fixed version (should succeed with XLA) + test_fixed_version() + + print("\n" + "=" * 70) + print("SOLUTION SUMMARY:") + print("Issue: Python 'if' with symbolic tensors fails under XLA compilation") + print("Solution: Replace with tf.cond() for XLA compatibility") + print() + print("Key Refactoring Patterns:") + print(" • if condition: -> tf.cond(condition, true_fn, false_fn)") + print(" • if a and b: -> tf.cond(tf.logical_and(a, b), ...)") + print(" • if a == b: -> tf.cond(tf.equal(a, b), ...)") + print(" • if tf.shape(x)[0] > n: -> tf.cond(tf.shape(x)[0] > n, ...)") + print() + print("See xla_conditional_compatibility_test.py for CI-safe unit tests.") + print("=" * 70) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conditional_compatibility_test.py b/tensorflow/compiler/tf2xla/kernels/xla_conditional_compatibility_test.py new file mode 100644 index 00000000000000..aa76d16e426154 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_conditional_compatibility_test.py @@ -0,0 +1,410 @@ +"""CI-safe unit tests for XLA compatibility with conditional operations. + +Regression test for GitHub issue #105133: +OperatorNotAllowedInGraphError when using tf.shape()[0] in conditional with jit_compile=True + +Tested with TF 2.20.0+ on CPU/GPU builds. +XLA-specific tests are skipped when JIT compilation is unavailable. + +This demonstrates user workarounds (tf.cond) vs problematic patterns (Python if with symbolic tensors). +Integration demo requiring XLA-enabled builds is in issue_105133_fix_demo.py. +""" + +import tensorflow as tf +import unittest +import os + + + +def _is_xla_available(): + """Check if XLA/JIT compilation is available in this build.""" + try: + # Try to compile a simple function to test XLA availability + @tf.function(jit_compile=True) + def test_fn(): + return tf.constant(1.0) + test_fn() + return True + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError): + return False + except Exception: + return False + + +class XlaConditionalCompatibilityTest(tf.test.TestCase): + """CI-safe unit tests for XLA conditional compatibility. + + Tests eager execution behavior and tf.cond workarounds. + XLA-specific tests are skipped when JIT is unavailable. + """ + + def test_eager_execution_with_python_conditionals(self): + """Verify that Python conditionals work correctly in eager execution.""" + + class TestModel(tf.keras.Model): + def __init__(self): + super().__init__() + self.dense = tf.keras.layers.Dense(10) + + def call(self, x): + # These work fine in eager execution + if tf.shape(x)[0] >= 1: + x = tf.stop_gradient(x) + + h, w = tf.shape(x)[1], tf.shape(x)[2] + if h > 1 and w > 1: + x = tf.nn.relu(x) + else: + x = tf.nn.tanh(x) + + x = tf.reshape(x, [tf.shape(x)[0], -1]) + return self.dense(x) + + model = TestModel() + x = tf.random.normal([4, 8, 8, 3]) + + # Should work without errors in eager execution + result = model(x) + + # Verify correct shapes and behavior + self.assertEqual(result.shape[0], 4) # Batch size preserved + self.assertEqual(result.shape[1], 10) # Output dimension correct + + @unittest.skipUnless(_is_xla_available(), "XLA/JIT compilation not available") + def test_python_if_with_symbolic_shape_raises_under_xla(self): + """Test that Python if with symbolic tensors fails under XLA compilation.""" + + def problematic_function(x): + # This pattern causes OperatorNotAllowedInGraphError in XLA + if tf.shape(x)[0] >= 1: # PROBLEMATIC: symbolic tensor as Python bool + return tf.stop_gradient(x) + return x + + x = tf.random.normal([4, 8]) + + # Should work in eager execution + eager_result = problematic_function(x) + self.assertEqual(eager_result.shape, x.shape) + + # Should fail when compiled with XLA + @tf.function(jit_compile=True) + def compiled_function(inputs): + return problematic_function(inputs) + + # Check for expected error message patterns + with self.assertRaises(Exception) as context: + compiled_function(x) + + error_msg = str(context.exception) + # Verify it's the expected symbolic tensor error + self.assertTrue( + any(keyword in error_msg.lower() for keyword in + ["symbolic", "python bool", "not allowed", "operatornotallowed"]), + f"Expected symbolic tensor error, got: {error_msg}" + ) + + def test_tf_cond_replacement_allows_jit_compilation(self): + """Test that tf.cond replacement works in both eager and XLA modes.""" + + def tf_cond_function(x): + # XLA-compatible version using tf.cond + return tf.cond( + tf.shape(x)[0] >= 1, + lambda: tf.stop_gradient(x), + lambda: x + ) + + x = tf.random.normal([4, 8]) + + # Test eager execution + eager_result = tf_cond_function(x) + self.assertEqual(eager_result.shape, x.shape) + + # Test graph mode (without XLA first) + @tf.function + def graph_function(inputs): + return tf_cond_function(inputs) + + graph_result = graph_function(x) + self.assertEqual(graph_result.shape, x.shape) + + # Test deterministic behavior: stop_gradient should preserve values + self.assertAllClose(eager_result, x) + self.assertAllClose(graph_result, x) + + @unittest.skipUnless(_is_xla_available(), "XLA/JIT compilation not available") + def test_tf_cond_works_under_xla_compilation(self): + """Test that tf.cond works correctly under XLA compilation.""" + + def xla_compatible_conditional(x): + return tf.cond( + tf.shape(x)[0] > 2, + lambda: tf.nn.relu(x), # Non-negative for large batch + lambda: tf.nn.tanh(x) # Can be negative for small batch + ) + + # Test with different batch sizes to verify conditional logic + x_small = tf.random.normal([2, 4]) # batch_size <= 2, should use tanh + x_large = tf.random.normal([5, 4]) # batch_size > 2, should use relu + + @tf.function(jit_compile=True) + def compiled_function(inputs): + return xla_compatible_conditional(inputs) + + # Should compile and run without errors + result_small = compiled_function(x_small) + result_large = compiled_function(x_large) + + # Verify correct conditional behavior + self.assertEqual(result_small.shape, x_small.shape) + self.assertEqual(result_large.shape, x_large.shape) + + # relu output should be non-negative, tanh can be negative + self.assertTrue(tf.reduce_all(result_large >= 0).numpy(), + "relu output should be non-negative") + # Note: We can't guarantee tanh produces negative values with random input, + # but we can check it's different from relu behavior + eager_small = xla_compatible_conditional(x_small) + self.assertAllClose(result_small, eager_small, rtol=1e-5) + + def test_tf_where_alternative_for_simple_conditionals(self): + """Test tf.where as an alternative to Python conditionals.""" + + def where_based_function(x): + # Use tf.where for element-wise conditionals + return tf.where( + tf.reduce_sum(x, axis=-1, keepdims=True) > 0, + tf.stop_gradient(x), + x * 0.5 + ) + + x = tf.random.normal([4, 8]) + + # Test eager execution + eager_result = where_based_function(x) + self.assertEqual(eager_result.shape, x.shape) + + # Test graph mode + @tf.function + def graph_function(inputs): + return where_based_function(inputs) + + graph_result = graph_function(x) + self.assertEqual(graph_result.shape, x.shape) + self.assertAllClose(eager_result, graph_result) + + def test_mathematical_masking_alternative(self): + """Test mathematical operations as alternative to conditionals.""" + + def mask_based_function(x): + # Use mathematical operations to avoid explicit conditionals + batch_size = tf.cast(tf.shape(x)[0], tf.float32) + threshold = tf.constant(3.0) + + # Create mask: 1.0 if batch_size > threshold, 0.0 otherwise + mask = tf.cast(batch_size > threshold, tf.float32) + + # Apply different operations based on mask + large_batch_result = tf.nn.l2_normalize(x, axis=-1) + small_batch_result = x * 0.5 + + # Combine results using the mask + return mask * large_batch_result + (1.0 - mask) * small_batch_result + + # Test with different batch sizes + x_small = tf.random.normal([2, 4]) # batch_size <= 3 + x_large = tf.random.normal([5, 4]) # batch_size > 3 + + result_small = mask_based_function(x_small) + result_large = mask_based_function(x_large) + + # Verify shapes + self.assertEqual(result_small.shape, x_small.shape) + self.assertEqual(result_large.shape, x_large.shape) + + # Verify behavior: large batch should be normalized (norm ≈ 1) + norms_large = tf.norm(result_large, axis=-1) + self.assertAllClose(norms_large, tf.ones_like(norms_large), atol=1e-5) + + # Small batch should be scaled by 0.5 + expected_small = x_small * 0.5 + self.assertAllClose(result_small, expected_small) + + def test_safe_shape_operations_in_tensorflow_ops(self): + """Test that tf.shape operations are safe when used in TF ops (not Python conditionals).""" + + def safe_shape_operations(x): + # These are safe: tf.shape used in TensorFlow operations, not Python conditionals + current_shape = tf.shape(x) + batch_size = current_shape[0] + + # Safe: Using shape values in TensorFlow operations + half_batch = batch_size // 2 + + # Safe: Reshape using computed shapes + new_shape = tf.concat([current_shape[:1], [-1]], axis=0) + reshaped = tf.reshape(x, new_shape) + + # Safe: Gather operations with dynamic indices + indices = tf.range(half_batch) + gathered = tf.gather(reshaped, indices) + + return gathered, batch_size, half_batch + + x = tf.random.normal([10, 8, 4]) + + # Test eager execution + result, batch_size, half_batch = safe_shape_operations(x) + + # Verify correct behavior + self.assertEqual(len(result.shape), 2) + self.assertEqual(batch_size.numpy(), 10) + self.assertEqual(half_batch.numpy(), 5) + self.assertEqual(result.shape[0], 5) # Half the batch size + + # Test graph mode + @tf.function + def graph_function(inputs): + return safe_shape_operations(inputs) + + graph_result, graph_batch, graph_half = graph_function(x) + + # Verify consistency between eager and graph modes + self.assertAllClose(result, graph_result) + self.assertEqual(batch_size.numpy(), graph_batch.numpy()) + self.assertEqual(half_batch.numpy(), graph_half.numpy()) + + @unittest.skipUnless(_is_xla_available(), "XLA/JIT compilation not available") + def test_safe_shape_operations_under_xla(self): + """Test that safe shape operations work under XLA compilation.""" + + @tf.function(jit_compile=True) + def xla_safe_operations(x): + # These operations should work fine in XLA + batch_size = tf.shape(x)[0] + feature_dim = tf.shape(x)[-1] + + # Use shapes in TensorFlow operations (not Python conditionals) + half_batch = batch_size // 2 + new_shape = [half_batch, feature_dim * 2] + + # Reshape and return + x_subset = x[:half_batch] + x_doubled = tf.concat([x_subset, x_subset], axis=-1) + return tf.reshape(x_doubled, new_shape) + + x = tf.random.normal([8, 4]) + result = xla_safe_operations(x) + + # Verify expected output shape + self.assertEqual(result.shape[0], 4) # half_batch + self.assertEqual(result.shape[1], 8) # feature_dim * 2 + + +class XlaConditionalBestPracticesTest(tf.test.TestCase): + """Best practices for writing XLA-compatible conditional code.""" + + def test_refactoring_patterns_eager_execution(self): + """Test refactoring patterns work correctly in eager execution.""" + + def tf_cond_pattern(x): + # XLA-compatible version using tf.cond + return tf.cond( + tf.shape(x)[0] > 3, + true_fn=lambda: tf.nn.relu(x), + false_fn=lambda: tf.nn.tanh(x) + ) + + # Test with different batch sizes + x_small = tf.random.normal([2, 4]) # batch_size <= 3, should use tanh + x_large = tf.random.normal([5, 4]) # batch_size > 3, should use relu + + result_small = tf_cond_pattern(x_small) + result_large = tf_cond_pattern(x_large) + + # Verify shapes are preserved + self.assertEqual(result_small.shape, x_small.shape) + self.assertEqual(result_large.shape, x_large.shape) + + # Verify relu produces non-negative output + self.assertTrue(tf.reduce_all(result_large >= 0).numpy()) + + @unittest.skipUnless(_is_xla_available(), "XLA/JIT compilation not available") + def test_refactoring_patterns_under_xla(self): + """Test refactoring patterns work under XLA compilation.""" + + @tf.function(jit_compile=True) + def xla_conditional_pattern(x): + return tf.cond( + tf.shape(x)[0] > 3, + true_fn=lambda: tf.nn.relu(x), + false_fn=lambda: tf.nn.tanh(x) + ) + + # Test compilation and execution + x_small = tf.random.normal([2, 4]) + x_large = tf.random.normal([5, 4]) + + # Should compile and run without errors + result_small = xla_conditional_pattern(x_small) + result_large = xla_conditional_pattern(x_large) + + # Verify correct behavior + self.assertEqual(result_small.shape, x_small.shape) + self.assertEqual(result_large.shape, x_large.shape) + self.assertTrue(tf.reduce_all(result_large >= 0).numpy()) + + def test_nested_tf_cond_patterns(self): + """Test nested tf.cond patterns work in eager and graph modes.""" + + def nested_conditional_function(x, training=True): + """Complex conditional logic using nested tf.cond operations.""" + + batch_size = tf.shape(x)[0] + feature_dim = tf.shape(x)[-1] + + def training_path(): + return tf.cond( + batch_size > 8, + lambda: tf.nn.dropout(x, rate=0.3), # Large batch: dropout + lambda: x # Small batch: no dropout + ) + + def inference_path(): + return tf.cond( + feature_dim > 16, + lambda: tf.nn.l2_normalize(x, axis=-1), # High-dim: normalize + lambda: x * 0.5 # Low-dim: scale down + ) + + # Top-level condition + return tf.cond( + training, + training_path, + inference_path + ) + + x = tf.random.normal([10, 32]) + + # Test training mode + train_result = nested_conditional_function(x, training=True) + self.assertEqual(train_result.shape, x.shape) + + # Test inference mode + inference_result = nested_conditional_function(x, training=False) + self.assertEqual(inference_result.shape, x.shape) + + # For high-dim inference, should be normalized (feature_dim=32 > 16) + norms = tf.norm(inference_result, axis=-1) + self.assertAllClose(norms, tf.ones_like(norms), atol=1e-5) + + +if __name__ == '__main__': + # Set up test environment + tf.config.run_functions_eagerly(False) + + # Print XLA availability for debugging + print(f"XLA/JIT compilation available: {_is_xla_available()}") + + # Run the tests + tf.test.main() \ No newline at end of file