Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions concat_fix_repro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import tensorflow as tf

class TestModel(tf.keras.Model):

def __init__(self):
super().__init__()
self.d1 = tf.keras.layers.Dense(64, activation='relu')
self.d2 = tf.keras.layers.Dense(32)
self.d3 = tf.keras.layers.Dense(16)

def call(self, x, indices=None):
x = self.d1(x)
if indices is not None:
(unique_vals, _) = tf.unique(indices)
x = tf.nn.relu(tf.gather(x, unique_vals))
else:
x = tf.nn.relu(x)
partitioned = tf.dynamic_partition(x, tf.cast(tf.reduce_sum(x, axis=1) > 0, tf.int32), num_partitions=2)
x = tf.concat(partitioned, axis=0)
(top_k_values, _) = tf.nn.top_k(x, k=tf.shape(x)[0] // 2)
x = tf.nn.relu(self.d2(top_k_values))
return self.d3(x)


def get_default_model():
return TestModel()


def get_sample_inputs():
x = tf.random.normal([10, 64])
indices = tf.random.uniform([10], maxval=5, dtype=tf.int32)
return (x, indices)


def main():
model = get_default_model()
inputs = get_sample_inputs()
eager_out = model(*inputs)
print('Eager Input shape:', inputs[0].shape)
print('Eager Output shape:', eager_out.shape)
@tf.function(jit_compile=True)
def compiled_forward(*args):
return model(*args)
compiled_out = compiled_forward(*inputs)
print('XLA Output shape:', compiled_out.shape)


if __name__ == '__main__':
main()
83 changes: 83 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/concat_dynamic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Test for dynamic shape preservation in concat operation with XLA compilation."""

import tensorflow as tf
import unittest


class ConcatDynamicShapeTest(unittest.TestCase):

def test_dynamic_partition_concat_topk_matmul_xla(self):
"""Test that XLA preserves dynamic shapes through the full pipeline."""

class TestModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.d1 = tf.keras.layers.Dense(64, activation='relu')
self.d2 = tf.keras.layers.Dense(32)
self.d3 = tf.keras.layers.Dense(16)

def call(self, x, indices=None):
x = self.d1(x)
if indices is not None:
(unique_vals, _) = tf.unique(indices)
x = tf.nn.relu(tf.gather(x, unique_vals))
else:
x = tf.nn.relu(x)

# This chain should preserve dynamic dimensions
partitioned = tf.dynamic_partition(
x, tf.cast(tf.reduce_sum(x, axis=1) > 0, tf.int32),
num_partitions=2)
x = tf.concat(partitioned, axis=0) # Critical: must preserve dynamic dim
(top_k_values, _) = tf.nn.top_k(x, k=tf.shape(x)[0] // 2)
x = tf.nn.relu(self.d2(top_k_values))
return self.d3(x)

model = TestModel()
x = tf.random.normal([10, 64])
indices = tf.random.uniform([10], maxval=5, dtype=tf.int32)

# Test eager execution
eager_out = model(x, indices)
self.assertEqual(len(eager_out.shape), 2)
self.assertEqual(eager_out.shape[1], 16)

# Test XLA compilation - this should not fail with shape errors
@tf.function(jit_compile=True)
def compiled_forward(x_input, indices_input):
return model(x_input, indices_input)

# This should succeed without Matrix size-incompatible error
compiled_out = compiled_forward(x, indices)

# Verify output shapes match between eager and compiled
self.assertEqual(eager_out.shape[1], compiled_out.shape[1])

def test_concat_preserves_dynamic_dimensions(self):
"""Direct test of concat with dynamic partition outputs."""

@tf.function(jit_compile=True)
def test_concat_dynamic():
x = tf.random.normal([8, 32])
partitions = tf.cast(tf.reduce_sum(x, axis=1) > 0, tf.int32)

# Create dynamic partition outputs
partitioned = tf.dynamic_partition(x, partitions, num_partitions=2)

# Concat should preserve dynamic dimension
result = tf.concat(partitioned, axis=0)

return result, tf.shape(result)

result, shape = test_concat_dynamic()

# Should succeed without errors and have reasonable output shape
self.assertEqual(len(result.shape), 2)
self.assertEqual(result.shape[1], 32)
self.assertGreater(shape[0], 0) # Dynamic first dimension


if __name__ == '__main__':
# Run with eager execution to verify test logic
tf.config.run_functions_eagerly(False)
unittest.main()
39 changes: 38 additions & 1 deletion tensorflow/compiler/tf2xla/kernels/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,44 @@ class ConcatBaseOp : public XlaOpKernel {
}

VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
xla::XlaOp result = xla::ConcatInDim(ctx->builder(), input_data, axis);

// Preserve dynamic dimension information for the concat axis.
// If any input has a dynamic dimension along `axis`, compute the
// summed size (mixing dynamic GetDimensionSize() and static constants)
// and set it on the concat result using SetDimensionSize.
bool any_input_dynamic = false;
xla::XlaOp dynamic_size_sum;
bool size_sum_initialized = false;

for (int i = 0; i < N; ++i) {
// Query the input XLA shape to see if this dimension is dynamic.
auto input_shape_or = ctx->InputXlaShape(i);
OP_REQUIRES_OK(ctx, input_shape_or.status());
const xla::Shape& input_shape = *input_shape_or;

xla::XlaOp input_size;
if (input_shape.is_dynamic_dimension(axis)) {
any_input_dynamic = true;
input_size = xla::GetDimensionSize(input_data[i], axis);
} else {
input_size = xla::ConstantR0<int32_t>(ctx->builder(),
input_shape.dimensions(axis));
}

if (!size_sum_initialized) {
dynamic_size_sum = input_size;
size_sum_initialized = true;
} else {
dynamic_size_sum = xla::Add(dynamic_size_sum, input_size);
}
}

if (any_input_dynamic && N > 0) {
result = xla::SetDimensionSize(result, dynamic_size_sum, axis);
}

ctx->SetOutput(0, result);
}

private:
Expand Down
Loading
Loading