From ff878bb22a339eb5f04ec35b1774f78f3157cfd9 Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 21:52:02 +0000 Subject: [PATCH 01/20] device: scaffold UMTensor type + is_device_tile specialization Adds the entry point for native TA::Tensor support on devices, paralleling the existing btas_um_tensor path. UMTensor is a TA::Tensor instantiated on device_um_allocator, which already maps to MemorySpace::Device_UM via platform.h:90. * fwd.h: UMTensor typedef under TILEDARRAY_HAS_DEVICE. * device/tensor.h: is_device_tile> partial specialization so the existing pass-through specs for Tile and LazyArrayTile (tensor/type_traits.h) classify UMTensor-backed tiles as device tiles -- this is the gate that routes the expression engine through madness::add_device_task at binary_eval/unary_eval/contraction_eval and the 6 sites in expressions/expr.h. Also adds detail::to_device / detail::to_host prefetch helpers that go directly through device::memPrefetchAsync (TA::Tensor stores data via shared_ptr rather than a varray, so we cannot route through to_execution_space). * device/tensor.cpp: static_asserts pinning down that the trait fires for UMTensor<{double,float,complex}>, propagates through Tile<>, and does not misclassify plain Tensor. Placeholder for the explicit instantiations that will land in Phase 4. * src/CMakeLists.txt: hook the two files into the TILEDARRAY_HAS_HIP OR TILEDARRAY_HAS_CUDA source list, alongside btas_um_tensor.{h,cpp}. No tile-op overloads yet -- Phase 1 is compile-only. --- src/CMakeLists.txt | 2 + src/TiledArray/device/tensor.cpp | 46 ++++++++++++++++++++ src/TiledArray/device/tensor.h | 75 ++++++++++++++++++++++++++++++++ src/TiledArray/fwd.h | 4 ++ 4 files changed, 127 insertions(+) create mode 100644 src/TiledArray/device/tensor.cpp create mode 100644 src/TiledArray/device/tensor.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 397fdc7a9a..23750eefd0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -234,6 +234,7 @@ if(TILEDARRAY_HAS_HIP OR TILEDARRAY_HAS_CUDA) TiledArray/device/blas.h TiledArray/device/btas.h TiledArray/device/btas_um_tensor.h + TiledArray/device/tensor.h TiledArray/device/device_task_fn.h TiledArray/device/kernel/mult_kernel.h TiledArray/device/kernel/reduce_kernel.h @@ -267,6 +268,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) set(TILEDARRAY_DEVICE_SOURCE_FILES TiledArray/device/btas_um_tensor.cpp + TiledArray/device/tensor.cpp ) if(TILEDARRAY_HAS_CUDA) diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp new file mode 100644 index 0000000000..d2272628c2 --- /dev/null +++ b/src/TiledArray/device/tensor.cpp @@ -0,0 +1,46 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +#include + +#include + +namespace TiledArray::detail { + +// Phase 1 sanity: confirm the is_device_tile specialization fires for the +// allocator alias and propagates through Tile<>. +static_assert(is_device_tile_v>, + "UMTensor must be tagged as a device tile"); +static_assert(is_device_tile_v>, + "UMTensor must be tagged as a device tile"); +static_assert( + is_device_tile_v>>, + "UMTensor> must be tagged as a device tile"); +static_assert(is_device_tile_v>>, + "Tile> must propagate the device-tile tag"); +static_assert(!is_device_tile_v>, + "Plain Tensor must not be tagged as a device tile"); + +} // namespace TiledArray::detail + +// Explicit instantiations of UMTensor and its tile-op overloads land here in +// Phase 4 once the overload set is in place. + diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h new file mode 100644 index 0000000000..e864f0fcc6 --- /dev/null +++ b/src/TiledArray/device/tensor.h @@ -0,0 +1,75 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +#ifndef TILEDARRAY_DEVICE_TENSOR_H +#define TILEDARRAY_DEVICE_TENSOR_H + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include +#include + +namespace TiledArray { +namespace detail { + +/// `UMTensor` lives in unified memory; the expression engine must route its +/// tile ops through `madness::add_device_task`. The pass-through specs for +/// `Tile` and `LazyArrayTile` in tensor/type_traits.h pick this up. +template +struct is_device_tile> : public std::true_type {}; + +/// Prefetch a UMTensor's storage to the device associated with its tile range. +/// Mirrors the pattern in device/btas_um_tensor.h but reaches the storage via +/// `.data()` + `.total_size()` since `TA::Tensor`'s buffer is a +/// `shared_ptr` rather than a varray-like container. +template +inline void to_device(const TiledArray::UMTensor& tile) { + if (tile.empty()) return; + auto stream = device::stream_for(tile.range()); + if (deviceEnv::instance()->concurrent_managed_access()) { + DeviceSafeCall(device::memPrefetchAsync(tile.data(), + tile.total_size() * sizeof(T), + stream.device, stream.stream)); + } +} + +/// Prefetch a UMTensor's storage back to the host. +template +inline void to_host(const TiledArray::UMTensor& tile) { + if (tile.empty()) return; + auto stream = device::stream_for(tile.range()); + if (deviceEnv::instance()->concurrent_managed_access()) { + DeviceSafeCall(device::memPrefetchAsync(tile.data(), + tile.total_size() * sizeof(T), + device::CpuDeviceId, stream.stream)); + } +} + +} // namespace detail +} // namespace TiledArray + +#endif // TILEDARRAY_HAS_DEVICE + +#endif // TILEDARRAY_DEVICE_TENSOR_H diff --git a/src/TiledArray/fwd.h b/src/TiledArray/fwd.h index 00c36a5092..70aee401bf 100644 --- a/src/TiledArray/fwd.h +++ b/src/TiledArray/fwd.h @@ -142,6 +142,10 @@ template using btasUMTensorVarray = ::btas::Tensor>; +/// TA::Tensor backed by the unified-memory allocator, usable as a device tile +template +using UMTensor = Tensor>; + #endif // TILEDARRAY_HAS_DEVICE template From 20a35bbc991f2c9d4fc2fbe083836b4f2b242237 Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 21:55:32 +0000 Subject: [PATCH 02/20] device: tier-1 tile-op overloads for UMTensor Adds the smallest tile-op set that the expression engine needs in order to evaluate a `C("ij") = a*A("ik")*B("kj")`-style expression plus a norm on UMTensor: clone, scale/scale_to, neg/neg_to, add/add_to (+ scaled forms), subt/subt_to (+ scaled forms), dot, squared_norm, norm, and gemm (returning + accumulating). Element-wise mult, permute, shift, and batched paths are deliberately not included yet -- they need librett / custom kernels and a clear nbatch story, and dragging them in here would obscure whether the dispatch surface alone is correct. Each overload is a concrete-type free function in `namespace TiledArray` so ADL prefers it over the generic templated forwarders in `tile_op/tile_interface.h`. No constraint relaxation, no member-function revert -- the dispatch falls out of overload partial ordering. The kernel pattern mirrors `device/btas.h`: resolve a BLAS++ queue from `blasqueue_for(range)`, prefetch every input + result to the device, call into BLAS++, then `sync_madness_task_with(stream)` so the wrapping MADNESS device task waits for the queue to drain. We do NOT thread an explicit `blas::Queue&` through composite ops -- `stream_for(range)` in `external/device.h` already returns the current task's stream when invoked inside a device task, which is what we want. For now batched tiles (`nbatch_ > 1`) are asserted away; the expression engine doesn't currently route batched UMTensor through these ops and silently miscomputing would be worse than a clear assert. `device/tensor.cpp` grows a tiny instantiation probe that exercises the full Tier-1 surface for `double` and `float`, so anything that doesn't type-check breaks the build immediately rather than waiting on Phase 5 tests. Real explicit instantiations land in Phase 4. --- src/TiledArray/device/tensor.cpp | 38 ++- src/TiledArray/device/tensor.h | 405 +++++++++++++++++++++++++++++++ 2 files changed, 441 insertions(+), 2 deletions(-) diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp index d2272628c2..962e13c03e 100644 --- a/src/TiledArray/device/tensor.cpp +++ b/src/TiledArray/device/tensor.cpp @@ -41,6 +41,40 @@ static_assert(!is_device_tile_v>, } // namespace TiledArray::detail -// Explicit instantiations of UMTensor and its tile-op overloads land here in -// Phase 4 once the overload set is in place. +// Phase 2 instantiation probes: force the compiler to type-check the +// device-tile overloads. Real explicit instantiations land in Phase 4. +namespace { + +template +void compile_test_tier1() { + using TA::UMTensor; + using helper_t = TiledArray::math::GemmHelper; + UMTensor a, b, c; + helper_t h(TiledArray::math::blas::Op::NoTrans, + TiledArray::math::blas::Op::NoTrans, 2u, 2u, 2u); + + (void)TiledArray::clone(a); + (void)TiledArray::scale(a, T(2)); + (void)TiledArray::scale_to(a, T(2)); + (void)TiledArray::neg(a); + (void)TiledArray::neg_to(a); + (void)TiledArray::add(a, b); + (void)TiledArray::add(a, b, T(2)); + (void)TiledArray::add_to(a, b); + (void)TiledArray::add_to(a, b, T(2)); + (void)TiledArray::subt(a, b); + (void)TiledArray::subt(a, b, T(2)); + (void)TiledArray::subt_to(a, b); + (void)TiledArray::subt_to(a, b, T(2)); + (void)TiledArray::dot(a, b); + (void)TiledArray::squared_norm(a); + (void)TiledArray::norm(a); + (void)TiledArray::gemm(a, b, T(1), h); + TiledArray::gemm(c, a, b, T(1), h); +} + +[[maybe_unused]] auto instantiate_tier1_double = &compile_test_tier1; +[[maybe_unused]] auto instantiate_tier1_float = &compile_test_tier1; + +} // namespace diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index e864f0fcc6..e42e90e73e 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -26,11 +26,17 @@ #ifdef TILEDARRAY_HAS_DEVICE +#include #include +#include +#include +#include #include #include #include +#include + namespace TiledArray { namespace detail { @@ -68,6 +74,405 @@ inline void to_host(const TiledArray::UMTensor& tile) { } } // namespace detail + +// --------------------------------------------------------------------------- +// Tile-op overloads for UMTensor. +// +// Each overload sits in `namespace TiledArray` so ADL finds it from the +// expression engine and from the tile_op layer's free-function defaults. +// More-specialized concrete-type overloads win against the generic +// `template ... add(left, right) { return +// left.add(right); }` forwarders in `tile_op/tile_interface.h`, so we never +// fall back to the CPU member functions for UMTensor. +// +// All overloads follow the stream/queue contract: +// 1. Resolve a queue via `blasqueue_for(range)`. Inside a device task this +// is the same queue everyone else in the task uses (see +// `external/device.h:899-907`); outside one, it round-robins. +// 2. Prefetch every input + the result to the device. +// 3. Call into BLAS++ / device kernels on that queue. +// 4. `sync_madness_task_with(stream)` so the enclosing MADNESS device task +// waits for the queue to drain before completing. +// +// For Phase 2 batched tiles (`nbatch_ > 1`) are not yet supported -- the +// expression engine doesn't currently feed batched UMTensor through these +// paths, and dropping the assertion now would silently miscompute. +// --------------------------------------------------------------------------- + +/// result[i] = arg[i] +template +inline UMTensor clone(const UMTensor& arg) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(arg.nbatch() == 1); + + auto& queue = blasqueue_for(arg.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg.range()); + + detail::to_device(arg); + detail::to_device(result); + + blas::copy(result.size(), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +namespace detail { + +/// Apply a scaling factor in-place on the device, replicating the +/// ComplexConjugate handling from device/btas.h::scale. Real-valued kernels +/// reduce to a single `blas::scal`; conjugation+scale on complex tiles +/// requires a custom kernel that we have not implemented yet. +template +inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor, + ::blas::Queue& queue) { + if constexpr (TiledArray::detail::is_blas_numeric_v || + std::is_arithmetic_v) { + ::blas::scal(n, factor, data, 1, queue); + } else { + if constexpr (TiledArray::detail::is_complex_v) { + abort(); // fused conjugation requires custom kernels, not yet supported + } else { + if constexpr (std::is_same_v< + Scalar, TiledArray::detail::ComplexConjugate>) { + // conjugation on a real tensor is a no-op + } else if constexpr (std::is_same_v< + Scalar, + TiledArray::detail::ComplexConjugate< + TiledArray::detail::ComplexNegTag>>) { + ::blas::scal(n, static_cast(-1), data, 1, queue); + } + } + } +} + +} // namespace detail + +/// result[i] = arg[i] * factor +template >> +inline UMTensor scale(const UMTensor& arg, const Scalar factor) { + auto result = clone(arg); + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + detail::apply_scale_factor(result.data(), result.size(), factor, queue); + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] *= factor (in-place) +template >> +inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { + TA_ASSERT(!result.empty()); + TA_ASSERT(result.nbatch() == 1); + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + detail::to_device(result); + detail::apply_scale_factor(result.data(), result.size(), factor, queue); + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = -arg[i] +template +inline UMTensor neg(const UMTensor& arg) { + return scale(arg, T(-1)); +} + +/// arg[i] = -arg[i] (in-place) +template +inline UMTensor& neg_to(UMTensor& arg) { + return scale_to(arg, T(-1)); +} + +/// result[i] = arg1[i] + arg2[i] +template +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg1.range()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + ::blas::copy(result.size(), arg1.data(), 1, result.data(), 1, queue); + ::blas::axpy(result.size(), T(1), arg2.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = (arg1[i] + arg2[i]) * factor +template >> +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor) { + auto result = add(arg1, arg2); + return scale_to(result, factor); +} + +/// result[i] += arg[i] +template +inline UMTensor& add_to(UMTensor& result, const UMTensor& arg) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + ::blas::axpy(result.size(), T(1), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] += arg[i] * factor +template >> +inline UMTensor& add_to(UMTensor& result, const UMTensor& arg, + const Scalar factor) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + ::blas::axpy(result.size(), T(factor), arg.data(), 1, result.data(), 1, + queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = arg1[i] - arg2[i] +template +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg1.range()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + ::blas::copy(result.size(), arg1.data(), 1, result.data(), 1, queue); + ::blas::axpy(result.size(), T(-1), arg2.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = (arg1[i] - arg2[i]) * factor +template >> +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor) { + auto result = subt(arg1, arg2); + return scale_to(result, factor); +} + +/// result[i] -= arg[i] +template +inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + ::blas::axpy(result.size(), T(-1), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] -= arg[i] * factor +template >> +inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg, + const Scalar factor) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + ::blas::axpy(result.size(), T(-factor), arg.data(), 1, result.data(), 1, + queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// dot product: scalar = sum_i arg1[i] * arg2[i] +template +inline T dot(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + TA_ASSERT(arg1.size() == arg2.size()); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(arg1); + detail::to_device(arg2); + + T result(0); + ::blas::dot(arg1.size(), arg1.data(), 1, arg2.data(), 1, &result, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// scalar = sum_i arg[i] * arg[i] +template +inline auto squared_norm(const UMTensor& arg) { + return dot(arg, arg); +} + +/// scalar = sqrt(squared_norm(arg)) +template +inline auto norm(const UMTensor& arg) { + using std::sqrt; + using ResultType = TiledArray::detail::scalar_t; + return static_cast(sqrt(squared_norm(arg))); +} + +/// gemm: returning form. result = factor * left * right +template >> +inline UMTensor gemm(const UMTensor& left, const UMTensor& right, + const Scalar factor, + const TiledArray::math::GemmHelper& gemm_helper) { + TA_ASSERT(!left.empty()); + TA_ASSERT(!right.empty()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + TA_ASSERT(left.nbatch() == 1 && right.nbatch() == 1); + + auto result_range = gemm_helper.template make_result_range( + left.range(), right.range()); + + auto& queue = blasqueue_for(result_range); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(result_range); + TA_ASSERT(result.nbatch() == 1); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m); + const integer ldb = std::max( + integer{1}, + gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k); + const integer ldc = std::max(integer{1}, n); + + const T factor_t = T(factor); + const T zero(0); + + // Match btas device gemm (device/btas.h): col-major view with right/left + // swapped reproduces TA::Tensor's row-major layout under cublas. + ::blas::gemm(::blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, right.data(), ldb, + left.data(), lda, zero, result.data(), ldc, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// gemm: accumulating form. result += factor * left * right +template >> +inline void gemm(UMTensor& result, const UMTensor& left, + const UMTensor& right, const Scalar factor, + const TiledArray::math::GemmHelper& gemm_helper) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!left.empty()); + TA_ASSERT(!right.empty()); + TA_ASSERT(result.range().rank() == gemm_helper.result_rank()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + TA_ASSERT(left.nbatch() == 1 && right.nbatch() == 1 && result.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m); + const integer ldb = std::max( + integer{1}, + gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k); + const integer ldc = std::max(integer{1}, n); + + const T factor_t = T(factor); + const T one(1); + + ::blas::gemm(::blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, right.data(), ldb, + left.data(), lda, one, result.data(), ldc, queue); + + device::sync_madness_task_with(stream); +} + } // namespace TiledArray #endif // TILEDARRAY_HAS_DEVICE From f8e870ead52832cf86bcb958bb10e9e0a07a3cab Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 21:57:09 +0000 Subject: [PATCH 03/20] device: add permute / shift / mult + perm-variants for UMTensor Rounds out Tier 1 with the ops that require librett or the device element-wise kernel: * permute(arg, Permutation) -- librett_permute on the tile data * permute(arg, BipartitePermutation)-- forwards to outer(perm); required to win ADL against the generic member-delegating overload (the analogous comment lives in device/btas_um_tensor.h:193) * shift(arg, bound_shift) -- copy + Range::inplace_shift * shift_to(arg, bound_shift) -- in-place range shift, no copy * mult / mult_to (+ scaled + permuted variants) via device::mult_kernel and device::mult_to_kernel * scale(a, f, perm), neg(a, perm), add(..., perm), subt(..., perm), and their scaled-with-perm forms -- thin compositions over the non-permuted core and the new permute All sit in `namespace TiledArray`; ADL wins over the `tile_op/tile_interface.h` defaults the same way it does for the non-permuted core. The instantiation probe in device/tensor.cpp now exercises the full Tier-1 surface (Phase 2a + 2b). --- src/TiledArray/device/tensor.cpp | 21 ++++ src/TiledArray/device/tensor.h | 203 +++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+) diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp index 962e13c03e..25c643c5a4 100644 --- a/src/TiledArray/device/tensor.cpp +++ b/src/TiledArray/device/tensor.cpp @@ -71,6 +71,27 @@ void compile_test_tier1() { (void)TiledArray::norm(a); (void)TiledArray::gemm(a, b, T(1), h); TiledArray::gemm(c, a, b, T(1), h); + + // Phase 2b: permute / shift / mult and the perm-variants. + TiledArray::Permutation perm(std::vector{1, 0}); + TiledArray::BipartitePermutation bperm(perm); + std::vector shift{0, 0}; + (void)TiledArray::permute(a, perm); + (void)TiledArray::permute(a, bperm); + (void)TiledArray::shift(a, shift); + (void)TiledArray::shift_to(a, shift); + (void)TiledArray::scale(a, T(2), perm); + (void)TiledArray::neg(a, perm); + (void)TiledArray::add(a, b, perm); + (void)TiledArray::add(a, b, T(2), perm); + (void)TiledArray::subt(a, b, perm); + (void)TiledArray::subt(a, b, T(2), perm); + (void)TiledArray::mult(a, b); + (void)TiledArray::mult(a, b, T(2)); + (void)TiledArray::mult(a, b, perm); + (void)TiledArray::mult(a, b, T(2), perm); + (void)TiledArray::mult_to(a, b); + (void)TiledArray::mult_to(a, b, T(2)); } [[maybe_unused]] auto instantiate_tier1_double = &compile_test_tier1; diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index e42e90e73e..1c0eb226bb 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -27,9 +27,12 @@ #ifdef TILEDARRAY_HAS_DEVICE #include +#include #include +#include #include #include +#include #include #include #include @@ -378,6 +381,206 @@ inline auto norm(const UMTensor& arg) { return static_cast(sqrt(squared_norm(arg))); } +/// result[perm(i)] = arg[i] +template +inline UMTensor permute(const UMTensor& arg, + const TiledArray::Permutation& perm) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(arg.nbatch() == 1); + TA_ASSERT(perm.size() == arg.range().rank()); + + auto result_range = perm * arg.range(); + auto& queue = blasqueue_for(result_range); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + // librett operates on the original (unpermuted) range and writes into the + // permuted layout; pointers go in as-is. + librett_permute(const_cast(arg.data()), result.data(), arg.range(), perm, + stream.stream); + + device::sync_madness_task_with(stream); + return result; +} + +/// BipartitePermutation -> plain Permutation forward. +/// Required to win ADL against the generic CPU member-delegating overload; +/// see the matching warning in device/btas_um_tensor.h:193. +template +inline UMTensor permute(const UMTensor& arg, + const TiledArray::BipartitePermutation& perm) { + TA_ASSERT(inner_size(perm) == 0); // UMTensor is a non-nested tile + return permute(arg, outer(perm)); +} + +/// result[perm(i)] = arg[i] * factor +template && + TiledArray::detail::is_permutation_v>> +inline UMTensor scale(const UMTensor& arg, const Scalar factor, + const Perm& perm) { + auto scaled = scale(arg, factor); + return permute(scaled, perm); +} + +/// result[perm(i)] = -arg[i] +template >> +inline UMTensor neg(const UMTensor& arg, const Perm& perm) { + return permute(neg(arg), perm); +} + +/// result[perm(i)] = arg1[i] + arg2[i] +template >> +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, + const Perm& perm) { + return permute(add(arg1, arg2), perm); +} + +/// result[perm(i)] = (arg1[i] + arg2[i]) * factor +template && + TiledArray::detail::is_permutation_v>> +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor, const Perm& perm) { + return permute(add(arg1, arg2, factor), perm); +} + +/// result[perm(i)] = arg1[i] - arg2[i] +template >> +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, + const Perm& perm) { + return permute(subt(arg1, arg2), perm); +} + +/// result[perm(i)] = (arg1[i] - arg2[i]) * factor +template && + TiledArray::detail::is_permutation_v>> +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor, const Perm& perm) { + return permute(subt(arg1, arg2, factor), perm); +} + +/// shift: result has arg's data, range shifted by bound_shift. +template +inline UMTensor shift(const UMTensor& arg, const Index& bound_shift) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(arg.nbatch() == 1); + + TiledArray::Range result_range(arg.range()); + result_range.inplace_shift(bound_shift); + + auto& queue = blasqueue_for(result_range); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + ::blas::copy(result.size(), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// shift_to: in-place range shift, no data movement. +template +inline UMTensor& shift_to(UMTensor& arg, const Index& bound_shift) { + const_cast(arg.range()).inplace_shift(bound_shift); + return arg; +} + +/// result[i] = arg1[i] * arg2[i] (element-wise / Hadamard) +template +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.size() == arg2.size()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg1.range()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + device::mult_kernel(result.data(), arg1.data(), arg2.data(), arg1.size(), + stream); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = arg1[i] * arg2[i] * factor +template >> +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor) { + auto result = mult(arg1, arg2); + return scale_to(result, factor); +} + +/// result[perm(i)] = arg1[i] * arg2[i] +template >> +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, + const Perm& perm) { + return permute(mult(arg1, arg2), perm); +} + +/// result[perm(i)] = arg1[i] * arg2[i] * factor +template && + TiledArray::detail::is_permutation_v>> +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor, const Perm& perm) { + return permute(mult(arg1, arg2, factor), perm); +} + +/// result[i] *= arg[i] +template +inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.size() == arg.size()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + device::mult_to_kernel(result.data(), arg.data(), result.size(), stream); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] *= arg[i] * factor +template >> +inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg, + const Scalar factor) { + mult_to(result, arg); + return scale_to(result, factor); +} + /// gemm: returning form. result = factor * left * right template >> From ca5c9df0ffb5001495e877a94796b9fff34d4910 Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 22:45:18 +0000 Subject: [PATCH 04/20] device: fix in-place op dispatch + correct scale semantics, add expression tests Three bug categories surfaced once expressions actually ran end-to-end: 1. **Scale-factor semantics on `*_to(arg, factor)` were inverted.** `TA::Tensor::add_to(r, factor)` is `(this += r) *= factor`, and `TA::Tensor::subt_to(r, factor)` is `(this -= r) *= factor` -- the scaling applies to the result of the in-place op, not to `arg` alone. My initial implementations computed `result += factor * arg` and `result -= factor * arg`, which gives wrong values whenever the engine calls `subt_to(std::move(second), first, -1)` from `tile_op/subt.h:117` (the "right is consumable" branch -- expected to compute `(second - first) * -1 = first - second`). The new implementations chain `add_to/subt_to(no factor)` + `scale_to`, matching TA::Tensor's `(l += r) *= factor` pattern. 2. **Engine passes consumable operands as rvalues.** Subt::eval and ScalSubt::eval at `tile_op/subt.h:117,302` pass the result tile via `std::move(...)`. A plain `UMTensor&` overload is not a viable candidate for an rvalue argument, so overload resolution falls through to the generic forwarder in `tile_op/tile_interface.h` (and `tile_interface/scale.h`) that delegates to TA::Tensor's CPU member function. The CPU member reads UM memory while the prior device kernel is still in flight on the queue -- silently miscomputing. Every in-place op (scale_to, neg_to, add_to[x2], subt_to[x2], mult_to[x2], shift_to) now has two concrete-type overloads: `UMTensor&` and `UMTensor&&`. Concrete types beat the templated forwarding reference in partial ordering regardless of constraint shape -- a constrained forwarding-ref version would in principle also win (constrained > unconstrained per [temp.constr. order]) but g++ does not consistently treat tile_interface's `enable_if`-only forwarders as unconstrained for this purpose; the two-concrete-overload form is robust. 3. **No correctness coverage existed for the dispatch path.** The previous `tensor_device.cpp` exercised tile-level ops directly from the main thread, which is the wrong contract: tile ops cooperate with the enclosing `madness::add_device_task` for stream sync and never run that way in production. Replaced with `tests/expressions_device_tensor.cpp`, mirroring the structure of `expressions_device_um.cpp` (the existing btas device test). 18 cases including: trait classification, direct assign, permute, scale, neg, add/subt (+ with-permute / -to / -with-factor variants), scaled-subt isolations on left/right, mixed linear combination (catches bug #1), Hadamard, contraction, norm2 / dot reductions, and a `reuse_stress` case repeating `dot(a, a)` 8x (catches the LazyArrayTile conversion race in MPQC's pattern -- expected to be a master-branch baseline failure not introduced here). All 18 pass. Also documents the `UMTensorArg` concept inline as the marker for "this is a UMTensor (any cv/ref qualifier)" -- kept around as documentation of intent even where the dispatch tiebreak forced us to use concrete-type overloads instead. --- src/TiledArray/device/tensor.h | 167 +++++++++++++----- tests/CMakeLists.txt | 2 +- tests/expressions_device_tensor.cpp | 264 ++++++++++++++++++++++++++++ 3 files changed, 390 insertions(+), 43 deletions(-) create mode 100644 tests/expressions_device_tensor.cpp diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 1c0eb226bb..185db47813 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -78,6 +78,51 @@ inline void to_host(const TiledArray::UMTensor& tile) { } // namespace detail +// --------------------------------------------------------------------------- +// In-place tile ops are tricky to dispatch correctly. +// +// `tile_op/{subt,add,mult,...}.h::Op::eval` passes the result via +// `std::move(...)` when the operand is consumable -- so the engine calls our +// `subt_to`, `add_to`, etc. with an rvalue. A plain `UMTensor&` overload +// is not a viable candidate for an rvalue, so overload resolution falls +// through to the generic forwarder in `tile_op/tile_interface.h` (and +// `tile_interface/scale.h`). That forwarder delegates to TA::Tensor's CPU +// member function, which then reads UM memory while the previous device +// kernel is still in flight on the queue -- silently miscomputing. +// +// To win the dispatch we provide two concrete-type overloads per in-place +// op: one taking `UMTensor&` and one taking `UMTensor&&`. Concrete +// types beat the templated forwarding reference `Result&&` in partial +// ordering regardless of the SFINAE / `requires` constraint shape, so this +// is robust against compiler differences. (A single forwarding-ref overload +// constrained with a `requires UMTensorArg<...>` concept would in principle +// also win because a constrained template subsumes an unconstrained one, +// but g++ does not consistently treat tile_interface's `enable_if`-only +// templates as unconstrained for this purpose -- the result is an ambiguous +// overload error. The two-concrete-overload form sidesteps the question.) +// +// The lvalue overload forwards to the rvalue overload to keep a single +// implementation per op. Value-returning overloads (e.g. +// `add(const UMTensor&, const UMTensor&)`) don't need this because +// reference-to-const binds to both lvalues and rvalues. +// +// The `UMTensorArg` concept is kept around as documentation of intent and +// as a clean handle for any future helper that genuinely wants forwarding +// references (e.g. a `to_device` overload set). +// --------------------------------------------------------------------------- +namespace detail { +template +struct is_um_tensor : std::false_type {}; +template +struct is_um_tensor> : std::true_type {}; +template +inline constexpr bool is_um_tensor_v = + is_um_tensor>>::value; +} // namespace detail + +template +concept UMTensorArg = detail::is_um_tensor_v; + // --------------------------------------------------------------------------- // Tile-op overloads for UMTensor. // @@ -166,9 +211,12 @@ inline UMTensor scale(const UMTensor& arg, const Scalar factor) { return result; } -/// result[i] *= factor (in-place) -template >> +/// result[i] *= factor (in-place). Forwarding-reference form so the engine's +/// `scale_to(std::move(tile), factor)` (from `tile_op/scal.h:82`) dispatches +/// here rather than to the tile_interface forwarder that would call the CPU +/// member function on UM memory. +template + requires TiledArray::detail::is_numeric_v inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { TA_ASSERT(!result.empty()); TA_ASSERT(result.nbatch() == 1); @@ -181,6 +229,12 @@ inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { return result; } +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& scale_to(UMTensor&& result, const Scalar factor) { + return scale_to(result, factor); +} + /// result[i] = -arg[i] template inline UMTensor neg(const UMTensor& arg) { @@ -193,6 +247,11 @@ inline UMTensor& neg_to(UMTensor& arg) { return scale_to(arg, T(-1)); } +template +inline UMTensor& neg_to(UMTensor&& arg) { + return scale_to(arg, T(-1)); +} + /// result[i] = arg1[i] + arg2[i] template inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2) { @@ -246,27 +305,26 @@ inline UMTensor& add_to(UMTensor& result, const UMTensor& arg) { return result; } -/// result[i] += arg[i] * factor -template >> +template +inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg) { + return add_to(result, arg); +} + +/// result[i] = (result[i] + arg[i]) * factor +/// Matches TA::Tensor::add_to(right, factor) semantics: `(l += r) *= factor`. +template + requires TiledArray::detail::is_numeric_v inline UMTensor& add_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { - TA_ASSERT(!result.empty()); - TA_ASSERT(!arg.empty()); - TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); - - auto& queue = blasqueue_for(result.range()); - const device::Stream stream(queue.device(), queue.stream()); - DeviceSafeCall(device::setDevice(stream.device)); - - detail::to_device(result); - detail::to_device(arg); - - ::blas::axpy(result.size(), T(factor), arg.data(), 1, result.data(), 1, - queue); + add_to(result, arg); + return scale_to(result, factor); +} - device::sync_madness_task_with(stream); - return result; +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg, + const Scalar factor) { + return add_to(result, arg, factor); } /// result[i] = arg1[i] - arg2[i] @@ -322,27 +380,34 @@ inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg) { return result; } -/// result[i] -= arg[i] * factor -template >> +template +inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg) { + return subt_to(result, arg); +} + +/// result[i] = (result[i] - arg[i]) * factor +/// Matches TA::Tensor::subt_to(right, factor) semantics: `(l -= r) *= factor`. +/// This convention is load-bearing for `tile_op/subt.h::Subt::eval` -- when +/// the engine reuses the right operand's storage, it calls +/// `subt_to(std::move(second), first, -1)` and relies on the result being +/// `(second - first) * -1 = first - second`. Hence the forwarding reference +/// on `result`: lvalue-only signatures lose overload resolution to the +/// templated forwarder in tile_op/tile_interface.h, which then dispatches to +/// TA::Tensor's CPU member function and races with any in-flight device +/// kernel on UM memory. +template + requires TiledArray::detail::is_numeric_v inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { - TA_ASSERT(!result.empty()); - TA_ASSERT(!arg.empty()); - TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); - - auto& queue = blasqueue_for(result.range()); - const device::Stream stream(queue.device(), queue.stream()); - DeviceSafeCall(device::setDevice(stream.device)); - - detail::to_device(result); - detail::to_device(arg); - - ::blas::axpy(result.size(), T(-factor), arg.data(), 1, result.data(), 1, - queue); + subt_to(result, arg); + return scale_to(result, factor); +} - device::sync_madness_task_with(stream); - return result; +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg, + const Scalar factor) { + return subt_to(result, arg, factor); } /// dot product: scalar = sum_i arg1[i] * arg2[i] @@ -500,6 +565,11 @@ inline UMTensor& shift_to(UMTensor& arg, const Index& bound_shift) { return arg; } +template +inline UMTensor& shift_to(UMTensor&& arg, const Index& bound_shift) { + return shift_to(arg, bound_shift); +} + /// result[i] = arg1[i] * arg2[i] (element-wise / Hadamard) template inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2) { @@ -572,15 +642,28 @@ inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg) { return result; } -/// result[i] *= arg[i] * factor -template >> +template +inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg) { + return mult_to(result, arg); +} + +/// result[i] = (result[i] * arg[i]) * factor +/// Matches TA::Tensor::mult_to(right, factor) semantics: `(l *= r) *= factor`. +template + requires TiledArray::detail::is_numeric_v inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { mult_to(result, arg); return scale_to(result, factor); } +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg, + const Scalar factor) { + return mult_to(result, arg, factor); +} + /// gemm: returning form. result = factor * left * right template >> diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a30770fb18..ee3df0196a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -105,7 +105,7 @@ set(ta_test_src_files ta_test.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp tensor_um.cpp) + list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp tensor_um.cpp expressions_device_tensor.cpp) endif() # if using C++20 must use Boost 1.74 or later: diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp new file mode 100644 index 0000000000..204a7a34f1 --- /dev/null +++ b/tests/expressions_device_tensor.cpp @@ -0,0 +1,264 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include +#include "unit_test_config.h" + +using namespace TiledArray; + +// Expression-engine tests for the native UMTensor tile type (TA::Tensor +// backed by device_um_allocator). The pattern follows expressions_device_um.cpp +// but uses the bare TA::Tensor specialization -- TA::Tensor is already +// shallow-copy, so we do not wrap it in TA::Tile<> (per CLAUDE.md guidance). +// +// All correctness checks use a CPU-side TiledArray::Tensor mirror +// of the input arrays (built from `find().get()` on the device side and a +// flat std::vector for reference). The expression runs through the engine +// for both sides; we then compare elements after `gop.fence()` to make sure +// the device kernels have actually completed. + +struct DeviceTensorExpressionsFixture : public TiledRangeFixture { + using TileD = UMTensor; + using TArrayD = TiledArray::DistArray; + + using HostTile = TiledArray::Tensor; + using HostArray = TiledArray::DistArray; + + static constexpr double tolerance = 5.0e-14; + + DeviceTensorExpressionsFixture() + : a(*GlobalFixture::world, tr), + b(*GlobalFixture::world, tr), + c(*GlobalFixture::world, tr), + a_h(*GlobalFixture::world, tr), + b_h(*GlobalFixture::world, tr), + c_h(*GlobalFixture::world, tr) { + fill_with_seed(a, a_h, 7); + fill_with_seed(b, b_h, 11); + GlobalFixture::world->gop.fence(); + } + + ~DeviceTensorExpressionsFixture() { GlobalFixture::world->gop.fence(); } + + // Fill paired device + host arrays with the same deterministic data so + // the host array is an exact reference for the device expression result. + template + static void fill_with_seed(DeviceArray& d, HostArrayT& h, int seed) { + auto pmap_d = d.pmap(); + for (auto it = pmap_d->begin(); it != pmap_d->end(); ++it) { + const auto tile_range = d.trange().make_tile_range(*it); + const auto vol = tile_range.volume(); + + // Build deterministic data so seeds match across allocators. + const auto ord = *it; + typename DeviceArray::value_type d_tile(tile_range); + typename HostArrayT::value_type h_tile(tile_range); + for (std::size_t k = 0; k < vol; ++k) { + // 1000-element period is plenty for unit testing; division keeps + // values in [-5, 5] so dot products stay representable. + const double v = + static_cast(((ord + 1) * 1664525u + seed + k) % 1000) / + 100.0 - + 5.0; + d_tile.data()[k] = v; + h_tile.data()[k] = v; + } + d.set(*it, d_tile); + h.set(*it, h_tile); + } + } + + // Compare every element of two DistArrays with matching tiles. + template + static void check_close(const DeviceArrayT& d, const HostArrayT& h_ref, + double tol) { + GlobalFixture::world->gop.fence(); + for (auto it = d.begin(); it != d.end(); ++it) { + auto d_tile = it->get(); + auto h_tile = h_ref.find(it.index()).get(); + BOOST_REQUIRE_EQUAL(d_tile.range(), h_tile.range()); + for (std::size_t k = 0; k < d_tile.size(); ++k) { + BOOST_CHECK_CLOSE_FRACTION(d_tile.data()[k], h_tile.data()[k], tol); + } + } + } + + TArrayD a, b, c; + HostArray a_h, b_h, c_h; +}; + +BOOST_FIXTURE_TEST_SUITE(device_tensor_expressions_suite, + DeviceTensorExpressionsFixture) + +BOOST_AUTO_TEST_CASE(is_device_tile_classification) { + using detail::is_device_tile_v; + BOOST_CHECK(is_device_tile_v>); + BOOST_CHECK(is_device_tile_v>); + BOOST_CHECK(is_device_tile_v); + BOOST_CHECK(!is_device_tile_v); +} + +BOOST_AUTO_TEST_CASE(direct_assign) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c")); + c_h("a,b,c") = a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a")); + c_h("a,b,c") = a_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.5 * a("a,b,c")); + c_h("a,b,c") = 2.5 * a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(neg) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = -a("a,b,c")); + c_h("a,b,c") = -a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") + b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(add_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") + b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") + b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(add_to) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c")); + c_h("a,b,c") = a_h("a,b,c"); + BOOST_REQUIRE_NO_THROW(c("a,b,c") += b("a,b,c")); + c_h("a,b,c") += b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt_to) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c")); + c_h("a,b,c") = a_h("a,b,c"); + BOOST_REQUIRE_NO_THROW(c("a,b,c") -= b("a,b,c")); + c_h("a,b,c") -= b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scaled_subt_right) { + // Isolate: scale-on-right only. `c = a - 3*b`. + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - 3.0 * b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") - 3.0 * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scaled_subt_left) { + // Isolate: scale-on-left only. `c = 2*a - b`. + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.0 * a("a,b,c") - b("a,b,c")); + c_h("a,b,c") = 2.0 * a_h("a,b,c") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(mixed_linear_combination) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.0 * a("a,b,c") - 3.0 * b("a,b,c")); + c_h("a,b,c") = 2.0 * a_h("a,b,c") - 3.0 * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(hadamard) { + // C(ijk) = A(ijk) .* B(ijk), element-wise multiplication + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") * b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(contraction) { + // C(i,k) = A(i,j) * B(j,k) requires rank-2 arrays; build them on the fly + // using the first slice of `tr` so the fixture data is reusable. + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 13); + fill_with_seed(b2, b2_h, 17); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k") = a2("i,j") * b2("j,k")); + c2_h("i,k") = a2_h("i,j") * b2_h("j,k"); + // GEMM tolerance: float-add reordering between BLAS and CPU Eigen path. + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(norm2_value) { + // Scalar reduction across all tiles. Compare device-computed value against + // CPU-computed value from the mirror array. + const double dev_norm = TA::norm2(a); + const double host_norm = TA::norm2(a_h); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_norm, host_norm, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(dot_value) { + // dot expression: scalar = a . b + double dev_dot = static_cast(a("a,b,c") * b("a,b,c")); + double host_dot = static_cast(a_h("a,b,c") * b_h("a,b,c")); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_dot, host_dot, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(reuse_stress) { + // MPQC-pattern stress: same input tile referenced multiple times in one + // expression, then again across iterations. Catches the LazyArrayTile + // conversion race if it surfaces (it should be a known master-branch + // baseline failure -- not introduced by this branch). + const double host_ref = + static_cast(a_h("a,b,c") * a_h("a,b,c")); + GlobalFixture::world->gop.fence(); + for (int iter = 0; iter < 8; ++iter) { + const double d = static_cast(a("a,b,c") * a("a,b,c")); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(d, host_ref, 1.0e-12); + } +} + +BOOST_AUTO_TEST_SUITE_END() + +#endif // TILEDARRAY_HAS_DEVICE From d7add874d8c82384e1be420edeed04cbefc10cfe Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 23:01:24 +0000 Subject: [PATCH 05/20] device: expand UMTensor expression tests (in-place ops, blocks, einsum, contraction variants) Goes from 18 to 59 test cases, covering the patterns the expression engine uses in production but the original cut didn't exercise. New coverage: Elementary ops + permutations: * In-place expression operators (+=, -=, *=, += with permute). * Negation in compounds: -(2*(a+b)), -a("c,b,a"). * All scale/permute combinations for add/subt/mult: scale_add, scale_add_permute, subt_permute, scale_subt, scale_subt_permute, mult_permute, scale_mult, scale_mult_permute. Dataflow + reduction: * Multi-step dataflow chain (t = a + b; c = 2*t - a) -- engine wires t's Future into the next dist-eval without an intervening fence (per CLAUDE.md's synchronization hierarchy). * Contraction-plus-reduce (norm2 of a contraction result). * no_alias() + reduce -- exercises the LHS-doesn't-alias-RHS optimization through to a reduction. Block expressions (PR 531 trouble area): * Basic block assign / scaled-sum / accumulate. * const_block: block from a const reference. * scal_block: 2 * a.block(...). * permute_block: a("c,b,a").block(...). * assign_sub_block: write into a tile sub-block of an existing array. * block_contract, block_permute_contract: blocks fed into GEMM. Contraction variants: * Outer product (rank-changing GEMM, no shared contraction index). * Permuted result (c("k,i") = a("i,j") * b("j,k")). * Transpose-on-right input (c("i,k") = a("i,j") * b("k,j")). * CC-style rank-4: r("a,c") = t("a,b,k,l") * v("c,b,k,l"). * scale_cont, scale_cont_permute, scale_cont_with_input_transpose -- scale-fuse-into-GEMM paths. * cont_non_uniform_split_inner / _two_inner -- catches GEMM kernels that silently assume uniform tile shapes. TA::einsum entry point: * Matmul, Hadamard, two-index contraction. Documents the one pattern not covered (`einsum("ij,jk->ijk")` with an index in both inputs and output) -- it segfaults inside einsum's internals on master regardless of allocator, out of scope for this branch. Dot variants: * dot_permute: dot of permuted arrays. * dot_contr: dot of two contraction expressions (one tier deeper than btas-device's NO_THROW-only version; we validate the scalar value). Tolerances are 5e-14 for non-GEMM ops, 1e-10 to 1e-9 for GEMM-bearing paths to absorb summation-order differences between BLAS++ on the device and the Eigen-based CPU reference. Validation: * All 59 device-tensor cases pass. * Full np=1 ta_test suite (1880 cases, 12.56M assertions) still green -- no regressions from this branch. --- tests/expressions_device_tensor.cpp | 661 ++++++++++++++++++++++++++++ 1 file changed, 661 insertions(+) diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp index 204a7a34f1..16deb01100 100644 --- a/tests/expressions_device_tensor.cpp +++ b/tests/expressions_device_tensor.cpp @@ -24,6 +24,7 @@ #ifdef TILEDARRAY_HAS_DEVICE #include +#include #include #include #include "unit_test_config.h" @@ -259,6 +260,666 @@ BOOST_AUTO_TEST_CASE(reuse_stress) { } } +// --------------------------------------------------------------------------- +// In-place expression operators (+=, -=, *=). These exercise the engine's +// "result is consumable" paths that surfaced the dispatch + sign-flip bugs; +// here we want broad coverage of compound assignment forms beyond the +// `add_to` / `subt_to` cases already tested. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(plus_equal_expr) { + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") += b("a,b,c"); + c_h("a,b,c") += b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(plus_equal_with_permute) { + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") += b("c,b,a"); + c_h("a,b,c") += b_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(minus_equal_expr) { + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") -= b("a,b,c"); + c_h("a,b,c") -= b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(times_equal_expr) { + // Hadamard, in place. + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") *= b("a,b,c"); + c_h("a,b,c") *= b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +// --------------------------------------------------------------------------- +// Negated and scaled-then-negated forms. These force the engine to combine +// scaling with sign-flip across different operand positions. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(neg_scaled_sum) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = -(2.0 * (a("a,b,c") + b("a,b,c")))); + c_h("a,b,c") = -(2.0 * (a_h("a,b,c") + b_h("a,b,c"))); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(neg_permuted) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = -a("c,b,a")); + c_h("a,b,c") = -a_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +// --------------------------------------------------------------------------- +// Multi-step chains: results of one expression feed the next. Validates +// dataflow handoff between dist-evals without an intervening fence (per +// CLAUDE.md's synchronization-hierarchy section). +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(multi_step_chain) { + TArrayD t(*GlobalFixture::world, tr); + HostArray t_h(*GlobalFixture::world, tr); + t("a,b,c") = a("a,b,c") + b("a,b,c"); + t_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); + c("a,b,c") = 2.0 * t("a,b,c") - a("a,b,c"); + c_h("a,b,c") = 2.0 * t_h("a,b,c") - a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +// --------------------------------------------------------------------------- +// Block expressions. PR 531 hit known issues in this area; we cover the +// common patterns: read-only block, block in a sum, block on the RHS of an +// accumulating assignment. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(block_assign) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + + // Result range matches the block's element range; build small companion + // arrays to receive the result. + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = a("a,b,c").block(lo, up); + blk_h("a,b,c") = a_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(block_add_then_scale) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = 2.0 * (a("a,b,c").block(lo, up) + b("a,b,c").block(lo, up)); + blk_h("a,b,c") = + 2.0 * (a_h("a,b,c").block(lo, up) + b_h("a,b,c").block(lo, up)); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(block_accumulate) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = a("a,b,c").block(lo, up); + blk_h("a,b,c") = a_h("a,b,c").block(lo, up); + blk_d("a,b,c") += b("a,b,c").block(lo, up); + blk_h("a,b,c") += b_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +// --------------------------------------------------------------------------- +// Outer product: c(i,j) = u(i) * v(j). Exercises the rank-changing GEMM +// path (different left / right / result ranks) without going through a +// shared contraction index. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(outer_product) { + const TiledRange tr_u{tr.data()[0]}; + const TiledRange tr_v{tr.data()[1]}; + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + + TArrayD u(*GlobalFixture::world, tr_u); + TArrayD v(*GlobalFixture::world, tr_v); + TArrayD w; + HostArray u_h(*GlobalFixture::world, tr_u); + HostArray v_h(*GlobalFixture::world, tr_v); + HostArray w_h; + fill_with_seed(u, u_h, 31); + fill_with_seed(v, v_h, 37); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = u("i") * v("j")); + w_h("i,j") = u_h("i") * v_h("j"); + check_close(w, w_h, 1.0e-12); +} + +// --------------------------------------------------------------------------- +// Contraction shape variants. Different output ranks and contraction +// patterns. CC-style: result is rank-4, contraction index is multi-d. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(contraction_permuted_result) { + // c(k,i) = a(i,j) * b(j,k) -- the same contraction as `contraction` but + // with the result indices swapped; checks that the engine fuses a final + // permutation into the GEMM as CLAUDE.md describes. + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 41); + fill_with_seed(b2, b2_h, 43); + GlobalFixture::world->gop.fence(); + BOOST_REQUIRE_NO_THROW(c2("k,i") = a2("i,j") * b2("j,k")); + c2_h("k,i") = a2_h("i,j") * b2_h("j,k"); + // Looser tolerance for permuted GEMM: BLAS sums in different + // tile-internal order than the CPU reference path. + check_close(c2, c2_h, 1.0e-10); +} + +BOOST_AUTO_TEST_CASE(contraction_with_transpose_on_right) { + // c(i,k) = a(i,j) * b(k,j) -- right operand needs transposing to align + // the contraction index. + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 47); + fill_with_seed(b2, b2_h, 53); + GlobalFixture::world->gop.fence(); + BOOST_REQUIRE_NO_THROW(c2("i,k") = a2("i,j") * b2("k,j")); + c2_h("i,k") = a2_h("i,j") * b2_h("k,j"); + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(contraction_rank4_via_two_indices) { + // r(a,c) = t(a,b,k,l) * v(c,b,k,l) -- pattern that shows up in CC-style + // intermediates; contraction is over (b,k,l), free indices are (a) on + // the left and (c) on the right. + const TiledRange tr4{tr.data()[0], tr.data()[1], tr.data()[2], tr.data()[2]}; + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD t(*GlobalFixture::world, tr4); + TArrayD v(*GlobalFixture::world, tr4); + TArrayD r; + HostArray t_h(*GlobalFixture::world, tr4); + HostArray v_h(*GlobalFixture::world, tr4); + HostArray r_h; + fill_with_seed(t, t_h, 59); + fill_with_seed(v, v_h, 61); + GlobalFixture::world->gop.fence(); + BOOST_REQUIRE_NO_THROW(r("a,c") = t("a,b,k,l") * v("c,b,k,l")); + r_h("a,c") = t_h("a,b,k,l") * v_h("c,b,k,l"); + check_close(r, r_h, 1.0e-12); +} + +// --------------------------------------------------------------------------- +// TA::einsum entry point. The fully-typed einsum API is the documented way +// to express patterns the regular `*` operator can't capture (general +// contraction with explicit output indices, Hadamard with permutation, +// etc.). For UMTensor we test that einsum dispatches through the same tile +// ops we already validated above and produces matching results vs. the +// host-tensor reference. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(einsum_matmul) { + // c(i,k) = a(i,j) * b(j,k) via einsum + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 67); + fill_with_seed(b2, b2_h, 71); + GlobalFixture::world->gop.fence(); + + auto c2 = TiledArray::einsum(a2("i,j"), b2("j,k"), "i,k"); + auto c2_h = TiledArray::einsum(a2_h("i,j"), b2_h("j,k"), "i,k"); + check_close(c2, c2_h, 1.0e-11); +} + +BOOST_AUTO_TEST_CASE(einsum_hadamard) { + // c(i,j) = a(i,j) * b(i,j) via einsum -- Hadamard / element-wise multiply + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 73); + fill_with_seed(b2, b2_h, 79); + GlobalFixture::world->gop.fence(); + + auto c2 = TiledArray::einsum(a2("i,j"), b2("i,j"), "i,j"); + auto c2_h = TiledArray::einsum(a2_h("i,j"), b2_h("i,j"), "i,j"); + check_close(c2, c2_h, tolerance); +} + +// Note: einsum patterns where an index appears in both inputs *and* the +// output (e.g. `einsum("ij,jk->ijk")`, an outer-product-with-broadcast +// over `j`) are not yet supported for plain (non-ToT) tile types -- they +// segfault inside einsum's internals on master regardless of allocator. +// We don't cover that case here. + +BOOST_AUTO_TEST_CASE(einsum_contraction_over_two_indices) { + // c(a,c) = t(a,b,k) * v(c,b,k) via einsum -- contraction over (b, k), + // free indices (a) on the left and (c) on the right. CC-intermediate + // shape, fully expressible with the regular `*` operator but still + // worth covering through the einsum entry point. + const TiledRange tr3{tr.data()[0], tr.data()[1], tr.data()[2]}; + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD t(*GlobalFixture::world, tr3); + TArrayD v(*GlobalFixture::world, tr3); + HostArray t_h(*GlobalFixture::world, tr3); + HostArray v_h(*GlobalFixture::world, tr3); + fill_with_seed(t, t_h, 83); + fill_with_seed(v, v_h, 89); + GlobalFixture::world->gop.fence(); + + auto r = TiledArray::einsum(t("a,b,k"), v("c,b,k"), "a,c"); + auto r_h = TiledArray::einsum(t_h("a,b,k"), v_h("c,b,k"), "a,c"); + check_close(r, r_h, 1.0e-11); +} + +BOOST_AUTO_TEST_CASE(einsum_permuted_result) { + // c(j,i) = a(i,j) -- one-operand reshape; not a true einsum binary, but + // also useful: verify einsum handles single-input permutation. + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 97); + GlobalFixture::world->gop.fence(); + + // For permutation alone we just use the expression DSL, which einsum + // delegates to; this verifies that path still works for UMTensor. + TArrayD a2T; + HostArray a2T_h; + a2T("j,i") = a2("i,j"); + a2T_h("j,i") = a2_h("i,j"); + check_close(a2T, a2T_h, tolerance); +} + +// --------------------------------------------------------------------------- +// Scaled / permuted variants of the elementary arithmetic ops. The first +// commit of these tests covered the bare forms; the engine fuses scaling +// and permutation differently across these combinations, so each one is +// a distinct dispatch path worth validating numerically. +// --------------------------------------------------------------------------- + +BOOST_AUTO_TEST_CASE(scale_add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5.0 * (a("a,b,c") + b("a,b,c"))); + c_h("a,b,c") = 5.0 * (a_h("a,b,c") + b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_add_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5.0 * (2.0 * a("c,b,a")) + (3.0 * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (2.0 * a_h("c,b,a")) + (3.0 * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") - b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5.0 * (a("a,b,c") - b("a,b,c"))); + c_h("a,b,c") = 5.0 * (a_h("a,b,c") - b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_subt_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5.0 * (2.0 * a("c,b,a")) - (3.0 * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (2.0 * a_h("c,b,a")) - (3.0 * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(mult_permute) { + // Hadamard with permutation on left operand. + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") * b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_mult) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5.0 * (a("a,b,c") * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (a_h("a,b,c") * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_mult_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5.0 * (2.0 * a("c,b,a")) * (3.0 * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (2.0 * a_h("c,b,a")) * (3.0 * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +// --------------------------------------------------------------------------- +// Scaled contraction variants. These exercise the engine's scale-fuse- +// into-GEMM path that PR 531 stumbled on. Tolerance is 1e-10 for GEMM +// paths to absorb summation-order differences between BLAS and Eigen. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(scale_cont) { + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 101); + fill_with_seed(b2, b2_h, 103); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k") = 5.0 * (a2("i,j") * b2("j,k"))); + c2_h("i,k") = 5.0 * (a2_h("i,j") * b2_h("j,k")); + check_close(c2, c2_h, 1.0e-10); +} + +BOOST_AUTO_TEST_CASE(scale_cont_permute) { + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 107); + fill_with_seed(b2, b2_h, 109); + GlobalFixture::world->gop.fence(); + + // c(k,i) = 5 * a(i,j) * b(j,k): scaled, result-permuted contraction. + BOOST_REQUIRE_NO_THROW(c2("k,i") = 5.0 * (a2("i,j") * b2("j,k"))); + c2_h("k,i") = 5.0 * (a2_h("i,j") * b2_h("j,k")); + check_close(c2, c2_h, 1.0e-10); +} + +BOOST_AUTO_TEST_CASE(scale_cont_with_input_transpose) { + // 5 * a(i,j) * b(k,j) -- contraction needs to transpose b before GEMM. + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 113); + fill_with_seed(b2, b2_h, 127); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k") = 5.0 * (a2("i,j") * b2("k,j"))); + c2_h("i,k") = 5.0 * (a2_h("i,j") * b2_h("k,j")); + check_close(c2, c2_h, 1.0e-10); +} + +// --------------------------------------------------------------------------- +// Non-uniform tile sizes for contraction. Mirrors btas-device's +// cont_non_uniform1/2: the rank-4 inputs use one tiny tiling on the +// outer dimensions and one wide tiling on an inner dimension, so the +// GEMM has irregular per-tile k blocks. Catches GEMM kernels that +// silently assume uniform tile shapes. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(cont_non_uniform_split_inner) { + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange tr_irr(tiling4.begin(), tiling4.end()); + + TArrayD lhs(*GlobalFixture::world, tr_irr); + TArrayD rhs(*GlobalFixture::world, tr_irr); + TArrayD out; + HostArray lhs_h(*GlobalFixture::world, tr_irr); + HostArray rhs_h(*GlobalFixture::world, tr_irr); + HostArray out_h; + fill_with_seed(lhs, lhs_h, 131); + fill_with_seed(rhs, rhs_h, 137); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(out("x,y") = + 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); + out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); + check_close(out, out_h, 1.0e-9); +} + +BOOST_AUTO_TEST_CASE(cont_non_uniform_split_two_inner) { + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_1, tr1_2, tr1_2}}; + TiledRange tr_irr(tiling4.begin(), tiling4.end()); + + TArrayD lhs(*GlobalFixture::world, tr_irr); + TArrayD rhs(*GlobalFixture::world, tr_irr); + TArrayD out; + HostArray lhs_h(*GlobalFixture::world, tr_irr); + HostArray rhs_h(*GlobalFixture::world, tr_irr); + HostArray out_h; + fill_with_seed(lhs, lhs_h, 139); + fill_with_seed(rhs, rhs_h, 149); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(out("x,y") = + 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); + out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); + check_close(out, out_h, 1.0e-9); +} + +// --------------------------------------------------------------------------- +// Contraction-plus-reduction (norm2 of a contraction). Exercises the +// dataflow handoff from a binary dist-eval to a reduction without an +// intervening fence. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(cont_plus_reduce) { + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 151); + fill_with_seed(b2, b2_h, 157); + GlobalFixture::world->gop.fence(); + + TArrayD c2; + HostArray c2_h; + c2("i,k") = a2("i,j") * b2("j,k"); + c2_h("i,k") = a2_h("i,j") * b2_h("j,k"); + const double dev_n = TA::norm2(c2); + const double host_n = TA::norm2(c2_h); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-10); +} + +BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { + // `no_alias()` tells the engine the LHS does not alias any RHS operand, + // permitting an extra in-place optimization. Validate that path + // produces correct values. + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2(*GlobalFixture::world, + TiledRange{tr.data()[0], tr.data()[1]}); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h(*GlobalFixture::world, + TiledRange{tr.data()[0], tr.data()[1]}); + fill_with_seed(a2, a2_h, 163); + fill_with_seed(b2, b2_h, 167); + c2.fill_local(0.0); + c2_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k").no_alias() = a2("i,j") * b2("j,k")); + c2_h("i,k").no_alias() = a2_h("i,j") * b2_h("j,k"); + check_close(c2, c2_h, 1.0e-10); + const double dev_n = TA::norm2(c2); + const double host_n = TA::norm2(c2_h); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-10); +} + +// --------------------------------------------------------------------------- +// Block-expression variants beyond the basic three already covered. +// Block bounds are TILE coordinates; a {3,3,3} -> {5,5,5} block selects +// the 2x2x2 corner tiles of `tr` (5 tiles per dim). +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(const_block) { + const auto& ca = a; + const auto& ca_h = a_h; + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = ca("a,b,c").block(lo, up); + blk_h("a,b,c") = ca_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scal_block) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = 2.0 * a("a,b,c").block(lo, up); + blk_h("a,b,c") = 2.0 * a_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(permute_block) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + // Permute the source annotation before slicing. + blk_d("a,b,c") = a("c,b,a").block(lo, up); + blk_h("a,b,c") = a_h("c,b,a").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(assign_sub_block) { + // Write into a tile sub-block of an existing array. Tiles outside the + // block keep their original contents -- so we initialize both sides + // identically with a known value before the block assignment. + c.fill_local(0.0); + c_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + BOOST_REQUIRE_NO_THROW(c("a,b,c").block(lo, up) = a("a,b,c").block(lo, up)); + c_h("a,b,c").block(lo, up) = a_h("a,b,c").block(lo, up); + check_close(c, c_h, tolerance); +} + +// --------------------------------------------------------------------------- +// Block-fed-into-contraction. PR 531 had known issues here. The result +// array has rank 2 (carved out of the rank-3 fixture by contracting two +// indices). +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(block_contract) { + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + TArrayD w(*GlobalFixture::world, tr_w); + HostArray w_h(*GlobalFixture::world, tr_w); + w.fill_local(0.0); + w_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array alo{3, 2, 3}; + const std::array aup{5, 5, 5}; + const std::array blo{2, 3, 3}; + const std::array bup{5, 5, 5}; + + BOOST_REQUIRE_NO_THROW( + w("a,b") = a("a,c,d").block(alo, aup) * b("c,d,b").block(blo, bup)); + w_h("a,b") = a_h("a,c,d").block(alo, aup) * b_h("c,d,b").block(blo, bup); + check_close(w, w_h, 1.0e-10); +} + +BOOST_AUTO_TEST_CASE(block_permute_contract) { + // Same as block_contract but with a permuted left-operand annotation: + // `a("a,d,c")` instead of `a("a,c,d")` -- forces a permutation of the + // sliced block before GEMM. + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + TArrayD w(*GlobalFixture::world, tr_w); + HostArray w_h(*GlobalFixture::world, tr_w); + w.fill_local(0.0); + w_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array alo{3, 3, 2}; + const std::array aup{5, 5, 5}; + const std::array blo{2, 3, 3}; + const std::array bup{5, 5, 5}; + + BOOST_REQUIRE_NO_THROW( + w("a,b") = a("a,d,c").block(alo, aup) * b("c,d,b").block(blo, bup)); + w_h("a,b") = a_h("a,d,c").block(alo, aup) * b_h("c,d,b").block(blo, bup); + check_close(w, w_h, 1.0e-10); +} + +// --------------------------------------------------------------------------- +// Dot-product variants beyond the basic case. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(dot_permute) { + const double dev_d = + static_cast(a("a,b,c") * b("c,b,a")); + const double host_d = + static_cast(a_h("a,b,c") * b_h("c,b,a")); + GlobalFixture::world->gop.fence(); + // Looser tolerance because permuted dot reads tiles in a different + // order, so the partial-sum accumulation order differs. + BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(dot_contr) { + // Dot of two contraction expressions: scalar = (a*b) . (b*a). + // This is a NO_THROW-only check in the btas-device suite; we go one + // step further and validate the scalar value against the CPU mirror. + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 173); + fill_with_seed(b2, b2_h, 179); + GlobalFixture::world->gop.fence(); + + const double dev_d = static_cast( + (a2("i,j") * b2("j,k")) * (a2("i,j") * b2("j,k"))); + const double host_d = static_cast( + (a2_h("i,j") * b2_h("j,k")) * (a2_h("i,j") * b2_h("j,k"))); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-10); +} + BOOST_AUTO_TEST_SUITE_END() #endif // TILEDARRAY_HAS_DEVICE From c53eab48afca333cb46fe22938ea29e3882c270d Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 23:29:43 +0000 Subject: [PATCH 06/20] device: add UMTensor archive support + array-level helpers Phase 3 surface: serialization plus the bulk DistArray-level helpers (prefetch and host<->device conversion) needed by downstream code that wants to round-trip UMTensor data through MADNESS archives or shuttle DistArrays between memory spaces. src/TiledArray/device/tensor.h: * `to_host(DistArray, P>&)` and `to_device(DistArray, P>&)`: bulk prefetch every local tile, fence the world, deviceSynchronize on exit. Mirrors the btas helpers in btas_um_tensor.h:567-617 but takes the bare UMTensor (no `TA::Tile<>` wrapper, per CLAUDE.md). * `um_tensor_to_ta_tensor(DistArray, P>)` and `ta_tensor_to_um_tensor(DistArray, P>)`: tile-by-tile conversion via `to_new_tile_type`. The per-tile lambda allocates a result of the target tile type, prefetches the source as needed, and memcpys -- since both sides are TA::Tensor and only the allocator differs, no per-element conversion is required. * `madness::archive::ArchiveStoreImpl>`: prefetches the tile to host before serializing, then writes the same fields TA::Tensor::serialize would (empty/range/nbatch/wrap(data)). The default member serialize is not safe to use as-is because UM data may be stale on the host while a device kernel is in flight. * `madness::archive::ArchiveLoadImpl>`: reconstructs the tensor in UM (writes go through host pages of UM -- if downstream code wants the data on the device, it should call `to_device` explicitly). tests/expressions_device_tensor.cpp: Five new cases (now 64 total): * serialize_um_tensor: single-tile round-trip through BufferOutput/ Input archives. * serialize_um_tensor_empty: empty branch in Store/Load. * um_to_ta_round_trip: device array -> host array -> device array, values preserved across both legs. * um_to_ta_then_expression: a host expression on a converted-from- device array matches the same expression on the host mirror. * bulk_prefetch_round_trip: to_host/to_device on a DistArray are no-ops for correctness (they only adjust page residency hints). Validation: * All 64 device-tensor cases pass. * Full np=1 ta_test suite: 1885 cases, 12.64M assertions -- still green; no regressions from this branch. --- src/TiledArray/device/tensor.h | 141 ++++++++++++++++++++++++++++ tests/expressions_device_tensor.cpp | 89 ++++++++++++++++++ 2 files changed, 230 insertions(+) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 185db47813..dda40ed1ae 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -26,6 +26,7 @@ #ifdef TILEDARRAY_HAS_DEVICE +#include #include #include #include @@ -39,6 +40,7 @@ #include #include +#include namespace TiledArray { namespace detail { @@ -759,8 +761,147 @@ inline void gemm(UMTensor& result, const UMTensor& left, device::sync_madness_task_with(stream); } +// --------------------------------------------------------------------------- +// Array-level helpers: bulk to-host / to-device prefetch and conversions +// between UMTensor-backed and host-Tensor-backed DistArrays. Mirrors the +// btas-device helpers in btas_um_tensor.h:567-617 but for the bare +// TA::Tensor specialization -- so the tile type is `UMTensor` directly, +// not wrapped in `TA::Tile<...>`. +// +// `to_host` / `to_device` are oneshot bulk-prefetch routines: they walk the +// pmap, dispatch one prefetch task per local tile, fence, then issue a +// `deviceSynchronize` to make sure every stream has drained. They're +// "stop the world" by design -- intended for explicit synchronization +// points (before a host read, after a load, etc.), not for inner loops. +// --------------------------------------------------------------------------- + +/// Prefetch every local tile of `array` to the host. Fences on the +/// containing world and globally synchronizes the device on exit. +template +inline void to_host(TiledArray::DistArray, Policy>& array) { + auto prefetch = [](UMTensor& tile) { + auto stream = device::stream_for(tile.range()); + detail::to_host(tile); + device::sync_madness_task_with(stream); + }; + auto& world = array.world(); + for (auto it = array.pmap()->begin(); it != array.pmap()->end(); ++it) { + if (!array.is_zero(*it)) world.taskq.add(prefetch, array.find(*it)); + } + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +/// Prefetch every local tile of `array` to the device. Fences on the +/// containing world and globally synchronizes the device on exit. +template +inline void to_device(TiledArray::DistArray, Policy>& array) { + auto prefetch = [](UMTensor& tile) { + auto stream = device::stream_for(tile.range()); + detail::to_device(tile); + device::sync_madness_task_with(stream); + }; + auto& world = array.world(); + for (auto it = array.pmap()->begin(); it != array.pmap()->end(); ++it) { + if (!array.is_zero(*it)) world.taskq.add(prefetch, array.find(*it)); + } + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +/// Convert a UMTensor-backed `DistArray` to one backed by host +/// `TA::Tensor`. Tile-by-tile copy through `to_new_tile_type` -- the +/// per-tile lambda allocates a host result, prefetches the source UM +/// buffer to host, and memcpys. +template +inline TiledArray::DistArray, Policy> +um_tensor_to_ta_tensor( + const TiledArray::DistArray, Policy>& um_array) { + auto convert_tile = [](const UMTensor& tile) { + detail::to_host(tile); + TiledArray::Tensor result(tile.range()); + std::copy_n(tile.data(), tile.total_size(), result.data()); + return result; + }; + auto out = to_new_tile_type>(um_array, convert_tile); + um_array.world().gop.fence(); + return out; +} + +/// Convert a host `TA::Tensor`-backed `DistArray` to a UMTensor-backed +/// one. Tile-by-tile copy: allocate UM, memcpy, prefetch to device. +template +inline TiledArray::DistArray, Policy> ta_tensor_to_um_tensor( + const TiledArray::DistArray, Policy>& host_array) { + auto convert_tile = [](const TiledArray::Tensor& tile) { + UMTensor result(tile.range()); + std::copy_n(tile.data(), tile.total_size(), result.data()); + detail::to_device(result); + return result; + }; + auto out = to_new_tile_type>(host_array, convert_tile); + host_array.world().gop.fence(); + return out; +} + } // namespace TiledArray +// --------------------------------------------------------------------------- +// MADNESS archive specializations for UMTensor. +// +// `TA::Tensor::serialize(ar)` works on any allocator (the member just walks +// `data() + range().volume() * nbatch()`), but UM data may be stale on the +// host if a device kernel is in flight. The Store specialization prefetches +// the tile back to the host before reading. Load goes through the default +// member -- the freshly constructed UM-allocated tile is host-writable, so +// no additional prefetch is needed (downstream code that wants the data on +// the device should call `to_device` explicitly). +// --------------------------------------------------------------------------- +namespace madness { +namespace archive { + +template +struct ArchiveStoreImpl> { + static inline void store(const Archive& ar, + const TiledArray::UMTensor& t) { + TiledArray::detail::to_host(t); + // Mirror TA::Tensor::serialize's store side; we cannot call the member + // because it is non-const and we want to keep the input parameter + // const-correct. + const bool empty = t.empty(); + ar & empty; + if (!empty) { + ar & t.range(); + ar & t.nbatch(); + ar & madness::archive::wrap(t.data(), + t.range().volume() * t.nbatch()); + } + } +}; + +template +struct ArchiveLoadImpl> { + static inline void load(const Archive& ar, TiledArray::UMTensor& t) { + bool empty = false; + ar & empty; + if (!empty) { + TiledArray::Range range; + std::size_t nbatch = 1; + ar & range; + ar & nbatch; + t = TiledArray::UMTensor( + std::move(range), typename TiledArray::UMTensor::nbatches(nbatch)); + ar & madness::archive::wrap(t.data(), + t.range().volume() * t.nbatch()); + } else { + t = TiledArray::UMTensor(); + } + } +}; + +} // namespace archive +} // namespace madness + #endif // TILEDARRAY_HAS_DEVICE #endif // TILEDARRAY_DEVICE_TENSOR_H diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp index 16deb01100..68ac90248d 100644 --- a/tests/expressions_device_tensor.cpp +++ b/tests/expressions_device_tensor.cpp @@ -920,6 +920,95 @@ BOOST_AUTO_TEST_CASE(dot_contr) { BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-10); } +// --------------------------------------------------------------------------- +// Phase 3 surface: archive round-trip, host/device array conversions, and +// bulk to_host / to_device. Smoke + correctness for the helpers in +// device/tensor.h that are not in the expression-engine path. +// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(serialize_um_tensor) { + // Single-tile round-trip: build a UMTensor, write to a buffer archive, + // read into a fresh UMTensor, compare element-wise. The Store side + // forces a host prefetch so the archive sees coherent data. + TileD t(TiledArray::Range{4, 4}); + for (std::size_t k = 0; k < t.size(); ++k) + t.data()[k] = static_cast(k) - 5.0; + + const std::size_t buf_size = + (t.range().volume() * sizeof(double) + + sizeof(std::size_t) * (t.range().rank() * 4 + 4)) * + 2; + std::vector buf(buf_size); + + madness::archive::BufferOutputArchive oar(buf.data(), buf.size()); + BOOST_REQUIRE_NO_THROW(oar & t); + const std::size_t nbyte = oar.size(); + oar.close(); + + TileD u; + madness::archive::BufferInputArchive iar(buf.data(), nbyte); + BOOST_REQUIRE_NO_THROW(iar & u); + iar.close(); + + BOOST_REQUIRE_EQUAL(t.range(), u.range()); + for (std::size_t k = 0; k < t.size(); ++k) + BOOST_CHECK_CLOSE(t.data()[k], u.data()[k], 1.0e-15); +} + +BOOST_AUTO_TEST_CASE(serialize_um_tensor_empty) { + // Empty-tensor round-trip: the empty branch in Store/Load. + TileD t; + std::vector buf(1024); + madness::archive::BufferOutputArchive oar(buf.data(), buf.size()); + BOOST_REQUIRE_NO_THROW(oar & t); + const std::size_t nbyte = oar.size(); + oar.close(); + + TileD u(TiledArray::Range{2, 2}); // start non-empty so load has work + madness::archive::BufferInputArchive iar(buf.data(), nbyte); + BOOST_REQUIRE_NO_THROW(iar & u); + iar.close(); + BOOST_CHECK(u.empty()); +} + +BOOST_AUTO_TEST_CASE(um_to_ta_round_trip) { + // UMTensor array -> host array -> UMTensor array, verify element-wise + // against the original both at the host and device endpoints. + HostArray host_view = TiledArray::um_tensor_to_ta_tensor(a); + GlobalFixture::world->gop.fence(); + // Compare host_view directly against a_h (same data was used for both). + check_close(host_view, a_h, tolerance); + + TArrayD device_view = TiledArray::ta_tensor_to_um_tensor(host_view); + GlobalFixture::world->gop.fence(); + // Round-trip: device_view should match the original `a`. + check_close(device_view, a_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(um_to_ta_then_expression) { + // After converting a device array to host, plain CPU expressions on the + // host array should produce the same values as their device counterpart. + HostArray sum_h(*GlobalFixture::world, tr); + sum_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); + + HostArray converted = TiledArray::um_tensor_to_ta_tensor(a); + HostArray sum_from_device(*GlobalFixture::world, tr); + sum_from_device("a,b,c") = converted("a,b,c") + b_h("a,b,c"); + + GlobalFixture::world->gop.fence(); + check_close(sum_from_device, sum_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(bulk_prefetch_round_trip) { + // to_host / to_device on a DistArray should be no-ops for correctness: + // the array contents are unchanged, only the page residency hints are + // adjusted. We just verify the array is still equal to its host mirror + // after bouncing through both directions. + TiledArray::to_host(a); + TiledArray::to_device(a); + GlobalFixture::world->gop.fence(); + check_close(a, a_h, tolerance); +} + BOOST_AUTO_TEST_SUITE_END() #endif // TILEDARRAY_HAS_DEVICE From 32f61315e4b7c34835f172653082ac5017d22605 Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 23:38:20 +0000 Subject: [PATCH 07/20] device: explicit instantiations of UMTensor for standard numeric types Phase 4: instantiate `Tensor>` once in device/tensor.cpp and `extern template`-declare them in device/tensor.h, so each TU including device/tensor.h does not re-instantiate the full ~3000-line Tensor class body. Mirrors the pattern at the bottom of src/TiledArray/tensor/tensor.h + src/TiledArray/tensor/tensor.cpp for the host-side instantiations. The instantiated set is double, float, complex<{double,float}>, int, long -- a superset of the host set (which omits int/long) for parity with btas_um_tensor.cpp. BLAS-bearing free functions (gemm, scale, axpy-driven add/subt, ...) are left as header-defined templates; explicitly instantiating them would pull the full BLAS++/librett surface into device/tensor.cpp, and the build-time saving from extern-templating them does not justify it. They get instantiated lazily in whichever TU actually calls them (typically the test or example TU). Validation: * All 64 device-tensor cases still pass. * Full np=1 ta_test suite: 1885 cases, 12.63M assertions -- still green; no regressions from this change. --- src/TiledArray/device/tensor.cpp | 29 +++++++++++++++++++++++++++++ src/TiledArray/device/tensor.h | 20 ++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp index 25c643c5a4..4b5e128fea 100644 --- a/src/TiledArray/device/tensor.cpp +++ b/src/TiledArray/device/tensor.cpp @@ -23,6 +23,35 @@ #include +namespace TiledArray { + +// Explicit instantiations of the UMTensor class for the standard numeric +// types. Without these, every TU including device/tensor.h would instantiate +// the full TA::Tensor> class body (~3000 lines of +// templated members) -- the matching `extern template` declarations in +// device/tensor.h suppress that per-TU work and route consumers to the +// symbols defined here. +// +// The list mirrors `src/TiledArray/tensor/tensor.cpp`'s host-side set +// (double, float, complex variants), plus int/long which are cheap to +// instantiate and useful for index-tile use cases. BLAS-bearing free +// functions (`gemm`, `scale`, ...) are still header-defined templates -- +// instantiating those for each numeric type would pull in the full +// BLAS++/librett surface here, and the compile-time saving from +// extern-templating them does not justify it. They get instantiated lazily +// in whichever TU actually calls them (typically the test or example TU). + +template class Tensor>; +template class Tensor>; +template class Tensor, + device_um_allocator>>; +template class Tensor, + device_um_allocator>>; +template class Tensor>; +template class Tensor>; + +} // namespace TiledArray + namespace TiledArray::detail { // Phase 1 sanity: confirm the is_device_tile specialization fires for the diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index dda40ed1ae..e486b8784d 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -902,6 +902,26 @@ struct ArchiveLoadImpl> { } // namespace archive } // namespace madness +// --------------------------------------------------------------------------- +// `extern template` declarations for the UMTensor class. Match the explicit +// instantiations in src/TiledArray/device/tensor.cpp so that consumers do +// not re-instantiate the full Tensor> class body +// in each TU. (Mirrors the analogous pattern at the bottom of +// src/TiledArray/tensor/tensor.h for the host-side instantiations.) +// --------------------------------------------------------------------------- +namespace TiledArray { + +extern template class Tensor>; +extern template class Tensor>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor>; +extern template class Tensor>; + +} // namespace TiledArray + #endif // TILEDARRAY_HAS_DEVICE #endif // TILEDARRAY_DEVICE_TENSOR_H From 4dc2eeb8a2ae5143b03a49136d2871279d68252d Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 May 2026 23:55:28 +0000 Subject: [PATCH 08/20] device: add UMTensor dense + vector example programs Phase 6: two example programs that exercise the UMTensor surface end-to-end through real timing loops, mirroring the existing btas-based ta_dense_device.cpp and ta_vector_device.cpp but using the bare UMTensor tile type (no `TA::Tile<>` wrapper). * examples/device/ta_dense_um_tensor.cpp: c(Nm,Nn) = a(Nm,Nk) * b(Nk,Nn) on UMTensor tiles, blocked by user-supplied Bm/Bn/Bk. Reports per-iteration wall time and GFLOPS, then verifies every result element equals the analytic Nk * val_a * val_b. Honors `cudaProfilerStart/Stop` when built with CUDA so the timed loop is profilable separately from setup. * examples/device/ta_vector_um_tensor.cpp: Element-wise op benchmark -- add, subt, scale, Hadamard, permute, in-place axpy -- on Nm x Nn UMTensor matrices. Reports per-op average wall time and effective bandwidth (counting one read + one write per element for unary ops, two reads + one write for binary ops). Both examples follow the convention in CLAUDE.md / the rest of TA's device examples: TA_SCOPED_INITIALIZE for runtime setup, world fence before reading device-side results, exception-catching `main` wrapper. Hooked into examples/device/CMakeLists.txt's `foreach(_exec ...)` list so they build alongside the existing examples whenever TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP. Smoke-validated locally: ta_dense_um_tensor 256 64 256 64 256 64 3 -> Verification PASSED ta_vector_um_tensor 512 128 512 128 3 -> all six ops reported timings without error --- examples/device/CMakeLists.txt | 2 +- examples/device/ta_dense_um_tensor.cpp | 205 ++++++++++++++++++++++++ examples/device/ta_vector_um_tensor.cpp | 157 ++++++++++++++++++ 3 files changed, 363 insertions(+), 1 deletion(-) create mode 100644 examples/device/ta_dense_um_tensor.cpp create mode 100644 examples/device/ta_vector_um_tensor.cpp diff --git a/examples/device/CMakeLists.txt b/examples/device/CMakeLists.txt index 14da71efae..febe78c57a 100644 --- a/examples/device/CMakeLists.txt +++ b/examples/device/CMakeLists.txt @@ -25,7 +25,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - foreach(_exec device_task ta_dense_device ta_cc_abcd_device ta_vector_device ta_reduce_device) + foreach(_exec device_task ta_dense_device ta_cc_abcd_device ta_vector_device ta_reduce_device ta_dense_um_tensor ta_vector_um_tensor) # Add executable add_ta_executable(${_exec} "${_exec}.cpp" "tiledarray") diff --git a/examples/device/ta_dense_um_tensor.cpp b/examples/device/ta_dense_um_tensor.cpp new file mode 100644 index 0000000000..1902824379 --- /dev/null +++ b/examples/device/ta_dense_um_tensor.cpp @@ -0,0 +1,205 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +// Dense matrix-multiply benchmark using the native UMTensor tile type +// (TA::Tensor backed by device_um_allocator). Companion to the btas-based +// ta_dense_device.cpp; same shape + reporting, but the tile is bare +// `UMTensor` -- no `TA::Tile<>` wrapper -- and the data flows through +// the device tile-op overloads in src/TiledArray/device/tensor.h. +// +// Usage: +// ta_dense_um_tensor Nm Bm Nn Bn Nk Bk [nrepeat=5] +// +// Computes c(Nm,Nn) = a(Nm,Nk) * b(Nk,Nn) with each dimension blocked by +// Bm/Bn/Bk. Default scalar type is double; nrepeat iterations are timed +// for an average GFLOPS reading. + +#include +#include + +#ifdef TILEDARRAY_HAS_CUDA +#include +#endif + +#include +#include +#include + +namespace { + +template +void run(TiledArray::World &world, long Nm, long Bm, long Nn, long Bn, long Nk, + long Bk, long nrepeat) { + using TA::DistArray; + using TA::TiledRange; + using TA::TiledRange1; + using TA::UMTensor; + using TileT = UMTensor; + using ArrayT = DistArray; + + constexpr bool complex_T = TA::detail::is_complex_v; + // GEMM flops: 2 * M * N * K (8 * for complex). + const std::int64_t nflops = (complex_T ? 8 : 2) * + static_cast(Nm) * + static_cast(Nn) * + static_cast(Nk); + + auto blocking = [](long N, long B) { + std::vector v; + for (long i = 0; i <= N; i += B) v.push_back(static_cast(i)); + return v; + }; + auto blk_m = blocking(Nm, Bm); + auto blk_n = blocking(Nn, Bn); + auto blk_k = blocking(Nk, Bk); + + TiledRange trange_a({TiledRange1(blk_m.begin(), blk_m.end()), + TiledRange1(blk_k.begin(), blk_k.end())}); + TiledRange trange_b({TiledRange1(blk_k.begin(), blk_k.end()), + TiledRange1(blk_n.begin(), blk_n.end())}); + TiledRange trange_c({TiledRange1(blk_m.begin(), blk_m.end()), + TiledRange1(blk_n.begin(), blk_n.end())}); + + if (world.rank() == 0) + std::cout << "TiledArray UMTensor dense matrix multiply\n" + << " Nodes = " << world.size() << "\n" + << " A = " << Nm << " x " << Nk << " (" + << double(Nm * Nk * sizeof(T)) / 1.0e9 << " GB)\n" + << " B = " << Nk << " x " << Nn << " (" + << double(Nk * Nn * sizeof(T)) / 1.0e9 << " GB)\n" + << " C = " << Nm << " x " << Nn << " (" + << double(Nm * Nn * sizeof(T)) / 1.0e9 << " GB)\n" + << " Tile A,B,C = " << Bm << "x" << Bk << ", " << Bk << "x" + << Bn << ", " << Bm << "x" << Bn << "\n" + << " Iterations = " << nrepeat << "\n"; + + ArrayT a(world, trange_a); + ArrayT b(world, trange_b); + ArrayT c(world, trange_c); + + const T val_a = T(0.03); + const T val_b = T(0.02); + a.fill(val_a); + b.fill(val_b); + world.gop.fence(); + + // Prefetch inputs to the device once before the timed loop -- the per-tile + // ops will also prefetch lazily, but doing it up front keeps the timing + // focused on the GEMM kernel cost. + TA::to_device(a); + TA::to_device(b); + +#ifdef TILEDARRAY_HAS_CUDA + cudaProfilerStart(); +#endif + + double total_time = 0.0; + double total_gflops = 0.0; + for (long i = 0; i < nrepeat; ++i) { + const double t0 = madness::wall_time(); + c("m,n") = a("m,k") * b("k,n"); + world.gop.fence(); + const double t1 = madness::wall_time(); + const double dt = t1 - t0; + const double gflops = static_cast(nflops) / (dt * 1.0e9); + total_time += dt; + total_gflops += gflops; + if (world.rank() == 0) + std::cout << " iter " << (i + 1) << " time=" << dt + << " s gflops=" << gflops << "\n"; + } + +#ifdef TILEDARRAY_HAS_CUDA + cudaProfilerStop(); +#endif + + if (world.rank() == 0) + std::cout << " Average time = " << (total_time / double(nrepeat)) + << " s\n Average gflops = " + << (total_gflops / double(nrepeat)) << "\n"; + + // Verify: every result element should be Nk * val_a * val_b. + const T expected = T(Nk) * val_a * val_b; + const auto eps = std::numeric_limits>::epsilon(); + const auto tolerance = std::abs(expected) * static_cast(Nk) * + static_cast(8) * eps; + TA::to_host(c); + bool ok = true; + for (auto it = c.begin(); it != c.end(); ++it) { + const auto tile = it->get(); + for (std::size_t k = 0; k < tile.size(); ++k) { + if (std::abs(tile.data()[k] - expected) > tolerance) { + ok = false; + if (world.rank() == 0) + std::cout << " MISMATCH at tile " << it.index() << " element " << k + << ": got " << tile.data()[k] << " expected " << expected + << "\n"; + break; + } + } + if (!ok) break; + } + if (world.rank() == 0) + std::cout << (ok ? " Verification PASSED\n" : " Verification FAILED\n"); +} + +} // namespace + +int try_main(int argc, char **argv) { + TiledArray::World &world = TA_SCOPED_INITIALIZE(argc, argv); + + if (argc < 7) { + if (world.rank() == 0) + std::cerr + << "Usage: " << argv[0] + << " Nm Bm Nn Bn Nk Bk [nrepeat=5]\n" + << " Computes c(Nm,Nn) = a(Nm,Nk) * b(Nk,Nn) with UMTensor tiles\n"; + return 1; + } + const long Nm = std::atol(argv[1]); + const long Bm = std::atol(argv[2]); + const long Nn = std::atol(argv[3]); + const long Bn = std::atol(argv[4]); + const long Nk = std::atol(argv[5]); + const long Bk = std::atol(argv[6]); + const long nrepeat = (argc >= 8 ? std::atol(argv[7]) : 5); + if (Nm <= 0 || Nn <= 0 || Nk <= 0 || Bm <= 0 || Bn <= 0 || Bk <= 0 || + nrepeat <= 0) { + if (world.rank() == 0) + std::cerr << "All sizes / blocks / nrepeat must be positive\n"; + return 1; + } + + run(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + return 0; +} + +int main(int argc, char **argv) { + try { + return try_main(argc, argv); + } catch (const std::exception &e) { + std::cerr << "exception: " << e.what() << "\n"; + return 1; + } catch (...) { + std::cerr << "unknown exception\n"; + return 1; + } +} diff --git a/examples/device/ta_vector_um_tensor.cpp b/examples/device/ta_vector_um_tensor.cpp new file mode 100644 index 0000000000..8d0960be4f --- /dev/null +++ b/examples/device/ta_vector_um_tensor.cpp @@ -0,0 +1,157 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +// Element-wise vector-op benchmarks (add, scale, permute, Hadamard) using +// the native UMTensor tile type. Companion to ta_vector_device.cpp. +// +// Usage: +// ta_vector_um_tensor Nm Bm Nn Bn [nrepeat=5] +// +// Times each op for nrepeat iterations and reports the average wall time +// and effective bandwidth (counting one read + one write per element for +// in-place ops, two reads + one write for binary ops). + +#include +#include + +#include +#include +#include + +namespace { + +template +void run(TiledArray::World &world, long Nm, long Bm, long Nn, long Bn, + long nrepeat) { + using TA::DistArray; + using TA::TiledRange; + using TA::TiledRange1; + using TA::UMTensor; + using TileT = UMTensor; + using ArrayT = DistArray; + + auto blocking = [](long N, long B) { + std::vector v; + for (long i = 0; i <= N; i += B) v.push_back(static_cast(i)); + return v; + }; + auto blk_m = blocking(Nm, Bm); + auto blk_n = blocking(Nn, Bn); + + TiledRange trange({TiledRange1(blk_m.begin(), blk_m.end()), + TiledRange1(blk_n.begin(), blk_n.end())}); + TiledRange trange_T({TiledRange1(blk_n.begin(), blk_n.end()), + TiledRange1(blk_m.begin(), blk_m.end())}); + + if (world.rank() == 0) + std::cout << "TiledArray UMTensor vector-op benchmark\n" + << " Nodes = " << world.size() << "\n" + << " Matrix = " << Nm << " x " << Nn << " (" + << double(Nm * Nn * sizeof(T)) / 1.0e9 << " GB)\n" + << " Tile = " << Bm << " x " << Bn << "\n" + << " Iterations = " << nrepeat << "\n"; + + ArrayT a(world, trange); + ArrayT b(world, trange); + ArrayT c(world, trange); + ArrayT t(world, trange_T); // transposed-shape result for permute test + + a.fill(T(0.03)); + b.fill(T(0.02)); + c.fill(T(0.0)); + t.fill(T(0.0)); + world.gop.fence(); + TA::to_device(a); + TA::to_device(b); + + const double bytes_per_elem = static_cast(sizeof(T)); + const double n_elems = static_cast(Nm) * static_cast(Nn); + + auto bench = [&](const char *name, double bytes_per_iter, auto &&op) { + double total_time = 0.0; + for (long i = 0; i < nrepeat; ++i) { + const double t0 = madness::wall_time(); + op(); + world.gop.fence(); + const double t1 = madness::wall_time(); + total_time += t1 - t0; + } + const double avg = total_time / static_cast(nrepeat); + const double bw_gbs = bytes_per_iter / (avg * 1.0e9); + if (world.rank() == 0) + std::cout << " " << name << ": avg=" << avg << " s bw=" << bw_gbs + << " GB/s\n"; + }; + + // Binary read-read-write: 3 element accesses per element. + const double rw3_bytes = 3.0 * n_elems * bytes_per_elem; + // Unary read-write: 2 element accesses per element. + const double rw2_bytes = 2.0 * n_elems * bytes_per_elem; + + bench("add(c=a+b)", rw3_bytes, [&] { c("m,n") = a("m,n") + b("m,n"); }); + bench("subt(c=a-b)", rw3_bytes, [&] { c("m,n") = a("m,n") - b("m,n"); }); + bench("scale(c=2*a)", rw2_bytes, [&] { c("m,n") = 2.0 * a("m,n"); }); + bench("hadamard(c=a*b)", rw3_bytes, [&] { c("m,n") = a("m,n") * b("m,n"); }); + bench("permute(t=a^T)", rw2_bytes, [&] { t("n,m") = a("m,n"); }); + bench("axpy(c+=a)", rw3_bytes, [&] { c("m,n") += a("m,n"); }); + + world.gop.fence(); +} + +} // namespace + +int try_main(int argc, char **argv) { + TiledArray::World &world = TA_SCOPED_INITIALIZE(argc, argv); + + if (argc < 5) { + if (world.rank() == 0) + std::cerr + << "Usage: " << argv[0] + << " Nm Bm Nn Bn [nrepeat=5]\n" + << " Times element-wise vector ops on Nm x Nn UMTensor matrices\n"; + return 1; + } + const long Nm = std::atol(argv[1]); + const long Bm = std::atol(argv[2]); + const long Nn = std::atol(argv[3]); + const long Bn = std::atol(argv[4]); + const long nrepeat = (argc >= 6 ? std::atol(argv[5]) : 5); + if (Nm <= 0 || Nn <= 0 || Bm <= 0 || Bn <= 0 || nrepeat <= 0) { + if (world.rank() == 0) + std::cerr << "All sizes / blocks / nrepeat must be positive\n"; + return 1; + } + + run(world, Nm, Bm, Nn, Bn, nrepeat); + return 0; +} + +int main(int argc, char **argv) { + try { + return try_main(argc, argv); + } catch (const std::exception &e) { + std::cerr << "exception: " << e.what() << "\n"; + return 1; + } catch (...) { + std::cerr << "unknown exception\n"; + return 1; + } +} From f0c6ec7fd05e397ee076ce748cf28b387db2cc0c Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 13 May 2026 00:15:42 +0000 Subject: [PATCH 09/20] device: drop redundant compile_test_tier1 instantiation probe The runtime instantiation probe (`compile_test_tier1` plus the two `instantiate_tier1_*` function pointers) was added in the Phase 2 commits to force template type-checking of the device tile-op overloads before any test exercised them. It is no longer load-bearing: * tests/expressions_device_tensor.cpp (64 cases) calls every tier-1 overload through the expression engine for `double`, so any instantiation breakage shows up at test-build time. * device/tensor.cpp's explicit `template class Tensor>` instantiations cover the class members authoritatively. Removing the probe trims ~75 lines of dead code from the .cpp. The `is_device_tile_v` static_asserts are kept -- they are zero-cost and guarantee trait correctness even when BUILD_TESTING=OFF. --- src/TiledArray/device/tensor.cpp | 63 ++------------------------------ 1 file changed, 3 insertions(+), 60 deletions(-) diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp index 4b5e128fea..98f82d42e5 100644 --- a/src/TiledArray/device/tensor.cpp +++ b/src/TiledArray/device/tensor.cpp @@ -54,8 +54,9 @@ template class Tensor>; namespace TiledArray::detail { -// Phase 1 sanity: confirm the is_device_tile specialization fires for the -// allocator alias and propagates through Tile<>. +// Compile-time guarantees on the trait wiring. Run before the test suite +// (and even when BUILD_TESTING=OFF) so a regression here breaks the +// library build instead of being deferred to a test failure. static_assert(is_device_tile_v>, "UMTensor must be tagged as a device tile"); static_assert(is_device_tile_v>, @@ -70,61 +71,3 @@ static_assert(!is_device_tile_v>, } // namespace TiledArray::detail -// Phase 2 instantiation probes: force the compiler to type-check the -// device-tile overloads. Real explicit instantiations land in Phase 4. -namespace { - -template -void compile_test_tier1() { - using TA::UMTensor; - using helper_t = TiledArray::math::GemmHelper; - UMTensor a, b, c; - helper_t h(TiledArray::math::blas::Op::NoTrans, - TiledArray::math::blas::Op::NoTrans, 2u, 2u, 2u); - - (void)TiledArray::clone(a); - (void)TiledArray::scale(a, T(2)); - (void)TiledArray::scale_to(a, T(2)); - (void)TiledArray::neg(a); - (void)TiledArray::neg_to(a); - (void)TiledArray::add(a, b); - (void)TiledArray::add(a, b, T(2)); - (void)TiledArray::add_to(a, b); - (void)TiledArray::add_to(a, b, T(2)); - (void)TiledArray::subt(a, b); - (void)TiledArray::subt(a, b, T(2)); - (void)TiledArray::subt_to(a, b); - (void)TiledArray::subt_to(a, b, T(2)); - (void)TiledArray::dot(a, b); - (void)TiledArray::squared_norm(a); - (void)TiledArray::norm(a); - (void)TiledArray::gemm(a, b, T(1), h); - TiledArray::gemm(c, a, b, T(1), h); - - // Phase 2b: permute / shift / mult and the perm-variants. - TiledArray::Permutation perm(std::vector{1, 0}); - TiledArray::BipartitePermutation bperm(perm); - std::vector shift{0, 0}; - (void)TiledArray::permute(a, perm); - (void)TiledArray::permute(a, bperm); - (void)TiledArray::shift(a, shift); - (void)TiledArray::shift_to(a, shift); - (void)TiledArray::scale(a, T(2), perm); - (void)TiledArray::neg(a, perm); - (void)TiledArray::add(a, b, perm); - (void)TiledArray::add(a, b, T(2), perm); - (void)TiledArray::subt(a, b, perm); - (void)TiledArray::subt(a, b, T(2), perm); - (void)TiledArray::mult(a, b); - (void)TiledArray::mult(a, b, T(2)); - (void)TiledArray::mult(a, b, perm); - (void)TiledArray::mult(a, b, T(2), perm); - (void)TiledArray::mult_to(a, b); - (void)TiledArray::mult_to(a, b, T(2)); -} - -[[maybe_unused]] auto instantiate_tier1_double = &compile_test_tier1; -[[maybe_unused]] auto instantiate_tier1_float = &compile_test_tier1; - -} // namespace - From a00c93af43c0c26a77d0fa52fec3e8768814bc80 Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 13 May 2026 00:27:36 +0000 Subject: [PATCH 10/20] device: drop unused UMTensorArg concept; reword scaffolding comments * Remove the `UMTensorArg` concept and the `detail::is_um_tensor` trait it was built on. Both were introduced during the dispatch debugging to try a constrained forwarding-reference approach (per the comment block at the top of the section); that approach was abandoned in favor of two concrete-type overloads (`UMTensor&` and `UMTensor&&`) per in-place op, which beat the templated forwarding-ref candidate in partial ordering. The concept and trait are dead declarations; the explanatory comment is kept and trimmed. * Rewording: two inline comments referencing "Phase 2" and "Phase 3" development sequencing -- meaningless to a future reader -- are reworded to describe the rule without the historical label. Verified: 64-case device_tensor_expressions_suite still passes; no behavior change. --- src/TiledArray/device/tensor.h | 44 +++++++++-------------------- tests/expressions_device_tensor.cpp | 6 ++-- 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index e486b8784d..359a059036 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -95,35 +95,19 @@ inline void to_host(const TiledArray::UMTensor& tile) { // To win the dispatch we provide two concrete-type overloads per in-place // op: one taking `UMTensor&` and one taking `UMTensor&&`. Concrete // types beat the templated forwarding reference `Result&&` in partial -// ordering regardless of the SFINAE / `requires` constraint shape, so this -// is robust against compiler differences. (A single forwarding-ref overload -// constrained with a `requires UMTensorArg<...>` concept would in principle -// also win because a constrained template subsumes an unconstrained one, -// but g++ does not consistently treat tile_interface's `enable_if`-only -// templates as unconstrained for this purpose -- the result is an ambiguous -// overload error. The two-concrete-overload form sidesteps the question.) +// ordering regardless of constraint shape, so this is robust against +// compiler differences. (A constrained forwarding-ref overload should in +// principle also win because a constrained template subsumes an +// unconstrained one, but g++ does not consistently treat +// tile_interface's `enable_if`-only templates as unconstrained for this +// purpose, leading to ambiguous-overload errors. Two concrete overloads +// sidestep the question.) // -// The lvalue overload forwards to the rvalue overload to keep a single -// implementation per op. Value-returning overloads (e.g. -// `add(const UMTensor&, const UMTensor&)`) don't need this because -// reference-to-const binds to both lvalues and rvalues. -// -// The `UMTensorArg` concept is kept around as documentation of intent and -// as a clean handle for any future helper that genuinely wants forwarding -// references (e.g. a `to_device` overload set). +// The lvalue overload does the work; the rvalue overload forwards to it. +// Value-returning overloads (e.g. `add(const UMTensor&, const UMTensor&)`) +// don't need this because reference-to-const binds to both lvalues and +// rvalues. // --------------------------------------------------------------------------- -namespace detail { -template -struct is_um_tensor : std::false_type {}; -template -struct is_um_tensor> : std::true_type {}; -template -inline constexpr bool is_um_tensor_v = - is_um_tensor>>::value; -} // namespace detail - -template -concept UMTensorArg = detail::is_um_tensor_v; // --------------------------------------------------------------------------- // Tile-op overloads for UMTensor. @@ -144,9 +128,9 @@ concept UMTensorArg = detail::is_um_tensor_v; // 4. `sync_madness_task_with(stream)` so the enclosing MADNESS device task // waits for the queue to drain before completing. // -// For Phase 2 batched tiles (`nbatch_ > 1`) are not yet supported -- the -// expression engine doesn't currently feed batched UMTensor through these -// paths, and dropping the assertion now would silently miscompute. +// Batched tiles (`nbatch_ > 1`) are not yet supported -- the expression +// engine doesn't currently feed batched UMTensor through these paths, and +// dropping the assertion would silently miscompute. // --------------------------------------------------------------------------- /// result[i] = arg[i] diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp index 68ac90248d..5b7182454e 100644 --- a/tests/expressions_device_tensor.cpp +++ b/tests/expressions_device_tensor.cpp @@ -921,9 +921,9 @@ BOOST_AUTO_TEST_CASE(dot_contr) { } // --------------------------------------------------------------------------- -// Phase 3 surface: archive round-trip, host/device array conversions, and -// bulk to_host / to_device. Smoke + correctness for the helpers in -// device/tensor.h that are not in the expression-engine path. +// Archive round-trip, host/device array conversions, and bulk to_host / +// to_device. Smoke + correctness for the helpers in device/tensor.h that +// are not in the expression-engine path. // --------------------------------------------------------------------------- BOOST_AUTO_TEST_CASE(serialize_um_tensor) { // Single-tile round-trip: build a UMTensor, write to a buffer archive, From 1ad8c695a7f1ca4ec2532fdaf7fcb9865c8e23bd Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 13 May 2026 04:56:21 +0000 Subject: [PATCH 11/20] chore: update .gitignore to avoide build directories --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 5112ca62ed..7eb9952c74 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,7 @@ # IDEs *.idea *.vscode + + +build/* +cmake-build* \ No newline at end of file From 05d7e5508ad56fb7f60bc3aafaf65a99833d1b6d Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 13 May 2026 05:01:19 +0000 Subject: [PATCH 12/20] device: update host-device conversion helpers Widen um_tensor_to_ta_tensor / ta_tensor_to_um_tensor from to , mirroring the signature in device/btas_um_tensor.h:619+. Identity overloads cover the case where source and destination tiles coincide. --- src/TiledArray/device/tensor.h | 53 +++++++++++++++++++---------- tests/expressions_device_tensor.cpp | 10 ++++-- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 359a059036..5bc40ccc8b 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -793,41 +793,58 @@ inline void to_device(TiledArray::DistArray, Policy>& array) { DeviceSafeCall(device::deviceSynchronize()); } -/// Convert a UMTensor-backed `DistArray` to one backed by host -/// `TA::Tensor`. Tile-by-tile copy through `to_new_tile_type` -- the -/// per-tile lambda allocates a host result, prefetches the source UM -/// buffer to host, and memcpys. -template -inline TiledArray::DistArray, Policy> +/// Convert a UMTensor-backed `DistArray` to one backed by a host tile type. +/// Template arg order `` matches the btas pair in +/// device/btas_um_tensor.h:619+. +template +inline std::enable_if_t, + TiledArray::DistArray> um_tensor_to_ta_tensor( - const TiledArray::DistArray, Policy>& um_array) { - auto convert_tile = [](const UMTensor& tile) { + const TiledArray::DistArray& um_array) { + auto convert_tile = [](const UMTile& tile) { detail::to_host(tile); - TiledArray::Tensor result(tile.range()); + HostTile result(tile.range()); std::copy_n(tile.data(), tile.total_size(), result.data()); return result; }; - auto out = to_new_tile_type>(um_array, convert_tile); + auto out = to_new_tile_type(um_array, convert_tile); um_array.world().gop.fence(); return out; } -/// Convert a host `TA::Tensor`-backed `DistArray` to a UMTensor-backed -/// one. Tile-by-tile copy: allocate UM, memcpy, prefetch to device. -template -inline TiledArray::DistArray, Policy> ta_tensor_to_um_tensor( - const TiledArray::DistArray, Policy>& host_array) { - auto convert_tile = [](const TiledArray::Tensor& tile) { - UMTensor result(tile.range()); +template +inline std::enable_if_t, + TiledArray::DistArray> +um_tensor_to_ta_tensor( + const TiledArray::DistArray& um_array) { + return um_array; +} + +/// Convert a host-tile-backed `DistArray` to a UMTensor-backed one. +template +inline std::enable_if_t, + TiledArray::DistArray> +ta_tensor_to_um_tensor( + const TiledArray::DistArray& host_array) { + auto convert_tile = [](const HostTile& tile) { + UMTile result(tile.range()); std::copy_n(tile.data(), tile.total_size(), result.data()); detail::to_device(result); return result; }; - auto out = to_new_tile_type>(host_array, convert_tile); + auto out = to_new_tile_type(host_array, convert_tile); host_array.world().gop.fence(); return out; } +template +inline std::enable_if_t, + TiledArray::DistArray> +ta_tensor_to_um_tensor( + const TiledArray::DistArray& host_array) { + return host_array; +} + } // namespace TiledArray // --------------------------------------------------------------------------- diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp index 5b7182454e..d2f8f59606 100644 --- a/tests/expressions_device_tensor.cpp +++ b/tests/expressions_device_tensor.cpp @@ -973,12 +973,15 @@ BOOST_AUTO_TEST_CASE(serialize_um_tensor_empty) { BOOST_AUTO_TEST_CASE(um_to_ta_round_trip) { // UMTensor array -> host array -> UMTensor array, verify element-wise // against the original both at the host and device endpoints. - HostArray host_view = TiledArray::um_tensor_to_ta_tensor(a); + HostArray host_view = + TiledArray::um_tensor_to_ta_tensor(a); GlobalFixture::world->gop.fence(); // Compare host_view directly against a_h (same data was used for both). check_close(host_view, a_h, tolerance); - TArrayD device_view = TiledArray::ta_tensor_to_um_tensor(host_view); + TArrayD device_view = + TiledArray::ta_tensor_to_um_tensor( + host_view); GlobalFixture::world->gop.fence(); // Round-trip: device_view should match the original `a`. check_close(device_view, a_h, tolerance); @@ -990,7 +993,8 @@ BOOST_AUTO_TEST_CASE(um_to_ta_then_expression) { HostArray sum_h(*GlobalFixture::world, tr); sum_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); - HostArray converted = TiledArray::um_tensor_to_ta_tensor(a); + HostArray converted = + TiledArray::um_tensor_to_ta_tensor(a); HostArray sum_from_device(*GlobalFixture::world, tr); sum_from_device("a,b,c") = converted("a,b,c") + b_h("a,b,c"); From ee31fa9ba3f152bcfb01d62c3cc3e17426236267 Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 13 May 2026 12:58:23 +0000 Subject: [PATCH 13/20] device: constrain UMTensor tile ops to numeric element types The tile-op overloads for UMTensor dispatch BLAS-on-pointer kernels (scal, copy, gemm, dot, etc.). When CCk is instantiated for the device tile, it composes nested tiles of the form UMTensor> to represent its outer/inner structure; those nested tiles must route through TA::Tensor member ops, not the device kernels. Add `requires TiledArray::detail::is_numeric_v` to each device tile-op overload, and gate `is_device_tile>` and the archive store helper on the same predicate so non-numeric element types fall through to the host tile path. --- src/TiledArray/device/tensor.h | 62 ++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 5bc40ccc8b..507dd316e3 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -49,13 +49,15 @@ namespace detail { /// tile ops through `madness::add_device_task`. The pass-through specs for /// `Tile` and `LazyArrayTile` in tensor/type_traits.h pick this up. template -struct is_device_tile> : public std::true_type {}; +struct is_device_tile> + : public std::bool_constant> {}; /// Prefetch a UMTensor's storage to the device associated with its tile range. /// Mirrors the pattern in device/btas_um_tensor.h but reaches the storage via /// `.data()` + `.total_size()` since `TA::Tensor`'s buffer is a /// `shared_ptr` rather than a varray-like container. template + requires TiledArray::detail::is_numeric_v inline void to_device(const TiledArray::UMTensor& tile) { if (tile.empty()) return; auto stream = device::stream_for(tile.range()); @@ -68,6 +70,7 @@ inline void to_device(const TiledArray::UMTensor& tile) { /// Prefetch a UMTensor's storage back to the host. template + requires TiledArray::detail::is_numeric_v inline void to_host(const TiledArray::UMTensor& tile) { if (tile.empty()) return; auto stream = device::stream_for(tile.range()); @@ -135,6 +138,7 @@ inline void to_host(const TiledArray::UMTensor& tile) { /// result[i] = arg[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor clone(const UMTensor& arg) { TA_ASSERT(!arg.empty()); TA_ASSERT(arg.nbatch() == 1); @@ -188,6 +192,7 @@ inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor, /// result[i] = arg[i] * factor template >> + requires TiledArray::detail::is_numeric_v inline UMTensor scale(const UMTensor& arg, const Scalar factor) { auto result = clone(arg); auto& queue = blasqueue_for(result.range()); @@ -202,7 +207,7 @@ inline UMTensor scale(const UMTensor& arg, const Scalar factor) { /// here rather than to the tile_interface forwarder that would call the CPU /// member function on UM memory. template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { TA_ASSERT(!result.empty()); TA_ASSERT(result.nbatch() == 1); @@ -216,30 +221,34 @@ inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& scale_to(UMTensor&& result, const Scalar factor) { return scale_to(result, factor); } /// result[i] = -arg[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor neg(const UMTensor& arg) { return scale(arg, T(-1)); } /// arg[i] = -arg[i] (in-place) template + requires TiledArray::detail::is_numeric_v inline UMTensor& neg_to(UMTensor& arg) { return scale_to(arg, T(-1)); } template + requires TiledArray::detail::is_numeric_v inline UMTensor& neg_to(UMTensor&& arg) { return scale_to(arg, T(-1)); } /// result[i] = arg1[i] + arg2[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2) { TA_ASSERT(!arg1.empty()); TA_ASSERT(!arg2.empty()); @@ -265,6 +274,7 @@ inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2) { /// result[i] = (arg1[i] + arg2[i]) * factor template >> + requires TiledArray::detail::is_numeric_v inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor) { auto result = add(arg1, arg2); @@ -273,6 +283,7 @@ inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, /// result[i] += arg[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor& add_to(UMTensor& result, const UMTensor& arg) { TA_ASSERT(!result.empty()); TA_ASSERT(!arg.empty()); @@ -292,6 +303,7 @@ inline UMTensor& add_to(UMTensor& result, const UMTensor& arg) { } template + requires TiledArray::detail::is_numeric_v inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg) { return add_to(result, arg); } @@ -299,7 +311,7 @@ inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg) { /// result[i] = (result[i] + arg[i]) * factor /// Matches TA::Tensor::add_to(right, factor) semantics: `(l += r) *= factor`. template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& add_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { add_to(result, arg); @@ -307,7 +319,7 @@ inline UMTensor& add_to(UMTensor& result, const UMTensor& arg, } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg, const Scalar factor) { return add_to(result, arg, factor); @@ -315,6 +327,7 @@ inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg, /// result[i] = arg1[i] - arg2[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2) { TA_ASSERT(!arg1.empty()); TA_ASSERT(!arg2.empty()); @@ -340,6 +353,7 @@ inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2) { /// result[i] = (arg1[i] - arg2[i]) * factor template >> + requires TiledArray::detail::is_numeric_v inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor) { auto result = subt(arg1, arg2); @@ -348,6 +362,7 @@ inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, /// result[i] -= arg[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg) { TA_ASSERT(!result.empty()); TA_ASSERT(!arg.empty()); @@ -367,6 +382,7 @@ inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg) { } template + requires TiledArray::detail::is_numeric_v inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg) { return subt_to(result, arg); } @@ -382,7 +398,7 @@ inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg) { /// TA::Tensor's CPU member function and races with any in-flight device /// kernel on UM memory. template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { subt_to(result, arg); @@ -390,7 +406,7 @@ inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg, } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg, const Scalar factor) { return subt_to(result, arg, factor); @@ -398,6 +414,7 @@ inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg, /// dot product: scalar = sum_i arg1[i] * arg2[i] template + requires TiledArray::detail::is_numeric_v inline T dot(const UMTensor& arg1, const UMTensor& arg2) { TA_ASSERT(!arg1.empty()); TA_ASSERT(!arg2.empty()); @@ -420,12 +437,14 @@ inline T dot(const UMTensor& arg1, const UMTensor& arg2) { /// scalar = sum_i arg[i] * arg[i] template + requires TiledArray::detail::is_numeric_v inline auto squared_norm(const UMTensor& arg) { return dot(arg, arg); } /// scalar = sqrt(squared_norm(arg)) template + requires TiledArray::detail::is_numeric_v inline auto norm(const UMTensor& arg) { using std::sqrt; using ResultType = TiledArray::detail::scalar_t; @@ -434,6 +453,7 @@ inline auto norm(const UMTensor& arg) { /// result[perm(i)] = arg[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor permute(const UMTensor& arg, const TiledArray::Permutation& perm) { TA_ASSERT(!arg.empty()); @@ -463,6 +483,7 @@ inline UMTensor permute(const UMTensor& arg, /// Required to win ADL against the generic CPU member-delegating overload; /// see the matching warning in device/btas_um_tensor.h:193. template + requires TiledArray::detail::is_numeric_v inline UMTensor permute(const UMTensor& arg, const TiledArray::BipartitePermutation& perm) { TA_ASSERT(inner_size(perm) == 0); // UMTensor is a non-nested tile @@ -473,6 +494,7 @@ inline UMTensor permute(const UMTensor& arg, template && TiledArray::detail::is_permutation_v>> + requires TiledArray::detail::is_numeric_v inline UMTensor scale(const UMTensor& arg, const Scalar factor, const Perm& perm) { auto scaled = scale(arg, factor); @@ -482,6 +504,7 @@ inline UMTensor scale(const UMTensor& arg, const Scalar factor, /// result[perm(i)] = -arg[i] template >> + requires TiledArray::detail::is_numeric_v inline UMTensor neg(const UMTensor& arg, const Perm& perm) { return permute(neg(arg), perm); } @@ -489,6 +512,7 @@ inline UMTensor neg(const UMTensor& arg, const Perm& perm) { /// result[perm(i)] = arg1[i] + arg2[i] template >> + requires TiledArray::detail::is_numeric_v inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, const Perm& perm) { return permute(add(arg1, arg2), perm); @@ -498,6 +522,7 @@ inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, template && TiledArray::detail::is_permutation_v>> + requires TiledArray::detail::is_numeric_v inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor, const Perm& perm) { return permute(add(arg1, arg2, factor), perm); @@ -506,6 +531,7 @@ inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, /// result[perm(i)] = arg1[i] - arg2[i] template >> + requires TiledArray::detail::is_numeric_v inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, const Perm& perm) { return permute(subt(arg1, arg2), perm); @@ -515,6 +541,7 @@ inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, template && TiledArray::detail::is_permutation_v>> + requires TiledArray::detail::is_numeric_v inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor, const Perm& perm) { return permute(subt(arg1, arg2, factor), perm); @@ -522,6 +549,7 @@ inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, /// shift: result has arg's data, range shifted by bound_shift. template + requires TiledArray::detail::is_numeric_v inline UMTensor shift(const UMTensor& arg, const Index& bound_shift) { TA_ASSERT(!arg.empty()); TA_ASSERT(arg.nbatch() == 1); @@ -546,18 +574,21 @@ inline UMTensor shift(const UMTensor& arg, const Index& bound_shift) { /// shift_to: in-place range shift, no data movement. template + requires TiledArray::detail::is_numeric_v inline UMTensor& shift_to(UMTensor& arg, const Index& bound_shift) { const_cast(arg.range()).inplace_shift(bound_shift); return arg; } template + requires TiledArray::detail::is_numeric_v inline UMTensor& shift_to(UMTensor&& arg, const Index& bound_shift) { return shift_to(arg, bound_shift); } /// result[i] = arg1[i] * arg2[i] (element-wise / Hadamard) template + requires TiledArray::detail::is_numeric_v inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2) { TA_ASSERT(!arg1.empty()); TA_ASSERT(!arg2.empty()); @@ -584,6 +615,7 @@ inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2) { /// result[i] = arg1[i] * arg2[i] * factor template >> + requires TiledArray::detail::is_numeric_v inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor) { auto result = mult(arg1, arg2); @@ -593,6 +625,7 @@ inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, /// result[perm(i)] = arg1[i] * arg2[i] template >> + requires TiledArray::detail::is_numeric_v inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, const Perm& perm) { return permute(mult(arg1, arg2), perm); @@ -602,6 +635,7 @@ inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, template && TiledArray::detail::is_permutation_v>> + requires TiledArray::detail::is_numeric_v inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor, const Perm& perm) { return permute(mult(arg1, arg2, factor), perm); @@ -609,6 +643,7 @@ inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, /// result[i] *= arg[i] template + requires TiledArray::detail::is_numeric_v inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg) { TA_ASSERT(!result.empty()); TA_ASSERT(!arg.empty()); @@ -629,6 +664,7 @@ inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg) { } template + requires TiledArray::detail::is_numeric_v inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg) { return mult_to(result, arg); } @@ -636,7 +672,7 @@ inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg) { /// result[i] = (result[i] * arg[i]) * factor /// Matches TA::Tensor::mult_to(right, factor) semantics: `(l *= r) *= factor`. template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { mult_to(result, arg); @@ -644,7 +680,7 @@ inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg, } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg, const Scalar factor) { return mult_to(result, arg, factor); @@ -653,6 +689,7 @@ inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg, /// gemm: returning form. result = factor * left * right template >> + requires TiledArray::detail::is_numeric_v inline UMTensor gemm(const UMTensor& left, const UMTensor& right, const Scalar factor, const TiledArray::math::GemmHelper& gemm_helper) { @@ -704,6 +741,7 @@ inline UMTensor gemm(const UMTensor& left, const UMTensor& right, /// gemm: accumulating form. result += factor * left * right template >> + requires TiledArray::detail::is_numeric_v inline void gemm(UMTensor& result, const UMTensor& left, const UMTensor& right, const Scalar factor, const TiledArray::math::GemmHelper& gemm_helper) { @@ -762,6 +800,7 @@ inline void gemm(UMTensor& result, const UMTensor& left, /// Prefetch every local tile of `array` to the host. Fences on the /// containing world and globally synchronizes the device on exit. template + requires TiledArray::detail::is_numeric_v inline void to_host(TiledArray::DistArray, Policy>& array) { auto prefetch = [](UMTensor& tile) { auto stream = device::stream_for(tile.range()); @@ -779,6 +818,7 @@ inline void to_host(TiledArray::DistArray, Policy>& array) { /// Prefetch every local tile of `array` to the device. Fences on the /// containing world and globally synchronizes the device on exit. template + requires TiledArray::detail::is_numeric_v inline void to_device(TiledArray::DistArray, Policy>& array) { auto prefetch = [](UMTensor& tile) { auto stream = device::stream_for(tile.range()); @@ -865,7 +905,9 @@ template struct ArchiveStoreImpl> { static inline void store(const Archive& ar, const TiledArray::UMTensor& t) { - TiledArray::detail::to_host(t); + if constexpr (TiledArray::detail::is_numeric_v) { + TiledArray::detail::to_host(t); + } // Mirror TA::Tensor::serialize's store side; we cannot call the member // because it is non-const and we want to keep the input parameter // const-correct. From bab98d38e7bb5cd34497ea4ce9249dcc6bac3255 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 May 2026 05:18:34 +0000 Subject: [PATCH 14/20] refactor: move static_asserts to tests and fixup comments in device/tensor.cpp --- src/TiledArray/device/tensor.cpp | 36 +++-------------------------- tests/expressions_device_tensor.cpp | 2 ++ 2 files changed, 5 insertions(+), 33 deletions(-) diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp index 98f82d42e5..3065d12e1f 100644 --- a/src/TiledArray/device/tensor.cpp +++ b/src/TiledArray/device/tensor.cpp @@ -27,20 +27,9 @@ namespace TiledArray { // Explicit instantiations of the UMTensor class for the standard numeric // types. Without these, every TU including device/tensor.h would instantiate -// the full TA::Tensor> class body (~3000 lines of -// templated members) -- the matching `extern template` declarations in -// device/tensor.h suppress that per-TU work and route consumers to the -// symbols defined here. -// -// The list mirrors `src/TiledArray/tensor/tensor.cpp`'s host-side set -// (double, float, complex variants), plus int/long which are cheap to -// instantiate and useful for index-tile use cases. BLAS-bearing free -// functions (`gemm`, `scale`, ...) are still header-defined templates -- -// instantiating those for each numeric type would pull in the full -// BLAS++/librett surface here, and the compile-time saving from -// extern-templating them does not justify it. They get instantiated lazily -// in whichever TU actually calls them (typically the test or example TU). - +// the full TA::Tensor> class body. +// Mirrors the host-side set in tensor/tensor.cpp; paired with the +// `extern template` declarations in device/tensor.h. template class Tensor>; template class Tensor>; template class Tensor, @@ -52,22 +41,3 @@ template class Tensor>; } // namespace TiledArray -namespace TiledArray::detail { - -// Compile-time guarantees on the trait wiring. Run before the test suite -// (and even when BUILD_TESTING=OFF) so a regression here breaks the -// library build instead of being deferred to a test failure. -static_assert(is_device_tile_v>, - "UMTensor must be tagged as a device tile"); -static_assert(is_device_tile_v>, - "UMTensor must be tagged as a device tile"); -static_assert( - is_device_tile_v>>, - "UMTensor> must be tagged as a device tile"); -static_assert(is_device_tile_v>>, - "Tile> must propagate the device-tile tag"); -static_assert(!is_device_tile_v>, - "Plain Tensor must not be tagged as a device tile"); - -} // namespace TiledArray::detail - diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp index d2f8f59606..97b0b16fa3 100644 --- a/tests/expressions_device_tensor.cpp +++ b/tests/expressions_device_tensor.cpp @@ -119,6 +119,8 @@ BOOST_AUTO_TEST_CASE(is_device_tile_classification) { using detail::is_device_tile_v; BOOST_CHECK(is_device_tile_v>); BOOST_CHECK(is_device_tile_v>); + BOOST_CHECK(is_device_tile_v>>); + BOOST_CHECK(is_device_tile_v>>); BOOST_CHECK(is_device_tile_v); BOOST_CHECK(!is_device_tile_v); } From 6e570cdb1f5e9cd2fa6460b7e16c9ff9793f14c1 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 May 2026 06:01:04 +0000 Subject: [PATCH 15/20] device: tidy UMTensor tile-op overloads (requires-only constraints, TA_EXCEPTION, doc fixups, formatting) --- src/TiledArray/device/tensor.h | 300 ++++++++++++++------------------- 1 file changed, 122 insertions(+), 178 deletions(-) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 507dd316e3..331da8c2f4 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -39,23 +39,20 @@ #include #include -#include #include +#include namespace TiledArray { namespace detail { -/// `UMTensor` lives in unified memory; the expression engine must route its -/// tile ops through `madness::add_device_task`. The pass-through specs for -/// `Tile` and `LazyArrayTile` in tensor/type_traits.h pick this up. +/// UMTensor lives in unified memory; it is identified as a device_tile and +/// the expression engine must route its tile ops through +/// madness::add_device_task. template struct is_device_tile> : public std::bool_constant> {}; /// Prefetch a UMTensor's storage to the device associated with its tile range. -/// Mirrors the pattern in device/btas_um_tensor.h but reaches the storage via -/// `.data()` + `.total_size()` since `TA::Tensor`'s buffer is a -/// `shared_ptr` rather than a varray-like container. template requires TiledArray::detail::is_numeric_v inline void to_device(const TiledArray::UMTensor& tile) { @@ -75,66 +72,44 @@ inline void to_host(const TiledArray::UMTensor& tile) { if (tile.empty()) return; auto stream = device::stream_for(tile.range()); if (deviceEnv::instance()->concurrent_managed_access()) { - DeviceSafeCall(device::memPrefetchAsync(tile.data(), - tile.total_size() * sizeof(T), - device::CpuDeviceId, stream.stream)); + DeviceSafeCall( + device::memPrefetchAsync(tile.data(), tile.total_size() * sizeof(T), + device::CpuDeviceId, stream.stream)); } } } // namespace detail -// --------------------------------------------------------------------------- -// In-place tile ops are tricky to dispatch correctly. -// -// `tile_op/{subt,add,mult,...}.h::Op::eval` passes the result via -// `std::move(...)` when the operand is consumable -- so the engine calls our -// `subt_to`, `add_to`, etc. with an rvalue. A plain `UMTensor&` overload -// is not a viable candidate for an rvalue, so overload resolution falls -// through to the generic forwarder in `tile_op/tile_interface.h` (and -// `tile_interface/scale.h`). That forwarder delegates to TA::Tensor's CPU -// member function, which then reads UM memory while the previous device -// kernel is still in flight on the queue -- silently miscomputing. -// -// To win the dispatch we provide two concrete-type overloads per in-place -// op: one taking `UMTensor&` and one taking `UMTensor&&`. Concrete -// types beat the templated forwarding reference `Result&&` in partial -// ordering regardless of constraint shape, so this is robust against -// compiler differences. (A constrained forwarding-ref overload should in -// principle also win because a constrained template subsumes an -// unconstrained one, but g++ does not consistently treat -// tile_interface's `enable_if`-only templates as unconstrained for this -// purpose, leading to ambiguous-overload errors. Two concrete overloads -// sidestep the question.) -// -// The lvalue overload does the work; the rvalue overload forwards to it. -// Value-returning overloads (e.g. `add(const UMTensor&, const UMTensor&)`) -// don't need this because reference-to-const binds to both lvalues and -// rvalues. -// --------------------------------------------------------------------------- - -// --------------------------------------------------------------------------- -// Tile-op overloads for UMTensor. -// -// Each overload sits in `namespace TiledArray` so ADL finds it from the -// expression engine and from the tile_op layer's free-function defaults. -// More-specialized concrete-type overloads win against the generic -// `template ... add(left, right) { return -// left.add(right); }` forwarders in `tile_op/tile_interface.h`, so we never -// fall back to the CPU member functions for UMTensor. -// -// All overloads follow the stream/queue contract: -// 1. Resolve a queue via `blasqueue_for(range)`. Inside a device task this -// is the same queue everyone else in the task uses (see -// `external/device.h:899-907`); outside one, it round-robins. -// 2. Prefetch every input + the result to the device. -// 3. Call into BLAS++ / device kernels on that queue. -// 4. `sync_madness_task_with(stream)` so the enclosing MADNESS device task -// waits for the queue to drain before completing. -// -// Batched tiles (`nbatch_ > 1`) are not yet supported -- the expression -// engine doesn't currently feed batched UMTensor through these paths, and -// dropping the assertion would silently miscompute. -// --------------------------------------------------------------------------- +// clang-format off +/// Tile-op overloads for UMTensor. +/// +/// Each overload sits in `namespace TiledArray` so ADL finds it from the +/// expression engine and from the tile_op layer's free-function defaults. +/// More-specialized concrete-type overloads win against the generic +/// forwarder in `tile_op/tile_interface.h`: +/// \code +/// template +/// auto add(Left&& left, Right&& right) { +/// return left.add(right); +/// } +/// \endcode +/// so we never fall back to the CPU member functions for UMTensor. +/// +/// All overloads follow the stream/queue contract: +/// 1. Resolve a queue via `blasqueue_for(range)`. Inside a device task +/// this is the same queue everyone else in the task uses (see +/// `external/device.h:899-907`); outside one, it round-robins. +/// 2. Prefetch every input + the result to the device. +/// 3. Call into BLAS++ / device kernels on that queue. +/// 4. `sync_madness_task_with(stream)` so the enclosing MADNESS device +/// task waits for the queue to drain before completing. +/// +/// In-place ops provide both an lvalue and an rvalue overload: the lvalue +/// overload does the work, the rvalue overload forwards to it. +/// +/// nbatch_ > 1 is not yet supported; the host-side tile +/// ops don't support them either. +// clang-format on /// result[i] = arg[i] template @@ -172,7 +147,9 @@ inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor, ::blas::scal(n, factor, data, 1, queue); } else { if constexpr (TiledArray::detail::is_complex_v) { - abort(); // fused conjugation requires custom kernels, not yet supported + TA_EXCEPTION( + "UMTensor scale with ComplexConjugate factor on complex T is not " + "implemented (requires a fused conjugation kernel)"); } else { if constexpr (std::is_same_v< Scalar, TiledArray::detail::ComplexConjugate>) { @@ -190,9 +167,9 @@ inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor, } // namespace detail /// result[i] = arg[i] * factor -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor scale(const UMTensor& arg, const Scalar factor) { auto result = clone(arg); auto& queue = blasqueue_for(result.range()); @@ -202,12 +179,10 @@ inline UMTensor scale(const UMTensor& arg, const Scalar factor) { return result; } -/// result[i] *= factor (in-place). Forwarding-reference form so the engine's -/// `scale_to(std::move(tile), factor)` (from `tile_op/scal.h:82`) dispatches -/// here rather than to the tile_interface forwarder that would call the CPU -/// member function on UM memory. +/// result[i] *= factor (in-place) template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { TA_ASSERT(!result.empty()); TA_ASSERT(result.nbatch() == 1); @@ -221,7 +196,8 @@ inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& scale_to(UMTensor&& result, const Scalar factor) { return scale_to(result, factor); } @@ -243,7 +219,7 @@ inline UMTensor& neg_to(UMTensor& arg) { template requires TiledArray::detail::is_numeric_v inline UMTensor& neg_to(UMTensor&& arg) { - return scale_to(arg, T(-1)); + return neg_to(arg); } /// result[i] = arg1[i] + arg2[i] @@ -272,9 +248,9 @@ inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2) { } /// result[i] = (arg1[i] + arg2[i]) * factor -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor) { auto result = add(arg1, arg2); @@ -311,7 +287,8 @@ inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg) { /// result[i] = (result[i] + arg[i]) * factor /// Matches TA::Tensor::add_to(right, factor) semantics: `(l += r) *= factor`. template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& add_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { add_to(result, arg); @@ -319,7 +296,8 @@ inline UMTensor& add_to(UMTensor& result, const UMTensor& arg, } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg, const Scalar factor) { return add_to(result, arg, factor); @@ -351,9 +329,9 @@ inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2) { } /// result[i] = (arg1[i] - arg2[i]) * factor -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor) { auto result = subt(arg1, arg2); @@ -389,16 +367,9 @@ inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg) { /// result[i] = (result[i] - arg[i]) * factor /// Matches TA::Tensor::subt_to(right, factor) semantics: `(l -= r) *= factor`. -/// This convention is load-bearing for `tile_op/subt.h::Subt::eval` -- when -/// the engine reuses the right operand's storage, it calls -/// `subt_to(std::move(second), first, -1)` and relies on the result being -/// `(second - first) * -1 = first - second`. Hence the forwarding reference -/// on `result`: lvalue-only signatures lose overload resolution to the -/// templated forwarder in tile_op/tile_interface.h, which then dispatches to -/// TA::Tensor's CPU member function and races with any in-flight device -/// kernel on UM memory. template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { subt_to(result, arg); @@ -406,7 +377,8 @@ inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg, } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg, const Scalar factor) { return subt_to(result, arg, factor); @@ -480,8 +452,6 @@ inline UMTensor permute(const UMTensor& arg, } /// BipartitePermutation -> plain Permutation forward. -/// Required to win ADL against the generic CPU member-delegating overload; -/// see the matching warning in device/btas_um_tensor.h:193. template requires TiledArray::detail::is_numeric_v inline UMTensor permute(const UMTensor& arg, @@ -491,10 +461,10 @@ inline UMTensor permute(const UMTensor& arg, } /// result[perm(i)] = arg[i] * factor -template && - TiledArray::detail::is_permutation_v>> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor scale(const UMTensor& arg, const Scalar factor, const Perm& perm) { auto scaled = scale(arg, factor); @@ -502,46 +472,46 @@ inline UMTensor scale(const UMTensor& arg, const Scalar factor, } /// result[perm(i)] = -arg[i] -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor neg(const UMTensor& arg, const Perm& perm) { return permute(neg(arg), perm); } /// result[perm(i)] = arg1[i] + arg2[i] -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, const Perm& perm) { return permute(add(arg1, arg2), perm); } /// result[perm(i)] = (arg1[i] + arg2[i]) * factor -template && - TiledArray::detail::is_permutation_v>> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor, const Perm& perm) { return permute(add(arg1, arg2, factor), perm); } /// result[perm(i)] = arg1[i] - arg2[i] -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, const Perm& perm) { return permute(subt(arg1, arg2), perm); } /// result[perm(i)] = (arg1[i] - arg2[i]) * factor -template && - TiledArray::detail::is_permutation_v>> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor, const Perm& perm) { return permute(subt(arg1, arg2, factor), perm); @@ -576,6 +546,8 @@ inline UMTensor shift(const UMTensor& arg, const Index& bound_shift) { template requires TiledArray::detail::is_numeric_v inline UMTensor& shift_to(UMTensor& arg, const Index& bound_shift) { + // `range()` only exposes a const accessor; cast is safe because we are the + // tile's owner here and only the range bounds change, not the data layout. const_cast(arg.range()).inplace_shift(bound_shift); return arg; } @@ -613,9 +585,9 @@ inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2) { } /// result[i] = arg1[i] * arg2[i] * factor -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor) { auto result = mult(arg1, arg2); @@ -623,19 +595,19 @@ inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, } /// result[perm(i)] = arg1[i] * arg2[i] -template >> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, const Perm& perm) { return permute(mult(arg1, arg2), perm); } /// result[perm(i)] = arg1[i] * arg2[i] * factor -template && - TiledArray::detail::is_permutation_v>> - requires TiledArray::detail::is_numeric_v +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, const Scalar factor, const Perm& perm) { return permute(mult(arg1, arg2, factor), perm); @@ -672,7 +644,8 @@ inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg) { /// result[i] = (result[i] * arg[i]) * factor /// Matches TA::Tensor::mult_to(right, factor) semantics: `(l *= r) *= factor`. template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg, const Scalar factor) { mult_to(result, arg); @@ -680,16 +653,17 @@ inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg, } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg, const Scalar factor) { return mult_to(result, arg, factor); } -/// gemm: returning form. result = factor * left * right -template >> - requires TiledArray::detail::is_numeric_v +/// gemm: result = factor * left * right +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline UMTensor gemm(const UMTensor& left, const UMTensor& right, const Scalar factor, const TiledArray::math::GemmHelper& gemm_helper) { @@ -738,10 +712,10 @@ inline UMTensor gemm(const UMTensor& left, const UMTensor& right, return result; } -/// gemm: accumulating form. result += factor * left * right -template >> - requires TiledArray::detail::is_numeric_v +/// gemm: result += factor * left * right +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v inline void gemm(UMTensor& result, const UMTensor& left, const UMTensor& right, const Scalar factor, const TiledArray::math::GemmHelper& gemm_helper) { @@ -783,19 +757,8 @@ inline void gemm(UMTensor& result, const UMTensor& left, device::sync_madness_task_with(stream); } -// --------------------------------------------------------------------------- -// Array-level helpers: bulk to-host / to-device prefetch and conversions -// between UMTensor-backed and host-Tensor-backed DistArrays. Mirrors the -// btas-device helpers in btas_um_tensor.h:567-617 but for the bare -// TA::Tensor specialization -- so the tile type is `UMTensor` directly, -// not wrapped in `TA::Tile<...>`. -// -// `to_host` / `to_device` are oneshot bulk-prefetch routines: they walk the -// pmap, dispatch one prefetch task per local tile, fence, then issue a -// `deviceSynchronize` to make sure every stream has drained. They're -// "stop the world" by design -- intended for explicit synchronization -// points (before a host read, after a load, etc.), not for inner loops. -// --------------------------------------------------------------------------- +/// Array-level helpers: bulk to-host / to-device prefetch and conversions +/// between UMTensor-backed and host-Tensor-backed DistArrays. /// Prefetch every local tile of `array` to the host. Fences on the /// containing world and globally synchronizes the device on exit. @@ -834,13 +797,10 @@ inline void to_device(TiledArray::DistArray, Policy>& array) { } /// Convert a UMTensor-backed `DistArray` to one backed by a host tile type. -/// Template arg order `` matches the btas pair in -/// device/btas_um_tensor.h:619+. template inline std::enable_if_t, TiledArray::DistArray> -um_tensor_to_ta_tensor( - const TiledArray::DistArray& um_array) { +um_tensor_to_ta_tensor(const TiledArray::DistArray& um_array) { auto convert_tile = [](const UMTile& tile) { detail::to_host(tile); HostTile result(tile.range()); @@ -855,8 +815,7 @@ um_tensor_to_ta_tensor( template inline std::enable_if_t, TiledArray::DistArray> -um_tensor_to_ta_tensor( - const TiledArray::DistArray& um_array) { +um_tensor_to_ta_tensor(const TiledArray::DistArray& um_array) { return um_array; } @@ -887,17 +846,12 @@ ta_tensor_to_um_tensor( } // namespace TiledArray -// --------------------------------------------------------------------------- -// MADNESS archive specializations for UMTensor. -// -// `TA::Tensor::serialize(ar)` works on any allocator (the member just walks -// `data() + range().volume() * nbatch()`), but UM data may be stale on the -// host if a device kernel is in flight. The Store specialization prefetches -// the tile back to the host before reading. Load goes through the default -// member -- the freshly constructed UM-allocated tile is host-writable, so -// no additional prefetch is needed (downstream code that wants the data on -// the device should call `to_device` explicitly). -// --------------------------------------------------------------------------- +/// MADNESS archive specializations for UMTensor. +/// +/// `TA::Tensor::serialize(ar)` works on any allocator (the member just walks +/// `data() + range().volume() * nbatch()`), but UM data may be stale on the +/// host if a device kernel is in flight. The Store specialization prefetches +/// the tile back to the host before reading. namespace madness { namespace archive { @@ -908,16 +862,13 @@ struct ArchiveStoreImpl> { if constexpr (TiledArray::detail::is_numeric_v) { TiledArray::detail::to_host(t); } - // Mirror TA::Tensor::serialize's store side; we cannot call the member - // because it is non-const and we want to keep the input parameter - // const-correct. + // Mirror TA::Tensor::serialize's store side const bool empty = t.empty(); ar & empty; if (!empty) { ar & t.range(); ar & t.nbatch(); - ar & madness::archive::wrap(t.data(), - t.range().volume() * t.nbatch()); + ar& madness::archive::wrap(t.data(), t.range().volume() * t.nbatch()); } } }; @@ -934,8 +885,7 @@ struct ArchiveLoadImpl> { ar & nbatch; t = TiledArray::UMTensor( std::move(range), typename TiledArray::UMTensor::nbatches(nbatch)); - ar & madness::archive::wrap(t.data(), - t.range().volume() * t.nbatch()); + ar& madness::archive::wrap(t.data(), t.range().volume() * t.nbatch()); } else { t = TiledArray::UMTensor(); } @@ -945,13 +895,7 @@ struct ArchiveLoadImpl> { } // namespace archive } // namespace madness -// --------------------------------------------------------------------------- -// `extern template` declarations for the UMTensor class. Match the explicit -// instantiations in src/TiledArray/device/tensor.cpp so that consumers do -// not re-instantiate the full Tensor> class body -// in each TU. (Mirrors the analogous pattern at the bottom of -// src/TiledArray/tensor/tensor.h for the host-side instantiations.) -// --------------------------------------------------------------------------- +/// extern template declarations for the UMTensor class. namespace TiledArray { extern template class Tensor>; From 3a6039a526d410ac8e674faf5389f68ac6193f7b Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 May 2026 06:52:20 +0000 Subject: [PATCH 16/20] test: expand and cleanup UMTensor expression coverage; tighten tolerances to at least 1e-12 --- tests/expressions_device_tensor.cpp | 435 ++++++++++++---------------- 1 file changed, 184 insertions(+), 251 deletions(-) diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp index 97b0b16fa3..a914bfbdda 100644 --- a/tests/expressions_device_tensor.cpp +++ b/tests/expressions_device_tensor.cpp @@ -25,22 +25,21 @@ #include #include -#include #include + +#include +#include + #include "unit_test_config.h" using namespace TiledArray; // Expression-engine tests for the native UMTensor tile type (TA::Tensor -// backed by device_um_allocator). The pattern follows expressions_device_um.cpp -// but uses the bare TA::Tensor specialization -- TA::Tensor is already -// shallow-copy, so we do not wrap it in TA::Tile<> (per CLAUDE.md guidance). +// backed by device_um_allocator). // -// All correctness checks use a CPU-side TiledArray::Tensor mirror -// of the input arrays (built from `find().get()` on the device side and a -// flat std::vector for reference). The expression runs through the engine -// for both sides; we then compare elements after `gop.fence()` to make sure -// the device kernels have actually completed. +// All correctness checks use a host side TiledArray::Tensor mirror +// of the input arrays. The expression runs through the engine for both sides; +// we then compare elements after. struct DeviceTensorExpressionsFixture : public TiledRangeFixture { using TileD = UMTensor; @@ -74,17 +73,13 @@ struct DeviceTensorExpressionsFixture : public TiledRangeFixture { const auto tile_range = d.trange().make_tile_range(*it); const auto vol = tile_range.volume(); - // Build deterministic data so seeds match across allocators. - const auto ord = *it; + std::mt19937 rng(seed + static_cast(*it)); + std::uniform_real_distribution dist(-5.0, 5.0); + typename DeviceArray::value_type d_tile(tile_range); typename HostArrayT::value_type h_tile(tile_range); for (std::size_t k = 0; k < vol; ++k) { - // 1000-element period is plenty for unit testing; division keeps - // values in [-5, 5] so dot products stay representable. - const double v = - static_cast(((ord + 1) * 1664525u + seed + k) % 1000) / - 100.0 - - 5.0; + const double v = dist(rng); d_tile.data()[k] = v; h_tile.data()[k] = v; } @@ -212,7 +207,6 @@ BOOST_AUTO_TEST_CASE(hadamard) { BOOST_AUTO_TEST_CASE(contraction) { // C(i,k) = A(i,j) * B(j,k) requires rank-2 arrays; build them on the fly - // using the first slice of `tr` so the fixture data is reusable. const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); @@ -231,8 +225,6 @@ BOOST_AUTO_TEST_CASE(contraction) { } BOOST_AUTO_TEST_CASE(norm2_value) { - // Scalar reduction across all tiles. Compare device-computed value against - // CPU-computed value from the mirror array. const double dev_norm = TA::norm2(a); const double host_norm = TA::norm2(a_h); GlobalFixture::world->gop.fence(); @@ -247,11 +239,24 @@ BOOST_AUTO_TEST_CASE(dot_value) { BOOST_CHECK_CLOSE_FRACTION(dev_dot, host_dot, 1.0e-12); } +BOOST_AUTO_TEST_CASE(reduce_factories) { + // Expression-level reductions through the DSL: scalar = a("...").reduce() + GlobalFixture::world->gop.fence(); + + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").sum().get(), + a_h("a,b,c").sum().get(), 1.0e-12); + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").squared_norm().get(), + a_h("a,b,c").squared_norm().get(), 1.0e-12); + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").norm().get(), + a_h("a,b,c").norm().get(), 1.0e-12); + BOOST_CHECK_EQUAL(a("a,b,c").min().get(), a_h("a,b,c").min().get()); + BOOST_CHECK_EQUAL(a("a,b,c").max().get(), a_h("a,b,c").max().get()); + BOOST_CHECK_EQUAL(a("a,b,c").abs_min().get(), a_h("a,b,c").abs_min().get()); + BOOST_CHECK_EQUAL(a("a,b,c").abs_max().get(), a_h("a,b,c").abs_max().get()); + BOOST_CHECK_NO_THROW(auto _ = a("a,b,c").product().get()); +} + BOOST_AUTO_TEST_CASE(reuse_stress) { - // MPQC-pattern stress: same input tile referenced multiple times in one - // expression, then again across iterations. Catches the LazyArrayTile - // conversion race if it surfaces (it should be a known master-branch - // baseline failure -- not introduced by this branch). const double host_ref = static_cast(a_h("a,b,c") * a_h("a,b,c")); GlobalFixture::world->gop.fence(); @@ -262,12 +267,8 @@ BOOST_AUTO_TEST_CASE(reuse_stress) { } } -// --------------------------------------------------------------------------- -// In-place expression operators (+=, -=, *=). These exercise the engine's -// "result is consumable" paths that surfaced the dispatch + sign-flip bugs; -// here we want broad coverage of compound assignment forms beyond the -// `add_to` / `subt_to` cases already tested. -// --------------------------------------------------------------------------- + +/// In-place expression operators (+=, -=, *=) BOOST_AUTO_TEST_CASE(plus_equal_expr) { c("a,b,c") = a("a,b,c"); c_h("a,b,c") = a_h("a,b,c"); @@ -293,7 +294,6 @@ BOOST_AUTO_TEST_CASE(minus_equal_expr) { } BOOST_AUTO_TEST_CASE(times_equal_expr) { - // Hadamard, in place. c("a,b,c") = a("a,b,c"); c_h("a,b,c") = a_h("a,b,c"); c("a,b,c") *= b("a,b,c"); @@ -301,10 +301,6 @@ BOOST_AUTO_TEST_CASE(times_equal_expr) { check_close(c, c_h, tolerance); } -// --------------------------------------------------------------------------- -// Negated and scaled-then-negated forms. These force the engine to combine -// scaling with sign-flip across different operand positions. -// --------------------------------------------------------------------------- BOOST_AUTO_TEST_CASE(neg_scaled_sum) { BOOST_REQUIRE_NO_THROW(c("a,b,c") = -(2.0 * (a("a,b,c") + b("a,b,c")))); c_h("a,b,c") = -(2.0 * (a_h("a,b,c") + b_h("a,b,c"))); @@ -317,11 +313,36 @@ BOOST_AUTO_TEST_CASE(neg_permuted) { check_close(c, c_h, tolerance); } -// --------------------------------------------------------------------------- -// Multi-step chains: results of one expression feed the next. Validates -// dataflow handoff between dist-evals without an intervening fence (per -// CLAUDE.md's synchronization-hierarchy section). -// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(scale_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.5 * a("c,b,a")); + c_h("a,b,c") = 2.5 * a_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") - b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(mult_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") * b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +// .conj() on a real-valued tensor is a compile-time no-op in +// `apply_scale_factor`, but the expression DSL still has to parse and route +// through the conjugation factory. This verifies it does. Complex-typed +// fixtures are out of scope here. +BOOST_AUTO_TEST_CASE(conj_real) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").conj()); + c_h("a,b,c") = a_h("a,b,c").conj(); + check_close(c, c_h, tolerance); +} + + +/// Multi-step chains BOOST_AUTO_TEST_CASE(multi_step_chain) { TArrayD t(*GlobalFixture::world, tr); HostArray t_h(*GlobalFixture::world, tr); @@ -332,11 +353,7 @@ BOOST_AUTO_TEST_CASE(multi_step_chain) { check_close(c, c_h, tolerance); } -// --------------------------------------------------------------------------- -// Block expressions. PR 531 hit known issues in this area; we cover the -// common patterns: read-only block, block in a sum, block on the RHS of an -// accumulating assignment. -// --------------------------------------------------------------------------- +/// Block expressions BOOST_AUTO_TEST_CASE(block_assign) { const std::array lo{3, 3, 3}; const std::array up{5, 5, 5}; @@ -382,11 +399,98 @@ BOOST_AUTO_TEST_CASE(block_accumulate) { check_close(blk_d, blk_h, tolerance); } -// --------------------------------------------------------------------------- -// Outer product: c(i,j) = u(i) * v(j). Exercises the rank-changing GEMM -// path (different left / right / result ranks) without going through a -// shared contraction index. -// --------------------------------------------------------------------------- +BOOST_AUTO_TEST_CASE(const_block) { + const auto& ca = a; + const auto& ca_h = a_h; + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = ca("a,b,c").block(lo, up); + blk_h("a,b,c") = ca_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scal_block) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = 2.0 * a("a,b,c").block(lo, up); + blk_h("a,b,c") = 2.0 * a_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(permute_block) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, + TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + // Permute the source annotation before slicing. + blk_d("a,b,c") = a("c,b,a").block(lo, up); + blk_h("a,b,c") = a_h("c,b,a").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(assign_sub_block) { + c.fill_local(0.0); + c_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + BOOST_REQUIRE_NO_THROW(c("a,b,c").block(lo, up) = a("a,b,c").block(lo, up)); + c_h("a,b,c").block(lo, up) = a_h("a,b,c").block(lo, up); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(block_contract) { + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + TArrayD w(*GlobalFixture::world, tr_w); + HostArray w_h(*GlobalFixture::world, tr_w); + w.fill_local(0.0); + w_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array alo{3, 2, 3}; + const std::array aup{5, 5, 5}; + const std::array blo{2, 3, 3}; + const std::array bup{5, 5, 5}; + + BOOST_REQUIRE_NO_THROW( + w("a,b") = a("a,c,d").block(alo, aup) * b("c,d,b").block(blo, bup)); + w_h("a,b") = a_h("a,c,d").block(alo, aup) * b_h("c,d,b").block(blo, bup); + check_close(w, w_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(block_permute_contract) { + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + TArrayD w(*GlobalFixture::world, tr_w); + HostArray w_h(*GlobalFixture::world, tr_w); + w.fill_local(0.0); + w_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array alo{3, 3, 2}; + const std::array aup{5, 5, 5}; + const std::array blo{2, 3, 3}; + const std::array bup{5, 5, 5}; + + BOOST_REQUIRE_NO_THROW( + w("a,b") = a("a,d,c").block(alo, aup) * b("c,d,b").block(blo, bup)); + w_h("a,b") = a_h("a,d,c").block(alo, aup) * b_h("c,d,b").block(blo, bup); + check_close(w, w_h, 1.0e-12); +} + BOOST_AUTO_TEST_CASE(outer_product) { const TiledRange tr_u{tr.data()[0]}; const TiledRange tr_v{tr.data()[1]}; @@ -407,14 +511,8 @@ BOOST_AUTO_TEST_CASE(outer_product) { check_close(w, w_h, 1.0e-12); } -// --------------------------------------------------------------------------- -// Contraction shape variants. Different output ranks and contraction -// patterns. CC-style: result is rank-4, contraction index is multi-d. -// --------------------------------------------------------------------------- BOOST_AUTO_TEST_CASE(contraction_permuted_result) { - // c(k,i) = a(i,j) * b(j,k) -- the same contraction as `contraction` but - // with the result indices swapped; checks that the engine fuses a final - // permutation into the GEMM as CLAUDE.md describes. + // c(k,i) = a(i,j) * b(j,k) const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); @@ -429,12 +527,11 @@ BOOST_AUTO_TEST_CASE(contraction_permuted_result) { c2_h("k,i") = a2_h("i,j") * b2_h("j,k"); // Looser tolerance for permuted GEMM: BLAS sums in different // tile-internal order than the CPU reference path. - check_close(c2, c2_h, 1.0e-10); + check_close(c2, c2_h, 1.0e-12); } BOOST_AUTO_TEST_CASE(contraction_with_transpose_on_right) { - // c(i,k) = a(i,j) * b(k,j) -- right operand needs transposing to align - // the contraction index. + // c(i,k) = a(i,j) * b(k,j) const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); @@ -451,9 +548,7 @@ BOOST_AUTO_TEST_CASE(contraction_with_transpose_on_right) { } BOOST_AUTO_TEST_CASE(contraction_rank4_via_two_indices) { - // r(a,c) = t(a,b,k,l) * v(c,b,k,l) -- pattern that shows up in CC-style - // intermediates; contraction is over (b,k,l), free indices are (a) on - // the left and (c) on the right. + // r(a,c) = t(a,b,k,l) * v(c,b,k,l) const TiledRange tr4{tr.data()[0], tr.data()[1], tr.data()[2], tr.data()[2]}; const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD t(*GlobalFixture::world, tr4); @@ -470,14 +565,8 @@ BOOST_AUTO_TEST_CASE(contraction_rank4_via_two_indices) { check_close(r, r_h, 1.0e-12); } -// --------------------------------------------------------------------------- -// TA::einsum entry point. The fully-typed einsum API is the documented way -// to express patterns the regular `*` operator can't capture (general -// contraction with explicit output indices, Hadamard with permutation, -// etc.). For UMTensor we test that einsum dispatches through the same tile -// ops we already validated above and produces matching results vs. the -// host-tensor reference. -// --------------------------------------------------------------------------- + +/// TA::einsum BOOST_AUTO_TEST_CASE(einsum_matmul) { // c(i,k) = a(i,j) * b(j,k) via einsum const TiledRange tr2{tr.data()[0], tr.data()[1]}; @@ -491,11 +580,11 @@ BOOST_AUTO_TEST_CASE(einsum_matmul) { auto c2 = TiledArray::einsum(a2("i,j"), b2("j,k"), "i,k"); auto c2_h = TiledArray::einsum(a2_h("i,j"), b2_h("j,k"), "i,k"); - check_close(c2, c2_h, 1.0e-11); + check_close(c2, c2_h, 1.0e-12); } BOOST_AUTO_TEST_CASE(einsum_hadamard) { - // c(i,j) = a(i,j) * b(i,j) via einsum -- Hadamard / element-wise multiply + // c(i,j) = a(i,j) * b(i,j) via einsum: Hadamard / element-wise multiply const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); @@ -510,17 +599,8 @@ BOOST_AUTO_TEST_CASE(einsum_hadamard) { check_close(c2, c2_h, tolerance); } -// Note: einsum patterns where an index appears in both inputs *and* the -// output (e.g. `einsum("ij,jk->ijk")`, an outer-product-with-broadcast -// over `j`) are not yet supported for plain (non-ToT) tile types -- they -// segfault inside einsum's internals on master regardless of allocator. -// We don't cover that case here. - BOOST_AUTO_TEST_CASE(einsum_contraction_over_two_indices) { - // c(a,c) = t(a,b,k) * v(c,b,k) via einsum -- contraction over (b, k), - // free indices (a) on the left and (c) on the right. CC-intermediate - // shape, fully expressible with the regular `*` operator but still - // worth covering through the einsum entry point. + // c(a,c) = t(a,b,k) * v(c,b,k) via einsum: contraction over (b, k), const TiledRange tr3{tr.data()[0], tr.data()[1], tr.data()[2]}; const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD t(*GlobalFixture::world, tr3); @@ -533,20 +613,17 @@ BOOST_AUTO_TEST_CASE(einsum_contraction_over_two_indices) { auto r = TiledArray::einsum(t("a,b,k"), v("c,b,k"), "a,c"); auto r_h = TiledArray::einsum(t_h("a,b,k"), v_h("c,b,k"), "a,c"); - check_close(r, r_h, 1.0e-11); + check_close(r, r_h, 1.0e-12); } -BOOST_AUTO_TEST_CASE(einsum_permuted_result) { - // c(j,i) = a(i,j) -- one-operand reshape; not a true einsum binary, but - // also useful: verify einsum handles single-input permutation. +BOOST_AUTO_TEST_CASE(rank2_transpose) { + // c(j,i) = a(i,j) const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); HostArray a2_h(*GlobalFixture::world, tr2); fill_with_seed(a2, a2_h, 97); GlobalFixture::world->gop.fence(); - // For permutation alone we just use the expression DSL, which einsum - // delegates to; this verifies that path still works for UMTensor. TArrayD a2T; HostArray a2T_h; a2T("j,i") = a2("i,j"); @@ -554,13 +631,7 @@ BOOST_AUTO_TEST_CASE(einsum_permuted_result) { check_close(a2T, a2T_h, tolerance); } -// --------------------------------------------------------------------------- -// Scaled / permuted variants of the elementary arithmetic ops. The first -// commit of these tests covered the bare forms; the engine fuses scaling -// and permutation differently across these combinations, so each one is -// a distinct dispatch path worth validating numerically. -// --------------------------------------------------------------------------- - +/// Scaled and permuted variants of the elementary arithmetic ops. BOOST_AUTO_TEST_CASE(scale_add) { BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5.0 * (a("a,b,c") + b("a,b,c"))); c_h("a,b,c") = 5.0 * (a_h("a,b,c") + b_h("a,b,c")); @@ -613,11 +684,7 @@ BOOST_AUTO_TEST_CASE(scale_mult_permute) { check_close(c, c_h, tolerance); } -// --------------------------------------------------------------------------- -// Scaled contraction variants. These exercise the engine's scale-fuse- -// into-GEMM path that PR 531 stumbled on. Tolerance is 1e-10 for GEMM -// paths to absorb summation-order differences between BLAS and Eigen. -// --------------------------------------------------------------------------- +/// Scaled contraction variants BOOST_AUTO_TEST_CASE(scale_cont) { const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); @@ -632,7 +699,7 @@ BOOST_AUTO_TEST_CASE(scale_cont) { BOOST_REQUIRE_NO_THROW(c2("i,k") = 5.0 * (a2("i,j") * b2("j,k"))); c2_h("i,k") = 5.0 * (a2_h("i,j") * b2_h("j,k")); - check_close(c2, c2_h, 1.0e-10); + check_close(c2, c2_h, 1.0e-12); } BOOST_AUTO_TEST_CASE(scale_cont_permute) { @@ -647,14 +714,14 @@ BOOST_AUTO_TEST_CASE(scale_cont_permute) { fill_with_seed(b2, b2_h, 109); GlobalFixture::world->gop.fence(); - // c(k,i) = 5 * a(i,j) * b(j,k): scaled, result-permuted contraction. + // c(k,i) = 5 * a(i,j) * b(j,k) BOOST_REQUIRE_NO_THROW(c2("k,i") = 5.0 * (a2("i,j") * b2("j,k"))); c2_h("k,i") = 5.0 * (a2_h("i,j") * b2_h("j,k")); - check_close(c2, c2_h, 1.0e-10); + check_close(c2, c2_h, 1.0e-12); } BOOST_AUTO_TEST_CASE(scale_cont_with_input_transpose) { - // 5 * a(i,j) * b(k,j) -- contraction needs to transpose b before GEMM. + // 5 * a(i,j) * b(k,j) const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); @@ -668,16 +735,10 @@ BOOST_AUTO_TEST_CASE(scale_cont_with_input_transpose) { BOOST_REQUIRE_NO_THROW(c2("i,k") = 5.0 * (a2("i,j") * b2("k,j"))); c2_h("i,k") = 5.0 * (a2_h("i,j") * b2_h("k,j")); - check_close(c2, c2_h, 1.0e-10); + check_close(c2, c2_h, 1.0e-12); } -// --------------------------------------------------------------------------- -// Non-uniform tile sizes for contraction. Mirrors btas-device's -// cont_non_uniform1/2: the rank-4 inputs use one tiny tiling on the -// outer dimensions and one wide tiling on an inner dimension, so the -// GEMM has irregular per-tile k blocks. Catches GEMM kernels that -// silently assume uniform tile shapes. -// --------------------------------------------------------------------------- +/// Non-uniform tile sizes for contraction. BOOST_AUTO_TEST_CASE(cont_non_uniform_split_inner) { std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; std::array tiling2 = {{0, 40}}; @@ -699,7 +760,7 @@ BOOST_AUTO_TEST_CASE(cont_non_uniform_split_inner) { BOOST_REQUIRE_NO_THROW(out("x,y") = 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); - check_close(out, out_h, 1.0e-9); + check_close(out, out_h, 1.0e-12); } BOOST_AUTO_TEST_CASE(cont_non_uniform_split_two_inner) { @@ -723,14 +784,10 @@ BOOST_AUTO_TEST_CASE(cont_non_uniform_split_two_inner) { BOOST_REQUIRE_NO_THROW(out("x,y") = 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); - check_close(out, out_h, 1.0e-9); + check_close(out, out_h, 1.0e-12); } -// --------------------------------------------------------------------------- -// Contraction-plus-reduction (norm2 of a contraction). Exercises the -// dataflow handoff from a binary dist-eval to a reduction without an -// intervening fence. -// --------------------------------------------------------------------------- +/// Contraction-plus-reduction BOOST_AUTO_TEST_CASE(cont_plus_reduce) { const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); @@ -748,13 +805,10 @@ BOOST_AUTO_TEST_CASE(cont_plus_reduce) { const double dev_n = TA::norm2(c2); const double host_n = TA::norm2(c2_h); GlobalFixture::world->gop.fence(); - BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-10); + BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-12); } BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { - // `no_alias()` tells the engine the LHS does not alias any RHS operand, - // permitting an extra in-place optimization. Validate that path - // produces correct values. const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); @@ -772,124 +826,14 @@ BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { BOOST_REQUIRE_NO_THROW(c2("i,k").no_alias() = a2("i,j") * b2("j,k")); c2_h("i,k").no_alias() = a2_h("i,j") * b2_h("j,k"); - check_close(c2, c2_h, 1.0e-10); + check_close(c2, c2_h, 1.0e-12); const double dev_n = TA::norm2(c2); const double host_n = TA::norm2(c2_h); GlobalFixture::world->gop.fence(); - BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-10); -} - -// --------------------------------------------------------------------------- -// Block-expression variants beyond the basic three already covered. -// Block bounds are TILE coordinates; a {3,3,3} -> {5,5,5} block selects -// the 2x2x2 corner tiles of `tr` (5 tiles per dim). -// --------------------------------------------------------------------------- -BOOST_AUTO_TEST_CASE(const_block) { - const auto& ca = a; - const auto& ca_h = a_h; - const std::array lo{3, 3, 3}; - const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, - TiledRange1{lo[2], up[2]}}; - TArrayD blk_d(*GlobalFixture::world, ctr); - HostArray blk_h(*GlobalFixture::world, ctr); - blk_d("a,b,c") = ca("a,b,c").block(lo, up); - blk_h("a,b,c") = ca_h("a,b,c").block(lo, up); - check_close(blk_d, blk_h, tolerance); -} - -BOOST_AUTO_TEST_CASE(scal_block) { - const std::array lo{3, 3, 3}; - const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, - TiledRange1{lo[2], up[2]}}; - TArrayD blk_d(*GlobalFixture::world, ctr); - HostArray blk_h(*GlobalFixture::world, ctr); - blk_d("a,b,c") = 2.0 * a("a,b,c").block(lo, up); - blk_h("a,b,c") = 2.0 * a_h("a,b,c").block(lo, up); - check_close(blk_d, blk_h, tolerance); -} - -BOOST_AUTO_TEST_CASE(permute_block) { - const std::array lo{3, 3, 3}; - const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, - TiledRange1{lo[2], up[2]}}; - TArrayD blk_d(*GlobalFixture::world, ctr); - HostArray blk_h(*GlobalFixture::world, ctr); - // Permute the source annotation before slicing. - blk_d("a,b,c") = a("c,b,a").block(lo, up); - blk_h("a,b,c") = a_h("c,b,a").block(lo, up); - check_close(blk_d, blk_h, tolerance); -} - -BOOST_AUTO_TEST_CASE(assign_sub_block) { - // Write into a tile sub-block of an existing array. Tiles outside the - // block keep their original contents -- so we initialize both sides - // identically with a known value before the block assignment. - c.fill_local(0.0); - c_h.fill_local(0.0); - GlobalFixture::world->gop.fence(); - - const std::array lo{3, 3, 3}; - const std::array up{5, 5, 5}; - BOOST_REQUIRE_NO_THROW(c("a,b,c").block(lo, up) = a("a,b,c").block(lo, up)); - c_h("a,b,c").block(lo, up) = a_h("a,b,c").block(lo, up); - check_close(c, c_h, tolerance); + BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-12); } -// --------------------------------------------------------------------------- -// Block-fed-into-contraction. PR 531 had known issues here. The result -// array has rank 2 (carved out of the rank-3 fixture by contracting two -// indices). -// --------------------------------------------------------------------------- -BOOST_AUTO_TEST_CASE(block_contract) { - const TiledRange tr_w{tr.data()[0], tr.data()[1]}; - TArrayD w(*GlobalFixture::world, tr_w); - HostArray w_h(*GlobalFixture::world, tr_w); - w.fill_local(0.0); - w_h.fill_local(0.0); - GlobalFixture::world->gop.fence(); - - const std::array alo{3, 2, 3}; - const std::array aup{5, 5, 5}; - const std::array blo{2, 3, 3}; - const std::array bup{5, 5, 5}; - - BOOST_REQUIRE_NO_THROW( - w("a,b") = a("a,c,d").block(alo, aup) * b("c,d,b").block(blo, bup)); - w_h("a,b") = a_h("a,c,d").block(alo, aup) * b_h("c,d,b").block(blo, bup); - check_close(w, w_h, 1.0e-10); -} - -BOOST_AUTO_TEST_CASE(block_permute_contract) { - // Same as block_contract but with a permuted left-operand annotation: - // `a("a,d,c")` instead of `a("a,c,d")` -- forces a permutation of the - // sliced block before GEMM. - const TiledRange tr_w{tr.data()[0], tr.data()[1]}; - TArrayD w(*GlobalFixture::world, tr_w); - HostArray w_h(*GlobalFixture::world, tr_w); - w.fill_local(0.0); - w_h.fill_local(0.0); - GlobalFixture::world->gop.fence(); - - const std::array alo{3, 3, 2}; - const std::array aup{5, 5, 5}; - const std::array blo{2, 3, 3}; - const std::array bup{5, 5, 5}; - - BOOST_REQUIRE_NO_THROW( - w("a,b") = a("a,d,c").block(alo, aup) * b("c,d,b").block(blo, bup)); - w_h("a,b") = a_h("a,d,c").block(alo, aup) * b_h("c,d,b").block(blo, bup); - check_close(w, w_h, 1.0e-10); -} - -// --------------------------------------------------------------------------- -// Dot-product variants beyond the basic case. -// --------------------------------------------------------------------------- +/// Dot-product variants BOOST_AUTO_TEST_CASE(dot_permute) { const double dev_d = static_cast(a("a,b,c") * b("c,b,a")); @@ -903,8 +847,6 @@ BOOST_AUTO_TEST_CASE(dot_permute) { BOOST_AUTO_TEST_CASE(dot_contr) { // Dot of two contraction expressions: scalar = (a*b) . (b*a). - // This is a NO_THROW-only check in the btas-device suite; we go one - // step further and validate the scalar value against the CPU mirror. const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); @@ -919,14 +861,10 @@ BOOST_AUTO_TEST_CASE(dot_contr) { const double host_d = static_cast( (a2_h("i,j") * b2_h("j,k")) * (a2_h("i,j") * b2_h("j,k"))); GlobalFixture::world->gop.fence(); - BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-10); + BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-12); } -// --------------------------------------------------------------------------- -// Archive round-trip, host/device array conversions, and bulk to_host / -// to_device. Smoke + correctness for the helpers in device/tensor.h that -// are not in the expression-engine path. -// --------------------------------------------------------------------------- +/// Archive round-trip, host/device array conversions BOOST_AUTO_TEST_CASE(serialize_um_tensor) { // Single-tile round-trip: build a UMTensor, write to a buffer archive, // read into a fresh UMTensor, compare element-wise. The Store side @@ -978,7 +916,7 @@ BOOST_AUTO_TEST_CASE(um_to_ta_round_trip) { HostArray host_view = TiledArray::um_tensor_to_ta_tensor(a); GlobalFixture::world->gop.fence(); - // Compare host_view directly against a_h (same data was used for both). + // Compare host_view directly against a_h check_close(host_view, a_h, tolerance); TArrayD device_view = @@ -990,8 +928,6 @@ BOOST_AUTO_TEST_CASE(um_to_ta_round_trip) { } BOOST_AUTO_TEST_CASE(um_to_ta_then_expression) { - // After converting a device array to host, plain CPU expressions on the - // host array should produce the same values as their device counterpart. HostArray sum_h(*GlobalFixture::world, tr); sum_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); @@ -1006,9 +942,6 @@ BOOST_AUTO_TEST_CASE(um_to_ta_then_expression) { BOOST_AUTO_TEST_CASE(bulk_prefetch_round_trip) { // to_host / to_device on a DistArray should be no-ops for correctness: - // the array contents are unchanged, only the page residency hints are - // adjusted. We just verify the array is still equal to its host mirror - // after bouncing through both directions. TiledArray::to_host(a); TiledArray::to_device(a); GlobalFixture::world->gop.fence(); From 2b5137ff375dd25afdb6ea7d90134ec2ceab2195 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 May 2026 20:42:46 +0000 Subject: [PATCH 17/20] device: add missing congruence asserts to UMTensor gemm overloads The returning overload was missing left_right_congruent checks, and the in-place overload was missing the full left_result/right_result/left_right congruence set. Silent geometry mismatches would have produced wrong results instead of asserting. Mirrors the asserts in tensor/kernels.h's host-side detail::gemm worker. --- src/TiledArray/device/tensor.h | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 331da8c2f4..0f76a380f3 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -673,6 +673,15 @@ inline UMTensor gemm(const UMTensor& left, const UMTensor& right, TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); TA_ASSERT(left.nbatch() == 1 && right.nbatch() == 1); + TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), + right.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().lobound_data(), + right.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().upbound_data(), + right.range().upbound_data())); + auto result_range = gemm_helper.template make_result_range( left.range(), right.range()); @@ -727,6 +736,31 @@ inline void gemm(UMTensor& result, const UMTensor& left, TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); TA_ASSERT(left.nbatch() == 1 && right.nbatch() == 1 && result.nbatch() == 1); + TA_ASSERT(gemm_helper.left_result_congruent(left.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_result_congruent(left.range().lobound_data(), + result.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_result_congruent(left.range().upbound_data(), + result.range().upbound_data())); + TA_ASSERT(gemm_helper.right_result_congruent(right.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.right_result_congruent(right.range().lobound_data(), + result.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.right_result_congruent(right.range().upbound_data(), + result.range().upbound_data())); + TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), + right.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().lobound_data(), + right.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().upbound_data(), + right.range().upbound_data())); + auto& queue = blasqueue_for(result.range()); const device::Stream stream(queue.device(), queue.stream()); DeviceSafeCall(device::setDevice(stream.device)); From 3fabd116b0e23ed945181ed936fadfa2503bfddf Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 May 2026 20:43:18 +0000 Subject: [PATCH 18/20] device: sync after async to_host prefetch in UM->host paths detail::to_host wraps memPrefetchAsync, so the immediate host reads in um_tensor_to_ta_tensor's convert_tile lambda and in ArchiveStoreImpl::store were racing the prefetch on Pascal+ devices with concurrent_managed_access. Drain the stream via sync_madness_task_with before the host walks the tile. --- src/TiledArray/device/tensor.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 0f76a380f3..a2691f3a68 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -836,7 +836,9 @@ inline std::enable_if_t, TiledArray::DistArray> um_tensor_to_ta_tensor(const TiledArray::DistArray& um_array) { auto convert_tile = [](const UMTile& tile) { + auto stream = device::stream_for(tile.range()); detail::to_host(tile); + device::sync_madness_task_with(stream); HostTile result(tile.range()); std::copy_n(tile.data(), tile.total_size(), result.data()); return result; @@ -894,9 +896,12 @@ struct ArchiveStoreImpl> { static inline void store(const Archive& ar, const TiledArray::UMTensor& t) { if constexpr (TiledArray::detail::is_numeric_v) { - TiledArray::detail::to_host(t); + if (!t.empty()) { + auto stream = TiledArray::device::stream_for(t.range()); + TiledArray::detail::to_host(t); + TiledArray::device::sync_madness_task_with(stream); + } } - // Mirror TA::Tensor::serialize's store side const bool empty = t.empty(); ar & empty; if (!empty) { From ba5efa21648397fdca8f6aab84ecc29f9d307e6a Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 May 2026 20:43:47 +0000 Subject: [PATCH 19/20] device: small UMTensor cleanups - shift_to: call Tensor::shift_to instead of const_cast'ing the range. TA::Tensor exposes a public shift_to member (unlike btas::Tensor), so the const_cast inherited from btas_um_tensor.h is unnecessary here. - apply_scale_factor: flatten 3-level nested if constexpr into one else-if-constexpr cascade. --- src/TiledArray/device/tensor.h | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index a2691f3a68..9a03f1f768 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -145,22 +145,18 @@ inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor, if constexpr (TiledArray::detail::is_blas_numeric_v || std::is_arithmetic_v) { ::blas::scal(n, factor, data, 1, queue); - } else { - if constexpr (TiledArray::detail::is_complex_v) { - TA_EXCEPTION( - "UMTensor scale with ComplexConjugate factor on complex T is not " - "implemented (requires a fused conjugation kernel)"); - } else { - if constexpr (std::is_same_v< - Scalar, TiledArray::detail::ComplexConjugate>) { - // conjugation on a real tensor is a no-op - } else if constexpr (std::is_same_v< - Scalar, - TiledArray::detail::ComplexConjugate< - TiledArray::detail::ComplexNegTag>>) { - ::blas::scal(n, static_cast(-1), data, 1, queue); - } - } + } else if constexpr (TiledArray::detail::is_complex_v) { + TA_EXCEPTION( + "UMTensor scale with ComplexConjugate factor on complex T is not " + "implemented (requires a fused conjugation kernel)"); + } else if constexpr (std::is_same_v< + Scalar, + TiledArray::detail::ComplexConjugate>) { + // conjugation on a real tensor is a no-op + } else if constexpr (std::is_same_v< + Scalar, TiledArray::detail::ComplexConjugate< + TiledArray::detail::ComplexNegTag>>) { + ::blas::scal(n, static_cast(-1), data, 1, queue); } } @@ -546,10 +542,7 @@ inline UMTensor shift(const UMTensor& arg, const Index& bound_shift) { template requires TiledArray::detail::is_numeric_v inline UMTensor& shift_to(UMTensor& arg, const Index& bound_shift) { - // `range()` only exposes a const accessor; cast is safe because we are the - // tile's owner here and only the range bounds change, not the data layout. - const_cast(arg.range()).inplace_shift(bound_shift); - return arg; + return arg.shift_to(bound_shift); } template From 7730bb4272fe9023aa95873ce9423261793fdea8 Mon Sep 17 00:00:00 2001 From: Ajay Date: Fri, 22 May 2026 10:21:33 -0400 Subject: [PATCH 20/20] chore: reformat source files --- .gitignore | 2 +- examples/device/ta_dense_um_tensor.cpp | 14 +++-- examples/device/ta_vector_um_tensor.cpp | 3 +- src/TiledArray/device/tensor.cpp | 1 - src/TiledArray/device/tensor.h | 6 +-- tests/expressions_device_tensor.cpp | 68 ++++++++++--------------- 6 files changed, 37 insertions(+), 57 deletions(-) diff --git a/.gitignore b/.gitignore index 7eb9952c74..09398e2bf3 100644 --- a/.gitignore +++ b/.gitignore @@ -46,4 +46,4 @@ build/* -cmake-build* \ No newline at end of file +cmake-build* diff --git a/examples/device/ta_dense_um_tensor.cpp b/examples/device/ta_dense_um_tensor.cpp index 1902824379..6b94e315f3 100644 --- a/examples/device/ta_dense_um_tensor.cpp +++ b/examples/device/ta_dense_um_tensor.cpp @@ -57,10 +57,9 @@ void run(TiledArray::World &world, long Nm, long Bm, long Nn, long Bn, long Nk, constexpr bool complex_T = TA::detail::is_complex_v; // GEMM flops: 2 * M * N * K (8 * for complex). - const std::int64_t nflops = (complex_T ? 8 : 2) * - static_cast(Nm) * - static_cast(Nn) * - static_cast(Nk); + const std::int64_t nflops = + (complex_T ? 8 : 2) * static_cast(Nm) * + static_cast(Nn) * static_cast(Nk); auto blocking = [](long N, long B) { std::vector v; @@ -133,8 +132,8 @@ void run(TiledArray::World &world, long Nm, long Bm, long Nn, long Bn, long Nk, if (world.rank() == 0) std::cout << " Average time = " << (total_time / double(nrepeat)) - << " s\n Average gflops = " - << (total_gflops / double(nrepeat)) << "\n"; + << " s\n Average gflops = " << (total_gflops / double(nrepeat)) + << "\n"; // Verify: every result element should be Nk * val_a * val_b. const T expected = T(Nk) * val_a * val_b; @@ -169,8 +168,7 @@ int try_main(int argc, char **argv) { if (argc < 7) { if (world.rank() == 0) std::cerr - << "Usage: " << argv[0] - << " Nm Bm Nn Bn Nk Bk [nrepeat=5]\n" + << "Usage: " << argv[0] << " Nm Bm Nn Bn Nk Bk [nrepeat=5]\n" << " Computes c(Nm,Nn) = a(Nm,Nk) * b(Nk,Nn) with UMTensor tiles\n"; return 1; } diff --git a/examples/device/ta_vector_um_tensor.cpp b/examples/device/ta_vector_um_tensor.cpp index 8d0960be4f..603a3bd15f 100644 --- a/examples/device/ta_vector_um_tensor.cpp +++ b/examples/device/ta_vector_um_tensor.cpp @@ -124,8 +124,7 @@ int try_main(int argc, char **argv) { if (argc < 5) { if (world.rank() == 0) std::cerr - << "Usage: " << argv[0] - << " Nm Bm Nn Bn [nrepeat=5]\n" + << "Usage: " << argv[0] << " Nm Bm Nn Bn [nrepeat=5]\n" << " Times element-wise vector ops on Nm x Nn UMTensor matrices\n"; return 1; } diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp index 3065d12e1f..94dc4d3cc0 100644 --- a/src/TiledArray/device/tensor.cpp +++ b/src/TiledArray/device/tensor.cpp @@ -40,4 +40,3 @@ template class Tensor>; template class Tensor>; } // namespace TiledArray - diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h index 9a03f1f768..4488734191 100644 --- a/src/TiledArray/device/tensor.h +++ b/src/TiledArray/device/tensor.h @@ -153,9 +153,9 @@ inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor, Scalar, TiledArray::detail::ComplexConjugate>) { // conjugation on a real tensor is a no-op - } else if constexpr (std::is_same_v< - Scalar, TiledArray::detail::ComplexConjugate< - TiledArray::detail::ComplexNegTag>>) { + } else if constexpr (std::is_same_v>) { ::blas::scal(n, static_cast(-1), data, 1, queue); } } diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp index a914bfbdda..126e17c3b4 100644 --- a/tests/expressions_device_tensor.cpp +++ b/tests/expressions_device_tensor.cpp @@ -27,8 +27,8 @@ #include #include -#include #include +#include #include "unit_test_config.h" @@ -243,12 +243,12 @@ BOOST_AUTO_TEST_CASE(reduce_factories) { // Expression-level reductions through the DSL: scalar = a("...").reduce() GlobalFixture::world->gop.fence(); - BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").sum().get(), - a_h("a,b,c").sum().get(), 1.0e-12); + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").sum().get(), a_h("a,b,c").sum().get(), + 1.0e-12); BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").squared_norm().get(), a_h("a,b,c").squared_norm().get(), 1.0e-12); - BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").norm().get(), - a_h("a,b,c").norm().get(), 1.0e-12); + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").norm().get(), a_h("a,b,c").norm().get(), + 1.0e-12); BOOST_CHECK_EQUAL(a("a,b,c").min().get(), a_h("a,b,c").min().get()); BOOST_CHECK_EQUAL(a("a,b,c").max().get(), a_h("a,b,c").max().get()); BOOST_CHECK_EQUAL(a("a,b,c").abs_min().get(), a_h("a,b,c").abs_min().get()); @@ -257,8 +257,7 @@ BOOST_AUTO_TEST_CASE(reduce_factories) { } BOOST_AUTO_TEST_CASE(reuse_stress) { - const double host_ref = - static_cast(a_h("a,b,c") * a_h("a,b,c")); + const double host_ref = static_cast(a_h("a,b,c") * a_h("a,b,c")); GlobalFixture::world->gop.fence(); for (int iter = 0; iter < 8; ++iter) { const double d = static_cast(a("a,b,c") * a("a,b,c")); @@ -267,7 +266,6 @@ BOOST_AUTO_TEST_CASE(reuse_stress) { } } - /// In-place expression operators (+=, -=, *=) BOOST_AUTO_TEST_CASE(plus_equal_expr) { c("a,b,c") = a("a,b,c"); @@ -341,7 +339,6 @@ BOOST_AUTO_TEST_CASE(conj_real) { check_close(c, c_h, tolerance); } - /// Multi-step chains BOOST_AUTO_TEST_CASE(multi_step_chain) { TArrayD t(*GlobalFixture::world, tr); @@ -360,8 +357,7 @@ BOOST_AUTO_TEST_CASE(block_assign) { // Result range matches the block's element range; build small companion // arrays to receive the result. - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, TiledRange1{lo[2], up[2]}}; TArrayD blk_d(*GlobalFixture::world, ctr); HostArray blk_h(*GlobalFixture::world, ctr); @@ -373,8 +369,7 @@ BOOST_AUTO_TEST_CASE(block_assign) { BOOST_AUTO_TEST_CASE(block_add_then_scale) { const std::array lo{3, 3, 3}; const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, TiledRange1{lo[2], up[2]}}; TArrayD blk_d(*GlobalFixture::world, ctr); HostArray blk_h(*GlobalFixture::world, ctr); @@ -387,8 +382,7 @@ BOOST_AUTO_TEST_CASE(block_add_then_scale) { BOOST_AUTO_TEST_CASE(block_accumulate) { const std::array lo{3, 3, 3}; const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, TiledRange1{lo[2], up[2]}}; TArrayD blk_d(*GlobalFixture::world, ctr); HostArray blk_h(*GlobalFixture::world, ctr); @@ -404,8 +398,7 @@ BOOST_AUTO_TEST_CASE(const_block) { const auto& ca_h = a_h; const std::array lo{3, 3, 3}; const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, TiledRange1{lo[2], up[2]}}; TArrayD blk_d(*GlobalFixture::world, ctr); HostArray blk_h(*GlobalFixture::world, ctr); @@ -417,8 +410,7 @@ BOOST_AUTO_TEST_CASE(const_block) { BOOST_AUTO_TEST_CASE(scal_block) { const std::array lo{3, 3, 3}; const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, TiledRange1{lo[2], up[2]}}; TArrayD blk_d(*GlobalFixture::world, ctr); HostArray blk_h(*GlobalFixture::world, ctr); @@ -430,8 +422,7 @@ BOOST_AUTO_TEST_CASE(scal_block) { BOOST_AUTO_TEST_CASE(permute_block) { const std::array lo{3, 3, 3}; const std::array up{5, 5, 5}; - const TiledRange ctr{TiledRange1{lo[0], up[0]}, - TiledRange1{lo[1], up[1]}, + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, TiledRange1{lo[2], up[2]}}; TArrayD blk_d(*GlobalFixture::world, ctr); HostArray blk_h(*GlobalFixture::world, ctr); @@ -466,8 +457,8 @@ BOOST_AUTO_TEST_CASE(block_contract) { const std::array blo{2, 3, 3}; const std::array bup{5, 5, 5}; - BOOST_REQUIRE_NO_THROW( - w("a,b") = a("a,c,d").block(alo, aup) * b("c,d,b").block(blo, bup)); + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,c,d").block(alo, aup) * + b("c,d,b").block(blo, bup)); w_h("a,b") = a_h("a,c,d").block(alo, aup) * b_h("c,d,b").block(blo, bup); check_close(w, w_h, 1.0e-12); } @@ -485,8 +476,8 @@ BOOST_AUTO_TEST_CASE(block_permute_contract) { const std::array blo{2, 3, 3}; const std::array bup{5, 5, 5}; - BOOST_REQUIRE_NO_THROW( - w("a,b") = a("a,d,c").block(alo, aup) * b("c,d,b").block(blo, bup)); + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,d,c").block(alo, aup) * + b("c,d,b").block(blo, bup)); w_h("a,b") = a_h("a,d,c").block(alo, aup) * b_h("c,d,b").block(blo, bup); check_close(w, w_h, 1.0e-12); } @@ -565,7 +556,6 @@ BOOST_AUTO_TEST_CASE(contraction_rank4_via_two_indices) { check_close(r, r_h, 1.0e-12); } - /// TA::einsum BOOST_AUTO_TEST_CASE(einsum_matmul) { // c(i,k) = a(i,j) * b(j,k) via einsum @@ -757,8 +747,7 @@ BOOST_AUTO_TEST_CASE(cont_non_uniform_split_inner) { fill_with_seed(rhs, rhs_h, 137); GlobalFixture::world->gop.fence(); - BOOST_REQUIRE_NO_THROW(out("x,y") = - 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); + BOOST_REQUIRE_NO_THROW(out("x,y") = 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); check_close(out, out_h, 1.0e-12); } @@ -781,8 +770,7 @@ BOOST_AUTO_TEST_CASE(cont_non_uniform_split_two_inner) { fill_with_seed(rhs, rhs_h, 149); GlobalFixture::world->gop.fence(); - BOOST_REQUIRE_NO_THROW(out("x,y") = - 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); + BOOST_REQUIRE_NO_THROW(out("x,y") = 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); check_close(out, out_h, 1.0e-12); } @@ -812,12 +800,10 @@ BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { const TiledRange tr2{tr.data()[0], tr.data()[1]}; TArrayD a2(*GlobalFixture::world, tr2); TArrayD b2(*GlobalFixture::world, tr2); - TArrayD c2(*GlobalFixture::world, - TiledRange{tr.data()[0], tr.data()[1]}); + TArrayD c2(*GlobalFixture::world, TiledRange{tr.data()[0], tr.data()[1]}); HostArray a2_h(*GlobalFixture::world, tr2); HostArray b2_h(*GlobalFixture::world, tr2); - HostArray c2_h(*GlobalFixture::world, - TiledRange{tr.data()[0], tr.data()[1]}); + HostArray c2_h(*GlobalFixture::world, TiledRange{tr.data()[0], tr.data()[1]}); fill_with_seed(a2, a2_h, 163); fill_with_seed(b2, b2_h, 167); c2.fill_local(0.0); @@ -835,10 +821,8 @@ BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { /// Dot-product variants BOOST_AUTO_TEST_CASE(dot_permute) { - const double dev_d = - static_cast(a("a,b,c") * b("c,b,a")); - const double host_d = - static_cast(a_h("a,b,c") * b_h("c,b,a")); + const double dev_d = static_cast(a("a,b,c") * b("c,b,a")); + const double host_d = static_cast(a_h("a,b,c") * b_h("c,b,a")); GlobalFixture::world->gop.fence(); // Looser tolerance because permuted dot reads tiles in a different // order, so the partial-sum accumulation order differs. @@ -856,10 +840,10 @@ BOOST_AUTO_TEST_CASE(dot_contr) { fill_with_seed(b2, b2_h, 179); GlobalFixture::world->gop.fence(); - const double dev_d = static_cast( - (a2("i,j") * b2("j,k")) * (a2("i,j") * b2("j,k"))); - const double host_d = static_cast( - (a2_h("i,j") * b2_h("j,k")) * (a2_h("i,j") * b2_h("j,k"))); + const double dev_d = + static_cast((a2("i,j") * b2("j,k")) * (a2("i,j") * b2("j,k"))); + const double host_d = static_cast((a2_h("i,j") * b2_h("j,k")) * + (a2_h("i,j") * b2_h("j,k"))); GlobalFixture::world->gop.fence(); BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-12); }