Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
85d982f
Initial plan
CodersAcademy006 Nov 28, 2025
742aca0
Add fallback mechanism for large tensors in LaunchConvOpImpl
CodersAcademy006 Nov 28, 2025
42d90cf
Address code review feedback: add bounds checking and validation
CodersAcademy006 Nov 28, 2025
d01cf50
Address remaining code review feedback: improve recursion safety comm…
CodersAcademy006 Nov 28, 2025
be78b4a
Fix XLA JIT compilation with Keras initializers and dynamic shapes (#…
CodersAcademy006 Nov 30, 2025
498997c
Add comprehensive documentation for XLA initializers fix
CodersAcademy006 Nov 30, 2025
2d72f24
Fix XLA JIT compilation with mixed-type dictionary keys (#105333)
CodersAcademy006 Nov 30, 2025
a6f9c64
Delete FIX_SUMMARY_105334.md
CodersAcademy006 Nov 30, 2025
7f75df3
Delete tensorflow/python/ops/demo_xla_initializers_fix.py
CodersAcademy006 Dec 2, 2025
07dedf5
Rename test file for Keras initializers
CodersAcademy006 Dec 2, 2025
3c7a9c3
Delete tensorflow/python/util/demo_mixed_dict_keys.py
CodersAcademy006 Dec 2, 2025
57d3852
Rename test_mixed_dict_keys.py to mixed_dict_keys_test.py
CodersAcademy006 Dec 2, 2025
7d43351
Remove unrelated cuDNN batch-splitting fallback from conv_ops_impl.h
CodersAcademy006 Dec 2, 2025
27e7c38
Revert conv_ops_impl.h changes (remove unrelated cuDNN fallback)
CodersAcademy006 Dec 2, 2025
6b91e7a
Fix _compute_fans: robust XLA-safe _to_int conversion and correct rec…
CodersAcademy006 Dec 2, 2025
30789ad
Move keras_initializers_dynamic_shapes_test to tensorflow/python/kera…
CodersAcademy006 Dec 2, 2025
ab52378
Fix imports in mixed_dict_keys_test to use internal TensorFlow APIs
CodersAcademy006 Dec 2, 2025
004b483
Clean PR: remove accidental keras initializer test from mixed-dict-ke…
CodersAcademy006 Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions tensorflow/core/kernels/conv_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,41 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

// Maximum tensor size (in bytes) that cuDNN can handle safely.
// cuDNN has internal limits around 2GB for certain operations.
// We use a conservative threshold to avoid CUDA invalid resource handle errors.
constexpr int64_t kMaxCudnnTensorSizeBytes = 2LL * 1024 * 1024 * 1024; // 2GB

// Helper function to check if the tensor size exceeds the safe limit for cuDNN.
// Returns true if the tensor is too large and needs fallback processing.
template <typename T>
inline bool IsTensorTooLargeForCudnn(const Tensor& tensor) {
int64_t tensor_size_bytes = tensor.NumElements() * sizeof(T);
return tensor_size_bytes > kMaxCudnnTensorSizeBytes;
}

// Helper function to compute the maximum batch size that keeps the tensor
// under the cuDNN size limit.
template <typename T>
inline int64_t ComputeSafeBatchSize(const Tensor& tensor, int64_t current_batch,
TensorFormat data_format) {
if (current_batch <= 0) return 1;
int64_t total_elements = tensor.NumElements();
if (total_elements <= 0) return 1;
// Handle edge case where total_elements < current_batch
if (total_elements < current_batch) {
// Each batch has less than 1 element on average, return 1
return 1;
}
int64_t elements_per_batch = total_elements / current_batch;
if (elements_per_batch <= 0) return 1;
int64_t max_elements = kMaxCudnnTensorSizeBytes / sizeof(T);
int64_t safe_batch = max_elements / elements_per_batch;
// Ensure at least batch size of 1, and cap at current batch size
return std::max(static_cast<int64_t>(1),
std::min(safe_batch, current_batch));
}

template <typename Device, typename T>
struct LaunchGeneric {
void operator()(OpKernelContext* ctx, const Tensor& input,
Expand Down Expand Up @@ -773,6 +808,123 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune,
absl::InvalidArgumentError("filter must not have zero elements "
"(i.e. all dimensions must be non-zero)"));

// Check if input tensor is too large for cuDNN and needs batch splitting.
// This addresses CUDA invalid resource handle errors with large tensors.
if (IsTensorTooLargeForCudnn<T>(input) && in_batch > 1) {
int64_t safe_batch = ComputeSafeBatchSize<T>(input, in_batch, data_format);
if (safe_batch < in_batch && safe_batch > 0) {
VLOG(2) << "Input tensor too large for cuDNN, splitting batch from "
<< in_batch << " to chunks of " << safe_batch;

// Process in batches to avoid cuDNN memory limits
int64_t batch_idx = GetTensorDimIndex(data_format, 'N', input.dims());

// Validate batch dimension before proceeding
OP_REQUIRES(context, batch_idx >= 0 && batch_idx < input.dims(),
absl::InternalError("Invalid batch dimension index"));
OP_REQUIRES(context, input.dim_size(batch_idx) > 0,
absl::InternalError("Input batch dimension is zero"));
OP_REQUIRES(context, output->dim_size(batch_idx) > 0,
absl::InternalError("Output batch dimension is zero"));

for (int64_t start = 0; start < in_batch; start += safe_batch) {
int64_t chunk_size = std::min(safe_batch, in_batch - start);

// Create sliced input tensor
std::vector<int64_t> input_slice_shape;
for (int i = 0; i < input.dims(); ++i) {
if (i == batch_idx) {
input_slice_shape.push_back(chunk_size);
} else {
input_slice_shape.push_back(input.dim_size(i));
}
}
TensorShape input_slice_ts(input_slice_shape);
Tensor input_slice;
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
input_slice_ts,
&input_slice));

// Create sliced output tensor
std::vector<int64_t> output_slice_shape;
for (int i = 0; i < output->dims(); ++i) {
if (i == batch_idx) {
output_slice_shape.push_back(chunk_size);
} else {
output_slice_shape.push_back(output->dim_size(i));
}
}
TensorShape output_slice_ts(output_slice_shape);
Tensor output_slice;
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
output_slice_ts,
&output_slice));

// Calculate elements per batch with validated dimensions
int64_t input_batch_dim = input.dim_size(batch_idx);
int64_t elements_per_batch = input.NumElements() / input_batch_dim;

// Validate bounds before pointer arithmetic
int64_t input_offset = start * elements_per_batch;
OP_REQUIRES(context, input_offset + chunk_size * elements_per_batch <=
input.NumElements(),
absl::InternalError("Input slice bounds check failed"));

// Copy input slice from input tensor (device to device)
int64_t copy_size_bytes = chunk_size * elements_per_batch * sizeof(T);
auto src_ptr = se::DeviceMemoryBase(
const_cast<T*>(input.template flat<T>().data() + input_offset),
copy_size_bytes);
auto dst_ptr = se::DeviceMemoryBase(
const_cast<T*>(input_slice.template flat<T>().data()),
copy_size_bytes);
OP_REQUIRES_OK(context,
stream->MemcpyD2D(&dst_ptr, src_ptr, copy_size_bytes));

// Recursively call LaunchConvOpImpl with the smaller batch.
// Safety note: The recursive call is guaranteed not to re-enter this
// batch-splitting code path because:
// 1. safe_batch is computed to keep sliced tensors under the size limit
// 2. IsTensorTooLargeForCudnn will return false for the sliced tensor
// 3. Even if it were to trigger, in_batch would equal chunk_size,
// and safe_batch would equal chunk_size, so the condition
// "safe_batch < in_batch" would be false
LaunchConvOpImpl<T>(context, cudnn_use_autotune, input_slice, filter,
dilations, strides, padding, explicit_paddings,
data_format, &output_slice);

// Check for errors from recursive call
if (!context->status().ok()) return;

// Calculate output elements per batch with validated dimensions
int64_t output_batch_dim = output->dim_size(batch_idx);
int64_t output_elements_per_batch =
output->NumElements() / output_batch_dim;

// Validate bounds before pointer arithmetic
int64_t output_offset = start * output_elements_per_batch;
OP_REQUIRES(
context,
output_offset + chunk_size * output_elements_per_batch <=
output->NumElements(),
absl::InternalError("Output slice bounds check failed"));

// Copy output slice to output tensor (device to device)
int64_t output_copy_size_bytes =
chunk_size * output_elements_per_batch * sizeof(T);
auto out_src_ptr = se::DeviceMemoryBase(
const_cast<T*>(output_slice.template flat<T>().data()),
output_copy_size_bytes);
auto out_dst_ptr = se::DeviceMemoryBase(
const_cast<T*>(output->template flat<T>().data() + output_offset),
output_copy_size_bytes);
OP_REQUIRES_OK(context, stream->MemcpyD2D(&out_dst_ptr, out_src_ptr,
output_copy_size_bytes));
}
return;
}
}

bool is_grouped_convolution = filter_depth != in_depth;
// check if filter is 1x1 and stride/dilation are all ones
bool one_filter = true;
Expand Down
182 changes: 182 additions & 0 deletions tensorflow/python/util/mixed_dict_keys_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.util import nest_util
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLA JIT compilation with mixed-type dictionary keys.

This test validates the fix for issue #105333 where @tf.function(jit_compile=True)
fails when returning dictionaries with mixed key types (e.g., strings and integers).
"""

from tensorflow.python.platform import test
from tensorflow.python.util import nest


class XLAMixedDictKeysTest(test.TestCase):
"""Test XLA JIT compilation with mixed-type dictionary keys."""

def test_mixed_string_int_keys_flatten(self):
"""Test flattening dict with mixed string and int keys."""
mixed_dict = {'string_key': 1, 123: 2, 'another': 3, 456: 4}
flattened = nest.flatten(mixed_dict)
# Should flatten successfully with deterministic order
# Keys sorted by type name first (int < str), then by value
self.assertEqual(len(flattened), 4)
self.assertIn(1, flattened)
self.assertIn(2, flattened)
self.assertIn(3, flattened)
self.assertIn(4, flattened)

def test_mixed_keys_with_xla_simple(self):
"""Test simple XLA function with mixed dict keys."""
@tf.function(jit_compile=True)
def simple_mixed_dict(x):
results = {}
results['string_key'] = x
results[123] = x + 1
return results

input_tensor = constant_op.constant([1.0, 2.0, 3.0])
output = simple_mixed_dict(input_tensor)

self.assertIn('string_key', output)
self.assertIn(123, output)
self.assertAllClose(output['string_key'], [1.0, 2.0, 3.0])
self.assertAllClose(output[123], [2.0, 3.0, 4.0])

def test_mixed_keys_with_xla_in_model(self):
"""Test XLA with mixed dict keys in Keras model (original issue #105333)."""
class SimpleModel(tf.keras.Model):
@tf.function(jit_compile=True)
def call(self, x):
results = {}
results['string_key'] = x
results[123] = x + 1
return x, results

model = SimpleModel()
input_tensor = tf.random.normal([2, 16, 16, 16, 32])
output_tensor, output_dict = model(input_tensor)

self.assertEqual(output_tensor.shape, (2, 16, 16, 16, 32))
self.assertIn('string_key', output_dict)
self.assertIn(123, output_dict)

def test_multiple_mixed_types(self):
"""Test dict with multiple mixed key types."""
@tf.function(jit_compile=True)
def multi_type_dict(x):
results = {}
results['str1'] = x
results[1] = x + 1
results['str2'] = x + 2
results[2] = x + 3
results[3] = x + 4
results['str3'] = x + 5
return results

input_tensor = constant_op.constant(10.0)
output = multi_type_dict(input_tensor)

# Verify all keys are present
self.assertIn('str1', output)
self.assertIn('str2', output)
self.assertIn('str3', output)
self.assertIn(1, output)
self.assertIn(2, output)
self.assertIn(3, output)

# Verify values
self.assertAlmostEqual(output['str1'].numpy(), 10.0)
self.assertAlmostEqual(output[1].numpy(), 11.0)
self.assertAlmostEqual(output['str2'].numpy(), 12.0)
self.assertAlmostEqual(output[2].numpy(), 13.0)

def test_nested_mixed_keys(self):
"""Test nested dicts with mixed keys."""
@tf.function(jit_compile=True)
def nested_mixed_dict(x):
inner = {
'inner_str': x,
100: x + 1
}
outer = {
'outer': inner,
200: x + 2
}
return outer

input_tensor = constant_op.constant(5.0)
output = nested_mixed_dict(input_tensor)

self.assertIn('outer', output)
self.assertIn(200, output)
self.assertIn('inner_str', output['outer'])
self.assertIn(100, output['outer'])

def test_pack_sequence_as_with_mixed_keys(self):
"""Test pack_sequence_as with mixed key types."""
structure = {'a': 1, 10: 2, 'b': 3, 20: 4}
flat_sequence = [100, 200, 300, 400]

packed = nest.pack_sequence_as(structure, flat_sequence)

# Verify repacking works correctly
self.assertEqual(len(packed), 4)
# Values should be assigned in sorted key order (int keys first, then str keys)

def test_without_xla_still_works(self):
"""Verify mixed keys work without XLA as well."""
@tf.function(jit_compile=False)
def no_xla_mixed_dict(x):
results = {}
results['string_key'] = x
results[123] = x + 1
return results

input_tensor = constant_op.constant([1.0, 2.0])
output = no_xla_mixed_dict(input_tensor)

self.assertIn('string_key', output)
self.assertIn(123, output)

def test_consistent_ordering(self):
"""Ensure consistent ordering across multiple calls."""
@tf.function(jit_compile=True)
def consistent_dict(x):
results = {}
results['z'] = x
results[3] = x + 1
results['a'] = x + 2
results[1] = x + 3
return results

input_tensor = constant_op.constant(1.0)

# Call multiple times and verify same order
output1 = consistent_dict(input_tensor)
output2 = consistent_dict(input_tensor)
output3 = consistent_dict(input_tensor)

keys1 = sorted(output1.keys(), key=lambda x: (type(x).__name__, x))
keys2 = sorted(output2.keys(), key=lambda x: (type(x).__name__, x))
keys3 = sorted(output3.keys(), key=lambda x: (type(x).__name__, x))

self.assertEqual(keys1, keys2)
self.assertEqual(keys2, keys3)


if __name__ == '__main__':
test.main()
22 changes: 16 additions & 6 deletions tensorflow/python/util/nest_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,19 +272,29 @@ def _tf_core_sorted(dict_):
try:
return sorted(dict_.keys())
except TypeError:
# pylint: disable=raise-missing-from
raise TypeError("nest only supports dicts with sortable keys.")
# If direct sorting fails (e.g., mixed types like int and str),
# try sorting by (type name, key) to group by type first, then by value
try:
return sorted(dict_.keys(), key=lambda x: (type(x).__name__, x))
except TypeError:
# If that still fails, fall back to sorting by string representation
# This ensures deterministic ordering even with complex mixed types
return sorted(dict_.keys(), key=lambda x: (type(x).__name__, str(x)))


def _tf_data_sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
return sorted(list(dict_))
except TypeError as e:
# pylint: disable=raise-missing-from
raise TypeError(
f"nest only supports dicts with sortable keys. Error: {e.message}"
)
# If direct sorting fails (e.g., mixed types like int and str),
# try sorting by (type name, key) to group by type first, then by value
try:
return sorted(list(dict_), key=lambda x: (type(x).__name__, x))
except TypeError:
# If that still fails, fall back to sorting by string representation
# This ensures deterministic ordering even with complex mixed types
return sorted(list(dict_), key=lambda x: (type(x).__name__, str(x)))


def yield_value(modality, iterable):
Expand Down
Loading